mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			171 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			171 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
 | 
						|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
import torch.nn.functional as F
 | 
						|
 | 
						|
__all__ = ['XLMRoberta', 'xlm_roberta_large']
 | 
						|
 | 
						|
 | 
						|
class SelfAttention(nn.Module):
 | 
						|
 | 
						|
    def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
 | 
						|
        assert dim % num_heads == 0
 | 
						|
        super().__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.head_dim = dim // num_heads
 | 
						|
        self.eps = eps
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.q = nn.Linear(dim, dim)
 | 
						|
        self.k = nn.Linear(dim, dim)
 | 
						|
        self.v = nn.Linear(dim, dim)
 | 
						|
        self.o = nn.Linear(dim, dim)
 | 
						|
        self.dropout = nn.Dropout(dropout)
 | 
						|
 | 
						|
    def forward(self, x, mask):
 | 
						|
        """
 | 
						|
        x:   [B, L, C].
 | 
						|
        """
 | 
						|
        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
 | 
						|
 | 
						|
        # compute query, key, value
 | 
						|
        q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
 | 
						|
        k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
 | 
						|
        v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
 | 
						|
 | 
						|
        # compute attention
 | 
						|
        p = self.dropout.p if self.training else 0.0
 | 
						|
        x = F.scaled_dot_product_attention(q, k, v, mask, p)
 | 
						|
        x = x.permute(0, 2, 1, 3).reshape(b, s, c)
 | 
						|
 | 
						|
        # output
 | 
						|
        x = self.o(x)
 | 
						|
        x = self.dropout(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class AttentionBlock(nn.Module):
 | 
						|
 | 
						|
    def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
 | 
						|
        super().__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.post_norm = post_norm
 | 
						|
        self.eps = eps
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.attn = SelfAttention(dim, num_heads, dropout, eps)
 | 
						|
        self.norm1 = nn.LayerNorm(dim, eps=eps)
 | 
						|
        self.ffn = nn.Sequential(
 | 
						|
            nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
 | 
						|
            nn.Dropout(dropout))
 | 
						|
        self.norm2 = nn.LayerNorm(dim, eps=eps)
 | 
						|
 | 
						|
    def forward(self, x, mask):
 | 
						|
        if self.post_norm:
 | 
						|
            x = self.norm1(x + self.attn(x, mask))
 | 
						|
            x = self.norm2(x + self.ffn(x))
 | 
						|
        else:
 | 
						|
            x = x + self.attn(self.norm1(x), mask)
 | 
						|
            x = x + self.ffn(self.norm2(x))
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class XLMRoberta(nn.Module):
 | 
						|
    """
 | 
						|
    XLMRobertaModel with no pooler and no LM head.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 vocab_size=250002,
 | 
						|
                 max_seq_len=514,
 | 
						|
                 type_size=1,
 | 
						|
                 pad_id=1,
 | 
						|
                 dim=1024,
 | 
						|
                 num_heads=16,
 | 
						|
                 num_layers=24,
 | 
						|
                 post_norm=True,
 | 
						|
                 dropout=0.1,
 | 
						|
                 eps=1e-5):
 | 
						|
        super().__init__()
 | 
						|
        self.vocab_size = vocab_size
 | 
						|
        self.max_seq_len = max_seq_len
 | 
						|
        self.type_size = type_size
 | 
						|
        self.pad_id = pad_id
 | 
						|
        self.dim = dim
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.num_layers = num_layers
 | 
						|
        self.post_norm = post_norm
 | 
						|
        self.eps = eps
 | 
						|
 | 
						|
        # embeddings
 | 
						|
        self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
 | 
						|
        self.type_embedding = nn.Embedding(type_size, dim)
 | 
						|
        self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
 | 
						|
        self.dropout = nn.Dropout(dropout)
 | 
						|
 | 
						|
        # blocks
 | 
						|
        self.blocks = nn.ModuleList([
 | 
						|
            AttentionBlock(dim, num_heads, post_norm, dropout, eps)
 | 
						|
            for _ in range(num_layers)
 | 
						|
        ])
 | 
						|
 | 
						|
        # norm layer
 | 
						|
        self.norm = nn.LayerNorm(dim, eps=eps)
 | 
						|
 | 
						|
    def forward(self, ids):
 | 
						|
        """
 | 
						|
        ids: [B, L] of torch.LongTensor.
 | 
						|
        """
 | 
						|
        b, s = ids.shape
 | 
						|
        mask = ids.ne(self.pad_id).long()
 | 
						|
 | 
						|
        # embeddings
 | 
						|
        x = self.token_embedding(ids) + \
 | 
						|
            self.type_embedding(torch.zeros_like(ids)) + \
 | 
						|
            self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
 | 
						|
        if self.post_norm:
 | 
						|
            x = self.norm(x)
 | 
						|
        x = self.dropout(x)
 | 
						|
 | 
						|
        # blocks
 | 
						|
        mask = torch.where(
 | 
						|
            mask.view(b, 1, 1, s).gt(0), 0.0,
 | 
						|
            torch.finfo(x.dtype).min)
 | 
						|
        for block in self.blocks:
 | 
						|
            x = block(x, mask)
 | 
						|
 | 
						|
        # output
 | 
						|
        if not self.post_norm:
 | 
						|
            x = self.norm(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
def xlm_roberta_large(pretrained=False,
 | 
						|
                      return_tokenizer=False,
 | 
						|
                      device='cpu',
 | 
						|
                      **kwargs):
 | 
						|
    """
 | 
						|
    XLMRobertaLarge adapted from Huggingface.
 | 
						|
    """
 | 
						|
    # params
 | 
						|
    cfg = dict(
 | 
						|
        vocab_size=250002,
 | 
						|
        max_seq_len=514,
 | 
						|
        type_size=1,
 | 
						|
        pad_id=1,
 | 
						|
        dim=1024,
 | 
						|
        num_heads=16,
 | 
						|
        num_layers=24,
 | 
						|
        post_norm=True,
 | 
						|
        dropout=0.1,
 | 
						|
        eps=1e-5)
 | 
						|
    cfg.update(**kwargs)
 | 
						|
 | 
						|
    # init a model on device
 | 
						|
    with torch.device(device):
 | 
						|
        model = XLMRoberta(**cfg)
 | 
						|
    return model
 |