mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			428 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			428 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from dataclasses import dataclass
 | 
						|
from typing import Optional, Tuple
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
 | 
						|
from diffusers.utils import BaseOutput, is_torch_version
 | 
						|
from diffusers.utils.torch_utils import randn_tensor
 | 
						|
from diffusers.models.attention_processor import SpatialNorm
 | 
						|
from .unet_causal_3d_blocks import (
 | 
						|
    CausalConv3d,
 | 
						|
    UNetMidBlockCausal3D,
 | 
						|
    get_down_block3d,
 | 
						|
    get_up_block3d,
 | 
						|
)
 | 
						|
 | 
						|
@dataclass
 | 
						|
class DecoderOutput(BaseOutput):
 | 
						|
    r"""
 | 
						|
    Output of decoding method.
 | 
						|
 | 
						|
    Args:
 | 
						|
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
 | 
						|
            The decoded output sample from the last layer of the model.
 | 
						|
    """
 | 
						|
 | 
						|
    sample: torch.FloatTensor
 | 
						|
 | 
						|
 | 
						|
class EncoderCausal3D(nn.Module):
 | 
						|
    r"""
 | 
						|
    The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
 | 
						|
 | 
						|
    Args:
 | 
						|
        in_channels (`int`, *optional*, defaults to 3):
 | 
						|
            The number of input channels.
 | 
						|
        out_channels (`int`, *optional*, defaults to 3):
 | 
						|
            The number of output channels.
 | 
						|
        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
 | 
						|
            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
 | 
						|
            options.
 | 
						|
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
 | 
						|
            The number of output channels for each block.
 | 
						|
        layers_per_block (`int`, *optional*, defaults to 2):
 | 
						|
            The number of layers per block.
 | 
						|
        norm_num_groups (`int`, *optional*, defaults to 32):
 | 
						|
            The number of groups for normalization.
 | 
						|
        act_fn (`str`, *optional*, defaults to `"silu"`):
 | 
						|
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
 | 
						|
        double_z (`bool`, *optional*, defaults to `True`):
 | 
						|
            Whether to double the number of output channels for the last block.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_channels: int = 3,
 | 
						|
        out_channels: int = 3,
 | 
						|
        down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
 | 
						|
        block_out_channels: Tuple[int, ...] = (64,),
 | 
						|
        layers_per_block: int = 2,
 | 
						|
        norm_num_groups: int = 32,
 | 
						|
        act_fn: str = "silu",
 | 
						|
        double_z: bool = True,
 | 
						|
        mid_block_add_attention=True,
 | 
						|
        time_compression_ratio: int = 4,
 | 
						|
        spatial_compression_ratio: int = 8,
 | 
						|
        disable_causal: bool = False,
 | 
						|
        mid_block_causal_attn: bool = False,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.layers_per_block = layers_per_block
 | 
						|
 | 
						|
        self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1, disable_causal=disable_causal)
 | 
						|
        self.mid_block = None
 | 
						|
        self.down_blocks = nn.ModuleList([])
 | 
						|
 | 
						|
        # down
 | 
						|
        output_channel = block_out_channels[0]
 | 
						|
        for i, down_block_type in enumerate(down_block_types):
 | 
						|
            input_channel = output_channel
 | 
						|
            output_channel = block_out_channels[i]
 | 
						|
            is_final_block = i == len(block_out_channels) - 1
 | 
						|
            num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
 | 
						|
            num_time_downsample_layers = int(np.log2(time_compression_ratio))
 | 
						|
 | 
						|
            if time_compression_ratio == 4:
 | 
						|
                add_spatial_downsample = bool(i < num_spatial_downsample_layers)
 | 
						|
                add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
 | 
						|
            elif time_compression_ratio == 8:
 | 
						|
                add_spatial_downsample = bool(i < num_spatial_downsample_layers)
 | 
						|
                add_time_downsample = bool(i < num_time_downsample_layers)
 | 
						|
            else:
 | 
						|
                raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
 | 
						|
            
 | 
						|
            downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
 | 
						|
            downsample_stride_T = (2, ) if add_time_downsample else (1, )
 | 
						|
            downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
 | 
						|
            down_block = get_down_block3d(
 | 
						|
                down_block_type,
 | 
						|
                num_layers=self.layers_per_block,
 | 
						|
                in_channels=input_channel,
 | 
						|
                out_channels=output_channel,
 | 
						|
                add_downsample=bool(add_spatial_downsample or add_time_downsample),
 | 
						|
                downsample_stride=downsample_stride,
 | 
						|
                resnet_eps=1e-6,
 | 
						|
                downsample_padding=0,
 | 
						|
                resnet_act_fn=act_fn,
 | 
						|
                resnet_groups=norm_num_groups,
 | 
						|
                attention_head_dim=output_channel,
 | 
						|
                temb_channels=None,
 | 
						|
                disable_causal=disable_causal,
 | 
						|
            )
 | 
						|
            self.down_blocks.append(down_block)
 | 
						|
 | 
						|
        # mid
 | 
						|
        self.mid_block = UNetMidBlockCausal3D(
 | 
						|
            in_channels=block_out_channels[-1],
 | 
						|
            resnet_eps=1e-6,
 | 
						|
            resnet_act_fn=act_fn,
 | 
						|
            output_scale_factor=1,
 | 
						|
            resnet_time_scale_shift="default",
 | 
						|
            attention_head_dim=block_out_channels[-1],
 | 
						|
            resnet_groups=norm_num_groups,
 | 
						|
            temb_channels=None,
 | 
						|
            add_attention=mid_block_add_attention,
 | 
						|
            disable_causal=disable_causal,
 | 
						|
            causal_attention=mid_block_causal_attn,
 | 
						|
        )
 | 
						|
 | 
						|
        # out
 | 
						|
        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
 | 
						|
        self.conv_act = nn.SiLU()
 | 
						|
 | 
						|
        conv_out_channels = 2 * out_channels if double_z else out_channels
 | 
						|
        self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, disable_causal=disable_causal)
 | 
						|
 | 
						|
        self.gradient_checkpointing = False
 | 
						|
 | 
						|
    def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
 | 
						|
        r"""The forward method of the `EncoderCausal3D` class."""
 | 
						|
        assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
 | 
						|
 | 
						|
        sample = self.conv_in(sample)
 | 
						|
 | 
						|
        if self.training and self.gradient_checkpointing:
 | 
						|
 | 
						|
            def create_custom_forward(module):
 | 
						|
                def custom_forward(*inputs):
 | 
						|
                    return module(*inputs)
 | 
						|
 | 
						|
                return custom_forward
 | 
						|
 | 
						|
            # down
 | 
						|
            if is_torch_version(">=", "1.11.0"):
 | 
						|
                for down_block in self.down_blocks:
 | 
						|
                    sample = torch.utils.checkpoint.checkpoint(
 | 
						|
                        create_custom_forward(down_block), sample, use_reentrant=False
 | 
						|
                    )
 | 
						|
                # middle
 | 
						|
                sample = torch.utils.checkpoint.checkpoint(
 | 
						|
                    create_custom_forward(self.mid_block), sample, use_reentrant=False
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                for down_block in self.down_blocks:
 | 
						|
                    sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
 | 
						|
                # middle
 | 
						|
                sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
 | 
						|
 | 
						|
        else:
 | 
						|
            # down
 | 
						|
            for down_block in self.down_blocks:
 | 
						|
                sample = down_block(sample)
 | 
						|
 | 
						|
            # middle
 | 
						|
            sample = self.mid_block(sample)
 | 
						|
 | 
						|
        # post-process
 | 
						|
        sample = self.conv_norm_out(sample)
 | 
						|
        sample = self.conv_act(sample)
 | 
						|
        sample = self.conv_out(sample)
 | 
						|
 | 
						|
        return sample
 | 
						|
 | 
						|
 | 
						|
class DecoderCausal3D(nn.Module):
 | 
						|
    r"""
 | 
						|
    The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
 | 
						|
 | 
						|
    Args:
 | 
						|
        in_channels (`int`, *optional*, defaults to 3):
 | 
						|
            The number of input channels.
 | 
						|
        out_channels (`int`, *optional*, defaults to 3):
 | 
						|
            The number of output channels.
 | 
						|
        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
 | 
						|
            The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
 | 
						|
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
 | 
						|
            The number of output channels for each block.
 | 
						|
        layers_per_block (`int`, *optional*, defaults to 2):
 | 
						|
            The number of layers per block.
 | 
						|
        norm_num_groups (`int`, *optional*, defaults to 32):
 | 
						|
            The number of groups for normalization.
 | 
						|
        act_fn (`str`, *optional*, defaults to `"silu"`):
 | 
						|
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
 | 
						|
        norm_type (`str`, *optional*, defaults to `"group"`):
 | 
						|
            The normalization type to use. Can be either `"group"` or `"spatial"`.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_channels: int = 3,
 | 
						|
        out_channels: int = 3,
 | 
						|
        up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
 | 
						|
        block_out_channels: Tuple[int, ...] = (64,),
 | 
						|
        layers_per_block: int = 2,
 | 
						|
        norm_num_groups: int = 32,
 | 
						|
        act_fn: str = "silu",
 | 
						|
        norm_type: str = "group",  # group, spatial
 | 
						|
        mid_block_add_attention=True,
 | 
						|
        time_compression_ratio: int = 4,
 | 
						|
        spatial_compression_ratio: int = 8,
 | 
						|
        disable_causal: bool = False,
 | 
						|
        mid_block_causal_attn: bool = False,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.layers_per_block = layers_per_block
 | 
						|
 | 
						|
        self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, disable_causal=disable_causal)
 | 
						|
        self.mid_block = None
 | 
						|
        self.up_blocks = nn.ModuleList([])
 | 
						|
 | 
						|
        temb_channels = in_channels if norm_type == "spatial" else None
 | 
						|
 | 
						|
        # mid
 | 
						|
        self.mid_block = UNetMidBlockCausal3D(
 | 
						|
            in_channels=block_out_channels[-1],
 | 
						|
            resnet_eps=1e-6,
 | 
						|
            resnet_act_fn=act_fn,
 | 
						|
            output_scale_factor=1,
 | 
						|
            resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
 | 
						|
            attention_head_dim=block_out_channels[-1],
 | 
						|
            resnet_groups=norm_num_groups,
 | 
						|
            temb_channels=temb_channels,
 | 
						|
            add_attention=mid_block_add_attention,
 | 
						|
            disable_causal=disable_causal,
 | 
						|
            causal_attention=mid_block_causal_attn,
 | 
						|
        )
 | 
						|
 | 
						|
        # up
 | 
						|
        reversed_block_out_channels = list(reversed(block_out_channels))
 | 
						|
        output_channel = reversed_block_out_channels[0]
 | 
						|
        for i, up_block_type in enumerate(up_block_types):
 | 
						|
            prev_output_channel = output_channel
 | 
						|
            output_channel = reversed_block_out_channels[i]
 | 
						|
            is_final_block = i == len(block_out_channels) - 1
 | 
						|
            num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
 | 
						|
            num_time_upsample_layers = int(np.log2(time_compression_ratio))
 | 
						|
 | 
						|
            if time_compression_ratio == 4:
 | 
						|
                add_spatial_upsample = bool(i < num_spatial_upsample_layers) 
 | 
						|
                add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
 | 
						|
            elif time_compression_ratio == 8:
 | 
						|
                add_spatial_upsample = bool(i >= len(block_out_channels) - num_spatial_upsample_layers)
 | 
						|
                add_time_upsample = bool(i >= len(block_out_channels) - num_time_upsample_layers)
 | 
						|
            else:
 | 
						|
                raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
 | 
						|
 | 
						|
            upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
 | 
						|
            upsample_scale_factor_T = (2, ) if add_time_upsample else (1, )
 | 
						|
            upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
 | 
						|
            up_block = get_up_block3d(
 | 
						|
                up_block_type,
 | 
						|
                num_layers=self.layers_per_block + 1,
 | 
						|
                in_channels=prev_output_channel,
 | 
						|
                out_channels=output_channel,
 | 
						|
                prev_output_channel=None,
 | 
						|
                add_upsample=bool(add_spatial_upsample or add_time_upsample),
 | 
						|
                upsample_scale_factor=upsample_scale_factor,
 | 
						|
                resnet_eps=1e-6,
 | 
						|
                resnet_act_fn=act_fn,
 | 
						|
                resnet_groups=norm_num_groups,
 | 
						|
                attention_head_dim=output_channel,
 | 
						|
                temb_channels=temb_channels,
 | 
						|
                resnet_time_scale_shift=norm_type,
 | 
						|
                disable_causal=disable_causal,
 | 
						|
            )
 | 
						|
            self.up_blocks.append(up_block)
 | 
						|
            prev_output_channel = output_channel
 | 
						|
 | 
						|
        # out
 | 
						|
        if norm_type == "spatial":
 | 
						|
            self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
 | 
						|
        else:
 | 
						|
            self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
 | 
						|
        self.conv_act = nn.SiLU()
 | 
						|
        self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, disable_causal=disable_causal)
 | 
						|
 | 
						|
        self.gradient_checkpointing = False
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        sample: torch.FloatTensor,
 | 
						|
        latent_embeds: Optional[torch.FloatTensor] = None,
 | 
						|
    ) -> torch.FloatTensor:
 | 
						|
        r"""The forward method of the `DecoderCausal3D` class."""
 | 
						|
        assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
 | 
						|
        
 | 
						|
        sample = self.conv_in(sample)
 | 
						|
 | 
						|
        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
 | 
						|
        if self.training and self.gradient_checkpointing:
 | 
						|
 | 
						|
            def create_custom_forward(module):
 | 
						|
                def custom_forward(*inputs):
 | 
						|
                    return module(*inputs)
 | 
						|
 | 
						|
                return custom_forward
 | 
						|
 | 
						|
            if is_torch_version(">=", "1.11.0"):
 | 
						|
                # middle
 | 
						|
                sample = torch.utils.checkpoint.checkpoint(
 | 
						|
                    create_custom_forward(self.mid_block),
 | 
						|
                    sample,
 | 
						|
                    latent_embeds,
 | 
						|
                    use_reentrant=False,
 | 
						|
                )
 | 
						|
                sample = sample.to(upscale_dtype)
 | 
						|
 | 
						|
                # up
 | 
						|
                for up_block in self.up_blocks:
 | 
						|
                    sample = torch.utils.checkpoint.checkpoint(
 | 
						|
                        create_custom_forward(up_block),
 | 
						|
                        sample,
 | 
						|
                        latent_embeds,
 | 
						|
                        use_reentrant=False,
 | 
						|
                    )
 | 
						|
            else:
 | 
						|
                # middle
 | 
						|
                sample = torch.utils.checkpoint.checkpoint(
 | 
						|
                    create_custom_forward(self.mid_block), sample, latent_embeds
 | 
						|
                )
 | 
						|
                sample = sample.to(upscale_dtype)
 | 
						|
 | 
						|
                # up
 | 
						|
                for up_block in self.up_blocks:
 | 
						|
                    sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
 | 
						|
        else:
 | 
						|
            # middle
 | 
						|
            sample = self.mid_block(sample, latent_embeds)
 | 
						|
            sample = sample.to(upscale_dtype)
 | 
						|
 | 
						|
            # up
 | 
						|
            for up_block in self.up_blocks:
 | 
						|
                sample = up_block(sample, latent_embeds)
 | 
						|
 | 
						|
        # post-process
 | 
						|
        if latent_embeds is None:
 | 
						|
            sample = self.conv_norm_out(sample)
 | 
						|
        else:
 | 
						|
            sample = self.conv_norm_out(sample, latent_embeds)
 | 
						|
        sample = self.conv_act(sample)
 | 
						|
        sample = self.conv_out(sample)   
 | 
						|
 | 
						|
        return sample
 | 
						|
 | 
						|
 | 
						|
