mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
fixed bug with Sage2 sm86 architecture
This commit is contained in:
parent
408ebc44ce
commit
35071d4c95
@ -223,7 +223,7 @@ def _parse_args():
|
||||
"--res",
|
||||
type=str,
|
||||
default="480p",
|
||||
choices=["480p", "720p", "823p", "1024p", "1280p"],
|
||||
choices=["480p", "720p", "832p", "1024p", "1280p"],
|
||||
help="Default resolution for the video (480p, 720p, 823p, 1024p or 1280p)"
|
||||
)
|
||||
|
||||
@ -730,7 +730,7 @@ def apply_changes( state,
|
||||
text_encoder_filename = server_config["text_encoder_filename"]
|
||||
vae_config = server_config["vae_config"]
|
||||
boost = server_config["boost"]
|
||||
if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ):
|
||||
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
|
||||
if "attention_mode" in changes:
|
||||
pass
|
||||
|
||||
@ -1218,7 +1218,7 @@ def generate_video(
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
s = str(e)
|
||||
keyword_list = ["vram", "VRAM", "memory", "triton", "cuda", "allocat"]
|
||||
keyword_list = ["vram", "VRAM", "memory","allocat"]
|
||||
VRAM_crash= False
|
||||
if any( keyword in s for keyword in keyword_list):
|
||||
VRAM_crash = True
|
||||
@ -1895,7 +1895,7 @@ def create_demo():
|
||||
# ("624x832 (3:4, 540p)", "624x832"),
|
||||
# ("720x720 (1:1, 540p)", "720x720"),
|
||||
],
|
||||
value={"480p": "832x480","720p": "1280x720","823p": "480x832","1024p": "1024x1024","1280p": "720x1280",}.get(res, "832x480"),
|
||||
value={"480p": "832x480","720p": "1280x720","832p": "480x832","1024p": "1024x1024","1280p": "720x1280",}.get(res, "832x480"),
|
||||
label="Resolution"
|
||||
)
|
||||
|
||||
@ -1930,7 +1930,7 @@ def create_demo():
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
||||
with gr.Column(visible = len(loras)>0) as loras_column:
|
||||
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioned a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
|
||||
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
|
||||
loras_choices = gr.Dropdown(
|
||||
choices=[
|
||||
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
|
||||
|
||||
@ -137,9 +137,10 @@ def sageattn(
|
||||
|
||||
@torch.compiler.disable
|
||||
def sageattn_qk_int8_pv_fp16_triton(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
qkv_list,
|
||||
# q: torch.Tensor,
|
||||
# k: torch.Tensor,
|
||||
# v: torch.Tensor,
|
||||
tensor_layout: str = "HND",
|
||||
quantization_backend: str = "triton",
|
||||
is_causal: bool =False,
|
||||
@ -211,7 +212,8 @@ def sageattn_qk_int8_pv_fp16_triton(
|
||||
- All tensors must be on the same cuda device.
|
||||
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
|
||||
"""
|
||||
|
||||
q, k, v = qkv_list
|
||||
qkv_list.clear()
|
||||
dtype = q.dtype
|
||||
assert q.is_cuda, "Input tensors must be on cuda."
|
||||
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
|
||||
@ -266,6 +268,8 @@ def sageattn_qk_int8_pv_fp16_triton(
|
||||
q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization backend: {quantization_backend}")
|
||||
del q,k, km
|
||||
|
||||
if is_causal:
|
||||
o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user