Added 10% boost, improved Loras and Teacache

This commit is contained in:
DeepBeepMeep 2025-03-10 23:26:42 +01:00
parent f9ce97a1ba
commit f8d9edeb50
5 changed files with 372 additions and 165 deletions

View File

@ -19,10 +19,11 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Mar 03, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
* Mar 04, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
If you upgrade you will need to do a 'pip install -r requirements.txt' again.
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
* Mar 04, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
- Support for all Wan including the Image to Video model
@ -152,15 +153,29 @@ python gradio_server.py --attention sdpa
### Loras support
-- Ready to be used but theoretical as no lora for Wan have been released as of today. ---
Every lora stored in the subfoler 'loras' will be automatically loaded. You will be then able to activate / desactive any of them when running the application.
Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0), alternatively you may specify a list of floats multipliers separated by a "," that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora shoud be separated by a space character or a carriage return. For instance:\
*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio interface. Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
Alternatively for each Lora's multiplier you may specify a list of float numbers multipliers separated by a "," (no space) that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
Then you can pre activate loras corresponding to a preset when launching the gradio server:
If multiple Loras are defined, remember that each multiplier associated to different Loras should be separated by a space or a carriage return, so we can specify the evolution of multipliers for multiple Loras. For instance for two Loras (press Shift Return to force a carriage return):
```
0.9,0.8,0.7
1.2,1.1,1.0
```
You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio Web interface. These presets will save the *comment* part of the prompt that should contain some instructions how to use the corresponding the loras (for instance by specifying a trigger word or providing an example).A comment in the prompt is a line that starts that a #. It will be ignored by the video generator. For instance:
```
# use they keyword ohnvx to trigger the Lora*
A ohnvx is driving a car
```
Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server:
```bash
python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
```
@ -180,11 +195,11 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
--open-browser : open automatically Browser when launching Gradio Server\
--lock-config : prevent modifying the video engine configuration from the interface\
--share : create a shareable URL on huggingface so that your server can be accessed remotely\
--multiple-images : Images as a starting point for new videos\
--multiple-images : allow the users to choose multiple images as different starting points for new videos\
--compile : turn on pytorch compilation\
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
--profile no : default (4) : no of profile between 1 and 5\
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.
### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here :

View File

