diff --git a/README.md b/README.md index 697a266..bdaa586 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,12 @@ python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder. +Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file: + +``` +t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' +``` + > 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance. diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py index 61f27cd..abab971 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/wan/configs/wan_i2v_14B.py @@ -10,6 +10,7 @@ i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') i2v_14B.update(wan_shared_cfg) i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model i2v_14B.t5_tokenizer = 'google/umt5-xxl' # clip diff --git a/wan/configs/wan_t2v_14B.py b/wan/configs/wan_t2v_14B.py index 9d0ee69..513c863 100644 --- a/wan/configs/wan_t2v_14B.py +++ b/wan/configs/wan_t2v_14B.py @@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg) # t5 t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model t2v_14B.t5_tokenizer = 'google/umt5-xxl' # vae diff --git a/wan/configs/wan_t2v_1_3B.py b/wan/configs/wan_t2v_1_3B.py index ea9502b..b88c30a 100644 --- a/wan/configs/wan_t2v_1_3B.py +++ b/wan/configs/wan_t2v_1_3B.py @@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg) # t5 t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' # vae diff --git a/wan/image2video.py b/wan/image2video.py index 84eb6f7..6e0e08c 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -80,6 +80,10 @@ class WanI2V: self.param_dtype = config.param_dtype shard_fn = partial(shard_model, device_id=device_id) + if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors': + quantization = "fp8_e4m3fn" + else: + quantization = "disabled" self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, @@ -87,6 +91,7 @@ class WanI2V: checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn=shard_fn if t5_fsdp else None, + quantization=quantization, ) self.vae_stride = config.vae_stride @@ -266,13 +271,15 @@ class WanI2V: # preprocess if not self.t5_cpu: self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) + with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True): + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] diff --git a/wan/modules/t5.py b/wan/modules/t5.py index c841b04..8960f34 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -9,6 +9,10 @@ import torch.nn.functional as F from .tokenizers import HuggingfaceTokenizer +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from safetensors.torch import load_file + __all__ = [ 'T5Model', 'T5Encoder', @@ -442,7 +446,7 @@ def _t5(name, model = model_cls(**kwargs) # set device - model = model.to(dtype=dtype, device=device) + # model = model.to(dtype=dtype, device=device) # init tokenizer if return_tokenizer: @@ -479,6 +483,7 @@ class T5EncoderModel: checkpoint_path=None, tokenizer_path=None, shard_fn=None, + quantization="disabled", ): self.text_len = text_len self.dtype = dtype @@ -486,14 +491,31 @@ class T5EncoderModel: self.checkpoint_path = checkpoint_path self.tokenizer_path = tokenizer_path - # init model - model = umt5_xxl( - encoder_only=True, - return_tokenizer=False, - dtype=dtype, - device=device).eval().requires_grad_(False) + logging.info(f'loading {checkpoint_path}') - model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + if quantization == "disabled": + # init model + model = umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device).eval().requires_grad_(False) + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + elif quantization == "fp8_e4m3fn": + with init_empty_weights(): + model = umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device).eval().requires_grad_(False) + cast_dtype = torch.float8_e4m3fn + state_dict = load_file(checkpoint_path, device="cpu") + params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} + for name, param in model.named_parameters(): + dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype + set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name]) + del state_dict + self.model = model if shard_fn is not None: self.model = shard_fn(self.model, sync_module_states=False) diff --git a/wan/text2video.py b/wan/text2video.py index 7b0bf3d..4460b8b 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -74,13 +74,19 @@ class WanT2V: self.param_dtype = config.param_dtype shard_fn = partial(shard_model, device_id=device_id) + if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors': + quantization = "fp8_e4m3fn" + else: + quantization = "disabled" + self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None) + shard_fn=shard_fn if t5_fsdp else None, + quantization=quantization) self.vae_stride = config.vae_stride self.patch_size = config.patch_size @@ -221,13 +227,15 @@ class WanT2V: if not self.t5_cpu: self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) + with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True): + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null]