给wan加入fp8量化

This commit is contained in:
Tianhao 2025-03-16 12:07:18 +08:00
parent b58b7c5737
commit 8f11f5ab6e
3 changed files with 60 additions and 3 deletions

View File

@ -96,7 +96,8 @@ class WanI2V:
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
#self.model = WanModel.from_pretrained(checkpoint_dir)
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
@ -174,6 +175,7 @@ class WanI2V:
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
offload_model = False
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
@ -295,7 +297,19 @@ class WanI2V:
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
# import gc
# del self.text_encoder
# del self.clip
# del self.vae
# gc.collect() # 立即触发垃圾回收
# torch.cuda.empty_cache() # 清空CUDA缓存
# torch.cuda.reset_peak_memory_stats()
# start_mem = torch.cuda.memory_allocated()
# print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
# self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
@ -329,6 +343,9 @@ class WanI2V:
x0 = [latent.to(self.device)]
del latent_model_input, timestep
# peak_mem_bytes = torch.cuda.max_memory_allocated()
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()

View File

@ -80,6 +80,8 @@ class WanRMSNorm(nn.Module):
Args:
x(Tensor): Shape [B, L, C]
"""
if self.weight.dtype != torch.bfloat16:
self.weight.data = self.weight.data.to(dtype=torch.bfloat16)
return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x):
@ -290,6 +292,9 @@ class WanAttentionBlock(nn.Module):
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
if self.modulation.dtype != torch.float16:
# 如果不是 fp16则转换为 fp16
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
@ -337,6 +342,9 @@ class Head(nn.Module):
e(Tensor): Shape [B, C]
"""
assert e.dtype == torch.float32
if self.modulation.dtype != torch.float16:
# 如果不是 fp16则转换为 fp16
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))

View File

@ -11,6 +11,7 @@ from functools import partial
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.distributed as dist
from tqdm import tqdm
@ -22,6 +23,19 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
# def convert_linear_conv_to_fp8(module):
# for name, child in module.named_children():
# # 递归处理子模块
# convert_linear_conv_to_fp8(child)
# # 判断是否为 Linear 或 Conv 层
# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
# # 转换权重
# if hasattr(child, 'weight') and child.weight is not None:
# # 保留 Parameter 类型,仅修改数据
# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
# # 可选:转换偏置(根据需求开启)
class WanT2V:
@ -81,7 +95,9 @@ class WanT2V:
device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
#self.model = WanModel.from_pretrained(checkpoint_dir )
self.model.eval().requires_grad_(False)
if use_usp:
@ -102,7 +118,9 @@ class WanT2V:
dist.barrier()
if dit_fsdp:
self.model = shard_fn(self.model)
# convert_linear_conv_to_fp8(self.model)
else:
# convert_linear_conv_to_fp8(self.model)
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
@ -152,6 +170,7 @@ class WanT2V:
- W: Frame width from size)
"""
# preprocess
#offload_model = False
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
@ -225,6 +244,16 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
# import gc
# del self.text_encoder
# del self.vae
# gc.collect() # 立即触发垃圾回收
# torch.cuda.empty_cache() # 清空CUDA缓存
# torch.cuda.reset_peak_memory_stats()
# start_mem = torch.cuda.memory_allocated()
#print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
@ -248,6 +277,9 @@ class WanT2V:
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
# peak_mem_bytes = torch.cuda.max_memory_allocated()
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
x0 = latents
if offload_model: