fixed save quantization

This commit is contained in:
DeepBeepMeep 2025-06-13 18:41:35 +02:00
parent 5a63326bb9
commit 73cf4e43c3
12 changed files with 116 additions and 64 deletions

View File

@ -66,14 +66,20 @@ If a model is not quantized, it is assumed to be mostly 16 bits (with maybe a fe
If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*.
Please note it is important than *bf16", "fp16* and *quanto* are all in lower cases letters.
## Creating a Quanto Quantized file
If you launch the app with the *--save-quantized* switch, WanGP will create a quantized file in the **ckpts** subfolder just after the model has been loaded. Please note that the model will *bf16* or *fp16* quantized depending on what you chose in the configuration menu.
1) Make sure that in the finetune definition json file there is only a URL or filepath that points to the non quantized model
2) Launch WanGP *python wgp.py --save-quantized*
3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16*
4) Launch a generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder it doesn't already exist.
4) Launch a video generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder if it doesn't already exist.
5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]*
6) Restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
7) Launch a new generation an verify in the terminal window that the right quantized model is loaded
8) In order to share the finetune definition file will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
6) Remove *--save-quantized*, restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
7) Launch a new generation and verify in the terminal window that the right quantized model is loaded
8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded.
Wan models supports both *fp16* and *bf16* data types albeit *fp16* delivers in theory better quality. On the contrary Hunyuan and LTXV supports only *bf16*.

View File

@ -949,11 +949,11 @@ class HunyuanVideoPipeline(DiffusionPipeline):
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
trans = self.transformer
if trans.enable_teacache:
if trans.enable_cache:
teacache_multiplier = trans.teacache_multiplier
trans.accumulated_rel_l1_distance = 0
trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
# trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
# trans.cache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
@ -1208,7 +1208,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
if ip_cfg_scale>0:
latent_items += 1
if self.transformer.enable_teacache:
if self.transformer.enable_cache:
self.transformer.previous_residual = [None] * latent_items
# if is_progress_bar:

View File

@ -934,7 +934,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
transformer = self.transformer
if transformer.enable_teacache:
if transformer.enable_cache:
teacache_multiplier = transformer.teacache_multiplier
transformer.accumulated_rel_l1_distance = 0
transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
@ -1136,7 +1136,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
if self._interrupt:
return [None]
if transformer.enable_teacache:
if transformer.enable_cache:
cache_size = round( infer_length / frames_per_batch )
transformer.previous_residual = [None] * latent_items
cache_all_previous_residual = [None] * latent_items
@ -1180,7 +1180,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1)
img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
if transformer.enable_teacache and cache_size > 1:
if transformer.enable_cache and cache_size > 1:
for l in range(latent_items):
if cache_all_previous_residual[l] != None:
bsz = cache_all_previous_residual[l].shape[0]
@ -1297,7 +1297,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
pred_latents[:, :, p] += latents[:, :, iii]
counter[:, :, p] += 1
if transformer.enable_teacache and cache_size > 1:
if transformer.enable_cache and cache_size > 1:
for l in range(latent_items):
if transformer.previous_residual[l] != None:
bsz = transformer.previous_residual[l].shape[0]

View File

@ -922,7 +922,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
if self.enable_teacache:
if self.enable_cache:
if x_id == 0:
self.should_calc = True
inp = img[0:1]
@ -932,7 +932,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
normed_inp = normed_inp.to(torch.bfloat16)
modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
del normed_inp, img_mod1_shift, img_mod1_scale
if step_no <= self.teacache_start_step or step_no == self.num_steps-1:
if step_no <= self.cache_start_step or step_no == self.num_steps-1:
self.accumulated_rel_l1_distance = 0
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
@ -950,7 +950,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
if not self.should_calc:
img += self.previous_residual[x_id]
else:
if self.enable_teacache:
if self.enable_cache:
self.previous_residual[x_id] = None
ori_img = img[0:1].clone()
# --------------------- Pass through DiT blocks ------------------------
@ -1014,7 +1014,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
single_block_args = None
# img = x[:, :img_seq_len, ...]
if self.enable_teacache:
if self.enable_cache:
if len(img) > 1:
self.previous_residual[0] = torch.empty_like(img)
for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):

