diff --git a/README.md b/README.md index 0511a54..697a266 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,8 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." ``` +You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder. + > 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance. @@ -222,6 +224,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." ``` +To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`. + +For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU: + +1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder. +2. Then, execute the following command: + +``` +python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +``` + > 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. diff --git a/generate.py b/generate.py index f8bd885..8dd8e05 100644 --- a/generate.py +++ b/generate.py @@ -311,6 +311,7 @@ def generate(args): dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, + fp8=args.fp8, ) logging.info( diff --git a/wan/image2video.py b/wan/image2video.py index 8b9e653..84eb6f7 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -103,6 +103,7 @@ class WanI2V: config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) + logging.info(f"Creating WanModel from {checkpoint_dir}") if not fp8: self.model = WanModel.from_pretrained(checkpoint_dir) else: @@ -131,7 +132,6 @@ class WanI2V: with init_empty_weights(): self.model = WanModel(**TRANSFORMER_CONFIG) - logging.info(f"Creating WanModel from {checkpoint_dir}") base_dtype=torch.bfloat16 dtype=torch.float8_e4m3fn @@ -382,9 +382,10 @@ class WanI2V: if offload_model: self.model.cpu() torch.cuda.empty_cache() + # load vae model back to device + self.vae.model.to(self.device) if self.rank == 0: - self.vae.model.to(self.device) videos = self.vae.decode(x0, device=self.device) del noise, latent diff --git a/wan/text2video.py b/wan/text2video.py index 2400545..7b0bf3d 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -14,6 +14,10 @@ import torch.cuda.amp as amp import torch.distributed as dist from tqdm import tqdm +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from safetensors.torch import load_file + from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel @@ -35,6 +39,8 @@ class WanT2V: dit_fsdp=False, use_usp=False, t5_cpu=False, + init_on_cpu=True, + fp8=False, ): r""" Initializes the Wan text-to-video generation model components. @@ -56,6 +62,8 @@ class WanT2V: Enable distribution strategy of USP. t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. + fp8 (`bool`, *optional*, defaults to False): + Enable 8-bit floating point precision for model parameters. """ self.device = torch.device(f"cuda:{device_id}") self.config = config @@ -81,9 +89,52 @@ class WanT2V: device=self.device) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + if not fp8: + self.model = WanModel.from_pretrained(checkpoint_dir) + else: + if '14B' in checkpoint_dir: + state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu") + else: + state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu") + + dim = state_dict["patch_embedding.weight"].shape[0] + in_channels = state_dict["patch_embedding.weight"].shape[1] + ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] + model_type = "i2v" if in_channels == 36 else "t2v" + num_heads = 40 if dim == 5120 else 12 + num_layers = 40 if dim == 5120 else 30 + TRANSFORMER_CONFIG= { + "dim": dim, + "ffn_dim": ffn_dim, + "eps": 1e-06, + "freq_dim": 256, + "in_dim": in_channels, + "model_type": model_type, + "out_dim": 16, + "text_len": 512, + "num_heads": num_heads, + "num_layers": num_layers, + } + + with init_empty_weights(): + self.model = WanModel(**TRANSFORMER_CONFIG) + + base_dtype=torch.bfloat16 + dtype=torch.float8_e4m3fn + params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"} + for name, param in self.model.named_parameters(): + dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype + # dtype_to_use = torch.bfloat16 + # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use) + set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) + + del state_dict + self.model.eval().requires_grad_(False) + if t5_fsdp or dit_fsdp or use_usp: + init_on_cpu = False + if use_usp: from xfuser.core.distributed import \ get_sequence_parallel_world_size @@ -103,7 +154,8 @@ class WanT2V: if dit_fsdp: self.model = shard_fn(self.model) else: - self.model.to(self.device) + if not init_on_cpu: + self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt @@ -190,6 +242,9 @@ class WanT2V: generator=seed_g) ] + if offload_model: + self.vae.model.cpu() + @contextmanager def noop_no_sync(): yield @@ -226,6 +281,10 @@ class WanT2V: arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} + if offload_model: + torch.cuda.empty_cache() + + self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents timestep = [t] @@ -253,6 +312,9 @@ class WanT2V: if offload_model: self.model.cpu() torch.cuda.empty_cache() + # load vae model back to device + self.vae.model.to(self.device) + if self.rank == 0: videos = self.vae.decode(x0)