Compare commits

...

8 Commits

Author SHA1 Message Date
Yexiong Lin
406d4ba4a2
Merge 36d6d91b90 into e5a741309d 2025-05-17 17:32:17 +08:00
Shiwei Zhang
e5a741309d
Update README.md (#406) 2025-05-17 10:57:06 +08:00
Yexiong Lin
36d6d91b90 update text2video.py 2025-03-04 19:54:56 +11:00
Yexiong Lin
1c7b73d13e Add the support for fp8 t5 2025-03-04 19:54:56 +11:00
Yexiong Lin
db54b7c613 Update README.md and text2video.py to offload model and enable using fp8 2025-03-04 19:54:56 +11:00
Yexiong Lin
24007c2c39 support fp8 model 2025-03-04 19:54:56 +11:00
Yexiong Lin
bc2aff711e support fp8 modell 2025-03-04 19:54:56 +11:00
Yexiong Lin
bebb16bb8e 支持Kijai的fp8模型 2025-03-04 19:54:56 +11:00
10 changed files with 214 additions and 29 deletions

View File

@ -166,6 +166,14 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
``` ```
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:
```
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
```
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance. > 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
@ -302,6 +310,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
``` ```
To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`.
For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU:
1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder.
2. Then, execute the following command:
```
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. > 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
@ -643,7 +662,7 @@ If you find our work helpful, please cite us.
``` ```
@article{wan2025, @article{wan2025,
title={Wan: Open and Advanced Large-Scale Video Generative Models}, title={Wan: Open and Advanced Large-Scale Video Generative Models},
author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu}, author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
journal = {arXiv preprint arXiv:2503.20314}, journal = {arXiv preprint arXiv:2503.20314},
year={2025} year={2025}
} }

View File

@ -155,6 +155,11 @@ def _parse_args():
action="store_true", action="store_true",
default=False, default=False,
help="Whether to use FSDP for DiT.") help="Whether to use FSDP for DiT.")
parser.add_argument(
"--fp8",
action="store_true",
default=False,
help="Whether to use fp8.")
parser.add_argument( parser.add_argument(
"--save_file", "--save_file",
type=str, type=str,
@ -366,6 +371,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp, dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
fp8=args.fp8,
) )
logging.info( logging.info(
@ -423,6 +429,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp, dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
fp8=args.fp8,
) )
logging.info("Generating video ...") logging.info("Generating video ...")

View File

@ -11,12 +11,14 @@ i2v_14B.update(wan_shared_cfg)
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
i2v_14B.t5_tokenizer = 'google/umt5-xxl' i2v_14B.t5_tokenizer = 'google/umt5-xxl'
# clip # clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float16 i2v_14B.clip_dtype = torch.float16
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
# i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
i2v_14B.clip_tokenizer = 'xlm-roberta-large' i2v_14B.clip_tokenizer = 'xlm-roberta-large'
# vae # vae

View File

@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg)
# t5 # t5
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
t2v_14B.t5_tokenizer = 'google/umt5-xxl' t2v_14B.t5_tokenizer = 'google/umt5-xxl'
# vae # vae

View File

@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg)
# t5 # t5
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
# vae # vae

View File

