diff --git a/wan/image2video.py b/wan/image2video.py index 6882c53..5e2a89d 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -99,7 +99,8 @@ class WanI2V: tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + #self.model = WanModel.from_pretrained(checkpoint_dir) + self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: @@ -178,6 +179,7 @@ class WanI2V: - H: Frame height (from max_area) - W: Frame width from max_area) """ + offload_model = False img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) F = frame_num @@ -298,7 +300,19 @@ class WanI2V: if offload_model: torch.cuda.empty_cache() - self.model.to(self.device) + + # import gc + # del self.text_encoder + # del self.clip + # del self.vae + # gc.collect() # 立即触发垃圾回收 + # torch.cuda.empty_cache() # 清空CUDA缓存 + # torch.cuda.reset_peak_memory_stats() + + # start_mem = torch.cuda.memory_allocated() + # print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB") + + # self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] timestep = [t] @@ -332,6 +346,9 @@ class WanI2V: x0 = [latent.to(self.device)] del latent_model_input, timestep + # peak_mem_bytes = torch.cuda.max_memory_allocated() + # print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB") + if offload_model: self.model.cpu() torch.cuda.empty_cache() diff --git a/wan/modules/model.py b/wan/modules/model.py index a5425da..d11b8b5 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -83,6 +83,8 @@ class WanRMSNorm(nn.Module): Args: x(Tensor): Shape [B, L, C] """ + if self.weight.dtype != torch.bfloat16: + self.weight.data = self.weight.data.to(dtype=torch.bfloat16) return self._norm(x.float()).type_as(x) * self.weight def _norm(self, x): @@ -294,6 +296,9 @@ class WanAttentionBlock(nn.Module): freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ assert e.dtype == torch.float32 + if self.modulation.dtype != torch.float16: + # 如果不是 fp16,则转换为 fp16 + self.modulation.data = self.modulation.data.to(dtype=torch.float16) with amp.autocast(dtype=torch.float32): e = (self.modulation + e).chunk(6, dim=1) assert e[0].dtype == torch.float32 @@ -341,6 +346,9 @@ class Head(nn.Module): e(Tensor): Shape [B, C] """ assert e.dtype == torch.float32 + if self.modulation.dtype != torch.float16: + # 如果不是 fp16,则转换为 fp16 + self.modulation.data = self.modulation.data.to(dtype=torch.float16) with amp.autocast(dtype=torch.float32): e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) diff --git a/wan/text2video.py b/wan/text2video.py index c518b61..6763997 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -11,6 +11,7 @@ from functools import partial import torch import torch.cuda.amp as amp +import torch.nn as nn import torch.distributed as dist from tqdm import tqdm @@ -25,6 +26,19 @@ from .utils.fm_solvers import ( ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +# def convert_linear_conv_to_fp8(module): +# for name, child in module.named_children(): +# # 递归处理子模块 +# convert_linear_conv_to_fp8(child) + +# # 判断是否为 Linear 或 Conv 层 +# if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): +# # 转换权重 +# if hasattr(child, 'weight') and child.weight is not None: +# # 保留 Parameter 类型,仅修改数据 +# child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn) +# # 可选:转换偏置(根据需求开启) + class WanT2V: @@ -84,7 +98,9 @@ class WanT2V: device=self.device) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn) + #self.model = WanModel.from_pretrained(checkpoint_dir ) + self.model.eval().requires_grad_(False) if use_usp: @@ -106,7 +122,9 @@ class WanT2V: dist.barrier() if dit_fsdp: self.model = shard_fn(self.model) + # convert_linear_conv_to_fp8(self.model) else: + # convert_linear_conv_to_fp8(self.model) self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt @@ -156,6 +174,7 @@ class WanT2V: - W: Frame width from size) """ # preprocess + #offload_model = False F = frame_num target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, size[1] // self.vae_stride[1], @@ -229,6 +248,16 @@ class WanT2V: arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} + + # import gc + # del self.text_encoder + # del self.vae + # gc.collect() # 立即触发垃圾回收 + # torch.cuda.empty_cache() # 清空CUDA缓存 + # torch.cuda.reset_peak_memory_stats() + + # start_mem = torch.cuda.memory_allocated() + #print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB") for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents @@ -252,6 +281,9 @@ class WanT2V: return_dict=False, generator=seed_g)[0] latents = [temp_x0.squeeze(0)] + + # peak_mem_bytes = torch.cuda.max_memory_allocated() + # print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB") x0 = latents if offload_model: