mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
multitalk files
This commit is contained in:
parent
621687c12a
commit
3a8bd05c6e
382
wan/multitalk/attention.py
Normal file
382
wan/multitalk/attention.py
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from .multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids
|
||||||
|
from wan.modules.attention import pay_attention
|
||||||
|
|
||||||
|
import xformers.ops
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_interface
|
||||||
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
FLASH_ATTN_2_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'flash_attention',
|
||||||
|
'attention',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
q_lens=None,
|
||||||
|
k_lens=None,
|
||||||
|
dropout_p=0.,
|
||||||
|
softmax_scale=None,
|
||||||
|
q_scale=None,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
deterministic=False,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
version=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
q: [B, Lq, Nq, C1].
|
||||||
|
k: [B, Lk, Nk, C1].
|
||||||
|
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
||||||
|
q_lens: [B].
|
||||||
|
k_lens: [B].
|
||||||
|
dropout_p: float. Dropout probability.
|
||||||
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
|
causal: bool. Whether to apply causal attention mask.
|
||||||
|
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
||||||
|
deterministic: bool. If True, slightly slower and uses more memory.
|
||||||
|
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
||||||
|
"""
|
||||||
|
half_dtypes = (torch.float16, torch.bfloat16)
|
||||||
|
assert dtype in half_dtypes
|
||||||
|
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
||||||
|
|
||||||
|
# params
|
||||||
|
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||||
|
|
||||||
|
def half(x):
|
||||||
|
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||||
|
|
||||||
|
# preprocess query
|
||||||
|
if q_lens is None:
|
||||||
|
q = half(q.flatten(0, 1))
|
||||||
|
q_lens = torch.tensor(
|
||||||
|
[lq] * b, dtype=torch.int32).to(
|
||||||
|
device=q.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
||||||
|
|
||||||
|
# preprocess key, value
|
||||||
|
if k_lens is None:
|
||||||
|
k = half(k.flatten(0, 1))
|
||||||
|
v = half(v.flatten(0, 1))
|
||||||
|
k_lens = torch.tensor(
|
||||||
|
[lk] * b, dtype=torch.int32).to(
|
||||||
|
device=k.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
||||||
|
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
||||||
|
|
||||||
|
q = q.to(v.dtype)
|
||||||
|
k = k.to(v.dtype)
|
||||||
|
|
||||||
|
if q_scale is not None:
|
||||||
|
q = q * q_scale
|
||||||
|
|
||||||
|
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
||||||
|
warnings.warn(
|
||||||
|
'Flash attention 3 is not available, use flash attention 2 instead.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply attention
|
||||||
|
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
||||||
|
# Note: dropout_p, window_size are not supported in FA3 now.
|
||||||
|
x = flash_attn_interface.flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||||
|
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||||
|
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||||
|
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
|
max_seqlen_q=lq,
|
||||||
|
max_seqlen_k=lk,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
||||||
|
else:
|
||||||
|
assert FLASH_ATTN_2_AVAILABLE
|
||||||
|
x = flash_attn.flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||||
|
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||||
|
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||||
|
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||||
|
max_seqlen_q=lq,
|
||||||
|
max_seqlen_k=lk,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
window_size=window_size,
|
||||||
|
deterministic=deterministic).unflatten(0, (b, lq))
|
||||||
|
|
||||||
|
# output
|
||||||
|
return x.type(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
q_lens=None,
|
||||||
|
k_lens=None,
|
||||||
|
dropout_p=0.,
|
||||||
|
softmax_scale=None,
|
||||||
|
q_scale=None,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
deterministic=False,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
fa_version=None,
|
||||||
|
):
|
||||||
|
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
||||||
|
return flash_attention(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
q_lens=q_lens,
|
||||||
|
k_lens=k_lens,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
q_scale=q_scale,
|
||||||
|
causal=causal,
|
||||||
|
window_size=window_size,
|
||||||
|
deterministic=deterministic,
|
||||||
|
dtype=dtype,
|
||||||
|
version=fa_version,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if q_lens is not None or k_lens is not None:
|
||||||
|
warnings.warn(
|
||||||
|
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
||||||
|
)
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
q = q.transpose(1, 2).to(dtype)
|
||||||
|
k = k.transpose(1, 2).to(dtype)
|
||||||
|
v = v.transpose(1, 2).to(dtype)
|
||||||
|
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
||||||
|
|
||||||
|
out = out.transpose(1, 2).contiguous()
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStreamAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
encoder_hidden_states_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
qkv_bias: bool,
|
||||||
|
qk_norm: bool,
|
||||||
|
norm_layer: nn.Module,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||||
|
self.dim = dim
|
||||||
|
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
|
||||||
|
self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
|
||||||
|
|
||||||
|
self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias)
|
||||||
|
|
||||||
|
self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
|
||||||
|
N_t, N_h, N_w = shape
|
||||||
|
|
||||||
|
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||||
|
# get q for hidden_state
|
||||||
|
B, N, C = x.shape
|
||||||
|
q = self.q_linear(x)
|
||||||
|
q_shape = (B, N, self.num_heads, self.head_dim)
|
||||||
|
q = q.view(q_shape).permute((0, 2, 1, 3))
|
||||||
|
|
||||||
|
if self.qk_norm:
|
||||||
|
q = self.q_norm(q)
|
||||||
|
|
||||||
|
# get kv from encoder_hidden_states
|
||||||
|
_, N_a, _ = encoder_hidden_states.shape
|
||||||
|
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||||
|
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
|
||||||
|
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
|
||||||
|
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||||
|
|
||||||
|
if self.qk_norm:
|
||||||
|
encoder_k = self.add_k_norm(encoder_k)
|
||||||
|
|
||||||
|
q = rearrange(q, "B H M K -> B M H K")
|
||||||
|
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||||
|
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||||
|
|
||||||
|
attn_bias = None
|
||||||
|
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
|
||||||
|
qkv_list = [q, encoder_k, encoder_v]
|
||||||
|
q = encoder_k = encoder_v = None
|
||||||
|
x = pay_attention(qkv_list)
|
||||||
|
x = rearrange(x, "B M H K -> B H M K")
|
||||||
|
|
||||||
|
# linear transform
|
||||||
|
x_output_shape = (B, N, C)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = x.reshape(x_output_shape)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
|
||||||
|
# reshape x to origin shape
|
||||||
|
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class SingleStreamMutiAttention(SingleStreamAttention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
encoder_hidden_states_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
qkv_bias: bool,
|
||||||
|
qk_norm: bool,
|
||||||
|
norm_layer: nn.Module,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
class_range: int = 24,
|
||||||
|
class_interval: int = 4,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
dim=dim,
|
||||||
|
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=proj_drop,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
self.class_interval = class_interval
|
||||||
|
self.class_range = class_range
|
||||||
|
self.rope_h1 = (0, self.class_interval)
|
||||||
|
self.rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||||||
|
self.rope_bak = int(self.class_range // 2)
|
||||||
|
|
||||||
|
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
shape=None,
|
||||||
|
x_ref_attn_map=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
encoder_hidden_states = encoder_hidden_states.squeeze(0)
|
||||||
|
if x_ref_attn_map == None:
|
||||||
|
return super().forward(x, encoder_hidden_states, shape)
|
||||||
|
|
||||||
|
N_t, _, _ = shape
|
||||||
|
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||||
|
|
||||||
|
# get q for hidden_state
|
||||||
|
B, N, C = x.shape
|
||||||
|
q = self.q_linear(x)
|
||||||
|
q_shape = (B, N, self.num_heads, self.head_dim)
|
||||||
|
q = q.view(q_shape).permute((0, 2, 1, 3))
|
||||||
|
|
||||||
|
if self.qk_norm:
|
||||||
|
q = self.q_norm(q)
|
||||||
|
|
||||||
|
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||||||
|
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||||||
|
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||||||
|
|
||||||
|
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||||||
|
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||||||
|
|
||||||
|
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
|
||||||
|
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
|
||||||
|
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
|
||||||
|
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||||
|
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||||
|
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
|
||||||
|
|
||||||
|
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||||
|
q = self.rope_1d(q, normalized_pos)
|
||||||
|
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||||
|
|
||||||
|
_, N_a, _ = encoder_hidden_states.shape
|
||||||
|
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||||
|
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
|
||||||
|
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
|
||||||
|
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||||
|
|
||||||
|
if self.qk_norm:
|
||||||
|
encoder_k = self.add_k_norm(encoder_k)
|
||||||
|
|
||||||
|
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device)
|
||||||
|
per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
|
||||||
|
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
|
||||||
|
encoder_pos = torch.concat([per_frame]*N_t, dim=0)
|
||||||
|
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||||
|
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||||||
|
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||||
|
|
||||||
|
q = rearrange(q, "B H M K -> B M H K")
|
||||||
|
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||||
|
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||||
|
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
|
||||||
|
qkv_list = [q, encoder_k, encoder_v]
|
||||||
|
q = encoder_k = encoder_v = None
|
||||||
|
x = pay_attention(qkv_list)
|
||||||
|
|
||||||
|
x = rearrange(x, "B M H K -> B H M K")
|
||||||
|
|
||||||
|
# linear transform
|
||||||
|
x_output_shape = (B, N, C)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = x.reshape(x_output_shape)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
|
||||||
|
# reshape x to origin shape
|
||||||
|
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||||
|
|
||||||
|
return x
|
||||||
23
wan/multitalk/kokoro/__init__.py
Normal file
23
wan/multitalk/kokoro/__init__.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
__version__ = '0.9.4'
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
|
||||||
|
# Add custom handler with clean format including module and line number
|
||||||
|
logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format="<green>{time:HH:mm:ss}</green> | <cyan>{module:>16}:{line}</cyan> | <level>{level: >8}</level> | <level>{message}</level>",
|
||||||
|
colorize=True,
|
||||||
|
level="INFO" # "DEBUG" to enable logger.debug("message") and up prints
|
||||||
|
# "ERROR" to enable only logger.error("message") prints
|
||||||
|
# etc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable before release or as needed
|
||||||
|
logger.disable("kokoro")
|
||||||
|
|
||||||
|
from .model import KModel
|
||||||
|
from .pipeline import KPipeline
|
||||||
148
wan/multitalk/kokoro/__main__.py
Normal file
148
wan/multitalk/kokoro/__main__.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
"""Kokoro TTS CLI
|
||||||
|
Example usage:
|
||||||
|
python3 -m kokoro --text "The sky above the port was the color of television, tuned to a dead channel." -o file.wav --debug
|
||||||
|
|
||||||
|
echo "Bom dia mundo, como vão vocês" > text.txt
|
||||||
|
python3 -m kokoro -i text.txt -l p --voice pm_alex > audio.wav
|
||||||
|
|
||||||
|
Common issues:
|
||||||
|
pip not installed: `uv pip install pip`
|
||||||
|
(Temporary workaround while https://github.com/explosion/spaCy/issues/13747 is not fixed)
|
||||||
|
|
||||||
|
espeak not installed: `apt-get install espeak-ng`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator, TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
languages = [
|
||||||
|
"a", # American English
|
||||||
|
"b", # British English
|
||||||
|
"h", # Hindi
|
||||||
|
"e", # Spanish
|
||||||
|
"f", # French
|
||||||
|
"i", # Italian
|
||||||
|
"p", # Brazilian Portuguese
|
||||||
|
"j", # Japanese
|
||||||
|
"z", # Mandarin Chinese
|
||||||
|
]
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from kokoro import KPipeline
|
||||||
|
|
||||||
|
|
||||||
|
def generate_audio(
|
||||||
|
text: str, kokoro_language: str, voice: str, speed=1
|
||||||
|
) -> Generator["KPipeline.Result", None, None]:
|
||||||
|
from kokoro import KPipeline
|
||||||
|
|
||||||
|
if not voice.startswith(kokoro_language):
|
||||||
|
logger.warning(f"Voice {voice} is not made for language {kokoro_language}")
|
||||||
|
pipeline = KPipeline(lang_code=kokoro_language)
|
||||||
|
yield from pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_and_save_audio(
|
||||||
|
output_file: Path, text: str, kokoro_language: str, voice: str, speed=1
|
||||||
|
) -> None:
|
||||||
|
with wave.open(str(output_file.resolve()), "wb") as wav_file:
|
||||||
|
wav_file.setnchannels(1) # Mono audio
|
||||||
|
wav_file.setsampwidth(2) # 2 bytes per sample (16-bit audio)
|
||||||
|
wav_file.setframerate(24000) # Sample rate
|
||||||
|
|
||||||
|
for result in generate_audio(
|
||||||
|
text, kokoro_language=kokoro_language, voice=voice, speed=speed
|
||||||
|
):
|
||||||
|
logger.debug(result.phonemes)
|
||||||
|
if result.audio is None:
|
||||||
|
continue
|
||||||
|
audio_bytes = (result.audio.numpy() * 32767).astype(np.int16).tobytes()
|
||||||
|
wav_file.writeframes(audio_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--voice",
|
||||||
|
default="af_heart",
|
||||||
|
help="Voice to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--language",
|
||||||
|
help="Language to use (defaults to the one corresponding to the voice)",
|
||||||
|
choices=languages,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o",
|
||||||
|
"--output-file",
|
||||||
|
"--output_file",
|
||||||
|
type=Path,
|
||||||
|
help="Path to output WAV file",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-i",
|
||||||
|
"--input-file",
|
||||||
|
"--input_file",
|
||||||
|
type=Path,
|
||||||
|
help="Path to input text file (default: stdin)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-t",
|
||||||
|
"--text",
|
||||||
|
help="Text to use instead of reading from stdin",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--speed",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Speech speed",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
action="store_true",
|
||||||
|
help="Print DEBUG messages to console",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.debug:
|
||||||
|
logger.level("DEBUG")
|
||||||
|
logger.debug(args)
|
||||||
|
|
||||||
|
lang = args.language or args.voice[0]
|
||||||
|
|
||||||
|
if args.text is not None and args.input_file is not None:
|
||||||
|
raise Exception("You cannot specify both 'text' and 'input_file'")
|
||||||
|
elif args.text:
|
||||||
|
text = args.text
|
||||||
|
elif args.input_file:
|
||||||
|
file: Path = args.input_file
|
||||||
|
text = file.read_text()
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
print("Press Ctrl+D to stop reading input and start generating", flush=True)
|
||||||
|
text = '\n'.join(sys.stdin)
|
||||||
|
|
||||||
|
logger.debug(f"Input text: {text!r}")
|
||||||
|
|
||||||
|
out_file: Path = args.output_file
|
||||||
|
if not out_file.suffix == ".wav":
|
||||||
|
logger.warning("The output file name should end with .wav")
|
||||||
|
generate_and_save_audio(
|
||||||
|
output_file=out_file,
|
||||||
|
text=text,
|
||||||
|
kokoro_language=lang,
|
||||||
|
voice=args.voice,
|
||||||
|
speed=args.speed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
197
wan/multitalk/kokoro/custom_stft.py
Normal file
197
wan/multitalk/kokoro/custom_stft.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
from attr import attr
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class CustomSTFT(nn.Module):
|
||||||
|
"""
|
||||||
|
STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d.
|
||||||
|
|
||||||
|
- forward STFT => Real-part conv1d + Imag-part conv1d
|
||||||
|
- inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum
|
||||||
|
- avoids F.unfold, so easier to export to ONNX
|
||||||
|
- uses replicate or constant padding for 'center=True' to approximate 'reflect'
|
||||||
|
(reflect is not supported for dynamic shapes in ONNX)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_length=800,
|
||||||
|
hop_length=200,
|
||||||
|
win_length=800,
|
||||||
|
window="hann",
|
||||||
|
center=True,
|
||||||
|
pad_mode="replicate", # or 'constant'
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.n_fft = filter_length
|
||||||
|
self.center = center
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
|
||||||
|
# Number of frequency bins for real-valued STFT with onesided=True
|
||||||
|
self.freq_bins = self.n_fft // 2 + 1
|
||||||
|
|
||||||
|
# Build window
|
||||||
|
assert window == 'hann', window
|
||||||
|
window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
|
||||||
|
if self.win_length < self.n_fft:
|
||||||
|
# Zero-pad up to n_fft
|
||||||
|
extra = self.n_fft - self.win_length
|
||||||
|
window_tensor = F.pad(window_tensor, (0, extra))
|
||||||
|
elif self.win_length > self.n_fft:
|
||||||
|
window_tensor = window_tensor[: self.n_fft]
|
||||||
|
self.register_buffer("window", window_tensor)
|
||||||
|
|
||||||
|
# Precompute forward DFT (real, imag)
|
||||||
|
# PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...)
|
||||||
|
n = np.arange(self.n_fft)
|
||||||
|
k = np.arange(self.freq_bins)
|
||||||
|
angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft)
|
||||||
|
dft_real = np.cos(angle)
|
||||||
|
dft_imag = -np.sin(angle) # note negative sign
|
||||||
|
|
||||||
|
# Combine window and dft => shape (freq_bins, filter_length)
|
||||||
|
# We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length).
|
||||||
|
forward_window = window_tensor.numpy() # shape (n_fft,)
|
||||||
|
forward_real = dft_real * forward_window # (freq_bins, n_fft)
|
||||||
|
forward_imag = dft_imag * forward_window
|
||||||
|
|
||||||
|
# Convert to PyTorch
|
||||||
|
forward_real_torch = torch.from_numpy(forward_real).float()
|
||||||
|
forward_imag_torch = torch.from_numpy(forward_imag).float()
|
||||||
|
|
||||||
|
# Register as Conv1d weight => (out_channels, in_channels, kernel_size)
|
||||||
|
# out_channels = freq_bins, in_channels=1, kernel_size=n_fft
|
||||||
|
self.register_buffer(
|
||||||
|
"weight_forward_real", forward_real_torch.unsqueeze(1)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"weight_forward_imag", forward_imag_torch.unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Precompute inverse DFT
|
||||||
|
# Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc.
|
||||||
|
# For simplicity, we won't do the "DC/nyquist not doubled" approach here.
|
||||||
|
# If you want perfect real iSTFT, you can add that logic.
|
||||||
|
# This version just yields good approximate reconstruction with Hann + typical overlap.
|
||||||
|
inv_scale = 1.0 / self.n_fft
|
||||||
|
n = np.arange(self.n_fft)
|
||||||
|
angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins)
|
||||||
|
idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft)
|
||||||
|
idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft)
|
||||||
|
|
||||||
|
# Multiply by window again for typical overlap-add
|
||||||
|
# We also incorporate the scale factor 1/n_fft
|
||||||
|
inv_window = window_tensor.numpy() * inv_scale
|
||||||
|
backward_real = idft_cos * inv_window # (freq_bins, n_fft)
|
||||||
|
backward_imag = idft_sin * inv_window
|
||||||
|
|
||||||
|
# We'll implement iSTFT as real+imag conv_transpose with stride=hop.
|
||||||
|
self.register_buffer(
|
||||||
|
"weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def transform(self, waveform: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Forward STFT => returns magnitude, phase
|
||||||
|
Output shape => (batch, freq_bins, frames)
|
||||||
|
"""
|
||||||
|
# waveform shape => (B, T). conv1d expects (B, 1, T).
|
||||||
|
# Optional center pad
|
||||||
|
if self.center:
|
||||||
|
pad_len = self.n_fft // 2
|
||||||
|
waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode)
|
||||||
|
|
||||||
|
x = waveform.unsqueeze(1) # => (B, 1, T)
|
||||||
|
# Convolution to get real part => shape (B, freq_bins, frames)
|
||||||
|
real_out = F.conv1d(
|
||||||
|
x,
|
||||||
|
self.weight_forward_real,
|
||||||
|
bias=None,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
# Imag part
|
||||||
|
imag_out = F.conv1d(
|
||||||
|
x,
|
||||||
|
self.weight_forward_imag,
|
||||||
|
bias=None,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# magnitude, phase
|
||||||
|
magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14)
|
||||||
|
phase = torch.atan2(imag_out, real_out)
|
||||||
|
# Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch
|
||||||
|
# In this case, PyTorch returns pi, ONNX returns -pi
|
||||||
|
correction_mask = (imag_out == 0) & (real_out < 0)
|
||||||
|
phase[correction_mask] = torch.pi
|
||||||
|
return magnitude, phase
|
||||||
|
|
||||||
|
|
||||||
|
def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None):
|
||||||
|
"""
|
||||||
|
Inverse STFT => returns waveform shape (B, T).
|
||||||
|
"""
|
||||||
|
# magnitude, phase => (B, freq_bins, frames)
|
||||||
|
# Re-create real/imag => shape (B, freq_bins, frames)
|
||||||
|
real_part = magnitude * torch.cos(phase)
|
||||||
|
imag_part = magnitude * torch.sin(phase)
|
||||||
|
|
||||||
|
# conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension
|
||||||
|
# so we do (B, freq_bins, frames) => (B, freq_bins, frames)
|
||||||
|
# But PyTorch conv_transpose1d expects (B, in_channels, input_length)
|
||||||
|
real_part = real_part # (B, freq_bins, frames)
|
||||||
|
imag_part = imag_part
|
||||||
|
|
||||||
|
# real iSTFT => convolve with "backward_real", "backward_imag", and sum
|
||||||
|
# We'll do 2 conv_transpose calls, each giving (B, 1, time),
|
||||||
|
# then add them => (B, 1, time).
|
||||||
|
real_rec = F.conv_transpose1d(
|
||||||
|
real_part,
|
||||||
|
self.weight_backward_real, # shape (freq_bins, 1, filter_length)
|
||||||
|
bias=None,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
imag_rec = F.conv_transpose1d(
|
||||||
|
imag_part,
|
||||||
|
self.weight_backward_imag,
|
||||||
|
bias=None,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
# sum => (B, 1, time)
|
||||||
|
waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part
|
||||||
|
|
||||||
|
# If we used "center=True" in forward, we should remove pad
|
||||||
|
if self.center:
|
||||||
|
pad_len = self.n_fft // 2
|
||||||
|
# Because of transposed convolution, total length might have extra samples
|
||||||
|
# We remove `pad_len` from start & end if possible
|
||||||
|
waveform = waveform[..., pad_len:-pad_len]
|
||||||
|
|
||||||
|
# If a specific length is desired, clamp
|
||||||
|
if length is not None:
|
||||||
|
waveform = waveform[..., :length]
|
||||||
|
|
||||||
|
# shape => (B, T)
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Full STFT -> iSTFT pass: returns time-domain reconstruction.
|
||||||
|
Same interface as your original code.
|
||||||
|
"""
|
||||||
|
mag, phase = self.transform(x)
|
||||||
|
return self.inverse(mag, phase, length=x.shape[-1])
|
||||||
421
wan/multitalk/kokoro/istftnet.py
Normal file
421
wan/multitalk/kokoro/istftnet.py
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
# ADAPTED from https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||||
|
from .custom_stft import CustomSTFT
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size*dilation - dilation)/2)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaIN1d(nn.Module):
|
||||||
|
def __init__(self, style_dim, num_features):
|
||||||
|
super().__init__()
|
||||||
|
# affine should be False, however there's a bug in the old torch.onnx.export (not newer dynamo) that causes the channel dimension to be lost if affine=False. When affine is true, there's additional learnably parameters. This shouldn't really matter setting it to True, since we're in inference mode
|
||||||
|
self.norm = nn.InstanceNorm1d(num_features, affine=True)
|
||||||
|
self.fc = nn.Linear(style_dim, num_features*2)
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
h = self.fc(s)
|
||||||
|
h = h.view(h.size(0), h.size(1), 1)
|
||||||
|
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||||
|
return (1 + gamma) * self.norm(x) + beta
|
||||||
|
|
||||||
|
|
||||||
|
class AdaINResBlock1(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
||||||
|
super(AdaINResBlock1, self).__init__()
|
||||||
|
self.convs1 = nn.ModuleList([
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]))),
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]))),
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2])))
|
||||||
|
])
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
self.convs2 = nn.ModuleList([
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1)))
|
||||||
|
])
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
self.adain1 = nn.ModuleList([
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
])
|
||||||
|
self.adain2 = nn.ModuleList([
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
])
|
||||||
|
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
||||||
|
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
||||||
|
xt = n1(x, s)
|
||||||
|
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = n2(xt, s)
|
||||||
|
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TorchSTFT(nn.Module):
|
||||||
|
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
||||||
|
super().__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
assert window == 'hann', window
|
||||||
|
self.window = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
|
||||||
|
|
||||||
|
def transform(self, input_data):
|
||||||
|
forward_transform = torch.stft(
|
||||||
|
input_data,
|
||||||
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
||||||
|
return_complex=True)
|
||||||
|
return torch.abs(forward_transform), torch.angle(forward_transform)
|
||||||
|
|
||||||
|
def inverse(self, magnitude, phase):
|
||||||
|
inverse_transform = torch.istft(
|
||||||
|
magnitude * torch.exp(phase * 1j),
|
||||||
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
||||||
|
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
self.magnitude, self.phase = self.transform(input_data)
|
||||||
|
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||||
|
return reconstruction
|
||||||
|
|
||||||
|
|
||||||
|
class SineGen(nn.Module):
|
||||||
|
""" Definition of sine generator
|
||||||
|
SineGen(samp_rate, harmonic_num = 0,
|
||||||
|
sine_amp = 0.1, noise_std = 0.003,
|
||||||
|
voiced_threshold = 0,
|
||||||
|
flag_for_pulse=False)
|
||||||
|
samp_rate: sampling rate in Hz
|
||||||
|
harmonic_num: number of harmonic overtones (default 0)
|
||||||
|
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||||
|
noise_std: std of Gaussian noise (default 0.003)
|
||||||
|
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||||
|
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||||
|
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||||
|
segment is always sin(torch.pi) or cos(0)
|
||||||
|
"""
|
||||||
|
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||||
|
sine_amp=0.1, noise_std=0.003,
|
||||||
|
voiced_threshold=0,
|
||||||
|
flag_for_pulse=False):
|
||||||
|
super(SineGen, self).__init__()
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = noise_std
|
||||||
|
self.harmonic_num = harmonic_num
|
||||||
|
self.dim = self.harmonic_num + 1
|
||||||
|
self.sampling_rate = samp_rate
|
||||||
|
self.voiced_threshold = voiced_threshold
|
||||||
|
self.flag_for_pulse = flag_for_pulse
|
||||||
|
self.upsample_scale = upsample_scale
|
||||||
|
|
||||||
|
def _f02uv(self, f0):
|
||||||
|
# generate uv signal
|
||||||
|
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||||
|
return uv
|
||||||
|
|
||||||
|
def _f02sine(self, f0_values):
|
||||||
|
""" f0_values: (batchsize, length, dim)
|
||||||
|
where dim indicates fundamental tone and overtones
|
||||||
|
"""
|
||||||
|
# convert to F0 in rad. The interger part n can be ignored
|
||||||
|
# because 2 * torch.pi * n doesn't affect phase
|
||||||
|
rad_values = (f0_values / self.sampling_rate) % 1
|
||||||
|
# initial phase noise (no noise for fundamental component)
|
||||||
|
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||||
|
rand_ini[:, 0] = 0
|
||||||
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||||
|
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||||
|
if not self.flag_for_pulse:
|
||||||
|
rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
|
||||||
|
phase = torch.cumsum(rad_values, dim=1) * 2 * torch.pi
|
||||||
|
phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||||
|
sines = torch.sin(phase)
|
||||||
|
else:
|
||||||
|
# If necessary, make sure that the first time step of every
|
||||||
|
# voiced segments is sin(pi) or cos(0)
|
||||||
|
# This is used for pulse-train generation
|
||||||
|
# identify the last time step in unvoiced segments
|
||||||
|
uv = self._f02uv(f0_values)
|
||||||
|
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||||
|
uv_1[:, -1, :] = 1
|
||||||
|
u_loc = (uv < 1) * (uv_1 > 0)
|
||||||
|
# get the instantanouse phase
|
||||||
|
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||||
|
# different batch needs to be processed differently
|
||||||
|
for idx in range(f0_values.shape[0]):
|
||||||
|
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||||
|
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||||
|
# stores the accumulation of i.phase within
|
||||||
|
# each voiced segments
|
||||||
|
tmp_cumsum[idx, :, :] = 0
|
||||||
|
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||||
|
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||||
|
# within the previous voiced segment.
|
||||||
|
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||||
|
# get the sines
|
||||||
|
sines = torch.cos(i_phase * 2 * torch.pi)
|
||||||
|
return sines
|
||||||
|
|
||||||
|
def forward(self, f0):
|
||||||
|
""" sine_tensor, uv = forward(f0)
|
||||||
|
input F0: tensor(batchsize=1, length, dim=1)
|
||||||
|
f0 for unvoiced steps should be 0
|
||||||
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
|
"""
|
||||||
|
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
||||||
|
# fundamental component
|
||||||
|
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||||
|
# generate sine waveforms
|
||||||
|
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||||
|
# generate uv signal
|
||||||
|
# uv = torch.ones(f0.shape)
|
||||||
|
# uv = uv * (f0 > self.voiced_threshold)
|
||||||
|
uv = self._f02uv(f0)
|
||||||
|
# noise: for unvoiced should be similar to sine_amp
|
||||||
|
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||||
|
# for voiced regions is self.noise_std
|
||||||
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||||
|
noise = noise_amp * torch.randn_like(sine_waves)
|
||||||
|
# first: set the unvoiced part to 0 by uv
|
||||||
|
# then: additive noise
|
||||||
|
sine_waves = sine_waves * uv + noise
|
||||||
|
return sine_waves, uv, noise
|
||||||
|
|
||||||
|
|
||||||
|
class SourceModuleHnNSF(nn.Module):
|
||||||
|
""" SourceModule for hn-nsf
|
||||||
|
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0)
|
||||||
|
sampling_rate: sampling_rate in Hz
|
||||||
|
harmonic_num: number of harmonic above F0 (default: 0)
|
||||||
|
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||||
|
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||||
|
note that amplitude of noise in unvoiced is decided
|
||||||
|
by sine_amp
|
||||||
|
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
uv (batchsize, length, 1)
|
||||||
|
"""
|
||||||
|
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0):
|
||||||
|
super(SourceModuleHnNSF, self).__init__()
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = add_noise_std
|
||||||
|
# to produce sine waveforms
|
||||||
|
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
||||||
|
sine_amp, add_noise_std, voiced_threshod)
|
||||||
|
# to merge source harmonics into a single excitation
|
||||||
|
self.l_linear = nn.Linear(harmonic_num + 1, 1)
|
||||||
|
self.l_tanh = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
"""
|
||||||
|
# source for harmonic branch
|
||||||
|
with torch.no_grad():
|
||||||
|
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||||
|
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||||
|
# source for noise branch, in the same shape as uv
|
||||||
|
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||||
|
return sine_merge, noise, uv
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
self.m_source = SourceModuleHnNSF(
|
||||||
|
sampling_rate=24000,
|
||||||
|
upsample_scale=math.prod(upsample_rates) * gen_istft_hop_size,
|
||||||
|
harmonic_num=8, voiced_threshod=10)
|
||||||
|
self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates) * gen_istft_hop_size)
|
||||||
|
self.noise_convs = nn.ModuleList()
|
||||||
|
self.noise_res = nn.ModuleList()
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(weight_norm(
|
||||||
|
nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||||
|
k, u, padding=(k-u)//2)))
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel//(2**(i+1))
|
||||||
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
|
||||||
|
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
if i + 1 < len(upsample_rates):
|
||||||
|
stride_f0 = math.prod(upsample_rates[i + 1:])
|
||||||
|
self.noise_convs.append(nn.Conv1d(
|
||||||
|
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
||||||
|
self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
|
||||||
|
else:
|
||||||
|
self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
||||||
|
self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
|
||||||
|
self.post_n_fft = gen_istft_n_fft
|
||||||
|
self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
||||||
|
self.ups.apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||||
|
self.stft = (
|
||||||
|
CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||||
|
if disable_complex
|
||||||
|
else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, s, f0):
|
||||||
|
with torch.no_grad():
|
||||||
|
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||||
|
har_source, noi_source, uv = self.m_source(f0)
|
||||||
|
har_source = har_source.transpose(1, 2).squeeze(1)
|
||||||
|
har_spec, har_phase = self.stft.transform(har_source)
|
||||||
|
har = torch.cat([har_spec, har_phase], dim=1)
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, negative_slope=0.1)
|
||||||
|
x_source = self.noise_convs[i](har)
|
||||||
|
x_source = self.noise_res[i](x_source, s)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
if i == self.num_upsamples - 1:
|
||||||
|
x = self.reflection_pad(x)
|
||||||
|
x = x + x_source
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||||
|
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||||
|
return self.stft.inverse(spec, phase)
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, layer_type):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_type = layer_type
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.layer_type == 'none':
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
|
||||||
|
class AdainResBlk1d(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.actv = actv
|
||||||
|
self.upsample_type = upsample
|
||||||
|
self.upsample = UpSample1d(upsample)
|
||||||
|
self.learned_sc = dim_in != dim_out
|
||||||
|
self._build_weights(dim_in, dim_out, style_dim)
|
||||||
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
if upsample == 'none':
|
||||||
|
self.pool = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||||
|
|
||||||
|
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||||
|
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||||
|
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||||
|
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||||
|
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||||
|
if self.learned_sc:
|
||||||
|
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||||
|
|
||||||
|
def _shortcut(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
if self.learned_sc:
|
||||||
|
x = self.conv1x1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _residual(self, x, s):
|
||||||
|
x = self.norm1(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.conv1(self.dropout(x))
|
||||||
|
x = self.norm2(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.conv2(self.dropout(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
out = self._residual(x, s)
|
||||||
|
out = (out + self._shortcut(x)) * torch.rsqrt(torch.tensor(2))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, dim_in, style_dim, dim_out,
|
||||||
|
resblock_kernel_sizes,
|
||||||
|
upsample_rates,
|
||||||
|
upsample_initial_channel,
|
||||||
|
resblock_dilation_sizes,
|
||||||
|
upsample_kernel_sizes,
|
||||||
|
gen_istft_n_fft, gen_istft_hop_size,
|
||||||
|
disable_complex=False):
|
||||||
|
super().__init__()
|
||||||
|
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
||||||
|
self.decode = nn.ModuleList()
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
||||||
|
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||||
|
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||||
|
self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
|
||||||
|
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
||||||
|
upsample_initial_channel, resblock_dilation_sizes,
|
||||||
|
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex)
|
||||||
|
|
||||||
|
def forward(self, asr, F0_curve, N, s):
|
||||||
|
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
||||||
|
N = self.N_conv(N.unsqueeze(1))
|
||||||
|
x = torch.cat([asr, F0, N], axis=1)
|
||||||
|
x = self.encode(x, s)
|
||||||
|
asr_res = self.asr_res(asr)
|
||||||
|
res = True
|
||||||
|
for block in self.decode:
|
||||||
|
if res:
|
||||||
|
x = torch.cat([x, asr_res, F0, N], axis=1)
|
||||||
|
x = block(x, s)
|
||||||
|
if block.upsample_type != "none":
|
||||||
|
res = False
|
||||||
|
x = self.generator(x, s, F0_curve)
|
||||||
|
return x
|
||||||
155
wan/multitalk/kokoro/model.py
Normal file
155
wan/multitalk/kokoro/model.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from .istftnet import Decoder
|
||||||
|
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from loguru import logger
|
||||||
|
from transformers import AlbertConfig
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
class KModel(torch.nn.Module):
|
||||||
|
'''
|
||||||
|
KModel is a torch.nn.Module with 2 main responsibilities:
|
||||||
|
1. Init weights, downloading config.json + model.pth from HF if needed
|
||||||
|
2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
|
||||||
|
|
||||||
|
You likely only need one KModel instance, and it can be reused across
|
||||||
|
multiple KPipelines to avoid redundant memory allocation.
|
||||||
|
|
||||||
|
Unlike KPipeline, KModel is language-blind.
|
||||||
|
|
||||||
|
KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
|
||||||
|
so there is no need to repeatedly download config.json outside of KModel.
|
||||||
|
'''
|
||||||
|
|
||||||
|
MODEL_NAMES = {
|
||||||
|
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
|
||||||
|
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: Optional[str] = None,
|
||||||
|
config: Union[Dict, str, None] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
disable_complex: bool = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if repo_id is None:
|
||||||
|
repo_id = 'hexgrad/Kokoro-82M'
|
||||||
|
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
||||||
|
self.repo_id = repo_id
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
if not config:
|
||||||
|
logger.debug("No config provided, downloading from HF")
|
||||||
|
config = hf_hub_download(repo_id=repo_id, filename='config.json')
|
||||||
|
with open(config, 'r', encoding='utf-8') as r:
|
||||||
|
config = json.load(r)
|
||||||
|
logger.debug(f"Loaded config: {config}")
|
||||||
|
self.vocab = config['vocab']
|
||||||
|
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
|
||||||
|
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
|
||||||
|
self.context_length = self.bert.config.max_position_embeddings
|
||||||
|
self.predictor = ProsodyPredictor(
|
||||||
|
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
|
||||||
|
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
|
||||||
|
)
|
||||||
|
self.text_encoder = TextEncoder(
|
||||||
|
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
|
||||||
|
depth=config['n_layer'], n_symbols=config['n_token']
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
|
||||||
|
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
|
||||||
|
)
|
||||||
|
if not model:
|
||||||
|
try:
|
||||||
|
model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id])
|
||||||
|
except:
|
||||||
|
model = os.path.join(repo_id, 'kokoro-v1_0.pth')
|
||||||
|
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
|
||||||
|
assert hasattr(self, key), key
|
||||||
|
try:
|
||||||
|
getattr(self, key).load_state_dict(state_dict)
|
||||||
|
except:
|
||||||
|
logger.debug(f"Did not load {key} from state_dict")
|
||||||
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
getattr(self, key).load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.bert.device
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Output:
|
||||||
|
audio: torch.FloatTensor
|
||||||
|
pred_dur: Optional[torch.LongTensor] = None
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_with_tokens(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
ref_s: torch.FloatTensor,
|
||||||
|
speed: float = 1
|
||||||
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
|
input_lengths = torch.full(
|
||||||
|
(input_ids.shape[0],),
|
||||||
|
input_ids.shape[-1],
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
|
||||||
|
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
|
||||||
|
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
|
||||||
|
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
|
||||||
|
s = ref_s[:, 128:]
|
||||||
|
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||||
|
x, _ = self.predictor.lstm(d)
|
||||||
|
duration = self.predictor.duration_proj(x)
|
||||||
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
|
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
|
||||||
|
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
|
||||||
|
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
|
||||||
|
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
|
||||||
|
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
|
||||||
|
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||||
|
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
||||||
|
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
||||||
|
asr = t_en @ pred_aln_trg
|
||||||
|
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
|
||||||
|
return audio, pred_dur
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
phonemes: str,
|
||||||
|
ref_s: torch.FloatTensor,
|
||||||
|
speed: float = 1,
|
||||||
|
return_output: bool = False
|
||||||
|
) -> Union['KModel.Output', torch.FloatTensor]:
|
||||||
|
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||||
|
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
|
||||||
|
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
||||||
|
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
|
||||||
|
ref_s = ref_s.to(self.device)
|
||||||
|
audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed)
|
||||||
|
audio = audio.squeeze().cpu()
|
||||||
|
pred_dur = pred_dur.cpu() if pred_dur is not None else None
|
||||||
|
logger.debug(f"pred_dur: {pred_dur}")
|
||||||
|
return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio
|
||||||
|
|
||||||
|
class KModelForONNX(torch.nn.Module):
|
||||||
|
def __init__(self, kmodel: KModel):
|
||||||
|
super().__init__()
|
||||||
|
self.kmodel = kmodel
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
ref_s: torch.FloatTensor,
|
||||||
|
speed: float = 1
|
||||||
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
|
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
|
||||||
|
return waveform, duration
|
||||||
183
wan/multitalk/kokoro/modules.py
Normal file
183
wan/multitalk/kokoro/modules.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||||
|
from .istftnet import AdainResBlk1d
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
from transformers import AlbertModel
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class LinearNorm(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||||
|
super(LinearNorm, self).__init__()
|
||||||
|
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
|
||||||
|
nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_layer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||||
|
return x.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoder(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Embedding(n_symbols, channels)
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
self.cnn = nn.ModuleList()
|
||||||
|
for _ in range(depth):
|
||||||
|
self.cnn.append(nn.Sequential(
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
||||||
|
LayerNorm(channels),
|
||||||
|
actv,
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
))
|
||||||
|
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
||||||
|
|
||||||
|
def forward(self, x, input_lengths, m):
|
||||||
|
x = self.embedding(x) # [B, T, emb]
|
||||||
|
x = x.transpose(1, 2) # [B, emb, T]
|
||||||
|
m = m.unsqueeze(1)
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
for c in self.cnn:
|
||||||
|
x = c(x)
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
x = x.transpose(1, 2) # [B, T, chn]
|
||||||
|
lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu')
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
x, _ = self.lstm(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
|
||||||
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
|
x = x_pad
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
def __init__(self, style_dim, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
self.fc = nn.Linear(style_dim, channels*2)
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
h = self.fc(s)
|
||||||
|
h = h.view(h.size(0), h.size(1), 1)
|
||||||
|
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||||
|
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
||||||
|
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
||||||
|
x = (1 + gamma) * x + beta
|
||||||
|
return x.transpose(1, -1).transpose(-1, -2)
|
||||||
|
|
||||||
|
|
||||||
|
class ProsodyPredictor(nn.Module):
|
||||||
|
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
|
||||||
|
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||||
|
self.duration_proj = LinearNorm(d_hid, max_dur)
|
||||||
|
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||||
|
self.F0 = nn.ModuleList()
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||||
|
self.N = nn.ModuleList()
|
||||||
|
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||||
|
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||||
|
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||||
|
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||||
|
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||||
|
|
||||||
|
def forward(self, texts, style, text_lengths, alignment, m):
|
||||||
|
d = self.text_encoder(texts, style, text_lengths, m)
|
||||||
|
m = m.unsqueeze(1)
|
||||||
|
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
x, _ = self.lstm(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||||
|
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device)
|
||||||
|
x_pad[:, :x.shape[1], :] = x
|
||||||
|
x = x_pad
|
||||||
|
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
|
||||||
|
en = (d.transpose(-1, -2) @ alignment)
|
||||||
|
return duration.squeeze(-1), en
|
||||||
|
|
||||||
|
def F0Ntrain(self, x, s):
|
||||||
|
x, _ = self.shared(x.transpose(-1, -2))
|
||||||
|
F0 = x.transpose(-1, -2)
|
||||||
|
for block in self.F0:
|
||||||
|
F0 = block(F0, s)
|
||||||
|
F0 = self.F0_proj(F0)
|
||||||
|
N = x.transpose(-1, -2)
|
||||||
|
for block in self.N:
|
||||||
|
N = block(N, s)
|
||||||
|
N = self.N_proj(N)
|
||||||
|
return F0.squeeze(1), N.squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class DurationEncoder(nn.Module):
|
||||||
|
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.lstms = nn.ModuleList()
|
||||||
|
for _ in range(nlayers):
|
||||||
|
self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
|
||||||
|
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
||||||
|
self.dropout = dropout
|
||||||
|
self.d_model = d_model
|
||||||
|
self.sty_dim = sty_dim
|
||||||
|
|
||||||
|
def forward(self, x, style, text_lengths, m):
|
||||||
|
masks = m
|
||||||
|
x = x.permute(2, 0, 1)
|
||||||
|
s = style.expand(x.shape[0], x.shape[1], -1)
|
||||||
|
x = torch.cat([x, s], axis=-1)
|
||||||
|
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
for block in self.lstms:
|
||||||
|
if isinstance(block, AdaLayerNorm):
|
||||||
|
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
||||||
|
x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
|
||||||
|
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
||||||
|
else:
|
||||||
|
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
x, lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
block.flatten_parameters()
|
||||||
|
x, _ = block(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
|
x, batch_first=True)
|
||||||
|
x = F.dropout(x, p=self.dropout, training=False)
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
|
||||||
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
|
x = x_pad
|
||||||
|
|
||||||
|
return x.transpose(-1, -2)
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||||
|
class CustomAlbert(AlbertModel):
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
outputs = super().forward(*args, **kwargs)
|
||||||
|
return outputs.last_hidden_state
|
||||||
445
wan/multitalk/kokoro/pipeline.py
Normal file
445
wan/multitalk/kokoro/pipeline.py
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
from .model import KModel
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from loguru import logger
|
||||||
|
from misaki import en, espeak
|
||||||
|
from typing import Callable, Generator, List, Optional, Tuple, Union
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
ALIASES = {
|
||||||
|
'en-us': 'a',
|
||||||
|
'en-gb': 'b',
|
||||||
|
'es': 'e',
|
||||||
|
'fr-fr': 'f',
|
||||||
|
'hi': 'h',
|
||||||
|
'it': 'i',
|
||||||
|
'pt-br': 'p',
|
||||||
|
'ja': 'j',
|
||||||
|
'zh': 'z',
|
||||||
|
}
|
||||||
|
|
||||||
|
LANG_CODES = dict(
|
||||||
|
# pip install misaki[en]
|
||||||
|
a='American English',
|
||||||
|
b='British English',
|
||||||
|
|
||||||
|
# espeak-ng
|
||||||
|
e='es',
|
||||||
|
f='fr-fr',
|
||||||
|
h='hi',
|
||||||
|
i='it',
|
||||||
|
p='pt-br',
|
||||||
|
|
||||||
|
# pip install misaki[ja]
|
||||||
|
j='Japanese',
|
||||||
|
|
||||||
|
# pip install misaki[zh]
|
||||||
|
z='Mandarin Chinese',
|
||||||
|
)
|
||||||
|
|
||||||
|
class KPipeline:
|
||||||
|
'''
|
||||||
|
KPipeline is a language-aware support class with 2 main responsibilities:
|
||||||
|
1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
|
||||||
|
2. Manage and store voices, lazily downloaded from HF if needed
|
||||||
|
|
||||||
|
You are expected to have one KPipeline per language. If you have multiple
|
||||||
|
KPipelines, you should reuse one KModel instance across all of them.
|
||||||
|
|
||||||
|
KPipeline is designed to work with a KModel, but this is not required.
|
||||||
|
There are 2 ways to pass an existing model into a pipeline:
|
||||||
|
1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
|
||||||
|
2. On call: us_pipeline(text, voice, model=model)
|
||||||
|
|
||||||
|
By default, KPipeline will automatically initialize its own KModel. To
|
||||||
|
suppress this, construct a "quiet" KPipeline with model=False.
|
||||||
|
|
||||||
|
A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
|
||||||
|
any audio. You can use this to phonemize and chunk your text in advance.
|
||||||
|
|
||||||
|
A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lang_code: str,
|
||||||
|
repo_id: Optional[str] = None,
|
||||||
|
model: Union[KModel, bool] = True,
|
||||||
|
trf: bool = False,
|
||||||
|
en_callable: Optional[Callable[[str], str]] = None,
|
||||||
|
device: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""Initialize a KPipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lang_code: Language code for G2P processing
|
||||||
|
model: KModel instance, True to create new model, False for no model
|
||||||
|
trf: Whether to use transformer-based G2P
|
||||||
|
device: Override default device selection ('cuda' or 'cpu', or None for auto)
|
||||||
|
If None, will auto-select cuda if available
|
||||||
|
If 'cuda' and not available, will explicitly raise an error
|
||||||
|
"""
|
||||||
|
if repo_id is None:
|
||||||
|
repo_id = 'hexgrad/Kokoro-82M'
|
||||||
|
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
||||||
|
config=None
|
||||||
|
else:
|
||||||
|
config = os.path.join(repo_id, 'config.json')
|
||||||
|
self.repo_id = repo_id
|
||||||
|
lang_code = lang_code.lower()
|
||||||
|
lang_code = ALIASES.get(lang_code, lang_code)
|
||||||
|
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||||
|
self.lang_code = lang_code
|
||||||
|
self.model = None
|
||||||
|
if isinstance(model, KModel):
|
||||||
|
self.model = model
|
||||||
|
elif model:
|
||||||
|
if device == 'cuda' and not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA requested but not available")
|
||||||
|
if device == 'mps' and not torch.backends.mps.is_available():
|
||||||
|
raise RuntimeError("MPS requested but not available")
|
||||||
|
if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
|
||||||
|
raise RuntimeError("MPS requested but fallback not enabled")
|
||||||
|
if device is None:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = 'cuda'
|
||||||
|
elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
|
||||||
|
device = 'mps'
|
||||||
|
else:
|
||||||
|
device = 'cpu'
|
||||||
|
try:
|
||||||
|
self.model = KModel(repo_id=repo_id, config=config).to(device).eval()
|
||||||
|
except RuntimeError as e:
|
||||||
|
if device == 'cuda':
|
||||||
|
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
|
||||||
|
Try setting device='cpu' or check CUDA installation.""")
|
||||||
|
raise
|
||||||
|
self.voices = {}
|
||||||
|
if lang_code in 'ab':
|
||||||
|
try:
|
||||||
|
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
|
||||||
|
logger.warning({str(e)})
|
||||||
|
fallback = None
|
||||||
|
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
|
||||||
|
elif lang_code == 'j':
|
||||||
|
try:
|
||||||
|
from misaki import ja
|
||||||
|
self.g2p = ja.JAG2P()
|
||||||
|
except ImportError:
|
||||||
|
logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
|
||||||
|
raise
|
||||||
|
elif lang_code == 'z':
|
||||||
|
try:
|
||||||
|
from misaki import zh
|
||||||
|
self.g2p = zh.ZHG2P(
|
||||||
|
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
|
||||||
|
en_callable=en_callable
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
language = LANG_CODES[lang_code]
|
||||||
|
logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
|
||||||
|
self.g2p = espeak.EspeakG2P(language=language)
|
||||||
|
|
||||||
|
def load_single_voice(self, voice: str):
|
||||||
|
if voice in self.voices:
|
||||||
|
return self.voices[voice]
|
||||||
|
if voice.endswith('.pt'):
|
||||||
|
f = voice
|
||||||
|
else:
|
||||||
|
f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
|
||||||
|
if not voice.startswith(self.lang_code):
|
||||||
|
v = LANG_CODES.get(voice, voice)
|
||||||
|
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
||||||
|
logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
|
||||||
|
pack = torch.load(f, weights_only=True)
|
||||||
|
self.voices[voice] = pack
|
||||||
|
return pack
|
||||||
|
|
||||||
|
"""
|
||||||
|
load_voice is a helper function that lazily downloads and loads a voice:
|
||||||
|
Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica').
|
||||||
|
If multiple voices are requested, they are averaged.
|
||||||
|
Delimiter is optional and defaults to ','.
|
||||||
|
"""
|
||||||
|
def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
|
||||||
|
if isinstance(voice, torch.FloatTensor):
|
||||||
|
return voice
|
||||||
|
if voice in self.voices:
|
||||||
|
return self.voices[voice]
|
||||||
|
logger.debug(f"Loading voice: {voice}")
|
||||||
|
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
|
||||||
|
if len(packs) == 1:
|
||||||
|
return packs[0]
|
||||||
|
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
||||||
|
return self.voices[voice]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tokens_to_ps(tokens: List[en.MToken]) -> str:
|
||||||
|
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def waterfall_last(
|
||||||
|
tokens: List[en.MToken],
|
||||||
|
next_count: int,
|
||||||
|
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
||||||
|
bumps: List[str] = [')', '”']
|
||||||
|
) -> int:
|
||||||
|
for w in waterfall:
|
||||||
|
z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
|
||||||
|
if z is None:
|
||||||
|
continue
|
||||||
|
z += 1
|
||||||
|
if z < len(tokens) and tokens[z].phonemes in bumps:
|
||||||
|
z += 1
|
||||||
|
if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
|
||||||
|
return z
|
||||||
|
return len(tokens)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tokens_to_text(tokens: List[en.MToken]) -> str:
|
||||||
|
return ''.join(t.text + t.whitespace for t in tokens).strip()
|
||||||
|
|
||||||
|
def en_tokenize(
|
||||||
|
self,
|
||||||
|
tokens: List[en.MToken]
|
||||||
|
) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
|
||||||
|
tks = []
|
||||||
|
pcount = 0
|
||||||
|
for t in tokens:
|
||||||
|
# American English: ɾ => T
|
||||||
|
t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T')
|
||||||
|
next_ps = t.phonemes + (' ' if t.whitespace else '')
|
||||||
|
next_pcount = pcount + len(next_ps.rstrip())
|
||||||
|
if next_pcount > 510:
|
||||||
|
z = KPipeline.waterfall_last(tks, next_pcount)
|
||||||
|
text = KPipeline.tokens_to_text(tks[:z])
|
||||||
|
logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
|
||||||
|
ps = KPipeline.tokens_to_ps(tks[:z])
|
||||||
|
yield text, ps, tks[:z]
|
||||||
|
tks = tks[z:]
|
||||||
|
pcount = len(KPipeline.tokens_to_ps(tks))
|
||||||
|
if not tks:
|
||||||
|
next_ps = next_ps.lstrip()
|
||||||
|
tks.append(t)
|
||||||
|
pcount += len(next_ps)
|
||||||
|
if tks:
|
||||||
|
text = KPipeline.tokens_to_text(tks)
|
||||||
|
ps = KPipeline.tokens_to_ps(tks)
|
||||||
|
yield ''.join(text).strip(), ''.join(ps).strip(), tks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer(
|
||||||
|
model: KModel,
|
||||||
|
ps: str,
|
||||||
|
pack: torch.FloatTensor,
|
||||||
|
speed: Union[float, Callable[[int], float]] = 1
|
||||||
|
) -> KModel.Output:
|
||||||
|
if callable(speed):
|
||||||
|
speed = speed(len(ps))
|
||||||
|
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||||
|
|
||||||
|
def generate_from_tokens(
|
||||||
|
self,
|
||||||
|
tokens: Union[str, List[en.MToken]],
|
||||||
|
voice: str,
|
||||||
|
speed: float = 1,
|
||||||
|
model: Optional[KModel] = None
|
||||||
|
) -> Generator['KPipeline.Result', None, None]:
|
||||||
|
"""Generate audio from either raw phonemes or pre-processed tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Either a phoneme string or list of pre-processed MTokens
|
||||||
|
voice: The voice to use for synthesis
|
||||||
|
speed: Speech speed modifier (default: 1)
|
||||||
|
model: Optional KModel instance (uses pipeline's model if not provided)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
KPipeline.Result containing the input tokens and generated audio
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no voice is provided or token sequence exceeds model limits
|
||||||
|
"""
|
||||||
|
model = model or self.model
|
||||||
|
if model and voice is None:
|
||||||
|
raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')
|
||||||
|
|
||||||
|
pack = self.load_voice(voice).to(model.device) if model else None
|
||||||
|
|
||||||
|
# Handle raw phoneme string
|
||||||
|
if isinstance(tokens, str):
|
||||||
|
logger.debug("Processing phonemes from raw string")
|
||||||
|
if len(tokens) > 510:
|
||||||
|
raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
|
||||||
|
output = KPipeline.infer(model, tokens, pack, speed) if model else None
|
||||||
|
yield self.Result(graphemes='', phonemes=tokens, output=output)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("Processing MTokens")
|
||||||
|
# Handle pre-processed tokens
|
||||||
|
for gs, ps, tks in self.en_tokenize(tokens):
|
||||||
|
if not ps:
|
||||||
|
continue
|
||||||
|
elif len(ps) > 510:
|
||||||
|
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||||
|
logger.warning("Truncating to 510 characters")
|
||||||
|
ps = ps[:510]
|
||||||
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||||
|
if output is not None and output.pred_dur is not None:
|
||||||
|
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||||
|
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
||||||
|
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
||||||
|
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
|
||||||
|
# We will count nice round half-frames, so the divisor is 80
|
||||||
|
MAGIC_DIVISOR = 80
|
||||||
|
if not tokens or len(pred_dur) < 3:
|
||||||
|
# We expect at least 3: <bos>, token, <eos>
|
||||||
|
return
|
||||||
|
# We track 2 counts, measured in half-frames: (left, right)
|
||||||
|
# This way we can cut space characters in half
|
||||||
|
# TODO: Is -3 an appropriate offset?
|
||||||
|
left = right = 2 * max(0, pred_dur[0].item() - 3)
|
||||||
|
# Updates:
|
||||||
|
# left = right + (2 * token_dur) + space_dur
|
||||||
|
# right = left + space_dur
|
||||||
|
i = 1
|
||||||
|
for t in tokens:
|
||||||
|
if i >= len(pred_dur)-1:
|
||||||
|
break
|
||||||
|
if not t.phonemes:
|
||||||
|
if t.whitespace:
|
||||||
|
i += 1
|
||||||
|
left = right + pred_dur[i].item()
|
||||||
|
right = left + pred_dur[i].item()
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
j = i + len(t.phonemes)
|
||||||
|
if j >= len(pred_dur):
|
||||||
|
break
|
||||||
|
t.start_ts = left / MAGIC_DIVISOR
|
||||||
|
token_dur = pred_dur[i: j].sum().item()
|
||||||
|
space_dur = pred_dur[j].item() if t.whitespace else 0
|
||||||
|
left = right + (2 * token_dur) + space_dur
|
||||||
|
t.end_ts = left / MAGIC_DIVISOR
|
||||||
|
right = left + space_dur
|
||||||
|
i = j + (1 if t.whitespace else 0)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Result:
|
||||||
|
graphemes: str
|
||||||
|
phonemes: str
|
||||||
|
tokens: Optional[List[en.MToken]] = None
|
||||||
|
output: Optional[KModel.Output] = None
|
||||||
|
text_index: Optional[int] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio(self) -> Optional[torch.FloatTensor]:
|
||||||
|
return None if self.output is None else self.output.audio
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_dur(self) -> Optional[torch.LongTensor]:
|
||||||
|
return None if self.output is None else self.output.pred_dur
|
||||||
|
|
||||||
|
### MARK: BEGIN BACKWARD COMPAT ###
|
||||||
|
def __iter__(self):
|
||||||
|
yield self.graphemes
|
||||||
|
yield self.phonemes
|
||||||
|
yield self.audio
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return [self.graphemes, self.phonemes, self.audio][index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 3
|
||||||
|
#### MARK: END BACKWARD COMPAT ####
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Union[str, List[str]],
|
||||||
|
voice: Optional[str] = None,
|
||||||
|
speed: Union[float, Callable[[int], float]] = 1,
|
||||||
|
split_pattern: Optional[str] = r'\n+',
|
||||||
|
model: Optional[KModel] = None
|
||||||
|
) -> Generator['KPipeline.Result', None, None]:
|
||||||
|
model = model or self.model
|
||||||
|
if model and voice is None:
|
||||||
|
raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
|
||||||
|
pack = self.load_voice(voice).to(model.device) if model else None
|
||||||
|
|
||||||
|
# Convert input to list of segments
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
||||||
|
|
||||||
|
# Process each segment
|
||||||
|
for graphemes_index, graphemes in enumerate(text):
|
||||||
|
if not graphemes.strip(): # Skip empty segments
|
||||||
|
continue
|
||||||
|
|
||||||
|
# English processing (unchanged)
|
||||||
|
if self.lang_code in 'ab':
|
||||||
|
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
|
||||||
|
_, tokens = self.g2p(graphemes)
|
||||||
|
for gs, ps, tks in self.en_tokenize(tokens):
|
||||||
|
if not ps:
|
||||||
|
continue
|
||||||
|
elif len(ps) > 510:
|
||||||
|
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||||
|
ps = ps[:510]
|
||||||
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||||
|
if output is not None and output.pred_dur is not None:
|
||||||
|
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||||
|
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index)
|
||||||
|
|
||||||
|
# Non-English processing with chunking
|
||||||
|
else:
|
||||||
|
# Split long text into smaller chunks (roughly 400 characters each)
|
||||||
|
# Using sentence boundaries when possible
|
||||||
|
chunk_size = 400
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
# Try to split on sentence boundaries first
|
||||||
|
sentences = re.split(r'([.!?]+)', graphemes)
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
for i in range(0, len(sentences), 2):
|
||||||
|
sentence = sentences[i]
|
||||||
|
# Add the punctuation back if it exists
|
||||||
|
if i + 1 < len(sentences):
|
||||||
|
sentence += sentences[i + 1]
|
||||||
|
|
||||||
|
if len(current_chunk) + len(sentence) <= chunk_size:
|
||||||
|
current_chunk += sentence
|
||||||
|
else:
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk.strip())
|
||||||
|
current_chunk = sentence
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk.strip())
|
||||||
|
|
||||||
|
# If no chunks were created (no sentence boundaries), fall back to character-based chunking
|
||||||
|
if not chunks:
|
||||||
|
chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)]
|
||||||
|
|
||||||
|
# Process each chunk
|
||||||
|
for chunk in chunks:
|
||||||
|
if not chunk.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
ps, _ = self.g2p(chunk)
|
||||||
|
if not ps:
|
||||||
|
continue
|
||||||
|
elif len(ps) > 510:
|
||||||
|
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
|
||||||
|
ps = ps[:510]
|
||||||
|
|
||||||
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||||
|
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
|
||||||
319
wan/multitalk/multitalk.py
Normal file
319
wan/multitalk/multitalk.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
import random
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from PIL import Image
|
||||||
|
import subprocess
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
import wan
|
||||||
|
from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||||
|
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||||
|
# from wan.utils.multitalk_utils import save_video_ffmpeg
|
||||||
|
# from .kokoro import KPipeline
|
||||||
|
from transformers import Wav2Vec2FeatureExtractor
|
||||||
|
from .wav2vec2 import Wav2Vec2Model
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import pyloudnorm as pyln
|
||||||
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
import soundfile as sf
|
||||||
|
import re
|
||||||
|
import math
|
||||||
|
|
||||||
|
def custom_init(device, wav2vec):
|
||||||
|
audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device)
|
||||||
|
audio_encoder.feature_extractor._freeze_parameters()
|
||||||
|
wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
|
||||||
|
return wav2vec_feature_extractor, audio_encoder
|
||||||
|
|
||||||
|
def loudness_norm(audio_array, sr=16000, lufs=-23):
|
||||||
|
meter = pyln.Meter(sr)
|
||||||
|
loudness = meter.integrated_loudness(audio_array)
|
||||||
|
if abs(loudness) > 100:
|
||||||
|
return audio_array
|
||||||
|
normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
|
||||||
|
return normalized_audio
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps = 25):
|
||||||
|
audio_duration = len(speech_array) / sr
|
||||||
|
video_length = audio_duration * fps
|
||||||
|
|
||||||
|
# wav2vec_feature_extractor
|
||||||
|
audio_feature = np.squeeze(
|
||||||
|
wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
|
||||||
|
)
|
||||||
|
audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
|
||||||
|
audio_feature = audio_feature.unsqueeze(0)
|
||||||
|
|
||||||
|
# audio encoder
|
||||||
|
with torch.no_grad():
|
||||||
|
embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
|
||||||
|
|
||||||
|
if len(embeddings) == 0:
|
||||||
|
print("Fail to extract audio embedding")
|
||||||
|
return None
|
||||||
|
|
||||||
|
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
|
||||||
|
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
||||||
|
|
||||||
|
audio_emb = audio_emb.cpu().detach()
|
||||||
|
return audio_emb
|
||||||
|
|
||||||
|
def audio_prepare_single(audio_path, sample_rate=16000, duration = 0):
|
||||||
|
ext = os.path.splitext(audio_path)[1].lower()
|
||||||
|
if ext in ['.mp4', '.mov', '.avi', '.mkv']:
|
||||||
|
human_speech_array = extract_audio_from_video(audio_path, sample_rate)
|
||||||
|
return human_speech_array
|
||||||
|
else:
|
||||||
|
human_speech_array, sr = librosa.load(audio_path, duration=duration, sr=sample_rate)
|
||||||
|
human_speech_array = loudness_norm(human_speech_array, sr)
|
||||||
|
return human_speech_array
|
||||||
|
|
||||||
|
|
||||||
|
def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0):
|
||||||
|
if not (left_path==None or right_path==None):
|
||||||
|
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
||||||
|
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
||||||
|
elif left_path==None:
|
||||||
|
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
||||||
|
human_speech_array1 = np.zeros(human_speech_array2.shape[0])
|
||||||
|
elif right_path==None:
|
||||||
|
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
||||||
|
human_speech_array2 = np.zeros(human_speech_array1.shape[0])
|
||||||
|
|
||||||
|
if audio_type=='para':
|
||||||
|
new_human_speech1 = human_speech_array1
|
||||||
|
new_human_speech2 = human_speech_array2
|
||||||
|
elif audio_type=='add':
|
||||||
|
new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])])
|
||||||
|
new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]])
|
||||||
|
sum_human_speechs = new_human_speech1 + new_human_speech2
|
||||||
|
return new_human_speech1, new_human_speech2, sum_human_speechs
|
||||||
|
|
||||||
|
def process_tts_single(text, save_dir, voice1):
|
||||||
|
s1_sentences = []
|
||||||
|
|
||||||
|
pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
|
||||||
|
|
||||||
|
voice_tensor = torch.load(voice1, weights_only=True)
|
||||||
|
generator = pipeline(
|
||||||
|
text, voice=voice_tensor, # <= change voice here
|
||||||
|
speed=1, split_pattern=r'\n+'
|
||||||
|
)
|
||||||
|
audios = []
|
||||||
|
for i, (gs, ps, audio) in enumerate(generator):
|
||||||
|
audios.append(audio)
|
||||||
|
audios = torch.concat(audios, dim=0)
|
||||||
|
s1_sentences.append(audios)
|
||||||
|
s1_sentences = torch.concat(s1_sentences, dim=0)
|
||||||
|
save_path1 =f'{save_dir}/s1.wav'
|
||||||
|
sf.write(save_path1, s1_sentences, 24000) # save each audio file
|
||||||
|
s1, _ = librosa.load(save_path1, sr=16000)
|
||||||
|
return s1, save_path1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_tts_multi(text, save_dir, voice1, voice2):
|
||||||
|
pattern = r'\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)'
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
s1_sentences = []
|
||||||
|
s2_sentences = []
|
||||||
|
|
||||||
|
pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
|
||||||
|
for idx, (speaker, content) in enumerate(matches):
|
||||||
|
if speaker == '1':
|
||||||
|
voice_tensor = torch.load(voice1, weights_only=True)
|
||||||
|
generator = pipeline(
|
||||||
|
content, voice=voice_tensor, # <= change voice here
|
||||||
|
speed=1, split_pattern=r'\n+'
|
||||||
|
)
|
||||||
|
audios = []
|
||||||
|
for i, (gs, ps, audio) in enumerate(generator):
|
||||||
|
audios.append(audio)
|
||||||
|
audios = torch.concat(audios, dim=0)
|
||||||
|
s1_sentences.append(audios)
|
||||||
|
s2_sentences.append(torch.zeros_like(audios))
|
||||||
|
elif speaker == '2':
|
||||||
|
voice_tensor = torch.load(voice2, weights_only=True)
|
||||||
|
generator = pipeline(
|
||||||
|
content, voice=voice_tensor, # <= change voice here
|
||||||
|
speed=1, split_pattern=r'\n+'
|
||||||
|
)
|
||||||
|
audios = []
|
||||||
|
for i, (gs, ps, audio) in enumerate(generator):
|
||||||
|
audios.append(audio)
|
||||||
|
audios = torch.concat(audios, dim=0)
|
||||||
|
s2_sentences.append(audios)
|
||||||
|
s1_sentences.append(torch.zeros_like(audios))
|
||||||
|
|
||||||
|
s1_sentences = torch.concat(s1_sentences, dim=0)
|
||||||
|
s2_sentences = torch.concat(s2_sentences, dim=0)
|
||||||
|
sum_sentences = s1_sentences + s2_sentences
|
||||||
|
save_path1 =f'{save_dir}/s1.wav'
|
||||||
|
save_path2 =f'{save_dir}/s2.wav'
|
||||||
|
save_path_sum = f'{save_dir}/sum.wav'
|
||||||
|
sf.write(save_path1, s1_sentences, 24000) # save each audio file
|
||||||
|
sf.write(save_path2, s2_sentences, 24000)
|
||||||
|
sf.write(save_path_sum, sum_sentences, 24000)
|
||||||
|
|
||||||
|
s1, _ = librosa.load(save_path1, sr=16000)
|
||||||
|
s2, _ = librosa.load(save_path2, sr=16000)
|
||||||
|
# sum, _ = librosa.load(save_path_sum, sr=16000)
|
||||||
|
return s1, s2, save_path_sum
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000):
|
||||||
|
wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
|
||||||
|
# wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
|
||||||
|
|
||||||
|
new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps)
|
||||||
|
audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
|
||||||
|
audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
|
||||||
|
|
||||||
|
full_audio_embs = []
|
||||||
|
if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
|
||||||
|
# if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
|
||||||
|
if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
|
||||||
|
if audio_guide2 == None: sum_human_speechs = None
|
||||||
|
return full_audio_embs, sum_human_speechs
|
||||||
|
|
||||||
|
|
||||||
|
def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5):
|
||||||
|
HUMAN_NUMBER = len(full_audio_embs)
|
||||||
|
audio_end_idx = audio_start_idx + clip_length
|
||||||
|
indices = (torch.arange(2 * 2 + 1) - 2) * 1
|
||||||
|
|
||||||
|
audio_embs = []
|
||||||
|
# split audio with window size
|
||||||
|
for human_idx in range(HUMAN_NUMBER):
|
||||||
|
center_indices = torch.arange(
|
||||||
|
audio_start_idx,
|
||||||
|
audio_end_idx,
|
||||||
|
1
|
||||||
|
).unsqueeze(
|
||||||
|
1
|
||||||
|
) + indices.unsqueeze(0)
|
||||||
|
center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1).to(full_audio_embs[human_idx].device)
|
||||||
|
audio_emb = full_audio_embs[human_idx][center_indices][None,...] #.to(self.device)
|
||||||
|
audio_embs.append(audio_emb)
|
||||||
|
audio_embs = torch.concat(audio_embs, dim=0) #.to(self.param_dtype)
|
||||||
|
|
||||||
|
# audio_cond = audio.to(device=x.device, dtype=x.dtype)
|
||||||
|
audio_cond = audio_embs
|
||||||
|
first_frame_audio_emb_s = audio_cond[:, :1, ...]
|
||||||
|
latter_frame_audio_emb = audio_cond[:, 1:, ...]
|
||||||
|
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=vae_scale)
|
||||||
|
middle_index = audio_window // 2
|
||||||
|
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||||
|
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||||
|
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||||
|
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||||
|
|
||||||
|
return [first_frame_audio_emb_s, latter_frame_audio_emb_s]
|
||||||
|
|
||||||
|
def resize_and_centercrop(cond_image, target_size):
|
||||||
|
"""
|
||||||
|
Resize image or tensor to the target size without padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the original size
|
||||||
|
if isinstance(cond_image, torch.Tensor):
|
||||||
|
_, orig_h, orig_w = cond_image.shape
|
||||||
|
else:
|
||||||
|
orig_h, orig_w = cond_image.height, cond_image.width
|
||||||
|
|
||||||
|
target_h, target_w = target_size
|
||||||
|
|
||||||
|
# Calculate the scaling factor for resizing
|
||||||
|
scale_h = target_h / orig_h
|
||||||
|
scale_w = target_w / orig_w
|
||||||
|
|
||||||
|
# Compute the final size
|
||||||
|
scale = max(scale_h, scale_w)
|
||||||
|
final_h = math.ceil(scale * orig_h)
|
||||||
|
final_w = math.ceil(scale * orig_w)
|
||||||
|
|
||||||
|
# Resize
|
||||||
|
if isinstance(cond_image, torch.Tensor):
|
||||||
|
if len(cond_image.shape) == 3:
|
||||||
|
cond_image = cond_image[None]
|
||||||
|
resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous()
|
||||||
|
# crop
|
||||||
|
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
|
||||||
|
cropped_tensor = cropped_tensor.squeeze(0)
|
||||||
|
else:
|
||||||
|
resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
|
||||||
|
resized_image = np.array(resized_image)
|
||||||
|
# tensor and crop
|
||||||
|
resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
|
||||||
|
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
|
||||||
|
cropped_tensor = cropped_tensor[:, :, None, :, :]
|
||||||
|
|
||||||
|
return cropped_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_transform(
|
||||||
|
t,
|
||||||
|
shift=5.0,
|
||||||
|
num_timesteps=1000,
|
||||||
|
):
|
||||||
|
t = t / num_timesteps
|
||||||
|
# shift the timestep based on ratio
|
||||||
|
new_t = shift * t / (1 + (shift - 1) * t)
|
||||||
|
new_t = new_t * num_timesteps
|
||||||
|
return new_t
|
||||||
|
|
||||||
|
|
||||||
|
# construct human mask
|
||||||
|
def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None):
|
||||||
|
human_masks = []
|
||||||
|
if HUMAN_NUMBER==1:
|
||||||
|
background_mask = torch.ones([src_h, src_w])
|
||||||
|
human_mask1 = torch.ones([src_h, src_w])
|
||||||
|
human_mask2 = torch.ones([src_h, src_w])
|
||||||
|
human_masks = [human_mask1, human_mask2, background_mask]
|
||||||
|
elif HUMAN_NUMBER==2:
|
||||||
|
if bbox != None:
|
||||||
|
assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio"
|
||||||
|
background_mask = torch.zeros([src_h, src_w])
|
||||||
|
for _, person_bbox in bbox.items():
|
||||||
|
x_min, y_min, x_max, y_max = person_bbox
|
||||||
|
human_mask = torch.zeros([src_h, src_w])
|
||||||
|
human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
|
||||||
|
background_mask += human_mask
|
||||||
|
human_masks.append(human_mask)
|
||||||
|
else:
|
||||||
|
x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale))
|
||||||
|
background_mask = torch.zeros([src_h, src_w])
|
||||||
|
background_mask = torch.zeros([src_h, src_w])
|
||||||
|
human_mask1 = torch.zeros([src_h, src_w])
|
||||||
|
human_mask2 = torch.zeros([src_h, src_w])
|
||||||
|
lefty_min, lefty_max = int((src_w//2) * face_scale), int((src_w//2) * (1 - face_scale))
|
||||||
|
righty_min, righty_max = int((src_w//2) * face_scale + (src_w//2)), int((src_w//2) * (1 - face_scale) + (src_w//2))
|
||||||
|
human_mask1[x_min:x_max, lefty_min:lefty_max] = 1
|
||||||
|
human_mask2[x_min:x_max, righty_min:righty_max] = 1
|
||||||
|
background_mask += human_mask1
|
||||||
|
background_mask += human_mask2
|
||||||
|
human_masks = [human_mask1, human_mask2]
|
||||||
|
background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1))
|
||||||
|
human_masks.append(background_mask)
|
||||||
|
|
||||||
|
ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device)
|
||||||
|
# resize and centercrop for ref_target_masks
|
||||||
|
# ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w))
|
||||||
|
N_h, N_w = lat_h // 2, lat_w // 2
|
||||||
|
token_ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(N_h, N_w), mode='nearest').squeeze()
|
||||||
|
token_ref_target_masks = (token_ref_target_masks > 0)
|
||||||
|
token_ref_target_masks = token_ref_target_masks.float() #.to(self.device)
|
||||||
|
|
||||||
|
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
|
||||||
|
|
||||||
|
return token_ref_target_masks
|
||||||
799
wan/multitalk/multitalk_model.py
Normal file
799
wan/multitalk/multitalk_model.py
Normal file
@ -0,0 +1,799 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from diffusers import ModelMixin
|
||||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
|
||||||
|
from .attention import flash_attention, SingleStreamMutiAttention
|
||||||
|
from ..utils.multitalk_utils import get_attn_map_with_target
|
||||||
|
|
||||||
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
|
# preprocess
|
||||||
|
assert dim % 2 == 0
|
||||||
|
half = dim // 2
|
||||||
|
position = position.type(torch.float64)
|
||||||
|
|
||||||
|
# calculation
|
||||||
|
sinusoid = torch.outer(
|
||||||
|
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_params(max_seq_len, dim, theta=10000):
|
||||||
|
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs = torch.outer(
|
||||||
|
torch.arange(max_seq_len),
|
||||||
|
1.0 / torch.pow(theta,
|
||||||
|
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||||
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_apply(x, grid_sizes, freqs):
|
||||||
|
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
||||||
|
|
||||||
|
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||||
|
seq_len = f * h * w
|
||||||
|
|
||||||
|
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
||||||
|
s, n, -1, 2))
|
||||||
|
freqs_i = torch.cat([
|
||||||
|
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
],
|
||||||
|
dim=-1).reshape(seq_len, 1, -1)
|
||||||
|
freqs_i = freqs_i.to(device=x_i.device)
|
||||||
|
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
||||||
|
x_i = torch.cat([x_i, x[i, seq_len:]])
|
||||||
|
|
||||||
|
output.append(x_i)
|
||||||
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
|
class WanRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x(Tensor): Shape [B, L, C]
|
||||||
|
"""
|
||||||
|
return self._norm(x.float()).type_as(x) * self.weight
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class WanLayerNorm(nn.LayerNorm):
|
||||||
|
|
||||||
|
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
||||||
|
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
origin_dtype = inputs.dtype
|
||||||
|
out = F.layer_norm(
|
||||||
|
inputs.float(),
|
||||||
|
self.normalized_shape,
|
||||||
|
None if self.weight is None else self.weight.float(),
|
||||||
|
None if self.bias is None else self.bias.float() ,
|
||||||
|
self.eps
|
||||||
|
).to(origin_dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WanSelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
eps=1e-6):
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.window_size = window_size
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.q = nn.Linear(dim, dim)
|
||||||
|
self.k = nn.Linear(dim, dim)
|
||||||
|
self.v = nn.Linear(dim, dim)
|
||||||
|
self.o = nn.Linear(dim, dim)
|
||||||
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, seq_lens, grid_sizes, freqs, ref_target_masks=None):
|
||||||
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
# query, key, value function
|
||||||
|
def qkv_fn(x):
|
||||||
|
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||||
|
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||||
|
v = self.v(x).view(b, s, n, d)
|
||||||
|
return q, k, v
|
||||||
|
q, k, v = qkv_fn(x)
|
||||||
|
|
||||||
|
q = rope_apply(q, grid_sizes, freqs)
|
||||||
|
k = rope_apply(k, grid_sizes, freqs)
|
||||||
|
|
||||||
|
|
||||||
|
x = flash_attention(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
k_lens=seq_lens,
|
||||||
|
window_size=self.window_size
|
||||||
|
).type_as(x)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = x.flatten(2)
|
||||||
|
x = self.o(x)
|
||||||
|
with torch.no_grad():
|
||||||
|
x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0],
|
||||||
|
ref_target_masks=ref_target_masks)
|
||||||
|
|
||||||
|
return x, x_ref_attn_map
|
||||||
|
|
||||||
|
|
||||||
|
class WanI2VCrossAttention(WanSelfAttention):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
eps=1e-6):
|
||||||
|
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
||||||
|
|
||||||
|
self.k_img = nn.Linear(dim, dim)
|
||||||
|
self.v_img = nn.Linear(dim, dim)
|
||||||
|
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, context, context_lens):
|
||||||
|
context_img = context[:, :257]
|
||||||
|
context = context[:, 257:]
|
||||||
|
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
# compute query, key, value
|
||||||
|
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
||||||
|
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
||||||
|
v = self.v(context).view(b, -1, n, d)
|
||||||
|
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
||||||
|
v_img = self.v_img(context_img).view(b, -1, n, d)
|
||||||
|
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
||||||
|
# compute attention
|
||||||
|
x = flash_attention(q, k, v, k_lens=context_lens)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = x.flatten(2)
|
||||||
|
img_x = img_x.flatten(2)
|
||||||
|
x = x + img_x
|
||||||
|
x = self.o(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WanAttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cross_attn_type,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
eps=1e-6,
|
||||||
|
output_dim=768,
|
||||||
|
norm_input_visual=True,
|
||||||
|
class_range=24,
|
||||||
|
class_interval=4):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.window_size = window_size
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
self.cross_attn_norm = cross_attn_norm
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm1 = WanLayerNorm(dim, eps)
|
||||||
|
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||||
|
self.norm3 = WanLayerNorm(
|
||||||
|
dim, eps,
|
||||||
|
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||||
|
self.cross_attn = WanI2VCrossAttention(dim,
|
||||||
|
num_heads,
|
||||||
|
(-1, -1),
|
||||||
|
qk_norm,
|
||||||
|
eps)
|
||||||
|
self.norm2 = WanLayerNorm(dim, eps)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
||||||
|
nn.Linear(ffn_dim, dim))
|
||||||
|
|
||||||
|
# modulation
|
||||||
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||||
|
|
||||||
|
# init audio module
|
||||||
|
self.audio_cross_attn = SingleStreamMutiAttention(
|
||||||
|
dim=dim,
|
||||||
|
encoder_hidden_states_dim=output_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qk_norm=False,
|
||||||
|
qkv_bias=True,
|
||||||
|
eps=eps,
|
||||||
|
norm_layer=WanRMSNorm,
|
||||||
|
class_range=class_range,
|
||||||
|
class_interval=class_interval
|
||||||
|
)
|
||||||
|
self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
e,
|
||||||
|
seq_lens,
|
||||||
|
grid_sizes,
|
||||||
|
freqs,
|
||||||
|
context,
|
||||||
|
context_lens,
|
||||||
|
audio_embedding=None,
|
||||||
|
ref_target_masks=None,
|
||||||
|
human_num=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
dtype = x.dtype
|
||||||
|
assert e.dtype == torch.float32
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
|
e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
|
||||||
|
assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
|
# self-attention
|
||||||
|
y, x_ref_attn_map = self.self_attn(
|
||||||
|
(self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes,
|
||||||
|
freqs, ref_target_masks=ref_target_masks)
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
|
x = x + y * e[2]
|
||||||
|
|
||||||
|
x = x.to(dtype)
|
||||||
|
|
||||||
|
# cross-attention of text
|
||||||
|
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||||
|
|
||||||
|
# cross attn of audio
|
||||||
|
x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding,
|
||||||
|
shape=grid_sizes[0], x_ref_attn_map=x_ref_attn_map, human_num=human_num)
|
||||||
|
x = x + x_a
|
||||||
|
|
||||||
|
y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype))
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
|
x = x + y * e[5]
|
||||||
|
|
||||||
|
|
||||||
|
x = x.to(dtype)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Head(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
# layers
|
||||||
|
out_dim = math.prod(patch_size) * out_dim
|
||||||
|
self.norm = WanLayerNorm(dim, eps)
|
||||||
|
self.head = nn.Linear(dim, out_dim)
|
||||||
|
|
||||||
|
# modulation
|
||||||
|
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||||
|
|
||||||
|
def forward(self, x, e):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x(Tensor): Shape [B, L1, C]
|
||||||
|
e(Tensor): Shape [B, C]
|
||||||
|
"""
|
||||||
|
assert e.dtype == torch.float32
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
|
e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
|
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MLPProj(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.proj = torch.nn.Sequential(
|
||||||
|
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
||||||
|
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
||||||
|
torch.nn.LayerNorm(out_dim))
|
||||||
|
|
||||||
|
def forward(self, image_embeds):
|
||||||
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class AudioProjModel(ModelMixin, ConfigMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seq_len=5,
|
||||||
|
seq_len_vf=12,
|
||||||
|
blocks=12,
|
||||||
|
channels=768,
|
||||||
|
intermediate_dim=512,
|
||||||
|
output_dim=768,
|
||||||
|
context_tokens=32,
|
||||||
|
norm_output_audio=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.blocks = blocks
|
||||||
|
self.channels = channels
|
||||||
|
self.input_dim = seq_len * blocks * channels
|
||||||
|
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||||
|
self.intermediate_dim = intermediate_dim
|
||||||
|
self.context_tokens = context_tokens
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# define multiple linear layers
|
||||||
|
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
||||||
|
self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
|
||||||
|
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
||||||
|
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
||||||
|
self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, audio_embeds, audio_embeds_vf):
|
||||||
|
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||||
|
B, _, _, S, C = audio_embeds.shape
|
||||||
|
|
||||||
|
# process audio of first frame
|
||||||
|
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||||
|
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||||
|
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||||
|
|
||||||
|
# process audio of latter frame
|
||||||
|
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||||
|
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||||
|
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||||
|
|
||||||
|
# first projection
|
||||||
|
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||||
|
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||||
|
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||||
|
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||||
|
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||||
|
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||||
|
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||||
|
|
||||||
|
# second projection
|
||||||
|
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||||
|
|
||||||
|
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
|
||||||
|
|
||||||
|
# normalization and reshape
|
||||||
|
context_tokens = self.norm(context_tokens)
|
||||||
|
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||||
|
|
||||||
|
return context_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class WanModel(ModelMixin, ConfigMixin):
|
||||||
|
r"""
|
||||||
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ignore_for_config = [
|
||||||
|
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
||||||
|
]
|
||||||
|
_no_split_modules = ['WanAttentionBlock']
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(self,
|
||||||
|
model_type='i2v',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
# audio params
|
||||||
|
audio_window=5,
|
||||||
|
intermediate_dim=512,
|
||||||
|
output_dim=768,
|
||||||
|
context_tokens=32,
|
||||||
|
vae_scale=4, # vae timedownsample scale
|
||||||
|
|
||||||
|
norm_input_visual=True,
|
||||||
|
norm_output_audio=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.'
|
||||||
|
self.model_type = model_type
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.text_len = text_len
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.dim = dim
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.freq_dim = freq_dim
|
||||||
|
self.text_dim = text_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.window_size = window_size
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
self.cross_attn_norm = cross_attn_norm
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
|
||||||
|
self.norm_output_audio = norm_output_audio
|
||||||
|
self.audio_window = audio_window
|
||||||
|
self.intermediate_dim = intermediate_dim
|
||||||
|
self.vae_scale = vae_scale
|
||||||
|
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
self.patch_embedding = nn.Conv3d(
|
||||||
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
self.text_embedding = nn.Sequential(
|
||||||
|
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
||||||
|
nn.Linear(dim, dim))
|
||||||
|
|
||||||
|
self.time_embedding = nn.Sequential(
|
||||||
|
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||||
|
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
cross_attn_type = 'i2v_cross_attn'
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
||||||
|
window_size, qk_norm, cross_attn_norm, eps,
|
||||||
|
output_dim=output_dim, norm_input_visual=norm_input_visual)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
# head
|
||||||
|
self.head = Head(dim, out_dim, patch_size, eps)
|
||||||
|
|
||||||
|
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||||
|
d = dim // num_heads
|
||||||
|
self.freqs = torch.cat([
|
||||||
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
rope_params(1024, 2 * (d // 6))
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
if model_type == 'i2v':
|
||||||
|
self.img_emb = MLPProj(1280, dim)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Not supported model type.')
|
||||||
|
|
||||||
|
# init audio adapter
|
||||||
|
self.audio_proj = AudioProjModel(
|
||||||
|
seq_len=audio_window,
|
||||||
|
seq_len_vf=audio_window+vae_scale-1,
|
||||||
|
intermediate_dim=intermediate_dim,
|
||||||
|
output_dim=output_dim,
|
||||||
|
context_tokens=context_tokens,
|
||||||
|
norm_output_audio=norm_output_audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# initialize weights
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def teacache_init(
|
||||||
|
self,
|
||||||
|
use_ret_steps=True,
|
||||||
|
teacache_thresh=0.2,
|
||||||
|
sample_steps=40,
|
||||||
|
model_scale='multitalk-480',
|
||||||
|
):
|
||||||
|
print("teacache_init")
|
||||||
|
self.enable_teacache = True
|
||||||
|
|
||||||
|
self.__class__.cnt = 0
|
||||||
|
self.__class__.num_steps = sample_steps*3
|
||||||
|
self.__class__.teacache_thresh = teacache_thresh
|
||||||
|
self.__class__.accumulated_rel_l1_distance_even = 0
|
||||||
|
self.__class__.accumulated_rel_l1_distance_odd = 0
|
||||||
|
self.__class__.previous_e0_even = None
|
||||||
|
self.__class__.previous_e0_odd = None
|
||||||
|
self.__class__.previous_residual_even = None
|
||||||
|
self.__class__.previous_residual_odd = None
|
||||||
|
self.__class__.use_ret_steps = use_ret_steps
|
||||||
|
|
||||||
|
if use_ret_steps:
|
||||||
|
if model_scale == 'multitalk-480':
|
||||||
|
self.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
||||||
|
if model_scale == 'multitalk-720':
|
||||||
|
self.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
||||||
|
self.__class__.ret_steps = 5*3
|
||||||
|
self.__class__.cutoff_steps = sample_steps*3
|
||||||
|
else:
|
||||||
|
if model_scale == 'multitalk-480':
|
||||||
|
self.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||||
|
|
||||||
|
if model_scale == 'multitalk-720':
|
||||||
|
self.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||||
|
self.__class__.ret_steps = 1*3
|
||||||
|
self.__class__.cutoff_steps = sample_steps*3 - 3
|
||||||
|
print("teacache_init done")
|
||||||
|
|
||||||
|
def disable_teacache(self):
|
||||||
|
self.enable_teacache = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
context,
|
||||||
|
seq_len,
|
||||||
|
clip_fea=None,
|
||||||
|
y=None,
|
||||||
|
audio=None,
|
||||||
|
ref_target_masks=None,
|
||||||
|
):
|
||||||
|
assert clip_fea is not None and y is not None
|
||||||
|
|
||||||
|
_, T, H, W = x[0].shape
|
||||||
|
N_t = T // self.patch_size[0]
|
||||||
|
N_h = H // self.patch_size[1]
|
||||||
|
N_w = W // self.patch_size[2]
|
||||||
|
|
||||||
|
if y is not None:
|
||||||
|
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
|
x[0] = x[0].to(context[0].dtype)
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||||
|
grid_sizes = torch.stack(
|
||||||
|
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||||
|
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||||
|
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||||
|
assert seq_lens.max() <= seq_len
|
||||||
|
x = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in x
|
||||||
|
])
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||||
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||||
|
|
||||||
|
# text embedding
|
||||||
|
context_lens = None
|
||||||
|
context = self.text_embedding(
|
||||||
|
torch.stack([
|
||||||
|
torch.cat(
|
||||||
|
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||||
|
for u in context
|
||||||
|
]))
|
||||||
|
|
||||||
|
# clip embedding
|
||||||
|
if clip_fea is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea)
|
||||||
|
context = torch.concat([context_clip, context], dim=1).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
audio_cond = audio.to(device=x.device, dtype=x.dtype)
|
||||||
|
first_frame_audio_emb_s = audio_cond[:, :1, ...]
|
||||||
|
latter_frame_audio_emb = audio_cond[:, 1:, ...]
|
||||||
|
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale)
|
||||||
|
middle_index = self.audio_window // 2
|
||||||
|
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||||
|
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||||
|
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||||
|
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||||
|
audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
||||||
|
human_num = len(audio_embedding)
|
||||||
|
audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# convert ref_target_masks to token_ref_target_masks
|
||||||
|
if ref_target_masks is not None:
|
||||||
|
ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32)
|
||||||
|
token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest')
|
||||||
|
token_ref_target_masks = token_ref_target_masks.squeeze(0)
|
||||||
|
token_ref_target_masks = (token_ref_target_masks > 0)
|
||||||
|
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
|
||||||
|
token_ref_target_masks = token_ref_target_masks.to(x.dtype)
|
||||||
|
|
||||||
|
# teacache
|
||||||
|
if self.enable_teacache:
|
||||||
|
modulated_inp = e0 if self.use_ret_steps else e
|
||||||
|
if self.cnt%3==0: # cond
|
||||||
|
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||||
|
should_calc_cond = True
|
||||||
|
self.accumulated_rel_l1_distance_cond = 0
|
||||||
|
else:
|
||||||
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
|
self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance_cond < self.teacache_thresh:
|
||||||
|
should_calc_cond = False
|
||||||
|
else:
|
||||||
|
should_calc_cond = True
|
||||||
|
self.accumulated_rel_l1_distance_cond = 0
|
||||||
|
self.previous_e0_cond = modulated_inp.clone()
|
||||||
|
elif self.cnt%3==1: # drop_text
|
||||||
|
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||||
|
should_calc_drop_text = True
|
||||||
|
self.accumulated_rel_l1_distance_drop_text = 0
|
||||||
|
else:
|
||||||
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
|
self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh:
|
||||||
|
should_calc_drop_text = False
|
||||||
|
else:
|
||||||
|
should_calc_drop_text = True
|
||||||
|
self.accumulated_rel_l1_distance_drop_text = 0
|
||||||
|
self.previous_e0_drop_text = modulated_inp.clone()
|
||||||
|
else: # uncond
|
||||||
|
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||||
|
should_calc_uncond = True
|
||||||
|
self.accumulated_rel_l1_distance_uncond = 0
|
||||||
|
else:
|
||||||
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
|
self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh:
|
||||||
|
should_calc_uncond = False
|
||||||
|
else:
|
||||||
|
should_calc_uncond = True
|
||||||
|
self.accumulated_rel_l1_distance_uncond = 0
|
||||||
|
self.previous_e0_uncond = modulated_inp.clone()
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
kwargs = dict(
|
||||||
|
e=e0,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
grid_sizes=grid_sizes,
|
||||||
|
freqs=self.freqs,
|
||||||
|
context=context,
|
||||||
|
context_lens=context_lens,
|
||||||
|
audio_embedding=audio_embedding,
|
||||||
|
ref_target_masks=token_ref_target_masks,
|
||||||
|
human_num=human_num,
|
||||||
|
)
|
||||||
|
if self.enable_teacache:
|
||||||
|
if self.cnt%3==0:
|
||||||
|
if not should_calc_cond:
|
||||||
|
x += self.previous_residual_cond
|
||||||
|
else:
|
||||||
|
ori_x = x.clone()
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
self.previous_residual_cond = x - ori_x
|
||||||
|
elif self.cnt%3==1:
|
||||||
|
if not should_calc_drop_text:
|
||||||
|
x += self.previous_residual_drop_text
|
||||||
|
else:
|
||||||
|
ori_x = x.clone()
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
self.previous_residual_drop_text = x - ori_x
|
||||||
|
else:
|
||||||
|
if not should_calc_uncond:
|
||||||
|
x += self.previous_residual_uncond
|
||||||
|
else:
|
||||||
|
ori_x = x.clone()
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
self.previous_residual_uncond = x - ori_x
|
||||||
|
else:
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
if self.enable_teacache:
|
||||||
|
self.cnt += 1
|
||||||
|
if self.cnt >= self.num_steps:
|
||||||
|
self.cnt = 0
|
||||||
|
|
||||||
|
return torch.stack(x).float()
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(self, x, grid_sizes):
|
||||||
|
r"""
|
||||||
|
Reconstruct video tensors from patch embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (List[Tensor]):
|
||||||
|
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||||||
|
grid_sizes (Tensor):
|
||||||
|
Original spatial-temporal grid dimensions before patching,
|
||||||
|
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tensor]:
|
||||||
|
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
||||||
|
"""
|
||||||
|
|
||||||
|
c = self.out_dim
|
||||||
|
out = []
|
||||||
|
for u, v in zip(x, grid_sizes.tolist()):
|
||||||
|
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
||||||
|
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
||||||
|
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
||||||
|
out.append(u)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
r"""
|
||||||
|
Initialize model parameters using Xavier initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# basic init
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
# init embeddings
|
||||||
|
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
||||||
|
for m in self.text_embedding.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, std=.02)
|
||||||
|
for m in self.time_embedding.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, std=.02)
|
||||||
|
|
||||||
|
# init output layer
|
||||||
|
nn.init.zeros_(self.head.head.weight)
|
||||||
353
wan/multitalk/multitalk_utils.py
Normal file
353
wan/multitalk/multitalk_utils.py
Normal file
@ -0,0 +1,353 @@
|
|||||||
|
import os
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from functools import lru_cache
|
||||||
|
import imageio
|
||||||
|
import uuid
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import subprocess
|
||||||
|
import soundfile as sf
|
||||||
|
import torchvision
|
||||||
|
import binascii
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
|
||||||
|
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||||
|
ASPECT_RATIO_627 = {
|
||||||
|
'0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1),
|
||||||
|
'0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1),
|
||||||
|
'1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1),
|
||||||
|
'3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)}
|
||||||
|
|
||||||
|
|
||||||
|
ASPECT_RATIO_960 = {
|
||||||
|
'0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1),
|
||||||
|
'0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1),
|
||||||
|
'1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1),
|
||||||
|
'1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1),
|
||||||
|
'2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1),
|
||||||
|
'3.75': ([1920, 512], 1)}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def torch_gc():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def split_token_counts_and_frame_ids(T, token_frame, world_size, rank):
|
||||||
|
|
||||||
|
S = T * token_frame
|
||||||
|
split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)]
|
||||||
|
start = sum(split_sizes[:rank])
|
||||||
|
end = start + split_sizes[rank]
|
||||||
|
counts = [0] * T
|
||||||
|
for idx in range(start, end):
|
||||||
|
t = idx // token_frame
|
||||||
|
counts[t] += 1
|
||||||
|
|
||||||
|
counts_filtered = []
|
||||||
|
frame_ids = []
|
||||||
|
for t, c in enumerate(counts):
|
||||||
|
if c > 0:
|
||||||
|
counts_filtered.append(c)
|
||||||
|
frame_ids.append(t)
|
||||||
|
return counts_filtered, frame_ids
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||||||
|
|
||||||
|
source_min, source_max = source_range
|
||||||
|
new_min, new_max = target_range
|
||||||
|
|
||||||
|
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||||||
|
scaled = normalized * (new_max - new_min) + new_min
|
||||||
|
return scaled
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.compile
|
||||||
|
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
|
||||||
|
|
||||||
|
ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
|
||||||
|
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||||
|
visual_q = visual_q * scale
|
||||||
|
visual_q = visual_q.transpose(1, 2)
|
||||||
|
ref_k = ref_k.transpose(1, 2)
|
||||||
|
attn = visual_q @ ref_k.transpose(-2, -1)
|
||||||
|
|
||||||
|
if attn_bias is not None: attn += attn_bias
|
||||||
|
|
||||||
|
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
|
||||||
|
|
||||||
|
x_ref_attn_maps = []
|
||||||
|
ref_target_masks = ref_target_masks.to(visual_q.dtype)
|
||||||
|
x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
|
||||||
|
|
||||||
|
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||||
|
ref_target_mask = ref_target_mask[None, None, None, ...]
|
||||||
|
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
|
||||||
|
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
|
||||||
|
x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
|
||||||
|
|
||||||
|
if mode == 'mean':
|
||||||
|
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
|
||||||
|
elif mode == 'max':
|
||||||
|
x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
|
||||||
|
|
||||||
|
x_ref_attn_maps.append(x_ref_attnmap)
|
||||||
|
|
||||||
|
del attn
|
||||||
|
del x_ref_attn_map_source
|
||||||
|
|
||||||
|
return torch.concat(x_ref_attn_maps, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
|
||||||
|
"""Args:
|
||||||
|
query (torch.tensor): B M H K
|
||||||
|
key (torch.tensor): B M H K
|
||||||
|
shape (tuple): (N_t, N_h, N_w)
|
||||||
|
ref_target_masks: [B, N_h * N_w]
|
||||||
|
"""
|
||||||
|
|
||||||
|
N_t, N_h, N_w = shape
|
||||||
|
|
||||||
|
x_seqlens = N_h * N_w
|
||||||
|
ref_k = ref_k[:, :x_seqlens]
|
||||||
|
if ref_images_count > 0 :
|
||||||
|
visual_q_shape = visual_q.shape
|
||||||
|
visual_q = visual_q.reshape(visual_q_shape[0], N_t, -1)
|
||||||
|
visual_q = visual_q[:, ref_images_count:]
|
||||||
|
visual_q = visual_q.reshape(visual_q_shape[0], -1, *visual_q_shape[-2:])
|
||||||
|
|
||||||
|
_, seq_lens, heads, _ = visual_q.shape
|
||||||
|
class_num, _ = ref_target_masks.shape
|
||||||
|
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
|
||||||
|
|
||||||
|
split_chunk = heads // split_num
|
||||||
|
|
||||||
|
for i in range(split_num):
|
||||||
|
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
|
||||||
|
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||||||
|
|
||||||
|
x_ref_attn_maps /= split_num
|
||||||
|
return x_ref_attn_maps
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||||
|
x1, x2 = x.unbind(dim=-1)
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return rearrange(x, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionalEmbedding1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
head_dim,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.base = 10000
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def precompute_freqs_cis_1d(self, pos_indices):
|
||||||
|
|
||||||
|
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||||||
|
freqs = freqs.to(pos_indices.device)
|
||||||
|
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||||||
|
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward(self, x, pos_indices):
|
||||||
|
"""1D RoPE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (torch.tensor): [B, head, seq, head_dim]
|
||||||
|
pos_indices (torch.tensor): [seq,]
|
||||||
|
Returns:
|
||||||
|
query with the same shape as input.
|
||||||
|
"""
|
||||||
|
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||||||
|
|
||||||
|
x_ = x.float()
|
||||||
|
|
||||||
|
freqs_cis = freqs_cis.float().to(x.device)
|
||||||
|
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||||
|
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||||
|
x_ = (x_ * cos) + (rotate_half(x_) * sin)
|
||||||
|
|
||||||
|
return x_.type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def rand_name(length=8, suffix=''):
|
||||||
|
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
||||||
|
if suffix:
|
||||||
|
if not suffix.startswith('.'):
|
||||||
|
suffix = '.' + suffix
|
||||||
|
name += suffix
|
||||||
|
return name
|
||||||
|
|
||||||
|
def cache_video(tensor,
|
||||||
|
save_file=None,
|
||||||
|
fps=30,
|
||||||
|
suffix='.mp4',
|
||||||
|
nrow=8,
|
||||||
|
normalize=True,
|
||||||
|
value_range=(-1, 1),
|
||||||
|
retry=5):
|
||||||
|
|
||||||
|
# cache file
|
||||||
|
cache_file = osp.join('/tmp', rand_name(
|
||||||
|
suffix=suffix)) if save_file is None else save_file
|
||||||
|
|
||||||
|
# save to cache
|
||||||
|
error = None
|
||||||
|
for _ in range(retry):
|
||||||
|
|
||||||
|
# preprocess
|
||||||
|
tensor = tensor.clamp(min(value_range), max(value_range))
|
||||||
|
tensor = torch.stack([
|
||||||
|
torchvision.utils.make_grid(
|
||||||
|
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||||
|
for u in tensor.unbind(2)
|
||||||
|
],
|
||||||
|
dim=1).permute(1, 2, 3, 0)
|
||||||
|
tensor = (tensor * 255).type(torch.uint8).cpu()
|
||||||
|
|
||||||
|
# write video
|
||||||
|
writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"])
|
||||||
|
for frame in tensor.numpy():
|
||||||
|
writer.append_data(frame)
|
||||||
|
writer.close()
|
||||||
|
return cache_file
|
||||||
|
|
||||||
|
def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False):
|
||||||
|
|
||||||
|
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||||
|
writer = imageio.get_writer(
|
||||||
|
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
|
||||||
|
)
|
||||||
|
for frame in tqdm(frames, desc="Saving video"):
|
||||||
|
frame = np.array(frame)
|
||||||
|
writer.append_data(frame)
|
||||||
|
writer.close()
|
||||||
|
save_path_tmp = save_path + "-temp.mp4"
|
||||||
|
|
||||||
|
if high_quality_save:
|
||||||
|
cache_video(
|
||||||
|
tensor=gen_video_samples.unsqueeze(0),
|
||||||
|
save_file=save_path_tmp,
|
||||||
|
fps=fps,
|
||||||
|
nrow=1,
|
||||||
|
normalize=True,
|
||||||
|
value_range=(-1, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
video_audio = (gen_video_samples+1)/2 # C T H W
|
||||||
|
video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy()
|
||||||
|
video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255]
|
||||||
|
save_video(video_audio, save_path_tmp, fps=fps, quality=quality)
|
||||||
|
|
||||||
|
|
||||||
|
# crop audio according to video length
|
||||||
|
_, T, _, _ = gen_video_samples.shape
|
||||||
|
duration = T / fps
|
||||||
|
save_path_crop_audio = save_path + "-cropaudio.wav"
|
||||||
|
final_command = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-i",
|
||||||
|
vocal_audio_list[0],
|
||||||
|
"-t",
|
||||||
|
f'{duration}',
|
||||||
|
save_path_crop_audio,
|
||||||
|
]
|
||||||
|
subprocess.run(final_command, check=True)
|
||||||
|
|
||||||
|
save_path = save_path + ".mp4"
|
||||||
|
if high_quality_save:
|
||||||
|
final_command = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i", save_path_tmp,
|
||||||
|
"-i", save_path_crop_audio,
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-crf", "0",
|
||||||
|
"-preset", "veryslow",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-shortest",
|
||||||
|
save_path,
|
||||||
|
]
|
||||||
|
subprocess.run(final_command, check=True)
|
||||||
|
os.remove(save_path_tmp)
|
||||||
|
os.remove(save_path_crop_audio)
|
||||||
|
else:
|
||||||
|
final_command = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i",
|
||||||
|
save_path_tmp,
|
||||||
|
"-i",
|
||||||
|
save_path_crop_audio,
|
||||||
|
"-c:v",
|
||||||
|
"libx264",
|
||||||
|
"-c:a",
|
||||||
|
"aac",
|
||||||
|
"-shortest",
|
||||||
|
save_path,
|
||||||
|
]
|
||||||
|
subprocess.run(final_command, check=True)
|
||||||
|
os.remove(save_path_tmp)
|
||||||
|
os.remove(save_path_crop_audio)
|
||||||
|
|
||||||
|
|
||||||
|
class MomentumBuffer:
|
||||||
|
def __init__(self, momentum: float):
|
||||||
|
self.momentum = momentum
|
||||||
|
self.running_average = 0
|
||||||
|
|
||||||
|
def update(self, update_value: torch.Tensor):
|
||||||
|
new_average = self.momentum * self.running_average
|
||||||
|
self.running_average = update_value + new_average
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def project(
|
||||||
|
v0: torch.Tensor, # [B, C, T, H, W]
|
||||||
|
v1: torch.Tensor, # [B, C, T, H, W]
|
||||||
|
):
|
||||||
|
dtype = v0.dtype
|
||||||
|
v0, v1 = v0.double(), v1.double()
|
||||||
|
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4])
|
||||||
|
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1
|
||||||
|
v0_orthogonal = v0 - v0_parallel
|
||||||
|
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_projected_guidance(
|
||||||
|
diff: torch.Tensor, # [B, C, T, H, W]
|
||||||
|
pred_cond: torch.Tensor, # [B, C, T, H, W]
|
||||||
|
momentum_buffer: MomentumBuffer = None,
|
||||||
|
eta: float = 0.0,
|
||||||
|
norm_threshold: float = 55,
|
||||||
|
):
|
||||||
|
if momentum_buffer is not None:
|
||||||
|
momentum_buffer.update(diff)
|
||||||
|
diff = momentum_buffer.running_average
|
||||||
|
if norm_threshold > 0:
|
||||||
|
ones = torch.ones_like(diff)
|
||||||
|
diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True)
|
||||||
|
print(f"diff_norm: {diff_norm}")
|
||||||
|
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||||
|
diff = diff * scale_factor
|
||||||
|
diff_parallel, diff_orthogonal = project(diff, pred_cond)
|
||||||
|
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||||
|
return normalized_update
|
||||||
20
wan/multitalk/torch_utils.py
Normal file
20
wan/multitalk/torch_utils.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def get_mask_from_lengths(lengths, max_len=None):
|
||||||
|
lengths = lengths.to(torch.long)
|
||||||
|
if max_len is None:
|
||||||
|
max_len = torch.max(lengths).item()
|
||||||
|
|
||||||
|
ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
|
||||||
|
mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def linear_interpolation(features, seq_len):
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
||||||
|
return output_features.transpose(1, 2)
|
||||||
|
|
||||||
125
wan/multitalk/wav2vec2.py
Normal file
125
wan/multitalk/wav2vec2.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from transformers import Wav2Vec2Config, Wav2Vec2Model
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
|
||||||
|
from .torch_utils import linear_interpolation
|
||||||
|
|
||||||
|
# the implementation of Wav2Vec2Model is borrowed from
|
||||||
|
# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
||||||
|
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
||||||
|
class Wav2Vec2Model(Wav2Vec2Model):
|
||||||
|
def __init__(self, config: Wav2Vec2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
seq_len,
|
||||||
|
attention_mask=None,
|
||||||
|
mask_time_indices=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
self.config.output_attentions = True
|
||||||
|
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
extract_features = self.feature_extractor(input_values)
|
||||||
|
extract_features = extract_features.transpose(1, 2)
|
||||||
|
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# compute reduced attention_mask corresponding to feature vectors
|
||||||
|
attention_mask = self._get_feature_vector_attention_mask(
|
||||||
|
extract_features.shape[1], attention_mask, add_adapter=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||||
|
hidden_states = self._mask_hidden_states(
|
||||||
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
if self.adapter is not None:
|
||||||
|
hidden_states = self.adapter(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (hidden_states, ) + encoder_outputs[1:]
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def feature_extract(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
seq_len,
|
||||||
|
):
|
||||||
|
extract_features = self.feature_extractor(input_values)
|
||||||
|
extract_features = extract_features.transpose(1, 2)
|
||||||
|
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
||||||
|
|
||||||
|
return extract_features
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
extract_features,
|
||||||
|
attention_mask=None,
|
||||||
|
mask_time_indices=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
self.config.output_attentions = True
|
||||||
|
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# compute reduced attention_mask corresponding to feature vectors
|
||||||
|
attention_mask = self._get_feature_vector_attention_mask(
|
||||||
|
extract_features.shape[1], attention_mask, add_adapter=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||||
|
hidden_states = self._mask_hidden_states(
|
||||||
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
if self.adapter is not None:
|
||||||
|
hidden_states = self.adapter(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (hidden_states, ) + encoder_outputs[1:]
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue
Block a user