From bebb16bb8ed4bacf1b2352381ec84eb969f46dee Mon Sep 17 00:00:00 2001 From: Yexiong Lin Date: Thu, 27 Feb 2025 17:01:58 +1100 Subject: [PATCH 1/6] =?UTF-8?q?=E6=94=AF=E6=8C=81Kijai=E7=9A=84fp8?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wan/configs/wan_i2v_14B.py | 3 ++- wan/image2video.py | 52 +++++++++++++++++++++++++++++++++++--- wan/modules/clip.py | 7 +++-- wan/modules/vae.py | 4 +-- 4 files changed, 58 insertions(+), 8 deletions(-) diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py index 12e8e20..b2ab367 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/wan/configs/wan_i2v_14B.py @@ -15,7 +15,8 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl' # clip i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_dtype = torch.float16 -i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' +# i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' +i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' i2v_14B.clip_tokenizer = 'xlm-roberta-large' # vae diff --git a/wan/image2video.py b/wan/image2video.py index 468f17c..df63624 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -16,6 +16,10 @@ import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from safetensors.torch import load_file + from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.model import WanModel @@ -84,6 +88,8 @@ class WanI2V: self.vae_stride = config.vae_stride self.patch_size = config.patch_size + print('device:', self.device) + self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) @@ -95,8 +101,44 @@ class WanI2V: config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) + state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu") + dim = state_dict["patch_embedding.weight"].shape[0] + in_channels = state_dict["patch_embedding.weight"].shape[1] + print("in_channels: ", in_channels) + ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] + model_type = "i2v" if in_channels == 36 else "t2v" + num_heads = 40 if dim == 5120 else 12 + num_layers = 40 if dim == 5120 else 30 + TRANSFORMER_CONFIG= { + "dim": dim, + "ffn_dim": ffn_dim, + "eps": 1e-06, + "freq_dim": 256, + "in_dim": in_channels, + "model_type": model_type, + "out_dim": 16, + "text_len": 512, + "num_heads": num_heads, + "num_layers": num_layers, + } + + with init_empty_weights(): + self.model = WanModel(**TRANSFORMER_CONFIG) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + + base_dtype=torch.bfloat16 + dtype=torch.float8_e4m3fn + params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"} + for name, param in self.model.named_parameters(): + dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype + # dtype_to_use = torch.bfloat16 + # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use) + set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) + + del state_dict + print("Model loaded successfully") + + # self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: @@ -242,9 +284,12 @@ class WanI2V: torch.zeros(3, 80, h, w) ], dim=1).to(self.device) - ])[0] + ],device=self.device)[0] y = torch.concat([msk, y]) + if offload_model: + self.vae.model.cpu() + @contextmanager def noop_no_sync(): yield @@ -334,7 +379,8 @@ class WanI2V: torch.cuda.empty_cache() if self.rank == 0: - videos = self.vae.decode(x0) + self.vae.model.to(self.device) + videos = self.vae.decode(x0, device=self.device) del noise, latent del sample_scheduler diff --git a/wan/modules/clip.py b/wan/modules/clip.py index 42dda04..ceccafc 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T +from safetensors.torch import load_file from .attention import flash_attention from .tokenizers import HuggingfaceTokenizer @@ -515,8 +516,10 @@ class CLIPModel: device=device) self.model = self.model.eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') - self.model.load_state_dict( - torch.load(checkpoint_path, map_location='cpu')) + state_dict = load_file(checkpoint_path, device='cpu') + self.model.load_state_dict(state_dict) + # self.model.load_state_dict( + # torch.load(checkpoint_path, map_location='cpu')) # init tokenizer self.tokenizer = HuggingfaceTokenizer( diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 5c6da57..dd98b42 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -644,7 +644,7 @@ class WanVAE: z_dim=z_dim, ).eval().requires_grad_(False).to(device) - def encode(self, videos): + def encode(self, videos, device=None): """ videos: A list of videos each with shape [C, T, H, W]. """ @@ -654,7 +654,7 @@ class WanVAE: for u in videos ] - def decode(self, zs): + def decode(self, zs, device=None): with amp.autocast(dtype=self.dtype): return [ self.model.decode(u.unsqueeze(0), From bc2aff711e2334628dec9845f6c9066199b7d059 Mon Sep 17 00:00:00 2001 From: Yexiong Lin Date: Thu, 27 Feb 2025 17:42:53 +1100 Subject: [PATCH 2/6] support fp8 modell --- generate.py | 6 +++ wan/configs/wan_i2v_14B.py | 2 +- wan/image2video.py | 77 ++++++++++++++++++++------------------ wan/modules/clip.py | 11 ++++-- 4 files changed, 55 insertions(+), 41 deletions(-) diff --git a/generate.py b/generate.py index 327918f..f8bd885 100644 --- a/generate.py +++ b/generate.py @@ -126,6 +126,11 @@ def _parse_args(): action="store_true", default=False, help="Whether to use FSDP for DiT.") + parser.add_argument( + "--fp8", + action="store_true", + default=False, + help="Whether to use fp8.") parser.add_argument( "--save_file", type=str, @@ -363,6 +368,7 @@ def generate(args): dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, + fp8=args.fp8, ) logging.info("Generating video ...") diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py index b2ab367..ccfd1af 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/wan/configs/wan_i2v_14B.py @@ -16,7 +16,7 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl' i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_dtype = torch.float16 # i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' -i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' +i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model i2v_14B.clip_tokenizer = 'xlm-roberta-large' # vae diff --git a/wan/image2video.py b/wan/image2video.py index df63624..8b9e653 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -43,6 +43,7 @@ class WanI2V: use_usp=False, t5_cpu=False, init_on_cpu=True, + fp8=False, ): r""" Initializes the image-to-video generation model components. @@ -66,6 +67,8 @@ class WanI2V: Whether to place T5 model on CPU. Only works without t5_fsdp. init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + fp8 (`bool`, *optional*, defaults to False): + Enable 8-bit floating point precision for model parameters. """ self.device = torch.device(f"cuda:{device_id}") self.config = config @@ -88,7 +91,6 @@ class WanI2V: self.vae_stride = config.vae_stride self.patch_size = config.patch_size - print('device:', self.device) self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), @@ -101,44 +103,47 @@ class WanI2V: config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) - state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu") - dim = state_dict["patch_embedding.weight"].shape[0] - in_channels = state_dict["patch_embedding.weight"].shape[1] - print("in_channels: ", in_channels) - ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] - model_type = "i2v" if in_channels == 36 else "t2v" - num_heads = 40 if dim == 5120 else 12 - num_layers = 40 if dim == 5120 else 30 - TRANSFORMER_CONFIG= { - "dim": dim, - "ffn_dim": ffn_dim, - "eps": 1e-06, - "freq_dim": 256, - "in_dim": in_channels, - "model_type": model_type, - "out_dim": 16, - "text_len": 512, - "num_heads": num_heads, - "num_layers": num_layers, - } + if not fp8: + self.model = WanModel.from_pretrained(checkpoint_dir) + else: + if '480P' in checkpoint_dir: + state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu") + elif '720P' in checkpoint_dir: + state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu") + dim = state_dict["patch_embedding.weight"].shape[0] + in_channels = state_dict["patch_embedding.weight"].shape[1] + ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] + model_type = "i2v" if in_channels == 36 else "t2v" + num_heads = 40 if dim == 5120 else 12 + num_layers = 40 if dim == 5120 else 30 + TRANSFORMER_CONFIG= { + "dim": dim, + "ffn_dim": ffn_dim, + "eps": 1e-06, + "freq_dim": 256, + "in_dim": in_channels, + "model_type": model_type, + "out_dim": 16, + "text_len": 512, + "num_heads": num_heads, + "num_layers": num_layers, + } - with init_empty_weights(): - self.model = WanModel(**TRANSFORMER_CONFIG) - logging.info(f"Creating WanModel from {checkpoint_dir}") - - base_dtype=torch.bfloat16 - dtype=torch.float8_e4m3fn - params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"} - for name, param in self.model.named_parameters(): - dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype - # dtype_to_use = torch.bfloat16 - # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use) - set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) + with init_empty_weights(): + self.model = WanModel(**TRANSFORMER_CONFIG) + logging.info(f"Creating WanModel from {checkpoint_dir}") + + base_dtype=torch.bfloat16 + dtype=torch.float8_e4m3fn + params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"} + for name, param in self.model.named_parameters(): + dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype + # dtype_to_use = torch.bfloat16 + # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use) + set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) - del state_dict - print("Model loaded successfully") + del state_dict - # self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: diff --git a/wan/modules/clip.py b/wan/modules/clip.py index ceccafc..619cf63 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -516,10 +516,13 @@ class CLIPModel: device=device) self.model = self.model.eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') - state_dict = load_file(checkpoint_path, device='cpu') - self.model.load_state_dict(state_dict) - # self.model.load_state_dict( - # torch.load(checkpoint_path, map_location='cpu')) + if checkpoint_path.endswith('.safetensors'): + state_dict = load_file(checkpoint_path, device='cpu') + self.model.load_state_dict(state_dict) + elif checkpoint_path.endswith('.pth'): + self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + else: + raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}') # init tokenizer self.tokenizer = HuggingfaceTokenizer( From 24007c2c3924df39aab0feeb9ea67812d8de10d0 Mon Sep 17 00:00:00 2001 From: Yexiong Lin Date: Thu, 27 Feb 2025 17:46:05 +1100 Subject: [PATCH 3/6] support fp8 model --- wan/configs/wan_i2v_14B.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py index ccfd1af..61f27cd 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/wan/configs/wan_i2v_14B.py @@ -15,8 +15,8 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl' # clip i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_dtype = torch.float16 -# i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' -i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model +i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' +# i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model i2v_14B.clip_tokenizer = 'xlm-roberta-large' # vae From db54b7c613d57329b2299e03809d9bee151f315f Mon Sep 17 00:00:00 2001 From: Yexiong Lin Date: Fri, 28 Feb 2025 16:35:02 +1100 Subject: [PATCH 4/6] Update README.md and text2video.py to offload model and enable using fp8 --- README.md | 13 +++++++++ generate.py | 1 + wan/image2video.py | 5 ++-- wan/text2video.py | 66 ++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 81 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0511a54..697a266 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,8 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." ``` +You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder. + > 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance. @@ -222,6 +224,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." ``` +To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`. + +For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU: + +1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder. +2. Then, execute the following command: + +``` +python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +``` + > 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. diff --git a/generate.py b/generate.py index f8bd885..8dd8e05 100644 --- a/generate.py +++ b/generate.py @@ -311,6 +311,7 @@ def generate(args): dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, + fp8=args.fp8, ) logging.info( diff --git a/wan/image2video.py b/wan/image2video.py index 8b9e653..84eb6f7 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -103,6 +103,7 @@ class WanI2V: config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) + logging.info(f"Creating WanModel from {checkpoint_dir}") if not fp8: self.model = WanModel.from_pretrained(checkpoint_dir) else: @@ -131,7 +132,6 @@ class WanI2V: with init_empty_weights(): self.model = WanModel(**TRANSFORMER_CONFIG) - logging.info(f"Creating WanModel from {checkpoint_dir}") base_dtype=torch.bfloat16 dtype=torch.float8_e4m3fn @@ -382,9 +382,10 @@ class WanI2V: if offload_model: self.model.cpu() torch.cuda.empty_cache() + # load vae model back to device + self.vae.model.to(self.device) if self.rank == 0: - self.vae.model.to(self.device) videos = self.vae.decode(x0, device=self.device) del noise, latent diff --git a/wan/text2video.py b/wan/text2video.py index 2400545..7b0bf3d 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -14,6 +14,10 @@ import torch.cuda.amp as amp import torch.distributed as dist from tqdm import tqdm +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from safetensors.torch import load_file + from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel @@ -35,6 +39,8 @@ class WanT2V: dit_fsdp=False, use_usp=False, t5_cpu=False, + init_on_cpu=True, + fp8=False, ): r""" Initializes the Wan text-to-video generation model components. @@ -56,6 +62,8 @@ class WanT2V: Enable distribution strategy of USP. t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. + fp8 (`bool`, *optional*, defaults to False): + Enable 8-bit floating point precision for model parameters. """ self.device = torch.device(f"cuda:{device_id}") self.config = config @@ -81,9 +89,52 @@ class WanT2V: device=self.device) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + if not fp8: + self.model = WanModel.from_pretrained(checkpoint_dir) + else: + if '14B' in checkpoint_dir: + state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu") + else: + state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu") + + dim = state_dict["patch_embedding.weight"].shape[0] + in_channels = state_dict["patch_embedding.weight"].shape[1] + ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] + model_type = "i2v" if in_channels == 36 else "t2v" + num_heads = 40 if dim == 5120 else 12 + num_layers = 40 if dim == 5120 else 30 + TRANSFORMER_CONFIG= { + "dim": dim, + "ffn_dim": ffn_dim, + "eps": 1e-06, + "freq_dim": 256, + "in_dim": in_channels, + "model_type": model_type, + "out_dim": 16, + "text_len": 512, + "num_heads": num_heads, + "num_layers": num_layers, + } + + with init_empty_weights(): + self.model = WanModel(**TRANSFORMER_CONFIG) + + base_dtype=torch.bfloat16 + dtype=torch.float8_e4m3fn + params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"} + for name, param in self.model.named_parameters(): + dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype + # dtype_to_use = torch.bfloat16 + # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use) + set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) + + del state_dict + self.model.eval().requires_grad_(False) + if t5_fsdp or dit_fsdp or use_usp: + init_on_cpu = False + if use_usp: from xfuser.core.distributed import \ get_sequence_parallel_world_size @@ -103,7 +154,8 @@ class WanT2V: if dit_fsdp: self.model = shard_fn(self.model) else: - self.model.to(self.device) + if not init_on_cpu: + self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt @@ -190,6 +242,9 @@ class WanT2V: generator=seed_g) ] + if offload_model: + self.vae.model.cpu() + @contextmanager def noop_no_sync(): yield @@ -226,6 +281,10 @@ class WanT2V: arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} + if offload_model: + torch.cuda.empty_cache() + + self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents timestep = [t] @@ -253,6 +312,9 @@ class WanT2V: if offload_model: self.model.cpu() torch.cuda.empty_cache() + # load vae model back to device + self.vae.model.to(self.device) + if self.rank == 0: videos = self.vae.decode(x0) From 1c7b73d13e909a958f6d9d80f5140c8dbcf1d358 Mon Sep 17 00:00:00 2001 From: Yexiong Lin Date: Fri, 28 Feb 2025 21:21:09 +1100 Subject: [PATCH 5/6] Add the support for fp8 t5 --- README.md | 6 ++++++ wan/configs/wan_i2v_14B.py | 1 + wan/configs/wan_t2v_14B.py | 1 + wan/configs/wan_t2v_1_3B.py | 1 + wan/image2video.py | 15 +++++++++++---- wan/modules/t5.py | 38 +++++++++++++++++++++++++++++-------- wan/text2video.py | 18 +++++++++++++----- 7 files changed, 63 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 697a266..bdaa586 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,12 @@ python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder. +Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file: + +``` +t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' +``` + > 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance. diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py index 61f27cd..abab971 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/wan/configs/wan_i2v_14B.py @@ -10,6 +10,7 @@ i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') i2v_14B.update(wan_shared_cfg) i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model i2v_14B.t5_tokenizer = 'google/umt5-xxl' # clip diff --git a/wan/configs/wan_t2v_14B.py b/wan/configs/wan_t2v_14B.py index 9d0ee69..513c863 100644 --- a/wan/configs/wan_t2v_14B.py +++ b/wan/configs/wan_t2v_14B.py @@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg) # t5 t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model t2v_14B.t5_tokenizer = 'google/umt5-xxl' # vae diff --git a/wan/configs/wan_t2v_1_3B.py b/wan/configs/wan_t2v_1_3B.py index ea9502b..b88c30a 100644 --- a/wan/configs/wan_t2v_1_3B.py +++ b/wan/configs/wan_t2v_1_3B.py @@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg) # t5 t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' # vae diff --git a/wan/image2video.py b/wan/image2video.py index 84eb6f7..6e0e08c 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -80,6 +80,10 @@ class WanI2V: self.param_dtype = config.param_dtype shard_fn = partial(shard_model, device_id=device_id) + if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors': + quantization = "fp8_e4m3fn" + else: + quantization = "disabled" self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, @@ -87,6 +91,7 @@ class WanI2V: checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn=shard_fn if t5_fsdp else None, + quantization=quantization, ) self.vae_stride = config.vae_stride @@ -266,13 +271,15 @@ class WanI2V: # preprocess if not self.t5_cpu: self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) + with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True): + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] diff --git a/wan/modules/t5.py b/wan/modules/t5.py index c841b04..8960f34 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -9,6 +9,10 @@ import torch.nn.functional as F from .tokenizers import HuggingfaceTokenizer +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from safetensors.torch import load_file + __all__ = [ 'T5Model', 'T5Encoder', @@ -442,7 +446,7 @@ def _t5(name, model = model_cls(**kwargs) # set device - model = model.to(dtype=dtype, device=device) + # model = model.to(dtype=dtype, device=device) # init tokenizer if return_tokenizer: @@ -479,6 +483,7 @@ class T5EncoderModel: checkpoint_path=None, tokenizer_path=None, shard_fn=None, + quantization="disabled", ): self.text_len = text_len self.dtype = dtype @@ -486,14 +491,31 @@ class T5EncoderModel: self.checkpoint_path = checkpoint_path self.tokenizer_path = tokenizer_path - # init model - model = umt5_xxl( - encoder_only=True, - return_tokenizer=False, - dtype=dtype, - device=device).eval().requires_grad_(False) + logging.info(f'loading {checkpoint_path}') - model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + if quantization == "disabled": + # init model + model = umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device).eval().requires_grad_(False) + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + elif quantization == "fp8_e4m3fn": + with init_empty_weights(): + model = umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device).eval().requires_grad_(False) + cast_dtype = torch.float8_e4m3fn + state_dict = load_file(checkpoint_path, device="cpu") + params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} + for name, param in model.named_parameters(): + dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype + set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name]) + del state_dict + self.model = model if shard_fn is not None: self.model = shard_fn(self.model, sync_module_states=False) diff --git a/wan/text2video.py b/wan/text2video.py index 7b0bf3d..4460b8b 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -74,13 +74,19 @@ class WanT2V: self.param_dtype = config.param_dtype shard_fn = partial(shard_model, device_id=device_id) + if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors': + quantization = "fp8_e4m3fn" + else: + quantization = "disabled" + self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None) + shard_fn=shard_fn if t5_fsdp else None, + quantization=quantization) self.vae_stride = config.vae_stride self.patch_size = config.patch_size @@ -221,13 +227,15 @@ class WanT2V: if not self.t5_cpu: self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) + with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True): + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] From 36d6d91b9043ff42f22f9f78647747f83fcb5538 Mon Sep 17 00:00:00 2001 From: Yexiong Lin Date: Tue, 4 Mar 2025 19:54:54 +1100 Subject: [PATCH 6/6] update text2video.py --- wan/text2video.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/wan/text2video.py b/wan/text2video.py index 4460b8b..99716bf 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -298,8 +298,6 @@ class WanT2V: timestep = [t] timestep = torch.stack(timestep) - - self.model.to(self.device) noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0] noise_pred_uncond = self.model(