diff --git a/wan/image2video.py b/wan/image2video.py index 468f17c..66d8eea 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -204,13 +204,19 @@ class WanI2V: generator=seed_g, device=self.device) - msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) + #20250226 fix noise assuming frames hardcoded at 81 (21 latent frames) + latent_frame_num = (frame_num - 1) // self.vae_stride[0] + 1 + noise = torch.randn(16, latent_frame_num, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) + + #msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) # 20250226 pftq: Fixed frames being hardcoded as 81 msk[:, 1:] = 0 msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + #msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.view(1, latent_frame_num, 4, lat_h, lat_w) # 20250226 pftq: align to actual frames, not hardcoded 81 frames msk = msk.transpose(1, 2)[0] if n_prompt == "": @@ -239,7 +245,8 @@ class WanI2V: torch.nn.functional.interpolate( img[None].cpu(), size=(h, w), mode='bicubic').transpose( 0, 1), - torch.zeros(3, 80, h, w) + #torch.zeros(3, 80, h, w) + torch.zeros(3, F-1, h, w) # 20250226 pftq: fixed 80 being hardcoded frame-1 ], dim=1).to(self.device) ])[0]