From 67067092300ba1baa80aa600bb161622ee142b7f Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 20 May 2025 23:00:32 +0200 Subject: [PATCH] optimization for i2v with CausVid --- hyvideo/modules/models.py | 3 +-- wan/image2video.py | 43 +++++++++++++++++++-------------------- wan/modules/model.py | 7 +++---- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/hyvideo/modules/models.py b/hyvideo/modules/models.py index b8276fd..ef3f766 100644 --- a/hyvideo/modules/models.py +++ b/hyvideo/modules/models.py @@ -492,8 +492,7 @@ class MMSingleStreamBlock(nn.Module): return img, txt class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): - @staticmethod - def preprocess_loras(model_filename, sd): + def preprocess_loras(self, model_filename, sd): if not "i2v" in model_filename: return sd new_sd = {} diff --git a/wan/image2video.py b/wan/image2video.py index d8c5aaa..6a30c91 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -330,8 +330,11 @@ class WanI2V: 'current_step' :i, }) - - if joint_pass: + if guide_scale == 1: + noise_pred = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0] + if self._interrupt: + return None + elif joint_pass: if audio_proj == None: noise_pred_cond, noise_pred_uncond = self.model( [latent_model_input, latent_model_input], @@ -347,13 +350,7 @@ class WanI2V: if self._interrupt: return None else: - noise_pred_cond = self.model( - [latent_model_input], - context=[context], - audio_scale = None if audio_scale == None else [audio_scale], - x_id=0, - **kwargs, - )[0] + noise_pred_cond = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0] if self._interrupt: return None @@ -377,22 +374,24 @@ class WanI2V: return None del latent_model_input - # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ - if cfg_star_switch: - positive_flat = noise_pred_cond.view(batch_size, -1) - negative_flat = noise_pred_uncond.view(batch_size, -1) + if guide_scale > 1: + # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ + if cfg_star_switch: + positive_flat = noise_pred_cond.view(batch_size, -1) + negative_flat = noise_pred_uncond.view(batch_size, -1) - alpha = optimized_scale(positive_flat,negative_flat) - alpha = alpha.view(batch_size, 1, 1, 1) + alpha = optimized_scale(positive_flat,negative_flat) + alpha = alpha.view(batch_size, 1, 1, 1) - if (i <= cfg_zero_step): - noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred... + if (i <= cfg_zero_step): + noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred... + else: + noise_pred_uncond *= alpha + if audio_scale == None: + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) else: - noise_pred_uncond *= alpha - if audio_scale == None: - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) - else: - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio) + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio) + noise_pred_uncond, noise_pred_noaudio = None, None temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), diff --git a/wan/modules/model.py b/wan/modules/model.py index 9e01486..39f5796 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -589,8 +589,7 @@ class MLPProj(torch.nn.Module): class WanModel(ModelMixin, ConfigMixin): - @staticmethod - def preprocess_loras(model_filename, sd): + def preprocess_loras(self, model_filename, sd): first = next(iter(sd), None) if first == None: @@ -634,8 +633,8 @@ class WanModel(ModelMixin, ConfigMixin): print(f"Lora alpha'{alpha_key}' is missing") new_sd.update(new_alphas) sd = new_sd - - if "text2video" in model_filename: + from wgp import test_class_i2v + if not test_class_i2v(model_filename): new_sd = {} # convert loras for i2v to t2v for k,v in sd.items():