mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	a few more fixes
This commit is contained in:
		
							parent
							
								
									743f6911a1
								
							
						
					
					
						commit
						95727e618f
					
				@ -9,8 +9,7 @@
 | 
			
		||||
        "URLs": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
 | 
			
		||||
            "ckpts/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors"
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
        "auto_quantize": true
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
@ -11,8 +11,6 @@ from typing import Union,Optional
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
from .attention import pay_attention
 | 
			
		||||
from torch.backends.cuda import sdp_kernel
 | 
			
		||||
# from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
			
		||||
# from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
 | 
			
		||||
 | 
			
		||||
__all__ = ['WanModel']
 | 
			
		||||
 | 
			
		||||
@ -156,7 +154,8 @@ class WanSelfAttention(nn.Module):
 | 
			
		||||
                 num_heads,
 | 
			
		||||
                 window_size=(-1, -1),
 | 
			
		||||
                 qk_norm=True,
 | 
			
		||||
                 eps=1e-6):
 | 
			
		||||
                 eps=1e-6,
 | 
			
		||||
                 block_no=0):
 | 
			
		||||
        assert dim % num_heads == 0
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
@ -165,6 +164,7 @@ class WanSelfAttention(nn.Module):
 | 
			
		||||
        self.window_size = window_size
 | 
			
		||||
        self.qk_norm = qk_norm
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        self.block_no = block_no
 | 
			
		||||
 | 
			
		||||
        # layers
 | 
			
		||||
        self.q = nn.Linear(dim, dim)
 | 
			
		||||
