diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 43cf7f5..95faa4d 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -1308,7 +1308,7 @@ class WanModel(ModelMixin, ConfigMixin): kwargs["standin_phase"] = 2 if (current_step == 0 or not standin_cache_enabled) and x_id == 0: standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2) - standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)) ) + standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) ) standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype) standin_e = standin_ref = None