mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	fixed save quantization
This commit is contained in:
		
							parent
							
								
									5a63326bb9
								
							
						
					
					
						commit
						73cf4e43c3
					
				@ -66,14 +66,20 @@ If a model is not quantized, it is assumed to be mostly 16 bits (with maybe a fe
 | 
			
		||||
 | 
			
		||||
If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*. 
 | 
			
		||||
 | 
			
		||||
Please note it is important than *bf16", "fp16* and *quanto* are all in lower cases letters.
 | 
			
		||||
 | 
			
		||||
## Creating a Quanto Quantized file
 | 
			
		||||
If you launch the app with the *--save-quantized* switch, WanGP will create a quantized file in the **ckpts** subfolder just after the model has been loaded. Please note that the model will *bf16* or *fp16* quantized depending on what you chose in the configuration menu.
 | 
			
		||||
 | 
			
		||||
1) Make sure that in the finetune definition json file there is only a URL or filepath that points to the non quantized model
 | 
			
		||||
2) Launch WanGP *python wgp.py --save-quantized*
 | 
			
		||||
3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16*
 | 
			
		||||
4) Launch a generation (settings used do not matter). As soon as the model is loaded, a new quantized model  will be created in the **ckpts** subfolder it doesn't already exist.
 | 
			
		||||
4) Launch a video generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder if it doesn't already exist.
 | 
			
		||||
5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]*
 | 
			
		||||
6) Restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
 | 
			
		||||
7) Launch a new generation an verify in the terminal window that the right quantized model is loaded
 | 
			
		||||
8) In order to share the finetune definition file will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
 | 
			
		||||
6) Remove *--save-quantized*, restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
 | 
			
		||||
7) Launch a new generation and verify in the terminal window that the right quantized model is loaded
 | 
			
		||||
8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
 | 
			
		||||
 | 
			
		||||
You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded.
 | 
			
		||||
 | 
			
		||||
Wan models supports both *fp16* and *bf16* data types albeit *fp16* delivers in theory better quality. On the contrary Hunyuan and LTXV supports only *bf16*.
 | 
			
		||||
