mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Merge 36d6d91b90 into 7c81b2f27d
				
					
				
			This commit is contained in:
		
						commit
						2f26de8423
					
				
							
								
								
									
										19
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								README.md
									
									
									
									
									
								
							@ -171,6 +171,14 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr
 | 
			
		||||
python generate.py  --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
 | 
			
		||||
 | 
			
		||||
Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -307,6 +315,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and
 | 
			
		||||
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`.
 | 
			
		||||
 | 
			
		||||
For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU:
 | 
			
		||||
 | 
			
		||||
1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder.
 | 
			
		||||
2. Then, execute the following command:
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." 
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -155,6 +155,11 @@ def _parse_args():
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
        help="Whether to use FSDP for DiT.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--fp8",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
        help="Whether to use fp8.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--save_file",
 | 
			
		||||
        type=str,
 | 
			
		||||
@ -366,6 +371,7 @@ def generate(args):
 | 
			
		||||
            dit_fsdp=args.dit_fsdp,
 | 
			
		||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
            t5_cpu=args.t5_cpu,
 | 
			
		||||
            fp8=args.fp8,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        logging.info(
 | 
			
		||||
@ -423,6 +429,7 @@ def generate(args):
 | 
			
		||||
            dit_fsdp=args.dit_fsdp,
 | 
			
		||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
            t5_cpu=args.t5_cpu,
 | 
			
		||||
            fp8=args.fp8,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        logging.info("Generating video ...")
 | 
			
		||||
 | 
			
		||||
@ -11,12 +11,14 @@ i2v_14B.update(wan_shared_cfg)
 | 
			
		||||
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
 | 
			
		||||
 | 
			
		||||
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
 | 
			
		||||
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
 | 
			
		||||
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
 | 
			
		||||
 | 
			
		||||
# clip
 | 
			
		||||
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
 | 
			
		||||
i2v_14B.clip_dtype = torch.float16
 | 
			
		||||
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
 | 
			
		||||
# i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
 | 
			
		||||
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
 | 
			
		||||
 | 
			
		||||
# vae
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg)
 | 
			
		||||
 | 
			
		||||
# t5
 | 
			
		||||
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
 | 
			
		||||
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
 | 
			
		||||
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
 | 
			
		||||
 | 
			
		||||
# vae
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg)
 | 
			
		||||
 | 
			
		||||
# t5
 | 
			
		||||
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
 | 
			
		||||
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
 | 
			
		||||
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
 | 
			
		||||
 | 
			
		||||
# vae
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,10 @@ import torch.distributed as dist
 | 
			
		||||
import torchvision.transforms.functional as TF
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
from accelerate import init_empty_weights
 | 
			
		||||
from accelerate.utils import set_module_tensor_to_device
 | 
			
		||||
from safetensors.torch import load_file
 | 
			
		||||
 | 
			
		||||
from .distributed.fsdp import shard_model
 | 
			
		||||
from .modules.clip import CLIPModel
 | 
			
		||||
from .modules.model import WanModel
 | 
			
		||||
@ -42,6 +46,7 @@ class WanI2V:
 | 
			
		||||
        use_usp=False,
 | 
			
		||||
        t5_cpu=False,
 | 
			
		||||
        init_on_cpu=True,
 | 
			
		||||
        fp8=False,
 | 
			
		||||
    ):
 | 
			
		||||
        r"""
 | 
			
		||||
        Initializes the image-to-video generation model components.
 | 
			
		||||
@ -65,6 +70,8 @@ class WanI2V:
 | 
			
		||||
                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
			
		||||
            init_on_cpu (`bool`, *optional*, defaults to True):
 | 
			
		||||
                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
 | 
			
		||||
            fp8 (`bool`, *optional*, defaults to False):
 | 
			
		||||
                Enable 8-bit floating point precision for model parameters.
 | 
			
		||||
        """
 | 
			
		||||
        self.device = torch.device(f"cuda:{device_id}")
 | 
			
		||||
        self.config = config
 | 
			
		||||
@ -76,6 +83,10 @@ class WanI2V:
 | 
			
		||||
        self.param_dtype = config.param_dtype
 | 
			
		||||
 | 
			
		||||
        shard_fn = partial(shard_model, device_id=device_id)
 | 
			
		||||
        if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
 | 
			
		||||
            quantization = "fp8_e4m3fn" 
 | 
			
		||||
        else:
 | 
			
		||||
            quantization = "disabled"
 | 
			
		||||
        self.text_encoder = T5EncoderModel(
 | 
			
		||||
            text_len=config.text_len,
 | 
			
		||||
            dtype=config.t5_dtype,
 | 
			
		||||
@ -83,10 +94,12 @@ class WanI2V:
 | 
			
		||||
            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
 | 
			
		||||
            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
 | 
			
		||||
            shard_fn=shard_fn if t5_fsdp else None,
 | 
			
		||||
            quantization=quantization,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.vae_stride = config.vae_stride
 | 
			
		||||
        self.patch_size = config.patch_size
 | 
			
		||||
 | 
			
		||||
        self.vae = WanVAE(
 | 
			
		||||
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
			
		||||
            device=self.device)
 | 
			
		||||
@ -99,7 +112,46 @@ class WanI2V:
 | 
			
		||||
            tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
			
		||||
        self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
			
		||||
        if not fp8:
 | 
			
		||||
            self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
			
		||||
        else:
 | 
			
		||||
            if '480P' in checkpoint_dir:
 | 
			
		||||
                state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
 | 
			
		||||
            elif '720P' in checkpoint_dir:
 | 
			
		||||
                state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
 | 
			
		||||
            dim = state_dict["patch_embedding.weight"].shape[0]
 | 
			
		||||
            in_channels = state_dict["patch_embedding.weight"].shape[1]
 | 
			
		||||
            ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
 | 
			
		||||
            model_type = "i2v" if in_channels == 36 else "t2v"
 | 
			
		||||
            num_heads = 40 if dim == 5120 else 12
 | 
			
		||||
            num_layers = 40 if dim == 5120 else 30
 | 
			
		||||
            TRANSFORMER_CONFIG= {
 | 
			
		||||
                "dim": dim,
 | 
			
		||||
                "ffn_dim": ffn_dim,
 | 
			
		||||
                "eps": 1e-06,
 | 
			
		||||
                "freq_dim": 256,
 | 
			
		||||
                "in_dim": in_channels,
 | 
			
		||||
                "model_type": model_type,
 | 
			
		||||
                "out_dim": 16,
 | 
			
		||||
                "text_len": 512,
 | 
			
		||||
                "num_heads": num_heads,
 | 
			
		||||
                "num_layers": num_layers,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            with init_empty_weights():
 | 
			
		||||
                self.model = WanModel(**TRANSFORMER_CONFIG)
 | 
			
		||||
            
 | 
			
		||||
            base_dtype=torch.bfloat16
 | 
			
		||||
            dtype=torch.float8_e4m3fn
 | 
			
		||||
            params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
 | 
			
		||||
            for name, param in self.model.named_parameters():
 | 
			
		||||
                dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
 | 
			
		||||
                # dtype_to_use = torch.bfloat16
 | 
			
		||||
                # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
 | 
			
		||||
                set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
 | 
			
		||||
 | 
			
		||||
            del state_dict
 | 
			
		||||
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
        if t5_fsdp or dit_fsdp or use_usp:
 | 
			
		||||
@ -222,13 +274,15 @@ class WanI2V:
 | 
			
		||||
        # preprocess
 | 
			
		||||
        if not self.t5_cpu:
 | 
			
		||||
            self.text_encoder.model.to(self.device)
 | 
			
		||||
            context = self.text_encoder([input_prompt], self.device)
 | 
			
		||||
            context_null = self.text_encoder([n_prompt], self.device)
 | 
			
		||||
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
 | 
			
		||||
                context = self.text_encoder([input_prompt], self.device)
 | 
			
		||||
                context_null = self.text_encoder([n_prompt], self.device)
 | 
			
		||||
            if offload_model:
 | 
			
		||||
                self.text_encoder.model.cpu()
 | 
			
		||||
        else:
 | 
			
		||||
            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
			
		||||
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
			
		||||
            with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
 | 
			
		||||
                context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
			
		||||
                context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
			
		||||
            context = [t.to(self.device) for t in context]
 | 
			
		||||
            context_null = [t.to(self.device) for t in context_null]
 | 
			
		||||
 | 
			
		||||
@ -245,9 +299,12 @@ class WanI2V:
 | 
			
		||||
                torch.zeros(3, F - 1, h, w)
 | 
			
		||||
            ],
 | 
			
		||||
                         dim=1).to(self.device)
 | 
			
		||||
        ])[0]
 | 
			
		||||
        ],device=self.device)[0]
 | 
			
		||||
        y = torch.concat([msk, y])
 | 
			
		||||
 | 
			
		||||
        if offload_model:
 | 
			
		||||
            self.vae.model.cpu()
 | 
			
		||||
 | 
			
		||||
        @contextmanager
 | 
			
		||||
        def noop_no_sync():
 | 
			
		||||
            yield
 | 
			
		||||
@ -335,9 +392,11 @@ class WanI2V:
 | 
			
		||||
            if offload_model:
 | 
			
		||||
                self.model.cpu()
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
                # load vae model back to device
 | 
			
		||||
                self.vae.model.to(self.device)
 | 
			
		||||
 | 
			
		||||
            if self.rank == 0:
 | 
			
		||||
                videos = self.vae.decode(x0)
 | 
			
		||||
                videos = self.vae.decode(x0, device=self.device)
 | 
			
		||||
 | 
			
		||||
        del noise, latent
 | 
			
		||||
        del sample_scheduler
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@ import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import torchvision.transforms as T
 | 
			
		||||
from safetensors.torch import load_file
 | 
			
		||||
 | 
			
		||||
from .attention import flash_attention
 | 
			
		||||
from .tokenizers import HuggingfaceTokenizer
 | 
			
		||||
@ -515,8 +516,13 @@ class CLIPModel:
 | 
			
		||||
            device=device)
 | 
			
		||||
        self.model = self.model.eval().requires_grad_(False)
 | 
			
		||||
        logging.info(f'loading {checkpoint_path}')
 | 
			
		||||
        self.model.load_state_dict(
 | 
			
		||||
            torch.load(checkpoint_path, map_location='cpu'))
 | 
			
		||||
        if checkpoint_path.endswith('.safetensors'):
 | 
			
		||||
            state_dict = load_file(checkpoint_path, device='cpu')
 | 
			
		||||
            self.model.load_state_dict(state_dict)
 | 
			
		||||
        elif checkpoint_path.endswith('.pth'):
 | 
			
		||||
            self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')
 | 
			
		||||
 | 
			
		||||
        # init tokenizer
 | 
			
		||||
        self.tokenizer = HuggingfaceTokenizer(
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,10 @@ import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from .tokenizers import HuggingfaceTokenizer
 | 
			
		||||
 | 
			
		||||
from accelerate import init_empty_weights
 | 
			
		||||
from accelerate.utils import set_module_tensor_to_device
 | 
			
		||||
from safetensors.torch import load_file
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'T5Model',
 | 
			
		||||
    'T5Encoder',
 | 
			
		||||
@ -442,7 +446,7 @@ def _t5(name,
 | 
			
		||||
        model = model_cls(**kwargs)
 | 
			
		||||
 | 
			
		||||
    # set device
 | 
			
		||||
    model = model.to(dtype=dtype, device=device)
 | 
			
		||||
    # model = model.to(dtype=dtype, device=device)
 | 
			
		||||
 | 
			
		||||
    # init tokenizer
 | 
			
		||||
    if return_tokenizer:
 | 
			
		||||
@ -479,6 +483,7 @@ class T5EncoderModel:
 | 
			
		||||
        checkpoint_path=None,
 | 
			
		||||
        tokenizer_path=None,
 | 
			
		||||
        shard_fn=None,
 | 
			
		||||
        quantization="disabled",
 | 
			
		||||
    ):
 | 
			
		||||
        self.text_len = text_len
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
@ -486,14 +491,31 @@ class T5EncoderModel:
 | 
			
		||||
        self.checkpoint_path = checkpoint_path
 | 
			
		||||
        self.tokenizer_path = tokenizer_path
 | 
			
		||||
 | 
			
		||||
        # init model
 | 
			
		||||
        model = umt5_xxl(
 | 
			
		||||
            encoder_only=True,
 | 
			
		||||
            return_tokenizer=False,
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
            device=device).eval().requires_grad_(False)
 | 
			
		||||
        
 | 
			
		||||
        logging.info(f'loading {checkpoint_path}')
 | 
			
		||||
        model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
 | 
			
		||||
        if quantization == "disabled":
 | 
			
		||||
            # init model
 | 
			
		||||
            model = umt5_xxl(
 | 
			
		||||
                encoder_only=True,
 | 
			
		||||
                return_tokenizer=False,
 | 
			
		||||
                dtype=dtype,
 | 
			
		||||
                device=device).eval().requires_grad_(False)
 | 
			
		||||
            model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
 | 
			
		||||
        elif quantization == "fp8_e4m3fn":
 | 
			
		||||
            with init_empty_weights():
 | 
			
		||||
                model = umt5_xxl(
 | 
			
		||||
                    encoder_only=True,
 | 
			
		||||
                    return_tokenizer=False,
 | 
			
		||||
                    dtype=dtype,
 | 
			
		||||
                    device=device).eval().requires_grad_(False)
 | 
			
		||||
            cast_dtype = torch.float8_e4m3fn
 | 
			
		||||
            state_dict = load_file(checkpoint_path, device="cpu")
 | 
			
		||||
            params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
 | 
			
		||||
            for name, param in model.named_parameters():
 | 
			
		||||
                dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
 | 
			
		||||
                set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
 | 
			
		||||
            del state_dict
 | 
			
		||||
 | 
			
		||||
        self.model = model
 | 
			
		||||
        if shard_fn is not None:
 | 
			
		||||
            self.model = shard_fn(self.model, sync_module_states=False)
 | 
			
		||||
 | 
			
		||||
@ -644,7 +644,7 @@ class WanVAE:
 | 
			
		||||
            z_dim=z_dim,
 | 
			
		||||
        ).eval().requires_grad_(False).to(device)
 | 
			
		||||
 | 
			
		||||
    def encode(self, videos):
 | 
			
		||||
    def encode(self, videos, device=None):
 | 
			
		||||
        """
 | 
			
		||||
        videos: A list of videos each with shape [C, T, H, W].
 | 
			
		||||
        """
 | 
			
		||||
@ -654,7 +654,7 @@ class WanVAE:
 | 
			
		||||
                for u in videos
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
    def decode(self, zs):
 | 
			
		||||
    def decode(self, zs, device=None):
 | 
			
		||||
        with amp.autocast(dtype=self.dtype):
 | 
			
		||||
            return [
 | 
			
		||||
                self.model.decode(u.unsqueeze(0),
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,10 @@ import torch.cuda.amp as amp
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
from accelerate import init_empty_weights
 | 
			
		||||
from accelerate.utils import set_module_tensor_to_device
 | 
			
		||||
from safetensors.torch import load_file
 | 
			
		||||
 | 
			
		||||
from .distributed.fsdp import shard_model
 | 
			
		||||
from .modules.model import WanModel
 | 
			
		||||
from .modules.t5 import T5EncoderModel
 | 
			
		||||
@ -38,6 +42,8 @@ class WanT2V:
 | 
			
		||||
        dit_fsdp=False,
 | 
			
		||||
        use_usp=False,
 | 
			
		||||
        t5_cpu=False,
 | 
			
		||||
        init_on_cpu=True,
 | 
			
		||||
        fp8=False,
 | 
			
		||||
    ):
 | 
			
		||||
        r"""
 | 
			
		||||
        Initializes the Wan text-to-video generation model components.
 | 
			
		||||
@ -59,6 +65,8 @@ class WanT2V:
 | 
			
		||||
                Enable distribution strategy of USP.
 | 
			
		||||
            t5_cpu (`bool`, *optional*, defaults to False):
 | 
			
		||||
                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
			
		||||
            fp8 (`bool`, *optional*, defaults to False):
 | 
			
		||||
                Enable 8-bit floating point precision for model parameters.
 | 
			
		||||
        """
 | 
			
		||||
        self.device = torch.device(f"cuda:{device_id}")
 | 
			
		||||
        self.config = config
 | 
			
		||||
@ -69,13 +77,19 @@ class WanT2V:
 | 
			
		||||
        self.param_dtype = config.param_dtype
 | 
			
		||||
 | 
			
		||||
        shard_fn = partial(shard_model, device_id=device_id)
 | 
			
		||||
        if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
 | 
			
		||||
            quantization = "fp8_e4m3fn" 
 | 
			
		||||
        else:
 | 
			
		||||
            quantization = "disabled"
 | 
			
		||||
        
 | 
			
		||||
        self.text_encoder = T5EncoderModel(
 | 
			
		||||
            text_len=config.text_len,
 | 
			
		||||
            dtype=config.t5_dtype,
 | 
			
		||||
            device=torch.device('cpu'),
 | 
			
		||||
            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
 | 
			
		||||
            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
 | 
			
		||||
            shard_fn=shard_fn if t5_fsdp else None)
 | 
			
		||||
            shard_fn=shard_fn if t5_fsdp else None,
 | 
			
		||||
            quantization=quantization)
 | 
			
		||||
 | 
			
		||||
        self.vae_stride = config.vae_stride
 | 
			
		||||
        self.patch_size = config.patch_size
 | 
			
		||||
@ -84,9 +98,52 @@ class WanT2V:
 | 
			
		||||
            device=self.device)
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
			
		||||
        self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
			
		||||
        if not fp8:
 | 
			
		||||
            self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
			
		||||
        else:
 | 
			
		||||
            if '14B' in checkpoint_dir:
 | 
			
		||||
                state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu")
 | 
			
		||||
            else:
 | 
			
		||||
                state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu")
 | 
			
		||||
        
 | 
			
		||||
            dim = state_dict["patch_embedding.weight"].shape[0]
 | 
			
		||||
            in_channels = state_dict["patch_embedding.weight"].shape[1]
 | 
			
		||||
            ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
 | 
			
		||||
            model_type = "i2v" if in_channels == 36 else "t2v"
 | 
			
		||||
            num_heads = 40 if dim == 5120 else 12
 | 
			
		||||
            num_layers = 40 if dim == 5120 else 30
 | 
			
		||||
            TRANSFORMER_CONFIG= {
 | 
			
		||||
                "dim": dim,
 | 
			
		||||
                "ffn_dim": ffn_dim,
 | 
			
		||||
                "eps": 1e-06,
 | 
			
		||||
                "freq_dim": 256,
 | 
			
		||||
                "in_dim": in_channels,
 | 
			
		||||
                "model_type": model_type,
 | 
			
		||||
                "out_dim": 16,
 | 
			
		||||
                "text_len": 512,
 | 
			
		||||
                "num_heads": num_heads,
 | 
			
		||||
                "num_layers": num_layers,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            with init_empty_weights():
 | 
			
		||||
                self.model = WanModel(**TRANSFORMER_CONFIG)
 | 
			
		||||
            
 | 
			
		||||
            base_dtype=torch.bfloat16
 | 
			
		||||
            dtype=torch.float8_e4m3fn
 | 
			
		||||
            params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
 | 
			
		||||
            for name, param in self.model.named_parameters():
 | 
			
		||||
                dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
 | 
			
		||||
                # dtype_to_use = torch.bfloat16
 | 
			
		||||
                # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
 | 
			
		||||
                set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
 | 
			
		||||
 | 
			
		||||
            del state_dict
 | 
			
		||||
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
        if t5_fsdp or dit_fsdp or use_usp:
 | 
			
		||||
            init_on_cpu = False
 | 
			
		||||
 | 
			
		||||
        if use_usp:
 | 
			
		||||
            from xfuser.core.distributed import get_sequence_parallel_world_size
 | 
			
		||||
 | 
			
		||||
@ -107,7 +164,8 @@ class WanT2V:
 | 
			
		||||
        if dit_fsdp:
 | 
			
		||||
            self.model = shard_fn(self.model)
 | 
			
		||||
        else:
 | 
			
		||||
            self.model.to(self.device)
 | 
			
		||||
            if not init_on_cpu:
 | 
			
		||||
                self.model.to(self.device)
 | 
			
		||||
 | 
			
		||||
        self.sample_neg_prompt = config.sample_neg_prompt
 | 
			
		||||
 | 
			
		||||
@ -173,13 +231,15 @@ class WanT2V:
 | 
			
		||||
 | 
			
		||||
        if not self.t5_cpu:
 | 
			
		||||
            self.text_encoder.model.to(self.device)
 | 
			
		||||
            context = self.text_encoder([input_prompt], self.device)
 | 
			
		||||
            context_null = self.text_encoder([n_prompt], self.device)
 | 
			
		||||
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
 | 
			
		||||
                context = self.text_encoder([input_prompt], self.device)
 | 
			
		||||
                context_null = self.text_encoder([n_prompt], self.device)
 | 
			
		||||
            if offload_model:
 | 
			
		||||
                self.text_encoder.model.cpu()
 | 
			
		||||
        else:
 | 
			
		||||
            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
			
		||||
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
			
		||||
            with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
 | 
			
		||||
                context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
			
		||||
                context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
			
		||||
            context = [t.to(self.device) for t in context]
 | 
			
		||||
            context_null = [t.to(self.device) for t in context_null]
 | 
			
		||||
 | 
			
		||||
@ -194,6 +254,9 @@ class WanT2V:
 | 
			
		||||
                generator=seed_g)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        if offload_model:
 | 
			
		||||
            self.vae.model.cpu()
 | 
			
		||||
 | 
			
		||||
        @contextmanager
 | 
			
		||||
        def noop_no_sync():
 | 
			
		||||
            yield
 | 
			
		||||
@ -230,13 +293,15 @@ class WanT2V:
 | 
			
		||||
            arg_c = {'context': context, 'seq_len': seq_len}
 | 
			
		||||
            arg_null = {'context': context_null, 'seq_len': seq_len}
 | 
			
		||||
 | 
			
		||||
            if offload_model:
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
            self.model.to(self.device)
 | 
			
		||||
            for _, t in enumerate(tqdm(timesteps)):
 | 
			
		||||
                latent_model_input = latents
 | 
			
		||||
                timestep = [t]
 | 
			
		||||
 | 
			
		||||
                timestep = torch.stack(timestep)
 | 
			
		||||
 | 
			
		||||
                self.model.to(self.device)
 | 
			
		||||
                noise_pred_cond = self.model(
 | 
			
		||||
                    latent_model_input, t=timestep, **arg_c)[0]
 | 
			
		||||
                noise_pred_uncond = self.model(
 | 
			
		||||
@ -257,6 +322,9 @@ class WanT2V:
 | 
			
		||||
            if offload_model:
 | 
			
		||||
                self.model.cpu()
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
                # load vae model back to device
 | 
			
		||||
                self.vae.model.to(self.device)
 | 
			
		||||
                
 | 
			
		||||
            if self.rank == 0:
 | 
			
		||||
                videos = self.vae.decode(x0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user