diff --git a/wgp.py b/wgp.py index a2a468f..5dd3633 100644 --- a/wgp.py +++ b/wgp.py @@ -756,6 +756,28 @@ def autoload_queue(state): return dataframe_update, loaded_flag, state +def run_autoload_and_prepare_ui(current_state): + df_update, loaded_flag, modified_state = autoload_queue(current_state) + should_start_processing = loaded_flag + accordion_update = gr.Accordion(open=True) if loaded_flag else gr.update() + return df_update, gr.update(visible=loaded_flag), accordion_update, should_start_processing, modified_state + +def start_processing_if_needed(should_start, current_state): + if not isinstance(current_state, dict) or 'gen' not in current_state: + yield "Error: Invalid state received before processing." + return + if should_start: + yield from process_tasks(current_state) + else: + yield None + +def finalize_generation_with_state(current_state): + if not isinstance(current_state, dict) or 'gen' not in current_state: + return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state + + gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state) + accordion_update = gr.Accordion(open=False) if len(get_gen_info(current_state).get("queue", [])) <= 1 else gr.update() + return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state def get_queue_table(queue): data = [] @@ -3503,6 +3525,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non js=trigger_zip_download_js ) + should_start_flag = gr.State(False) load_queue_btn.upload( fn=load_queue_action, inputs=[load_queue_btn, state], @@ -3511,6 +3534,28 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non fn=lambda s: (gr.update(visible=bool(get_gen_info(s).get("queue",[]))), gr.Accordion(open=True)) if bool(get_gen_info(s).get("queue",[])) else (gr.update(visible=False), gr.update()), inputs=[state], outputs=[current_gen_column, queue_accordion] + ).then( + fn=lambda s: ( + (gr.Button(visible=False), gr.Button(visible=True), gr.Column(visible=True), True) + if bool(get_gen_info(s).get("queue",[])) + else (gr.Button(visible=True), gr.Button(visible=False), gr.Column(visible=False), False) + ), + inputs=[state], + outputs=[generate_btn, add_to_queue_btn, current_gen_column, should_start_flag] + ).then( + fn=start_processing_if_needed, + inputs=[should_start_flag, state], + outputs=[gen_status], + trigger_mode="once" + ).then( + fn=finalize_generation_with_state, + inputs=[state], + outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, queue_accordion, state], + trigger_mode="always_last" + ).then( + unload_model_if_needed, + inputs= [state], + outputs= [] ) clear_queue_btn.click( @@ -4139,28 +4184,6 @@ def create_demo(): generate_about_tab() should_start_flag = gr.State(False) - def run_autoload_and_prepare_ui(current_state): - df_update, loaded_flag, modified_state = autoload_queue(current_state) - should_start_processing = loaded_flag - accordion_update = gr.Accordion(open=True) if loaded_flag else gr.update() - return df_update, gr.update(visible=loaded_flag), accordion_update, should_start_processing, modified_state - - def start_processing_if_needed(should_start, current_state): - if not isinstance(current_state, dict) or 'gen' not in current_state: - yield "Error: Invalid state received before processing." - return - if should_start: - yield from process_tasks(current_state) - else: - yield None - - def finalize_generation_with_state(current_state): - if not isinstance(current_state, dict) or 'gen' not in current_state: - return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state - - gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state) - accordion_update = gr.Accordion(open=False) if len(get_gen_info(current_state).get("queue", [])) <= 1 else gr.update() - return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state demo.load( fn=run_autoload_and_prepare_ui,