mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 15:54:52 +00:00
support fp8 modell
This commit is contained in:
parent
bebb16bb8e
commit
bc2aff711e
@ -126,6 +126,11 @@ def _parse_args():
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for DiT.")
|
||||
parser.add_argument(
|
||||
"--fp8",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use fp8.")
|
||||
parser.add_argument(
|
||||
"--save_file",
|
||||
type=str,
|
||||
@ -363,6 +368,7 @@ def generate(args):
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
fp8=args.fp8,
|
||||
)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
|
@ -16,7 +16,7 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
||||
i2v_14B.clip_dtype = torch.float16
|
||||
# i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
||||
i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors'
|
||||
i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
|
||||
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
||||
|
||||
# vae
|
||||
|
@ -43,6 +43,7 @@ class WanI2V:
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
init_on_cpu=True,
|
||||
fp8=False,
|
||||
):
|
||||
r"""
|
||||
Initializes the image-to-video generation model components.
|
||||
@ -66,6 +67,8 @@ class WanI2V:
|
||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||
init_on_cpu (`bool`, *optional*, defaults to True):
|
||||
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
||||
fp8 (`bool`, *optional*, defaults to False):
|
||||
Enable 8-bit floating point precision for model parameters.
|
||||
"""
|
||||
self.device = torch.device(f"cuda:{device_id}")
|
||||
self.config = config
|
||||
@ -88,7 +91,6 @@ class WanI2V:
|
||||
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
print('device:', self.device)
|
||||
|
||||
self.vae = WanVAE(
|
||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||
@ -101,44 +103,47 @@ class WanI2V:
|
||||
config.clip_checkpoint),
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||
|
||||
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
|
||||
dim = state_dict["patch_embedding.weight"].shape[0]
|
||||
in_channels = state_dict["patch_embedding.weight"].shape[1]
|
||||
print("in_channels: ", in_channels)
|
||||
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
|
||||
model_type = "i2v" if in_channels == 36 else "t2v"
|
||||
num_heads = 40 if dim == 5120 else 12
|
||||
num_layers = 40 if dim == 5120 else 30
|
||||
TRANSFORMER_CONFIG= {
|
||||
"dim": dim,
|
||||
"ffn_dim": ffn_dim,
|
||||
"eps": 1e-06,
|
||||
"freq_dim": 256,
|
||||
"in_dim": in_channels,
|
||||
"model_type": model_type,
|
||||
"out_dim": 16,
|
||||
"text_len": 512,
|
||||
"num_heads": num_heads,
|
||||
"num_layers": num_layers,
|
||||
}
|
||||
if not fp8:
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
else:
|
||||
if '480P' in checkpoint_dir:
|
||||
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
|
||||
elif '720P' in checkpoint_dir:
|
||||
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
|
||||
dim = state_dict["patch_embedding.weight"].shape[0]
|
||||
in_channels = state_dict["patch_embedding.weight"].shape[1]
|
||||
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
|
||||
model_type = "i2v" if in_channels == 36 else "t2v"
|
||||
num_heads = 40 if dim == 5120 else 12
|
||||
num_layers = 40 if dim == 5120 else 30
|
||||
TRANSFORMER_CONFIG= {
|
||||
"dim": dim,
|
||||
"ffn_dim": ffn_dim,
|
||||
"eps": 1e-06,
|
||||
"freq_dim": 256,
|
||||
"in_dim": in_channels,
|
||||
"model_type": model_type,
|
||||
"out_dim": 16,
|
||||
"text_len": 512,
|
||||
"num_heads": num_heads,
|
||||
"num_layers": num_layers,
|
||||
}
|
||||
|
||||
with init_empty_weights():
|
||||
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
|
||||
base_dtype=torch.bfloat16
|
||||
dtype=torch.float8_e4m3fn
|
||||
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
|
||||
for name, param in self.model.named_parameters():
|
||||
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||
# dtype_to_use = torch.bfloat16
|
||||
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
|
||||
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
|
||||
with init_empty_weights():
|
||||
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
|
||||
base_dtype=torch.bfloat16
|
||||
dtype=torch.float8_e4m3fn
|
||||
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
|
||||
for name, param in self.model.named_parameters():
|
||||
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||
# dtype_to_use = torch.bfloat16
|
||||
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
|
||||
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
|
||||
|
||||
del state_dict
|
||||
print("Model loaded successfully")
|
||||
del state_dict
|
||||
|
||||
# self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if t5_fsdp or dit_fsdp or use_usp:
|
||||
|
@ -516,10 +516,13 @@ class CLIPModel:
|
||||
device=device)
|
||||
self.model = self.model.eval().requires_grad_(False)
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
state_dict = load_file(checkpoint_path, device='cpu')
|
||||
self.model.load_state_dict(state_dict)
|
||||
# self.model.load_state_dict(
|
||||
# torch.load(checkpoint_path, map_location='cpu'))
|
||||
if checkpoint_path.endswith('.safetensors'):
|
||||
state_dict = load_file(checkpoint_path, device='cpu')
|
||||
self.model.load_state_dict(state_dict)
|
||||
elif checkpoint_path.endswith('.pth'):
|
||||
self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||
else:
|
||||
raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
|
Loading…
Reference in New Issue
Block a user