From 452d246c8819c4859f1de94148a96de5c51c86d8 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Thu, 10 Apr 2025 04:22:35 +1000 Subject: [PATCH 1/4] move load/save/clear buttons to accordion --- wgp.py | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/wgp.py b/wgp.py index 97cd830..785fff2 100644 --- a/wgp.py +++ b/wgp.py @@ -3479,7 +3479,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(): onemore_btn = gr.Button("One More Sample Please !") abort_btn = gr.Button("Abort") - + with gr.Accordion("Queue Management", open=False) as queue_accordion: queue_df = gr.DataFrame( headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], @@ -3489,14 +3489,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non wrap=True, value=[], line_breaks= True, - visible= False, + visible= True, elem_id="queue_df" ) - with gr.Row(): - queue_zip_base64_output = gr.Text(visible=False) - save_queue_btn = gr.DownloadButton("Save Queue", size="sm") - load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm") - clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop") + with gr.Row(): + queue_zip_base64_output = gr.Text(visible=False) + save_queue_btn = gr.DownloadButton("Save Queue", size="sm") + load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm") + clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop") trigger_zip_download_js = """ (base64String) => { if (!base64String) { @@ -3544,9 +3544,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non inputs=[load_queue_btn, state], outputs=[queue_df] ).then( - fn=lambda s: gr.update(visible=bool(get_gen_info(s).get("queue",[]))), + fn=lambda s: (gr.update(visible=bool(get_gen_info(s).get("queue",[]))), gr.Accordion(open=True)) if bool(get_gen_info(s).get("queue",[])) else (gr.update(visible=False), gr.update()), inputs=[state], - outputs=[current_gen_column] + outputs=[current_gen_column, queue_accordion] ) clear_queue_btn.click( @@ -3554,9 +3554,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non inputs=[state], outputs=[queue_df] ).then( - fn=lambda: gr.update(visible=False), + fn=lambda: (gr.update(visible=False), gr.Accordion(open=False)), inputs=None, - outputs=[current_gen_column] + outputs=[current_gen_column, queue_accordion] ) extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, @@ -3639,6 +3639,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then(finalize_generation, inputs= [state], outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] + ).then( + fn=lambda s: gr.Accordion(open=False) if len(get_gen_info(s).get("queue", [])) <= 1 else gr.update(), + inputs=[state], + outputs=[queue_accordion] ).then(unload_model_if_needed, inputs= [state], outputs= [] @@ -3653,6 +3657,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], outputs=queue_df + ).then( + fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), # Expand if queue has items (len > 1 assumes placeholder) + inputs=[state], + outputs=[queue_accordion] ).then( fn=update_status, inputs = [state], @@ -3670,7 +3678,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non gen_info, prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, - advanced_row, image_prompt_column, video_prompt_column, + advanced_row, image_prompt_column, video_prompt_column, queue_accordion, *prompt_vars ) @@ -4153,7 +4161,7 @@ def create_demo(): gen_info, prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, - advanced_row, image_prompt_column, video_prompt_column, + advanced_row, image_prompt_column, video_prompt_column, queue_accordion, *prompt_vars_outputs ) = generate_video_tab(model_choice=model_choice, header=header) with gr.Tab("Informations"): @@ -4170,7 +4178,8 @@ def create_demo(): def run_autoload_and_prepare_ui(current_state): df_update, loaded_flag, modified_state = autoload_queue(current_state) should_start_processing = loaded_flag - return df_update, gr.update(visible=loaded_flag), should_start_processing, modified_state + accordion_update = gr.Accordion(open=True) if loaded_flag else gr.update() + return df_update, gr.update(visible=loaded_flag), accordion_update, should_start_processing, modified_state def start_processing_if_needed(should_start, current_state): if not isinstance(current_state, dict) or 'gen' not in current_state: @@ -4179,18 +4188,20 @@ def create_demo(): if should_start: yield from process_tasks(current_state) else: - yield "Autoload complete. Processing not started." + yield None def finalize_generation_with_state(current_state): if not isinstance(current_state, dict) or 'gen' not in current_state: - return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), current_state + return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state + gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state) - return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, current_state + accordion_update = gr.Accordion(open=False) if len(get_gen_info(current_state).get("queue", [])) <= 1 else gr.update() + return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state demo.load( fn=run_autoload_and_prepare_ui, inputs=[state], - outputs=[queue_df, current_gen_column, should_start_flag, state] + outputs=[queue_df, current_gen_column, queue_accordion, should_start_flag, state] ).then( fn=start_processing_if_needed, inputs=[should_start_flag, state], @@ -4199,7 +4210,7 @@ def create_demo(): ).then( fn=finalize_generation_with_state, inputs=[state], - outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, state], + outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, queue_accordion, state], trigger_mode="always_last" ) From 965267d74004d71bd2b23a0ec87ce7aaff81b79c Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Thu, 10 Apr 2025 04:37:27 +1000 Subject: [PATCH 2/4] clear queue aborts currently generating item --- wgp.py | 156 ++++++++++++++++++++++----------------------------------- 1 file changed, 60 insertions(+), 96 deletions(-) diff --git a/wgp.py b/wgp.py index 785fff2..a2a468f 100644 --- a/wgp.py +++ b/wgp.py @@ -362,24 +362,21 @@ def save_queue_action(state): gen = get_gen_info(state) queue = gen.get("queue", []) - if not queue or len(queue) <=1 : # Check if queue is empty or only has the placeholder + if not queue or len(queue) <=1 : gr.Info("Queue is empty. Nothing to save.") - return None # Return None if nothing to save + return None - # Use an in-memory buffer for the zip file zip_buffer = io.BytesIO() - # Still use a temporary directory *only* for storing images before zipping with tempfile.TemporaryDirectory() as tmpdir: queue_manifest = [] - image_paths_in_zip = {} # Tracks image PIL object ID -> filename in zip + image_paths_in_zip = {} for task_index, task in enumerate(queue): - # Skip the placeholder item if it exists if task is None or not isinstance(task, dict) or task_index == 0: continue params_copy = task.get('params', {}).copy() - task_id_s = task.get('id', f"task_{task_index}") # Use a different var name + task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs"] for key in image_keys: @@ -387,95 +384,71 @@ def save_queue_action(state): if images_pil is None: continue - # Ensure images_pil is always a list for processing is_originally_list = isinstance(images_pil, list) if not is_originally_list: images_pil = [images_pil] image_filenames_for_json = [] for img_index, pil_image in enumerate(images_pil): - # Ensure it's actually a PIL Image object before proceeding if not isinstance(pil_image, Image.Image): print(f"Warning: Expected PIL Image for key '{key}' in task {task_id_s}, got {type(pil_image)}. Skipping image.") continue - # Use object ID to check if this specific image instance is already saved img_id = id(pil_image) if img_id in image_paths_in_zip: - # If already saved, just add its filename to the list image_filenames_for_json.append(image_paths_in_zip[img_id]) - continue # Move to the next image in the list + continue - # Image not saved yet, create filename and save path img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png" img_save_path = os.path.join(tmpdir, img_filename_in_zip) try: - # Save the image to the temporary directory pil_image.save(img_save_path, "PNG") image_filenames_for_json.append(img_filename_in_zip) - # Store the mapping from image ID to its filename in the zip image_paths_in_zip[img_id] = img_filename_in_zip except Exception as e: print(f"Error saving image {img_filename_in_zip} for task {task_id_s}: {e}") - # Optionally decide if you want to continue or fail here - # Update the params_copy with the list of filenames (or single filename) if image_filenames_for_json: params_copy[key] = image_filenames_for_json if is_originally_list else image_filenames_for_json[0] else: - # If no images were successfully processed for this key, remove it params_copy.pop(key, None) - # Clean up parameters before adding to manifest params_copy.pop('state', None) - params_copy.pop('start_image_data_base64', None) # Don't need base64 in saved queue + params_copy.pop('start_image_data_base64', None) params_copy.pop('end_image_data_base64', None) - # Also remove the actual PIL data if it somehow remained params_copy.pop('start_image_data', None) params_copy.pop('end_image_data', None) manifest_entry = { "id": task.get('id'), "params": params_copy, - # Keep other necessary top-level task info if needed, like repeats etc. - # Example: "repeats": task.get('repeats', 1) } queue_manifest.append(manifest_entry) - # --- Create queue.json content --- manifest_path = os.path.join(tmpdir, "queue.json") try: with open(manifest_path, 'w', encoding='utf-8') as f: - # Dump only the relevant manifest data json.dump(queue_manifest, f, indent=4) except Exception as e: print(f"Error writing queue.json: {e}") gr.Warning("Failed to create queue manifest.") - return None # Return None on failure + return None - # --- Create the zip file in memory --- try: with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: - # Add queue.json zf.write(manifest_path, arcname="queue.json") - # Add all unique images that were saved to the temp dir for saved_img_rel_path in image_paths_in_zip.values(): saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path) if os.path.exists(saved_img_abs_path): zf.write(saved_img_abs_path, arcname=saved_img_rel_path) else: - # This shouldn't happen if saving was successful, but good to check print(f"Warning: Image file {saved_img_rel_path} not found during zipping.") - # --- Prepare for return --- - # Move buffer position to the beginning zip_buffer.seek(0) - # Read the binary content zip_binary_content = zip_buffer.getvalue() - # Encode as base64 string zip_base64 = base64.b64encode(zip_binary_content).decode('utf-8') print(f"Queue successfully prepared as base64 string ({len(zip_base64)} chars).") return zip_base64 @@ -483,18 +456,17 @@ def save_queue_action(state): except Exception as e: print(f"Error creating zip file in memory: {e}") gr.Warning("Failed to create zip data for download.") - return None # Return None on failure + return None finally: zip_buffer.close() def load_queue_action(filepath, state): global task_id gen = get_gen_info(state) - original_queue = gen.get("queue", []) # Store original queue for error case + original_queue = gen.get("queue", []) if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file(): print("[load_queue_action] Warning: No valid file selected or file not found.") - # Return the current state of the DataFrame return update_queue_data(original_queue) newly_loaded_queue = [] @@ -518,7 +490,6 @@ def load_queue_action(filepath, state): print(f"[load_queue_action] Manifest loaded. Processing {len(loaded_manifest)} tasks.") for task_index, task_data in enumerate(loaded_manifest): - # (Keep the existing task processing logic here...) if task_data is None or not isinstance(task_data, dict): print(f"[load_queue_action] Skipping invalid task data at index {task_index}") continue @@ -528,7 +499,7 @@ def load_queue_action(filepath, state): max_id_in_file = max(max_id_in_file, task_id_loaded) loaded_pil_images = {} image_keys = ["image_start", "image_end", "image_refs"] - params['state'] = state # Add state back temporarily for consistency if needed by internal logic, but it's removed before saving + params['state'] = state for key in image_keys: image_filenames = params.get(key) @@ -544,26 +515,22 @@ def load_queue_action(filepath, state): continue try: pil_image = Image.open(img_load_path) - # Ensure the image data is loaded into memory before the temp dir is cleaned up pil_image.load() - # Convert image right after loading converted_image = convert_image(pil_image) loaded_pils.append(converted_image) - pil_image.close() # Close the file handle + pil_image.close() except Exception as img_e: print(f"[load_queue_action] Error loading image {img_filename_in_zip}: {img_e}") if loaded_pils: params[key] = loaded_pils if is_list else loaded_pils[0] - loaded_pil_images[key] = params[key] # Store loaded PILs for preview generation + loaded_pil_images[key] = params[key] else: params.pop(key, None) - # Generate preview base64 strings primary_preview_pil, secondary_preview_pil = None, None start_prev_pil_list = loaded_pil_images.get("image_start") end_prev_pil_list = loaded_pil_images.get("image_end") ref_prev_pil_list = loaded_pil_images.get("image_refs") - # Extract first image for preview if available if start_prev_pil_list: primary_preview_pil = start_prev_pil_list[0] if isinstance(start_prev_pil_list, list) and start_prev_pil_list else start_prev_pil_list if not isinstance(start_prev_pil_list, list) else None if end_prev_pil_list: @@ -571,97 +538,102 @@ def load_queue_action(filepath, state): elif ref_prev_pil_list and isinstance(ref_prev_pil_list, list) and ref_prev_pil_list: primary_preview_pil = ref_prev_pil_list[0] - # Generate base64 only if PIL image exists start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None - # Get top-level image data (PIL objects) for runtime task top_level_start_image = loaded_pil_images.get("image_start") top_level_end_image = loaded_pil_images.get("image_end") - # Construct the runtime task dictionary runtime_task = { "id": task_id_loaded, - "params": params.copy(), # Use a copy of params - # Extract necessary params for top level if they exist + "params": params.copy(), "repeats": params.get('repeat_generation', 1), "length": params.get('video_length'), "steps": params.get('num_inference_steps'), "prompt": params.get('prompt'), - # Store the actual loaded PIL image data here "start_image_data": top_level_start_image, "end_image_data": top_level_end_image, - # Store base64 previews generated above "start_image_data_base64": start_b64, "end_image_data_base64": end_b64, } newly_loaded_queue.append(runtime_task) print(f"[load_queue_action] Processed task {task_index+1}/{len(loaded_manifest)}, ID: {task_id_loaded}") - # --- State Update --- with lock: print("[load_queue_action] Acquiring lock to update state...") - gen["queue"] = newly_loaded_queue[:] # Replace the queue in the state - local_queue_copy_for_global_ref = gen["queue"][:] # Copy for global ref update - current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0]) # Safer max ID calculation + gen["queue"] = newly_loaded_queue[:] + local_queue_copy_for_global_ref = gen["queue"][:] + current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0]) - # Update global task ID only if the loaded max ID is higher if current_max_id_in_new_queue > task_id: print(f"[load_queue_action] Updating global task_id from {task_id} to {current_max_id_in_new_queue + 1}") - task_id = current_max_id_in_new_queue + 1 # Ensure next ID is unique + task_id = current_max_id_in_new_queue + 1 else: print(f"[load_queue_action] Global task_id ({task_id}) is >= max in file ({current_max_id_in_new_queue}). Not changing task_id.") gen["prompts_max"] = len(newly_loaded_queue) print("[load_queue_action] State update complete. Releasing lock.") - # --- Global Reference Update --- if local_queue_copy_for_global_ref is not None: print("[load_queue_action] Updating global queue reference...") update_global_queue_ref(local_queue_copy_for_global_ref) else: - # This case should ideally not be reached if state update happens print("[load_queue_action] Warning: Skipping global ref update as local copy is None.") print(f"[load_queue_action] Queue load successful. Returning DataFrame update for {len(newly_loaded_queue)} tasks.") - # *** Return the DataFrame update object *** return update_queue_data(newly_loaded_queue) except (ValueError, zipfile.BadZipFile, FileNotFoundError, Exception) as e: error_message = f"Error during queue load: {e}" print(f"[load_queue_action] Caught error: {error_message}") traceback.print_exc() - # Optionally show a Gradio warning/error to the user - gr.Warning(f"Failed to load queue: {error_message[:200]}") # Show truncated error + gr.Warning(f"Failed to load queue: {error_message[:200]}") - # *** Return the DataFrame update for the original queue *** print("[load_queue_action] Load failed. Returning DataFrame update for original queue.") return update_queue_data(original_queue) finally: - # Clean up the uploaded file object if it exists and has a path if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name): try: - # Gradio often uses temp files, attempting removal is good practice - # os.remove(filepath.name) - # print(f"[load_queue_action] Cleaned up temporary upload file: {filepath.name}") - pass # Let Gradio manage its temp files unless specifically needed + pass except OSError as e: - # Ignore errors like "file not found" if already cleaned up print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}") pass def clear_queue_action(state): gen = get_gen_info(state) queue = gen.get("queue", []) - if not queue: - gr.Info("Queue is already empty.") - return update_queue_data([]) + aborted_current = False + cleared_pending = False with lock: - queue.clear() - gen["prompts_max"] = 0 + if "in_progress" in gen and gen["in_progress"]: + print("Clear Queue: Signalling abort for in-progress task.") + gen["abort"] = True + gen["extra_orders"] = 0 + if wan_model is not None: + wan_model._interrupt = True + aborted_current = True + + if queue: + if len(queue) > 1 or (len(queue) == 1 and queue[0] is not None and queue[0].get('id') is not None): + print(f"Clear Queue: Clearing {len(queue)} tasks from queue.") + queue.clear() + cleared_pending = True + else: + pass + + if aborted_current or cleared_pending: + gen["prompts_max"] = 0 + + if aborted_current and cleared_pending: + gr.Info("Queue cleared and current generation aborted.") + elif aborted_current: + gr.Info("Current generation aborted.") + elif cleared_pending: + gr.Info("Queue cleared.") + else: + gr.Info("Queue is already empty or only contains the active task (which wasn't aborted now).") - gr.Info("Queue cleared.") return update_queue_data([]) def autosave_queue(): @@ -725,7 +697,7 @@ def autosave_queue(): if os.path.exists(saved_img_abs_path): zf.write(saved_img_abs_path, arcname=saved_img_rel_path) return output_filename - return None # Should not happen if queue has items + return None saved_path = _save_queue_to_file(global_queue_ref, AUTOSAVE_FILENAME) @@ -740,17 +712,15 @@ def autosave_queue(): def autoload_queue(state): global task_id - # Initial check using the original state try: - gen = get_gen_info(state) # Make sure initial state is a dict + gen = get_gen_info(state) original_queue = gen.get("queue", []) except AttributeError: print("[autoload_queue] Error: Initial state is not a dictionary. Cannot autoload.") - # Return default values indicating no load occurred and the state is unchanged - return gr.update(visible=False), False, state # Return an empty DF update + return gr.update(visible=False), False, state loaded_flag = False - dataframe_update = update_queue_data(original_queue) # Default update is the original queue + dataframe_update = update_queue_data(original_queue) if not original_queue and Path(AUTOSAVE_FILENAME).is_file(): print(f"Autoloading queue from {AUTOSAVE_FILENAME}...") @@ -758,38 +728,32 @@ def autoload_queue(state): def __init__(self, name): self.name = name mock_filepath = MockFile(AUTOSAVE_FILENAME) - - # Call load_queue_action, it modifies 'state' internally and returns a DataFrame update dataframe_update = load_queue_action(mock_filepath, state) - # Now check the 'state' dictionary which should have been modified by load_queue_action - gen = get_gen_info(state) # Use the (potentially) modified state dictionary + gen = get_gen_info(state) loaded_queue_after_action = gen.get("queue", []) - if loaded_queue_after_action: # Check if the queue in the state is now populated + if loaded_queue_after_action: print(f"Autoload successful. Loaded {len(loaded_queue_after_action)} tasks into state.") loaded_flag = True - # Global ref update was already done inside load_queue_action if successful else: print("Autoload attempted but queue in state remains empty (file might be empty or invalid).") - # Ensure state reflects empty queue if load failed but file existed with lock: gen["queue"] = [] gen["prompts_max"] = 0 update_global_queue_ref([]) - dataframe_update = update_queue_data([]) # Ensure UI shows empty queue + dataframe_update = update_queue_data([]) - else: # Handle cases where autoload shouldn't happen + else: if original_queue: print("Autoload skipped: Queue is not empty.") - update_global_queue_ref(original_queue) # Ensure global ref matches current state - dataframe_update = update_queue_data(original_queue) # UI should show current queue + update_global_queue_ref(original_queue) + dataframe_update = update_queue_data(original_queue) else: print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.") - update_global_queue_ref([]) # Ensure global ref is empty - dataframe_update = update_queue_data([]) # UI should show empty queue + update_global_queue_ref([]) + dataframe_update = update_queue_data([]) - # Return the DataFrame update needed for the UI, the flag, and the final state dictionary return dataframe_update, loaded_flag, state From 87614b8216173caac19b54a674fc66d85b153ede Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Thu, 10 Apr 2025 04:53:18 +1000 Subject: [PATCH 3/4] fix manual load queue --- wgp.py | 67 +++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 45 insertions(+), 22 deletions(-) diff --git a/wgp.py b/wgp.py index a2a468f..5dd3633 100644 --- a/wgp.py +++ b/wgp.py @@ -756,6 +756,28 @@ def autoload_queue(state): return dataframe_update, loaded_flag, state +def run_autoload_and_prepare_ui(current_state): + df_update, loaded_flag, modified_state = autoload_queue(current_state) + should_start_processing = loaded_flag + accordion_update = gr.Accordion(open=True) if loaded_flag else gr.update() + return df_update, gr.update(visible=loaded_flag), accordion_update, should_start_processing, modified_state + +def start_processing_if_needed(should_start, current_state): + if not isinstance(current_state, dict) or 'gen' not in current_state: + yield "Error: Invalid state received before processing." + return + if should_start: + yield from process_tasks(current_state) + else: + yield None + +def finalize_generation_with_state(current_state): + if not isinstance(current_state, dict) or 'gen' not in current_state: + return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state + + gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state) + accordion_update = gr.Accordion(open=False) if len(get_gen_info(current_state).get("queue", [])) <= 1 else gr.update() + return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state def get_queue_table(queue): data = [] @@ -3503,6 +3525,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non js=trigger_zip_download_js ) + should_start_flag = gr.State(False) load_queue_btn.upload( fn=load_queue_action, inputs=[load_queue_btn, state], @@ -3511,6 +3534,28 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non fn=lambda s: (gr.update(visible=bool(get_gen_info(s).get("queue",[]))), gr.Accordion(open=True)) if bool(get_gen_info(s).get("queue",[])) else (gr.update(visible=False), gr.update()), inputs=[state], outputs=[current_gen_column, queue_accordion] + ).then( + fn=lambda s: ( + (gr.Button(visible=False), gr.Button(visible=True), gr.Column(visible=True), True) + if bool(get_gen_info(s).get("queue",[])) + else (gr.Button(visible=True), gr.Button(visible=False), gr.Column(visible=False), False) + ), + inputs=[state], + outputs=[generate_btn, add_to_queue_btn, current_gen_column, should_start_flag] + ).then( + fn=start_processing_if_needed, + inputs=[should_start_flag, state], + outputs=[gen_status], + trigger_mode="once" + ).then( + fn=finalize_generation_with_state, + inputs=[state], + outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, queue_accordion, state], + trigger_mode="always_last" + ).then( + unload_model_if_needed, + inputs= [state], + outputs= [] ) clear_queue_btn.click( @@ -4139,28 +4184,6 @@ def create_demo(): generate_about_tab() should_start_flag = gr.State(False) - def run_autoload_and_prepare_ui(current_state): - df_update, loaded_flag, modified_state = autoload_queue(current_state) - should_start_processing = loaded_flag - accordion_update = gr.Accordion(open=True) if loaded_flag else gr.update() - return df_update, gr.update(visible=loaded_flag), accordion_update, should_start_processing, modified_state - - def start_processing_if_needed(should_start, current_state): - if not isinstance(current_state, dict) or 'gen' not in current_state: - yield "Error: Invalid state received before processing." - return - if should_start: - yield from process_tasks(current_state) - else: - yield None - - def finalize_generation_with_state(current_state): - if not isinstance(current_state, dict) or 'gen' not in current_state: - return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state - - gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state) - accordion_update = gr.Accordion(open=False) if len(get_gen_info(current_state).get("queue", [])) <= 1 else gr.update() - return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state demo.load( fn=run_autoload_and_prepare_ui, From 2e026c67c26eca0adb34bb28ae0a9f1d3af30766 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Thu, 10 Apr 2025 06:05:45 +1000 Subject: [PATCH 4/4] add save and exit, fix vace attachments --- wgp.py | 227 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 176 insertions(+), 51 deletions(-) diff --git a/wgp.py b/wgp.py index 5dd3633..5c4643d 100644 --- a/wgp.py +++ b/wgp.py @@ -31,6 +31,7 @@ from PIL import Image import zipfile import tempfile import atexit +import shutil global_queue_ref = [] AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 @@ -364,21 +365,23 @@ def save_queue_action(state): if not queue or len(queue) <=1 : gr.Info("Queue is empty. Nothing to save.") - return None + return "" zip_buffer = io.BytesIO() with tempfile.TemporaryDirectory() as tmpdir: queue_manifest = [] - image_paths_in_zip = {} + file_paths_in_zip = {} for task_index, task in enumerate(queue): - if task is None or not isinstance(task, dict) or task_index == 0: continue + if task is None or not isinstance(task, dict) or task.get('id') is None: continue params_copy = task.get('params', {}).copy() task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs"] + video_keys = ["video_guide", "video_mask"] + for key in image_keys: images_pil = params_copy.get(key) if images_pil is None: @@ -395,8 +398,8 @@ def save_queue_action(state): continue img_id = id(pil_image) - if img_id in image_paths_in_zip: - image_filenames_for_json.append(image_paths_in_zip[img_id]) + if img_id in file_paths_in_zip: + image_filenames_for_json.append(file_paths_in_zip[img_id]) continue img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png" @@ -405,7 +408,8 @@ def save_queue_action(state): try: pil_image.save(img_save_path, "PNG") image_filenames_for_json.append(img_filename_in_zip) - image_paths_in_zip[img_id] = img_filename_in_zip + file_paths_in_zip[img_id] = img_filename_in_zip + print(f"Saved image: {img_filename_in_zip}") except Exception as e: print(f"Error saving image {img_filename_in_zip} for task {task_id_s}: {e}") @@ -414,17 +418,47 @@ def save_queue_action(state): else: params_copy.pop(key, None) + for key in video_keys: + video_path_orig = params_copy.get(key) + if video_path_orig is None or not isinstance(video_path_orig, str): + continue + + if video_path_orig in file_paths_in_zip: + params_copy[key] = file_paths_in_zip[video_path_orig] + continue + + if not os.path.isfile(video_path_orig): + print(f"Warning: Video file not found for key '{key}' in task {task_id_s}: {video_path_orig}. Skipping video.") + params_copy.pop(key, None) + continue + + _, extension = os.path.splitext(video_path_orig) + vid_filename_in_zip = f"task{task_id_s}_{key}{extension if extension else '.mp4'}" + vid_save_path = os.path.join(tmpdir, vid_filename_in_zip) + + try: + shutil.copy2(video_path_orig, vid_save_path) + params_copy[key] = vid_filename_in_zip + file_paths_in_zip[video_path_orig] = vid_filename_in_zip + print(f"Copied video: {video_path_orig} -> {vid_filename_in_zip}") + except Exception as e: + print(f"Error copying video {video_path_orig} to {vid_filename_in_zip} for task {task_id_s}: {e}") + params_copy.pop(key, None) + params_copy.pop('state', None) params_copy.pop('start_image_data_base64', None) params_copy.pop('end_image_data_base64', None) params_copy.pop('start_image_data', None) params_copy.pop('end_image_data', None) + task.pop('start_image_data', None) + task.pop('end_image_data', None) manifest_entry = { "id": task.get('id'), "params": params_copy, } + manifest_entry = {k: v for k, v in manifest_entry.items() if v is not None} queue_manifest.append(manifest_entry) manifest_path = os.path.join(tmpdir, "queue.json") @@ -440,12 +474,13 @@ def save_queue_action(state): with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: zf.write(manifest_path, arcname="queue.json") - for saved_img_rel_path in image_paths_in_zip.values(): - saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path) - if os.path.exists(saved_img_abs_path): - zf.write(saved_img_abs_path, arcname=saved_img_rel_path) + for file_id, saved_file_rel_path in file_paths_in_zip.items(): + saved_file_abs_path = os.path.join(tmpdir, saved_file_rel_path) + if os.path.exists(saved_file_abs_path): + zf.write(saved_file_abs_path, arcname=saved_file_rel_path) + print(f"Adding to zip: {saved_file_rel_path}") else: - print(f"Warning: Image file {saved_img_rel_path} not found during zipping.") + print(f"Warning: File {saved_file_rel_path} (ID: {file_id}) not found during zipping.") zip_buffer.seek(0) zip_binary_content = zip_buffer.getvalue() @@ -464,6 +499,8 @@ def load_queue_action(filepath, state): global task_id gen = get_gen_info(state) original_queue = gen.get("queue", []) + save_path_base = server_config.get("save_path", "outputs") + loaded_cache_dir = os.path.join(save_path_base, "_loaded_queue_cache") if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file(): print("[load_queue_action] Warning: No valid file selected or file not found.") @@ -476,6 +513,9 @@ def load_queue_action(filepath, state): try: print(f"[load_queue_action] Attempting to load queue from: {filepath.name}") + os.makedirs(loaded_cache_dir, exist_ok=True) + print(f"[load_queue_action] Using cache directory: {loaded_cache_dir}") + with tempfile.TemporaryDirectory() as tmpdir: with zipfile.ZipFile(filepath.name, 'r') as zf: if "queue.json" not in zf.namelist(): raise ValueError("queue.json not found in zip file") @@ -497,21 +537,29 @@ def load_queue_action(filepath, state): params = task_data.get('params', {}) task_id_loaded = task_data.get('id', 0) max_id_in_file = max(max_id_in_file, task_id_loaded) - loaded_pil_images = {} - image_keys = ["image_start", "image_end", "image_refs"] params['state'] = state + image_keys = ["image_start", "image_end", "image_refs"] + video_keys = ["video_guide", "video_mask"] + + loaded_pil_images = {} + loaded_video_paths = {} + for key in image_keys: image_filenames = params.get(key) if image_filenames is None: continue + is_list = isinstance(image_filenames, list) if not is_list: image_filenames = [image_filenames] + loaded_pils = [] for img_filename_in_zip in image_filenames: - if not isinstance(img_filename_in_zip, str): continue + if not isinstance(img_filename_in_zip, str): + print(f"[load_queue_action] Warning: Non-string filename found for image key '{key}'. Skipping.") + continue img_load_path = os.path.join(tmpdir, img_filename_in_zip) if not os.path.exists(img_load_path): - print(f"[load_queue_action] Image file not found during load: {img_load_path}") + print(f"[load_queue_action] Image file not found in extracted data: {img_load_path}. Skipping.") continue try: pil_image = Image.open(img_load_path) @@ -519,30 +567,53 @@ def load_queue_action(filepath, state): converted_image = convert_image(pil_image) loaded_pils.append(converted_image) pil_image.close() + print(f"Loaded image: {img_filename_in_zip} for key {key}") except Exception as img_e: print(f"[load_queue_action] Error loading image {img_filename_in_zip}: {img_e}") if loaded_pils: params[key] = loaded_pils if is_list else loaded_pils[0] loaded_pil_images[key] = params[key] - else: params.pop(key, None) + else: + params.pop(key, None) - primary_preview_pil, secondary_preview_pil = None, None - start_prev_pil_list = loaded_pil_images.get("image_start") - end_prev_pil_list = loaded_pil_images.get("image_end") - ref_prev_pil_list = loaded_pil_images.get("image_refs") + for key in video_keys: + video_filename_in_zip = params.get(key) + if video_filename_in_zip is None or not isinstance(video_filename_in_zip, str): + continue - if start_prev_pil_list: - primary_preview_pil = start_prev_pil_list[0] if isinstance(start_prev_pil_list, list) and start_prev_pil_list else start_prev_pil_list if not isinstance(start_prev_pil_list, list) else None - if end_prev_pil_list: - secondary_preview_pil = end_prev_pil_list[0] if isinstance(end_prev_pil_list, list) and end_prev_pil_list else end_prev_pil_list if not isinstance(end_prev_pil_list, list) else None - elif ref_prev_pil_list and isinstance(ref_prev_pil_list, list) and ref_prev_pil_list: - primary_preview_pil = ref_prev_pil_list[0] + video_load_path = os.path.join(tmpdir, video_filename_in_zip) + if not os.path.exists(video_load_path): + print(f"[load_queue_action] Video file not found in extracted data: {video_load_path}. Skipping.") + params.pop(key, None) + continue + + persistent_video_path = os.path.join(loaded_cache_dir, video_filename_in_zip) + try: + shutil.copy2(video_load_path, persistent_video_path) + params[key] = persistent_video_path + loaded_video_paths[key] = persistent_video_path + print(f"Loaded video: {video_filename_in_zip} -> {persistent_video_path}") + except Exception as vid_e: + print(f"[load_queue_action] Error copying video {video_filename_in_zip} to cache: {vid_e}") + params.pop(key, None) + + + primary_preview_pil_list = loaded_pil_images.get("image_start") or loaded_pil_images.get("image_refs") + secondary_preview_pil_list = loaded_pil_images.get("image_end") + + primary_preview_pil = None + if primary_preview_pil_list: + primary_preview_pil = primary_preview_pil_list[0] if isinstance(primary_preview_pil_list, list) else primary_preview_pil_list + + secondary_preview_pil = None + if secondary_preview_pil_list: + secondary_preview_pil = secondary_preview_pil_list[0] if isinstance(secondary_preview_pil_list, list) else secondary_preview_pil_list start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None - top_level_start_image = loaded_pil_images.get("image_start") - top_level_end_image = loaded_pil_images.get("image_end") + top_level_start_image = params.get("image_start") or params.get("image_refs") + top_level_end_image = params.get("image_end") runtime_task = { "id": task_id_loaded, @@ -557,19 +628,20 @@ def load_queue_action(filepath, state): "end_image_data_base64": end_b64, } newly_loaded_queue.append(runtime_task) - print(f"[load_queue_action] Processed task {task_index+1}/{len(loaded_manifest)}, ID: {task_id_loaded}") + print(f"[load_queue_action] Reconstructed task {task_index+1}/{len(loaded_manifest)}, ID: {task_id_loaded}") with lock: print("[load_queue_action] Acquiring lock to update state...") gen["queue"] = newly_loaded_queue[:] local_queue_copy_for_global_ref = gen["queue"][:] - current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0]) - if current_max_id_in_new_queue > task_id: - print(f"[load_queue_action] Updating global task_id from {task_id} to {current_max_id_in_new_queue + 1}") - task_id = current_max_id_in_new_queue + 1 + current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0]) + if current_max_id_in_new_queue >= task_id: + new_task_id = current_max_id_in_new_queue + 1 + print(f"[load_queue_action] Updating global task_id from {task_id} to {new_task_id}") + task_id = new_task_id else: - print(f"[load_queue_action] Global task_id ({task_id}) is >= max in file ({current_max_id_in_new_queue}). Not changing task_id.") + print(f"[load_queue_action] Global task_id ({task_id}) is > max in file ({current_max_id_in_new_queue}). Not changing task_id.") gen["prompts_max"] = len(newly_loaded_queue) print("[load_queue_action] State update complete. Releasing lock.") @@ -593,11 +665,14 @@ def load_queue_action(filepath, state): return update_queue_data(original_queue) finally: if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name): - try: - pass - except OSError as e: - print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}") - pass + if tempfile.gettempdir() in os.path.abspath(filepath.name): + try: + os.remove(filepath.name) + print(f"[load_queue_action] Removed temporary upload file: {filepath.name}") + except OSError as e: + print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}") + else: + print(f"[load_queue_action] Info: Did not remove non-temporary file: {filepath.name}") def clear_queue_action(state): gen = get_gen_info(state) @@ -636,6 +711,12 @@ def clear_queue_action(state): return update_queue_data([]) +def quit_application(): + print("Save and Quit requested...") + autosave_queue() + import signal + os.kill(os.getpid(), signal.SIGINT) + def autosave_queue(): global global_queue_ref if not global_queue_ref: @@ -649,14 +730,20 @@ def autosave_queue(): def _save_queue_to_file(queue_to_save, output_filename): if not queue_to_save: return None + with tempfile.TemporaryDirectory() as tmpdir: queue_manifest = [] - image_paths_in_zip = {} + file_paths_in_zip = {} + for task_index, task in enumerate(queue_to_save): - if task is None or not isinstance(task, dict): continue + if task is None or not isinstance(task, dict) or task.get('id') is None: continue + params_copy = task.get('params', {}).copy() task_id_s = task.get('id', f"task_{task_index}") + image_keys = ["image_start", "image_end", "image_refs"] + video_keys = ["video_guide", "video_mask"] + for key in image_keys: images_pil = params_copy.get(key) if images_pil is None: continue @@ -666,36 +753,70 @@ def autosave_queue(): for img_index, pil_image in enumerate(images_pil): if not isinstance(pil_image, Image.Image): continue img_id = id(pil_image) - if img_id in image_paths_in_zip: - image_filenames_for_json.append(image_paths_in_zip[img_id]) + if img_id in file_paths_in_zip: + image_filenames_for_json.append(file_paths_in_zip[img_id]) continue img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png" img_save_path = os.path.join(tmpdir, img_filename_in_zip) try: pil_image.save(img_save_path, "PNG") image_filenames_for_json.append(img_filename_in_zip) - image_paths_in_zip[img_id] = img_filename_in_zip + file_paths_in_zip[img_id] = img_filename_in_zip except Exception as e: print(f"Autosave error saving image {img_filename_in_zip}: {e}") if image_filenames_for_json: params_copy[key] = image_filenames_for_json if is_list else image_filenames_for_json[0] else: params_copy.pop(key, None) + + for key in video_keys: + video_path_orig = params_copy.get(key) + if video_path_orig is None or not isinstance(video_path_orig, str): + continue + + if video_path_orig in file_paths_in_zip: + params_copy[key] = file_paths_in_zip[video_path_orig] + continue + + if not os.path.isfile(video_path_orig): + print(f"Warning (Autosave): Video file not found for key '{key}' in task {task_id_s}: {video_path_orig}. Skipping.") + params_copy.pop(key, None) + continue + + _, extension = os.path.splitext(video_path_orig) + vid_filename_in_zip = f"task{task_id_s}_{key}{extension if extension else '.mp4'}" + vid_save_path = os.path.join(tmpdir, vid_filename_in_zip) + + try: + shutil.copy2(video_path_orig, vid_save_path) + params_copy[key] = vid_filename_in_zip + file_paths_in_zip[video_path_orig] = vid_filename_in_zip + except Exception as e: + print(f"Error (Autosave) copying video {video_path_orig} to {vid_filename_in_zip} for task {task_id_s}: {e}") + params_copy.pop(key, None) params_copy.pop('state', None) params_copy.pop('start_image_data_base64', None) params_copy.pop('end_image_data_base64', None) + params_copy.pop('start_image_data', None) + params_copy.pop('end_image_data', None) + manifest_entry = { - "id": task.get('id'), "params": params_copy, + "id": task.get('id'), + "params": params_copy, } + manifest_entry = {k: v for k, v in manifest_entry.items() if v is not None} queue_manifest.append(manifest_entry) + manifest_path = os.path.join(tmpdir, "queue.json") with open(manifest_path, 'w', encoding='utf-8') as f: json.dump(queue_manifest, f, indent=4) with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zf: zf.write(manifest_path, arcname="queue.json") - for saved_img_rel_path in image_paths_in_zip.values(): - saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path) - if os.path.exists(saved_img_abs_path): - zf.write(saved_img_abs_path, arcname=saved_img_rel_path) + for saved_file_rel_path in file_paths_in_zip.values(): + saved_file_abs_path = os.path.join(tmpdir, saved_file_rel_path) + if os.path.exists(saved_file_abs_path): + zf.write(saved_file_abs_path, arcname=saved_file_rel_path) + else: + print(f"Warning (Autosave): File {saved_file_rel_path} not found during zipping.") return output_filename return None @@ -1160,7 +1281,6 @@ text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/ server_config_filename = "wgp_config.json" if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"): - import shutil shutil.move("gradio_config.json", server_config_filename) if not Path(server_config_filename).is_file(): @@ -2965,7 +3085,6 @@ def download_loras(): lora_dir = get_lora_dir(get_model_filename("i2v"), quantizeTransformer) log_path = os.path.join(lora_dir, "log.txt") if not os.path.isfile(log_path): - import shutil tmp_path = os.path.join(lora_dir, "tmp_lora_dowload") import glob snapshot_download(repo_id="DeepBeepMeep/Wan2.1", allow_patterns="loras_i2v/*", local_dir= tmp_path) @@ -3483,6 +3602,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non save_queue_btn = gr.DownloadButton("Save Queue", size="sm") load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm") clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop") + quit_button = gr.Button("Save and Quit", size="sm", variant="secondary") trigger_zip_download_js = """ (base64String) => { if (!base64String) { @@ -3567,6 +3687,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non inputs=None, outputs=[current_gen_column, queue_accordion] ) + quit_button.click( + fn=quit_application, + inputs=[], + outputs=[] + ) extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column,