mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
add queues
This commit is contained in:
parent
f2c5a06626
commit
f7c5193eaa
836
gradio_server.py
836
gradio_server.py
@ -1,5 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Any
|
||||
import argparse
|
||||
from mmgp import offload, safetensors2, profile_type
|
||||
try:
|
||||
@ -21,6 +23,7 @@ import traceback
|
||||
import math
|
||||
import asyncio
|
||||
from wan.utils import prompt_parser
|
||||
from typing import List
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.3.0"
|
||||
@ -29,6 +32,194 @@ mmgp_version = version("mmgp")
|
||||
if mmgp_version != target_mmgp_version:
|
||||
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
|
||||
exit()
|
||||
queue = []
|
||||
lock = threading.Lock()
|
||||
current_task_id = None
|
||||
task_id = 0
|
||||
progress_tracker = {}
|
||||
tracker_lock = threading.Lock()
|
||||
file_list = []
|
||||
|
||||
def runner():
|
||||
global current_task_id
|
||||
while True:
|
||||
with lock:
|
||||
for item in queue:
|
||||
task_id = item['id']
|
||||
with tracker_lock:
|
||||
progress = progress_tracker.get(task_id, {})
|
||||
|
||||
if item['status'] == "Processing":
|
||||
current_step = progress.get('current_step', 0)
|
||||
total_steps = progress.get('total_steps', 0)
|
||||
elapsed = time.time() - progress.get('start_time', time.time())
|
||||
item.update({
|
||||
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
||||
'steps': f"{current_step}/{total_steps}",
|
||||
'time': f"{elapsed:.1f}s"
|
||||
})
|
||||
if not any(item['status'] == "Processing" for item in queue):
|
||||
for item in queue:
|
||||
if item['status'] == "Queued":
|
||||
item['status'] = "Processing"
|
||||
current_task_id = item['id']
|
||||
threading.Thread(target=process_task, args=(item,)).start()
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
def process_prompt_and_add_tasks(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
resolution,
|
||||
video_length,
|
||||
seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
flow_shift,
|
||||
embedded_guidance_scale,
|
||||
repeat_generation,
|
||||
multi_images_gen_type,
|
||||
tea_cache,
|
||||
tea_cache_start_step_perc,
|
||||
loras_choices,
|
||||
loras_mult_choices,
|
||||
image_to_continue,
|
||||
image_to_end,
|
||||
video_to_continue,
|
||||
max_frames,
|
||||
RIFLEx_setting,
|
||||
slg_switch,
|
||||
slg_layers,
|
||||
slg_start,
|
||||
slg_end,
|
||||
state_arg,
|
||||
image2video
|
||||
):
|
||||
if len(prompt) ==0:
|
||||
return
|
||||
prompt, errors = prompt_parser.process_template(prompt)
|
||||
if len(errors) > 0:
|
||||
gr.Info("Error processing prompt template: " + errors)
|
||||
return
|
||||
prompts = prompt.replace("\r", "").split("\n")
|
||||
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
||||
if len(prompts) ==0:
|
||||
return
|
||||
|
||||
for single_prompt in prompts:
|
||||
task_params = (
|
||||
single_prompt,
|
||||
negative_prompt,
|
||||
resolution,
|
||||
video_length,
|
||||
seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
flow_shift,
|
||||
embedded_guidance_scale,
|
||||
repeat_generation,
|
||||
multi_images_gen_type,
|
||||
tea_cache,
|
||||
tea_cache_start_step_perc,
|
||||
loras_choices,
|
||||
loras_mult_choices,
|
||||
image_to_continue,
|
||||
image_to_end,
|
||||
video_to_continue,
|
||||
max_frames,
|
||||
RIFLEx_setting,
|
||||
slg_switch,
|
||||
slg_layers,
|
||||
slg_start,
|
||||
slg_end,
|
||||
state_arg,
|
||||
image2video
|
||||
)
|
||||
add_video_task(*task_params)
|
||||
return update_queue_data()
|
||||
|
||||
def process_task(task):
|
||||
try:
|
||||
task_id, *params = task['params']
|
||||
generate_video(task_id, *params)
|
||||
finally:
|
||||
with lock:
|
||||
queue[:] = [item for item in queue if item['id'] != task['id']]
|
||||
with tracker_lock:
|
||||
if task['id'] in progress_tracker:
|
||||
del progress_tracker[task['id']]
|
||||
|
||||
def add_video_task(*params):
|
||||
global task_id
|
||||
with lock:
|
||||
task_id += 1
|
||||
current_task_id = task_id
|
||||
queue.append({
|
||||
"id": current_task_id,
|
||||
"params": (current_task_id,) + params,
|
||||
"status": "Queued",
|
||||
"progress": "0.0%",
|
||||
"steps": f"0/{params[5]}",
|
||||
"time": "--",
|
||||
"prompt": params[0]
|
||||
})
|
||||
return
|
||||
|
||||
def move_up(selected_indices):
|
||||
if not selected_indices or len(selected_indices) == 0:
|
||||
return update_queue_data()
|
||||
idx = selected_indices[0]
|
||||
if isinstance(idx, list):
|
||||
idx = idx[0]
|
||||
idx = int(idx)
|
||||
with lock:
|
||||
if idx > 0:
|
||||
queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
|
||||
return update_queue_data()
|
||||
|
||||
def move_down(selected_indices):
|
||||
if not selected_indices or len(selected_indices) == 0:
|
||||
return update_queue_data()
|
||||
idx = selected_indices[0]
|
||||
if isinstance(idx, list):
|
||||
idx = idx[0]
|
||||
idx = int(idx)
|
||||
with lock:
|
||||
if idx < len(queue)-1:
|
||||
queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
|
||||
return update_queue_data()
|
||||
|
||||
def remove_task(selected_indices):
|
||||
if not selected_indices or len(selected_indices) == 0:
|
||||
return update_queue_data()
|
||||
idx = selected_indices[0]
|
||||
if isinstance(idx, list):
|
||||
idx = idx[0]
|
||||
idx = int(idx)
|
||||
with lock:
|
||||
if idx < len(queue):
|
||||
if idx == 0:
|
||||
wan_model._interrupt = True
|
||||
del queue[idx]
|
||||
return update_queue_data()
|
||||
|
||||
def update_queue_data():
|
||||
with lock:
|
||||
data = []
|
||||
for item in queue:
|
||||
data.append([
|
||||
str(item['id']),
|
||||
item['status'],
|
||||
item.get('progress', "0.0%"),
|
||||
item.get('steps', ''),
|
||||
item.get('time', '--'),
|
||||
(item['prompt'][:47] + '...') if len(item['prompt']) > 50 else item['prompt'],
|
||||
"↑",
|
||||
"↓",
|
||||
"✖"
|
||||
])
|
||||
return data
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a video from a text prompt or image using Gradio")
|
||||
@ -788,53 +979,22 @@ def save_video(final_frames, output_path, fps=24):
|
||||
final_frames = (final_frames * 255).astype(np.uint8)
|
||||
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
||||
|
||||
def build_callback(state, pipe, progress, status, num_inference_steps):
|
||||
def callback(step_idx, latents, read_state = False):
|
||||
status = state["progress_status"]
|
||||
if read_state:
|
||||
phase, step_idx = state["progress_phase"]
|
||||
else:
|
||||
step_idx += 1
|
||||
if state.get("abort", False):
|
||||
# pipe._interrupt = True
|
||||
phase = " - Aborting"
|
||||
elif step_idx == num_inference_steps:
|
||||
phase = " - VAE Decoding"
|
||||
else:
|
||||
phase = " - Denoising"
|
||||
state["progress_phase"] = (phase, step_idx)
|
||||
status_msg = status + phase
|
||||
if step_idx >= 0:
|
||||
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
|
||||
else:
|
||||
progress(0, status_msg)
|
||||
|
||||
return callback
|
||||
def build_callback(task_id, total_steps):
|
||||
start_time = time.time()
|
||||
def update_progress(step, _):
|
||||
with tracker_lock:
|
||||
elapsed = time.time() - start_time
|
||||
progress_tracker[task_id] = {
|
||||
'current_step': step + 1,
|
||||
'total_steps': total_steps,
|
||||
'start_time': start_time,
|
||||
'last_update': time.time()
|
||||
}
|
||||
return update_progress
|
||||
|
||||
def abort_generation(state):
|
||||
if "in_progress" in state:
|
||||
state["abort"] = True
|
||||
state["extra_orders"] = 0
|
||||
wan_model._interrupt= True
|
||||
return gr.Button(interactive= False)
|
||||
else:
|
||||
return gr.Button(interactive= True)
|
||||
|
||||
def refresh_gallery(state, txt):
|
||||
file_list = state.get("file_list", None)
|
||||
prompt = state.get("prompt", "")
|
||||
if len(prompt) == 0:
|
||||
return file_list, gr.Text(visible= False, value="")
|
||||
else:
|
||||
prompts_max = state.get("prompts_max",0)
|
||||
prompt_no = state.get("prompt_no",0)
|
||||
if prompts_max >1 :
|
||||
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
|
||||
else:
|
||||
label = f"Current Prompt"
|
||||
return file_list, gr.Text(visible= True, value=prompt, label=label)
|
||||
|
||||
|
||||
def refresh_gallery(state):
|
||||
file_list = state.get("file_list", None)
|
||||
return file_list
|
||||
def finalize_gallery(state):
|
||||
choice = 0
|
||||
if "in_progress" in state:
|
||||
@ -845,7 +1005,7 @@ def finalize_gallery(state):
|
||||
time.sleep(0.2)
|
||||
global gen_in_progress
|
||||
gen_in_progress = False
|
||||
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False), gr.Text(visible= False, value="")
|
||||
return gr.Gallery(selected_index=choice), gr.Button(interactive=True), gr.Button(visible=False), gr.Checkbox(visible=False), gr.Text(visible=False, value="")
|
||||
|
||||
def select_video(state , event_data: gr.EventData):
|
||||
data= event_data._data
|
||||
@ -862,36 +1022,8 @@ def expand_slist(slist, num_inference_steps ):
|
||||
pos += inc
|
||||
return new_slist
|
||||
|
||||
|
||||
def one_more_video(state):
|
||||
extra_orders = state.get("extra_orders", 0)
|
||||
extra_orders += 1
|
||||
state["extra_orders"] = extra_orders
|
||||
prompts_max = state.get("prompts_max",0)
|
||||
if prompts_max == 0:
|
||||
return state
|
||||
prompt_no = state["prompt_no"]
|
||||
video_no = state["video_no"]
|
||||
total_video = state["total_video"]
|
||||
# total_video += (prompts_max- prompt_no)
|
||||
total_video += 1
|
||||
total_generation = state["total_generation"] + extra_orders
|
||||
state["total_video"] = total_video
|
||||
|
||||
state["progress_status"] = f"Video {video_no}/{total_video}"
|
||||
offload.shared_state["refresh"] = 1
|
||||
# if (prompts_max - prompt_no) > 1:
|
||||
# gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for the next {prompts_max - prompt_no} prompts")
|
||||
# else:
|
||||
gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for this prompt")
|
||||
|
||||
return state
|
||||
|
||||
def prepare_generate_video():
|
||||
|
||||
return gr.Button(visible= False), gr.Checkbox(visible= True)
|
||||
|
||||
def generate_video(
|
||||
task_id,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
resolution,
|
||||
@ -921,7 +1053,6 @@ def generate_video(
|
||||
progress=gr.Progress() #track_tqdm= True
|
||||
|
||||
):
|
||||
|
||||
global wan_model, offloadobj
|
||||
reload_needed = state.get("_reload_needed", False)
|
||||
file_model_needed = model_needed(image2video)
|
||||
@ -930,9 +1061,9 @@ def generate_video(
|
||||
offloadobj.release()
|
||||
offloadobj = None
|
||||
wan_model = None
|
||||
yield f"Loading model {get_model_name(file_model_needed)}..."
|
||||
print(f"Loading model {get_model_name(file_model_needed)}...")
|
||||
wan_model, offloadobj, trans = load_models(image2video)
|
||||
yield f"Model loaded"
|
||||
print(f"Model loaded")
|
||||
state["_reload_needed"] = False
|
||||
|
||||
from PIL import Image
|
||||
@ -949,8 +1080,8 @@ def generate_video(
|
||||
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
|
||||
return
|
||||
|
||||
if state.get("validate_success",0) != 1:
|
||||
return
|
||||
#if state.get("validate_success",0) != 1:
|
||||
# return
|
||||
width, height = resolution.split("x")
|
||||
width, height = int(width), int(height)
|
||||
|
||||
@ -997,71 +1128,11 @@ def generate_video(
|
||||
global gen_in_progress
|
||||
gen_in_progress = True
|
||||
temp_filename = None
|
||||
if len(prompt) ==0:
|
||||
return
|
||||
prompt, errors = prompt_parser.process_template(prompt)
|
||||
if len(errors) > 0:
|
||||
gr.Info(f"Error processing prompt template: " + errors)
|
||||
prompts = prompt.replace("\r", "").split("\n")
|
||||
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
||||
if len(prompts) ==0:
|
||||
return
|
||||
if image2video:
|
||||
if image_to_continue is not None:
|
||||
if isinstance(image_to_continue, list):
|
||||
image_to_continue = [ tup[0] for tup in image_to_continue ]
|
||||
else:
|
||||
image_to_continue = [image_to_continue]
|
||||
if image_to_end != None:
|
||||
if isinstance(image_to_end , list):
|
||||
image_to_end = [ tup[0] for tup in image_to_end ]
|
||||
else:
|
||||
image_to_end = [image_to_end ]
|
||||
if len(image_to_continue) != len(image_to_end):
|
||||
gr.Info("The number of start and end images should be the same ")
|
||||
return
|
||||
if multi_images_gen_type == 0:
|
||||
new_prompts = []
|
||||
new_image_to_continue = []
|
||||
new_image_to_end = []
|
||||
for i in range(len(prompts) * len(image_to_continue) ):
|
||||
new_prompts.append( prompts[ i % len(prompts)] )
|
||||
new_image_to_continue.append(image_to_continue[i // len(prompts)] )
|
||||
if image_to_end != None:
|
||||
new_image_to_end.append(image_to_end[i // len(prompts)] )
|
||||
prompts = new_prompts
|
||||
image_to_continue = new_image_to_continue
|
||||
if image_to_end != None:
|
||||
image_to_end = new_image_to_end
|
||||
else:
|
||||
if len(prompts) >= len(image_to_continue):
|
||||
if len(prompts) % len(image_to_continue) !=0:
|
||||
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
|
||||
rep = len(prompts) // len(image_to_continue)
|
||||
new_image_to_continue = []
|
||||
new_image_to_end = []
|
||||
for i, _ in enumerate(prompts):
|
||||
new_image_to_continue.append(image_to_continue[i//rep] )
|
||||
if image_to_end != None:
|
||||
new_image_to_end.append(image_to_end[i//rep] )
|
||||
image_to_continue = new_image_to_continue
|
||||
if image_to_end != None:
|
||||
image_to_end = new_image_to_end
|
||||
else:
|
||||
if len(image_to_continue) % len(prompts) !=0:
|
||||
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
|
||||
rep = len(image_to_continue) // len(prompts)
|
||||
new_prompts = []
|
||||
for i, _ in enumerate(image_to_continue):
|
||||
new_prompts.append( prompts[ i//rep] )
|
||||
prompts = new_prompts
|
||||
|
||||
elif video_to_continue != None and len(video_to_continue) >0 :
|
||||
if video_to_continue != None and len(video_to_continue) >0 :
|
||||
input_image_or_video_path = video_to_continue
|
||||
# pipeline.num_input_frames = max_frames
|
||||
# pipeline.max_frames = max_frames
|
||||
else:
|
||||
return
|
||||
else:
|
||||
input_image_or_video_path = None
|
||||
|
||||
@ -1148,195 +1219,167 @@ def generate_video(
|
||||
if seed == None or seed <0:
|
||||
seed = random.randint(0, 999999999)
|
||||
|
||||
file_list = []
|
||||
state["file_list"] = file_list
|
||||
global file_list
|
||||
state["file_list"] = file_list
|
||||
global save_path
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
video_no = 0
|
||||
total_video = repeat_generation * len(prompts)
|
||||
state["total_video"] = total_video
|
||||
extra_generation = 0
|
||||
abort = False
|
||||
start_time = time.time()
|
||||
state["prompts_max"] = len(prompts)
|
||||
for no, prompt in enumerate(prompts):
|
||||
state["prompt"] = prompt
|
||||
repeat_no = 0
|
||||
state["prompt_no"] = no
|
||||
extra_generation = 0
|
||||
yield f"Prompt No{no}"
|
||||
while True:
|
||||
extra_orders = state.get("extra_orders",0)
|
||||
state["extra_orders"] = 0
|
||||
extra_generation += extra_orders
|
||||
state["total_generation"] = repeat_generation + extra_generation
|
||||
# total_video += (len(prompts)- no) * extra_orders
|
||||
total_video += extra_orders
|
||||
if abort or repeat_no >= (repeat_generation + extra_generation):
|
||||
break
|
||||
with tracker_lock:
|
||||
progress_tracker[task_id] = {
|
||||
'current_step': 0,
|
||||
'total_steps': num_inference_steps,
|
||||
'start_time': time.time(),
|
||||
'last_update': time.time()
|
||||
}
|
||||
if trans.enable_teacache:
|
||||
trans.teacache_counter = 0
|
||||
trans.num_steps = num_inference_steps
|
||||
trans.teacache_skipped_steps = 0
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
callback = build_callback(task_id, num_inference_steps)
|
||||
offload.shared_state["callback"] = callback
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
wan_model._interrupt = False
|
||||
state["progress_status"] = "Starting"
|
||||
try:
|
||||
if image2video:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
image_to_continue.convert('RGB'),
|
||||
image_to_end.convert('RGB') if image_to_end != None else None,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
max_area=MAX_AREA_CONFIGS[resolution],
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
|
||||
if trans.enable_teacache:
|
||||
trans.teacache_counter = 0
|
||||
trans.num_steps = num_inference_steps
|
||||
trans.teacache_skipped_steps = 0
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
else:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
size=(width, height),
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
except Exception as e:
|
||||
gen_in_progress = False
|
||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||
os.remove(temp_filename)
|
||||
offload.last_offload_obj.unload_all()
|
||||
offload.unload_loras_from_model(trans)
|
||||
# if compile:
|
||||
# cache_size = torch._dynamo.config.cache_size_limit
|
||||
# torch.compiler.reset()
|
||||
# torch._dynamo.config.cache_size_limit = cache_size
|
||||
|
||||
video_no += 1
|
||||
status = f"Video {video_no}/{total_video}"
|
||||
state["video_no"] = video_no
|
||||
state["progress_status"] = status
|
||||
state["progress_phase"] = (" - Encoding Prompt", -1 )
|
||||
progress(0, desc=status + " - Encoding Prompt" )
|
||||
callback = build_callback(state, trans, progress, status, num_inference_steps)
|
||||
offload.shared_state["callback"] = callback
|
||||
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
wan_model._interrupt = False
|
||||
try:
|
||||
if image2video:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
image_to_continue[no].convert('RGB'),
|
||||
image_to_end[no].convert('RGB') if image_to_end != None else None,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
max_area=MAX_AREA_CONFIGS[resolution],
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
|
||||
else:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
size=(width, height),
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
except Exception as e:
|
||||
gen_in_progress = False
|
||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||
os.remove(temp_filename)
|
||||
offload.last_offload_obj.unload_all()
|
||||
offload.unload_loras_from_model(trans)
|
||||
# if compile:
|
||||
# cache_size = torch._dynamo.config.cache_size_limit
|
||||
# torch.compiler.reset()
|
||||
# torch._dynamo.config.cache_size_limit = cache_size
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
s = str(e)
|
||||
keyword_list = ["vram", "VRAM", "memory","allocat"]
|
||||
VRAM_crash= False
|
||||
if any( keyword in s for keyword in keyword_list):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
s = str(e)
|
||||
keyword_list = ["vram", "VRAM", "memory","allocat"]
|
||||
VRAM_crash= False
|
||||
if any( keyword in s for keyword in keyword_list):
|
||||
VRAM_crash = True
|
||||
else:
|
||||
stack = traceback.extract_stack(f=None, limit=5)
|
||||
for frame in stack:
|
||||
if any( keyword in frame.name for keyword in keyword_list):
|
||||
VRAM_crash = True
|
||||
else:
|
||||
stack = traceback.extract_stack(f=None, limit=5)
|
||||
for frame in stack:
|
||||
if any( keyword in frame.name for keyword in keyword_list):
|
||||
VRAM_crash = True
|
||||
break
|
||||
state["prompt"] = ""
|
||||
if VRAM_crash:
|
||||
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
|
||||
else:
|
||||
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
||||
break
|
||||
state["prompt"] = ""
|
||||
if VRAM_crash:
|
||||
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
|
||||
else:
|
||||
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
||||
finally:
|
||||
with tracker_lock:
|
||||
if task_id in progress_tracker:
|
||||
del progress_tracker[task_id]
|
||||
|
||||
if trans.enable_teacache:
|
||||
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
if trans.enable_teacache:
|
||||
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
|
||||
if samples != None:
|
||||
samples = samples.to("cpu")
|
||||
offload.last_offload_obj.unload_all()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if samples != None:
|
||||
samples = samples.to("cpu")
|
||||
offload.last_offload_obj.unload_all()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if samples == None:
|
||||
end_time = time.time()
|
||||
abort = True
|
||||
state["prompt"] = ""
|
||||
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
||||
else:
|
||||
sample = samples.cpu()
|
||||
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
||||
if samples == None:
|
||||
end_time = time.time()
|
||||
abort = True
|
||||
state["prompt"] = ""
|
||||
else:
|
||||
sample = samples.cpu()
|
||||
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
||||
|
||||
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
||||
if os.name == 'nt':
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
||||
else:
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
|
||||
video_path = os.path.join(save_path, file_name)
|
||||
cache_video(
|
||||
tensor=sample[None],
|
||||
save_file=video_path,
|
||||
fps=16,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
|
||||
configs = {
|
||||
'prompt': prompt,
|
||||
'negative_prompt': negative_prompt,
|
||||
'resolution': resolution,
|
||||
'video_length': video_length,
|
||||
'seed': seed,
|
||||
'num_inference_steps': num_inference_steps,
|
||||
}
|
||||
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
||||
if os.name == 'nt':
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
||||
else:
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
|
||||
video_path = os.path.join(save_path, file_name)
|
||||
cache_video(
|
||||
tensor=sample[None],
|
||||
save_file=video_path,
|
||||
fps=16,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
|
||||
configs = {
|
||||
'prompt': prompt,
|
||||
'negative_prompt': negative_prompt,
|
||||
'resolution': resolution,
|
||||
'video_length': video_length,
|
||||
'seed': seed,
|
||||
'num_inference_steps': num_inference_steps,
|
||||
}
|
||||
|
||||
metadata_choice = server_config["metadata_choice"]
|
||||
if metadata_choice == "json":
|
||||
with open(video_path.replace('.mp4', '.json'), 'w') as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
elif metadata_choice == "metadata":
|
||||
from mutagen.mp4 import MP4
|
||||
file = MP4(video_path)
|
||||
file.tags['©cmt'] = [json.dumps(configs)]
|
||||
file.save()
|
||||
metadata_choice = server_config["metadata_choice"]
|
||||
if metadata_choice == "json":
|
||||
with open(video_path.replace('.mp4', '.json'), 'w') as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
elif metadata_choice == "metadata":
|
||||
from mutagen.mp4 import MP4
|
||||
file = MP4(video_path)
|
||||
file.tags['©cmt'] = [json.dumps(configs)]
|
||||
file.save()
|
||||
|
||||
print(f"New video saved to Path: "+video_path)
|
||||
file_list.append(video_path)
|
||||
if video_no < total_video:
|
||||
yield status
|
||||
else:
|
||||
end_time = time.time()
|
||||
state["prompt"] = ""
|
||||
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
||||
seed += 1
|
||||
repeat_no += 1
|
||||
print(f"New video saved to Path: "+video_path)
|
||||
file_list.append(video_path)
|
||||
seed += 1
|
||||
|
||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||
os.remove(temp_filename)
|
||||
gen_in_progress = False
|
||||
offload.unload_loras_from_model(trans)
|
||||
|
||||
|
||||
def get_new_preset_msg(advanced = True):
|
||||
if advanced:
|
||||
return "Enter here a Name for a Lora Preset or Choose one in the List"
|
||||
@ -1952,16 +1995,40 @@ def generate_video_tab(image2video=False):
|
||||
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
||||
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
||||
with gr.Column():
|
||||
gen_status = gr.Text(label="Status", interactive= False)
|
||||
output = gr.Gallery(
|
||||
label="Generated videos", show_label=False, elem_id="gallery"
|
||||
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
|
||||
generate_btn = gr.Button("Generate")
|
||||
onemore_btn = gr.Button("One More Please !", visible= False)
|
||||
abort_btn = gr.Button("Abort")
|
||||
gen_info = gr.Text(label="Current prompt", visible= False , interactive= False)
|
||||
|
||||
|
||||
queue_df = gr.DataFrame(
|
||||
headers=["ID", "Status", "Progress", "Steps", "Time", "Prompt", "", "", ""],
|
||||
datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str"],
|
||||
interactive=False,
|
||||
col_count=(9, "fixed"),
|
||||
wrap=True,
|
||||
value=update_queue_data,
|
||||
every=1,
|
||||
elem_id="queue_df"
|
||||
)
|
||||
def handle_selection(evt: gr.SelectData):
|
||||
cell_value = evt.value
|
||||
selected_index = evt.index
|
||||
if cell_value == "↑":
|
||||
return move_up([selected_index])
|
||||
elif cell_value == "↓":
|
||||
return move_down([selected_index])
|
||||
elif cell_value == "✖":
|
||||
return remove_task([selected_index])
|
||||
return queue_df
|
||||
selected_indices = gr.State([])
|
||||
queue_df.select(
|
||||
fn=handle_selection,
|
||||
outputs=selected_indices
|
||||
)
|
||||
queue_df.change(
|
||||
fn=refresh_gallery,
|
||||
inputs=[state],
|
||||
outputs=[output]
|
||||
)
|
||||
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
||||
save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
||||
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
|
||||
@ -1978,48 +2045,43 @@ def generate_video_tab(image2video=False):
|
||||
refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
||||
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
||||
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
||||
gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] )
|
||||
abort_btn.click(abort_generation,state,abort_btn )
|
||||
output.select(select_video, state, None )
|
||||
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
|
||||
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]
|
||||
).then(
|
||||
fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
|
||||
).then(
|
||||
fn=generate_video,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
resolution,
|
||||
video_length,
|
||||
seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
flow_shift,
|
||||
embedded_guidance_scale,
|
||||
repeat_generation,
|
||||
multi_images_gen_type,
|
||||
tea_cache_setting,
|
||||
tea_cache_start_step_perc,
|
||||
loras_choices,
|
||||
loras_mult_choices,
|
||||
image_to_continue,
|
||||
image_to_end,
|
||||
video_to_continue,
|
||||
max_frames,
|
||||
RIFLEx_setting,
|
||||
slg_switch,
|
||||
slg_layers,
|
||||
slg_start_perc,
|
||||
slg_end_perc,
|
||||
state,
|
||||
gr.State(image2video)
|
||||
],
|
||||
outputs= [gen_status]
|
||||
).then(
|
||||
finalize_gallery,
|
||||
[state],
|
||||
[output , abort_btn, generate_btn, onemore_btn, gen_info]
|
||||
original_inputs = [
|
||||
prompt,
|
||||
negative_prompt,
|
||||
resolution,
|
||||
video_length,
|
||||
seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
flow_shift,
|
||||
embedded_guidance_scale,
|
||||
repeat_generation,
|
||||
multi_images_gen_type,
|
||||
tea_cache_setting,
|
||||
tea_cache_start_step_perc,
|
||||
loras_choices,
|
||||
loras_mult_choices,
|
||||
image_to_continue,
|
||||
image_to_end,
|
||||
video_to_continue,
|
||||
max_frames,
|
||||
RIFLEx_setting,
|
||||
slg_switch,
|
||||
slg_layers,
|
||||
slg_start_perc,
|
||||
slg_end_perc,
|
||||
state,
|
||||
gr.State(image2video)
|
||||
]
|
||||
|
||||
#generate_btn.click(
|
||||
# fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
|
||||
#).then(
|
||||
generate_btn.click(
|
||||
fn=process_prompt_and_add_tasks,
|
||||
inputs=original_inputs,
|
||||
outputs=queue_df
|
||||
)
|
||||
return loras_choices, lset_name, header, state
|
||||
|
||||
@ -2226,6 +2288,59 @@ def create_demo():
|
||||
margin: 0 20px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.queue-item {
|
||||
border: 1px solid #ccc;
|
||||
padding: 10px;
|
||||
margin: 5px 0;
|
||||
border-radius: 5px;
|
||||
}
|
||||
.current {
|
||||
background: #f8f9fa;
|
||||
border-left: 4px solid #007bff;
|
||||
}
|
||||
.task-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
.progress-container {
|
||||
height: 10px;
|
||||
background: #e9ecef;
|
||||
border-radius: 5px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.progress-bar {
|
||||
height: 100%;
|
||||
background: #007bff;
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
.task-details {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
font-size: 0.9em;
|
||||
color: #6c757d;
|
||||
margin-top: 5px;
|
||||
}
|
||||
.task-prompt {
|
||||
font-size: 0.8em;
|
||||
color: #868e96;
|
||||
margin-top: 5px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
#queue_df td:nth-child(7),
|
||||
#queue_df td:nth-child(8),
|
||||
#queue_df td:nth-child(9) {
|
||||
cursor: pointer;
|
||||
text-align: center;
|
||||
font-weight: bold;
|
||||
}
|
||||
#queue_df td:nth-child(7):hover,
|
||||
#queue_df td:nth-child(8):hover,
|
||||
#queue_df td:nth-child(9):hover {
|
||||
background-color: #e0e0e0;
|
||||
}
|
||||
"""
|
||||
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
@ -2260,6 +2375,7 @@ def create_demo():
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
threading.Thread(target=runner, daemon=True).start()
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
server_port = int(args.server_port)
|
||||
if os.name == "nt":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user