demo3.py 4.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. import numpy as np
  3. import gradio as gr
  4. css = '''
  5. code {white-space: pre-wrap !important;}
  6. .gradio-container {max-width: none !important;}
  7. .outer_parent {flex: 1;}
  8. .inner_parent {flex: 1;}
  9. footer {display: none !important; visibility: hidden !important;}
  10. .translucent {display: none !important; visibility: hidden !important;}
  11. '''
  12. from gradio.themes.utils import colors
  13. with gr.Blocks(
  14. fill_height=True, css=css,
  15. theme=gr.themes.Default(primary_hue=colors.blue, secondary_hue=colors.cyan, neutral_hue=colors.gray)
  16. ) as demo:
  17. with gr.Row(elem_classes='outer_parent'):
  18. with gr.Column(scale=25):
  19. with gr.Row():
  20. clear_btn = gr.Button("➕ New Chat", variant="secondary", size="sm", min_width=60)
  21. retry_btn = gr.Button("Retry", variant="secondary", size="sm", min_width=60, visible=False)
  22. undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, interactive=False)
  23. seed = gr.Number(label="Random Seed", value=12345, precision=0)
  24. with gr.Accordion(open=True, label='Language Model'):
  25. with gr.Group():
  26. with gr.Row():
  27. temperature = gr.Slider(
  28. minimum=0.0,
  29. maximum=2.0,
  30. step=0.01,
  31. value=0.6,
  32. label="Temperature")
  33. top_p = gr.Slider(
  34. minimum=0.0,
  35. maximum=1.0,
  36. step=0.01,
  37. value=0.9,
  38. label="Top P")
  39. max_new_tokens = gr.Slider(
  40. minimum=128,
  41. maximum=4096,
  42. step=1,
  43. value=4096,
  44. label="Max New Tokens")
  45. with gr.Accordion(open=True, label='Image Diffusion Model'):
  46. with gr.Group():
  47. with gr.Row():
  48. image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=896, step=64)
  49. image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=1152, step=64)
  50. with gr.Row():
  51. num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1)
  52. steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1)
  53. with gr.Accordion(open=False, label='Advanced'):
  54. cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01)
  55. highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01)
  56. highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1)
  57. highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01)
  58. n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
  59. render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False)
  60. examples = gr.Dataset(
  61. samples=[
  62. ['generate an image of the fierce battle of warriors and a dragon'],
  63. ['change the dragon to a dinosaur']
  64. ],
  65. components=[gr.Textbox(visible=False)],
  66. label='Quick Prompts'
  67. )
  68. with gr.Column(scale=75, elem_classes='inner_parent'):
  69. canvas_state = gr.State(None)
  70. chatbot = gr.Chatbot(label='Omost', scale=1, show_copy_button=True, layout="panel", render=False)
  71. def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height,
  72. highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt):
  73. pass
  74. render_button.click(
  75. fn=diffusion_fn, inputs=[
  76. chatInterface.chatbot, canvas_state,
  77. num_samples, seed, image_width, image_height, highres_scale,
  78. steps, cfg, highres_steps, highres_denoise, n_prompt
  79. ], outputs=[chatInterface.chatbot]).then(
  80. fn=lambda x: x, inputs=[
  81. chatInterface.chatbot
  82. ], outputs=[chatInterface.chatbot_state])
  83. if __name__ == "__main__":
  84. demo.queue().launch(inbrowser=True, server_name='0.0.0.0')