@ -16,6 +16,10 @@ import torch.distributed as dist
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from tqdm import tqdm from tqdm import tqdm
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel from .modules.clip import CLIPModel
from .modules.model import WanModel from .modules.model import WanModel
@ -42,6 +46,7 @@ class WanI2V:
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
init_on_cpu=True, init_on_cpu=True,
fp8=False,
): ):
r""" r"""
Initializes the image-to-video generation model components. Initializes the image-to-video generation model components.
@ -65,6 +70,8 @@ class WanI2V:
Whether to place T5 model on CPU. Only works without t5_fsdp. Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True): init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP. Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
fp8 (`bool`, *optional*, defaults to False):
Enable 8-bit floating point precision for model parameters.
""" """
self.device = torch.device(f"cuda:{device_id}") self.device = torch.device(f"cuda:{device_id}")
self.config = config self.config = config
@ -76,6 +83,10 @@ class WanI2V:
self.param_dtype = config.param_dtype self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id) shard_fn = partial(shard_model, device_id=device_id)
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
quantization = "fp8_e4m3fn"
else:
quantization = "disabled"
self.text_encoder = T5EncoderModel( self.text_encoder = T5EncoderModel(
text_len=config.text_len, text_len=config.text_len,
dtype=config.t5_dtype, dtype=config.t5_dtype,
@ -83,10 +94,12 @@ class WanI2V:
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None, shard_fn=shard_fn if t5_fsdp else None,
quantization=quantization,
) )
self.vae_stride = config.vae_stride self.vae_stride = config.vae_stride
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.vae = WanVAE( self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device) device=self.device)
@ -99,7 +112,46 @@ class WanI2V:
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}") logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir) if not fp8:
self.model = WanModel.from_pretrained(checkpoint_dir)
else:
if '480P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
elif '720P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1]
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12
num_layers = 40 if dim == 5120 else 30
TRANSFORMER_CONFIG= {
"dim": dim,
"ffn_dim": ffn_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": 16,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
}
with init_empty_weights():
self.model = WanModel(**TRANSFORMER_CONFIG)
base_dtype=torch.bfloat16
dtype=torch.float8_e4m3fn
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
for name, param in self.model.named_parameters():
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
# dtype_to_use = torch.bfloat16
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp: if t5_fsdp or dit_fsdp or use_usp:
@ -222,13 +274,15 @@ class WanI2V:
# preprocess # preprocess
if not self.t5_cpu: if not self.t5_cpu:
self.text_encoder.model.to(self.device) self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
context_null = self.text_encoder([n_prompt], self.device) context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model: if offload_model:
self.text_encoder.model.cpu() self.text_encoder.model.cpu()
else: else:
context = self.text_encoder([input_prompt], torch.device('cpu')) with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context] context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
@ -245,9 +299,12 @@ class WanI2V:
torch.zeros(3, F - 1, h, w) torch.zeros(3, F - 1, h, w)
], ],
dim=1).to(self.device) dim=1).to(self.device)
])[0] ],device=self.device)[0]
y = torch.concat([msk, y]) y = torch.concat([msk, y])
if offload_model:
self.vae.model.cpu()
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -335,9 +392,11 @@ class WanI2V:
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# load vae model back to device
self.vae.model.to(self.device)
if self.rank == 0: if self.rank == 0:
videos = self.vae.decode(x0) videos = self.vae.decode(x0, device=self.device)
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler

View File

