mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			65 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			65 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
 | 
						|
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
						|
# All rights reserved.
 | 
						|
#
 | 
						|
# This source code is licensed under the license found in the
 | 
						|
# LICENSE file in the root directory of this source tree.
 | 
						|
 | 
						|
from typing import Callable, Optional
 | 
						|
 | 
						|
from torch import Tensor, nn
 | 
						|
import torch.nn.functional as F
 | 
						|
 | 
						|
 | 
						|
class SwiGLUFFN(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_features: int,
 | 
						|
        hidden_features: Optional[int] = None,
 | 
						|
        out_features: Optional[int] = None,
 | 
						|
        act_layer: Callable[..., nn.Module] = None,
 | 
						|
        drop: float = 0.0,
 | 
						|
        bias: bool = True,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        out_features = out_features or in_features
 | 
						|
        hidden_features = hidden_features or in_features
 | 
						|
        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
 | 
						|
        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
 | 
						|
 | 
						|
    def forward(self, x: Tensor) -> Tensor:
 | 
						|
        x12 = self.w12(x)
 | 
						|
        x1, x2 = x12.chunk(2, dim=-1)
 | 
						|
        hidden = F.silu(x1) * x2
 | 
						|
        return self.w3(hidden)
 | 
						|
 | 
						|
 | 
						|
try:
 | 
						|
    from xformers.ops import SwiGLU
 | 
						|
 | 
						|
    XFORMERS_AVAILABLE = True
 | 
						|
except ImportError:
 | 
						|
    SwiGLU = SwiGLUFFN
 | 
						|
    XFORMERS_AVAILABLE = False
 | 
						|
 | 
						|
 | 
						|
class SwiGLUFFNFused(SwiGLU):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_features: int,
 | 
						|
        hidden_features: Optional[int] = None,
 | 
						|
        out_features: Optional[int] = None,
 | 
						|
        act_layer: Callable[..., nn.Module] = None,
 | 
						|
        drop: float = 0.0,
 | 
						|
        bias: bool = True,
 | 
						|
    ) -> None:
 | 
						|
        out_features = out_features or in_features
 | 
						|
        hidden_features = hidden_features or in_features
 | 
						|
        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
 | 
						|
        super().__init__(
 | 
						|
            in_features=in_features,
 | 
						|
            hidden_features=hidden_features,
 | 
						|
            out_features=out_features,
 | 
						|
            bias=bias,
 | 
						|
        )
 |