mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			221 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			221 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
This module provides the implementation of an Audio Projection Model, which is designed for
 | 
						|
audio processing tasks. The model takes audio embeddings as input and outputs context tokens
 | 
						|
that can be used for various downstream applications, such as audio analysis or synthesis.
 | 
						|
 | 
						|
The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
 | 
						|
provides a foundation for building custom models. This implementation includes multiple linear
 | 
						|
layers with ReLU activation functions and a LayerNorm for normalization.
 | 
						|
 | 
						|
Key Features:
 | 
						|
- Audio embedding input with flexible sequence length and block structure.
 | 
						|
- Multiple linear layers for feature transformation.
 | 
						|
- ReLU activation for non-linear transformation.
 | 
						|
- LayerNorm for stabilizing and speeding up training.
 | 
						|
- Rearrangement of input embeddings to match the model's expected input shape.
 | 
						|
- Customizable number of blocks, channels, and context tokens for adaptability.
 | 
						|
 | 
						|
The module is structured to be easily integrated into larger systems or used as a standalone
 | 
						|
component for audio feature extraction and processing.
 | 
						|
 | 
						|
Classes:
 | 
						|
- AudioProjModel: A class representing the audio projection model with configurable parameters.
 | 
						|
 | 
						|
Functions:
 | 
						|
- (none)
 | 
						|
 | 
						|
Dependencies:
 | 
						|
- torch: For tensor operations and neural network components.
 | 
						|
- diffusers: For the ModelMixin base class.
 | 
						|
- einops: For tensor rearrangement operations.
 | 
						|
 | 
						|
"""
 | 
						|
 | 
						|
import torch
 | 
						|
from diffusers import ModelMixin
 | 
						|
from einops import rearrange
 | 
						|
 | 
						|
import math
 | 
						|
import torch.nn as nn
 | 
						|
 | 
						|
class AudioProjNet2(ModelMixin):
 | 
						|
    """Audio Projection Model
 | 
						|
 | 
						|
    This class defines an audio projection model that takes audio embeddings as input
 | 
						|
    and produces context tokens as output. The model is based on the ModelMixin class
 | 
						|
    and consists of multiple linear layers and activation functions. It can be used
 | 
						|
    for various audio processing tasks.
 | 
						|
 | 
						|
    Attributes:
 | 
						|
        seq_len (int): The length of the audio sequence.
 | 
						|
        blocks (int): The number of blocks in the audio projection model.
 | 
						|
        channels (int): The number of channels in the audio projection model.
 | 
						|
        intermediate_dim (int): The intermediate dimension of the model.
 | 
						|
        context_tokens (int): The number of context tokens in the output.
 | 
						|
        output_dim (int): The output dimension of the context tokens.
 | 
						|
 | 
						|
    Methods:
 | 
						|
        __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
 | 
						|
            Initializes the AudioProjModel with the given parameters.
 | 
						|
        forward(self, audio_embeds):
 | 
						|
            Defines the forward pass for the AudioProjModel.
 | 
						|
            Parameters:
 | 
						|
            audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
 | 
						|
            Returns:
 | 
						|
            context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        seq_len=5,
 | 
						|
        blocks=12,  # add a new parameter blocks
 | 
						|
        channels=768,  # add a new parameter channels
 | 
						|
        intermediate_dim=512,
 | 
						|
        output_dim=768,
 | 
						|
        context_tokens=4,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        self.seq_len = seq_len
 | 
						|
        self.blocks = blocks
 | 
						|
        self.channels = channels
 | 
						|
        self.input_dim = (
 | 
						|
            seq_len * blocks * channels
 | 
						|
        )  
 | 
						|
        self.intermediate_dim = intermediate_dim
 | 
						|
        self.context_tokens = context_tokens
 | 
						|
        self.output_dim = output_dim
 | 
						|
 | 
						|
        # define multiple linear layers
 | 
						|
        self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
 | 
						|
        self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
 | 
						|
        self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
 | 
						|
 | 
						|
        self.norm = nn.LayerNorm(output_dim)
 | 
						|
 | 
						|
 | 
						|
    def forward(self, audio_embeds):
 | 
						|
          
 | 
						|
        video_length = audio_embeds.shape[1]
 | 
						|
        audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
 | 
						|
        batch_size, window_size, blocks, channels = audio_embeds.shape
 | 
						|
        audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
 | 
						|
 | 
						|
        audio_embeds = torch.relu(self.proj1(audio_embeds))
 | 
						|
        audio_embeds = torch.relu(self.proj2(audio_embeds))
 | 
						|
 | 
						|
        context_tokens = self.proj3(audio_embeds).reshape(
 | 
						|
            batch_size, self.context_tokens, self.output_dim
 | 
						|
        )
 | 
						|
        context_tokens = self.norm(context_tokens)
 | 
						|
        out_all = rearrange(
 | 
						|
            context_tokens, "(bz f) m c -> bz f m c", f=video_length
 | 
						|
        )
 | 
						|
 | 
						|
        return out_all
 | 
						|
 | 
						|
 | 
						|
