mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
This UI color is the good one + slightly reduced VRAM when using Sage2 attention
This commit is contained in:
parent
1406a44a69
commit
f2163e0984
@ -1595,7 +1595,7 @@ def create_demo():
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
||||||
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="stone", neutral_hue="slate", text_size= "md")) as demo:
|
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size= "md")) as demo:
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
|
||||||
if use_image2video:
|
if use_image2video:
|
||||||
|
|||||||
@ -38,27 +38,35 @@ import warnings
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
@torch.compiler.disable()
|
from .sage2_core import sageattn as alt_sageattn
|
||||||
def sageattn_wrapper(
|
|
||||||
qkv_list,
|
|
||||||
attention_length
|
|
||||||
):
|
|
||||||
q,k, v = qkv_list
|
|
||||||
padding_length = q.shape[0] -attention_length
|
|
||||||
q = q[:attention_length, :, : ].unsqueeze(0)
|
|
||||||
k = k[:attention_length, :, : ].unsqueeze(0)
|
|
||||||
v = v[:attention_length, :, : ].unsqueeze(0)
|
|
||||||
|
|
||||||
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
|
|
||||||
del q, k ,v
|
|
||||||
qkv_list.clear()
|
|
||||||
|
|
||||||
if padding_length > 0:
|
|
||||||
o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
|
||||||
|
|
||||||
return o
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
sageattn = None
|
sageattn = None
|
||||||
|
alt_sageattn = None
|
||||||
|
|
||||||
|
# @torch.compiler.disable()
|
||||||
|
def sageattn_wrapper(
|
||||||
|
qkv_list,
|
||||||
|
attention_length
|
||||||
|
):
|
||||||
|
q,k, v = qkv_list
|
||||||
|
padding_length = q.shape[0] -attention_length
|
||||||
|
q = q[:attention_length, :, : ].unsqueeze(0)
|
||||||
|
k = k[:attention_length, :, : ].unsqueeze(0)
|
||||||
|
v = v[:attention_length, :, : ].unsqueeze(0)
|
||||||
|
if True:
|
||||||
|
qkv_list = [q,k,v]
|
||||||
|
del q, k ,v
|
||||||
|
o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0)
|
||||||
|
else:
|
||||||
|
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
|
||||||
|
del q, k ,v
|
||||||
|
|
||||||
|
qkv_list.clear()
|
||||||
|
|
||||||
|
if padding_length > 0:
|
||||||
|
o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
# # try:
|
# # try:
|
||||||
# if True:
|
# if True:
|
||||||
|
|||||||
1094
wan/modules/sage2_core.py
Normal file
1094
wan/modules/sage2_core.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user