multitalk files

This commit is contained in:
DeepBeepMeep 2025-07-08 18:44:55 +02:00
parent 621687c12a
commit 3a8bd05c6e
13 changed files with 3570 additions and 0 deletions

382
wan/multitalk/attention.py Normal file
View File

@ -0,0 +1,382 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
from einops import rearrange, repeat
from .multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids
from wan.modules.attention import pay_attention
import xformers.ops
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
'flash_attention',
'attention',
]
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
# output
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out
class SingleStreamAttention(nn.Module):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
qk_norm: bool,
norm_layer: nn.Module,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
eps: float = 1e-6,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.encoder_hidden_states_dim = encoder_hidden_states_dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qk_norm = qk_norm
self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias)
self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
N_t, N_h, N_w = shape
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# get q for hidden_state
B, N, C = x.shape
q = self.q_linear(x)
q_shape = (B, N, self.num_heads, self.head_dim)
q = q.view(q_shape).permute((0, 2, 1, 3))
if self.qk_norm:
q = self.q_norm(q)
# get kv from encoder_hidden_states
_, N_a, _ = encoder_hidden_states.shape
encoder_kv = self.kv_linear(encoder_hidden_states)
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
encoder_k, encoder_v = encoder_kv.unbind(0)
if self.qk_norm:
encoder_k = self.add_k_norm(encoder_k)
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
attn_bias = None
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
qkv_list = [q, encoder_k, encoder_v]
q = encoder_k = encoder_v = None
x = pay_attention(qkv_list)
x = rearrange(x, "B M H K -> B H M K")
# linear transform
x_output_shape = (B, N, C)
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
# reshape x to origin shape
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
return x
class SingleStreamMutiAttention(SingleStreamAttention):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
qk_norm: bool,
norm_layer: nn.Module,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
eps: float = 1e-6,
class_range: int = 24,
class_interval: int = 4,
) -> None:
super().__init__(
dim=dim,
encoder_hidden_states_dim=encoder_hidden_states_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
norm_layer=norm_layer,
attn_drop=attn_drop,
proj_drop=proj_drop,
eps=eps,
)
self.class_interval = class_interval
self.class_range = class_range
self.rope_h1 = (0, self.class_interval)
self.rope_h2 = (self.class_range - self.class_interval, self.class_range)
self.rope_bak = int(self.class_range // 2)
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
def forward(self,
x: torch.Tensor,
encoder_hidden_states: torch.Tensor,
shape=None,
x_ref_attn_map=None,
) -> torch.Tensor:
encoder_hidden_states = encoder_hidden_states.squeeze(0)
if x_ref_attn_map == None:
return super().forward(x, encoder_hidden_states, shape)
N_t, _, _ = shape
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# get q for hidden_state
B, N, C = x.shape
q = self.q_linear(x)
q_shape = (B, N, self.num_heads, self.head_dim)
q = q.view(q_shape).permute((0, 2, 1, 3))
if self.qk_norm:
q = self.q_norm(q)
max_values = x_ref_attn_map.max(1).values[:, None, None]
min_values = x_ref_attn_map.min(1).values[:, None, None]
max_min_values = torch.cat([max_values, min_values], dim=2)
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
max_indices = x_ref_attn_map.argmax(dim=0)
normalized_map = torch.stack([human1, human2, back], dim=1)
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
q = self.rope_1d(q, normalized_pos)
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
_, N_a, _ = encoder_hidden_states.shape
encoder_kv = self.kv_linear(encoder_hidden_states)
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
encoder_k, encoder_v = encoder_kv.unbind(0)
if self.qk_norm:
encoder_k = self.add_k_norm(encoder_k)
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device)
per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
encoder_pos = torch.concat([per_frame]*N_t, dim=0)
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
encoder_k = self.rope_1d(encoder_k, encoder_pos)
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
qkv_list = [q, encoder_k, encoder_v]
q = encoder_k = encoder_v = None
x = pay_attention(qkv_list)
x = rearrange(x, "B M H K -> B H M K")
# linear transform
x_output_shape = (B, N, C)
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
# reshape x to origin shape
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
return x

View File

@ -0,0 +1,23 @@
__version__ = '0.9.4'
from loguru import logger
import sys
# Remove default handler
logger.remove()
# Add custom handler with clean format including module and line number
logger.add(
sys.stderr,
format="<green>{time:HH:mm:ss}</green> | <cyan>{module:>16}:{line}</cyan> | <level>{level: >8}</level> | <level>{message}</level>",
colorize=True,
level="INFO" # "DEBUG" to enable logger.debug("message") and up prints
# "ERROR" to enable only logger.error("message") prints
# etc
)
# Disable before release or as needed
logger.disable("kokoro")
from .model import KModel
from .pipeline import KPipeline

View File