View File

@ -551,8 +551,8 @@ def main():
# Setup tea cache if needed
trans = wan_model.model
trans.enable_teacache = (args.teacache > 0)
if trans.enable_teacache:
trans.enable_cache = (args.teacache > 0)
if trans.enable_cache:
if "480p" in args.transformer_file:
# example from your code
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
@ -582,10 +582,10 @@ def main():
enable_riflex = args.riflex
# If teacache => reset counters
if trans.enable_teacache:
if trans.enable_cache:
trans.teacache_counter = 0
trans.teacache_multiplier = args.teacache
trans.teacache_start_step = int(args.teacache_start * args.steps / 100.0)
trans.cache_start_step = int(args.teacache_start * args.steps / 100.0)
trans.num_steps = args.steps
trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None
@ -655,7 +655,7 @@ def main():
raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
# If teacache was used, we can see how many steps were skipped
if trans.enable_teacache:
if trans.enable_cache:
print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
# Save result

View File

@ -78,7 +78,7 @@ class DTT2V:
self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
self.scheduler = FlowUniPCMultistepScheduler()
@ -316,7 +316,7 @@ class DTT2V:
updated_num_steps= len(step_matrix)
if callback != None:
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
if self.model.enable_teacache:
if self.model.enable_cache:
x_count = 2 if self.do_classifier_free_guidance else 1
self.model.previous_residual = [None] * x_count
time_steps_comb = []
@ -327,7 +327,7 @@ class DTT2V:
if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise
time_steps_comb.append(timestep)
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
self.model.compute_teacache_threshold(self.model.cache_start_step, time_steps_comb, self.model.teacache_multiplier)
del time_steps_comb
from mmgp import offload
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)

View File

@ -116,7 +116,7 @@ class WanI2V:
self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
self.sample_neg_prompt = config.sample_neg_prompt
@ -317,9 +317,9 @@ class WanI2V:
"audio_context_lens": audio_context_lens,
})
if self.model.enable_teacache:
if self.model.enable_cache:
self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.teacache_multiplier)
# self.model.to(self.device)
if callback != None:

View File

