From cfa7c776c68f4a25ea8c0db471f2f4e747c87896 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 8 Sep 2025 04:30:58 +1000 Subject: [PATCH] initial commit --- wgp.py | 284 +++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 236 insertions(+), 48 deletions(-) diff --git a/wgp.py b/wgp.py index 9bbf1ff..5592a27 100644 --- a/wgp.py +++ b/wgp.py @@ -195,7 +195,116 @@ def clean_image_list(gradio_list): gradio_list = [ convert_image( Image.open(img) if isinstance(img, str) else img ) for img in gradio_list ] return gradio_list +def edit_task_in_queue( + lset_name, + image_mode, + prompt, + negative_prompt, + resolution, + video_length, + batch_size, + seed, + force_fps, + num_inference_steps, + guidance_scale, + guidance2_scale, + guidance3_scale, + switch_threshold, + switch_threshold2, + guidance_phases, + model_switch_phase, + audio_guidance_scale, + flow_shift, + sample_solver, + embedded_guidance_scale, + repeat_generation, + multi_prompts_gen_type, + multi_images_gen_type, + skip_steps_cache_type, + skip_steps_multiplier, + skip_steps_start_step_perc, + loras_choices, + loras_multipliers, + image_prompt_type, + image_start, + image_end, + model_mode, + video_source, + keep_frames_video_source, + video_guide_outpainting, + video_prompt_type, + image_refs, + frames_positions, + video_guide, + image_guide, + keep_frames_video_guide, + denoising_strength, + video_mask, + image_mask, + control_net_weight, + control_net_weight2, + mask_expand, + audio_guide, + audio_guide2, + audio_source, + audio_prompt_type, + speakers_locations, + sliding_window_size, + sliding_window_overlap, + sliding_window_color_correction_strength, + sliding_window_overlap_noise, + sliding_window_discard_last_frames, + image_refs_relative_size, + remove_background_images_ref, + temporal_upsampling, + spatial_upsampling, + film_grain_intensity, + film_grain_saturation, + MMAudio_setting, + MMAudio_prompt, + MMAudio_neg_prompt, + RIFLEx_setting, + NAG_scale, + NAG_tau, + NAG_alpha, + slg_switch, + slg_layers, + slg_start_perc, + slg_end_perc, + apg_switch, + cfg_star_switch, + cfg_zero_step, + prompt_enhancer, + min_frames_if_references, + override_profile, + mode, + state, +): + inputs = get_function_arguments(edit_task_in_queue, locals()) + + gen = get_gen_info(state) + queue = gen.get("queue", []) + editing_task_index = state.get("editing_task_index", None) + if editing_task_index is None: + gr.Warning("No task selected for editing.") + return update_queue_data(queue), gr.Tabs(selected="video_gen") + + task_to_edit_index = editing_task_index + 1 + if task_to_edit_index >= len(queue): + gr.Warning("Task index out of bounds. Cannot edit.") + return update_queue_data(queue), gr.Tabs(selected="video_gen") + + task_to_edit = queue[task_to_edit_index] + inputs['state'] = state + task_to_edit['params'] = inputs.copy() + task_to_edit['prompt'] = inputs.get('prompt') + task_to_edit['length'] = inputs.get('video_length') + task_to_edit['steps'] = inputs.get('num_inference_steps') + update_task_thumbnails(task_to_edit, inputs) + gr.Info(f"Task ID {task_to_edit['id']} has been updated successfully.") + state["editing_task_index"] = None + return update_queue_data(queue), gr.Tabs(selected="video_gen") def process_prompt_and_add_tasks(state, model_choice): @@ -6485,27 +6594,41 @@ def handle_celll_selection(state, evt: gr.SelectData): queue = gen.get("queue", []) if evt.index is None: - return gr.update(), gr.update(), gr.update(visible=False) + return gr.update(), gr.update(), gr.update(visible=False), gr.update() row_index, col_index = evt.index - cell_value = None + + if col_index == 1: + state["editing_task_index"] = row_index + task_to_edit_index = row_index + 1 + + if task_to_edit_index < len(queue): + task_data = queue[task_to_edit_index] + gr.Info(f"Loading task '{task_data['prompt'][:50]}...' for editing.") + return gr.update(), gr.update(), gr.update(visible=False), gr.Tabs(selected="edit") + else: + gr.Warning("Task index out of bounds.") + return gr.update(), gr.update(), gr.update(visible=False), gr.update() + if col_index in [6, 7, 8]: if col_index == 6: cell_value = "↑" elif col_index == 7: cell_value = "↓" elif col_index == 8: cell_value = "✖" - if col_index == 6: - new_df_data = move_up(queue, [row_index]) - return new_df_data, gr.update(), gr.update(visible=False) - elif col_index == 7: - new_df_data = move_down(queue, [row_index]) - return new_df_data, gr.update(), gr.update(visible=False) - elif col_index == 8: - new_df_data = remove_task(queue, [row_index]) - gen["prompts_max"] = gen.get("prompts_max",0) - 1 - update_status(state) - return new_df_data, gr.update(), gr.update(visible=False) + if col_index == 6: + new_df_data = move_up(queue, [row_index]) + return new_df_data, gr.update(), gr.update(visible=False), gr.update() + elif col_index == 7: + new_df_data = move_down(queue, [row_index]) + return new_df_data, gr.update(), gr.update(visible=False), gr.update() + elif col_index == 8: + new_df_data = remove_task(queue, [row_index]) + gen["prompts_max"] = gen.get("prompts_max",0) - 1 + update_status(state) + return new_df_data, gr.update(), gr.update(visible=False), gr.update() + start_img_col_idx = 4 end_img_col_idx = 5 image_data_to_show = None + names = [] if col_index == start_img_col_idx: with lock: row_index += 1 @@ -6518,13 +6641,11 @@ def handle_celll_selection(state, evt: gr.SelectData): if row_index < len(queue): image_data_to_show = queue[row_index].get('end_image_data_base64') names = queue[row_index].get('end_image_labels') - if image_data_to_show: value = get_modal_image( image_data_to_show[0], names[0]) - return gr.update(), gr.update(value=value), gr.update(visible=True) + return gr.update(), gr.update(value=value), gr.update(visible=True), gr.update() else: - return gr.update(), gr.update(), gr.update(visible=False) - + return gr.update(), gr.update(), gr.update(visible=False), gr.update() def change_model(state, model_choice): if model_choice == None: @@ -6922,7 +7043,7 @@ def refresh_video_length_label(state, current_video_length, force_fps, video_gui computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source ) return gr.update(label= compute_video_length_label(computed_fps, current_video_length)) -def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): +def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None, tab_id='generate'): global inputs_names #, advanced if update_form: @@ -7898,9 +8019,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_info_add_videos_btn = gr.Button("Add Videos / Images", size ="sm") if not update_form: - generate_btn = gr.Button("Generate") - generate_trigger = gr.Text(visible = False) - add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False) + if tab_id == 'edit': + edit_btn = gr.Button("Edit") + generate_btn = gr.Button("Generate", visible=False) + add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible=False) + else: + generate_btn = gr.Button("Generate") + add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible=False) + edit_btn = gr.Button("Save Edit", visible=False) + generate_trigger = gr.Text(visible = False) add_to_queue_trigger = gr.Text(visible = False) with gr.Column(visible= False) as current_gen_column: @@ -7985,7 +8112,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_checkbox.input(fn=refresh_video_guide_outpainting_row, inputs=[video_guide_outpainting_checkbox, video_guide_outpainting], outputs= [video_guide_outpainting_row,video_guide_outpainting]) show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) - queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) + queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container, main_tabs]) gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab]) preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] ) @@ -8125,6 +8252,18 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars] ) + if tab_id == 'edit': + edit_inputs_names = list(inspect.signature(edit_task_in_queue).parameters)[:-1] + edit_inputs_components = [locals_dict[k] for k in edit_inputs_names] + edit_btn.click( + fn=validate_wizard_prompt, + inputs=[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], + outputs=[prompt] + ).then( + fn=edit_task_in_queue, + inputs=edit_inputs_components + [state], + outputs=[queue_df, main_tabs] + ) refresh_form_trigger.change(fn= fill_inputs, inputs=[state], @@ -8152,7 +8291,27 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non inputs=[state], outputs=[gen_status]) - generate_btn.click(fn = init_generate, inputs = [state, output, last_choice], outputs=[generate_trigger, mode]) + if tab_id == 'generate': + generate_btn.click(fn = init_generate, inputs = [state, output, last_choice], outputs=[generate_trigger, mode]) + add_to_queue_btn.click(fn = lambda : (get_unique_id(), ""), inputs = None, outputs=[add_to_queue_trigger, mode]) + # gr.on(triggers=[add_to_queue_btn.click, add_to_queue_trigger.change],fn=validate_wizard_prompt, + add_to_queue_trigger.change(fn=validate_wizard_prompt, + inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , + outputs= [prompt] + ).then(fn=save_inputs, + inputs =[target_state] + gen_inputs, + outputs= None + ).then(fn=process_prompt_and_add_tasks, + inputs = [state, model_choice], + outputs=queue_df + ).then( + fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), + inputs=[state], + outputs=[queue_accordion] + ).then( + fn=update_status, + inputs = [state], + ) generate_trigger.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , @@ -8277,36 +8436,20 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non outputs=[current_gen_column, queue_accordion] ) - - add_to_queue_btn.click(fn = lambda : (get_unique_id(), ""), inputs = None, outputs=[add_to_queue_trigger, mode]) - # gr.on(triggers=[add_to_queue_btn.click, add_to_queue_trigger.change],fn=validate_wizard_prompt, - add_to_queue_trigger.change(fn=validate_wizard_prompt, - inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] - ).then(fn=save_inputs, - inputs =[target_state] + gen_inputs, - outputs= None - ).then(fn=process_prompt_and_add_tasks, - inputs = [state, model_choice], - outputs=queue_df - ).then( - fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), - inputs=[state], - outputs=[queue_accordion] - ).then( - fn=update_status, - inputs = [state], - ) - close_modal_button.click( lambda: gr.update(visible=False), inputs=[], outputs=[modal_container] ) - return ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger - ) + if tab_id == 'edit': + locals_dict = locals() + gen_inputs = [locals_dict[k] for k in inputs_names] + [state] + extra_inputs + return gen_inputs + else: + return ( state, loras_choices, lset_name, resolution, + video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger + ) def generate_download_tab(lset_name,loras_choices, state): @@ -9268,7 +9411,52 @@ def create_ui(): with gr.Row(): ( state, loras_choices, lset_name, resolution, video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger - ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs) + ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs, tab_id='generate') + with gr.Tab("Edit", id="edit") as edit_tab: + gr.Markdown("

