mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			238 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			238 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Optional
 | 
						|
 | 
						|
from einops import rearrange
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
 | 
						|
from .activation_layers import get_activation_layer
 | 
						|
from .attenion import attention
 | 
						|
from .norm_layers import get_norm_layer
 | 
						|
from .embed_layers import TimestepEmbedder, TextProjection
 | 
						|
from .attenion import attention
 | 
						|
from .mlp_layers import MLP
 | 
						|
from .modulate_layers import modulate, apply_gate
 | 
						|
 | 
						|
 | 
						|
class IndividualTokenRefinerBlock(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        hidden_size,
 | 
						|
        heads_num,
 | 
						|
        mlp_width_ratio: str = 4.0,
 | 
						|
        mlp_drop_rate: float = 0.0,
 | 
						|
        act_type: str = "silu",
 | 
						|
        qk_norm: bool = False,
 | 
						|
        qk_norm_type: str = "layer",
 | 
						|
        qkv_bias: bool = True,
 | 
						|
        dtype: Optional[torch.dtype] = None,
 | 
						|
        device: Optional[torch.device] = None,
 | 
						|
    ):
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
        self.heads_num = heads_num
 | 
						|
        head_dim = hidden_size // heads_num
 | 
						|
        mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
 | 
						|
 | 
						|
        self.norm1 = nn.LayerNorm(
 | 
						|
            hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
 | 
						|
        )
 | 
						|
        self.self_attn_qkv = nn.Linear(
 | 
						|
            hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
 | 
						|
        )
 | 
						|
        qk_norm_layer = get_norm_layer(qk_norm_type)
 | 
						|
        self.self_attn_q_norm = (
 | 
						|
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
 | 
						|
            if qk_norm
 | 
						|
            else nn.Identity()
 | 
						|
        )
 | 
						|
        self.self_attn_k_norm = (
 | 
						|
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
 | 
						|
            if qk_norm
 | 
						|
            else nn.Identity()
 | 
						|
        )
 | 
						|
        self.self_attn_proj = nn.Linear(
 | 
						|
            hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.norm2 = nn.LayerNorm(
 | 
						|
            hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
 | 
						|
        )
 | 
						|
        act_layer = get_activation_layer(act_type)
 | 
						|
        self.mlp = MLP(
 | 
						|
            in_channels=hidden_size,
 | 
						|
            hidden_channels=mlp_hidden_dim,
 | 
						|
            act_layer=act_layer,
 | 
						|
            drop=mlp_drop_rate,
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
        self.adaLN_modulation = nn.Sequential(
 | 
						|
            act_layer(),
 | 
						|
            nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
 | 
						|
        )
 | 
						|
        # Zero-initialize the modulation
 | 
						|
        nn.init.zeros_(self.adaLN_modulation[1].weight)
 | 
						|
        nn.init.zeros_(self.adaLN_modulation[1].bias)
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        x: torch.Tensor,
 | 
						|
        c: torch.Tensor,  # timestep_aware_representations + context_aware_representations
 | 
						|
        attn_mask: torch.Tensor = None,
 | 
						|
    ):
 | 
						|
        gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
 | 
						|
 | 
						|
        norm_x = self.norm1(x)
 | 
						|
        qkv = self.self_attn_qkv(norm_x)
 | 
						|
        q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
 | 
						|
        # Apply QK-Norm if needed
 | 
						|
        q = self.self_attn_q_norm(q).to(v)
 | 
						|
        k = self.self_attn_k_norm(k).to(v)
 | 
						|
        qkv_list = [q, k, v]
 | 
						|
        del q,k
 | 
						|
        # Self-Attention
 | 
						|
        attn = attention( qkv_list, mode="torch", attn_mask=attn_mask)
 | 
						|
 | 
						|
        x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
 | 
						|
 | 
						|
        # FFN Layer
 | 
						|
        x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
 | 
						|
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class IndividualTokenRefiner(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        hidden_size,
 | 
						|
        heads_num,
 | 
						|
        depth,
 | 
						|
        mlp_width_ratio: float = 4.0,
 | 
						|
        mlp_drop_rate: float = 0.0,
 | 
						|
        act_type: str = "silu",
 | 
						|
        qk_norm: bool = False,
 | 
						|
        qk_norm_type: str = "layer",
 | 
						|
        qkv_bias: bool = True,
 | 
						|
        dtype: Optional[torch.dtype] = None,
 | 
						|
        device: Optional[torch.device] = None,
 | 
						|
    ):
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
        self.blocks = nn.ModuleList(
 | 
						|
            [
 | 
						|
                IndividualTokenRefinerBlock(
 | 
						|
                    hidden_size=hidden_size,
 | 
						|
                    heads_num=heads_num,
 | 
						|
                    mlp_width_ratio=mlp_width_ratio,
 | 
						|
                    mlp_drop_rate=mlp_drop_rate,
 | 
						|
                    act_type=act_type,
 | 
						|
                    qk_norm=qk_norm,
 | 
						|
                    qk_norm_type=qk_norm_type,
 | 
						|
                    qkv_bias=qkv_bias,
 | 
						|
                    **factory_kwargs,
 | 
						|
                )
 | 
						|
                for _ in range(depth)
 | 
						|
            ]
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        x: torch.Tensor,
 | 
						|
        c: torch.LongTensor,
 | 
						|
        mask: Optional[torch.Tensor] = None,
 | 
						|
    ):
 | 
						|
        self_attn_mask = None
 | 
						|
        if mask is not None:
 | 
						|
            batch_size = mask.shape[0]
 | 
						|
            seq_len = mask.shape[1]
 | 
						|
            mask = mask.to(x.device)
 | 
						|
            # batch_size x 1 x seq_len x seq_len
 | 
						|
            self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
 | 
						|
                1, 1, seq_len, 1
 | 
						|
            )
 | 
						|
            # batch_size x 1 x seq_len x seq_len
 | 
						|
            self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
 | 
						|
            # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
 | 
						|
            self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
 | 
						|
            # avoids self-attention weight being NaN for padding tokens
 | 
						|
            self_attn_mask[:, :, :, 0] = True
 | 
						|
 | 
						|
        for block in self.blocks:
 | 
						|
            x = block(x, c, self_attn_mask)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class SingleTokenRefiner(nn.Module):
 | 
						|
    """
 | 
						|
    A single token refiner block for llm text embedding refine.
 | 
						|
    """
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_channels,
 | 
						|
        hidden_size,
 | 
						|
        heads_num,
 | 
						|
        depth,
 | 
						|
        mlp_width_ratio: float = 4.0,
 | 
						|
        mlp_drop_rate: float = 0.0,
 | 
						|
        act_type: str = "silu",
 | 
						|
        qk_norm: bool = False,
 | 
						|
        qk_norm_type: str = "layer",
 | 
						|
        qkv_bias: bool = True,
 | 
						|
        attn_mode: str = "torch",
 | 
						|
        dtype: Optional[torch.dtype] = None,
 | 
						|
        device: Optional[torch.device] = None,
 | 
						|
    ):
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
        self.attn_mode = attn_mode
 | 
						|
        assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
 | 
						|
 | 
						|
        self.input_embedder = nn.Linear(
 | 
						|
            in_channels, hidden_size, bias=True, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        act_layer = get_activation_layer(act_type)
 | 
						|
        # Build timestep embedding layer
 | 
						|
        self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
 | 
						|
        # Build context embedding layer
 | 
						|
        self.c_embedder = TextProjection(
 | 
						|
            in_channels, hidden_size, act_layer, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.individual_token_refiner = IndividualTokenRefiner(
 | 
						|
            hidden_size=hidden_size,
 | 
						|
            heads_num=heads_num,
 | 
						|
            depth=depth,
 | 
						|
            mlp_width_ratio=mlp_width_ratio,
 | 
						|
            mlp_drop_rate=mlp_drop_rate,
 | 
						|
            act_type=act_type,
 | 
						|
            qk_norm=qk_norm,
 | 
						|
            qk_norm_type=qk_norm_type,
 | 
						|
            qkv_bias=qkv_bias,
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        x: torch.Tensor,
 | 
						|
        t: torch.LongTensor,
 | 
						|
        mask: Optional[torch.LongTensor] = None,
 | 
						|
    ):
 | 
						|
        timestep_aware_representations = self.t_embedder(t)
 | 
						|
 | 
						|
        if mask is None:
 | 
						|
            context_aware_representations = x.mean(dim=1)
 | 
						|
        else:
 | 
						|
            mask_float = mask.float().unsqueeze(-1)  # [b, s1, 1]
 | 
						|
            context_aware_representations = (x * mask_float).sum(
 | 
						|
                dim=1
 | 
						|
            ) / mask_float.sum(dim=1)
 | 
						|
        context_aware_representations = self.c_embedder(context_aware_representations.to(x.dtype))
 | 
						|
        c = timestep_aware_representations + context_aware_representations
 | 
						|
 | 
						|
        x = self.input_embedder(x)
 | 
						|
 | 
						|
        x = self.individual_token_refiner(x, c, mask)
 | 
						|
 | 
						|
        return x
 |