Compare commits

...

3 Commits

Author SHA1 Message Date
yupeng1111
4037cb07e1
Merge c5a6d87db7 into 827906c30f 2025-06-05 17:51:35 +05:30
Shiwei Zhang
827906c30f
Update README.md 2025-06-05 10:02:00 +08:00
澎鹏
c5a6d87db7 fix frame size bug 2025-04-30 14:44:10 +08:00
2 changed files with 8 additions and 1 deletions

View File

@ -36,6 +36,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## Community Works
If your work has improved **Wan2.1** and you would like more people to see it, please inform us.
- [ATI](https://github.com/bytedance/ATI), built on **Wan2.1-I2V-14B**, is a trajectory-based motion-control framework that unifies object, local, and camera movements in video generation. Refer to [their website](https://anytraj.github.io/) for more examples.
- [Phantom](https://github.com/Phantom-video/Phantom) has developed a unified video generation framework for single and multi-subject references based on both **Wan2.1-T2V-1.3B** and **Wan2.1-T2V-14B**. Please refer to [their examples](https://github.com/Phantom-video/Phantom).
- [UniAnimate-DiT](https://github.com/ali-vilab/UniAnimate-DiT), based on **Wan2.1-14B-I2V**, has trained a Human image animation model and has open-sourced the inference and training code. Feel free to enjoy it!
- [CFG-Zero](https://github.com/WeichenFan/CFG-Zero-star) enhances **Wan2.1** (covering both T2V and I2V models) from the perspective of CFG.

View File

@ -13,6 +13,7 @@ import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision
import torchvision.transforms.functional as TF
from tqdm import tqdm
@ -211,7 +212,12 @@ class WanFLF2V:
round(last_frame_size[1] * last_frame_resize_ratio),
]
# 2. center crop
last_frame = TF.center_crop(last_frame, last_frame_size)
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((last_frame_size[0], last_frame_size[1])),
torchvision.transforms.CenterCrop((first_frame_size[0], first_frame_size[1]))
])
last_frame = transform(last_frame)
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])