@ -174,10 +174,6 @@ class WanSelfAttention(nn.Module):
 | 
			
		||||
        self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
        self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
 | 
			
		||||
        # Only initialize SparseDiffAttn if this is not a subclass initialization
 | 
			
		||||
        # if self.__class__ == WanSelfAttention:
 | 
			
		||||
        #     layer_num, layer_counter = LayerCounter.build_for_layer(is_attn_sparse=True, is_mlp_sparse=False)
 | 
			
		||||
        #     self.attn = SparseDiffAttn(layer_num, layer_counter)
 | 
			
		||||
 | 
			
		||||
    def forward(self, xlist, grid_sizes, freqs, block_mask = None):
 | 
			
		||||
        r"""
 | 
			
		||||
@ -204,9 +200,14 @@ class WanSelfAttention(nn.Module):
 | 
			
		||||
        del q,k
 | 
			
		||||
 | 
			
		||||
        q,k = apply_rotary_emb(qklist, freqs, head_first=False)
 | 
			
		||||
        chipmunk = offload.shared_state["_chipmunk"] 
 | 
			
		||||
        if chipmunk:
 | 
			
		||||
            x = self.attn(q, k, v)
 | 
			
		||||
        chipmunk = offload.shared_state.get("_chipmunk", False) 
 | 
			
		||||
        if chipmunk and self.__class__ == WanSelfAttention:
 | 
			
		||||
            q = q.transpose(1,2)
 | 
			
		||||
            k = k.transpose(1,2)
 | 
			
		||||
            v = v.transpose(1,2)
 | 
			
		||||
            attn_layers = offload.shared_state["_chipmunk_layers"]
 | 
			
		||||
            x = attn_layers[self.block_no](q, k, v)
 | 
			
		||||
            x = x.transpose(1,2)
 | 
			
		||||
        elif block_mask == None:
 | 
			
		||||
            qkv_list = [q,k,v]
 | 
			
		||||
            del q,k,v
 | 
			
		||||
@ -372,7 +373,8 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
                 qk_norm=True,
 | 
			
		||||
                 cross_attn_norm=False,
 | 
			
		||||
                 eps=1e-6,
 | 
			
		||||
                 block_id=None
 | 
			
		||||
                 block_id=None,
 | 
			
		||||
                 block_no = 0
 | 
			
		||||
                 ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
@ -382,11 +384,12 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
        self.qk_norm = qk_norm
 | 
			
		||||
        self.cross_attn_norm = cross_attn_norm
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        self.block_no = block_no
 | 
			
		||||
 | 
			
		||||
        # layers
 | 
			
		||||
        self.norm1 = WanLayerNorm(dim, eps)
 | 
			
		||||
        self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
 | 
			
		||||
                                          eps)
 | 
			
		||||
                                          eps, block_no= block_no)
 | 
			
		||||
        self.norm3 = WanLayerNorm(
 | 
			
		||||
            dim, eps,
 | 
			
		||||
            elementwise_affine=True) if cross_attn_norm else nn.Identity()
 | 
			
		||||
@ -394,7 +397,8 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
                                                                      num_heads,
 | 
			
		||||
                                                                      (-1, -1),
 | 
			
		||||
                                                                      qk_norm,
 | 
			
		||||
                                                                      eps)
 | 
			
		||||
                                                                      eps, 
 | 
			
		||||
                                                                      block_no)
 | 
			
		||||
        self.norm2 = WanLayerNorm(dim, eps)
 | 
			
		||||
        self.ffn = nn.Sequential(
 | 
			
		||||
            nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
 | 
			
		||||
@ -599,6 +603,27 @@ class MLPProj(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
    def setup_chipmunk(self):
 | 
			
		||||
        from chipmunk.util import LayerCounter
 | 
			
		||||
        from chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
			
		||||
        seq_shape = (21, 45, 80)
 | 
			
		||||
        chipmunk_layers =[]
 | 
			
		||||
        for i in range(self.num_layers):
 | 
			
		||||
            layer_num, layer_counter = LayerCounter.build_for_layer(is_attn_sparse=True, is_mlp_sparse=False)            
 | 
			
		||||
            chipmunk_layers.append( SparseDiffAttn(layer_num, layer_counter))
 | 
			
		||||
        offload.shared_state["_chipmunk_layers"] = chipmunk_layers
 | 
			
		||||
 | 
			
		||||
        chipmunk_layers[0].initialize_static_mask(
 | 
			
		||||
            seq_shape=seq_shape,
 | 
			
		||||
            txt_len=0,
 | 
			
		||||
            local_heads_num=self.num_heads,
 | 
			
		||||
            device='cuda'
 | 
			
		||||
        )
 | 
			
		||||
        chipmunk_layers[0].layer_counter.reset()
 | 
			
		||||
 | 
			
		||||
    def release_chipmunk(self):
 | 
			
		||||
        offload.shared_state["_chipmunk_layers"] = None
 | 
			
		||||
 | 
			
		||||
    def preprocess_loras(self, model_type, sd):
 | 
			
		||||
 | 
			
		||||
        first = next(iter(sd), None)
 | 
			
		||||
@ -766,8 +791,8 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
            cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
 | 
			
		||||
            self.blocks = nn.ModuleList([
 | 
			
		||||
                WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
 | 
			
		||||
                                window_size, qk_norm, cross_attn_norm, eps)
 | 
			
		||||
                for _ in range(num_layers)
 | 
			
		||||
                                window_size, qk_norm, cross_attn_norm, eps, block_no =i)
 | 
			
		||||
                for i in range(num_layers)
 | 
			
		||||
            ])
 | 
			
		||||
 | 
			
		||||
        # head
 | 
			
		||||
@ -791,7 +816,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
            # blocks
 | 
			
		||||
            self.blocks = nn.ModuleList([
 | 
			
		||||
                WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
 | 
			
		||||
                                    self.cross_attn_norm, self.eps,
 | 
			
		||||
                                    self.cross_attn_norm, self.eps, block_no =i,
 | 
			
		||||
                                    block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
 | 
			
		||||
                for i in range(self.num_layers)
 | 
			
		||||
            ])
 | 
			
		||||
@ -944,6 +969,10 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
        if torch.is_tensor(freqs) and freqs.device != device:
 | 
			
		||||
            freqs = freqs.to(device)
 | 
			
		||||
 | 
			
		||||
        chipmunk = offload.shared_state.get("_chipmunk", False) 
 | 
			
		||||
        if chipmunk:
 | 
			
		||||
            from src.chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
 | 
			
		||||
            voxel_shape = (4, 6, 8)
 | 
			
		||||
 | 
			
		||||
        x_list = x
 | 
			
		||||
        joint_pass = len(x_list) > 1
 | 
			
		||||
@ -960,20 +989,15 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                # embeddings
 | 
			
		||||
                x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
 | 
			
		||||
                grid_sizes = x.shape[2:]
 | 
			
		||||
                x = x.flatten(2).transpose(1, 2)
 | 
			
		||||
                if chipmunk:
 | 
			
		||||
                    x = x.unsqueeze(-1)
 | 
			
		||||
                    x_og_shape = x.shape
 | 
			
		||||
                    x = voxel_chunk_no_padding(x, voxel_shape).squeeze(-1).transpose(1, 2)
 | 
			
		||||
                else:
 | 
			
		||||
                    x = x.flatten(2).transpose(1, 2)
 | 
			
		||||
                x_list[i] = x
 | 
			
		||||
        x, y = None, None
 | 
			
		||||
 | 
			
		||||
        offload.shared_state["_chipmunk"] = False
 | 
			
		||||
        chipmunk = offload.shared_state["_chipmunk"] 
 | 
			
		||||
        if chipmunk:
 | 
			
		||||
            voxel_shape = (4, 6, 8)
 | 
			
		||||
            for x in x_list:
 | 
			
		||||
                from src.chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
 | 
			
		||||
                x = x.unsqueeze(-1)
 | 
			
		||||
                x_og_shape = x.shape
 | 
			
		||||
                x = voxel_chunk_no_padding(x, voxel_shape).squeeze(-1).transpose(1, 2)
 | 
			
		||||
            x = None
 | 
			
		||||
 | 
			
		||||
        block_mask = None
 | 
			
		||||
        if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
 | 
			
		||||
 | 
			
		||||
@ -108,7 +108,7 @@ class WanT2V:
 | 
			
		||||
 | 
			
		||||
        self.sample_neg_prompt = config.sample_neg_prompt
 | 
			
		||||
 | 
			
		||||
        if "Vace" in model_filename[-1]:
 | 
			
		||||
        if base_model_type in ["vace_14B", "vace_1.3B"]:
 | 
			
		||||
            self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
 | 
			
		||||
                                            min_area=480*832,
 | 
			
		||||
                                            max_area=480*832,
 | 
			
		||||
@ -492,16 +492,10 @@ class WanT2V:
 | 
			
		||||
        if callback != None:
 | 
			
		||||
            callback(-1, None, True)
 | 
			
		||||
 | 
			
		||||
        # seq_shape = (21, 45, 80)
 | 
			
		||||
        # local_heads_num = 40 #12 for 1.3B
 | 
			
		||||
 | 
			
		||||
        # self.model.blocks[0].self_attn.attn.initialize_static_mask(
 | 
			
		||||
        #     seq_shape=seq_shape,
 | 
			
		||||
        #     txt_len=0,
 | 
			
		||||
        #     local_heads_num=local_heads_num,
 | 
			
		||||
        #     device='cuda'
 | 
			
		||||
        # )
 | 
			
		||||
        # self.model.blocks[0].self_attn.attn.layer_counter.reset()
 | 
			
		||||
        offload.shared_state["_chipmunk"] =  False
 | 
			
		||||
        chipmunk = offload.shared_state.get("_chipmunk", False)        
 | 
			
		||||
        if chipmunk:
 | 
			
		||||
            self.model.setup_chipmunk()
 | 
			
		||||
 | 
			
		||||
        for i, t in enumerate(tqdm(timesteps)):
 | 
			
		||||
 | 
			
		||||
@ -614,6 +608,9 @@ class WanT2V:
 | 
			
		||||
 | 
			
		||||
        x0 = [latents]
 | 
			
		||||
 | 
			
		||||
        if chipmunk:
 | 
			
		||||
            self.model.release_chipmunk() # need to add it at every exit when in prof
 | 
			
		||||
 | 
			
		||||
        if return_latent_slice != None:
 | 
			
		||||
            if overlapped_latents != None:
 | 
			
		||||
                # latents [:, 1:] = self.toto
 | 
			
		||||
@ -640,6 +637,4 @@ class WanT2V:
 | 
			
		||||
            module = modules_dict[f"vace_blocks.{vace_layer}"]
 | 
			
		||||
            target = modules_dict[f"blocks.{model_layer}"]
 | 
			
		||||
            setattr(target, "vace", module )
 | 
			
		||||
        delattr(model, "vace_blocks")
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
        delattr(model, "vace_blocks")
 | 
			
		||||
							
								
								
									
										28
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								wgp.py
									
									
									
									
									
								
							@ -1772,7 +1772,7 @@ def fix_settings(model_type, ui_defaults):
 | 
			
		||||
        if not "I" in video_prompt_type:  # workaround for settings corruption
 | 
			
		||||
            video_prompt_type += "I" 
 | 
			
		||||
    if model_type in ["hunyuan"]:
 | 
			
		||||
        del_in_sequence(video_prompt_type, "I")
 | 
			
		||||
        video_prompt_type = video_prompt_type.replace("I", "")
 | 
			
		||||
    ui_defaults["video_prompt_type"] = video_prompt_type
 | 
			
		||||
 | 
			
		||||
def get_default_settings(model_type):
 | 
			
		||||
@ -1990,7 +1990,9 @@ def save_quantized_model(model, model_type, model_filename, dtype,  config_file)
 | 
			
		||||
        pos = model_filename.rfind(".")
 | 
			
		||||
        model_filename =  model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:] 
 | 
			
		||||
    
 | 
			
		||||
    if not os.path.isfile(model_filename):
 | 
			
		||||
    if os.path.isfile(model_filename):
 | 
			
		||||
        print(f"There isn't any model to quantize as quantized model '{model_filename}' aready exists")
 | 
			
		||||
    else:
 | 
			
		||||
        offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file)
 | 
			
		||||
        print(f"New quantized file '{model_filename}' had been created for finetune Id '{model_type}'.")
 | 
			
		||||
        finetune_def = get_model_finetune_def(model_type)
 | 
			
		||||
@ -2360,10 +2362,18 @@ def load_models(model_type):
 | 
			
		||||
    preload =int(args.preload)
 | 
			
		||||
    save_quantized = args.save_quantized and finetune_def != None
 | 
			
		||||
    model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) 
 | 
			
		||||
    modules = finetune_def.get("modules", []) if finetune_def != None else []
 | 
			
		||||
    if save_quantized and "quanto" in model_filename:
 | 
			
		||||
        save_quantized = False
 | 
			
		||||
        print("Need to provide a non quantized model to create a quantized model to be saved") 
 | 
			
		||||
    quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename         
 | 
			
		||||
    if save_quantized and len(modules) > 0:
 | 
			
		||||
        _, model_types_no_module =  dependent_models_types = get_dependent_models(base_model_type, transformer_quantization, transformer_dtype_policy) 
 | 
			
		||||
        print(f"Unable to create a finetune quantized model as some modules are declared in the finetune definition. If your finetune includes already the module weights you can remove the 'modules' entry and try again. If not you will need also to change temporarly the model 'architecture' to an architecture that wont require the modules part ('{model_types_no_module[0] if len(model_types_no_module)>0 else ''}' ?) to quantize and then add back the original 'modules' and 'architecture' entries.")
 | 
			
		||||
        save_quantized = False
 | 
			
		||||
    quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
 | 
			
		||||
    if quantizeTransformer and len(modules) > 0:
 | 
			
		||||
        print(f"Autoquantize is not yet supported if some modules are declared")
 | 
			
		||||
        quantizeTransformer = False
 | 
			
		||||
    model_family = get_model_family(model_type)
 | 
			
		||||
    transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
 | 
			
		||||
    if quantizeTransformer or "quanto" in model_filename:
 | 
			
		||||
@ -2379,16 +2389,14 @@ def load_models(model_type):
 | 
			
		||||
    model_file_list = dependent_models + [model_filename]
 | 
			
		||||
    model_type_list = dependent_models_types + [model_type]
 | 
			
		||||
    new_transformer_filename = model_file_list[-1] 
 | 
			
		||||
    if finetune_def != None:
 | 
			
		||||
        for module_type in finetune_def.get("modules", []):
 | 
			
		||||
            model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype))
 | 
			
		||||
            model_type_list.append(module_type)
 | 
			
		||||
    for module_type in modules:
 | 
			
		||||
        model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype))
 | 
			
		||||
        model_type_list.append(module_type)
 | 
			
		||||
 | 
			
		||||
    for filename, file_model_type in zip(model_file_list, model_type_list): 
 | 
			
		||||
        download_models(filename, file_model_type)
 | 
			
		||||
    VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
 | 
			
		||||
    mixed_precision_transformer =  server_config.get("mixed_precision","0") == "1"
 | 
			
		||||
    transformer_filename = None
 | 
			
		||||
    transformer_loras_filenames = None
 | 
			
		||||
    transformer_type = None
 | 
			
		||||
    for i, filename in enumerate(model_file_list):
 | 
			
		||||
@ -5317,9 +5325,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                            video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False)
 | 
			
		||||
                            video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False)
 | 
			
		||||
 | 
			
		||||
                video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) 
 | 
			
		||||
                video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) 
 | 
			
		||||
 | 
			
		||||
                mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
 | 
			
		||||
                mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
 | 
			
		||||
 | 
			
		||||
                image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images",
 | 
			
		||||
                        type ="pil",   show_label= True,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user