mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
171 lines
4.8 KiB
Python
171 lines
4.8 KiB
Python
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
|
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
__all__ = ['XLMRoberta', 'xlm_roberta_large']
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
|
|
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
|
assert dim % num_heads == 0
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.eps = eps
|
|
|
|
# layers
|
|
self.q = nn.Linear(dim, dim)
|
|
self.k = nn.Linear(dim, dim)
|
|
self.v = nn.Linear(dim, dim)
|
|
self.o = nn.Linear(dim, dim)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x, mask):
|
|
"""
|
|
x: [B, L, C].
|
|
"""
|
|
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
|
|
|
# compute query, key, value
|
|
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
|
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
|
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
|
|
|
# compute attention
|
|
p = self.dropout.p if self.training else 0.0
|
|
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
|
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
|
|
|
# output
|
|
x = self.o(x)
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
|
|
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.post_norm = post_norm
|
|
self.eps = eps
|
|
|
|
# layers
|
|
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
|
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
|
self.ffn = nn.Sequential(
|
|
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
|
nn.Dropout(dropout))
|
|
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
|
|
|
def forward(self, x, mask):
|
|
if self.post_norm:
|
|
x = self.norm1(x + self.attn(x, mask))
|
|
x = self.norm2(x + self.ffn(x))
|
|
else:
|
|
x = x + self.attn(self.norm1(x), mask)
|
|
x = x + self.ffn(self.norm2(x))
|
|
return x
|
|
|
|
|
|
class XLMRoberta(nn.Module):
|
|
"""
|
|
XLMRobertaModel with no pooler and no LM head.
|
|
"""
|
|
|
|
def __init__(self,
|
|
vocab_size=250002,
|
|
max_seq_len=514,
|
|
type_size=1,
|
|
pad_id=1,
|
|
dim=1024,
|
|
num_heads=16,
|
|
num_layers=24,
|
|
post_norm=True,
|
|
dropout=0.1,
|
|
eps=1e-5):
|
|
super().__init__()
|
|
self.vocab_size = vocab_size
|
|
self.max_seq_len = max_seq_len
|
|
self.type_size = type_size
|
|
self.pad_id = pad_id
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.num_layers = num_layers
|
|
self.post_norm = post_norm
|
|
self.eps = eps
|
|
|
|
# embeddings
|
|
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
|
self.type_embedding = nn.Embedding(type_size, dim)
|
|
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
# blocks
|
|
self.blocks = nn.ModuleList([
|
|
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
# norm layer
|
|
self.norm = nn.LayerNorm(dim, eps=eps)
|
|
|
|
def forward(self, ids):
|
|
"""
|
|
ids: [B, L] of torch.LongTensor.
|
|
"""
|
|
b, s = ids.shape
|
|
mask = ids.ne(self.pad_id).long()
|
|
|
|
# embeddings
|
|
x = self.token_embedding(ids) + \
|
|
self.type_embedding(torch.zeros_like(ids)) + \
|
|
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
|
if self.post_norm:
|
|
x = self.norm(x)
|
|
x = self.dropout(x)
|
|
|
|
# blocks
|
|
mask = torch.where(
|
|
mask.view(b, 1, 1, s).gt(0), 0.0,
|
|
torch.finfo(x.dtype).min)
|
|
for block in self.blocks:
|
|
x = block(x, mask)
|
|
|
|
# output
|
|
if not self.post_norm:
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
|
|
def xlm_roberta_large(pretrained=False,
|
|
return_tokenizer=False,
|
|
device='cpu',
|
|
**kwargs):
|
|
"""
|
|
XLMRobertaLarge adapted from Huggingface.
|
|
"""
|
|
# params
|
|
cfg = dict(
|
|
vocab_size=250002,
|
|
max_seq_len=514,
|
|
type_size=1,
|
|
pad_id=1,
|
|
dim=1024,
|
|
num_heads=16,
|
|
num_layers=24,
|
|
post_norm=True,
|
|
dropout=0.1,
|
|
eps=1e-5)
|
|
cfg.update(**kwargs)
|
|
|
|
# init a model on device
|
|
with torch.device(device):
|
|
model = XLMRoberta(**cfg)
|
|
return model
|