@ -0,0 +1,148 @@
"""Kokoro TTS CLI
Example usage:
python3 -m kokoro --text "The sky above the port was the color of television, tuned to a dead channel." -o file.wav --debug
echo "Bom dia mundo, como vão vocês" > text.txt
python3 -m kokoro -i text.txt -l p --voice pm_alex > audio.wav
Common issues:
pip not installed: `uv pip install pip`
(Temporary workaround while https://github.com/explosion/spaCy/issues/13747 is not fixed)
espeak not installed: `apt-get install espeak-ng`
"""
import argparse
import wave
from pathlib import Path
from typing import Generator, TYPE_CHECKING
import numpy as np
from loguru import logger
languages = [
"a", # American English
"b", # British English
"h", # Hindi
"e", # Spanish
"f", # French
"i", # Italian
"p", # Brazilian Portuguese
"j", # Japanese
"z", # Mandarin Chinese
]
if TYPE_CHECKING:
from kokoro import KPipeline
def generate_audio(
text: str, kokoro_language: str, voice: str, speed=1
) -> Generator["KPipeline.Result", None, None]:
from kokoro import KPipeline
if not voice.startswith(kokoro_language):
logger.warning(f"Voice {voice} is not made for language {kokoro_language}")
pipeline = KPipeline(lang_code=kokoro_language)
yield from pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")
def generate_and_save_audio(
output_file: Path, text: str, kokoro_language: str, voice: str, speed=1
) -> None:
with wave.open(str(output_file.resolve()), "wb") as wav_file:
wav_file.setnchannels(1) # Mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (16-bit audio)
wav_file.setframerate(24000) # Sample rate
for result in generate_audio(
text, kokoro_language=kokoro_language, voice=voice, speed=speed
):
logger.debug(result.phonemes)
if result.audio is None:
continue
audio_bytes = (result.audio.numpy() * 32767).astype(np.int16).tobytes()
wav_file.writeframes(audio_bytes)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--voice",
default="af_heart",
help="Voice to use",
)
parser.add_argument(
"-l",
"--language",
help="Language to use (defaults to the one corresponding to the voice)",
choices=languages,
)
parser.add_argument(
"-o",
"--output-file",
"--output_file",
type=Path,
help="Path to output WAV file",
required=True,
)
parser.add_argument(
"-i",
"--input-file",
"--input_file",
type=Path,
help="Path to input text file (default: stdin)",
)
parser.add_argument(
"-t",
"--text",
help="Text to use instead of reading from stdin",
)
parser.add_argument(
"-s",
"--speed",
type=float,
default=1.0,
help="Speech speed",
)
parser.add_argument(
"--debug",
action="store_true",
help="Print DEBUG messages to console",
)
args = parser.parse_args()
if args.debug:
logger.level("DEBUG")
logger.debug(args)
lang = args.language or args.voice[0]
if args.text is not None and args.input_file is not None:
raise Exception("You cannot specify both 'text' and 'input_file'")
elif args.text:
text = args.text
elif args.input_file:
file: Path = args.input_file
text = file.read_text()
else:
import sys
print("Press Ctrl+D to stop reading input and start generating", flush=True)
text = '\n'.join(sys.stdin)
logger.debug(f"Input text: {text!r}")
out_file: Path = args.output_file
if not out_file.suffix == ".wav":
logger.warning("The output file name should end with .wav")
generate_and_save_audio(
output_file=out_file,
text=text,
kokoro_language=lang,
voice=args.voice,
speed=args.speed,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,197 @@
from attr import attr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class CustomSTFT(nn.Module):
"""
STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d.
- forward STFT => Real-part conv1d + Imag-part conv1d
- inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum
- avoids F.unfold, so easier to export to ONNX
- uses replicate or constant padding for 'center=True' to approximate 'reflect'
(reflect is not supported for dynamic shapes in ONNX)
"""
def __init__(
self,
filter_length=800,
hop_length=200,
win_length=800,
window="hann",
center=True,
pad_mode="replicate", # or 'constant'
):
super().__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.n_fft = filter_length
self.center = center
self.pad_mode = pad_mode
# Number of frequency bins for real-valued STFT with onesided=True
self.freq_bins = self.n_fft // 2 + 1
# Build window
assert window == 'hann', window
window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
if self.win_length < self.n_fft:
# Zero-pad up to n_fft
extra = self.n_fft - self.win_length
window_tensor = F.pad(window_tensor, (0, extra))
elif self.win_length > self.n_fft:
window_tensor = window_tensor[: self.n_fft]
self.register_buffer("window", window_tensor)
# Precompute forward DFT (real, imag)
# PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...)
n = np.arange(self.n_fft)
k = np.arange(self.freq_bins)
angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft)
dft_real = np.cos(angle)
dft_imag = -np.sin(angle) # note negative sign
# Combine window and dft => shape (freq_bins, filter_length)
# We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length).
forward_window = window_tensor.numpy() # shape (n_fft,)
forward_real = dft_real * forward_window # (freq_bins, n_fft)
forward_imag = dft_imag * forward_window
# Convert to PyTorch
forward_real_torch = torch.from_numpy(forward_real).float()
forward_imag_torch = torch.from_numpy(forward_imag).float()
# Register as Conv1d weight => (out_channels, in_channels, kernel_size)
# out_channels = freq_bins, in_channels=1, kernel_size=n_fft
self.register_buffer(
"weight_forward_real", forward_real_torch.unsqueeze(1)
)
self.register_buffer(
"weight_forward_imag", forward_imag_torch.unsqueeze(1)
)
# Precompute inverse DFT
# Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc.
# For simplicity, we won't do the "DC/nyquist not doubled" approach here.
# If you want perfect real iSTFT, you can add that logic.
# This version just yields good approximate reconstruction with Hann + typical overlap.
inv_scale = 1.0 / self.n_fft
n = np.arange(self.n_fft)
angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins)
idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft)
idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft)
# Multiply by window again for typical overlap-add
# We also incorporate the scale factor 1/n_fft
inv_window = window_tensor.numpy() * inv_scale
backward_real = idft_cos * inv_window # (freq_bins, n_fft)
backward_imag = idft_sin * inv_window
# We'll implement iSTFT as real+imag conv_transpose with stride=hop.
self.register_buffer(
"weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1)
)
self.register_buffer(
"weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1)
)
def transform(self, waveform: torch.Tensor):
"""
Forward STFT => returns magnitude, phase
Output shape => (batch, freq_bins, frames)
"""
# waveform shape => (B, T). conv1d expects (B, 1, T).
# Optional center pad
if self.center:
pad_len = self.n_fft // 2
waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode)
x = waveform.unsqueeze(1) # => (B, 1, T)
# Convolution to get real part => shape (B, freq_bins, frames)
real_out = F.conv1d(
x,
self.weight_forward_real,
bias=None,
stride=self.hop_length,
padding=0,
)
# Imag part
imag_out = F.conv1d(
x,
self.weight_forward_imag,
bias=None,
stride=self.hop_length,
padding=0,
)
# magnitude, phase
magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14)
phase = torch.atan2(imag_out, real_out)
# Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch
# In this case, PyTorch returns pi, ONNX returns -pi
correction_mask = (imag_out == 0) & (real_out < 0)
phase[correction_mask] = torch.pi
return magnitude, phase
def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None):
"""
Inverse STFT => returns waveform shape (B, T).
"""
# magnitude, phase => (B, freq_bins, frames)
# Re-create real/imag => shape (B, freq_bins, frames)
real_part = magnitude * torch.cos(phase)
imag_part = magnitude * torch.sin(phase)
# conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension
# so we do (B, freq_bins, frames) => (B, freq_bins, frames)
# But PyTorch conv_transpose1d expects (B, in_channels, input_length)
real_part = real_part # (B, freq_bins, frames)
imag_part = imag_part
# real iSTFT => convolve with "backward_real", "backward_imag", and sum
# We'll do 2 conv_transpose calls, each giving (B, 1, time),
# then add them => (B, 1, time).
real_rec = F.conv_transpose1d(
real_part,
self.weight_backward_real, # shape (freq_bins, 1, filter_length)
bias=None,
stride=self.hop_length,
padding=0,
)
imag_rec = F.conv_transpose1d(
imag_part,
self.weight_backward_imag,
bias=None,
stride=self.hop_length,
padding=0,
)
# sum => (B, 1, time)
waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part
# If we used "center=True" in forward, we should remove pad
if self.center:
pad_len = self.n_fft // 2
# Because of transposed convolution, total length might have extra samples
# We remove `pad_len` from start & end if possible
waveform = waveform[..., pad_len:-pad_len]
# If a specific length is desired, clamp
if length is not None:
waveform = waveform[..., :length]
# shape => (B, T)
return waveform
def forward(self, x: torch.Tensor):
"""
Full STFT -> iSTFT pass: returns time-domain reconstruction.
Same interface as your original code.
"""
mag, phase = self.transform(x)
return self.inverse(mag, phase, length=x.shape[-1])

View File

