diff --git a/wgp.py b/wgp.py index 083ee1e..414cc80 100644 --- a/wgp.py +++ b/wgp.py @@ -28,6 +28,12 @@ from wan.utils import prompt_parser import base64 import io from PIL import Image +import zipfile +import tempfile +import shutil +import atexit +global_queue_ref = [] +AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.3.4" @@ -98,10 +104,14 @@ def process_prompt_and_add_tasks(state, model_choice): inputs["state"] = state inputs.pop("lset_name") if inputs == None: - return + gr.Warning("Internal state error: Could not retrieve inputs for the model.") + return update_queue_data(queue) prompt = inputs["prompt"] if len(prompt) ==0: - return + gr.Info("Prompt cannot be empty.") + gen = get_gen_info(state) + queue = gen.get("queue", []) + return get_queue_table(queue) prompt, errors = prompt_parser.process_template(prompt) if len(errors) > 0: gr.Info("Error processing prompt template: " + errors) @@ -111,7 +121,10 @@ def process_prompt_and_add_tasks(state, model_choice): prompts = prompt.replace("\r", "").split("\n") prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] if len(prompts) ==0: - return + gr.Info("Prompt cannot be empty.") + gen = get_gen_info(state) + queue = gen.get("queue", []) + return get_queue_table(queue) resolution = inputs["resolution"] width, height = resolution.split("x") @@ -250,9 +263,6 @@ def process_prompt_and_add_tasks(state, model_choice): queue= gen.get("queue", []) return update_queue_data(queue) - - - def add_video_task(**inputs): global task_id state = inputs["state"] @@ -327,6 +337,444 @@ def remove_task(queue, selected_indices): del queue[idx] return update_queue_data(queue) +def update_global_queue_ref(queue): + global global_queue_ref + with lock: + global_queue_ref = queue[:] + +def save_queue_action(state): + gen = get_gen_info(state) + queue = gen.get("queue", []) + + if not queue or len(queue) <=1 : # Check if queue is empty or only has the placeholder + gr.Info("Queue is empty. Nothing to save.") + return None # Return None if nothing to save + + # Use an in-memory buffer for the zip file + zip_buffer = io.BytesIO() + + # Still use a temporary directory *only* for storing images before zipping + with tempfile.TemporaryDirectory() as tmpdir: + queue_manifest = [] + image_paths_in_zip = {} # Tracks image PIL object ID -> filename in zip + + for task_index, task in enumerate(queue): + # Skip the placeholder item if it exists + if task is None or not isinstance(task, dict) or task_index == 0: continue + + params_copy = task.get('params', {}).copy() + task_id_s = task.get('id', f"task_{task_index}") # Use a different var name + + image_keys = ["image_start", "image_end", "image_refs"] + for key in image_keys: + images_pil = params_copy.get(key) + if images_pil is None: + continue + + # Ensure images_pil is always a list for processing + is_originally_list = isinstance(images_pil, list) + if not is_originally_list: + images_pil = [images_pil] + + image_filenames_for_json = [] + for img_index, pil_image in enumerate(images_pil): + # Ensure it's actually a PIL Image object before proceeding + if not isinstance(pil_image, Image.Image): + print(f"Warning: Expected PIL Image for key '{key}' in task {task_id_s}, got {type(pil_image)}. Skipping image.") + continue + + # Use object ID to check if this specific image instance is already saved + img_id = id(pil_image) + if img_id in image_paths_in_zip: + # If already saved, just add its filename to the list + image_filenames_for_json.append(image_paths_in_zip[img_id]) + continue # Move to the next image in the list + + # Image not saved yet, create filename and save path + img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png" + img_save_path = os.path.join(tmpdir, img_filename_in_zip) + + try: + # Save the image to the temporary directory + pil_image.save(img_save_path, "PNG") + image_filenames_for_json.append(img_filename_in_zip) + # Store the mapping from image ID to its filename in the zip + image_paths_in_zip[img_id] = img_filename_in_zip + except Exception as e: + print(f"Error saving image {img_filename_in_zip} for task {task_id_s}: {e}") + # Optionally decide if you want to continue or fail here + + # Update the params_copy with the list of filenames (or single filename) + if image_filenames_for_json: + params_copy[key] = image_filenames_for_json if is_originally_list else image_filenames_for_json[0] + else: + # If no images were successfully processed for this key, remove it + params_copy.pop(key, None) + + + # Clean up parameters before adding to manifest + params_copy.pop('state', None) + params_copy.pop('start_image_data_base64', None) # Don't need base64 in saved queue + params_copy.pop('end_image_data_base64', None) + # Also remove the actual PIL data if it somehow remained + params_copy.pop('start_image_data', None) + params_copy.pop('end_image_data', None) + + manifest_entry = { + "id": task.get('id'), + "params": params_copy, + # Keep other necessary top-level task info if needed, like repeats etc. + # Example: "repeats": task.get('repeats', 1) + } + queue_manifest.append(manifest_entry) + + # --- Create queue.json content --- + manifest_path = os.path.join(tmpdir, "queue.json") + try: + with open(manifest_path, 'w', encoding='utf-8') as f: + # Dump only the relevant manifest data + json.dump(queue_manifest, f, indent=4) + except Exception as e: + print(f"Error writing queue.json: {e}") + gr.Warning("Failed to create queue manifest.") + return None # Return None on failure + + # --- Create the zip file in memory --- + try: + with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: + # Add queue.json + zf.write(manifest_path, arcname="queue.json") + + # Add all unique images that were saved to the temp dir + for saved_img_rel_path in image_paths_in_zip.values(): + saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path) + if os.path.exists(saved_img_abs_path): + zf.write(saved_img_abs_path, arcname=saved_img_rel_path) + else: + # This shouldn't happen if saving was successful, but good to check + print(f"Warning: Image file {saved_img_rel_path} not found during zipping.") + + # --- Prepare for return --- + # Move buffer position to the beginning + zip_buffer.seek(0) + # Read the binary content + zip_binary_content = zip_buffer.getvalue() + # Encode as base64 string + zip_base64 = base64.b64encode(zip_binary_content).decode('utf-8') + print(f"Queue successfully prepared as base64 string ({len(zip_base64)} chars).") + return zip_base64 + + except Exception as e: + print(f"Error creating zip file in memory: {e}") + gr.Warning("Failed to create zip data for download.") + return None # Return None on failure + finally: + zip_buffer.close() + +def load_queue_action(filepath, state): + global task_id + gen = get_gen_info(state) + original_queue = gen.get("queue", []) # Store original queue for error case + + if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file(): + print("[load_queue_action] Warning: No valid file selected or file not found.") + # Return the current state of the DataFrame + return update_queue_data(original_queue) + + newly_loaded_queue = [] + max_id_in_file = 0 + error_message = "" + local_queue_copy_for_global_ref = None + + try: + print(f"[load_queue_action] Attempting to load queue from: {filepath.name}") + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(filepath.name, 'r') as zf: + if "queue.json" not in zf.namelist(): raise ValueError("queue.json not found in zip file") + print(f"[load_queue_action] Extracting {filepath.name} to {tmpdir}") + zf.extractall(tmpdir) + print(f"[load_queue_action] Extraction complete.") + + manifest_path = os.path.join(tmpdir, "queue.json") + print(f"[load_queue_action] Reading manifest: {manifest_path}") + with open(manifest_path, 'r', encoding='utf-8') as f: + loaded_manifest = json.load(f) + print(f"[load_queue_action] Manifest loaded. Processing {len(loaded_manifest)} tasks.") + + for task_index, task_data in enumerate(loaded_manifest): + # (Keep the existing task processing logic here...) + if task_data is None or not isinstance(task_data, dict): + print(f"[load_queue_action] Skipping invalid task data at index {task_index}") + continue + + params = task_data.get('params', {}) + task_id_loaded = task_data.get('id', 0) + max_id_in_file = max(max_id_in_file, task_id_loaded) + loaded_pil_images = {} + image_keys = ["image_start", "image_end", "image_refs"] + params['state'] = state # Add state back temporarily for consistency if needed by internal logic, but it's removed before saving + + for key in image_keys: + image_filenames = params.get(key) + if image_filenames is None: continue + is_list = isinstance(image_filenames, list) + if not is_list: image_filenames = [image_filenames] + loaded_pils = [] + for img_filename_in_zip in image_filenames: + if not isinstance(img_filename_in_zip, str): continue + img_load_path = os.path.join(tmpdir, img_filename_in_zip) + if not os.path.exists(img_load_path): + print(f"[load_queue_action] Image file not found during load: {img_load_path}") + continue + try: + pil_image = Image.open(img_load_path) + # Ensure the image data is loaded into memory before the temp dir is cleaned up + pil_image.load() + # Convert image right after loading + converted_image = convert_image(pil_image) + loaded_pils.append(converted_image) + pil_image.close() # Close the file handle + except Exception as img_e: + print(f"[load_queue_action] Error loading image {img_filename_in_zip}: {img_e}") + if loaded_pils: + params[key] = loaded_pils if is_list else loaded_pils[0] + loaded_pil_images[key] = params[key] # Store loaded PILs for preview generation + else: params.pop(key, None) + + # Generate preview base64 strings + primary_preview_pil, secondary_preview_pil = None, None + start_prev_pil_list = loaded_pil_images.get("image_start") + end_prev_pil_list = loaded_pil_images.get("image_end") + ref_prev_pil_list = loaded_pil_images.get("image_refs") + + # Extract first image for preview if available + if start_prev_pil_list: + primary_preview_pil = start_prev_pil_list[0] if isinstance(start_prev_pil_list, list) and start_prev_pil_list else start_prev_pil_list if not isinstance(start_prev_pil_list, list) else None + if end_prev_pil_list: + secondary_preview_pil = end_prev_pil_list[0] if isinstance(end_prev_pil_list, list) and end_prev_pil_list else end_prev_pil_list if not isinstance(end_prev_pil_list, list) else None + elif ref_prev_pil_list and isinstance(ref_prev_pil_list, list) and ref_prev_pil_list: + primary_preview_pil = ref_prev_pil_list[0] + + # Generate base64 only if PIL image exists + start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None + end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None + + # Get top-level image data (PIL objects) for runtime task + top_level_start_image = loaded_pil_images.get("image_start") + top_level_end_image = loaded_pil_images.get("image_end") + + # Construct the runtime task dictionary + runtime_task = { + "id": task_id_loaded, + "params": params.copy(), # Use a copy of params + # Extract necessary params for top level if they exist + "repeats": params.get('repeat_generation', 1), + "length": params.get('video_length'), + "steps": params.get('num_inference_steps'), + "prompt": params.get('prompt'), + # Store the actual loaded PIL image data here + "start_image_data": top_level_start_image, + "end_image_data": top_level_end_image, + # Store base64 previews generated above + "start_image_data_base64": start_b64, + "end_image_data_base64": end_b64, + } + newly_loaded_queue.append(runtime_task) + print(f"[load_queue_action] Processed task {task_index+1}/{len(loaded_manifest)}, ID: {task_id_loaded}") + + # --- State Update --- + with lock: + print("[load_queue_action] Acquiring lock to update state...") + gen["queue"] = newly_loaded_queue[:] # Replace the queue in the state + local_queue_copy_for_global_ref = gen["queue"][:] # Copy for global ref update + current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0]) # Safer max ID calculation + + # Update global task ID only if the loaded max ID is higher + if current_max_id_in_new_queue > task_id: + print(f"[load_queue_action] Updating global task_id from {task_id} to {current_max_id_in_new_queue + 1}") + task_id = current_max_id_in_new_queue + 1 # Ensure next ID is unique + else: + print(f"[load_queue_action] Global task_id ({task_id}) is >= max in file ({current_max_id_in_new_queue}). Not changing task_id.") + + gen["prompts_max"] = len(newly_loaded_queue) + print("[load_queue_action] State update complete. Releasing lock.") + + # --- Global Reference Update --- + if local_queue_copy_for_global_ref is not None: + print("[load_queue_action] Updating global queue reference...") + update_global_queue_ref(local_queue_copy_for_global_ref) + else: + # This case should ideally not be reached if state update happens + print("[load_queue_action] Warning: Skipping global ref update as local copy is None.") + + print(f"[load_queue_action] Queue load successful. Returning DataFrame update for {len(newly_loaded_queue)} tasks.") + # *** Return the DataFrame update object *** + return update_queue_data(newly_loaded_queue) + + except (ValueError, zipfile.BadZipFile, FileNotFoundError, Exception) as e: + error_message = f"Error during queue load: {e}" + print(f"[load_queue_action] Caught error: {error_message}") + traceback.print_exc() + # Optionally show a Gradio warning/error to the user + gr.Warning(f"Failed to load queue: {error_message[:200]}") # Show truncated error + + # *** Return the DataFrame update for the original queue *** + print("[load_queue_action] Load failed. Returning DataFrame update for original queue.") + return update_queue_data(original_queue) + finally: + # Clean up the uploaded file object if it exists and has a path + if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name): + try: + # Gradio often uses temp files, attempting removal is good practice + # os.remove(filepath.name) + # print(f"[load_queue_action] Cleaned up temporary upload file: {filepath.name}") + pass # Let Gradio manage its temp files unless specifically needed + except OSError as e: + # Ignore errors like "file not found" if already cleaned up + print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}") + pass + +def clear_queue_action(state): + gen = get_gen_info(state) + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is already empty.") + return update_queue_data([]) + + with lock: + queue.clear() + gen["prompts_max"] = 0 + + gr.Info("Queue cleared.") + return update_queue_data([]) + +def autosave_queue(): + global global_queue_ref + if not global_queue_ref: + print("Autosave: Queue is empty, nothing to save.") + return + + print(f"Autosaving queue ({len(global_queue_ref)} items) to {AUTOSAVE_FILENAME}...") + temp_state_for_save = {"gen": {"queue": global_queue_ref}} + zip_file_path = None + try: + + def _save_queue_to_file(queue_to_save, output_filename): + if not queue_to_save: return None + with tempfile.TemporaryDirectory() as tmpdir: + queue_manifest = [] + image_paths_in_zip = {} + for task_index, task in enumerate(queue_to_save): + if task is None or not isinstance(task, dict): continue + params_copy = task.get('params', {}).copy() + task_id_s = task.get('id', f"task_{task_index}") + image_keys = ["image_start", "image_end", "image_refs"] + for key in image_keys: + images_pil = params_copy.get(key) + if images_pil is None: continue + is_list = isinstance(images_pil, list) + if not is_list: images_pil = [images_pil] + image_filenames_for_json = [] + for img_index, pil_image in enumerate(images_pil): + if not isinstance(pil_image, Image.Image): continue + img_id = id(pil_image) + if img_id in image_paths_in_zip: + image_filenames_for_json.append(image_paths_in_zip[img_id]) + continue + img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png" + img_save_path = os.path.join(tmpdir, img_filename_in_zip) + try: + pil_image.save(img_save_path, "PNG") + image_filenames_for_json.append(img_filename_in_zip) + image_paths_in_zip[img_id] = img_filename_in_zip + except Exception as e: + print(f"Autosave error saving image {img_filename_in_zip}: {e}") + if image_filenames_for_json: + params_copy[key] = image_filenames_for_json if is_list else image_filenames_for_json[0] + else: + params_copy.pop(key, None) + params_copy.pop('state', None) + params_copy.pop('start_image_data_base64', None) + params_copy.pop('end_image_data_base64', None) + manifest_entry = { + "id": task.get('id'), "params": params_copy, + } + queue_manifest.append(manifest_entry) + manifest_path = os.path.join(tmpdir, "queue.json") + with open(manifest_path, 'w', encoding='utf-8') as f: json.dump(queue_manifest, f, indent=4) + with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zf: + zf.write(manifest_path, arcname="queue.json") + for saved_img_rel_path in image_paths_in_zip.values(): + saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path) + if os.path.exists(saved_img_abs_path): + zf.write(saved_img_abs_path, arcname=saved_img_rel_path) + return output_filename + return None # Should not happen if queue has items + + saved_path = _save_queue_to_file(global_queue_ref, AUTOSAVE_FILENAME) + + if saved_path: + print(f"Queue autosaved successfully to {saved_path}") + else: + print("Autosave failed.") + except Exception as e: + print(f"Error during autosave: {e}") + traceback.print_exc() + + +def autoload_queue(state): + global task_id + # Initial check using the original state + try: + gen = get_gen_info(state) # Make sure initial state is a dict + original_queue = gen.get("queue", []) + except AttributeError: + print("[autoload_queue] Error: Initial state is not a dictionary. Cannot autoload.") + # Return default values indicating no load occurred and the state is unchanged + return gr.update(visible=False), False, state # Return an empty DF update + + loaded_flag = False + dataframe_update = update_queue_data(original_queue) # Default update is the original queue + + if not original_queue and Path(AUTOSAVE_FILENAME).is_file(): + print(f"Autoloading queue from {AUTOSAVE_FILENAME}...") + class MockFile: + def __init__(self, name): + self.name = name + mock_filepath = MockFile(AUTOSAVE_FILENAME) + + # Call load_queue_action, it modifies 'state' internally and returns a DataFrame update + dataframe_update = load_queue_action(mock_filepath, state) + + # Now check the 'state' dictionary which should have been modified by load_queue_action + gen = get_gen_info(state) # Use the (potentially) modified state dictionary + loaded_queue_after_action = gen.get("queue", []) + + if loaded_queue_after_action: # Check if the queue in the state is now populated + print(f"Autoload successful. Loaded {len(loaded_queue_after_action)} tasks into state.") + loaded_flag = True + # Global ref update was already done inside load_queue_action if successful + else: + print("Autoload attempted but queue in state remains empty (file might be empty or invalid).") + # Ensure state reflects empty queue if load failed but file existed + with lock: + gen["queue"] = [] + gen["prompts_max"] = 0 + update_global_queue_ref([]) + dataframe_update = update_queue_data([]) # Ensure UI shows empty queue + + else: # Handle cases where autoload shouldn't happen + if original_queue: + print("Autoload skipped: Queue is not empty.") + update_global_queue_ref(original_queue) # Ensure global ref matches current state + dataframe_update = update_queue_data(original_queue) # UI should show current queue + else: + print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.") + update_global_queue_ref([]) # Ensure global ref is empty + dataframe_update = update_queue_data([]) # UI should show empty queue + + # Return the DataFrame update needed for the UI, the flag, and the final state dictionary + return dataframe_update, loaded_flag, state def get_queue_table(queue): @@ -390,7 +838,7 @@ def get_queue_table(queue): ]) return data def update_queue_data(queue): - + update_global_queue_ref(queue) data = get_queue_table(queue) # if len(data) == 0: @@ -1993,6 +2441,7 @@ def process_tasks(state, progress=gr.Progress()): yield status queue[:] = [item for item in queue if item['id'] != task['id']] + update_global_queue_ref(queue) gen["prompts_max"] = 0 gen["prompt"] = "" @@ -2716,7 +3165,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non wizard_variables = "\n".join(variables) for _ in range( PROMPT_VARS_MAX - len(prompt_vars)): prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False)) - + with gr.Column(not advanced_prompt) as prompt_column_wizard: wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3) wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) @@ -2902,7 +3351,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non queue_df = gr.DataFrame( headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], - column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], + column_widths= ["5%", None, "7%", "7%", "10%", "10%", "3%", "3%", "3%"], interactive=False, col_count=(9, "fixed"), wrap=True, @@ -2911,6 +3360,72 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non visible= False, elem_id="queue_df" ) + with gr.Row(): + queue_zip_base64_output = gr.Text(visible=False) + save_queue_btn = gr.DownloadButton("Save Queue", size="sm") + load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm") + clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop") + trigger_zip_download_js = """ + (base64String) => { + if (!base64String) { + console.log("No base64 zip data received, skipping download."); + return; + } + try { + const byteCharacters = atob(base64String); + const byteNumbers = new Array(byteCharacters.length); + for (let i = 0; i < byteCharacters.length; i++) { + byteNumbers[i] = byteCharacters.charCodeAt(i); + } + const byteArray = new Uint8Array(byteNumbers); + const blob = new Blob([byteArray], { type: 'application/zip' }); + + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.style.display = 'none'; + a.href = url; + a.download = 'queue.zip'; + document.body.appendChild(a); + a.click(); + + window.URL.revokeObjectURL(url); + document.body.removeChild(a); + console.log("Zip download triggered."); + } catch (e) { + console.error("Error processing base64 data or triggering download:", e); + } + } + """ + save_queue_btn.click( + fn=save_queue_action, + inputs=[state], + outputs=[queue_zip_base64_output] + ).then( + fn=None, + inputs=[queue_zip_base64_output], + outputs=None, + js=trigger_zip_download_js + ) + + load_queue_btn.upload( + fn=load_queue_action, + inputs=[load_queue_btn, state], + outputs=[queue_df] + ).then( + fn=lambda s: gr.update(visible=bool(get_gen_info(s).get("queue",[]))), + inputs=[state], + outputs=[current_gen_column] + ) + + clear_queue_btn.click( + fn=clear_queue_action, + inputs=[state], + outputs=[queue_df] + ).then( + fn=lambda: gr.update(visible=False), + inputs=None, + outputs=[current_gen_column] + ) extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column, @@ -3014,7 +3529,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non outputs=[modal_container] ) - return loras_choices, lset_name, state + return ( + loras_choices, lset_name, state, queue_df, current_gen_column, + gen_status, output, abort_btn, generate_btn, add_to_queue_btn, + gen_info, + prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, + prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, + advanced_row, image_prompt_column, video_prompt_column, + *prompt_vars + ) + def generate_download_tab(lset_name,loras_choices, state): with gr.Row(): @@ -3479,8 +4003,15 @@ def create_demo(): with gr.Row(): header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True) with gr.Row(): - - loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header) + ( + loras_choices, lset_name, state, queue_df, current_gen_column, + gen_status, output, abort_btn, generate_btn, add_to_queue_btn, + gen_info, + prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, + prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, + advanced_row, image_prompt_column, video_prompt_column, + *prompt_vars_outputs + ) = generate_video_tab(model_choice=model_choice, header=header) with gr.Tab("Informations"): generate_info_tab() if not args.lock_config: @@ -3491,9 +4022,47 @@ def create_demo(): with gr.Tab("About"): generate_about_tab() + should_start_flag = gr.State(False) + def run_autoload_and_prepare_ui(current_state): + df_update, loaded_flag, modified_state = autoload_queue(current_state) + should_start_processing = loaded_flag + return df_update, gr.update(visible=loaded_flag), should_start_processing, modified_state + + def start_processing_if_needed(should_start, current_state): + if not isinstance(current_state, dict) or 'gen' not in current_state: + yield "Error: Invalid state received before processing." + return + if should_start: + yield from process_tasks(current_state) + else: + yield "Autoload complete. Processing not started." + + def finalize_generation_with_state(current_state): + if not isinstance(current_state, dict) or 'gen' not in current_state: + return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), current_state + gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state) + return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, current_state + + demo.load( + fn=run_autoload_and_prepare_ui, + inputs=[state], + outputs=[queue_df, current_gen_column, should_start_flag, state] + ).then( + fn=start_processing_if_needed, + inputs=[should_start_flag, state], + outputs=[gen_status], + trigger_mode="once" + ).then( + fn=finalize_generation_with_state, + inputs=[state], + outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, state], + trigger_mode="always_last" + ) + return demo if __name__ == "__main__": + atexit.register(autosave_queue) # threading.Thread(target=runner, daemon=True).start() os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" server_port = int(args.server_port)