diff --git a/README.md b/README.md index 8261540..939cb7f 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,9 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)** ## 🚀 Quick Start -**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/) +**One-click installation:** +- Get started instantly with [Pinokio App](https://pinokio.computer/) +- Use Redtash1 [One Click Install with Sage](https://github.com/Redtash1/Wan2GP-Windows-One-Click-Install-With-Sage) **Manual installation:** ```bash @@ -136,8 +138,7 @@ pip install -r requirements.txt **Run the application:** ```bash -python wgp.py # Text-to-video (default) -python wgp.py --i2v # Image-to-video +python wgp.py ``` **Update the application:** diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 02fd473..c458852 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -232,7 +232,7 @@ class QwenImagePipeline(): #DiffusionPipeline drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] - if self.processor is not None and image is not None: + if self.processor is not None and image is not None and len(image) > 0: img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" if isinstance(image, list): base_img_prompt = "" @@ -972,7 +972,7 @@ class QwenImagePipeline(): #DiffusionPipeline if callback is not None: preview = self._unpack_latents(latents, height, width, self.vae_scale_factor) - preview = preview.squeeze(0) + preview = preview.transpose(0,2).squeeze(0) callback(i, preview, False) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 285340d..7dc4b19 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -595,7 +595,7 @@ class WanAny2V: color_reference_frame = input_frames[:, -1:].clone() if prefix_frames_count > 0: overlapped_frames_num = prefix_frames_count - overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1 + overlapped_latents_frames_num = (overlapped_frames_num -1 // 4) + 1 # overlapped_latents_frames_num = overlapped_latents.shape[2] # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 else: @@ -735,11 +735,20 @@ class WanAny2V: if callback != None: callback(-1, None, True) + + clear_caches() offload.shared_state["_chipmunk"] = False chipmunk = offload.shared_state.get("_chipmunk", False) if chipmunk: self.model.setup_chipmunk() + offload.shared_state["_radial"] = offload.shared_state["_attention"]=="radial" + radial = offload.shared_state.get("_radial", False) + if radial: + radial_cache = get_cache("radial") + from shared.radial_attention.attention import fill_radial_cache + fill_radial_cache(radial_cache, len(self.model.blocks), *target_shape[1:]) + # init denoising updated_num_steps= len(timesteps) diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index cd02470..82a313c 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -315,6 +315,8 @@ class WanSelfAttention(nn.Module): x_ref_attn_map = None chipmunk = offload.shared_state.get("_chipmunk", False) + radial = offload.shared_state.get("_radial", False) + if chipmunk and self.__class__ == WanSelfAttention: q = q.transpose(1,2) k = k.transpose(1,2) @@ -322,12 +324,17 @@ class WanSelfAttention(nn.Module): attn_layers = offload.shared_state["_chipmunk_layers"] x = attn_layers[self.block_no](q, k, v) x = x.transpose(1,2) + elif radial and self.__class__ == WanSelfAttention: + qkv_list = [q,k,v] + del q,k,v + radial_cache = get_cache("radial") + no_step_no = offload.shared_state["step_no"] + x = radial_cache[self.block_no](qkv_list=qkv_list, timestep_no=no_step_no) elif block_mask == None: qkv_list = [q,k,v] del q,k,v - x = pay_attention( - qkv_list, - window_size=self.window_size) + + x = pay_attention( qkv_list, window_size=self.window_size) else: with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): @@ -1311,9 +1318,8 @@ class WanModel(ModelMixin, ConfigMixin): del causal_mask offload.shared_state["embed_sizes"] = grid_sizes - offload.shared_state["step_no"] = real_step_no + offload.shared_state["step_no"] = current_step_no offload.shared_state["max_steps"] = max_steps - if current_step_no == 0 and x_id == 0: clear_caches() # arguments kwargs = dict( diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index c587544..1a30fff 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -184,6 +184,15 @@ class family_handler(): "visible": False } + # extra_model_def["image_ref_choices"] = { + # "choices": [("None", ""), + # ("People / Objects", "I"), + # ("Landscape followed by People / Objects (if any)", "KI"), + # ], + # "visible": False, + # "letters_filter": "KFI", + # } + extra_model_def["video_guide_outpainting"] = [0,1] extra_model_def["keep_frames_video_guide_not_supported"] = True extra_model_def["extract_guide_from_window_start"] = True diff --git a/shared/attention.py b/shared/attention.py index cc6ece0..efaa223 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -42,6 +42,11 @@ try: except ImportError: sageattn_varlen_wrapper = None +try: + from spas_sage_attn import block_sparse_sage2_attn_cuda +except ImportError: + block_sparse_sage2_attn_cuda = None + try: from .sage2_core import sageattn as sageattn2, is_sage2_supported @@ -62,6 +67,8 @@ def sageattn2_wrapper( return o +from sageattn import sageattn_blackwell as sageattn3 + try: from sageattn import sageattn_blackwell as sageattn3 except ImportError: @@ -144,6 +151,9 @@ def get_attention_modes(): ret.append("sage") if sageattn2 != None and version("sageattention").startswith("2") : ret.append("sage2") + if block_sparse_sage2_attn_cuda != None and version("sageattention").startswith("2") : + ret.append("radial") + if sageattn3 != None: # and version("sageattention").startswith("3") : ret.append("sage3") @@ -159,6 +169,8 @@ def get_supported_attention_modes(): if not sage2_supported: if "sage2" in ret: ret.remove("sage2") + if "radial" in ret: + ret.remove("radial") if major < 7: if "sage" in ret: @@ -225,6 +237,7 @@ def pay_attention( if attn == "chipmunk": from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG + if attn == "radial": attn ="sage2" if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"): assert attention_mask == None diff --git a/shared/sage2_core.py b/shared/sage2_core.py index 04b77f8..33cd98c 100644 --- a/shared/sage2_core.py +++ b/shared/sage2_core.py @@ -28,18 +28,24 @@ from sageattention.triton.quant_per_thread import per_thread_int8 as per_thread_ try: from sageattention import _qattn_sm80 + if not hasattr(_qattn_sm80, "qk_int8_sv_f16_accum_f32_attn"): + _qattn_sm80 = torch.ops.sageattention_qattn_sm80 SM80_ENABLED = True except: SM80_ENABLED = False try: from sageattention import _qattn_sm89 + if not hasattr(_qattn_sm89, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"): + _qattn_sm89 = torch.ops.sageattention_qattn_sm89 SM89_ENABLED = True except: SM89_ENABLED = False try: from sageattention import _qattn_sm90 + if not hasattr(_qattn_sm90, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"): + _qattn_sm90 = torch.ops.sageattention_qattn_sm90 SM90_ENABLED = True except: SM90_ENABLED = False diff --git a/wgp.py b/wgp.py index e00660c..6a257e8 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.74" +WanGP_version = "8.75" settings_version = 2.36 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -1394,6 +1394,12 @@ def _parse_args(): help="View form generation / refresh time" ) + parser.add_argument( + "--betatest", + action="store_true", + help="test unreleased features" + ) + parser.add_argument( "--vram-safety-coefficient", type=float, @@ -4367,7 +4373,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri if image_start is None or not "I" in prompt_enhancer: image_start = [None] * num_prompts else: - image_start = [img[0] for img in image_start] + image_start = [convert_image(img[0]) for img in image_start] if len(image_start) == 1: image_start = image_start * num_prompts else: @@ -8712,9 +8718,9 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice ("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"), ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"), ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"), - ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"), - ("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"), - ], + ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2")]\ + + ([("Radial" + check("radial")+ ": x? faster but ? quality - requires Sparge & Sage 2 Attn (usually complex to set up on Windows without WSL)", "radial")] if args.betatest else [])\ + + [("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3")], value= attention_mode, label="Attention Type", interactive= not lock_ui_attention