mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Merge 8f11f5ab6e into 7c81b2f27d
				
					
				
			This commit is contained in:
		
						commit
						2596a8d5d5
					
				@ -99,7 +99,8 @@ class WanI2V:
 | 
			
		||||
            tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
			
		||||
        self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
			
		||||
        #self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
			
		||||
        self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
        if t5_fsdp or dit_fsdp or use_usp:
 | 
			
		||||
@ -178,6 +179,7 @@ class WanI2V:
 | 
			
		||||
                - H: Frame height (from max_area)
 | 
			
		||||
                - W: Frame width from max_area)
 | 
			
		||||
        """
 | 
			
		||||
        offload_model = False
 | 
			
		||||
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
 | 
			
		||||
 | 
			
		||||
        F = frame_num
 | 
			
		||||
@ -298,7 +300,19 @@ class WanI2V:
 | 
			
		||||
            if offload_model:
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
            self.model.to(self.device)
 | 
			
		||||
            
 | 
			
		||||
            # import gc
 | 
			
		||||
            # del self.text_encoder
 | 
			
		||||
            # del self.clip
 | 
			
		||||
            # del self.vae
 | 
			
		||||
            # gc.collect()  # 立即触发垃圾回收
 | 
			
		||||
            # torch.cuda.empty_cache()  # 清空CUDA缓存
 | 
			
		||||
            # torch.cuda.reset_peak_memory_stats()
 | 
			
		||||
            
 | 
			
		||||
            # start_mem = torch.cuda.memory_allocated()
 | 
			
		||||
            # print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
 | 
			
		||||
            
 | 
			
		||||
            # self.model.to(self.device)
 | 
			
		||||
            for _, t in enumerate(tqdm(timesteps)):
 | 
			
		||||
                latent_model_input = [latent.to(self.device)]
 | 
			
		||||
                timestep = [t]
 | 
			
		||||
@ -332,6 +346,9 @@ class WanI2V:
 | 
			
		||||
                x0 = [latent.to(self.device)]
 | 
			
		||||
                del latent_model_input, timestep
 | 
			
		||||
 | 
			
		||||
            # peak_mem_bytes = torch.cuda.max_memory_allocated()
 | 
			
		||||
            # print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
 | 
			
		||||
        
 | 
			
		||||
            if offload_model:
 | 
			
		||||
                self.model.cpu()
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
@ -83,6 +83,8 @@ class WanRMSNorm(nn.Module):
 | 
			
		||||
        Args:
 | 
			
		||||
            x(Tensor): Shape [B, L, C]
 | 
			
		||||
        """
 | 
			
		||||
        if self.weight.dtype != torch.bfloat16:
 | 
			
		||||
            self.weight.data = self.weight.data.to(dtype=torch.bfloat16)
 | 
			
		||||
        return self._norm(x.float()).type_as(x) * self.weight
 | 
			
		||||
 | 
			
		||||
    def _norm(self, x):
 | 
			
		||||
@ -294,6 +296,9 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
 | 
			
		||||
        """
 | 
			
		||||
        assert e.dtype == torch.float32
 | 
			
		||||
        if self.modulation.dtype != torch.float16:
 | 
			
		||||
            # 如果不是 fp16,则转换为 fp16
 | 
			
		||||
            self.modulation.data = self.modulation.data.to(dtype=torch.float16)
 | 
			
		||||
        with amp.autocast(dtype=torch.float32):
 | 
			
		||||
            e = (self.modulation + e).chunk(6, dim=1)
 | 
			
		||||
        assert e[0].dtype == torch.float32
 | 
			
		||||
@ -341,6 +346,9 @@ class Head(nn.Module):
 | 
			
		||||
            e(Tensor): Shape [B, C]
 | 
			
		||||
        """
 | 
			
		||||
        assert e.dtype == torch.float32
 | 
			
		||||
        if self.modulation.dtype != torch.float16:
 | 
			
		||||
            # 如果不是 fp16,则转换为 fp16
 | 
			
		||||
            self.modulation.data = self.modulation.data.to(dtype=torch.float16)
 | 
			
		||||
        with amp.autocast(dtype=torch.float32):
 | 
			
		||||
            e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
 | 
			
		||||
            x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,7 @@ from functools import partial
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.cuda.amp as amp
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
@ -25,6 +26,19 @@ from .utils.fm_solvers import (
 | 
			
		||||
)
 | 
			
		||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 | 
			
		||||
 | 
			
		||||
