mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-03 22:04:21 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			514 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			514 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Modified from transformers.models.t5.modeling_t5
 | 
						|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						|
import logging
 | 
						|
import math
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
import torch.nn.functional as F
 | 
						|
 | 
						|
from .tokenizers import HuggingfaceTokenizer
 | 
						|
 | 
						|
__all__ = [
 | 
						|
    'T5Model',
 | 
						|
    'T5Encoder',
 | 
						|
    'T5Decoder',
 | 
						|
    'T5EncoderModel',
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
def fp16_clamp(x):
 | 
						|
    if x.dtype == torch.float16 and torch.isinf(x).any():
 | 
						|
        clamp = torch.finfo(x.dtype).max - 1000
 | 
						|
        x = torch.clamp(x, min=-clamp, max=clamp)
 | 
						|
    return x
 | 
						|
 | 
						|
 | 
						|
def init_weights(m):
 | 
						|
    if isinstance(m, T5LayerNorm):
 | 
						|
        nn.init.ones_(m.weight)
 | 
						|
    elif isinstance(m, T5Model):
 | 
						|
        nn.init.normal_(m.token_embedding.weight, std=1.0)
 | 
						|
    elif isinstance(m, T5FeedForward):
 | 
						|
        nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
 | 
						|
        nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
 | 
						|
        nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
 | 
						|
    elif isinstance(m, T5Attention):
 | 
						|
        nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
 | 
						|
        nn.init.normal_(m.k.weight, std=m.dim**-0.5)
 | 
						|
        nn.init.normal_(m.v.weight, std=m.dim**-0.5)
 | 
						|
        nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
 | 
						|
    elif isinstance(m, T5RelativeEmbedding):
 | 
						|
        nn.init.normal_(
 | 
						|
            m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
 | 
						|
 | 
						|
 | 
						|
class GELU(nn.Module):
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        return 0.5 * x * (1.0 + torch.tanh(
 | 
						|
            math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
 | 
						|
 | 
						|
 | 
						|
class T5LayerNorm(nn.Module):
 | 
						|
 | 
						|
    def __init__(self, dim, eps=1e-6):
 | 
						|
        super(T5LayerNorm, self).__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.eps = eps
 | 
						|
        self.weight = nn.Parameter(torch.ones(dim))
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
 | 
						|
                            self.eps)
 | 
						|
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
 | 
						|
            x = x.type_as(self.weight)
 | 
						|
        return self.weight * x
 | 
						|
 | 
						|
 | 
						|
class T5Attention(nn.Module):
 | 
						|
 | 
						|
    def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
 | 
						|
        assert dim_attn % num_heads == 0
 | 
						|
        super(T5Attention, self).__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.dim_attn = dim_attn
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.head_dim = dim_attn // num_heads
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.q = nn.Linear(dim, dim_attn, bias=False)
 | 
						|
        self.k = nn.Linear(dim, dim_attn, bias=False)
 | 
						|
        self.v = nn.Linear(dim, dim_attn, bias=False)
 | 
						|
        self.o = nn.Linear(dim_attn, dim, bias=False)
 | 
						|
        self.dropout = nn.Dropout(dropout)
 | 
						|
 | 
						|
    def forward(self, x, context=None, mask=None, pos_bias=None):
 | 
						|
        """
 | 
						|
        x:          [B, L1, C].
 | 
						|
        context:    [B, L2, C] or None.
 | 
						|
        mask:       [B, L2] or [B, L1, L2] or None.
 | 
						|
        """
 | 
						|
        # check inputs
 | 
						|
        context = x if context is None else context
 | 
						|
        b, n, c = x.size(0), self.num_heads, self.head_dim
 | 
						|
 | 
						|
        # compute query, key, value
 | 
						|
        q = self.q(x).view(b, -1, n, c)
 | 
						|
        k = self.k(context).view(b, -1, n, c)
 | 
						|
        v = self.v(context).view(b, -1, n, c)
 | 
						|
 | 
						|
        # attention bias
 | 
						|
        attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
 | 
						|
        if pos_bias is not None:
 | 
						|
            attn_bias += pos_bias
 | 
						|
        if mask is not None:
 | 
						|
            assert mask.ndim in [2, 3]
 | 
						|
            mask = mask.view(b, 1, 1,
 | 
						|
                             -1) if mask.ndim == 2 else mask.unsqueeze(1)
 | 
						|
            attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
 | 
						|
 | 
						|
        # compute attention (T5 does not use scaling)
 | 
						|
        attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
 | 
						|
        attn = F.softmax(attn.float(), dim=-1).type_as(attn)
 | 
						|
        x = torch.einsum('bnij,bjnc->binc', attn, v)
 | 
						|
 | 
						|
        # output
 | 
						|
        x = x.reshape(b, -1, n * c)
 | 
						|
        x = self.o(x)
 | 
						|
        x = self.dropout(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class T5FeedForward(nn.Module):
 | 
						|
 | 
						|
    def __init__(self, dim, dim_ffn, dropout=0.1):
 | 
						|
        super(T5FeedForward, self).__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.dim_ffn = dim_ffn
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
 | 
						|
        self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
 | 
						|
        self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
 | 
						|
        self.dropout = nn.Dropout(dropout)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = self.fc1(x) * self.gate(x)
 | 
						|
        x = self.dropout(x)
 | 
						|
        x = self.fc2(x)
 | 
						|
        x = self.dropout(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class T5SelfAttention(nn.Module):
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 dim,
 | 
						|
                 dim_attn,
 | 
						|
                 dim_ffn,
 | 
						|
                 num_heads,
 | 
						|
                 num_buckets,
 | 
						|
                 shared_pos=True,
 | 
						|
                 dropout=0.1):
 | 
						|
        super(T5SelfAttention, self).__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.dim_attn = dim_attn
 | 
						|
        self.dim_ffn = dim_ffn
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.num_buckets = num_buckets
 | 
						|
        self.shared_pos = shared_pos
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.norm1 = T5LayerNorm(dim)
 | 
						|
        self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
 | 
						|
        self.norm2 = T5LayerNorm(dim)
 | 
						|
        self.ffn = T5FeedForward(dim, dim_ffn, dropout)
 | 
						|
        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
 | 
						|
            num_buckets, num_heads, bidirectional=True)
 | 
						|
 | 
						|
    def forward(self, x, mask=None, pos_bias=None):
 | 
						|
        e = pos_bias if self.shared_pos else self.pos_embedding(
 | 
						|
            x.size(1), x.size(1))
 | 
						|
        x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
 | 
						|
        x = fp16_clamp(x + self.ffn(self.norm2(x)))
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class T5CrossAttention(nn.Module):
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 dim,
 | 
						|
                 dim_attn,
 | 
						|
                 dim_ffn,
 | 
						|
                 num_heads,
 | 
						|
                 num_buckets,
 | 
						|
                 shared_pos=True,
 | 
						|
                 dropout=0.1):
 | 
						|
        super(T5CrossAttention, self).__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.dim_attn = dim_attn
 | 
						|
        self.dim_ffn = dim_ffn
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.num_buckets = num_buckets
 | 
						|
        self.shared_pos = shared_pos
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.norm1 = T5LayerNorm(dim)
 | 
						|
        self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
 | 
						|
        self.norm2 = T5LayerNorm(dim)
 | 
						|
        self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
 | 
						|
        self.norm3 = T5LayerNorm(dim)
 | 
						|
        self.ffn = T5FeedForward(dim, dim_ffn, dropout)
 | 
						|
        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
 | 
						|
            num_buckets, num_heads, bidirectional=False)
 | 
						|
 | 
						|
    def forward(self,
 | 
						|
                x,
 | 
						|
                mask=None,
 | 
						|
                encoder_states=None,
 | 
						|
                encoder_mask=None,
 | 
						|
                pos_bias=None):
 | 
						|
        e = pos_bias if self.shared_pos else self.pos_embedding(
 | 
						|
            x.size(1), x.size(1))
 | 
						|
        x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
 | 
						|
        x = fp16_clamp(x + self.cross_attn(
 | 
						|
            self.norm2(x), context=encoder_states, mask=encoder_mask))
 | 
						|
        x = fp16_clamp(x + self.ffn(self.norm3(x)))
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class T5RelativeEmbedding(nn.Module):
 | 
						|
 | 
						|
    def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
 | 
						|
        super(T5RelativeEmbedding, self).__init__()
 | 
						|
        self.num_buckets = num_buckets
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.bidirectional = bidirectional
 | 
						|
        self.max_dist = max_dist
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.embedding = nn.Embedding(num_buckets, num_heads)
 | 
						|
 | 
						|
    def forward(self, lq, lk):
 | 
						|
        device = self.embedding.weight.device
 | 
						|
        # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
 | 
						|
        #     torch.arange(lq).unsqueeze(1).to(device)
 | 
						|
        rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
 | 
						|
            torch.arange(lq, device=device).unsqueeze(1)
 | 
						|
        rel_pos = self._relative_position_bucket(rel_pos)
 | 
						|
        rel_pos_embeds = self.embedding(rel_pos)
 | 
						|
        rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
 | 
						|
            0)  # [1, N, Lq, Lk]
 | 
						|
        return rel_pos_embeds.contiguous()
 | 
						|
 | 
						|
    def _relative_position_bucket(self, rel_pos):
 | 
						|
        # preprocess
 | 
						|
        if self.bidirectional:
 | 
						|
            num_buckets = self.num_buckets // 2
 | 
						|
            rel_buckets = (rel_pos > 0).long() * num_buckets
 | 
						|
            rel_pos = torch.abs(rel_pos)
 | 
						|
        else:
 | 
						|
            num_buckets = self.num_buckets
 | 
						|
            rel_buckets = 0
 | 
						|
            rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
 | 
						|
 | 
						|
        # embeddings for small and large positions
 | 
						|
        max_exact = num_buckets // 2
 | 
						|
        rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
 | 
						|
                                     math.log(self.max_dist / max_exact) *
 | 
						|
                                     (num_buckets - max_exact)).long()
 | 
						|
        rel_pos_large = torch.min(
 | 
						|
            rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
 | 
						|
        rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
 | 
						|
        return rel_buckets
 | 
						|
 | 
						|
 | 
						|
class T5Encoder(nn.Module):
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 vocab,
 | 
						|
                 dim,
 | 
						|
                 dim_attn,
 | 
						|
                 dim_ffn,
 | 
						|
                 num_heads,
 | 
						|
                 num_layers,
 | 
						|
                 num_buckets,
 | 
						|
                 shared_pos=True,
 | 
						|
                 dropout=0.1):
 | 
						|
        super(T5Encoder, self).__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.dim_attn = dim_attn
 | 
						|
        self.dim_ffn = dim_ffn
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.num_layers = num_layers
 | 
						|
        self.num_buckets = num_buckets
 | 
						|
        self.shared_pos = shared_pos
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
 | 
						|
            else nn.Embedding(vocab, dim)
 | 
						|
        self.pos_embedding = T5RelativeEmbedding(
 | 
						|
            num_buckets, num_heads, bidirectional=True) if shared_pos else None
 | 
						|
        self.dropout = nn.Dropout(dropout)
 | 
						|
        self.blocks = nn.ModuleList([
 | 
						|
            T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
 | 
						|
                            shared_pos, dropout) for _ in range(num_layers)
 | 
						|
        ])
 | 
						|
        self.norm = T5LayerNorm(dim)
 | 
						|
 | 
						|
        # initialize weights
 | 
						|
        self.apply(init_weights)
 | 
						|
 | 
						|
    def forward(self, ids, mask=None):
 | 
						|
        x = self.token_embedding(ids)
 | 
						|
        x = self.dropout(x)
 | 
						|
        e = self.pos_embedding(x.size(1),
 | 
						|
                               x.size(1)) if self.shared_pos else None
 | 
						|
        for block in self.blocks:
 | 
						|
            x = block(x, mask, pos_bias=e)
 | 
						|
        x = self.norm(x)
 | 
						|
        x = self.dropout(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class T5Decoder(nn.Module):
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 vocab,
 | 
						|
                 dim,
 | 
						|
                 dim_attn,
 | 
						|
                 dim_ffn,
 | 
						|
                 num_heads,
 | 
						|
                 num_layers,
 | 
						|
                 num_buckets,
 | 
						|
                 shared_pos=True,
 | 
						|
                 dropout=0.1):
 | 
						|
        super(T5Decoder, self).__init__()
 | 
						|
        self.dim = dim
 | 
						|
        self.dim_attn = dim_attn
 | 
						|
        self.dim_ffn = dim_ffn
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.num_layers = num_layers
 | 
						|
        self.num_buckets = num_buckets
 | 
						|
        self.shared_pos = shared_pos
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
 | 
						|
            else nn.Embedding(vocab, dim)
 | 
						|
        self.pos_embedding = T5RelativeEmbedding(
 | 
						|
            num_buckets, num_heads, bidirectional=False) if shared_pos else None
 | 
						|
        self.dropout = nn.Dropout(dropout)
 | 
						|
        self.blocks = nn.ModuleList([
 | 
						|
            T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
 | 
						|
                             shared_pos, dropout) for _ in range(num_layers)
 | 
						|
        ])
 | 
						|
        self.norm = T5LayerNorm(dim)
 | 
						|
 | 
						|
        # initialize weights
 | 
						|
        self.apply(init_weights)
 | 
						|
 | 
						|
    def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
 | 
						|
        b, s = ids.size()
 | 
						|
 | 
						|
        # causal mask
 | 
						|
        if mask is None:
 | 
						|
            mask = torch.tril(torch.ones(1, s, s).to(ids.device))
 | 
						|
        elif mask.ndim == 2:
 | 
						|
            mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
 | 
						|
 | 
						|
        # layers
 | 
						|
        x = self.token_embedding(ids)
 | 
						|
        x = self.dropout(x)
 | 
						|
        e = self.pos_embedding(x.size(1),
 | 
						|
                               x.size(1)) if self.shared_pos else None
 | 
						|
        for block in self.blocks:
 | 
						|
            x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
 | 
						|
        x = self.norm(x)
 | 
						|
        x = self.dropout(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class T5Model(nn.Module):
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 vocab_size,
 | 
						|
                 dim,
 | 
						|
                 dim_attn,
 | 
						|
                 dim_ffn,
 | 
						|
                 num_heads,
 | 
						|
                 encoder_layers,
 | 
						|
                 decoder_layers,
 | 
						|
                 num_buckets,
 | 
						|
                 shared_pos=True,
 | 
						|
                 dropout=0.1):
 | 
						|
        super(T5Model, self).__init__()
 | 
						|
        self.vocab_size = vocab_size
 | 
						|
        self.dim = dim
 | 
						|
        self.dim_attn = dim_attn
 | 
						|
        self.dim_ffn = dim_ffn
 | 
						|
        self.num_heads = num_heads
 | 
						|
        self.encoder_layers = encoder_layers
 | 
						|
        self.decoder_layers = decoder_layers
 | 
						|
        self.num_buckets = num_buckets
 | 
						|
 | 
						|
        # layers
 | 
						|
        self.token_embedding = nn.Embedding(vocab_size, dim)
 | 
						|
        self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
 | 
						|
                                 num_heads, encoder_layers, num_buckets,
 | 
						|
                                 shared_pos, dropout)
 | 
						|
        self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
 | 
						|
                                 num_heads, decoder_layers, num_buckets,
 | 
						|
                                 shared_pos, dropout)
 | 
						|
        self.head = nn.Linear(dim, vocab_size, bias=False)
 | 
						|
 | 
						|
        # initialize weights
 | 
						|
        self.apply(init_weights)
 | 
						|
 | 
						|
    def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
 | 
						|
        x = self.encoder(encoder_ids, encoder_mask)
 | 
						|
        x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
 | 
						|
        x = self.head(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
def _t5(name,
 | 
						|
        encoder_only=False,
 | 
						|
        decoder_only=False,
 | 
						|
        return_tokenizer=False,
 | 
						|
        tokenizer_kwargs={},
 | 
						|
        dtype=torch.float32,
 | 
						|
        device='cpu',
 | 
						|
        **kwargs):
 | 
						|
    # sanity check
 | 
						|
    assert not (encoder_only and decoder_only)
 | 
						|
 | 
						|
    # params
 | 
						|
    if encoder_only:
 | 
						|
        model_cls = T5Encoder
 | 
						|
        kwargs['vocab'] = kwargs.pop('vocab_size')
 | 
						|
        kwargs['num_layers'] = kwargs.pop('encoder_layers')
 | 
						|
        _ = kwargs.pop('decoder_layers')
 | 
						|
    elif decoder_only:
 | 
						|
        model_cls = T5Decoder
 | 
						|
        kwargs['vocab'] = kwargs.pop('vocab_size')
 | 
						|
        kwargs['num_layers'] = kwargs.pop('decoder_layers')
 | 
						|
        _ = kwargs.pop('encoder_layers')
 | 
						|
    else:
 | 
						|
        model_cls = T5Model
 | 
						|
 | 
						|
    # init model
 | 
						|
    with torch.device(device):
 | 
						|
        model = model_cls(**kwargs)
 | 
						|
 | 
						|
    # set device
 | 
						|
    model = model.to(dtype=dtype, device=device)
 | 
						|
 | 
						|
    # init tokenizer
 | 
						|
    if return_tokenizer:
 | 
						|
        from .tokenizers import HuggingfaceTokenizer
 | 
						|
        tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
 | 
						|
        return model, tokenizer
 | 
						|
    else:
 | 
						|
        return model
 | 
						|
 | 
						|
 | 
						|
def umt5_xxl(**kwargs):
 | 
						|
    cfg = dict(
 | 
						|
        vocab_size=256384,
 | 
						|
        dim=4096,
 | 
						|
        dim_attn=4096,
 | 
						|
        dim_ffn=10240,
 | 
						|
        num_heads=64,
 | 
						|
        encoder_layers=24,
 | 
						|
        decoder_layers=24,
 | 
						|
        num_buckets=32,
 | 
						|
        shared_pos=False,
 | 
						|
        dropout=0.1)
 | 
						|
    cfg.update(**kwargs)
 | 
						|
    return _t5('umt5-xxl', **cfg)
 | 
						|
 | 
						|
 | 
						|
class T5EncoderModel:
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        text_len,
 | 
						|
        dtype=torch.bfloat16,
 | 
						|
        device=torch.cuda.current_device(),
 | 
						|
        checkpoint_path=None,
 | 
						|
        tokenizer_path=None,
 | 
						|
        shard_fn=None,
 | 
						|
    ):
 | 
						|
        self.text_len = text_len
 | 
						|
        self.dtype = dtype
 | 
						|
        self.device = device
 | 
						|
        self.checkpoint_path = checkpoint_path
 | 
						|
        self.tokenizer_path = tokenizer_path
 | 
						|
 | 
						|
        # init model
 | 
						|
        model = umt5_xxl(
 | 
						|
            encoder_only=True,
 | 
						|
            return_tokenizer=False,
 | 
						|
            dtype=dtype,
 | 
						|
            device=device).eval().requires_grad_(False)
 | 
						|
        logging.info(f'loading {checkpoint_path}')
 | 
						|
        model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
 | 
						|
        self.model = model
 | 
						|
        if shard_fn is not None:
 | 
						|
            self.model = shard_fn(self.model, sync_module_states=False)
 | 
						|
        else:
 | 
						|
            self.model.to(self.device)
 | 
						|
        # init tokenizer
 | 
						|
        self.tokenizer = HuggingfaceTokenizer(
 | 
						|
            name=tokenizer_path, seq_len=text_len, clean='whitespace')
 | 
						|
 | 
						|
    def __call__(self, texts, device):
 | 
						|
        ids, mask = self.tokenizer(
 | 
						|
            texts, return_mask=True, add_special_tokens=True)
 | 
						|
        ids = ids.to(device)
 | 
						|
        mask = mask.to(device)
 | 
						|
        seq_lens = mask.gt(0).sum(dim=1).long()
 | 
						|
        context = self.model(ids, mask)
 | 
						|
        return [u[:v] for u, v in zip(context, seq_lens)]
 |