Add skip layer guidance

This commit is contained in:
Jimmy 2025-03-13 09:33:09 -04:00
parent 48b777ae6c
commit 936db03daa
3 changed files with 56 additions and 20 deletions

View File

@ -413,6 +413,9 @@ def parse_args():
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
# LoRA usage
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
@ -540,6 +543,12 @@ def main():
except:
raise ValueError(f"Invalid resolution: '{resolution_str}'")
# Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
if args.slg_layers:
slg_list = [int(x) for x in args.slg_layers.split(",")]
else:
slg_list = None
# Additional checks (from your original code).
if "480p" in args.transformer_file:
# Then we cannot exceed certain area for 480p model
@ -628,6 +637,10 @@ def main():
callback=None, # or define your own callback if you want
enable_RIFLEx=enable_riflex,
VAE_tile_size=VAE_tile_size,
joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
slg_layers=slg_list,
slg_start=args.slg_start,
slg_end=args.slg_end,
)
except Exception as e:
offloadobj.unload_all()

View File

@ -132,22 +132,25 @@ class WanI2V:
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = False,
VAE_tile_size= 0,
joint_pass = False,
):
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = False,
VAE_tile_size= 0,
joint_pass = False,
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
):
r"""
Generates video frames from input image and text prompt using diffusion process.
@ -331,25 +334,42 @@ class WanI2V:
callback(-1, None)
for i, t in enumerate(tqdm(timesteps)):
slg_layers_local = None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
slg_layers_local = slg_layers
offload.set_step_no_for_lora(i)
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
if joint_pass:
if slg_layers is not None:
raise ValueError('Can not use SLG and joint-pass')
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, **arg_both)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
latent_model_input,
t=timestep,
current_step=i,
is_uncond=False,
**arg_c,
)[0]
if self._interrupt:
return None
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
latent_model_input,
t=timestep,
current_step=i,
is_uncond=True,
slg_layers=slg_layers_local,
**arg_null,
)[0]
if self._interrupt:
return None
del latent_model_input

View File

@ -716,7 +716,8 @@ class WanModel(ModelMixin, ConfigMixin):
pipeline = None,
current_step = 0,
context2 = None,
is_uncond=False
is_uncond=False,
slg_layers=None,
):
r"""
Forward pass through the diffusion model
@ -843,7 +844,9 @@ class WanModel(ModelMixin, ConfigMixin):
# context=context,
context_lens=context_lens)
for block in self.blocks:
for block_idx, block in enumerate(self.blocks):
if slg_layers is not None and block_idx in slg_layers and is_uncond:
continue
if pipeline._interrupt:
if joint_pass:
return None, None