给wan加入fp8量化

This commit is contained in:
Tianhao 2025-03-16 12:07:18 +08:00
parent b58b7c5737
commit 8f11f5ab6e
3 changed files with 60 additions and 3 deletions

View File

@ -96,7 +96,8 @@ class WanI2V:
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}") logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir) #self.model = WanModel.from_pretrained(checkpoint_dir)
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp: if t5_fsdp or dit_fsdp or use_usp:
@ -174,6 +175,7 @@ class WanI2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
offload_model = False
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num F = frame_num
@ -295,7 +297,19 @@ class WanI2V:
if offload_model: if offload_model:
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.model.to(self.device)
# import gc
# del self.text_encoder
# del self.clip
# del self.vae
# gc.collect() # 立即触发垃圾回收
# torch.cuda.empty_cache() # 清空CUDA缓存
# torch.cuda.reset_peak_memory_stats()
# start_mem = torch.cuda.memory_allocated()
# print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
# self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]
@ -329,6 +343,9 @@ class WanI2V:
x0 = [latent.to(self.device)] x0 = [latent.to(self.device)]
del latent_model_input, timestep del latent_model_input, timestep
# peak_mem_bytes = torch.cuda.max_memory_allocated()
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -80,6 +80,8 @@ class WanRMSNorm(nn.Module):
Args: Args:
x(Tensor): Shape [B, L, C] x(Tensor): Shape [B, L, C]
""" """
if self.weight.dtype != torch.bfloat16:
self.weight.data = self.weight.data.to(dtype=torch.bfloat16)
return self._norm(x.float()).type_as(x) * self.weight return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x): def _norm(self, x):
@ -290,6 +292,9 @@ class WanAttentionBlock(nn.Module):
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
assert e.dtype == torch.float32 assert e.dtype == torch.float32
if self.modulation.dtype != torch.float16:
# 如果不是 fp16则转换为 fp16
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
with amp.autocast(dtype=torch.float32): with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1) e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32 assert e[0].dtype == torch.float32
@ -337,6 +342,9 @@ class Head(nn.Module):
e(Tensor): Shape [B, C] e(Tensor): Shape [B, C]
""" """
assert e.dtype == torch.float32 assert e.dtype == torch.float32
if self.modulation.dtype != torch.float16:
# 如果不是 fp16则转换为 fp16
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
with amp.autocast(dtype=torch.float32): with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))

View File

@ -11,6 +11,7 @@ from functools import partial
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
@ -22,6 +23,19 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps) get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
# def convert_linear_conv_to_fp8(module):
# for name, child in module.named_children():
# # 递归处理子模块
# convert_linear_conv_to_fp8(child)
# # 判断是否为 Linear 或 Conv 层
# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
# # 转换权重
# if hasattr(child, 'weight') and child.weight is not None:
# # 保留 Parameter 类型,仅修改数据
# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
# # 可选:转换偏置(根据需求开启)
class WanT2V: class WanT2V:
@ -81,7 +95,9 @@ class WanT2V:
device=self.device) device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}") logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir) self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
#self.model = WanModel.from_pretrained(checkpoint_dir )
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if use_usp: if use_usp:
@ -102,7 +118,9 @@ class WanT2V:
dist.barrier() dist.barrier()
if dit_fsdp: if dit_fsdp:
self.model = shard_fn(self.model) self.model = shard_fn(self.model)
# convert_linear_conv_to_fp8(self.model)
else: else:
# convert_linear_conv_to_fp8(self.model)
self.model.to(self.device) self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
@ -152,6 +170,7 @@ class WanT2V:
- W: Frame width from size) - W: Frame width from size)
""" """
# preprocess # preprocess
#offload_model = False
F = frame_num F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1], size[1] // self.vae_stride[1],
@ -226,6 +245,16 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
# import gc
# del self.text_encoder
# del self.vae
# gc.collect() # 立即触发垃圾回收
# torch.cuda.empty_cache() # 清空CUDA缓存
# torch.cuda.reset_peak_memory_stats()
# start_mem = torch.cuda.memory_allocated()
#print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
@ -249,6 +278,9 @@ class WanT2V:
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
# peak_mem_bytes = torch.cuda.max_memory_allocated()
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()