From 936db03daa55a849a8ff73ee7072de86edbabe81 Mon Sep 17 00:00:00 2001 From: Jimmy <39@🇺🇸.com> Date: Thu, 13 Mar 2025 09:33:09 -0400 Subject: [PATCH] Add skip layer guidance --- i2v_inference.py | 13 ++++++++++ wan/image2video.py | 56 ++++++++++++++++++++++++++++++-------------- wan/modules/model.py | 7 ++++-- 3 files changed, 56 insertions(+), 20 deletions(-) diff --git a/i2v_inference.py b/i2v_inference.py index 7a9fe5a..9b5a5cc 100644 --- a/i2v_inference.py +++ b/i2v_inference.py @@ -413,6 +413,9 @@ def parse_args(): parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.") parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]") parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.") + parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance") + parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG") + parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG") # LoRA usage parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.") @@ -540,6 +543,12 @@ def main(): except: raise ValueError(f"Invalid resolution: '{resolution_str}'") + # Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided) + if args.slg_layers: + slg_list = [int(x) for x in args.slg_layers.split(",")] + else: + slg_list = None + # Additional checks (from your original code). if "480p" in args.transformer_file: # Then we cannot exceed certain area for 480p model @@ -628,6 +637,10 @@ def main(): callback=None, # or define your own callback if you want enable_RIFLEx=enable_riflex, VAE_tile_size=VAE_tile_size, + joint_pass=slg_list is None, # set if you want a small speed improvement without SLG + slg_layers=slg_list, + slg_start=args.slg_start, + slg_end=args.slg_end, ) except Exception as e: offloadobj.unload_all() diff --git a/wan/image2video.py b/wan/image2video.py index 7ffcb78..bb2f6f4 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -132,22 +132,25 @@ class WanI2V: self.sample_neg_prompt = config.sample_neg_prompt def generate(self, - input_prompt, - img, - max_area=720 * 1280, - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=40, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True, - callback = None, - enable_RIFLEx = False, - VAE_tile_size= 0, - joint_pass = False, - ): + input_prompt, + img, + max_area=720 * 1280, + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + callback = None, + enable_RIFLEx = False, + VAE_tile_size= 0, + joint_pass = False, + slg_layers = None, + slg_start = 0.0, + slg_end = 1.0, + ): r""" Generates video frames from input image and text prompt using diffusion process. @@ -331,25 +334,42 @@ class WanI2V: callback(-1, None) for i, t in enumerate(tqdm(timesteps)): + slg_layers_local = None + if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): + slg_layers_local = slg_layers + offload.set_step_no_for_lora(i) latent_model_input = [latent.to(self.device)] timestep = [t] timestep = torch.stack(timestep).to(self.device) if joint_pass: + if slg_layers is not None: + raise ValueError('Can not use SLG and joint-pass') noise_pred_cond, noise_pred_uncond = self.model( latent_model_input, t=timestep, current_step=i, **arg_both) if self._interrupt: return None else: noise_pred_cond = self.model( - latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0] + latent_model_input, + t=timestep, + current_step=i, + is_uncond=False, + **arg_c, + )[0] if self._interrupt: return None if offload_model: torch.cuda.empty_cache() noise_pred_uncond = self.model( - latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0] + latent_model_input, + t=timestep, + current_step=i, + is_uncond=True, + slg_layers=slg_layers_local, + **arg_null, + )[0] if self._interrupt: return None del latent_model_input diff --git a/wan/modules/model.py b/wan/modules/model.py index 59f31f8..82b71f4 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -716,7 +716,8 @@ class WanModel(ModelMixin, ConfigMixin): pipeline = None, current_step = 0, context2 = None, - is_uncond=False + is_uncond=False, + slg_layers=None, ): r""" Forward pass through the diffusion model @@ -843,7 +844,9 @@ class WanModel(ModelMixin, ConfigMixin): # context=context, context_lens=context_lens) - for block in self.blocks: + for block_idx, block in enumerate(self.blocks): + if slg_layers is not None and block_idx in slg_layers and is_uncond: + continue if pipeline._interrupt: if joint_pass: return None, None