@ -0,0 +1,421 @@
# ADAPTED from https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
from .custom_stft import CustomSTFT
from torch.nn.utils import weight_norm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
class AdaIN1d(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
# affine should be False, however there's a bug in the old torch.onnx.export (not newer dynamo) that causes the channel dimension to be lost if affine=False. When affine is true, there's additional learnably parameters. This shouldn't really matter setting it to True, since we're in inference mode
self.norm = nn.InstanceNorm1d(num_features, affine=True)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class AdaINResBlock1(nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
super(AdaINResBlock1, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
self.adain1 = nn.ModuleList([
AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels),
])
self.adain2 = nn.ModuleList([
AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels),
])
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
def forward(self, x, s):
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
xt = n1(x, s)
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
xt = c1(xt)
xt = n2(xt, s)
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
xt = c2(xt)
x = xt + x
return x
class TorchSTFT(nn.Module):
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
super().__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
assert window == 'hann', window
self.window = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
def transform(self, input_data):
forward_transform = torch.stft(
input_data,
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
return_complex=True)
return torch.abs(forward_transform), torch.angle(forward_transform)
def inverse(self, magnitude, phase):
inverse_transform = torch.istft(
magnitude * torch.exp(phase * 1j),
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
class SineGen(nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(torch.pi) or cos(0)
"""
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0,
flag_for_pulse=False):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
self.flag_for_pulse = flag_for_pulse
self.upsample_scale = upsample_scale
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
def _f02sine(self, f0_values):
""" f0_values: (batchsize, length, dim)
where dim indicates fundamental tone and overtones
"""
# convert to F0 in rad. The interger part n can be ignored
# because 2 * torch.pi * n doesn't affect phase
rad_values = (f0_values / self.sampling_rate) % 1
# initial phase noise (no noise for fundamental component)
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
if not self.flag_for_pulse:
rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
phase = torch.cumsum(rad_values, dim=1) * 2 * torch.pi
phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
sines = torch.sin(phase)
else:
# If necessary, make sure that the first time step of every
# voiced segments is sin(pi) or cos(0)
# This is used for pulse-train generation
# identify the last time step in unvoiced segments
uv = self._f02uv(f0_values)
uv_1 = torch.roll(uv, shifts=-1, dims=1)
uv_1[:, -1, :] = 1
u_loc = (uv < 1) * (uv_1 > 0)
# get the instantanouse phase
tmp_cumsum = torch.cumsum(rad_values, dim=1)
# different batch needs to be processed differently
for idx in range(f0_values.shape[0]):
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
# stores the accumulation of i.phase within
# each voiced segments
tmp_cumsum[idx, :, :] = 0
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
# rad_values - tmp_cumsum: remove the accumulation of i.phase
# within the previous voiced segment.
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
# get the sines
sines = torch.cos(i_phase * 2 * torch.pi)
return sines
def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
# fundamental component
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
# generate sine waveforms
sine_waves = self._f02sine(fn) * self.sine_amp
# generate uv signal
# uv = torch.ones(f0.shape)
# uv = uv * (f0 > self.voiced_threshold)
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = nn.Linear(harmonic_num + 1, 1)
self.l_tanh = nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
class Generator(nn.Module):
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=24000,
upsample_scale=math.prod(upsample_rates) * gen_istft_hop_size,
harmonic_num=8, voiced_threshod=10)
self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates) * gen_istft_hop_size)
self.noise_convs = nn.ModuleList()
self.noise_res = nn.ModuleList()
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(
nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel//(2**(i+1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
c_cur = upsample_initial_channel // (2 ** (i + 1))
if i + 1 < len(upsample_rates):
stride_f0 = math.prod(upsample_rates[i + 1:])
self.noise_convs.append(nn.Conv1d(
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
else:
self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
self.post_n_fft = gen_istft_n_fft
self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft = (
CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
if disable_complex
else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
)
def forward(self, x, s, f0):
with torch.no_grad():
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, noi_source, uv = self.m_source(f0)
har_source = har_source.transpose(1, 2).squeeze(1)
har_spec, har_phase = self.stft.transform(har_source)
har = torch.cat([har_spec, har_phase], dim=1)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, negative_slope=0.1)
x_source = self.noise_convs[i](har)
x_source = self.noise_res[i](x_source, s)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
x = x + x_source
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x, s)
else:
xs += self.resblocks[i*self.num_kernels+j](x, s)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
return self.stft.inverse(spec, phase)
class UpSample1d(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
else:
return F.interpolate(x, scale_factor=2, mode='nearest')
class AdainResBlk1d(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
super().__init__()
self.actv = actv
self.upsample_type = upsample
self.upsample = UpSample1d(upsample)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
self.dropout = nn.Dropout(dropout_p)
if upsample == 'none':
self.pool = nn.Identity()
else:
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
def _build_weights(self, dim_in, dim_out, style_dim):
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
self.norm1 = AdaIN1d(style_dim, dim_in)
self.norm2 = AdaIN1d(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
def _shortcut(self, x):
x = self.upsample(x)
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
x = self.pool(x)
x = self.conv1(self.dropout(x))
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(self.dropout(x))
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) * torch.rsqrt(torch.tensor(2))
return out
class Decoder(nn.Module):
def __init__(self, dim_in, style_dim, dim_out,
resblock_kernel_sizes,
upsample_rates,
upsample_initial_channel,
resblock_dilation_sizes,
upsample_kernel_sizes,
gen_istft_n_fft, gen_istft_hop_size,
disable_complex=False):
super().__init__()
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
self.decode = nn.ModuleList()
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
upsample_initial_channel, resblock_dilation_sizes,
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex)
def forward(self, asr, F0_curve, N, s):
F0 = self.F0_conv(F0_curve.unsqueeze(1))
N = self.N_conv(N.unsqueeze(1))
x = torch.cat([asr, F0, N], axis=1)
x = self.encode(x, s)
asr_res = self.asr_res(asr)
res = True
for block in self.decode:
if res:
x = torch.cat([x, asr_res, F0, N], axis=1)
x = block(x, s)
if block.upsample_type != "none":
res = False
x = self.generator(x, s, F0_curve)
return x

View File

@ -0,0 +1,155 @@
from .istftnet import Decoder
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from transformers import AlbertConfig
from typing import Dict, Optional, Union
import json
import torch
import os
class KModel(torch.nn.Module):
'''
KModel is a torch.nn.Module with 2 main responsibilities:
1. Init weights, downloading config.json + model.pth from HF if needed
2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
You likely only need one KModel instance, and it can be reused across
multiple KPipelines to avoid redundant memory allocation.
Unlike KPipeline, KModel is language-blind.
KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
so there is no need to repeatedly download config.json outside of KModel.
'''
MODEL_NAMES = {
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
}
def __init__(
self,
repo_id: Optional[str] = None,
config: Union[Dict, str, None] = None,
model: Optional[str] = None,
disable_complex: bool = False
):
super().__init__()
if repo_id is None:
repo_id = 'hexgrad/Kokoro-82M'
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
self.repo_id = repo_id
if not isinstance(config, dict):
if not config:
logger.debug("No config provided, downloading from HF")
config = hf_hub_download(repo_id=repo_id, filename='config.json')
with open(config, 'r', encoding='utf-8') as r:
config = json.load(r)
logger.debug(f"Loaded config: {config}")
self.vocab = config['vocab']
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
self.context_length = self.bert.config.max_position_embeddings
self.predictor = ProsodyPredictor(
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
)
self.text_encoder = TextEncoder(
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
depth=config['n_layer'], n_symbols=config['n_token']
)
self.decoder = Decoder(
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
)
if not model:
try:
model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id])
except:
model = os.path.join(repo_id, 'kokoro-v1_0.pth')
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
assert hasattr(self, key), key
try:
getattr(self, key).load_state_dict(state_dict)
except:
logger.debug(f"Did not load {key} from state_dict")
state_dict = {k[7:]: v for k, v in state_dict.items()}
getattr(self, key).load_state_dict(state_dict, strict=False)
@property
def device(self):
return self.bert.device
@dataclass
class Output:
audio: torch.FloatTensor
pred_dur: Optional[torch.LongTensor] = None
@torch.no_grad()
def forward_with_tokens(
self,
input_ids: torch.LongTensor,
ref_s: torch.FloatTensor,
speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]:
input_lengths = torch.full(
(input_ids.shape[0],),
input_ids.shape[-1],
device=input_ids.device,
dtype=torch.long
)
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
en = d.transpose(-1, -2) @ pred_aln_trg
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
return audio, pred_dur
def forward(
self,
phonemes: str,
ref_s: torch.FloatTensor,
speed: float = 1,
return_output: bool = False
) -> Union['KModel.Output', torch.FloatTensor]:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
ref_s = ref_s.to(self.device)
audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed)
audio = audio.squeeze().cpu()
pred_dur = pred_dur.cpu() if pred_dur is not None else None
logger.debug(f"pred_dur: {pred_dur}")
return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio
class KModelForONNX(torch.nn.Module):
def __init__(self, kmodel: KModel):
super().__init__()
self.kmodel = kmodel
def forward(
self,
input_ids: torch.LongTensor,
ref_s: torch.FloatTensor,
speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]:
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
return waveform, duration

View File

@ -0,0 +1,183 @@
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
from .istftnet import AdainResBlk1d
from torch.nn.utils import weight_norm
from transformers import AlbertModel
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearNorm(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class TextEncoder(nn.Module):
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
super().__init__()
self.embedding = nn.Embedding(n_symbols, channels)
padding = (kernel_size - 1) // 2
self.cnn = nn.ModuleList()
for _ in range(depth):
self.cnn.append(nn.Sequential(
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
LayerNorm(channels),
actv,
nn.Dropout(0.2),
))
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
def forward(self, x, input_lengths, m):
x = self.embedding(x) # [B, T, emb]
x = x.transpose(1, 2) # [B, emb, T]
m = m.unsqueeze(1)
x.masked_fill_(m, 0.0)
for c in self.cnn:
x = c(x)
x.masked_fill_(m, 0.0)
x = x.transpose(1, 2) # [B, T, chn]
lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu')
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
x_pad[:, :, :x.shape[-1]] = x
x = x_pad
x.masked_fill_(m, 0.0)
return x
class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.fc = nn.Linear(style_dim, channels*2)
def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
class ProsodyPredictor(nn.Module):
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
super().__init__()
self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.duration_proj = LinearNorm(d_hid, max_dur)
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.F0 = nn.ModuleList()
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.N = nn.ModuleList()
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
def forward(self, texts, style, text_lengths, alignment, m):
d = self.text_encoder(texts, style, text_lengths, m)
m = m.unsqueeze(1)
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device)
x_pad[:, :x.shape[1], :] = x
x = x_pad
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
en = (d.transpose(-1, -2) @ alignment)
return duration.squeeze(-1), en
def F0Ntrain(self, x, s):
x, _ = self.shared(x.transpose(-1, -2))
F0 = x.transpose(-1, -2)
for block in self.F0:
F0 = block(F0, s)
F0 = self.F0_proj(F0)
N = x.transpose(-1, -2)
for block in self.N:
N = block(N, s)
N = self.N_proj(N)
return F0.squeeze(1), N.squeeze(1)
class DurationEncoder(nn.Module):
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
super().__init__()
self.lstms = nn.ModuleList()
for _ in range(nlayers):
self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
self.dropout = dropout
self.d_model = d_model
self.sty_dim = sty_dim
def forward(self, x, style, text_lengths, m):
masks = m
x = x.permute(2, 0, 1)
s = style.expand(x.shape[0], x.shape[1], -1)
x = torch.cat([x, s], axis=-1)
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
x = x.transpose(0, 1)
x = x.transpose(-1, -2)
for block in self.lstms:
if isinstance(block, AdaLayerNorm):
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
else:
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
x = x.transpose(-1, -2)
x = nn.utils.rnn.pack_padded_sequence(
x, lengths, batch_first=True, enforce_sorted=False)
block.flatten_parameters()
x, _ = block(x)
x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True)
x = F.dropout(x, p=self.dropout, training=False)
x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
x_pad[:, :, :x.shape[-1]] = x
x = x_pad
return x.transpose(-1, -2)
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs):
outputs = super().forward(*args, **kwargs)
return outputs.last_hidden_state

View File

@ -0,0 +1,445 @@
from .model import KModel
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from misaki import en, espeak
from typing import Callable, Generator, List, Optional, Tuple, Union
import re
import torch
import os
ALIASES = {
'en-us': 'a',
'en-gb': 'b',
'es': 'e',
'fr-fr': 'f',
'hi': 'h',
'it': 'i',
'pt-br': 'p',
'ja': 'j',
'zh': 'z',
}
LANG_CODES = dict(
# pip install misaki[en]
a='American English',
b='British English',
# espeak-ng
e='es',
f='fr-fr',
h='hi',
i='it',
p='pt-br',
# pip install misaki[ja]
j='Japanese',
# pip install misaki[zh]
z='Mandarin Chinese',
)
class KPipeline:
'''
KPipeline is a language-aware support class with 2 main responsibilities:
1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
2. Manage and store voices, lazily downloaded from HF if needed
You are expected to have one KPipeline per language. If you have multiple
KPipelines, you should reuse one KModel instance across all of them.
KPipeline is designed to work with a KModel, but this is not required.
There are 2 ways to pass an existing model into a pipeline:
1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
2. On call: us_pipeline(text, voice, model=model)
By default, KPipeline will automatically initialize its own KModel. To
suppress this, construct a "quiet" KPipeline with model=False.
A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
any audio. You can use this to phonemize and chunk your text in advance.
A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
'''
def __init__(
self,
lang_code: str,
repo_id: Optional[str] = None,
model: Union[KModel, bool] = True,
trf: bool = False,
en_callable: Optional[Callable[[str], str]] = None,
device: Optional[str] = None
):
"""Initialize a KPipeline.
Args:
lang_code: Language code for G2P processing
model: KModel instance, True to create new model, False for no model
trf: Whether to use transformer-based G2P
device: Override default device selection ('cuda' or 'cpu', or None for auto)
If None, will auto-select cuda if available
If 'cuda' and not available, will explicitly raise an error
"""
if repo_id is None:
repo_id = 'hexgrad/Kokoro-82M'
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
config=None
else:
config = os.path.join(repo_id, 'config.json')
self.repo_id = repo_id
lang_code = lang_code.lower()
lang_code = ALIASES.get(lang_code, lang_code)
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
self.lang_code = lang_code
self.model = None
if isinstance(model, KModel):
self.model = model
elif model:
if device == 'cuda' and not torch.cuda.is_available():
raise RuntimeError("CUDA requested but not available")
if device == 'mps' and not torch.backends.mps.is_available():
raise RuntimeError("MPS requested but not available")
if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
raise RuntimeError("MPS requested but fallback not enabled")
if device is None:
if torch.cuda.is_available():
device = 'cuda'
elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
try:
self.model = KModel(repo_id=repo_id, config=config).to(device).eval()
except RuntimeError as e:
if device == 'cuda':
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
Try setting device='cpu' or check CUDA installation.""")
raise
self.voices = {}
if lang_code in 'ab':
try:
fallback = espeak.EspeakFallback(british=lang_code=='b')
except Exception as e:
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
logger.warning({str(e)})
fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
elif lang_code == 'j':
try:
from misaki import ja
self.g2p = ja.JAG2P()
except ImportError:
logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
raise
elif lang_code == 'z':
try:
from misaki import zh
self.g2p = zh.ZHG2P(
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
en_callable=en_callable
)
except ImportError:
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
raise
else:
language = LANG_CODES[lang_code]
logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
self.g2p = espeak.EspeakG2P(language=language)
def load_single_voice(self, voice: str):
if voice in self.voices:
return self.voices[voice]
if voice.endswith('.pt'):
f = voice
else:
f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
if not voice.startswith(self.lang_code):
v = LANG_CODES.get(voice, voice)
p = LANG_CODES.get(self.lang_code, self.lang_code)
logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
pack = torch.load(f, weights_only=True)
self.voices[voice] = pack
return pack
"""
load_voice is a helper function that lazily downloads and loads a voice:
Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica').
If multiple voices are requested, they are averaged.
Delimiter is optional and defaults to ','.
"""
def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
if isinstance(voice, torch.FloatTensor):
return voice
if voice in self.voices:
return self.voices[voice]
logger.debug(f"Loading voice: {voice}")
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
if len(packs) == 1:
return packs[0]
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
return self.voices[voice]
@staticmethod
def tokens_to_ps(tokens: List[en.MToken]) -> str:
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
@staticmethod
def waterfall_last(
tokens: List[en.MToken],
next_count: int,
waterfall: List[str] = ['!.?…', ':;', ',—'],
bumps: List[str] = [')', '']
) -> int:
for w in waterfall:
z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
if z is None:
continue
z += 1
if z < len(tokens) and tokens[z].phonemes in bumps:
z += 1
if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
return z
return len(tokens)
@staticmethod
def tokens_to_text(tokens: List[en.MToken]) -> str:
return ''.join(t.text + t.whitespace for t in tokens).strip()
def en_tokenize(
self,
tokens: List[en.MToken]
) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
tks = []
pcount = 0
for t in tokens:
# American English: ɾ => T
t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T')
next_ps = t.phonemes + (' ' if t.whitespace else '')
next_pcount = pcount + len(next_ps.rstrip())
if next_pcount > 510:
z = KPipeline.waterfall_last(tks, next_pcount)
text = KPipeline.tokens_to_text(tks[:z])
logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
ps = KPipeline.tokens_to_ps(tks[:z])
yield text, ps, tks[:z]
tks = tks[z:]
pcount = len(KPipeline.tokens_to_ps(tks))
if not tks:
next_ps = next_ps.lstrip()
tks.append(t)
pcount += len(next_ps)
if tks:
text = KPipeline.tokens_to_text(tks)
ps = KPipeline.tokens_to_ps(tks)
yield ''.join(text).strip(), ''.join(ps).strip(), tks
@staticmethod
def infer(
model: KModel,
ps: str,
pack: torch.FloatTensor,
speed: Union[float, Callable[[int], float]] = 1
) -> KModel.Output:
if callable(speed):
speed = speed(len(ps))
return model(ps, pack[len(ps)-1], speed, return_output=True)
def generate_from_tokens(
self,
tokens: Union[str, List[en.MToken]],
voice: str,
speed: float = 1,
model: Optional[KModel] = None
) -> Generator['KPipeline.Result', None, None]:
"""Generate audio from either raw phonemes or pre-processed tokens.
Args:
tokens: Either a phoneme string or list of pre-processed MTokens
voice: The voice to use for synthesis
speed: Speech speed modifier (default: 1)
model: Optional KModel instance (uses pipeline's model if not provided)
Yields:
KPipeline.Result containing the input tokens and generated audio
Raises:
ValueError: If no voice is provided or token sequence exceeds model limits
"""
model = model or self.model
if model and voice is None:
raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')
pack = self.load_voice(voice).to(model.device) if model else None
# Handle raw phoneme string
if isinstance(tokens, str):
logger.debug("Processing phonemes from raw string")
if len(tokens) > 510:
raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
output = KPipeline.infer(model, tokens, pack, speed) if model else None
yield self.Result(graphemes='', phonemes=tokens, output=output)
return
logger.debug("Processing MTokens")
# Handle pre-processed tokens
for gs, ps, tks in self.en_tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
logger.warning("Truncating to 510 characters")
ps = ps[:510]
output = KPipeline.infer(model, ps, pack, speed) if model else None
if output is not None and output.pred_dur is not None:
KPipeline.join_timestamps(tks, output.pred_dur)
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
@staticmethod
def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
# We will count nice round half-frames, so the divisor is 80
MAGIC_DIVISOR = 80
if not tokens or len(pred_dur) < 3:
# We expect at least 3: <bos>, token, <eos>
return
# We track 2 counts, measured in half-frames: (left, right)
# This way we can cut space characters in half
# TODO: Is -3 an appropriate offset?
left = right = 2 * max(0, pred_dur[0].item() - 3)
# Updates:
# left = right + (2 * token_dur) + space_dur
# right = left + space_dur
i = 1
for t in tokens:
if i >= len(pred_dur)-1:
break
if not t.phonemes:
if t.whitespace:
i += 1
left = right + pred_dur[i].item()
right = left + pred_dur[i].item()
i += 1
continue
j = i + len(t.phonemes)
if j >= len(pred_dur):
break
t.start_ts = left / MAGIC_DIVISOR
token_dur = pred_dur[i: j].sum().item()
space_dur = pred_dur[j].item() if t.whitespace else 0
left = right + (2 * token_dur) + space_dur
t.end_ts = left / MAGIC_DIVISOR
right = left + space_dur
i = j + (1 if t.whitespace else 0)
@dataclass
class Result:
graphemes: str
phonemes: str
tokens: Optional[List[en.MToken]] = None
output: Optional[KModel.Output] = None
text_index: Optional[int] = None
@property
def audio(self) -> Optional[torch.FloatTensor]:
return None if self.output is None else self.output.audio
@property
def pred_dur(self) -> Optional[torch.LongTensor]:
return None if self.output is None else self.output.pred_dur
### MARK: BEGIN BACKWARD COMPAT ###
def __iter__(self):
yield self.graphemes
yield self.phonemes
yield self.audio
def __getitem__(self, index):
return [self.graphemes, self.phonemes, self.audio][index]
def __len__(self):
return 3
#### MARK: END BACKWARD COMPAT ####
def __call__(
self,
text: Union[str, List[str]],
voice: Optional[str] = None,
speed: Union[float, Callable[[int], float]] = 1,
split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None
) -> Generator['KPipeline.Result', None, None]:
model = model or self.model
if model and voice is None:
raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
pack = self.load_voice(voice).to(model.device) if model else None
# Convert input to list of segments
if isinstance(text, str):
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
# Process each segment
for graphemes_index, graphemes in enumerate(text):
if not graphemes.strip(): # Skip empty segments
continue
# English processing (unchanged)
if self.lang_code in 'ab':
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
_, tokens = self.g2p(graphemes)
for gs, ps, tks in self.en_tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
ps = ps[:510]
output = KPipeline.infer(model, ps, pack, speed) if model else None
if output is not None and output.pred_dur is not None:
KPipeline.join_timestamps(tks, output.pred_dur)
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index)
# Non-English processing with chunking
else:
# Split long text into smaller chunks (roughly 400 characters each)
# Using sentence boundaries when possible
chunk_size = 400
chunks = []
# Try to split on sentence boundaries first
sentences = re.split(r'([.!?]+)', graphemes)
current_chunk = ""
for i in range(0, len(sentences), 2):
sentence = sentences[i]
# Add the punctuation back if it exists
if i + 1 < len(sentences):
sentence += sentences[i + 1]
if len(current_chunk) + len(sentence) <= chunk_size:
current_chunk += sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk.strip())
# If no chunks were created (no sentence boundaries), fall back to character-based chunking
if not chunks:
chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)]
# Process each chunk
for chunk in chunks:
if not chunk.strip():
continue
ps, _ = self.g2p(chunk)
if not ps:
continue
elif len(ps) > 510:
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
ps = ps[:510]
output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)