Edit Task

") + edit_tab_components = generate_video_tab( + update_form=False, + state_dict=state.value, + ui_defaults=get_default_settings(transformer_type), + model_family=model_family, + model_choice=model_choice, + header=header, + main=main, + main_tabs=main_tabs, + tab_id='edit' + ) + def fill_inputs_for_edit(state): + editing_task_index = state.get("editing_task_index", None) + if editing_task_index is None: + return [gr.update()] * len(edit_tab_components) + + gen = get_gen_info(state) + queue = gen.get("queue", []) + task_to_edit_index = editing_task_index + 1 + + if task_to_edit_index >= len(queue): + gr.Warning("Task to edit not found in queue.") + state["editing_task_index"] = None + return [gr.update()] * len(edit_tab_components) + + task = queue[task_to_edit_index] + ui_defaults = task['params'] + + image_list_keys = ['image_start', 'image_end', 'image_refs'] + for key in image_list_keys: + value = ui_defaults.get(key) + if value is not None and not isinstance(value, list): + ui_defaults[key] = [value] + + if ui_defaults.get('model_type') != state["model_type"]: + gr.Warning(f"Editing a task for a different model ({ui_defaults.get('model_type')}). Some settings may not apply.") + + return generate_video_tab(update_form=True, state_dict=state, ui_defaults=ui_defaults) + edit_tab.select( + fn=fill_inputs_for_edit, + inputs=[state], + outputs=edit_tab_components + ) with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: