This UI color is the good one + slightly reduced VRAM when using Sage2 attention

This commit is contained in:
DeepBeepMeep 2025-03-19 23:33:18 +01:00
parent 1406a44a69
commit f2163e0984
3 changed files with 1122 additions and 20 deletions

View File

@ -1595,7 +1595,7 @@ def create_demo():
}
"""
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="stone", neutral_hue="slate", text_size= "md")) as demo:
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size= "md")) as demo:
state_dict = {}
if use_image2video:

View File

@ -38,8 +38,13 @@ import warnings
try:
from sageattention import sageattn
@torch.compiler.disable()
def sageattn_wrapper(
from .sage2_core import sageattn as alt_sageattn
except ImportError:
sageattn = None
alt_sageattn = None
# @torch.compiler.disable()
def sageattn_wrapper(
qkv_list,
attention_length
):
@ -48,17 +53,20 @@ try:
q = q[:attention_length, :, : ].unsqueeze(0)
k = k[:attention_length, :, : ].unsqueeze(0)
v = v[:attention_length, :, : ].unsqueeze(0)
if True:
qkv_list = [q,k,v]
del q, k ,v
o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0)
else:
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
del q, k ,v
qkv_list.clear()
if padding_length > 0:
o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
return o
except ImportError:
sageattn = None
# # try:
# if True:

1094
wan/modules/sage2_core.py Normal file

File diff suppressed because it is too large Load Diff