mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Merge 5f7e7ed289 into 7c81b2f27d
				
					
				
			This commit is contained in:
		
						commit
						3e1b0ea31e
					
				@ -79,53 +79,101 @@ def flash_attention(
 | 
			
		||||
        k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
 | 
			
		||||
        v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
 | 
			
		||||
 | 
			
		||||
    q = q.to(v.dtype)
 | 
			
		||||
    k = k.to(v.dtype)
 | 
			
		||||
    try:
 | 
			
		||||
        q = q.to(v.dtype)
 | 
			
		||||
        k = k.to(v.dtype)
 | 
			
		||||
    
 | 
			
		||||
        if q_scale is not None:
 | 
			
		||||
            q = q * q_scale
 | 
			
		||||
    
 | 
			
		||||
        if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                'Flash attention 3 is not available, use flash attention 2 instead.'
 | 
			
		||||
            )
 | 
			
		||||
    
 | 
			
		||||
        # apply attention
 | 
			
		||||
        if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
            # Note: dropout_p, window_size are not supported in FA3 now.
 | 
			
		||||
            x = flash_attn_interface.flash_attn_varlen_func(
 | 
			
		||||
                q=q,
 | 
			
		||||
                k=k,
 | 
			
		||||
                v=v,
 | 
			
		||||
                cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
 | 
			
		||||
                    0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
                cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
 | 
			
		||||
                    0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
                seqused_q=None,
 | 
			
		||||
                seqused_k=None,
 | 
			
		||||
                max_seqlen_q=lq,
 | 
			
		||||
                max_seqlen_k=lk,
 | 
			
		||||
                softmax_scale=softmax_scale,
 | 
			
		||||
                causal=causal,
 | 
			
		||||
                deterministic=deterministic)[0].unflatten(0, (b, lq))
 | 
			
		||||
        else:
 | 
			
		||||
            assert FLASH_ATTN_2_AVAILABLE
 | 
			
		||||
            x = flash_attn.flash_attn_varlen_func(
 | 
			
		||||
                q=q,
 | 
			
		||||
                k=k,
 | 
			
		||||
                v=v,
 | 
			
		||||
                cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
 | 
			
		||||
                    0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
                cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
 | 
			
		||||
                    0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
                max_seqlen_q=lq,
 | 
			
		||||
                max_seqlen_k=lk,
 | 
			
		||||
                dropout_p=dropout_p,
 | 
			
		||||
                softmax_scale=softmax_scale,
 | 
			
		||||
                causal=causal,
 | 
			
		||||
                window_size=window_size,
 | 
			
		||||
                deterministic=deterministic).unflatten(0, (b, lq))
 | 
			
		||||
    
 | 
			
		||||
    except RuntimeError as e:
 | 
			
		||||
        if "FlashAttention only supports Ampere GPUs or newer" in str(e):
 | 
			
		||||
            #for cards like 2080ti that aren't Ampere structure
 | 
			
		||||
            from torch import nn
 | 
			
		||||
            import torch.nn.functional as F
 | 
			
		||||
        
 | 
			
		||||
            q = q.to(half(k).dtype)
 | 
			
		||||
 | 
			
		||||
    if q_scale is not None:
 | 
			
		||||
        q = q * q_scale
 | 
			
		||||
            # 转置维度,保证形状为 [B, N, L, C]
 | 
			
		||||
            q = q.view(b, lq, q.size(1), q.size(2)).transpose(1, 2)
 | 
			
		||||
            k = k.view(b, lk, k.size(1), k.size(2)).transpose(1, 2)
 | 
			
		||||
            v = v.view(b, lk, v.size(1), v.size(2)).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            'Flash attention 3 is not available, use flash attention 2 instead.'
 | 
			
		||||
        )
 | 
			
		||||
            # 计算注意力
 | 
			
		||||
            # 注意:确保 Q、K、V 的形状为 [B, N, L, C]
 | 
			
		||||
            # 设置默认缩放因子
 | 
			
		||||
            if softmax_scale is None:
 | 
			
		||||
                softmax_scale = 1.0 / q.size(-1) ** 0.5
 | 
			
		||||
 | 
			
		||||
    # apply attention
 | 
			
		||||
    if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
        # Note: dropout_p, window_size are not supported in FA3 now.
 | 
			
		||||
        x = flash_attn_interface.flash_attn_varlen_func(
 | 
			
		||||
            q=q,
 | 
			
		||||
            k=k,
 | 
			
		||||
            v=v,
 | 
			
		||||
            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
 | 
			
		||||
                0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
 | 
			
		||||
                0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
            seqused_q=None,
 | 
			
		||||
            seqused_k=None,
 | 
			
		||||
            max_seqlen_q=lq,
 | 
			
		||||
            max_seqlen_k=lk,
 | 
			
		||||
            softmax_scale=softmax_scale,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            deterministic=deterministic)[0].unflatten(0, (b, lq))
 | 
			
		||||
    else:
 | 
			
		||||
        assert FLASH_ATTN_2_AVAILABLE
 | 
			
		||||
        x = flash_attn.flash_attn_varlen_func(
 | 
			
		||||
            q=q,
 | 
			
		||||
            k=k,
 | 
			
		||||
            v=v,
 | 
			
		||||
            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
 | 
			
		||||
                0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
 | 
			
		||||
                0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
            max_seqlen_q=lq,
 | 
			
		||||
            max_seqlen_k=lk,
 | 
			
		||||
            dropout_p=dropout_p,
 | 
			
		||||
            softmax_scale=softmax_scale,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            window_size=window_size,
 | 
			
		||||
            deterministic=deterministic).unflatten(0, (b, lq))
 | 
			
		||||
            # 如果 q_scale 存在,则应用缩放
 | 
			
		||||
            if q_scale is not None:
 | 
			
		||||
                q = q * q_scale
 | 
			
		||||
 | 
			
		||||
            # 创建掩码
 | 
			
		||||
            if causal:
 | 
			
		||||
                attn_mask = torch.triu(torch.full((q.size(2), k.size(2)), -torch.inf), diagonal=1).to(q.device)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_mask = None
 | 
			
		||||
 | 
			
		||||
            # 计算注意力
 | 
			
		||||
            # 使用 scaled_dot_product_attention
 | 
			
		||||
            x = F.scaled_dot_product_attention(
 | 
			
		||||
                q, k, v,
 | 
			
		||||
                attn_mask=attn_mask,
 | 
			
		||||
                dropout_p=dropout_p,
 | 
			
		||||
                is_causal=causal,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # 转换回原形状 [B, L, N, C]
 | 
			
		||||
            x = x.transpose(1, 2).contiguous()
 | 
			
		||||
 | 
			
		||||
            # 对输出应用 Dropout
 | 
			
		||||
            dropout = nn.Dropout(dropout_p)
 | 
			
		||||
            x = dropout(x)            
 | 
			
		||||
        else:
 | 
			
		||||
            raise
 | 
			
		||||
            
 | 
			
		||||
    # output
 | 
			
		||||
    return x.type(out_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user