mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 15:54:52 +00:00
支持Kijai的fp8模型
This commit is contained in:
parent
5cb59ec52b
commit
bebb16bb8e
@ -15,7 +15,8 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
|||||||
# clip
|
# clip
|
||||||
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
||||||
i2v_14B.clip_dtype = torch.float16
|
i2v_14B.clip_dtype = torch.float16
|
||||||
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
# i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
||||||
|
i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors'
|
||||||
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
||||||
|
|
||||||
# vae
|
# vae
|
||||||
|
@ -16,6 +16,10 @@ import torch.distributed as dist
|
|||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from .modules.clip import CLIPModel
|
from .modules.clip import CLIPModel
|
||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
@ -84,6 +88,8 @@ class WanI2V:
|
|||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
self.vae_stride = config.vae_stride
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
print('device:', self.device)
|
||||||
|
|
||||||
self.vae = WanVAE(
|
self.vae = WanVAE(
|
||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||||
device=self.device)
|
device=self.device)
|
||||||
@ -95,8 +101,44 @@ class WanI2V:
|
|||||||
config.clip_checkpoint),
|
config.clip_checkpoint),
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||||
|
|
||||||
|
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
|
||||||
|
dim = state_dict["patch_embedding.weight"].shape[0]
|
||||||
|
in_channels = state_dict["patch_embedding.weight"].shape[1]
|
||||||
|
print("in_channels: ", in_channels)
|
||||||
|
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
|
||||||
|
model_type = "i2v" if in_channels == 36 else "t2v"
|
||||||
|
num_heads = 40 if dim == 5120 else 12
|
||||||
|
num_layers = 40 if dim == 5120 else 30
|
||||||
|
TRANSFORMER_CONFIG= {
|
||||||
|
"dim": dim,
|
||||||
|
"ffn_dim": ffn_dim,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"in_dim": in_channels,
|
||||||
|
"model_type": model_type,
|
||||||
|
"out_dim": 16,
|
||||||
|
"text_len": 512,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
"num_layers": num_layers,
|
||||||
|
}
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
|
||||||
|
base_dtype=torch.bfloat16
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||||
|
# dtype_to_use = torch.bfloat16
|
||||||
|
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
|
||||||
|
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
|
||||||
|
|
||||||
|
del state_dict
|
||||||
|
print("Model loaded successfully")
|
||||||
|
|
||||||
|
# self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if t5_fsdp or dit_fsdp or use_usp:
|
if t5_fsdp or dit_fsdp or use_usp:
|
||||||
@ -242,9 +284,12 @@ class WanI2V:
|
|||||||
torch.zeros(3, 80, h, w)
|
torch.zeros(3, 80, h, w)
|
||||||
],
|
],
|
||||||
dim=1).to(self.device)
|
dim=1).to(self.device)
|
||||||
])[0]
|
],device=self.device)[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, y])
|
||||||
|
|
||||||
|
if offload_model:
|
||||||
|
self.vae.model.cpu()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
yield
|
yield
|
||||||
@ -334,7 +379,8 @@ class WanI2V:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.vae.decode(x0)
|
self.vae.model.to(self.device)
|
||||||
|
videos = self.vae.decode(x0, device=self.device)
|
||||||
|
|
||||||
del noise, latent
|
del noise, latent
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from .attention import flash_attention
|
from .attention import flash_attention
|
||||||
from .tokenizers import HuggingfaceTokenizer
|
from .tokenizers import HuggingfaceTokenizer
|
||||||
@ -515,8 +516,10 @@ class CLIPModel:
|
|||||||
device=device)
|
device=device)
|
||||||
self.model = self.model.eval().requires_grad_(False)
|
self.model = self.model.eval().requires_grad_(False)
|
||||||
logging.info(f'loading {checkpoint_path}')
|
logging.info(f'loading {checkpoint_path}')
|
||||||
self.model.load_state_dict(
|
state_dict = load_file(checkpoint_path, device='cpu')
|
||||||
torch.load(checkpoint_path, map_location='cpu'))
|
self.model.load_state_dict(state_dict)
|
||||||
|
# self.model.load_state_dict(
|
||||||
|
# torch.load(checkpoint_path, map_location='cpu'))
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
self.tokenizer = HuggingfaceTokenizer(
|
self.tokenizer = HuggingfaceTokenizer(
|
||||||
|
@ -644,7 +644,7 @@ class WanVAE:
|
|||||||
z_dim=z_dim,
|
z_dim=z_dim,
|
||||||
).eval().requires_grad_(False).to(device)
|
).eval().requires_grad_(False).to(device)
|
||||||
|
|
||||||
def encode(self, videos):
|
def encode(self, videos, device=None):
|
||||||
"""
|
"""
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
videos: A list of videos each with shape [C, T, H, W].
|
||||||
"""
|
"""
|
||||||
@ -654,7 +654,7 @@ class WanVAE:
|
|||||||
for u in videos
|
for u in videos
|
||||||
]
|
]
|
||||||
|
|
||||||
def decode(self, zs):
|
def decode(self, zs, device=None):
|
||||||
with amp.autocast(dtype=self.dtype):
|
with amp.autocast(dtype=self.dtype):
|
||||||
return [
|
return [
|
||||||
self.model.decode(u.unsqueeze(0),
|
self.model.decode(u.unsqueeze(0),
|
||||||
|
Loading…
Reference in New Issue
Block a user