This commit is contained in:
Yexiong Lin 2025-05-27 23:54:39 +10:00 committed by GitHub
commit 77161717c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 213 additions and 28 deletions

View File

@ -166,6 +166,14 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:
```
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
```
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
@ -302,6 +310,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`.
For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU:
1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder.
2. Then, execute the following command:
```
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.

View File

@ -155,6 +155,11 @@ def _parse_args():
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--fp8",
action="store_true",
default=False,
help="Whether to use fp8.")
parser.add_argument(
"--save_file",
type=str,
@ -366,6 +371,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
fp8=args.fp8,
)
logging.info(
@ -423,6 +429,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
fp8=args.fp8,
)
logging.info("Generating video ...")

View File

@ -11,12 +11,14 @@ i2v_14B.update(wan_shared_cfg)
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
# clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float16
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
# i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
# vae

View File

@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg)
# t5
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
# vae

View File

@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg)
# t5
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
# vae

View File

@ -16,6 +16,10 @@ import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
@ -42,6 +46,7 @@ class WanI2V:
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
fp8=False,
):
r"""
Initializes the image-to-video generation model components.
@ -65,6 +70,8 @@ class WanI2V:
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
fp8 (`bool`, *optional*, defaults to False):
Enable 8-bit floating point precision for model parameters.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
@ -76,6 +83,10 @@ class WanI2V:
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
quantization = "fp8_e4m3fn"
else:
quantization = "disabled"
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
@ -83,10 +94,12 @@ class WanI2V:
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
quantization=quantization,
)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
@ -99,7 +112,46 @@ class WanI2V:
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
if not fp8:
self.model = WanModel.from_pretrained(checkpoint_dir)
else:
if '480P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
elif '720P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1]
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12
num_layers = 40 if dim == 5120 else 30
TRANSFORMER_CONFIG= {
"dim": dim,
"ffn_dim": ffn_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": 16,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
}
with init_empty_weights():
self.model = WanModel(**TRANSFORMER_CONFIG)
base_dtype=torch.bfloat16
dtype=torch.float8_e4m3fn
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
for name, param in self.model.named_parameters():
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
# dtype_to_use = torch.bfloat16
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
@ -222,13 +274,15 @@ class WanI2V:
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
@ -245,9 +299,12 @@ class WanI2V:
torch.zeros(3, F - 1, h, w)
],
dim=1).to(self.device)
])[0]
],device=self.device)[0]
y = torch.concat([msk, y])
if offload_model:
self.vae.model.cpu()
@contextmanager
def noop_no_sync():
yield
@ -335,9 +392,11 @@ class WanI2V:
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
# load vae model back to device
self.vae.model.to(self.device)
if self.rank == 0:
videos = self.vae.decode(x0)
videos = self.vae.decode(x0, device=self.device)
del noise, latent
del sample_scheduler

View File

@ -7,6 +7,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from safetensors.torch import load_file
from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer
@ -515,8 +516,13 @@ class CLIPModel:
device=device)
self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu'))
if checkpoint_path.endswith('.safetensors'):
state_dict = load_file(checkpoint_path, device='cpu')
self.model.load_state_dict(state_dict)
elif checkpoint_path.endswith('.pth'):
self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
else:
raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(

View File

@ -9,6 +9,10 @@ import torch.nn.functional as F
from .tokenizers import HuggingfaceTokenizer
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
__all__ = [
'T5Model',
'T5Encoder',
@ -442,7 +446,7 @@ def _t5(name,
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
# model = model.to(dtype=dtype, device=device)
# init tokenizer
if return_tokenizer:
@ -479,6 +483,7 @@ class T5EncoderModel:
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
quantization="disabled",
):
self.text_len = text_len
self.dtype = dtype
@ -486,14 +491,31 @@ class T5EncoderModel:
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
if quantization == "disabled":
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
elif quantization == "fp8_e4m3fn":
with init_empty_weights():
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
cast_dtype = torch.float8_e4m3fn
state_dict = load_file(checkpoint_path, device="cpu")
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
for name, param in model.named_parameters():
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)

View File

@ -644,7 +644,7 @@ class WanVAE:
z_dim=z_dim,
).eval().requires_grad_(False).to(device)
def encode(self, videos):
def encode(self, videos, device=None):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
@ -654,7 +654,7 @@ class WanVAE:
for u in videos
]
def decode(self, zs):
def decode(self, zs, device=None):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),

View File

@ -14,6 +14,10 @@ import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
@ -38,6 +42,8 @@ class WanT2V:
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
fp8=False,
):
r"""
Initializes the Wan text-to-video generation model components.
@ -59,6 +65,8 @@ class WanT2V:
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
fp8 (`bool`, *optional*, defaults to False):
Enable 8-bit floating point precision for model parameters.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
@ -69,13 +77,19 @@ class WanT2V:
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
quantization = "fp8_e4m3fn"
else:
quantization = "disabled"
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
shard_fn=shard_fn if t5_fsdp else None,
quantization=quantization)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
@ -84,9 +98,52 @@ class WanT2V:
device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
if not fp8:
self.model = WanModel.from_pretrained(checkpoint_dir)
else:
if '14B' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu")
else:
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1]
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12
num_layers = 40 if dim == 5120 else 30
TRANSFORMER_CONFIG= {
"dim": dim,
"ffn_dim": ffn_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": 16,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
}
with init_empty_weights():
self.model = WanModel(**TRANSFORMER_CONFIG)
base_dtype=torch.bfloat16
dtype=torch.float8_e4m3fn
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
for name, param in self.model.named_parameters():
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
# dtype_to_use = torch.bfloat16
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
@ -107,7 +164,8 @@ class WanT2V:
if dit_fsdp:
self.model = shard_fn(self.model)
else:
self.model.to(self.device)
if not init_on_cpu:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
@ -173,13 +231,15 @@ class WanT2V:
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
@ -194,6 +254,9 @@ class WanT2V:
generator=seed_g)
]
if offload_model:
self.vae.model.cpu()
@contextmanager
def noop_no_sync():
yield
@ -230,13 +293,15 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
@ -257,6 +322,9 @@ class WanT2V:
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
# load vae model back to device
self.vae.model.to(self.device)
if self.rank == 0:
videos = self.vae.decode(x0)