Fixed I2V Frames Hardcoded at 81! in WanI2V.generate()

Look for 20250226 pftq
This commit is contained in:
pftq 2025-02-26 13:18:01 -08:00 committed by GitHub
parent b5d66656ae
commit f9267254b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -204,13 +204,19 @@ class WanI2V:
generator=seed_g,
device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
#20250226 fix noise assuming frames hardcoded at 81 (21 latent frames)
latent_frame_num = (frame_num - 1) // self.vae_stride[0] + 1
noise = torch.randn(16, latent_frame_num, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device)
#msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk = torch.ones(1, F, lat_h, lat_w, device=self.device) # 20250226 pftq: Fixed frames being hardcoded as 81
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
#msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.view(1, latent_frame_num, 4, lat_h, lat_w) # 20250226 pftq: align to actual frames, not hardcoded 81 frames
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
@ -239,7 +245,8 @@ class WanI2V:
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, 80, h, w)
#torch.zeros(3, 80, h, w)
torch.zeros(3, F-1, h, w) # 20250226 pftq: fixed 80 being hardcoded frame-1
],
dim=1).to(self.device)
])[0]