mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-18 13:12:16 +00:00
Compare commits
7 Commits
59a3d1ba28
...
b8deedd627
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8deedd627 | ||
|
|
854bd88e7f | ||
|
|
8177ee5bc6 | ||
|
|
f134d60bcc | ||
|
|
bcc437daed | ||
|
|
e4f90fa81f | ||
|
|
8f11f5ab6e |
@ -36,6 +36,10 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
## Community Works
|
||||
If your work has improved **Wan2.1** and you would like more people to see it, please inform us.
|
||||
- [LightX2V](https://github.com/ModelTC/LightX2V), a lightweight and efficient video generation framework that integrates **Wan2.1** and **Wan2.2**, supports multiple engineering acceleration techniques for fast inference, which can run on RTX 5090 and RTX 4060 (8GB VRAM).
|
||||
- [DriVerse](https://github.com/shalfun/DriVerse), an autonomous driving world model based on **Wan2.1-14B-I2V**, generates future driving videos conditioned on any scene frame and given trajectory. Refer to the [project page](https://github.com/shalfun/DriVerse/tree/main) for more examples.
|
||||
- [Training-Free-WAN-Editing](https://github.com/KyujinHan/Awesome-Training-Free-WAN2.1-Editing), built on **Wan2.1-T2V-1.3B**, allows training-free video editing with image-based training-free methods, such as [FlowEdit](https://arxiv.org/abs/2412.08629) and [FlowAlign](https://arxiv.org/abs/2505.23145).
|
||||
- [Wan-Move](https://github.com/ali-vilab/Wan-Move), accepted to NeurIPS 2025, a framework that brings **Wan2.1-I2V-14B** to SOTA fine-grained, point-level motion control! Refer to [their project page](https://wan-move.github.io/) for more information.
|
||||
- [EchoShot](https://github.com/JoHnneyWang/EchoShot), a native multi-shot portrait video generation model based on **Wan2.1-T2V-1.3B**, allows generation of multiple video clips featuring the same character as well as highly flexible content controllability. Refer to [their project page](https://johnneywang.github.io/EchoShot-webpage/) for more information.
|
||||
- [AniCrafter](https://github.com/MyNiuuu/AniCrafter), a human-centric animation model based on **Wan2.1-14B-I2V**, controls the Video Diffusion Models with 3DGS Avatars to insert and animate anyone into any scene following given motion sequences. Refer to the [project page](https://myniuuu.github.io/AniCrafter) for more examples.
|
||||
- [HyperMotion](https://vivocameraresearch.github.io/hypermotion/), a human image animation framework based on **Wan2.1**, addresses the challenge of generating complex human body motions in pose-guided animation. Refer to [their website](https://vivocameraresearch.github.io/magictryon/) for more examples.
|
||||
|
||||
@ -99,7 +99,8 @@ class WanI2V:
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
#self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if t5_fsdp or dit_fsdp or use_usp:
|
||||
@ -178,6 +179,7 @@ class WanI2V:
|
||||
- H: Frame height (from max_area)
|
||||
- W: Frame width from max_area)
|
||||
"""
|
||||
offload_model = False
|
||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
||||
|
||||
F = frame_num
|
||||
@ -298,7 +300,19 @@ class WanI2V:
|
||||
if offload_model:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
# import gc
|
||||
# del self.text_encoder
|
||||
# del self.clip
|
||||
# del self.vae
|
||||
# gc.collect() # 立即触发垃圾回收
|
||||
# torch.cuda.empty_cache() # 清空CUDA缓存
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# start_mem = torch.cuda.memory_allocated()
|
||||
# print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
|
||||
|
||||
# self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
timestep = [t]
|
||||
@ -332,6 +346,9 @@ class WanI2V:
|
||||
x0 = [latent.to(self.device)]
|
||||
del latent_model_input, timestep
|
||||
|
||||
# peak_mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
|
||||
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -83,6 +83,8 @@ class WanRMSNorm(nn.Module):
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
if self.weight.dtype != torch.bfloat16:
|
||||
self.weight.data = self.weight.data.to(dtype=torch.bfloat16)
|
||||
return self._norm(x.float()).type_as(x) * self.weight
|
||||
|
||||
def _norm(self, x):
|
||||
@ -294,6 +296,9 @@ class WanAttentionBlock(nn.Module):
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
if self.modulation.dtype != torch.float16:
|
||||
# 如果不是 fp16,则转换为 fp16
|
||||
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = (self.modulation + e).chunk(6, dim=1)
|
||||
assert e[0].dtype == torch.float32
|
||||
@ -341,6 +346,9 @@ class Head(nn.Module):
|
||||
e(Tensor): Shape [B, C]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
if self.modulation.dtype != torch.float16:
|
||||
# 如果不是 fp16,则转换为 fp16
|
||||
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
|
||||
@ -11,6 +11,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -25,6 +26,19 @@ from .utils.fm_solvers import (
|
||||
)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
# def convert_linear_conv_to_fp8(module):
|
||||
# for name, child in module.named_children():
|
||||
# # 递归处理子模块
|
||||
# convert_linear_conv_to_fp8(child)
|
||||
|
||||
# # 判断是否为 Linear 或 Conv 层
|
||||
# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
# # 转换权重
|
||||
# if hasattr(child, 'weight') and child.weight is not None:
|
||||
# # 保留 Parameter 类型,仅修改数据
|
||||
# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
|
||||
# # 可选:转换偏置(根据需求开启)
|
||||
|
||||
|
||||
class WanT2V:
|
||||
|
||||
@ -84,7 +98,9 @@ class WanT2V:
|
||||
device=self.device)
|
||||
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
|
||||
#self.model = WanModel.from_pretrained(checkpoint_dir )
|
||||
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if use_usp:
|
||||
@ -106,7 +122,9 @@ class WanT2V:
|
||||
dist.barrier()
|
||||
if dit_fsdp:
|
||||
self.model = shard_fn(self.model)
|
||||
# convert_linear_conv_to_fp8(self.model)
|
||||
else:
|
||||
# convert_linear_conv_to_fp8(self.model)
|
||||
self.model.to(self.device)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
@ -156,6 +174,7 @@ class WanT2V:
|
||||
- W: Frame width from size)
|
||||
"""
|
||||
# preprocess
|
||||
#offload_model = False
|
||||
F = frame_num
|
||||
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||
size[1] // self.vae_stride[1],
|
||||
@ -229,6 +248,16 @@ class WanT2V:
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
# import gc
|
||||
# del self.text_encoder
|
||||
# del self.vae
|
||||
# gc.collect() # 立即触发垃圾回收
|
||||
# torch.cuda.empty_cache() # 清空CUDA缓存
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# start_mem = torch.cuda.memory_allocated()
|
||||
#print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
|
||||
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
@ -252,6 +281,9 @@ class WanT2V:
|
||||
return_dict=False,
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
|
||||
# peak_mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user