mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-18 21:22:11 +00:00
Compare commits
3 Commits
86a4de8ab1
...
926066bd64
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
926066bd64 | ||
|
|
841fe5237b | ||
|
|
c5a6d87db7 |
@ -13,6 +13,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.distributed as dist
|
||||
import torchvision
|
||||
import torchvision.transforms.functional as TF
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -211,7 +212,12 @@ class WanFLF2V:
|
||||
round(last_frame_size[1] * last_frame_resize_ratio),
|
||||
]
|
||||
# 2. center crop
|
||||
last_frame = TF.center_crop(last_frame, last_frame_size)
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize((last_frame_size[0], last_frame_size[1])),
|
||||
torchvision.transforms.CenterCrop((first_frame_size[0], first_frame_size[1]))
|
||||
])
|
||||
|
||||
last_frame = transform(last_frame)
|
||||
|
||||
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
|
||||
self.patch_size[1] * self.patch_size[2])
|
||||
|
||||
@ -207,7 +207,7 @@ class WanI2V:
|
||||
generator=seed_g,
|
||||
device=self.device)
|
||||
|
||||
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
||||
msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([
|
||||
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user