diff --git a/wan/image2video.py b/wan/image2video.py index 6882c53..17849b2 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -207,7 +207,7 @@ class WanI2V: generator=seed_g, device=self.device) - msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) + msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) msk[:, 1:] = 0 msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]