initial commit

This commit is contained in:
Chris Malone 2025-09-08 04:30:58 +10:00
parent 66a07db16b
commit cfa7c776c6

258
wgp.py
View File

@ -195,7 +195,116 @@ def clean_image_list(gradio_list):
gradio_list = [ convert_image( Image.open(img) if isinstance(img, str) else img ) for img in gradio_list ] gradio_list = [ convert_image( Image.open(img) if isinstance(img, str) else img ) for img in gradio_list ]
return gradio_list return gradio_list
def edit_task_in_queue(
lset_name,
image_mode,
prompt,
negative_prompt,
resolution,
video_length,
batch_size,
seed,
force_fps,
num_inference_steps,
guidance_scale,
guidance2_scale,
guidance3_scale,
switch_threshold,
switch_threshold2,
guidance_phases,
model_switch_phase,
audio_guidance_scale,
flow_shift,
sample_solver,
embedded_guidance_scale,
repeat_generation,
multi_prompts_gen_type,
multi_images_gen_type,
skip_steps_cache_type,
skip_steps_multiplier,
skip_steps_start_step_perc,
loras_choices,
loras_multipliers,
image_prompt_type,
image_start,
image_end,
model_mode,
video_source,
keep_frames_video_source,
video_guide_outpainting,
video_prompt_type,
image_refs,
frames_positions,
video_guide,
image_guide,
keep_frames_video_guide,
denoising_strength,
video_mask,
image_mask,
control_net_weight,
control_net_weight2,
mask_expand,
audio_guide,
audio_guide2,
audio_source,
audio_prompt_type,
speakers_locations,
sliding_window_size,
sliding_window_overlap,
sliding_window_color_correction_strength,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
image_refs_relative_size,
remove_background_images_ref,
temporal_upsampling,
spatial_upsampling,
film_grain_intensity,
film_grain_saturation,
MMAudio_setting,
MMAudio_prompt,
MMAudio_neg_prompt,
RIFLEx_setting,
NAG_scale,
NAG_tau,
NAG_alpha,
slg_switch,
slg_layers,
slg_start_perc,
slg_end_perc,
apg_switch,
cfg_star_switch,
cfg_zero_step,
prompt_enhancer,
min_frames_if_references,
override_profile,
mode,
state,
):
inputs = get_function_arguments(edit_task_in_queue, locals())
gen = get_gen_info(state)
queue = gen.get("queue", [])
editing_task_index = state.get("editing_task_index", None)
if editing_task_index is None:
gr.Warning("No task selected for editing.")
return update_queue_data(queue), gr.Tabs(selected="video_gen")
task_to_edit_index = editing_task_index + 1
if task_to_edit_index >= len(queue):
gr.Warning("Task index out of bounds. Cannot edit.")
return update_queue_data(queue), gr.Tabs(selected="video_gen")
task_to_edit = queue[task_to_edit_index]
inputs['state'] = state
task_to_edit['params'] = inputs.copy()
task_to_edit['prompt'] = inputs.get('prompt')
task_to_edit['length'] = inputs.get('video_length')
task_to_edit['steps'] = inputs.get('num_inference_steps')
update_task_thumbnails(task_to_edit, inputs)
gr.Info(f"Task ID {task_to_edit['id']} has been updated successfully.")
state["editing_task_index"] = None
return update_queue_data(queue), gr.Tabs(selected="video_gen")
def process_prompt_and_add_tasks(state, model_choice): def process_prompt_and_add_tasks(state, model_choice):
@ -6485,27 +6594,41 @@ def handle_celll_selection(state, evt: gr.SelectData):
queue = gen.get("queue", []) queue = gen.get("queue", [])
if evt.index is None: if evt.index is None:
return gr.update(), gr.update(), gr.update(visible=False) return gr.update(), gr.update(), gr.update(visible=False), gr.update()
row_index, col_index = evt.index row_index, col_index = evt.index
cell_value = None
if col_index == 1:
state["editing_task_index"] = row_index
task_to_edit_index = row_index + 1
if task_to_edit_index < len(queue):
task_data = queue[task_to_edit_index]
gr.Info(f"Loading task '{task_data['prompt'][:50]}...' for editing.")
return gr.update(), gr.update(), gr.update(visible=False), gr.Tabs(selected="edit")
else:
gr.Warning("Task index out of bounds.")
return gr.update(), gr.update(), gr.update(visible=False), gr.update()
if col_index in [6, 7, 8]: if col_index in [6, 7, 8]:
if col_index == 6: cell_value = "" if col_index == 6: cell_value = ""
elif col_index == 7: cell_value = "" elif col_index == 7: cell_value = ""
elif col_index == 8: cell_value = "" elif col_index == 8: cell_value = ""
if col_index == 6: if col_index == 6:
new_df_data = move_up(queue, [row_index]) new_df_data = move_up(queue, [row_index])
return new_df_data, gr.update(), gr.update(visible=False) return new_df_data, gr.update(), gr.update(visible=False), gr.update()
elif col_index == 7: elif col_index == 7:
new_df_data = move_down(queue, [row_index]) new_df_data = move_down(queue, [row_index])
return new_df_data, gr.update(), gr.update(visible=False) return new_df_data, gr.update(), gr.update(visible=False), gr.update()
elif col_index == 8: elif col_index == 8:
new_df_data = remove_task(queue, [row_index]) new_df_data = remove_task(queue, [row_index])
gen["prompts_max"] = gen.get("prompts_max",0) - 1 gen["prompts_max"] = gen.get("prompts_max",0) - 1
update_status(state) update_status(state)
return new_df_data, gr.update(), gr.update(visible=False) return new_df_data, gr.update(), gr.update(visible=False), gr.update()
start_img_col_idx = 4 start_img_col_idx = 4
end_img_col_idx = 5 end_img_col_idx = 5
image_data_to_show = None image_data_to_show = None
names = []
if col_index == start_img_col_idx: if col_index == start_img_col_idx:
with lock: with lock:
row_index += 1 row_index += 1
@ -6518,13 +6641,11 @@ def handle_celll_selection(state, evt: gr.SelectData):
if row_index < len(queue): if row_index < len(queue):
image_data_to_show = queue[row_index].get('end_image_data_base64') image_data_to_show = queue[row_index].get('end_image_data_base64')
names = queue[row_index].get('end_image_labels') names = queue[row_index].get('end_image_labels')
if image_data_to_show: if image_data_to_show:
value = get_modal_image( image_data_to_show[0], names[0]) value = get_modal_image( image_data_to_show[0], names[0])
return gr.update(), gr.update(value=value), gr.update(visible=True) return gr.update(), gr.update(value=value), gr.update(visible=True), gr.update()
else: else:
return gr.update(), gr.update(), gr.update(visible=False) return gr.update(), gr.update(), gr.update(visible=False), gr.update()
def change_model(state, model_choice): def change_model(state, model_choice):
if model_choice == None: if model_choice == None:
@ -6922,7 +7043,7 @@ def refresh_video_length_label(state, current_video_length, force_fps, video_gui
computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source ) computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source )
return gr.update(label= compute_video_length_label(computed_fps, current_video_length)) return gr.update(label= compute_video_length_label(computed_fps, current_video_length))
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None, tab_id='generate'):
global inputs_names #, advanced global inputs_names #, advanced
if update_form: if update_form:
@ -7898,9 +8019,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_info_add_videos_btn = gr.Button("Add Videos / Images", size ="sm") video_info_add_videos_btn = gr.Button("Add Videos / Images", size ="sm")
if not update_form: if not update_form:
generate_btn = gr.Button("Generate") if tab_id == 'edit':
generate_trigger = gr.Text(visible = False) edit_btn = gr.Button("Edit")
generate_btn = gr.Button("Generate", visible=False)
add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible=False) add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible=False)
else:
generate_btn = gr.Button("Generate")
add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible=False)
edit_btn = gr.Button("Save Edit", visible=False)
generate_trigger = gr.Text(visible = False)
add_to_queue_trigger = gr.Text(visible = False) add_to_queue_trigger = gr.Text(visible = False)
with gr.Column(visible= False) as current_gen_column: with gr.Column(visible= False) as current_gen_column:
@ -7985,7 +8112,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide_outpainting_checkbox.input(fn=refresh_video_guide_outpainting_row, inputs=[video_guide_outpainting_checkbox, video_guide_outpainting], outputs= [video_guide_outpainting_row,video_guide_outpainting]) video_guide_outpainting_checkbox.input(fn=refresh_video_guide_outpainting_row, inputs=[video_guide_outpainting_checkbox, video_guide_outpainting], outputs= [video_guide_outpainting_row,video_guide_outpainting])
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container, main_tabs])
gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab]) gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab])
preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview])
PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] ) PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] )
@ -8125,6 +8252,18 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars] fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
) )
if tab_id == 'edit':
edit_inputs_names = list(inspect.signature(edit_task_in_queue).parameters)[:-1]
edit_inputs_components = [locals_dict[k] for k in edit_inputs_names]
edit_btn.click(
fn=validate_wizard_prompt,
inputs=[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars],
outputs=[prompt]
).then(
fn=edit_task_in_queue,
inputs=edit_inputs_components + [state],
outputs=[queue_df, main_tabs]
)
refresh_form_trigger.change(fn= fill_inputs, refresh_form_trigger.change(fn= fill_inputs,
inputs=[state], inputs=[state],
@ -8152,7 +8291,27 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
inputs=[state], inputs=[state],
outputs=[gen_status]) outputs=[gen_status])
if tab_id == 'generate':
generate_btn.click(fn = init_generate, inputs = [state, output, last_choice], outputs=[generate_trigger, mode]) generate_btn.click(fn = init_generate, inputs = [state, output, last_choice], outputs=[generate_trigger, mode])
add_to_queue_btn.click(fn = lambda : (get_unique_id(), ""), inputs = None, outputs=[add_to_queue_trigger, mode])
# gr.on(triggers=[add_to_queue_btn.click, add_to_queue_trigger.change],fn=validate_wizard_prompt,
add_to_queue_trigger.change(fn=validate_wizard_prompt,
inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
outputs= [prompt]
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice],
outputs=queue_df
).then(
fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(),
inputs=[state],
outputs=[queue_accordion]
).then(
fn=update_status,
inputs = [state],
)
generate_trigger.change(fn=validate_wizard_prompt, generate_trigger.change(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
@ -8277,33 +8436,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
outputs=[current_gen_column, queue_accordion] outputs=[current_gen_column, queue_accordion]
) )
add_to_queue_btn.click(fn = lambda : (get_unique_id(), ""), inputs = None, outputs=[add_to_queue_trigger, mode])
# gr.on(triggers=[add_to_queue_btn.click, add_to_queue_trigger.change],fn=validate_wizard_prompt,
add_to_queue_trigger.change(fn=validate_wizard_prompt,
inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
outputs= [prompt]
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice],
outputs=queue_df
).then(
fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(),
inputs=[state],
outputs=[queue_accordion]
).then(
fn=update_status,
inputs = [state],
)
close_modal_button.click( close_modal_button.click(
lambda: gr.update(visible=False), lambda: gr.update(visible=False),
inputs=[], inputs=[],
outputs=[modal_container] outputs=[modal_container]
) )
if tab_id == 'edit':
locals_dict = locals()
gen_inputs = [locals_dict[k] for k in inputs_names] + [state] + extra_inputs
return gen_inputs
else:
return ( state, loras_choices, lset_name, resolution, return ( state, loras_choices, lset_name, resolution,
video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger
) )
@ -9268,7 +9411,52 @@ def create_ui():
with gr.Row(): with gr.Row():
( state, loras_choices, lset_name, resolution, ( state, loras_choices, lset_name, resolution,
video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger
) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs) ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs, tab_id='generate')
with gr.Tab("Edit", id="edit") as edit_tab:
gr.Markdown("<div align=center><H2>Edit Task</H2></div>")
edit_tab_components = generate_video_tab(
update_form=False,
state_dict=state.value,
ui_defaults=get_default_settings(transformer_type),
model_family=model_family,
model_choice=model_choice,
header=header,
main=main,
main_tabs=main_tabs,
tab_id='edit'
)
def fill_inputs_for_edit(state):
editing_task_index = state.get("editing_task_index", None)
if editing_task_index is None:
return [gr.update()] * len(edit_tab_components)
gen = get_gen_info(state)
queue = gen.get("queue", [])
task_to_edit_index = editing_task_index + 1
if task_to_edit_index >= len(queue):
gr.Warning("Task to edit not found in queue.")
state["editing_task_index"] = None
return [gr.update()] * len(edit_tab_components)
task = queue[task_to_edit_index]
ui_defaults = task['params']
image_list_keys = ['image_start', 'image_end', 'image_refs']
for key in image_list_keys:
value = ui_defaults.get(key)
if value is not None and not isinstance(value, list):
ui_defaults[key] = [value]
if ui_defaults.get('model_type') != state["model_type"]:
gr.Warning(f"Editing a task for a different model ({ui_defaults.get('model_type')}). Some settings may not apply.")
return generate_video_tab(update_form=True, state_dict=state, ui_defaults=ui_defaults)
edit_tab.select(
fn=fill_inputs_for_edit,
inputs=[state],
outputs=edit_tab_components
)
with gr.Tab("Guides", id="info") as info_tab: with gr.Tab("Guides", id="info") as info_tab:
generate_info_tab() generate_info_tab()
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: