mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Lora fest + Skip Layer Guidance
This commit is contained in:
commit
e554e1a3d6
@ -19,7 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
|
||||
## 🔥 Latest News!!
|
||||
* Mar 14, 2025: 👋 Wan2.1GP v1.7: Lora Fest special edition: very fast loading / unload of loras so for those Loras collectors around. You will need to refresh the requirements *pip install -r requirements.txt*
|
||||
* Mar 14, 2025: 👋 Wan2.1GP v1.7:
|
||||
- Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt*
|
||||
- Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
|
||||
* Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
|
||||
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
|
||||
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
||||
|
||||
@ -738,6 +738,10 @@ def generate_video(
|
||||
video_to_continue,
|
||||
max_frames,
|
||||
RIFLEx_setting,
|
||||
slg_switch,
|
||||
slg_layers,
|
||||
slg_start,
|
||||
slg_end,
|
||||
state,
|
||||
progress=gr.Progress() #track_tqdm= True
|
||||
|
||||
@ -760,7 +764,8 @@ def generate_video(
|
||||
width, height = resolution.split("x")
|
||||
width, height = int(width), int(height)
|
||||
|
||||
|
||||
if slg_switch == 0:
|
||||
slg_layers = None
|
||||
if use_image2video:
|
||||
if "480p" in transformer_filename_i2v and width * height > 848*480:
|
||||
raise gr.Error("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
||||
@ -982,6 +987,9 @@ def generate_video(
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -999,6 +1007,9 @@ def generate_video(
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
except Exception as e:
|
||||
gen_in_progress = False
|
||||
@ -1490,6 +1501,34 @@ def create_demo():
|
||||
label="RIFLEx positional embedding to generate long video"
|
||||
)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown("Experimental: Skip Layer guidance,should improve video quality")
|
||||
with gr.Row():
|
||||
slg_switch = gr.Dropdown(
|
||||
choices=[
|
||||
("OFF", 0),
|
||||
("ON", 1),
|
||||
],
|
||||
value= 0,
|
||||
visible=True,
|
||||
scale = 1,
|
||||
label="Skip Layer guidance"
|
||||
)
|
||||
slg_layers = gr.Dropdown(
|
||||
choices=[
|
||||
(str(i), i ) for i in range(40)
|
||||
],
|
||||
value= [9],
|
||||
multiselect= True,
|
||||
label="Skip Layers",
|
||||
scale= 3
|
||||
)
|
||||
with gr.Row():
|
||||
slg_start_perc = gr.Slider(0, 100, value=10, step=1, label="Denoising Steps % start")
|
||||
slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end")
|
||||
|
||||
|
||||
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
|
||||
|
||||
with gr.Column():
|
||||
@ -1537,7 +1576,11 @@ def create_demo():
|
||||
video_to_continue,
|
||||
max_frames,
|
||||
RIFLEx_setting,
|
||||
state
|
||||
slg_switch,
|
||||
slg_layers,
|
||||
slg_start_perc,
|
||||
slg_end_perc,
|
||||
state,
|
||||
],
|
||||
outputs= [gen_status] #,state
|
||||
|
||||
|
||||
@ -413,6 +413,9 @@ def parse_args():
|
||||
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
||||
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
||||
parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
|
||||
parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
|
||||
parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
|
||||
|
||||
# LoRA usage
|
||||
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
||||
@ -540,6 +543,12 @@ def main():
|
||||
except:
|
||||
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
||||
|
||||
# Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
|
||||
if args.slg_layers:
|
||||
slg_list = [int(x) for x in args.slg_layers.split(",")]
|
||||
else:
|
||||
slg_list = None
|
||||
|
||||
# Additional checks (from your original code).
|
||||
if "480p" in args.transformer_file:
|
||||
# Then we cannot exceed certain area for 480p model
|
||||
@ -628,6 +637,10 @@ def main():
|
||||
callback=None, # or define your own callback if you want
|
||||
enable_RIFLEx=enable_riflex,
|
||||
VAE_tile_size=VAE_tile_size,
|
||||
joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
|
||||
slg_layers=slg_list,
|
||||
slg_start=args.slg_start,
|
||||
slg_end=args.slg_end,
|
||||
)
|
||||
except Exception as e:
|
||||
offloadobj.unload_all()
|
||||
|
||||
@ -132,22 +132,25 @@ class WanI2V:
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
img,
|
||||
max_area=720 * 1280,
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=40,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True,
|
||||
callback = None,
|
||||
enable_RIFLEx = False,
|
||||
VAE_tile_size= 0,
|
||||
joint_pass = False,
|
||||
):
|
||||
input_prompt,
|
||||
img,
|
||||
max_area=720 * 1280,
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=40,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True,
|
||||
callback = None,
|
||||
enable_RIFLEx = False,
|
||||
VAE_tile_size= 0,
|
||||
joint_pass = False,
|
||||
slg_layers = None,
|
||||
slg_start = 0.0,
|
||||
slg_end = 1.0,
|
||||
):
|
||||
r"""
|
||||
Generates video frames from input image and text prompt using diffusion process.
|
||||
|
||||
@ -332,24 +335,41 @@ class WanI2V:
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
offload.set_step_no_for_lora(self.model, i)
|
||||
slg_layers_local = None
|
||||
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
|
||||
slg_layers_local = slg_layers
|
||||
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep).to(self.device)
|
||||
if joint_pass:
|
||||
# if slg_layers is not None:
|
||||
# raise ValueError('Can not use SLG and joint-pass')
|
||||
noise_pred_cond, noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, current_step=i, **arg_both)
|
||||
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
|
||||
if self._interrupt:
|
||||
return None
|
||||
else:
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
||||
latent_model_input,
|
||||
t=timestep,
|
||||
current_step=i,
|
||||
is_uncond=False,
|
||||
**arg_c,
|
||||
)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
if offload_model:
|
||||
torch.cuda.empty_cache()
|
||||
noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
||||
latent_model_input,
|
||||
t=timestep,
|
||||
current_step=i,
|
||||
is_uncond=True,
|
||||
slg_layers=slg_layers_local,
|
||||
**arg_null,
|
||||
)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
del latent_model_input
|
||||
|
||||
@ -717,8 +717,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
current_step = 0,
|
||||
context2 = None,
|
||||
is_uncond=False,
|
||||
max_steps = 0
|
||||
|
||||
max_steps = 0,
|
||||
slg_layers=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
@ -851,8 +851,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
# context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
for l, block in enumerate(self.blocks):
|
||||
offload.shared_state["layer"] = l
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
offload.shared_state["layer"] = block_idx
|
||||
if "refresh" in offload.shared_state:
|
||||
del offload.shared_state["refresh"]
|
||||
offload.shared_state["callback"](-1, -1, True)
|
||||
@ -861,9 +861,16 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
return None, None
|
||||
else:
|
||||
return [None]
|
||||
for i, (x, context) in enumerate(zip(x_list, context_list)):
|
||||
x_list[i] = block(x, context = context, e= e0, **kwargs)
|
||||
del x
|
||||
|
||||
if slg_layers is not None and block_idx in slg_layers:
|
||||
if is_uncond and not joint_pass:
|
||||
continue
|
||||
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
|
||||
|
||||
else:
|
||||
for i, (x, context) in enumerate(zip(x_list, context_list)):
|
||||
x_list[i] = block(x, context = context, e= e0, **kwargs)
|
||||
del x
|
||||
|
||||
if self.enable_teacache:
|
||||
if joint_pass:
|
||||
|
||||
@ -119,20 +119,23 @@ class WanT2V:
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
size=(1280, 720),
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True,
|
||||
callback = None,
|
||||
enable_RIFLEx = None,
|
||||
VAE_tile_size = 0,
|
||||
joint_pass = False,
|
||||
input_prompt,
|
||||
size=(1280, 720),
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True,
|
||||
callback = None,
|
||||
enable_RIFLEx = None,
|
||||
VAE_tile_size = 0,
|
||||
joint_pass = False,
|
||||
slg_layers = None,
|
||||
slg_start = 0.0,
|
||||
slg_end = 1.0,
|
||||
):
|
||||
r"""
|
||||
Generates video frames from text prompt using diffusion process.
|
||||
@ -253,6 +256,9 @@ class WanT2V:
|
||||
callback(-1, None)
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
slg_layers_local = None
|
||||
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
|
||||
slg_layers_local = slg_layers
|
||||
timestep = [t]
|
||||
offload.set_step_no_for_lora(self.model, i)
|
||||
timestep = torch.stack(timestep)
|
||||
@ -260,7 +266,7 @@ class WanT2V:
|
||||
# self.model.to(self.device)
|
||||
if joint_pass:
|
||||
noise_pred_cond, noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep,current_step=i, **arg_both)
|
||||
latent_model_input, t=timestep,current_step=i, slg_layers=slg_layers_local, **arg_both)
|
||||
if self._interrupt:
|
||||
return None
|
||||
else:
|
||||
@ -269,7 +275,7 @@ class WanT2V:
|
||||
if self._interrupt:
|
||||
return None
|
||||
noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
|
||||
latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user