支持Kijai的fp8模型

This commit is contained in:
Yexiong Lin 2025-02-27 17:01:58 +11:00
parent 5cb59ec52b
commit bebb16bb8e
4 changed files with 58 additions and 8 deletions

View File

@ -15,7 +15,8 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl'
# clip # clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float16 i2v_14B.clip_dtype = torch.float16
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' # i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors'
i2v_14B.clip_tokenizer = 'xlm-roberta-large' i2v_14B.clip_tokenizer = 'xlm-roberta-large'
# vae # vae

View File

@ -16,6 +16,10 @@ import torch.distributed as dist
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from tqdm import tqdm from tqdm import tqdm
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel from .modules.clip import CLIPModel
from .modules.model import WanModel from .modules.model import WanModel
@ -84,6 +88,8 @@ class WanI2V:
self.vae_stride = config.vae_stride self.vae_stride = config.vae_stride
self.patch_size = config.patch_size self.patch_size = config.patch_size
print('device:', self.device)
self.vae = WanVAE( self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device) device=self.device)
@ -95,8 +101,44 @@ class WanI2V:
config.clip_checkpoint), config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1]
print("in_channels: ", in_channels)
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12
num_layers = 40 if dim == 5120 else 30
TRANSFORMER_CONFIG= {
"dim": dim,
"ffn_dim": ffn_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": 16,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
}
with init_empty_weights():
self.model = WanModel(**TRANSFORMER_CONFIG)
logging.info(f"Creating WanModel from {checkpoint_dir}") logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
base_dtype=torch.bfloat16
dtype=torch.float8_e4m3fn
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
for name, param in self.model.named_parameters():
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
# dtype_to_use = torch.bfloat16
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
del state_dict
print("Model loaded successfully")
# self.model = WanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp: if t5_fsdp or dit_fsdp or use_usp:
@ -242,9 +284,12 @@ class WanI2V:
torch.zeros(3, 80, h, w) torch.zeros(3, 80, h, w)
], ],
dim=1).to(self.device) dim=1).to(self.device)
])[0] ],device=self.device)[0]
y = torch.concat([msk, y]) y = torch.concat([msk, y])
if offload_model:
self.vae.model.cpu()
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -334,7 +379,8 @@ class WanI2V:
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.rank == 0: if self.rank == 0:
videos = self.vae.decode(x0) self.vae.model.to(self.device)
videos = self.vae.decode(x0, device=self.device)
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler

View File

@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
from safetensors.torch import load_file
from .attention import flash_attention from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer from .tokenizers import HuggingfaceTokenizer
@ -515,8 +516,10 @@ class CLIPModel:
device=device) device=device)
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}') logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict( state_dict = load_file(checkpoint_path, device='cpu')
torch.load(checkpoint_path, map_location='cpu')) self.model.load_state_dict(state_dict)
# self.model.load_state_dict(
# torch.load(checkpoint_path, map_location='cpu'))
# init tokenizer # init tokenizer
self.tokenizer = HuggingfaceTokenizer( self.tokenizer = HuggingfaceTokenizer(

View File

@ -644,7 +644,7 @@ class WanVAE:
z_dim=z_dim, z_dim=z_dim,
).eval().requires_grad_(False).to(device) ).eval().requires_grad_(False).to(device)
def encode(self, videos): def encode(self, videos, device=None):
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
@ -654,7 +654,7 @@ class WanVAE:
for u in videos for u in videos
] ]
def decode(self, zs): def decode(self, zs, device=None):
with amp.autocast(dtype=self.dtype): with amp.autocast(dtype=self.dtype):
return [ return [
self.model.decode(u.unsqueeze(0), self.model.decode(u.unsqueeze(0),