mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
fixed bug with Sage2 sm86 architecture
This commit is contained in:
parent
408ebc44ce
commit
35071d4c95
@ -223,7 +223,7 @@ def _parse_args():
|
|||||||
"--res",
|
"--res",
|
||||||
type=str,
|
type=str,
|
||||||
default="480p",
|
default="480p",
|
||||||
choices=["480p", "720p", "823p", "1024p", "1280p"],
|
choices=["480p", "720p", "832p", "1024p", "1280p"],
|
||||||
help="Default resolution for the video (480p, 720p, 823p, 1024p or 1280p)"
|
help="Default resolution for the video (480p, 720p, 823p, 1024p or 1280p)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -730,7 +730,7 @@ def apply_changes( state,
|
|||||||
text_encoder_filename = server_config["text_encoder_filename"]
|
text_encoder_filename = server_config["text_encoder_filename"]
|
||||||
vae_config = server_config["vae_config"]
|
vae_config = server_config["vae_config"]
|
||||||
boost = server_config["boost"]
|
boost = server_config["boost"]
|
||||||
if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ):
|
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
|
||||||
if "attention_mode" in changes:
|
if "attention_mode" in changes:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1218,7 +1218,7 @@ def generate_video(
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
s = str(e)
|
s = str(e)
|
||||||
keyword_list = ["vram", "VRAM", "memory", "triton", "cuda", "allocat"]
|
keyword_list = ["vram", "VRAM", "memory","allocat"]
|
||||||
VRAM_crash= False
|
VRAM_crash= False
|
||||||
if any( keyword in s for keyword in keyword_list):
|
if any( keyword in s for keyword in keyword_list):
|
||||||
VRAM_crash = True
|
VRAM_crash = True
|
||||||
@ -1895,7 +1895,7 @@ def create_demo():
|
|||||||
# ("624x832 (3:4, 540p)", "624x832"),
|
# ("624x832 (3:4, 540p)", "624x832"),
|
||||||
# ("720x720 (1:1, 540p)", "720x720"),
|
# ("720x720 (1:1, 540p)", "720x720"),
|
||||||
],
|
],
|
||||||
value={"480p": "832x480","720p": "1280x720","823p": "480x832","1024p": "1024x1024","1280p": "720x1280",}.get(res, "832x480"),
|
value={"480p": "832x480","720p": "1280x720","832p": "480x832","1024p": "1024x1024","1280p": "720x1280",}.get(res, "832x480"),
|
||||||
label="Resolution"
|
label="Resolution"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1930,7 +1930,7 @@ def create_demo():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
||||||
with gr.Column(visible = len(loras)>0) as loras_column:
|
with gr.Column(visible = len(loras)>0) as loras_column:
|
||||||
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioned a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
|
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
|
||||||
loras_choices = gr.Dropdown(
|
loras_choices = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
|
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
|
||||||
|
|||||||
@ -137,9 +137,10 @@ def sageattn(
|
|||||||
|
|
||||||
@torch.compiler.disable
|
@torch.compiler.disable
|
||||||
def sageattn_qk_int8_pv_fp16_triton(
|
def sageattn_qk_int8_pv_fp16_triton(
|
||||||
q: torch.Tensor,
|
qkv_list,
|
||||||
k: torch.Tensor,
|
# q: torch.Tensor,
|
||||||
v: torch.Tensor,
|
# k: torch.Tensor,
|
||||||
|
# v: torch.Tensor,
|
||||||
tensor_layout: str = "HND",
|
tensor_layout: str = "HND",
|
||||||
quantization_backend: str = "triton",
|
quantization_backend: str = "triton",
|
||||||
is_causal: bool =False,
|
is_causal: bool =False,
|
||||||
@ -211,7 +212,8 @@ def sageattn_qk_int8_pv_fp16_triton(
|
|||||||
- All tensors must be on the same cuda device.
|
- All tensors must be on the same cuda device.
|
||||||
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
|
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
|
||||||
"""
|
"""
|
||||||
|
q, k, v = qkv_list
|
||||||
|
qkv_list.clear()
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
assert q.is_cuda, "Input tensors must be on cuda."
|
assert q.is_cuda, "Input tensors must be on cuda."
|
||||||
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
|
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
|
||||||
@ -266,6 +268,8 @@ def sageattn_qk_int8_pv_fp16_triton(
|
|||||||
q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
|
q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quantization backend: {quantization_backend}")
|
raise ValueError(f"Unsupported quantization backend: {quantization_backend}")
|
||||||
|
del q,k, km
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
|
o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user