Lora festival part 2, new macros, new user interface

This commit is contained in:
DeepBeepMeep 2025-03-17 23:43:34 +01:00
parent e554e1a3d6
commit a15060267a
4 changed files with 815 additions and 128 deletions

View File

@ -19,6 +19,13 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Marc 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
- Clearer user interface
- Download 30 Loras in one click to try them all (expand the info section)
- Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt
- Added basic macro prompt language to prefill prompts with differnent values. With one prompt template, you can generate multiple prompts.
- New Multiple images prompts: you can now combine any number of images with any number of text promtps (need to launch the app with --multiple-images)
- New command lines options to launch directly the 1.3B t2v model or the 14B t2v model
* Mar 14, 2025: 👋 Wan2.1GP v1.7:
- Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt*
- Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
@ -121,7 +128,11 @@ To run the text to video generator (in Low VRAM mode):
```bash
python gradio_server.py
#or
python gradio_server.py --t2v
python gradio_server.py --t2v #launch the default text 2 video model
#or
python gradio_server.py --t2v-14B #for the 14B model
#or
python gradio_server.py --t2v-1-3B #for the 1.3B model
```
@ -191,10 +202,27 @@ python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.l
You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
### Macros (basic)
In *Advanced Mode*, you can starts prompt lines with a "!" , for instance:\
```
! {Subject}="cat","woman","man", {Location}="forest","lake","city", {Possessive}="its", "her", "his"
In the video, a {Subject} is presented. The {Subject} is in a {Location} and looks at {Possessive} watch.
```
This will create automatically 3 prompts that will cause the generation of 3 videos:
```
In the video, a cat is presented. The cat is in a forest and looks at its watch.
In the video, a man is presented. The man is in a lake and looks at his watch.
In the video, a woman is presented. The woman is in a city and looks at her watch.
```
You can define multiple lines of macros. If there is only one macro line, the app will generate a simple user interface to enter the macro variables when getting back to *Normal Mode* (advanced mode turned off)
### Command line parameters for Gradio Server
--i2v : launch the image to video generator\
--t2v : launch the text to video generator\
--t2v : launch the text to video generator (default defined in the configuration)\
--t2v-14B : launch the 14B model text to video generator\
--t2v-1-3B : launch the 1.3B model text to video generator\
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
--lora-preset preset : name of preset gile (without the extension) to preload
@ -208,7 +236,12 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
--compile : turn on pytorch compilation\
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
--profile no : default (4) : no of profile between 1 and 5\
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.\
--seed no : set default seed value\
--frames no : set the default number of frames to generate\
--steps no : set the default number of denoising steps\
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
--advanced : turn on the advanced mode while launching the app
### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here :

View File

