From 4ef20d2299cb95801ac869002d3b3a9afb400abc Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 28 Mar 2025 02:08:38 +1100 Subject: [PATCH] add image thumbnails and previews to queue items --- gradio_server.py | 169 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 141 insertions(+), 28 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 0ebb534..237b9dd 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -22,6 +22,9 @@ import traceback import math import asyncio from wan.utils import prompt_parser +import base64 +import io +from PIL import Image PROMPT_VARS_MAX = 10 target_mmgp_version = "3.3.4" @@ -50,6 +53,29 @@ def format_time(seconds): minutes = int((seconds % 3600) // 60) return f"{hours}h {minutes}m" +def pil_to_base64_uri(pil_image, format="png", quality=75): + if pil_image is None: + return None + buffer = io.BytesIO() + try: + img_to_save = pil_image + if format.lower() == 'jpeg' and pil_image.mode == 'RGBA': + img_to_save = pil_image.convert('RGB') + elif format.lower() == 'png' and pil_image.mode not in ['RGB', 'RGBA', 'L', 'P']: + img_to_save = pil_image.convert('RGBA') + elif pil_image.mode == 'P': + img_to_save = pil_image.convert('RGBA' if 'transparency' in pil_image.info else 'RGB') + if format.lower() == 'jpeg': + img_to_save.save(buffer, format=format, quality=quality) + else: + img_to_save.save(buffer, format=format) + img_bytes = buffer.getvalue() + encoded_string = base64.b64encode(img_bytes).decode("utf-8") + return f"data:image/{format.lower()};base64,{encoded_string}" + except Exception as e: + print(f"Error converting PIL to base64: {e}") + return None + def runner(): global current_task_id while True: @@ -175,6 +201,9 @@ def add_video_task(*params): with lock: task_id += 1 current_task_id = task_id + start_image_data = params[16] if len(params) > 16 else None + end_image_data = params[17] if len(params) > 17 else None + queue.append({ "id": current_task_id, "params": (current_task_id,) + params, @@ -184,9 +213,11 @@ def add_video_task(*params): "progress": "0.0%", "steps": f"0/{params[5]}", "time": "--", - "prompt": params[0] + "prompt": params[0], + "start_image_data": start_image_data, + "end_image_data": end_image_data }) - return + return update_queue_data() def move_up(selected_indices): if not selected_indices or len(selected_indices) == 0: @@ -233,6 +264,15 @@ def update_queue_data(): truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt'] full_prompt = item['prompt'].replace('"', '"') prompt_cell = f'{truncated_prompt}' + start_img_uri = pil_to_base64_uri(item.get('start_image_data'), format="jpeg", quality=70) + end_img_uri = pil_to_base64_uri(item.get('end_image_data'), format="jpeg", quality=70) + thumbnail_size = "50px" + start_img_md = "" + end_img_md = "" + if start_img_uri: + start_img_md = f'Start' + if end_img_uri: + end_img_md = f'End' data.append([ item.get('status', "Starting"), item.get('repeats', "0/0"), @@ -240,6 +280,8 @@ def update_queue_data(): item.get('steps', ''), item.get('time', '--'), prompt_cell, + start_img_md, + end_img_md, "↑", "↓", "✖" @@ -1143,7 +1185,6 @@ def generate_video( print(f"Model loaded") reload_needed= False - from PIL import Image import numpy as np import tempfile @@ -1905,6 +1946,10 @@ def generate_video_tab(image2video=False): download_status = gr.Markdown() with gr.Row(): with gr.Column(): + with gr.Column(visible=False, elem_id="image-modal-container") as modal_container: + with gr.Row(elem_id="image-modal-close-button-row"): + close_modal_button = gr.Button("❌", size="sm") + modal_image_display = gr.Image(label="Full Resolution Image", interactive=False, show_label=False) gallery_update_trigger = gr.Textbox(value="0", visible=False, label="_gallery_trigger") with gr.Row(visible= len(loras)>0) as presets_column: lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")] @@ -2134,25 +2179,49 @@ def generate_video_tab(image2video=False): , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) generate_btn = gr.Button("Generate") queue_df = gr.DataFrame( - headers=["Status", "Completed", "Progress", "Steps", "Time", "Prompt", "", "", ""], - datatype=["str", "str", "str", "str", "str", "markdown", "str", "str", "str"], + headers=["Status", "Completed", "Progress", "Steps", "Time", "Prompt", "Start", "End", "", "", ""], + datatype=["str", "str", "str", "str", "str", "markdown", "markdown", "markdown", "str", "str", "str"], interactive=False, - col_count=(9, "fixed"), + col_count=(11, "fixed"), wrap=True, value=update_queue_data, every=1, elem_id="queue_df" ) def handle_selection(evt: gr.SelectData): - cell_value = evt.value - selected_index = evt.index - if cell_value == "↑": - return move_up([selected_index]) - elif cell_value == "↓": - return move_down([selected_index]) - elif cell_value == "✖": - return remove_task([selected_index]) - return queue_df + if evt.index is None: + return gr.update(), gr.update(), gr.update(visible=False) + row_index, col_index = evt.index + cell_value = None + if col_index in [8, 9, 10]: + if col_index == 8: cell_value = "↑" + elif col_index == 9: cell_value = "↓" + elif col_index == 10: cell_value = "✖" + if col_index == 8: + new_df_data = move_up([row_index]) + return new_df_data, gr.update(), gr.update(visible=False) + elif col_index == 9: + new_df_data = move_down([row_index]) + return new_df_data, gr.update(), gr.update(visible=False) + elif col_index == 10: + new_df_data = remove_task([row_index]) + return new_df_data, gr.update(), gr.update(visible=False) + start_img_col_idx = 6 + end_img_col_idx = 7 + image_data_to_show = None + if col_index == start_img_col_idx: + with lock: + if row_index < len(queue): + image_data_to_show = queue[row_index].get('start_image_data') + elif col_index == end_img_col_idx: + with lock: + if row_index < len(queue): + image_data_to_show = queue[row_index].get('end_image_data') + + if image_data_to_show: + return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True) + else: + return gr.update(), gr.update(), gr.update(visible=False) def refresh_gallery_on_trigger(state): if(state.get("update_gallery", False)): state['update_gallery'] = False @@ -2160,7 +2229,8 @@ def generate_video_tab(image2video=False): selected_indices = gr.State([]) queue_df.select( fn=handle_selection, - outputs=selected_indices + inputs=None, + outputs=[queue_df, modal_image_display, modal_container], ) gallery_update_trigger.change( fn=refresh_gallery_on_trigger, @@ -2229,6 +2299,11 @@ def generate_video_tab(image2video=False): inputs=original_inputs, outputs=queue_df ) + close_modal_button.click( + lambda: gr.update(visible=False), + inputs=[], + outputs=[modal_container] + ) return loras_choices, lset_name, header, state def generate_configuration_tab(): @@ -2524,16 +2599,9 @@ def create_demo(): #queue_df th { pointer-events: none; } - #queue_df .tabulator-col { - pointer-events: none; - } - #queue_df .tabulator-col .tabulator-arrow { - display: none; - } #queue_df table { overflow: hidden !important; } - #queue_df::-webkit-scrollbar { display: none !important; } @@ -2545,7 +2613,8 @@ def create_demo(): width: 100px; } #queue_df td:nth-child(6) { - width: 300px; + width: auto; + min-width: 200px; } #queue_df td:nth-child(7), #queue_df td:nth-child(8), @@ -2553,12 +2622,56 @@ def create_demo(): cursor: pointer; text-align: center; font-weight: bold; + width: 60px; + text-align: center; + padding: 2px !important; + cursor: pointer; } - #queue_df td:nth-child(7):hover, - #queue_df td:nth-child(8):hover, - #queue_df td:nth-child(9):hover { - background-color: #e0e0e0; + #queue_df td:nth-child(10) img, + #queue_df td:nth-child(11) img { + max-width: 50px; + max-height: 50px; + object-fit: contain; + display: block; + margin: auto; } + #image-modal-container { + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background-color: rgba(0, 0, 0, 0.7); + justify-content: center; + align-items: center; + z-index: 1000; + padding: 20px; + box-sizing: border-box; + } + #image-modal-container > div { + background-color: white; + padding: 15px; + border-radius: 8px; + max-width: 90%; + max-height: 90%; + overflow: auto; + position: relative; + display: flex; + flex-direction: column; + } + #image-modal-container img { + max-width: 100%; + max-height: 80vh; + object-fit: contain; + margin-top: 10px; + } + #image-modal-close-button-row { + display: flex; + justify-content: flex-end; + } + #image-modal-close-button-row button { + cursor: pointer; + } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo: gr.Markdown("

Wan 2.1GP v3.2 by DeepBeepMeep (Updates)

")