@ -194,6 +194,11 @@ def pay_attention(
q = q.to(v.dtype)
k = k.to(v.dtype)
if attn == "chipmunk":
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
assert attention_mask == None
# Poor's man var k len attention

View File

@ -11,6 +11,8 @@ from typing import Union,Optional
from mmgp import offload
from .attention import pay_attention
from torch.backends.cuda import sdp_kernel
# from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
# from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
__all__ = ['WanModel']
@ -172,6 +174,11 @@ class WanSelfAttention(nn.Module):
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
# Only initialize SparseDiffAttn if this is not a subclass initialization
# if self.__class__ == WanSelfAttention:
# layer_num, layer_counter = LayerCounter.build_for_layer(is_attn_sparse=True, is_mlp_sparse=False)
# self.attn = SparseDiffAttn(layer_num, layer_counter)
def forward(self, xlist, grid_sizes, freqs, block_mask = None):
r"""
Args:
@ -197,7 +204,10 @@ class WanSelfAttention(nn.Module):
del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
if block_mask == None:
chipmunk = offload.shared_state["_chipmunk"]
if chipmunk:
x = self.attn(q, k, v)
elif block_mask == None:
qkv_list = [q,k,v]
del q,k,v
x = pay_attention(
@ -954,6 +964,16 @@ class WanModel(ModelMixin, ConfigMixin):
x_list[i] = x
x, y = None, None
offload.shared_state["_chipmunk"] = False
chipmunk = offload.shared_state["_chipmunk"]
if chipmunk:
voxel_shape = (4, 6, 8)
for x in x_list:
from src.chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
x = x.unsqueeze(-1)
x_og_shape = x.shape
x = voxel_chunk_no_padding(x, voxel_shape).squeeze(-1).transpose(1, 2)
x = None
block_mask = None
if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
@ -1027,11 +1047,11 @@ class WanModel(ModelMixin, ConfigMixin):
del c
should_calc = True
if self.enable_teacache:
if self.enable_cache:
if x_id != 0:
should_calc = self.should_calc
else:
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
if current_step <= self.cache_start_step or current_step == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
@ -1057,7 +1077,7 @@ class WanModel(ModelMixin, ConfigMixin):
x += self.previous_residual[x_id]
x = None
else:
if self.enable_teacache:
if self.enable_cache:
if joint_pass:
self.previous_residual = [ None ] * len(self.previous_residual)
else:
@ -1084,7 +1104,7 @@ class WanModel(ModelMixin, ConfigMixin):
del x
del context, hints
if self.enable_teacache:
if self.enable_cache:
if joint_pass:
for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
if i == 0 or is_source and i != last_x_idx :
@ -1101,6 +1121,10 @@ class WanModel(ModelMixin, ConfigMixin):
residual, ori_hidden_states = None, None
for i, x in enumerate(x_list):
if chipmunk:
x = reverse_voxel_chunk_no_padding(x.transpose(1, 2).unsqueeze(-1), x_og_shape, voxel_shape).squeeze(-1)
x = x.flatten(2).transpose(1, 2)
# head
x = self.head(x, e)

View File

@ -101,7 +101,7 @@ class WanT2V:
self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
save_quantized_model(self.model, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
self.sample_neg_prompt = config.sample_neg_prompt
@ -458,13 +458,24 @@ class WanT2V:
z_reactive = [ zz[0:16, 0:overlapped_latents_size + ref_images_count].clone() for zz in z]
if self.model.enable_teacache:
if self.model.enable_cache:
x_count = 3 if phantom else 2
self.model.previous_residual = [None] * x_count
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
callback(-1, None, True)
prev = 50/1000
# seq_shape = (21, 45, 80)
# local_heads_num = 40 #12 for 1.3B
# self.model.blocks[0].self_attn.attn.initialize_static_mask(
# seq_shape=seq_shape,
# txt_len=0,
# local_heads_num=local_heads_num,
# device='cuda'
# )
# self.model.blocks[0].self_attn.attn.layer_counter.reset()
for i, t in enumerate(tqdm(timesteps)):
timestep = [t]

View File

@ -338,17 +338,19 @@ def create_progress_hook(filename):
return hook
def save_quantized_model(model, model_filename, dtype, config_file):
if "quanto" in model_filename:
return
from mmgp import offload
if dtype == torch.bfloat16:
model_filename = model_filename.replace("fp16", "bf16")
model_filename = model_filename.replace("fp16", "bf16").replace("FP16", "bf16")
elif dtype == torch.float16:
model_filename = model_filename.replace("bf16", "fp16")
model_filename = model_filename.replace("bf16", "fp16").replace("BF16", "bf16")
if "_fp16" in model_filename:
model_filename = model_filename.replace("_fp16", "_quanto_fp16_int8")
elif "_bf16" in model_filename:
model_filename = model_filename.replace("_bf16", "_quanto_bf16_int8")
else:
for rep in ["mfp16", "fp16", "mbf16", "bf16"]:
if "_" + rep in model_filename:
model_filename = model_filename.replace("_" + rep, "_quanto_" + rep + "_int8")
break
if not "quanto" in model_filename:
pos = model_filename.rfind(".")
model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:]

48
wgp.py
View File

@ -1624,7 +1624,7 @@ def get_model_family(model_type):
def test_class_i2v(model_type):
model_type = get_base_model_type(model_type)
return model_type in ["i2v", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ]
return model_type in ["i2v", "i2v_720p", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ]
def get_model_name(model_type, description_container = [""]):
finetune_def = get_model_finetune_def(model_type)
@ -1731,19 +1731,19 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
raw_filename = choices[0]
else:
if quantization in ("int8", "fp8"):
sub_choices = [ name for name in choices if quantization in name]
sub_choices = [ name for name in choices if quantization in name or quantization.upper() in name]
else:
sub_choices = [ name for name in choices if "quanto" not in name]
if len(sub_choices) > 0:
dtype_str = "fp16" if dtype == torch.float16 else "bf16"
new_sub_choices = [ name for name in sub_choices if dtype_str in name]
new_sub_choices = [ name for name in sub_choices if dtype_str in name or dtype_str.upper() in name]
sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices
raw_filename = sub_choices[0]
else:
raw_filename = choices[0]
if dtype == torch.float16 and not "fp16" in raw_filename and model_family == "wan" and finetune_def == None :
if dtype == torch.float16 and not any("fp16","FP16") in raw_filename and model_family == "wan" and finetune_def == None :
if "quanto_int8" in raw_filename:
raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
elif "quanto_bf16_int8" in raw_filename:
@ -1753,6 +1753,8 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
return raw_filename
def get_transformer_dtype(model_family, transformer_dtype_policy):
if not isinstance(transformer_dtype_policy, str):
return transformer_dtype_policy
if len(transformer_dtype_policy) == 0:
if not bfloat16_supported:
return torch.float16
@ -2290,19 +2292,25 @@ def get_transformer_model(model):
def load_models(model_type):
global transformer_type, transformer_loras_filenames
model_filename = get_model_filename(model_type=model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
base_model_type = get_base_model_type(model_type)
finetune_def = get_model_finetune_def(model_type)
quantizeTransformer = finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
model_family = get_model_family(model_type)
perc_reserved_mem_max = args.perc_reserved_mem_max
preload =int(args.preload)
save_quantized = args.save_quantized
save_quantized = args.save_quantized and finetune_def != None
model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy)
if save_quantized and "quanto" in model_filename:
save_quantized = False
print("Need to provide a non quantized model to create a quantized model to be saved")
quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
model_family = get_model_family(model_type)
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
if quantizeTransformer or "quanto" in model_filename:
transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype
transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype
perc_reserved_mem_max = args.perc_reserved_mem_max
if preload == 0:
preload = server_config.get("preload_in_VRAM", 0)
new_transformer_loras_filenames = None
dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype)
new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None
model_file_list = dependent_models + [model_filename]
@ -2310,15 +2318,11 @@ def load_models(model_type):
new_transformer_filename = model_file_list[-1]
if finetune_def != None:
for module_type in finetune_def.get("modules", []):
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype_policy))
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype))
model_type_list.append(module_type)
for filename, file_model_type in zip(model_file_list, model_type_list):
download_models(filename, file_model_type)
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
if quantizeTransformer:
transformer_dtype = torch.bfloat16 if "bf16" in model_filename else transformer_dtype
transformer_dtype = torch.float16 if "fp16" in model_filename else transformer_dtype
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
transformer_filename = None
@ -2364,7 +2368,7 @@ def load_models(model_type):
prompt_enhancer_llm_tokenizer = None
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
if len(args.gpu) > 0:
torch.set_default_device(args.gpu)
transformer_loras_filenames = new_transformer_loras_filenames
@ -3092,11 +3096,11 @@ def generate_video(
# TeaCache
if args.teacache > 0:
tea_cache_setting = args.teacache
trans.enable_teacache = tea_cache_setting > 0
if trans.enable_teacache:
trans.enable_cache = tea_cache_setting > 0
if trans.enable_cache:
trans.teacache_multiplier = tea_cache_setting
trans.rel_l1_thresh = 0
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
trans.cache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
if get_model_family(model_type) == "wan":
if image2video:
if '720p' in model_filename:
@ -3323,7 +3327,7 @@ def generate_video(
progress_args = [0, merge_status_context(status, "Encoding Prompt")]
send_cmd("progress", progress_args)
if trans.enable_teacache:
if trans.enable_cache:
trans.teacache_counter = 0
trans.num_steps = num_inference_steps
trans.teacache_skipped_steps = 0
@ -3412,7 +3416,7 @@ def generate_video(
trans.previous_residual = None
trans.previous_modulated_input = None
if trans.enable_teacache:
if trans.enable_cache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
if samples != None: