mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 07:44:53 +00:00
Add the support for fp8 t5
This commit is contained in:
parent
db54b7c613
commit
1c7b73d13e
@ -137,6 +137,12 @@ python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B
|
|||||||
|
|
||||||
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
|
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
|
||||||
|
|
||||||
|
Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:
|
||||||
|
|
||||||
|
```
|
||||||
|
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
|
||||||
|
```
|
||||||
|
|
||||||
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
|
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
|
|||||||
i2v_14B.update(wan_shared_cfg)
|
i2v_14B.update(wan_shared_cfg)
|
||||||
|
|
||||||
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||||
|
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||||
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||||
|
|
||||||
# clip
|
# clip
|
||||||
|
@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg)
|
|||||||
|
|
||||||
# t5
|
# t5
|
||||||
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||||
|
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||||
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||||
|
|
||||||
# vae
|
# vae
|
||||||
|
@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg)
|
|||||||
|
|
||||||
# t5
|
# t5
|
||||||
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||||
|
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||||
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
||||||
|
|
||||||
# vae
|
# vae
|
||||||
|
@ -80,6 +80,10 @@ class WanI2V:
|
|||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
|
|
||||||
shard_fn = partial(shard_model, device_id=device_id)
|
shard_fn = partial(shard_model, device_id=device_id)
|
||||||
|
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
|
||||||
|
quantization = "fp8_e4m3fn"
|
||||||
|
else:
|
||||||
|
quantization = "disabled"
|
||||||
self.text_encoder = T5EncoderModel(
|
self.text_encoder = T5EncoderModel(
|
||||||
text_len=config.text_len,
|
text_len=config.text_len,
|
||||||
dtype=config.t5_dtype,
|
dtype=config.t5_dtype,
|
||||||
@ -87,6 +91,7 @@ class WanI2V:
|
|||||||
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||||
shard_fn=shard_fn if t5_fsdp else None,
|
shard_fn=shard_fn if t5_fsdp else None,
|
||||||
|
quantization=quantization,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
self.vae_stride = config.vae_stride
|
||||||
@ -266,11 +271,13 @@ class WanI2V:
|
|||||||
# preprocess
|
# preprocess
|
||||||
if not self.t5_cpu:
|
if not self.t5_cpu:
|
||||||
self.text_encoder.model.to(self.device)
|
self.text_encoder.model.to(self.device)
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
||||||
context = self.text_encoder([input_prompt], self.device)
|
context = self.text_encoder([input_prompt], self.device)
|
||||||
context_null = self.text_encoder([n_prompt], self.device)
|
context_null = self.text_encoder([n_prompt], self.device)
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.text_encoder.model.cpu()
|
self.text_encoder.model.cpu()
|
||||||
else:
|
else:
|
||||||
|
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
|
||||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||||
context = [t.to(self.device) for t in context]
|
context = [t.to(self.device) for t in context]
|
||||||
|
@ -9,6 +9,10 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .tokenizers import HuggingfaceTokenizer
|
from .tokenizers import HuggingfaceTokenizer
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'T5Model',
|
'T5Model',
|
||||||
'T5Encoder',
|
'T5Encoder',
|
||||||
@ -442,7 +446,7 @@ def _t5(name,
|
|||||||
model = model_cls(**kwargs)
|
model = model_cls(**kwargs)
|
||||||
|
|
||||||
# set device
|
# set device
|
||||||
model = model.to(dtype=dtype, device=device)
|
# model = model.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
if return_tokenizer:
|
if return_tokenizer:
|
||||||
@ -479,6 +483,7 @@ class T5EncoderModel:
|
|||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
tokenizer_path=None,
|
tokenizer_path=None,
|
||||||
shard_fn=None,
|
shard_fn=None,
|
||||||
|
quantization="disabled",
|
||||||
):
|
):
|
||||||
self.text_len = text_len
|
self.text_len = text_len
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@ -486,14 +491,31 @@ class T5EncoderModel:
|
|||||||
self.checkpoint_path = checkpoint_path
|
self.checkpoint_path = checkpoint_path
|
||||||
self.tokenizer_path = tokenizer_path
|
self.tokenizer_path = tokenizer_path
|
||||||
|
|
||||||
|
|
||||||
|
logging.info(f'loading {checkpoint_path}')
|
||||||
|
if quantization == "disabled":
|
||||||
# init model
|
# init model
|
||||||
model = umt5_xxl(
|
model = umt5_xxl(
|
||||||
encoder_only=True,
|
encoder_only=True,
|
||||||
return_tokenizer=False,
|
return_tokenizer=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device).eval().requires_grad_(False)
|
device=device).eval().requires_grad_(False)
|
||||||
logging.info(f'loading {checkpoint_path}')
|
|
||||||
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||||
|
elif quantization == "fp8_e4m3fn":
|
||||||
|
with init_empty_weights():
|
||||||
|
model = umt5_xxl(
|
||||||
|
encoder_only=True,
|
||||||
|
return_tokenizer=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device).eval().requires_grad_(False)
|
||||||
|
cast_dtype = torch.float8_e4m3fn
|
||||||
|
state_dict = load_file(checkpoint_path, device="cpu")
|
||||||
|
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
|
||||||
|
set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
|
||||||
|
del state_dict
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
if shard_fn is not None:
|
if shard_fn is not None:
|
||||||
self.model = shard_fn(self.model, sync_module_states=False)
|
self.model = shard_fn(self.model, sync_module_states=False)
|
||||||
|
@ -74,13 +74,19 @@ class WanT2V:
|
|||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
|
|
||||||
shard_fn = partial(shard_model, device_id=device_id)
|
shard_fn = partial(shard_model, device_id=device_id)
|
||||||
|
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
|
||||||
|
quantization = "fp8_e4m3fn"
|
||||||
|
else:
|
||||||
|
quantization = "disabled"
|
||||||
|
|
||||||
self.text_encoder = T5EncoderModel(
|
self.text_encoder = T5EncoderModel(
|
||||||
text_len=config.text_len,
|
text_len=config.text_len,
|
||||||
dtype=config.t5_dtype,
|
dtype=config.t5_dtype,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||||
shard_fn=shard_fn if t5_fsdp else None)
|
shard_fn=shard_fn if t5_fsdp else None,
|
||||||
|
quantization=quantization)
|
||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
self.vae_stride = config.vae_stride
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
@ -221,11 +227,13 @@ class WanT2V:
|
|||||||
|
|
||||||
if not self.t5_cpu:
|
if not self.t5_cpu:
|
||||||
self.text_encoder.model.to(self.device)
|
self.text_encoder.model.to(self.device)
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
||||||
context = self.text_encoder([input_prompt], self.device)
|
context = self.text_encoder([input_prompt], self.device)
|
||||||
context_null = self.text_encoder([n_prompt], self.device)
|
context_null = self.text_encoder([n_prompt], self.device)
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.text_encoder.model.cpu()
|
self.text_encoder.model.cpu()
|
||||||
else:
|
else:
|
||||||
|
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
|
||||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||||
context = [t.to(self.device) for t in context]
|
context = [t.to(self.device) for t in context]
|
||||||
|
Loading…
Reference in New Issue
Block a user