319
wan/multitalk/multitalk.py Normal file
View File

@ -0,0 +1,319 @@
import random
import os
import torch
import torch.distributed as dist
from PIL import Image
import subprocess
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
import wan
from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.utils import cache_image, cache_video, str2bool
# from wan.utils.multitalk_utils import save_video_ffmpeg
# from .kokoro import KPipeline
from transformers import Wav2Vec2FeatureExtractor
from .wav2vec2 import Wav2Vec2Model
import librosa
import pyloudnorm as pyln
import numpy as np
from einops import rearrange
import soundfile as sf
import re
import math
def custom_init(device, wav2vec):
audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device)
audio_encoder.feature_extractor._freeze_parameters()
wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
return wav2vec_feature_extractor, audio_encoder
def loudness_norm(audio_array, sr=16000, lufs=-23):
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(audio_array)
if abs(loudness) > 100:
return audio_array
normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
return normalized_audio
def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps = 25):
audio_duration = len(speech_array) / sr
video_length = audio_duration * fps
# wav2vec_feature_extractor
audio_feature = np.squeeze(
wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
)
audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
audio_feature = audio_feature.unsqueeze(0)
# audio encoder
with torch.no_grad():
embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
if len(embeddings) == 0:
print("Fail to extract audio embedding")
return None
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
audio_emb = rearrange(audio_emb, "b s d -> s b d")
audio_emb = audio_emb.cpu().detach()
return audio_emb
def audio_prepare_single(audio_path, sample_rate=16000, duration = 0):
ext = os.path.splitext(audio_path)[1].lower()
if ext in ['.mp4', '.mov', '.avi', '.mkv']:
human_speech_array = extract_audio_from_video(audio_path, sample_rate)
return human_speech_array
else:
human_speech_array, sr = librosa.load(audio_path, duration=duration, sr=sample_rate)
human_speech_array = loudness_norm(human_speech_array, sr)
return human_speech_array
def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0):
if not (left_path==None or right_path==None):
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
elif left_path==None:
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
human_speech_array1 = np.zeros(human_speech_array2.shape[0])
elif right_path==None:
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
human_speech_array2 = np.zeros(human_speech_array1.shape[0])
if audio_type=='para':
new_human_speech1 = human_speech_array1
new_human_speech2 = human_speech_array2
elif audio_type=='add':
new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])])
new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]])
sum_human_speechs = new_human_speech1 + new_human_speech2
return new_human_speech1, new_human_speech2, sum_human_speechs
def process_tts_single(text, save_dir, voice1):
s1_sentences = []
pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
voice_tensor = torch.load(voice1, weights_only=True)
generator = pipeline(
text, voice=voice_tensor, # <= change voice here
speed=1, split_pattern=r'\n+'
)
audios = []
for i, (gs, ps, audio) in enumerate(generator):
audios.append(audio)
audios = torch.concat(audios, dim=0)
s1_sentences.append(audios)
s1_sentences = torch.concat(s1_sentences, dim=0)
save_path1 =f'{save_dir}/s1.wav'
sf.write(save_path1, s1_sentences, 24000) # save each audio file
s1, _ = librosa.load(save_path1, sr=16000)
return s1, save_path1
def process_tts_multi(text, save_dir, voice1, voice2):
pattern = r'\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)'
matches = re.findall(pattern, text, re.DOTALL)
s1_sentences = []
s2_sentences = []
pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
for idx, (speaker, content) in enumerate(matches):
if speaker == '1':
voice_tensor = torch.load(voice1, weights_only=True)
generator = pipeline(
content, voice=voice_tensor, # <= change voice here
speed=1, split_pattern=r'\n+'
)
audios = []
for i, (gs, ps, audio) in enumerate(generator):
audios.append(audio)
audios = torch.concat(audios, dim=0)
s1_sentences.append(audios)
s2_sentences.append(torch.zeros_like(audios))
elif speaker == '2':
voice_tensor = torch.load(voice2, weights_only=True)
generator = pipeline(
content, voice=voice_tensor, # <= change voice here
speed=1, split_pattern=r'\n+'
)
audios = []
for i, (gs, ps, audio) in enumerate(generator):
audios.append(audio)
audios = torch.concat(audios, dim=0)
s2_sentences.append(audios)
s1_sentences.append(torch.zeros_like(audios))
s1_sentences = torch.concat(s1_sentences, dim=0)
s2_sentences = torch.concat(s2_sentences, dim=0)
sum_sentences = s1_sentences + s2_sentences
save_path1 =f'{save_dir}/s1.wav'
save_path2 =f'{save_dir}/s2.wav'
save_path_sum = f'{save_dir}/sum.wav'
sf.write(save_path1, s1_sentences, 24000) # save each audio file
sf.write(save_path2, s2_sentences, 24000)
sf.write(save_path_sum, sum_sentences, 24000)
s1, _ = librosa.load(save_path1, sr=16000)
s2, _ = librosa.load(save_path2, sr=16000)
# sum, _ = librosa.load(save_path_sum, sr=16000)
return s1, s2, save_path_sum
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000):
wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
# wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps)
audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
full_audio_embs = []
if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
# if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
if audio_guide2 == None: sum_human_speechs = None
return full_audio_embs, sum_human_speechs
def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5):
HUMAN_NUMBER = len(full_audio_embs)
audio_end_idx = audio_start_idx + clip_length
indices = (torch.arange(2 * 2 + 1) - 2) * 1
audio_embs = []
# split audio with window size
for human_idx in range(HUMAN_NUMBER):
center_indices = torch.arange(
audio_start_idx,
audio_end_idx,
1
).unsqueeze(
1
) + indices.unsqueeze(0)
center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1).to(full_audio_embs[human_idx].device)
audio_emb = full_audio_embs[human_idx][center_indices][None,...] #.to(self.device)
audio_embs.append(audio_emb)
audio_embs = torch.concat(audio_embs, dim=0) #.to(self.param_dtype)
# audio_cond = audio.to(device=x.device, dtype=x.dtype)
audio_cond = audio_embs
first_frame_audio_emb_s = audio_cond[:, :1, ...]
latter_frame_audio_emb = audio_cond[:, 1:, ...]
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=vae_scale)
middle_index = audio_window // 2
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
return [first_frame_audio_emb_s, latter_frame_audio_emb_s]
def resize_and_centercrop(cond_image, target_size):
"""
Resize image or tensor to the target size without padding.
"""
# Get the original size
if isinstance(cond_image, torch.Tensor):
_, orig_h, orig_w = cond_image.shape
else:
orig_h, orig_w = cond_image.height, cond_image.width
target_h, target_w = target_size
# Calculate the scaling factor for resizing
scale_h = target_h / orig_h
scale_w = target_w / orig_w
# Compute the final size
scale = max(scale_h, scale_w)
final_h = math.ceil(scale * orig_h)
final_w = math.ceil(scale * orig_w)
# Resize
if isinstance(cond_image, torch.Tensor):
if len(cond_image.shape) == 3:
cond_image = cond_image[None]
resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous()
# crop
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
cropped_tensor = cropped_tensor.squeeze(0)
else:
resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
resized_image = np.array(resized_image)
# tensor and crop
resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
cropped_tensor = cropped_tensor[:, :, None, :, :]
return cropped_tensor
def timestep_transform(
t,
shift=5.0,
num_timesteps=1000,
):
t = t / num_timesteps
# shift the timestep based on ratio
new_t = shift * t / (1 + (shift - 1) * t)
new_t = new_t * num_timesteps
return new_t
# construct human mask
def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None):
human_masks = []
if HUMAN_NUMBER==1:
background_mask = torch.ones([src_h, src_w])
human_mask1 = torch.ones([src_h, src_w])
human_mask2 = torch.ones([src_h, src_w])
human_masks = [human_mask1, human_mask2, background_mask]
elif HUMAN_NUMBER==2:
if bbox != None:
assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio"
background_mask = torch.zeros([src_h, src_w])
for _, person_bbox in bbox.items():
x_min, y_min, x_max, y_max = person_bbox
human_mask = torch.zeros([src_h, src_w])
human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
background_mask += human_mask
human_masks.append(human_mask)
else:
x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale))
background_mask = torch.zeros([src_h, src_w])
background_mask = torch.zeros([src_h, src_w])
human_mask1 = torch.zeros([src_h, src_w])
human_mask2 = torch.zeros([src_h, src_w])
lefty_min, lefty_max = int((src_w//2) * face_scale), int((src_w//2) * (1 - face_scale))
righty_min, righty_max = int((src_w//2) * face_scale + (src_w//2)), int((src_w//2) * (1 - face_scale) + (src_w//2))
human_mask1[x_min:x_max, lefty_min:lefty_max] = 1
human_mask2[x_min:x_max, righty_min:righty_max] = 1
background_mask += human_mask1
background_mask += human_mask2
human_masks = [human_mask1, human_mask2]
background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1))
human_masks.append(background_mask)
ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device)
# resize and centercrop for ref_target_masks
# ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w))
N_h, N_w = lat_h // 2, lat_w // 2
token_ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(N_h, N_w), mode='nearest').squeeze()
token_ref_target_masks = (token_ref_target_masks > 0)
token_ref_target_masks = token_ref_target_masks.float() #.to(self.device)
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
return token_ref_target_masks

