Compare commits

...

3 Commits

Author SHA1 Message Date
tiantiankaixin
3eeae18e84
Merge 8f11f5ab6e into 827906c30f 2025-06-05 20:31:08 +08:00
Shiwei Zhang
827906c30f
Update README.md 2025-06-05 10:02:00 +08:00
Tianhao
8f11f5ab6e 给wan加入fp8量化 2025-03-16 12:11:33 +08:00
4 changed files with 61 additions and 3 deletions

View File

@ -36,6 +36,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## Community Works ## Community Works
If your work has improved **Wan2.1** and you would like more people to see it, please inform us. If your work has improved **Wan2.1** and you would like more people to see it, please inform us.
- [ATI](https://github.com/bytedance/ATI), built on **Wan2.1-I2V-14B**, is a trajectory-based motion-control framework that unifies object, local, and camera movements in video generation. Refer to [their website](https://anytraj.github.io/) for more examples.
- [Phantom](https://github.com/Phantom-video/Phantom) has developed a unified video generation framework for single and multi-subject references based on both **Wan2.1-T2V-1.3B** and **Wan2.1-T2V-14B**. Please refer to [their examples](https://github.com/Phantom-video/Phantom). - [Phantom](https://github.com/Phantom-video/Phantom) has developed a unified video generation framework for single and multi-subject references based on both **Wan2.1-T2V-1.3B** and **Wan2.1-T2V-14B**. Please refer to [their examples](https://github.com/Phantom-video/Phantom).
- [UniAnimate-DiT](https://github.com/ali-vilab/UniAnimate-DiT), based on **Wan2.1-14B-I2V**, has trained a Human image animation model and has open-sourced the inference and training code. Feel free to enjoy it! - [UniAnimate-DiT](https://github.com/ali-vilab/UniAnimate-DiT), based on **Wan2.1-14B-I2V**, has trained a Human image animation model and has open-sourced the inference and training code. Feel free to enjoy it!
- [CFG-Zero](https://github.com/WeichenFan/CFG-Zero-star) enhances **Wan2.1** (covering both T2V and I2V models) from the perspective of CFG. - [CFG-Zero](https://github.com/WeichenFan/CFG-Zero-star) enhances **Wan2.1** (covering both T2V and I2V models) from the perspective of CFG.

View File

@ -99,7 +99,8 @@ class WanI2V:
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}") logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir) #self.model = WanModel.from_pretrained(checkpoint_dir)
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp: if t5_fsdp or dit_fsdp or use_usp:
@ -178,6 +179,7 @@ class WanI2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
offload_model = False
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num F = frame_num
@ -298,7 +300,19 @@ class WanI2V:
if offload_model: if offload_model:
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.model.to(self.device)
# import gc
# del self.text_encoder
# del self.clip
# del self.vae
# gc.collect() # 立即触发垃圾回收
# torch.cuda.empty_cache() # 清空CUDA缓存
# torch.cuda.reset_peak_memory_stats()
# start_mem = torch.cuda.memory_allocated()
# print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
# self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]
@ -332,6 +346,9 @@ class WanI2V:
x0 = [latent.to(self.device)] x0 = [latent.to(self.device)]
del latent_model_input, timestep del latent_model_input, timestep
# peak_mem_bytes = torch.cuda.max_memory_allocated()
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -83,6 +83,8 @@ class WanRMSNorm(nn.Module):
Args: Args:
x(Tensor): Shape [B, L, C] x(Tensor): Shape [B, L, C]
""" """
if self.weight.dtype != torch.bfloat16:
self.weight.data = self.weight.data.to(dtype=torch.bfloat16)
return self._norm(x.float()).type_as(x) * self.weight return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x): def _norm(self, x):
@ -294,6 +296,9 @@ class WanAttentionBlock(nn.Module):
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
assert e.dtype == torch.float32 assert e.dtype == torch.float32
if self.modulation.dtype != torch.float16:
# 如果不是 fp16则转换为 fp16
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
with amp.autocast(dtype=torch.float32): with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1) e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32 assert e[0].dtype == torch.float32
@ -341,6 +346,9 @@ class Head(nn.Module):
e(Tensor): Shape [B, C] e(Tensor): Shape [B, C]
""" """
assert e.dtype == torch.float32 assert e.dtype == torch.float32
if self.modulation.dtype != torch.float16:
# 如果不是 fp16则转换为 fp16
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
with amp.autocast(dtype=torch.float32): with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))

View File

@ -11,6 +11,7 @@ from functools import partial
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
@ -25,6 +26,19 @@ from .utils.fm_solvers import (
) )
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
# def convert_linear_conv_to_fp8(module):
# for name, child in module.named_children():
# # 递归处理子模块
# convert_linear_conv_to_fp8(child)
# # 判断是否为 Linear 或 Conv 层
# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
# # 转换权重
# if hasattr(child, 'weight') and child.weight is not None:
# # 保留 Parameter 类型,仅修改数据
# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
# # 可选:转换偏置(根据需求开启)
class WanT2V: class WanT2V:
@ -84,7 +98,9 @@ class WanT2V:
device=self.device) device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}") logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir) self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
#self.model = WanModel.from_pretrained(checkpoint_dir )
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if use_usp: if use_usp:
@ -106,7 +122,9 @@ class WanT2V:
dist.barrier() dist.barrier()
if dit_fsdp: if dit_fsdp:
self.model = shard_fn(self.model) self.model = shard_fn(self.model)
# convert_linear_conv_to_fp8(self.model)
else: else:
# convert_linear_conv_to_fp8(self.model)
self.model.to(self.device) self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
@ -156,6 +174,7 @@ class WanT2V:
- W: Frame width from size) - W: Frame width from size)
""" """
# preprocess # preprocess
#offload_model = False
F = frame_num F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1], size[1] // self.vae_stride[1],
@ -230,6 +249,16 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
# import gc
# del self.text_encoder
# del self.vae
# gc.collect() # 立即触发垃圾回收
# torch.cuda.empty_cache() # 清空CUDA缓存
# torch.cuda.reset_peak_memory_stats()
# start_mem = torch.cuda.memory_allocated()
#print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
@ -253,6 +282,9 @@ class WanT2V:
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
# peak_mem_bytes = torch.cuda.max_memory_allocated()
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()