fixed bug with Sage2 sm86 architecture

This commit is contained in:
DeepBeepMeep 2025-03-20 09:28:13 +01:00
parent 408ebc44ce
commit 35071d4c95
2 changed files with 13 additions and 9 deletions

View File

@ -223,7 +223,7 @@ def _parse_args():
"--res", "--res",
type=str, type=str,
default="480p", default="480p",
choices=["480p", "720p", "823p", "1024p", "1280p"], choices=["480p", "720p", "832p", "1024p", "1280p"],
help="Default resolution for the video (480p, 720p, 823p, 1024p or 1280p)" help="Default resolution for the video (480p, 720p, 823p, 1024p or 1280p)"
) )
@ -730,7 +730,7 @@ def apply_changes( state,
text_encoder_filename = server_config["text_encoder_filename"] text_encoder_filename = server_config["text_encoder_filename"]
vae_config = server_config["vae_config"] vae_config = server_config["vae_config"]
boost = server_config["boost"] boost = server_config["boost"]
if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ): if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
if "attention_mode" in changes: if "attention_mode" in changes:
pass pass
@ -1218,7 +1218,7 @@ def generate_video(
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
s = str(e) s = str(e)
keyword_list = ["vram", "VRAM", "memory", "triton", "cuda", "allocat"] keyword_list = ["vram", "VRAM", "memory","allocat"]
VRAM_crash= False VRAM_crash= False
if any( keyword in s for keyword in keyword_list): if any( keyword in s for keyword in keyword_list):
VRAM_crash = True VRAM_crash = True
@ -1895,7 +1895,7 @@ def create_demo():
# ("624x832 (3:4, 540p)", "624x832"), # ("624x832 (3:4, 540p)", "624x832"),
# ("720x720 (1:1, 540p)", "720x720"), # ("720x720 (1:1, 540p)", "720x720"),
], ],
value={"480p": "832x480","720p": "1280x720","823p": "480x832","1024p": "1024x1024","1280p": "720x1280",}.get(res, "832x480"), value={"480p": "832x480","720p": "1280x720","832p": "480x832","1024p": "1024x1024","1280p": "720x1280",}.get(res, "832x480"),
label="Resolution" label="Resolution"
) )
@ -1930,7 +1930,7 @@ def create_demo():
with gr.Row(): with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value="") negative_prompt = gr.Textbox(label="Negative Prompt", value="")
with gr.Column(visible = len(loras)>0) as loras_column: with gr.Column(visible = len(loras)>0) as loras_column:
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioned a trigger word in the Prompt. You can save Loras combinations in presets.</B>") gr.Markdown("<B>Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
loras_choices = gr.Dropdown( loras_choices = gr.Dropdown(
choices=[ choices=[
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names) (lora_name, str(i) ) for i, lora_name in enumerate(loras_names)

View File

@ -137,9 +137,10 @@ def sageattn(
@torch.compiler.disable @torch.compiler.disable
def sageattn_qk_int8_pv_fp16_triton( def sageattn_qk_int8_pv_fp16_triton(
q: torch.Tensor, qkv_list,
k: torch.Tensor, # q: torch.Tensor,
v: torch.Tensor, # k: torch.Tensor,
# v: torch.Tensor,
tensor_layout: str = "HND", tensor_layout: str = "HND",
quantization_backend: str = "triton", quantization_backend: str = "triton",
is_causal: bool =False, is_causal: bool =False,
@ -211,7 +212,8 @@ def sageattn_qk_int8_pv_fp16_triton(
- All tensors must be on the same cuda device. - All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
""" """
q, k, v = qkv_list
qkv_list.clear()
dtype = q.dtype dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda." assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
@ -266,6 +268,8 @@ def sageattn_qk_int8_pv_fp16_triton(
q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout) q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
else: else:
raise ValueError(f"Unsupported quantization backend: {quantization_backend}") raise ValueError(f"Unsupported quantization backend: {quantization_backend}")
del q,k, km
if is_causal: if is_causal:
o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse) o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
else: else: