mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			152 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			152 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch.nn as nn
 | 
						|
 | 
						|
 | 
						|
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
 | 
						|
    scratch = nn.Module()
 | 
						|
 | 
						|
    out_shape1 = out_shape
 | 
						|
    out_shape2 = out_shape
 | 
						|
    out_shape3 = out_shape
 | 
						|
    if len(in_shape) >= 4:
 | 
						|
        out_shape4 = out_shape
 | 
						|
 | 
						|
    if expand:
 | 
						|
        out_shape1 = out_shape
 | 
						|
        out_shape2 = out_shape * 2
 | 
						|
        out_shape3 = out_shape * 4
 | 
						|
        if len(in_shape) >= 4:
 | 
						|
            out_shape4 = out_shape * 8
 | 
						|
 | 
						|
    scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False,
 | 
						|
                                  groups=groups)
 | 
						|
    scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False,
 | 
						|
                                  groups=groups)
 | 
						|
    scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False,
 | 
						|
                                  groups=groups)
 | 
						|
    if len(in_shape) >= 4:
 | 
						|
        scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False,
 | 
						|
                                      groups=groups)
 | 
						|
 | 
						|
    return scratch
 | 
						|
 | 
						|
 | 
						|
class ResidualConvUnit(nn.Module):
 | 
						|
    """Residual convolution module.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, features, activation, bn):
 | 
						|
        """Init.
 | 
						|
 | 
						|
        Args:
 | 
						|
            features (int): number of features
 | 
						|
        """
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        self.bn = bn
 | 
						|
 | 
						|
        self.groups = 1
 | 
						|
 | 
						|
        self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
 | 
						|
 | 
						|
        self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
 | 
						|
 | 
						|
        if self.bn == True:
 | 
						|
            self.bn1 = nn.BatchNorm2d(features)
 | 
						|
            self.bn2 = nn.BatchNorm2d(features)
 | 
						|
 | 
						|
        self.activation = activation
 | 
						|
 | 
						|
        self.skip_add = nn.quantized.FloatFunctional()
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        """Forward pass.
 | 
						|
 | 
						|
        Args:
 | 
						|
            x (tensor): input
 | 
						|
 | 
						|
        Returns:
 | 
						|
            tensor: output
 | 
						|
        """
 | 
						|
 | 
						|
        out = self.activation(x)
 | 
						|
        out = self.conv1(out)
 | 
						|
        if self.bn == True:
 | 
						|
            out = self.bn1(out)
 | 
						|
 | 
						|
        out = self.activation(out)
 | 
						|
        out = self.conv2(out)
 | 
						|
        if self.bn == True:
 | 
						|
            out = self.bn2(out)
 | 
						|
 | 
						|
        if self.groups > 1:
 | 
						|
            out = self.conv_merge(out)
 | 
						|
 | 
						|
        return self.skip_add.add(out, x)
 | 
						|
 | 
						|
 | 
						|
class FeatureFusionBlock(nn.Module):
 | 
						|
    """Feature fusion block.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
            self,
 | 
						|
            features,
 | 
						|
            activation,
 | 
						|
            deconv=False,
 | 
						|
            bn=False,
 | 
						|
            expand=False,
 | 
						|
            align_corners=True,
 | 
						|
            size=None
 | 
						|
    ):
 | 
						|
        """Init.
 | 
						|
 | 
						|
        Args:
 | 
						|
            features (int): number of features
 | 
						|
        """
 | 
						|
        super(FeatureFusionBlock, self).__init__()
 | 
						|
 | 
						|
        self.deconv = deconv
 | 
						|
        self.align_corners = align_corners
 | 
						|
 | 
						|
        self.groups = 1
 | 
						|
 | 
						|
        self.expand = expand
 | 
						|
        out_features = features
 | 
						|
        if self.expand == True:
 | 
						|
            out_features = features // 2
 | 
						|
 | 
						|
        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
 | 
						|
 | 
						|
        self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
 | 
						|
        self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
 | 
						|
 | 
						|
        self.skip_add = nn.quantized.FloatFunctional()
 | 
						|
 | 
						|
        self.size = size
 | 
						|
 | 
						|
    def forward(self, *xs, size=None):
 | 
						|
        """Forward pass.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            tensor: output
 | 
						|
        """
 | 
						|
        output = xs[0]
 | 
						|
 | 
						|
        if len(xs) == 2:
 | 
						|
            res = self.resConfUnit1(xs[1])
 | 
						|
            output = self.skip_add.add(output, res)
 | 
						|
 | 
						|
        output = self.resConfUnit2(output)
 | 
						|
 | 
						|
        if (size is None) and (self.size is None):
 | 
						|
            modifier = {"scale_factor": 2}
 | 
						|
        elif size is None:
 | 
						|
            modifier = {"size": self.size}
 | 
						|
        else:
 | 
						|
            modifier = {"size": size}
 | 
						|
 | 
						|
        output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
 | 
						|
        output = self.out_conv(output)
 | 
						|
 | 
						|
        return output
 |