mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
This UI color is the good one + slightly reduced VRAM when using Sage2 attention
This commit is contained in:
parent
1406a44a69
commit
f2163e0984
@ -1595,7 +1595,7 @@ def create_demo():
|
||||
}
|
||||
"""
|
||||
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
||||
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="stone", neutral_hue="slate", text_size= "md")) as demo:
|
||||
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size= "md")) as demo:
|
||||
state_dict = {}
|
||||
|
||||
if use_image2video:
|
||||
|
||||
@ -38,8 +38,13 @@ import warnings
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
@torch.compiler.disable()
|
||||
def sageattn_wrapper(
|
||||
from .sage2_core import sageattn as alt_sageattn
|
||||
except ImportError:
|
||||
sageattn = None
|
||||
alt_sageattn = None
|
||||
|
||||
# @torch.compiler.disable()
|
||||
def sageattn_wrapper(
|
||||
qkv_list,
|
||||
attention_length
|
||||
):
|
||||
@ -48,17 +53,20 @@ try:
|
||||
q = q[:attention_length, :, : ].unsqueeze(0)
|
||||
k = k[:attention_length, :, : ].unsqueeze(0)
|
||||
v = v[:attention_length, :, : ].unsqueeze(0)
|
||||
|
||||
if True:
|
||||
qkv_list = [q,k,v]
|
||||
del q, k ,v
|
||||
o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0)
|
||||
else:
|
||||
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
|
||||
del q, k ,v
|
||||
|
||||
qkv_list.clear()
|
||||
|
||||
if padding_length > 0:
|
||||
o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
||||
|
||||
return o
|
||||
except ImportError:
|
||||
sageattn = None
|
||||
|
||||
# # try:
|
||||
# if True:
|
||||
|
||||
1094
wan/modules/sage2_core.py
Normal file
1094
wan/modules/sage2_core.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user