View File

@ -0,0 +1,799 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import numpy as np
import os
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from .attention import flash_attention, SingleStreamMutiAttention
from ..utils.multitalk_utils import get_attn_map_with_target
__all__ = ['WanModel']
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
s, n, c = x.size(1), x.size(2), x.size(3) // 2
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
freqs_i = freqs_i.to(device=x_i.device)
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
output.append(x_i)
return torch.stack(output).float()
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
out = F.layer_norm(
inputs.float(),
self.normalized_shape,
None if self.weight is None else self.weight.float(),
None if self.bias is None else self.bias.float() ,
self.eps
).to(origin_dtype)
return out
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs, ref_target_masks=None):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
x = flash_attention(
q=q,
k=k,
v=v,
k_lens=seq_lens,
window_size=self.window_size
).type_as(x)
# output
x = x.flatten(2)
x = self.o(x)
with torch.no_grad():
x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0],
ref_target_masks=ref_target_masks)
return x, x_ref_attn_map
class WanI2VCrossAttention(WanSelfAttention):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
super().__init__(dim, num_heads, window_size, qk_norm, eps)
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
context_img = context[:, :257]
context = context[:, 257:]
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img, k_lens=None)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
img_x = img_x.flatten(2)
x = x + img_x
x = self.o(x)
return x
class WanAttentionBlock(nn.Module):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
output_dim=768,
norm_input_visual=True,
class_range=24,
class_interval=4):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(
dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WanI2VCrossAttention(dim,
num_heads,
(-1, -1),
qk_norm,
eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
# init audio module
self.audio_cross_attn = SingleStreamMutiAttention(
dim=dim,
encoder_hidden_states_dim=output_dim,
num_heads=num_heads,
qk_norm=False,
qkv_bias=True,
eps=eps,
norm_layer=WanRMSNorm,
class_range=class_range,
class_interval=class_interval
)
self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity()
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
audio_embedding=None,
ref_target_masks=None,
human_num=None,
):
dtype = x.dtype
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention
y, x_ref_attn_map = self.self_attn(
(self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes,
freqs, ref_target_masks=ref_target_masks)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
x = x.to(dtype)
# cross-attention of text
x = x + self.cross_attn(self.norm3(x), context, context_lens)
# cross attn of audio
x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding,
shape=grid_sizes[0], x_ref_attn_map=x_ref_attn_map, human_num=human_num)
x = x + x_a
y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype))
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
x = x.to(dtype)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim))
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class AudioProjModel(ModelMixin, ConfigMixin):
def __init__(
self,
seq_len=5,
seq_len_vf=12,
blocks=12,
channels=768,
intermediate_dim=512,
output_dim=768,
context_tokens=32,
norm_output_audio=False,
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = seq_len * blocks * channels
self.input_dim_vf = seq_len_vf * blocks * channels
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.output_dim = output_dim
# define multiple linear layers
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
def forward(self, audio_embeds, audio_embeds_vf):
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
B, _, _, S, C = audio_embeds.shape
# process audio of first frame
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
# process audio of latter frame
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
# first projection
audio_embeds = torch.relu(self.proj1(audio_embeds))
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
batch_size_c, N_t, C_a = audio_embeds_c.shape
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
# second projection
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
# normalization and reshape
context_tokens = self.norm(context_tokens)
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
return context_tokens
class WanModel(ModelMixin, ConfigMixin):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
ignore_for_config = [
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
]
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
model_type='i2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
# audio params
audio_window=5,
intermediate_dim=512,
output_dim=768,
context_tokens=32,
vae_scale=4, # vae timedownsample scale
norm_input_visual=True,
norm_output_audio=True):
super().__init__()
assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.'
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.norm_output_audio = norm_output_audio
self.audio_window = audio_window
self.intermediate_dim = intermediate_dim
self.vae_scale = vae_scale
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
cross_attn_type = 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps,
output_dim=output_dim, norm_input_visual=norm_input_visual)
for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps)
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim)
else:
raise NotImplementedError('Not supported model type.')
# init audio adapter
self.audio_proj = AudioProjModel(
seq_len=audio_window,
seq_len_vf=audio_window+vae_scale-1,
intermediate_dim=intermediate_dim,
output_dim=output_dim,
context_tokens=context_tokens,
norm_output_audio=norm_output_audio,
)
# initialize weights
self.init_weights()
def teacache_init(
self,
use_ret_steps=True,
teacache_thresh=0.2,
sample_steps=40,
model_scale='multitalk-480',
):
print("teacache_init")
self.enable_teacache = True
self.__class__.cnt = 0
self.__class__.num_steps = sample_steps*3
self.__class__.teacache_thresh = teacache_thresh
self.__class__.accumulated_rel_l1_distance_even = 0
self.__class__.accumulated_rel_l1_distance_odd = 0
self.__class__.previous_e0_even = None
self.__class__.previous_e0_odd = None
self.__class__.previous_residual_even = None
self.__class__.previous_residual_odd = None
self.__class__.use_ret_steps = use_ret_steps
if use_ret_steps:
if model_scale == 'multitalk-480':
self.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
if model_scale == 'multitalk-720':
self.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
self.__class__.ret_steps = 5*3
self.__class__.cutoff_steps = sample_steps*3
else:
if model_scale == 'multitalk-480':
self.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
if model_scale == 'multitalk-720':
self.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
self.__class__.ret_steps = 1*3
self.__class__.cutoff_steps = sample_steps*3 - 3
print("teacache_init done")
def disable_teacache(self):
self.enable_teacache = False
def forward(
self,
x,
t,
context,
seq_len,
clip_fea=None,
y=None,
audio=None,
ref_target_masks=None,
):
assert clip_fea is not None and y is not None
_, T, H, W = x[0].shape
N_t = T // self.patch_size[0]
N_h = H // self.patch_size[1]
N_w = W // self.patch_size[2]
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
x[0] = x[0].to(context[0].dtype)
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# text embedding
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
# clip embedding
if clip_fea is not None:
context_clip = self.img_emb(clip_fea)
context = torch.concat([context_clip, context], dim=1).to(x.dtype)
audio_cond = audio.to(device=x.device, dtype=x.dtype)
first_frame_audio_emb_s = audio_cond[:, :1, ...]
latter_frame_audio_emb = audio_cond[:, 1:, ...]
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale)
middle_index = self.audio_window // 2
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
human_num = len(audio_embedding)
audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)
# convert ref_target_masks to token_ref_target_masks
if ref_target_masks is not None:
ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32)
token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest')
token_ref_target_masks = token_ref_target_masks.squeeze(0)
token_ref_target_masks = (token_ref_target_masks > 0)
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
token_ref_target_masks = token_ref_target_masks.to(x.dtype)
# teacache
if self.enable_teacache:
modulated_inp = e0 if self.use_ret_steps else e
if self.cnt%3==0: # cond
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_cond = True
self.accumulated_rel_l1_distance_cond = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_cond < self.teacache_thresh:
should_calc_cond = False
else:
should_calc_cond = True
self.accumulated_rel_l1_distance_cond = 0
self.previous_e0_cond = modulated_inp.clone()
elif self.cnt%3==1: # drop_text
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_drop_text = True
self.accumulated_rel_l1_distance_drop_text = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh:
should_calc_drop_text = False
else:
should_calc_drop_text = True
self.accumulated_rel_l1_distance_drop_text = 0
self.previous_e0_drop_text = modulated_inp.clone()
else: # uncond
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_uncond = True
self.accumulated_rel_l1_distance_uncond = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh:
should_calc_uncond = False
else:
should_calc_uncond = True
self.accumulated_rel_l1_distance_uncond = 0
self.previous_e0_uncond = modulated_inp.clone()
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens,
audio_embedding=audio_embedding,
ref_target_masks=token_ref_target_masks,
human_num=human_num,
)
if self.enable_teacache:
if self.cnt%3==0:
if not should_calc_cond:
x += self.previous_residual_cond
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_cond = x - ori_x
elif self.cnt%3==1:
if not should_calc_drop_text:
x += self.previous_residual_drop_text
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_drop_text = x - ori_x
else:
if not should_calc_uncond:
x += self.previous_residual_uncond
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_uncond = x - ori_x
else:
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
if self.enable_teacache:
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
return torch.stack(x).float()
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)

