mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
fixed save quantization
This commit is contained in:
parent
5a63326bb9
commit
73cf4e43c3
@ -66,14 +66,20 @@ If a model is not quantized, it is assumed to be mostly 16 bits (with maybe a fe
|
||||
|
||||
If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*.
|
||||
|
||||
Please note it is important than *bf16", "fp16* and *quanto* are all in lower cases letters.
|
||||
|
||||
## Creating a Quanto Quantized file
|
||||
If you launch the app with the *--save-quantized* switch, WanGP will create a quantized file in the **ckpts** subfolder just after the model has been loaded. Please note that the model will *bf16* or *fp16* quantized depending on what you chose in the configuration menu.
|
||||
|
||||
1) Make sure that in the finetune definition json file there is only a URL or filepath that points to the non quantized model
|
||||
2) Launch WanGP *python wgp.py --save-quantized*
|
||||
3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16*
|
||||
4) Launch a generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder it doesn't already exist.
|
||||
4) Launch a video generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder if it doesn't already exist.
|
||||
5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]*
|
||||
6) Restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
|
||||
7) Launch a new generation an verify in the terminal window that the right quantized model is loaded
|
||||
8) In order to share the finetune definition file will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
|
||||
6) Remove *--save-quantized*, restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
|
||||
7) Launch a new generation and verify in the terminal window that the right quantized model is loaded
|
||||
8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
|
||||
|
||||
You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded.
|
||||
|
||||
Wan models supports both *fp16* and *bf16* data types albeit *fp16* delivers in theory better quality. On the contrary Hunyuan and LTXV supports only *bf16*.
|
||||
@ -949,11 +949,11 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
# to deal with lora scaling and other possible forward hooks
|
||||
trans = self.transformer
|
||||
if trans.enable_teacache:
|
||||
if trans.enable_cache:
|
||||
teacache_multiplier = trans.teacache_multiplier
|
||||
trans.accumulated_rel_l1_distance = 0
|
||||
trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
||||
# trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
||||
# trans.cache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
@ -1208,7 +1208,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
if ip_cfg_scale>0:
|
||||
latent_items += 1
|
||||
|
||||
if self.transformer.enable_teacache:
|
||||
if self.transformer.enable_cache:
|
||||
self.transformer.previous_residual = [None] * latent_items
|
||||
|
||||
# if is_progress_bar:
|
||||
|
||||
@ -934,7 +934,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
||||
|
||||
transformer = self.transformer
|
||||
|
||||
if transformer.enable_teacache:
|
||||
if transformer.enable_cache:
|
||||
teacache_multiplier = transformer.teacache_multiplier
|
||||
transformer.accumulated_rel_l1_distance = 0
|
||||
transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
||||
@ -1136,7 +1136,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
||||
if self._interrupt:
|
||||
return [None]
|
||||
|
||||
if transformer.enable_teacache:
|
||||
if transformer.enable_cache:
|
||||
cache_size = round( infer_length / frames_per_batch )
|
||||
transformer.previous_residual = [None] * latent_items
|
||||
cache_all_previous_residual = [None] * latent_items
|
||||
@ -1180,7 +1180,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
||||
img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1)
|
||||
img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
|
||||
|
||||
if transformer.enable_teacache and cache_size > 1:
|
||||
if transformer.enable_cache and cache_size > 1:
|
||||
for l in range(latent_items):
|
||||
if cache_all_previous_residual[l] != None:
|
||||
bsz = cache_all_previous_residual[l].shape[0]
|
||||
@ -1297,7 +1297,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
||||
pred_latents[:, :, p] += latents[:, :, iii]
|
||||
counter[:, :, p] += 1
|
||||
|
||||
if transformer.enable_teacache and cache_size > 1:
|
||||
if transformer.enable_cache and cache_size > 1:
|
||||
for l in range(latent_items):
|
||||
if transformer.previous_residual[l] != None:
|
||||
bsz = transformer.previous_residual[l].shape[0]
|
||||
|
||||
@ -922,7 +922,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
||||
|
||||
|
||||
if self.enable_teacache:
|
||||
if self.enable_cache:
|
||||
if x_id == 0:
|
||||
self.should_calc = True
|
||||
inp = img[0:1]
|
||||
@ -932,7 +932,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
normed_inp = normed_inp.to(torch.bfloat16)
|
||||
modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
|
||||
del normed_inp, img_mod1_shift, img_mod1_scale
|
||||
if step_no <= self.teacache_start_step or step_no == self.num_steps-1:
|
||||
if step_no <= self.cache_start_step or step_no == self.num_steps-1:
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||
@ -950,7 +950,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
if not self.should_calc:
|
||||
img += self.previous_residual[x_id]
|
||||
else:
|
||||
if self.enable_teacache:
|
||||
if self.enable_cache:
|
||||
self.previous_residual[x_id] = None
|
||||
ori_img = img[0:1].clone()
|
||||
# --------------------- Pass through DiT blocks ------------------------
|
||||
@ -1014,7 +1014,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
single_block_args = None
|
||||
|
||||
# img = x[:, :img_seq_len, ...]
|
||||
if self.enable_teacache:
|
||||
if self.enable_cache:
|
||||
if len(img) > 1:
|
||||
self.previous_residual[0] = torch.empty_like(img)
|
||||
for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
|
||||
|
||||
@ -551,8 +551,8 @@ def main():
|
||||
|
||||
# Setup tea cache if needed
|
||||
trans = wan_model.model
|
||||
trans.enable_teacache = (args.teacache > 0)
|
||||
if trans.enable_teacache:
|
||||
trans.enable_cache = (args.teacache > 0)
|
||||
if trans.enable_cache:
|
||||
if "480p" in args.transformer_file:
|
||||
# example from your code
|
||||
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||
@ -582,10 +582,10 @@ def main():
|
||||
enable_riflex = args.riflex
|
||||
|
||||
# If teacache => reset counters
|
||||
if trans.enable_teacache:
|
||||
if trans.enable_cache:
|
||||
trans.teacache_counter = 0
|
||||
trans.teacache_multiplier = args.teacache
|
||||
trans.teacache_start_step = int(args.teacache_start * args.steps / 100.0)
|
||||
trans.cache_start_step = int(args.teacache_start * args.steps / 100.0)
|
||||
trans.num_steps = args.steps
|
||||
trans.teacache_skipped_steps = 0
|
||||
trans.previous_residual_uncond = None
|
||||
@ -655,7 +655,7 @@ def main():
|
||||
raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
|
||||
|
||||
# If teacache was used, we can see how many steps were skipped
|
||||
if trans.enable_teacache:
|
||||
if trans.enable_cache:
|
||||
print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
|
||||
|
||||
# Save result
|
||||
|
||||
@ -78,7 +78,7 @@ class DTT2V:
|
||||
self.model.eval().requires_grad_(False)
|
||||
if save_quantized:
|
||||
from wan.utils.utils import save_quantized_model
|
||||
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
|
||||
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
|
||||
|
||||
self.scheduler = FlowUniPCMultistepScheduler()
|
||||
|
||||
@ -316,7 +316,7 @@ class DTT2V:
|
||||
updated_num_steps= len(step_matrix)
|
||||
if callback != None:
|
||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
||||
if self.model.enable_teacache:
|
||||
if self.model.enable_cache:
|
||||
x_count = 2 if self.do_classifier_free_guidance else 1
|
||||
self.model.previous_residual = [None] * x_count
|
||||
time_steps_comb = []
|
||||
@ -327,7 +327,7 @@ class DTT2V:
|
||||
if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
|
||||
timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise
|
||||
time_steps_comb.append(timestep)
|
||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
|
||||
self.model.compute_teacache_threshold(self.model.cache_start_step, time_steps_comb, self.model.teacache_multiplier)
|
||||
del time_steps_comb
|
||||
from mmgp import offload
|
||||
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
|
||||
|
||||
@ -116,7 +116,7 @@ class WanI2V:
|
||||
self.model.eval().requires_grad_(False)
|
||||
if save_quantized:
|
||||
from wan.utils.utils import save_quantized_model
|
||||
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
|
||||
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
|
||||
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
@ -317,9 +317,9 @@ class WanI2V:
|
||||
"audio_context_lens": audio_context_lens,
|
||||
})
|
||||
|
||||
if self.model.enable_teacache:
|
||||
if self.model.enable_cache:
|
||||
self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
|
||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.teacache_multiplier)
|
||||
|
||||
# self.model.to(self.device)
|
||||
if callback != None:
|
||||
|
||||
@ -194,6 +194,11 @@ def pay_attention(
|
||||
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
if attn == "chipmunk":
|
||||
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
|
||||
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
|
||||
|
||||
if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
|
||||
assert attention_mask == None
|
||||
# Poor's man var k len attention
|
||||
|
||||
@ -11,6 +11,8 @@ from typing import Union,Optional
|
||||
from mmgp import offload
|
||||
from .attention import pay_attention
|
||||
from torch.backends.cuda import sdp_kernel
|
||||
# from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
|
||||
# from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
@ -172,6 +174,11 @@ class WanSelfAttention(nn.Module):
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
# Only initialize SparseDiffAttn if this is not a subclass initialization
|
||||
# if self.__class__ == WanSelfAttention:
|
||||
# layer_num, layer_counter = LayerCounter.build_for_layer(is_attn_sparse=True, is_mlp_sparse=False)
|
||||
# self.attn = SparseDiffAttn(layer_num, layer_counter)
|
||||
|
||||
def forward(self, xlist, grid_sizes, freqs, block_mask = None):
|
||||
r"""
|
||||
Args:
|
||||
@ -197,7 +204,10 @@ class WanSelfAttention(nn.Module):
|
||||
del q,k
|
||||
|
||||
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
||||
if block_mask == None:
|
||||
chipmunk = offload.shared_state["_chipmunk"]
|
||||
if chipmunk:
|
||||
x = self.attn(q, k, v)
|
||||
elif block_mask == None:
|
||||
qkv_list = [q,k,v]
|
||||
del q,k,v
|
||||
x = pay_attention(
|
||||
@ -954,6 +964,16 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
x_list[i] = x
|
||||
x, y = None, None
|
||||
|
||||
offload.shared_state["_chipmunk"] = False
|
||||
chipmunk = offload.shared_state["_chipmunk"]
|
||||
if chipmunk:
|
||||
voxel_shape = (4, 6, 8)
|
||||
for x in x_list:
|
||||
from src.chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
|
||||
x = x.unsqueeze(-1)
|
||||
x_og_shape = x.shape
|
||||
x = voxel_chunk_no_padding(x, voxel_shape).squeeze(-1).transpose(1, 2)
|
||||
x = None
|
||||
|
||||
block_mask = None
|
||||
if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
|
||||
@ -1027,11 +1047,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
del c
|
||||
|
||||
should_calc = True
|
||||
if self.enable_teacache:
|
||||
if self.enable_cache:
|
||||
if x_id != 0:
|
||||
should_calc = self.should_calc
|
||||
else:
|
||||
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
|
||||
if current_step <= self.cache_start_step or current_step == self.num_steps-1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
@ -1057,7 +1077,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
x += self.previous_residual[x_id]
|
||||
x = None
|
||||
else:
|
||||
if self.enable_teacache:
|
||||
if self.enable_cache:
|
||||
if joint_pass:
|
||||
self.previous_residual = [ None ] * len(self.previous_residual)
|
||||
else:
|
||||
@ -1084,7 +1104,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
del x
|
||||
del context, hints
|
||||
|
||||
if self.enable_teacache:
|
||||
if self.enable_cache:
|
||||
if joint_pass:
|
||||
for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
|
||||
if i == 0 or is_source and i != last_x_idx :
|
||||
@ -1101,6 +1121,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
residual, ori_hidden_states = None, None
|
||||
|
||||
for i, x in enumerate(x_list):
|
||||
if chipmunk:
|
||||
x = reverse_voxel_chunk_no_padding(x.transpose(1, 2).unsqueeze(-1), x_og_shape, voxel_shape).squeeze(-1)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ class WanT2V:
|
||||
self.model.eval().requires_grad_(False)
|
||||
if save_quantized:
|
||||
from wan.utils.utils import save_quantized_model
|
||||
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
|
||||
save_quantized_model(self.model, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
@ -458,13 +458,24 @@ class WanT2V:
|
||||
z_reactive = [ zz[0:16, 0:overlapped_latents_size + ref_images_count].clone() for zz in z]
|
||||
|
||||
|
||||
if self.model.enable_teacache:
|
||||
if self.model.enable_cache:
|
||||
x_count = 3 if phantom else 2
|
||||
self.model.previous_residual = [None] * x_count
|
||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.teacache_multiplier)
|
||||
if callback != None:
|
||||
callback(-1, None, True)
|
||||
prev = 50/1000
|
||||
|
||||
# seq_shape = (21, 45, 80)
|
||||
# local_heads_num = 40 #12 for 1.3B
|
||||
|
||||
# self.model.blocks[0].self_attn.attn.initialize_static_mask(
|
||||
# seq_shape=seq_shape,
|
||||
# txt_len=0,
|
||||
# local_heads_num=local_heads_num,
|
||||
# device='cuda'
|
||||
# )
|
||||
# self.model.blocks[0].self_attn.attn.layer_counter.reset()
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
|
||||
timestep = [t]
|
||||
|
||||
@ -338,17 +338,19 @@ def create_progress_hook(filename):
|
||||
return hook
|
||||
|
||||
def save_quantized_model(model, model_filename, dtype, config_file):
|
||||
if "quanto" in model_filename:
|
||||
return
|
||||
from mmgp import offload
|
||||
if dtype == torch.bfloat16:
|
||||
model_filename = model_filename.replace("fp16", "bf16")
|
||||
model_filename = model_filename.replace("fp16", "bf16").replace("FP16", "bf16")
|
||||
elif dtype == torch.float16:
|
||||
model_filename = model_filename.replace("bf16", "fp16")
|
||||
model_filename = model_filename.replace("bf16", "fp16").replace("BF16", "bf16")
|
||||
|
||||
if "_fp16" in model_filename:
|
||||
model_filename = model_filename.replace("_fp16", "_quanto_fp16_int8")
|
||||
elif "_bf16" in model_filename:
|
||||
model_filename = model_filename.replace("_bf16", "_quanto_bf16_int8")
|
||||
else:
|
||||
for rep in ["mfp16", "fp16", "mbf16", "bf16"]:
|
||||
if "_" + rep in model_filename:
|
||||
model_filename = model_filename.replace("_" + rep, "_quanto_" + rep + "_int8")
|
||||
break
|
||||
if not "quanto" in model_filename:
|
||||
pos = model_filename.rfind(".")
|
||||
model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:]
|
||||
|
||||
|
||||
48
wgp.py
48
wgp.py
@ -1624,7 +1624,7 @@ def get_model_family(model_type):
|
||||
|
||||
def test_class_i2v(model_type):
|
||||
model_type = get_base_model_type(model_type)
|
||||
return model_type in ["i2v", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ]
|
||||
return model_type in ["i2v", "i2v_720p", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ]
|
||||
|
||||
def get_model_name(model_type, description_container = [""]):
|
||||
finetune_def = get_model_finetune_def(model_type)
|
||||
@ -1731,19 +1731,19 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
|
||||
raw_filename = choices[0]
|
||||
else:
|
||||
if quantization in ("int8", "fp8"):
|
||||
sub_choices = [ name for name in choices if quantization in name]
|
||||
sub_choices = [ name for name in choices if quantization in name or quantization.upper() in name]
|
||||
else:
|
||||
sub_choices = [ name for name in choices if "quanto" not in name]
|
||||
|
||||
if len(sub_choices) > 0:
|
||||
dtype_str = "fp16" if dtype == torch.float16 else "bf16"
|
||||
new_sub_choices = [ name for name in sub_choices if dtype_str in name]
|
||||
new_sub_choices = [ name for name in sub_choices if dtype_str in name or dtype_str.upper() in name]
|
||||
sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices
|
||||
raw_filename = sub_choices[0]
|
||||
else:
|
||||
raw_filename = choices[0]
|
||||
|
||||
if dtype == torch.float16 and not "fp16" in raw_filename and model_family == "wan" and finetune_def == None :
|
||||
if dtype == torch.float16 and not any("fp16","FP16") in raw_filename and model_family == "wan" and finetune_def == None :
|
||||
if "quanto_int8" in raw_filename:
|
||||
raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
|
||||
elif "quanto_bf16_int8" in raw_filename:
|
||||
@ -1753,6 +1753,8 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
|
||||
return raw_filename
|
||||
|
||||
def get_transformer_dtype(model_family, transformer_dtype_policy):
|
||||
if not isinstance(transformer_dtype_policy, str):
|
||||
return transformer_dtype_policy
|
||||
if len(transformer_dtype_policy) == 0:
|
||||
if not bfloat16_supported:
|
||||
return torch.float16
|
||||
@ -2290,19 +2292,25 @@ def get_transformer_model(model):
|
||||
|
||||
def load_models(model_type):
|
||||
global transformer_type, transformer_loras_filenames
|
||||
model_filename = get_model_filename(model_type=model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
|
||||
base_model_type = get_base_model_type(model_type)
|
||||
finetune_def = get_model_finetune_def(model_type)
|
||||
quantizeTransformer = finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
|
||||
|
||||
model_family = get_model_family(model_type)
|
||||
perc_reserved_mem_max = args.perc_reserved_mem_max
|
||||
preload =int(args.preload)
|
||||
save_quantized = args.save_quantized
|
||||
save_quantized = args.save_quantized and finetune_def != None
|
||||
model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy)
|
||||
if save_quantized and "quanto" in model_filename:
|
||||
save_quantized = False
|
||||
print("Need to provide a non quantized model to create a quantized model to be saved")
|
||||
quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
|
||||
model_family = get_model_family(model_type)
|
||||
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
||||
if quantizeTransformer or "quanto" in model_filename:
|
||||
transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype
|
||||
transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype
|
||||
perc_reserved_mem_max = args.perc_reserved_mem_max
|
||||
if preload == 0:
|
||||
preload = server_config.get("preload_in_VRAM", 0)
|
||||
new_transformer_loras_filenames = None
|
||||
dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
|
||||
dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype)
|
||||
new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None
|
||||
|
||||
model_file_list = dependent_models + [model_filename]
|
||||
@ -2310,15 +2318,11 @@ def load_models(model_type):
|
||||
new_transformer_filename = model_file_list[-1]
|
||||
if finetune_def != None:
|
||||
for module_type in finetune_def.get("modules", []):
|
||||
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype_policy))
|
||||
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype))
|
||||
model_type_list.append(module_type)
|
||||
|
||||
for filename, file_model_type in zip(model_file_list, model_type_list):
|
||||
download_models(filename, file_model_type)
|
||||
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
||||
if quantizeTransformer:
|
||||
transformer_dtype = torch.bfloat16 if "bf16" in model_filename else transformer_dtype
|
||||
transformer_dtype = torch.float16 if "fp16" in model_filename else transformer_dtype
|
||||
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
||||
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
||||
transformer_filename = None
|
||||
@ -2364,7 +2368,7 @@ def load_models(model_type):
|
||||
prompt_enhancer_llm_tokenizer = None
|
||||
|
||||
|
||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
|
||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
|
||||
if len(args.gpu) > 0:
|
||||
torch.set_default_device(args.gpu)
|
||||
transformer_loras_filenames = new_transformer_loras_filenames
|
||||
@ -3092,11 +3096,11 @@ def generate_video(
|
||||
# TeaCache
|
||||
if args.teacache > 0:
|
||||
tea_cache_setting = args.teacache
|
||||
trans.enable_teacache = tea_cache_setting > 0
|
||||
if trans.enable_teacache:
|
||||
trans.enable_cache = tea_cache_setting > 0
|
||||
if trans.enable_cache:
|
||||
trans.teacache_multiplier = tea_cache_setting
|
||||
trans.rel_l1_thresh = 0
|
||||
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
||||
trans.cache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
||||
if get_model_family(model_type) == "wan":
|
||||
if image2video:
|
||||
if '720p' in model_filename:
|
||||
@ -3323,7 +3327,7 @@ def generate_video(
|
||||
progress_args = [0, merge_status_context(status, "Encoding Prompt")]
|
||||
send_cmd("progress", progress_args)
|
||||
|
||||
if trans.enable_teacache:
|
||||
if trans.enable_cache:
|
||||
trans.teacache_counter = 0
|
||||
trans.num_steps = num_inference_steps
|
||||
trans.teacache_skipped_steps = 0
|
||||
@ -3412,7 +3416,7 @@ def generate_video(
|
||||
trans.previous_residual = None
|
||||
trans.previous_modulated_input = None
|
||||
|
||||
if trans.enable_teacache:
|
||||
if trans.enable_cache:
|
||||
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
|
||||
|
||||
if samples != None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user