mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-13 11:10:11 +00:00
feat: support sage for Wan generate..
This commit is contained in:
parent
827906c30f
commit
1d0c2ae6c4
@ -9,7 +9,11 @@ from xfuser.core.distributed import (
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
from ..modules.model import sinusoidal_embedding_1d
|
||||
|
||||
import wan.utils.utils as wan_utils
|
||||
try:
|
||||
from yunchang.kernels import AttnType
|
||||
except ImportError:
|
||||
raise ImportError("Please install yunchang 0.6.0 or later")
|
||||
|
||||
def pad_freqs(original_tensor, target_len):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
@ -209,13 +213,20 @@ def usp_attn_forward(self,
|
||||
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
||||
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
||||
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
||||
|
||||
x = xFuserLongContextAttention()(
|
||||
None,
|
||||
query=half(q),
|
||||
key=half(k),
|
||||
value=half(v),
|
||||
window_size=self.window_size)
|
||||
if wan_utils.ENABLE_SAGE_ATTENTION:
|
||||
x = xFuserLongContextAttention(attn_type=AttnType.SAGE_FP8_SM90)(
|
||||
None,
|
||||
query=half(q),
|
||||
key=half(k),
|
||||
value=half(v),
|
||||
window_size=self.window_size)
|
||||
else:
|
||||
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
|
||||
None,
|
||||
query=half(q),
|
||||
key=half(k),
|
||||
value=half(v),
|
||||
window_size=self.window_size)
|
||||
|
||||
# TODO: padding after attention.
|
||||
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
||||
|
@ -10,6 +10,13 @@ import torchvision
|
||||
|
||||
__all__ = ['cache_video', 'cache_image', 'str2bool']
|
||||
|
||||
try:
|
||||
import sageattention
|
||||
HAS_SAGE_ATTENTION = True
|
||||
except ImportError:
|
||||
HAS_SAGE_ATTENTION = False
|
||||
|
||||
ENABLE_SAGE_ATTENTION = os.environ.get('ENABLE_SAGE_ATTENTION', '0') == '1' and HAS_SAGE_ATTENTION
|
||||
|
||||
def rand_name(length=8, suffix=''):
|
||||
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
||||
|
Loading…
Reference in New Issue
Block a user