Lora festival part 2, new macros, new user interface

This commit is contained in:
DeepBeepMeep 2025-03-17 23:43:34 +01:00
parent e554e1a3d6
commit a15060267a
4 changed files with 815 additions and 128 deletions

View File

@ -19,6 +19,13 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!! ## 🔥 Latest News!!
* Marc 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
- Clearer user interface
- Download 30 Loras in one click to try them all (expand the info section)
- Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt
- Added basic macro prompt language to prefill prompts with differnent values. With one prompt template, you can generate multiple prompts.
- New Multiple images prompts: you can now combine any number of images with any number of text promtps (need to launch the app with --multiple-images)
- New command lines options to launch directly the 1.3B t2v model or the 14B t2v model
* Mar 14, 2025: 👋 Wan2.1GP v1.7: * Mar 14, 2025: 👋 Wan2.1GP v1.7:
- Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt* - Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt*
- Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation - Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
@ -121,7 +128,11 @@ To run the text to video generator (in Low VRAM mode):
```bash ```bash
python gradio_server.py python gradio_server.py
#or #or
python gradio_server.py --t2v python gradio_server.py --t2v #launch the default text 2 video model
#or
python gradio_server.py --t2v-14B #for the 14B model
#or
python gradio_server.py --t2v-1-3B #for the 1.3B model
``` ```
@ -191,10 +202,27 @@ python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.l
You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer. You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
### Macros (basic)
In *Advanced Mode*, you can starts prompt lines with a "!" , for instance:\
```
! {Subject}="cat","woman","man", {Location}="forest","lake","city", {Possessive}="its", "her", "his"
In the video, a {Subject} is presented. The {Subject} is in a {Location} and looks at {Possessive} watch.
```
This will create automatically 3 prompts that will cause the generation of 3 videos:
```
In the video, a cat is presented. The cat is in a forest and looks at its watch.
In the video, a man is presented. The man is in a lake and looks at his watch.
In the video, a woman is presented. The woman is in a city and looks at her watch.
```
You can define multiple lines of macros. If there is only one macro line, the app will generate a simple user interface to enter the macro variables when getting back to *Normal Mode* (advanced mode turned off)
### Command line parameters for Gradio Server ### Command line parameters for Gradio Server
--i2v : launch the image to video generator\ --i2v : launch the image to video generator\
--t2v : launch the text to video generator\ --t2v : launch the text to video generator (default defined in the configuration)\
--t2v-14B : launch the 14B model text to video generator\
--t2v-1-3B : launch the 1.3B model text to video generator\
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\ --quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\ --lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
--lora-preset preset : name of preset gile (without the extension) to preload --lora-preset preset : name of preset gile (without the extension) to preload
@ -208,7 +236,12 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
--compile : turn on pytorch compilation\ --compile : turn on pytorch compilation\
--attention mode: force attention mode among, sdpa, flash, sage, sage2\ --attention mode: force attention mode among, sdpa, flash, sage, sage2\
--profile no : default (4) : no of profile between 1 and 5\ --profile no : default (4) : no of profile between 1 and 5\
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4. --preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.\
--seed no : set default seed value\
--frames no : set the default number of frames to generate\
--steps no : set the default number of denoising steps\
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
--advanced : turn on the advanced mode while launching the app
### Profiles (for power users only) ### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here : You can choose between 5 profiles, but two are really relevant here :

View File

