diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index cf13c9e..2df66c4 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -1,8 +1,11 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.cuda.amp as amp + import numpy as np import logging +import yunchang +from yunchang.kernels import AttnType from xfuser.core.distributed import ( get_sequence_parallel_rank, @@ -13,6 +16,12 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ..modules.model import sinusoidal_embedding_1d +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + def pad_freqs(original_tensor, target_len): seq_len, s1, s2 = original_tensor.shape @@ -280,6 +289,21 @@ def usp_attn_forward(self, # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) + if FLASH_ATTN_3_AVAILABLE: + x = xFuserLongContextAttention(attn_type=AttnType.FA3)( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + else: + x = xFuserLongContextAttention()( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + x = xFuserLongContextAttention()( None, query=half(q),