fixed standin modulation bug

This commit is contained in:
deepbeepmeep 2025-08-30 04:08:22 +02:00
parent 492fa97c94
commit e15afefb3d

View File

@ -1308,7 +1308,7 @@ class WanModel(ModelMixin, ConfigMixin):
kwargs["standin_phase"] = 2 kwargs["standin_phase"] = 2
if (current_step == 0 or not standin_cache_enabled) and x_id == 0: if (current_step == 0 or not standin_cache_enabled) and x_id == 0:
standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2) standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2)
standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)) ) standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) )
standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype) standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype)
standin_e = standin_ref = None standin_e = standin_ref = None