optimization for i2v with CausVid

This commit is contained in:
DeepBeepMeep 2025-05-20 23:00:32 +02:00
parent 9c3c6f3b74
commit 6706709230
3 changed files with 25 additions and 28 deletions

View File

@ -492,8 +492,7 @@ class MMSingleStreamBlock(nn.Module):
return img, txt return img, txt
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
@staticmethod def preprocess_loras(self, model_filename, sd):
def preprocess_loras(model_filename, sd):
if not "i2v" in model_filename: if not "i2v" in model_filename:
return sd return sd
new_sd = {} new_sd = {}

View File

@ -330,8 +330,11 @@ class WanI2V:
'current_step' :i, 'current_step' :i,
}) })
if guide_scale == 1:
if joint_pass: noise_pred = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
if self._interrupt:
return None
elif joint_pass:
if audio_proj == None: if audio_proj == None:
noise_pred_cond, noise_pred_uncond = self.model( noise_pred_cond, noise_pred_uncond = self.model(
[latent_model_input, latent_model_input], [latent_model_input, latent_model_input],
@ -347,13 +350,7 @@ class WanI2V:
if self._interrupt: if self._interrupt:
return None return None
else: else:
noise_pred_cond = self.model( noise_pred_cond = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
[latent_model_input],
context=[context],
audio_scale = None if audio_scale == None else [audio_scale],
x_id=0,
**kwargs,
)[0]
if self._interrupt: if self._interrupt:
return None return None
@ -377,22 +374,24 @@ class WanI2V:
return None return None
del latent_model_input del latent_model_input
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ if guide_scale > 1:
if cfg_star_switch: # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
positive_flat = noise_pred_cond.view(batch_size, -1) if cfg_star_switch:
negative_flat = noise_pred_uncond.view(batch_size, -1) positive_flat = noise_pred_cond.view(batch_size, -1)
negative_flat = noise_pred_uncond.view(batch_size, -1)
alpha = optimized_scale(positive_flat,negative_flat) alpha = optimized_scale(positive_flat,negative_flat)
alpha = alpha.view(batch_size, 1, 1, 1) alpha = alpha.view(batch_size, 1, 1, 1)
if (i <= cfg_zero_step): if (i <= cfg_zero_step):
noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred... noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
else:
noise_pred_uncond *= alpha
if audio_scale == None:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
else: else:
noise_pred_uncond *= alpha noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
if audio_scale == None:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
noise_pred_uncond, noise_pred_noaudio = None, None noise_pred_uncond, noise_pred_noaudio = None, None
temp_x0 = sample_scheduler.step( temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0), noise_pred.unsqueeze(0),

View File

@ -589,8 +589,7 @@ class MLPProj(torch.nn.Module):
class WanModel(ModelMixin, ConfigMixin): class WanModel(ModelMixin, ConfigMixin):
@staticmethod def preprocess_loras(self, model_filename, sd):
def preprocess_loras(model_filename, sd):
first = next(iter(sd), None) first = next(iter(sd), None)
if first == None: if first == None:
@ -634,8 +633,8 @@ class WanModel(ModelMixin, ConfigMixin):
print(f"Lora alpha'{alpha_key}' is missing") print(f"Lora alpha'{alpha_key}' is missing")
new_sd.update(new_alphas) new_sd.update(new_alphas)
sd = new_sd sd = new_sd
from wgp import test_class_i2v
if "text2video" in model_filename: if not test_class_i2v(model_filename):
new_sd = {} new_sd = {}
# convert loras for i2v to t2v # convert loras for i2v to t2v
for k,v in sd.items(): for k,v in sd.items():