mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			89 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			89 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import torch.nn as nn
 | 
						|
 | 
						|
 | 
						|
class RMSNorm(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        dim: int,
 | 
						|
        elementwise_affine=True,
 | 
						|
        eps: float = 1e-6,
 | 
						|
        device=None,
 | 
						|
        dtype=None,
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        Initialize the RMSNorm normalization layer.
 | 
						|
 | 
						|
        Args:
 | 
						|
            dim (int): The dimension of the input tensor.
 | 
						|
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
 | 
						|
 | 
						|
        Attributes:
 | 
						|
            eps (float): A small value added to the denominator for numerical stability.
 | 
						|
            weight (nn.Parameter): Learnable scaling parameter.
 | 
						|
 | 
						|
        """
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
        self.eps = eps
 | 
						|
        if elementwise_affine:
 | 
						|
            self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
 | 
						|
 | 
						|
    def _norm(self, x):
 | 
						|
        """
 | 
						|
        Apply the RMSNorm normalization to the input tensor.
 | 
						|
 | 
						|
        Args:
 | 
						|
            x (torch.Tensor): The input tensor.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            torch.Tensor: The normalized tensor.
 | 
						|
 | 
						|
        """
 | 
						|
 | 
						|
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        """
 | 
						|
        Forward pass through the RMSNorm layer.
 | 
						|
 | 
						|
        Args:
 | 
						|
            x (torch.Tensor): The input tensor.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            torch.Tensor: The output tensor after applying RMSNorm.
 | 
						|
 | 
						|
        """
 | 
						|
        output = self._norm(x.float()).type_as(x)
 | 
						|
        if hasattr(self, "weight"):
 | 
						|
            output = output * self.weight
 | 
						|
        return output
 | 
						|
 | 
						|
    def apply_(self, x):
 | 
						|
        y = x.pow(2).mean(-1, keepdim=True)
 | 
						|
        y.add_(self.eps)
 | 
						|
        y.rsqrt_()
 | 
						|
        x.mul_(y)
 | 
						|
        del y
 | 
						|
        if hasattr(self, "weight"):
 | 
						|
            x.mul_(self.weight)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
def get_norm_layer(norm_layer):
 | 
						|
    """
 | 
						|
    Get the normalization layer.
 | 
						|
 | 
						|
    Args:
 | 
						|
        norm_layer (str): The type of normalization layer.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        norm_layer (nn.Module): The normalization layer.
 | 
						|
    """
 | 
						|
    if norm_layer == "layer":
 | 
						|
        return nn.LayerNorm
 | 
						|
    elif norm_layer == "rms":
 | 
						|
        return RMSNorm
 | 
						|
    else:
 | 
						|
        raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
 |