From bc2aff711e2334628dec9845f6c9066199b7d059 Mon Sep 17 00:00:00 2001 From: Yexiong Lin Date: Thu, 27 Feb 2025 17:42:53 +1100 Subject: [PATCH] support fp8 modell --- generate.py | 6 +++ wan/configs/wan_i2v_14B.py | 2 +- wan/image2video.py | 77 ++++++++++++++++++++------------------ wan/modules/clip.py | 11 ++++-- 4 files changed, 55 insertions(+), 41 deletions(-) diff --git a/generate.py b/generate.py index 327918f..f8bd885 100644 --- a/generate.py +++ b/generate.py @@ -126,6 +126,11 @@ def _parse_args(): action="store_true", default=False, help="Whether to use FSDP for DiT.") + parser.add_argument( + "--fp8", + action="store_true", + default=False, + help="Whether to use fp8.") parser.add_argument( "--save_file", type=str, @@ -363,6 +368,7 @@ def generate(args): dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, + fp8=args.fp8, ) logging.info("Generating video ...") diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py index b2ab367..ccfd1af 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/wan/configs/wan_i2v_14B.py @@ -16,7 +16,7 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl' i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_dtype = torch.float16 # i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' -i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' +i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model i2v_14B.clip_tokenizer = 'xlm-roberta-large' # vae diff --git a/wan/image2video.py b/wan/image2video.py index df63624..8b9e653 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -43,6 +43,7 @@ class WanI2V: use_usp=False, t5_cpu=False, init_on_cpu=True, + fp8=False, ): r""" Initializes the image-to-video generation model components. @@ -66,6 +67,8 @@ class WanI2V: Whether to place T5 model on CPU. Only works without t5_fsdp. init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + fp8 (`bool`, *optional*, defaults to False): + Enable 8-bit floating point precision for model parameters. """ self.device = torch.device(f"cuda:{device_id}") self.config = config @@ -88,7 +91,6 @@ class WanI2V: self.vae_stride = config.vae_stride self.patch_size = config.patch_size - print('device:', self.device) self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), @@ -101,44 +103,47 @@ class WanI2V: config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) - state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu") - dim = state_dict["patch_embedding.weight"].shape[0] - in_channels = state_dict["patch_embedding.weight"].shape[1] - print("in_channels: ", in_channels) - ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] - model_type = "i2v" if in_channels == 36 else "t2v" - num_heads = 40 if dim == 5120 else 12 - num_layers = 40 if dim == 5120 else 30 - TRANSFORMER_CONFIG= { - "dim": dim, - "ffn_dim": ffn_dim, - "eps": 1e-06, - "freq_dim": 256, - "in_dim": in_channels, - "model_type": model_type, - "out_dim": 16, - "text_len": 512, - "num_heads": num_heads, - "num_layers": num_layers, - } + if not fp8: + self.model = WanModel.from_pretrained(checkpoint_dir) + else: + if '480P' in checkpoint_dir: + state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu") + elif '720P' in checkpoint_dir: + state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu") + dim = state_dict["patch_embedding.weight"].shape[0] + in_channels = state_dict["patch_embedding.weight"].shape[1] + ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] + model_type = "i2v" if in_channels == 36 else "t2v" + num_heads = 40 if dim == 5120 else 12 + num_layers = 40 if dim == 5120 else 30 + TRANSFORMER_CONFIG= { + "dim": dim, + "ffn_dim": ffn_dim, + "eps": 1e-06, + "freq_dim": 256, + "in_dim": in_channels, + "model_type": model_type, + "out_dim": 16, + "text_len": 512, + "num_heads": num_heads, + "num_layers": num_layers, + } - with init_empty_weights(): - self.model = WanModel(**TRANSFORMER_CONFIG) - logging.info(f"Creating WanModel from {checkpoint_dir}") - - base_dtype=torch.bfloat16 - dtype=torch.float8_e4m3fn - params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"} - for name, param in self.model.named_parameters(): - dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype - # dtype_to_use = torch.bfloat16 - # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use) - set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) + with init_empty_weights(): + self.model = WanModel(**TRANSFORMER_CONFIG) + logging.info(f"Creating WanModel from {checkpoint_dir}") + + base_dtype=torch.bfloat16 + dtype=torch.float8_e4m3fn + params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"} + for name, param in self.model.named_parameters(): + dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype + # dtype_to_use = torch.bfloat16 + # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use) + set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) - del state_dict - print("Model loaded successfully") + del state_dict - # self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: diff --git a/wan/modules/clip.py b/wan/modules/clip.py index ceccafc..619cf63 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -516,10 +516,13 @@ class CLIPModel: device=device) self.model = self.model.eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') - state_dict = load_file(checkpoint_path, device='cpu') - self.model.load_state_dict(state_dict) - # self.model.load_state_dict( - # torch.load(checkpoint_path, map_location='cpu')) + if checkpoint_path.endswith('.safetensors'): + state_dict = load_file(checkpoint_path, device='cpu') + self.model.load_state_dict(state_dict) + elif checkpoint_path.endswith('.pth'): + self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + else: + raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}') # init tokenizer self.tokenizer = HuggingfaceTokenizer(