mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-17 12:57:40 +00:00
Compare commits
8 Commits
74c2e13417
...
406d4ba4a2
Author | SHA1 | Date | |
---|---|---|---|
|
406d4ba4a2 | ||
|
e5a741309d | ||
|
36d6d91b90 | ||
|
1c7b73d13e | ||
|
db54b7c613 | ||
|
24007c2c39 | ||
|
bc2aff711e | ||
|
bebb16bb8e |
21
README.md
21
README.md
@ -166,6 +166,14 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr
|
|||||||
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
|
||||||
|
|
||||||
|
Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:
|
||||||
|
|
||||||
|
```
|
||||||
|
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
|
||||||
|
```
|
||||||
|
|
||||||
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
|
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
|
||||||
|
|
||||||
|
|
||||||
@ -302,6 +310,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and
|
|||||||
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`.
|
||||||
|
|
||||||
|
For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU:
|
||||||
|
|
||||||
|
1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder.
|
||||||
|
2. Then, execute the following command:
|
||||||
|
|
||||||
|
```
|
||||||
|
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||||
|
```
|
||||||
|
|
||||||
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
|
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
|
||||||
|
|
||||||
|
|
||||||
@ -643,7 +662,7 @@ If you find our work helpful, please cite us.
|
|||||||
```
|
```
|
||||||
@article{wan2025,
|
@article{wan2025,
|
||||||
title={Wan: Open and Advanced Large-Scale Video Generative Models},
|
title={Wan: Open and Advanced Large-Scale Video Generative Models},
|
||||||
author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
|
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
|
||||||
journal = {arXiv preprint arXiv:2503.20314},
|
journal = {arXiv preprint arXiv:2503.20314},
|
||||||
year={2025}
|
year={2025}
|
||||||
}
|
}
|
||||||
|
@ -155,6 +155,11 @@ def _parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether to use FSDP for DiT.")
|
help="Whether to use FSDP for DiT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp8",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to use fp8.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_file",
|
"--save_file",
|
||||||
type=str,
|
type=str,
|
||||||
@ -366,6 +371,7 @@ def generate(args):
|
|||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
|
fp8=args.fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -423,6 +429,7 @@ def generate(args):
|
|||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
|
fp8=args.fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
logging.info("Generating video ...")
|
||||||
|
@ -11,12 +11,14 @@ i2v_14B.update(wan_shared_cfg)
|
|||||||
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
|
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
|
||||||
|
|
||||||
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||||
|
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||||
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||||
|
|
||||||
# clip
|
# clip
|
||||||
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
||||||
i2v_14B.clip_dtype = torch.float16
|
i2v_14B.clip_dtype = torch.float16
|
||||||
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
||||||
|
# i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
|
||||||
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
||||||
|
|
||||||
# vae
|
# vae
|
||||||
|
@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg)
|
|||||||
|
|
||||||
# t5
|
# t5
|
||||||
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||||
|
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||||
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||||
|
|
||||||
# vae
|
# vae
|
||||||
|
@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg)
|
|||||||
|
|
||||||
# t5
|
# t5
|
||||||
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||||
|
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||||
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
||||||
|
|
||||||
# vae
|
# vae
|
||||||
|
@ -16,6 +16,10 @@ import torch.distributed as dist
|
|||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from .modules.clip import CLIPModel
|
from .modules.clip import CLIPModel
|
||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
@ -42,6 +46,7 @@ class WanI2V:
|
|||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
init_on_cpu=True,
|
init_on_cpu=True,
|
||||||
|
fp8=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the image-to-video generation model components.
|
Initializes the image-to-video generation model components.
|
||||||
@ -65,6 +70,8 @@ class WanI2V:
|
|||||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||||
init_on_cpu (`bool`, *optional*, defaults to True):
|
init_on_cpu (`bool`, *optional*, defaults to True):
|
||||||
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
||||||
|
fp8 (`bool`, *optional*, defaults to False):
|
||||||
|
Enable 8-bit floating point precision for model parameters.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = torch.device(f"cuda:{device_id}")
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -76,6 +83,10 @@ class WanI2V:
|
|||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
|
|
||||||
shard_fn = partial(shard_model, device_id=device_id)
|
shard_fn = partial(shard_model, device_id=device_id)
|
||||||
|
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
|
||||||
|
quantization = "fp8_e4m3fn"
|
||||||
|
else:
|
||||||
|
quantization = "disabled"
|
||||||
self.text_encoder = T5EncoderModel(
|
self.text_encoder = T5EncoderModel(
|
||||||
text_len=config.text_len,
|
text_len=config.text_len,
|
||||||
dtype=config.t5_dtype,
|
dtype=config.t5_dtype,
|
||||||
@ -83,10 +94,12 @@ class WanI2V:
|
|||||||
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||||
shard_fn=shard_fn if t5_fsdp else None,
|
shard_fn=shard_fn if t5_fsdp else None,
|
||||||
|
quantization=quantization,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
self.vae_stride = config.vae_stride
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
self.vae = WanVAE(
|
self.vae = WanVAE(
|
||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||||
device=self.device)
|
device=self.device)
|
||||||
@ -99,7 +112,46 @@ class WanI2V:
|
|||||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
|
if not fp8:
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
|
else:
|
||||||
|
if '480P' in checkpoint_dir:
|
||||||
|
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
|
||||||
|
elif '720P' in checkpoint_dir:
|
||||||
|
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
|
||||||
|
dim = state_dict["patch_embedding.weight"].shape[0]
|
||||||
|
in_channels = state_dict["patch_embedding.weight"].shape[1]
|
||||||
|
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
|
||||||
|
model_type = "i2v" if in_channels == 36 else "t2v"
|
||||||
|
num_heads = 40 if dim == 5120 else 12
|
||||||
|
num_layers = 40 if dim == 5120 else 30
|
||||||
|
TRANSFORMER_CONFIG= {
|
||||||
|
"dim": dim,
|
||||||
|
"ffn_dim": ffn_dim,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"in_dim": in_channels,
|
||||||
|
"model_type": model_type,
|
||||||
|
"out_dim": 16,
|
||||||
|
"text_len": 512,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
"num_layers": num_layers,
|
||||||
|
}
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||||
|
|
||||||
|
base_dtype=torch.bfloat16
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||||
|
# dtype_to_use = torch.bfloat16
|
||||||
|
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
|
||||||
|
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
|
||||||
|
|
||||||
|
del state_dict
|
||||||
|
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if t5_fsdp or dit_fsdp or use_usp:
|
if t5_fsdp or dit_fsdp or use_usp:
|
||||||
@ -222,11 +274,13 @@ class WanI2V:
|
|||||||
# preprocess
|
# preprocess
|
||||||
if not self.t5_cpu:
|
if not self.t5_cpu:
|
||||||
self.text_encoder.model.to(self.device)
|
self.text_encoder.model.to(self.device)
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
||||||
context = self.text_encoder([input_prompt], self.device)
|
context = self.text_encoder([input_prompt], self.device)
|
||||||
context_null = self.text_encoder([n_prompt], self.device)
|
context_null = self.text_encoder([n_prompt], self.device)
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.text_encoder.model.cpu()
|
self.text_encoder.model.cpu()
|
||||||
else:
|
else:
|
||||||
|
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
|
||||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||||
context = [t.to(self.device) for t in context]
|
context = [t.to(self.device) for t in context]
|
||||||
@ -245,9 +299,12 @@ class WanI2V:
|
|||||||
torch.zeros(3, F - 1, h, w)
|
torch.zeros(3, F - 1, h, w)
|
||||||
],
|
],
|
||||||
dim=1).to(self.device)
|
dim=1).to(self.device)
|
||||||
])[0]
|
],device=self.device)[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, y])
|
||||||
|
|
||||||
|
if offload_model:
|
||||||
|
self.vae.model.cpu()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
yield
|
yield
|
||||||
@ -335,9 +392,11 @@ class WanI2V:
|
|||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# load vae model back to device
|
||||||
|
self.vae.model.to(self.device)
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0, device=self.device)
|
||||||
|
|
||||||
del noise, latent
|
del noise, latent
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from .attention import flash_attention
|
from .attention import flash_attention
|
||||||
from .tokenizers import HuggingfaceTokenizer
|
from .tokenizers import HuggingfaceTokenizer
|
||||||
@ -515,8 +516,13 @@ class CLIPModel:
|
|||||||
device=device)
|
device=device)
|
||||||
self.model = self.model.eval().requires_grad_(False)
|
self.model = self.model.eval().requires_grad_(False)
|
||||||
logging.info(f'loading {checkpoint_path}')
|
logging.info(f'loading {checkpoint_path}')
|
||||||
self.model.load_state_dict(
|
if checkpoint_path.endswith('.safetensors'):
|
||||||
torch.load(checkpoint_path, map_location='cpu'))
|
state_dict = load_file(checkpoint_path, device='cpu')
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
elif checkpoint_path.endswith('.pth'):
|
||||||
|
self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
self.tokenizer = HuggingfaceTokenizer(
|
self.tokenizer = HuggingfaceTokenizer(
|
||||||
|
@ -9,6 +9,10 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .tokenizers import HuggingfaceTokenizer
|
from .tokenizers import HuggingfaceTokenizer
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'T5Model',
|
'T5Model',
|
||||||
'T5Encoder',
|
'T5Encoder',
|
||||||
@ -442,7 +446,7 @@ def _t5(name,
|
|||||||
model = model_cls(**kwargs)
|
model = model_cls(**kwargs)
|
||||||
|
|
||||||
# set device
|
# set device
|
||||||
model = model.to(dtype=dtype, device=device)
|
# model = model.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
if return_tokenizer:
|
if return_tokenizer:
|
||||||
@ -479,6 +483,7 @@ class T5EncoderModel:
|
|||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
tokenizer_path=None,
|
tokenizer_path=None,
|
||||||
shard_fn=None,
|
shard_fn=None,
|
||||||
|
quantization="disabled",
|
||||||
):
|
):
|
||||||
self.text_len = text_len
|
self.text_len = text_len
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@ -486,14 +491,31 @@ class T5EncoderModel:
|
|||||||
self.checkpoint_path = checkpoint_path
|
self.checkpoint_path = checkpoint_path
|
||||||
self.tokenizer_path = tokenizer_path
|
self.tokenizer_path = tokenizer_path
|
||||||
|
|
||||||
|
|
||||||
|
logging.info(f'loading {checkpoint_path}')
|
||||||
|
if quantization == "disabled":
|
||||||
# init model
|
# init model
|
||||||
model = umt5_xxl(
|
model = umt5_xxl(
|
||||||
encoder_only=True,
|
encoder_only=True,
|
||||||
return_tokenizer=False,
|
return_tokenizer=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device).eval().requires_grad_(False)
|
device=device).eval().requires_grad_(False)
|
||||||
logging.info(f'loading {checkpoint_path}')
|
|
||||||
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||||
|
elif quantization == "fp8_e4m3fn":
|
||||||
|
with init_empty_weights():
|
||||||
|
model = umt5_xxl(
|
||||||
|
encoder_only=True,
|
||||||
|
return_tokenizer=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device).eval().requires_grad_(False)
|
||||||
|
cast_dtype = torch.float8_e4m3fn
|
||||||
|
state_dict = load_file(checkpoint_path, device="cpu")
|
||||||
|
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
|
||||||
|
set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
|
||||||
|
del state_dict
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
if shard_fn is not None:
|
if shard_fn is not None:
|
||||||
self.model = shard_fn(self.model, sync_module_states=False)
|
self.model = shard_fn(self.model, sync_module_states=False)
|
||||||
|
@ -644,7 +644,7 @@ class WanVAE:
|
|||||||
z_dim=z_dim,
|
z_dim=z_dim,
|
||||||
).eval().requires_grad_(False).to(device)
|
).eval().requires_grad_(False).to(device)
|
||||||
|
|
||||||
def encode(self, videos):
|
def encode(self, videos, device=None):
|
||||||
"""
|
"""
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
videos: A list of videos each with shape [C, T, H, W].
|
||||||
"""
|
"""
|
||||||
@ -654,7 +654,7 @@ class WanVAE:
|
|||||||
for u in videos
|
for u in videos
|
||||||
]
|
]
|
||||||
|
|
||||||
def decode(self, zs):
|
def decode(self, zs, device=None):
|
||||||
with amp.autocast(dtype=self.dtype):
|
with amp.autocast(dtype=self.dtype):
|
||||||
return [
|
return [
|
||||||
self.model.decode(u.unsqueeze(0),
|
self.model.decode(u.unsqueeze(0),
|
||||||
|
@ -14,6 +14,10 @@ import torch.cuda.amp as amp
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
@ -38,6 +42,8 @@ class WanT2V:
|
|||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
|
init_on_cpu=True,
|
||||||
|
fp8=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the Wan text-to-video generation model components.
|
Initializes the Wan text-to-video generation model components.
|
||||||
@ -59,6 +65,8 @@ class WanT2V:
|
|||||||
Enable distribution strategy of USP.
|
Enable distribution strategy of USP.
|
||||||
t5_cpu (`bool`, *optional*, defaults to False):
|
t5_cpu (`bool`, *optional*, defaults to False):
|
||||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||||
|
fp8 (`bool`, *optional*, defaults to False):
|
||||||
|
Enable 8-bit floating point precision for model parameters.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = torch.device(f"cuda:{device_id}")
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -69,13 +77,19 @@ class WanT2V:
|
|||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
|
|
||||||
shard_fn = partial(shard_model, device_id=device_id)
|
shard_fn = partial(shard_model, device_id=device_id)
|
||||||
|
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
|
||||||
|
quantization = "fp8_e4m3fn"
|
||||||
|
else:
|
||||||
|
quantization = "disabled"
|
||||||
|
|
||||||
self.text_encoder = T5EncoderModel(
|
self.text_encoder = T5EncoderModel(
|
||||||
text_len=config.text_len,
|
text_len=config.text_len,
|
||||||
dtype=config.t5_dtype,
|
dtype=config.t5_dtype,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||||
shard_fn=shard_fn if t5_fsdp else None)
|
shard_fn=shard_fn if t5_fsdp else None,
|
||||||
|
quantization=quantization)
|
||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
self.vae_stride = config.vae_stride
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
@ -84,9 +98,52 @@ class WanT2V:
|
|||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
|
if not fp8:
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
|
else:
|
||||||
|
if '14B' in checkpoint_dir:
|
||||||
|
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu")
|
||||||
|
|
||||||
|
dim = state_dict["patch_embedding.weight"].shape[0]
|
||||||
|
in_channels = state_dict["patch_embedding.weight"].shape[1]
|
||||||
|
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
|
||||||
|
model_type = "i2v" if in_channels == 36 else "t2v"
|
||||||
|
num_heads = 40 if dim == 5120 else 12
|
||||||
|
num_layers = 40 if dim == 5120 else 30
|
||||||
|
TRANSFORMER_CONFIG= {
|
||||||
|
"dim": dim,
|
||||||
|
"ffn_dim": ffn_dim,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"in_dim": in_channels,
|
||||||
|
"model_type": model_type,
|
||||||
|
"out_dim": 16,
|
||||||
|
"text_len": 512,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
"num_layers": num_layers,
|
||||||
|
}
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||||
|
|
||||||
|
base_dtype=torch.bfloat16
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||||
|
# dtype_to_use = torch.bfloat16
|
||||||
|
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
|
||||||
|
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
|
||||||
|
|
||||||
|
del state_dict
|
||||||
|
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
if t5_fsdp or dit_fsdp or use_usp:
|
||||||
|
init_on_cpu = False
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
|
|
||||||
@ -107,6 +164,7 @@ class WanT2V:
|
|||||||
if dit_fsdp:
|
if dit_fsdp:
|
||||||
self.model = shard_fn(self.model)
|
self.model = shard_fn(self.model)
|
||||||
else:
|
else:
|
||||||
|
if not init_on_cpu:
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
@ -173,11 +231,13 @@ class WanT2V:
|
|||||||
|
|
||||||
if not self.t5_cpu:
|
if not self.t5_cpu:
|
||||||
self.text_encoder.model.to(self.device)
|
self.text_encoder.model.to(self.device)
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
||||||
context = self.text_encoder([input_prompt], self.device)
|
context = self.text_encoder([input_prompt], self.device)
|
||||||
context_null = self.text_encoder([n_prompt], self.device)
|
context_null = self.text_encoder([n_prompt], self.device)
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.text_encoder.model.cpu()
|
self.text_encoder.model.cpu()
|
||||||
else:
|
else:
|
||||||
|
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
|
||||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||||
context = [t.to(self.device) for t in context]
|
context = [t.to(self.device) for t in context]
|
||||||
@ -194,6 +254,9 @@ class WanT2V:
|
|||||||
generator=seed_g)
|
generator=seed_g)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if offload_model:
|
||||||
|
self.vae.model.cpu()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
yield
|
yield
|
||||||
@ -230,13 +293,15 @@ class WanT2V:
|
|||||||
arg_c = {'context': context, 'seq_len': seq_len}
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
|
if offload_model:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
timestep = torch.stack(timestep)
|
timestep = torch.stack(timestep)
|
||||||
|
|
||||||
self.model.to(self.device)
|
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input, t=timestep, **arg_c)[0]
|
latent_model_input, t=timestep, **arg_c)[0]
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
@ -257,6 +322,9 @@ class WanT2V:
|
|||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# load vae model back to device
|
||||||
|
self.vae.model.to(self.device)
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user