mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
Added 10% boost, improved Loras and Teacache
This commit is contained in:
parent
f9ce97a1ba
commit
f8d9edeb50
35
README.md
35
README.md
@ -19,10 +19,11 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
* Mar 03, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
|
||||||
* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
|
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
||||||
|
* Mar 04, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
|
||||||
If you upgrade you will need to do a 'pip install -r requirements.txt' again.
|
If you upgrade you will need to do a 'pip install -r requirements.txt' again.
|
||||||
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
* Mar 04, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
||||||
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
||||||
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
||||||
- Support for all Wan including the Image to Video model
|
- Support for all Wan including the Image to Video model
|
||||||
@ -152,15 +153,29 @@ python gradio_server.py --attention sdpa
|
|||||||
|
|
||||||
### Loras support
|
### Loras support
|
||||||
|
|
||||||
-- Ready to be used but theoretical as no lora for Wan have been released as of today. ---
|
|
||||||
|
|
||||||
Every lora stored in the subfoler 'loras' will be automatically loaded. You will be then able to activate / desactive any of them when running the application.
|
Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
|
||||||
|
|
||||||
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0), alternatively you may specify a list of floats multipliers separated by a "," that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
|
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora shoud be separated by a space character or a carriage return. For instance:\
|
||||||
|
*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
|
||||||
|
|
||||||
You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio interface. Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
|
Alternatively for each Lora's multiplier you may specify a list of float numbers multipliers separated by a "," (no space) that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
|
||||||
|
|
||||||
Then you can pre activate loras corresponding to a preset when launching the gradio server:
|
If multiple Loras are defined, remember that each multiplier associated to different Loras should be separated by a space or a carriage return, so we can specify the evolution of multipliers for multiple Loras. For instance for two Loras (press Shift Return to force a carriage return):
|
||||||
|
|
||||||
|
```
|
||||||
|
0.9,0.8,0.7
|
||||||
|
1.2,1.1,1.0
|
||||||
|
```
|
||||||
|
You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio Web interface. These presets will save the *comment* part of the prompt that should contain some instructions how to use the corresponding the loras (for instance by specifying a trigger word or providing an example).A comment in the prompt is a line that starts that a #. It will be ignored by the video generator. For instance:
|
||||||
|
|
||||||
|
```
|
||||||
|
# use they keyword ohnvx to trigger the Lora*
|
||||||
|
A ohnvx is driving a car
|
||||||
|
```
|
||||||
|
Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
|
||||||
|
|
||||||
|
Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server:
|
||||||
```bash
|
```bash
|
||||||
python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
|
python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
|
||||||
```
|
```
|
||||||
@ -180,11 +195,11 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
|||||||
--open-browser : open automatically Browser when launching Gradio Server\
|
--open-browser : open automatically Browser when launching Gradio Server\
|
||||||
--lock-config : prevent modifying the video engine configuration from the interface\
|
--lock-config : prevent modifying the video engine configuration from the interface\
|
||||||
--share : create a shareable URL on huggingface so that your server can be accessed remotely\
|
--share : create a shareable URL on huggingface so that your server can be accessed remotely\
|
||||||
--multiple-images : Images as a starting point for new videos\
|
--multiple-images : allow the users to choose multiple images as different starting points for new videos\
|
||||||
--compile : turn on pytorch compilation\
|
--compile : turn on pytorch compilation\
|
||||||
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
||||||
--profile no : default (4) : no of profile between 1 and 5\
|
--profile no : default (4) : no of profile between 1 and 5\
|
||||||
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware
|
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.
|
||||||
|
|
||||||
### Profiles (for power users only)
|
### Profiles (for power users only)
|
||||||
You can choose between 5 profiles, but two are really relevant here :
|
You can choose between 5 profiles, but two are really relevant here :
|
||||||
|
|||||||
254
gradio_server.py
254
gradio_server.py
@ -57,18 +57,18 @@ def _parse_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
# parser.add_argument(
|
||||||
"--lora-dir-i2v",
|
# "--lora-dir-i2v",
|
||||||
type=str,
|
# type=str,
|
||||||
default="loras_i2v",
|
# default="loras_i2v",
|
||||||
help="Path to a directory that contains Loras for i2v"
|
# help="Path to a directory that contains Loras for i2v"
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-dir",
|
"--lora-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="loras",
|
default="",
|
||||||
help="Path to a directory that contains Loras"
|
help="Path to a directory that contains Loras"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -80,12 +80,12 @@ def _parse_args():
|
|||||||
help="Lora preset to preload"
|
help="Lora preset to preload"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
# parser.add_argument(
|
||||||
"--lora-preset-i2v",
|
# "--lora-preset-i2v",
|
||||||
type=str,
|
# type=str,
|
||||||
default="",
|
# default="",
|
||||||
help="Lora preset to preload for i2v"
|
# help="Lora preset to preload for i2v"
|
||||||
)
|
# )
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile",
|
"--profile",
|
||||||
@ -198,6 +198,7 @@ if not Path(server_config_filename).is_file():
|
|||||||
"text_encoder_filename" : text_encoder_choices[1],
|
"text_encoder_filename" : text_encoder_choices[1],
|
||||||
"compile" : "",
|
"compile" : "",
|
||||||
"default_ui": "t2v",
|
"default_ui": "t2v",
|
||||||
|
"boost" : 1,
|
||||||
"vae_config": 0,
|
"vae_config": 0,
|
||||||
"profile" : profile_type.LowRAM_LowVRAM }
|
"profile" : profile_type.LowRAM_LowVRAM }
|
||||||
|
|
||||||
@ -223,6 +224,7 @@ if len(args.attention)> 0:
|
|||||||
|
|
||||||
profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
|
profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
|
||||||
compile = server_config.get("compile", "")
|
compile = server_config.get("compile", "")
|
||||||
|
boost = server_config.get("boost", 1)
|
||||||
vae_config = server_config.get("vae_config", 0)
|
vae_config = server_config.get("vae_config", 0)
|
||||||
if len(args.vae_config) > 0:
|
if len(args.vae_config) > 0:
|
||||||
vae_config = int(args.vae_config)
|
vae_config = int(args.vae_config)
|
||||||
@ -234,13 +236,14 @@ if args.t2v:
|
|||||||
if args.i2v:
|
if args.i2v:
|
||||||
use_image2video = True
|
use_image2video = True
|
||||||
|
|
||||||
if use_image2video:
|
# if use_image2video:
|
||||||
lora_dir =args.lora_dir_i2v
|
# lora_dir =args.lora_dir_i2v
|
||||||
lora_preselected_preset = args.lora_preset_i2v
|
# lora_preselected_preset = args.lora_preset_i2v
|
||||||
else:
|
# else:
|
||||||
lora_dir =args.lora_dir
|
lora_dir =args.lora_dir
|
||||||
lora_preselected_preset = args.lora_preset
|
if len(lora_dir) ==0:
|
||||||
|
lora_dir = "loras_i2v" if use_image2video else "loras"
|
||||||
|
lora_preselected_preset = args.lora_preset
|
||||||
default_tea_cache = 0
|
default_tea_cache = 0
|
||||||
# if args.fast : #or args.fastest
|
# if args.fast : #or args.fastest
|
||||||
# transformer_filename_t2v = transformer_choices_t2v[2]
|
# transformer_filename_t2v = transformer_choices_t2v[2]
|
||||||
@ -321,8 +324,16 @@ def extract_preset(lset_name, loras):
|
|||||||
raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
|
raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
|
||||||
|
|
||||||
loras_mult_choices = lset["loras_mult"]
|
loras_mult_choices = lset["loras_mult"]
|
||||||
return loras_choices, loras_mult_choices
|
prompt = lset.get("prompt", "")
|
||||||
|
return loras_choices, loras_mult_choices, prompt, lset.get("full_prompt", False)
|
||||||
|
|
||||||
|
def get_default_prompt(i2v):
|
||||||
|
if i2v:
|
||||||
|
return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field."
|
||||||
|
else:
|
||||||
|
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
|
||||||
|
|
||||||
|
|
||||||
def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
|
def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
|
||||||
loras =[]
|
loras =[]
|
||||||
loras_names = []
|
loras_names = []
|
||||||
@ -337,7 +348,7 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
|
|||||||
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
||||||
|
|
||||||
default_lora_preset = ""
|
default_lora_preset = ""
|
||||||
|
default_prompt = ""
|
||||||
if lora_dir != None:
|
if lora_dir != None:
|
||||||
import glob
|
import glob
|
||||||
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
||||||
@ -350,15 +361,16 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
|
|||||||
|
|
||||||
if len(loras) > 0:
|
if len(loras) > 0:
|
||||||
loras_names = [ Path(lora).stem for lora in loras ]
|
loras_names = [ Path(lora).stem for lora in loras ]
|
||||||
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
||||||
|
|
||||||
if len(lora_preselected_preset) > 0:
|
if len(lora_preselected_preset) > 0:
|
||||||
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
||||||
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
||||||
default_lora_preset = lora_preselected_preset
|
default_lora_preset = lora_preselected_preset
|
||||||
default_loras_choices, default_loras_multis_str= extract_preset(default_lora_preset, loras)
|
default_loras_choices, default_loras_multis_str, default_prompt, _ = extract_preset(default_lora_preset, loras)
|
||||||
|
if len(default_prompt) == 0:
|
||||||
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
|
default_prompt = get_default_prompt(use_image2video)
|
||||||
|
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
||||||
|
|
||||||
|
|
||||||
def load_t2v_model(model_filename, value):
|
def load_t2v_model(model_filename, value):
|
||||||
@ -439,13 +451,13 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
|
|||||||
kwargs["budgets"] = { "*" : "70%" }
|
kwargs["budgets"] = { "*" : "70%" }
|
||||||
|
|
||||||
|
|
||||||
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
|
loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
|
||||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
|
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
|
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
||||||
|
|
||||||
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
||||||
gen_in_progress = False
|
gen_in_progress = False
|
||||||
|
|
||||||
def get_auto_attention():
|
def get_auto_attention():
|
||||||
@ -487,13 +499,14 @@ def apply_changes( state,
|
|||||||
profile_choice,
|
profile_choice,
|
||||||
vae_config_choice,
|
vae_config_choice,
|
||||||
default_ui_choice ="t2v",
|
default_ui_choice ="t2v",
|
||||||
|
boost_choice = 1
|
||||||
):
|
):
|
||||||
if args.lock_config:
|
if args.lock_config:
|
||||||
return
|
return
|
||||||
if gen_in_progress:
|
if gen_in_progress:
|
||||||
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
||||||
return
|
return
|
||||||
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
|
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
||||||
server_config = {"attention_mode" : attention_choice,
|
server_config = {"attention_mode" : attention_choice,
|
||||||
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
|
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
|
||||||
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
|
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
|
||||||
@ -502,6 +515,7 @@ def apply_changes( state,
|
|||||||
"profile" : profile_choice,
|
"profile" : profile_choice,
|
||||||
"vae_config" : vae_config_choice,
|
"vae_config" : vae_config_choice,
|
||||||
"default_ui" : default_ui_choice,
|
"default_ui" : default_ui_choice,
|
||||||
|
"boost" : boost_choice,
|
||||||
}
|
}
|
||||||
|
|
||||||
if Path(server_config_filename).is_file():
|
if Path(server_config_filename).is_file():
|
||||||
@ -529,7 +543,7 @@ def apply_changes( state,
|
|||||||
state["config_new"] = server_config
|
state["config_new"] = server_config
|
||||||
state["config_old"] = old_server_config
|
state["config_old"] = old_server_config
|
||||||
|
|
||||||
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config
|
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost
|
||||||
attention_mode = server_config["attention_mode"]
|
attention_mode = server_config["attention_mode"]
|
||||||
profile = server_config["profile"]
|
profile = server_config["profile"]
|
||||||
compile = server_config["compile"]
|
compile = server_config["compile"]
|
||||||
@ -537,8 +551,8 @@ def apply_changes( state,
|
|||||||
transformer_filename_i2v = server_config["transformer_filename_i2v"]
|
transformer_filename_i2v = server_config["transformer_filename_i2v"]
|
||||||
text_encoder_filename = server_config["text_encoder_filename"]
|
text_encoder_filename = server_config["text_encoder_filename"]
|
||||||
vae_config = server_config["vae_config"]
|
vae_config = server_config["vae_config"]
|
||||||
|
boost = server_config["boost"]
|
||||||
if all(change in ["attention_mode", "vae_config", "default_ui"] for change in changes ):
|
if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ):
|
||||||
if "attention_mode" in changes:
|
if "attention_mode" in changes:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -548,7 +562,7 @@ def apply_changes( state,
|
|||||||
offloadobj = None
|
offloadobj = None
|
||||||
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
||||||
|
|
||||||
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
||||||
|
|
||||||
|
|
||||||
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
|
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
|
||||||
@ -727,7 +741,9 @@ def generate_video(
|
|||||||
if len(prompt) ==0:
|
if len(prompt) ==0:
|
||||||
return
|
return
|
||||||
prompts = prompt.replace("\r", "").split("\n")
|
prompts = prompt.replace("\r", "").split("\n")
|
||||||
|
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
||||||
|
if len(prompts) ==0:
|
||||||
|
return
|
||||||
if use_image2video:
|
if use_image2video:
|
||||||
if image_to_continue is not None:
|
if image_to_continue is not None:
|
||||||
if isinstance(image_to_continue, list):
|
if isinstance(image_to_continue, list):
|
||||||
@ -772,6 +788,9 @@ def generate_video(
|
|||||||
return False
|
return False
|
||||||
list_mult_choices_nums = []
|
list_mult_choices_nums = []
|
||||||
if len(loras_mult_choices) > 0:
|
if len(loras_mult_choices) > 0:
|
||||||
|
loras_mult_choices_list = loras_mult_choices.replace("\r", "").split("\n")
|
||||||
|
loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")]
|
||||||
|
loras_mult_choices = " ".join(loras_mult_choices_list)
|
||||||
list_mult_choices_str = loras_mult_choices.split(" ")
|
list_mult_choices_str = loras_mult_choices.split(" ")
|
||||||
for i, mult in enumerate(list_mult_choices_str):
|
for i, mult in enumerate(list_mult_choices_str):
|
||||||
mult = mult.strip()
|
mult = mult.strip()
|
||||||
@ -805,18 +824,36 @@ def generate_video(
|
|||||||
# VAE Tiling
|
# VAE Tiling
|
||||||
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
||||||
|
|
||||||
|
joint_pass = boost ==1
|
||||||
# TeaCache
|
# TeaCache
|
||||||
trans = wan_model.model
|
trans = wan_model.model
|
||||||
trans.enable_teacache = tea_cache > 0
|
trans.enable_teacache = tea_cache > 0
|
||||||
|
if trans.enable_teacache:
|
||||||
|
if use_image2video:
|
||||||
|
if '480p' in transformer_filename_i2v:
|
||||||
|
# teacache_thresholds = [0.13, .19, 0.26]
|
||||||
|
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||||
|
elif '720p' in transformer_filename_i2v:
|
||||||
|
teacache_thresholds = [0.18, 0.2 , 0.3]
|
||||||
|
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||||
|
else:
|
||||||
|
raise gr.Error("Teacache not supported for this model")
|
||||||
|
else:
|
||||||
|
if '1.3B' in transformer_filename_t2v:
|
||||||
|
# teacache_thresholds= [0.05, 0.07, 0.08]
|
||||||
|
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
||||||
|
elif '14B' in transformer_filename_t2v:
|
||||||
|
# teacache_thresholds = [0.14, 0.15, 0.2]
|
||||||
|
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
||||||
|
else:
|
||||||
|
raise gr.Error("Teacache not supported for this model")
|
||||||
|
|
||||||
import random
|
import random
|
||||||
if seed == None or seed <0:
|
if seed == None or seed <0:
|
||||||
seed = random.randint(0, 999999999)
|
seed = random.randint(0, 999999999)
|
||||||
|
|
||||||
file_list = []
|
file_list = []
|
||||||
state["file_list"] = file_list
|
state["file_list"] = file_list
|
||||||
from einops import rearrange
|
|
||||||
save_path = os.path.join(os.getcwd(), "gradio_outputs")
|
save_path = os.path.join(os.getcwd(), "gradio_outputs")
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
video_no = 0
|
video_no = 0
|
||||||
@ -830,14 +867,12 @@ def generate_video(
|
|||||||
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_teacache:
|
||||||
trans.teacache_counter = 0
|
trans.teacache_counter = 0
|
||||||
trans.rel_l1_thresh = tea_cache
|
trans.teacache_multiplier = tea_cache
|
||||||
trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2)
|
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
||||||
|
trans.num_steps = num_inference_steps
|
||||||
|
trans.teacache_skipped_steps = 0
|
||||||
trans.previous_residual_uncond = None
|
trans.previous_residual_uncond = None
|
||||||
trans.previous_modulated_input_uncond = None
|
|
||||||
trans.previous_residual_cond = None
|
trans.previous_residual_cond = None
|
||||||
trans.previous_modulated_input_cond= None
|
|
||||||
|
|
||||||
trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
|
|
||||||
|
|
||||||
video_no += 1
|
video_no += 1
|
||||||
status = f"Video {video_no}/{total_video}"
|
status = f"Video {video_no}/{total_video}"
|
||||||
@ -853,7 +888,7 @@ def generate_video(
|
|||||||
if use_image2video:
|
if use_image2video:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
prompt,
|
prompt,
|
||||||
image_to_continue[ (video_no-1) % len(image_to_continue)],
|
image_to_continue[ (video_no-1) % len(image_to_continue)].convert('RGB'),
|
||||||
frame_num=(video_length // 4)* 4 + 1,
|
frame_num=(video_length // 4)* 4 + 1,
|
||||||
max_area=MAX_AREA_CONFIGS[resolution],
|
max_area=MAX_AREA_CONFIGS[resolution],
|
||||||
shift=flow_shift,
|
shift=flow_shift,
|
||||||
@ -864,7 +899,8 @@ def generate_video(
|
|||||||
offload_model=False,
|
offload_model=False,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
enable_RIFLEx = enable_RIFLEx,
|
enable_RIFLEx = enable_RIFLEx,
|
||||||
VAE_tile_size = VAE_tile_size
|
VAE_tile_size = VAE_tile_size,
|
||||||
|
joint_pass = joint_pass,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -880,7 +916,8 @@ def generate_video(
|
|||||||
offload_model=False,
|
offload_model=False,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
enable_RIFLEx = enable_RIFLEx,
|
enable_RIFLEx = enable_RIFLEx,
|
||||||
VAE_tile_size = VAE_tile_size
|
VAE_tile_size = VAE_tile_size,
|
||||||
|
joint_pass = joint_pass,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
gen_in_progress = False
|
gen_in_progress = False
|
||||||
@ -911,6 +948,7 @@ def generate_video(
|
|||||||
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
||||||
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_teacache:
|
||||||
|
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
||||||
trans.previous_residual_uncond = None
|
trans.previous_residual_uncond = None
|
||||||
trans.previous_residual_cond = None
|
trans.previous_residual_cond = None
|
||||||
|
|
||||||
@ -957,7 +995,25 @@ def generate_video(
|
|||||||
|
|
||||||
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
||||||
|
|
||||||
def save_lset(lset_name, loras_choices, loras_mult_choices):
|
|
||||||
|
def validate_delete_lset(lset_name):
|
||||||
|
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
||||||
|
gr.Info(f"Choose a Preset to delete")
|
||||||
|
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
|
||||||
|
else:
|
||||||
|
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
|
||||||
|
|
||||||
|
def validate_save_lset(lset_name):
|
||||||
|
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
||||||
|
gr.Info("Please enter a name for the preset")
|
||||||
|
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
|
||||||
|
else:
|
||||||
|
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
|
||||||
|
|
||||||
|
def cancel_lset():
|
||||||
|
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
||||||
|
|
||||||
|
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
|
||||||
global loras_presets
|
global loras_presets
|
||||||
|
|
||||||
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
||||||
@ -968,6 +1024,16 @@ def save_lset(lset_name, loras_choices, loras_mult_choices):
|
|||||||
|
|
||||||
loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
|
loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
|
||||||
lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
|
lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
|
||||||
|
if save_lset_prompt_cbox!=1:
|
||||||
|
prompts = prompt.replace("\r", "").split("\n")
|
||||||
|
prompts = [prompt for prompt in prompts if len(prompt)> 0 and prompt.startswith("#")]
|
||||||
|
prompt = "\n".join(prompts)
|
||||||
|
|
||||||
|
if len(prompt) > 0:
|
||||||
|
lset["prompt"] = prompt
|
||||||
|
lset["full_prompt"] = save_lset_prompt_cbox ==1
|
||||||
|
|
||||||
|
|
||||||
lset_name_filename = lset_name + ".lset"
|
lset_name_filename = lset_name + ".lset"
|
||||||
full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
|
full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
|
||||||
|
|
||||||
@ -982,7 +1048,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices):
|
|||||||
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
||||||
lset_choices.append( (new_preset_msg, ""))
|
lset_choices.append( (new_preset_msg, ""))
|
||||||
|
|
||||||
return gr.Dropdown(choices=lset_choices, value= lset_name)
|
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
||||||
|
|
||||||
def delete_lset(lset_name):
|
def delete_lset(lset_name):
|
||||||
global loras_presets
|
global loras_presets
|
||||||
@ -1000,23 +1066,31 @@ def delete_lset(lset_name):
|
|||||||
|
|
||||||
lset_choices = [ (preset, preset) for preset in loras_presets]
|
lset_choices = [ (preset, preset) for preset in loras_presets]
|
||||||
lset_choices.append((new_preset_msg, ""))
|
lset_choices.append((new_preset_msg, ""))
|
||||||
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1])
|
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
|
||||||
|
|
||||||
def apply_lset(lset_name, loras_choices, loras_mult_choices):
|
|
||||||
|
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
|
||||||
|
|
||||||
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
||||||
gr.Info("Please choose a preset in the list or create one")
|
gr.Info("Please choose a preset in the list or create one")
|
||||||
else:
|
else:
|
||||||
loras_choices, loras_mult_choices= extract_preset(lset_name, loras)
|
loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras)
|
||||||
|
if full_prompt:
|
||||||
|
prompt = preset_prompt
|
||||||
|
elif len(preset_prompt) > 0:
|
||||||
|
prompts = prompt.replace("\r", "").split("\n")
|
||||||
|
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
||||||
|
prompt = "\n".join(prompts)
|
||||||
|
prompt = preset_prompt + '\n' + prompt
|
||||||
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
||||||
|
|
||||||
return loras_choices, loras_mult_choices
|
return loras_choices, loras_mult_choices, prompt
|
||||||
|
|
||||||
def create_demo():
|
def create_demo():
|
||||||
|
|
||||||
default_inference_steps = 30
|
default_inference_steps = 30
|
||||||
|
|
||||||
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
state = gr.State({})
|
state = gr.State({})
|
||||||
|
|
||||||
@ -1130,6 +1204,16 @@ def create_demo():
|
|||||||
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
|
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
boost_choice = gr.Dropdown(
|
||||||
|
choices=[
|
||||||
|
# ("Auto (ON if Video longer than 5s)", 0),
|
||||||
|
("ON", 1),
|
||||||
|
("OFF", 2),
|
||||||
|
],
|
||||||
|
value=boost,
|
||||||
|
label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)"
|
||||||
|
)
|
||||||
|
|
||||||
profile_choice = gr.Dropdown(
|
profile_choice = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
|
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
|
||||||
@ -1161,16 +1245,12 @@ def create_demo():
|
|||||||
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
||||||
if args.multiple_images:
|
if args.multiple_images:
|
||||||
image_to_continue = gr.Gallery(
|
image_to_continue = gr.Gallery(
|
||||||
label="Images as a starting point for new videos", type ="numpy", #file_types= "image",
|
label="Images as a starting point for new videos", type ="pil", #file_types= "image",
|
||||||
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
|
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
|
||||||
else:
|
else:
|
||||||
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
|
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
|
||||||
|
|
||||||
if use_image2video:
|
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos, lines that starts with # are ignored)", value=default_prompt, lines=3)
|
||||||
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", lines=3)
|
|
||||||
else:
|
|
||||||
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.", lines=3)
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
if use_image2video:
|
if use_image2video:
|
||||||
@ -1223,9 +1303,21 @@ def create_demo():
|
|||||||
# with gr.Column():
|
# with gr.Column():
|
||||||
with gr.Row(height=17):
|
with gr.Row(height=17):
|
||||||
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
||||||
|
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
|
||||||
|
save_lset_prompt_drop= gr.Dropdown(
|
||||||
|
choices=[
|
||||||
|
("Save Prompt Comments Only", 0),
|
||||||
|
("Save Full Prompt", 1)
|
||||||
|
], show_label= False, container=False, visible= False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
with gr.Row(height=17):
|
with gr.Row(height=17):
|
||||||
|
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
|
||||||
|
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
|
||||||
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
||||||
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
||||||
|
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
|
||||||
|
|
||||||
|
|
||||||
loras_choices = gr.Dropdown(
|
loras_choices = gr.Dropdown(
|
||||||
@ -1237,7 +1329,7 @@ def create_demo():
|
|||||||
visible= len(loras)>0,
|
visible= len(loras)>0,
|
||||||
label="Activated Loras"
|
label="Activated Loras"
|
||||||
)
|
)
|
||||||
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns", value=default_loras_multis_str, visible= len(loras)>0 )
|
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
|
||||||
|
|
||||||
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
|
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
|
||||||
with gr.Row(visible=False) as advanced_row:
|
with gr.Row(visible=False) as advanced_row:
|
||||||
@ -1250,18 +1342,23 @@ def create_demo():
|
|||||||
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
|
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
|
||||||
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
||||||
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
||||||
tea_cache_setting = gr.Dropdown(
|
with gr.Row():
|
||||||
choices=[
|
gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)")
|
||||||
("Tea Cache Disabled", 0),
|
with gr.Row():
|
||||||
("0.03 (around x1.6 speed up)", 0.03),
|
tea_cache_setting = gr.Dropdown(
|
||||||
("0.05 (around x2 speed up)", 0.05),
|
choices=[
|
||||||
("0.10 (around x3 speed up)", 0.1),
|
("Tea Cache Disabled", 0),
|
||||||
],
|
("around x1.5 speed up", 1.5),
|
||||||
value=default_tea_cache,
|
("around x1.75 speed up", 1.75),
|
||||||
visible=True,
|
("around x2 speed up", 2.0),
|
||||||
label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)"
|
("around x2.25 speed up", 2.25),
|
||||||
)
|
("around x2.5 speed up", 2.5),
|
||||||
tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)")
|
],
|
||||||
|
value=default_tea_cache,
|
||||||
|
visible=True,
|
||||||
|
label="Tea Cache Global Acceleration"
|
||||||
|
)
|
||||||
|
tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
|
||||||
|
|
||||||
RIFLEx_setting = gr.Dropdown(
|
RIFLEx_setting = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
@ -1283,9 +1380,13 @@ def create_demo():
|
|||||||
generate_btn = gr.Button("Generate")
|
generate_btn = gr.Button("Generate")
|
||||||
abort_btn = gr.Button("Abort")
|
abort_btn = gr.Button("Abort")
|
||||||
|
|
||||||
save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices], outputs=[lset_name])
|
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
||||||
delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name])
|
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
||||||
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices], outputs=[loras_choices, loras_mult_choices])
|
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
||||||
|
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
||||||
|
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
|
||||||
|
|
||||||
|
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
|
||||||
|
|
||||||
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
||||||
|
|
||||||
@ -1335,6 +1436,7 @@ def create_demo():
|
|||||||
profile_choice,
|
profile_choice,
|
||||||
vae_config_choice,
|
vae_config_choice,
|
||||||
default_ui_choice,
|
default_ui_choice,
|
||||||
|
boost_choice,
|
||||||
],
|
],
|
||||||
outputs= msg
|
outputs= msg
|
||||||
).then(
|
).then(
|
||||||
|
|||||||
@ -146,7 +146,7 @@ class WanI2V:
|
|||||||
callback = None,
|
callback = None,
|
||||||
enable_RIFLEx = False,
|
enable_RIFLEx = False,
|
||||||
VAE_tile_size= 0,
|
VAE_tile_size= 0,
|
||||||
|
joint_pass = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Generates video frames from input image and text prompt using diffusion process.
|
Generates video frames from input image and text prompt using diffusion process.
|
||||||
@ -310,9 +310,22 @@ class WanI2V:
|
|||||||
'pipeline' : self
|
'pipeline' : self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
arg_both= {
|
||||||
|
'context': [context[0]],
|
||||||
|
'context2': context_null,
|
||||||
|
'clip_fea': clip_context,
|
||||||
|
'seq_len': max_seq_len,
|
||||||
|
'y': [y],
|
||||||
|
'freqs' : freqs,
|
||||||
|
'pipeline' : self
|
||||||
|
}
|
||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self.model.enable_teacache:
|
||||||
|
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None)
|
callback(-1, None)
|
||||||
@ -323,17 +336,22 @@ class WanI2V:
|
|||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
timestep = torch.stack(timestep).to(self.device)
|
timestep = torch.stack(timestep).to(self.device)
|
||||||
|
if joint_pass:
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond, noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
latent_model_input, t=timestep, current_step=i, **arg_both)
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
if offload_model:
|
else:
|
||||||
torch.cuda.empty_cache()
|
noise_pred_cond = self.model(
|
||||||
noise_pred_uncond = self.model(
|
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
||||||
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
if self._interrupt:
|
||||||
if self._interrupt:
|
return None
|
||||||
return None
|
if offload_model:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
noise_pred_uncond = self.model(
|
||||||
|
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
||||||
|
if self._interrupt:
|
||||||
|
return None
|
||||||
del latent_model_input
|
del latent_model_input
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
@ -667,7 +667,43 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
|
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
|
||||||
|
|
||||||
|
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
||||||
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
|
e_list = []
|
||||||
|
for t in timesteps:
|
||||||
|
t = torch.stack([t])
|
||||||
|
e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)))
|
||||||
|
|
||||||
|
best_threshold = 0.01
|
||||||
|
best_diff = 1000
|
||||||
|
target_nb_steps= int(len(timesteps) / speed_factor)
|
||||||
|
threshold = 0.01
|
||||||
|
while threshold <= 0.6:
|
||||||
|
accumulated_rel_l1_distance =0
|
||||||
|
nb_steps = 0
|
||||||
|
diff = 1000
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
skip = False
|
||||||
|
if not (i<=start_step or i== len(timesteps)):
|
||||||
|
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
|
||||||
|
if accumulated_rel_l1_distance < threshold:
|
||||||
|
skip = True
|
||||||
|
else:
|
||||||
|
accumulated_rel_l1_distance = 0
|
||||||
|
previous_modulated_input = e_list[i]
|
||||||
|
if not skip:
|
||||||
|
nb_steps += 1
|
||||||
|
diff = abs(target_nb_steps - nb_steps)
|
||||||
|
if diff < best_diff:
|
||||||
|
best_threshold = threshold
|
||||||
|
best_diff = diff
|
||||||
|
elif diff > best_diff:
|
||||||
|
break
|
||||||
|
threshold += 0.01
|
||||||
|
self.rel_l1_thresh = best_threshold
|
||||||
|
print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(len(timesteps) - best_diff):0.1f} for a target of x{speed_factor}")
|
||||||
|
return best_threshold
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -679,6 +715,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
freqs = None,
|
freqs = None,
|
||||||
pipeline = None,
|
pipeline = None,
|
||||||
current_step = 0,
|
current_step = 0,
|
||||||
|
context2 = None,
|
||||||
is_uncond=False
|
is_uncond=False
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@ -722,10 +759,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||||
assert seq_lens.max() <= seq_len
|
assert seq_lens.max() <= seq_len
|
||||||
x = torch.cat([
|
if len(x)==1 and seq_len == x[0].size(1):
|
||||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
x = x[0]
|
||||||
dim=1) for u in x
|
else:
|
||||||
])
|
x = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in x
|
||||||
|
])
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
@ -740,82 +780,105 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||||
for u in context
|
for u in context
|
||||||
]))
|
]))
|
||||||
|
if context2!=None:
|
||||||
|
context2 = self.text_embedding(
|
||||||
|
torch.stack([
|
||||||
|
torch.cat(
|
||||||
|
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||||
|
for u in context2
|
||||||
|
]))
|
||||||
|
|
||||||
if clip_fea is not None:
|
if clip_fea is not None:
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
# deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
if context2 != None:
|
||||||
|
context2 = torch.concat([context_clip, context2], dim=1)
|
||||||
|
|
||||||
|
joint_pass = context2 != None
|
||||||
|
if joint_pass:
|
||||||
|
x_list = [x, x.clone()]
|
||||||
|
context_list = [context, context2]
|
||||||
|
is_uncond = False
|
||||||
|
else:
|
||||||
|
x_list = [x]
|
||||||
|
context_list = [context]
|
||||||
|
del x
|
||||||
should_calc = True
|
should_calc = True
|
||||||
if self.enable_teacache and current_step >= self.teacache_start_step:
|
if self.enable_teacache:
|
||||||
if current_step == self.teacache_start_step:
|
if is_uncond:
|
||||||
self.accumulated_rel_l1_distance_cond = 0
|
should_calc = self.should_calc
|
||||||
self.accumulated_rel_l1_distance_uncond = 0
|
|
||||||
self.teacache_skipped_cond_steps = 0
|
|
||||||
self.teacache_skipped_uncond_steps = 0
|
|
||||||
else:
|
else:
|
||||||
prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond
|
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
|
||||||
acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
|
|
||||||
|
|
||||||
temb_relative_l1 = relative_l1_distance(prev_input, e0)
|
|
||||||
setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
|
|
||||||
|
|
||||||
if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
|
|
||||||
should_calc = False
|
|
||||||
self.teacache_counter += 1
|
|
||||||
else:
|
|
||||||
should_calc = True
|
should_calc = True
|
||||||
setattr(self, acc_distance_attr, 0)
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
|
||||||
if is_uncond:
|
|
||||||
self.previous_modulated_input_uncond = e0.clone()
|
|
||||||
if should_calc:
|
|
||||||
self.previous_residual_uncond = None
|
|
||||||
else:
|
else:
|
||||||
x += self.previous_residual_uncond
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
self.teacache_skipped_cond_steps += 1
|
self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||||
# print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" )
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||||
else:
|
should_calc = False
|
||||||
self.previous_modulated_input_cond = e0.clone()
|
self.teacache_skipped_steps += 1
|
||||||
if should_calc:
|
# print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" )
|
||||||
self.previous_residual_cond = None
|
else:
|
||||||
else:
|
should_calc = True
|
||||||
x += self.previous_residual_cond
|
self.accumulated_rel_l1_distance = 0
|
||||||
self.teacache_skipped_uncond_steps += 1
|
self.previous_modulated_input = e
|
||||||
# print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" )
|
self.should_calc = should_calc
|
||||||
|
|
||||||
if should_calc:
|
if not should_calc:
|
||||||
|
for i, x in enumerate(x_list):
|
||||||
|
x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond
|
||||||
|
else:
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
ori_hidden_states = x.clone()
|
if joint_pass or is_uncond:
|
||||||
|
self.previous_residual_uncond = None
|
||||||
|
if joint_pass or not is_uncond:
|
||||||
|
self.previous_residual_cond = None
|
||||||
|
ori_hidden_states = x_list[0].clone()
|
||||||
# arguments
|
# arguments
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
e=e0,
|
# e=e0,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
grid_sizes=grid_sizes,
|
grid_sizes=grid_sizes,
|
||||||
freqs=freqs,
|
freqs=freqs,
|
||||||
context=context,
|
# context=context,
|
||||||
context_lens=context_lens)
|
context_lens=context_lens)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
if pipeline._interrupt:
|
if pipeline._interrupt:
|
||||||
return [None]
|
if joint_pass:
|
||||||
|
return None, None
|
||||||
x = block(x, **kwargs)
|
else:
|
||||||
|
return [None]
|
||||||
|
for i, (x, context) in enumerate(zip(x_list, context_list)):
|
||||||
|
x_list[i] = block(x, context = context, e= e0, **kwargs)
|
||||||
|
del x
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
residual = ori_hidden_states # just to have a readable code
|
if joint_pass:
|
||||||
torch.sub(x, ori_hidden_states, out=residual)
|
self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states)
|
||||||
if is_uncond:
|
self.previous_residual_uncond = ori_hidden_states
|
||||||
self.previous_residual_uncond = residual
|
torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond)
|
||||||
else:
|
else:
|
||||||
self.previous_residual_cond = residual
|
residual = ori_hidden_states # just to have a readable code
|
||||||
del residual, ori_hidden_states
|
torch.sub(x_list[0], ori_hidden_states, out=residual)
|
||||||
|
if i==1 or is_uncond:
|
||||||
|
self.previous_residual_uncond = residual
|
||||||
|
else:
|
||||||
|
self.previous_residual_cond = residual
|
||||||
|
residual, ori_hidden_states = None, None
|
||||||
|
|
||||||
# head
|
for i, x in enumerate(x_list):
|
||||||
x = self.head(x, e)
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x_list[i] = self.unpatchify(x, grid_sizes)
|
||||||
return [u.float() for u in x]
|
del x
|
||||||
|
|
||||||
|
if joint_pass:
|
||||||
|
return x_list[0][0], x_list[1][0]
|
||||||
|
else:
|
||||||
|
return [u.float() for u in x_list[0]]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@ -131,7 +131,8 @@ class WanT2V:
|
|||||||
offload_model=True,
|
offload_model=True,
|
||||||
callback = None,
|
callback = None,
|
||||||
enable_RIFLEx = None,
|
enable_RIFLEx = None,
|
||||||
VAE_tile_size = 0
|
VAE_tile_size = 0,
|
||||||
|
joint_pass = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Generates video frames from text prompt using diffusion process.
|
Generates video frames from text prompt using diffusion process.
|
||||||
@ -240,8 +241,10 @@ class WanT2V:
|
|||||||
freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
|
freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
|
||||||
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
|
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
|
|
||||||
|
if self.model.enable_teacache:
|
||||||
|
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None)
|
callback(-1, None)
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
@ -251,14 +254,20 @@ class WanT2V:
|
|||||||
timestep = torch.stack(timestep)
|
timestep = torch.stack(timestep)
|
||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
noise_pred_cond = self.model(
|
if joint_pass:
|
||||||
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
|
noise_pred_cond, noise_pred_uncond = self.model(
|
||||||
if self._interrupt:
|
latent_model_input, t=timestep,current_step=i, **arg_both)
|
||||||
return None
|
if self._interrupt:
|
||||||
noise_pred_uncond = self.model(
|
return None
|
||||||
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
|
else:
|
||||||
if self._interrupt:
|
noise_pred_cond = self.model(
|
||||||
return None
|
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
|
||||||
|
if self._interrupt:
|
||||||
|
return None
|
||||||
|
noise_pred_uncond = self.model(
|
||||||
|
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
|
||||||
|
if self._interrupt:
|
||||||
|
return None
|
||||||
|
|
||||||
del latent_model_input
|
del latent_model_input
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user