From 35071d4c952d879670a238e2bbb65e59cff189c2 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 20 Mar 2025 09:28:13 +0100 Subject: [PATCH] fixed bug with Sage2 sm86 architecture --- gradio_server.py | 10 +++++----- wan/modules/sage2_core.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 0ad2d8c..b61c326 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -223,7 +223,7 @@ def _parse_args(): "--res", type=str, default="480p", - choices=["480p", "720p", "823p", "1024p", "1280p"], + choices=["480p", "720p", "832p", "1024p", "1280p"], help="Default resolution for the video (480p, 720p, 823p, 1024p or 1280p)" ) @@ -730,7 +730,7 @@ def apply_changes( state, text_encoder_filename = server_config["text_encoder_filename"] vae_config = server_config["vae_config"] boost = server_config["boost"] - if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ): + if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ): if "attention_mode" in changes: pass @@ -1218,7 +1218,7 @@ def generate_video( gc.collect() torch.cuda.empty_cache() s = str(e) - keyword_list = ["vram", "VRAM", "memory", "triton", "cuda", "allocat"] + keyword_list = ["vram", "VRAM", "memory","allocat"] VRAM_crash= False if any( keyword in s for keyword in keyword_list): VRAM_crash = True @@ -1895,7 +1895,7 @@ def create_demo(): # ("624x832 (3:4, 540p)", "624x832"), # ("720x720 (1:1, 540p)", "720x720"), ], - value={"480p": "832x480","720p": "1280x720","823p": "480x832","1024p": "1024x1024","1280p": "720x1280",}.get(res, "832x480"), + value={"480p": "832x480","720p": "1280x720","832p": "480x832","1024p": "1024x1024","1280p": "720x1280",}.get(res, "832x480"), label="Resolution" ) @@ -1930,7 +1930,7 @@ def create_demo(): with gr.Row(): negative_prompt = gr.Textbox(label="Negative Prompt", value="") with gr.Column(visible = len(loras)>0) as loras_column: - gr.Markdown("Loras can be used to create special effects on the video by mentioned a trigger word in the Prompt. You can save Loras combinations in presets.") + gr.Markdown("Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.") loras_choices = gr.Dropdown( choices=[ (lora_name, str(i) ) for i, lora_name in enumerate(loras_names) diff --git a/wan/modules/sage2_core.py b/wan/modules/sage2_core.py index c4dddc7..d83d255 100644 --- a/wan/modules/sage2_core.py +++ b/wan/modules/sage2_core.py @@ -137,9 +137,10 @@ def sageattn( @torch.compiler.disable def sageattn_qk_int8_pv_fp16_triton( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + qkv_list, + # q: torch.Tensor, + # k: torch.Tensor, + # v: torch.Tensor, tensor_layout: str = "HND", quantization_backend: str = "triton", is_causal: bool =False, @@ -211,7 +212,8 @@ def sageattn_qk_int8_pv_fp16_triton( - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. """ - + q, k, v = qkv_list + qkv_list.clear() dtype = q.dtype assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" @@ -266,6 +268,8 @@ def sageattn_qk_int8_pv_fp16_triton( q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout) else: raise ValueError(f"Unsupported quantization backend: {quantization_backend}") + del q,k, km + if is_causal: o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse) else: