mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Add skip layer guidance
This commit is contained in:
parent
48b777ae6c
commit
936db03daa
@ -413,6 +413,9 @@ def parse_args():
|
||||
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
||||
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
||||
parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
|
||||
parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
|
||||
parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
|
||||
|
||||
# LoRA usage
|
||||
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
||||
@ -540,6 +543,12 @@ def main():
|
||||
except:
|
||||
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
||||
|
||||
# Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
|
||||
if args.slg_layers:
|
||||
slg_list = [int(x) for x in args.slg_layers.split(",")]
|
||||
else:
|
||||
slg_list = None
|
||||
|
||||
# Additional checks (from your original code).
|
||||
if "480p" in args.transformer_file:
|
||||
# Then we cannot exceed certain area for 480p model
|
||||
@ -628,6 +637,10 @@ def main():
|
||||
callback=None, # or define your own callback if you want
|
||||
enable_RIFLEx=enable_riflex,
|
||||
VAE_tile_size=VAE_tile_size,
|
||||
joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
|
||||
slg_layers=slg_list,
|
||||
slg_start=args.slg_start,
|
||||
slg_end=args.slg_end,
|
||||
)
|
||||
except Exception as e:
|
||||
offloadobj.unload_all()
|
||||
|
||||
@ -132,22 +132,25 @@ class WanI2V:
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
img,
|
||||
max_area=720 * 1280,
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=40,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True,
|
||||
callback = None,
|
||||
enable_RIFLEx = False,
|
||||
VAE_tile_size= 0,
|
||||
joint_pass = False,
|
||||
):
|
||||
input_prompt,
|
||||
img,
|
||||
max_area=720 * 1280,
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=40,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True,
|
||||
callback = None,
|
||||
enable_RIFLEx = False,
|
||||
VAE_tile_size= 0,
|
||||
joint_pass = False,
|
||||
slg_layers = None,
|
||||
slg_start = 0.0,
|
||||
slg_end = 1.0,
|
||||
):
|
||||
r"""
|
||||
Generates video frames from input image and text prompt using diffusion process.
|
||||
|
||||
@ -331,25 +334,42 @@ class WanI2V:
|
||||
callback(-1, None)
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
slg_layers_local = None
|
||||
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
|
||||
slg_layers_local = slg_layers
|
||||
|
||||
offload.set_step_no_for_lora(i)
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep).to(self.device)
|
||||
if joint_pass:
|
||||
if slg_layers is not None:
|
||||
raise ValueError('Can not use SLG and joint-pass')
|
||||
noise_pred_cond, noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, current_step=i, **arg_both)
|
||||
if self._interrupt:
|
||||
return None
|
||||
else:
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
||||
latent_model_input,
|
||||
t=timestep,
|
||||
current_step=i,
|
||||
is_uncond=False,
|
||||
**arg_c,
|
||||
)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
if offload_model:
|
||||
torch.cuda.empty_cache()
|
||||
noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
||||
latent_model_input,
|
||||
t=timestep,
|
||||
current_step=i,
|
||||
is_uncond=True,
|
||||
slg_layers=slg_layers_local,
|
||||
**arg_null,
|
||||
)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
del latent_model_input
|
||||
|
||||
@ -716,7 +716,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
pipeline = None,
|
||||
current_step = 0,
|
||||
context2 = None,
|
||||
is_uncond=False
|
||||
is_uncond=False,
|
||||
slg_layers=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
@ -843,7 +844,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
# context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
for block in self.blocks:
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
if slg_layers is not None and block_idx in slg_layers and is_uncond:
|
||||
continue
|
||||
if pipeline._interrupt:
|
||||
if joint_pass:
|
||||
return None, None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user