@ -57,18 +57,18 @@ def _parse_args():
)
parser.add_argument(
"--lora-dir-i2v",
type=str,
default="loras_i2v",
help="Path to a directory that contains Loras for i2v"
)
# parser.add_argument(
# "--lora-dir-i2v",
# type=str,
# default="loras_i2v",
# help="Path to a directory that contains Loras for i2v"
# )
parser.add_argument(
"--lora-dir",
type=str,
default="loras",
default="",
help="Path to a directory that contains Loras"
)
@ -80,12 +80,12 @@ def _parse_args():
help="Lora preset to preload"
)
parser.add_argument(
"--lora-preset-i2v",
type=str,
default="",
help="Lora preset to preload for i2v"
)
# parser.add_argument(
# "--lora-preset-i2v",
# type=str,
# default="",
# help="Lora preset to preload for i2v"
# )
parser.add_argument(
"--profile",
@ -198,6 +198,7 @@ if not Path(server_config_filename).is_file():
"text_encoder_filename" : text_encoder_choices[1],
"compile" : "",
"default_ui": "t2v",
"boost" : 1,
"vae_config": 0,
"profile" : profile_type.LowRAM_LowVRAM }
@ -223,6 +224,7 @@ if len(args.attention)> 0:
profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
compile = server_config.get("compile", "")
boost = server_config.get("boost", 1)
vae_config = server_config.get("vae_config", 0)
if len(args.vae_config) > 0:
vae_config = int(args.vae_config)
@ -234,13 +236,14 @@ if args.t2v:
if args.i2v:
use_image2video = True
if use_image2video:
lora_dir =args.lora_dir_i2v
lora_preselected_preset = args.lora_preset_i2v
else:
lora_dir =args.lora_dir
lora_preselected_preset = args.lora_preset
# if use_image2video:
# lora_dir =args.lora_dir_i2v
# lora_preselected_preset = args.lora_preset_i2v
# else:
lora_dir =args.lora_dir
if len(lora_dir) ==0:
lora_dir = "loras_i2v" if use_image2video else "loras"
lora_preselected_preset = args.lora_preset
default_tea_cache = 0
# if args.fast : #or args.fastest
# transformer_filename_t2v = transformer_choices_t2v[2]
@ -321,8 +324,16 @@ def extract_preset(lset_name, loras):
raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
loras_mult_choices = lset["loras_mult"]
return loras_choices, loras_mult_choices
prompt = lset.get("prompt", "")
return loras_choices, loras_mult_choices, prompt, lset.get("full_prompt", False)
def get_default_prompt(i2v):
if i2v:
return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field."
else:
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
loras =[]
loras_names = []
@ -337,7 +348,7 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
raise Exception("--lora-dir should be a path to a directory that contains Loras")
default_lora_preset = ""
default_prompt = ""
if lora_dir != None:
import glob
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
@ -350,15 +361,16 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
if len(loras) > 0:
loras_names = [ Path(lora).stem for lora in loras ]
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
if len(lora_preselected_preset) > 0:
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
default_lora_preset = lora_preselected_preset
default_loras_choices, default_loras_multis_str= extract_preset(default_lora_preset, loras)
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
default_loras_choices, default_loras_multis_str, default_prompt, _ = extract_preset(default_lora_preset, loras)
if len(default_prompt) == 0:
default_prompt = get_default_prompt(use_image2video)
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
def load_t2v_model(model_filename, value):
@ -439,13 +451,13 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
kwargs["budgets"] = { "*" : "70%" }
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
gen_in_progress = False
def get_auto_attention():
@ -487,13 +499,14 @@ def apply_changes( state,
profile_choice,
vae_config_choice,
default_ui_choice ="t2v",
boost_choice = 1
):
if args.lock_config:
return
if gen_in_progress:
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
return
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
server_config = {"attention_mode" : attention_choice,
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
@ -502,6 +515,7 @@ def apply_changes( state,
"profile" : profile_choice,
"vae_config" : vae_config_choice,
"default_ui" : default_ui_choice,
"boost" : boost_choice,
}
if Path(server_config_filename).is_file():
@ -529,7 +543,7 @@ def apply_changes( state,
state["config_new"] = server_config
state["config_old"] = old_server_config
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost
attention_mode = server_config["attention_mode"]
profile = server_config["profile"]
compile = server_config["compile"]
@ -537,8 +551,8 @@ def apply_changes( state,
transformer_filename_i2v = server_config["transformer_filename_i2v"]
text_encoder_filename = server_config["text_encoder_filename"]
vae_config = server_config["vae_config"]
if all(change in ["attention_mode", "vae_config", "default_ui"] for change in changes ):
boost = server_config["boost"]
if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ):
if "attention_mode" in changes:
pass
@ -548,7 +562,7 @@ def apply_changes( state,
offloadobj = None
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
@ -727,7 +741,9 @@ def generate_video(
if len(prompt) ==0:
return
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
if len(prompts) ==0:
return
if use_image2video:
if image_to_continue is not None:
if isinstance(image_to_continue, list):
@ -772,6 +788,9 @@ def generate_video(
return False
list_mult_choices_nums = []
if len(loras_mult_choices) > 0:
loras_mult_choices_list = loras_mult_choices.replace("\r", "").split("\n")
loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")]
loras_mult_choices = " ".join(loras_mult_choices_list)
list_mult_choices_str = loras_mult_choices.split(" ")
for i, mult in enumerate(list_mult_choices_str):
mult = mult.strip()
@ -805,18 +824,36 @@ def generate_video(
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
joint_pass = boost ==1
# TeaCache
trans = wan_model.model
trans.enable_teacache = tea_cache > 0
if trans.enable_teacache:
if use_image2video:
if '480p' in transformer_filename_i2v:
# teacache_thresholds = [0.13, .19, 0.26]
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
elif '720p' in transformer_filename_i2v:
teacache_thresholds = [0.18, 0.2 , 0.3]
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
else:
raise gr.Error("Teacache not supported for this model")
else:
if '1.3B' in transformer_filename_t2v:
# teacache_thresholds= [0.05, 0.07, 0.08]
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
elif '14B' in transformer_filename_t2v:
# teacache_thresholds = [0.14, 0.15, 0.2]
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
else:
raise gr.Error("Teacache not supported for this model")
import random
if seed == None or seed <0:
seed = random.randint(0, 999999999)
file_list = []
state["file_list"] = file_list
from einops import rearrange
save_path = os.path.join(os.getcwd(), "gradio_outputs")
os.makedirs(save_path, exist_ok=True)
video_no = 0
@ -830,14 +867,12 @@ def generate_video(
if trans.enable_teacache:
trans.teacache_counter = 0
trans.rel_l1_thresh = tea_cache
trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2)
trans.teacache_multiplier = tea_cache
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
trans.num_steps = num_inference_steps
trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None
trans.previous_modulated_input_uncond = None
trans.previous_residual_cond = None
trans.previous_modulated_input_cond= None
trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
video_no += 1
status = f"Video {video_no}/{total_video}"
@ -853,7 +888,7 @@ def generate_video(
if use_image2video:
samples = wan_model.generate(
prompt,
image_to_continue[ (video_no-1) % len(image_to_continue)],
image_to_continue[ (video_no-1) % len(image_to_continue)].convert('RGB'),
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
@ -864,7 +899,8 @@ def generate_video(
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
)
else:
@ -880,7 +916,8 @@ def generate_video(
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
)
except Exception as e:
gen_in_progress = False
@ -911,6 +948,7 @@ def generate_video(
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
if trans.enable_teacache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
@ -957,7 +995,25 @@ def generate_video(
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
def save_lset(lset_name, loras_choices, loras_mult_choices):
def validate_delete_lset(lset_name):
if len(lset_name) == 0 or lset_name == new_preset_msg:
gr.Info(f"Choose a Preset to delete")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
else:
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
def validate_save_lset(lset_name):
if len(lset_name) == 0 or lset_name == new_preset_msg:
gr.Info("Please enter a name for the preset")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
else:
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
def cancel_lset():
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
global loras_presets
if len(lset_name) == 0 or lset_name== new_preset_msg:
@ -968,6 +1024,16 @@ def save_lset(lset_name, loras_choices, loras_mult_choices):
loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
if save_lset_prompt_cbox!=1:
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt for prompt in prompts if len(prompt)> 0 and prompt.startswith("#")]
prompt = "\n".join(prompts)
if len(prompt) > 0:
lset["prompt"] = prompt
lset["full_prompt"] = save_lset_prompt_cbox ==1
lset_name_filename = lset_name + ".lset"
full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
@ -982,7 +1048,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices):
lset_choices = [ ( preset, preset) for preset in loras_presets ]
lset_choices.append( (new_preset_msg, ""))
return gr.Dropdown(choices=lset_choices, value= lset_name)
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def delete_lset(lset_name):
global loras_presets
@ -1000,23 +1066,31 @@ def delete_lset(lset_name):
lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((new_preset_msg, ""))
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1])
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
def apply_lset(lset_name, loras_choices, loras_mult_choices):
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
if len(lset_name) == 0 or lset_name== new_preset_msg:
gr.Info("Please choose a preset in the list or create one")
else:
loras_choices, loras_mult_choices= extract_preset(lset_name, loras)
loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras)
if full_prompt:
prompt = preset_prompt
elif len(preset_prompt) > 0:
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
prompt = "\n".join(prompts)
prompt = preset_prompt + '\n' + prompt
gr.Info(f"Lora Preset '{lset_name}' has been applied")
return loras_choices, loras_mult_choices
return loras_choices, loras_mult_choices, prompt
def create_demo():
default_inference_steps = 30
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
with gr.Blocks() as demo:
state = gr.State({})
@ -1130,6 +1204,16 @@ def create_demo():
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
)
boost_choice = gr.Dropdown(
choices=[
# ("Auto (ON if Video longer than 5s)", 0),
("ON", 1),
("OFF", 2),
],
value=boost,
label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)"
)
profile_choice = gr.Dropdown(
choices=[
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
@ -1161,16 +1245,12 @@ def create_demo():
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
if args.multiple_images:
image_to_continue = gr.Gallery(
label="Images as a starting point for new videos", type ="numpy", #file_types= "image",
label="Images as a starting point for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
else:
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
if use_image2video:
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", lines=3)
else:
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.", lines=3)
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos, lines that starts with # are ignored)", value=default_prompt, lines=3)
with gr.Row():
if use_image2video:
@ -1223,9 +1303,21 @@ def create_demo():
# with gr.Column():
with gr.Row(height=17):
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
save_lset_prompt_drop= gr.Dropdown(
choices=[
("Save Prompt Comments Only", 0),
("Save Full Prompt", 1)
], show_label= False, container=False, visible= False
)
with gr.Row(height=17):
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
loras_choices = gr.Dropdown(
@ -1237,7 +1329,7 @@ def create_demo():
visible= len(loras)>0,
label="Activated Loras"
)
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns", value=default_loras_multis_str, visible= len(loras)>0 )
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
with gr.Row(visible=False) as advanced_row:
@ -1250,18 +1342,23 @@ def create_demo():
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
tea_cache_setting = gr.Dropdown(
choices=[
("Tea Cache Disabled", 0),
("0.03 (around x1.6 speed up)", 0.03),
("0.05 (around x2 speed up)", 0.05),
("0.10 (around x3 speed up)", 0.1),
],
value=default_tea_cache,
visible=True,
label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)"
)
tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)")
with gr.Row():
gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)")
with gr.Row():
tea_cache_setting = gr.Dropdown(
choices=[
("Tea Cache Disabled", 0),
("around x1.5 speed up", 1.5),
("around x1.75 speed up", 1.75),
("around x2 speed up", 2.0),
("around x2.25 speed up", 2.25),
("around x2.5 speed up", 2.5),
],
value=default_tea_cache,
visible=True,
label="Tea Cache Global Acceleration"
)
tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
RIFLEx_setting = gr.Dropdown(
choices=[
@ -1283,9 +1380,13 @@ def create_demo():
generate_btn = gr.Button("Generate")
abort_btn = gr.Button("Abort")
save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices], outputs=[lset_name])
delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name])
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices], outputs=[loras_choices, loras_mult_choices])
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
@ -1335,6 +1436,7 @@ def create_demo():
profile_choice,
vae_config_choice,
default_ui_choice,
boost_choice,
],
outputs= msg
).then(

View File

@ -146,7 +146,7 @@ class WanI2V:
callback = None,
enable_RIFLEx = False,
VAE_tile_size= 0,
joint_pass = False,
):
r"""
Generates video frames from input image and text prompt using diffusion process.
@ -310,9 +310,22 @@ class WanI2V:
'pipeline' : self
}
arg_both= {
'context': [context[0]],
'context2': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
'pipeline' : self
}
if offload_model:
torch.cuda.empty_cache()
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
# self.model.to(self.device)
if callback != None:
callback(-1, None)
@ -323,17 +336,22 @@ class WanI2V:
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
if self._interrupt:
return None
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
if self._interrupt:
return None
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, **arg_both)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
if self._interrupt:
return None
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
if self._interrupt:
return None
del latent_model_input
if offload_model:
torch.cuda.empty_cache()

View File

@ -667,7 +667,43 @@ class WanModel(ModelMixin, ConfigMixin):
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
rescale_func = np.poly1d(self.coefficients)
e_list = []
for t in timesteps:
t = torch.stack([t])
e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)))
best_threshold = 0.01
best_diff = 1000
target_nb_steps= int(len(timesteps) / speed_factor)
threshold = 0.01
while threshold <= 0.6:
accumulated_rel_l1_distance =0
nb_steps = 0
diff = 1000
for i, t in enumerate(timesteps):
skip = False
if not (i<=start_step or i== len(timesteps)):
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
if accumulated_rel_l1_distance < threshold:
skip = True
else:
accumulated_rel_l1_distance = 0
previous_modulated_input = e_list[i]
if not skip:
nb_steps += 1
diff = abs(target_nb_steps - nb_steps)
if diff < best_diff:
best_threshold = threshold
best_diff = diff
elif diff > best_diff:
break
threshold += 0.01
self.rel_l1_thresh = best_threshold
print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(len(timesteps) - best_diff):0.1f} for a target of x{speed_factor}")
return best_threshold
def forward(
self,
x,
@ -679,6 +715,7 @@ class WanModel(ModelMixin, ConfigMixin):
freqs = None,
pipeline = None,
current_step = 0,
context2 = None,
is_uncond=False
):
r"""
@ -722,10 +759,13 @@ class WanModel(ModelMixin, ConfigMixin):
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
if len(x)==1 and seq_len == x[0].size(1):
x = x[0]
else:
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
e = self.time_embedding(
@ -740,82 +780,105 @@ class WanModel(ModelMixin, ConfigMixin):
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if context2!=None:
context2 = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context2
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
if context2 != None:
context2 = torch.concat([context_clip, context2], dim=1)
joint_pass = context2 != None
if joint_pass:
x_list = [x, x.clone()]
context_list = [context, context2]
is_uncond = False
else:
x_list = [x]
context_list = [context]
del x
should_calc = True
if self.enable_teacache and current_step >= self.teacache_start_step:
if current_step == self.teacache_start_step:
self.accumulated_rel_l1_distance_cond = 0
self.accumulated_rel_l1_distance_uncond = 0
self.teacache_skipped_cond_steps = 0
self.teacache_skipped_uncond_steps = 0
if self.enable_teacache:
if is_uncond:
should_calc = self.should_calc
else:
prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond
acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
temb_relative_l1 = relative_l1_distance(prev_input, e0)
setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
should_calc = False
self.teacache_counter += 1
else:
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
should_calc = True
setattr(self, acc_distance_attr, 0)
if is_uncond:
self.previous_modulated_input_uncond = e0.clone()
if should_calc:
self.previous_residual_uncond = None
self.accumulated_rel_l1_distance = 0
else:
x += self.previous_residual_uncond
self.teacache_skipped_cond_steps += 1
# print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" )
else:
self.previous_modulated_input_cond = e0.clone()
if should_calc:
self.previous_residual_cond = None
else:
x += self.previous_residual_cond
self.teacache_skipped_uncond_steps += 1
# print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" )
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
self.teacache_skipped_steps += 1
# print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" )
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = e
self.should_calc = should_calc
if should_calc:
if not should_calc:
for i, x in enumerate(x_list):
x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond
else:
if self.enable_teacache:
ori_hidden_states = x.clone()
if joint_pass or is_uncond:
self.previous_residual_uncond = None
if joint_pass or not is_uncond:
self.previous_residual_cond = None
ori_hidden_states = x_list[0].clone()
# arguments
kwargs = dict(
e=e0,
# e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
context=context,
# context=context,
context_lens=context_lens)
for block in self.blocks:
if pipeline._interrupt:
return [None]
x = block(x, **kwargs)
if joint_pass:
return None, None
else:
return [None]
for i, (x, context) in enumerate(zip(x_list, context_list)):
x_list[i] = block(x, context = context, e= e0, **kwargs)
del x
if self.enable_teacache:
residual = ori_hidden_states # just to have a readable code
torch.sub(x, ori_hidden_states, out=residual)
if is_uncond:
self.previous_residual_uncond = residual
if joint_pass:
self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states)
self.previous_residual_uncond = ori_hidden_states
torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond)
else:
self.previous_residual_cond = residual
del residual, ori_hidden_states
residual = ori_hidden_states # just to have a readable code
torch.sub(x_list[0], ori_hidden_states, out=residual)
if i==1 or is_uncond:
self.previous_residual_uncond = residual
else:
self.previous_residual_cond = residual
residual, ori_hidden_states = None, None
# head
x = self.head(x, e)
for i, x in enumerate(x_list):
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
# unpatchify
x_list[i] = self.unpatchify(x, grid_sizes)
del x
if joint_pass:
return x_list[0][0], x_list[1][0]
else:
return [u.float() for u in x_list[0]]
def unpatchify(self, x, grid_sizes):
r"""

View File

@ -131,7 +131,8 @@ class WanT2V:
offload_model=True,
callback = None,
enable_RIFLEx = None,
VAE_tile_size = 0
VAE_tile_size = 0,
joint_pass = False,
):
r"""
Generates video frames from text prompt using diffusion process.
@ -240,8 +241,10 @@ class WanT2V:
freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
callback(-1, None)
for i, t in enumerate(tqdm(timesteps)):
@ -251,14 +254,20 @@ class WanT2V:
timestep = torch.stack(timestep)
# self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
if self._interrupt:
return None
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep,current_step=i, **arg_both)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
if self._interrupt:
return None
del latent_model_input
noise_pred = noise_pred_uncond + guide_scale * (