Compare commits

...

3 Commits

Author SHA1 Message Date
yupeng1111
926066bd64
Merge c5a6d87db7 into 841fe5237b 2025-12-08 17:53:31 +11:00
Shiguang Ai
841fe5237b
When frame_num is changed from 81 the tensor does not match (#533) 2025-12-02 14:53:33 +08:00
澎鹏
c5a6d87db7 fix frame size bug 2025-04-30 14:44:10 +08:00
2 changed files with 8 additions and 2 deletions

View File

@ -13,6 +13,7 @@ import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision
import torchvision.transforms.functional as TF
from tqdm import tqdm
@ -211,7 +212,12 @@ class WanFLF2V:
round(last_frame_size[1] * last_frame_resize_ratio),
]
# 2. center crop
last_frame = TF.center_crop(last_frame, last_frame_size)
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((last_frame_size[0], last_frame_size[1])),
torchvision.transforms.CenterCrop((first_frame_size[0], first_frame_size[1]))
])
last_frame = transform(last_frame)
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])

View File

@ -207,7 +207,7 @@ class WanI2V:
generator=seed_g,
device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]