mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 15:54:52 +00:00
Update README.md and text2video.py to offload model and enable using fp8
This commit is contained in:
parent
24007c2c39
commit
db54b7c613
13
README.md
13
README.md
@ -135,6 +135,8 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr
|
|||||||
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
|
||||||
|
|
||||||
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
|
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
|
||||||
|
|
||||||
|
|
||||||
@ -222,6 +224,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and
|
|||||||
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`.
|
||||||
|
|
||||||
|
For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU:
|
||||||
|
|
||||||
|
1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder.
|
||||||
|
2. Then, execute the following command:
|
||||||
|
|
||||||
|
```
|
||||||
|
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||||
|
```
|
||||||
|
|
||||||
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
|
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
|
||||||
|
|
||||||
|
|
||||||
|
@ -311,6 +311,7 @@ def generate(args):
|
|||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
|
fp8=args.fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -103,6 +103,7 @@ class WanI2V:
|
|||||||
config.clip_checkpoint),
|
config.clip_checkpoint),
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||||
|
|
||||||
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
if not fp8:
|
if not fp8:
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
else:
|
else:
|
||||||
@ -131,7 +132,6 @@ class WanI2V:
|
|||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
self.model = WanModel(**TRANSFORMER_CONFIG)
|
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
|
||||||
|
|
||||||
base_dtype=torch.bfloat16
|
base_dtype=torch.bfloat16
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
@ -382,9 +382,10 @@ class WanI2V:
|
|||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# load vae model back to device
|
||||||
|
self.vae.model.to(self.device)
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self.vae.model.to(self.device)
|
|
||||||
videos = self.vae.decode(x0, device=self.device)
|
videos = self.vae.decode(x0, device=self.device)
|
||||||
|
|
||||||
del noise, latent
|
del noise, latent
|
||||||
|
@ -14,6 +14,10 @@ import torch.cuda.amp as amp
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
@ -35,6 +39,8 @@ class WanT2V:
|
|||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
|
init_on_cpu=True,
|
||||||
|
fp8=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the Wan text-to-video generation model components.
|
Initializes the Wan text-to-video generation model components.
|
||||||
@ -56,6 +62,8 @@ class WanT2V:
|
|||||||
Enable distribution strategy of USP.
|
Enable distribution strategy of USP.
|
||||||
t5_cpu (`bool`, *optional*, defaults to False):
|
t5_cpu (`bool`, *optional*, defaults to False):
|
||||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||||
|
fp8 (`bool`, *optional*, defaults to False):
|
||||||
|
Enable 8-bit floating point precision for model parameters.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = torch.device(f"cuda:{device_id}")
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -81,9 +89,52 @@ class WanT2V:
|
|||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
if not fp8:
|
||||||
|
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
|
else:
|
||||||
|
if '14B' in checkpoint_dir:
|
||||||
|
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu")
|
||||||
|
|
||||||
|
dim = state_dict["patch_embedding.weight"].shape[0]
|
||||||
|
in_channels = state_dict["patch_embedding.weight"].shape[1]
|
||||||
|
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
|
||||||
|
model_type = "i2v" if in_channels == 36 else "t2v"
|
||||||
|
num_heads = 40 if dim == 5120 else 12
|
||||||
|
num_layers = 40 if dim == 5120 else 30
|
||||||
|
TRANSFORMER_CONFIG= {
|
||||||
|
"dim": dim,
|
||||||
|
"ffn_dim": ffn_dim,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"in_dim": in_channels,
|
||||||
|
"model_type": model_type,
|
||||||
|
"out_dim": 16,
|
||||||
|
"text_len": 512,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
"num_layers": num_layers,
|
||||||
|
}
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||||
|
|
||||||
|
base_dtype=torch.bfloat16
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||||
|
# dtype_to_use = torch.bfloat16
|
||||||
|
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
|
||||||
|
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
|
||||||
|
|
||||||
|
del state_dict
|
||||||
|
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
if t5_fsdp or dit_fsdp or use_usp:
|
||||||
|
init_on_cpu = False
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import \
|
||||||
get_sequence_parallel_world_size
|
get_sequence_parallel_world_size
|
||||||
@ -103,7 +154,8 @@ class WanT2V:
|
|||||||
if dit_fsdp:
|
if dit_fsdp:
|
||||||
self.model = shard_fn(self.model)
|
self.model = shard_fn(self.model)
|
||||||
else:
|
else:
|
||||||
self.model.to(self.device)
|
if not init_on_cpu:
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
@ -190,6 +242,9 @@ class WanT2V:
|
|||||||
generator=seed_g)
|
generator=seed_g)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if offload_model:
|
||||||
|
self.vae.model.cpu()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
yield
|
yield
|
||||||
@ -226,6 +281,10 @@ class WanT2V:
|
|||||||
arg_c = {'context': context, 'seq_len': seq_len}
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
|
if offload_model:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
@ -253,6 +312,9 @@ class WanT2V:
|
|||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# load vae model back to device
|
||||||
|
self.vae.model.to(self.device)
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user