diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py index 4f300ca..7fedf61 100644 --- a/wan/first_last_frame2video.py +++ b/wan/first_last_frame2video.py @@ -224,7 +224,7 @@ class WanFLF2V: generator=seed_g, device=self.device) - msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) msk[:, 1: -1] = 0 msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) diff --git a/wan/image2video.py b/wan/image2video.py index 5004f46..bb839a4 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -204,7 +204,7 @@ class WanI2V: generator=seed_g, device=self.device) - msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) msk[:, 1:] = 0 msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]