feat: support sage for Wan generate..

This commit is contained in:
artetaout 2025-06-09 08:56:48 +00:00
parent 827906c30f
commit 1d0c2ae6c4
2 changed files with 26 additions and 8 deletions

View File

@ -9,7 +9,11 @@ from xfuser.core.distributed import (
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d
import wan.utils.utils as wan_utils
try:
from yunchang.kernels import AttnType
except ImportError:
raise ImportError("Please install yunchang 0.6.0 or later")
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
@ -209,13 +213,20 @@ def usp_attn_forward(self,
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
x = xFuserLongContextAttention()(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
if wan_utils.ENABLE_SAGE_ATTENTION:
x = xFuserLongContextAttention(attn_type=AttnType.SAGE_FP8_SM90)(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
else:
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
# TODO: padding after attention.
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)

View File

@ -10,6 +10,13 @@ import torchvision
__all__ = ['cache_video', 'cache_image', 'str2bool']
try:
import sageattention
HAS_SAGE_ATTENTION = True
except ImportError:
HAS_SAGE_ATTENTION = False
ENABLE_SAGE_ATTENTION = os.environ.get('ENABLE_SAGE_ATTENTION', '0') == '1' and HAS_SAGE_ATTENTION
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')