This commit is contained in:
yupeng1111 2025-05-28 11:39:09 +10:00 committed by GitHub
commit 0b95ddb46d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,7 @@ import numpy as np
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
import torchvision
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from tqdm import tqdm from tqdm import tqdm
@ -211,7 +212,12 @@ class WanFLF2V:
round(last_frame_size[1] * last_frame_resize_ratio), round(last_frame_size[1] * last_frame_resize_ratio),
] ]
# 2. center crop # 2. center crop
last_frame = TF.center_crop(last_frame, last_frame_size) transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((last_frame_size[0], last_frame_size[1])),
torchvision.transforms.CenterCrop((first_frame_size[0], first_frame_size[1]))
])
last_frame = transform(last_frame)
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2]) self.patch_size[1] * self.patch_size[2])