Merge pull request #3 from songruizz/feat/xdit-fa3

add support for Flash Attention 3 on multi-gpu inference
This commit is contained in:
zc8gerard 2025-05-26 21:11:34 +08:00 committed by GitHub
commit cce8316296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,10 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
import yunchang
from yunchang.kernels import AttnType
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
@ -10,6 +14,12 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
@ -210,6 +220,21 @@ def usp_attn_forward(self,
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
if FLASH_ATTN_3_AVAILABLE:
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
else:
x = xFuserLongContextAttention()(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
x = xFuserLongContextAttention()(
None,
query=half(q),