# def convert_linear_conv_to_fp8(module):
 | 
			
		||||
#     for name, child in module.named_children():
 | 
			
		||||
#         # 递归处理子模块
 | 
			
		||||
#         convert_linear_conv_to_fp8(child)
 | 
			
		||||
        
 | 
			
		||||
#         # 判断是否为 Linear 或 Conv 层
 | 
			
		||||
#         if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
 | 
			
		||||
#             # 转换权重
 | 
			
		||||
#             if hasattr(child, 'weight') and child.weight is not None:
 | 
			
		||||
#                 # 保留 Parameter 类型,仅修改数据
 | 
			
		||||
#                 child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
 | 
			
		||||
#                 # 可选:转换偏置(根据需求开启)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanT2V:
 | 
			
		||||
 | 
			
		||||
@ -84,7 +98,9 @@ class WanT2V:
 | 
			
		||||
            device=self.device)
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
			
		||||
        self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
			
		||||
        self.model = WanModel.from_pretrained(checkpoint_dir ,torch_dtype=torch.float8_e4m3fn)
 | 
			
		||||
        #self.model = WanModel.from_pretrained(checkpoint_dir )
 | 
			
		||||
        
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
        if use_usp:
 | 
			
		||||
@ -106,7 +122,9 @@ class WanT2V:
 | 
			
		||||
            dist.barrier()
 | 
			
		||||
        if dit_fsdp:
 | 
			
		||||
            self.model = shard_fn(self.model)
 | 
			
		||||
            # convert_linear_conv_to_fp8(self.model)
 | 
			
		||||
        else:
 | 
			
		||||
            # convert_linear_conv_to_fp8(self.model)
 | 
			
		||||
            self.model.to(self.device)
 | 
			
		||||
 | 
			
		||||
        self.sample_neg_prompt = config.sample_neg_prompt
 | 
			
		||||
@ -156,6 +174,7 @@ class WanT2V:
 | 
			
		||||
                - W: Frame width from size)
 | 
			
		||||
        """
 | 
			
		||||
        # preprocess
 | 
			
		||||
        #offload_model = False
 | 
			
		||||
        F = frame_num
 | 
			
		||||
        target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
 | 
			
		||||
                        size[1] // self.vae_stride[1],
 | 
			
		||||
@ -229,6 +248,16 @@ class WanT2V:
 | 
			
		||||
 | 
			
		||||
            arg_c = {'context': context, 'seq_len': seq_len}
 | 
			
		||||
            arg_null = {'context': context_null, 'seq_len': seq_len}
 | 
			
		||||
            
 | 
			
		||||
            # import gc
 | 
			
		||||
            # del self.text_encoder
 | 
			
		||||
            # del self.vae
 | 
			
		||||
            # gc.collect()  # 立即触发垃圾回收
 | 
			
		||||
            # torch.cuda.empty_cache()  # 清空CUDA缓存
 | 
			
		||||
            # torch.cuda.reset_peak_memory_stats()
 | 
			
		||||
            
 | 
			
		||||
            # start_mem = torch.cuda.memory_allocated()
 | 
			
		||||
            #print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB") 
 | 
			
		||||
 | 
			
		||||
            for _, t in enumerate(tqdm(timesteps)):
 | 
			
		||||
                latent_model_input = latents
 | 
			
		||||
@ -252,6 +281,9 @@ class WanT2V:
 | 
			
		||||
                    return_dict=False,
 | 
			
		||||
                    generator=seed_g)[0]
 | 
			
		||||
                latents = [temp_x0.squeeze(0)]
 | 
			
		||||
                
 | 
			
		||||
            # peak_mem_bytes = torch.cuda.max_memory_allocated()
 | 
			
		||||
            # print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
 | 
			
		||||
 | 
			
		||||
            x0 = latents
 | 
			
		||||
            if offload_model:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user