add support for FA3 on multi-gpu inference

This commit is contained in:
songrui.771 2025-05-26 14:46:33 +08:00
parent e5a741309d
commit 4c35b3fd58

View File

@ -1,6 +1,10 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
import yunchang
from yunchang.kernels import AttnType
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
@ -10,6 +14,12 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
@ -210,6 +220,21 @@ def usp_attn_forward(self,
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
if FLASH_ATTN_3_AVAILABLE:
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
else:
x = xFuserLongContextAttention()(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
x = xFuserLongContextAttention()(
None,
query=half(q),