Fix msk frame_num

This commit is contained in:
Feiteng 2025-05-09 07:36:53 +00:00
parent 204f899b64
commit 57843a2111

View File

@ -204,7 +204,7 @@ class WanI2V:
generator=seed_g,
device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]