Compare commits

...

3 Commits

Author SHA1 Message Date
yupeng1111
926066bd64
Merge c5a6d87db7 into 841fe5237b 2025-12-08 17:53:31 +11:00
Shiguang Ai
841fe5237b
When frame_num is changed from 81 the tensor does not match (#533) 2025-12-02 14:53:33 +08:00
澎鹏
c5a6d87db7 fix frame size bug 2025-04-30 14:44:10 +08:00
2 changed files with 8 additions and 2 deletions

View File

@ -13,6 +13,7 @@ import numpy as np
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
import torchvision
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from tqdm import tqdm from tqdm import tqdm
@ -211,7 +212,12 @@ class WanFLF2V:
round(last_frame_size[1] * last_frame_resize_ratio), round(last_frame_size[1] * last_frame_resize_ratio),
] ]
# 2. center crop # 2. center crop
last_frame = TF.center_crop(last_frame, last_frame_size) transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((last_frame_size[0], last_frame_size[1])),
torchvision.transforms.CenterCrop((first_frame_size[0], first_frame_size[1]))
])
last_frame = transform(last_frame)
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2]) self.patch_size[1] * self.patch_size[2])

View File

@ -207,7 +207,7 @@ class WanI2V:
generator=seed_g, generator=seed_g,
device=self.device) device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0 msk[:, 1:] = 0
msk = torch.concat([ msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]