Lora fest + Skip Layer Guidance

This commit is contained in:
DeepBeepMeep 2025-03-15 01:12:51 +01:00
commit e554e1a3d6
6 changed files with 136 additions and 45 deletions

View File

@ -19,7 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Mar 14, 2025: 👋 Wan2.1GP v1.7: Lora Fest special edition: very fast loading / unload of loras so for those Loras collectors around. You will need to refresh the requirements *pip install -r requirements.txt*
* Mar 14, 2025: 👋 Wan2.1GP v1.7:
- Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt*
- Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
* Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated

View File

@ -738,6 +738,10 @@ def generate_video(
video_to_continue,
max_frames,
RIFLEx_setting,
slg_switch,
slg_layers,
slg_start,
slg_end,
state,
progress=gr.Progress() #track_tqdm= True
@ -760,7 +764,8 @@ def generate_video(
width, height = resolution.split("x")
width, height = int(width), int(height)
if slg_switch == 0:
slg_layers = None
if use_image2video:
if "480p" in transformer_filename_i2v and width * height > 848*480:
raise gr.Error("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
@ -982,6 +987,9 @@ def generate_video(
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
else:
@ -999,6 +1007,9 @@ def generate_video(
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
except Exception as e:
gen_in_progress = False
@ -1490,6 +1501,34 @@ def create_demo():
label="RIFLEx positional embedding to generate long video"
)
with gr.Row():
gr.Markdown("Experimental: Skip Layer guidance,should improve video quality")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value= 0,
visible=True,
scale = 1,
label="Skip Layer guidance"
)
slg_layers = gr.Dropdown(
choices=[
(str(i), i ) for i in range(40)
],
value= [9],
multiselect= True,
label="Skip Layers",
scale= 3
)
with gr.Row():
slg_start_perc = gr.Slider(0, 100, value=10, step=1, label="Denoising Steps % start")
slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end")
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
with gr.Column():
@ -1537,7 +1576,11 @@ def create_demo():
video_to_continue,
max_frames,
RIFLEx_setting,
state
slg_switch,
slg_layers,
slg_start_perc,
slg_end_perc,
state,
],
outputs= [gen_status] #,state

View File

@ -413,6 +413,9 @@ def parse_args():
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
# LoRA usage
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
@ -540,6 +543,12 @@ def main():
except:
raise ValueError(f"Invalid resolution: '{resolution_str}'")
# Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
if args.slg_layers:
slg_list = [int(x) for x in args.slg_layers.split(",")]
else:
slg_list = None
# Additional checks (from your original code).
if "480p" in args.transformer_file:
# Then we cannot exceed certain area for 480p model
@ -628,6 +637,10 @@ def main():
callback=None, # or define your own callback if you want
enable_RIFLEx=enable_riflex,
VAE_tile_size=VAE_tile_size,
joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
slg_layers=slg_list,
slg_start=args.slg_start,
slg_end=args.slg_end,
)
except Exception as e:
offloadobj.unload_all()

View File

@ -132,22 +132,25 @@ class WanI2V:
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = False,
VAE_tile_size= 0,
joint_pass = False,
):
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = False,
VAE_tile_size= 0,
joint_pass = False,
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
):
r"""
Generates video frames from input image and text prompt using diffusion process.
@ -332,24 +335,41 @@ class WanI2V:
for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i)
slg_layers_local = None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
slg_layers_local = slg_layers
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
if joint_pass:
# if slg_layers is not None:
# raise ValueError('Can not use SLG and joint-pass')
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, **arg_both)
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
latent_model_input,
t=timestep,
current_step=i,
is_uncond=False,
**arg_c,
)[0]
if self._interrupt:
return None
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
latent_model_input,
t=timestep,
current_step=i,
is_uncond=True,
slg_layers=slg_layers_local,
**arg_null,
)[0]
if self._interrupt:
return None
del latent_model_input

View File

@ -717,8 +717,8 @@ class WanModel(ModelMixin, ConfigMixin):
current_step = 0,
context2 = None,
is_uncond=False,
max_steps = 0
max_steps = 0,
slg_layers=None,
):
r"""
Forward pass through the diffusion model
@ -851,8 +851,8 @@ class WanModel(ModelMixin, ConfigMixin):
# context=context,
context_lens=context_lens)
for l, block in enumerate(self.blocks):
offload.shared_state["layer"] = l
for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx
if "refresh" in offload.shared_state:
del offload.shared_state["refresh"]
offload.shared_state["callback"](-1, -1, True)
@ -861,9 +861,16 @@ class WanModel(ModelMixin, ConfigMixin):
return None, None
else:
return [None]
for i, (x, context) in enumerate(zip(x_list, context_list)):
x_list[i] = block(x, context = context, e= e0, **kwargs)
del x
if slg_layers is not None and block_idx in slg_layers:
if is_uncond and not joint_pass:
continue
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
else:
for i, (x, context) in enumerate(zip(x_list, context_list)):
x_list[i] = block(x, context = context, e= e0, **kwargs)
del x
if self.enable_teacache:
if joint_pass:

View File

@ -119,20 +119,23 @@ class WanT2V:
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = None,
VAE_tile_size = 0,
joint_pass = False,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = None,
VAE_tile_size = 0,
joint_pass = False,
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
):
r"""
Generates video frames from text prompt using diffusion process.
@ -253,6 +256,9 @@ class WanT2V:
callback(-1, None)
for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
slg_layers_local = None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
slg_layers_local = slg_layers
timestep = [t]
offload.set_step_no_for_lora(self.model, i)
timestep = torch.stack(timestep)
@ -260,7 +266,7 @@ class WanT2V:
# self.model.to(self.device)
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep,current_step=i, **arg_both)
latent_model_input, t=timestep,current_step=i, slg_layers=slg_layers_local, **arg_both)
if self._interrupt:
return None
else:
@ -269,7 +275,7 @@ class WanT2V:
if self._interrupt:
return None
noise_pred_uncond = self.model(
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0]
if self._interrupt:
return None