@ -20,6 +20,9 @@ import gc
import traceback import traceback
import math import math
import asyncio import asyncio
from wan.utils import prompt_parser
PROMPT_VARS_MAX = 10
def _parse_args(): def _parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -60,7 +63,7 @@ def _parse_args():
parser.add_argument( parser.add_argument(
"--lora-dir-i2v", "--lora-dir-i2v",
type=str, type=str,
default="loras_i2v", default="",
help="Path to a directory that contains Loras for i2v" help="Path to a directory that contains Loras for i2v"
) )
@ -72,6 +75,13 @@ def _parse_args():
help="Path to a directory that contains Loras" help="Path to a directory that contains Loras"
) )
parser.add_argument(
"--check-loras",
type=str,
default=0,
help="Filter Loras that are not valid"
)
parser.add_argument( parser.add_argument(
"--lora-preset", "--lora-preset",
@ -101,6 +111,34 @@ def _parse_args():
help="Verbose level" help="Verbose level"
) )
parser.add_argument(
"--steps",
type=int,
default=0,
help="default denoising steps"
)
parser.add_argument(
"--frames",
type=int,
default=0,
help="default number of frames"
)
parser.add_argument(
"--seed",
type=int,
default=-1,
help="default generation seed"
)
parser.add_argument(
"--advanced",
action="store_true",
help="Access advanced options by default"
)
parser.add_argument( parser.add_argument(
"--server-port", "--server-port",
type=str, type=str,
@ -133,6 +171,18 @@ def _parse_args():
help="image to video mode" help="image to video mode"
) )
parser.add_argument(
"--t2v-14B",
action="store_true",
help="text to video mode 14B model"
)
parser.add_argument(
"--t2v-1-3B",
action="store_true",
help="text to video mode 1.3B model"
)
parser.add_argument( parser.add_argument(
"--compile", "--compile",
action="store_true", action="store_true",
@ -196,6 +246,13 @@ preload =int(args.preload)
force_profile_no = int(args.profile) force_profile_no = int(args.profile)
verbose_level = int(args.verbose) verbose_level = int(args.verbose)
quantizeTransformer = args.quantize_transformer quantizeTransformer = args.quantize_transformer
default_seed = args.seed
default_number_frames = int(args.frames)
if default_number_frames > 0:
default_number_frames = ((default_number_frames - 1) // 4) * 4 + 1
default_inference_steps = args.steps
check_loras = args.check_loras ==1
advanced = args.advanced
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"] transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"] transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
@ -247,12 +304,26 @@ if args.t2v:
use_image2video = False use_image2video = False
if args.i2v: if args.i2v:
use_image2video = True use_image2video = True
if args.t2v_14B:
use_image2video = False
if not "14B" in transformer_filename_t2v:
transformer_filename_t2v = transformer_choices_t2v[2]
lock_ui_transformer = False
if args.t2v_1_3B:
transformer_filename_t2v = transformer_choices_t2v[0]
use_image2video = False
lock_ui_transformer = False
only_allow_edit_in_advanced = False
lora_dir =args.lora_dir lora_dir =args.lora_dir
if use_image2video and len(lora_dir)==0: if use_image2video and len(lora_dir)==0:
root_lora_dir =args.lora_dir_i2v lora_dir =args.lora_dir_i2v
if len(lora_dir) ==0: if len(lora_dir) ==0:
root_lora_dir = "loras_i2v" if use_image2video else "loras" root_lora_dir = "loras_i2v" if use_image2video else "loras"
else:
root_lora_dir = lora_dir
lora_dir = get_lora_dir(root_lora_dir) lora_dir = get_lora_dir(root_lora_dir)
lora_preselected_preset = args.lora_preset lora_preselected_preset = args.lora_preset
default_tea_cache = 0 default_tea_cache = 0
@ -378,6 +449,7 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo
dir_presets.sort() dir_presets.sort()
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets] loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
if check_loras:
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier, loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
if len(loras) > 0: if len(loras) > 0:
@ -492,7 +564,9 @@ def get_default_flow(model_filename):
return 3.0 if "480p" in model_filename else 5.0 return 3.0 if "480p" in model_filename else 5.0
def generate_header(model_filename, compile, attention_mode): def generate_header(model_filename, compile, attention_mode):
header = "<H2 ALIGN=CENTER><SPAN> ----------------- "
header = "<div class='title-with-lines'><div class=line></div><h2>"
if "image" in model_filename: if "image" in model_filename:
model_name = "Wan2.1 image2video" model_name = "Wan2.1 image2video"
@ -508,7 +582,8 @@ def generate_header(model_filename, compile, attention_mode):
if compile: if compile:
header += ", pytorch compilation ON" header += ", pytorch compilation ON"
header += ") -----------------</SPAN></H2>" header += ") </h2><div class=line></div> "
return header return header
@ -591,19 +666,19 @@ def apply_changes( state,
# return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>" # return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>"
def update_defaults(state, num_inference_steps,flow_shift): def update_defaults(state, num_inference_steps,flow_shift, lset_name , loras_choices):
if "config_changes" not in state: if "config_changes" not in state:
return get_default_flow("") return get_default_flow("")
changes = state["config_changes"] changes = state["config_changes"]
server_config = state["config_new"] server_config = state["config_new"]
old_server_config = state["config_old"] old_server_config = state["config_old"]
t2v_changed = False
if not use_image2video: if not use_image2video:
old_is_14B = "14B" in server_config["transformer_filename"] old_is_14B = "14B" in server_config["transformer_filename"]
new_is_14B = "14B" in old_server_config["transformer_filename"] new_is_14B = "14B" in old_server_config["transformer_filename"]
trans_file = server_config["transformer_filename"] trans_file = server_config["transformer_filename"]
# if old_is_14B != new_is_14B: t2v_changed = old_is_14B != new_is_14B
# num_inference_steps, flow_shift = get_default_flow(trans_file) # num_inference_steps, flow_shift = get_default_flow(trans_file)
else: else:
old_is_720P = "720P" in server_config["transformer_filename_i2v"] old_is_720P = "720P" in server_config["transformer_filename_i2v"]
@ -615,9 +690,11 @@ def update_defaults(state, num_inference_steps,flow_shift):
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] ) header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)] new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
lset_choices = [ (preset, preset) for preset in loras_presets] lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append( (new_preset_msg, "")) lset_choices.append( (get_new_preset_msg(advanced), ""))
if t2v_changed:
return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= []) return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
else:
return num_inference_steps, flow_shift, header, lset_name , loras_choices
from moviepy.editor import ImageSequenceClip from moviepy.editor import ImageSequenceClip
@ -661,9 +738,20 @@ def abort_generation(state):
else: else:
return gr.Button(interactive= True) return gr.Button(interactive= True)
def refresh_gallery(state): def refresh_gallery(state, txt):
file_list = state.get("file_list", None) file_list = state.get("file_list", None)
return file_list prompt = state.get("prompt", "")
if len(prompt) == 0:
return file_list, gr.Text(visible= False, value="")
else:
prompts_max = state.get("prompts_max",0)
prompt_no = state.get("prompt_no",0)
if prompts_max >1 :
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
else:
label = f"Current Prompt"
return file_list, gr.Text(visible= True, value=prompt, label=label)
def finalize_gallery(state): def finalize_gallery(state):
choice = 0 choice = 0
@ -675,7 +763,7 @@ def finalize_gallery(state):
time.sleep(0.2) time.sleep(0.2)
global gen_in_progress global gen_in_progress
gen_in_progress = False gen_in_progress = False
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False) return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False), gr.Text(visible= False, value="")
def select_video(state , event_data: gr.EventData): def select_video(state , event_data: gr.EventData):
data= event_data._data data= event_data._data
@ -697,7 +785,9 @@ def one_more_video(state):
extra_orders = state.get("extra_orders", 0) extra_orders = state.get("extra_orders", 0)
extra_orders += 1 extra_orders += 1
state["extra_orders"] = extra_orders state["extra_orders"] = extra_orders
prompts_max = state["prompts_max"] prompts_max = state.get("prompts_max",0)
if prompts_max == 0:
return state
prompt_no = state["prompt_no"] prompt_no = state["prompt_no"]
video_no = state["video_no"] video_no = state["video_no"]
total_video = state["total_video"] total_video = state["total_video"]
@ -730,6 +820,7 @@ def generate_video(
flow_shift, flow_shift,
embedded_guidance_scale, embedded_guidance_scale,
repeat_generation, repeat_generation,
multi_images_gen_type,
tea_cache, tea_cache,
tea_cache_start_step_perc, tea_cache_start_step_perc,
loras_choices, loras_choices,
@ -759,8 +850,11 @@ def generate_video(
elif attention_mode in attention_modes_supported: elif attention_mode in attention_modes_supported:
attn = attention_mode attn = attention_mode
else: else:
raise gr.Error(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.") gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
return
if state.get("validate_success",0) != 1:
return
width, height = resolution.split("x") width, height = resolution.split("x")
width, height = int(width), int(height) width, height = int(width), int(height)
@ -768,17 +862,18 @@ def generate_video(
slg_layers = None slg_layers = None
if use_image2video: if use_image2video:
if "480p" in transformer_filename_i2v and width * height > 848*480: if "480p" in transformer_filename_i2v and width * height > 848*480:
raise gr.Error("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
return
resolution = str(width) + "*" + str(height) resolution = str(width) + "*" + str(height)
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']: if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
raise gr.Error(f"Resolution {resolution} not supported by image 2 video") gr.Info(f"Resolution {resolution} not supported by image 2 video")
return
else: else:
if "1.3B" in transformer_filename_t2v and width * height > 848*480: if "1.3B" in transformer_filename_t2v and width * height > 848*480:
raise gr.Error("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P") gr.Info("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
return
offload.shared_state["_attention"] = attn offload.shared_state["_attention"] = attn
@ -808,6 +903,9 @@ def generate_video(
temp_filename = None temp_filename = None
if len(prompt) ==0: if len(prompt) ==0:
return return
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
gr.Info(f"Error processing prompt template: " + errors)
prompts = prompt.replace("\r", "").split("\n") prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0: if len(prompts) ==0:
@ -818,6 +916,15 @@ def generate_video(
image_to_continue = [ tup[0] for tup in image_to_continue ] image_to_continue = [ tup[0] for tup in image_to_continue ]
else: else:
image_to_continue = [image_to_continue] image_to_continue = [image_to_continue]
if multi_images_gen_type == 0:
new_prompts = []
new_image_to_continue = []
for i in range(len(prompts) * len(image_to_continue) ):
new_prompts.append( prompts[ i % len(prompts)] )
new_image_to_continue.append(image_to_continue[i // len(prompts)] )
prompts = new_prompts
image_to_continue = new_image_to_continue
else:
if len(prompts) >= len(image_to_continue): if len(prompts) >= len(image_to_continue):
if len(prompts) % len(image_to_continue) !=0: if len(prompts) % len(image_to_continue) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
@ -900,6 +1007,10 @@ def generate_video(
# TeaCache # TeaCache
trans.enable_teacache = tea_cache > 0 trans.enable_teacache = tea_cache > 0
if trans.enable_teacache: if trans.enable_teacache:
trans.teacache_multiplier = tea_cache
trans.rel_l1_thresh = 0
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
if use_image2video: if use_image2video:
if '480p' in transformer_filename_i2v: if '480p' in transformer_filename_i2v:
# teacache_thresholds = [0.13, .19, 0.26] # teacache_thresholds = [0.13, .19, 0.26]
@ -935,9 +1046,11 @@ def generate_video(
start_time = time.time() start_time = time.time()
state["prompts_max"] = len(prompts) state["prompts_max"] = len(prompts)
for no, prompt in enumerate(prompts): for no, prompt in enumerate(prompts):
state["prompt"] = prompt
repeat_no = 0 repeat_no = 0
state["prompt_no"] = no state["prompt_no"] = no
extra_generation = 0 extra_generation = 0
yield f"Prompt No{no}"
while True: while True:
extra_orders = state.get("extra_orders",0) extra_orders = state.get("extra_orders",0)
state["extra_orders"] = 0 state["extra_orders"] = 0
@ -950,8 +1063,6 @@ def generate_video(
if trans.enable_teacache: if trans.enable_teacache:
trans.teacache_counter = 0 trans.teacache_counter = 0
trans.teacache_multiplier = tea_cache
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
trans.num_steps = num_inference_steps trans.num_steps = num_inference_steps
trans.teacache_skipped_steps = 0 trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None trans.previous_residual_uncond = None
@ -1035,6 +1146,7 @@ def generate_video(
if any( keyword in frame.name for keyword in keyword_list): if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True VRAM_crash = True
break break
state["prompt"] = ""
if VRAM_crash: if VRAM_crash:
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.") raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
else: else:
@ -1054,6 +1166,7 @@ def generate_video(
if samples == None: if samples == None:
end_time = time.time() end_time = time.time()
abort = True abort = True
state["prompt"] = ""
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s" yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else: else:
sample = samples.cpu() sample = samples.cpu()
@ -1079,6 +1192,7 @@ def generate_video(
yield status yield status
else: else:
end_time = time.time() end_time = time.time()
state["prompt"] = ""
yield f"Total Generation Time: {end_time-start_time:.1f}s" yield f"Total Generation Time: {end_time-start_time:.1f}s"
seed += 1 seed += 1
repeat_no += 1 repeat_no += 1
@ -1089,18 +1203,22 @@ def generate_video(
offload.unload_loras_from_model(trans) offload.unload_loras_from_model(trans)
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above" def get_new_preset_msg(advanced = True):
if advanced:
return "Enter here a Name for a Lora Preset or Choose one in the List"
else:
return "Choose a Lora Preset in this List to Apply a Special Effect"
def validate_delete_lset(lset_name): def validate_delete_lset(lset_name):
if len(lset_name) == 0 or lset_name == new_preset_msg: if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
gr.Info(f"Choose a Preset to delete") gr.Info(f"Choose a Preset to delete")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False) return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
else: else:
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True) return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
def validate_save_lset(lset_name): def validate_save_lset(lset_name):
if len(lset_name) == 0 or lset_name == new_preset_msg: if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
gr.Info("Please enter a name for the preset") gr.Info("Please enter a name for the preset")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False) return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
else: else:
@ -1109,10 +1227,12 @@ def validate_save_lset(lset_name):
def cancel_lset(): def cancel_lset():
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox): def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
global loras_presets global loras_presets
if len(lset_name) == 0 or lset_name== new_preset_msg: if state.get("validate_success",0) == 0:
pass
if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
gr.Info("Please enter a name for the preset") gr.Info("Please enter a name for the preset")
lset_choices =[("Please enter a name for a Lora Preset","")] lset_choices =[("Please enter a name for a Lora Preset","")]
else: else:
@ -1142,14 +1262,14 @@ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_pr
gr.Info(f"Lora Preset '{lset_name}' has been created") gr.Info(f"Lora Preset '{lset_name}' has been created")
loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem ) loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem )
lset_choices = [ ( preset, preset) for preset in loras_presets ] lset_choices = [ ( preset, preset) for preset in loras_presets ]
lset_choices.append( (new_preset_msg, "")) lset_choices.append( (get_new_preset_msg(), ""))
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def delete_lset(lset_name): def delete_lset(lset_name):
global loras_presets global loras_presets
lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" ) lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" )
if len(lset_name) > 0 and lset_name != new_preset_msg: if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
if not os.path.isfile(lset_name_filename): if not os.path.isfile(lset_name_filename):
raise gr.Error(f"Preset '{lset_name}' not found ") raise gr.Error(f"Preset '{lset_name}' not found ")
os.remove(lset_name_filename) os.remove(lset_name_filename)
@ -1161,7 +1281,7 @@ def delete_lset(lset_name):
gr.Info(f"Choose a Preset to delete") gr.Info(f"Choose a Preset to delete")
lset_choices = [ (preset, preset) for preset in loras_presets] lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((new_preset_msg, "")) lset_choices.append((get_new_preset_msg(), ""))
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False) return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
def refresh_lora_list(lset_name, loras_choices): def refresh_lora_list(lset_name, loras_choices):
@ -1179,15 +1299,15 @@ def refresh_lora_list(lset_name, loras_choices):
lora_names_selected.append(lora_id) lora_names_selected.append(lora_id)
lset_choices = [ (preset, preset) for preset in loras_presets] lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((new_preset_msg, "")) lset_choices.append((get_new_preset_msg(advanced), ""))
if lset_name in loras_presets: if lset_name in loras_presets:
pos = loras_presets.index(lset_name) pos = loras_presets.index(lset_name)
else: else:
pos = len(loras_presets) pos = len(loras_presets)
lset_name ="" lset_name =""
errors = wan_model.model._loras_errors errors = getattr(wan_model.model, "_loras_errors", "")
if len(errors) > 0: if errors !=None and len(errors) > 0:
error_files = [path for path, _ in errors] error_files = [path for path, _ in errors]
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files)) gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
else: else:
@ -1196,9 +1316,11 @@ def refresh_lora_list(lset_name, loras_choices):
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected) return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt): def apply_lset(state, lset_name, loras_choices, loras_mult_choices, prompt):
if len(lset_name) == 0 or lset_name== new_preset_msg: state["apply_success"] = 0
if len(lset_name) == 0 or lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False):
gr.Info("Please choose a preset in the list or create one") gr.Info("Please choose a preset in the list or create one")
else: else:
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras) loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
@ -1213,24 +1335,203 @@ def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
prompt = "\n".join(prompts) prompt = "\n".join(prompts)
prompt = preset_prompt + '\n' + prompt prompt = preset_prompt + '\n' + prompt
gr.Info(f"Lora Preset '{lset_name}' has been applied") gr.Info(f"Lora Preset '{lset_name}' has been applied")
state["apply_success"] = 1
state["wizard_prompt"] = 0
return loras_choices, loras_mult_choices, prompt return loras_choices, loras_mult_choices, prompt
def extract_prompt_from_wizard(state, prompt, wizard_prompt, allow_null_values, *args):
prompts = wizard_prompt.replace("\r" ,"").split("\n")
new_prompts = []
macro_already_written = False
for prompt in prompts:
if not macro_already_written and not prompt.startswith("#") and "{" in prompt and "}" in prompt:
variables = state["variables"]
values = args[:len(variables)]
macro = "! "
for i, (variable, value) in enumerate(zip(variables, values)):
if len(value) == 0 and not allow_null_values:
return prompt, "You need to provide a value for '" + variable + "'"
sub_values= [ "\"" + sub_value + "\"" for sub_value in value.split("\n") ]
value = ",".join(sub_values)
if i>0:
macro += " : "
macro += "{" + variable + "}"+ f"={value}"
if len(variables) > 0:
macro_already_written = True
new_prompts.append(macro)
new_prompts.append(prompt)
else:
new_prompts.append(prompt)
prompt = "\n".join(new_prompts)
return prompt, ""
def validate_wizard_prompt(state, prompt, wizard_prompt, *args):
state["validate_success"] = 0
if state.get("wizard_prompt",0) != 1:
state["validate_success"] = 1
return prompt
prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, False, *args)
if len(errors) > 0:
gr.Info(errors)
return prompt
state["validate_success"] = 1
return prompt
def fill_prompt_from_wizard(state, prompt, wizard_prompt, *args):
if state.get("wizard_prompt",0) == 1:
prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, True, *args)
if len(errors) > 0:
gr.Info(errors)
state["wizard_prompt"] = 0
return gr.Textbox(visible= True, value =prompt) , gr.Textbox(visible= False), gr.Column(visible = True), *[gr.Column(visible = False)] * 2, *[gr.Textbox(visible= False)] * PROMPT_VARS_MAX
def extract_wizard_prompt(prompt):
variables = []
values = {}
prompts = prompt.replace("\r" ,"").split("\n")
if sum(prompt.startswith("!") for prompt in prompts) > 1:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
new_prompts = []
errors = ""
for prompt in prompts:
if prompt.startswith("!"):
variables, errors = prompt_parser.extract_variable_names(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
if len(variables) > PROMPT_VARS_MAX:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
values, errors = prompt_parser.extract_variable_values(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
else:
variables_extra, errors = prompt_parser.extract_variable_names(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
variables += variables_extra
variables = [var for pos, var in enumerate(variables) if var not in variables[:pos]]
if len(variables) > PROMPT_VARS_MAX:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
new_prompts.append(prompt)
wizard_prompt = "\n".join(new_prompts)
return wizard_prompt, variables, values, errors
def fill_wizard_prompt(state, prompt, wizard_prompt):
def get_hidden_textboxes(num = PROMPT_VARS_MAX ):
return [gr.Textbox(value="", visible=False)] * num
hidden_column = gr.Column(visible = False)
visible_column = gr.Column(visible = True)
if advanced or state.get("apply_success") != 1:
return prompt, wizard_prompt, gr.Column(), gr.Column(), hidden_column, *get_hidden_textboxes()
prompt_parts= []
state["wizard_prompt"] = 0
wizard_prompt, variables, values, errors = extract_wizard_prompt(prompt)
if len(errors) > 0:
gr.Info( errors )
return gr.Textbox(prompt, visible=True), gr.Textbox(wizard_prompt, visible=False), visible_column, *[hidden_column] * 2, *get_hidden_textboxes()
for variable in variables:
value = values.get(variable, "")
prompt_parts.append(gr.Textbox( placeholder=variable, info= variable, visible= True, value= "\n".join(value) ))
any_macro = len(variables) > 0
prompt_parts += get_hidden_textboxes(PROMPT_VARS_MAX-len(prompt_parts))
state["variables"] = variables
state["wizard_prompt"] = 1
return gr.Textbox(prompt, visible = False), gr.Textbox(wizard_prompt, visible = True), hidden_column, visible_column, visible_column if any_macro else hidden_column, *prompt_parts
def switch_prompt_type(state, prompt, wizard_prompt, *prompt_vars):
if advanced:
return fill_prompt_from_wizard(state, prompt, wizard_prompt, *prompt_vars)
else:
state["apply_success"] = 1
return fill_wizard_prompt(state, prompt, wizard_prompt)
visible= False
def switch_advanced(new_advanced, lset_name):
global advanced
advanced= new_advanced
lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((get_new_preset_msg(advanced), ""))
if lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False) or lset_name=="":
lset_name = get_new_preset_msg(advanced)
if only_allow_edit_in_advanced:
return gr.Row(visible=new_advanced), gr.Row(visible=new_advanced), gr.Button(visible=new_advanced), gr.Row(visible= not new_advanced), gr.Dropdown(choices=lset_choices, value= lset_name)
else:
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
def download_loras():
from huggingface_hub import snapshot_download
yield "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
log_path = os.path.join(lora_dir, "log.txt")
if not os.path.isfile(log_path):
import shutil
tmp_path = os.path.join(lora_dir, "tmp_lora_dowload")
import shutil, glob
snapshot_download(repo_id="DeepBeepMeep/Wan2.1", allow_patterns="loras_i2v/*", local_dir= tmp_path)
[shutil.move(f, lora_dir) for f in glob.glob(os.path.join(tmp_path, "loras_i2v", "*.*")) if not "README.txt" in f ]
yield "<B><FONT SIZE=3>Loras have been completely downloaded</B></FONT>", *[gr.Column(visible=True)] * 2
from datetime import datetime
dt = datetime.today().strftime('%Y-%m-%d')
with open( log_path, "w", encoding="utf-8") as writer:
writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
return
def create_demo(): def create_demo():
css= """
default_inference_steps = 30 .title-with-lines {
display: flex;
align-items: center;
margin: 30px 0;
}
.line {
flex-grow: 1;
height: 1px;
background-color: #333;
}
h2 {
margin: 0 20px;
white-space: nowrap;
}
"""
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v) default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
with gr.Blocks() as demo: with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="emerald", neutral_hue="slate", text_size= "md")) as demo:
state = gr.State({}) state_dict = {}
if use_image2video: if use_image2video:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>") gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
else: else:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>") gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP by <B>DeepBeepMeep</B>, a super fast and low VRAM Video Generator !</FONT>") gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
with gr.Accordion("Click here for some Info on how to use Wan2GP and to download 20+ Loras", open = False):
if use_image2video and False: if use_image2video and False:
pass pass
else: else:
@ -1241,6 +1542,14 @@ def create_demo():
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
if use_image2video:
with gr.Row():
with gr.Row(scale =3):
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras). Dont't forget first to make a backup of your Loras just in case.")
with gr.Row(scale =1):
download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
with gr.Row():
download_status = gr.Markdown()
# css = """<STYLE> # css = """<STYLE>
# h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; } # h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; }
@ -1370,8 +1679,36 @@ def create_demo():
apply_btn = gr.Button("Apply Changes") apply_btn = gr.Button("Apply Changes")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(visible= len(loras)>0) as presets_column:
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")]
with gr.Column(scale=6):
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
with gr.Column(scale=1):
# with gr.Column():
with gr.Row(height=17):
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
refresh_lora_btn = gr.Button("Refresh", size="sm", min_width= 1, visible=advanced or not only_allow_edit_in_advanced)
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
save_lset_prompt_drop= gr.Dropdown(
choices=[
("Save Prompt Comments Only", 0),
("Save Full Prompt", 1)
], show_label= False, container=False, value =1, visible= False
)
with gr.Row(height=17, visible=False) as refresh2_row:
refresh_lora_btn2 = gr.Button("Refresh", size="sm", min_width= 1)
with gr.Row(height=17, visible=advanced or not only_allow_edit_in_advanced) as preset_buttons_rows:
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) ####### video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
if args.multiple_images: if args.multiple_images:
image_to_continue = gr.Gallery( image_to_continue = gr.Gallery(
@ -1380,7 +1717,29 @@ def create_demo():
else: else:
image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=use_image2video) image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=use_image2video)
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos, lines that starts with # are ignored)", value=default_prompt, lines=3) advanced_prompt = advanced
prompt_vars=[]
if not advanced_prompt:
default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt)
advanced_prompt = len(errors) > 0
with gr.Column(visible= advanced_prompt) as prompt_column_advanced: #visible= False
prompt = gr.Textbox( visible= advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments, ! lines = macros)", value=default_prompt, lines=3)
with gr.Column(visible=not advanced_prompt and len(variables) > 0) as prompt_column_wizard_vars: #visible= False
gr.Markdown("<B>Please fill the following input fields to adapt automatically the Prompt:</B>")
with gr.Row(): #visible= not advanced_prompt and len(variables) > 0
if not advanced_prompt:
for variable in variables:
value = values.get(variable, "")
prompt_vars.append(gr.Textbox( placeholder=variable, min_width=80, show_label= False, info= variable, visible= True, value= "\n".join(value) ))
state_dict["wizard_prompt"] = 1
state_dict["variables"] = variables
for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
with gr.Column(not advanced_prompt) as prompt_column_wizard:
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
state = gr.State(state_dict)
with gr.Row(): with gr.Row():
if use_image2video: if use_image2video:
@ -1417,40 +1776,37 @@ def create_demo():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
video_length = gr.Slider(5, 193, value=81, step=4, label="Number of frames (16 = 1s)") video_length = gr.Slider(5, 193, value=default_number_frames if default_number_frames > 0 else 81, step=4, label="Number of frames (16 = 1s)")
with gr.Column(): with gr.Column():
num_inference_steps = gr.Slider(1, 100, value= default_inference_steps, step=1, label="Number of Inference Steps") num_inference_steps = gr.Slider(1, 100, value= default_inference_steps if default_inference_steps > 0 else 30, step=1, label="Number of Inference Steps")
with gr.Row(): with gr.Row():
max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) ######### max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) #########
with gr.Row(visible= len(loras)>0):
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")] show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
with gr.Column(scale=5): with gr.Row(visible=advanced) as advanced_row:
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset) with gr.Column():
with gr.Column(scale=1): seed = gr.Slider(-1, 999999999, value=default_seed, step=1, label="Seed (-1 for random)")
# with gr.Column(): with gr.Row():
with gr.Row(height=17): repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1) multi_images_gen_type = gr.Dropdown(
refresh_lora_btn = gr.Button("Refresh Lora List", size="sm", min_width= 1)
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
save_lset_prompt_drop= gr.Dropdown(
choices=[ choices=[
("Save Prompt Comments Only", 0), ("Generate every combination of images and texts prompts", 0),
("Save Full Prompt", 1) ("Match images and text prompts", 1),
], show_label= False, container=False, visible= False ], visible= args.multiple_images, label= "Multiple Images as Prompts"
) )
with gr.Row():
with gr.Row(height=17): guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False) flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
save_lset_btn = gr.Button("Save", size="sm", min_width= 1) with gr.Row():
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1) negative_prompt = gr.Textbox(label="Negative Prompt", value="")
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) with gr.Row():
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioned a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
with gr.Column() as loras_column:
loras_choices = gr.Dropdown( loras_choices = gr.Dropdown(
choices=[ choices=[
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names) (lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
@ -1462,19 +1818,9 @@ def create_demo():
) )
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 ) loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
with gr.Row(visible=False) as advanced_row:
with gr.Column():
seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
with gr.Row(): with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value="") gr.Markdown("<B>Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)</B>")
with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
with gr.Row():
gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)")
with gr.Row(): with gr.Row():
tea_cache_setting = gr.Dropdown( tea_cache_setting = gr.Dropdown(
choices=[ choices=[
@ -1491,6 +1837,7 @@ def create_demo():
) )
tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation") tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
RIFLEx_setting = gr.Dropdown( RIFLEx_setting = gr.Dropdown(
choices=[ choices=[
("Auto (ON if Video longer than 5s)", 0), ("Auto (ON if Video longer than 5s)", 0),
@ -1503,7 +1850,7 @@ def create_demo():
with gr.Row(): with gr.Row():
gr.Markdown("Experimental: Skip Layer guidance,should improve video quality") gr.Markdown("<B>Experimental: Skip Layer guidance,should improve video quality</B>")
with gr.Row(): with gr.Row():
slg_switch = gr.Dropdown( slg_switch = gr.Dropdown(
choices=[ choices=[
@ -1529,33 +1876,43 @@ def create_demo():
slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end") slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end")
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row]) show_advanced.change(fn=switch_advanced, inputs=[show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, prompt, wizard_prompt, *prompt_vars], outputs = [prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
with gr.Column(): with gr.Column():
gen_status = gr.Text(label="Status", interactive= False) gen_status = gr.Text(label="Status", interactive= False)
output = gr.Gallery( output = gr.Gallery(
label="Generated videos", show_label=False, elem_id="gallery" label="Generated videos", show_label=False, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False) , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
generate_btn = gr.Button("Generate") generate_btn = gr.Button("Generate")
onemore_btn = gr.Button("One More Please !", visible= False) onemore_btn = gr.Button("One More Please !", visible= False)
abort_btn = gr.Button("Abort") abort_btn = gr.Button("Abort")
gen_info = gr.Text(label="Current prompt", visible= False , interactive= False) #gr.Markdown("Current prompt") #, ,
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ]) cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt]) apply_lset_btn.click(apply_lset, inputs=[state, lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt]).then(
fn = fill_wizard_prompt, inputs = [state, prompt, wizard_prompt], outputs = [ prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
)
refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices]) refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
refresh_lora_btn2.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
gen_status.change(refresh_gallery, inputs = [state], outputs = output ) gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] )
abort_btn.click(abort_generation,state,abort_btn ) abort_btn.click(abort_generation,state,abort_btn )
output.select(select_video, state, None ) output.select(select_video, state, None )
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state]) onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]).then( generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]
).then(
fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
).then(
fn=generate_video, fn=generate_video,
inputs=[ inputs=[
prompt, prompt,
@ -1568,6 +1925,7 @@ def create_demo():
flow_shift, flow_shift,
embedded_guidance_scale, embedded_guidance_scale,
repeat_generation, repeat_generation,
multi_images_gen_type,
tea_cache_setting, tea_cache_setting,
tea_cache_start_step_perc, tea_cache_start_step_perc,
loras_choices, loras_choices,
@ -1587,7 +1945,7 @@ def create_demo():
).then( ).then(
finalize_gallery, finalize_gallery,
[state], [state],
[output , abort_btn, generate_btn, onemore_btn] [output , abort_btn, generate_btn, onemore_btn, gen_info]
) )
apply_btn.click( apply_btn.click(
@ -1607,7 +1965,7 @@ def create_demo():
outputs= msg outputs= msg
).then( ).then(
update_defaults, update_defaults,
[state, num_inference_steps, flow_shift], [state, num_inference_steps, flow_shift,lset_name , loras_choices],
[num_inference_steps, flow_shift, header, lset_name , loras_choices ] [num_inference_steps, flow_shift, header, lset_name , loras_choices ]
) )

View File

@ -676,6 +676,7 @@ class WanModel(ModelMixin, ConfigMixin):
best_threshold = 0.01 best_threshold = 0.01
best_diff = 1000 best_diff = 1000
best_signed_diff = 1000
target_nb_steps= int(len(timesteps) / speed_factor) target_nb_steps= int(len(timesteps) / speed_factor)
threshold = 0.01 threshold = 0.01
while threshold <= 0.6: while threshold <= 0.6:
@ -686,6 +687,8 @@ class WanModel(ModelMixin, ConfigMixin):
skip = False skip = False
if not (i<=start_step or i== len(timesteps)): if not (i<=start_step or i== len(timesteps)):
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item()) accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
# self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
if accumulated_rel_l1_distance < threshold: if accumulated_rel_l1_distance < threshold:
skip = True skip = True
else: else:
@ -693,15 +696,17 @@ class WanModel(ModelMixin, ConfigMixin):
previous_modulated_input = e_list[i] previous_modulated_input = e_list[i]
if not skip: if not skip:
nb_steps += 1 nb_steps += 1
diff = abs(target_nb_steps - nb_steps) signed_diff = target_nb_steps - nb_steps
diff = abs(signed_diff)
if diff < best_diff: if diff < best_diff:
best_threshold = threshold best_threshold = threshold
best_diff = diff best_diff = diff
best_signed_diff = signed_diff
elif diff > best_diff: elif diff > best_diff:
break break
threshold += 0.01 threshold += 0.01
self.rel_l1_thresh = best_threshold self.rel_l1_thresh = best_threshold
print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(len(timesteps) - best_diff):0.1f} for a target of x{speed_factor}") print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
return best_threshold return best_threshold
def forward( def forward(

291
wan/utils/prompt_parser.py Normal file
View File

@ -0,0 +1,291 @@
import re
def process_template(input_text):
"""
Process a text template with macro instructions and variable substitution.
Supports multiple values for variables to generate multiple output versions.
Each section between macro lines is treated as a separate template.
Args:
input_text (str): The input template text
Returns:
tuple: (output_text, error_message)
- output_text: Processed output with variables substituted, or empty string if error
- error_message: Error description and problematic line, or empty string if no error
"""
lines = input_text.strip().split('\n')
current_variables = {}
current_template_lines = []
all_output_lines = []
error_message = ""
# Process the input line by line
line_number = 0
while line_number < len(lines):
orig_line = lines[line_number]
line = orig_line.strip()
line_number += 1
# Skip empty lines or comments
if not line or line.startswith('#'):
continue
# Handle macro instructions
if line.startswith('!'):
# Process any accumulated template lines before starting a new macro
if current_template_lines:
# Process the current template with current variables
template_output, err = process_current_template(current_template_lines, current_variables)
if err:
return "", err
all_output_lines.extend(template_output)
current_template_lines = [] # Reset template lines
# Reset variables for the new macro
current_variables = {}
# Parse the macro line
macro_line = line[1:].strip()
# Check for unmatched braces in the whole line
open_braces = macro_line.count('{')
close_braces = macro_line.count('}')
if open_braces != close_braces:
error_message = f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces\nLine: '{orig_line}'"
return "", error_message
# Check for unclosed quotes
if macro_line.count('"') % 2 != 0:
error_message = f"Unclosed double quotes\nLine: '{orig_line}'"
return "", error_message
# Split by optional colon separator
var_sections = re.split(r'\s*:\s*', macro_line)
for section in var_sections:
section = section.strip()
if not section:
continue
# Extract variable name
var_match = re.search(r'\{([^}]+)\}', section)
if not var_match:
if '{' in section or '}' in section:
error_message = f"Malformed variable declaration\nLine: '{orig_line}'"
return "", error_message
continue
var_name = var_match.group(1).strip()
if not var_name:
error_message = f"Empty variable name\nLine: '{orig_line}'"
return "", error_message
# Check variable value format
value_part = section[section.find('}')+1:].strip()
if not value_part.startswith('='):
error_message = f"Missing '=' after variable '{{{var_name}}}'\nLine: '{orig_line}'"
return "", error_message
# Extract all quoted values
var_values = re.findall(r'"([^"]*)"', value_part)
# Check if there are values specified
if not var_values:
error_message = f"No quoted values found for variable '{{{var_name}}}'\nLine: '{orig_line}'"
return "", error_message
# Check for missing commas between values
# Look for patterns like "value""value" (missing comma)
if re.search(r'"[^,]*"[^,]*"', value_part):
error_message = f"Missing comma between values for variable '{{{var_name}}}'\nLine: '{orig_line}'"
return "", error_message
# Store the variable values
current_variables[var_name] = var_values
# Handle template lines
else:
# Check for unknown variables in template line
var_references = re.findall(r'\{([^}]+)\}', line)
for var_ref in var_references:
if var_ref not in current_variables:
error_message = f"Unknown variable '{{{var_ref}}}' in template\nLine: '{orig_line}'"
return "", error_message
# Add to current template lines
current_template_lines.append(line)
# Process any remaining template lines
if current_template_lines:
template_output, err = process_current_template(current_template_lines, current_variables)
if err:
return "", err
all_output_lines.extend(template_output)
return '\n'.join(all_output_lines), ""
def process_current_template(template_lines, variables):
"""
Process a set of template lines with the current variables.
Args:
template_lines (list): List of template lines to process
variables (dict): Dictionary of variable names to lists of values
Returns:
tuple: (output_lines, error_message)
"""
if not variables or not template_lines:
return template_lines, ""
output_lines = []
# Find the maximum number of values for any variable
max_values = max(len(values) for values in variables.values())
# Generate each combination
for i in range(max_values):
for template in template_lines:
output_line = template
for var_name, var_values in variables.items():
# Use modulo to cycle through values if needed
value_index = i % len(var_values)
var_value = var_values[value_index]
output_line = output_line.replace(f"{{{var_name}}}", var_value)
output_lines.append(output_line)
return output_lines, ""
def extract_variable_names(macro_line):
"""
Extract all variable names from a macro line.
Args:
macro_line (str): A macro line (with or without the leading '!')
Returns:
tuple: (variable_names, error_message)
- variable_names: List of variable names found in the macro
- error_message: Error description if any, empty string if no error
"""
# Remove leading '!' if present
if macro_line.startswith('!'):
macro_line = macro_line[1:].strip()
variable_names = []
# Check for unmatched braces
open_braces = macro_line.count('{')
close_braces = macro_line.count('}')
if open_braces != close_braces:
return [], f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces"
# Split by optional colon separator
var_sections = re.split(r'\s*:\s*', macro_line)
for section in var_sections:
section = section.strip()
if not section:
continue
# Extract variable name
var_matches = re.findall(r'\{([^}]+)\}', section)
for var_name in var_matches:
new_var = var_name.strip()
if not new_var in variable_names:
variable_names.append(new_var)
return variable_names, ""
def extract_variable_values(macro_line):
"""
Extract all variable names and their values from a macro line.
Args:
macro_line (str): A macro line (with or without the leading '!')
Returns:
tuple: (variables_dict, error_message)
- variables_dict: Dictionary mapping variable names to their values
- error_message: Error description if any, empty string if no error
"""
# Remove leading '!' if present
if macro_line.startswith('!'):
macro_line = macro_line[1:].strip()
variables = {}
# Check for unmatched braces
open_braces = macro_line.count('{')
close_braces = macro_line.count('}')
if open_braces != close_braces:
return {}, f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces"
# Check for unclosed quotes
if macro_line.count('"') % 2 != 0:
return {}, "Unclosed double quotes"
# Split by optional colon separator
var_sections = re.split(r'\s*:\s*', macro_line)
for section in var_sections:
section = section.strip()
if not section:
continue
# Extract variable name
var_match = re.search(r'\{([^}]+)\}', section)
if not var_match:
if '{' in section or '}' in section:
return {}, "Malformed variable declaration"
continue
var_name = var_match.group(1).strip()
if not var_name:
return {}, "Empty variable name"
# Check variable value format
value_part = section[section.find('}')+1:].strip()
if not value_part.startswith('='):
return {}, f"Missing '=' after variable '{{{var_name}}}'"
# Extract all quoted values
var_values = re.findall(r'"([^"]*)"', value_part)
# Check if there are values specified
if not var_values:
return {}, f"No quoted values found for variable '{{{var_name}}}'"
# Check for missing commas between values
if re.search(r'"[^,]*"[^,]*"', value_part):
return {}, f"Missing comma between values for variable '{{{var_name}}}'"
variables[var_name] = var_values
return variables, ""
def generate_macro_line(variables_dict):
"""
Generate a macro line from a dictionary of variable names and their values.
Args:
variables_dict (dict): Dictionary mapping variable names to lists of values
Returns:
str: A formatted macro line (including the leading '!')
"""
sections = []
for var_name, values in variables_dict.items():
# Format each value with quotes
quoted_values = [f'"{value}"' for value in values]
# Join values with commas
values_str = ','.join(quoted_values)
# Create the variable assignment
section = f"{{{var_name}}}={values_str}"
sections.append(section)
# Join sections with a colon and space for readability
return "! " + " : ".join(sections)