def reshape_tensor(x, heads):
 | 
						|
    bs, length, width = x.shape
 | 
						|
    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
 | 
						|
    x = x.view(bs, length, heads, -1)
 | 
						|
    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
 | 
						|
    x = x.transpose(1, 2)
 | 
						|
    # (bs, n_heads, length, dim_per_head) 
 | 
						|
    x = x.reshape(bs, heads, length, -1)
 | 
						|
    return x
 | 
						|
 | 
						|
 | 
						|
class PerceiverAttentionCA(nn.Module):
 | 
						|
    def __init__(self, *, dim=3072, dim_head=1024, heads=33):
 | 
						|
        super().__init__()
 | 
						|
        self.scale = dim_head ** -0.5
 | 
						|
        self.dim_head = dim_head
 | 
						|
        self.heads = heads
 | 
						|
        inner_dim = dim_head #* heads
 | 
						|
 | 
						|
        self.norm1 = nn.LayerNorm(dim)
 | 
						|
        self.norm2 = nn.LayerNorm(dim)
 | 
						|
 | 
						|
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
 | 
						|
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
 | 
						|
        self.to_out = nn.Linear(inner_dim, dim, bias=False)
 | 
						|
 | 
						|
        import torch.nn.init as init
 | 
						|
        init.zeros_(self.to_out.weight)
 | 
						|
        if self.to_out.bias is not None:
 | 
						|
            init.zeros_(self.to_out.bias)
 | 
						|
 | 
						|
    def forward(self, x, latents):
 | 
						|
        """
 | 
						|
        Args:
 | 
						|
            x (torch.Tensor): image features
 | 
						|
                shape (b, t, aa, D)
 | 
						|
            latent (torch.Tensor): latent features
 | 
						|
                shape (b, t, hw, D)
 | 
						|
        """
 | 
						|
        x = self.norm1(x)
 | 
						|
        latents = self.norm2(latents)
 | 
						|
        # print("latents shape: ", latents.shape)
 | 
						|
        # print("x shape: ", x.shape)
 | 
						|
        q = self.to_q(latents)
 | 
						|
        k, v = self.to_kv(x).chunk(2, dim=-1)
 | 
						|
 | 
						|
 | 
						|
        # attention
 | 
						|
        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
 | 
						|
        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
 | 
						|
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
 | 
						|
        out = weight @ v
 | 
						|
        
 | 
						|
        # out = out.permute(0, 2, 1, 3)
 | 
						|
        return self.to_out(out)
 | 
						|
    #def forward(self, x, latents):
 | 
						|
    #    """
 | 
						|
    #    Args:
 | 
						|
    #        x (torch.Tensor): image features
 | 
						|
    #            shape (b, t, aa, D)
 | 
						|
    #        latent (torch.Tensor): latent features
 | 
						|
    #            shape (b, t, hw, D)
 | 
						|
    #    """
 | 
						|
    #    if get_sequence_parallel_state():
 | 
						|
    #        sp_size = nccl_info.sp_size
 | 
						|
    #        sp_rank = nccl_info.rank_within_group
 | 
						|
    #        print("rank:", latents.shape, sp_size, sp_rank)
 | 
						|
    #        latents = torch.chunk(latents, sp_size, dim=1)[sp_rank]
 | 
						|
 | 
						|
    #    x = self.norm1(x)
 | 
						|
    #    latents = self.norm2(latents)
 | 
						|
    #    # print("latents shape: ", latents.shape)
 | 
						|
    #    # print("x shape: ", x.shape)
 | 
						|
    #    q = self.to_q(latents)
 | 
						|
    #    k, v = self.to_kv(x).chunk(2, dim=-1)
 | 
						|
 | 
						|
    #    # print("q, k, v: ", q.shape, k.shape, v.shape)
 | 
						|
 | 
						|
    #    # attention
 | 
						|
    #    #scale = 1 / math.sqrt(math.sqrt(self.dim_head))
 | 
						|
    #    #weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
 | 
						|
    #    #weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
 | 
						|
    #    #out = weight @ v
 | 
						|
    #    def shrink_head(encoder_state, dim):
 | 
						|
    #        local_heads = encoder_state.shape[dim] // nccl_info.sp_size
 | 
						|
    #        return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
 | 
						|
 | 
						|
    #    if get_sequence_parallel_state():
 | 
						|
    #    # batch_size, seq_len, attn_heads, head_dim
 | 
						|
    #        q = all_to_all_4D(q, scatter_dim=2, gather_dim=1)  # [2, 32256, 24, 128]
 | 
						|
    #        k = shrink_head(k ,dim=2)
 | 
						|
    #        v = shrink_head(v ,dim=2)
 | 
						|
    #    qkv = torch.stack([query, key, value], dim=2)
 | 
						|
    #    attn = flash_attn_no_pad(qkv, causal=False, dropout_p=0.0, softmax_scale=None)
 | 
						|
    #    # out = out.permute(0, 2, 1, 3)
 | 
						|
    #    #b, s, a, d = attn.shape
 | 
						|
    #    #attn = attn.reshape(b, s, -1)
 | 
						|
    #        
 | 
						|
    #    out = self.to_out(attn)
 | 
						|
    #    if get_sequence_parallel_state():
 | 
						|
    #        out = all_gather(out, dim=1)
 | 
						|
    #    return out
 |