class DiagonalGaussianDistribution(object):
 | 
						|
    def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
 | 
						|
        if parameters.ndim == 3:
 | 
						|
            dim = 2 # (B, L, C)
 | 
						|
        elif parameters.ndim == 5 or parameters.ndim == 4:
 | 
						|
            dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
 | 
						|
        else:
 | 
						|
            raise NotImplementedError
 | 
						|
        self.parameters = parameters
 | 
						|
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
 | 
						|
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
 | 
						|
        self.deterministic = deterministic
 | 
						|
        self.std = torch.exp(0.5 * self.logvar)
 | 
						|
        self.var = torch.exp(self.logvar)
 | 
						|
        if self.deterministic:
 | 
						|
            self.var = self.std = torch.zeros_like(
 | 
						|
                self.mean, device=self.parameters.device, dtype=self.parameters.dtype
 | 
						|
            )
 | 
						|
 | 
						|
    def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
 | 
						|
        # make sure sample is on the same device as the parameters and has same dtype
 | 
						|
        sample = randn_tensor(
 | 
						|
            self.mean.shape,
 | 
						|
            generator=generator,
 | 
						|
            device=self.parameters.device,
 | 
						|
            dtype=self.parameters.dtype,
 | 
						|
        )
 | 
						|
        x = self.mean + self.std * sample
 | 
						|
        return x
 | 
						|
 | 
						|
    def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
 | 
						|
        if self.deterministic:
 | 
						|
            return torch.Tensor([0.0])
 | 
						|
        else:
 | 
						|
            reduce_dim = list(range(1, self.mean.ndim))
 | 
						|
            if other is None:
 | 
						|
                return 0.5 * torch.sum(
 | 
						|
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
 | 
						|
                    dim=reduce_dim,
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return 0.5 * torch.sum(
 | 
						|
                    torch.pow(self.mean - other.mean, 2) / other.var
 | 
						|
                    + self.var / other.var
 | 
						|
                    - 1.0
 | 
						|
                    - self.logvar
 | 
						|
                    + other.logvar,
 | 
						|
                    dim=reduce_dim,
 | 
						|
                )
 | 
						|
 | 
						|
    def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
 | 
						|
        if self.deterministic:
 | 
						|
            return torch.Tensor([0.0])
 | 
						|
        logtwopi = np.log(2.0 * np.pi)
 | 
						|
        return 0.5 * torch.sum(
 | 
						|
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
 | 
						|
            dim=dims,
 | 
						|
        )
 | 
						|
 | 
						|
    def mode(self) -> torch.Tensor:
 | 
						|
        return self.mean
 |