add queues

This commit is contained in:
Chris Malone 2025-03-24 17:20:53 +11:00
parent f2c5a06626
commit f7c5193eaa

View File

@ -1,5 +1,7 @@
import os import os
import time import time
import threading
from typing import Dict, Any
import argparse import argparse
from mmgp import offload, safetensors2, profile_type from mmgp import offload, safetensors2, profile_type
try: try:
@ -21,6 +23,7 @@ import traceback
import math import math
import asyncio import asyncio
from wan.utils import prompt_parser from wan.utils import prompt_parser
from typing import List
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.3.0" target_mmgp_version = "3.3.0"
@ -29,6 +32,194 @@ mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version: if mmgp_version != target_mmgp_version:
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'") print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
exit() exit()
queue = []
lock = threading.Lock()
current_task_id = None
task_id = 0
progress_tracker = {}
tracker_lock = threading.Lock()
file_list = []
def runner():
global current_task_id
while True:
with lock:
for item in queue:
task_id = item['id']
with tracker_lock:
progress = progress_tracker.get(task_id, {})
if item['status'] == "Processing":
current_step = progress.get('current_step', 0)
total_steps = progress.get('total_steps', 0)
elapsed = time.time() - progress.get('start_time', time.time())
item.update({
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
'steps': f"{current_step}/{total_steps}",
'time': f"{elapsed:.1f}s"
})
if not any(item['status'] == "Processing" for item in queue):
for item in queue:
if item['status'] == "Queued":
item['status'] = "Processing"
current_task_id = item['id']
threading.Thread(target=process_task, args=(item,)).start()
break
time.sleep(1)
def process_prompt_and_add_tasks(
prompt,
negative_prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
tea_cache,
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_to_continue,
image_to_end,
video_to_continue,
max_frames,
RIFLEx_setting,
slg_switch,
slg_layers,
slg_start,
slg_end,
state_arg,
image2video
):
if len(prompt) ==0:
return
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
gr.Info("Error processing prompt template: " + errors)
return
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0:
return
for single_prompt in prompts:
task_params = (
single_prompt,
negative_prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
tea_cache,
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_to_continue,
image_to_end,
video_to_continue,
max_frames,
RIFLEx_setting,
slg_switch,
slg_layers,
slg_start,
slg_end,
state_arg,
image2video
)
add_video_task(*task_params)
return update_queue_data()
def process_task(task):
try:
task_id, *params = task['params']
generate_video(task_id, *params)
finally:
with lock:
queue[:] = [item for item in queue if item['id'] != task['id']]
with tracker_lock:
if task['id'] in progress_tracker:
del progress_tracker[task['id']]
def add_video_task(*params):
global task_id
with lock:
task_id += 1
current_task_id = task_id
queue.append({
"id": current_task_id,
"params": (current_task_id,) + params,
"status": "Queued",
"progress": "0.0%",
"steps": f"0/{params[5]}",
"time": "--",
"prompt": params[0]
})
return
def move_up(selected_indices):
if not selected_indices or len(selected_indices) == 0:
return update_queue_data()
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
idx = int(idx)
with lock:
if idx > 0:
queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
return update_queue_data()
def move_down(selected_indices):
if not selected_indices or len(selected_indices) == 0:
return update_queue_data()
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
idx = int(idx)
with lock:
if idx < len(queue)-1:
queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
return update_queue_data()
def remove_task(selected_indices):
if not selected_indices or len(selected_indices) == 0:
return update_queue_data()
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
idx = int(idx)
with lock:
if idx < len(queue):
if idx == 0:
wan_model._interrupt = True
del queue[idx]
return update_queue_data()
def update_queue_data():
with lock:
data = []
for item in queue:
data.append([
str(item['id']),
item['status'],
item.get('progress', "0.0%"),
item.get('steps', ''),
item.get('time', '--'),
(item['prompt'][:47] + '...') if len(item['prompt']) > 50 else item['prompt'],
"",
"",
""
])
return data
def _parse_args(): def _parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio") description="Generate a video from a text prompt or image using Gradio")
@ -788,53 +979,22 @@ def save_video(final_frames, output_path, fps=24):
final_frames = (final_frames * 255).astype(np.uint8) final_frames = (final_frames * 255).astype(np.uint8)
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None) ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
def build_callback(state, pipe, progress, status, num_inference_steps): def build_callback(task_id, total_steps):
def callback(step_idx, latents, read_state = False): start_time = time.time()
status = state["progress_status"] def update_progress(step, _):
if read_state: with tracker_lock:
phase, step_idx = state["progress_phase"] elapsed = time.time() - start_time
else: progress_tracker[task_id] = {
step_idx += 1 'current_step': step + 1,
if state.get("abort", False): 'total_steps': total_steps,
# pipe._interrupt = True 'start_time': start_time,
phase = " - Aborting" 'last_update': time.time()
elif step_idx == num_inference_steps: }
phase = " - VAE Decoding" return update_progress
else:
phase = " - Denoising"
state["progress_phase"] = (phase, step_idx)
status_msg = status + phase
if step_idx >= 0:
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
else:
progress(0, status_msg)
return callback
def abort_generation(state): def refresh_gallery(state):
if "in_progress" in state: file_list = state.get("file_list", None)
state["abort"] = True return file_list
state["extra_orders"] = 0
wan_model._interrupt= True
return gr.Button(interactive= False)
else:
return gr.Button(interactive= True)
def refresh_gallery(state, txt):
file_list = state.get("file_list", None)
prompt = state.get("prompt", "")
if len(prompt) == 0:
return file_list, gr.Text(visible= False, value="")
else:
prompts_max = state.get("prompts_max",0)
prompt_no = state.get("prompt_no",0)
if prompts_max >1 :
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
else:
label = f"Current Prompt"
return file_list, gr.Text(visible= True, value=prompt, label=label)
def finalize_gallery(state): def finalize_gallery(state):
choice = 0 choice = 0
if "in_progress" in state: if "in_progress" in state:
@ -845,7 +1005,7 @@ def finalize_gallery(state):
time.sleep(0.2) time.sleep(0.2)
global gen_in_progress global gen_in_progress
gen_in_progress = False gen_in_progress = False
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False), gr.Text(visible= False, value="") return gr.Gallery(selected_index=choice), gr.Button(interactive=True), gr.Button(visible=False), gr.Checkbox(visible=False), gr.Text(visible=False, value="")
def select_video(state , event_data: gr.EventData): def select_video(state , event_data: gr.EventData):
data= event_data._data data= event_data._data
@ -862,36 +1022,8 @@ def expand_slist(slist, num_inference_steps ):
pos += inc pos += inc
return new_slist return new_slist
def one_more_video(state):
extra_orders = state.get("extra_orders", 0)
extra_orders += 1
state["extra_orders"] = extra_orders
prompts_max = state.get("prompts_max",0)
if prompts_max == 0:
return state
prompt_no = state["prompt_no"]
video_no = state["video_no"]
total_video = state["total_video"]
# total_video += (prompts_max- prompt_no)
total_video += 1
total_generation = state["total_generation"] + extra_orders
state["total_video"] = total_video
state["progress_status"] = f"Video {video_no}/{total_video}"
offload.shared_state["refresh"] = 1
# if (prompts_max - prompt_no) > 1:
# gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for the next {prompts_max - prompt_no} prompts")
# else:
gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for this prompt")
return state
def prepare_generate_video():
return gr.Button(visible= False), gr.Checkbox(visible= True)
def generate_video( def generate_video(
task_id,
prompt, prompt,
negative_prompt, negative_prompt,
resolution, resolution,
@ -921,7 +1053,6 @@ def generate_video(
progress=gr.Progress() #track_tqdm= True progress=gr.Progress() #track_tqdm= True
): ):
global wan_model, offloadobj global wan_model, offloadobj
reload_needed = state.get("_reload_needed", False) reload_needed = state.get("_reload_needed", False)
file_model_needed = model_needed(image2video) file_model_needed = model_needed(image2video)
@ -930,9 +1061,9 @@ def generate_video(
offloadobj.release() offloadobj.release()
offloadobj = None offloadobj = None
wan_model = None wan_model = None
yield f"Loading model {get_model_name(file_model_needed)}..." print(f"Loading model {get_model_name(file_model_needed)}...")
wan_model, offloadobj, trans = load_models(image2video) wan_model, offloadobj, trans = load_models(image2video)
yield f"Model loaded" print(f"Model loaded")
state["_reload_needed"] = False state["_reload_needed"] = False
from PIL import Image from PIL import Image
@ -949,8 +1080,8 @@ def generate_video(
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.") gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
return return
if state.get("validate_success",0) != 1: #if state.get("validate_success",0) != 1:
return # return
width, height = resolution.split("x") width, height = resolution.split("x")
width, height = int(width), int(height) width, height = int(width), int(height)
@ -997,71 +1128,11 @@ def generate_video(
global gen_in_progress global gen_in_progress
gen_in_progress = True gen_in_progress = True
temp_filename = None temp_filename = None
if len(prompt) ==0:
return
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
gr.Info(f"Error processing prompt template: " + errors)
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0:
return
if image2video: if image2video:
if image_to_continue is not None: if video_to_continue != None and len(video_to_continue) >0 :
if isinstance(image_to_continue, list):
image_to_continue = [ tup[0] for tup in image_to_continue ]
else:
image_to_continue = [image_to_continue]
if image_to_end != None:
if isinstance(image_to_end , list):
image_to_end = [ tup[0] for tup in image_to_end ]
else:
image_to_end = [image_to_end ]
if len(image_to_continue) != len(image_to_end):
gr.Info("The number of start and end images should be the same ")
return
if multi_images_gen_type == 0:
new_prompts = []
new_image_to_continue = []
new_image_to_end = []
for i in range(len(prompts) * len(image_to_continue) ):
new_prompts.append( prompts[ i % len(prompts)] )
new_image_to_continue.append(image_to_continue[i // len(prompts)] )
if image_to_end != None:
new_image_to_end.append(image_to_end[i // len(prompts)] )
prompts = new_prompts
image_to_continue = new_image_to_continue
if image_to_end != None:
image_to_end = new_image_to_end
else:
if len(prompts) >= len(image_to_continue):
if len(prompts) % len(image_to_continue) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
rep = len(prompts) // len(image_to_continue)
new_image_to_continue = []
new_image_to_end = []
for i, _ in enumerate(prompts):
new_image_to_continue.append(image_to_continue[i//rep] )
if image_to_end != None:
new_image_to_end.append(image_to_end[i//rep] )
image_to_continue = new_image_to_continue
if image_to_end != None:
image_to_end = new_image_to_end
else:
if len(image_to_continue) % len(prompts) !=0:
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
rep = len(image_to_continue) // len(prompts)
new_prompts = []
for i, _ in enumerate(image_to_continue):
new_prompts.append( prompts[ i//rep] )
prompts = new_prompts
elif video_to_continue != None and len(video_to_continue) >0 :
input_image_or_video_path = video_to_continue input_image_or_video_path = video_to_continue
# pipeline.num_input_frames = max_frames # pipeline.num_input_frames = max_frames
# pipeline.max_frames = max_frames # pipeline.max_frames = max_frames
else:
return
else: else:
input_image_or_video_path = None input_image_or_video_path = None
@ -1148,195 +1219,167 @@ def generate_video(
if seed == None or seed <0: if seed == None or seed <0:
seed = random.randint(0, 999999999) seed = random.randint(0, 999999999)
file_list = [] global file_list
state["file_list"] = file_list state["file_list"] = file_list
global save_path global save_path
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
video_no = 0
total_video = repeat_generation * len(prompts)
state["total_video"] = total_video
extra_generation = 0
abort = False abort = False
start_time = time.time() with tracker_lock:
state["prompts_max"] = len(prompts) progress_tracker[task_id] = {
for no, prompt in enumerate(prompts): 'current_step': 0,
state["prompt"] = prompt 'total_steps': num_inference_steps,
repeat_no = 0 'start_time': time.time(),
state["prompt_no"] = no 'last_update': time.time()
extra_generation = 0 }
yield f"Prompt No{no}" if trans.enable_teacache:
while True: trans.teacache_counter = 0
extra_orders = state.get("extra_orders",0) trans.num_steps = num_inference_steps
state["extra_orders"] = 0 trans.teacache_skipped_steps = 0
extra_generation += extra_orders trans.previous_residual_uncond = None
state["total_generation"] = repeat_generation + extra_generation trans.previous_residual_cond = None
# total_video += (len(prompts)- no) * extra_orders callback = build_callback(task_id, num_inference_steps)
total_video += extra_orders offload.shared_state["callback"] = callback
if abort or repeat_no >= (repeat_generation + extra_generation): gc.collect()
break torch.cuda.empty_cache()
wan_model._interrupt = False
state["progress_status"] = "Starting"
try:
if image2video:
samples = wan_model.generate(
prompt,
image_to_continue.convert('RGB'),
image_to_end.convert('RGB') if image_to_end != None else None,
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
if trans.enable_teacache: else:
trans.teacache_counter = 0 samples = wan_model.generate(
trans.num_steps = num_inference_steps prompt,
trans.teacache_skipped_steps = 0 frame_num=(video_length // 4)* 4 + 1,
trans.previous_residual_uncond = None size=(width, height),
trans.previous_residual_cond = None shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
except Exception as e:
gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
offload.last_offload_obj.unload_all()
offload.unload_loras_from_model(trans)
# if compile:
# cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset()
# torch._dynamo.config.cache_size_limit = cache_size
video_no += 1 gc.collect()
status = f"Video {video_no}/{total_video}" torch.cuda.empty_cache()
state["video_no"] = video_no s = str(e)
state["progress_status"] = status keyword_list = ["vram", "VRAM", "memory","allocat"]
state["progress_phase"] = (" - Encoding Prompt", -1 ) VRAM_crash= False
progress(0, desc=status + " - Encoding Prompt" ) if any( keyword in s for keyword in keyword_list):
callback = build_callback(state, trans, progress, status, num_inference_steps) VRAM_crash = True
offload.shared_state["callback"] = callback else:
stack = traceback.extract_stack(f=None, limit=5)
for frame in stack:
gc.collect() if any( keyword in frame.name for keyword in keyword_list):
torch.cuda.empty_cache()
wan_model._interrupt = False
try:
if image2video:
samples = wan_model.generate(
prompt,
image_to_continue[no].convert('RGB'),
image_to_end[no].convert('RGB') if image_to_end != None else None,
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
else:
samples = wan_model.generate(
prompt,
frame_num=(video_length // 4)* 4 + 1,
size=(width, height),
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
except Exception as e:
gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
offload.last_offload_obj.unload_all()
offload.unload_loras_from_model(trans)
# if compile:
# cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset()
# torch._dynamo.config.cache_size_limit = cache_size
gc.collect()
torch.cuda.empty_cache()
s = str(e)
keyword_list = ["vram", "VRAM", "memory","allocat"]
VRAM_crash= False
if any( keyword in s for keyword in keyword_list):
VRAM_crash = True VRAM_crash = True
else: break
stack = traceback.extract_stack(f=None, limit=5) state["prompt"] = ""
for frame in stack: if VRAM_crash:
if any( keyword in frame.name for keyword in keyword_list): raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
VRAM_crash = True else:
break raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
state["prompt"] = "" finally:
if VRAM_crash: with tracker_lock:
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.") if task_id in progress_tracker:
else: del progress_tracker[task_id]
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
if trans.enable_teacache: if trans.enable_teacache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
trans.previous_residual_uncond = None trans.previous_residual_uncond = None
trans.previous_residual_cond = None trans.previous_residual_cond = None
if samples != None: if samples != None:
samples = samples.to("cpu") samples = samples.to("cpu")
offload.last_offload_obj.unload_all() offload.last_offload_obj.unload_all()
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if samples == None: if samples == None:
end_time = time.time() end_time = time.time()
abort = True abort = True
state["prompt"] = "" state["prompt"] = ""
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s" else:
else: sample = samples.cpu()
sample = samples.cpu() # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt': if os.name == 'nt':
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4" file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
else: else:
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4" file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
video_path = os.path.join(save_path, file_name) video_path = os.path.join(save_path, file_name)
cache_video( cache_video(
tensor=sample[None], tensor=sample[None],
save_file=video_path, save_file=video_path,
fps=16, fps=16,
nrow=1, nrow=1,
normalize=True, normalize=True,
value_range=(-1, 1)) value_range=(-1, 1))
configs = { configs = {
'prompt': prompt, 'prompt': prompt,
'negative_prompt': negative_prompt, 'negative_prompt': negative_prompt,
'resolution': resolution, 'resolution': resolution,
'video_length': video_length, 'video_length': video_length,
'seed': seed, 'seed': seed,
'num_inference_steps': num_inference_steps, 'num_inference_steps': num_inference_steps,
} }
metadata_choice = server_config["metadata_choice"] metadata_choice = server_config["metadata_choice"]
if metadata_choice == "json": if metadata_choice == "json":
with open(video_path.replace('.mp4', '.json'), 'w') as f: with open(video_path.replace('.mp4', '.json'), 'w') as f:
json.dump(configs, f, indent=4) json.dump(configs, f, indent=4)
elif metadata_choice == "metadata": elif metadata_choice == "metadata":
from mutagen.mp4 import MP4 from mutagen.mp4 import MP4
file = MP4(video_path) file = MP4(video_path)
file.tags['©cmt'] = [json.dumps(configs)] file.tags['©cmt'] = [json.dumps(configs)]
file.save() file.save()
print(f"New video saved to Path: "+video_path) print(f"New video saved to Path: "+video_path)
file_list.append(video_path) file_list.append(video_path)
if video_no < total_video: seed += 1
yield status
else:
end_time = time.time()
state["prompt"] = ""
yield f"Total Generation Time: {end_time-start_time:.1f}s"
seed += 1
repeat_no += 1
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) os.remove(temp_filename)
gen_in_progress = False gen_in_progress = False
offload.unload_loras_from_model(trans) offload.unload_loras_from_model(trans)
def get_new_preset_msg(advanced = True): def get_new_preset_msg(advanced = True):
if advanced: if advanced:
return "Enter here a Name for a Lora Preset or Choose one in the List" return "Enter here a Name for a Lora Preset or Choose one in the List"
@ -1952,16 +1995,40 @@ def generate_video_tab(image2video=False):
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
with gr.Column(): with gr.Column():
gen_status = gr.Text(label="Status", interactive= False)
output = gr.Gallery( output = gr.Gallery(
label="Generated videos", show_label=False, elem_id="gallery" label="Generated videos", show_label=False, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
generate_btn = gr.Button("Generate") generate_btn = gr.Button("Generate")
onemore_btn = gr.Button("One More Please !", visible= False) queue_df = gr.DataFrame(
abort_btn = gr.Button("Abort") headers=["ID", "Status", "Progress", "Steps", "Time", "Prompt", "", "", ""],
gen_info = gr.Text(label="Current prompt", visible= False , interactive= False) datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str"],
interactive=False,
col_count=(9, "fixed"),
wrap=True,
value=update_queue_data,
every=1,
elem_id="queue_df"
)
def handle_selection(evt: gr.SelectData):
cell_value = evt.value
selected_index = evt.index
if cell_value == "":
return move_up([selected_index])
elif cell_value == "":
return move_down([selected_index])
elif cell_value == "":
return remove_task([selected_index])
return queue_df
selected_indices = gr.State([])
queue_df.select(
fn=handle_selection,
outputs=selected_indices
)
queue_df.change(
fn=refresh_gallery,
inputs=[state],
outputs=[output]
)
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
@ -1978,48 +2045,43 @@ def generate_video_tab(image2video=False):
refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] )
abort_btn.click(abort_generation,state,abort_btn )
output.select(select_video, state, None ) output.select(select_video, state, None )
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state]) original_inputs = [
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn] prompt,
).then( negative_prompt,
fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt] resolution,
).then( video_length,
fn=generate_video, seed,
inputs=[ num_inference_steps,
prompt, guidance_scale,
negative_prompt, flow_shift,
resolution, embedded_guidance_scale,
video_length, repeat_generation,
seed, multi_images_gen_type,
num_inference_steps, tea_cache_setting,
guidance_scale, tea_cache_start_step_perc,
flow_shift, loras_choices,
embedded_guidance_scale, loras_mult_choices,
repeat_generation, image_to_continue,
multi_images_gen_type, image_to_end,
tea_cache_setting, video_to_continue,
tea_cache_start_step_perc, max_frames,
loras_choices, RIFLEx_setting,
loras_mult_choices, slg_switch,
image_to_continue, slg_layers,
image_to_end, slg_start_perc,
video_to_continue, slg_end_perc,
max_frames, state,
RIFLEx_setting, gr.State(image2video)
slg_switch, ]
slg_layers,
slg_start_perc, #generate_btn.click(
slg_end_perc, # fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
state, #).then(
gr.State(image2video) generate_btn.click(
], fn=process_prompt_and_add_tasks,
outputs= [gen_status] inputs=original_inputs,
).then( outputs=queue_df
finalize_gallery,
[state],
[output , abort_btn, generate_btn, onemore_btn, gen_info]
) )
return loras_choices, lset_name, header, state return loras_choices, lset_name, header, state
@ -2226,6 +2288,59 @@ def create_demo():
margin: 0 20px; margin: 0 20px;
white-space: nowrap; white-space: nowrap;
} }
.queue-item {
border: 1px solid #ccc;
padding: 10px;
margin: 5px 0;
border-radius: 5px;
}
.current {
background: #f8f9fa;
border-left: 4px solid #007bff;
}
.task-header {
display: flex;
justify-content: space-between;
margin-bottom: 5px;
}
.progress-container {
height: 10px;
background: #e9ecef;
border-radius: 5px;
overflow: hidden;
}
.progress-bar {
height: 100%;
background: #007bff;
transition: width 0.3s ease;
}
.task-details {
display: flex;
justify-content: space-between;
font-size: 0.9em;
color: #6c757d;
margin-top: 5px;
}
.task-prompt {
font-size: 0.8em;
color: #868e96;
margin-top: 5px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
#queue_df td:nth-child(7),
#queue_df td:nth-child(8),
#queue_df td:nth-child(9) {
cursor: pointer;
text-align: center;
font-weight: bold;
}
#queue_df td:nth-child(7):hover,
#queue_df td:nth-child(8):hover,
#queue_df td:nth-child(9):hover {
background-color: #e0e0e0;
}
""" """
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo: with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>") gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
@ -2260,6 +2375,7 @@ def create_demo():
return demo return demo
if __name__ == "__main__": if __name__ == "__main__":
threading.Thread(target=runner, daemon=True).start()
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_port = int(args.server_port) server_port = int(args.server_port)
if os.name == "nt": if os.name == "nt":