mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
24 lines
591 B
Python
24 lines
591 B
Python
import torch.nn as nn
|
|
|
|
|
|
def get_activation_layer(act_type):
|
|
"""get activation layer
|
|
|
|
Args:
|
|
act_type (str): the activation type
|
|
|
|
Returns:
|
|
torch.nn.functional: the activation layer
|
|
"""
|
|
if act_type == "gelu":
|
|
return lambda: nn.GELU()
|
|
elif act_type == "gelu_tanh":
|
|
# Approximate `tanh` requires torch >= 1.13
|
|
return lambda: nn.GELU(approximate="tanh")
|
|
elif act_type == "relu":
|
|
return nn.ReLU
|
|
elif act_type == "silu":
|
|
return nn.SiLU
|
|
else:
|
|
raise ValueError(f"Unknown activation type: {act_type}")
|