mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
optimization for i2v with CausVid
This commit is contained in:
parent
9c3c6f3b74
commit
6706709230
@ -492,8 +492,7 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||||
@staticmethod
|
def preprocess_loras(self, model_filename, sd):
|
||||||
def preprocess_loras(model_filename, sd):
|
|
||||||
if not "i2v" in model_filename:
|
if not "i2v" in model_filename:
|
||||||
return sd
|
return sd
|
||||||
new_sd = {}
|
new_sd = {}
|
||||||
|
|||||||
@ -330,8 +330,11 @@ class WanI2V:
|
|||||||
'current_step' :i,
|
'current_step' :i,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if guide_scale == 1:
|
||||||
if joint_pass:
|
noise_pred = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
|
||||||
|
if self._interrupt:
|
||||||
|
return None
|
||||||
|
elif joint_pass:
|
||||||
if audio_proj == None:
|
if audio_proj == None:
|
||||||
noise_pred_cond, noise_pred_uncond = self.model(
|
noise_pred_cond, noise_pred_uncond = self.model(
|
||||||
[latent_model_input, latent_model_input],
|
[latent_model_input, latent_model_input],
|
||||||
@ -347,13 +350,7 @@ class WanI2V:
|
|||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
|
||||||
[latent_model_input],
|
|
||||||
context=[context],
|
|
||||||
audio_scale = None if audio_scale == None else [audio_scale],
|
|
||||||
x_id=0,
|
|
||||||
**kwargs,
|
|
||||||
)[0]
|
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -377,6 +374,7 @@ class WanI2V:
|
|||||||
return None
|
return None
|
||||||
del latent_model_input
|
del latent_model_input
|
||||||
|
|
||||||
|
if guide_scale > 1:
|
||||||
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
||||||
if cfg_star_switch:
|
if cfg_star_switch:
|
||||||
positive_flat = noise_pred_cond.view(batch_size, -1)
|
positive_flat = noise_pred_cond.view(batch_size, -1)
|
||||||
@ -393,6 +391,7 @@ class WanI2V:
|
|||||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
||||||
|
|
||||||
noise_pred_uncond, noise_pred_noaudio = None, None
|
noise_pred_uncond, noise_pred_noaudio = None, None
|
||||||
temp_x0 = sample_scheduler.step(
|
temp_x0 = sample_scheduler.step(
|
||||||
noise_pred.unsqueeze(0),
|
noise_pred.unsqueeze(0),
|
||||||
|
|||||||
@ -589,8 +589,7 @@ class MLPProj(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class WanModel(ModelMixin, ConfigMixin):
|
class WanModel(ModelMixin, ConfigMixin):
|
||||||
@staticmethod
|
def preprocess_loras(self, model_filename, sd):
|
||||||
def preprocess_loras(model_filename, sd):
|
|
||||||
|
|
||||||
first = next(iter(sd), None)
|
first = next(iter(sd), None)
|
||||||
if first == None:
|
if first == None:
|
||||||
@ -634,8 +633,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
print(f"Lora alpha'{alpha_key}' is missing")
|
print(f"Lora alpha'{alpha_key}' is missing")
|
||||||
new_sd.update(new_alphas)
|
new_sd.update(new_alphas)
|
||||||
sd = new_sd
|
sd = new_sd
|
||||||
|
from wgp import test_class_i2v
|
||||||
if "text2video" in model_filename:
|
if not test_class_i2v(model_filename):
|
||||||
new_sd = {}
|
new_sd = {}
|
||||||
# convert loras for i2v to t2v
|
# convert loras for i2v to t2v
|
||||||
for k,v in sd.items():
|
for k,v in sd.items():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user