From f8d9edeb5005924c6a436901c1f1cd4dc09b9f5d Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 10 Mar 2025 23:26:42 +0100 Subject: [PATCH] Added 10% boost, improved Loras and Teacache --- README.md | 35 ++++-- gradio_server.py | 254 ++++++++++++++++++++++++++++++------------- wan/image2video.py | 42 +++++-- wan/modules/model.py | 177 ++++++++++++++++++++---------- wan/text2video.py | 29 +++-- 5 files changed, 372 insertions(+), 165 deletions(-) diff --git a/README.md b/README.md index c086297..438de90 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,11 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! -* Mar 03, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated -* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process. +* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user) +* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated +* Mar 04, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process. If you upgrade you will need to do a 'pip install -r requirements.txt' again. -* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end +* Mar 04, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end * Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache) * Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings: - Support for all Wan including the Image to Video model @@ -152,15 +153,29 @@ python gradio_server.py --attention sdpa ### Loras support --- Ready to be used but theoretical as no lora for Wan have been released as of today. --- -Every lora stored in the subfoler 'loras' will be automatically loaded. You will be then able to activate / desactive any of them when running the application. +Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" . -For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0), alternatively you may specify a list of floats multipliers separated by a "," that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7. +For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora shoud be separated by a space character or a carriage return. For instance:\ +*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8. -You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio interface. Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users +Alternatively for each Lora's multiplier you may specify a list of float numbers multipliers separated by a "," (no space) that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7. -Then you can pre activate loras corresponding to a preset when launching the gradio server: +If multiple Loras are defined, remember that each multiplier associated to different Loras should be separated by a space or a carriage return, so we can specify the evolution of multipliers for multiple Loras. For instance for two Loras (press Shift Return to force a carriage return): + +``` +0.9,0.8,0.7 +1.2,1.1,1.0 +``` +You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio Web interface. These presets will save the *comment* part of the prompt that should contain some instructions how to use the corresponding the loras (for instance by specifying a trigger word or providing an example).A comment in the prompt is a line that starts that a #. It will be ignored by the video generator. For instance: + +``` +# use they keyword ohnvx to trigger the Lora* +A ohnvx is driving a car +``` +Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users + +Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server: ```bash python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder ``` @@ -180,11 +195,11 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil --open-browser : open automatically Browser when launching Gradio Server\ --lock-config : prevent modifying the video engine configuration from the interface\ --share : create a shareable URL on huggingface so that your server can be accessed remotely\ ---multiple-images : Images as a starting point for new videos\ +--multiple-images : allow the users to choose multiple images as different starting points for new videos\ --compile : turn on pytorch compilation\ --attention mode: force attention mode among, sdpa, flash, sage, sage2\ --profile no : default (4) : no of profile between 1 and 5\ ---preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware +--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4. ### Profiles (for power users only) You can choose between 5 profiles, but two are really relevant here : diff --git a/gradio_server.py b/gradio_server.py index 8055f52..662ea50 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -57,18 +57,18 @@ def _parse_args(): ) - parser.add_argument( - "--lora-dir-i2v", - type=str, - default="loras_i2v", - help="Path to a directory that contains Loras for i2v" - ) + # parser.add_argument( + # "--lora-dir-i2v", + # type=str, + # default="loras_i2v", + # help="Path to a directory that contains Loras for i2v" + # ) parser.add_argument( "--lora-dir", type=str, - default="loras", + default="", help="Path to a directory that contains Loras" ) @@ -80,12 +80,12 @@ def _parse_args(): help="Lora preset to preload" ) - parser.add_argument( - "--lora-preset-i2v", - type=str, - default="", - help="Lora preset to preload for i2v" - ) + # parser.add_argument( + # "--lora-preset-i2v", + # type=str, + # default="", + # help="Lora preset to preload for i2v" + # ) parser.add_argument( "--profile", @@ -198,6 +198,7 @@ if not Path(server_config_filename).is_file(): "text_encoder_filename" : text_encoder_choices[1], "compile" : "", "default_ui": "t2v", + "boost" : 1, "vae_config": 0, "profile" : profile_type.LowRAM_LowVRAM } @@ -223,6 +224,7 @@ if len(args.attention)> 0: profile = force_profile_no if force_profile_no >=0 else server_config["profile"] compile = server_config.get("compile", "") +boost = server_config.get("boost", 1) vae_config = server_config.get("vae_config", 0) if len(args.vae_config) > 0: vae_config = int(args.vae_config) @@ -234,13 +236,14 @@ if args.t2v: if args.i2v: use_image2video = True -if use_image2video: - lora_dir =args.lora_dir_i2v - lora_preselected_preset = args.lora_preset_i2v -else: - lora_dir =args.lora_dir - lora_preselected_preset = args.lora_preset - +# if use_image2video: +# lora_dir =args.lora_dir_i2v +# lora_preselected_preset = args.lora_preset_i2v +# else: +lora_dir =args.lora_dir +if len(lora_dir) ==0: + lora_dir = "loras_i2v" if use_image2video else "loras" +lora_preselected_preset = args.lora_preset default_tea_cache = 0 # if args.fast : #or args.fastest # transformer_filename_t2v = transformer_choices_t2v[2] @@ -321,8 +324,16 @@ def extract_preset(lset_name, loras): raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}") loras_mult_choices = lset["loras_mult"] - return loras_choices, loras_mult_choices + prompt = lset.get("prompt", "") + return loras_choices, loras_mult_choices, prompt, lset.get("full_prompt", False) +def get_default_prompt(i2v): + if i2v: + return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field." + else: + return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect." + + def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None): loras =[] loras_names = [] @@ -337,7 +348,7 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m raise Exception("--lora-dir should be a path to a directory that contains Loras") default_lora_preset = "" - + default_prompt = "" if lora_dir != None: import glob dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") ) @@ -350,15 +361,16 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m if len(loras) > 0: loras_names = [ Path(lora).stem for lora in loras ] - offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier, + offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier, if len(lora_preselected_preset) > 0: if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")): raise Exception(f"Unknown preset '{lora_preselected_preset}'") default_lora_preset = lora_preselected_preset - default_loras_choices, default_loras_multis_str= extract_preset(default_lora_preset, loras) - - return loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets + default_loras_choices, default_loras_multis_str, default_prompt, _ = extract_preset(default_lora_preset, loras) + if len(default_prompt) == 0: + default_prompt = get_default_prompt(use_image2video) + return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets def load_t2v_model(model_filename, value): @@ -439,13 +451,13 @@ def load_models(i2v, lora_dir, lora_preselected_preset ): kwargs["budgets"] = { "*" : "70%" } - loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None) + loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None) offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs) - return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets + return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets -wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset ) +wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset ) gen_in_progress = False def get_auto_attention(): @@ -487,13 +499,14 @@ def apply_changes( state, profile_choice, vae_config_choice, default_ui_choice ="t2v", + boost_choice = 1 ): if args.lock_config: return if gen_in_progress: yield "
Unable to change config when a generation is in progress
" return - global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets + global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets server_config = {"attention_mode" : attention_choice, "transformer_filename": transformer_choices_t2v[transformer_t2v_choice], "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ########## @@ -502,6 +515,7 @@ def apply_changes( state, "profile" : profile_choice, "vae_config" : vae_config_choice, "default_ui" : default_ui_choice, + "boost" : boost_choice, } if Path(server_config_filename).is_file(): @@ -529,7 +543,7 @@ def apply_changes( state, state["config_new"] = server_config state["config_old"] = old_server_config - global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config + global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost attention_mode = server_config["attention_mode"] profile = server_config["profile"] compile = server_config["compile"] @@ -537,8 +551,8 @@ def apply_changes( state, transformer_filename_i2v = server_config["transformer_filename_i2v"] text_encoder_filename = server_config["text_encoder_filename"] vae_config = server_config["vae_config"] - - if all(change in ["attention_mode", "vae_config", "default_ui"] for change in changes ): + boost = server_config["boost"] + if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ): if "attention_mode" in changes: pass @@ -548,7 +562,7 @@ def apply_changes( state, offloadobj = None yield "
Please wait while the new configuration is being applied
" - wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset ) + wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset ) yield "
The new configuration has been succesfully applied
" @@ -727,7 +741,9 @@ def generate_video( if len(prompt) ==0: return prompts = prompt.replace("\r", "").split("\n") - + prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")] + if len(prompts) ==0: + return if use_image2video: if image_to_continue is not None: if isinstance(image_to_continue, list): @@ -772,6 +788,9 @@ def generate_video( return False list_mult_choices_nums = [] if len(loras_mult_choices) > 0: + loras_mult_choices_list = loras_mult_choices.replace("\r", "").split("\n") + loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] + loras_mult_choices = " ".join(loras_mult_choices_list) list_mult_choices_str = loras_mult_choices.split(" ") for i, mult in enumerate(list_mult_choices_str): mult = mult.strip() @@ -805,18 +824,36 @@ def generate_video( # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 - + joint_pass = boost ==1 # TeaCache trans = wan_model.model trans.enable_teacache = tea_cache > 0 - + if trans.enable_teacache: + if use_image2video: + if '480p' in transformer_filename_i2v: + # teacache_thresholds = [0.13, .19, 0.26] + trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] + elif '720p' in transformer_filename_i2v: + teacache_thresholds = [0.18, 0.2 , 0.3] + trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + else: + raise gr.Error("Teacache not supported for this model") + else: + if '1.3B' in transformer_filename_t2v: + # teacache_thresholds= [0.05, 0.07, 0.08] + trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] + elif '14B' in transformer_filename_t2v: + # teacache_thresholds = [0.14, 0.15, 0.2] + trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + else: + raise gr.Error("Teacache not supported for this model") + import random if seed == None or seed <0: seed = random.randint(0, 999999999) file_list = [] state["file_list"] = file_list - from einops import rearrange save_path = os.path.join(os.getcwd(), "gradio_outputs") os.makedirs(save_path, exist_ok=True) video_no = 0 @@ -830,14 +867,12 @@ def generate_video( if trans.enable_teacache: trans.teacache_counter = 0 - trans.rel_l1_thresh = tea_cache - trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2) + trans.teacache_multiplier = tea_cache + trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100) + trans.num_steps = num_inference_steps + trans.teacache_skipped_steps = 0 trans.previous_residual_uncond = None - trans.previous_modulated_input_uncond = None trans.previous_residual_cond = None - trans.previous_modulated_input_cond= None - - trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu" video_no += 1 status = f"Video {video_no}/{total_video}" @@ -853,7 +888,7 @@ def generate_video( if use_image2video: samples = wan_model.generate( prompt, - image_to_continue[ (video_no-1) % len(image_to_continue)], + image_to_continue[ (video_no-1) % len(image_to_continue)].convert('RGB'), frame_num=(video_length // 4)* 4 + 1, max_area=MAX_AREA_CONFIGS[resolution], shift=flow_shift, @@ -864,7 +899,8 @@ def generate_video( offload_model=False, callback=callback, enable_RIFLEx = enable_RIFLEx, - VAE_tile_size = VAE_tile_size + VAE_tile_size = VAE_tile_size, + joint_pass = joint_pass, ) else: @@ -880,7 +916,8 @@ def generate_video( offload_model=False, callback=callback, enable_RIFLEx = enable_RIFLEx, - VAE_tile_size = VAE_tile_size + VAE_tile_size = VAE_tile_size, + joint_pass = joint_pass, ) except Exception as e: gen_in_progress = False @@ -911,6 +948,7 @@ def generate_video( raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'") if trans.enable_teacache: + print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) trans.previous_residual_uncond = None trans.previous_residual_cond = None @@ -957,7 +995,25 @@ def generate_video( new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above" -def save_lset(lset_name, loras_choices, loras_mult_choices): + +def validate_delete_lset(lset_name): + if len(lset_name) == 0 or lset_name == new_preset_msg: + gr.Info(f"Choose a Preset to delete") + return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False) + else: + return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True) + +def validate_save_lset(lset_name): + if len(lset_name) == 0 or lset_name == new_preset_msg: + gr.Info("Please enter a name for the preset") + return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False) + else: + return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True) + +def cancel_lset(): + return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) + +def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox): global loras_presets if len(lset_name) == 0 or lset_name== new_preset_msg: @@ -968,6 +1024,16 @@ def save_lset(lset_name, loras_choices, loras_mult_choices): loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ] lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices} + if save_lset_prompt_cbox!=1: + prompts = prompt.replace("\r", "").split("\n") + prompts = [prompt for prompt in prompts if len(prompt)> 0 and prompt.startswith("#")] + prompt = "\n".join(prompts) + + if len(prompt) > 0: + lset["prompt"] = prompt + lset["full_prompt"] = save_lset_prompt_cbox ==1 + + lset_name_filename = lset_name + ".lset" full_lset_name_filename = os.path.join(lora_dir, lset_name_filename) @@ -982,7 +1048,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices): lset_choices = [ ( preset, preset) for preset in loras_presets ] lset_choices.append( (new_preset_msg, "")) - return gr.Dropdown(choices=lset_choices, value= lset_name) + return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) def delete_lset(lset_name): global loras_presets @@ -1000,23 +1066,31 @@ def delete_lset(lset_name): lset_choices = [ (preset, preset) for preset in loras_presets] lset_choices.append((new_preset_msg, "")) - return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]) + return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False) -def apply_lset(lset_name, loras_choices, loras_mult_choices): + +def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt): if len(lset_name) == 0 or lset_name== new_preset_msg: gr.Info("Please choose a preset in the list or create one") else: - loras_choices, loras_mult_choices= extract_preset(lset_name, loras) + loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras) + if full_prompt: + prompt = preset_prompt + elif len(preset_prompt) > 0: + prompts = prompt.replace("\r", "").split("\n") + prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")] + prompt = "\n".join(prompts) + prompt = preset_prompt + '\n' + prompt gr.Info(f"Lora Preset '{lset_name}' has been applied") - return loras_choices, loras_mult_choices + return loras_choices, loras_mult_choices, prompt def create_demo(): default_inference_steps = 30 + default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v) - with gr.Blocks() as demo: state = gr.State({}) @@ -1130,6 +1204,16 @@ def create_demo(): label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)" ) + boost_choice = gr.Dropdown( + choices=[ + # ("Auto (ON if Video longer than 5s)", 0), + ("ON", 1), + ("OFF", 2), + ], + value=boost, + label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)" + ) + profile_choice = gr.Dropdown( choices=[ ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1), @@ -1161,16 +1245,12 @@ def create_demo(): video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) ####### if args.multiple_images: image_to_continue = gr.Gallery( - label="Images as a starting point for new videos", type ="numpy", #file_types= "image", + label="Images as a starting point for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video) else: image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video) - if use_image2video: - prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", lines=3) - else: - prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.", lines=3) - + prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos, lines that starts with # are ignored)", value=default_prompt, lines=3) with gr.Row(): if use_image2video: @@ -1223,9 +1303,21 @@ def create_demo(): # with gr.Column(): with gr.Row(height=17): apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1) + # save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False) + save_lset_prompt_drop= gr.Dropdown( + choices=[ + ("Save Prompt Comments Only", 0), + ("Save Full Prompt", 1) + ], show_label= False, container=False, visible= False + ) + + with gr.Row(height=17): + confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False) + confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False) save_lset_btn = gr.Button("Save", size="sm", min_width= 1) delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1) + cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) loras_choices = gr.Dropdown( @@ -1237,7 +1329,7 @@ def create_demo(): visible= len(loras)>0, label="Activated Loras" ) - loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns", value=default_loras_multis_str, visible= len(loras)>0 ) + loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 ) show_advanced = gr.Checkbox(label="Show Advanced Options", value=False) with gr.Row(visible=False) as advanced_row: @@ -1250,18 +1342,23 @@ def create_demo(): guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False) flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale") - tea_cache_setting = gr.Dropdown( - choices=[ - ("Tea Cache Disabled", 0), - ("0.03 (around x1.6 speed up)", 0.03), - ("0.05 (around x2 speed up)", 0.05), - ("0.10 (around x3 speed up)", 0.1), - ], - value=default_tea_cache, - visible=True, - label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)" - ) - tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)") + with gr.Row(): + gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)") + with gr.Row(): + tea_cache_setting = gr.Dropdown( + choices=[ + ("Tea Cache Disabled", 0), + ("around x1.5 speed up", 1.5), + ("around x1.75 speed up", 1.75), + ("around x2 speed up", 2.0), + ("around x2.25 speed up", 2.25), + ("around x2.5 speed up", 2.5), + ], + value=default_tea_cache, + visible=True, + label="Tea Cache Global Acceleration" + ) + tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation") RIFLEx_setting = gr.Dropdown( choices=[ @@ -1283,9 +1380,13 @@ def create_demo(): generate_btn = gr.Button("Generate") abort_btn = gr.Button("Abort") - save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices], outputs=[lset_name]) - delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name]) - apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices], outputs=[loras_choices, loras_mult_choices]) + save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) + confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) + delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) + confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) + cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ]) + + apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt]) gen_status.change(refresh_gallery, inputs = [state], outputs = output ) @@ -1335,6 +1436,7 @@ def create_demo(): profile_choice, vae_config_choice, default_ui_choice, + boost_choice, ], outputs= msg ).then( diff --git a/wan/image2video.py b/wan/image2video.py index 5428c68..7ffcb78 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -146,7 +146,7 @@ class WanI2V: callback = None, enable_RIFLEx = False, VAE_tile_size= 0, - + joint_pass = False, ): r""" Generates video frames from input image and text prompt using diffusion process. @@ -310,9 +310,22 @@ class WanI2V: 'pipeline' : self } + arg_both= { + 'context': [context[0]], + 'context2': context_null, + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + 'freqs' : freqs, + 'pipeline' : self + } + if offload_model: torch.cuda.empty_cache() + if self.model.enable_teacache: + self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) + # self.model.to(self.device) if callback != None: callback(-1, None) @@ -323,17 +336,22 @@ class WanI2V: timestep = [t] timestep = torch.stack(timestep).to(self.device) - - noise_pred_cond = self.model( - latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0] - if self._interrupt: - return None - if offload_model: - torch.cuda.empty_cache() - noise_pred_uncond = self.model( - latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0] - if self._interrupt: - return None + if joint_pass: + noise_pred_cond, noise_pred_uncond = self.model( + latent_model_input, t=timestep, current_step=i, **arg_both) + if self._interrupt: + return None + else: + noise_pred_cond = self.model( + latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0] + if self._interrupt: + return None + if offload_model: + torch.cuda.empty_cache() + noise_pred_uncond = self.model( + latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0] + if self._interrupt: + return None del latent_model_input if offload_model: torch.cuda.empty_cache() diff --git a/wan/modules/model.py b/wan/modules/model.py index 11b57ab..59f31f8 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -667,7 +667,43 @@ class WanModel(ModelMixin, ConfigMixin): return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device)) - + def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): + rescale_func = np.poly1d(self.coefficients) + e_list = [] + for t in timesteps: + t = torch.stack([t]) + e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t))) + + best_threshold = 0.01 + best_diff = 1000 + target_nb_steps= int(len(timesteps) / speed_factor) + threshold = 0.01 + while threshold <= 0.6: + accumulated_rel_l1_distance =0 + nb_steps = 0 + diff = 1000 + for i, t in enumerate(timesteps): + skip = False + if not (i<=start_step or i== len(timesteps)): + accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item()) + if accumulated_rel_l1_distance < threshold: + skip = True + else: + accumulated_rel_l1_distance = 0 + previous_modulated_input = e_list[i] + if not skip: + nb_steps += 1 + diff = abs(target_nb_steps - nb_steps) + if diff < best_diff: + best_threshold = threshold + best_diff = diff + elif diff > best_diff: + break + threshold += 0.01 + self.rel_l1_thresh = best_threshold + print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(len(timesteps) - best_diff):0.1f} for a target of x{speed_factor}") + return best_threshold + def forward( self, x, @@ -679,6 +715,7 @@ class WanModel(ModelMixin, ConfigMixin): freqs = None, pipeline = None, current_step = 0, + context2 = None, is_uncond=False ): r""" @@ -722,10 +759,13 @@ class WanModel(ModelMixin, ConfigMixin): x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) assert seq_lens.max() <= seq_len - x = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - dim=1) for u in x - ]) + if len(x)==1 and seq_len == x[0].size(1): + x = x[0] + else: + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) # time embeddings e = self.time_embedding( @@ -740,82 +780,105 @@ class WanModel(ModelMixin, ConfigMixin): [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ])) + if context2!=None: + context2 = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context2 + ])) if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) - # deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache) + if context2 != None: + context2 = torch.concat([context_clip, context2], dim=1) + + joint_pass = context2 != None + if joint_pass: + x_list = [x, x.clone()] + context_list = [context, context2] + is_uncond = False + else: + x_list = [x] + context_list = [context] + del x should_calc = True - if self.enable_teacache and current_step >= self.teacache_start_step: - if current_step == self.teacache_start_step: - self.accumulated_rel_l1_distance_cond = 0 - self.accumulated_rel_l1_distance_uncond = 0 - self.teacache_skipped_cond_steps = 0 - self.teacache_skipped_uncond_steps = 0 + if self.enable_teacache: + if is_uncond: + should_calc = self.should_calc else: - prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond - acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond' - - temb_relative_l1 = relative_l1_distance(prev_input, e0) - setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1) - - if getattr(self, acc_distance_attr) < self.rel_l1_thresh: - should_calc = False - self.teacache_counter += 1 - else: + if current_step <= self.teacache_start_step or current_step == self.num_steps-1: should_calc = True - setattr(self, acc_distance_attr, 0) - - if is_uncond: - self.previous_modulated_input_uncond = e0.clone() - if should_calc: - self.previous_residual_uncond = None + self.accumulated_rel_l1_distance = 0 else: - x += self.previous_residual_uncond - self.teacache_skipped_cond_steps += 1 - # print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" ) - else: - self.previous_modulated_input_cond = e0.clone() - if should_calc: - self.previous_residual_cond = None - else: - x += self.previous_residual_cond - self.teacache_skipped_uncond_steps += 1 - # print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" ) + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + self.teacache_skipped_steps += 1 + # print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" ) + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = e + self.should_calc = should_calc - if should_calc: + if not should_calc: + for i, x in enumerate(x_list): + x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond + else: if self.enable_teacache: - ori_hidden_states = x.clone() + if joint_pass or is_uncond: + self.previous_residual_uncond = None + if joint_pass or not is_uncond: + self.previous_residual_cond = None + ori_hidden_states = x_list[0].clone() # arguments kwargs = dict( - e=e0, + # e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=freqs, - context=context, + # context=context, context_lens=context_lens) for block in self.blocks: if pipeline._interrupt: - return [None] - - x = block(x, **kwargs) + if joint_pass: + return None, None + else: + return [None] + for i, (x, context) in enumerate(zip(x_list, context_list)): + x_list[i] = block(x, context = context, e= e0, **kwargs) + del x if self.enable_teacache: - residual = ori_hidden_states # just to have a readable code - torch.sub(x, ori_hidden_states, out=residual) - if is_uncond: - self.previous_residual_uncond = residual + if joint_pass: + self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states) + self.previous_residual_uncond = ori_hidden_states + torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond) else: - self.previous_residual_cond = residual - del residual, ori_hidden_states + residual = ori_hidden_states # just to have a readable code + torch.sub(x_list[0], ori_hidden_states, out=residual) + if i==1 or is_uncond: + self.previous_residual_uncond = residual + else: + self.previous_residual_cond = residual + residual, ori_hidden_states = None, None - # head - x = self.head(x, e) + for i, x in enumerate(x_list): + # head + x = self.head(x, e) - # unpatchify - x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] + # unpatchify + x_list[i] = self.unpatchify(x, grid_sizes) + del x + + if joint_pass: + return x_list[0][0], x_list[1][0] + else: + return [u.float() for u in x_list[0]] def unpatchify(self, x, grid_sizes): r""" diff --git a/wan/text2video.py b/wan/text2video.py index 0b3bc2d..88eff60 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -131,7 +131,8 @@ class WanT2V: offload_model=True, callback = None, enable_RIFLEx = None, - VAE_tile_size = 0 + VAE_tile_size = 0, + joint_pass = False, ): r""" Generates video frames from text prompt using diffusion process. @@ -240,8 +241,10 @@ class WanT2V: freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx) arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} + arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} - + if self.model.enable_teacache: + self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) if callback != None: callback(-1, None) for i, t in enumerate(tqdm(timesteps)): @@ -251,14 +254,20 @@ class WanT2V: timestep = torch.stack(timestep) # self.model.to(self.device) - noise_pred_cond = self.model( - latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0] - if self._interrupt: - return None + if joint_pass: + noise_pred_cond, noise_pred_uncond = self.model( + latent_model_input, t=timestep,current_step=i, **arg_both) + if self._interrupt: + return None + else: + noise_pred_cond = self.model( + latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0] + if self._interrupt: + return None + noise_pred_uncond = self.model( + latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0] + if self._interrupt: + return None del latent_model_input noise_pred = noise_pred_uncond + guide_scale * (