mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			29 lines
		
	
	
		
			813 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			29 lines
		
	
	
		
			813 B
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
						|
# All rights reserved.
 | 
						|
#
 | 
						|
# This source code is licensed under the license found in the
 | 
						|
# LICENSE file in the root directory of this source tree.
 | 
						|
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
 | 
						|
 | 
						|
 | 
						|
from typing import Union
 | 
						|
 | 
						|
import torch
 | 
						|
from torch import Tensor
 | 
						|
from torch import nn
 | 
						|
 | 
						|
 | 
						|
class LayerScale(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        dim: int,
 | 
						|
        init_values: Union[float, Tensor] = 1e-5,
 | 
						|
        inplace: bool = False,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        self.inplace = inplace
 | 
						|
        self.gamma = nn.Parameter(init_values * torch.ones(dim))
 | 
						|
 | 
						|
    def forward(self, x: Tensor) -> Tensor:
 | 
						|
        return x.mul_(self.gamma) if self.inplace else x * self.gamma
 |