This commit is contained in:
DeepBeepMeep 2025-04-09 22:11:18 +02:00
commit 7d2bb5308c

451
wgp.py
View File

@ -31,6 +31,7 @@ from PIL import Image
import zipfile import zipfile
import tempfile import tempfile
import atexit import atexit
import shutil
global_queue_ref = [] global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip" AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
@ -367,120 +368,127 @@ def save_queue_action(state):
gen = get_gen_info(state) gen = get_gen_info(state)
queue = gen.get("queue", []) queue = gen.get("queue", [])
if not queue or len(queue) <=1 : # Check if queue is empty or only has the placeholder if not queue or len(queue) <=1 :
gr.Info("Queue is empty. Nothing to save.") gr.Info("Queue is empty. Nothing to save.")
return None # Return None if nothing to save return ""
# Use an in-memory buffer for the zip file
zip_buffer = io.BytesIO() zip_buffer = io.BytesIO()
# Still use a temporary directory *only* for storing images before zipping
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
queue_manifest = [] queue_manifest = []
image_paths_in_zip = {} # Tracks image PIL object ID -> filename in zip file_paths_in_zip = {}
for task_index, task in enumerate(queue): for task_index, task in enumerate(queue):
# Skip the placeholder item if it exists if task is None or not isinstance(task, dict) or task.get('id') is None: continue
if task is None or not isinstance(task, dict) or task_index == 0: continue
params_copy = task.get('params', {}).copy() params_copy = task.get('params', {}).copy()
task_id_s = task.get('id', f"task_{task_index}") # Use a different var name task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask"]
for key in image_keys: for key in image_keys:
images_pil = params_copy.get(key) images_pil = params_copy.get(key)
if images_pil is None: if images_pil is None:
continue continue
# Ensure images_pil is always a list for processing
is_originally_list = isinstance(images_pil, list) is_originally_list = isinstance(images_pil, list)
if not is_originally_list: if not is_originally_list:
images_pil = [images_pil] images_pil = [images_pil]
image_filenames_for_json = [] image_filenames_for_json = []
for img_index, pil_image in enumerate(images_pil): for img_index, pil_image in enumerate(images_pil):
# Ensure it's actually a PIL Image object before proceeding
if not isinstance(pil_image, Image.Image): if not isinstance(pil_image, Image.Image):
print(f"Warning: Expected PIL Image for key '{key}' in task {task_id_s}, got {type(pil_image)}. Skipping image.") print(f"Warning: Expected PIL Image for key '{key}' in task {task_id_s}, got {type(pil_image)}. Skipping image.")
continue continue
# Use object ID to check if this specific image instance is already saved
img_id = id(pil_image) img_id = id(pil_image)
if img_id in image_paths_in_zip: if img_id in file_paths_in_zip:
# If already saved, just add its filename to the list image_filenames_for_json.append(file_paths_in_zip[img_id])
image_filenames_for_json.append(image_paths_in_zip[img_id]) continue
continue # Move to the next image in the list
# Image not saved yet, create filename and save path
img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png" img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png"
img_save_path = os.path.join(tmpdir, img_filename_in_zip) img_save_path = os.path.join(tmpdir, img_filename_in_zip)
try: try:
# Save the image to the temporary directory
pil_image.save(img_save_path, "PNG") pil_image.save(img_save_path, "PNG")
image_filenames_for_json.append(img_filename_in_zip) image_filenames_for_json.append(img_filename_in_zip)
# Store the mapping from image ID to its filename in the zip file_paths_in_zip[img_id] = img_filename_in_zip
image_paths_in_zip[img_id] = img_filename_in_zip print(f"Saved image: {img_filename_in_zip}")
except Exception as e: except Exception as e:
print(f"Error saving image {img_filename_in_zip} for task {task_id_s}: {e}") print(f"Error saving image {img_filename_in_zip} for task {task_id_s}: {e}")
# Optionally decide if you want to continue or fail here
# Update the params_copy with the list of filenames (or single filename)
if image_filenames_for_json: if image_filenames_for_json:
params_copy[key] = image_filenames_for_json if is_originally_list else image_filenames_for_json[0] params_copy[key] = image_filenames_for_json if is_originally_list else image_filenames_for_json[0]
else: else:
# If no images were successfully processed for this key, remove it params_copy.pop(key, None)
for key in video_keys:
video_path_orig = params_copy.get(key)
if video_path_orig is None or not isinstance(video_path_orig, str):
continue
if video_path_orig in file_paths_in_zip:
params_copy[key] = file_paths_in_zip[video_path_orig]
continue
if not os.path.isfile(video_path_orig):
print(f"Warning: Video file not found for key '{key}' in task {task_id_s}: {video_path_orig}. Skipping video.")
params_copy.pop(key, None)
continue
_, extension = os.path.splitext(video_path_orig)
vid_filename_in_zip = f"task{task_id_s}_{key}{extension if extension else '.mp4'}"
vid_save_path = os.path.join(tmpdir, vid_filename_in_zip)
try:
shutil.copy2(video_path_orig, vid_save_path)
params_copy[key] = vid_filename_in_zip
file_paths_in_zip[video_path_orig] = vid_filename_in_zip
print(f"Copied video: {video_path_orig} -> {vid_filename_in_zip}")
except Exception as e:
print(f"Error copying video {video_path_orig} to {vid_filename_in_zip} for task {task_id_s}: {e}")
params_copy.pop(key, None) params_copy.pop(key, None)
# Clean up parameters before adding to manifest
params_copy.pop('state', None) params_copy.pop('state', None)
params_copy.pop('start_image_data_base64', None) # Don't need base64 in saved queue params_copy.pop('start_image_data_base64', None)
params_copy.pop('end_image_data_base64', None) params_copy.pop('end_image_data_base64', None)
# Also remove the actual PIL data if it somehow remained
params_copy.pop('start_image_data', None) params_copy.pop('start_image_data', None)
params_copy.pop('end_image_data', None) params_copy.pop('end_image_data', None)
task.pop('start_image_data', None)
task.pop('end_image_data', None)
manifest_entry = { manifest_entry = {
"id": task.get('id'), "id": task.get('id'),
"params": params_copy, "params": params_copy,
# Keep other necessary top-level task info if needed, like repeats etc.
# Example: "repeats": task.get('repeats', 1)
} }
manifest_entry = {k: v for k, v in manifest_entry.items() if v is not None}
queue_manifest.append(manifest_entry) queue_manifest.append(manifest_entry)
# --- Create queue.json content ---
manifest_path = os.path.join(tmpdir, "queue.json") manifest_path = os.path.join(tmpdir, "queue.json")
try: try:
with open(manifest_path, 'w', encoding='utf-8') as f: with open(manifest_path, 'w', encoding='utf-8') as f:
# Dump only the relevant manifest data
json.dump(queue_manifest, f, indent=4) json.dump(queue_manifest, f, indent=4)
except Exception as e: except Exception as e:
print(f"Error writing queue.json: {e}") print(f"Error writing queue.json: {e}")
gr.Warning("Failed to create queue manifest.") gr.Warning("Failed to create queue manifest.")
return None # Return None on failure return None
# --- Create the zip file in memory ---
try: try:
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
# Add queue.json
zf.write(manifest_path, arcname="queue.json") zf.write(manifest_path, arcname="queue.json")
# Add all unique images that were saved to the temp dir for file_id, saved_file_rel_path in file_paths_in_zip.items():
for saved_img_rel_path in image_paths_in_zip.values(): saved_file_abs_path = os.path.join(tmpdir, saved_file_rel_path)
saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path) if os.path.exists(saved_file_abs_path):
if os.path.exists(saved_img_abs_path): zf.write(saved_file_abs_path, arcname=saved_file_rel_path)
zf.write(saved_img_abs_path, arcname=saved_img_rel_path) print(f"Adding to zip: {saved_file_rel_path}")
else: else:
# This shouldn't happen if saving was successful, but good to check print(f"Warning: File {saved_file_rel_path} (ID: {file_id}) not found during zipping.")
print(f"Warning: Image file {saved_img_rel_path} not found during zipping.")
# --- Prepare for return ---
# Move buffer position to the beginning
zip_buffer.seek(0) zip_buffer.seek(0)
# Read the binary content
zip_binary_content = zip_buffer.getvalue() zip_binary_content = zip_buffer.getvalue()
# Encode as base64 string
zip_base64 = base64.b64encode(zip_binary_content).decode('utf-8') zip_base64 = base64.b64encode(zip_binary_content).decode('utf-8')
print(f"Queue successfully prepared as base64 string ({len(zip_base64)} chars).") print(f"Queue successfully prepared as base64 string ({len(zip_base64)} chars).")
return zip_base64 return zip_base64
@ -488,18 +496,19 @@ def save_queue_action(state):
except Exception as e: except Exception as e:
print(f"Error creating zip file in memory: {e}") print(f"Error creating zip file in memory: {e}")
gr.Warning("Failed to create zip data for download.") gr.Warning("Failed to create zip data for download.")
return None # Return None on failure return None
finally: finally:
zip_buffer.close() zip_buffer.close()
def load_queue_action(filepath, state): def load_queue_action(filepath, state):
global task_id global task_id
gen = get_gen_info(state) gen = get_gen_info(state)
original_queue = gen.get("queue", []) # Store original queue for error case original_queue = gen.get("queue", [])
save_path_base = server_config.get("save_path", "outputs")
loaded_cache_dir = os.path.join(save_path_base, "_loaded_queue_cache")
if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file(): if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file():
print("[load_queue_action] Warning: No valid file selected or file not found.") print("[load_queue_action] Warning: No valid file selected or file not found.")
# Return the current state of the DataFrame
return update_queue_data(original_queue) return update_queue_data(original_queue)
newly_loaded_queue = [] newly_loaded_queue = []
@ -509,6 +518,9 @@ def load_queue_action(filepath, state):
try: try:
print(f"[load_queue_action] Attempting to load queue from: {filepath.name}") print(f"[load_queue_action] Attempting to load queue from: {filepath.name}")
os.makedirs(loaded_cache_dir, exist_ok=True)
print(f"[load_queue_action] Using cache directory: {loaded_cache_dir}")
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(filepath.name, 'r') as zf: with zipfile.ZipFile(filepath.name, 'r') as zf:
if "queue.json" not in zf.namelist(): raise ValueError("queue.json not found in zip file") if "queue.json" not in zf.namelist(): raise ValueError("queue.json not found in zip file")
@ -523,7 +535,6 @@ def load_queue_action(filepath, state):
print(f"[load_queue_action] Manifest loaded. Processing {len(loaded_manifest)} tasks.") print(f"[load_queue_action] Manifest loaded. Processing {len(loaded_manifest)} tasks.")
for task_index, task_data in enumerate(loaded_manifest): for task_index, task_data in enumerate(loaded_manifest):
# (Keep the existing task processing logic here...)
if task_data is None or not isinstance(task_data, dict): if task_data is None or not isinstance(task_data, dict):
print(f"[load_queue_action] Skipping invalid task data at index {task_index}") print(f"[load_queue_action] Skipping invalid task data at index {task_index}")
continue continue
@ -531,144 +542,186 @@ def load_queue_action(filepath, state):
params = task_data.get('params', {}) params = task_data.get('params', {})
task_id_loaded = task_data.get('id', 0) task_id_loaded = task_data.get('id', 0)
max_id_in_file = max(max_id_in_file, task_id_loaded) max_id_in_file = max(max_id_in_file, task_id_loaded)
loaded_pil_images = {} params['state'] = state
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
params['state'] = state # Add state back temporarily for consistency if needed by internal logic, but it's removed before saving video_keys = ["video_guide", "video_mask"]
loaded_pil_images = {}
loaded_video_paths = {}
for key in image_keys: for key in image_keys:
image_filenames = params.get(key) image_filenames = params.get(key)
if image_filenames is None: continue if image_filenames is None: continue
is_list = isinstance(image_filenames, list) is_list = isinstance(image_filenames, list)
if not is_list: image_filenames = [image_filenames] if not is_list: image_filenames = [image_filenames]
loaded_pils = [] loaded_pils = []
for img_filename_in_zip in image_filenames: for img_filename_in_zip in image_filenames:
if not isinstance(img_filename_in_zip, str): continue if not isinstance(img_filename_in_zip, str):
print(f"[load_queue_action] Warning: Non-string filename found for image key '{key}'. Skipping.")
continue
img_load_path = os.path.join(tmpdir, img_filename_in_zip) img_load_path = os.path.join(tmpdir, img_filename_in_zip)
if not os.path.exists(img_load_path): if not os.path.exists(img_load_path):
print(f"[load_queue_action] Image file not found during load: {img_load_path}") print(f"[load_queue_action] Image file not found in extracted data: {img_load_path}. Skipping.")
continue continue
try: try:
pil_image = Image.open(img_load_path) pil_image = Image.open(img_load_path)
# Ensure the image data is loaded into memory before the temp dir is cleaned up
pil_image.load() pil_image.load()
# Convert image right after loading
converted_image = convert_image(pil_image) converted_image = convert_image(pil_image)
loaded_pils.append(converted_image) loaded_pils.append(converted_image)
pil_image.close() # Close the file handle pil_image.close()
print(f"Loaded image: {img_filename_in_zip} for key {key}")
except Exception as img_e: except Exception as img_e:
print(f"[load_queue_action] Error loading image {img_filename_in_zip}: {img_e}") print(f"[load_queue_action] Error loading image {img_filename_in_zip}: {img_e}")
if loaded_pils: if loaded_pils:
params[key] = loaded_pils if is_list else loaded_pils[0] params[key] = loaded_pils if is_list else loaded_pils[0]
loaded_pil_images[key] = params[key] # Store loaded PILs for preview generation loaded_pil_images[key] = params[key]
else: params.pop(key, None) else:
params.pop(key, None)
# Generate preview base64 strings for key in video_keys:
primary_preview_pil, secondary_preview_pil = None, None video_filename_in_zip = params.get(key)
start_prev_pil_list = loaded_pil_images.get("image_start") if video_filename_in_zip is None or not isinstance(video_filename_in_zip, str):
end_prev_pil_list = loaded_pil_images.get("image_end") continue
ref_prev_pil_list = loaded_pil_images.get("image_refs")
# Extract first image for preview if available video_load_path = os.path.join(tmpdir, video_filename_in_zip)
if start_prev_pil_list: if not os.path.exists(video_load_path):
primary_preview_pil = start_prev_pil_list[0] if isinstance(start_prev_pil_list, list) and start_prev_pil_list else start_prev_pil_list if not isinstance(start_prev_pil_list, list) else None print(f"[load_queue_action] Video file not found in extracted data: {video_load_path}. Skipping.")
if end_prev_pil_list: params.pop(key, None)
secondary_preview_pil = end_prev_pil_list[0] if isinstance(end_prev_pil_list, list) and end_prev_pil_list else end_prev_pil_list if not isinstance(end_prev_pil_list, list) else None continue
elif ref_prev_pil_list and isinstance(ref_prev_pil_list, list) and ref_prev_pil_list:
primary_preview_pil = ref_prev_pil_list[0] persistent_video_path = os.path.join(loaded_cache_dir, video_filename_in_zip)
try:
shutil.copy2(video_load_path, persistent_video_path)
params[key] = persistent_video_path
loaded_video_paths[key] = persistent_video_path
print(f"Loaded video: {video_filename_in_zip} -> {persistent_video_path}")
except Exception as vid_e:
print(f"[load_queue_action] Error copying video {video_filename_in_zip} to cache: {vid_e}")
params.pop(key, None)
primary_preview_pil_list = loaded_pil_images.get("image_start") or loaded_pil_images.get("image_refs")
secondary_preview_pil_list = loaded_pil_images.get("image_end")
primary_preview_pil = None
if primary_preview_pil_list:
primary_preview_pil = primary_preview_pil_list[0] if isinstance(primary_preview_pil_list, list) else primary_preview_pil_list
secondary_preview_pil = None
if secondary_preview_pil_list:
secondary_preview_pil = secondary_preview_pil_list[0] if isinstance(secondary_preview_pil_list, list) else secondary_preview_pil_list
# Generate base64 only if PIL image exists
start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None
end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None
# Get top-level image data (PIL objects) for runtime task top_level_start_image = params.get("image_start") or params.get("image_refs")
top_level_start_image = loaded_pil_images.get("image_start") top_level_end_image = params.get("image_end")
top_level_end_image = loaded_pil_images.get("image_end")
# Construct the runtime task dictionary
runtime_task = { runtime_task = {
"id": task_id_loaded, "id": task_id_loaded,
"params": params.copy(), # Use a copy of params "params": params.copy(),
# Extract necessary params for top level if they exist
"repeats": params.get('repeat_generation', 1), "repeats": params.get('repeat_generation', 1),
"length": params.get('video_length'), "length": params.get('video_length'),
"steps": params.get('num_inference_steps'), "steps": params.get('num_inference_steps'),
"prompt": params.get('prompt'), "prompt": params.get('prompt'),
# Store the actual loaded PIL image data here
"start_image_data": top_level_start_image, "start_image_data": top_level_start_image,
"end_image_data": top_level_end_image, "end_image_data": top_level_end_image,
# Store base64 previews generated above
"start_image_data_base64": start_b64, "start_image_data_base64": start_b64,
"end_image_data_base64": end_b64, "end_image_data_base64": end_b64,
} }
newly_loaded_queue.append(runtime_task) newly_loaded_queue.append(runtime_task)
print(f"[load_queue_action] Processed task {task_index+1}/{len(loaded_manifest)}, ID: {task_id_loaded}") print(f"[load_queue_action] Reconstructed task {task_index+1}/{len(loaded_manifest)}, ID: {task_id_loaded}")
# --- State Update ---
with lock: with lock:
print("[load_queue_action] Acquiring lock to update state...") print("[load_queue_action] Acquiring lock to update state...")
gen["queue"] = newly_loaded_queue[:] # Replace the queue in the state gen["queue"] = newly_loaded_queue[:]
local_queue_copy_for_global_ref = gen["queue"][:] # Copy for global ref update local_queue_copy_for_global_ref = gen["queue"][:]
current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0]) # Safer max ID calculation
# Update global task ID only if the loaded max ID is higher current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0])
if current_max_id_in_new_queue > task_id: if current_max_id_in_new_queue >= task_id:
print(f"[load_queue_action] Updating global task_id from {task_id} to {current_max_id_in_new_queue + 1}") new_task_id = current_max_id_in_new_queue + 1
task_id = current_max_id_in_new_queue + 1 # Ensure next ID is unique print(f"[load_queue_action] Updating global task_id from {task_id} to {new_task_id}")
task_id = new_task_id
else: else:
print(f"[load_queue_action] Global task_id ({task_id}) is >= max in file ({current_max_id_in_new_queue}). Not changing task_id.") print(f"[load_queue_action] Global task_id ({task_id}) is > max in file ({current_max_id_in_new_queue}). Not changing task_id.")
gen["prompts_max"] = len(newly_loaded_queue) gen["prompts_max"] = len(newly_loaded_queue)
print("[load_queue_action] State update complete. Releasing lock.") print("[load_queue_action] State update complete. Releasing lock.")
# --- Global Reference Update ---
if local_queue_copy_for_global_ref is not None: if local_queue_copy_for_global_ref is not None:
print("[load_queue_action] Updating global queue reference...") print("[load_queue_action] Updating global queue reference...")
update_global_queue_ref(local_queue_copy_for_global_ref) update_global_queue_ref(local_queue_copy_for_global_ref)
else: else:
# This case should ideally not be reached if state update happens
print("[load_queue_action] Warning: Skipping global ref update as local copy is None.") print("[load_queue_action] Warning: Skipping global ref update as local copy is None.")
print(f"[load_queue_action] Queue load successful. Returning DataFrame update for {len(newly_loaded_queue)} tasks.") print(f"[load_queue_action] Queue load successful. Returning DataFrame update for {len(newly_loaded_queue)} tasks.")
# *** Return the DataFrame update object ***
return update_queue_data(newly_loaded_queue) return update_queue_data(newly_loaded_queue)
except (ValueError, zipfile.BadZipFile, FileNotFoundError, Exception) as e: except (ValueError, zipfile.BadZipFile, FileNotFoundError, Exception) as e:
error_message = f"Error during queue load: {e}" error_message = f"Error during queue load: {e}"
print(f"[load_queue_action] Caught error: {error_message}") print(f"[load_queue_action] Caught error: {error_message}")
traceback.print_exc() traceback.print_exc()
# Optionally show a Gradio warning/error to the user gr.Warning(f"Failed to load queue: {error_message[:200]}")
gr.Warning(f"Failed to load queue: {error_message[:200]}") # Show truncated error
# *** Return the DataFrame update for the original queue ***
print("[load_queue_action] Load failed. Returning DataFrame update for original queue.") print("[load_queue_action] Load failed. Returning DataFrame update for original queue.")
return update_queue_data(original_queue) return update_queue_data(original_queue)
finally: finally:
# Clean up the uploaded file object if it exists and has a path
if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name): if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name):
if tempfile.gettempdir() in os.path.abspath(filepath.name):
try: try:
# Gradio often uses temp files, attempting removal is good practice os.remove(filepath.name)
# os.remove(filepath.name) print(f"[load_queue_action] Removed temporary upload file: {filepath.name}")
# print(f"[load_queue_action] Cleaned up temporary upload file: {filepath.name}")
pass # Let Gradio manage its temp files unless specifically needed
except OSError as e: except OSError as e:
# Ignore errors like "file not found" if already cleaned up
print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}") print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}")
pass else:
print(f"[load_queue_action] Info: Did not remove non-temporary file: {filepath.name}")
def clear_queue_action(state): def clear_queue_action(state):
gen = get_gen_info(state) gen = get_gen_info(state)
queue = gen.get("queue", []) queue = gen.get("queue", [])
if not queue: aborted_current = False
gr.Info("Queue is already empty.") cleared_pending = False
return update_queue_data([])
with lock: with lock:
if "in_progress" in gen and gen["in_progress"]:
print("Clear Queue: Signalling abort for in-progress task.")
gen["abort"] = True
gen["extra_orders"] = 0
if wan_model is not None:
wan_model._interrupt = True
aborted_current = True
if queue:
if len(queue) > 1 or (len(queue) == 1 and queue[0] is not None and queue[0].get('id') is not None):
print(f"Clear Queue: Clearing {len(queue)} tasks from queue.")
queue.clear() queue.clear()
cleared_pending = True
else:
pass
if aborted_current or cleared_pending:
gen["prompts_max"] = 0 gen["prompts_max"] = 0
if aborted_current and cleared_pending:
gr.Info("Queue cleared and current generation aborted.")
elif aborted_current:
gr.Info("Current generation aborted.")
elif cleared_pending:
gr.Info("Queue cleared.") gr.Info("Queue cleared.")
else:
gr.Info("Queue is already empty or only contains the active task (which wasn't aborted now).")
return update_queue_data([]) return update_queue_data([])
def quit_application():
print("Save and Quit requested...")
autosave_queue()
import signal
os.kill(os.getpid(), signal.SIGINT)
def autosave_queue(): def autosave_queue():
global global_queue_ref global global_queue_ref
if not global_queue_ref: if not global_queue_ref:
@ -682,14 +735,20 @@ def autosave_queue():
def _save_queue_to_file(queue_to_save, output_filename): def _save_queue_to_file(queue_to_save, output_filename):
if not queue_to_save: return None if not queue_to_save: return None
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
queue_manifest = [] queue_manifest = []
image_paths_in_zip = {} file_paths_in_zip = {}
for task_index, task in enumerate(queue_to_save): for task_index, task in enumerate(queue_to_save):
if task is None or not isinstance(task, dict): continue if task is None or not isinstance(task, dict) or task.get('id') is None: continue
params_copy = task.get('params', {}).copy() params_copy = task.get('params', {}).copy()
task_id_s = task.get('id', f"task_{task_index}") task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask"]
for key in image_keys: for key in image_keys:
images_pil = params_copy.get(key) images_pil = params_copy.get(key)
if images_pil is None: continue if images_pil is None: continue
@ -699,38 +758,72 @@ def autosave_queue():
for img_index, pil_image in enumerate(images_pil): for img_index, pil_image in enumerate(images_pil):
if not isinstance(pil_image, Image.Image): continue if not isinstance(pil_image, Image.Image): continue
img_id = id(pil_image) img_id = id(pil_image)
if img_id in image_paths_in_zip: if img_id in file_paths_in_zip:
image_filenames_for_json.append(image_paths_in_zip[img_id]) image_filenames_for_json.append(file_paths_in_zip[img_id])
continue continue
img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png" img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png"
img_save_path = os.path.join(tmpdir, img_filename_in_zip) img_save_path = os.path.join(tmpdir, img_filename_in_zip)
try: try:
pil_image.save(img_save_path, "PNG") pil_image.save(img_save_path, "PNG")
image_filenames_for_json.append(img_filename_in_zip) image_filenames_for_json.append(img_filename_in_zip)
image_paths_in_zip[img_id] = img_filename_in_zip file_paths_in_zip[img_id] = img_filename_in_zip
except Exception as e: except Exception as e:
print(f"Autosave error saving image {img_filename_in_zip}: {e}") print(f"Autosave error saving image {img_filename_in_zip}: {e}")
if image_filenames_for_json: if image_filenames_for_json:
params_copy[key] = image_filenames_for_json if is_list else image_filenames_for_json[0] params_copy[key] = image_filenames_for_json if is_list else image_filenames_for_json[0]
else: else:
params_copy.pop(key, None) params_copy.pop(key, None)
for key in video_keys:
video_path_orig = params_copy.get(key)
if video_path_orig is None or not isinstance(video_path_orig, str):
continue
if video_path_orig in file_paths_in_zip:
params_copy[key] = file_paths_in_zip[video_path_orig]
continue
if not os.path.isfile(video_path_orig):
print(f"Warning (Autosave): Video file not found for key '{key}' in task {task_id_s}: {video_path_orig}. Skipping.")
params_copy.pop(key, None)
continue
_, extension = os.path.splitext(video_path_orig)
vid_filename_in_zip = f"task{task_id_s}_{key}{extension if extension else '.mp4'}"
vid_save_path = os.path.join(tmpdir, vid_filename_in_zip)
try:
shutil.copy2(video_path_orig, vid_save_path)
params_copy[key] = vid_filename_in_zip
file_paths_in_zip[video_path_orig] = vid_filename_in_zip
except Exception as e:
print(f"Error (Autosave) copying video {video_path_orig} to {vid_filename_in_zip} for task {task_id_s}: {e}")
params_copy.pop(key, None)
params_copy.pop('state', None) params_copy.pop('state', None)
params_copy.pop('start_image_data_base64', None) params_copy.pop('start_image_data_base64', None)
params_copy.pop('end_image_data_base64', None) params_copy.pop('end_image_data_base64', None)
params_copy.pop('start_image_data', None)
params_copy.pop('end_image_data', None)
manifest_entry = { manifest_entry = {
"id": task.get('id'), "params": params_copy, "id": task.get('id'),
"params": params_copy,
} }
manifest_entry = {k: v for k, v in manifest_entry.items() if v is not None}
queue_manifest.append(manifest_entry) queue_manifest.append(manifest_entry)
manifest_path = os.path.join(tmpdir, "queue.json") manifest_path = os.path.join(tmpdir, "queue.json")
with open(manifest_path, 'w', encoding='utf-8') as f: json.dump(queue_manifest, f, indent=4) with open(manifest_path, 'w', encoding='utf-8') as f: json.dump(queue_manifest, f, indent=4)
with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zf:
zf.write(manifest_path, arcname="queue.json") zf.write(manifest_path, arcname="queue.json")
for saved_img_rel_path in image_paths_in_zip.values(): for saved_file_rel_path in file_paths_in_zip.values():
saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path) saved_file_abs_path = os.path.join(tmpdir, saved_file_rel_path)
if os.path.exists(saved_img_abs_path): if os.path.exists(saved_file_abs_path):
zf.write(saved_img_abs_path, arcname=saved_img_rel_path) zf.write(saved_file_abs_path, arcname=saved_file_rel_path)
else:
print(f"Warning (Autosave): File {saved_file_rel_path} not found during zipping.")
return output_filename return output_filename
return None # Should not happen if queue has items return None
saved_path = _save_queue_to_file(global_queue_ref, AUTOSAVE_FILENAME) saved_path = _save_queue_to_file(global_queue_ref, AUTOSAVE_FILENAME)
@ -745,17 +838,15 @@ def autosave_queue():
def autoload_queue(state): def autoload_queue(state):
global task_id global task_id
# Initial check using the original state
try: try:
gen = get_gen_info(state) # Make sure initial state is a dict gen = get_gen_info(state)
original_queue = gen.get("queue", []) original_queue = gen.get("queue", [])
except AttributeError: except AttributeError:
print("[autoload_queue] Error: Initial state is not a dictionary. Cannot autoload.") print("[autoload_queue] Error: Initial state is not a dictionary. Cannot autoload.")
# Return default values indicating no load occurred and the state is unchanged return gr.update(visible=False), False, state
return gr.update(visible=False), False, state # Return an empty DF update
loaded_flag = False loaded_flag = False
dataframe_update = update_queue_data(original_queue) # Default update is the original queue dataframe_update = update_queue_data(original_queue)
if not original_queue and Path(AUTOSAVE_FILENAME).is_file(): if not original_queue and Path(AUTOSAVE_FILENAME).is_file():
print(f"Autoloading queue from {AUTOSAVE_FILENAME}...") print(f"Autoloading queue from {AUTOSAVE_FILENAME}...")
@ -763,40 +854,56 @@ def autoload_queue(state):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
mock_filepath = MockFile(AUTOSAVE_FILENAME) mock_filepath = MockFile(AUTOSAVE_FILENAME)
# Call load_queue_action, it modifies 'state' internally and returns a DataFrame update
dataframe_update = load_queue_action(mock_filepath, state) dataframe_update = load_queue_action(mock_filepath, state)
# Now check the 'state' dictionary which should have been modified by load_queue_action gen = get_gen_info(state)
gen = get_gen_info(state) # Use the (potentially) modified state dictionary
loaded_queue_after_action = gen.get("queue", []) loaded_queue_after_action = gen.get("queue", [])
if loaded_queue_after_action: # Check if the queue in the state is now populated if loaded_queue_after_action:
print(f"Autoload successful. Loaded {len(loaded_queue_after_action)} tasks into state.") print(f"Autoload successful. Loaded {len(loaded_queue_after_action)} tasks into state.")
loaded_flag = True loaded_flag = True
# Global ref update was already done inside load_queue_action if successful
else: else:
print("Autoload attempted but queue in state remains empty (file might be empty or invalid).") print("Autoload attempted but queue in state remains empty (file might be empty or invalid).")
# Ensure state reflects empty queue if load failed but file existed
with lock: with lock:
gen["queue"] = [] gen["queue"] = []
gen["prompts_max"] = 0 gen["prompts_max"] = 0
update_global_queue_ref([]) update_global_queue_ref([])
dataframe_update = update_queue_data([]) # Ensure UI shows empty queue dataframe_update = update_queue_data([])
else: # Handle cases where autoload shouldn't happen else:
if original_queue: if original_queue:
print("Autoload skipped: Queue is not empty.") print("Autoload skipped: Queue is not empty.")
update_global_queue_ref(original_queue) # Ensure global ref matches current state update_global_queue_ref(original_queue)
dataframe_update = update_queue_data(original_queue) # UI should show current queue dataframe_update = update_queue_data(original_queue)
else: else:
print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.") print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.")
update_global_queue_ref([]) # Ensure global ref is empty update_global_queue_ref([])
dataframe_update = update_queue_data([]) # UI should show empty queue dataframe_update = update_queue_data([])
# Return the DataFrame update needed for the UI, the flag, and the final state dictionary
return dataframe_update, loaded_flag, state return dataframe_update, loaded_flag, state
def run_autoload_and_prepare_ui(current_state):
df_update, loaded_flag, modified_state = autoload_queue(current_state)
should_start_processing = loaded_flag
accordion_update = gr.Accordion(open=True) if loaded_flag else gr.update()
return df_update, gr.update(visible=loaded_flag), accordion_update, should_start_processing, modified_state
def start_processing_if_needed(should_start, current_state):
if not isinstance(current_state, dict) or 'gen' not in current_state:
yield "Error: Invalid state received before processing."
return
if should_start:
yield from process_tasks(current_state)
else:
yield None
def finalize_generation_with_state(current_state):
if not isinstance(current_state, dict) or 'gen' not in current_state:
return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state
gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state)
accordion_update = gr.Accordion(open=False) if len(get_gen_info(current_state).get("queue", [])) <= 1 else gr.update()
return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state
def get_queue_table(queue): def get_queue_table(queue):
data = [] data = []
@ -1179,7 +1286,6 @@ text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/
server_config_filename = "wgp_config.json" server_config_filename = "wgp_config.json"
if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"): if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"):
import shutil
shutil.move("gradio_config.json", server_config_filename) shutil.move("gradio_config.json", server_config_filename)
if not Path(server_config_filename).is_file(): if not Path(server_config_filename).is_file():
@ -2984,7 +3090,6 @@ def download_loras():
lora_dir = get_lora_dir(get_model_filename("i2v"), quantizeTransformer) lora_dir = get_lora_dir(get_model_filename("i2v"), quantizeTransformer)
log_path = os.path.join(lora_dir, "log.txt") log_path = os.path.join(lora_dir, "log.txt")
if not os.path.isfile(log_path): if not os.path.isfile(log_path):
import shutil
tmp_path = os.path.join(lora_dir, "tmp_lora_dowload") tmp_path = os.path.join(lora_dir, "tmp_lora_dowload")
import glob import glob
snapshot_download(repo_id="DeepBeepMeep/Wan2.1", allow_patterns="loras_i2v/*", local_dir= tmp_path) snapshot_download(repo_id="DeepBeepMeep/Wan2.1", allow_patterns="loras_i2v/*", local_dir= tmp_path)
@ -3484,7 +3589,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Row(): with gr.Row():
onemore_btn = gr.Button("One More Sample Please !") onemore_btn = gr.Button("One More Sample Please !")
abort_btn = gr.Button("Abort") abort_btn = gr.Button("Abort")
with gr.Accordion("Queue Management", open=False) as queue_accordion:
queue_df = gr.DataFrame( queue_df = gr.DataFrame(
headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
@ -3494,7 +3599,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
wrap=True, wrap=True,
value=[], value=[],
line_breaks= True, line_breaks= True,
visible= False, visible= True,
elem_id="queue_df" elem_id="queue_df"
) )
with gr.Row(): with gr.Row():
@ -3502,6 +3607,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
save_queue_btn = gr.DownloadButton("Save Queue", size="sm") save_queue_btn = gr.DownloadButton("Save Queue", size="sm")
load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm") load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm")
clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop") clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop")
quit_button = gr.Button("Save and Quit", size="sm", variant="secondary")
trigger_zip_download_js = """ trigger_zip_download_js = """
(base64String) => { (base64String) => {
if (!base64String) { if (!base64String) {
@ -3544,14 +3650,37 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
js=trigger_zip_download_js js=trigger_zip_download_js
) )
should_start_flag = gr.State(False)
load_queue_btn.upload( load_queue_btn.upload(
fn=load_queue_action, fn=load_queue_action,
inputs=[load_queue_btn, state], inputs=[load_queue_btn, state],
outputs=[queue_df] outputs=[queue_df]
).then( ).then(
fn=lambda s: gr.update(visible=bool(get_gen_info(s).get("queue",[]))), fn=lambda s: (gr.update(visible=bool(get_gen_info(s).get("queue",[]))), gr.Accordion(open=True)) if bool(get_gen_info(s).get("queue",[])) else (gr.update(visible=False), gr.update()),
inputs=[state], inputs=[state],
outputs=[current_gen_column] outputs=[current_gen_column, queue_accordion]
).then(
fn=lambda s: (
(gr.Button(visible=False), gr.Button(visible=True), gr.Column(visible=True), True)
if bool(get_gen_info(s).get("queue",[]))
else (gr.Button(visible=True), gr.Button(visible=False), gr.Column(visible=False), False)
),
inputs=[state],
outputs=[generate_btn, add_to_queue_btn, current_gen_column, should_start_flag]
).then(
fn=start_processing_if_needed,
inputs=[should_start_flag, state],
outputs=[gen_status],
trigger_mode="once"
).then(
fn=finalize_generation_with_state,
inputs=[state],
outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, queue_accordion, state],
trigger_mode="always_last"
).then(
unload_model_if_needed,
inputs= [state],
outputs= []
) )
clear_queue_btn.click( clear_queue_btn.click(
@ -3559,9 +3688,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
inputs=[state], inputs=[state],
outputs=[queue_df] outputs=[queue_df]
).then( ).then(
fn=lambda: gr.update(visible=False), fn=lambda: (gr.update(visible=False), gr.Accordion(open=False)),
inputs=None, inputs=None,
outputs=[current_gen_column] outputs=[current_gen_column, queue_accordion]
)
quit_button.click(
fn=quit_application,
inputs=[],
outputs=[]
) )
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
@ -3644,6 +3778,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
).then(finalize_generation, ).then(finalize_generation,
inputs= [state], inputs= [state],
outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
).then(
fn=lambda s: gr.Accordion(open=False) if len(get_gen_info(s).get("queue", [])) <= 1 else gr.update(),
inputs=[state],
outputs=[queue_accordion]
).then(unload_model_if_needed, ).then(unload_model_if_needed,
inputs= [state], inputs= [state],
outputs= [] outputs= []
@ -3658,6 +3796,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
).then(fn=process_prompt_and_add_tasks, ).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice], inputs = [state, model_choice],
outputs=queue_df outputs=queue_df
).then(
fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), # Expand if queue has items (len > 1 assumes placeholder)
inputs=[state],
outputs=[queue_accordion]
).then( ).then(
fn=update_status, fn=update_status,
inputs = [state], inputs = [state],
@ -3675,7 +3817,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
gen_info, gen_info,
prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
advanced_row, image_prompt_column, video_prompt_column, advanced_row, image_prompt_column, video_prompt_column, queue_accordion,
*prompt_vars *prompt_vars
) )
@ -4158,7 +4300,7 @@ def create_demo():
gen_info, gen_info,
prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
advanced_row, image_prompt_column, video_prompt_column, advanced_row, image_prompt_column, video_prompt_column, queue_accordion,
*prompt_vars_outputs *prompt_vars_outputs
) = generate_video_tab(model_choice=model_choice, header=header) ) = generate_video_tab(model_choice=model_choice, header=header)
with gr.Tab("Informations"): with gr.Tab("Informations"):
@ -4172,30 +4314,11 @@ def create_demo():
generate_about_tab() generate_about_tab()
should_start_flag = gr.State(False) should_start_flag = gr.State(False)
def run_autoload_and_prepare_ui(current_state):
df_update, loaded_flag, modified_state = autoload_queue(current_state)
should_start_processing = loaded_flag
return df_update, gr.update(visible=loaded_flag), should_start_processing, modified_state
def start_processing_if_needed(should_start, current_state):
if not isinstance(current_state, dict) or 'gen' not in current_state:
yield "Error: Invalid state received before processing."
return
if should_start:
yield from process_tasks(current_state)
else:
yield "Autoload complete. Processing not started."
def finalize_generation_with_state(current_state):
if not isinstance(current_state, dict) or 'gen' not in current_state:
return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), current_state
gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state)
return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, current_state
demo.load( demo.load(
fn=run_autoload_and_prepare_ui, fn=run_autoload_and_prepare_ui,
inputs=[state], inputs=[state],
outputs=[queue_df, current_gen_column, should_start_flag, state] outputs=[queue_df, current_gen_column, queue_accordion, should_start_flag, state]
).then( ).then(
fn=start_processing_if_needed, fn=start_processing_if_needed,
inputs=[should_start_flag, state], inputs=[should_start_flag, state],
@ -4204,7 +4327,7 @@ def create_demo():
).then( ).then(
fn=finalize_generation_with_state, fn=finalize_generation_with_state,
inputs=[state], inputs=[state],
outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, state], outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, queue_accordion, state],
trigger_mode="always_last" trigger_mode="always_last"
) )