mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
382 lines
12 KiB
Python
382 lines
12 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
from torch import nn
|
|
import torch
|
|
from typing import Tuple, Optional
|
|
from einops import rearrange
|
|
import torch.nn.functional as F
|
|
import math
|
|
from shared.attention import pay_attention
|
|
|
|
MEMORY_LAYOUT = {
|
|
"flash": (
|
|
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
|
lambda x: x,
|
|
),
|
|
"torch": (
|
|
lambda x: x.transpose(1, 2),
|
|
lambda x: x.transpose(1, 2),
|
|
),
|
|
"vanilla": (
|
|
lambda x: x.transpose(1, 2),
|
|
lambda x: x.transpose(1, 2),
|
|
),
|
|
}
|
|
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
mode="torch",
|
|
drop_rate=0,
|
|
attn_mask=None,
|
|
causal=False,
|
|
max_seqlen_q=None,
|
|
batch_size=1,
|
|
):
|
|
"""
|
|
Perform QKV self attention.
|
|
|
|
Args:
|
|
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
|
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
|
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
|
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
|
drop_rate (float): Dropout rate in attention map. (default: 0)
|
|
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
|
(default: None)
|
|
causal (bool): Whether to use causal attention. (default: False)
|
|
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
|
used to index into q.
|
|
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
|
used to index into kv.
|
|
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
|
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
|
|
|
Returns:
|
|
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
|
"""
|
|
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
|
|
|
if mode == "torch":
|
|
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
|
attn_mask = attn_mask.to(q.dtype)
|
|
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
|
|
|
elif mode == "flash":
|
|
x = flash_attn_func(
|
|
q,
|
|
k,
|
|
v,
|
|
)
|
|
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
|
elif mode == "vanilla":
|
|
scale_factor = 1 / math.sqrt(q.size(-1))
|
|
|
|
b, a, s, _ = q.shape
|
|
s1 = k.size(2)
|
|
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
|
if causal:
|
|
# Only applied to self attention
|
|
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
|
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
attn_bias.to(q.dtype)
|
|
|
|
if attn_mask is not None:
|
|
if attn_mask.dtype == torch.bool:
|
|
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
else:
|
|
attn_bias += attn_mask
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
|
attn += attn_bias
|
|
attn = attn.softmax(dim=-1)
|
|
attn = torch.dropout(attn, p=drop_rate, train=True)
|
|
x = attn @ v
|
|
else:
|
|
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
|
|
|
x = post_attn_layout(x)
|
|
b, s, a, d = x.shape
|
|
out = x.reshape(b, s, -1)
|
|
return out
|
|
|
|
|
|
class CausalConv1d(nn.Module):
|
|
|
|
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
|
super().__init__()
|
|
|
|
self.pad_mode = pad_mode
|
|
padding = (kernel_size - 1, 0) # T
|
|
self.time_causal_padding = padding
|
|
|
|
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
|
|
|
def forward(self, x):
|
|
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
|
return self.conv(x)
|
|
|
|
|
|
|
|
class FaceEncoder(nn.Module):
|
|
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
|
factory_kwargs = {"dtype": dtype, "device": device}
|
|
super().__init__()
|
|
|
|
self.num_heads = num_heads
|
|
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
|
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
self.act = nn.SiLU()
|
|
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
|
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
|
|
|
self.out_proj = nn.Linear(1024, hidden_dim)
|
|
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
|
|
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
|
|
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
|
|
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
|
|
|
def forward(self, x):
|
|
|
|
x = rearrange(x, "b t c -> b c t")
|
|
b, c, t = x.shape
|
|
|
|
x = self.conv1_local(x)
|
|
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
|
|
|
x = self.norm1(x)
|
|
x = self.act(x)
|
|
x = rearrange(x, "b t c -> b c t")
|
|
x = self.conv2(x)
|
|
x = rearrange(x, "b c t -> b t c")
|
|
x = self.norm2(x)
|
|
x = self.act(x)
|
|
x = rearrange(x, "b t c -> b c t")
|
|
x = self.conv3(x)
|
|
x = rearrange(x, "b c t -> b t c")
|
|
x = self.norm3(x)
|
|
x = self.act(x)
|
|
x = self.out_proj(x)
|
|
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
|
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
|
x = torch.cat([x, padding], dim=-2)
|
|
x_local = x.clone()
|
|
|
|
return x_local
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
elementwise_affine=True,
|
|
eps: float = 1e-6,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
"""
|
|
Initialize the RMSNorm normalization layer.
|
|
|
|
Args:
|
|
dim (int): The dimension of the input tensor.
|
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
|
|
|
Attributes:
|
|
eps (float): A small value added to the denominator for numerical stability.
|
|
weight (nn.Parameter): Learnable scaling parameter.
|
|
|
|
"""
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.eps = eps
|
|
if elementwise_affine:
|
|
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
|
|
|
def _norm(self, x):
|
|
"""
|
|
Apply the RMSNorm normalization to the input tensor.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The normalized tensor.
|
|
|
|
"""
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Forward pass through the RMSNorm layer.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The output tensor after applying RMSNorm.
|
|
|
|
"""
|
|
output = self._norm(x.float()).type_as(x)
|
|
if hasattr(self, "weight"):
|
|
output = output * self.weight
|
|
return output
|
|
|
|
|
|
def get_norm_layer(norm_layer):
|
|
"""
|
|
Get the normalization layer.
|
|
|
|
Args:
|
|
norm_layer (str): The type of normalization layer.
|
|
|
|
Returns:
|
|
norm_layer (nn.Module): The normalization layer.
|
|
"""
|
|
if norm_layer == "layer":
|
|
return nn.LayerNorm
|
|
elif norm_layer == "rms":
|
|
return RMSNorm
|
|
else:
|
|
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
|
|
|
|
|
class FaceAdapter(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_dim: int,
|
|
heads_num: int,
|
|
qk_norm: bool = True,
|
|
qk_norm_type: str = "rms",
|
|
num_adapter_layers: int = 1,
|
|
dtype=None,
|
|
device=None,
|
|
):
|
|
|
|
factory_kwargs = {"dtype": dtype, "device": device}
|
|
super().__init__()
|
|
self.hidden_size = hidden_dim
|
|
self.heads_num = heads_num
|
|
self.fuser_blocks = nn.ModuleList(
|
|
[
|
|
FaceBlock(
|
|
self.hidden_size,
|
|
self.heads_num,
|
|
qk_norm=qk_norm,
|
|
qk_norm_type=qk_norm_type,
|
|
**factory_kwargs,
|
|
)
|
|
for _ in range(num_adapter_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
motion_embed: torch.Tensor,
|
|
idx: int,
|
|
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
|
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
|
|
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
|
|
|
|
|
|
|
class FaceBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
heads_num: int,
|
|
qk_norm: bool = True,
|
|
qk_norm_type: str = "rms",
|
|
qk_scale: float = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
|
|
self.deterministic = False
|
|
self.hidden_size = hidden_size
|
|
self.heads_num = heads_num
|
|
head_dim = hidden_size // heads_num
|
|
self.scale = qk_scale or head_dim**-0.5
|
|
|
|
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
|
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
|
|
|
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
|
|
|
qk_norm_layer = get_norm_layer(qk_norm_type)
|
|
self.q_norm = (
|
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
|
)
|
|
self.k_norm = (
|
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
|
)
|
|
|
|
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
|
|
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
motion_vec: torch.Tensor,
|
|
motion_mask: Optional[torch.Tensor] = None,
|
|
use_context_parallel=False,
|
|
) -> torch.Tensor:
|
|
|
|
B, T, N, C = motion_vec.shape
|
|
T_comp = T
|
|
|
|
x_motion = self.pre_norm_motion(motion_vec)
|
|
x_feat = self.pre_norm_feat(x)
|
|
|
|
kv = self.linear1_kv(x_motion)
|
|
q = self.linear1_q(x_feat)
|
|
|
|
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
|
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
|
|
|
# Apply QK-Norm if needed.
|
|
q = self.q_norm(q).to(v)
|
|
k = self.k_norm(k).to(v)
|
|
|
|
k = rearrange(k, "B L N H D -> (B L) N H D")
|
|
v = rearrange(v, "B L N H D -> (B L) N H D")
|
|
|
|
if use_context_parallel:
|
|
q = gather_forward(q, dim=1)
|
|
|
|
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
|
|
# Compute attention.
|
|
# Size([batches, tokens, heads, head_features])
|
|
qkv_list = [q, k, v]
|
|
del q,k,v
|
|
attn = pay_attention(qkv_list)
|
|
# attn = attention(
|
|
# q,
|
|
# k,
|
|
# v,
|
|
# max_seqlen_q=q.shape[1],
|
|
# batch_size=q.shape[0],
|
|
# )
|
|
|
|
attn = attn.reshape(*attn.shape[:2], -1)
|
|
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
|
|
# if use_context_parallel:
|
|
# attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
|
|
|
|
output = self.linear2(attn)
|
|
|
|
if motion_mask is not None:
|
|
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
|
|
|
return output |