diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 4dbbe03..d145753 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -1,14 +1,17 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch +import math try: import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_3_AVAILABLE = False try: import flash_attn + FLASH_ATTN_2_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False @@ -20,160 +23,214 @@ __all__ = [ 'attention', ] +DEBUG_ATTENTION = True + + +def log_debug(message): + if DEBUG_ATTENTION: + print(f"[DEBUG] {message}") + + +def manual_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, +): + """Attention manuelle optimisée pour tous les devices""" + # Déplacement immédiat sur le bon device + device = q.device + k = k.to(device) + v = v.to(device) + if q_lens is not None: q_lens = q_lens.to(device) + if k_lens is not None: k_lens = k_lens.to(device) + + B, Lq, N, C = q.shape + _, Lk, _, _ = k.shape + original_dtype = q.dtype + + # Conversion au dtype de calcul + q = q.to(dtype).transpose(1, 2) + k = k.to(dtype).transpose(1, 2) + v = v.to(dtype).transpose(1, 2) + + # Scaling + scale_factor = softmax_scale or (1.0 / math.sqrt(C)) + if q_scale is not None: + q = q * q_scale.view(1, -1, 1, 1) + + # Calcul des scores d'attention + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + + # Création des masques + attn_mask = torch.zeros(B, 1, Lq, Lk, device=device, dtype=torch.float32) + + # Masque de padding des clés + if k_lens is not None: + key_mask = torch.arange(Lk, device=device)[None, :] < k_lens[:, None] + attn_mask = attn_mask.masked_fill(~key_mask.view(B, 1, 1, Lk), float('-inf')) + + # Masque causal + if causal: + causal_mask = torch.ones(Lq, Lk, device=device, dtype=torch.bool).tril() + attn_mask = attn_mask.masked_fill(~causal_mask, float('-inf')) + + # Masque de fenêtre + if window_size != (-1, -1): + left, right = window_size + row = torch.arange(Lq, device=device)[:, None] + col = torch.arange(Lk, device=device)[None, :] + window_mask = (row - col >= -left) & (row - col <= right) + attn_mask = attn_mask.masked_fill(~window_mask, float('-inf')) + + # Application du masque + attn_scores += attn_mask + + # Softmax et dropout + attn_weights = torch.softmax(attn_scores, dim=-1) + if not deterministic and dropout_p > 0: + attn_weights = torch.dropout(attn_weights, dropout_p, True) + + # Calcul de la sortie + out = torch.matmul(attn_weights, v) + + # Masque de padding des requêtes + if q_lens is not None: + query_mask = torch.arange(Lq, device=device)[None, :] < q_lens[:, None] + out = out * query_mask.view(B, 1, Lq, 1).to(out.dtype) + + # Retour au format original + return out.transpose(1, 2).contiguous().to(original_dtype) + def flash_attention( - q, - k, - v, - q_lens=None, - k_lens=None, - dropout_p=0., - softmax_scale=None, - q_scale=None, - causal=False, - window_size=(-1, -1), - deterministic=False, - dtype=torch.bfloat16, - version=None, + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, ): - """ - q: [B, Lq, Nq, C1]. - k: [B, Lk, Nk, C1]. - v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. - q_lens: [B]. - k_lens: [B]. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - causal: bool. Whether to apply causal attention mask. - window_size: (left right). If not (-1, -1), apply sliding window local attention. - deterministic: bool. If True, slightly slower and uses more memory. - dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. - """ - half_dtypes = (torch.float16, torch.bfloat16) - assert dtype in half_dtypes - assert q.device.type == 'cuda' and q.size(-1) <= 256 - - # params - b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype - - def half(x): - return x if x.dtype in half_dtypes else x.to(dtype) - - # preprocess query - if q_lens is None: - q = half(q.flatten(0, 1)) - q_lens = torch.tensor( - [lq] * b, dtype=torch.int32).to( - device=q.device, non_blocking=True) - else: - q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) - - # preprocess key, value - if k_lens is None: - k = half(k.flatten(0, 1)) - v = half(v.flatten(0, 1)) - k_lens = torch.tensor( - [lk] * b, dtype=torch.int32).to( - device=k.device, non_blocking=True) - else: - k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) - v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) - - q = q.to(v.dtype) - k = k.to(v.dtype) - - if q_scale is not None: - q = q * q_scale - - if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: - warnings.warn( - 'Flash attention 3 is not available, use flash attention 2 instead.' + """Wrapper pour FlashAttention avec fallback manuel""" + # Fallback si FlashAttention non disponible + if not (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE): + return manual_attention( + q, k, v, q_lens, k_lens, dropout_p, softmax_scale, + q_scale, causal, window_size, deterministic, dtype ) - # apply attention - if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: - # Note: dropout_p, window_size are not supported in FA3 now. - x = flash_attn_interface.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - seqused_q=None, - seqused_k=None, - max_seqlen_q=lq, - max_seqlen_k=lk, - softmax_scale=softmax_scale, - causal=causal, - deterministic=deterministic)[0].unflatten(0, (b, lq)) - else: - assert FLASH_ATTN_2_AVAILABLE - x = flash_attn.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - max_seqlen_q=lq, - max_seqlen_k=lk, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - deterministic=deterministic).unflatten(0, (b, lq)) + # Paramètres GPU + device = q.device + b, lq, lk = q.size(0), q.size(1), k.size(1) + out_dtype = q.dtype - # output - return x.type(out_dtype) + # Préparation des séquences + if q_lens is None: + q_lens = torch.full((b,), lq, dtype=torch.int32, device=device) + q_flat = q.flatten(0, 1) + else: + q_lens = q_lens.to(device) + q_flat = torch.cat([u[:l] for u, l in zip(q, q_lens)]) + + if k_lens is None: + k_lens = torch.full((b,), lk, dtype=torch.int32, device=device) + k_flat = k.flatten(0, 1) + v_flat = v.flatten(0, 1) + else: + k_lens = k_lens.to(device) + k_flat = torch.cat([u[:l] for u, l in zip(k, k_lens)]) + v_flat = torch.cat([u[:l] for u, l in zip(v, k_lens)]) + + # Conversion de type + q_flat = q_flat.to(dtype) + k_flat = k_flat.to(dtype) + v_flat = v_flat.to(dtype) + + # Application de q_scale + if q_scale is not None: + q_flat = q_flat * q_scale + + # Préparation des séquences cumulatives + cu_seqlens_q = torch.cat([torch.tensor([0], device=device), q_lens.cumsum(0)]) + cu_seqlens_k = torch.cat([torch.tensor([0], device=device), k_lens.cumsum(0)]) + + # Appel à FlashAttention + try: + if FLASH_ATTN_3_AVAILABLE and (version is None or version == 3): + x = flash_attn_interface.flash_attn_varlen_func( + q_flat, k_flat, v_flat, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q=lq, max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic + )[0] + else: + x = flash_attn.flash_attn_varlen_func( + q_flat, k_flat, v_flat, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q=lq, max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic + ) + return x.unflatten(0, (b, lq)).to(out_dtype) + except Exception as e: + warnings.warn(f"FlashAttention failed: {e}, using manual attention") + return manual_attention( + q, k, v, q_lens, k_lens, dropout_p, softmax_scale, + q_scale, causal, window_size, deterministic, dtype + ) def attention( - q, - k, - v, - q_lens=None, - k_lens=None, - dropout_p=0., - softmax_scale=None, - q_scale=None, - causal=False, - window_size=(-1, -1), - deterministic=False, - dtype=torch.bfloat16, - fa_version=None, + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + fa_version=None, ): + """Fonction d'attention unifiée""" + # Synchronisation des devices + device = q.device + k = k.to(device) + v = v.to(device) + if q_lens is not None: q_lens = q_lens.to(device) + if k_lens is not None: k_lens = k_lens.to(device) + + # Sélection de l'implémentation if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: return flash_attention( - q=q, - k=k, - v=v, - q_lens=q_lens, - k_lens=k_lens, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - q_scale=q_scale, - causal=causal, - window_size=window_size, - deterministic=deterministic, - dtype=dtype, - version=fa_version, + q, k, v, q_lens, k_lens, dropout_p, softmax_scale, + q_scale, causal, window_size, deterministic, dtype, fa_version ) else: - if q_lens is not None or k_lens is not None: - warnings.warn( - 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' - ) - attn_mask = None - - q = q.transpose(1, 2).to(dtype) - k = k.transpose(1, 2).to(dtype) - v = v.transpose(1, 2).to(dtype) - - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) - - out = out.transpose(1, 2).contiguous() - return out + return manual_attention( + q, k, v, q_lens, k_lens, dropout_p, softmax_scale, + q_scale, causal, window_size, deterministic, dtype + ) \ No newline at end of file