mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
fixed save quantization
This commit is contained in:
parent
5a63326bb9
commit
73cf4e43c3
@ -66,14 +66,20 @@ If a model is not quantized, it is assumed to be mostly 16 bits (with maybe a fe
|
|||||||
|
|
||||||
If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*.
|
If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*.
|
||||||
|
|
||||||
|
Please note it is important than *bf16", "fp16* and *quanto* are all in lower cases letters.
|
||||||
|
|
||||||
## Creating a Quanto Quantized file
|
## Creating a Quanto Quantized file
|
||||||
If you launch the app with the *--save-quantized* switch, WanGP will create a quantized file in the **ckpts** subfolder just after the model has been loaded. Please note that the model will *bf16* or *fp16* quantized depending on what you chose in the configuration menu.
|
If you launch the app with the *--save-quantized* switch, WanGP will create a quantized file in the **ckpts** subfolder just after the model has been loaded. Please note that the model will *bf16* or *fp16* quantized depending on what you chose in the configuration menu.
|
||||||
|
|
||||||
1) Make sure that in the finetune definition json file there is only a URL or filepath that points to the non quantized model
|
1) Make sure that in the finetune definition json file there is only a URL or filepath that points to the non quantized model
|
||||||
2) Launch WanGP *python wgp.py --save-quantized*
|
2) Launch WanGP *python wgp.py --save-quantized*
|
||||||
3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16*
|
3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16*
|
||||||
4) Launch a generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder it doesn't already exist.
|
4) Launch a video generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder if it doesn't already exist.
|
||||||
5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]*
|
5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]*
|
||||||
6) Restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
|
6) Remove *--save-quantized*, restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
|
||||||
7) Launch a new generation an verify in the terminal window that the right quantized model is loaded
|
7) Launch a new generation and verify in the terminal window that the right quantized model is loaded
|
||||||
8) In order to share the finetune definition file will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
|
8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
|
||||||
|
|
||||||
|
You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded.
|
||||||
|
|
||||||
|
Wan models supports both *fp16* and *bf16* data types albeit *fp16* delivers in theory better quality. On the contrary Hunyuan and LTXV supports only *bf16*.
|
||||||
@ -949,11 +949,11 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
|||||||
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||||
# to deal with lora scaling and other possible forward hooks
|
# to deal with lora scaling and other possible forward hooks
|
||||||
trans = self.transformer
|
trans = self.transformer
|
||||||
if trans.enable_teacache:
|
if trans.enable_cache:
|
||||||
teacache_multiplier = trans.teacache_multiplier
|
teacache_multiplier = trans.teacache_multiplier
|
||||||
trans.accumulated_rel_l1_distance = 0
|
trans.accumulated_rel_l1_distance = 0
|
||||||
trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
||||||
# trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
# trans.cache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
||||||
# 1. Check inputs. Raise error if not correct
|
# 1. Check inputs. Raise error if not correct
|
||||||
self.check_inputs(
|
self.check_inputs(
|
||||||
prompt,
|
prompt,
|
||||||
@ -1208,7 +1208,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
|||||||
if ip_cfg_scale>0:
|
if ip_cfg_scale>0:
|
||||||
latent_items += 1
|
latent_items += 1
|
||||||
|
|
||||||
if self.transformer.enable_teacache:
|
if self.transformer.enable_cache:
|
||||||
self.transformer.previous_residual = [None] * latent_items
|
self.transformer.previous_residual = [None] * latent_items
|
||||||
|
|
||||||
# if is_progress_bar:
|
# if is_progress_bar:
|
||||||
|
|||||||
@ -934,7 +934,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
transformer = self.transformer
|
transformer = self.transformer
|
||||||
|
|
||||||
if transformer.enable_teacache:
|
if transformer.enable_cache:
|
||||||
teacache_multiplier = transformer.teacache_multiplier
|
teacache_multiplier = transformer.teacache_multiplier
|
||||||
transformer.accumulated_rel_l1_distance = 0
|
transformer.accumulated_rel_l1_distance = 0
|
||||||
transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
||||||
@ -1136,7 +1136,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
|||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return [None]
|
return [None]
|
||||||
|
|
||||||
if transformer.enable_teacache:
|
if transformer.enable_cache:
|
||||||
cache_size = round( infer_length / frames_per_batch )
|
cache_size = round( infer_length / frames_per_batch )
|
||||||
transformer.previous_residual = [None] * latent_items
|
transformer.previous_residual = [None] * latent_items
|
||||||
cache_all_previous_residual = [None] * latent_items
|
cache_all_previous_residual = [None] * latent_items
|
||||||
@ -1180,7 +1180,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
|||||||
img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1)
|
img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1)
|
||||||
img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
|
img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
|
||||||
|
|
||||||
if transformer.enable_teacache and cache_size > 1:
|
if transformer.enable_cache and cache_size > 1:
|
||||||
for l in range(latent_items):
|
for l in range(latent_items):
|
||||||
if cache_all_previous_residual[l] != None:
|
if cache_all_previous_residual[l] != None:
|
||||||
bsz = cache_all_previous_residual[l].shape[0]
|
bsz = cache_all_previous_residual[l].shape[0]
|
||||||
@ -1297,7 +1297,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
|||||||
pred_latents[:, :, p] += latents[:, :, iii]
|
pred_latents[:, :, p] += latents[:, :, iii]
|
||||||
counter[:, :, p] += 1
|
counter[:, :, p] += 1
|
||||||
|
|
||||||
if transformer.enable_teacache and cache_size > 1:
|
if transformer.enable_cache and cache_size > 1:
|
||||||
for l in range(latent_items):
|
for l in range(latent_items):
|
||||||
if transformer.previous_residual[l] != None:
|
if transformer.previous_residual[l] != None:
|
||||||
bsz = transformer.previous_residual[l].shape[0]
|
bsz = transformer.previous_residual[l].shape[0]
|
||||||
|
|||||||
@ -922,7 +922,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
|||||||
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
||||||
|
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_cache:
|
||||||
if x_id == 0:
|
if x_id == 0:
|
||||||
self.should_calc = True
|
self.should_calc = True
|
||||||
inp = img[0:1]
|
inp = img[0:1]
|
||||||
@ -932,7 +932,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
|||||||
normed_inp = normed_inp.to(torch.bfloat16)
|
normed_inp = normed_inp.to(torch.bfloat16)
|
||||||
modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
|
modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
|
||||||
del normed_inp, img_mod1_shift, img_mod1_scale
|
del normed_inp, img_mod1_shift, img_mod1_scale
|
||||||
if step_no <= self.teacache_start_step or step_no == self.num_steps-1:
|
if step_no <= self.cache_start_step or step_no == self.num_steps-1:
|
||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
else:
|
else:
|
||||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||||
@ -950,7 +950,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
|||||||
if not self.should_calc:
|
if not self.should_calc:
|
||||||
img += self.previous_residual[x_id]
|
img += self.previous_residual[x_id]
|
||||||
else:
|
else:
|
||||||
if self.enable_teacache:
|
if self.enable_cache:
|
||||||
self.previous_residual[x_id] = None
|
self.previous_residual[x_id] = None
|
||||||
ori_img = img[0:1].clone()
|
ori_img = img[0:1].clone()
|
||||||
# --------------------- Pass through DiT blocks ------------------------
|
# --------------------- Pass through DiT blocks ------------------------
|
||||||
@ -1014,7 +1014,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
|||||||
single_block_args = None
|
single_block_args = None
|
||||||
|
|
||||||
# img = x[:, :img_seq_len, ...]
|
# img = x[:, :img_seq_len, ...]
|
||||||
if self.enable_teacache:
|
if self.enable_cache:
|
||||||
if len(img) > 1:
|
if len(img) > 1:
|
||||||
self.previous_residual[0] = torch.empty_like(img)
|
self.previous_residual[0] = torch.empty_like(img)
|
||||||
for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
|
for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
|
||||||
|
|||||||
@ -551,8 +551,8 @@ def main():
|
|||||||
|
|
||||||
# Setup tea cache if needed
|
# Setup tea cache if needed
|
||||||
trans = wan_model.model
|
trans = wan_model.model
|
||||||
trans.enable_teacache = (args.teacache > 0)
|
trans.enable_cache = (args.teacache > 0)
|
||||||
if trans.enable_teacache:
|
if trans.enable_cache:
|
||||||
if "480p" in args.transformer_file:
|
if "480p" in args.transformer_file:
|
||||||
# example from your code
|
# example from your code
|
||||||
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||||
@ -582,10 +582,10 @@ def main():
|
|||||||
enable_riflex = args.riflex
|
enable_riflex = args.riflex
|
||||||
|
|
||||||
# If teacache => reset counters
|
# If teacache => reset counters
|
||||||
if trans.enable_teacache:
|
if trans.enable_cache:
|
||||||
trans.teacache_counter = 0
|
trans.teacache_counter = 0
|
||||||
trans.teacache_multiplier = args.teacache
|
trans.teacache_multiplier = args.teacache
|
||||||
trans.teacache_start_step = int(args.teacache_start * args.steps / 100.0)
|
trans.cache_start_step = int(args.teacache_start * args.steps / 100.0)
|
||||||
trans.num_steps = args.steps
|
trans.num_steps = args.steps
|
||||||
trans.teacache_skipped_steps = 0
|
trans.teacache_skipped_steps = 0
|
||||||
trans.previous_residual_uncond = None
|
trans.previous_residual_uncond = None
|
||||||
@ -655,7 +655,7 @@ def main():
|
|||||||
raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
|
raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
|
||||||
|
|
||||||
# If teacache was used, we can see how many steps were skipped
|
# If teacache was used, we can see how many steps were skipped
|
||||||
if trans.enable_teacache:
|
if trans.enable_cache:
|
||||||
print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
|
print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
|
||||||
|
|
||||||
# Save result
|
# Save result
|
||||||
|
|||||||
@ -78,7 +78,7 @@ class DTT2V:
|
|||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
if save_quantized:
|
if save_quantized:
|
||||||
from wan.utils.utils import save_quantized_model
|
from wan.utils.utils import save_quantized_model
|
||||||
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
|
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
|
||||||
|
|
||||||
self.scheduler = FlowUniPCMultistepScheduler()
|
self.scheduler = FlowUniPCMultistepScheduler()
|
||||||
|
|
||||||
@ -316,7 +316,7 @@ class DTT2V:
|
|||||||
updated_num_steps= len(step_matrix)
|
updated_num_steps= len(step_matrix)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
||||||
if self.model.enable_teacache:
|
if self.model.enable_cache:
|
||||||
x_count = 2 if self.do_classifier_free_guidance else 1
|
x_count = 2 if self.do_classifier_free_guidance else 1
|
||||||
self.model.previous_residual = [None] * x_count
|
self.model.previous_residual = [None] * x_count
|
||||||
time_steps_comb = []
|
time_steps_comb = []
|
||||||
@ -327,7 +327,7 @@ class DTT2V:
|
|||||||
if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
|
if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
|
||||||
timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise
|
timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise
|
||||||
time_steps_comb.append(timestep)
|
time_steps_comb.append(timestep)
|
||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.cache_start_step, time_steps_comb, self.model.teacache_multiplier)
|
||||||
del time_steps_comb
|
del time_steps_comb
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
|
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class WanI2V:
|
|||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
if save_quantized:
|
if save_quantized:
|
||||||
from wan.utils.utils import save_quantized_model
|
from wan.utils.utils import save_quantized_model
|
||||||
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
|
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
|
||||||
|
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
@ -317,9 +317,9 @@ class WanI2V:
|
|||||||
"audio_context_lens": audio_context_lens,
|
"audio_context_lens": audio_context_lens,
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.model.enable_teacache:
|
if self.model.enable_cache:
|
||||||
self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
|
self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
|
||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
|
|||||||
@ -194,6 +194,11 @@ def pay_attention(
|
|||||||
|
|
||||||
q = q.to(v.dtype)
|
q = q.to(v.dtype)
|
||||||
k = k.to(v.dtype)
|
k = k.to(v.dtype)
|
||||||
|
|
||||||
|
if attn == "chipmunk":
|
||||||
|
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
|
||||||
|
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
|
||||||
|
|
||||||
if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
|
if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
|
||||||
assert attention_mask == None
|
assert attention_mask == None
|
||||||
# Poor's man var k len attention
|
# Poor's man var k len attention
|
||||||
|
|||||||
@ -11,6 +11,8 @@ from typing import Union,Optional
|
|||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
from .attention import pay_attention
|
from .attention import pay_attention
|
||||||
from torch.backends.cuda import sdp_kernel
|
from torch.backends.cuda import sdp_kernel
|
||||||
|
# from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
|
||||||
|
# from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
@ -172,6 +174,11 @@ class WanSelfAttention(nn.Module):
|
|||||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
# Only initialize SparseDiffAttn if this is not a subclass initialization
|
||||||
|
# if self.__class__ == WanSelfAttention:
|
||||||
|
# layer_num, layer_counter = LayerCounter.build_for_layer(is_attn_sparse=True, is_mlp_sparse=False)
|
||||||
|
# self.attn = SparseDiffAttn(layer_num, layer_counter)
|
||||||
|
|
||||||
def forward(self, xlist, grid_sizes, freqs, block_mask = None):
|
def forward(self, xlist, grid_sizes, freqs, block_mask = None):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -197,7 +204,10 @@ class WanSelfAttention(nn.Module):
|
|||||||
del q,k
|
del q,k
|
||||||
|
|
||||||
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
||||||
if block_mask == None:
|
chipmunk = offload.shared_state["_chipmunk"]
|
||||||
|
if chipmunk:
|
||||||
|
x = self.attn(q, k, v)
|
||||||
|
elif block_mask == None:
|
||||||
qkv_list = [q,k,v]
|
qkv_list = [q,k,v]
|
||||||
del q,k,v
|
del q,k,v
|
||||||
x = pay_attention(
|
x = pay_attention(
|
||||||
@ -954,6 +964,16 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
x_list[i] = x
|
x_list[i] = x
|
||||||
x, y = None, None
|
x, y = None, None
|
||||||
|
|
||||||
|
offload.shared_state["_chipmunk"] = False
|
||||||
|
chipmunk = offload.shared_state["_chipmunk"]
|
||||||
|
if chipmunk:
|
||||||
|
voxel_shape = (4, 6, 8)
|
||||||
|
for x in x_list:
|
||||||
|
from src.chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
|
||||||
|
x = x.unsqueeze(-1)
|
||||||
|
x_og_shape = x.shape
|
||||||
|
x = voxel_chunk_no_padding(x, voxel_shape).squeeze(-1).transpose(1, 2)
|
||||||
|
x = None
|
||||||
|
|
||||||
block_mask = None
|
block_mask = None
|
||||||
if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
|
if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
|
||||||
@ -1027,11 +1047,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
del c
|
del c
|
||||||
|
|
||||||
should_calc = True
|
should_calc = True
|
||||||
if self.enable_teacache:
|
if self.enable_cache:
|
||||||
if x_id != 0:
|
if x_id != 0:
|
||||||
should_calc = self.should_calc
|
should_calc = self.should_calc
|
||||||
else:
|
else:
|
||||||
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
|
if current_step <= self.cache_start_step or current_step == self.num_steps-1:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
else:
|
else:
|
||||||
@ -1057,7 +1077,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
x += self.previous_residual[x_id]
|
x += self.previous_residual[x_id]
|
||||||
x = None
|
x = None
|
||||||
else:
|
else:
|
||||||
if self.enable_teacache:
|
if self.enable_cache:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
self.previous_residual = [ None ] * len(self.previous_residual)
|
self.previous_residual = [ None ] * len(self.previous_residual)
|
||||||
else:
|
else:
|
||||||
@ -1084,7 +1104,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
del x
|
del x
|
||||||
del context, hints
|
del context, hints
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_cache:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
|
for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
|
||||||
if i == 0 or is_source and i != last_x_idx :
|
if i == 0 or is_source and i != last_x_idx :
|
||||||
@ -1101,6 +1121,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
residual, ori_hidden_states = None, None
|
residual, ori_hidden_states = None, None
|
||||||
|
|
||||||
for i, x in enumerate(x_list):
|
for i, x in enumerate(x_list):
|
||||||
|
if chipmunk:
|
||||||
|
x = reverse_voxel_chunk_no_padding(x.transpose(1, 2).unsqueeze(-1), x_og_shape, voxel_shape).squeeze(-1)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
|||||||
@ -101,7 +101,7 @@ class WanT2V:
|
|||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
if save_quantized:
|
if save_quantized:
|
||||||
from wan.utils.utils import save_quantized_model
|
from wan.utils.utils import save_quantized_model
|
||||||
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
|
save_quantized_model(self.model, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
@ -458,13 +458,24 @@ class WanT2V:
|
|||||||
z_reactive = [ zz[0:16, 0:overlapped_latents_size + ref_images_count].clone() for zz in z]
|
z_reactive = [ zz[0:16, 0:overlapped_latents_size + ref_images_count].clone() for zz in z]
|
||||||
|
|
||||||
|
|
||||||
if self.model.enable_teacache:
|
if self.model.enable_cache:
|
||||||
x_count = 3 if phantom else 2
|
x_count = 3 if phantom else 2
|
||||||
self.model.previous_residual = [None] * x_count
|
self.model.previous_residual = [None] * x_count
|
||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None, True)
|
callback(-1, None, True)
|
||||||
prev = 50/1000
|
|
||||||
|
# seq_shape = (21, 45, 80)
|
||||||
|
# local_heads_num = 40 #12 for 1.3B
|
||||||
|
|
||||||
|
# self.model.blocks[0].self_attn.attn.initialize_static_mask(
|
||||||
|
# seq_shape=seq_shape,
|
||||||
|
# txt_len=0,
|
||||||
|
# local_heads_num=local_heads_num,
|
||||||
|
# device='cuda'
|
||||||
|
# )
|
||||||
|
# self.model.blocks[0].self_attn.attn.layer_counter.reset()
|
||||||
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
|
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|||||||
@ -338,17 +338,19 @@ def create_progress_hook(filename):
|
|||||||
return hook
|
return hook
|
||||||
|
|
||||||
def save_quantized_model(model, model_filename, dtype, config_file):
|
def save_quantized_model(model, model_filename, dtype, config_file):
|
||||||
|
if "quanto" in model_filename:
|
||||||
|
return
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
model_filename = model_filename.replace("fp16", "bf16")
|
model_filename = model_filename.replace("fp16", "bf16").replace("FP16", "bf16")
|
||||||
elif dtype == torch.float16:
|
elif dtype == torch.float16:
|
||||||
model_filename = model_filename.replace("bf16", "fp16")
|
model_filename = model_filename.replace("bf16", "fp16").replace("BF16", "bf16")
|
||||||
|
|
||||||
if "_fp16" in model_filename:
|
for rep in ["mfp16", "fp16", "mbf16", "bf16"]:
|
||||||
model_filename = model_filename.replace("_fp16", "_quanto_fp16_int8")
|
if "_" + rep in model_filename:
|
||||||
elif "_bf16" in model_filename:
|
model_filename = model_filename.replace("_" + rep, "_quanto_" + rep + "_int8")
|
||||||
model_filename = model_filename.replace("_bf16", "_quanto_bf16_int8")
|
break
|
||||||
else:
|
if not "quanto" in model_filename:
|
||||||
pos = model_filename.rfind(".")
|
pos = model_filename.rfind(".")
|
||||||
model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:]
|
model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:]
|
||||||
|
|
||||||
|
|||||||
48
wgp.py
48
wgp.py
@ -1624,7 +1624,7 @@ def get_model_family(model_type):
|
|||||||
|
|
||||||
def test_class_i2v(model_type):
|
def test_class_i2v(model_type):
|
||||||
model_type = get_base_model_type(model_type)
|
model_type = get_base_model_type(model_type)
|
||||||
return model_type in ["i2v", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ]
|
return model_type in ["i2v", "i2v_720p", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ]
|
||||||
|
|
||||||
def get_model_name(model_type, description_container = [""]):
|
def get_model_name(model_type, description_container = [""]):
|
||||||
finetune_def = get_model_finetune_def(model_type)
|
finetune_def = get_model_finetune_def(model_type)
|
||||||
@ -1731,19 +1731,19 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
|
|||||||
raw_filename = choices[0]
|
raw_filename = choices[0]
|
||||||
else:
|
else:
|
||||||
if quantization in ("int8", "fp8"):
|
if quantization in ("int8", "fp8"):
|
||||||
sub_choices = [ name for name in choices if quantization in name]
|
sub_choices = [ name for name in choices if quantization in name or quantization.upper() in name]
|
||||||
else:
|
else:
|
||||||
sub_choices = [ name for name in choices if "quanto" not in name]
|
sub_choices = [ name for name in choices if "quanto" not in name]
|
||||||
|
|
||||||
if len(sub_choices) > 0:
|
if len(sub_choices) > 0:
|
||||||
dtype_str = "fp16" if dtype == torch.float16 else "bf16"
|
dtype_str = "fp16" if dtype == torch.float16 else "bf16"
|
||||||
new_sub_choices = [ name for name in sub_choices if dtype_str in name]
|
new_sub_choices = [ name for name in sub_choices if dtype_str in name or dtype_str.upper() in name]
|
||||||
sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices
|
sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices
|
||||||
raw_filename = sub_choices[0]
|
raw_filename = sub_choices[0]
|
||||||
else:
|
else:
|
||||||
raw_filename = choices[0]
|
raw_filename = choices[0]
|
||||||
|
|
||||||
if dtype == torch.float16 and not "fp16" in raw_filename and model_family == "wan" and finetune_def == None :
|
if dtype == torch.float16 and not any("fp16","FP16") in raw_filename and model_family == "wan" and finetune_def == None :
|
||||||
if "quanto_int8" in raw_filename:
|
if "quanto_int8" in raw_filename:
|
||||||
raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
|
raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
|
||||||
elif "quanto_bf16_int8" in raw_filename:
|
elif "quanto_bf16_int8" in raw_filename:
|
||||||
@ -1753,6 +1753,8 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
|
|||||||
return raw_filename
|
return raw_filename
|
||||||
|
|
||||||
def get_transformer_dtype(model_family, transformer_dtype_policy):
|
def get_transformer_dtype(model_family, transformer_dtype_policy):
|
||||||
|
if not isinstance(transformer_dtype_policy, str):
|
||||||
|
return transformer_dtype_policy
|
||||||
if len(transformer_dtype_policy) == 0:
|
if len(transformer_dtype_policy) == 0:
|
||||||
if not bfloat16_supported:
|
if not bfloat16_supported:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
@ -2290,19 +2292,25 @@ def get_transformer_model(model):
|
|||||||
|
|
||||||
def load_models(model_type):
|
def load_models(model_type):
|
||||||
global transformer_type, transformer_loras_filenames
|
global transformer_type, transformer_loras_filenames
|
||||||
model_filename = get_model_filename(model_type=model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
|
|
||||||
base_model_type = get_base_model_type(model_type)
|
base_model_type = get_base_model_type(model_type)
|
||||||
finetune_def = get_model_finetune_def(model_type)
|
finetune_def = get_model_finetune_def(model_type)
|
||||||
quantizeTransformer = finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
|
|
||||||
|
|
||||||
model_family = get_model_family(model_type)
|
|
||||||
perc_reserved_mem_max = args.perc_reserved_mem_max
|
|
||||||
preload =int(args.preload)
|
preload =int(args.preload)
|
||||||
save_quantized = args.save_quantized
|
save_quantized = args.save_quantized and finetune_def != None
|
||||||
|
model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy)
|
||||||
|
if save_quantized and "quanto" in model_filename:
|
||||||
|
save_quantized = False
|
||||||
|
print("Need to provide a non quantized model to create a quantized model to be saved")
|
||||||
|
quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
|
||||||
|
model_family = get_model_family(model_type)
|
||||||
|
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
||||||
|
if quantizeTransformer or "quanto" in model_filename:
|
||||||
|
transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype
|
||||||
|
transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype
|
||||||
|
perc_reserved_mem_max = args.perc_reserved_mem_max
|
||||||
if preload == 0:
|
if preload == 0:
|
||||||
preload = server_config.get("preload_in_VRAM", 0)
|
preload = server_config.get("preload_in_VRAM", 0)
|
||||||
new_transformer_loras_filenames = None
|
new_transformer_loras_filenames = None
|
||||||
dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
|
dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype)
|
||||||
new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None
|
new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None
|
||||||
|
|
||||||
model_file_list = dependent_models + [model_filename]
|
model_file_list = dependent_models + [model_filename]
|
||||||
@ -2310,15 +2318,11 @@ def load_models(model_type):
|
|||||||
new_transformer_filename = model_file_list[-1]
|
new_transformer_filename = model_file_list[-1]
|
||||||
if finetune_def != None:
|
if finetune_def != None:
|
||||||
for module_type in finetune_def.get("modules", []):
|
for module_type in finetune_def.get("modules", []):
|
||||||
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype_policy))
|
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype))
|
||||||
model_type_list.append(module_type)
|
model_type_list.append(module_type)
|
||||||
|
|
||||||
for filename, file_model_type in zip(model_file_list, model_type_list):
|
for filename, file_model_type in zip(model_file_list, model_type_list):
|
||||||
download_models(filename, file_model_type)
|
download_models(filename, file_model_type)
|
||||||
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
|
||||||
if quantizeTransformer:
|
|
||||||
transformer_dtype = torch.bfloat16 if "bf16" in model_filename else transformer_dtype
|
|
||||||
transformer_dtype = torch.float16 if "fp16" in model_filename else transformer_dtype
|
|
||||||
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
||||||
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
||||||
transformer_filename = None
|
transformer_filename = None
|
||||||
@ -2364,7 +2368,7 @@ def load_models(model_type):
|
|||||||
prompt_enhancer_llm_tokenizer = None
|
prompt_enhancer_llm_tokenizer = None
|
||||||
|
|
||||||
|
|
||||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
|
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
|
||||||
if len(args.gpu) > 0:
|
if len(args.gpu) > 0:
|
||||||
torch.set_default_device(args.gpu)
|
torch.set_default_device(args.gpu)
|
||||||
transformer_loras_filenames = new_transformer_loras_filenames
|
transformer_loras_filenames = new_transformer_loras_filenames
|
||||||
@ -3092,11 +3096,11 @@ def generate_video(
|
|||||||
# TeaCache
|
# TeaCache
|
||||||
if args.teacache > 0:
|
if args.teacache > 0:
|
||||||
tea_cache_setting = args.teacache
|
tea_cache_setting = args.teacache
|
||||||
trans.enable_teacache = tea_cache_setting > 0
|
trans.enable_cache = tea_cache_setting > 0
|
||||||
if trans.enable_teacache:
|
if trans.enable_cache:
|
||||||
trans.teacache_multiplier = tea_cache_setting
|
trans.teacache_multiplier = tea_cache_setting
|
||||||
trans.rel_l1_thresh = 0
|
trans.rel_l1_thresh = 0
|
||||||
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
trans.cache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
||||||
if get_model_family(model_type) == "wan":
|
if get_model_family(model_type) == "wan":
|
||||||
if image2video:
|
if image2video:
|
||||||
if '720p' in model_filename:
|
if '720p' in model_filename:
|
||||||
@ -3323,7 +3327,7 @@ def generate_video(
|
|||||||
progress_args = [0, merge_status_context(status, "Encoding Prompt")]
|
progress_args = [0, merge_status_context(status, "Encoding Prompt")]
|
||||||
send_cmd("progress", progress_args)
|
send_cmd("progress", progress_args)
|
||||||
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_cache:
|
||||||
trans.teacache_counter = 0
|
trans.teacache_counter = 0
|
||||||
trans.num_steps = num_inference_steps
|
trans.num_steps = num_inference_steps
|
||||||
trans.teacache_skipped_steps = 0
|
trans.teacache_skipped_steps = 0
|
||||||
@ -3412,7 +3416,7 @@ def generate_video(
|
|||||||
trans.previous_residual = None
|
trans.previous_residual = None
|
||||||
trans.previous_modulated_input = None
|
trans.previous_modulated_input = None
|
||||||
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_cache:
|
||||||
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
|
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
|
||||||
|
|
||||||
if samples != None:
|
if samples != None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user