mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-13 19:20:09 +00:00
236 lines
6.9 KiB
Python
236 lines
6.9 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
import torch
|
|
import math
|
|
|
|
try:
|
|
import flash_attn_interface
|
|
|
|
FLASH_ATTN_3_AVAILABLE = True
|
|
except ModuleNotFoundError:
|
|
FLASH_ATTN_3_AVAILABLE = False
|
|
|
|
try:
|
|
import flash_attn
|
|
|
|
FLASH_ATTN_2_AVAILABLE = True
|
|
except ModuleNotFoundError:
|
|
FLASH_ATTN_2_AVAILABLE = False
|
|
|
|
import warnings
|
|
|
|
__all__ = [
|
|
'flash_attention',
|
|
'attention',
|
|
]
|
|
|
|
DEBUG_ATTENTION = True
|
|
|
|
|
|
def log_debug(message):
|
|
if DEBUG_ATTENTION:
|
|
print(f"[DEBUG] {message}")
|
|
|
|
|
|
def manual_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
q_lens=None,
|
|
k_lens=None,
|
|
dropout_p=0.,
|
|
softmax_scale=None,
|
|
q_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1),
|
|
deterministic=False,
|
|
dtype=torch.bfloat16,
|
|
):
|
|
"""Attention manuelle optimisée pour tous les devices"""
|
|
# Déplacement immédiat sur le bon device
|
|
device = q.device
|
|
k = k.to(device)
|
|
v = v.to(device)
|
|
if q_lens is not None: q_lens = q_lens.to(device)
|
|
if k_lens is not None: k_lens = k_lens.to(device)
|
|
|
|
B, Lq, N, C = q.shape
|
|
_, Lk, _, _ = k.shape
|
|
original_dtype = q.dtype
|
|
|
|
# Conversion au dtype de calcul
|
|
q = q.to(dtype).transpose(1, 2)
|
|
k = k.to(dtype).transpose(1, 2)
|
|
v = v.to(dtype).transpose(1, 2)
|
|
|
|
# Scaling
|
|
scale_factor = softmax_scale or (1.0 / math.sqrt(C))
|
|
if q_scale is not None:
|
|
q = q * q_scale.view(1, -1, 1, 1)
|
|
|
|
# Calcul des scores d'attention
|
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
|
|
|
|
# Création des masques
|
|
attn_mask = torch.zeros(B, 1, Lq, Lk, device=device, dtype=torch.float32)
|
|
|
|
# Masque de padding des clés
|
|
if k_lens is not None:
|
|
key_mask = torch.arange(Lk, device=device)[None, :] < k_lens[:, None]
|
|
attn_mask = attn_mask.masked_fill(~key_mask.view(B, 1, 1, Lk), float('-inf'))
|
|
|
|
# Masque causal
|
|
if causal:
|
|
causal_mask = torch.ones(Lq, Lk, device=device, dtype=torch.bool).tril()
|
|
attn_mask = attn_mask.masked_fill(~causal_mask, float('-inf'))
|
|
|
|
# Masque de fenêtre
|
|
if window_size != (-1, -1):
|
|
left, right = window_size
|
|
row = torch.arange(Lq, device=device)[:, None]
|
|
col = torch.arange(Lk, device=device)[None, :]
|
|
window_mask = (row - col >= -left) & (row - col <= right)
|
|
attn_mask = attn_mask.masked_fill(~window_mask, float('-inf'))
|
|
|
|
# Application du masque
|
|
attn_scores += attn_mask
|
|
|
|
# Softmax et dropout
|
|
attn_weights = torch.softmax(attn_scores, dim=-1)
|
|
if not deterministic and dropout_p > 0:
|
|
attn_weights = torch.dropout(attn_weights, dropout_p, True)
|
|
|
|
# Calcul de la sortie
|
|
out = torch.matmul(attn_weights, v)
|
|
|
|
# Masque de padding des requêtes
|
|
if q_lens is not None:
|
|
query_mask = torch.arange(Lq, device=device)[None, :] < q_lens[:, None]
|
|
out = out * query_mask.view(B, 1, Lq, 1).to(out.dtype)
|
|
|
|
# Retour au format original
|
|
return out.transpose(1, 2).contiguous().to(original_dtype)
|
|
|
|
|
|
def flash_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
q_lens=None,
|
|
k_lens=None,
|
|
dropout_p=0.,
|
|
softmax_scale=None,
|
|
q_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1),
|
|
deterministic=False,
|
|
dtype=torch.bfloat16,
|
|
version=None,
|
|
):
|
|
"""Wrapper pour FlashAttention avec fallback manuel"""
|
|
# Fallback si FlashAttention non disponible
|
|
if not (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE):
|
|
return manual_attention(
|
|
q, k, v, q_lens, k_lens, dropout_p, softmax_scale,
|
|
q_scale, causal, window_size, deterministic, dtype
|
|
)
|
|
|
|
# Paramètres GPU
|
|
device = q.device
|
|
b, lq, lk = q.size(0), q.size(1), k.size(1)
|
|
out_dtype = q.dtype
|
|
|
|
# Préparation des séquences
|
|
if q_lens is None:
|
|
q_lens = torch.full((b,), lq, dtype=torch.int32, device=device)
|
|
q_flat = q.flatten(0, 1)
|
|
else:
|
|
q_lens = q_lens.to(device)
|
|
q_flat = torch.cat([u[:l] for u, l in zip(q, q_lens)])
|
|
|
|
if k_lens is None:
|
|
k_lens = torch.full((b,), lk, dtype=torch.int32, device=device)
|
|
k_flat = k.flatten(0, 1)
|
|
v_flat = v.flatten(0, 1)
|
|
else:
|
|
k_lens = k_lens.to(device)
|
|
k_flat = torch.cat([u[:l] for u, l in zip(k, k_lens)])
|
|
v_flat = torch.cat([u[:l] for u, l in zip(v, k_lens)])
|
|
|
|
# Conversion de type
|
|
q_flat = q_flat.to(dtype)
|
|
k_flat = k_flat.to(dtype)
|
|
v_flat = v_flat.to(dtype)
|
|
|
|
# Application de q_scale
|
|
if q_scale is not None:
|
|
q_flat = q_flat * q_scale
|
|
|
|
# Préparation des séquences cumulatives
|
|
cu_seqlens_q = torch.cat([torch.tensor([0], device=device), q_lens.cumsum(0)])
|
|
cu_seqlens_k = torch.cat([torch.tensor([0], device=device), k_lens.cumsum(0)])
|
|
|
|
# Appel à FlashAttention
|
|
try:
|
|
if FLASH_ATTN_3_AVAILABLE and (version is None or version == 3):
|
|
x = flash_attn_interface.flash_attn_varlen_func(
|
|
q_flat, k_flat, v_flat,
|
|
cu_seqlens_q, cu_seqlens_k,
|
|
max_seqlen_q=lq, max_seqlen_k=lk,
|
|
softmax_scale=softmax_scale,
|
|
causal=causal,
|
|
deterministic=deterministic
|
|
)[0]
|
|
else:
|
|
x = flash_attn.flash_attn_varlen_func(
|
|
q_flat, k_flat, v_flat,
|
|
cu_seqlens_q, cu_seqlens_k,
|
|
max_seqlen_q=lq, max_seqlen_k=lk,
|
|
dropout_p=dropout_p,
|
|
softmax_scale=softmax_scale,
|
|
causal=causal,
|
|
window_size=window_size,
|
|
deterministic=deterministic
|
|
)
|
|
return x.unflatten(0, (b, lq)).to(out_dtype)
|
|
except Exception as e:
|
|
warnings.warn(f"FlashAttention failed: {e}, using manual attention")
|
|
return manual_attention(
|
|
q, k, v, q_lens, k_lens, dropout_p, softmax_scale,
|
|
q_scale, causal, window_size, deterministic, dtype
|
|
)
|
|
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
q_lens=None,
|
|
k_lens=None,
|
|
dropout_p=0.,
|
|
softmax_scale=None,
|
|
q_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1),
|
|
deterministic=False,
|
|
dtype=torch.bfloat16,
|
|
fa_version=None,
|
|
):
|
|
"""Fonction d'attention unifiée"""
|
|
# Synchronisation des devices
|
|
device = q.device
|
|
k = k.to(device)
|
|
v = v.to(device)
|
|
if q_lens is not None: q_lens = q_lens.to(device)
|
|
if k_lens is not None: k_lens = k_lens.to(device)
|
|
|
|
# Sélection de l'implémentation
|
|
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
|
return flash_attention(
|
|
q, k, v, q_lens, k_lens, dropout_p, softmax_scale,
|
|
q_scale, causal, window_size, deterministic, dtype, fa_version
|
|
)
|
|
else:
|
|
return manual_attention(
|
|
q, k, v, q_lens, k_lens, dropout_p, softmax_scale,
|
|
q_scale, causal, window_size, deterministic, dtype
|
|
) |