mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	various fixes
This commit is contained in:
		
							parent
							
								
									79df3aae64
								
							
						
					
					
						commit
						5ce8fc3d53
					
				@ -122,7 +122,9 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)**
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
## 🚀 Quick Start
 | 
					## 🚀 Quick Start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/)
 | 
					**One-click installation:** 
 | 
				
			||||||
 | 
					- Get started instantly with [Pinokio App](https://pinokio.computer/)
 | 
				
			||||||
 | 
					- Use Redtash1 [One Click Install with Sage](https://github.com/Redtash1/Wan2GP-Windows-One-Click-Install-With-Sage)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Manual installation:**
 | 
					**Manual installation:**
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
@ -136,8 +138,7 @@ pip install -r requirements.txt
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
**Run the application:**
 | 
					**Run the application:**
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
python wgp.py  # Text-to-video (default)
 | 
					python wgp.py
 | 
				
			||||||
python wgp.py --i2v  # Image-to-video
 | 
					 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Update the application:**
 | 
					**Update the application:**
 | 
				
			||||||
 | 
				
			|||||||
@ -232,7 +232,7 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
				
			|||||||
        drop_idx = self.prompt_template_encode_start_idx
 | 
					        drop_idx = self.prompt_template_encode_start_idx
 | 
				
			||||||
        txt = [template.format(e) for e in prompt]
 | 
					        txt = [template.format(e) for e in prompt]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.processor is not None and image is not None:
 | 
					        if self.processor is not None and image is not None and len(image) > 0:
 | 
				
			||||||
            img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
 | 
					            img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
 | 
				
			||||||
            if isinstance(image, list):
 | 
					            if isinstance(image, list):
 | 
				
			||||||
                base_img_prompt = ""
 | 
					                base_img_prompt = ""
 | 
				
			||||||
@ -972,7 +972,7 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            if callback is not None:
 | 
					            if callback is not None:
 | 
				
			||||||
                preview = self._unpack_latents(latents, height, width, self.vae_scale_factor)
 | 
					                preview = self._unpack_latents(latents, height, width, self.vae_scale_factor)
 | 
				
			||||||
                preview = preview.squeeze(0)
 | 
					                preview = preview.transpose(0,2).squeeze(0)
 | 
				
			||||||
                callback(i, preview, False)         
 | 
					                callback(i, preview, False)         
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -595,7 +595,7 @@ class WanAny2V:
 | 
				
			|||||||
                color_reference_frame = input_frames[:, -1:].clone()
 | 
					                color_reference_frame = input_frames[:, -1:].clone()
 | 
				
			||||||
                if prefix_frames_count > 0:
 | 
					                if prefix_frames_count > 0:
 | 
				
			||||||
                    overlapped_frames_num = prefix_frames_count
 | 
					                    overlapped_frames_num = prefix_frames_count
 | 
				
			||||||
                    overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1 
 | 
					                    overlapped_latents_frames_num = (overlapped_frames_num -1 // 4) + 1 
 | 
				
			||||||
                    # overlapped_latents_frames_num = overlapped_latents.shape[2]
 | 
					                    # overlapped_latents_frames_num = overlapped_latents.shape[2]
 | 
				
			||||||
                    # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
 | 
					                    # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
 | 
				
			||||||
                else: 
 | 
					                else: 
 | 
				
			||||||
@ -735,11 +735,20 @@ class WanAny2V:
 | 
				
			|||||||
        if callback != None:
 | 
					        if callback != None:
 | 
				
			||||||
            callback(-1, None, True)
 | 
					            callback(-1, None, True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        clear_caches()
 | 
				
			||||||
        offload.shared_state["_chipmunk"] =  False
 | 
					        offload.shared_state["_chipmunk"] =  False
 | 
				
			||||||
        chipmunk = offload.shared_state.get("_chipmunk", False)        
 | 
					        chipmunk = offload.shared_state.get("_chipmunk", False)        
 | 
				
			||||||
        if chipmunk:
 | 
					        if chipmunk:
 | 
				
			||||||
            self.model.setup_chipmunk()
 | 
					            self.model.setup_chipmunk()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        offload.shared_state["_radial"] =  offload.shared_state["_attention"]=="radial"
 | 
				
			||||||
 | 
					        radial = offload.shared_state.get("_radial", False)        
 | 
				
			||||||
 | 
					        if radial:
 | 
				
			||||||
 | 
					            radial_cache = get_cache("radial")
 | 
				
			||||||
 | 
					            from shared.radial_attention.attention import fill_radial_cache
 | 
				
			||||||
 | 
					            fill_radial_cache(radial_cache, len(self.model.blocks), *target_shape[1:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # init denoising
 | 
					        # init denoising
 | 
				
			||||||
        updated_num_steps= len(timesteps)
 | 
					        updated_num_steps= len(timesteps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -315,6 +315,8 @@ class WanSelfAttention(nn.Module):
 | 
				
			|||||||
            x_ref_attn_map = None
 | 
					            x_ref_attn_map = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        chipmunk = offload.shared_state.get("_chipmunk", False) 
 | 
					        chipmunk = offload.shared_state.get("_chipmunk", False) 
 | 
				
			||||||
 | 
					        radial = offload.shared_state.get("_radial", False) 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if chipmunk and self.__class__ == WanSelfAttention:
 | 
					        if chipmunk and self.__class__ == WanSelfAttention:
 | 
				
			||||||
            q = q.transpose(1,2)
 | 
					            q = q.transpose(1,2)
 | 
				
			||||||
            k = k.transpose(1,2)
 | 
					            k = k.transpose(1,2)
 | 
				
			||||||
@ -322,12 +324,17 @@ class WanSelfAttention(nn.Module):
 | 
				
			|||||||
            attn_layers = offload.shared_state["_chipmunk_layers"]
 | 
					            attn_layers = offload.shared_state["_chipmunk_layers"]
 | 
				
			||||||
            x = attn_layers[self.block_no](q, k, v)
 | 
					            x = attn_layers[self.block_no](q, k, v)
 | 
				
			||||||
            x = x.transpose(1,2)
 | 
					            x = x.transpose(1,2)
 | 
				
			||||||
 | 
					        elif radial and self.__class__ == WanSelfAttention:
 | 
				
			||||||
 | 
					            qkv_list = [q,k,v]
 | 
				
			||||||
 | 
					            del q,k,v
 | 
				
			||||||
 | 
					            radial_cache = get_cache("radial")
 | 
				
			||||||
 | 
					            no_step_no = offload.shared_state["step_no"] 
 | 
				
			||||||
 | 
					            x = radial_cache[self.block_no](qkv_list=qkv_list, timestep_no=no_step_no)
 | 
				
			||||||
        elif block_mask == None:
 | 
					        elif block_mask == None:
 | 
				
			||||||
            qkv_list = [q,k,v]
 | 
					            qkv_list = [q,k,v]
 | 
				
			||||||
            del q,k,v
 | 
					            del q,k,v
 | 
				
			||||||
            x = pay_attention(
 | 
					
 | 
				
			||||||
                qkv_list,
 | 
					            x = pay_attention( qkv_list, window_size=self.window_size)
 | 
				
			||||||
                window_size=self.window_size)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
 | 
					            with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
 | 
				
			||||||
@ -1311,9 +1318,8 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
				
			|||||||
            del causal_mask
 | 
					            del causal_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        offload.shared_state["embed_sizes"] = grid_sizes 
 | 
					        offload.shared_state["embed_sizes"] = grid_sizes 
 | 
				
			||||||
        offload.shared_state["step_no"] = real_step_no 
 | 
					        offload.shared_state["step_no"] = current_step_no 
 | 
				
			||||||
        offload.shared_state["max_steps"] = max_steps
 | 
					        offload.shared_state["max_steps"] = max_steps
 | 
				
			||||||
        if current_step_no == 0 and x_id == 0: clear_caches()
 | 
					 | 
				
			||||||
        # arguments
 | 
					        # arguments
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        kwargs = dict(
 | 
					        kwargs = dict(
 | 
				
			||||||
 | 
				
			|||||||
@ -184,6 +184,15 @@ class family_handler():
 | 
				
			|||||||
                "visible": False
 | 
					                "visible": False
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # extra_model_def["image_ref_choices"] = {
 | 
				
			||||||
 | 
					            #         "choices": [("None", ""),
 | 
				
			||||||
 | 
					            #             ("People / Objects", "I"),
 | 
				
			||||||
 | 
					            #             ("Landscape followed by People / Objects (if any)", "KI"),
 | 
				
			||||||
 | 
					            #             ],
 | 
				
			||||||
 | 
					            #         "visible": False,
 | 
				
			||||||
 | 
					            #         "letters_filter":  "KFI",
 | 
				
			||||||
 | 
					            # }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            extra_model_def["video_guide_outpainting"] = [0,1]
 | 
					            extra_model_def["video_guide_outpainting"] = [0,1]
 | 
				
			||||||
            extra_model_def["keep_frames_video_guide_not_supported"] = True
 | 
					            extra_model_def["keep_frames_video_guide_not_supported"] = True
 | 
				
			||||||
            extra_model_def["extract_guide_from_window_start"] = True
 | 
					            extra_model_def["extract_guide_from_window_start"] = True
 | 
				
			||||||
 | 
				
			|||||||
@ -42,6 +42,11 @@ try:
 | 
				
			|||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
    sageattn_varlen_wrapper = None
 | 
					    sageattn_varlen_wrapper = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    from spas_sage_attn import block_sparse_sage2_attn_cuda
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    block_sparse_sage2_attn_cuda = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    from .sage2_core import sageattn as sageattn2, is_sage2_supported
 | 
					    from .sage2_core import sageattn as sageattn2, is_sage2_supported
 | 
				
			||||||
@ -62,6 +67,8 @@ def sageattn2_wrapper(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return o
 | 
					    return o
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from sageattn import sageattn_blackwell as sageattn3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    from sageattn import sageattn_blackwell as sageattn3
 | 
					    from sageattn import sageattn_blackwell as sageattn3
 | 
				
			||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
@ -144,6 +151,9 @@ def get_attention_modes():
 | 
				
			|||||||
        ret.append("sage")
 | 
					        ret.append("sage")
 | 
				
			||||||
    if sageattn2 != None and version("sageattention").startswith("2") :
 | 
					    if sageattn2 != None and version("sageattention").startswith("2") :
 | 
				
			||||||
        ret.append("sage2")
 | 
					        ret.append("sage2")
 | 
				
			||||||
 | 
					    if block_sparse_sage2_attn_cuda != None and version("sageattention").startswith("2") :
 | 
				
			||||||
 | 
					        ret.append("radial")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if sageattn3 != None: # and version("sageattention").startswith("3") :
 | 
					    if sageattn3 != None: # and version("sageattention").startswith("3") :
 | 
				
			||||||
        ret.append("sage3")
 | 
					        ret.append("sage3")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
@ -159,6 +169,8 @@ def get_supported_attention_modes():
 | 
				
			|||||||
    if not sage2_supported:
 | 
					    if not sage2_supported:
 | 
				
			||||||
        if "sage2" in ret:
 | 
					        if "sage2" in ret:
 | 
				
			||||||
            ret.remove("sage2")
 | 
					            ret.remove("sage2")
 | 
				
			||||||
 | 
					        if "radial" in ret:
 | 
				
			||||||
 | 
					            ret.remove("radial")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if  major < 7:
 | 
					    if  major < 7:
 | 
				
			||||||
        if "sage" in ret:
 | 
					        if "sage" in ret:
 | 
				
			||||||
@ -225,6 +237,7 @@ def pay_attention(
 | 
				
			|||||||
    if attn == "chipmunk":
 | 
					    if attn == "chipmunk":
 | 
				
			||||||
        from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
					        from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
				
			||||||
        from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
 | 
					        from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
 | 
				
			||||||
 | 
					    if attn == "radial": attn ="sage2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"):
 | 
					    if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"):
 | 
				
			||||||
        assert attention_mask == None
 | 
					        assert attention_mask == None
 | 
				
			||||||
 | 
				
			|||||||
@ -28,18 +28,24 @@ from sageattention.triton.quant_per_thread import per_thread_int8 as per_thread_
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    from sageattention import _qattn_sm80
 | 
					    from sageattention import _qattn_sm80
 | 
				
			||||||
 | 
					    if not hasattr(_qattn_sm80, "qk_int8_sv_f16_accum_f32_attn"): 
 | 
				
			||||||
 | 
					        _qattn_sm80 = torch.ops.sageattention_qattn_sm80
 | 
				
			||||||
    SM80_ENABLED = True
 | 
					    SM80_ENABLED = True
 | 
				
			||||||
except:
 | 
					except:
 | 
				
			||||||
    SM80_ENABLED = False
 | 
					    SM80_ENABLED = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    from sageattention import _qattn_sm89
 | 
					    from sageattention import _qattn_sm89
 | 
				
			||||||
 | 
					    if not hasattr(_qattn_sm89, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"): 
 | 
				
			||||||
 | 
					        _qattn_sm89 = torch.ops.sageattention_qattn_sm89
 | 
				
			||||||
    SM89_ENABLED = True
 | 
					    SM89_ENABLED = True
 | 
				
			||||||
except:
 | 
					except:
 | 
				
			||||||
    SM89_ENABLED = False
 | 
					    SM89_ENABLED = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    from sageattention import _qattn_sm90
 | 
					    from sageattention import _qattn_sm90
 | 
				
			||||||
 | 
					    if not hasattr(_qattn_sm90, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"): 
 | 
				
			||||||
 | 
					        _qattn_sm90 = torch.ops.sageattention_qattn_sm90
 | 
				
			||||||
    SM90_ENABLED = True
 | 
					    SM90_ENABLED = True
 | 
				
			||||||
except:
 | 
					except:
 | 
				
			||||||
    SM90_ENABLED = False
 | 
					    SM90_ENABLED = False
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										16
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								wgp.py
									
									
									
									
									
								
							@ -63,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip"
 | 
				
			|||||||
PROMPT_VARS_MAX = 10
 | 
					PROMPT_VARS_MAX = 10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
target_mmgp_version = "3.6.0"
 | 
					target_mmgp_version = "3.6.0"
 | 
				
			||||||
WanGP_version = "8.74"
 | 
					WanGP_version = "8.75"
 | 
				
			||||||
settings_version = 2.36
 | 
					settings_version = 2.36
 | 
				
			||||||
max_source_video_frames = 3000
 | 
					max_source_video_frames = 3000
 | 
				
			||||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
					prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
				
			||||||
@ -1394,6 +1394,12 @@ def _parse_args():
 | 
				
			|||||||
        help="View form generation / refresh time"
 | 
					        help="View form generation / refresh time"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--betatest",
 | 
				
			||||||
 | 
					        action="store_true",
 | 
				
			||||||
 | 
					        help="test unreleased features"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--vram-safety-coefficient",
 | 
					        "--vram-safety-coefficient",
 | 
				
			||||||
        type=float,
 | 
					        type=float,
 | 
				
			||||||
@ -4367,7 +4373,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri
 | 
				
			|||||||
    if image_start is None or not "I" in prompt_enhancer:
 | 
					    if image_start is None or not "I" in prompt_enhancer:
 | 
				
			||||||
        image_start = [None] * num_prompts
 | 
					        image_start = [None] * num_prompts
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        image_start = [img[0] for img in image_start]
 | 
					        image_start = [convert_image(img[0]) for img in image_start]
 | 
				
			||||||
        if len(image_start) == 1:
 | 
					        if len(image_start) == 1:
 | 
				
			||||||
            image_start = image_start * num_prompts
 | 
					            image_start = image_start * num_prompts
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
@ -8712,9 +8718,9 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice
 | 
				
			|||||||
                        ("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
 | 
					                        ("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
 | 
				
			||||||
                        ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
 | 
					                        ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
 | 
				
			||||||
                        ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
 | 
					                        ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
 | 
				
			||||||
                        ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
 | 
					                        ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2")]\
 | 
				
			||||||
                        ("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"),
 | 
					                        + ([("Radial" + check("radial")+ ": x? faster but ? quality - requires Sparge & Sage 2 Attn (usually complex to set up on Windows without WSL)", "radial")] if args.betatest else [])\
 | 
				
			||||||
                    ],
 | 
					                        +  [("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3")],
 | 
				
			||||||
                    value= attention_mode,
 | 
					                    value= attention_mode,
 | 
				
			||||||
                    label="Attention Type",
 | 
					                    label="Attention Type",
 | 
				
			||||||
                    interactive= not lock_ui_attention
 | 
					                    interactive= not lock_ui_attention
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user