From 8f11f5ab6e2bf5a71f5fdde53efe7b0bb6d53360 Mon Sep 17 00:00:00 2001 From: Tianhao Date: Sun, 16 Mar 2025 12:07:18 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=99wan=E5=8A=A0=E5=85=A5fp8=E9=87=8F?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wan/image2video.py | 21 +++++++++++++++++++-- wan/modules/model.py | 8 ++++++++ wan/text2video.py | 34 +++++++++++++++++++++++++++++++++- 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/wan/image2video.py b/wan/image2video.py index 468f17c..8402797 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -96,7 +96,8 @@ class WanI2V: tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + #self.model = WanModel.from_pretrained(checkpoint_dir) + self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: @@ -174,6 +175,7 @@ class WanI2V: - H: Frame height (from max_area) - W: Frame width from max_area) """ + offload_model = False img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) F = frame_num @@ -295,7 +297,19 @@ class WanI2V: if offload_model: torch.cuda.empty_cache() - self.model.to(self.device) + + # import gc + # del self.text_encoder + # del self.clip + # del self.vae + # gc.collect() # 立即触发垃圾回收 + # torch.cuda.empty_cache() # 清空CUDA缓存 + # torch.cuda.reset_peak_memory_stats() + + # start_mem = torch.cuda.memory_allocated() + # print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB") + + # self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] timestep = [t] @@ -329,6 +343,9 @@ class WanI2V: x0 = [latent.to(self.device)] del latent_model_input, timestep + # peak_mem_bytes = torch.cuda.max_memory_allocated() + # print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB") + if offload_model: self.model.cpu() torch.cuda.empty_cache() diff --git a/wan/modules/model.py b/wan/modules/model.py index b65021c..7605f8a 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -80,6 +80,8 @@ class WanRMSNorm(nn.Module): Args: x(Tensor): Shape [B, L, C] """ + if self.weight.dtype != torch.bfloat16: + self.weight.data = self.weight.data.to(dtype=torch.bfloat16) return self._norm(x.float()).type_as(x) * self.weight def _norm(self, x): @@ -290,6 +292,9 @@ class WanAttentionBlock(nn.Module): freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ assert e.dtype == torch.float32 + if self.modulation.dtype != torch.float16: + # 如果不是 fp16,则转换为 fp16 + self.modulation.data = self.modulation.data.to(dtype=torch.float16) with amp.autocast(dtype=torch.float32): e = (self.modulation + e).chunk(6, dim=1) assert e[0].dtype == torch.float32 @@ -337,6 +342,9 @@ class Head(nn.Module): e(Tensor): Shape [B, C] """ assert e.dtype == torch.float32 + if self.modulation.dtype != torch.float16: + # 如果不是 fp16,则转换为 fp16 + self.modulation.data = self.modulation.data.to(dtype=torch.float16) with amp.autocast(dtype=torch.float32): e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) diff --git a/wan/text2video.py b/wan/text2video.py index 2400545..8bb1917 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -11,6 +11,7 @@ from functools import partial import torch import torch.cuda.amp as amp +import torch.nn as nn import torch.distributed as dist from tqdm import tqdm @@ -22,6 +23,19 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +# def convert_linear_conv_to_fp8(module): +# for name, child in module.named_children(): +# # 递归处理子模块 +# convert_linear_conv_to_fp8(child) + +# # 判断是否为 Linear 或 Conv 层 +# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): +# # 转换权重 +# if hasattr(child, 'weight') and child.weight is not None: +# # 保留 Parameter 类型,仅修改数据 +# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn) +# # 可选:转换偏置(根据需求开启) + class WanT2V: @@ -81,7 +95,9 @@ class WanT2V: device=self.device) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn) + #self.model = WanModel.from_pretrained(checkpoint_dir ) + self.model.eval().requires_grad_(False) if use_usp: @@ -102,7 +118,9 @@ class WanT2V: dist.barrier() if dit_fsdp: self.model = shard_fn(self.model) + # convert_linear_conv_to_fp8(self.model) else: + # convert_linear_conv_to_fp8(self.model) self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt @@ -152,6 +170,7 @@ class WanT2V: - W: Frame width from size) """ # preprocess + #offload_model = False F = frame_num target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, size[1] // self.vae_stride[1], @@ -225,6 +244,16 @@ class WanT2V: arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} + + # import gc + # del self.text_encoder + # del self.vae + # gc.collect() # 立即触发垃圾回收 + # torch.cuda.empty_cache() # 清空CUDA缓存 + # torch.cuda.reset_peak_memory_stats() + + # start_mem = torch.cuda.memory_allocated() + #print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB") for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents @@ -248,6 +277,9 @@ class WanT2V: return_dict=False, generator=seed_g)[0] latents = [temp_x0.squeeze(0)] + + # peak_mem_bytes = torch.cuda.max_memory_allocated() + # print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB") x0 = latents if offload_model: