mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	multitalk files
This commit is contained in:
		
							parent
							
								
									621687c12a
								
							
						
					
					
						commit
						3a8bd05c6e
					
				
							
								
								
									
										382
									
								
								wan/multitalk/attention.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										382
									
								
								wan/multitalk/attention.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,382 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from einops import rearrange, repeat
 | 
			
		||||
from .multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids
 | 
			
		||||
from wan.modules.attention import pay_attention
 | 
			
		||||
 | 
			
		||||
import xformers.ops
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import flash_attn_interface
 | 
			
		||||
    FLASH_ATTN_3_AVAILABLE = True
 | 
			
		||||
except ModuleNotFoundError:
 | 
			
		||||
    FLASH_ATTN_3_AVAILABLE = False
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import flash_attn
 | 
			
		||||
    FLASH_ATTN_2_AVAILABLE = True
 | 
			
		||||
except ModuleNotFoundError:
 | 
			
		||||
    FLASH_ATTN_2_AVAILABLE = False
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'flash_attention',
 | 
			
		||||
    'attention',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def flash_attention(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    q_lens=None,
 | 
			
		||||
    k_lens=None,
 | 
			
		||||
    dropout_p=0.,
 | 
			
		||||
    softmax_scale=None,
 | 
			
		||||
    q_scale=None,
 | 
			
		||||
    causal=False,
 | 
			
		||||
    window_size=(-1, -1),
 | 
			
		||||
    deterministic=False,
 | 
			
		||||
    dtype=torch.bfloat16,
 | 
			
		||||
    version=None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    q:              [B, Lq, Nq, C1].
 | 
			
		||||
    k:              [B, Lk, Nk, C1].
 | 
			
		||||
    v:              [B, Lk, Nk, C2]. Nq must be divisible by Nk.
 | 
			
		||||
    q_lens:         [B].
 | 
			
		||||
    k_lens:         [B].
 | 
			
		||||
    dropout_p:      float. Dropout probability.
 | 
			
		||||
    softmax_scale:  float. The scaling of QK^T before applying softmax.
 | 
			
		||||
    causal:         bool. Whether to apply causal attention mask.
 | 
			
		||||
    window_size:    (left right). If not (-1, -1), apply sliding window local attention.
 | 
			
		||||
    deterministic:  bool. If True, slightly slower and uses more memory.
 | 
			
		||||
    dtype:          torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
 | 
			
		||||
    """
 | 
			
		||||
    half_dtypes = (torch.float16, torch.bfloat16)
 | 
			
		||||
    assert dtype in half_dtypes
 | 
			
		||||
    assert q.device.type == 'cuda' and q.size(-1) <= 256
 | 
			
		||||
 | 
			
		||||
    # params
 | 
			
		||||
    b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
 | 
			
		||||
 | 
			
		||||
    def half(x):
 | 
			
		||||
        return x if x.dtype in half_dtypes else x.to(dtype)
 | 
			
		||||
 | 
			
		||||
    # preprocess query
 | 
			
		||||
    if q_lens is None:
 | 
			
		||||
        q = half(q.flatten(0, 1))
 | 
			
		||||
        q_lens = torch.tensor(
 | 
			
		||||
            [lq] * b, dtype=torch.int32).to(
 | 
			
		||||
                device=q.device, non_blocking=True)
 | 
			
		||||
    else:
 | 
			
		||||
        q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
 | 
			
		||||
 | 
			
		||||
    # preprocess key, value
 | 
			
		||||
    if k_lens is None:
 | 
			
		||||
        k = half(k.flatten(0, 1))
 | 
			
		||||
        v = half(v.flatten(0, 1))
 | 
			
		||||
        k_lens = torch.tensor(
 | 
			
		||||
            [lk] * b, dtype=torch.int32).to(
 | 
			
		||||
                device=k.device, non_blocking=True)
 | 
			
		||||
    else:
 | 
			
		||||
        k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
 | 
			
		||||
        v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
 | 
			
		||||
 | 
			
		||||
    q = q.to(v.dtype)
 | 
			
		||||
    k = k.to(v.dtype)
 | 
			
		||||
 | 
			
		||||
    if q_scale is not None:
 | 
			
		||||
        q = q * q_scale
 | 
			
		||||
 | 
			
		||||
    if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            'Flash attention 3 is not available, use flash attention 2 instead.'
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # apply attention
 | 
			
		||||
    if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
        # Note: dropout_p, window_size are not supported in FA3 now.
 | 
			
		||||
        x = flash_attn_interface.flash_attn_varlen_func(
 | 
			
		||||
            q=q,
 | 
			
		||||
            k=k,
 | 
			
		||||
            v=v,
 | 
			
		||||
            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
 | 
			
		||||
                0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
 | 
			
		||||
                0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
            seqused_q=None,
 | 
			
		||||
            seqused_k=None,
 | 
			
		||||
            max_seqlen_q=lq,
 | 
			
		||||
            max_seqlen_k=lk,
 | 
			
		||||
            softmax_scale=softmax_scale,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            deterministic=deterministic)[0].unflatten(0, (b, lq))
 | 
			
		||||
    else:
 | 
			
		||||
        assert FLASH_ATTN_2_AVAILABLE
 | 
			
		||||
        x = flash_attn.flash_attn_varlen_func(
 | 
			
		||||
            q=q,
 | 
			
		||||
            k=k,
 | 
			
		||||
            v=v,
 | 
			
		||||
            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
 | 
			
		||||
                0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
 | 
			
		||||
                0, dtype=torch.int32).to(q.device, non_blocking=True),
 | 
			
		||||
            max_seqlen_q=lq,
 | 
			
		||||
            max_seqlen_k=lk,
 | 
			
		||||
            dropout_p=dropout_p,
 | 
			
		||||
            softmax_scale=softmax_scale,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            window_size=window_size,
 | 
			
		||||
            deterministic=deterministic).unflatten(0, (b, lq))
 | 
			
		||||
 | 
			
		||||
    # output
 | 
			
		||||
    return x.type(out_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def attention(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    q_lens=None,
 | 
			
		||||
    k_lens=None,
 | 
			
		||||
    dropout_p=0.,
 | 
			
		||||
    softmax_scale=None,
 | 
			
		||||
    q_scale=None,
 | 
			
		||||
    causal=False,
 | 
			
		||||
    window_size=(-1, -1),
 | 
			
		||||
    deterministic=False,
 | 
			
		||||
    dtype=torch.bfloat16,
 | 
			
		||||
    fa_version=None,
 | 
			
		||||
):
 | 
			
		||||
    if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
        return flash_attention(
 | 
			
		||||
            q=q,
 | 
			
		||||
            k=k,
 | 
			
		||||
            v=v,
 | 
			
		||||
            q_lens=q_lens,
 | 
			
		||||
            k_lens=k_lens,
 | 
			
		||||
            dropout_p=dropout_p,
 | 
			
		||||
            softmax_scale=softmax_scale,
 | 
			
		||||
            q_scale=q_scale,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            window_size=window_size,
 | 
			
		||||
            deterministic=deterministic,
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
            version=fa_version,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        if q_lens is not None or k_lens is not None:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
 | 
			
		||||
            )
 | 
			
		||||
        attn_mask = None
 | 
			
		||||
 | 
			
		||||
        q = q.transpose(1, 2).to(dtype)
 | 
			
		||||
        k = k.transpose(1, 2).to(dtype)
 | 
			
		||||
        v = v.transpose(1, 2).to(dtype)
 | 
			
		||||
 | 
			
		||||
        out = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
            q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
 | 
			
		||||
 | 
			
		||||
        out = out.transpose(1, 2).contiguous()
 | 
			
		||||
        return out
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
class SingleStreamAttention(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        dim: int,
 | 
			
		||||
        encoder_hidden_states_dim: int,
 | 
			
		||||
        num_heads: int,
 | 
			
		||||
        qkv_bias: bool,
 | 
			
		||||
        qk_norm: bool,
 | 
			
		||||
        norm_layer: nn.Module,
 | 
			
		||||
        attn_drop: float = 0.0,
 | 
			
		||||
        proj_drop: float = 0.0,
 | 
			
		||||
        eps: float = 1e-6,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.encoder_hidden_states_dim = encoder_hidden_states_dim
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.head_dim = dim // num_heads
 | 
			
		||||
        self.scale = self.head_dim**-0.5
 | 
			
		||||
        self.qk_norm = qk_norm
 | 
			
		||||
 | 
			
		||||
        self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
 | 
			
		||||
 | 
			
		||||
        self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
        self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
 | 
			
		||||
        self.attn_drop = nn.Dropout(attn_drop)
 | 
			
		||||
        self.proj = nn.Linear(dim, dim)
 | 
			
		||||
        self.proj_drop = nn.Dropout(proj_drop)
 | 
			
		||||
 | 
			
		||||
        self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias)
 | 
			
		||||
 | 
			
		||||
        self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
 | 
			
		||||
        self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
 | 
			
		||||
        N_t, N_h, N_w = shape
 | 
			
		||||
 | 
			
		||||
        x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
 | 
			
		||||
        # get q for hidden_state
 | 
			
		||||
        B, N, C = x.shape
 | 
			
		||||
        q = self.q_linear(x)
 | 
			
		||||
        q_shape = (B, N, self.num_heads, self.head_dim)
 | 
			
		||||
        q = q.view(q_shape).permute((0, 2, 1, 3))
 | 
			
		||||
 | 
			
		||||
        if self.qk_norm:
 | 
			
		||||
            q = self.q_norm(q)
 | 
			
		||||
        
 | 
			
		||||
        # get kv from encoder_hidden_states
 | 
			
		||||
        _, N_a, _ = encoder_hidden_states.shape
 | 
			
		||||
        encoder_kv = self.kv_linear(encoder_hidden_states)
 | 
			
		||||
        encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
 | 
			
		||||
        encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 
 | 
			
		||||
        encoder_k, encoder_v = encoder_kv.unbind(0)
 | 
			
		||||
 | 
			
		||||
        if self.qk_norm:
 | 
			
		||||
            encoder_k = self.add_k_norm(encoder_k)
 | 
			
		||||
 | 
			
		||||
        q = rearrange(q, "B H M K -> B M H K")
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
 | 
			
		||||
        encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
 | 
			
		||||
 | 
			
		||||
        attn_bias = None
 | 
			
		||||
        # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
 | 
			
		||||
        qkv_list = [q, encoder_k, encoder_v]
 | 
			
		||||
        q = encoder_k = encoder_v = None
 | 
			
		||||
        x = pay_attention(qkv_list)
 | 
			
		||||
        x = rearrange(x, "B M H K -> B H M K") 
 | 
			
		||||
 | 
			
		||||
        # linear transform
 | 
			
		||||
        x_output_shape = (B, N, C)
 | 
			
		||||
        x = x.transpose(1, 2) 
 | 
			
		||||
        x = x.reshape(x_output_shape) 
 | 
			
		||||
        x = self.proj(x)
 | 
			
		||||
        x = self.proj_drop(x)
 | 
			
		||||
 | 
			
		||||
        # reshape x to origin shape
 | 
			
		||||
        x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        dim: int,
 | 
			
		||||
        encoder_hidden_states_dim: int,
 | 
			
		||||
        num_heads: int,
 | 
			
		||||
        qkv_bias: bool,
 | 
			
		||||
        qk_norm: bool,
 | 
			
		||||
        norm_layer: nn.Module,
 | 
			
		||||
        attn_drop: float = 0.0,
 | 
			
		||||
        proj_drop: float = 0.0,
 | 
			
		||||
        eps: float = 1e-6,
 | 
			
		||||
        class_range: int = 24,
 | 
			
		||||
        class_interval: int = 4,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            dim=dim,
 | 
			
		||||
            encoder_hidden_states_dim=encoder_hidden_states_dim,
 | 
			
		||||
            num_heads=num_heads,
 | 
			
		||||
            qkv_bias=qkv_bias,
 | 
			
		||||
            qk_norm=qk_norm,
 | 
			
		||||
            norm_layer=norm_layer,
 | 
			
		||||
            attn_drop=attn_drop,
 | 
			
		||||
            proj_drop=proj_drop,
 | 
			
		||||
            eps=eps,
 | 
			
		||||
        )
 | 
			
		||||
        self.class_interval = class_interval
 | 
			
		||||
        self.class_range = class_range
 | 
			
		||||
        self.rope_h1  = (0, self.class_interval)
 | 
			
		||||
        self.rope_h2  = (self.class_range - self.class_interval, self.class_range)
 | 
			
		||||
        self.rope_bak = int(self.class_range // 2)
 | 
			
		||||
 | 
			
		||||
        self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    def forward(self, 
 | 
			
		||||
                x: torch.Tensor, 
 | 
			
		||||
                encoder_hidden_states: torch.Tensor, 
 | 
			
		||||
                shape=None, 
 | 
			
		||||
                x_ref_attn_map=None,
 | 
			
		||||
                ) -> torch.Tensor:
 | 
			
		||||
        
 | 
			
		||||
        encoder_hidden_states = encoder_hidden_states.squeeze(0)
 | 
			
		||||
        if x_ref_attn_map == None:
 | 
			
		||||
            return super().forward(x, encoder_hidden_states, shape)
 | 
			
		||||
 | 
			
		||||
        N_t, _, _ = shape 
 | 
			
		||||
        x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) 
 | 
			
		||||
 | 
			
		||||
        # get q for hidden_state
 | 
			
		||||
        B, N, C = x.shape
 | 
			
		||||
        q = self.q_linear(x) 
 | 
			
		||||
        q_shape = (B, N, self.num_heads, self.head_dim) 
 | 
			
		||||
        q = q.view(q_shape).permute((0, 2, 1, 3))
 | 
			
		||||
 | 
			
		||||
        if self.qk_norm:
 | 
			
		||||
            q = self.q_norm(q)
 | 
			
		||||
 | 
			
		||||
        max_values = x_ref_attn_map.max(1).values[:, None, None] 
 | 
			
		||||
        min_values = x_ref_attn_map.min(1).values[:, None, None] 
 | 
			
		||||
        max_min_values = torch.cat([max_values, min_values], dim=2)
 | 
			
		||||
 | 
			
		||||
        human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
 | 
			
		||||
        human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
 | 
			
		||||
 | 
			
		||||
        human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
 | 
			
		||||
        human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
 | 
			
		||||
        back   = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
 | 
			
		||||
        max_indices = x_ref_attn_map.argmax(dim=0)
 | 
			
		||||
        normalized_map = torch.stack([human1, human2, back], dim=1)
 | 
			
		||||
        normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N 
 | 
			
		||||
 | 
			
		||||
        q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
 | 
			
		||||
        q = self.rope_1d(q, normalized_pos)
 | 
			
		||||
        q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
 | 
			
		||||
 | 
			
		||||
        _, N_a, _ = encoder_hidden_states.shape 
 | 
			
		||||
        encoder_kv = self.kv_linear(encoder_hidden_states) 
 | 
			
		||||
        encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
 | 
			
		||||
        encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 
 | 
			
		||||
        encoder_k, encoder_v = encoder_kv.unbind(0) 
 | 
			
		||||
 | 
			
		||||
        if self.qk_norm:
 | 
			
		||||
            encoder_k = self.add_k_norm(encoder_k)
 | 
			
		||||
 | 
			
		||||
        per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device)
 | 
			
		||||
        per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
 | 
			
		||||
        per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
 | 
			
		||||
        encoder_pos = torch.concat([per_frame]*N_t, dim=0)
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
 | 
			
		||||
        encoder_k = self.rope_1d(encoder_k, encoder_pos)
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
 | 
			
		||||
 
 | 
			
		||||
        q = rearrange(q, "B H M K -> B M H K")
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
 | 
			
		||||
        encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
 | 
			
		||||
        # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
 | 
			
		||||
        qkv_list = [q, encoder_k, encoder_v]
 | 
			
		||||
        q = encoder_k = encoder_v = None
 | 
			
		||||
        x = pay_attention(qkv_list)
 | 
			
		||||
 | 
			
		||||
        x = rearrange(x, "B M H K -> B H M K")
 | 
			
		||||
 | 
			
		||||
        # linear transform
 | 
			
		||||
        x_output_shape = (B, N, C)
 | 
			
		||||
        x = x.transpose(1, 2) 
 | 
			
		||||
        x = x.reshape(x_output_shape) 
 | 
			
		||||
        x = self.proj(x) 
 | 
			
		||||
        x = self.proj_drop(x)
 | 
			
		||||
 | 
			
		||||
        # reshape x to origin shape
 | 
			
		||||
        x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) 
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
							
								
								
									
										23
									
								
								wan/multitalk/kokoro/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								wan/multitalk/kokoro/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,23 @@
 | 
			
		||||
__version__ = '0.9.4'
 | 
			
		||||
 | 
			
		||||
from loguru import logger
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
# Remove default handler
 | 
			
		||||
logger.remove()
 | 
			
		||||
 | 
			
		||||
# Add custom handler with clean format including module and line number
 | 
			
		||||
logger.add(
 | 
			
		||||
    sys.stderr,
 | 
			
		||||
    format="<green>{time:HH:mm:ss}</green> | <cyan>{module:>16}:{line}</cyan> | <level>{level: >8}</level> | <level>{message}</level>",
 | 
			
		||||
    colorize=True,
 | 
			
		||||
    level="INFO" # "DEBUG" to enable logger.debug("message") and up prints 
 | 
			
		||||
                 # "ERROR" to enable only logger.error("message") prints
 | 
			
		||||
                 # etc
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Disable before release or as needed
 | 
			
		||||
logger.disable("kokoro")
 | 
			
		||||
 | 
			
		||||
from .model import KModel
 | 
			
		||||
from .pipeline import KPipeline
 | 
			
		||||
							
								
								
									
										148
									
								
								wan/multitalk/kokoro/__main__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								wan/multitalk/kokoro/__main__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,148 @@
 | 
			
		||||
"""Kokoro TTS CLI
 | 
			
		||||
Example usage:
 | 
			
		||||
python3 -m kokoro --text "The sky above the port was the color of television, tuned to a dead channel." -o file.wav --debug
 | 
			
		||||
 | 
			
		||||
echo "Bom dia mundo, como vão vocês" > text.txt
 | 
			
		||||
python3 -m kokoro -i text.txt -l p --voice pm_alex > audio.wav
 | 
			
		||||
 | 
			
		||||
Common issues:
 | 
			
		||||
pip not installed: `uv pip install pip`
 | 
			
		||||
(Temporary workaround while https://github.com/explosion/spaCy/issues/13747 is not fixed)
 | 
			
		||||
 | 
			
		||||
espeak not installed: `apt-get install espeak-ng`
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import wave
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Generator, TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from loguru import logger
 | 
			
		||||
 | 
			
		||||
languages = [
 | 
			
		||||
    "a",  # American English
 | 
			
		||||
    "b",  # British English
 | 
			
		||||
    "h",  # Hindi
 | 
			
		||||
    "e",  # Spanish
 | 
			
		||||
    "f",  # French
 | 
			
		||||
    "i",  # Italian
 | 
			
		||||
    "p",  # Brazilian Portuguese
 | 
			
		||||
    "j",  # Japanese
 | 
			
		||||
    "z",  # Mandarin Chinese
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from kokoro import KPipeline
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_audio(
 | 
			
		||||
    text: str, kokoro_language: str, voice: str, speed=1
 | 
			
		||||
) -> Generator["KPipeline.Result", None, None]:
 | 
			
		||||
    from kokoro import KPipeline
 | 
			
		||||
 | 
			
		||||
    if not voice.startswith(kokoro_language):
 | 
			
		||||
        logger.warning(f"Voice {voice} is not made for language {kokoro_language}")
 | 
			
		||||
    pipeline = KPipeline(lang_code=kokoro_language)
 | 
			
		||||
    yield from pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_and_save_audio(
 | 
			
		||||
    output_file: Path, text: str, kokoro_language: str, voice: str, speed=1
 | 
			
		||||
) -> None:
 | 
			
		||||
    with wave.open(str(output_file.resolve()), "wb") as wav_file:
 | 
			
		||||
        wav_file.setnchannels(1)  # Mono audio
 | 
			
		||||
        wav_file.setsampwidth(2)  # 2 bytes per sample (16-bit audio)
 | 
			
		||||
        wav_file.setframerate(24000)  # Sample rate
 | 
			
		||||
 | 
			
		||||
        for result in generate_audio(
 | 
			
		||||
            text, kokoro_language=kokoro_language, voice=voice, speed=speed
 | 
			
		||||
        ):
 | 
			
		||||
            logger.debug(result.phonemes)
 | 
			
		||||
            if result.audio is None:
 | 
			
		||||
                continue
 | 
			
		||||
            audio_bytes = (result.audio.numpy() * 32767).astype(np.int16).tobytes()
 | 
			
		||||
            wav_file.writeframes(audio_bytes)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main() -> None:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-m",
 | 
			
		||||
        "--voice",
 | 
			
		||||
        default="af_heart",
 | 
			
		||||
        help="Voice to use",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-l",
 | 
			
		||||
        "--language",
 | 
			
		||||
        help="Language to use (defaults to the one corresponding to the voice)",
 | 
			
		||||
        choices=languages,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-o",
 | 
			
		||||
        "--output-file",
 | 
			
		||||
        "--output_file",
 | 
			
		||||
        type=Path,
 | 
			
		||||
        help="Path to output WAV file",
 | 
			
		||||
        required=True,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-i",
 | 
			
		||||
        "--input-file",
 | 
			
		||||
        "--input_file",
 | 
			
		||||
        type=Path,
 | 
			
		||||
        help="Path to input text file (default: stdin)",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-t",
 | 
			
		||||
        "--text",
 | 
			
		||||
        help="Text to use instead of reading from stdin",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-s",
 | 
			
		||||
        "--speed",
 | 
			
		||||
        type=float,
 | 
			
		||||
        default=1.0,
 | 
			
		||||
        help="Speech speed",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--debug",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Print DEBUG messages to console",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    if args.debug:
 | 
			
		||||
        logger.level("DEBUG")
 | 
			
		||||
    logger.debug(args)
 | 
			
		||||
 | 
			
		||||
    lang = args.language or args.voice[0]
 | 
			
		||||
 | 
			
		||||
    if args.text is not None and args.input_file is not None:
 | 
			
		||||
        raise Exception("You cannot specify both 'text' and 'input_file'")
 | 
			
		||||
    elif args.text:
 | 
			
		||||
        text = args.text
 | 
			
		||||
    elif args.input_file:
 | 
			
		||||
        file: Path = args.input_file
 | 
			
		||||
        text = file.read_text()
 | 
			
		||||
    else:
 | 
			
		||||
        import sys
 | 
			
		||||
        print("Press Ctrl+D to stop reading input and start generating", flush=True)
 | 
			
		||||
        text = '\n'.join(sys.stdin)
 | 
			
		||||
 | 
			
		||||
    logger.debug(f"Input text: {text!r}")
 | 
			
		||||
 | 
			
		||||
    out_file: Path = args.output_file
 | 
			
		||||
    if not out_file.suffix == ".wav":
 | 
			
		||||
        logger.warning("The output file name should end with .wav")
 | 
			
		||||
    generate_and_save_audio(
 | 
			
		||||
        output_file=out_file,
 | 
			
		||||
        text=text,
 | 
			
		||||
        kokoro_language=lang,
 | 
			
		||||
        voice=args.voice,
 | 
			
		||||
        speed=args.speed,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										197
									
								
								wan/multitalk/kokoro/custom_stft.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										197
									
								
								wan/multitalk/kokoro/custom_stft.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,197 @@
 | 
			
		||||
from attr import attr
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
class CustomSTFT(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d.
 | 
			
		||||
 | 
			
		||||
    - forward STFT => Real-part conv1d + Imag-part conv1d
 | 
			
		||||
    - inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum
 | 
			
		||||
    - avoids F.unfold, so easier to export to ONNX
 | 
			
		||||
    - uses replicate or constant padding for 'center=True' to approximate 'reflect' 
 | 
			
		||||
      (reflect is not supported for dynamic shapes in ONNX)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        filter_length=800,
 | 
			
		||||
        hop_length=200,
 | 
			
		||||
        win_length=800,
 | 
			
		||||
        window="hann",
 | 
			
		||||
        center=True,
 | 
			
		||||
        pad_mode="replicate",  # or 'constant'
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.filter_length = filter_length
 | 
			
		||||
        self.hop_length = hop_length
 | 
			
		||||
        self.win_length = win_length
 | 
			
		||||
        self.n_fft = filter_length
 | 
			
		||||
        self.center = center
 | 
			
		||||
        self.pad_mode = pad_mode
 | 
			
		||||
 | 
			
		||||
        # Number of frequency bins for real-valued STFT with onesided=True
 | 
			
		||||
        self.freq_bins = self.n_fft // 2 + 1
 | 
			
		||||
 | 
			
		||||
        # Build window
 | 
			
		||||
        assert window == 'hann', window
 | 
			
		||||
        window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
 | 
			
		||||
        if self.win_length < self.n_fft:
 | 
			
		||||
            # Zero-pad up to n_fft
 | 
			
		||||
            extra = self.n_fft - self.win_length
 | 
			
		||||
            window_tensor = F.pad(window_tensor, (0, extra))
 | 
			
		||||
        elif self.win_length > self.n_fft:
 | 
			
		||||
            window_tensor = window_tensor[: self.n_fft]
 | 
			
		||||
        self.register_buffer("window", window_tensor)
 | 
			
		||||
 | 
			
		||||
        # Precompute forward DFT (real, imag)
 | 
			
		||||
        # PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...)
 | 
			
		||||
        n = np.arange(self.n_fft)
 | 
			
		||||
        k = np.arange(self.freq_bins)
 | 
			
		||||
        angle = 2 * np.pi * np.outer(k, n) / self.n_fft  # shape (freq_bins, n_fft)
 | 
			
		||||
        dft_real = np.cos(angle)
 | 
			
		||||
        dft_imag = -np.sin(angle)  # note negative sign
 | 
			
		||||
 | 
			
		||||
        # Combine window and dft => shape (freq_bins, filter_length)
 | 
			
		||||
        # We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length).
 | 
			
		||||
        forward_window = window_tensor.numpy()  # shape (n_fft,)
 | 
			
		||||
        forward_real = dft_real * forward_window  # (freq_bins, n_fft)
 | 
			
		||||
        forward_imag = dft_imag * forward_window
 | 
			
		||||
 | 
			
		||||
        # Convert to PyTorch
 | 
			
		||||
        forward_real_torch = torch.from_numpy(forward_real).float()
 | 
			
		||||
        forward_imag_torch = torch.from_numpy(forward_imag).float()
 | 
			
		||||
 | 
			
		||||
        # Register as Conv1d weight => (out_channels, in_channels, kernel_size)
 | 
			
		||||
        # out_channels = freq_bins, in_channels=1, kernel_size=n_fft
 | 
			
		||||
        self.register_buffer(
 | 
			
		||||
            "weight_forward_real", forward_real_torch.unsqueeze(1)
 | 
			
		||||
        )
 | 
			
		||||
        self.register_buffer(
 | 
			
		||||
            "weight_forward_imag", forward_imag_torch.unsqueeze(1)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Precompute inverse DFT
 | 
			
		||||
        # Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc.
 | 
			
		||||
        # For simplicity, we won't do the "DC/nyquist not doubled" approach here. 
 | 
			
		||||
        # If you want perfect real iSTFT, you can add that logic. 
 | 
			
		||||
        # This version just yields good approximate reconstruction with Hann + typical overlap.
 | 
			
		||||
        inv_scale = 1.0 / self.n_fft
 | 
			
		||||
        n = np.arange(self.n_fft)
 | 
			
		||||
        angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft  # shape (n_fft, freq_bins)
 | 
			
		||||
        idft_cos = np.cos(angle_t).T  # => (freq_bins, n_fft)
 | 
			
		||||
        idft_sin = np.sin(angle_t).T  # => (freq_bins, n_fft)
 | 
			
		||||
 | 
			
		||||
        # Multiply by window again for typical overlap-add
 | 
			
		||||
        # We also incorporate the scale factor 1/n_fft
 | 
			
		||||
        inv_window = window_tensor.numpy() * inv_scale
 | 
			
		||||
        backward_real = idft_cos * inv_window  # (freq_bins, n_fft)
 | 
			
		||||
        backward_imag = idft_sin * inv_window
 | 
			
		||||
 | 
			
		||||
        # We'll implement iSTFT as real+imag conv_transpose with stride=hop.
 | 
			
		||||
        self.register_buffer(
 | 
			
		||||
            "weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1)
 | 
			
		||||
        )
 | 
			
		||||
        self.register_buffer(
 | 
			
		||||
            "weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1)
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def transform(self, waveform: torch.Tensor):
 | 
			
		||||
        """
 | 
			
		||||
        Forward STFT => returns magnitude, phase
 | 
			
		||||
        Output shape => (batch, freq_bins, frames)
 | 
			
		||||
        """
 | 
			
		||||
        # waveform shape => (B, T).  conv1d expects (B, 1, T).
 | 
			
		||||
        # Optional center pad
 | 
			
		||||
        if self.center:
 | 
			
		||||
            pad_len = self.n_fft // 2
 | 
			
		||||
            waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode)
 | 
			
		||||
 | 
			
		||||
        x = waveform.unsqueeze(1)  # => (B, 1, T)
 | 
			
		||||
        # Convolution to get real part => shape (B, freq_bins, frames)
 | 
			
		||||
        real_out = F.conv1d(
 | 
			
		||||
            x,
 | 
			
		||||
            self.weight_forward_real,
 | 
			
		||||
            bias=None,
 | 
			
		||||
            stride=self.hop_length,
 | 
			
		||||
            padding=0,
 | 
			
		||||
        )
 | 
			
		||||
        # Imag part
 | 
			
		||||
        imag_out = F.conv1d(
 | 
			
		||||
            x,
 | 
			
		||||
            self.weight_forward_imag,
 | 
			
		||||
            bias=None,
 | 
			
		||||
            stride=self.hop_length,
 | 
			
		||||
            padding=0,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # magnitude, phase
 | 
			
		||||
        magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14)
 | 
			
		||||
        phase = torch.atan2(imag_out, real_out)
 | 
			
		||||
        # Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch
 | 
			
		||||
        # In this case, PyTorch returns pi, ONNX returns -pi
 | 
			
		||||
        correction_mask = (imag_out == 0) & (real_out < 0)
 | 
			
		||||
        phase[correction_mask] = torch.pi
 | 
			
		||||
        return magnitude, phase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None):
 | 
			
		||||
        """
 | 
			
		||||
        Inverse STFT => returns waveform shape (B, T).
 | 
			
		||||
        """
 | 
			
		||||
        # magnitude, phase => (B, freq_bins, frames)
 | 
			
		||||
        # Re-create real/imag => shape (B, freq_bins, frames)
 | 
			
		||||
        real_part = magnitude * torch.cos(phase)
 | 
			
		||||
        imag_part = magnitude * torch.sin(phase)
 | 
			
		||||
 | 
			
		||||
        # conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension
 | 
			
		||||
        # so we do (B, freq_bins, frames) => (B, freq_bins, frames)
 | 
			
		||||
        # But PyTorch conv_transpose1d expects (B, in_channels, input_length)
 | 
			
		||||
        real_part = real_part  # (B, freq_bins, frames)
 | 
			
		||||
        imag_part = imag_part
 | 
			
		||||
 | 
			
		||||
        # real iSTFT => convolve with "backward_real", "backward_imag", and sum
 | 
			
		||||
        # We'll do 2 conv_transpose calls, each giving (B, 1, time),
 | 
			
		||||
        # then add them => (B, 1, time).
 | 
			
		||||
        real_rec = F.conv_transpose1d(
 | 
			
		||||
            real_part,
 | 
			
		||||
            self.weight_backward_real,  # shape (freq_bins, 1, filter_length)
 | 
			
		||||
            bias=None,
 | 
			
		||||
            stride=self.hop_length,
 | 
			
		||||
            padding=0,
 | 
			
		||||
        )
 | 
			
		||||
        imag_rec = F.conv_transpose1d(
 | 
			
		||||
            imag_part,
 | 
			
		||||
            self.weight_backward_imag,
 | 
			
		||||
            bias=None,
 | 
			
		||||
            stride=self.hop_length,
 | 
			
		||||
            padding=0,
 | 
			
		||||
        )
 | 
			
		||||
        # sum => (B, 1, time)
 | 
			
		||||
        waveform = real_rec - imag_rec  # typical real iFFT has minus for imaginary part
 | 
			
		||||
 | 
			
		||||
        # If we used "center=True" in forward, we should remove pad
 | 
			
		||||
        if self.center:
 | 
			
		||||
            pad_len = self.n_fft // 2
 | 
			
		||||
            # Because of transposed convolution, total length might have extra samples
 | 
			
		||||
            # We remove `pad_len` from start & end if possible
 | 
			
		||||
            waveform = waveform[..., pad_len:-pad_len]
 | 
			
		||||
 | 
			
		||||
        # If a specific length is desired, clamp
 | 
			
		||||
        if length is not None:
 | 
			
		||||
            waveform = waveform[..., :length]
 | 
			
		||||
 | 
			
		||||
        # shape => (B, T)
 | 
			
		||||
        return waveform
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        """
 | 
			
		||||
        Full STFT -> iSTFT pass: returns time-domain reconstruction.
 | 
			
		||||
        Same interface as your original code.
 | 
			
		||||
        """
 | 
			
		||||
        mag, phase = self.transform(x)
 | 
			
		||||
        return self.inverse(mag, phase, length=x.shape[-1])
 | 
			
		||||
							
								
								
									
										421
									
								
								wan/multitalk/kokoro/istftnet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										421
									
								
								wan/multitalk/kokoro/istftnet.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,421 @@
 | 
			
		||||
# ADAPTED from https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
 | 
			
		||||
from .custom_stft import CustomSTFT
 | 
			
		||||
from torch.nn.utils import weight_norm
 | 
			
		||||
import math
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
 | 
			
		||||
def init_weights(m, mean=0.0, std=0.01):
 | 
			
		||||
    classname = m.__class__.__name__
 | 
			
		||||
    if classname.find("Conv") != -1:
 | 
			
		||||
        m.weight.data.normal_(mean, std)
 | 
			
		||||
 | 
			
		||||
def get_padding(kernel_size, dilation=1):
 | 
			
		||||
    return int((kernel_size*dilation - dilation)/2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AdaIN1d(nn.Module):
 | 
			
		||||
    def __init__(self, style_dim, num_features):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # affine should be False, however there's a bug in the old torch.onnx.export (not newer dynamo) that causes the channel dimension to be lost if affine=False. When affine is true, there's additional learnably parameters. This shouldn't really matter setting it to True, since we're in inference mode
 | 
			
		||||
        self.norm = nn.InstanceNorm1d(num_features, affine=True)
 | 
			
		||||
        self.fc = nn.Linear(style_dim, num_features*2)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, s):
 | 
			
		||||
        h = self.fc(s)
 | 
			
		||||
        h = h.view(h.size(0), h.size(1), 1)
 | 
			
		||||
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
 | 
			
		||||
        return (1 + gamma) * self.norm(x) + beta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AdaINResBlock1(nn.Module):
 | 
			
		||||
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
 | 
			
		||||
        super(AdaINResBlock1, self).__init__()
 | 
			
		||||
        self.convs1 = nn.ModuleList([
 | 
			
		||||
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
 | 
			
		||||
                                  padding=get_padding(kernel_size, dilation[0]))),
 | 
			
		||||
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
 | 
			
		||||
                                  padding=get_padding(kernel_size, dilation[1]))),
 | 
			
		||||
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
 | 
			
		||||
                                  padding=get_padding(kernel_size, dilation[2])))
 | 
			
		||||
        ])
 | 
			
		||||
        self.convs1.apply(init_weights)
 | 
			
		||||
        self.convs2 = nn.ModuleList([
 | 
			
		||||
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
 | 
			
		||||
                                  padding=get_padding(kernel_size, 1))),
 | 
			
		||||
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
 | 
			
		||||
                                  padding=get_padding(kernel_size, 1))),
 | 
			
		||||
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
 | 
			
		||||
                                  padding=get_padding(kernel_size, 1)))
 | 
			
		||||
        ])
 | 
			
		||||
        self.convs2.apply(init_weights)
 | 
			
		||||
        self.adain1 = nn.ModuleList([
 | 
			
		||||
            AdaIN1d(style_dim, channels),
 | 
			
		||||
            AdaIN1d(style_dim, channels),
 | 
			
		||||
            AdaIN1d(style_dim, channels),
 | 
			
		||||
        ])
 | 
			
		||||
        self.adain2 = nn.ModuleList([
 | 
			
		||||
            AdaIN1d(style_dim, channels),
 | 
			
		||||
            AdaIN1d(style_dim, channels),
 | 
			
		||||
            AdaIN1d(style_dim, channels),
 | 
			
		||||
        ])
 | 
			
		||||
        self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
 | 
			
		||||
        self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, s):
 | 
			
		||||
        for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
 | 
			
		||||
            xt = n1(x, s)
 | 
			
		||||
            xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2)  # Snake1D
 | 
			
		||||
            xt = c1(xt)
 | 
			
		||||
            xt = n2(xt, s)
 | 
			
		||||
            xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2)  # Snake1D
 | 
			
		||||
            xt = c2(xt)
 | 
			
		||||
            x = xt + x
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TorchSTFT(nn.Module):
 | 
			
		||||
    def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.filter_length = filter_length
 | 
			
		||||
        self.hop_length = hop_length
 | 
			
		||||
        self.win_length = win_length
 | 
			
		||||
        assert window == 'hann', window
 | 
			
		||||
        self.window = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    def transform(self, input_data):
 | 
			
		||||
        forward_transform = torch.stft(
 | 
			
		||||
            input_data,
 | 
			
		||||
            self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
 | 
			
		||||
            return_complex=True)
 | 
			
		||||
        return torch.abs(forward_transform), torch.angle(forward_transform)
 | 
			
		||||
 | 
			
		||||
    def inverse(self, magnitude, phase):
 | 
			
		||||
        inverse_transform = torch.istft(
 | 
			
		||||
            magnitude * torch.exp(phase * 1j),
 | 
			
		||||
            self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
 | 
			
		||||
        return inverse_transform.unsqueeze(-2)  # unsqueeze to stay consistent with conv_transpose1d implementation
 | 
			
		||||
 | 
			
		||||
    def forward(self, input_data):
 | 
			
		||||
        self.magnitude, self.phase = self.transform(input_data)
 | 
			
		||||
        reconstruction = self.inverse(self.magnitude, self.phase)
 | 
			
		||||
        return reconstruction
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SineGen(nn.Module):
 | 
			
		||||
    """ Definition of sine generator
 | 
			
		||||
    SineGen(samp_rate, harmonic_num = 0,
 | 
			
		||||
            sine_amp = 0.1, noise_std = 0.003,
 | 
			
		||||
            voiced_threshold = 0,
 | 
			
		||||
            flag_for_pulse=False)
 | 
			
		||||
    samp_rate: sampling rate in Hz
 | 
			
		||||
    harmonic_num: number of harmonic overtones (default 0)
 | 
			
		||||
    sine_amp: amplitude of sine-wavefrom (default 0.1)
 | 
			
		||||
    noise_std: std of Gaussian noise (default 0.003)
 | 
			
		||||
    voiced_thoreshold: F0 threshold for U/V classification (default 0)
 | 
			
		||||
    flag_for_pulse: this SinGen is used inside PulseGen (default False)
 | 
			
		||||
    Note: when flag_for_pulse is True, the first time step of a voiced
 | 
			
		||||
        segment is always sin(torch.pi) or cos(0)
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
 | 
			
		||||
                 sine_amp=0.1, noise_std=0.003,
 | 
			
		||||
                 voiced_threshold=0,
 | 
			
		||||
                 flag_for_pulse=False):
 | 
			
		||||
        super(SineGen, self).__init__()
 | 
			
		||||
        self.sine_amp = sine_amp
 | 
			
		||||
        self.noise_std = noise_std
 | 
			
		||||
        self.harmonic_num = harmonic_num
 | 
			
		||||
        self.dim = self.harmonic_num + 1
 | 
			
		||||
        self.sampling_rate = samp_rate
 | 
			
		||||
        self.voiced_threshold = voiced_threshold
 | 
			
		||||
        self.flag_for_pulse = flag_for_pulse
 | 
			
		||||
        self.upsample_scale = upsample_scale
 | 
			
		||||
 | 
			
		||||
    def _f02uv(self, f0):
 | 
			
		||||
        # generate uv signal
 | 
			
		||||
        uv = (f0 > self.voiced_threshold).type(torch.float32)
 | 
			
		||||
        return uv
 | 
			
		||||
 | 
			
		||||
    def _f02sine(self, f0_values):
 | 
			
		||||
        """ f0_values: (batchsize, length, dim)
 | 
			
		||||
            where dim indicates fundamental tone and overtones
 | 
			
		||||
        """
 | 
			
		||||
        # convert to F0 in rad. The interger part n can be ignored
 | 
			
		||||
        # because 2 * torch.pi * n doesn't affect phase
 | 
			
		||||
        rad_values = (f0_values / self.sampling_rate) % 1
 | 
			
		||||
        # initial phase noise (no noise for fundamental component)
 | 
			
		||||
        rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
 | 
			
		||||
        rand_ini[:, 0] = 0
 | 
			
		||||
        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
 | 
			
		||||
        # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
 | 
			
		||||
        if not self.flag_for_pulse:
 | 
			
		||||
            rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
 | 
			
		||||
            phase = torch.cumsum(rad_values, dim=1) * 2 * torch.pi
 | 
			
		||||
            phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
 | 
			
		||||
            sines = torch.sin(phase)
 | 
			
		||||
        else:
 | 
			
		||||
            # If necessary, make sure that the first time step of every
 | 
			
		||||
            # voiced segments is sin(pi) or cos(0)
 | 
			
		||||
            # This is used for pulse-train generation
 | 
			
		||||
            # identify the last time step in unvoiced segments
 | 
			
		||||
            uv = self._f02uv(f0_values)
 | 
			
		||||
            uv_1 = torch.roll(uv, shifts=-1, dims=1)
 | 
			
		||||
            uv_1[:, -1, :] = 1
 | 
			
		||||
            u_loc = (uv < 1) * (uv_1 > 0)
 | 
			
		||||
            # get the instantanouse phase
 | 
			
		||||
            tmp_cumsum = torch.cumsum(rad_values, dim=1)
 | 
			
		||||
            # different batch needs to be processed differently
 | 
			
		||||
            for idx in range(f0_values.shape[0]):
 | 
			
		||||
                temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
 | 
			
		||||
                temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
 | 
			
		||||
                # stores the accumulation of i.phase within
 | 
			
		||||
                # each voiced segments
 | 
			
		||||
                tmp_cumsum[idx, :, :] = 0
 | 
			
		||||
                tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
 | 
			
		||||
            # rad_values - tmp_cumsum: remove the accumulation of i.phase
 | 
			
		||||
            # within the previous voiced segment.
 | 
			
		||||
            i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
 | 
			
		||||
            # get the sines
 | 
			
		||||
            sines = torch.cos(i_phase * 2 * torch.pi)
 | 
			
		||||
        return sines
 | 
			
		||||
 | 
			
		||||
    def forward(self, f0):
 | 
			
		||||
        """ sine_tensor, uv = forward(f0)
 | 
			
		||||
        input F0: tensor(batchsize=1, length, dim=1)
 | 
			
		||||
                  f0 for unvoiced steps should be 0
 | 
			
		||||
        output sine_tensor: tensor(batchsize=1, length, dim)
 | 
			
		||||
        output uv: tensor(batchsize=1, length, 1)
 | 
			
		||||
        """
 | 
			
		||||
        f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
 | 
			
		||||
        # fundamental component
 | 
			
		||||
        fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
 | 
			
		||||
        # generate sine waveforms
 | 
			
		||||
        sine_waves = self._f02sine(fn) * self.sine_amp
 | 
			
		||||
        # generate uv signal
 | 
			
		||||
        # uv = torch.ones(f0.shape)
 | 
			
		||||
        # uv = uv * (f0 > self.voiced_threshold)
 | 
			
		||||
        uv = self._f02uv(f0)
 | 
			
		||||
        # noise: for unvoiced should be similar to sine_amp
 | 
			
		||||
        #        std = self.sine_amp/3 -> max value ~ self.sine_amp
 | 
			
		||||
        #        for voiced regions is self.noise_std
 | 
			
		||||
        noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
 | 
			
		||||
        noise = noise_amp * torch.randn_like(sine_waves)
 | 
			
		||||
        # first: set the unvoiced part to 0 by uv
 | 
			
		||||
        # then: additive noise
 | 
			
		||||
        sine_waves = sine_waves * uv + noise
 | 
			
		||||
        return sine_waves, uv, noise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SourceModuleHnNSF(nn.Module):
 | 
			
		||||
    """ SourceModule for hn-nsf
 | 
			
		||||
    SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
 | 
			
		||||
                 add_noise_std=0.003, voiced_threshod=0)
 | 
			
		||||
    sampling_rate: sampling_rate in Hz
 | 
			
		||||
    harmonic_num: number of harmonic above F0 (default: 0)
 | 
			
		||||
    sine_amp: amplitude of sine source signal (default: 0.1)
 | 
			
		||||
    add_noise_std: std of additive Gaussian noise (default: 0.003)
 | 
			
		||||
        note that amplitude of noise in unvoiced is decided
 | 
			
		||||
        by sine_amp
 | 
			
		||||
    voiced_threshold: threhold to set U/V given F0 (default: 0)
 | 
			
		||||
    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
 | 
			
		||||
    F0_sampled (batchsize, length, 1)
 | 
			
		||||
    Sine_source (batchsize, length, 1)
 | 
			
		||||
    noise_source (batchsize, length 1)
 | 
			
		||||
    uv (batchsize, length, 1)
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
 | 
			
		||||
                 add_noise_std=0.003, voiced_threshod=0):
 | 
			
		||||
        super(SourceModuleHnNSF, self).__init__()
 | 
			
		||||
        self.sine_amp = sine_amp
 | 
			
		||||
        self.noise_std = add_noise_std
 | 
			
		||||
        # to produce sine waveforms
 | 
			
		||||
        self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
 | 
			
		||||
                                 sine_amp, add_noise_std, voiced_threshod)
 | 
			
		||||
        # to merge source harmonics into a single excitation
 | 
			
		||||
        self.l_linear = nn.Linear(harmonic_num + 1, 1)
 | 
			
		||||
        self.l_tanh = nn.Tanh()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        """
 | 
			
		||||
        Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
 | 
			
		||||
        F0_sampled (batchsize, length, 1)
 | 
			
		||||
        Sine_source (batchsize, length, 1)
 | 
			
		||||
        noise_source (batchsize, length 1)
 | 
			
		||||
        """
 | 
			
		||||
        # source for harmonic branch
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            sine_wavs, uv, _ = self.l_sin_gen(x)
 | 
			
		||||
        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
 | 
			
		||||
        # source for noise branch, in the same shape as uv
 | 
			
		||||
        noise = torch.randn_like(uv) * self.sine_amp / 3
 | 
			
		||||
        return sine_merge, noise, uv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Generator(nn.Module):
 | 
			
		||||
    def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False):
 | 
			
		||||
        super(Generator, self).__init__()
 | 
			
		||||
        self.num_kernels = len(resblock_kernel_sizes)
 | 
			
		||||
        self.num_upsamples = len(upsample_rates)
 | 
			
		||||
        self.m_source = SourceModuleHnNSF(
 | 
			
		||||
                    sampling_rate=24000,
 | 
			
		||||
                    upsample_scale=math.prod(upsample_rates) * gen_istft_hop_size,
 | 
			
		||||
                    harmonic_num=8, voiced_threshod=10)
 | 
			
		||||
        self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates) * gen_istft_hop_size)
 | 
			
		||||
        self.noise_convs = nn.ModuleList()
 | 
			
		||||
        self.noise_res = nn.ModuleList()
 | 
			
		||||
        self.ups = nn.ModuleList()
 | 
			
		||||
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
 | 
			
		||||
            self.ups.append(weight_norm(
 | 
			
		||||
                nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
 | 
			
		||||
                                   k, u, padding=(k-u)//2)))
 | 
			
		||||
        self.resblocks = nn.ModuleList()
 | 
			
		||||
        for i in range(len(self.ups)):
 | 
			
		||||
            ch = upsample_initial_channel//(2**(i+1))
 | 
			
		||||
            for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
 | 
			
		||||
                self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
 | 
			
		||||
            c_cur = upsample_initial_channel // (2 ** (i + 1))
 | 
			
		||||
            if i + 1 < len(upsample_rates):
 | 
			
		||||
                stride_f0 = math.prod(upsample_rates[i + 1:])
 | 
			
		||||
                self.noise_convs.append(nn.Conv1d(
 | 
			
		||||
                    gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
 | 
			
		||||
                self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
 | 
			
		||||
            else:
 | 
			
		||||
                self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
 | 
			
		||||
                self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
 | 
			
		||||
        self.post_n_fft = gen_istft_n_fft
 | 
			
		||||
        self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
 | 
			
		||||
        self.ups.apply(init_weights)
 | 
			
		||||
        self.conv_post.apply(init_weights)
 | 
			
		||||
        self.reflection_pad = nn.ReflectionPad1d((1, 0))
 | 
			
		||||
        self.stft = (
 | 
			
		||||
            CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
 | 
			
		||||
            if disable_complex
 | 
			
		||||
            else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, s, f0):
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
 | 
			
		||||
            har_source, noi_source, uv = self.m_source(f0)
 | 
			
		||||
            har_source = har_source.transpose(1, 2).squeeze(1)
 | 
			
		||||
            har_spec, har_phase = self.stft.transform(har_source)
 | 
			
		||||
            har = torch.cat([har_spec, har_phase], dim=1)
 | 
			
		||||
        for i in range(self.num_upsamples):
 | 
			
		||||
            x = F.leaky_relu(x, negative_slope=0.1) 
 | 
			
		||||
            x_source = self.noise_convs[i](har)
 | 
			
		||||
            x_source = self.noise_res[i](x_source, s)
 | 
			
		||||
            x = self.ups[i](x)
 | 
			
		||||
            if i == self.num_upsamples - 1:
 | 
			
		||||
                x = self.reflection_pad(x)
 | 
			
		||||
            x = x + x_source
 | 
			
		||||
            xs = None
 | 
			
		||||
            for j in range(self.num_kernels):
 | 
			
		||||
                if xs is None:
 | 
			
		||||
                    xs = self.resblocks[i*self.num_kernels+j](x, s)
 | 
			
		||||
                else:
 | 
			
		||||
                    xs += self.resblocks[i*self.num_kernels+j](x, s)
 | 
			
		||||
            x = xs / self.num_kernels
 | 
			
		||||
        x = F.leaky_relu(x)
 | 
			
		||||
        x = self.conv_post(x)
 | 
			
		||||
        spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
 | 
			
		||||
        phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
 | 
			
		||||
        return self.stft.inverse(spec, phase)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UpSample1d(nn.Module):
 | 
			
		||||
    def __init__(self, layer_type):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.layer_type = layer_type
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        if self.layer_type == 'none':
 | 
			
		||||
            return x
 | 
			
		||||
        else:
 | 
			
		||||
            return F.interpolate(x, scale_factor=2, mode='nearest')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AdainResBlk1d(nn.Module):
 | 
			
		||||
    def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.actv = actv
 | 
			
		||||
        self.upsample_type = upsample
 | 
			
		||||
        self.upsample = UpSample1d(upsample)
 | 
			
		||||
        self.learned_sc = dim_in != dim_out
 | 
			
		||||
        self._build_weights(dim_in, dim_out, style_dim)
 | 
			
		||||
        self.dropout = nn.Dropout(dropout_p)
 | 
			
		||||
        if upsample == 'none':
 | 
			
		||||
            self.pool = nn.Identity()
 | 
			
		||||
        else:
 | 
			
		||||
            self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
 | 
			
		||||
 | 
			
		||||
    def _build_weights(self, dim_in, dim_out, style_dim):
 | 
			
		||||
        self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
 | 
			
		||||
        self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
 | 
			
		||||
        self.norm1 = AdaIN1d(style_dim, dim_in)
 | 
			
		||||
        self.norm2 = AdaIN1d(style_dim, dim_out)
 | 
			
		||||
        if self.learned_sc:
 | 
			
		||||
            self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
 | 
			
		||||
 | 
			
		||||
    def _shortcut(self, x):
 | 
			
		||||
        x = self.upsample(x)
 | 
			
		||||
        if self.learned_sc:
 | 
			
		||||
            x = self.conv1x1(x)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def _residual(self, x, s):
 | 
			
		||||
        x = self.norm1(x, s)
 | 
			
		||||
        x = self.actv(x)
 | 
			
		||||
        x = self.pool(x)
 | 
			
		||||
        x = self.conv1(self.dropout(x))
 | 
			
		||||
        x = self.norm2(x, s)
 | 
			
		||||
        x = self.actv(x)
 | 
			
		||||
        x = self.conv2(self.dropout(x))
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, s):
 | 
			
		||||
        out = self._residual(x, s)
 | 
			
		||||
        out = (out + self._shortcut(x)) * torch.rsqrt(torch.tensor(2))
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Decoder(nn.Module):
 | 
			
		||||
    def __init__(self, dim_in, style_dim, dim_out, 
 | 
			
		||||
                 resblock_kernel_sizes,
 | 
			
		||||
                 upsample_rates,
 | 
			
		||||
                 upsample_initial_channel,
 | 
			
		||||
                 resblock_dilation_sizes,
 | 
			
		||||
                 upsample_kernel_sizes,
 | 
			
		||||
                 gen_istft_n_fft, gen_istft_hop_size,
 | 
			
		||||
                 disable_complex=False):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
 | 
			
		||||
        self.decode = nn.ModuleList()
 | 
			
		||||
        self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
 | 
			
		||||
        self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
 | 
			
		||||
        self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
 | 
			
		||||
        self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
 | 
			
		||||
        self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
 | 
			
		||||
        self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
 | 
			
		||||
        self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
 | 
			
		||||
        self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, 
 | 
			
		||||
                                   upsample_initial_channel, resblock_dilation_sizes, 
 | 
			
		||||
                                   upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex)
 | 
			
		||||
 | 
			
		||||
    def forward(self, asr, F0_curve, N, s):
 | 
			
		||||
        F0 = self.F0_conv(F0_curve.unsqueeze(1))
 | 
			
		||||
        N = self.N_conv(N.unsqueeze(1))
 | 
			
		||||
        x = torch.cat([asr, F0, N], axis=1)
 | 
			
		||||
        x = self.encode(x, s)
 | 
			
		||||
        asr_res = self.asr_res(asr)
 | 
			
		||||
        res = True
 | 
			
		||||
        for block in self.decode:
 | 
			
		||||
            if res:
 | 
			
		||||
                x = torch.cat([x, asr_res, F0, N], axis=1)
 | 
			
		||||
            x = block(x, s)
 | 
			
		||||
            if block.upsample_type != "none":
 | 
			
		||||
                res = False
 | 
			
		||||
        x = self.generator(x, s, F0_curve)
 | 
			
		||||
        return x
 | 
			
		||||
							
								
								
									
										155
									
								
								wan/multitalk/kokoro/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										155
									
								
								wan/multitalk/kokoro/model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,155 @@
 | 
			
		||||
from .istftnet import Decoder
 | 
			
		||||
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from loguru import logger
 | 
			
		||||
from transformers import AlbertConfig
 | 
			
		||||
from typing import Dict, Optional, Union
 | 
			
		||||
import json
 | 
			
		||||
import torch
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
class KModel(torch.nn.Module):
 | 
			
		||||
    '''
 | 
			
		||||
    KModel is a torch.nn.Module with 2 main responsibilities:
 | 
			
		||||
    1. Init weights, downloading config.json + model.pth from HF if needed
 | 
			
		||||
    2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
 | 
			
		||||
 | 
			
		||||
    You likely only need one KModel instance, and it can be reused across
 | 
			
		||||
    multiple KPipelines to avoid redundant memory allocation.
 | 
			
		||||
 | 
			
		||||
    Unlike KPipeline, KModel is language-blind.
 | 
			
		||||
 | 
			
		||||
    KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
 | 
			
		||||
    so there is no need to repeatedly download config.json outside of KModel.
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    MODEL_NAMES = {
 | 
			
		||||
        'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
 | 
			
		||||
        'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        repo_id: Optional[str] = None,
 | 
			
		||||
        config: Union[Dict, str, None] = None,
 | 
			
		||||
        model: Optional[str] = None,
 | 
			
		||||
        disable_complex: bool = False
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        if repo_id is None:
 | 
			
		||||
            repo_id = 'hexgrad/Kokoro-82M'
 | 
			
		||||
            print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
 | 
			
		||||
        self.repo_id = repo_id
 | 
			
		||||
        if not isinstance(config, dict):
 | 
			
		||||
            if not config:
 | 
			
		||||
                logger.debug("No config provided, downloading from HF")
 | 
			
		||||
                config = hf_hub_download(repo_id=repo_id, filename='config.json')
 | 
			
		||||
            with open(config, 'r', encoding='utf-8') as r:
 | 
			
		||||
                config = json.load(r)
 | 
			
		||||
                logger.debug(f"Loaded config: {config}")
 | 
			
		||||
        self.vocab = config['vocab']
 | 
			
		||||
        self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
 | 
			
		||||
        self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
 | 
			
		||||
        self.context_length = self.bert.config.max_position_embeddings
 | 
			
		||||
        self.predictor = ProsodyPredictor(
 | 
			
		||||
            style_dim=config['style_dim'], d_hid=config['hidden_dim'],
 | 
			
		||||
            nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
 | 
			
		||||
        )
 | 
			
		||||
        self.text_encoder = TextEncoder(
 | 
			
		||||
            channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
 | 
			
		||||
            depth=config['n_layer'], n_symbols=config['n_token']
 | 
			
		||||
        )
 | 
			
		||||
        self.decoder = Decoder(
 | 
			
		||||
            dim_in=config['hidden_dim'], style_dim=config['style_dim'],
 | 
			
		||||
            dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
 | 
			
		||||
        )
 | 
			
		||||
        if not model:
 | 
			
		||||
            try:
 | 
			
		||||
                model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id])
 | 
			
		||||
            except:
 | 
			
		||||
                model = os.path.join(repo_id, 'kokoro-v1_0.pth')
 | 
			
		||||
        for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
 | 
			
		||||
            assert hasattr(self, key), key
 | 
			
		||||
            try:
 | 
			
		||||
                getattr(self, key).load_state_dict(state_dict)
 | 
			
		||||
            except:
 | 
			
		||||
                logger.debug(f"Did not load {key} from state_dict")
 | 
			
		||||
                state_dict = {k[7:]: v for k, v in state_dict.items()}
 | 
			
		||||
                getattr(self, key).load_state_dict(state_dict, strict=False)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def device(self):
 | 
			
		||||
        return self.bert.device
 | 
			
		||||
 | 
			
		||||
    @dataclass
 | 
			
		||||
    class Output:
 | 
			
		||||
        audio: torch.FloatTensor
 | 
			
		||||
        pred_dur: Optional[torch.LongTensor] = None
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def forward_with_tokens(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor,
 | 
			
		||||
        ref_s: torch.FloatTensor,
 | 
			
		||||
        speed: float = 1
 | 
			
		||||
    ) -> tuple[torch.FloatTensor, torch.LongTensor]:
 | 
			
		||||
        input_lengths = torch.full(
 | 
			
		||||
            (input_ids.shape[0],), 
 | 
			
		||||
            input_ids.shape[-1], 
 | 
			
		||||
            device=input_ids.device,
 | 
			
		||||
            dtype=torch.long
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
 | 
			
		||||
        text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
 | 
			
		||||
        bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
 | 
			
		||||
        d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
 | 
			
		||||
        s = ref_s[:, 128:]
 | 
			
		||||
        d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
 | 
			
		||||
        x, _ = self.predictor.lstm(d)
 | 
			
		||||
        duration = self.predictor.duration_proj(x)
 | 
			
		||||
        duration = torch.sigmoid(duration).sum(axis=-1) / speed
 | 
			
		||||
        pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
 | 
			
		||||
        indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
 | 
			
		||||
        pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
 | 
			
		||||
        pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
 | 
			
		||||
        pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
 | 
			
		||||
        en = d.transpose(-1, -2) @ pred_aln_trg
 | 
			
		||||
        F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
 | 
			
		||||
        t_en = self.text_encoder(input_ids, input_lengths, text_mask)
 | 
			
		||||
        asr = t_en @ pred_aln_trg
 | 
			
		||||
        audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
 | 
			
		||||
        return audio, pred_dur
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        phonemes: str,
 | 
			
		||||
        ref_s: torch.FloatTensor,
 | 
			
		||||
        speed: float = 1,
 | 
			
		||||
        return_output: bool = False
 | 
			
		||||
    ) -> Union['KModel.Output', torch.FloatTensor]:
 | 
			
		||||
        input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
 | 
			
		||||
        logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
 | 
			
		||||
        assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
 | 
			
		||||
        input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
 | 
			
		||||
        ref_s = ref_s.to(self.device)
 | 
			
		||||
        audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed)
 | 
			
		||||
        audio = audio.squeeze().cpu()
 | 
			
		||||
        pred_dur = pred_dur.cpu() if pred_dur is not None else None
 | 
			
		||||
        logger.debug(f"pred_dur: {pred_dur}")
 | 
			
		||||
        return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio
 | 
			
		||||
 | 
			
		||||
class KModelForONNX(torch.nn.Module):
 | 
			
		||||
    def __init__(self, kmodel: KModel):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.kmodel = kmodel
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor,
 | 
			
		||||
        ref_s: torch.FloatTensor,
 | 
			
		||||
        speed: float = 1
 | 
			
		||||
    ) -> tuple[torch.FloatTensor, torch.LongTensor]:
 | 
			
		||||
        waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
 | 
			
		||||
        return waveform, duration
 | 
			
		||||
							
								
								
									
										183
									
								
								wan/multitalk/kokoro/modules.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								wan/multitalk/kokoro/modules.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,183 @@
 | 
			
		||||
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
 | 
			
		||||
from .istftnet import AdainResBlk1d
 | 
			
		||||
from torch.nn.utils import weight_norm
 | 
			
		||||
from transformers import AlbertModel
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearNorm(nn.Module):
 | 
			
		||||
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
 | 
			
		||||
        super(LinearNorm, self).__init__()
 | 
			
		||||
        self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
 | 
			
		||||
        nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.linear_layer(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNorm(nn.Module):
 | 
			
		||||
    def __init__(self, channels, eps=1e-5):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.channels = channels
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        self.gamma = nn.Parameter(torch.ones(channels))
 | 
			
		||||
        self.beta = nn.Parameter(torch.zeros(channels))
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = x.transpose(1, -1)
 | 
			
		||||
        x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
 | 
			
		||||
        return x.transpose(1, -1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TextEncoder(nn.Module):
 | 
			
		||||
    def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.embedding = nn.Embedding(n_symbols, channels)
 | 
			
		||||
        padding = (kernel_size - 1) // 2
 | 
			
		||||
        self.cnn = nn.ModuleList()
 | 
			
		||||
        for _ in range(depth):
 | 
			
		||||
            self.cnn.append(nn.Sequential(
 | 
			
		||||
                weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
 | 
			
		||||
                LayerNorm(channels),
 | 
			
		||||
                actv,
 | 
			
		||||
                nn.Dropout(0.2),
 | 
			
		||||
            ))
 | 
			
		||||
        self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, input_lengths, m):
 | 
			
		||||
        x = self.embedding(x)  # [B, T, emb]
 | 
			
		||||
        x = x.transpose(1, 2)  # [B, emb, T]
 | 
			
		||||
        m = m.unsqueeze(1)
 | 
			
		||||
        x.masked_fill_(m, 0.0)
 | 
			
		||||
        for c in self.cnn:
 | 
			
		||||
            x = c(x)
 | 
			
		||||
            x.masked_fill_(m, 0.0)
 | 
			
		||||
        x = x.transpose(1, 2)  # [B, T, chn]
 | 
			
		||||
        lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu')
 | 
			
		||||
        x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
 | 
			
		||||
        self.lstm.flatten_parameters()
 | 
			
		||||
        x, _ = self.lstm(x)
 | 
			
		||||
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
 | 
			
		||||
        x = x.transpose(-1, -2)
 | 
			
		||||
        x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
 | 
			
		||||
        x_pad[:, :, :x.shape[-1]] = x
 | 
			
		||||
        x = x_pad
 | 
			
		||||
        x.masked_fill_(m, 0.0)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AdaLayerNorm(nn.Module):
 | 
			
		||||
    def __init__(self, style_dim, channels, eps=1e-5):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.channels = channels
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        self.fc = nn.Linear(style_dim, channels*2)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, s):
 | 
			
		||||
        x = x.transpose(-1, -2)
 | 
			
		||||
        x = x.transpose(1, -1)
 | 
			
		||||
        h = self.fc(s)
 | 
			
		||||
        h = h.view(h.size(0), h.size(1), 1)
 | 
			
		||||
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
 | 
			
		||||
        gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
 | 
			
		||||
        x = F.layer_norm(x, (self.channels,), eps=self.eps)
 | 
			
		||||
        x = (1 + gamma) * x + beta
 | 
			
		||||
        return x.transpose(1, -1).transpose(-1, -2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProsodyPredictor(nn.Module):
 | 
			
		||||
    def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
 | 
			
		||||
        self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
 | 
			
		||||
        self.duration_proj = LinearNorm(d_hid, max_dur)
 | 
			
		||||
        self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
 | 
			
		||||
        self.F0 = nn.ModuleList()
 | 
			
		||||
        self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
 | 
			
		||||
        self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
 | 
			
		||||
        self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
 | 
			
		||||
        self.N = nn.ModuleList()
 | 
			
		||||
        self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
 | 
			
		||||
        self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
 | 
			
		||||
        self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
 | 
			
		||||
        self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
 | 
			
		||||
        self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
 | 
			
		||||
 | 
			
		||||
    def forward(self, texts, style, text_lengths, alignment, m):
 | 
			
		||||
        d = self.text_encoder(texts, style, text_lengths, m)
 | 
			
		||||
        m = m.unsqueeze(1)
 | 
			
		||||
        lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
 | 
			
		||||
        x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False)
 | 
			
		||||
        self.lstm.flatten_parameters()
 | 
			
		||||
        x, _ = self.lstm(x)
 | 
			
		||||
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
 | 
			
		||||
        x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device)
 | 
			
		||||
        x_pad[:, :x.shape[1], :] = x
 | 
			
		||||
        x = x_pad
 | 
			
		||||
        duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
 | 
			
		||||
        en = (d.transpose(-1, -2) @ alignment)
 | 
			
		||||
        return duration.squeeze(-1), en
 | 
			
		||||
 | 
			
		||||
    def F0Ntrain(self, x, s):
 | 
			
		||||
        x, _ = self.shared(x.transpose(-1, -2))
 | 
			
		||||
        F0 = x.transpose(-1, -2)
 | 
			
		||||
        for block in self.F0:
 | 
			
		||||
            F0 = block(F0, s)
 | 
			
		||||
        F0 = self.F0_proj(F0)
 | 
			
		||||
        N = x.transpose(-1, -2)
 | 
			
		||||
        for block in self.N:
 | 
			
		||||
            N = block(N, s)
 | 
			
		||||
        N = self.N_proj(N)
 | 
			
		||||
        return F0.squeeze(1), N.squeeze(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DurationEncoder(nn.Module):
 | 
			
		||||
    def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.lstms = nn.ModuleList()
 | 
			
		||||
        for _ in range(nlayers):
 | 
			
		||||
            self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
 | 
			
		||||
            self.lstms.append(AdaLayerNorm(sty_dim, d_model))
 | 
			
		||||
        self.dropout = dropout
 | 
			
		||||
        self.d_model = d_model
 | 
			
		||||
        self.sty_dim = sty_dim
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, style, text_lengths, m):
 | 
			
		||||
        masks = m
 | 
			
		||||
        x = x.permute(2, 0, 1)
 | 
			
		||||
        s = style.expand(x.shape[0], x.shape[1], -1)
 | 
			
		||||
        x = torch.cat([x, s], axis=-1)
 | 
			
		||||
        x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
 | 
			
		||||
        x = x.transpose(0, 1)
 | 
			
		||||
        x = x.transpose(-1, -2)
 | 
			
		||||
        for block in self.lstms:
 | 
			
		||||
            if isinstance(block, AdaLayerNorm):
 | 
			
		||||
                x = block(x.transpose(-1, -2), style).transpose(-1, -2)
 | 
			
		||||
                x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
 | 
			
		||||
                x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
 | 
			
		||||
            else:
 | 
			
		||||
                lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
 | 
			
		||||
                x = x.transpose(-1, -2)
 | 
			
		||||
                x = nn.utils.rnn.pack_padded_sequence(
 | 
			
		||||
                    x, lengths, batch_first=True, enforce_sorted=False)
 | 
			
		||||
                block.flatten_parameters()
 | 
			
		||||
                x, _ = block(x)
 | 
			
		||||
                x, _ = nn.utils.rnn.pad_packed_sequence(
 | 
			
		||||
                    x, batch_first=True)
 | 
			
		||||
                x = F.dropout(x, p=self.dropout, training=False)
 | 
			
		||||
                x = x.transpose(-1, -2)
 | 
			
		||||
                x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
 | 
			
		||||
                x_pad[:, :, :x.shape[-1]] = x
 | 
			
		||||
                x = x_pad
 | 
			
		||||
 | 
			
		||||
        return x.transpose(-1, -2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
 | 
			
		||||
class CustomAlbert(AlbertModel):
 | 
			
		||||
    def forward(self, *args, **kwargs):
 | 
			
		||||
        outputs = super().forward(*args, **kwargs)
 | 
			
		||||
        return outputs.last_hidden_state
 | 
			
		||||
							
								
								
									
										445
									
								
								wan/multitalk/kokoro/pipeline.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										445
									
								
								wan/multitalk/kokoro/pipeline.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,445 @@
 | 
			
		||||
from .model import KModel
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from loguru import logger
 | 
			
		||||
from misaki import en, espeak
 | 
			
		||||
from typing import Callable, Generator, List, Optional, Tuple, Union
 | 
			
		||||
import re
 | 
			
		||||
import torch
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
ALIASES = {
 | 
			
		||||
    'en-us': 'a',
 | 
			
		||||
    'en-gb': 'b',
 | 
			
		||||
    'es': 'e',
 | 
			
		||||
    'fr-fr': 'f',
 | 
			
		||||
    'hi': 'h',
 | 
			
		||||
    'it': 'i',
 | 
			
		||||
    'pt-br': 'p',
 | 
			
		||||
    'ja': 'j',
 | 
			
		||||
    'zh': 'z',
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LANG_CODES = dict(
 | 
			
		||||
    # pip install misaki[en]
 | 
			
		||||
    a='American English',
 | 
			
		||||
    b='British English',
 | 
			
		||||
 | 
			
		||||
    # espeak-ng
 | 
			
		||||
    e='es',
 | 
			
		||||
    f='fr-fr',
 | 
			
		||||
    h='hi',
 | 
			
		||||
    i='it',
 | 
			
		||||
    p='pt-br',
 | 
			
		||||
 | 
			
		||||
    # pip install misaki[ja]
 | 
			
		||||
    j='Japanese',
 | 
			
		||||
 | 
			
		||||
    # pip install misaki[zh]
 | 
			
		||||
    z='Mandarin Chinese',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
class KPipeline:
 | 
			
		||||
    '''
 | 
			
		||||
    KPipeline is a language-aware support class with 2 main responsibilities:
 | 
			
		||||
    1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
 | 
			
		||||
    2. Manage and store voices, lazily downloaded from HF if needed
 | 
			
		||||
 | 
			
		||||
    You are expected to have one KPipeline per language. If you have multiple
 | 
			
		||||
    KPipelines, you should reuse one KModel instance across all of them.
 | 
			
		||||
 | 
			
		||||
    KPipeline is designed to work with a KModel, but this is not required.
 | 
			
		||||
    There are 2 ways to pass an existing model into a pipeline:
 | 
			
		||||
    1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
 | 
			
		||||
    2. On call: us_pipeline(text, voice, model=model)
 | 
			
		||||
 | 
			
		||||
    By default, KPipeline will automatically initialize its own KModel. To
 | 
			
		||||
    suppress this, construct a "quiet" KPipeline with model=False.
 | 
			
		||||
 | 
			
		||||
    A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
 | 
			
		||||
    any audio. You can use this to phonemize and chunk your text in advance.
 | 
			
		||||
 | 
			
		||||
    A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        lang_code: str,
 | 
			
		||||
        repo_id: Optional[str] = None,
 | 
			
		||||
        model: Union[KModel, bool] = True,
 | 
			
		||||
        trf: bool = False,
 | 
			
		||||
        en_callable: Optional[Callable[[str], str]] = None,
 | 
			
		||||
        device: Optional[str] = None
 | 
			
		||||
    ):
 | 
			
		||||
        """Initialize a KPipeline.
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            lang_code: Language code for G2P processing
 | 
			
		||||
            model: KModel instance, True to create new model, False for no model
 | 
			
		||||
            trf: Whether to use transformer-based G2P
 | 
			
		||||
            device: Override default device selection ('cuda' or 'cpu', or None for auto)
 | 
			
		||||
                   If None, will auto-select cuda if available
 | 
			
		||||
                   If 'cuda' and not available, will explicitly raise an error
 | 
			
		||||
        """
 | 
			
		||||
        if repo_id is None:
 | 
			
		||||
            repo_id = 'hexgrad/Kokoro-82M'
 | 
			
		||||
            print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
 | 
			
		||||
            config=None
 | 
			
		||||
        else:
 | 
			
		||||
            config = os.path.join(repo_id, 'config.json')
 | 
			
		||||
        self.repo_id = repo_id
 | 
			
		||||
        lang_code = lang_code.lower()
 | 
			
		||||
        lang_code = ALIASES.get(lang_code, lang_code)
 | 
			
		||||
        assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
 | 
			
		||||
        self.lang_code = lang_code
 | 
			
		||||
        self.model = None
 | 
			
		||||
        if isinstance(model, KModel):
 | 
			
		||||
            self.model = model
 | 
			
		||||
        elif model:
 | 
			
		||||
            if device == 'cuda' and not torch.cuda.is_available():
 | 
			
		||||
                raise RuntimeError("CUDA requested but not available")
 | 
			
		||||
            if device == 'mps' and not torch.backends.mps.is_available():
 | 
			
		||||
                raise RuntimeError("MPS requested but not available")
 | 
			
		||||
            if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
 | 
			
		||||
                raise RuntimeError("MPS requested but fallback not enabled")
 | 
			
		||||
            if device is None:
 | 
			
		||||
                if torch.cuda.is_available():
 | 
			
		||||
                    device = 'cuda'
 | 
			
		||||
                elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
 | 
			
		||||
                    device = 'mps'
 | 
			
		||||
                else:
 | 
			
		||||
                    device = 'cpu'
 | 
			
		||||
            try:
 | 
			
		||||
                self.model = KModel(repo_id=repo_id, config=config).to(device).eval()
 | 
			
		||||
            except RuntimeError as e:
 | 
			
		||||
                if device == 'cuda':
 | 
			
		||||
                    raise RuntimeError(f"""Failed to initialize model on CUDA: {e}. 
 | 
			
		||||
                                       Try setting device='cpu' or check CUDA installation.""")
 | 
			
		||||
                raise
 | 
			
		||||
        self.voices = {}
 | 
			
		||||
        if lang_code in 'ab':
 | 
			
		||||
            try:
 | 
			
		||||
                fallback = espeak.EspeakFallback(british=lang_code=='b')
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
 | 
			
		||||
                logger.warning({str(e)})
 | 
			
		||||
                fallback = None
 | 
			
		||||
            self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
 | 
			
		||||
        elif lang_code == 'j':
 | 
			
		||||
            try:
 | 
			
		||||
                from misaki import ja
 | 
			
		||||
                self.g2p = ja.JAG2P()
 | 
			
		||||
            except ImportError:
 | 
			
		||||
                logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
 | 
			
		||||
                raise
 | 
			
		||||
        elif lang_code == 'z':
 | 
			
		||||
            try:
 | 
			
		||||
                from misaki import zh
 | 
			
		||||
                self.g2p = zh.ZHG2P(
 | 
			
		||||
                    version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
 | 
			
		||||
                    en_callable=en_callable
 | 
			
		||||
                )
 | 
			
		||||
            except ImportError:
 | 
			
		||||
                logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
 | 
			
		||||
                raise
 | 
			
		||||
        else:
 | 
			
		||||
            language = LANG_CODES[lang_code]
 | 
			
		||||
            logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
 | 
			
		||||
            self.g2p = espeak.EspeakG2P(language=language)
 | 
			
		||||
 | 
			
		||||
    def load_single_voice(self, voice: str):
 | 
			
		||||
        if voice in self.voices:
 | 
			
		||||
            return self.voices[voice]
 | 
			
		||||
        if voice.endswith('.pt'):
 | 
			
		||||
            f = voice
 | 
			
		||||
        else:
 | 
			
		||||
            f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
 | 
			
		||||
            if not voice.startswith(self.lang_code):
 | 
			
		||||
                v = LANG_CODES.get(voice, voice)
 | 
			
		||||
                p = LANG_CODES.get(self.lang_code, self.lang_code)
 | 
			
		||||
                logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
 | 
			
		||||
        pack = torch.load(f, weights_only=True)
 | 
			
		||||
        self.voices[voice] = pack
 | 
			
		||||
        return pack
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    load_voice is a helper function that lazily downloads and loads a voice:
 | 
			
		||||
    Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica').
 | 
			
		||||
    If multiple voices are requested, they are averaged.
 | 
			
		||||
    Delimiter is optional and defaults to ','.
 | 
			
		||||
    """
 | 
			
		||||
    def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
 | 
			
		||||
        if isinstance(voice, torch.FloatTensor):
 | 
			
		||||
            return voice
 | 
			
		||||
        if voice in self.voices:
 | 
			
		||||
            return self.voices[voice]
 | 
			
		||||
        logger.debug(f"Loading voice: {voice}")
 | 
			
		||||
        packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
 | 
			
		||||
        if len(packs) == 1:
 | 
			
		||||
            return packs[0]
 | 
			
		||||
        self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
 | 
			
		||||
        return self.voices[voice]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tokens_to_ps(tokens: List[en.MToken]) -> str:
 | 
			
		||||
        return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def waterfall_last(
 | 
			
		||||
        tokens: List[en.MToken],
 | 
			
		||||
        next_count: int,
 | 
			
		||||
        waterfall: List[str] = ['!.?…', ':;', ',—'],
 | 
			
		||||
        bumps: List[str] = [')', '”']
 | 
			
		||||
    ) -> int:
 | 
			
		||||
        for w in waterfall:
 | 
			
		||||
            z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
 | 
			
		||||
            if z is None:
 | 
			
		||||
                continue
 | 
			
		||||
            z += 1
 | 
			
		||||
            if z < len(tokens) and tokens[z].phonemes in bumps:
 | 
			
		||||
                z += 1
 | 
			
		||||
            if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
 | 
			
		||||
                return z
 | 
			
		||||
        return len(tokens)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tokens_to_text(tokens: List[en.MToken]) -> str:
 | 
			
		||||
        return ''.join(t.text + t.whitespace for t in tokens).strip()
 | 
			
		||||
 | 
			
		||||
    def en_tokenize(
 | 
			
		||||
        self,
 | 
			
		||||
        tokens: List[en.MToken]
 | 
			
		||||
    ) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
 | 
			
		||||
        tks = []
 | 
			
		||||
        pcount = 0
 | 
			
		||||
        for t in tokens:
 | 
			
		||||
            # American English: ɾ => T
 | 
			
		||||
            t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T')
 | 
			
		||||
            next_ps = t.phonemes + (' ' if t.whitespace else '')
 | 
			
		||||
            next_pcount = pcount + len(next_ps.rstrip())
 | 
			
		||||
            if next_pcount > 510:
 | 
			
		||||
                z = KPipeline.waterfall_last(tks, next_pcount)
 | 
			
		||||
                text = KPipeline.tokens_to_text(tks[:z])
 | 
			
		||||
                logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
 | 
			
		||||
                ps = KPipeline.tokens_to_ps(tks[:z])
 | 
			
		||||
                yield text, ps, tks[:z]
 | 
			
		||||
                tks = tks[z:]
 | 
			
		||||
                pcount = len(KPipeline.tokens_to_ps(tks))
 | 
			
		||||
                if not tks:
 | 
			
		||||
                    next_ps = next_ps.lstrip()
 | 
			
		||||
            tks.append(t)
 | 
			
		||||
            pcount += len(next_ps)
 | 
			
		||||
        if tks:
 | 
			
		||||
            text = KPipeline.tokens_to_text(tks)
 | 
			
		||||
            ps = KPipeline.tokens_to_ps(tks)
 | 
			
		||||
            yield ''.join(text).strip(), ''.join(ps).strip(), tks
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def infer(
 | 
			
		||||
        model: KModel,
 | 
			
		||||
        ps: str,
 | 
			
		||||
        pack: torch.FloatTensor,
 | 
			
		||||
        speed: Union[float, Callable[[int], float]] = 1
 | 
			
		||||
    ) -> KModel.Output:
 | 
			
		||||
        if callable(speed):
 | 
			
		||||
            speed = speed(len(ps))
 | 
			
		||||
        return model(ps, pack[len(ps)-1], speed, return_output=True)
 | 
			
		||||
 | 
			
		||||
    def generate_from_tokens(
 | 
			
		||||
        self,
 | 
			
		||||
        tokens: Union[str, List[en.MToken]],
 | 
			
		||||
        voice: str,
 | 
			
		||||
        speed: float = 1,
 | 
			
		||||
        model: Optional[KModel] = None
 | 
			
		||||
    ) -> Generator['KPipeline.Result', None, None]:
 | 
			
		||||
        """Generate audio from either raw phonemes or pre-processed tokens.
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            tokens: Either a phoneme string or list of pre-processed MTokens
 | 
			
		||||
            voice: The voice to use for synthesis
 | 
			
		||||
            speed: Speech speed modifier (default: 1)
 | 
			
		||||
            model: Optional KModel instance (uses pipeline's model if not provided)
 | 
			
		||||
        
 | 
			
		||||
        Yields:
 | 
			
		||||
            KPipeline.Result containing the input tokens and generated audio
 | 
			
		||||
            
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: If no voice is provided or token sequence exceeds model limits
 | 
			
		||||
        """
 | 
			
		||||
        model = model or self.model
 | 
			
		||||
        if model and voice is None:
 | 
			
		||||
            raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')
 | 
			
		||||
        
 | 
			
		||||
        pack = self.load_voice(voice).to(model.device) if model else None
 | 
			
		||||
 | 
			
		||||
        # Handle raw phoneme string
 | 
			
		||||
        if isinstance(tokens, str):
 | 
			
		||||
            logger.debug("Processing phonemes from raw string")
 | 
			
		||||
            if len(tokens) > 510:
 | 
			
		||||
                raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
 | 
			
		||||
            output = KPipeline.infer(model, tokens, pack, speed) if model else None
 | 
			
		||||
            yield self.Result(graphemes='', phonemes=tokens, output=output)
 | 
			
		||||
            return
 | 
			
		||||
        
 | 
			
		||||
        logger.debug("Processing MTokens")
 | 
			
		||||
        # Handle pre-processed tokens
 | 
			
		||||
        for gs, ps, tks in self.en_tokenize(tokens):
 | 
			
		||||
            if not ps:
 | 
			
		||||
                continue
 | 
			
		||||
            elif len(ps) > 510:
 | 
			
		||||
                logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
 | 
			
		||||
                logger.warning("Truncating to 510 characters")
 | 
			
		||||
                ps = ps[:510]
 | 
			
		||||
            output = KPipeline.infer(model, ps, pack, speed) if model else None
 | 
			
		||||
            if output is not None and output.pred_dur is not None:
 | 
			
		||||
                KPipeline.join_timestamps(tks, output.pred_dur)
 | 
			
		||||
            yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
 | 
			
		||||
        # Multiply by 600 to go from pred_dur frames to sample_rate 24000
 | 
			
		||||
        # Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
 | 
			
		||||
        # We will count nice round half-frames, so the divisor is 80
 | 
			
		||||
        MAGIC_DIVISOR = 80
 | 
			
		||||
        if not tokens or len(pred_dur) < 3:
 | 
			
		||||
            # We expect at least 3: <bos>, token, <eos>
 | 
			
		||||
            return
 | 
			
		||||
        # We track 2 counts, measured in half-frames: (left, right)
 | 
			
		||||
        # This way we can cut space characters in half
 | 
			
		||||
        # TODO: Is -3 an appropriate offset?
 | 
			
		||||
        left = right = 2 * max(0, pred_dur[0].item() - 3)
 | 
			
		||||
        # Updates:
 | 
			
		||||
        # left = right + (2 * token_dur) + space_dur
 | 
			
		||||
        # right = left + space_dur
 | 
			
		||||
        i = 1
 | 
			
		||||
        for t in tokens:
 | 
			
		||||
            if i >= len(pred_dur)-1:
 | 
			
		||||
                break
 | 
			
		||||
            if not t.phonemes:
 | 
			
		||||
                if t.whitespace:
 | 
			
		||||
                    i += 1
 | 
			
		||||
                    left = right + pred_dur[i].item()
 | 
			
		||||
                    right = left + pred_dur[i].item()
 | 
			
		||||
                    i += 1
 | 
			
		||||
                continue
 | 
			
		||||
            j = i + len(t.phonemes)
 | 
			
		||||
            if j >= len(pred_dur):
 | 
			
		||||
                break
 | 
			
		||||
            t.start_ts = left / MAGIC_DIVISOR
 | 
			
		||||
            token_dur = pred_dur[i: j].sum().item()
 | 
			
		||||
            space_dur = pred_dur[j].item() if t.whitespace else 0
 | 
			
		||||
            left = right + (2 * token_dur) + space_dur
 | 
			
		||||
            t.end_ts = left / MAGIC_DIVISOR
 | 
			
		||||
            right = left + space_dur
 | 
			
		||||
            i = j + (1 if t.whitespace else 0)
 | 
			
		||||
 | 
			
		||||
    @dataclass
 | 
			
		||||
    class Result:
 | 
			
		||||
        graphemes: str
 | 
			
		||||
        phonemes: str
 | 
			
		||||
        tokens: Optional[List[en.MToken]] = None
 | 
			
		||||
        output: Optional[KModel.Output] = None
 | 
			
		||||
        text_index: Optional[int] = None
 | 
			
		||||
 | 
			
		||||
        @property
 | 
			
		||||
        def audio(self) -> Optional[torch.FloatTensor]:
 | 
			
		||||
            return None if self.output is None else self.output.audio
 | 
			
		||||
 | 
			
		||||
        @property
 | 
			
		||||
        def pred_dur(self) -> Optional[torch.LongTensor]:
 | 
			
		||||
            return None if self.output is None else self.output.pred_dur
 | 
			
		||||
 | 
			
		||||
        ### MARK: BEGIN BACKWARD COMPAT ###
 | 
			
		||||
        def __iter__(self):
 | 
			
		||||
            yield self.graphemes
 | 
			
		||||
            yield self.phonemes
 | 
			
		||||
            yield self.audio
 | 
			
		||||
 | 
			
		||||
        def __getitem__(self, index):
 | 
			
		||||
            return [self.graphemes, self.phonemes, self.audio][index]
 | 
			
		||||
 | 
			
		||||
        def __len__(self):
 | 
			
		||||
            return 3
 | 
			
		||||
        #### MARK: END BACKWARD COMPAT ####
 | 
			
		||||
 | 
			
		||||
    def __call__(
 | 
			
		||||
        self,
 | 
			
		||||
        text: Union[str, List[str]],
 | 
			
		||||
        voice: Optional[str] = None,
 | 
			
		||||
        speed: Union[float, Callable[[int], float]] = 1,
 | 
			
		||||
        split_pattern: Optional[str] = r'\n+',
 | 
			
		||||
        model: Optional[KModel] = None
 | 
			
		||||
    ) -> Generator['KPipeline.Result', None, None]:
 | 
			
		||||
        model = model or self.model
 | 
			
		||||
        if model and voice is None:
 | 
			
		||||
            raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
 | 
			
		||||
        pack = self.load_voice(voice).to(model.device) if model else None
 | 
			
		||||
        
 | 
			
		||||
        # Convert input to list of segments
 | 
			
		||||
        if isinstance(text, str):
 | 
			
		||||
            text = re.split(split_pattern, text.strip()) if split_pattern else [text]
 | 
			
		||||
            
 | 
			
		||||
        # Process each segment
 | 
			
		||||
        for graphemes_index, graphemes in enumerate(text):
 | 
			
		||||
            if not graphemes.strip():  # Skip empty segments
 | 
			
		||||
                continue
 | 
			
		||||
                
 | 
			
		||||
            # English processing (unchanged)
 | 
			
		||||
            if self.lang_code in 'ab':
 | 
			
		||||
                logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
 | 
			
		||||
                _, tokens = self.g2p(graphemes)
 | 
			
		||||
                for gs, ps, tks in self.en_tokenize(tokens):
 | 
			
		||||
                    if not ps:
 | 
			
		||||
                        continue
 | 
			
		||||
                    elif len(ps) > 510:
 | 
			
		||||
                        logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
 | 
			
		||||
                        ps = ps[:510]
 | 
			
		||||
                    output = KPipeline.infer(model, ps, pack, speed) if model else None
 | 
			
		||||
                    if output is not None and output.pred_dur is not None:
 | 
			
		||||
                        KPipeline.join_timestamps(tks, output.pred_dur)
 | 
			
		||||
                    yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index)
 | 
			
		||||
            
 | 
			
		||||
            # Non-English processing with chunking
 | 
			
		||||
            else:
 | 
			
		||||
                # Split long text into smaller chunks (roughly 400 characters each)
 | 
			
		||||
                # Using sentence boundaries when possible
 | 
			
		||||
                chunk_size = 400
 | 
			
		||||
                chunks = []
 | 
			
		||||
                
 | 
			
		||||
                # Try to split on sentence boundaries first
 | 
			
		||||
                sentences = re.split(r'([.!?]+)', graphemes)
 | 
			
		||||
                current_chunk = ""
 | 
			
		||||
                
 | 
			
		||||
                for i in range(0, len(sentences), 2):
 | 
			
		||||
                    sentence = sentences[i]
 | 
			
		||||
                    # Add the punctuation back if it exists
 | 
			
		||||
                    if i + 1 < len(sentences):
 | 
			
		||||
                        sentence += sentences[i + 1]
 | 
			
		||||
                        
 | 
			
		||||
                    if len(current_chunk) + len(sentence) <= chunk_size:
 | 
			
		||||
                        current_chunk += sentence
 | 
			
		||||
                    else:
 | 
			
		||||
                        if current_chunk:
 | 
			
		||||
                            chunks.append(current_chunk.strip())
 | 
			
		||||
                        current_chunk = sentence
 | 
			
		||||
                
 | 
			
		||||
                if current_chunk:
 | 
			
		||||
                    chunks.append(current_chunk.strip())
 | 
			
		||||
                
 | 
			
		||||
                # If no chunks were created (no sentence boundaries), fall back to character-based chunking
 | 
			
		||||
                if not chunks:
 | 
			
		||||
                    chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)]
 | 
			
		||||
                
 | 
			
		||||
                # Process each chunk
 | 
			
		||||
                for chunk in chunks:
 | 
			
		||||
                    if not chunk.strip():
 | 
			
		||||
                        continue
 | 
			
		||||
                        
 | 
			
		||||
                    ps, _ = self.g2p(chunk)
 | 
			
		||||
                    if not ps:
 | 
			
		||||
                        continue
 | 
			
		||||
                    elif len(ps) > 510:
 | 
			
		||||
                        logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
 | 
			
		||||
                        ps = ps[:510]
 | 
			
		||||
                        
 | 
			
		||||
                    output = KPipeline.infer(model, ps, pack, speed) if model else None
 | 
			
		||||
                    yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
 | 
			
		||||
							
								
								
									
										319
									
								
								wan/multitalk/multitalk.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								wan/multitalk/multitalk.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,319 @@
 | 
			
		||||
import random
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import subprocess
 | 
			
		||||
import torchvision.transforms as transforms
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import wan
 | 
			
		||||
from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
 | 
			
		||||
from wan.utils.utils import cache_image, cache_video, str2bool
 | 
			
		||||
# from wan.utils.multitalk_utils import save_video_ffmpeg
 | 
			
		||||
# from .kokoro import KPipeline
 | 
			
		||||
from transformers import Wav2Vec2FeatureExtractor
 | 
			
		||||
from .wav2vec2 import Wav2Vec2Model
 | 
			
		||||
 | 
			
		||||
import librosa
 | 
			
		||||
import pyloudnorm as pyln
 | 
			
		||||
import numpy as np
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
import soundfile as sf
 | 
			
		||||
import re
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
def custom_init(device, wav2vec):    
 | 
			
		||||
    audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device)
 | 
			
		||||
    audio_encoder.feature_extractor._freeze_parameters()
 | 
			
		||||
    wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
 | 
			
		||||
    return wav2vec_feature_extractor, audio_encoder
 | 
			
		||||
 | 
			
		||||
def loudness_norm(audio_array, sr=16000, lufs=-23):
 | 
			
		||||
    meter = pyln.Meter(sr)
 | 
			
		||||
    loudness = meter.integrated_loudness(audio_array)
 | 
			
		||||
    if abs(loudness) > 100:
 | 
			
		||||
        return audio_array
 | 
			
		||||
    normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
 | 
			
		||||
    return normalized_audio
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps = 25):
 | 
			
		||||
    audio_duration = len(speech_array) / sr
 | 
			
		||||
    video_length = audio_duration * fps
 | 
			
		||||
 | 
			
		||||
    # wav2vec_feature_extractor
 | 
			
		||||
    audio_feature = np.squeeze(
 | 
			
		||||
        wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
 | 
			
		||||
    )
 | 
			
		||||
    audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
 | 
			
		||||
    audio_feature = audio_feature.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
    # audio encoder
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
 | 
			
		||||
 | 
			
		||||
    if len(embeddings) == 0:
 | 
			
		||||
        print("Fail to extract audio embedding")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
 | 
			
		||||
    audio_emb = rearrange(audio_emb, "b s d -> s b d")
 | 
			
		||||
 | 
			
		||||
    audio_emb = audio_emb.cpu().detach()
 | 
			
		||||
    return audio_emb
 | 
			
		||||
	
 | 
			
		||||
def audio_prepare_single(audio_path, sample_rate=16000, duration = 0):
 | 
			
		||||
    ext = os.path.splitext(audio_path)[1].lower()
 | 
			
		||||
    if ext in ['.mp4', '.mov', '.avi', '.mkv']:
 | 
			
		||||
        human_speech_array = extract_audio_from_video(audio_path, sample_rate)
 | 
			
		||||
        return human_speech_array
 | 
			
		||||
    else:
 | 
			
		||||
        human_speech_array, sr = librosa.load(audio_path, duration=duration, sr=sample_rate)
 | 
			
		||||
        human_speech_array = loudness_norm(human_speech_array, sr)
 | 
			
		||||
        return human_speech_array
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0):
 | 
			
		||||
    if not (left_path==None or right_path==None):
 | 
			
		||||
        human_speech_array1 = audio_prepare_single(left_path, duration = duration)
 | 
			
		||||
        human_speech_array2 = audio_prepare_single(right_path, duration = duration)
 | 
			
		||||
    elif left_path==None:
 | 
			
		||||
        human_speech_array2 = audio_prepare_single(right_path, duration = duration)
 | 
			
		||||
        human_speech_array1 = np.zeros(human_speech_array2.shape[0])
 | 
			
		||||
    elif right_path==None:
 | 
			
		||||
        human_speech_array1 = audio_prepare_single(left_path, duration = duration)
 | 
			
		||||
        human_speech_array2 = np.zeros(human_speech_array1.shape[0])
 | 
			
		||||
 | 
			
		||||
    if audio_type=='para':
 | 
			
		||||
        new_human_speech1 = human_speech_array1
 | 
			
		||||
        new_human_speech2 = human_speech_array2
 | 
			
		||||
    elif audio_type=='add':
 | 
			
		||||
        new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])]) 
 | 
			
		||||
        new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]])
 | 
			
		||||
    sum_human_speechs = new_human_speech1 + new_human_speech2
 | 
			
		||||
    return new_human_speech1, new_human_speech2, sum_human_speechs
 | 
			
		||||
 | 
			
		||||
def process_tts_single(text, save_dir, voice1):    
 | 
			
		||||
    s1_sentences = []
 | 
			
		||||
 | 
			
		||||
    pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
 | 
			
		||||
 | 
			
		||||
    voice_tensor = torch.load(voice1, weights_only=True)
 | 
			
		||||
    generator = pipeline(
 | 
			
		||||
        text, voice=voice_tensor, # <= change voice here
 | 
			
		||||
        speed=1, split_pattern=r'\n+'
 | 
			
		||||
    )
 | 
			
		||||
    audios = []
 | 
			
		||||
    for i, (gs, ps, audio) in enumerate(generator):
 | 
			
		||||
        audios.append(audio)
 | 
			
		||||
    audios = torch.concat(audios, dim=0)
 | 
			
		||||
    s1_sentences.append(audios)
 | 
			
		||||
    s1_sentences = torch.concat(s1_sentences, dim=0)
 | 
			
		||||
    save_path1 =f'{save_dir}/s1.wav'
 | 
			
		||||
    sf.write(save_path1, s1_sentences, 24000) # save each audio file
 | 
			
		||||
    s1, _ = librosa.load(save_path1, sr=16000)
 | 
			
		||||
    return s1, save_path1
 | 
			
		||||
    
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
def process_tts_multi(text, save_dir, voice1, voice2):
 | 
			
		||||
    pattern = r'\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)'
 | 
			
		||||
    matches = re.findall(pattern, text, re.DOTALL)
 | 
			
		||||
    
 | 
			
		||||
    s1_sentences = []
 | 
			
		||||
    s2_sentences = []
 | 
			
		||||
 | 
			
		||||
    pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
 | 
			
		||||
    for idx, (speaker, content) in enumerate(matches):
 | 
			
		||||
        if speaker == '1':
 | 
			
		||||
            voice_tensor = torch.load(voice1, weights_only=True)
 | 
			
		||||
            generator = pipeline(
 | 
			
		||||
                content, voice=voice_tensor, # <= change voice here
 | 
			
		||||
                speed=1, split_pattern=r'\n+'
 | 
			
		||||
            )
 | 
			
		||||
            audios = []
 | 
			
		||||
            for i, (gs, ps, audio) in enumerate(generator):
 | 
			
		||||
                audios.append(audio)
 | 
			
		||||
            audios = torch.concat(audios, dim=0)
 | 
			
		||||
            s1_sentences.append(audios)
 | 
			
		||||
            s2_sentences.append(torch.zeros_like(audios))
 | 
			
		||||
        elif speaker == '2':
 | 
			
		||||
            voice_tensor = torch.load(voice2, weights_only=True)
 | 
			
		||||
            generator = pipeline(
 | 
			
		||||
                content, voice=voice_tensor, # <= change voice here
 | 
			
		||||
                speed=1, split_pattern=r'\n+'
 | 
			
		||||
            )
 | 
			
		||||
            audios = []
 | 
			
		||||
            for i, (gs, ps, audio) in enumerate(generator):
 | 
			
		||||
                audios.append(audio)
 | 
			
		||||
            audios = torch.concat(audios, dim=0)
 | 
			
		||||
            s2_sentences.append(audios)
 | 
			
		||||
            s1_sentences.append(torch.zeros_like(audios))
 | 
			
		||||
    
 | 
			
		||||
    s1_sentences = torch.concat(s1_sentences, dim=0)
 | 
			
		||||
    s2_sentences = torch.concat(s2_sentences, dim=0)
 | 
			
		||||
    sum_sentences = s1_sentences + s2_sentences
 | 
			
		||||
    save_path1 =f'{save_dir}/s1.wav'
 | 
			
		||||
    save_path2 =f'{save_dir}/s2.wav'
 | 
			
		||||
    save_path_sum = f'{save_dir}/sum.wav'
 | 
			
		||||
    sf.write(save_path1, s1_sentences, 24000) # save each audio file
 | 
			
		||||
    sf.write(save_path2, s2_sentences, 24000)
 | 
			
		||||
    sf.write(save_path_sum, sum_sentences, 24000)
 | 
			
		||||
 | 
			
		||||
    s1, _ = librosa.load(save_path1, sr=16000)
 | 
			
		||||
    s2, _ = librosa.load(save_path2, sr=16000)
 | 
			
		||||
    # sum, _ = librosa.load(save_path_sum, sr=16000)
 | 
			
		||||
    return s1, s2, save_path_sum
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames =  0, fps = 25, sr = 16000):
 | 
			
		||||
    wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
 | 
			
		||||
    # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
 | 
			
		||||
 | 
			
		||||
    new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps)
 | 
			
		||||
    audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
 | 
			
		||||
    audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
 | 
			
		||||
 | 
			
		||||
    full_audio_embs = []
 | 
			
		||||
    if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
 | 
			
		||||
    # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
 | 
			
		||||
    if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
 | 
			
		||||
    if audio_guide2 == None: sum_human_speechs = None
 | 
			
		||||
    return full_audio_embs, sum_human_speechs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5):
 | 
			
		||||
    HUMAN_NUMBER = len(full_audio_embs)
 | 
			
		||||
    audio_end_idx = audio_start_idx + clip_length
 | 
			
		||||
    indices = (torch.arange(2 * 2 + 1) - 2) * 1 
 | 
			
		||||
 | 
			
		||||
    audio_embs = []
 | 
			
		||||
    # split audio with window size
 | 
			
		||||
    for human_idx in range(HUMAN_NUMBER):   
 | 
			
		||||
        center_indices = torch.arange(
 | 
			
		||||
            audio_start_idx,
 | 
			
		||||
            audio_end_idx,
 | 
			
		||||
            1
 | 
			
		||||
        ).unsqueeze(
 | 
			
		||||
            1
 | 
			
		||||
        ) + indices.unsqueeze(0)
 | 
			
		||||
        center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1).to(full_audio_embs[human_idx].device)
 | 
			
		||||
        audio_emb = full_audio_embs[human_idx][center_indices][None,...] #.to(self.device)
 | 
			
		||||
        audio_embs.append(audio_emb)
 | 
			
		||||
    audio_embs = torch.concat(audio_embs, dim=0) #.to(self.param_dtype)
 | 
			
		||||
 | 
			
		||||
    # audio_cond = audio.to(device=x.device, dtype=x.dtype)
 | 
			
		||||
    audio_cond = audio_embs
 | 
			
		||||
    first_frame_audio_emb_s = audio_cond[:, :1, ...] 
 | 
			
		||||
    latter_frame_audio_emb = audio_cond[:, 1:, ...] 
 | 
			
		||||
    latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=vae_scale) 
 | 
			
		||||
    middle_index = audio_window // 2
 | 
			
		||||
    latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] 
 | 
			
		||||
    latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
 | 
			
		||||
    latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] 
 | 
			
		||||
    latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
 | 
			
		||||
    latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] 
 | 
			
		||||
    latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
 | 
			
		||||
    latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) 
 | 
			
		||||
 
 | 
			
		||||
    return [first_frame_audio_emb_s, latter_frame_audio_emb_s]
 | 
			
		||||
 | 
			
		||||
def resize_and_centercrop(cond_image, target_size):
 | 
			
		||||
        """
 | 
			
		||||
        Resize image or tensor to the target size without padding.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Get the original size
 | 
			
		||||
        if isinstance(cond_image, torch.Tensor):
 | 
			
		||||
            _, orig_h, orig_w = cond_image.shape
 | 
			
		||||
        else:
 | 
			
		||||
            orig_h, orig_w = cond_image.height, cond_image.width
 | 
			
		||||
 | 
			
		||||
        target_h, target_w = target_size
 | 
			
		||||
        
 | 
			
		||||
        # Calculate the scaling factor for resizing
 | 
			
		||||
        scale_h = target_h / orig_h
 | 
			
		||||
        scale_w = target_w / orig_w
 | 
			
		||||
        
 | 
			
		||||
        # Compute the final size
 | 
			
		||||
        scale = max(scale_h, scale_w)
 | 
			
		||||
        final_h = math.ceil(scale * orig_h)
 | 
			
		||||
        final_w = math.ceil(scale * orig_w)
 | 
			
		||||
        
 | 
			
		||||
        # Resize
 | 
			
		||||
        if isinstance(cond_image, torch.Tensor):
 | 
			
		||||
            if len(cond_image.shape) == 3:
 | 
			
		||||
                cond_image = cond_image[None]
 | 
			
		||||
            resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
 | 
			
		||||
            # crop
 | 
			
		||||
            cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
 | 
			
		||||
            cropped_tensor = cropped_tensor.squeeze(0)
 | 
			
		||||
        else:
 | 
			
		||||
            resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
 | 
			
		||||
            resized_image = np.array(resized_image)
 | 
			
		||||
            # tensor and crop
 | 
			
		||||
            resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
 | 
			
		||||
            cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
 | 
			
		||||
            cropped_tensor = cropped_tensor[:, :, None, :, :] 
 | 
			
		||||
 | 
			
		||||
        return cropped_tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def timestep_transform(
 | 
			
		||||
    t,
 | 
			
		||||
    shift=5.0,
 | 
			
		||||
    num_timesteps=1000,
 | 
			
		||||
):
 | 
			
		||||
    t = t / num_timesteps
 | 
			
		||||
    # shift the timestep based on ratio
 | 
			
		||||
    new_t = shift * t / (1 + (shift - 1) * t)
 | 
			
		||||
    new_t = new_t * num_timesteps
 | 
			
		||||
    return new_t
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# construct human mask
 | 
			
		||||
def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None):
 | 
			
		||||
    human_masks = []
 | 
			
		||||
    if HUMAN_NUMBER==1:
 | 
			
		||||
        background_mask = torch.ones([src_h, src_w])
 | 
			
		||||
        human_mask1 = torch.ones([src_h, src_w])
 | 
			
		||||
        human_mask2 = torch.ones([src_h, src_w])
 | 
			
		||||
        human_masks = [human_mask1, human_mask2, background_mask]
 | 
			
		||||
    elif HUMAN_NUMBER==2:
 | 
			
		||||
        if bbox != None:
 | 
			
		||||
            assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio"
 | 
			
		||||
            background_mask = torch.zeros([src_h, src_w])
 | 
			
		||||
            for _, person_bbox in bbox.items():
 | 
			
		||||
                x_min, y_min, x_max, y_max = person_bbox
 | 
			
		||||
                human_mask = torch.zeros([src_h, src_w])
 | 
			
		||||
                human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
 | 
			
		||||
                background_mask += human_mask
 | 
			
		||||
                human_masks.append(human_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale))
 | 
			
		||||
            background_mask = torch.zeros([src_h, src_w])
 | 
			
		||||
            background_mask = torch.zeros([src_h, src_w])
 | 
			
		||||
            human_mask1 = torch.zeros([src_h, src_w])
 | 
			
		||||
            human_mask2 = torch.zeros([src_h, src_w])
 | 
			
		||||
            lefty_min, lefty_max = int((src_w//2) * face_scale), int((src_w//2) * (1 - face_scale))
 | 
			
		||||
            righty_min, righty_max = int((src_w//2) * face_scale + (src_w//2)), int((src_w//2) * (1 - face_scale) + (src_w//2))
 | 
			
		||||
            human_mask1[x_min:x_max, lefty_min:lefty_max] = 1
 | 
			
		||||
            human_mask2[x_min:x_max, righty_min:righty_max] = 1
 | 
			
		||||
            background_mask += human_mask1
 | 
			
		||||
            background_mask += human_mask2
 | 
			
		||||
            human_masks = [human_mask1, human_mask2]
 | 
			
		||||
        background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1))
 | 
			
		||||
        human_masks.append(background_mask)
 | 
			
		||||
    
 | 
			
		||||
    ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device)
 | 
			
		||||
    # resize and centercrop for ref_target_masks 
 | 
			
		||||
    # ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w))
 | 
			
		||||
    N_h, N_w = lat_h // 2, lat_w // 2
 | 
			
		||||
    token_ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(N_h, N_w), mode='nearest').squeeze() 
 | 
			
		||||
    token_ref_target_masks = (token_ref_target_masks > 0) 
 | 
			
		||||
    token_ref_target_masks = token_ref_target_masks.float() #.to(self.device)
 | 
			
		||||
 | 
			
		||||
    token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) 
 | 
			
		||||
 | 
			
		||||
    return token_ref_target_masks
 | 
			
		||||
							
								
								
									
										799
									
								
								wan/multitalk/multitalk_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										799
									
								
								wan/multitalk/multitalk_model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,799 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import math
 | 
			
		||||
import numpy as np
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
import torch.cuda.amp as amp
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from diffusers import ModelMixin
 | 
			
		||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
 | 
			
		||||
 | 
			
		||||
from .attention import flash_attention, SingleStreamMutiAttention
 | 
			
		||||
from ..utils.multitalk_utils import get_attn_map_with_target
 | 
			
		||||
 | 
			
		||||
__all__ = ['WanModel']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sinusoidal_embedding_1d(dim, position):
 | 
			
		||||
    # preprocess
 | 
			
		||||
    assert dim % 2 == 0
 | 
			
		||||
    half = dim // 2
 | 
			
		||||
    position = position.type(torch.float64)
 | 
			
		||||
 | 
			
		||||
    # calculation
 | 
			
		||||
    sinusoid = torch.outer(
 | 
			
		||||
        position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
 | 
			
		||||
    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@amp.autocast(enabled=False)
 | 
			
		||||
def rope_params(max_seq_len, dim, theta=10000):
 | 
			
		||||
 | 
			
		||||
    assert dim % 2 == 0
 | 
			
		||||
    freqs = torch.outer(
 | 
			
		||||
        torch.arange(max_seq_len),
 | 
			
		||||
        1.0 / torch.pow(theta,
 | 
			
		||||
                        torch.arange(0, dim, 2).to(torch.float64).div(dim)))
 | 
			
		||||
    freqs = torch.polar(torch.ones_like(freqs), freqs)
 | 
			
		||||
    return freqs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@amp.autocast(enabled=False)
 | 
			
		||||
def rope_apply(x, grid_sizes, freqs):
 | 
			
		||||
    s, n, c = x.size(1), x.size(2), x.size(3) // 2
 | 
			
		||||
 | 
			
		||||
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
 | 
			
		||||
 | 
			
		||||
    output = []
 | 
			
		||||
    for i, (f, h, w) in enumerate(grid_sizes.tolist()):
 | 
			
		||||
        seq_len = f * h * w
 | 
			
		||||
 | 
			
		||||
        x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
 | 
			
		||||
            s, n, -1, 2))
 | 
			
		||||
        freqs_i = torch.cat([
 | 
			
		||||
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
 | 
			
		||||
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
 | 
			
		||||
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
 | 
			
		||||
        ],
 | 
			
		||||
                            dim=-1).reshape(seq_len, 1, -1)
 | 
			
		||||
        freqs_i = freqs_i.to(device=x_i.device)
 | 
			
		||||
        x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
 | 
			
		||||
        x_i = torch.cat([x_i, x[i, seq_len:]])
 | 
			
		||||
 | 
			
		||||
        output.append(x_i)
 | 
			
		||||
    return torch.stack(output).float()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanRMSNorm(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dim, eps=1e-5):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        self.weight = nn.Parameter(torch.ones(dim))
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        r"""
 | 
			
		||||
        Args:
 | 
			
		||||
            x(Tensor): Shape [B, L, C]
 | 
			
		||||
        """
 | 
			
		||||
        return self._norm(x.float()).type_as(x) * self.weight
 | 
			
		||||
 | 
			
		||||
    def _norm(self, x):
 | 
			
		||||
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanLayerNorm(nn.LayerNorm):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dim, eps=1e-6, elementwise_affine=False):
 | 
			
		||||
        super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
 | 
			
		||||
 | 
			
		||||
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        origin_dtype = inputs.dtype
 | 
			
		||||
        out = F.layer_norm(
 | 
			
		||||
            inputs.float(), 
 | 
			
		||||
            self.normalized_shape, 
 | 
			
		||||
            None if self.weight is None else self.weight.float(), 
 | 
			
		||||
            None if self.bias is None else self.bias.float() ,
 | 
			
		||||
            self.eps
 | 
			
		||||
        ).to(origin_dtype)
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanSelfAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 dim,
 | 
			
		||||
                 num_heads,
 | 
			
		||||
                 window_size=(-1, -1),
 | 
			
		||||
                 qk_norm=True,
 | 
			
		||||
                 eps=1e-6):
 | 
			
		||||
        assert dim % num_heads == 0
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.head_dim = dim // num_heads
 | 
			
		||||
        self.window_size = window_size
 | 
			
		||||
        self.qk_norm = qk_norm
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        # layers
 | 
			
		||||
        self.q = nn.Linear(dim, dim)
 | 
			
		||||
        self.k = nn.Linear(dim, dim)
 | 
			
		||||
        self.v = nn.Linear(dim, dim)
 | 
			
		||||
        self.o = nn.Linear(dim, dim)
 | 
			
		||||
        self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
        self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, seq_lens, grid_sizes, freqs, ref_target_masks=None):
 | 
			
		||||
        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
 | 
			
		||||
 | 
			
		||||
        # query, key, value function
 | 
			
		||||
        def qkv_fn(x):
 | 
			
		||||
            q = self.norm_q(self.q(x)).view(b, s, n, d)
 | 
			
		||||
            k = self.norm_k(self.k(x)).view(b, s, n, d)
 | 
			
		||||
            v = self.v(x).view(b, s, n, d)
 | 
			
		||||
            return q, k, v
 | 
			
		||||
        q, k, v = qkv_fn(x)
 | 
			
		||||
 | 
			
		||||
        q = rope_apply(q, grid_sizes, freqs)
 | 
			
		||||
        k = rope_apply(k, grid_sizes, freqs)
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        x = flash_attention(
 | 
			
		||||
            q=q,
 | 
			
		||||
            k=k,
 | 
			
		||||
            v=v,
 | 
			
		||||
            k_lens=seq_lens,
 | 
			
		||||
            window_size=self.window_size
 | 
			
		||||
        ).type_as(x)
 | 
			
		||||
 | 
			
		||||
        # output
 | 
			
		||||
        x = x.flatten(2)
 | 
			
		||||
        x = self.o(x)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0], 
 | 
			
		||||
                                                    ref_target_masks=ref_target_masks)
 | 
			
		||||
 | 
			
		||||
        return x, x_ref_attn_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanI2VCrossAttention(WanSelfAttention):
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 dim,
 | 
			
		||||
                 num_heads,
 | 
			
		||||
                 window_size=(-1, -1),
 | 
			
		||||
                 qk_norm=True,
 | 
			
		||||
                 eps=1e-6):
 | 
			
		||||
        super().__init__(dim, num_heads, window_size, qk_norm, eps)
 | 
			
		||||
 | 
			
		||||
        self.k_img = nn.Linear(dim, dim)
 | 
			
		||||
        self.v_img = nn.Linear(dim, dim)
 | 
			
		||||
        self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, context, context_lens):
 | 
			
		||||
        context_img = context[:, :257]
 | 
			
		||||
        context = context[:, 257:]
 | 
			
		||||
        b, n, d = x.size(0), self.num_heads, self.head_dim
 | 
			
		||||
 | 
			
		||||
        # compute query, key, value
 | 
			
		||||
        q = self.norm_q(self.q(x)).view(b, -1, n, d)
 | 
			
		||||
        k = self.norm_k(self.k(context)).view(b, -1, n, d)
 | 
			
		||||
        v = self.v(context).view(b, -1, n, d)
 | 
			
		||||
        k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
 | 
			
		||||
        v_img = self.v_img(context_img).view(b, -1, n, d)
 | 
			
		||||
        img_x = flash_attention(q, k_img, v_img, k_lens=None)
 | 
			
		||||
        # compute attention
 | 
			
		||||
        x = flash_attention(q, k, v, k_lens=context_lens)
 | 
			
		||||
 | 
			
		||||
        # output
 | 
			
		||||
        x = x.flatten(2)
 | 
			
		||||
        img_x = img_x.flatten(2)
 | 
			
		||||
        x = x + img_x
 | 
			
		||||
        x = self.o(x)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanAttentionBlock(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 cross_attn_type,
 | 
			
		||||
                 dim,
 | 
			
		||||
                 ffn_dim,
 | 
			
		||||
                 num_heads,
 | 
			
		||||
                 window_size=(-1, -1),
 | 
			
		||||
                 qk_norm=True,
 | 
			
		||||
                 cross_attn_norm=False,
 | 
			
		||||
                 eps=1e-6,
 | 
			
		||||
                 output_dim=768,
 | 
			
		||||
                 norm_input_visual=True,
 | 
			
		||||
                 class_range=24,
 | 
			
		||||
                 class_interval=4):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.ffn_dim = ffn_dim
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.window_size = window_size
 | 
			
		||||
        self.qk_norm = qk_norm
 | 
			
		||||
        self.cross_attn_norm = cross_attn_norm
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        # layers
 | 
			
		||||
        self.norm1 = WanLayerNorm(dim, eps)
 | 
			
		||||
        self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
 | 
			
		||||
        self.norm3 = WanLayerNorm(
 | 
			
		||||
            dim, eps,
 | 
			
		||||
            elementwise_affine=True) if cross_attn_norm else nn.Identity()
 | 
			
		||||
        self.cross_attn = WanI2VCrossAttention(dim,
 | 
			
		||||
                                                num_heads,
 | 
			
		||||
                                                (-1, -1),
 | 
			
		||||
                                                qk_norm,
 | 
			
		||||
                                                eps)
 | 
			
		||||
        self.norm2 = WanLayerNorm(dim, eps)
 | 
			
		||||
        self.ffn = nn.Sequential(
 | 
			
		||||
            nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
 | 
			
		||||
            nn.Linear(ffn_dim, dim))
 | 
			
		||||
 | 
			
		||||
        # modulation
 | 
			
		||||
        self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
 | 
			
		||||
 | 
			
		||||
        # init audio module
 | 
			
		||||
        self.audio_cross_attn = SingleStreamMutiAttention(
 | 
			
		||||
                dim=dim,
 | 
			
		||||
                encoder_hidden_states_dim=output_dim,
 | 
			
		||||
                num_heads=num_heads,
 | 
			
		||||
                qk_norm=False,
 | 
			
		||||
                qkv_bias=True,
 | 
			
		||||
                eps=eps,
 | 
			
		||||
                norm_layer=WanRMSNorm,
 | 
			
		||||
                class_range=class_range,
 | 
			
		||||
                class_interval=class_interval
 | 
			
		||||
            )
 | 
			
		||||
        self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True)  if norm_input_visual else nn.Identity()
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        x,
 | 
			
		||||
        e,
 | 
			
		||||
        seq_lens,
 | 
			
		||||
        grid_sizes,
 | 
			
		||||
        freqs,
 | 
			
		||||
        context,
 | 
			
		||||
        context_lens,
 | 
			
		||||
        audio_embedding=None,
 | 
			
		||||
        ref_target_masks=None,
 | 
			
		||||
        human_num=None,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        dtype = x.dtype
 | 
			
		||||
        assert e.dtype == torch.float32
 | 
			
		||||
        with amp.autocast(dtype=torch.float32):
 | 
			
		||||
            e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
 | 
			
		||||
        assert e[0].dtype == torch.float32
 | 
			
		||||
 | 
			
		||||
        # self-attention
 | 
			
		||||
        y, x_ref_attn_map = self.self_attn(
 | 
			
		||||
            (self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes,
 | 
			
		||||
            freqs, ref_target_masks=ref_target_masks)
 | 
			
		||||
        with amp.autocast(dtype=torch.float32):
 | 
			
		||||
            x = x + y * e[2]
 | 
			
		||||
        
 | 
			
		||||
        x = x.to(dtype)
 | 
			
		||||
 | 
			
		||||
        # cross-attention of text
 | 
			
		||||
        x = x + self.cross_attn(self.norm3(x), context, context_lens)
 | 
			
		||||
 | 
			
		||||
        # cross attn of audio
 | 
			
		||||
        x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding,
 | 
			
		||||
                                        shape=grid_sizes[0], x_ref_attn_map=x_ref_attn_map, human_num=human_num)
 | 
			
		||||
        x = x + x_a
 | 
			
		||||
 | 
			
		||||
        y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype))
 | 
			
		||||
        with amp.autocast(dtype=torch.float32):
 | 
			
		||||
            x = x + y * e[5]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        x = x.to(dtype)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Head(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dim, out_dim, patch_size, eps=1e-6):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.out_dim = out_dim
 | 
			
		||||
        self.patch_size = patch_size
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        # layers
 | 
			
		||||
        out_dim = math.prod(patch_size) * out_dim
 | 
			
		||||
        self.norm = WanLayerNorm(dim, eps)
 | 
			
		||||
        self.head = nn.Linear(dim, out_dim)
 | 
			
		||||
 | 
			
		||||
        # modulation
 | 
			
		||||
        self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, e):
 | 
			
		||||
        r"""
 | 
			
		||||
        Args:
 | 
			
		||||
            x(Tensor): Shape [B, L1, C]
 | 
			
		||||
            e(Tensor): Shape [B, C]
 | 
			
		||||
        """
 | 
			
		||||
        assert e.dtype == torch.float32
 | 
			
		||||
        with amp.autocast(dtype=torch.float32):
 | 
			
		||||
            e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
 | 
			
		||||
            x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MLPProj(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, in_dim, out_dim):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.proj = torch.nn.Sequential(
 | 
			
		||||
            torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
 | 
			
		||||
            torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
 | 
			
		||||
            torch.nn.LayerNorm(out_dim))
 | 
			
		||||
 | 
			
		||||
    def forward(self, image_embeds):
 | 
			
		||||
        clip_extra_context_tokens = self.proj(image_embeds)
 | 
			
		||||
        return clip_extra_context_tokens
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AudioProjModel(ModelMixin, ConfigMixin):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        seq_len=5,
 | 
			
		||||
        seq_len_vf=12,
 | 
			
		||||
        blocks=12,  
 | 
			
		||||
        channels=768, 
 | 
			
		||||
        intermediate_dim=512,
 | 
			
		||||
        output_dim=768,
 | 
			
		||||
        context_tokens=32,
 | 
			
		||||
        norm_output_audio=False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.seq_len = seq_len
 | 
			
		||||
        self.blocks = blocks
 | 
			
		||||
        self.channels = channels
 | 
			
		||||
        self.input_dim = seq_len * blocks * channels  
 | 
			
		||||
        self.input_dim_vf = seq_len_vf * blocks * channels
 | 
			
		||||
        self.intermediate_dim = intermediate_dim
 | 
			
		||||
        self.context_tokens = context_tokens
 | 
			
		||||
        self.output_dim = output_dim
 | 
			
		||||
 | 
			
		||||
        # define multiple linear layers
 | 
			
		||||
        self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
 | 
			
		||||
        self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
 | 
			
		||||
        self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
 | 
			
		||||
        self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
 | 
			
		||||
        self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
 | 
			
		||||
 | 
			
		||||
    def forward(self, audio_embeds, audio_embeds_vf):
 | 
			
		||||
        video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
 | 
			
		||||
        B, _, _, S, C = audio_embeds.shape
 | 
			
		||||
 | 
			
		||||
        # process audio of first frame
 | 
			
		||||
        audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
 | 
			
		||||
        batch_size, window_size, blocks, channels = audio_embeds.shape
 | 
			
		||||
        audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
 | 
			
		||||
 | 
			
		||||
        # process audio of latter frame
 | 
			
		||||
        audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
 | 
			
		||||
        batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
 | 
			
		||||
        audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
 | 
			
		||||
 | 
			
		||||
        # first projection
 | 
			
		||||
        audio_embeds = torch.relu(self.proj1(audio_embeds)) 
 | 
			
		||||
        audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) 
 | 
			
		||||
        audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
 | 
			
		||||
        audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
 | 
			
		||||
        audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) 
 | 
			
		||||
        batch_size_c, N_t, C_a = audio_embeds_c.shape
 | 
			
		||||
        audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
 | 
			
		||||
 | 
			
		||||
        # second projection
 | 
			
		||||
        audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
 | 
			
		||||
 | 
			
		||||
        context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
 | 
			
		||||
 | 
			
		||||
        # normalization and reshape
 | 
			
		||||
        context_tokens = self.norm(context_tokens)
 | 
			
		||||
        context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
 | 
			
		||||
 | 
			
		||||
        return context_tokens
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
    r"""
 | 
			
		||||
    Wan diffusion backbone supporting both text-to-video and image-to-video.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    ignore_for_config = [
 | 
			
		||||
        'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
 | 
			
		||||
    ]
 | 
			
		||||
    _no_split_modules = ['WanAttentionBlock']
 | 
			
		||||
 | 
			
		||||
    @register_to_config
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 model_type='i2v',
 | 
			
		||||
                 patch_size=(1, 2, 2),
 | 
			
		||||
                 text_len=512,
 | 
			
		||||
                 in_dim=16,
 | 
			
		||||
                 dim=2048,
 | 
			
		||||
                 ffn_dim=8192,
 | 
			
		||||
                 freq_dim=256,
 | 
			
		||||
                 text_dim=4096,
 | 
			
		||||
                 out_dim=16,
 | 
			
		||||
                 num_heads=16,
 | 
			
		||||
                 num_layers=32,
 | 
			
		||||
                 window_size=(-1, -1),
 | 
			
		||||
                 qk_norm=True,
 | 
			
		||||
                 cross_attn_norm=True,
 | 
			
		||||
                 eps=1e-6,
 | 
			
		||||
                 # audio params
 | 
			
		||||
                 audio_window=5,
 | 
			
		||||
                 intermediate_dim=512,
 | 
			
		||||
                 output_dim=768,
 | 
			
		||||
                 context_tokens=32,
 | 
			
		||||
                 vae_scale=4, # vae timedownsample scale
 | 
			
		||||
 | 
			
		||||
                 norm_input_visual=True,
 | 
			
		||||
                 norm_output_audio=True):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.'
 | 
			
		||||
        self.model_type = model_type
 | 
			
		||||
 | 
			
		||||
        self.patch_size = patch_size
 | 
			
		||||
        self.text_len = text_len
 | 
			
		||||
        self.in_dim = in_dim
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.ffn_dim = ffn_dim
 | 
			
		||||
        self.freq_dim = freq_dim
 | 
			
		||||
        self.text_dim = text_dim
 | 
			
		||||
        self.out_dim = out_dim
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.num_layers = num_layers
 | 
			
		||||
        self.window_size = window_size
 | 
			
		||||
        self.qk_norm = qk_norm
 | 
			
		||||
        self.cross_attn_norm = cross_attn_norm
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self.norm_output_audio = norm_output_audio
 | 
			
		||||
        self.audio_window = audio_window
 | 
			
		||||
        self.intermediate_dim = intermediate_dim
 | 
			
		||||
        self.vae_scale = vae_scale
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        # embeddings
 | 
			
		||||
        self.patch_embedding = nn.Conv3d(
 | 
			
		||||
            in_dim, dim, kernel_size=patch_size, stride=patch_size)
 | 
			
		||||
        self.text_embedding = nn.Sequential(
 | 
			
		||||
            nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
 | 
			
		||||
            nn.Linear(dim, dim))
 | 
			
		||||
 | 
			
		||||
        self.time_embedding = nn.Sequential(
 | 
			
		||||
            nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
 | 
			
		||||
        self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
 | 
			
		||||
 | 
			
		||||
        # blocks
 | 
			
		||||
        cross_attn_type = 'i2v_cross_attn'
 | 
			
		||||
        self.blocks = nn.ModuleList([
 | 
			
		||||
            WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
 | 
			
		||||
                              window_size, qk_norm, cross_attn_norm, eps, 
 | 
			
		||||
                              output_dim=output_dim, norm_input_visual=norm_input_visual)
 | 
			
		||||
            for _ in range(num_layers)
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
        # head
 | 
			
		||||
        self.head = Head(dim, out_dim, patch_size, eps)
 | 
			
		||||
 | 
			
		||||
        assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
 | 
			
		||||
        d = dim // num_heads
 | 
			
		||||
        self.freqs = torch.cat([
 | 
			
		||||
            rope_params(1024, d - 4 * (d // 6)),
 | 
			
		||||
            rope_params(1024, 2 * (d // 6)),
 | 
			
		||||
            rope_params(1024, 2 * (d // 6))
 | 
			
		||||
        ],
 | 
			
		||||
                               dim=1)
 | 
			
		||||
 | 
			
		||||
        if model_type == 'i2v':
 | 
			
		||||
            self.img_emb = MLPProj(1280, dim)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError('Not supported model type.')
 | 
			
		||||
        
 | 
			
		||||
        # init audio adapter
 | 
			
		||||
        self.audio_proj = AudioProjModel(
 | 
			
		||||
                    seq_len=audio_window,
 | 
			
		||||
                    seq_len_vf=audio_window+vae_scale-1,
 | 
			
		||||
                    intermediate_dim=intermediate_dim,
 | 
			
		||||
                    output_dim=output_dim,
 | 
			
		||||
                    context_tokens=context_tokens,
 | 
			
		||||
                    norm_output_audio=norm_output_audio,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # initialize weights
 | 
			
		||||
        self.init_weights()
 | 
			
		||||
 | 
			
		||||
    def teacache_init(
 | 
			
		||||
        self,
 | 
			
		||||
        use_ret_steps=True,
 | 
			
		||||
        teacache_thresh=0.2,
 | 
			
		||||
        sample_steps=40,
 | 
			
		||||
        model_scale='multitalk-480',
 | 
			
		||||
    ):
 | 
			
		||||
        print("teacache_init")
 | 
			
		||||
        self.enable_teacache = True
 | 
			
		||||
        
 | 
			
		||||
        self.__class__.cnt = 0
 | 
			
		||||
        self.__class__.num_steps = sample_steps*3
 | 
			
		||||
        self.__class__.teacache_thresh = teacache_thresh
 | 
			
		||||
        self.__class__.accumulated_rel_l1_distance_even = 0
 | 
			
		||||
        self.__class__.accumulated_rel_l1_distance_odd = 0
 | 
			
		||||
        self.__class__.previous_e0_even = None
 | 
			
		||||
        self.__class__.previous_e0_odd = None
 | 
			
		||||
        self.__class__.previous_residual_even = None
 | 
			
		||||
        self.__class__.previous_residual_odd = None
 | 
			
		||||
        self.__class__.use_ret_steps = use_ret_steps
 | 
			
		||||
 | 
			
		||||
        if use_ret_steps:
 | 
			
		||||
            if model_scale == 'multitalk-480':
 | 
			
		||||
                self.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04,  1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
 | 
			
		||||
            if model_scale == 'multitalk-720':
 | 
			
		||||
                self.__class__.coefficients = [ 8.10705460e+03,  2.13393892e+03, -3.72934672e+02,  1.66203073e+01, -4.17769401e-02]
 | 
			
		||||
            self.__class__.ret_steps = 5*3
 | 
			
		||||
            self.__class__.cutoff_steps = sample_steps*3
 | 
			
		||||
        else:
 | 
			
		||||
            if model_scale == 'multitalk-480':
 | 
			
		||||
                self.__class__.coefficients = [-3.02331670e+02,  2.23948934e+02, -5.25463970e+01,  5.87348440e+00, -2.01973289e-01]
 | 
			
		||||
        
 | 
			
		||||
            if model_scale == 'multitalk-720':
 | 
			
		||||
                self.__class__.coefficients = [-114.36346466,   65.26524496,  -18.82220707,    4.91518089,   -0.23412683]
 | 
			
		||||
            self.__class__.ret_steps = 1*3
 | 
			
		||||
            self.__class__.cutoff_steps = sample_steps*3 - 3
 | 
			
		||||
        print("teacache_init done")
 | 
			
		||||
    
 | 
			
		||||
    def disable_teacache(self):
 | 
			
		||||
        self.enable_teacache = False
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
            self,
 | 
			
		||||
            x,
 | 
			
		||||
            t,
 | 
			
		||||
            context,
 | 
			
		||||
            seq_len,
 | 
			
		||||
            clip_fea=None,
 | 
			
		||||
            y=None,
 | 
			
		||||
            audio=None,
 | 
			
		||||
            ref_target_masks=None,
 | 
			
		||||
        ):
 | 
			
		||||
        assert clip_fea is not None and y is not None
 | 
			
		||||
 | 
			
		||||
        _, T, H, W = x[0].shape
 | 
			
		||||
        N_t = T // self.patch_size[0]
 | 
			
		||||
        N_h = H // self.patch_size[1]
 | 
			
		||||
        N_w = W // self.patch_size[2]
 | 
			
		||||
 | 
			
		||||
        if y is not None:
 | 
			
		||||
            x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
 | 
			
		||||
        x[0] = x[0].to(context[0].dtype)
 | 
			
		||||
 | 
			
		||||
        # embeddings
 | 
			
		||||
        x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
 | 
			
		||||
        grid_sizes = torch.stack(
 | 
			
		||||
            [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
 | 
			
		||||
        x = [u.flatten(2).transpose(1, 2) for u in x]
 | 
			
		||||
        seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
 | 
			
		||||
        assert seq_lens.max() <= seq_len
 | 
			
		||||
        x = torch.cat([
 | 
			
		||||
            torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
 | 
			
		||||
                      dim=1) for u in x
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
        # time embeddings
 | 
			
		||||
        with amp.autocast(dtype=torch.float32):
 | 
			
		||||
            e = self.time_embedding(
 | 
			
		||||
                sinusoidal_embedding_1d(self.freq_dim, t).float())
 | 
			
		||||
            e0 = self.time_projection(e).unflatten(1, (6, self.dim))
 | 
			
		||||
            assert e.dtype == torch.float32 and e0.dtype == torch.float32
 | 
			
		||||
 | 
			
		||||
        # text embedding
 | 
			
		||||
        context_lens = None
 | 
			
		||||
        context = self.text_embedding(
 | 
			
		||||
            torch.stack([
 | 
			
		||||
                torch.cat(
 | 
			
		||||
                    [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
 | 
			
		||||
                for u in context
 | 
			
		||||
            ]))
 | 
			
		||||
 | 
			
		||||
        # clip embedding
 | 
			
		||||
        if clip_fea is not None:
 | 
			
		||||
            context_clip = self.img_emb(clip_fea) 
 | 
			
		||||
            context = torch.concat([context_clip, context], dim=1).to(x.dtype)
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        audio_cond = audio.to(device=x.device, dtype=x.dtype)
 | 
			
		||||
        first_frame_audio_emb_s = audio_cond[:, :1, ...] 
 | 
			
		||||
        latter_frame_audio_emb = audio_cond[:, 1:, ...] 
 | 
			
		||||
        latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale) 
 | 
			
		||||
        middle_index = self.audio_window // 2
 | 
			
		||||
        latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] 
 | 
			
		||||
        latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
 | 
			
		||||
        latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] 
 | 
			
		||||
        latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
 | 
			
		||||
        latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] 
 | 
			
		||||
        latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
 | 
			
		||||
        latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) 
 | 
			
		||||
        audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) 
 | 
			
		||||
        human_num = len(audio_embedding)
 | 
			
		||||
        audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # convert ref_target_masks to token_ref_target_masks
 | 
			
		||||
        if ref_target_masks is not None:
 | 
			
		||||
            ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32) 
 | 
			
		||||
            token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest') 
 | 
			
		||||
            token_ref_target_masks = token_ref_target_masks.squeeze(0)
 | 
			
		||||
            token_ref_target_masks = (token_ref_target_masks > 0)
 | 
			
		||||
            token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) 
 | 
			
		||||
            token_ref_target_masks = token_ref_target_masks.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
        # teacache
 | 
			
		||||
        if self.enable_teacache:
 | 
			
		||||
            modulated_inp = e0 if self.use_ret_steps else e
 | 
			
		||||
            if self.cnt%3==0: # cond
 | 
			
		||||
                if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
 | 
			
		||||
                    should_calc_cond = True
 | 
			
		||||
                    self.accumulated_rel_l1_distance_cond = 0
 | 
			
		||||
                else:
 | 
			
		||||
                    rescale_func = np.poly1d(self.coefficients)
 | 
			
		||||
                    self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item())
 | 
			
		||||
                    if self.accumulated_rel_l1_distance_cond < self.teacache_thresh:
 | 
			
		||||
                        should_calc_cond = False
 | 
			
		||||
                    else:
 | 
			
		||||
                        should_calc_cond = True
 | 
			
		||||
                        self.accumulated_rel_l1_distance_cond = 0
 | 
			
		||||
                self.previous_e0_cond = modulated_inp.clone()
 | 
			
		||||
            elif self.cnt%3==1: # drop_text
 | 
			
		||||
                if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
 | 
			
		||||
                    should_calc_drop_text = True
 | 
			
		||||
                    self.accumulated_rel_l1_distance_drop_text = 0
 | 
			
		||||
                else:
 | 
			
		||||
                    rescale_func = np.poly1d(self.coefficients)
 | 
			
		||||
                    self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item())
 | 
			
		||||
                    if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh:
 | 
			
		||||
                        should_calc_drop_text = False
 | 
			
		||||
                    else:
 | 
			
		||||
                        should_calc_drop_text = True
 | 
			
		||||
                        self.accumulated_rel_l1_distance_drop_text = 0
 | 
			
		||||
                self.previous_e0_drop_text = modulated_inp.clone()
 | 
			
		||||
            else: # uncond
 | 
			
		||||
                if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
 | 
			
		||||
                    should_calc_uncond = True
 | 
			
		||||
                    self.accumulated_rel_l1_distance_uncond = 0
 | 
			
		||||
                else:
 | 
			
		||||
                    rescale_func = np.poly1d(self.coefficients)
 | 
			
		||||
                    self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item())
 | 
			
		||||
                    if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh:
 | 
			
		||||
                        should_calc_uncond = False
 | 
			
		||||
                    else:
 | 
			
		||||
                        should_calc_uncond = True
 | 
			
		||||
                        self.accumulated_rel_l1_distance_uncond = 0
 | 
			
		||||
                self.previous_e0_uncond = modulated_inp.clone()
 | 
			
		||||
 | 
			
		||||
        # arguments
 | 
			
		||||
        kwargs = dict(
 | 
			
		||||
            e=e0,
 | 
			
		||||
            seq_lens=seq_lens,
 | 
			
		||||
            grid_sizes=grid_sizes,
 | 
			
		||||
            freqs=self.freqs,
 | 
			
		||||
            context=context,
 | 
			
		||||
            context_lens=context_lens,
 | 
			
		||||
            audio_embedding=audio_embedding,
 | 
			
		||||
            ref_target_masks=token_ref_target_masks,
 | 
			
		||||
            human_num=human_num,
 | 
			
		||||
            )
 | 
			
		||||
        if self.enable_teacache:
 | 
			
		||||
            if self.cnt%3==0:
 | 
			
		||||
                if not should_calc_cond:
 | 
			
		||||
                    x +=  self.previous_residual_cond
 | 
			
		||||
                else:
 | 
			
		||||
                    ori_x = x.clone()
 | 
			
		||||
                    for block in self.blocks:
 | 
			
		||||
                        x = block(x, **kwargs)
 | 
			
		||||
                    self.previous_residual_cond = x - ori_x
 | 
			
		||||
            elif self.cnt%3==1:
 | 
			
		||||
                if not should_calc_drop_text:
 | 
			
		||||
                    x +=  self.previous_residual_drop_text
 | 
			
		||||
                else:
 | 
			
		||||
                    ori_x = x.clone()
 | 
			
		||||
                    for block in self.blocks:
 | 
			
		||||
                        x = block(x, **kwargs)
 | 
			
		||||
                    self.previous_residual_drop_text = x - ori_x
 | 
			
		||||
            else:
 | 
			
		||||
                if not should_calc_uncond:
 | 
			
		||||
                    x +=  self.previous_residual_uncond
 | 
			
		||||
                else:
 | 
			
		||||
                    ori_x = x.clone()
 | 
			
		||||
                    for block in self.blocks:
 | 
			
		||||
                        x = block(x, **kwargs)
 | 
			
		||||
                    self.previous_residual_uncond = x - ori_x
 | 
			
		||||
        else:
 | 
			
		||||
            for block in self.blocks:
 | 
			
		||||
                x = block(x, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # head
 | 
			
		||||
        x = self.head(x, e)
 | 
			
		||||
 | 
			
		||||
        # unpatchify
 | 
			
		||||
        x = self.unpatchify(x, grid_sizes)
 | 
			
		||||
        if self.enable_teacache:
 | 
			
		||||
            self.cnt += 1
 | 
			
		||||
            if self.cnt >= self.num_steps:
 | 
			
		||||
                self.cnt = 0
 | 
			
		||||
 | 
			
		||||
        return torch.stack(x).float()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def unpatchify(self, x, grid_sizes):
 | 
			
		||||
        r"""
 | 
			
		||||
        Reconstruct video tensors from patch embeddings.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x (List[Tensor]):
 | 
			
		||||
                List of patchified features, each with shape [L, C_out * prod(patch_size)]
 | 
			
		||||
            grid_sizes (Tensor):
 | 
			
		||||
                Original spatial-temporal grid dimensions before patching,
 | 
			
		||||
                    shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List[Tensor]:
 | 
			
		||||
                Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        c = self.out_dim
 | 
			
		||||
        out = []
 | 
			
		||||
        for u, v in zip(x, grid_sizes.tolist()):
 | 
			
		||||
            u = u[:math.prod(v)].view(*v, *self.patch_size, c)
 | 
			
		||||
            u = torch.einsum('fhwpqrc->cfphqwr', u)
 | 
			
		||||
            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
 | 
			
		||||
            out.append(u)
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def init_weights(self):
 | 
			
		||||
        r"""
 | 
			
		||||
        Initialize model parameters using Xavier initialization.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # basic init
 | 
			
		||||
        for m in self.modules():
 | 
			
		||||
            if isinstance(m, nn.Linear):
 | 
			
		||||
                nn.init.xavier_uniform_(m.weight)
 | 
			
		||||
                if m.bias is not None:
 | 
			
		||||
                    nn.init.zeros_(m.bias)
 | 
			
		||||
 | 
			
		||||
        # init embeddings
 | 
			
		||||
        nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
 | 
			
		||||
        for m in self.text_embedding.modules():
 | 
			
		||||
            if isinstance(m, nn.Linear):
 | 
			
		||||
                nn.init.normal_(m.weight, std=.02)
 | 
			
		||||
        for m in self.time_embedding.modules():
 | 
			
		||||
            if isinstance(m, nn.Linear):
 | 
			
		||||
                nn.init.normal_(m.weight, std=.02)
 | 
			
		||||
 | 
			
		||||
        # init output layer
 | 
			
		||||
        nn.init.zeros_(self.head.head.weight)
 | 
			
		||||
							
								
								
									
										353
									
								
								wan/multitalk/multitalk_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										353
									
								
								wan/multitalk/multitalk_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,353 @@
 | 
			
		||||
import os
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
from einops import rearrange, repeat
 | 
			
		||||
from functools import lru_cache
 | 
			
		||||
import imageio
 | 
			
		||||
import uuid
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import numpy as np
 | 
			
		||||
import subprocess
 | 
			
		||||
import soundfile as sf
 | 
			
		||||
import torchvision
 | 
			
		||||
import binascii
 | 
			
		||||
import os.path as osp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
 | 
			
		||||
ASPECT_RATIO_627 = {
 | 
			
		||||
     '0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1), 
 | 
			
		||||
     '0.82': ([576, 704], 1),  '1.00': ([640, 640], 1),  '1.22': ([704, 576], 1), '1.50': ([768, 512], 1), 
 | 
			
		||||
     '1.86': ([832, 448], 1),  '2.00': ([896, 448], 1),  '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1), 
 | 
			
		||||
     '3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ASPECT_RATIO_960 = {
 | 
			
		||||
     '0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1), 
 | 
			
		||||
     '0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1), 
 | 
			
		||||
     '1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1), 
 | 
			
		||||
     '1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1), 
 | 
			
		||||
     '2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1), 
 | 
			
		||||
     '3.75': ([1920, 512], 1)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def torch_gc():
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
    torch.cuda.ipc_collect()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_token_counts_and_frame_ids(T, token_frame, world_size, rank):
 | 
			
		||||
 | 
			
		||||
    S = T * token_frame
 | 
			
		||||
    split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)]
 | 
			
		||||
    start = sum(split_sizes[:rank])
 | 
			
		||||
    end = start + split_sizes[rank]
 | 
			
		||||
    counts = [0] * T
 | 
			
		||||
    for idx in range(start, end):
 | 
			
		||||
        t = idx // token_frame
 | 
			
		||||
        counts[t] += 1
 | 
			
		||||
 | 
			
		||||
    counts_filtered = []
 | 
			
		||||
    frame_ids = []
 | 
			
		||||
    for t, c in enumerate(counts):
 | 
			
		||||
        if c > 0:
 | 
			
		||||
            counts_filtered.append(c)
 | 
			
		||||
            frame_ids.append(t)
 | 
			
		||||
    return counts_filtered, frame_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
 | 
			
		||||
 | 
			
		||||
    source_min, source_max = source_range
 | 
			
		||||
    new_min, new_max = target_range
 | 
			
		||||
 
 | 
			
		||||
    normalized = (column - source_min) / (source_max - source_min + epsilon)
 | 
			
		||||
    scaled = normalized * (new_max - new_min) + new_min
 | 
			
		||||
    return scaled
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.compile
 | 
			
		||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
 | 
			
		||||
    
 | 
			
		||||
    ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
 | 
			
		||||
    scale = 1.0 / visual_q.shape[-1] ** 0.5
 | 
			
		||||
    visual_q = visual_q * scale
 | 
			
		||||
    visual_q = visual_q.transpose(1, 2)
 | 
			
		||||
    ref_k = ref_k.transpose(1, 2)
 | 
			
		||||
    attn = visual_q @ ref_k.transpose(-2, -1)
 | 
			
		||||
 | 
			
		||||
    if attn_bias is not None: attn += attn_bias
 | 
			
		||||
 | 
			
		||||
    x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
 | 
			
		||||
 | 
			
		||||
    x_ref_attn_maps = []
 | 
			
		||||
    ref_target_masks = ref_target_masks.to(visual_q.dtype)
 | 
			
		||||
    x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
 | 
			
		||||
 | 
			
		||||
    for class_idx, ref_target_mask in enumerate(ref_target_masks):
 | 
			
		||||
        ref_target_mask = ref_target_mask[None, None, None, ...]
 | 
			
		||||
        x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
 | 
			
		||||
        x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
 | 
			
		||||
        x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
 | 
			
		||||
       
 | 
			
		||||
        if mode == 'mean':
 | 
			
		||||
            x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
 | 
			
		||||
        elif mode == 'max':
 | 
			
		||||
            x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
 | 
			
		||||
        
 | 
			
		||||
        x_ref_attn_maps.append(x_ref_attnmap)
 | 
			
		||||
    
 | 
			
		||||
    del attn
 | 
			
		||||
    del x_ref_attn_map_source
 | 
			
		||||
 | 
			
		||||
    return torch.concat(x_ref_attn_maps, dim=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
 | 
			
		||||
    """Args:
 | 
			
		||||
        query (torch.tensor): B M H K
 | 
			
		||||
        key (torch.tensor): B M H K
 | 
			
		||||
        shape (tuple): (N_t, N_h, N_w)
 | 
			
		||||
        ref_target_masks: [B, N_h * N_w]
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    N_t, N_h, N_w = shape
 | 
			
		||||
    
 | 
			
		||||
    x_seqlens = N_h * N_w
 | 
			
		||||
    ref_k     = ref_k[:, :x_seqlens]
 | 
			
		||||
    if ref_images_count > 0 :
 | 
			
		||||
        visual_q_shape = visual_q.shape 
 | 
			
		||||
        visual_q = visual_q.reshape(visual_q_shape[0], N_t, -1)
 | 
			
		||||
        visual_q = visual_q[:, ref_images_count:]
 | 
			
		||||
        visual_q = visual_q.reshape(visual_q_shape[0], -1, *visual_q_shape[-2:])
 | 
			
		||||
 | 
			
		||||
    _, seq_lens, heads, _ = visual_q.shape
 | 
			
		||||
    class_num, _ = ref_target_masks.shape
 | 
			
		||||
    x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
 | 
			
		||||
 | 
			
		||||
    split_chunk = heads // split_num
 | 
			
		||||
    
 | 
			
		||||
    for i in range(split_num):
 | 
			
		||||
        x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
 | 
			
		||||
        x_ref_attn_maps += x_ref_attn_maps_perhead
 | 
			
		||||
    
 | 
			
		||||
    x_ref_attn_maps /= split_num
 | 
			
		||||
    return x_ref_attn_maps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rotate_half(x):
 | 
			
		||||
    x = rearrange(x, "... (d r) -> ... d r", r=2)
 | 
			
		||||
    x1, x2 = x.unbind(dim=-1)
 | 
			
		||||
    x = torch.stack((-x2, x1), dim=-1)
 | 
			
		||||
    return rearrange(x, "... d r -> ... (d r)")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RotaryPositionalEmbedding1D(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 head_dim,
 | 
			
		||||
                 ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.head_dim = head_dim
 | 
			
		||||
        self.base = 10000
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @lru_cache(maxsize=32)
 | 
			
		||||
    def precompute_freqs_cis_1d(self, pos_indices):
 | 
			
		||||
 | 
			
		||||
        freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
 | 
			
		||||
        freqs = freqs.to(pos_indices.device)
 | 
			
		||||
        freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
 | 
			
		||||
        freqs = repeat(freqs, "... n -> ... (n r)", r=2)
 | 
			
		||||
        return freqs
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, pos_indices):
 | 
			
		||||
        """1D RoPE.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            query (torch.tensor): [B, head, seq, head_dim]
 | 
			
		||||
            pos_indices (torch.tensor): [seq,]
 | 
			
		||||
        Returns:
 | 
			
		||||
            query with the same shape as input.
 | 
			
		||||
        """
 | 
			
		||||
        freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
 | 
			
		||||
 | 
			
		||||
        x_ = x.float()
 | 
			
		||||
 | 
			
		||||
        freqs_cis = freqs_cis.float().to(x.device)
 | 
			
		||||
        cos, sin = freqs_cis.cos(), freqs_cis.sin()
 | 
			
		||||
        cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
 | 
			
		||||
        x_ = (x_ * cos) + (rotate_half(x_) * sin)
 | 
			
		||||
 | 
			
		||||
        return x_.type_as(x)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rand_name(length=8, suffix=''):
 | 
			
		||||
    name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
 | 
			
		||||
    if suffix:
 | 
			
		||||
        if not suffix.startswith('.'):
 | 
			
		||||
            suffix = '.' + suffix
 | 
			
		||||
        name += suffix
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
def cache_video(tensor,
 | 
			
		||||
                save_file=None,
 | 
			
		||||
                fps=30,
 | 
			
		||||
                suffix='.mp4',
 | 
			
		||||
                nrow=8,
 | 
			
		||||
                normalize=True,
 | 
			
		||||
                value_range=(-1, 1),
 | 
			
		||||
                retry=5):
 | 
			
		||||
    
 | 
			
		||||
    # cache file
 | 
			
		||||
    cache_file = osp.join('/tmp', rand_name(
 | 
			
		||||
        suffix=suffix)) if save_file is None else save_file
 | 
			
		||||
 | 
			
		||||
    # save to cache
 | 
			
		||||
    error = None
 | 
			
		||||
    for _ in range(retry):
 | 
			
		||||
       
 | 
			
		||||
        # preprocess
 | 
			
		||||
        tensor = tensor.clamp(min(value_range), max(value_range))
 | 
			
		||||
        tensor = torch.stack([
 | 
			
		||||
                torchvision.utils.make_grid(
 | 
			
		||||
                    u, nrow=nrow, normalize=normalize, value_range=value_range)
 | 
			
		||||
                for u in tensor.unbind(2)
 | 
			
		||||
            ],
 | 
			
		||||
                                 dim=1).permute(1, 2, 3, 0)
 | 
			
		||||
        tensor = (tensor * 255).type(torch.uint8).cpu()
 | 
			
		||||
 | 
			
		||||
        # write video
 | 
			
		||||
        writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"])
 | 
			
		||||
        for frame in tensor.numpy():
 | 
			
		||||
            writer.append_data(frame)
 | 
			
		||||
        writer.close()
 | 
			
		||||
        return cache_file
 | 
			
		||||
 | 
			
		||||
def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False):
 | 
			
		||||
    
 | 
			
		||||
    def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
 | 
			
		||||
        writer = imageio.get_writer(
 | 
			
		||||
            save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
 | 
			
		||||
        )
 | 
			
		||||
        for frame in tqdm(frames, desc="Saving video"):
 | 
			
		||||
            frame = np.array(frame)
 | 
			
		||||
            writer.append_data(frame)
 | 
			
		||||
        writer.close()
 | 
			
		||||
    save_path_tmp = save_path + "-temp.mp4"
 | 
			
		||||
 | 
			
		||||
    if high_quality_save:
 | 
			
		||||
        cache_video(
 | 
			
		||||
                    tensor=gen_video_samples.unsqueeze(0),
 | 
			
		||||
                    save_file=save_path_tmp,
 | 
			
		||||
                    fps=fps,
 | 
			
		||||
                    nrow=1,
 | 
			
		||||
                    normalize=True,
 | 
			
		||||
                    value_range=(-1, 1)
 | 
			
		||||
                    )
 | 
			
		||||
    else:
 | 
			
		||||
        video_audio = (gen_video_samples+1)/2 # C T H W
 | 
			
		||||
        video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy()
 | 
			
		||||
        video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8)  # to [0, 255]
 | 
			
		||||
        save_video(video_audio, save_path_tmp, fps=fps, quality=quality)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # crop audio according to video length
 | 
			
		||||
    _, T, _, _ = gen_video_samples.shape
 | 
			
		||||
    duration = T / fps
 | 
			
		||||
    save_path_crop_audio = save_path + "-cropaudio.wav"
 | 
			
		||||
    final_command = [
 | 
			
		||||
        "ffmpeg",
 | 
			
		||||
        "-i",
 | 
			
		||||
        vocal_audio_list[0],
 | 
			
		||||
        "-t",
 | 
			
		||||
        f'{duration}',
 | 
			
		||||
        save_path_crop_audio,
 | 
			
		||||
    ]
 | 
			
		||||
    subprocess.run(final_command, check=True)
 | 
			
		||||
 | 
			
		||||
    save_path = save_path + ".mp4"
 | 
			
		||||
    if high_quality_save:
 | 
			
		||||
        final_command = [
 | 
			
		||||
            "ffmpeg",
 | 
			
		||||
            "-y",
 | 
			
		||||
            "-i", save_path_tmp,
 | 
			
		||||
            "-i", save_path_crop_audio,
 | 
			
		||||
            "-c:v", "libx264",
 | 
			
		||||
            "-crf", "0",
 | 
			
		||||
            "-preset", "veryslow",
 | 
			
		||||
            "-c:a", "aac", 
 | 
			
		||||
            "-shortest",
 | 
			
		||||
            save_path,
 | 
			
		||||
        ]
 | 
			
		||||
        subprocess.run(final_command, check=True)
 | 
			
		||||
        os.remove(save_path_tmp)
 | 
			
		||||
        os.remove(save_path_crop_audio)
 | 
			
		||||
    else:
 | 
			
		||||
        final_command = [
 | 
			
		||||
            "ffmpeg",
 | 
			
		||||
            "-y",
 | 
			
		||||
            "-i",
 | 
			
		||||
            save_path_tmp,
 | 
			
		||||
            "-i",
 | 
			
		||||
            save_path_crop_audio,
 | 
			
		||||
            "-c:v",
 | 
			
		||||
            "libx264",
 | 
			
		||||
            "-c:a",
 | 
			
		||||
            "aac",
 | 
			
		||||
            "-shortest",
 | 
			
		||||
            save_path,
 | 
			
		||||
        ]
 | 
			
		||||
        subprocess.run(final_command, check=True)
 | 
			
		||||
        os.remove(save_path_tmp)
 | 
			
		||||
        os.remove(save_path_crop_audio)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MomentumBuffer:
 | 
			
		||||
    def __init__(self, momentum: float): 
 | 
			
		||||
        self.momentum = momentum 
 | 
			
		||||
        self.running_average = 0 
 | 
			
		||||
    
 | 
			
		||||
    def update(self, update_value: torch.Tensor): 
 | 
			
		||||
        new_average = self.momentum * self.running_average 
 | 
			
		||||
        self.running_average = update_value + new_average
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def project( 
 | 
			
		||||
        v0: torch.Tensor, # [B, C, T, H, W] 
 | 
			
		||||
        v1: torch.Tensor, # [B, C, T, H, W] 
 | 
			
		||||
        ): 
 | 
			
		||||
    dtype = v0.dtype 
 | 
			
		||||
    v0, v1 = v0.double(), v1.double() 
 | 
			
		||||
    v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4]) 
 | 
			
		||||
    v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1 
 | 
			
		||||
    v0_orthogonal = v0 - v0_parallel
 | 
			
		||||
    return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def adaptive_projected_guidance( 
 | 
			
		||||
          diff: torch.Tensor, # [B, C, T, H, W] 
 | 
			
		||||
          pred_cond: torch.Tensor, # [B, C, T, H, W] 
 | 
			
		||||
          momentum_buffer: MomentumBuffer = None, 
 | 
			
		||||
          eta: float = 0.0,
 | 
			
		||||
          norm_threshold: float = 55,
 | 
			
		||||
          ): 
 | 
			
		||||
    if momentum_buffer is not None: 
 | 
			
		||||
        momentum_buffer.update(diff) 
 | 
			
		||||
        diff = momentum_buffer.running_average
 | 
			
		||||
    if norm_threshold > 0: 
 | 
			
		||||
        ones = torch.ones_like(diff) 
 | 
			
		||||
        diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True) 
 | 
			
		||||
        print(f"diff_norm: {diff_norm}")
 | 
			
		||||
        scale_factor = torch.minimum(ones, norm_threshold / diff_norm) 
 | 
			
		||||
        diff = diff * scale_factor 
 | 
			
		||||
    diff_parallel, diff_orthogonal = project(diff, pred_cond) 
 | 
			
		||||
    normalized_update = diff_orthogonal + eta * diff_parallel
 | 
			
		||||
    return normalized_update
 | 
			
		||||
							
								
								
									
										20
									
								
								wan/multitalk/torch_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								wan/multitalk/torch_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,20 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_mask_from_lengths(lengths, max_len=None):
 | 
			
		||||
    lengths = lengths.to(torch.long)
 | 
			
		||||
    if max_len is None:
 | 
			
		||||
        max_len = torch.max(lengths).item()
 | 
			
		||||
 | 
			
		||||
    ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
 | 
			
		||||
    mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
 | 
			
		||||
 | 
			
		||||
    return mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def linear_interpolation(features, seq_len):
 | 
			
		||||
    features = features.transpose(1, 2)
 | 
			
		||||
    output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
 | 
			
		||||
    return output_features.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										125
									
								
								wan/multitalk/wav2vec2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								wan/multitalk/wav2vec2.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,125 @@
 | 
			
		||||
from transformers import Wav2Vec2Config, Wav2Vec2Model
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutput
 | 
			
		||||
 | 
			
		||||
from .torch_utils import linear_interpolation
 | 
			
		||||
 | 
			
		||||
# the implementation of Wav2Vec2Model is borrowed from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
 | 
			
		||||
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
 | 
			
		||||
class Wav2Vec2Model(Wav2Vec2Model):
 | 
			
		||||
    def __init__(self, config: Wav2Vec2Config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_values,
 | 
			
		||||
        seq_len,
 | 
			
		||||
        attention_mask=None,
 | 
			
		||||
        mask_time_indices=None,
 | 
			
		||||
        output_attentions=None,
 | 
			
		||||
        output_hidden_states=None,
 | 
			
		||||
        return_dict=None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.config.output_attentions = True
 | 
			
		||||
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        extract_features = self.feature_extractor(input_values)
 | 
			
		||||
        extract_features = extract_features.transpose(1, 2)
 | 
			
		||||
        extract_features = linear_interpolation(extract_features, seq_len=seq_len)
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            # compute reduced attention_mask corresponding to feature vectors
 | 
			
		||||
            attention_mask = self._get_feature_vector_attention_mask(
 | 
			
		||||
                extract_features.shape[1], attention_mask, add_adapter=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states, extract_features = self.feature_projection(extract_features)
 | 
			
		||||
        hidden_states = self._mask_hidden_states(
 | 
			
		||||
            hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        encoder_outputs = self.encoder(
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = encoder_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if self.adapter is not None:
 | 
			
		||||
            hidden_states = self.adapter(hidden_states)
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return (hidden_states, ) + encoder_outputs[1:]
 | 
			
		||||
        return BaseModelOutput(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            hidden_states=encoder_outputs.hidden_states,
 | 
			
		||||
            attentions=encoder_outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def feature_extract(
 | 
			
		||||
        self,
 | 
			
		||||
        input_values,
 | 
			
		||||
        seq_len,
 | 
			
		||||
    ):
 | 
			
		||||
        extract_features = self.feature_extractor(input_values)
 | 
			
		||||
        extract_features = extract_features.transpose(1, 2)
 | 
			
		||||
        extract_features = linear_interpolation(extract_features, seq_len=seq_len)
 | 
			
		||||
 | 
			
		||||
        return extract_features
 | 
			
		||||
 | 
			
		||||
    def encode(
 | 
			
		||||
        self,
 | 
			
		||||
        extract_features,
 | 
			
		||||
        attention_mask=None,
 | 
			
		||||
        mask_time_indices=None,
 | 
			
		||||
        output_attentions=None,
 | 
			
		||||
        output_hidden_states=None,
 | 
			
		||||
        return_dict=None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.config.output_attentions = True
 | 
			
		||||
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            # compute reduced attention_mask corresponding to feature vectors
 | 
			
		||||
            attention_mask = self._get_feature_vector_attention_mask(
 | 
			
		||||
                extract_features.shape[1], attention_mask, add_adapter=False
 | 
			
		||||
            )
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
        hidden_states, extract_features = self.feature_projection(extract_features)
 | 
			
		||||
        hidden_states = self._mask_hidden_states(
 | 
			
		||||
            hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        encoder_outputs = self.encoder(
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = encoder_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if self.adapter is not None:
 | 
			
		||||
            hidden_states = self.adapter(hidden_states)
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return (hidden_states, ) + encoder_outputs[1:]
 | 
			
		||||
        return BaseModelOutput(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            hidden_states=encoder_outputs.hidden_states,
 | 
			
		||||
            attentions=encoder_outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user