Merge branch 'release_0.0.3' into mul_device_teacache

This commit is contained in:
kousakawang 2025-05-26 21:58:25 +08:00 committed by GitHub
commit ed01a1e721
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,11 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import numpy as np import numpy as np
import logging import logging
import yunchang
from yunchang.kernels import AttnType
from xfuser.core.distributed import ( from xfuser.core.distributed import (
get_sequence_parallel_rank, get_sequence_parallel_rank,
@ -13,6 +16,12 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d from ..modules.model import sinusoidal_embedding_1d
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def pad_freqs(original_tensor, target_len): def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape seq_len, s1, s2 = original_tensor.shape
@ -280,6 +289,21 @@ def usp_attn_forward(self,
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
if FLASH_ATTN_3_AVAILABLE:
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
else:
x = xFuserLongContextAttention()(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
x = xFuserLongContextAttention()( x = xFuserLongContextAttention()(
None, None,
query=half(q), query=half(q),