From 4c35b3fd589b2018e7897665230535b84cb47557 Mon Sep 17 00:00:00 2001 From: "songrui.771" Date: Mon, 26 May 2025 14:46:33 +0800 Subject: [PATCH] add support for FA3 on multi-gpu inference --- wan/distributed/xdit_context_parallel.py | 25 ++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 4718577..b5c7e27 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -1,6 +1,10 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.cuda.amp as amp + +import yunchang +from yunchang.kernels import AttnType + from xfuser.core.distributed import ( get_sequence_parallel_rank, get_sequence_parallel_world_size, @@ -10,6 +14,12 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ..modules.model import sinusoidal_embedding_1d +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + def pad_freqs(original_tensor, target_len): seq_len, s1, s2 = original_tensor.shape @@ -210,6 +220,21 @@ def usp_attn_forward(self, # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) + if FLASH_ATTN_3_AVAILABLE: + x = xFuserLongContextAttention(attn_type=AttnType.FA3)( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + else: + x = xFuserLongContextAttention()( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + x = xFuserLongContextAttention()( None, query=half(q),