Merge branch 'release_0.0.3' into mul_device_teacache

This commit is contained in:
kousakawang 2025-05-26 21:58:25 +08:00 committed by GitHub
commit ed01a1e721
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,11 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
import numpy as np
import logging
import yunchang
from yunchang.kernels import AttnType
from xfuser.core.distributed import (
get_sequence_parallel_rank,
@ -13,6 +16,12 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
@ -280,6 +289,21 @@ def usp_attn_forward(self,
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
if FLASH_ATTN_3_AVAILABLE:
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
else:
x = xFuserLongContextAttention()(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
x = xFuserLongContextAttention()(
None,
query=half(q),