diff --git a/README.md b/README.md index bcb23a6..4793a7a 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! -* Mar 14, 2025: 👋 Wan2.1GP v1.7: Lora Fest special edition: very fast loading / unload of loras so for those Loras collectors around. You will need to refresh the requirements *pip install -r requirements.txt* +* Mar 14, 2025: 👋 Wan2.1GP v1.7: + - Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt* + - Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation * Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt* * Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user) * Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated diff --git a/gradio_server.py b/gradio_server.py index 479c7ff..1597a6e 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -738,6 +738,10 @@ def generate_video( video_to_continue, max_frames, RIFLEx_setting, + slg_switch, + slg_layers, + slg_start, + slg_end, state, progress=gr.Progress() #track_tqdm= True @@ -760,7 +764,8 @@ def generate_video( width, height = resolution.split("x") width, height = int(width), int(height) - + if slg_switch == 0: + slg_layers = None if use_image2video: if "480p" in transformer_filename_i2v and width * height > 848*480: raise gr.Error("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") @@ -982,6 +987,9 @@ def generate_video( enable_RIFLEx = enable_RIFLEx, VAE_tile_size = VAE_tile_size, joint_pass = joint_pass, + slg_layers = slg_layers, + slg_start = slg_start/100, + slg_end = slg_end/100, ) else: @@ -999,6 +1007,9 @@ def generate_video( enable_RIFLEx = enable_RIFLEx, VAE_tile_size = VAE_tile_size, joint_pass = joint_pass, + slg_layers = slg_layers, + slg_start = slg_start/100, + slg_end = slg_end/100, ) except Exception as e: gen_in_progress = False @@ -1490,6 +1501,34 @@ def create_demo(): label="RIFLEx positional embedding to generate long video" ) + + with gr.Row(): + gr.Markdown("Experimental: Skip Layer guidance,should improve video quality") + with gr.Row(): + slg_switch = gr.Dropdown( + choices=[ + ("OFF", 0), + ("ON", 1), + ], + value= 0, + visible=True, + scale = 1, + label="Skip Layer guidance" + ) + slg_layers = gr.Dropdown( + choices=[ + (str(i), i ) for i in range(40) + ], + value= [9], + multiselect= True, + label="Skip Layers", + scale= 3 + ) + with gr.Row(): + slg_start_perc = gr.Slider(0, 100, value=10, step=1, label="Denoising Steps % start") + slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end") + + show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row]) with gr.Column(): @@ -1537,7 +1576,11 @@ def create_demo(): video_to_continue, max_frames, RIFLEx_setting, - state + slg_switch, + slg_layers, + slg_start_perc, + slg_end_perc, + state, ], outputs= [gen_status] #,state diff --git a/i2v_inference.py b/i2v_inference.py index 7a9fe5a..9b5a5cc 100644 --- a/i2v_inference.py +++ b/i2v_inference.py @@ -413,6 +413,9 @@ def parse_args(): parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.") parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]") parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.") + parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance") + parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG") + parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG") # LoRA usage parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.") @@ -540,6 +543,12 @@ def main(): except: raise ValueError(f"Invalid resolution: '{resolution_str}'") + # Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided) + if args.slg_layers: + slg_list = [int(x) for x in args.slg_layers.split(",")] + else: + slg_list = None + # Additional checks (from your original code). if "480p" in args.transformer_file: # Then we cannot exceed certain area for 480p model @@ -628,6 +637,10 @@ def main(): callback=None, # or define your own callback if you want enable_RIFLEx=enable_riflex, VAE_tile_size=VAE_tile_size, + joint_pass=slg_list is None, # set if you want a small speed improvement without SLG + slg_layers=slg_list, + slg_start=args.slg_start, + slg_end=args.slg_end, ) except Exception as e: offloadobj.unload_all() diff --git a/wan/image2video.py b/wan/image2video.py index 170f90f..088f51e 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -132,22 +132,25 @@ class WanI2V: self.sample_neg_prompt = config.sample_neg_prompt def generate(self, - input_prompt, - img, - max_area=720 * 1280, - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=40, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True, - callback = None, - enable_RIFLEx = False, - VAE_tile_size= 0, - joint_pass = False, - ): + input_prompt, + img, + max_area=720 * 1280, + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + callback = None, + enable_RIFLEx = False, + VAE_tile_size= 0, + joint_pass = False, + slg_layers = None, + slg_start = 0.0, + slg_end = 1.0, + ): r""" Generates video frames from input image and text prompt using diffusion process. @@ -332,24 +335,41 @@ class WanI2V: for i, t in enumerate(tqdm(timesteps)): offload.set_step_no_for_lora(self.model, i) + slg_layers_local = None + if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): + slg_layers_local = slg_layers + latent_model_input = [latent.to(self.device)] timestep = [t] timestep = torch.stack(timestep).to(self.device) if joint_pass: + # if slg_layers is not None: + # raise ValueError('Can not use SLG and joint-pass') noise_pred_cond, noise_pred_uncond = self.model( - latent_model_input, t=timestep, current_step=i, **arg_both) + latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) if self._interrupt: return None else: noise_pred_cond = self.model( - latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0] + latent_model_input, + t=timestep, + current_step=i, + is_uncond=False, + **arg_c, + )[0] if self._interrupt: return None if offload_model: torch.cuda.empty_cache() noise_pred_uncond = self.model( - latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0] + latent_model_input, + t=timestep, + current_step=i, + is_uncond=True, + slg_layers=slg_layers_local, + **arg_null, + )[0] if self._interrupt: return None del latent_model_input diff --git a/wan/modules/model.py b/wan/modules/model.py index 9089c41..e864898 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -717,8 +717,8 @@ class WanModel(ModelMixin, ConfigMixin): current_step = 0, context2 = None, is_uncond=False, - max_steps = 0 - + max_steps = 0, + slg_layers=None, ): r""" Forward pass through the diffusion model @@ -851,8 +851,8 @@ class WanModel(ModelMixin, ConfigMixin): # context=context, context_lens=context_lens) - for l, block in enumerate(self.blocks): - offload.shared_state["layer"] = l + for block_idx, block in enumerate(self.blocks): + offload.shared_state["layer"] = block_idx if "refresh" in offload.shared_state: del offload.shared_state["refresh"] offload.shared_state["callback"](-1, -1, True) @@ -861,9 +861,16 @@ class WanModel(ModelMixin, ConfigMixin): return None, None else: return [None] - for i, (x, context) in enumerate(zip(x_list, context_list)): - x_list[i] = block(x, context = context, e= e0, **kwargs) - del x + + if slg_layers is not None and block_idx in slg_layers: + if is_uncond and not joint_pass: + continue + x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs) + + else: + for i, (x, context) in enumerate(zip(x_list, context_list)): + x_list[i] = block(x, context = context, e= e0, **kwargs) + del x if self.enable_teacache: if joint_pass: diff --git a/wan/text2video.py b/wan/text2video.py index d99266e..385cdfd 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -119,20 +119,23 @@ class WanT2V: self.sample_neg_prompt = config.sample_neg_prompt def generate(self, - input_prompt, - size=(1280, 720), - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True, - callback = None, - enable_RIFLEx = None, - VAE_tile_size = 0, - joint_pass = False, + input_prompt, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + callback = None, + enable_RIFLEx = None, + VAE_tile_size = 0, + joint_pass = False, + slg_layers = None, + slg_start = 0.0, + slg_end = 1.0, ): r""" Generates video frames from text prompt using diffusion process. @@ -253,6 +256,9 @@ class WanT2V: callback(-1, None) for i, t in enumerate(tqdm(timesteps)): latent_model_input = latents + slg_layers_local = None + if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): + slg_layers_local = slg_layers timestep = [t] offload.set_step_no_for_lora(self.model, i) timestep = torch.stack(timestep) @@ -260,7 +266,7 @@ class WanT2V: # self.model.to(self.device) if joint_pass: noise_pred_cond, noise_pred_uncond = self.model( - latent_model_input, t=timestep,current_step=i, **arg_both) + latent_model_input, t=timestep,current_step=i, slg_layers=slg_layers_local, **arg_both) if self._interrupt: return None else: @@ -269,7 +275,7 @@ class WanT2V: if self._interrupt: return None noise_pred_uncond = self.model( - latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0] + latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0] if self._interrupt: return None