diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py index 12e8e20..b2ab367 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/wan/configs/wan_i2v_14B.py @@ -15,7 +15,8 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl' # clip i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_dtype = torch.float16 -i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' +# i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' +i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' i2v_14B.clip_tokenizer = 'xlm-roberta-large' # vae diff --git a/wan/image2video.py b/wan/image2video.py index 468f17c..df63624 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -16,6 +16,10 @@ import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from safetensors.torch import load_file + from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.model import WanModel @@ -84,6 +88,8 @@ class WanI2V: self.vae_stride = config.vae_stride self.patch_size = config.patch_size + print('device:', self.device) + self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) @@ -95,8 +101,44 @@ class WanI2V: config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) + state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu") + dim = state_dict["patch_embedding.weight"].shape[0] + in_channels = state_dict["patch_embedding.weight"].shape[1] + print("in_channels: ", in_channels) + ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] + model_type = "i2v" if in_channels == 36 else "t2v" + num_heads = 40 if dim == 5120 else 12 + num_layers = 40 if dim == 5120 else 30 + TRANSFORMER_CONFIG= { + "dim": dim, + "ffn_dim": ffn_dim, + "eps": 1e-06, + "freq_dim": 256, + "in_dim": in_channels, + "model_type": model_type, + "out_dim": 16, + "text_len": 512, + "num_heads": num_heads, + "num_layers": num_layers, + } + + with init_empty_weights(): + self.model = WanModel(**TRANSFORMER_CONFIG) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + + base_dtype=torch.bfloat16 + dtype=torch.float8_e4m3fn + params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"} + for name, param in self.model.named_parameters(): + dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype + # dtype_to_use = torch.bfloat16 + # print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use) + set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) + + del state_dict + print("Model loaded successfully") + + # self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: @@ -242,9 +284,12 @@ class WanI2V: torch.zeros(3, 80, h, w) ], dim=1).to(self.device) - ])[0] + ],device=self.device)[0] y = torch.concat([msk, y]) + if offload_model: + self.vae.model.cpu() + @contextmanager def noop_no_sync(): yield @@ -334,7 +379,8 @@ class WanI2V: torch.cuda.empty_cache() if self.rank == 0: - videos = self.vae.decode(x0) + self.vae.model.to(self.device) + videos = self.vae.decode(x0, device=self.device) del noise, latent del sample_scheduler diff --git a/wan/modules/clip.py b/wan/modules/clip.py index 42dda04..ceccafc 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T +from safetensors.torch import load_file from .attention import flash_attention from .tokenizers import HuggingfaceTokenizer @@ -515,8 +516,10 @@ class CLIPModel: device=device) self.model = self.model.eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') - self.model.load_state_dict( - torch.load(checkpoint_path, map_location='cpu')) + state_dict = load_file(checkpoint_path, device='cpu') + self.model.load_state_dict(state_dict) + # self.model.load_state_dict( + # torch.load(checkpoint_path, map_location='cpu')) # init tokenizer self.tokenizer = HuggingfaceTokenizer( diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 5c6da57..dd98b42 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -644,7 +644,7 @@ class WanVAE: z_dim=z_dim, ).eval().requires_grad_(False).to(device) - def encode(self, videos): + def encode(self, videos, device=None): """ videos: A list of videos each with shape [C, T, H, W]. """ @@ -654,7 +654,7 @@ class WanVAE: for u in videos ] - def decode(self, zs): + def decode(self, zs, device=None): with amp.autocast(dtype=self.dtype): return [ self.model.decode(u.unsqueeze(0),