From 1d0c2ae6c414f2d18e65c67ee997fd8202fc273c Mon Sep 17 00:00:00 2001 From: artetaout Date: Mon, 9 Jun 2025 08:56:48 +0000 Subject: [PATCH] feat: support sage for Wan generate.. --- wan/distributed/xdit_context_parallel.py | 27 +++++++++++++++++------- wan/utils/utils.py | 7 ++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 4718577..36aec4b 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -9,7 +9,11 @@ from xfuser.core.distributed import ( from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ..modules.model import sinusoidal_embedding_1d - +import wan.utils.utils as wan_utils +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") def pad_freqs(original_tensor, target_len): seq_len, s1, s2 = original_tensor.shape @@ -209,13 +213,20 @@ def usp_attn_forward(self, # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) - - x = xFuserLongContextAttention()( - None, - query=half(q), - key=half(k), - value=half(v), - window_size=self.window_size) + if wan_utils.ENABLE_SAGE_ATTENTION: + x = xFuserLongContextAttention(attn_type=AttnType.SAGE_FP8_SM90)( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + else: + x = xFuserLongContextAttention(attn_type=AttnType.FA3)( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) # TODO: padding after attention. # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) diff --git a/wan/utils/utils.py b/wan/utils/utils.py index d725999..50a85e2 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -10,6 +10,13 @@ import torchvision __all__ = ['cache_video', 'cache_image', 'str2bool'] +try: + import sageattention + HAS_SAGE_ATTENTION = True +except ImportError: + HAS_SAGE_ATTENTION = False + +ENABLE_SAGE_ATTENTION = os.environ.get('ENABLE_SAGE_ATTENTION', '0') == '1' and HAS_SAGE_ATTENTION def rand_name(length=8, suffix=''): name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')