@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
from safetensors.torch import load_file
from .attention import flash_attention from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer from .tokenizers import HuggingfaceTokenizer
@ -515,8 +516,13 @@ class CLIPModel:
device=device) device=device)
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}') logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict( if checkpoint_path.endswith('.safetensors'):
torch.load(checkpoint_path, map_location='cpu')) state_dict = load_file(checkpoint_path, device='cpu')
self.model.load_state_dict(state_dict)
elif checkpoint_path.endswith('.pth'):
self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
else:
raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')
# init tokenizer # init tokenizer
self.tokenizer = HuggingfaceTokenizer( self.tokenizer = HuggingfaceTokenizer(

View File

@ -9,6 +9,10 @@ import torch.nn.functional as F
from .tokenizers import HuggingfaceTokenizer from .tokenizers import HuggingfaceTokenizer
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
__all__ = [ __all__ = [
'T5Model', 'T5Model',
'T5Encoder', 'T5Encoder',
@ -442,7 +446,7 @@ def _t5(name,
model = model_cls(**kwargs) model = model_cls(**kwargs)
# set device # set device
model = model.to(dtype=dtype, device=device) # model = model.to(dtype=dtype, device=device)
# init tokenizer # init tokenizer
if return_tokenizer: if return_tokenizer:
@ -479,6 +483,7 @@ class T5EncoderModel:
checkpoint_path=None, checkpoint_path=None,
tokenizer_path=None, tokenizer_path=None,
shard_fn=None, shard_fn=None,
quantization="disabled",
): ):
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
@ -486,14 +491,31 @@ class T5EncoderModel:
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}') logging.info(f'loading {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) if quantization == "disabled":
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
elif quantization == "fp8_e4m3fn":
with init_empty_weights():
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
cast_dtype = torch.float8_e4m3fn
state_dict = load_file(checkpoint_path, device="cpu")
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
for name, param in model.named_parameters():
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model = model self.model = model
if shard_fn is not None: if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False) self.model = shard_fn(self.model, sync_module_states=False)

View File

@ -644,7 +644,7 @@ class WanVAE:
z_dim=z_dim, z_dim=z_dim,
).eval().requires_grad_(False).to(device) ).eval().requires_grad_(False).to(device)
def encode(self, videos): def encode(self, videos, device=None):
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
@ -654,7 +654,7 @@ class WanVAE:
for u in videos for u in videos
] ]
def decode(self, zs): def decode(self, zs, device=None):
with amp.autocast(dtype=self.dtype): with amp.autocast(dtype=self.dtype):
return [ return [
self.model.decode(u.unsqueeze(0), self.model.decode(u.unsqueeze(0),

View File

@ -14,6 +14,10 @@ import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
@ -38,6 +42,8 @@ class WanT2V:
dit_fsdp=False, dit_fsdp=False,
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
init_on_cpu=True,
fp8=False,
): ):
r""" r"""
Initializes the Wan text-to-video generation model components. Initializes the Wan text-to-video generation model components.
@ -59,6 +65,8 @@ class WanT2V:
Enable distribution strategy of USP. Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False): t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp. Whether to place T5 model on CPU. Only works without t5_fsdp.
fp8 (`bool`, *optional*, defaults to False):
Enable 8-bit floating point precision for model parameters.
""" """
self.device = torch.device(f"cuda:{device_id}") self.device = torch.device(f"cuda:{device_id}")
self.config = config self.config = config
@ -69,13 +77,19 @@ class WanT2V:
self.param_dtype = config.param_dtype self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id) shard_fn = partial(shard_model, device_id=device_id)
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
quantization = "fp8_e4m3fn"
else:
quantization = "disabled"
self.text_encoder = T5EncoderModel( self.text_encoder = T5EncoderModel(
text_len=config.text_len, text_len=config.text_len,
dtype=config.t5_dtype, dtype=config.t5_dtype,
device=torch.device('cpu'), device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None) shard_fn=shard_fn if t5_fsdp else None,
quantization=quantization)
self.vae_stride = config.vae_stride self.vae_stride = config.vae_stride
self.patch_size = config.patch_size self.patch_size = config.patch_size
@ -84,9 +98,52 @@ class WanT2V:
device=self.device) device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}") logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir) if not fp8:
self.model = WanModel.from_pretrained(checkpoint_dir)
else:
if '14B' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu")
else:
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1]
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12
num_layers = 40 if dim == 5120 else 30
TRANSFORMER_CONFIG= {
"dim": dim,
"ffn_dim": ffn_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": 16,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
}
with init_empty_weights():
self.model = WanModel(**TRANSFORMER_CONFIG)
base_dtype=torch.bfloat16
dtype=torch.float8_e4m3fn
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
for name, param in self.model.named_parameters():
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
# dtype_to_use = torch.bfloat16
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
init_on_cpu = False
if use_usp: if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size from xfuser.core.distributed import get_sequence_parallel_world_size
@ -107,7 +164,8 @@ class WanT2V:
if dit_fsdp: if dit_fsdp:
self.model = shard_fn(self.model) self.model = shard_fn(self.model)
else: else:
self.model.to(self.device) if not init_on_cpu:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
@ -173,13 +231,15 @@ class WanT2V:
if not self.t5_cpu: if not self.t5_cpu:
self.text_encoder.model.to(self.device) self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
context_null = self.text_encoder([n_prompt], self.device) context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model: if offload_model:
self.text_encoder.model.cpu() self.text_encoder.model.cpu()
else: else:
context = self.text_encoder([input_prompt], torch.device('cpu')) with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context] context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
@ -194,6 +254,9 @@ class WanT2V:
generator=seed_g) generator=seed_g)
] ]
if offload_model:
self.vae.model.cpu()
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -230,13 +293,15 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model( noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0] latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
@ -257,6 +322,9 @@ class WanT2V:
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# load vae model back to device
self.vae.model.to(self.device)
if self.rank == 0: if self.rank == 0:
videos = self.vae.decode(x0) videos = self.vae.decode(x0)