View File

@ -0,0 +1,353 @@
import os
from einops import rearrange
import torch
import torch.nn as nn
from einops import rearrange, repeat
from functools import lru_cache
import imageio
import uuid
from tqdm import tqdm
import numpy as np
import subprocess
import soundfile as sf
import torchvision
import binascii
import os.path as osp
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
ASPECT_RATIO_627 = {
'0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1),
'0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1),
'1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1),
'3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)}
ASPECT_RATIO_960 = {
'0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1),
'0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1),
'1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1),
'1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1),
'2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1),
'3.75': ([1920, 512], 1)}
def torch_gc():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def split_token_counts_and_frame_ids(T, token_frame, world_size, rank):
S = T * token_frame
split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)]
start = sum(split_sizes[:rank])
end = start + split_sizes[rank]
counts = [0] * T
for idx in range(start, end):
t = idx // token_frame
counts[t] += 1
counts_filtered = []
frame_ids = []
for t, c in enumerate(counts):
if c > 0:
counts_filtered.append(c)
frame_ids.append(t)
return counts_filtered, frame_ids
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
source_min, source_max = source_range
new_min, new_max = target_range
normalized = (column - source_min) / (source_max - source_min + epsilon)
scaled = normalized * (new_max - new_min) + new_min
return scaled
# @torch.compile
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q * scale
visual_q = visual_q.transpose(1, 2)
ref_k = ref_k.transpose(1, 2)
attn = visual_q @ ref_k.transpose(-2, -1)
if attn_bias is not None: attn += attn_bias
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
x_ref_attn_maps = []
ref_target_masks = ref_target_masks.to(visual_q.dtype)
x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask[None, None, None, ...]
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
if mode == 'mean':
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
elif mode == 'max':
x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
x_ref_attn_maps.append(x_ref_attnmap)
del attn
del x_ref_attn_map_source
return torch.concat(x_ref_attn_maps, dim=0)
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
"""Args:
query (torch.tensor): B M H K
key (torch.tensor): B M H K
shape (tuple): (N_t, N_h, N_w)
ref_target_masks: [B, N_h * N_w]
"""
N_t, N_h, N_w = shape
x_seqlens = N_h * N_w
ref_k = ref_k[:, :x_seqlens]
if ref_images_count > 0 :
visual_q_shape = visual_q.shape
visual_q = visual_q.reshape(visual_q_shape[0], N_t, -1)
visual_q = visual_q[:, ref_images_count:]
visual_q = visual_q.reshape(visual_q_shape[0], -1, *visual_q_shape[-2:])
_, seq_lens, heads, _ = visual_q.shape
class_num, _ = ref_target_masks.shape
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
split_chunk = heads // split_num
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
x_ref_attn_maps += x_ref_attn_maps_perhead
x_ref_attn_maps /= split_num
return x_ref_attn_maps
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class RotaryPositionalEmbedding1D(nn.Module):
def __init__(self,
head_dim,
):
super().__init__()
self.head_dim = head_dim
self.base = 10000
@lru_cache(maxsize=32)
def precompute_freqs_cis_1d(self, pos_indices):
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
freqs = freqs.to(pos_indices.device)
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
def forward(self, x, pos_indices):
"""1D RoPE.
Args:
query (torch.tensor): [B, head, seq, head_dim]
pos_indices (torch.tensor): [seq,]
Returns:
query with the same shape as input.
"""
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
x_ = x.float()
freqs_cis = freqs_cis.float().to(x.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
x_ = (x_ * cos) + (rotate_half(x_) * sin)
return x_.type_as(x)
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:
if not suffix.startswith('.'):
suffix = '.' + suffix
name += suffix
return name
def cache_video(tensor,
save_file=None,
fps=30,
suffix='.mp4',
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5):
# cache file
cache_file = osp.join('/tmp', rand_name(
suffix=suffix)) if save_file is None else save_file
# save to cache
error = None
for _ in range(retry):
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack([
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range)
for u in tensor.unbind(2)
],
dim=1).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
# write video
writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"])
for frame in tensor.numpy():
writer.append_data(frame)
writer.close()
return cache_file
def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False):
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
writer = imageio.get_writer(
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
)
for frame in tqdm(frames, desc="Saving video"):
frame = np.array(frame)
writer.append_data(frame)
writer.close()
save_path_tmp = save_path + "-temp.mp4"
if high_quality_save:
cache_video(
tensor=gen_video_samples.unsqueeze(0),
save_file=save_path_tmp,
fps=fps,
nrow=1,
normalize=True,
value_range=(-1, 1)
)
else:
video_audio = (gen_video_samples+1)/2 # C T H W
video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy()
video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255]
save_video(video_audio, save_path_tmp, fps=fps, quality=quality)
# crop audio according to video length
_, T, _, _ = gen_video_samples.shape
duration = T / fps
save_path_crop_audio = save_path + "-cropaudio.wav"
final_command = [
"ffmpeg",
"-i",
vocal_audio_list[0],
"-t",
f'{duration}',
save_path_crop_audio,
]
subprocess.run(final_command, check=True)
save_path = save_path + ".mp4"
if high_quality_save:
final_command = [
"ffmpeg",
"-y",
"-i", save_path_tmp,
"-i", save_path_crop_audio,
"-c:v", "libx264",
"-crf", "0",
"-preset", "veryslow",
"-c:a", "aac",
"-shortest",
save_path,
]
subprocess.run(final_command, check=True)
os.remove(save_path_tmp)
os.remove(save_path_crop_audio)
else:
final_command = [
"ffmpeg",
"-y",
"-i",
save_path_tmp,
"-i",
save_path_crop_audio,
"-c:v",
"libx264",
"-c:a",
"aac",
"-shortest",
save_path,
]
subprocess.run(final_command, check=True)
os.remove(save_path_tmp)
os.remove(save_path_crop_audio)
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def project(
v0: torch.Tensor, # [B, C, T, H, W]
v1: torch.Tensor, # [B, C, T, H, W]
):
dtype = v0.dtype
v0, v1 = v0.double(), v1.double()
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
def adaptive_projected_guidance(
diff: torch.Tensor, # [B, C, T, H, W]
pred_cond: torch.Tensor, # [B, C, T, H, W]
momentum_buffer: MomentumBuffer = None,
eta: float = 0.0,
norm_threshold: float = 55,
):
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True)
print(f"diff_norm: {diff_norm}")
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond)
normalized_update = diff_orthogonal + eta * diff_parallel
return normalized_update

View File

@ -0,0 +1,20 @@
import torch
import torch.nn.functional as F
def get_mask_from_lengths(lengths, max_len=None):
lengths = lengths.to(torch.long)
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
return mask
def linear_interpolation(features, seq_len):
features = features.transpose(1, 2)
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)

125
wan/multitalk/wav2vec2.py Normal file
View File

@ -0,0 +1,125 @@
from transformers import Wav2Vec2Config, Wav2Vec2Model
from transformers.modeling_outputs import BaseModelOutput
from .torch_utils import linear_interpolation
# the implementation of Wav2Vec2Model is borrowed from
# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
class Wav2Vec2Model(Wav2Vec2Model):
def __init__(self, config: Wav2Vec2Config):
super().__init__(config)
def forward(
self,
input_values,
seq_len,
attention_mask=None,
mask_time_indices=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
self.config.output_attentions = True
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, ) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def feature_extract(
self,
input_values,
seq_len,
):
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
return extract_features
def encode(
self,
extract_features,
attention_mask=None,
mask_time_indices=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
self.config.output_attentions = True
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, ) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)