diff --git a/gradio_server.py b/gradio_server.py
index 0ebb534..237b9dd 100644
--- a/gradio_server.py
+++ b/gradio_server.py
@@ -22,6 +22,9 @@ import traceback
import math
import asyncio
from wan.utils import prompt_parser
+import base64
+import io
+from PIL import Image
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.3.4"
@@ -50,6 +53,29 @@ def format_time(seconds):
minutes = int((seconds % 3600) // 60)
return f"{hours}h {minutes}m"
+def pil_to_base64_uri(pil_image, format="png", quality=75):
+ if pil_image is None:
+ return None
+ buffer = io.BytesIO()
+ try:
+ img_to_save = pil_image
+ if format.lower() == 'jpeg' and pil_image.mode == 'RGBA':
+ img_to_save = pil_image.convert('RGB')
+ elif format.lower() == 'png' and pil_image.mode not in ['RGB', 'RGBA', 'L', 'P']:
+ img_to_save = pil_image.convert('RGBA')
+ elif pil_image.mode == 'P':
+ img_to_save = pil_image.convert('RGBA' if 'transparency' in pil_image.info else 'RGB')
+ if format.lower() == 'jpeg':
+ img_to_save.save(buffer, format=format, quality=quality)
+ else:
+ img_to_save.save(buffer, format=format)
+ img_bytes = buffer.getvalue()
+ encoded_string = base64.b64encode(img_bytes).decode("utf-8")
+ return f"data:image/{format.lower()};base64,{encoded_string}"
+ except Exception as e:
+ print(f"Error converting PIL to base64: {e}")
+ return None
+
def runner():
global current_task_id
while True:
@@ -175,6 +201,9 @@ def add_video_task(*params):
with lock:
task_id += 1
current_task_id = task_id
+ start_image_data = params[16] if len(params) > 16 else None
+ end_image_data = params[17] if len(params) > 17 else None
+
queue.append({
"id": current_task_id,
"params": (current_task_id,) + params,
@@ -184,9 +213,11 @@ def add_video_task(*params):
"progress": "0.0%",
"steps": f"0/{params[5]}",
"time": "--",
- "prompt": params[0]
+ "prompt": params[0],
+ "start_image_data": start_image_data,
+ "end_image_data": end_image_data
})
- return
+ return update_queue_data()
def move_up(selected_indices):
if not selected_indices or len(selected_indices) == 0:
@@ -233,6 +264,15 @@ def update_queue_data():
truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt']
full_prompt = item['prompt'].replace('"', '"')
prompt_cell = f'{truncated_prompt}'
+ start_img_uri = pil_to_base64_uri(item.get('start_image_data'), format="jpeg", quality=70)
+ end_img_uri = pil_to_base64_uri(item.get('end_image_data'), format="jpeg", quality=70)
+ thumbnail_size = "50px"
+ start_img_md = ""
+ end_img_md = ""
+ if start_img_uri:
+ start_img_md = f''
+ if end_img_uri:
+ end_img_md = f'
'
data.append([
item.get('status', "Starting"),
item.get('repeats', "0/0"),
@@ -240,6 +280,8 @@ def update_queue_data():
item.get('steps', ''),
item.get('time', '--'),
prompt_cell,
+ start_img_md,
+ end_img_md,
"↑",
"↓",
"✖"
@@ -1143,7 +1185,6 @@ def generate_video(
print(f"Model loaded")
reload_needed= False
- from PIL import Image
import numpy as np
import tempfile
@@ -1905,6 +1946,10 @@ def generate_video_tab(image2video=False):
download_status = gr.Markdown()
with gr.Row():
with gr.Column():
+ with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
+ with gr.Row(elem_id="image-modal-close-button-row"):
+ close_modal_button = gr.Button("❌", size="sm")
+ modal_image_display = gr.Image(label="Full Resolution Image", interactive=False, show_label=False)
gallery_update_trigger = gr.Textbox(value="0", visible=False, label="_gallery_trigger")
with gr.Row(visible= len(loras)>0) as presets_column:
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")]
@@ -2134,25 +2179,49 @@ def generate_video_tab(image2video=False):
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
generate_btn = gr.Button("Generate")
queue_df = gr.DataFrame(
- headers=["Status", "Completed", "Progress", "Steps", "Time", "Prompt", "", "", ""],
- datatype=["str", "str", "str", "str", "str", "markdown", "str", "str", "str"],
+ headers=["Status", "Completed", "Progress", "Steps", "Time", "Prompt", "Start", "End", "", "", ""],
+ datatype=["str", "str", "str", "str", "str", "markdown", "markdown", "markdown", "str", "str", "str"],
interactive=False,
- col_count=(9, "fixed"),
+ col_count=(11, "fixed"),
wrap=True,
value=update_queue_data,
every=1,
elem_id="queue_df"
)
def handle_selection(evt: gr.SelectData):
- cell_value = evt.value
- selected_index = evt.index
- if cell_value == "↑":
- return move_up([selected_index])
- elif cell_value == "↓":
- return move_down([selected_index])
- elif cell_value == "✖":
- return remove_task([selected_index])
- return queue_df
+ if evt.index is None:
+ return gr.update(), gr.update(), gr.update(visible=False)
+ row_index, col_index = evt.index
+ cell_value = None
+ if col_index in [8, 9, 10]:
+ if col_index == 8: cell_value = "↑"
+ elif col_index == 9: cell_value = "↓"
+ elif col_index == 10: cell_value = "✖"
+ if col_index == 8:
+ new_df_data = move_up([row_index])
+ return new_df_data, gr.update(), gr.update(visible=False)
+ elif col_index == 9:
+ new_df_data = move_down([row_index])
+ return new_df_data, gr.update(), gr.update(visible=False)
+ elif col_index == 10:
+ new_df_data = remove_task([row_index])
+ return new_df_data, gr.update(), gr.update(visible=False)
+ start_img_col_idx = 6
+ end_img_col_idx = 7
+ image_data_to_show = None
+ if col_index == start_img_col_idx:
+ with lock:
+ if row_index < len(queue):
+ image_data_to_show = queue[row_index].get('start_image_data')
+ elif col_index == end_img_col_idx:
+ with lock:
+ if row_index < len(queue):
+ image_data_to_show = queue[row_index].get('end_image_data')
+
+ if image_data_to_show:
+ return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True)
+ else:
+ return gr.update(), gr.update(), gr.update(visible=False)
def refresh_gallery_on_trigger(state):
if(state.get("update_gallery", False)):
state['update_gallery'] = False
@@ -2160,7 +2229,8 @@ def generate_video_tab(image2video=False):
selected_indices = gr.State([])
queue_df.select(
fn=handle_selection,
- outputs=selected_indices
+ inputs=None,
+ outputs=[queue_df, modal_image_display, modal_container],
)
gallery_update_trigger.change(
fn=refresh_gallery_on_trigger,
@@ -2229,6 +2299,11 @@ def generate_video_tab(image2video=False):
inputs=original_inputs,
outputs=queue_df
)
+ close_modal_button.click(
+ lambda: gr.update(visible=False),
+ inputs=[],
+ outputs=[modal_container]
+ )
return loras_choices, lset_name, header, state
def generate_configuration_tab():
@@ -2524,16 +2599,9 @@ def create_demo():
#queue_df th {
pointer-events: none;
}
- #queue_df .tabulator-col {
- pointer-events: none;
- }
- #queue_df .tabulator-col .tabulator-arrow {
- display: none;
- }
#queue_df table {
overflow: hidden !important;
}
-
#queue_df::-webkit-scrollbar {
display: none !important;
}
@@ -2545,7 +2613,8 @@ def create_demo():
width: 100px;
}
#queue_df td:nth-child(6) {
- width: 300px;
+ width: auto;
+ min-width: 200px;
}
#queue_df td:nth-child(7),
#queue_df td:nth-child(8),
@@ -2553,12 +2622,56 @@ def create_demo():
cursor: pointer;
text-align: center;
font-weight: bold;
+ width: 60px;
+ text-align: center;
+ padding: 2px !important;
+ cursor: pointer;
}
- #queue_df td:nth-child(7):hover,
- #queue_df td:nth-child(8):hover,
- #queue_df td:nth-child(9):hover {
- background-color: #e0e0e0;
+ #queue_df td:nth-child(10) img,
+ #queue_df td:nth-child(11) img {
+ max-width: 50px;
+ max-height: 50px;
+ object-fit: contain;
+ display: block;
+ margin: auto;
}
+ #image-modal-container {
+ position: fixed;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgba(0, 0, 0, 0.7);
+ justify-content: center;
+ align-items: center;
+ z-index: 1000;
+ padding: 20px;
+ box-sizing: border-box;
+ }
+ #image-modal-container > div {
+ background-color: white;
+ padding: 15px;
+ border-radius: 8px;
+ max-width: 90%;
+ max-height: 90%;
+ overflow: auto;
+ position: relative;
+ display: flex;
+ flex-direction: column;
+ }
+ #image-modal-container img {
+ max-width: 100%;
+ max-height: 80vh;
+ object-fit: contain;
+ margin-top: 10px;
+ }
+ #image-modal-close-button-row {
+ display: flex;
+ justify-content: flex-end;
+ }
+ #image-modal-close-button-row button {
+ cursor: pointer;
+ }
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
gr.Markdown("