@ -949,11 +949,11 @@ class HunyuanVideoPipeline(DiffusionPipeline):
 | 
			
		||||
        # width = width or self.transformer.config.sample_size * self.vae_scale_factor
 | 
			
		||||
        # to deal with lora scaling and other possible forward hooks
 | 
			
		||||
        trans = self.transformer
 | 
			
		||||
        if trans.enable_teacache:
 | 
			
		||||
        if trans.enable_cache:
 | 
			
		||||
            teacache_multiplier = trans.teacache_multiplier
 | 
			
		||||
            trans.accumulated_rel_l1_distance = 0
 | 
			
		||||
            trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
 | 
			
		||||
            # trans.teacache_start_step =  int(tea_cache_start_step_perc*num_inference_steps/100)
 | 
			
		||||
            # trans.cache_start_step =  int(tea_cache_start_step_perc*num_inference_steps/100)
 | 
			
		||||
        # 1. Check inputs. Raise error if not correct
 | 
			
		||||
        self.check_inputs(
 | 
			
		||||
            prompt,
 | 
			
		||||
@ -1208,7 +1208,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
 | 
			
		||||
        if ip_cfg_scale>0:
 | 
			
		||||
            latent_items += 1
 | 
			
		||||
 | 
			
		||||
        if self.transformer.enable_teacache:
 | 
			
		||||
        if self.transformer.enable_cache:
 | 
			
		||||
            self.transformer.previous_residual = [None] * latent_items
 | 
			
		||||
 | 
			
		||||
        # if is_progress_bar:
 | 
			
		||||
 | 
			
		||||
@ -934,7 +934,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
 | 
			
		||||
 | 
			
		||||
        transformer = self.transformer
 | 
			
		||||
 | 
			
		||||
        if transformer.enable_teacache:
 | 
			
		||||
        if transformer.enable_cache:
 | 
			
		||||
            teacache_multiplier = transformer.teacache_multiplier
 | 
			
		||||
            transformer.accumulated_rel_l1_distance = 0
 | 
			
		||||
            transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
 | 
			
		||||
@ -1136,7 +1136,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
 | 
			
		||||
        if self._interrupt:
 | 
			
		||||
            return [None]
 | 
			
		||||
 | 
			
		||||
        if transformer.enable_teacache:
 | 
			
		||||
        if transformer.enable_cache:
 | 
			
		||||
            cache_size = round( infer_length / frames_per_batch )
 | 
			
		||||
            transformer.previous_residual = [None] * latent_items
 | 
			
		||||
            cache_all_previous_residual =  [None] * latent_items
 | 
			
		||||
@ -1180,7 +1180,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
 | 
			
		||||
                    img_ref_len = (latent_model_input.shape[-1] // 2)  * (latent_model_input.shape[-2] // 2) * ( 1) 
 | 
			
		||||
                    img_all_len = (latents_all.shape[-1] // 2)  * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
 | 
			
		||||
 | 
			
		||||
                    if transformer.enable_teacache and cache_size > 1:
 | 
			
		||||
                    if transformer.enable_cache and cache_size > 1:
 | 
			
		||||
                        for l in range(latent_items):
 | 
			
		||||
                            if cache_all_previous_residual[l] != None:
 | 
			
		||||
                                bsz = cache_all_previous_residual[l].shape[0]
 | 
			
		||||
@ -1297,7 +1297,7 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
 | 
			
		||||
                        pred_latents[:, :, p] += latents[:, :, iii]
 | 
			
		||||
                        counter[:, :, p] += 1
 | 
			
		||||
 | 
			
		||||
                    if transformer.enable_teacache and cache_size > 1:
 | 
			
		||||
                    if transformer.enable_cache and cache_size > 1:
 | 
			
		||||
                        for l in range(latent_items):
 | 
			
		||||
                            if transformer.previous_residual[l] != None:
 | 
			
		||||
                                bsz = transformer.previous_residual[l].shape[0]
 | 
			
		||||
 | 
			
		||||
@ -922,7 +922,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
 | 
			
		||||
        freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        if self.enable_teacache:
 | 
			
		||||
        if self.enable_cache:
 | 
			
		||||
            if x_id == 0:
 | 
			
		||||
                self.should_calc = True
 | 
			
		||||
                inp = img[0:1] 
 | 
			
		||||
@ -932,7 +932,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
 | 
			
		||||
                normed_inp = normed_inp.to(torch.bfloat16)
 | 
			
		||||
                modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
 | 
			
		||||
                del normed_inp, img_mod1_shift, img_mod1_scale
 | 
			
		||||
                if step_no <= self.teacache_start_step or step_no == self.num_steps-1:
 | 
			
		||||
                if step_no <= self.cache_start_step or step_no == self.num_steps-1:
 | 
			
		||||
                    self.accumulated_rel_l1_distance = 0
 | 
			
		||||
                else: 
 | 
			
		||||
                    coefficients = [7.33226126e+02, -4.01131952e+02,  6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
 | 
			
		||||
@ -950,7 +950,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
 | 
			
		||||
        if not self.should_calc:
 | 
			
		||||
            img += self.previous_residual[x_id]
 | 
			
		||||
        else:
 | 
			
		||||
            if self.enable_teacache:            
 | 
			
		||||
            if self.enable_cache:            
 | 
			
		||||
                self.previous_residual[x_id] = None
 | 
			
		||||
                ori_img = img[0:1].clone()
 | 
			
		||||
            # --------------------- Pass through DiT blocks ------------------------
 | 
			
		||||
@ -1014,7 +1014,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
 | 
			
		||||
                    single_block_args = None
 | 
			
		||||
 | 
			
		||||
            # img = x[:, :img_seq_len, ...]
 | 
			
		||||
            if self.enable_teacache:
 | 
			
		||||
            if self.enable_cache:
 | 
			
		||||
                if len(img) > 1:
 | 
			
		||||
                    self.previous_residual[0] = torch.empty_like(img)
 | 
			
		||||
                    for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
 | 
			
		||||
 | 
			
		||||
@ -551,8 +551,8 @@ def main():
 | 
			
		||||
 | 
			
		||||
    # Setup tea cache if needed
 | 
			
		||||
    trans = wan_model.model
 | 
			
		||||
    trans.enable_teacache = (args.teacache > 0)
 | 
			
		||||
    if trans.enable_teacache:
 | 
			
		||||
    trans.enable_cache = (args.teacache > 0)
 | 
			
		||||
    if trans.enable_cache:
 | 
			
		||||
        if "480p" in args.transformer_file:
 | 
			
		||||
            # example from your code
 | 
			
		||||
            trans.coefficients = [-3.02331670e+02,  2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
 | 
			
		||||
@ -582,10 +582,10 @@ def main():
 | 
			
		||||
    enable_riflex = args.riflex
 | 
			
		||||
 | 
			
		||||
    # If teacache => reset counters
 | 
			
		||||
    if trans.enable_teacache:
 | 
			
		||||
    if trans.enable_cache:
 | 
			
		||||
        trans.teacache_counter = 0
 | 
			
		||||
        trans.teacache_multiplier = args.teacache
 | 
			
		||||
        trans.teacache_start_step = int(args.teacache_start * args.steps / 100.0)
 | 
			
		||||
        trans.cache_start_step = int(args.teacache_start * args.steps / 100.0)
 | 
			
		||||
        trans.num_steps = args.steps
 | 
			
		||||
        trans.teacache_skipped_steps = 0
 | 
			
		||||
        trans.previous_residual_uncond = None
 | 
			
		||||
@ -655,7 +655,7 @@ def main():
 | 
			
		||||
        raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
 | 
			
		||||
 | 
			
		||||
    # If teacache was used, we can see how many steps were skipped
 | 
			
		||||
    if trans.enable_teacache:
 | 
			
		||||
    if trans.enable_cache:
 | 
			
		||||
        print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
 | 
			
		||||
 | 
			
		||||
    # Save result
 | 
			
		||||
 | 
			
		||||
@ -78,7 +78,7 @@ class DTT2V:
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
        if save_quantized:            
 | 
			
		||||
            from wan.utils.utils import save_quantized_model
 | 
			
		||||
            save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
 | 
			
		||||
            save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
 | 
			
		||||
 | 
			
		||||
        self.scheduler = FlowUniPCMultistepScheduler()
 | 
			
		||||
 | 
			
		||||
@ -316,7 +316,7 @@ class DTT2V:
 | 
			
		||||
        updated_num_steps=  len(step_matrix)
 | 
			
		||||
        if callback != None:
 | 
			
		||||
            callback(-1, None, True, override_num_inference_steps = updated_num_steps)
 | 
			
		||||
        if self.model.enable_teacache:
 | 
			
		||||
        if self.model.enable_cache:
 | 
			
		||||
            x_count = 2 if self.do_classifier_free_guidance else 1
 | 
			
		||||
            self.model.previous_residual = [None] * x_count 
 | 
			
		||||
            time_steps_comb = []
 | 
			
		||||
@ -327,7 +327,7 @@ class DTT2V:
 | 
			
		||||
                if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
 | 
			
		||||
                    timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise
 | 
			
		||||
                time_steps_comb.append(timestep)
 | 
			
		||||
            self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
 | 
			
		||||
            self.model.compute_teacache_threshold(self.model.cache_start_step, time_steps_comb, self.model.teacache_multiplier)
 | 
			
		||||
            del time_steps_comb
 | 
			
		||||
        from mmgp import offload
 | 
			
		||||
        freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False) 
 | 
			
		||||
 | 
			
		||||
@ -116,7 +116,7 @@ class WanI2V:
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
        if save_quantized:            
 | 
			
		||||
            from wan.utils.utils import save_quantized_model
 | 
			
		||||
            save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
 | 
			
		||||
            save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self.sample_neg_prompt = config.sample_neg_prompt
 | 
			
		||||
@ -317,9 +317,9 @@ class WanI2V:
 | 
			
		||||
            "audio_context_lens": audio_context_lens,
 | 
			
		||||
            }) 
 | 
			
		||||
 | 
			
		||||
        if self.model.enable_teacache:
 | 
			
		||||
        if self.model.enable_cache:
 | 
			
		||||
            self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
 | 
			
		||||
            self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
 | 
			
		||||
            self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.teacache_multiplier)
 | 
			
		||||
 | 
			
		||||
        # self.model.to(self.device)
 | 
			
		||||
        if callback != None:
 | 
			
		||||
 | 
			
		||||
@ -194,6 +194,11 @@ def pay_attention(
 | 
			
		||||
 | 
			
		||||
    q = q.to(v.dtype)
 | 
			
		||||
    k = k.to(v.dtype)
 | 
			
		||||
 | 
			
		||||
    if attn == "chipmunk":
 | 
			
		||||
        from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
			
		||||
        from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
 | 
			
		||||
 | 
			
		||||
    if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
 | 
			
		||||
        assert attention_mask == None
 | 
			
		||||
        # Poor's man var k len attention
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,8 @@ from typing import Union,Optional
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
from .attention import pay_attention
 | 
			
		||||
from torch.backends.cuda import sdp_kernel
 | 
			
		||||
# from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
			
		||||
# from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
 | 
			
		||||
 | 
			
		||||
__all__ = ['WanModel']
 | 
			
		||||
 | 
			
		||||
@ -172,6 +174,11 @@ class WanSelfAttention(nn.Module):
 | 
			
		||||
        self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
        self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
 | 
			
		||||
        # Only initialize SparseDiffAttn if this is not a subclass initialization
 | 
			
		||||
        # if self.__class__ == WanSelfAttention:
 | 
			
		||||
        #     layer_num, layer_counter = LayerCounter.build_for_layer(is_attn_sparse=True, is_mlp_sparse=False)
 | 
			
		||||
        #     self.attn = SparseDiffAttn(layer_num, layer_counter)
 | 
			
		||||
 | 
			
		||||
    def forward(self, xlist, grid_sizes, freqs, block_mask = None):
 | 
			
		||||
        r"""
 | 
			
		||||
        Args:
 | 
			
		||||
@ -197,7 +204,10 @@ class WanSelfAttention(nn.Module):
 | 
			
		||||
        del q,k
 | 
			
		||||
 | 
			
		||||
        q,k = apply_rotary_emb(qklist, freqs, head_first=False)
 | 
			
		||||
        if block_mask == None:
 | 
			
		||||
        chipmunk = offload.shared_state["_chipmunk"] 
 | 
			
		||||
        if chipmunk:
 | 
			
		||||
            x = self.attn(q, k, v)
 | 
			
		||||
        elif block_mask == None:
 | 
			
		||||
            qkv_list = [q,k,v]
 | 
			
		||||
            del q,k,v
 | 
			
		||||
            x = pay_attention(
 | 
			
		||||
@ -954,6 +964,16 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                x_list[i] = x
 | 
			
		||||
        x, y = None, None
 | 
			
		||||
 | 
			
		||||
        offload.shared_state["_chipmunk"] = False
 | 
			
		||||
        chipmunk = offload.shared_state["_chipmunk"] 
 | 
			
		||||
        if chipmunk:
 | 
			
		||||
            voxel_shape = (4, 6, 8)
 | 
			
		||||
            for x in x_list:
 | 
			
		||||
                from src.chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
 | 
			
		||||
                x = x.unsqueeze(-1)
 | 
			
		||||
                x_og_shape = x.shape
 | 
			
		||||
                x = voxel_chunk_no_padding(x, voxel_shape).squeeze(-1).transpose(1, 2)
 | 
			
		||||
            x = None
 | 
			
		||||
 | 
			
		||||
        block_mask = None
 | 
			
		||||
        if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
 | 
			
		||||
@ -1027,11 +1047,11 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
            del c
 | 
			
		||||
 | 
			
		||||
        should_calc = True
 | 
			
		||||
        if self.enable_teacache: 
 | 
			
		||||
        if self.enable_cache: 
 | 
			
		||||
            if x_id != 0:
 | 
			
		||||
                should_calc = self.should_calc
 | 
			
		||||
            else:
 | 
			
		||||
                if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
 | 
			
		||||
                if current_step <= self.cache_start_step or current_step == self.num_steps-1:
 | 
			
		||||
                    should_calc = True
 | 
			
		||||
                    self.accumulated_rel_l1_distance = 0
 | 
			
		||||
                else:
 | 
			
		||||
@ -1057,7 +1077,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                x += self.previous_residual[x_id]
 | 
			
		||||
            x = None
 | 
			
		||||
        else:
 | 
			
		||||
            if self.enable_teacache:
 | 
			
		||||
            if self.enable_cache:
 | 
			
		||||
                if joint_pass:
 | 
			
		||||
                    self.previous_residual = [ None ] * len(self.previous_residual)
 | 
			
		||||
                else:
 | 
			
		||||
@ -1084,7 +1104,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                        del x
 | 
			
		||||
                    del context, hints
 | 
			
		||||
 | 
			
		||||
            if self.enable_teacache:
 | 
			
		||||
            if self.enable_cache:
 | 
			
		||||
                if joint_pass:
 | 
			
		||||
                    for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
 | 
			
		||||
                        if i == 0 or is_source and i != last_x_idx  :
 | 
			
		||||
@ -1101,6 +1121,10 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                residual, ori_hidden_states = None, None
 | 
			
		||||
 | 
			
		||||
        for i, x in enumerate(x_list):
 | 
			
		||||
            if chipmunk:
 | 
			
		||||
                x = reverse_voxel_chunk_no_padding(x.transpose(1, 2).unsqueeze(-1), x_og_shape, voxel_shape).squeeze(-1)
 | 
			
		||||
                x = x.flatten(2).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
            # head
 | 
			
		||||
            x = self.head(x, e)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -101,7 +101,7 @@ class WanT2V:
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
        if save_quantized:            
 | 
			
		||||
            from wan.utils.utils import save_quantized_model
 | 
			
		||||
            save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
 | 
			
		||||
            save_quantized_model(self.model, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
 | 
			
		||||
 | 
			
		||||
        self.sample_neg_prompt = config.sample_neg_prompt
 | 
			
		||||
 | 
			
		||||
@ -458,13 +458,24 @@ class WanT2V:
 | 
			
		||||
                z_reactive = [  zz[0:16, 0:overlapped_latents_size + ref_images_count].clone() for zz in z]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if self.model.enable_teacache:
 | 
			
		||||
        if self.model.enable_cache:
 | 
			
		||||
            x_count = 3 if phantom else 2
 | 
			
		||||
            self.model.previous_residual = [None] * x_count 
 | 
			
		||||
            self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
 | 
			
		||||
            self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.teacache_multiplier)
 | 
			
		||||
        if callback != None:
 | 
			
		||||
            callback(-1, None, True)
 | 
			
		||||
        prev = 50/1000
 | 
			
		||||
 | 
			
		||||
        # seq_shape = (21, 45, 80)
 | 
			
		||||
        # local_heads_num = 40 #12 for 1.3B
 | 
			
		||||
 | 
			
		||||
        # self.model.blocks[0].self_attn.attn.initialize_static_mask(
 | 
			
		||||
        #     seq_shape=seq_shape,
 | 
			
		||||
        #     txt_len=0,
 | 
			
		||||
        #     local_heads_num=local_heads_num,
 | 
			
		||||
        #     device='cuda'
 | 
			
		||||
        # )
 | 
			
		||||
        # self.model.blocks[0].self_attn.attn.layer_counter.reset()
 | 
			
		||||
 | 
			
		||||
        for i, t in enumerate(tqdm(timesteps)):
 | 
			
		||||
 | 
			
		||||
            timestep = [t]
 | 
			
		||||
 | 
			
		||||
@ -338,17 +338,19 @@ def create_progress_hook(filename):
 | 
			
		||||
    return hook
 | 
			
		||||
 | 
			
		||||
def save_quantized_model(model, model_filename, dtype,  config_file):
 | 
			
		||||
    if "quanto" in model_filename:
 | 
			
		||||
        return
 | 
			
		||||
    from mmgp import offload
 | 
			
		||||
    if dtype == torch.bfloat16:
 | 
			
		||||
         model_filename =  model_filename.replace("fp16", "bf16")
 | 
			
		||||
         model_filename =  model_filename.replace("fp16", "bf16").replace("FP16", "bf16")
 | 
			
		||||
    elif dtype == torch.float16:
 | 
			
		||||
         model_filename =  model_filename.replace("bf16", "fp16")
 | 
			
		||||
         model_filename =  model_filename.replace("bf16", "fp16").replace("BF16", "bf16")
 | 
			
		||||
 | 
			
		||||
    if "_fp16" in model_filename:
 | 
			
		||||
        model_filename = model_filename.replace("_fp16", "_quanto_fp16_int8")
 | 
			
		||||
    elif "_bf16" in model_filename:
 | 
			
		||||
        model_filename = model_filename.replace("_bf16", "_quanto_bf16_int8")
 | 
			
		||||
    else:
 | 
			
		||||
    for rep in ["mfp16", "fp16", "mbf16", "bf16"]:
 | 
			
		||||
        if "_" + rep in model_filename:
 | 
			
		||||
            model_filename = model_filename.replace("_" + rep, "_quanto_" + rep + "_int8")
 | 
			
		||||
            break
 | 
			
		||||
    if not "quanto" in model_filename:
 | 
			
		||||
        pos = model_filename.rfind(".")
 | 
			
		||||
        model_filename =  model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:] 
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										48
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										48
									
								
								wgp.py
									
									
									
									
									
								
							@ -1624,7 +1624,7 @@ def get_model_family(model_type):
 | 
			
		||||
 | 
			
		||||
def test_class_i2v(model_type):
 | 
			
		||||
    model_type = get_base_model_type(model_type)
 | 
			
		||||
    return model_type in ["i2v", "fun_inp_1.3B", "fun_inp", "flf2v_720p",  "fantasy", "hunyuan_i2v" ] 
 | 
			
		||||
    return model_type in ["i2v", "i2v_720p", "fun_inp_1.3B", "fun_inp", "flf2v_720p",  "fantasy", "hunyuan_i2v" ] 
 | 
			
		||||
 | 
			
		||||
def get_model_name(model_type, description_container = [""]):
 | 
			
		||||
    finetune_def = get_model_finetune_def(model_type)
 | 
			
		||||
@ -1731,19 +1731,19 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
 | 
			
		||||
        raw_filename = choices[0]
 | 
			
		||||
    else:
 | 
			
		||||
        if quantization in ("int8", "fp8"):
 | 
			
		||||
            sub_choices = [ name for name in choices if quantization in name]
 | 
			
		||||
            sub_choices = [ name for name in choices if quantization in name or quantization.upper() in name]
 | 
			
		||||
        else:
 | 
			
		||||
            sub_choices = [ name for name in choices if "quanto" not in name]
 | 
			
		||||
 | 
			
		||||
        if len(sub_choices) > 0:
 | 
			
		||||
            dtype_str = "fp16" if dtype == torch.float16 else "bf16"
 | 
			
		||||
            new_sub_choices = [ name for name in sub_choices if dtype_str in name]
 | 
			
		||||
            new_sub_choices = [ name for name in sub_choices if dtype_str in name or dtype_str.upper() in name]
 | 
			
		||||
            sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices
 | 
			
		||||
            raw_filename = sub_choices[0]
 | 
			
		||||
        else:
 | 
			
		||||
            raw_filename = choices[0]
 | 
			
		||||
 | 
			
		||||
    if dtype == torch.float16 and not "fp16" in raw_filename and model_family == "wan" and finetune_def == None :
 | 
			
		||||
    if dtype == torch.float16 and not any("fp16","FP16") in raw_filename and model_family == "wan" and finetune_def == None :
 | 
			
		||||
        if "quanto_int8" in raw_filename:
 | 
			
		||||
            raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
 | 
			
		||||
        elif "quanto_bf16_int8" in raw_filename:
 | 
			
		||||
@ -1753,6 +1753,8 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
 | 
			
		||||
    return raw_filename
 | 
			
		||||
 | 
			
		||||
def get_transformer_dtype(model_family, transformer_dtype_policy):
 | 
			
		||||
    if not isinstance(transformer_dtype_policy, str):
 | 
			
		||||
        return transformer_dtype_policy
 | 
			
		||||
    if len(transformer_dtype_policy) == 0:
 | 
			
		||||
        if not bfloat16_supported:
 | 
			
		||||
            return torch.float16
 | 
			
		||||
@ -2290,19 +2292,25 @@ def get_transformer_model(model):
 | 
			
		||||
 | 
			
		||||
def load_models(model_type):
 | 
			
		||||
    global transformer_type, transformer_loras_filenames
 | 
			
		||||
    model_filename = get_model_filename(model_type=model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) 
 | 
			
		||||
    base_model_type = get_base_model_type(model_type)
 | 
			
		||||
    finetune_def = get_model_finetune_def(model_type)
 | 
			
		||||
    quantizeTransformer =  finetune_def !=None and  transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename 
 | 
			
		||||
        
 | 
			
		||||
    model_family = get_model_family(model_type)
 | 
			
		||||
    perc_reserved_mem_max = args.perc_reserved_mem_max
 | 
			
		||||
    preload =int(args.preload)
 | 
			
		||||
    save_quantized = args.save_quantized
 | 
			
		||||
    save_quantized = args.save_quantized and finetune_def != None
 | 
			
		||||
    model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) 
 | 
			
		||||
    if save_quantized and "quanto" in model_filename:
 | 
			
		||||
        save_quantized = False
 | 
			
		||||
        print("Need to provide a non quantized model to create a quantized model to be saved") 
 | 
			
		||||
    quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename         
 | 
			
		||||
    model_family = get_model_family(model_type)
 | 
			
		||||
    transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
 | 
			
		||||
    if quantizeTransformer or "quanto" in model_filename:
 | 
			
		||||
        transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype
 | 
			
		||||
        transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype
 | 
			
		||||
    perc_reserved_mem_max = args.perc_reserved_mem_max
 | 
			
		||||
    if preload == 0:
 | 
			
		||||
        preload = server_config.get("preload_in_VRAM", 0)
 | 
			
		||||
    new_transformer_loras_filenames = None
 | 
			
		||||
    dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) 
 | 
			
		||||
    dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype) 
 | 
			
		||||
    new_transformer_loras_filenames = [model_filename]  if "_lora" in model_filename else None
 | 
			
		||||
    
 | 
			
		||||
    model_file_list = dependent_models + [model_filename]
 | 
			
		||||
@ -2310,15 +2318,11 @@ def load_models(model_type):
 | 
			
		||||
    new_transformer_filename = model_file_list[-1] 
 | 
			
		||||
    if finetune_def != None:
 | 
			
		||||
        for module_type in finetune_def.get("modules", []):
 | 
			
		||||
            model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype_policy))
 | 
			
		||||
            model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype))
 | 
			
		||||
            model_type_list.append(module_type)
 | 
			
		||||
 | 
			
		||||
    for filename, file_model_type in zip(model_file_list, model_type_list): 
 | 
			
		||||
        download_models(filename, file_model_type)
 | 
			
		||||
    transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
 | 
			
		||||
    if quantizeTransformer:
 | 
			
		||||
        transformer_dtype = torch.bfloat16 if "bf16" in model_filename else transformer_dtype
 | 
			
		||||
        transformer_dtype = torch.float16 if "fp16" in model_filename else transformer_dtype
 | 
			
		||||
    VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
 | 
			
		||||
    mixed_precision_transformer =  server_config.get("mixed_precision","0") == "1"
 | 
			
		||||
    transformer_filename = None
 | 
			
		||||
@ -2364,7 +2368,7 @@ def load_models(model_type):
 | 
			
		||||
        prompt_enhancer_llm_tokenizer = None
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
    offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)  
 | 
			
		||||
    offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)  
 | 
			
		||||
    if len(args.gpu) > 0:
 | 
			
		||||
        torch.set_default_device(args.gpu)
 | 
			
		||||
    transformer_loras_filenames = new_transformer_loras_filenames
 | 
			
		||||
@ -3092,11 +3096,11 @@ def generate_video(
 | 
			
		||||
    # TeaCache
 | 
			
		||||
    if args.teacache > 0:
 | 
			
		||||
        tea_cache_setting = args.teacache 
 | 
			
		||||
    trans.enable_teacache = tea_cache_setting > 0
 | 
			
		||||
    if trans.enable_teacache:
 | 
			
		||||
    trans.enable_cache = tea_cache_setting > 0
 | 
			
		||||
    if trans.enable_cache:
 | 
			
		||||
        trans.teacache_multiplier = tea_cache_setting
 | 
			
		||||
        trans.rel_l1_thresh = 0
 | 
			
		||||
        trans.teacache_start_step =  int(tea_cache_start_step_perc*num_inference_steps/100)
 | 
			
		||||
        trans.cache_start_step =  int(tea_cache_start_step_perc*num_inference_steps/100)
 | 
			
		||||
        if get_model_family(model_type) == "wan":
 | 
			
		||||
            if image2video:
 | 
			
		||||
                if '720p' in model_filename:
 | 
			
		||||
@ -3323,7 +3327,7 @@ def generate_video(
 | 
			
		||||
            progress_args = [0, merge_status_context(status, "Encoding Prompt")]
 | 
			
		||||
            send_cmd("progress", progress_args)
 | 
			
		||||
 | 
			
		||||
            if trans.enable_teacache:
 | 
			
		||||
            if trans.enable_cache:
 | 
			
		||||
                trans.teacache_counter = 0
 | 
			
		||||
                trans.num_steps = num_inference_steps                
 | 
			
		||||
                trans.teacache_skipped_steps = 0    
 | 
			
		||||
@ -3412,7 +3416,7 @@ def generate_video(
 | 
			
		||||
                trans.previous_residual = None
 | 
			
		||||
                trans.previous_modulated_input = None
 | 
			
		||||
 | 
			
		||||
            if trans.enable_teacache:
 | 
			
		||||
            if trans.enable_cache:
 | 
			
		||||
                print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
 | 
			
		||||
 | 
			
		||||
            if samples != None:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user