add queues

This commit is contained in:
Chris Malone 2025-03-24 17:20:53 +11:00
parent f2c5a06626
commit f7c5193eaa

View File

@ -1,5 +1,7 @@
import os
import time
import threading
from typing import Dict, Any
import argparse
from mmgp import offload, safetensors2, profile_type
try:
@ -21,6 +23,7 @@ import traceback
import math
import asyncio
from wan.utils import prompt_parser
from typing import List
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.3.0"
@ -29,6 +32,194 @@ mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
exit()
queue = []
lock = threading.Lock()
current_task_id = None
task_id = 0
progress_tracker = {}
tracker_lock = threading.Lock()
file_list = []
def runner():
global current_task_id
while True:
with lock:
for item in queue:
task_id = item['id']
with tracker_lock:
progress = progress_tracker.get(task_id, {})
if item['status'] == "Processing":
current_step = progress.get('current_step', 0)
total_steps = progress.get('total_steps', 0)
elapsed = time.time() - progress.get('start_time', time.time())
item.update({
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
'steps': f"{current_step}/{total_steps}",
'time': f"{elapsed:.1f}s"
})
if not any(item['status'] == "Processing" for item in queue):
for item in queue:
if item['status'] == "Queued":
item['status'] = "Processing"
current_task_id = item['id']
threading.Thread(target=process_task, args=(item,)).start()
break
time.sleep(1)
def process_prompt_and_add_tasks(
prompt,
negative_prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
tea_cache,
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_to_continue,
image_to_end,
video_to_continue,
max_frames,
RIFLEx_setting,
slg_switch,
slg_layers,
slg_start,
slg_end,
state_arg,
image2video
):
if len(prompt) ==0:
return
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
gr.Info("Error processing prompt template: " + errors)
return
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0:
return
for single_prompt in prompts:
task_params = (
single_prompt,
negative_prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
tea_cache,
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_to_continue,
image_to_end,
video_to_continue,
max_frames,
RIFLEx_setting,
slg_switch,
slg_layers,
slg_start,
slg_end,
state_arg,
image2video
)
add_video_task(*task_params)
return update_queue_data()
def process_task(task):
try:
task_id, *params = task['params']
generate_video(task_id, *params)
finally:
with lock:
queue[:] = [item for item in queue if item['id'] != task['id']]
with tracker_lock:
if task['id'] in progress_tracker:
del progress_tracker[task['id']]
def add_video_task(*params):
global task_id
with lock:
task_id += 1
current_task_id = task_id
queue.append({
"id": current_task_id,
"params": (current_task_id,) + params,
"status": "Queued",
"progress": "0.0%",
"steps": f"0/{params[5]}",
"time": "--",
"prompt": params[0]
})
return
def move_up(selected_indices):
if not selected_indices or len(selected_indices) == 0:
return update_queue_data()
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
idx = int(idx)
with lock:
if idx > 0:
queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
return update_queue_data()
def move_down(selected_indices):
if not selected_indices or len(selected_indices) == 0:
return update_queue_data()
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
idx = int(idx)
with lock:
if idx < len(queue)-1:
queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
return update_queue_data()
def remove_task(selected_indices):
if not selected_indices or len(selected_indices) == 0:
return update_queue_data()
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
idx = int(idx)
with lock:
if idx < len(queue):
if idx == 0:
wan_model._interrupt = True
del queue[idx]
return update_queue_data()
def update_queue_data():
with lock:
data = []
for item in queue:
data.append([
str(item['id']),
item['status'],
item.get('progress', "0.0%"),
item.get('steps', ''),
item.get('time', '--'),
(item['prompt'][:47] + '...') if len(item['prompt']) > 50 else item['prompt'],
"",
"",
""
])
return data
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
@ -788,53 +979,22 @@ def save_video(final_frames, output_path, fps=24):
final_frames = (final_frames * 255).astype(np.uint8)
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
def build_callback(state, pipe, progress, status, num_inference_steps):
def callback(step_idx, latents, read_state = False):
status = state["progress_status"]
if read_state:
phase, step_idx = state["progress_phase"]
else:
step_idx += 1
if state.get("abort", False):
# pipe._interrupt = True
phase = " - Aborting"
elif step_idx == num_inference_steps:
phase = " - VAE Decoding"
else:
phase = " - Denoising"
state["progress_phase"] = (phase, step_idx)
status_msg = status + phase
if step_idx >= 0:
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
else:
progress(0, status_msg)
def build_callback(task_id, total_steps):
start_time = time.time()
def update_progress(step, _):
with tracker_lock:
elapsed = time.time() - start_time
progress_tracker[task_id] = {
'current_step': step + 1,
'total_steps': total_steps,
'start_time': start_time,
'last_update': time.time()
}
return update_progress
return callback
def abort_generation(state):
if "in_progress" in state:
state["abort"] = True
state["extra_orders"] = 0
wan_model._interrupt= True
return gr.Button(interactive= False)
else:
return gr.Button(interactive= True)
def refresh_gallery(state, txt):
def refresh_gallery(state):
file_list = state.get("file_list", None)
prompt = state.get("prompt", "")
if len(prompt) == 0:
return file_list, gr.Text(visible= False, value="")
else:
prompts_max = state.get("prompts_max",0)
prompt_no = state.get("prompt_no",0)
if prompts_max >1 :
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
else:
label = f"Current Prompt"
return file_list, gr.Text(visible= True, value=prompt, label=label)
return file_list
def finalize_gallery(state):
choice = 0
if "in_progress" in state:
@ -845,7 +1005,7 @@ def finalize_gallery(state):
time.sleep(0.2)
global gen_in_progress
gen_in_progress = False
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False), gr.Text(visible= False, value="")
return gr.Gallery(selected_index=choice), gr.Button(interactive=True), gr.Button(visible=False), gr.Checkbox(visible=False), gr.Text(visible=False, value="")
def select_video(state , event_data: gr.EventData):
data= event_data._data
@ -862,36 +1022,8 @@ def expand_slist(slist, num_inference_steps ):
pos += inc
return new_slist
def one_more_video(state):
extra_orders = state.get("extra_orders", 0)
extra_orders += 1
state["extra_orders"] = extra_orders
prompts_max = state.get("prompts_max",0)
if prompts_max == 0:
return state
prompt_no = state["prompt_no"]
video_no = state["video_no"]
total_video = state["total_video"]
# total_video += (prompts_max- prompt_no)
total_video += 1
total_generation = state["total_generation"] + extra_orders
state["total_video"] = total_video
state["progress_status"] = f"Video {video_no}/{total_video}"
offload.shared_state["refresh"] = 1
# if (prompts_max - prompt_no) > 1:
# gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for the next {prompts_max - prompt_no} prompts")
# else:
gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for this prompt")
return state
def prepare_generate_video():
return gr.Button(visible= False), gr.Checkbox(visible= True)
def generate_video(
task_id,
prompt,
negative_prompt,
resolution,
@ -921,7 +1053,6 @@ def generate_video(
progress=gr.Progress() #track_tqdm= True
):
global wan_model, offloadobj
reload_needed = state.get("_reload_needed", False)
file_model_needed = model_needed(image2video)
@ -930,9 +1061,9 @@ def generate_video(
offloadobj.release()
offloadobj = None
wan_model = None
yield f"Loading model {get_model_name(file_model_needed)}..."
print(f"Loading model {get_model_name(file_model_needed)}...")
wan_model, offloadobj, trans = load_models(image2video)
yield f"Model loaded"
print(f"Model loaded")
state["_reload_needed"] = False
from PIL import Image
@ -949,8 +1080,8 @@ def generate_video(
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
return
if state.get("validate_success",0) != 1:
return
#if state.get("validate_success",0) != 1:
# return
width, height = resolution.split("x")
width, height = int(width), int(height)
@ -997,71 +1128,11 @@ def generate_video(
global gen_in_progress
gen_in_progress = True
temp_filename = None
if len(prompt) ==0:
return
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
gr.Info(f"Error processing prompt template: " + errors)
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0:
return
if image2video:
if image_to_continue is not None:
if isinstance(image_to_continue, list):
image_to_continue = [ tup[0] for tup in image_to_continue ]
else:
image_to_continue = [image_to_continue]
if image_to_end != None:
if isinstance(image_to_end , list):
image_to_end = [ tup[0] for tup in image_to_end ]
else:
image_to_end = [image_to_end ]
if len(image_to_continue) != len(image_to_end):
gr.Info("The number of start and end images should be the same ")
return
if multi_images_gen_type == 0:
new_prompts = []
new_image_to_continue = []
new_image_to_end = []
for i in range(len(prompts) * len(image_to_continue) ):
new_prompts.append( prompts[ i % len(prompts)] )
new_image_to_continue.append(image_to_continue[i // len(prompts)] )
if image_to_end != None:
new_image_to_end.append(image_to_end[i // len(prompts)] )
prompts = new_prompts
image_to_continue = new_image_to_continue
if image_to_end != None:
image_to_end = new_image_to_end
else:
if len(prompts) >= len(image_to_continue):
if len(prompts) % len(image_to_continue) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
rep = len(prompts) // len(image_to_continue)
new_image_to_continue = []
new_image_to_end = []
for i, _ in enumerate(prompts):
new_image_to_continue.append(image_to_continue[i//rep] )
if image_to_end != None:
new_image_to_end.append(image_to_end[i//rep] )
image_to_continue = new_image_to_continue
if image_to_end != None:
image_to_end = new_image_to_end
else:
if len(image_to_continue) % len(prompts) !=0:
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
rep = len(image_to_continue) // len(prompts)
new_prompts = []
for i, _ in enumerate(image_to_continue):
new_prompts.append( prompts[ i//rep] )
prompts = new_prompts
elif video_to_continue != None and len(video_to_continue) >0 :
if video_to_continue != None and len(video_to_continue) >0 :
input_image_or_video_path = video_to_continue
# pipeline.num_input_frames = max_frames
# pipeline.max_frames = max_frames
else:
return
else:
input_image_or_video_path = None
@ -1148,195 +1219,167 @@ def generate_video(
if seed == None or seed <0:
seed = random.randint(0, 999999999)
file_list = []
global file_list
state["file_list"] = file_list
global save_path
os.makedirs(save_path, exist_ok=True)
video_no = 0
total_video = repeat_generation * len(prompts)
state["total_video"] = total_video
extra_generation = 0
abort = False
start_time = time.time()
state["prompts_max"] = len(prompts)
for no, prompt in enumerate(prompts):
state["prompt"] = prompt
repeat_no = 0
state["prompt_no"] = no
extra_generation = 0
yield f"Prompt No{no}"
while True:
extra_orders = state.get("extra_orders",0)
state["extra_orders"] = 0
extra_generation += extra_orders
state["total_generation"] = repeat_generation + extra_generation
# total_video += (len(prompts)- no) * extra_orders
total_video += extra_orders
if abort or repeat_no >= (repeat_generation + extra_generation):
break
with tracker_lock:
progress_tracker[task_id] = {
'current_step': 0,
'total_steps': num_inference_steps,
'start_time': time.time(),
'last_update': time.time()
}
if trans.enable_teacache:
trans.teacache_counter = 0
trans.num_steps = num_inference_steps
trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
callback = build_callback(task_id, num_inference_steps)
offload.shared_state["callback"] = callback
gc.collect()
torch.cuda.empty_cache()
wan_model._interrupt = False
state["progress_status"] = "Starting"
try:
if image2video:
samples = wan_model.generate(
prompt,
image_to_continue.convert('RGB'),
image_to_end.convert('RGB') if image_to_end != None else None,
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
if trans.enable_teacache:
trans.teacache_counter = 0
trans.num_steps = num_inference_steps
trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
else:
samples = wan_model.generate(
prompt,
frame_num=(video_length // 4)* 4 + 1,
size=(width, height),
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
except Exception as e:
gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
offload.last_offload_obj.unload_all()
offload.unload_loras_from_model(trans)
# if compile:
# cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset()
# torch._dynamo.config.cache_size_limit = cache_size
video_no += 1
status = f"Video {video_no}/{total_video}"
state["video_no"] = video_no
state["progress_status"] = status
state["progress_phase"] = (" - Encoding Prompt", -1 )
progress(0, desc=status + " - Encoding Prompt" )
callback = build_callback(state, trans, progress, status, num_inference_steps)
offload.shared_state["callback"] = callback
gc.collect()
torch.cuda.empty_cache()
wan_model._interrupt = False
try:
if image2video:
samples = wan_model.generate(
prompt,
image_to_continue[no].convert('RGB'),
image_to_end[no].convert('RGB') if image_to_end != None else None,
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
else:
samples = wan_model.generate(
prompt,
frame_num=(video_length // 4)* 4 + 1,
size=(width, height),
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
except Exception as e:
gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
offload.last_offload_obj.unload_all()
offload.unload_loras_from_model(trans)
# if compile:
# cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset()
# torch._dynamo.config.cache_size_limit = cache_size
gc.collect()
torch.cuda.empty_cache()
s = str(e)
keyword_list = ["vram", "VRAM", "memory","allocat"]
VRAM_crash= False
if any( keyword in s for keyword in keyword_list):
gc.collect()
torch.cuda.empty_cache()
s = str(e)
keyword_list = ["vram", "VRAM", "memory","allocat"]
VRAM_crash= False
if any( keyword in s for keyword in keyword_list):
VRAM_crash = True
else:
stack = traceback.extract_stack(f=None, limit=5)
for frame in stack:
if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True
else:
stack = traceback.extract_stack(f=None, limit=5)
for frame in stack:
if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True
break
state["prompt"] = ""
if VRAM_crash:
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
else:
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
break
state["prompt"] = ""
if VRAM_crash:
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
else:
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
finally:
with tracker_lock:
if task_id in progress_tracker:
del progress_tracker[task_id]
if trans.enable_teacache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
if trans.enable_teacache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
if samples != None:
samples = samples.to("cpu")
offload.last_offload_obj.unload_all()
gc.collect()
torch.cuda.empty_cache()
if samples != None:
samples = samples.to("cpu")
offload.last_offload_obj.unload_all()
gc.collect()
torch.cuda.empty_cache()
if samples == None:
end_time = time.time()
abort = True
state["prompt"] = ""
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else:
sample = samples.cpu()
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
if samples == None:
end_time = time.time()
abort = True
state["prompt"] = ""
else:
sample = samples.cpu()
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt':
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
else:
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
video_path = os.path.join(save_path, file_name)
cache_video(
tensor=sample[None],
save_file=video_path,
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt':
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
else:
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
video_path = os.path.join(save_path, file_name)
cache_video(
tensor=sample[None],
save_file=video_path,
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
configs = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'resolution': resolution,
'video_length': video_length,
'seed': seed,
'num_inference_steps': num_inference_steps,
}
configs = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'resolution': resolution,
'video_length': video_length,
'seed': seed,
'num_inference_steps': num_inference_steps,
}
metadata_choice = server_config["metadata_choice"]
if metadata_choice == "json":
with open(video_path.replace('.mp4', '.json'), 'w') as f:
json.dump(configs, f, indent=4)
elif metadata_choice == "metadata":
from mutagen.mp4 import MP4
file = MP4(video_path)
file.tags['©cmt'] = [json.dumps(configs)]
file.save()
metadata_choice = server_config["metadata_choice"]
if metadata_choice == "json":
with open(video_path.replace('.mp4', '.json'), 'w') as f:
json.dump(configs, f, indent=4)
elif metadata_choice == "metadata":
from mutagen.mp4 import MP4
file = MP4(video_path)
file.tags['©cmt'] = [json.dumps(configs)]
file.save()
print(f"New video saved to Path: "+video_path)
file_list.append(video_path)
if video_no < total_video:
yield status
else:
end_time = time.time()
state["prompt"] = ""
yield f"Total Generation Time: {end_time-start_time:.1f}s"
seed += 1
repeat_no += 1
print(f"New video saved to Path: "+video_path)
file_list.append(video_path)
seed += 1
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
gen_in_progress = False
offload.unload_loras_from_model(trans)
def get_new_preset_msg(advanced = True):
if advanced:
return "Enter here a Name for a Lora Preset or Choose one in the List"
@ -1952,16 +1995,40 @@ def generate_video_tab(image2video=False):
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
with gr.Column():
gen_status = gr.Text(label="Status", interactive= False)
output = gr.Gallery(
label="Generated videos", show_label=False, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
generate_btn = gr.Button("Generate")
onemore_btn = gr.Button("One More Please !", visible= False)
abort_btn = gr.Button("Abort")
gen_info = gr.Text(label="Current prompt", visible= False , interactive= False)
queue_df = gr.DataFrame(
headers=["ID", "Status", "Progress", "Steps", "Time", "Prompt", "", "", ""],
datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str"],
interactive=False,
col_count=(9, "fixed"),
wrap=True,
value=update_queue_data,
every=1,
elem_id="queue_df"
)
def handle_selection(evt: gr.SelectData):
cell_value = evt.value
selected_index = evt.index
if cell_value == "":
return move_up([selected_index])
elif cell_value == "":
return move_down([selected_index])
elif cell_value == "":
return remove_task([selected_index])
return queue_df
selected_indices = gr.State([])
queue_df.select(
fn=handle_selection,
outputs=selected_indices
)
queue_df.change(
fn=refresh_gallery,
inputs=[state],
outputs=[output]
)
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
@ -1978,48 +2045,43 @@ def generate_video_tab(image2video=False):
refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] )
abort_btn.click(abort_generation,state,abort_btn )
output.select(select_video, state, None )
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]
).then(
fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
).then(
fn=generate_video,
inputs=[
prompt,
negative_prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
tea_cache_setting,
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_to_continue,
image_to_end,
video_to_continue,
max_frames,
RIFLEx_setting,
slg_switch,
slg_layers,
slg_start_perc,
slg_end_perc,
state,
gr.State(image2video)
],
outputs= [gen_status]
).then(
finalize_gallery,
[state],
[output , abort_btn, generate_btn, onemore_btn, gen_info]
original_inputs = [
prompt,
negative_prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
tea_cache_setting,
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_to_continue,
image_to_end,
video_to_continue,
max_frames,
RIFLEx_setting,
slg_switch,
slg_layers,
slg_start_perc,
slg_end_perc,
state,
gr.State(image2video)
]
#generate_btn.click(
# fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
#).then(
generate_btn.click(
fn=process_prompt_and_add_tasks,
inputs=original_inputs,
outputs=queue_df
)
return loras_choices, lset_name, header, state
@ -2226,6 +2288,59 @@ def create_demo():
margin: 0 20px;
white-space: nowrap;
}
.queue-item {
border: 1px solid #ccc;
padding: 10px;
margin: 5px 0;
border-radius: 5px;
}
.current {
background: #f8f9fa;
border-left: 4px solid #007bff;
}
.task-header {
display: flex;
justify-content: space-between;
margin-bottom: 5px;
}
.progress-container {
height: 10px;
background: #e9ecef;
border-radius: 5px;
overflow: hidden;
}
.progress-bar {
height: 100%;
background: #007bff;
transition: width 0.3s ease;
}
.task-details {
display: flex;
justify-content: space-between;
font-size: 0.9em;
color: #6c757d;
margin-top: 5px;
}
.task-prompt {
font-size: 0.8em;
color: #868e96;
margin-top: 5px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
#queue_df td:nth-child(7),
#queue_df td:nth-child(8),
#queue_df td:nth-child(9) {
cursor: pointer;
text-align: center;
font-weight: bold;
}
#queue_df td:nth-child(7):hover,
#queue_df td:nth-child(8):hover,
#queue_df td:nth-child(9):hover {
background-color: #e0e0e0;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
@ -2260,6 +2375,7 @@ def create_demo():
return demo
if __name__ == "__main__":
threading.Thread(target=runner, daemon=True).start()
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_port = int(args.server_port)
if os.name == "nt":