support fp8 modell

This commit is contained in:
Yexiong Lin 2025-02-27 17:42:53 +11:00
parent bebb16bb8e
commit bc2aff711e
4 changed files with 55 additions and 41 deletions

View File

@ -126,6 +126,11 @@ def _parse_args():
action="store_true", action="store_true",
default=False, default=False,
help="Whether to use FSDP for DiT.") help="Whether to use FSDP for DiT.")
parser.add_argument(
"--fp8",
action="store_true",
default=False,
help="Whether to use fp8.")
parser.add_argument( parser.add_argument(
"--save_file", "--save_file",
type=str, type=str,
@ -363,6 +368,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp, dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
fp8=args.fp8,
) )
logging.info("Generating video ...") logging.info("Generating video ...")

View File

@ -16,7 +16,7 @@ i2v_14B.t5_tokenizer = 'google/umt5-xxl'
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float16 i2v_14B.clip_dtype = torch.float16
# i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' # i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
i2v_14B.clip_tokenizer = 'xlm-roberta-large' i2v_14B.clip_tokenizer = 'xlm-roberta-large'
# vae # vae

View File

@ -43,6 +43,7 @@ class WanI2V:
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
init_on_cpu=True, init_on_cpu=True,
fp8=False,
): ):
r""" r"""
Initializes the image-to-video generation model components. Initializes the image-to-video generation model components.
@ -66,6 +67,8 @@ class WanI2V:
Whether to place T5 model on CPU. Only works without t5_fsdp. Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True): init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP. Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
fp8 (`bool`, *optional*, defaults to False):
Enable 8-bit floating point precision for model parameters.
""" """
self.device = torch.device(f"cuda:{device_id}") self.device = torch.device(f"cuda:{device_id}")
self.config = config self.config = config
@ -88,7 +91,6 @@ class WanI2V:
self.vae_stride = config.vae_stride self.vae_stride = config.vae_stride
self.patch_size = config.patch_size self.patch_size = config.patch_size
print('device:', self.device)
self.vae = WanVAE( self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
@ -101,10 +103,15 @@ class WanI2V:
config.clip_checkpoint), config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
if not fp8:
self.model = WanModel.from_pretrained(checkpoint_dir)
else:
if '480P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu") state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
elif '720P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0] dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1] in_channels = state_dict["patch_embedding.weight"].shape[1]
print("in_channels: ", in_channels)
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0] ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v" model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12 num_heads = 40 if dim == 5120 else 12
@ -136,9 +143,7 @@ class WanI2V:
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name]) set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
del state_dict del state_dict
print("Model loaded successfully")
# self.model = WanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp: if t5_fsdp or dit_fsdp or use_usp:

View File

@ -516,10 +516,13 @@ class CLIPModel:
device=device) device=device)
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}') logging.info(f'loading {checkpoint_path}')
if checkpoint_path.endswith('.safetensors'):
state_dict = load_file(checkpoint_path, device='cpu') state_dict = load_file(checkpoint_path, device='cpu')
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
# self.model.load_state_dict( elif checkpoint_path.endswith('.pth'):
# torch.load(checkpoint_path, map_location='cpu')) self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
else:
raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')
# init tokenizer # init tokenizer
self.tokenizer = HuggingfaceTokenizer( self.tokenizer = HuggingfaceTokenizer(