This UI color is the good one + slightly reduced VRAM when using Sage2 attention

This commit is contained in:
DeepBeepMeep 2025-03-19 23:33:18 +01:00
parent 1406a44a69
commit f2163e0984
3 changed files with 1122 additions and 20 deletions

View File

@ -1595,7 +1595,7 @@ def create_demo():
} }
""" """
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v) default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="stone", neutral_hue="slate", text_size= "md")) as demo: with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size= "md")) as demo:
state_dict = {} state_dict = {}
if use_image2video: if use_image2video:

View File

@ -38,7 +38,12 @@ import warnings
try: try:
from sageattention import sageattn from sageattention import sageattn
@torch.compiler.disable() from .sage2_core import sageattn as alt_sageattn
except ImportError:
sageattn = None
alt_sageattn = None
# @torch.compiler.disable()
def sageattn_wrapper( def sageattn_wrapper(
qkv_list, qkv_list,
attention_length attention_length
@ -48,17 +53,20 @@ try:
q = q[:attention_length, :, : ].unsqueeze(0) q = q[:attention_length, :, : ].unsqueeze(0)
k = k[:attention_length, :, : ].unsqueeze(0) k = k[:attention_length, :, : ].unsqueeze(0)
v = v[:attention_length, :, : ].unsqueeze(0) v = v[:attention_length, :, : ].unsqueeze(0)
if True:
qkv_list = [q,k,v]
del q, k ,v
o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0)
else:
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0) o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
del q, k ,v del q, k ,v
qkv_list.clear() qkv_list.clear()
if padding_length > 0: if padding_length > 0:
o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0) o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
return o return o
except ImportError:
sageattn = None
# # try: # # try:
# if True: # if True:

1094
wan/modules/sage2_core.py Normal file

File diff suppressed because it is too large Load Diff