mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
Merge 8f11f5ab6e
into ec902046f6
This commit is contained in:
commit
1bd6028362
@ -99,7 +99,8 @@ class WanI2V:
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
#self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if t5_fsdp or dit_fsdp or use_usp:
|
||||
@ -178,6 +179,7 @@ class WanI2V:
|
||||
- H: Frame height (from max_area)
|
||||
- W: Frame width from max_area)
|
||||
"""
|
||||
offload_model = False
|
||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
||||
|
||||
F = frame_num
|
||||
@ -298,7 +300,19 @@ class WanI2V:
|
||||
if offload_model:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
# import gc
|
||||
# del self.text_encoder
|
||||
# del self.clip
|
||||
# del self.vae
|
||||
# gc.collect() # 立即触发垃圾回收
|
||||
# torch.cuda.empty_cache() # 清空CUDA缓存
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# start_mem = torch.cuda.memory_allocated()
|
||||
# print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
|
||||
|
||||
# self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
timestep = [t]
|
||||
@ -332,6 +346,9 @@ class WanI2V:
|
||||
x0 = [latent.to(self.device)]
|
||||
del latent_model_input, timestep
|
||||
|
||||
# peak_mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
|
||||
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -83,6 +83,8 @@ class WanRMSNorm(nn.Module):
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
if self.weight.dtype != torch.bfloat16:
|
||||
self.weight.data = self.weight.data.to(dtype=torch.bfloat16)
|
||||
return self._norm(x.float()).type_as(x) * self.weight
|
||||
|
||||
def _norm(self, x):
|
||||
@ -294,6 +296,9 @@ class WanAttentionBlock(nn.Module):
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
if self.modulation.dtype != torch.float16:
|
||||
# 如果不是 fp16,则转换为 fp16
|
||||
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = (self.modulation + e).chunk(6, dim=1)
|
||||
assert e[0].dtype == torch.float32
|
||||
@ -341,6 +346,9 @@ class Head(nn.Module):
|
||||
e(Tensor): Shape [B, C]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
if self.modulation.dtype != torch.float16:
|
||||
# 如果不是 fp16,则转换为 fp16
|
||||
self.modulation.data = self.modulation.data.to(dtype=torch.float16)
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
|
@ -11,6 +11,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -25,6 +26,19 @@ from .utils.fm_solvers import (
|
||||
)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
# def convert_linear_conv_to_fp8(module):
|
||||
# for name, child in module.named_children():
|
||||
# # 递归处理子模块
|
||||
# convert_linear_conv_to_fp8(child)
|
||||
|
||||
# # 判断是否为 Linear 或 Conv 层
|
||||
# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
# # 转换权重
|
||||
# if hasattr(child, 'weight') and child.weight is not None:
|
||||
# # 保留 Parameter 类型,仅修改数据
|
||||
# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
|
||||
# # 可选:转换偏置(根据需求开启)
|
||||
|
||||
|
||||
class WanT2V:
|
||||
|
||||
@ -84,7 +98,9 @@ class WanT2V:
|
||||
device=self.device)
|
||||
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
|
||||
#self.model = WanModel.from_pretrained(checkpoint_dir )
|
||||
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if use_usp:
|
||||
@ -106,7 +122,9 @@ class WanT2V:
|
||||
dist.barrier()
|
||||
if dit_fsdp:
|
||||
self.model = shard_fn(self.model)
|
||||
# convert_linear_conv_to_fp8(self.model)
|
||||
else:
|
||||
# convert_linear_conv_to_fp8(self.model)
|
||||
self.model.to(self.device)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
@ -156,6 +174,7 @@ class WanT2V:
|
||||
- W: Frame width from size)
|
||||
"""
|
||||
# preprocess
|
||||
#offload_model = False
|
||||
F = frame_num
|
||||
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||
size[1] // self.vae_stride[1],
|
||||
@ -229,6 +248,16 @@ class WanT2V:
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
# import gc
|
||||
# del self.text_encoder
|
||||
# del self.vae
|
||||
# gc.collect() # 立即触发垃圾回收
|
||||
# torch.cuda.empty_cache() # 清空CUDA缓存
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# start_mem = torch.cuda.memory_allocated()
|
||||
#print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
|
||||
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
@ -252,6 +281,9 @@ class WanT2V:
|
||||
return_dict=False,
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
|
||||
# peak_mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
|
Loading…
Reference in New Issue
Block a user