mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-14 03:30:10 +00:00
add support for FA3 on multi-gpu inference
This commit is contained in:
parent
e5a741309d
commit
4c35b3fd58
@ -1,6 +1,10 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
|
|
||||||
|
import yunchang
|
||||||
|
from yunchang.kernels import AttnType
|
||||||
|
|
||||||
from xfuser.core.distributed import (
|
from xfuser.core.distributed import (
|
||||||
get_sequence_parallel_rank,
|
get_sequence_parallel_rank,
|
||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
@ -10,6 +14,12 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
|||||||
|
|
||||||
from ..modules.model import sinusoidal_embedding_1d
|
from ..modules.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_interface
|
||||||
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def pad_freqs(original_tensor, target_len):
|
def pad_freqs(original_tensor, target_len):
|
||||||
seq_len, s1, s2 = original_tensor.shape
|
seq_len, s1, s2 = original_tensor.shape
|
||||||
@ -210,6 +220,21 @@ def usp_attn_forward(self,
|
|||||||
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
||||||
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
||||||
|
|
||||||
|
if FLASH_ATTN_3_AVAILABLE:
|
||||||
|
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
|
||||||
|
None,
|
||||||
|
query=half(q),
|
||||||
|
key=half(k),
|
||||||
|
value=half(v),
|
||||||
|
window_size=self.window_size)
|
||||||
|
else:
|
||||||
|
x = xFuserLongContextAttention()(
|
||||||
|
None,
|
||||||
|
query=half(q),
|
||||||
|
key=half(k),
|
||||||
|
value=half(v),
|
||||||
|
window_size=self.window_size)
|
||||||
|
|
||||||
x = xFuserLongContextAttention()(
|
x = xFuserLongContextAttention()(
|
||||||
None,
|
None,
|
||||||
query=half(q),
|
query=half(q),
|
||||||
|
Loading…
Reference in New Issue
Block a user