mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Merge pull request #61 from AmericanPresidentJimmyCarter/slg
Add skip layer guidance
This commit is contained in:
commit
5f07a7092d
@ -413,6 +413,9 @@ def parse_args():
|
|||||||
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
||||||
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
||||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
||||||
|
parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
|
||||||
|
parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
|
||||||
|
parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
|
||||||
|
|
||||||
# LoRA usage
|
# LoRA usage
|
||||||
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
||||||
@ -540,6 +543,12 @@ def main():
|
|||||||
except:
|
except:
|
||||||
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
||||||
|
|
||||||
|
# Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
|
||||||
|
if args.slg_layers:
|
||||||
|
slg_list = [int(x) for x in args.slg_layers.split(",")]
|
||||||
|
else:
|
||||||
|
slg_list = None
|
||||||
|
|
||||||
# Additional checks (from your original code).
|
# Additional checks (from your original code).
|
||||||
if "480p" in args.transformer_file:
|
if "480p" in args.transformer_file:
|
||||||
# Then we cannot exceed certain area for 480p model
|
# Then we cannot exceed certain area for 480p model
|
||||||
@ -628,6 +637,10 @@ def main():
|
|||||||
callback=None, # or define your own callback if you want
|
callback=None, # or define your own callback if you want
|
||||||
enable_RIFLEx=enable_riflex,
|
enable_RIFLEx=enable_riflex,
|
||||||
VAE_tile_size=VAE_tile_size,
|
VAE_tile_size=VAE_tile_size,
|
||||||
|
joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
|
||||||
|
slg_layers=slg_list,
|
||||||
|
slg_start=args.slg_start,
|
||||||
|
slg_end=args.slg_end,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
offloadobj.unload_all()
|
offloadobj.unload_all()
|
||||||
|
|||||||
@ -132,22 +132,25 @@ class WanI2V:
|
|||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
img,
|
img,
|
||||||
max_area=720 * 1280,
|
max_area=720 * 1280,
|
||||||
frame_num=81,
|
frame_num=81,
|
||||||
shift=5.0,
|
shift=5.0,
|
||||||
sample_solver='unipc',
|
sample_solver='unipc',
|
||||||
sampling_steps=40,
|
sampling_steps=40,
|
||||||
guide_scale=5.0,
|
guide_scale=5.0,
|
||||||
n_prompt="",
|
n_prompt="",
|
||||||
seed=-1,
|
seed=-1,
|
||||||
offload_model=True,
|
offload_model=True,
|
||||||
callback = None,
|
callback = None,
|
||||||
enable_RIFLEx = False,
|
enable_RIFLEx = False,
|
||||||
VAE_tile_size= 0,
|
VAE_tile_size= 0,
|
||||||
joint_pass = False,
|
joint_pass = False,
|
||||||
):
|
slg_layers = None,
|
||||||
|
slg_start = 0.0,
|
||||||
|
slg_end = 1.0,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Generates video frames from input image and text prompt using diffusion process.
|
Generates video frames from input image and text prompt using diffusion process.
|
||||||
|
|
||||||
@ -331,25 +334,42 @@ class WanI2V:
|
|||||||
callback(-1, None)
|
callback(-1, None)
|
||||||
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
|
slg_layers_local = None
|
||||||
|
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
|
||||||
|
slg_layers_local = slg_layers
|
||||||
|
|
||||||
offload.set_step_no_for_lora(i)
|
offload.set_step_no_for_lora(i)
|
||||||
latent_model_input = [latent.to(self.device)]
|
latent_model_input = [latent.to(self.device)]
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
timestep = torch.stack(timestep).to(self.device)
|
timestep = torch.stack(timestep).to(self.device)
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
|
if slg_layers is not None:
|
||||||
|
raise ValueError('Can not use SLG and joint-pass')
|
||||||
noise_pred_cond, noise_pred_uncond = self.model(
|
noise_pred_cond, noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, current_step=i, **arg_both)
|
latent_model_input, t=timestep, current_step=i, **arg_both)
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
latent_model_input,
|
||||||
|
t=timestep,
|
||||||
|
current_step=i,
|
||||||
|
is_uncond=False,
|
||||||
|
**arg_c,
|
||||||
|
)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
latent_model_input,
|
||||||
|
t=timestep,
|
||||||
|
current_step=i,
|
||||||
|
is_uncond=True,
|
||||||
|
slg_layers=slg_layers_local,
|
||||||
|
**arg_null,
|
||||||
|
)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
del latent_model_input
|
del latent_model_input
|
||||||
|
|||||||
@ -716,7 +716,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
pipeline = None,
|
pipeline = None,
|
||||||
current_step = 0,
|
current_step = 0,
|
||||||
context2 = None,
|
context2 = None,
|
||||||
is_uncond=False
|
is_uncond=False,
|
||||||
|
slg_layers=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@ -843,7 +844,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# context=context,
|
# context=context,
|
||||||
context_lens=context_lens)
|
context_lens=context_lens)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block_idx, block in enumerate(self.blocks):
|
||||||
|
if slg_layers is not None and block_idx in slg_layers and is_uncond:
|
||||||
|
continue
|
||||||
if pipeline._interrupt:
|
if pipeline._interrupt:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
return None, None
|
return None, None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user