mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 07:44:53 +00:00
给wan加入fp8量化
This commit is contained in:
parent
b58b7c5737
commit
8f11f5ab6e
@ -96,7 +96,8 @@ class WanI2V:
|
|||||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
#self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
|
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if t5_fsdp or dit_fsdp or use_usp:
|
if t5_fsdp or dit_fsdp or use_usp:
|
||||||
@ -174,6 +175,7 @@ class WanI2V:
|
|||||||
- H: Frame height (from max_area)
|
- H: Frame height (from max_area)
|
||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
|
offload_model = False
|
||||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
||||||
|
|
||||||
F = frame_num
|
F = frame_num
|
||||||
@ -295,7 +297,19 @@ class WanI2V:
|
|||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self.model.to(self.device)
|
|
||||||
|
# import gc
|
||||||
|
# del self.text_encoder
|
||||||
|
# del self.clip
|
||||||
|
# del self.vae
|
||||||
|
# gc.collect() # 立即触发垃圾回收
|
||||||
|
# torch.cuda.empty_cache() # 清空CUDA缓存
|
||||||
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
# start_mem = torch.cuda.memory_allocated()
|
||||||
|
# print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
|
||||||
|
|
||||||
|
# self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = [latent.to(self.device)]
|
latent_model_input = [latent.to(self.device)]
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
@ -329,6 +343,9 @@ class WanI2V:
|
|||||||
x0 = [latent.to(self.device)]
|
x0 = [latent.to(self.device)]
|
||||||
del latent_model_input, timestep
|
del latent_model_input, timestep
|
||||||
|
|
||||||
|
# peak_mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
|
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
|
||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -80,6 +80,8 @@ class WanRMSNorm(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L, C]
|
x(Tensor): Shape [B, L, C]
|
||||||
"""
|
"""
|
||||||
|
if self.weight.dtype != torch.bfloat16:
|
||||||
|
self.weight.data = self.weight.data.to(dtype=torch.bfloat16)
|
||||||
return self._norm(x.float()).type_as(x) * self.weight
|
return self._norm(x.float()).type_as(x) * self.weight
|
||||||
|
|
||||||
def _norm(self, x):
|
def _norm(self, x):
|
||||||
@ -290,6 +292,9 @@ class WanAttentionBlock(nn.Module):
|
|||||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
"""
|
"""
|
||||||
assert e.dtype == torch.float32
|
assert e.dtype == torch.float32
|
||||||
|
if self.modulation.dtype != torch.float16:
|
||||||
|
# 如果不是 fp16,则转换为 fp16
|
||||||
|
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
|
||||||
with amp.autocast(dtype=torch.float32):
|
with amp.autocast(dtype=torch.float32):
|
||||||
e = (self.modulation + e).chunk(6, dim=1)
|
e = (self.modulation + e).chunk(6, dim=1)
|
||||||
assert e[0].dtype == torch.float32
|
assert e[0].dtype == torch.float32
|
||||||
@ -337,6 +342,9 @@ class Head(nn.Module):
|
|||||||
e(Tensor): Shape [B, C]
|
e(Tensor): Shape [B, C]
|
||||||
"""
|
"""
|
||||||
assert e.dtype == torch.float32
|
assert e.dtype == torch.float32
|
||||||
|
if self.modulation.dtype != torch.float16:
|
||||||
|
# 如果不是 fp16,则转换为 fp16
|
||||||
|
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
|
||||||
with amp.autocast(dtype=torch.float32):
|
with amp.autocast(dtype=torch.float32):
|
||||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||||
|
@ -11,6 +11,7 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -22,6 +23,19 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
|||||||
get_sampling_sigmas, retrieve_timesteps)
|
get_sampling_sigmas, retrieve_timesteps)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
|
||||||
|
# def convert_linear_conv_to_fp8(module):
|
||||||
|
# for name, child in module.named_children():
|
||||||
|
# # 递归处理子模块
|
||||||
|
# convert_linear_conv_to_fp8(child)
|
||||||
|
|
||||||
|
# # 判断是否为 Linear 或 Conv 层
|
||||||
|
# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||||
|
# # 转换权重
|
||||||
|
# if hasattr(child, 'weight') and child.weight is not None:
|
||||||
|
# # 保留 Parameter 类型,仅修改数据
|
||||||
|
# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
|
||||||
|
# # 可选:转换偏置(根据需求开启)
|
||||||
|
|
||||||
|
|
||||||
class WanT2V:
|
class WanT2V:
|
||||||
|
|
||||||
@ -81,7 +95,9 @@ class WanT2V:
|
|||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
|
||||||
|
#self.model = WanModel.from_pretrained(checkpoint_dir )
|
||||||
|
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
@ -102,7 +118,9 @@ class WanT2V:
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
if dit_fsdp:
|
if dit_fsdp:
|
||||||
self.model = shard_fn(self.model)
|
self.model = shard_fn(self.model)
|
||||||
|
# convert_linear_conv_to_fp8(self.model)
|
||||||
else:
|
else:
|
||||||
|
# convert_linear_conv_to_fp8(self.model)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
@ -152,6 +170,7 @@ class WanT2V:
|
|||||||
- W: Frame width from size)
|
- W: Frame width from size)
|
||||||
"""
|
"""
|
||||||
# preprocess
|
# preprocess
|
||||||
|
#offload_model = False
|
||||||
F = frame_num
|
F = frame_num
|
||||||
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||||
size[1] // self.vae_stride[1],
|
size[1] // self.vae_stride[1],
|
||||||
@ -226,6 +245,16 @@ class WanT2V:
|
|||||||
arg_c = {'context': context, 'seq_len': seq_len}
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
|
# import gc
|
||||||
|
# del self.text_encoder
|
||||||
|
# del self.vae
|
||||||
|
# gc.collect() # 立即触发垃圾回收
|
||||||
|
# torch.cuda.empty_cache() # 清空CUDA缓存
|
||||||
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
# start_mem = torch.cuda.memory_allocated()
|
||||||
|
#print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
|
||||||
|
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
@ -249,6 +278,9 @@ class WanT2V:
|
|||||||
generator=seed_g)[0]
|
generator=seed_g)[0]
|
||||||
latents = [temp_x0.squeeze(0)]
|
latents = [temp_x0.squeeze(0)]
|
||||||
|
|
||||||
|
# peak_mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
|
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
|
||||||
|
|
||||||
x0 = latents
|
x0 = latents
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
|
Loading…
Reference in New Issue
Block a user