@ -20,6 +20,9 @@ import gc
import traceback
import math
import asyncio
from wan.utils import prompt_parser
PROMPT_VARS_MAX = 10
def _parse_args():
parser = argparse.ArgumentParser(
@ -60,7 +63,7 @@ def _parse_args():
parser.add_argument(
"--lora-dir-i2v",
type=str,
default="loras_i2v",
default="",
help="Path to a directory that contains Loras for i2v"
)
@ -72,6 +75,13 @@ def _parse_args():
help="Path to a directory that contains Loras"
)
parser.add_argument(
"--check-loras",
type=str,
default=0,
help="Filter Loras that are not valid"
)
parser.add_argument(
"--lora-preset",
@ -101,6 +111,34 @@ def _parse_args():
help="Verbose level"
)
parser.add_argument(
"--steps",
type=int,
default=0,
help="default denoising steps"
)
parser.add_argument(
"--frames",
type=int,
default=0,
help="default number of frames"
)
parser.add_argument(
"--seed",
type=int,
default=-1,
help="default generation seed"
)
parser.add_argument(
"--advanced",
action="store_true",
help="Access advanced options by default"
)
parser.add_argument(
"--server-port",
type=str,
@ -133,6 +171,18 @@ def _parse_args():
help="image to video mode"
)
parser.add_argument(
"--t2v-14B",
action="store_true",
help="text to video mode 14B model"
)
parser.add_argument(
"--t2v-1-3B",
action="store_true",
help="text to video mode 1.3B model"
)
parser.add_argument(
"--compile",
action="store_true",
@ -196,6 +246,13 @@ preload =int(args.preload)
force_profile_no = int(args.profile)
verbose_level = int(args.verbose)
quantizeTransformer = args.quantize_transformer
default_seed = args.seed
default_number_frames = int(args.frames)
if default_number_frames > 0:
default_number_frames = ((default_number_frames - 1) // 4) * 4 + 1
default_inference_steps = args.steps
check_loras = args.check_loras ==1
advanced = args.advanced
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
@ -247,12 +304,26 @@ if args.t2v:
use_image2video = False
if args.i2v:
use_image2video = True
if args.t2v_14B:
use_image2video = False
if not "14B" in transformer_filename_t2v:
transformer_filename_t2v = transformer_choices_t2v[2]
lock_ui_transformer = False
if args.t2v_1_3B:
transformer_filename_t2v = transformer_choices_t2v[0]
use_image2video = False
lock_ui_transformer = False
only_allow_edit_in_advanced = False
lora_dir =args.lora_dir
if use_image2video and len(lora_dir)==0:
root_lora_dir =args.lora_dir_i2v
lora_dir =args.lora_dir_i2v
if len(lora_dir) ==0:
root_lora_dir = "loras_i2v" if use_image2video else "loras"
else:
root_lora_dir = lora_dir
lora_dir = get_lora_dir(root_lora_dir)
lora_preselected_preset = args.lora_preset
default_tea_cache = 0
@ -378,7 +449,8 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo
dir_presets.sort()
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
if check_loras:
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
if len(loras) > 0:
loras_names = [ Path(lora).stem for lora in loras ]
@ -492,7 +564,9 @@ def get_default_flow(model_filename):
return 3.0 if "480p" in model_filename else 5.0
def generate_header(model_filename, compile, attention_mode):
header = "<H2 ALIGN=CENTER><SPAN> ----------------- "
header = "<div class='title-with-lines'><div class=line></div><h2>"
if "image" in model_filename:
model_name = "Wan2.1 image2video"
@ -508,7 +582,8 @@ def generate_header(model_filename, compile, attention_mode):
if compile:
header += ", pytorch compilation ON"
header += ") -----------------</SPAN></H2>"
header += ") </h2><div class=line></div> "
return header
@ -591,19 +666,19 @@ def apply_changes( state,
# return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>"
def update_defaults(state, num_inference_steps,flow_shift):
def update_defaults(state, num_inference_steps,flow_shift, lset_name , loras_choices):
if "config_changes" not in state:
return get_default_flow("")
changes = state["config_changes"]
server_config = state["config_new"]
old_server_config = state["config_old"]
t2v_changed = False
if not use_image2video:
old_is_14B = "14B" in server_config["transformer_filename"]
new_is_14B = "14B" in old_server_config["transformer_filename"]
trans_file = server_config["transformer_filename"]
# if old_is_14B != new_is_14B:
t2v_changed = old_is_14B != new_is_14B
# num_inference_steps, flow_shift = get_default_flow(trans_file)
else:
old_is_720P = "720P" in server_config["transformer_filename_i2v"]
@ -615,9 +690,11 @@ def update_defaults(state, num_inference_steps,flow_shift):
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append( (new_preset_msg, ""))
return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
lset_choices.append( (get_new_preset_msg(advanced), ""))
if t2v_changed:
return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
else:
return num_inference_steps, flow_shift, header, lset_name , loras_choices
from moviepy.editor import ImageSequenceClip
@ -661,9 +738,20 @@ def abort_generation(state):
else:
return gr.Button(interactive= True)
def refresh_gallery(state):
def refresh_gallery(state, txt):
file_list = state.get("file_list", None)
return file_list
prompt = state.get("prompt", "")
if len(prompt) == 0:
return file_list, gr.Text(visible= False, value="")
else:
prompts_max = state.get("prompts_max",0)
prompt_no = state.get("prompt_no",0)
if prompts_max >1 :
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
else:
label = f"Current Prompt"
return file_list, gr.Text(visible= True, value=prompt, label=label)
def finalize_gallery(state):
choice = 0
@ -675,7 +763,7 @@ def finalize_gallery(state):
time.sleep(0.2)
global gen_in_progress
gen_in_progress = False
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False)
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False), gr.Text(visible= False, value="")
def select_video(state , event_data: gr.EventData):
data= event_data._data
@ -697,7 +785,9 @@ def one_more_video(state):
extra_orders = state.get("extra_orders", 0)
extra_orders += 1
state["extra_orders"] = extra_orders
prompts_max = state["prompts_max"]
prompts_max = state.get("prompts_max",0)
if prompts_max == 0:
return state
prompt_no = state["prompt_no"]
video_no = state["video_no"]
total_video = state["total_video"]
@ -730,6 +820,7 @@ def generate_video(
flow_shift,
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
tea_cache,
tea_cache_start_step_perc,
loras_choices,
@ -759,8 +850,11 @@ def generate_video(
elif attention_mode in attention_modes_supported:
attn = attention_mode
else:
raise gr.Error(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
return
if state.get("validate_success",0) != 1:
return
width, height = resolution.split("x")
width, height = int(width), int(height)
@ -768,17 +862,18 @@ def generate_video(
slg_layers = None
if use_image2video:
if "480p" in transformer_filename_i2v and width * height > 848*480:
raise gr.Error("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
return
resolution = str(width) + "*" + str(height)
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
raise gr.Error(f"Resolution {resolution} not supported by image 2 video")
gr.Info(f"Resolution {resolution} not supported by image 2 video")
return
else:
if "1.3B" in transformer_filename_t2v and width * height > 848*480:
raise gr.Error("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
gr.Info("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
return
offload.shared_state["_attention"] = attn
@ -808,6 +903,9 @@ def generate_video(
temp_filename = None
if len(prompt) ==0:
return
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
gr.Info(f"Error processing prompt template: " + errors)
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0:
@ -818,22 +916,31 @@ def generate_video(
image_to_continue = [ tup[0] for tup in image_to_continue ]
else:
image_to_continue = [image_to_continue]
if len(prompts) >= len(image_to_continue):
if len(prompts) % len(image_to_continue) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
rep = len(prompts) // len(image_to_continue)
new_image_to_continue = []
for i, _ in enumerate(prompts):
new_image_to_continue.append(image_to_continue[i//rep] )
image_to_continue = new_image_to_continue
else:
if len(image_to_continue) % len(prompts) !=0:
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
rep = len(image_to_continue) // len(prompts)
if multi_images_gen_type == 0:
new_prompts = []
for i, _ in enumerate(image_to_continue):
new_prompts.append( prompts[ i//rep] )
new_image_to_continue = []
for i in range(len(prompts) * len(image_to_continue) ):
new_prompts.append( prompts[ i % len(prompts)] )
new_image_to_continue.append(image_to_continue[i // len(prompts)] )
prompts = new_prompts
image_to_continue = new_image_to_continue
else:
if len(prompts) >= len(image_to_continue):
if len(prompts) % len(image_to_continue) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
rep = len(prompts) // len(image_to_continue)
new_image_to_continue = []
for i, _ in enumerate(prompts):
new_image_to_continue.append(image_to_continue[i//rep] )
image_to_continue = new_image_to_continue
else:
if len(image_to_continue) % len(prompts) !=0:
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
rep = len(image_to_continue) // len(prompts)
new_prompts = []
for i, _ in enumerate(image_to_continue):
new_prompts.append( prompts[ i//rep] )
prompts = new_prompts
elif video_to_continue != None and len(video_to_continue) >0 :
input_image_or_video_path = video_to_continue
@ -900,6 +1007,10 @@ def generate_video(
# TeaCache
trans.enable_teacache = tea_cache > 0
if trans.enable_teacache:
trans.teacache_multiplier = tea_cache
trans.rel_l1_thresh = 0
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
if use_image2video:
if '480p' in transformer_filename_i2v:
# teacache_thresholds = [0.13, .19, 0.26]
@ -935,9 +1046,11 @@ def generate_video(
start_time = time.time()
state["prompts_max"] = len(prompts)
for no, prompt in enumerate(prompts):
state["prompt"] = prompt
repeat_no = 0
state["prompt_no"] = no
extra_generation = 0
yield f"Prompt No{no}"
while True:
extra_orders = state.get("extra_orders",0)
state["extra_orders"] = 0
@ -950,8 +1063,6 @@ def generate_video(
if trans.enable_teacache:
trans.teacache_counter = 0
trans.teacache_multiplier = tea_cache
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
trans.num_steps = num_inference_steps
trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None
@ -1035,6 +1146,7 @@ def generate_video(
if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True
break
state["prompt"] = ""
if VRAM_crash:
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
else:
@ -1054,6 +1166,7 @@ def generate_video(
if samples == None:
end_time = time.time()
abort = True
state["prompt"] = ""
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else:
sample = samples.cpu()
@ -1076,9 +1189,10 @@ def generate_video(
print(f"New video saved to Path: "+video_path)
file_list.append(video_path)
if video_no < total_video:
yield status
yield status
else:
end_time = time.time()
state["prompt"] = ""
yield f"Total Generation Time: {end_time-start_time:.1f}s"
seed += 1
repeat_no += 1
@ -1089,18 +1203,22 @@ def generate_video(
offload.unload_loras_from_model(trans)
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
def get_new_preset_msg(advanced = True):
if advanced:
return "Enter here a Name for a Lora Preset or Choose one in the List"
else:
return "Choose a Lora Preset in this List to Apply a Special Effect"
def validate_delete_lset(lset_name):
if len(lset_name) == 0 or lset_name == new_preset_msg:
if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
gr.Info(f"Choose a Preset to delete")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
else:
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
def validate_save_lset(lset_name):
if len(lset_name) == 0 or lset_name == new_preset_msg:
if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
gr.Info("Please enter a name for the preset")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
else:
@ -1109,10 +1227,12 @@ def validate_save_lset(lset_name):
def cancel_lset():
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
global loras_presets
if len(lset_name) == 0 or lset_name== new_preset_msg:
if state.get("validate_success",0) == 0:
pass
if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
gr.Info("Please enter a name for the preset")
lset_choices =[("Please enter a name for a Lora Preset","")]
else:
@ -1142,14 +1262,14 @@ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_pr
gr.Info(f"Lora Preset '{lset_name}' has been created")
loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem )
lset_choices = [ ( preset, preset) for preset in loras_presets ]
lset_choices.append( (new_preset_msg, ""))
lset_choices.append( (get_new_preset_msg(), ""))
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def delete_lset(lset_name):
global loras_presets
lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" )
if len(lset_name) > 0 and lset_name != new_preset_msg:
if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
if not os.path.isfile(lset_name_filename):
raise gr.Error(f"Preset '{lset_name}' not found ")
os.remove(lset_name_filename)
@ -1161,7 +1281,7 @@ def delete_lset(lset_name):
gr.Info(f"Choose a Preset to delete")
lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((new_preset_msg, ""))
lset_choices.append((get_new_preset_msg(), ""))
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
def refresh_lora_list(lset_name, loras_choices):
@ -1179,15 +1299,15 @@ def refresh_lora_list(lset_name, loras_choices):
lora_names_selected.append(lora_id)
lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((new_preset_msg, ""))
lset_choices.append((get_new_preset_msg(advanced), ""))
if lset_name in loras_presets:
pos = loras_presets.index(lset_name)
else:
pos = len(loras_presets)
lset_name =""
errors = wan_model.model._loras_errors
if len(errors) > 0:
errors = getattr(wan_model.model, "_loras_errors", "")
if errors !=None and len(errors) > 0:
error_files = [path for path, _ in errors]
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
else:
@ -1196,9 +1316,11 @@ def refresh_lora_list(lset_name, loras_choices):
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
def apply_lset(state, lset_name, loras_choices, loras_mult_choices, prompt):
if len(lset_name) == 0 or lset_name== new_preset_msg:
state["apply_success"] = 0
if len(lset_name) == 0 or lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False):
gr.Info("Please choose a preset in the list or create one")
else:
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
@ -1213,34 +1335,221 @@ def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
prompt = "\n".join(prompts)
prompt = preset_prompt + '\n' + prompt
gr.Info(f"Lora Preset '{lset_name}' has been applied")
state["apply_success"] = 1
state["wizard_prompt"] = 0
return loras_choices, loras_mult_choices, prompt
def create_demo():
default_inference_steps = 30
def extract_prompt_from_wizard(state, prompt, wizard_prompt, allow_null_values, *args):
prompts = wizard_prompt.replace("\r" ,"").split("\n")
new_prompts = []
macro_already_written = False
for prompt in prompts:
if not macro_already_written and not prompt.startswith("#") and "{" in prompt and "}" in prompt:
variables = state["variables"]
values = args[:len(variables)]
macro = "! "
for i, (variable, value) in enumerate(zip(variables, values)):
if len(value) == 0 and not allow_null_values:
return prompt, "You need to provide a value for '" + variable + "'"
sub_values= [ "\"" + sub_value + "\"" for sub_value in value.split("\n") ]
value = ",".join(sub_values)
if i>0:
macro += " : "
macro += "{" + variable + "}"+ f"={value}"
if len(variables) > 0:
macro_already_written = True
new_prompts.append(macro)
new_prompts.append(prompt)
else:
new_prompts.append(prompt)
prompt = "\n".join(new_prompts)
return prompt, ""
def validate_wizard_prompt(state, prompt, wizard_prompt, *args):
state["validate_success"] = 0
if state.get("wizard_prompt",0) != 1:
state["validate_success"] = 1
return prompt
prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, False, *args)
if len(errors) > 0:
gr.Info(errors)
return prompt
state["validate_success"] = 1
return prompt
def fill_prompt_from_wizard(state, prompt, wizard_prompt, *args):
if state.get("wizard_prompt",0) == 1:
prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, True, *args)
if len(errors) > 0:
gr.Info(errors)
state["wizard_prompt"] = 0
return gr.Textbox(visible= True, value =prompt) , gr.Textbox(visible= False), gr.Column(visible = True), *[gr.Column(visible = False)] * 2, *[gr.Textbox(visible= False)] * PROMPT_VARS_MAX
def extract_wizard_prompt(prompt):
variables = []
values = {}
prompts = prompt.replace("\r" ,"").split("\n")
if sum(prompt.startswith("!") for prompt in prompts) > 1:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
new_prompts = []
errors = ""
for prompt in prompts:
if prompt.startswith("!"):
variables, errors = prompt_parser.extract_variable_names(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
if len(variables) > PROMPT_VARS_MAX:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
values, errors = prompt_parser.extract_variable_values(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
else:
variables_extra, errors = prompt_parser.extract_variable_names(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
variables += variables_extra
variables = [var for pos, var in enumerate(variables) if var not in variables[:pos]]
if len(variables) > PROMPT_VARS_MAX:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
new_prompts.append(prompt)
wizard_prompt = "\n".join(new_prompts)
return wizard_prompt, variables, values, errors
def fill_wizard_prompt(state, prompt, wizard_prompt):
def get_hidden_textboxes(num = PROMPT_VARS_MAX ):
return [gr.Textbox(value="", visible=False)] * num
hidden_column = gr.Column(visible = False)
visible_column = gr.Column(visible = True)
if advanced or state.get("apply_success") != 1:
return prompt, wizard_prompt, gr.Column(), gr.Column(), hidden_column, *get_hidden_textboxes()
prompt_parts= []
state["wizard_prompt"] = 0
wizard_prompt, variables, values, errors = extract_wizard_prompt(prompt)
if len(errors) > 0:
gr.Info( errors )
return gr.Textbox(prompt, visible=True), gr.Textbox(wizard_prompt, visible=False), visible_column, *[hidden_column] * 2, *get_hidden_textboxes()
for variable in variables:
value = values.get(variable, "")
prompt_parts.append(gr.Textbox( placeholder=variable, info= variable, visible= True, value= "\n".join(value) ))
any_macro = len(variables) > 0
prompt_parts += get_hidden_textboxes(PROMPT_VARS_MAX-len(prompt_parts))
state["variables"] = variables
state["wizard_prompt"] = 1
return gr.Textbox(prompt, visible = False), gr.Textbox(wizard_prompt, visible = True), hidden_column, visible_column, visible_column if any_macro else hidden_column, *prompt_parts
def switch_prompt_type(state, prompt, wizard_prompt, *prompt_vars):
if advanced:
return fill_prompt_from_wizard(state, prompt, wizard_prompt, *prompt_vars)
else:
state["apply_success"] = 1
return fill_wizard_prompt(state, prompt, wizard_prompt)
visible= False
def switch_advanced(new_advanced, lset_name):
global advanced
advanced= new_advanced
lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((get_new_preset_msg(advanced), ""))
if lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False) or lset_name=="":
lset_name = get_new_preset_msg(advanced)
if only_allow_edit_in_advanced:
return gr.Row(visible=new_advanced), gr.Row(visible=new_advanced), gr.Button(visible=new_advanced), gr.Row(visible= not new_advanced), gr.Dropdown(choices=lset_choices, value= lset_name)
else:
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
def download_loras():
from huggingface_hub import snapshot_download
yield "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
log_path = os.path.join(lora_dir, "log.txt")
if not os.path.isfile(log_path):
import shutil
tmp_path = os.path.join(lora_dir, "tmp_lora_dowload")
import shutil, glob
snapshot_download(repo_id="DeepBeepMeep/Wan2.1", allow_patterns="loras_i2v/*", local_dir= tmp_path)
[shutil.move(f, lora_dir) for f in glob.glob(os.path.join(tmp_path, "loras_i2v", "*.*")) if not "README.txt" in f ]
yield "<B><FONT SIZE=3>Loras have been completely downloaded</B></FONT>", *[gr.Column(visible=True)] * 2
from datetime import datetime
dt = datetime.today().strftime('%Y-%m-%d')
with open( log_path, "w", encoding="utf-8") as writer:
writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
return
def create_demo():
css= """
.title-with-lines {
display: flex;
align-items: center;
margin: 30px 0;
}
.line {
flex-grow: 1;
height: 1px;
background-color: #333;
}
h2 {
margin: 0 20px;
white-space: nowrap;
}
"""
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
with gr.Blocks() as demo:
state = gr.State({})
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="emerald", neutral_hue="slate", text_size= "md")) as demo:
state_dict = {}
if use_image2video:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
else:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP by <B>DeepBeepMeep</B>, a super fast and low VRAM Video Generator !</FONT>")
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
if use_image2video and False:
pass
else:
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
with gr.Accordion("Click here for some Info on how to use Wan2GP and to download 20+ Loras", open = False):
if use_image2video and False:
pass
else:
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
if use_image2video:
with gr.Row():
with gr.Row(scale =3):
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras). Dont't forget first to make a backup of your Loras just in case.")
with gr.Row(scale =1):
download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
with gr.Row():
download_status = gr.Markdown()
# css = """<STYLE>
# h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; }
@ -1370,8 +1679,36 @@ def create_demo():
apply_btn = gr.Button("Apply Changes")
with gr.Row():
with gr.Column():
with gr.Row(visible= len(loras)>0) as presets_column:
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")]
with gr.Column(scale=6):
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
with gr.Column(scale=1):
# with gr.Column():
with gr.Row(height=17):
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
refresh_lora_btn = gr.Button("Refresh", size="sm", min_width= 1, visible=advanced or not only_allow_edit_in_advanced)
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
save_lset_prompt_drop= gr.Dropdown(
choices=[
("Save Prompt Comments Only", 0),
("Save Full Prompt", 1)
], show_label= False, container=False, value =1, visible= False
)
with gr.Row(height=17, visible=False) as refresh2_row:
refresh_lora_btn2 = gr.Button("Refresh", size="sm", min_width= 1)
with gr.Row(height=17, visible=advanced or not only_allow_edit_in_advanced) as preset_buttons_rows:
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
if args.multiple_images:
image_to_continue = gr.Gallery(
@ -1380,8 +1717,30 @@ def create_demo():
else:
image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=use_image2video)
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos, lines that starts with # are ignored)", value=default_prompt, lines=3)
advanced_prompt = advanced
prompt_vars=[]
if not advanced_prompt:
default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt)
advanced_prompt = len(errors) > 0
with gr.Column(visible= advanced_prompt) as prompt_column_advanced: #visible= False
prompt = gr.Textbox( visible= advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments, ! lines = macros)", value=default_prompt, lines=3)
with gr.Column(visible=not advanced_prompt and len(variables) > 0) as prompt_column_wizard_vars: #visible= False
gr.Markdown("<B>Please fill the following input fields to adapt automatically the Prompt:</B>")
with gr.Row(): #visible= not advanced_prompt and len(variables) > 0
if not advanced_prompt:
for variable in variables:
value = values.get(variable, "")
prompt_vars.append(gr.Textbox( placeholder=variable, min_width=80, show_label= False, info= variable, visible= True, value= "\n".join(value) ))
state_dict["wizard_prompt"] = 1
state_dict["variables"] = variables
for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
with gr.Column(not advanced_prompt) as prompt_column_wizard:
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
state = gr.State(state_dict)
with gr.Row():
if use_image2video:
resolution = gr.Dropdown(
@ -1417,64 +1776,51 @@ def create_demo():
with gr.Row():
with gr.Column():
video_length = gr.Slider(5, 193, value=81, step=4, label="Number of frames (16 = 1s)")
video_length = gr.Slider(5, 193, value=default_number_frames if default_number_frames > 0 else 81, step=4, label="Number of frames (16 = 1s)")
with gr.Column():
num_inference_steps = gr.Slider(1, 100, value= default_inference_steps, step=1, label="Number of Inference Steps")
num_inference_steps = gr.Slider(1, 100, value= default_inference_steps if default_inference_steps > 0 else 30, step=1, label="Number of Inference Steps")
with gr.Row():
max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) #########
with gr.Row(visible= len(loras)>0):
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
with gr.Column(scale=5):
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
with gr.Column(scale=1):
# with gr.Column():
with gr.Row(height=17):
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
refresh_lora_btn = gr.Button("Refresh Lora List", size="sm", min_width= 1)
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
save_lset_prompt_drop= gr.Dropdown(
choices=[
("Save Prompt Comments Only", 0),
("Save Full Prompt", 1)
], show_label= False, container=False, visible= False
)
with gr.Row(height=17):
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
loras_choices = gr.Dropdown(
choices=[
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
],
value= default_loras_choices,
multiselect= True,
visible= len(loras)>0,
label="Activated Loras"
)
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
with gr.Row(visible=False) as advanced_row:
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
with gr.Row(visible=advanced) as advanced_row:
with gr.Column():
seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
seed = gr.Slider(-1, 999999999, value=default_seed, step=1, label="Seed (-1 for random)")
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
multi_images_gen_type = gr.Dropdown(
choices=[
("Generate every combination of images and texts prompts", 0),
("Match images and text prompts", 1),
], visible= args.multiple_images, label= "Multiple Images as Prompts"
)
with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
with gr.Row():
gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)")
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
with gr.Row():
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioned a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
with gr.Column() as loras_column:
loras_choices = gr.Dropdown(
choices=[
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
],
value= default_loras_choices,
multiselect= True,
visible= len(loras)>0,
label="Activated Loras"
)
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
with gr.Row():
gr.Markdown("<B>Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)</B>")
with gr.Row():
tea_cache_setting = gr.Dropdown(
choices=[
@ -1491,6 +1837,7 @@ def create_demo():
)
tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
RIFLEx_setting = gr.Dropdown(
choices=[
("Auto (ON if Video longer than 5s)", 0),
@ -1503,7 +1850,7 @@ def create_demo():
with gr.Row():
gr.Markdown("Experimental: Skip Layer guidance,should improve video quality")
gr.Markdown("<B>Experimental: Skip Layer guidance,should improve video quality</B>")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
@ -1529,33 +1876,43 @@ def create_demo():
slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end")
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
show_advanced.change(fn=switch_advanced, inputs=[show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, prompt, wizard_prompt, *prompt_vars], outputs = [prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
with gr.Column():
gen_status = gr.Text(label="Status", interactive= False)
output = gr.Gallery(
label="Generated videos", show_label=False, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
generate_btn = gr.Button("Generate")
onemore_btn = gr.Button("One More Please !", visible= False)
abort_btn = gr.Button("Abort")
gen_info = gr.Text(label="Current prompt", visible= False , interactive= False) #gr.Markdown("Current prompt") #, ,
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
apply_lset_btn.click(apply_lset, inputs=[state, lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt]).then(
fn = fill_wizard_prompt, inputs = [state, prompt, wizard_prompt], outputs = [ prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
)
refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
refresh_lora_btn2.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] )
abort_btn.click(abort_generation,state,abort_btn )
output.select(select_video, state, None )
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]).then(
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]
).then(
fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
).then(
fn=generate_video,
inputs=[
prompt,
@ -1568,6 +1925,7 @@ def create_demo():
flow_shift,
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
tea_cache_setting,
tea_cache_start_step_perc,
loras_choices,
@ -1587,7 +1945,7 @@ def create_demo():
).then(
finalize_gallery,
[state],
[output , abort_btn, generate_btn, onemore_btn]
[output , abort_btn, generate_btn, onemore_btn, gen_info]
)
apply_btn.click(
@ -1607,7 +1965,7 @@ def create_demo():
outputs= msg
).then(
update_defaults,
[state, num_inference_steps, flow_shift],
[state, num_inference_steps, flow_shift,lset_name , loras_choices],
[num_inference_steps, flow_shift, header, lset_name , loras_choices ]
)

View File

@ -676,6 +676,7 @@ class WanModel(ModelMixin, ConfigMixin):
best_threshold = 0.01
best_diff = 1000
best_signed_diff = 1000
target_nb_steps= int(len(timesteps) / speed_factor)
threshold = 0.01
while threshold <= 0.6:
@ -686,6 +687,8 @@ class WanModel(ModelMixin, ConfigMixin):
skip = False
if not (i<=start_step or i== len(timesteps)):
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
# self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
if accumulated_rel_l1_distance < threshold:
skip = True
else:
@ -693,15 +696,17 @@ class WanModel(ModelMixin, ConfigMixin):
previous_modulated_input = e_list[i]
if not skip:
nb_steps += 1
diff = abs(target_nb_steps - nb_steps)
signed_diff = target_nb_steps - nb_steps
diff = abs(signed_diff)
if diff < best_diff:
best_threshold = threshold
best_diff = diff
best_signed_diff = signed_diff
elif diff > best_diff:
break
threshold += 0.01
self.rel_l1_thresh = best_threshold
print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(len(timesteps) - best_diff):0.1f} for a target of x{speed_factor}")
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
return best_threshold
def forward(

291
wan/utils/prompt_parser.py Normal file
View File

@ -0,0 +1,291 @@
import re
def process_template(input_text):
"""
Process a text template with macro instructions and variable substitution.
Supports multiple values for variables to generate multiple output versions.
Each section between macro lines is treated as a separate template.
Args:
input_text (str): The input template text
Returns:
tuple: (output_text, error_message)
- output_text: Processed output with variables substituted, or empty string if error
- error_message: Error description and problematic line, or empty string if no error
"""
lines = input_text.strip().split('\n')
current_variables = {}
current_template_lines = []
all_output_lines = []
error_message = ""
# Process the input line by line
line_number = 0
while line_number < len(lines):
orig_line = lines[line_number]
line = orig_line.strip()
line_number += 1
# Skip empty lines or comments
if not line or line.startswith('#'):
continue
# Handle macro instructions
if line.startswith('!'):
# Process any accumulated template lines before starting a new macro
if current_template_lines:
# Process the current template with current variables
template_output, err = process_current_template(current_template_lines, current_variables)
if err:
return "", err
all_output_lines.extend(template_output)
current_template_lines = [] # Reset template lines
# Reset variables for the new macro
current_variables = {}
# Parse the macro line
macro_line = line[1:].strip()
# Check for unmatched braces in the whole line
open_braces = macro_line.count('{')
close_braces = macro_line.count('}')
if open_braces != close_braces:
error_message = f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces\nLine: '{orig_line}'"
return "", error_message
# Check for unclosed quotes
if macro_line.count('"') % 2 != 0:
error_message = f"Unclosed double quotes\nLine: '{orig_line}'"
return "", error_message
# Split by optional colon separator
var_sections = re.split(r'\s*:\s*', macro_line)
for section in var_sections:
section = section.strip()
if not section:
continue
# Extract variable name
var_match = re.search(r'\{([^}]+)\}', section)
if not var_match:
if '{' in section or '}' in section:
error_message = f"Malformed variable declaration\nLine: '{orig_line}'"
return "", error_message
continue
var_name = var_match.group(1).strip()
if not var_name:
error_message = f"Empty variable name\nLine: '{orig_line}'"
return "", error_message
# Check variable value format
value_part = section[section.find('}')+1:].strip()
if not value_part.startswith('='):
error_message = f"Missing '=' after variable '{{{var_name}}}'\nLine: '{orig_line}'"
return "", error_message
# Extract all quoted values
var_values = re.findall(r'"([^"]*)"', value_part)
# Check if there are values specified
if not var_values:
error_message = f"No quoted values found for variable '{{{var_name}}}'\nLine: '{orig_line}'"
return "", error_message
# Check for missing commas between values
# Look for patterns like "value""value" (missing comma)
if re.search(r'"[^,]*"[^,]*"', value_part):
error_message = f"Missing comma between values for variable '{{{var_name}}}'\nLine: '{orig_line}'"
return "", error_message
# Store the variable values
current_variables[var_name] = var_values
# Handle template lines
else:
# Check for unknown variables in template line
var_references = re.findall(r'\{([^}]+)\}', line)
for var_ref in var_references:
if var_ref not in current_variables:
error_message = f"Unknown variable '{{{var_ref}}}' in template\nLine: '{orig_line}'"
return "", error_message
# Add to current template lines
current_template_lines.append(line)
# Process any remaining template lines
if current_template_lines:
template_output, err = process_current_template(current_template_lines, current_variables)
if err:
return "", err
all_output_lines.extend(template_output)
return '\n'.join(all_output_lines), ""
def process_current_template(template_lines, variables):
"""
Process a set of template lines with the current variables.
Args:
template_lines (list): List of template lines to process
variables (dict): Dictionary of variable names to lists of values
Returns:
tuple: (output_lines, error_message)
"""
if not variables or not template_lines:
return template_lines, ""
output_lines = []
# Find the maximum number of values for any variable
max_values = max(len(values) for values in variables.values())
# Generate each combination
for i in range(max_values):
for template in template_lines:
output_line = template
for var_name, var_values in variables.items():
# Use modulo to cycle through values if needed
value_index = i % len(var_values)
var_value = var_values[value_index]
output_line = output_line.replace(f"{{{var_name}}}", var_value)
output_lines.append(output_line)
return output_lines, ""
def extract_variable_names(macro_line):
"""
Extract all variable names from a macro line.
Args:
macro_line (str): A macro line (with or without the leading '!')
Returns:
tuple: (variable_names, error_message)
- variable_names: List of variable names found in the macro
- error_message: Error description if any, empty string if no error
"""
# Remove leading '!' if present
if macro_line.startswith('!'):
macro_line = macro_line[1:].strip()
variable_names = []
# Check for unmatched braces
open_braces = macro_line.count('{')
close_braces = macro_line.count('}')
if open_braces != close_braces:
return [], f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces"
# Split by optional colon separator
var_sections = re.split(r'\s*:\s*', macro_line)
for section in var_sections:
section = section.strip()
if not section:
continue
# Extract variable name
var_matches = re.findall(r'\{([^}]+)\}', section)
for var_name in var_matches:
new_var = var_name.strip()
if not new_var in variable_names:
variable_names.append(new_var)
return variable_names, ""
def extract_variable_values(macro_line):
"""
Extract all variable names and their values from a macro line.
Args:
macro_line (str): A macro line (with or without the leading '!')
Returns:
tuple: (variables_dict, error_message)
- variables_dict: Dictionary mapping variable names to their values
- error_message: Error description if any, empty string if no error
"""
# Remove leading '!' if present
if macro_line.startswith('!'):
macro_line = macro_line[1:].strip()
variables = {}
# Check for unmatched braces
open_braces = macro_line.count('{')
close_braces = macro_line.count('}')
if open_braces != close_braces:
return {}, f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces"
# Check for unclosed quotes
if macro_line.count('"') % 2 != 0:
return {}, "Unclosed double quotes"
# Split by optional colon separator
var_sections = re.split(r'\s*:\s*', macro_line)
for section in var_sections:
section = section.strip()
if not section:
continue
# Extract variable name
var_match = re.search(r'\{([^}]+)\}', section)
if not var_match:
if '{' in section or '}' in section:
return {}, "Malformed variable declaration"
continue
var_name = var_match.group(1).strip()
if not var_name:
return {}, "Empty variable name"
# Check variable value format
value_part = section[section.find('}')+1:].strip()
if not value_part.startswith('='):
return {}, f"Missing '=' after variable '{{{var_name}}}'"
# Extract all quoted values
var_values = re.findall(r'"([^"]*)"', value_part)
# Check if there are values specified
if not var_values:
return {}, f"No quoted values found for variable '{{{var_name}}}'"
# Check for missing commas between values
if re.search(r'"[^,]*"[^,]*"', value_part):
return {}, f"Missing comma between values for variable '{{{var_name}}}'"
variables[var_name] = var_values
return variables, ""
def generate_macro_line(variables_dict):
"""
Generate a macro line from a dictionary of variable names and their values.
Args:
variables_dict (dict): Dictionary mapping variable names to lists of values
Returns:
str: A formatted macro line (including the leading '!')
"""
sections = []
for var_name, values in variables_dict.items():
# Format each value with quotes
quoted_values = [f'"{value}"' for value in values]
# Join values with commas
values_str = ','.join(quoted_values)
# Create the variable assignment
section = f"{{{var_name}}}={values_str}"
sections.append(section)
# Join sections with a colon and space for readability
return "! " + " : ".join(sections)