From 3a8bd05c6e57a67e0abdeb277872069a1949a034 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 8 Jul 2025 18:44:55 +0200 Subject: [PATCH] multitalk files --- wan/multitalk/attention.py | 382 +++++++++++++ wan/multitalk/kokoro/__init__.py | 23 + wan/multitalk/kokoro/__main__.py | 148 ++++++ wan/multitalk/kokoro/custom_stft.py | 197 +++++++ wan/multitalk/kokoro/istftnet.py | 421 +++++++++++++++ wan/multitalk/kokoro/model.py | 155 ++++++ wan/multitalk/kokoro/modules.py | 183 +++++++ wan/multitalk/kokoro/pipeline.py | 445 ++++++++++++++++ wan/multitalk/multitalk.py | 319 +++++++++++ wan/multitalk/multitalk_model.py | 799 ++++++++++++++++++++++++++++ wan/multitalk/multitalk_utils.py | 353 ++++++++++++ wan/multitalk/torch_utils.py | 20 + wan/multitalk/wav2vec2.py | 125 +++++ 13 files changed, 3570 insertions(+) create mode 100644 wan/multitalk/attention.py create mode 100644 wan/multitalk/kokoro/__init__.py create mode 100644 wan/multitalk/kokoro/__main__.py create mode 100644 wan/multitalk/kokoro/custom_stft.py create mode 100644 wan/multitalk/kokoro/istftnet.py create mode 100644 wan/multitalk/kokoro/model.py create mode 100644 wan/multitalk/kokoro/modules.py create mode 100644 wan/multitalk/kokoro/pipeline.py create mode 100644 wan/multitalk/multitalk.py create mode 100644 wan/multitalk/multitalk_model.py create mode 100644 wan/multitalk/multitalk_utils.py create mode 100644 wan/multitalk/torch_utils.py create mode 100644 wan/multitalk/wav2vec2.py diff --git a/wan/multitalk/attention.py b/wan/multitalk/attention.py new file mode 100644 index 0000000..ffc2a50 --- /dev/null +++ b/wan/multitalk/attention.py @@ -0,0 +1,382 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +from einops import rearrange, repeat +from .multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids +from wan.modules.attention import pay_attention + +import xformers.ops + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +import warnings + +__all__ = [ + 'flash_attention', + 'attention', +] + + +def flash_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, +): + """ + q: [B, Lq, Nq, C1]. + k: [B, Lk, Nk, C1]. + v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. + q_lens: [B]. + k_lens: [B]. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + causal: bool. Whether to apply causal attention mask. + window_size: (left right). If not (-1, -1), apply sliding window local attention. + deterministic: bool. If True, slightly slower and uses more memory. + dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. + """ + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + assert q.device.type == 'cuda' and q.size(-1) <= 256 + + # params + b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # preprocess query + if q_lens is None: + q = half(q.flatten(0, 1)) + q_lens = torch.tensor( + [lq] * b, dtype=torch.int32).to( + device=q.device, non_blocking=True) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + # preprocess key, value + if k_lens is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + k_lens = torch.tensor( + [lk] * b, dtype=torch.int32).to( + device=k.device, non_blocking=True) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + + if q_scale is not None: + q = q * q_scale + + if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + warnings.warn( + 'Flash attention 3 is not available, use flash attention 2 instead.' + ) + + # apply attention + if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: + # Note: dropout_p, window_size are not supported in FA3 now. + x = flash_attn_interface.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=lq, + max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic)[0].unflatten(0, (b, lq)) + else: + assert FLASH_ATTN_2_AVAILABLE + x = flash_attn.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=lq, + max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic).unflatten(0, (b, lq)) + + # output + return x.type(out_dtype) + + +def attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + fa_version=None, +): + if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: + return flash_attention( + q=q, + k=k, + v=v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=fa_version, + ) + else: + if q_lens is not None or k_lens is not None: + warnings.warn( + 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' + ) + attn_mask = None + + q = q.transpose(1, 2).to(dtype) + k = k.transpose(1, 2).to(dtype) + v = v.transpose(1, 2).to(dtype) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + + out = out.transpose(1, 2).contiguous() + return out + + +class SingleStreamAttention(nn.Module): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + qk_norm: bool, + norm_layer: nn.Module, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + eps: float = 1e-6, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.encoder_hidden_states_dim = encoder_hidden_states_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.qk_norm = qk_norm + + self.q_linear = nn.Linear(dim, dim, bias=qkv_bias) + + self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias) + + self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + + def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: + N_t, N_h, N_w = shape + + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + # get q for hidden_state + B, N, C = x.shape + q = self.q_linear(x) + q_shape = (B, N, self.num_heads, self.head_dim) + q = q.view(q_shape).permute((0, 2, 1, 3)) + + if self.qk_norm: + q = self.q_norm(q) + + # get kv from encoder_hidden_states + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) + encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) + encoder_k, encoder_v = encoder_kv.unbind(0) + + if self.qk_norm: + encoder_k = self.add_k_norm(encoder_k) + + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + + attn_bias = None + # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,) + qkv_list = [q, encoder_k, encoder_v] + q = encoder_k = encoder_v = None + x = pay_attention(qkv_list) + x = rearrange(x, "B M H K -> B H M K") + + # linear transform + x_output_shape = (B, N, C) + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + + # reshape x to origin shape + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + + return x + +class SingleStreamMutiAttention(SingleStreamAttention): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + qk_norm: bool, + norm_layer: nn.Module, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + eps: float = 1e-6, + class_range: int = 24, + class_interval: int = 4, + ) -> None: + super().__init__( + dim=dim, + encoder_hidden_states_dim=encoder_hidden_states_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + norm_layer=norm_layer, + attn_drop=attn_drop, + proj_drop=proj_drop, + eps=eps, + ) + self.class_interval = class_interval + self.class_range = class_range + self.rope_h1 = (0, self.class_interval) + self.rope_h2 = (self.class_range - self.class_interval, self.class_range) + self.rope_bak = int(self.class_range // 2) + + self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) + + def forward(self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + shape=None, + x_ref_attn_map=None, + ) -> torch.Tensor: + + encoder_hidden_states = encoder_hidden_states.squeeze(0) + if x_ref_attn_map == None: + return super().forward(x, encoder_hidden_states, shape) + + N_t, _, _ = shape + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # get q for hidden_state + B, N, C = x.shape + q = self.q_linear(x) + q_shape = (B, N, self.num_heads, self.head_dim) + q = q.view(q_shape).permute((0, 2, 1, 3)) + + if self.qk_norm: + q = self.q_norm(q) + + max_values = x_ref_attn_map.max(1).values[:, None, None] + min_values = x_ref_attn_map.min(1).values[:, None, None] + max_min_values = torch.cat([max_values, min_values], dim=2) + + human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() + human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() + + human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1])) + human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1])) + back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device) + max_indices = x_ref_attn_map.argmax(dim=0) + normalized_map = torch.stack([human1, human2, back], dim=1) + normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N + + q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + q = self.rope_1d(q, normalized_pos) + q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) + encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) + encoder_k, encoder_v = encoder_kv.unbind(0) + + if self.qk_norm: + encoder_k = self.add_k_norm(encoder_k) + + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device) + per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2 + per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 + encoder_pos = torch.concat([per_frame]*N_t, dim=0) + encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + encoder_k = self.rope_1d(encoder_k, encoder_pos) + encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,) + qkv_list = [q, encoder_k, encoder_v] + q = encoder_k = encoder_v = None + x = pay_attention(qkv_list) + + x = rearrange(x, "B M H K -> B H M K") + + # linear transform + x_output_shape = (B, N, C) + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + + # reshape x to origin shape + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + + return x \ No newline at end of file diff --git a/wan/multitalk/kokoro/__init__.py b/wan/multitalk/kokoro/__init__.py new file mode 100644 index 0000000..9156e5c --- /dev/null +++ b/wan/multitalk/kokoro/__init__.py @@ -0,0 +1,23 @@ +__version__ = '0.9.4' + +from loguru import logger +import sys + +# Remove default handler +logger.remove() + +# Add custom handler with clean format including module and line number +logger.add( + sys.stderr, + format="{time:HH:mm:ss} | {module:>16}:{line} | {level: >8} | {message}", + colorize=True, + level="INFO" # "DEBUG" to enable logger.debug("message") and up prints + # "ERROR" to enable only logger.error("message") prints + # etc +) + +# Disable before release or as needed +logger.disable("kokoro") + +from .model import KModel +from .pipeline import KPipeline diff --git a/wan/multitalk/kokoro/__main__.py b/wan/multitalk/kokoro/__main__.py new file mode 100644 index 0000000..34ee21a --- /dev/null +++ b/wan/multitalk/kokoro/__main__.py @@ -0,0 +1,148 @@ +"""Kokoro TTS CLI +Example usage: +python3 -m kokoro --text "The sky above the port was the color of television, tuned to a dead channel." -o file.wav --debug + +echo "Bom dia mundo, como vão vocês" > text.txt +python3 -m kokoro -i text.txt -l p --voice pm_alex > audio.wav + +Common issues: +pip not installed: `uv pip install pip` +(Temporary workaround while https://github.com/explosion/spaCy/issues/13747 is not fixed) + +espeak not installed: `apt-get install espeak-ng` +""" + +import argparse +import wave +from pathlib import Path +from typing import Generator, TYPE_CHECKING + +import numpy as np +from loguru import logger + +languages = [ + "a", # American English + "b", # British English + "h", # Hindi + "e", # Spanish + "f", # French + "i", # Italian + "p", # Brazilian Portuguese + "j", # Japanese + "z", # Mandarin Chinese +] + +if TYPE_CHECKING: + from kokoro import KPipeline + + +def generate_audio( + text: str, kokoro_language: str, voice: str, speed=1 +) -> Generator["KPipeline.Result", None, None]: + from kokoro import KPipeline + + if not voice.startswith(kokoro_language): + logger.warning(f"Voice {voice} is not made for language {kokoro_language}") + pipeline = KPipeline(lang_code=kokoro_language) + yield from pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+") + + +def generate_and_save_audio( + output_file: Path, text: str, kokoro_language: str, voice: str, speed=1 +) -> None: + with wave.open(str(output_file.resolve()), "wb") as wav_file: + wav_file.setnchannels(1) # Mono audio + wav_file.setsampwidth(2) # 2 bytes per sample (16-bit audio) + wav_file.setframerate(24000) # Sample rate + + for result in generate_audio( + text, kokoro_language=kokoro_language, voice=voice, speed=speed + ): + logger.debug(result.phonemes) + if result.audio is None: + continue + audio_bytes = (result.audio.numpy() * 32767).astype(np.int16).tobytes() + wav_file.writeframes(audio_bytes) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--voice", + default="af_heart", + help="Voice to use", + ) + parser.add_argument( + "-l", + "--language", + help="Language to use (defaults to the one corresponding to the voice)", + choices=languages, + ) + parser.add_argument( + "-o", + "--output-file", + "--output_file", + type=Path, + help="Path to output WAV file", + required=True, + ) + parser.add_argument( + "-i", + "--input-file", + "--input_file", + type=Path, + help="Path to input text file (default: stdin)", + ) + parser.add_argument( + "-t", + "--text", + help="Text to use instead of reading from stdin", + ) + parser.add_argument( + "-s", + "--speed", + type=float, + default=1.0, + help="Speech speed", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Print DEBUG messages to console", + ) + args = parser.parse_args() + if args.debug: + logger.level("DEBUG") + logger.debug(args) + + lang = args.language or args.voice[0] + + if args.text is not None and args.input_file is not None: + raise Exception("You cannot specify both 'text' and 'input_file'") + elif args.text: + text = args.text + elif args.input_file: + file: Path = args.input_file + text = file.read_text() + else: + import sys + print("Press Ctrl+D to stop reading input and start generating", flush=True) + text = '\n'.join(sys.stdin) + + logger.debug(f"Input text: {text!r}") + + out_file: Path = args.output_file + if not out_file.suffix == ".wav": + logger.warning("The output file name should end with .wav") + generate_and_save_audio( + output_file=out_file, + text=text, + kokoro_language=lang, + voice=args.voice, + speed=args.speed, + ) + + +if __name__ == "__main__": + main() diff --git a/wan/multitalk/kokoro/custom_stft.py b/wan/multitalk/kokoro/custom_stft.py new file mode 100644 index 0000000..c9cf0d2 --- /dev/null +++ b/wan/multitalk/kokoro/custom_stft.py @@ -0,0 +1,197 @@ +from attr import attr +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class CustomSTFT(nn.Module): + """ + STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d. + + - forward STFT => Real-part conv1d + Imag-part conv1d + - inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum + - avoids F.unfold, so easier to export to ONNX + - uses replicate or constant padding for 'center=True' to approximate 'reflect' + (reflect is not supported for dynamic shapes in ONNX) + """ + + def __init__( + self, + filter_length=800, + hop_length=200, + win_length=800, + window="hann", + center=True, + pad_mode="replicate", # or 'constant' + ): + super().__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.n_fft = filter_length + self.center = center + self.pad_mode = pad_mode + + # Number of frequency bins for real-valued STFT with onesided=True + self.freq_bins = self.n_fft // 2 + 1 + + # Build window + assert window == 'hann', window + window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32) + if self.win_length < self.n_fft: + # Zero-pad up to n_fft + extra = self.n_fft - self.win_length + window_tensor = F.pad(window_tensor, (0, extra)) + elif self.win_length > self.n_fft: + window_tensor = window_tensor[: self.n_fft] + self.register_buffer("window", window_tensor) + + # Precompute forward DFT (real, imag) + # PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...) + n = np.arange(self.n_fft) + k = np.arange(self.freq_bins) + angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft) + dft_real = np.cos(angle) + dft_imag = -np.sin(angle) # note negative sign + + # Combine window and dft => shape (freq_bins, filter_length) + # We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length). + forward_window = window_tensor.numpy() # shape (n_fft,) + forward_real = dft_real * forward_window # (freq_bins, n_fft) + forward_imag = dft_imag * forward_window + + # Convert to PyTorch + forward_real_torch = torch.from_numpy(forward_real).float() + forward_imag_torch = torch.from_numpy(forward_imag).float() + + # Register as Conv1d weight => (out_channels, in_channels, kernel_size) + # out_channels = freq_bins, in_channels=1, kernel_size=n_fft + self.register_buffer( + "weight_forward_real", forward_real_torch.unsqueeze(1) + ) + self.register_buffer( + "weight_forward_imag", forward_imag_torch.unsqueeze(1) + ) + + # Precompute inverse DFT + # Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc. + # For simplicity, we won't do the "DC/nyquist not doubled" approach here. + # If you want perfect real iSTFT, you can add that logic. + # This version just yields good approximate reconstruction with Hann + typical overlap. + inv_scale = 1.0 / self.n_fft + n = np.arange(self.n_fft) + angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins) + idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft) + idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft) + + # Multiply by window again for typical overlap-add + # We also incorporate the scale factor 1/n_fft + inv_window = window_tensor.numpy() * inv_scale + backward_real = idft_cos * inv_window # (freq_bins, n_fft) + backward_imag = idft_sin * inv_window + + # We'll implement iSTFT as real+imag conv_transpose with stride=hop. + self.register_buffer( + "weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1) + ) + self.register_buffer( + "weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1) + ) + + + + def transform(self, waveform: torch.Tensor): + """ + Forward STFT => returns magnitude, phase + Output shape => (batch, freq_bins, frames) + """ + # waveform shape => (B, T). conv1d expects (B, 1, T). + # Optional center pad + if self.center: + pad_len = self.n_fft // 2 + waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode) + + x = waveform.unsqueeze(1) # => (B, 1, T) + # Convolution to get real part => shape (B, freq_bins, frames) + real_out = F.conv1d( + x, + self.weight_forward_real, + bias=None, + stride=self.hop_length, + padding=0, + ) + # Imag part + imag_out = F.conv1d( + x, + self.weight_forward_imag, + bias=None, + stride=self.hop_length, + padding=0, + ) + + # magnitude, phase + magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14) + phase = torch.atan2(imag_out, real_out) + # Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch + # In this case, PyTorch returns pi, ONNX returns -pi + correction_mask = (imag_out == 0) & (real_out < 0) + phase[correction_mask] = torch.pi + return magnitude, phase + + + def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None): + """ + Inverse STFT => returns waveform shape (B, T). + """ + # magnitude, phase => (B, freq_bins, frames) + # Re-create real/imag => shape (B, freq_bins, frames) + real_part = magnitude * torch.cos(phase) + imag_part = magnitude * torch.sin(phase) + + # conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension + # so we do (B, freq_bins, frames) => (B, freq_bins, frames) + # But PyTorch conv_transpose1d expects (B, in_channels, input_length) + real_part = real_part # (B, freq_bins, frames) + imag_part = imag_part + + # real iSTFT => convolve with "backward_real", "backward_imag", and sum + # We'll do 2 conv_transpose calls, each giving (B, 1, time), + # then add them => (B, 1, time). + real_rec = F.conv_transpose1d( + real_part, + self.weight_backward_real, # shape (freq_bins, 1, filter_length) + bias=None, + stride=self.hop_length, + padding=0, + ) + imag_rec = F.conv_transpose1d( + imag_part, + self.weight_backward_imag, + bias=None, + stride=self.hop_length, + padding=0, + ) + # sum => (B, 1, time) + waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part + + # If we used "center=True" in forward, we should remove pad + if self.center: + pad_len = self.n_fft // 2 + # Because of transposed convolution, total length might have extra samples + # We remove `pad_len` from start & end if possible + waveform = waveform[..., pad_len:-pad_len] + + # If a specific length is desired, clamp + if length is not None: + waveform = waveform[..., :length] + + # shape => (B, T) + return waveform + + def forward(self, x: torch.Tensor): + """ + Full STFT -> iSTFT pass: returns time-domain reconstruction. + Same interface as your original code. + """ + mag, phase = self.transform(x) + return self.inverse(mag, phase, length=x.shape[-1]) diff --git a/wan/multitalk/kokoro/istftnet.py b/wan/multitalk/kokoro/istftnet.py new file mode 100644 index 0000000..1c874fc --- /dev/null +++ b/wan/multitalk/kokoro/istftnet.py @@ -0,0 +1,421 @@ +# ADAPTED from https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py +from .custom_stft import CustomSTFT +from torch.nn.utils import weight_norm +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +class AdaIN1d(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + # affine should be False, however there's a bug in the old torch.onnx.export (not newer dynamo) that causes the channel dimension to be lost if affine=False. When affine is true, there's additional learnably parameters. This shouldn't really matter setting it to True, since we're in inference mode + self.norm = nn.InstanceNorm1d(num_features, affine=True) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + + +class AdaINResBlock1(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64): + super(AdaINResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + self.convs2 = nn.ModuleList([ + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + self.adain1 = nn.ModuleList([ + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + ]) + self.adain2 = nn.ModuleList([ + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + ]) + self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]) + self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]) + + def forward(self, x, s): + for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2): + xt = n1(x, s) + xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D + xt = c1(xt) + xt = n2(xt, s) + xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D + xt = c2(xt) + x = xt + x + return x + + +class TorchSTFT(nn.Module): + def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'): + super().__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + assert window == 'hann', window + self.window = torch.hann_window(win_length, periodic=True, dtype=torch.float32) + + def transform(self, input_data): + forward_transform = torch.stft( + input_data, + self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device), + return_complex=True) + return torch.abs(forward_transform), torch.angle(forward_transform) + + def inverse(self, magnitude, phase): + inverse_transform = torch.istft( + magnitude * torch.exp(phase * 1j), + self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device)) + return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class SineGen(nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(torch.pi) or cos(0) + """ + def __init__(self, samp_rate, upsample_scale, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.upsample_scale = upsample_scale + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * torch.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2) + phase = torch.cumsum(rad_values, dim=1) * 2 * torch.pi + phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2) + sines = torch.sin(phase) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + # get the sines + sines = torch.cos(i_phase * 2 * torch.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + self.sine_amp = sine_amp + self.noise_std = add_noise_std + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + # to merge source harmonics into a single excitation + self.l_linear = nn.Linear(harmonic_num + 1, 1) + self.l_tanh = nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class Generator(nn.Module): + def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=24000, + upsample_scale=math.prod(upsample_rates) * gen_istft_hop_size, + harmonic_num=8, voiced_threshod=10) + self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates) * gen_istft_hop_size) + self.noise_convs = nn.ModuleList() + self.noise_res = nn.ModuleList() + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)): + self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim)) + c_cur = upsample_initial_channel // (2 ** (i + 1)) + if i + 1 < len(upsample_rates): + stride_f0 = math.prod(upsample_rates[i + 1:]) + self.noise_convs.append(nn.Conv1d( + gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2)) + self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim)) + else: + self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)) + self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim)) + self.post_n_fft = gen_istft_n_fft + self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = nn.ReflectionPad1d((1, 0)) + self.stft = ( + CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft) + if disable_complex + else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft) + ) + + def forward(self, x, s, f0): + with torch.no_grad(): + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + har_source, noi_source, uv = self.m_source(f0) + har_source = har_source.transpose(1, 2).squeeze(1) + har_spec, har_phase = self.stft.transform(har_source) + har = torch.cat([har_spec, har_phase], dim=1) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, negative_slope=0.1) + x_source = self.noise_convs[i](har) + x_source = self.noise_res[i](x_source, s) + x = self.ups[i](x) + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x, s) + else: + xs += self.resblocks[i*self.num_kernels+j](x, s) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) + phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) + return self.stft.inverse(spec, phase) + + +class UpSample1d(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == 'none': + return x + else: + return F.interpolate(x, scale_factor=2, mode='nearest') + + +class AdainResBlk1d(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0): + super().__init__() + self.actv = actv + self.upsample_type = upsample + self.upsample = UpSample1d(upsample) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + self.dropout = nn.Dropout(dropout_p) + if upsample == 'none': + self.pool = nn.Identity() + else: + self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) + + def _build_weights(self, dim_in, dim_out, style_dim): + self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) + self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) + self.norm1 = AdaIN1d(style_dim, dim_in) + self.norm2 = AdaIN1d(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) + + def _shortcut(self, x): + x = self.upsample(x) + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + x = self.pool(x) + x = self.conv1(self.dropout(x)) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(self.dropout(x)) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) * torch.rsqrt(torch.tensor(2)) + return out + + +class Decoder(nn.Module): + def __init__(self, dim_in, style_dim, dim_out, + resblock_kernel_sizes, + upsample_rates, + upsample_initial_channel, + resblock_dilation_sizes, + upsample_kernel_sizes, + gen_istft_n_fft, gen_istft_hop_size, + disable_complex=False): + super().__init__() + self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) + self.decode = nn.ModuleList() + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) + self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) + self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) + self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1))) + self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, + upsample_initial_channel, resblock_dilation_sizes, + upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex) + + def forward(self, asr, F0_curve, N, s): + F0 = self.F0_conv(F0_curve.unsqueeze(1)) + N = self.N_conv(N.unsqueeze(1)) + x = torch.cat([asr, F0, N], axis=1) + x = self.encode(x, s) + asr_res = self.asr_res(asr) + res = True + for block in self.decode: + if res: + x = torch.cat([x, asr_res, F0, N], axis=1) + x = block(x, s) + if block.upsample_type != "none": + res = False + x = self.generator(x, s, F0_curve) + return x diff --git a/wan/multitalk/kokoro/model.py b/wan/multitalk/kokoro/model.py new file mode 100644 index 0000000..9d6554c --- /dev/null +++ b/wan/multitalk/kokoro/model.py @@ -0,0 +1,155 @@ +from .istftnet import Decoder +from .modules import CustomAlbert, ProsodyPredictor, TextEncoder +from dataclasses import dataclass +from huggingface_hub import hf_hub_download +from loguru import logger +from transformers import AlbertConfig +from typing import Dict, Optional, Union +import json +import torch +import os + +class KModel(torch.nn.Module): + ''' + KModel is a torch.nn.Module with 2 main responsibilities: + 1. Init weights, downloading config.json + model.pth from HF if needed + 2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor) + + You likely only need one KModel instance, and it can be reused across + multiple KPipelines to avoid redundant memory allocation. + + Unlike KPipeline, KModel is language-blind. + + KModel stores self.vocab and thus knows how to map phonemes -> input_ids, + so there is no need to repeatedly download config.json outside of KModel. + ''' + + MODEL_NAMES = { + 'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth', + 'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth', + } + + def __init__( + self, + repo_id: Optional[str] = None, + config: Union[Dict, str, None] = None, + model: Optional[str] = None, + disable_complex: bool = False + ): + super().__init__() + if repo_id is None: + repo_id = 'hexgrad/Kokoro-82M' + print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.") + self.repo_id = repo_id + if not isinstance(config, dict): + if not config: + logger.debug("No config provided, downloading from HF") + config = hf_hub_download(repo_id=repo_id, filename='config.json') + with open(config, 'r', encoding='utf-8') as r: + config = json.load(r) + logger.debug(f"Loaded config: {config}") + self.vocab = config['vocab'] + self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert'])) + self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim']) + self.context_length = self.bert.config.max_position_embeddings + self.predictor = ProsodyPredictor( + style_dim=config['style_dim'], d_hid=config['hidden_dim'], + nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout'] + ) + self.text_encoder = TextEncoder( + channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'], + depth=config['n_layer'], n_symbols=config['n_token'] + ) + self.decoder = Decoder( + dim_in=config['hidden_dim'], style_dim=config['style_dim'], + dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet'] + ) + if not model: + try: + model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id]) + except: + model = os.path.join(repo_id, 'kokoro-v1_0.pth') + for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items(): + assert hasattr(self, key), key + try: + getattr(self, key).load_state_dict(state_dict) + except: + logger.debug(f"Did not load {key} from state_dict") + state_dict = {k[7:]: v for k, v in state_dict.items()} + getattr(self, key).load_state_dict(state_dict, strict=False) + + @property + def device(self): + return self.bert.device + + @dataclass + class Output: + audio: torch.FloatTensor + pred_dur: Optional[torch.LongTensor] = None + + @torch.no_grad() + def forward_with_tokens( + self, + input_ids: torch.LongTensor, + ref_s: torch.FloatTensor, + speed: float = 1 + ) -> tuple[torch.FloatTensor, torch.LongTensor]: + input_lengths = torch.full( + (input_ids.shape[0],), + input_ids.shape[-1], + device=input_ids.device, + dtype=torch.long + ) + + text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths) + text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device) + bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int()) + d_en = self.bert_encoder(bert_dur).transpose(-1, -2) + s = ref_s[:, 128:] + d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask) + x, _ = self.predictor.lstm(d) + duration = self.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) / speed + pred_dur = torch.round(duration).clamp(min=1).long().squeeze() + indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur) + pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device) + pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1 + pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device) + en = d.transpose(-1, -2) @ pred_aln_trg + F0_pred, N_pred = self.predictor.F0Ntrain(en, s) + t_en = self.text_encoder(input_ids, input_lengths, text_mask) + asr = t_en @ pred_aln_trg + audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze() + return audio, pred_dur + + def forward( + self, + phonemes: str, + ref_s: torch.FloatTensor, + speed: float = 1, + return_output: bool = False + ) -> Union['KModel.Output', torch.FloatTensor]: + input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) + logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}") + assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) + input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device) + ref_s = ref_s.to(self.device) + audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed) + audio = audio.squeeze().cpu() + pred_dur = pred_dur.cpu() if pred_dur is not None else None + logger.debug(f"pred_dur: {pred_dur}") + return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio + +class KModelForONNX(torch.nn.Module): + def __init__(self, kmodel: KModel): + super().__init__() + self.kmodel = kmodel + + def forward( + self, + input_ids: torch.LongTensor, + ref_s: torch.FloatTensor, + speed: float = 1 + ) -> tuple[torch.FloatTensor, torch.LongTensor]: + waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed) + return waveform, duration diff --git a/wan/multitalk/kokoro/modules.py b/wan/multitalk/kokoro/modules.py new file mode 100644 index 0000000..05d1575 --- /dev/null +++ b/wan/multitalk/kokoro/modules.py @@ -0,0 +1,183 @@ +# https://github.com/yl4579/StyleTTS2/blob/main/models.py +from .istftnet import AdainResBlk1d +from torch.nn.utils import weight_norm +from transformers import AlbertModel +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LinearNorm(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) + nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class TextEncoder(nn.Module): + def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)): + super().__init__() + self.embedding = nn.Embedding(n_symbols, channels) + padding = (kernel_size - 1) // 2 + self.cnn = nn.ModuleList() + for _ in range(depth): + self.cnn.append(nn.Sequential( + weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)), + LayerNorm(channels), + actv, + nn.Dropout(0.2), + )) + self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True) + + def forward(self, x, input_lengths, m): + x = self.embedding(x) # [B, T, emb] + x = x.transpose(1, 2) # [B, emb, T] + m = m.unsqueeze(1) + x.masked_fill_(m, 0.0) + for c in self.cnn: + x = c(x) + x.masked_fill_(m, 0.0) + x = x.transpose(1, 2) # [B, T, chn] + lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu') + x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + x = x.transpose(-1, -2) + x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device) + x_pad[:, :, :x.shape[-1]] = x + x = x_pad + x.masked_fill_(m, 0.0) + return x + + +class AdaLayerNorm(nn.Module): + def __init__(self, style_dim, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + self.fc = nn.Linear(style_dim, channels*2) + + def forward(self, x, s): + x = x.transpose(-1, -2) + x = x.transpose(1, -1) + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), eps=self.eps) + x = (1 + gamma) * x + beta + return x.transpose(1, -1).transpose(-1, -2) + + +class ProsodyPredictor(nn.Module): + def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1): + super().__init__() + self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout) + self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) + self.duration_proj = LinearNorm(d_hid, max_dur) + self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) + self.F0 = nn.ModuleList() + self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) + self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) + self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) + self.N = nn.ModuleList() + self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) + self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) + self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) + self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) + self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) + + def forward(self, texts, style, text_lengths, alignment, m): + d = self.text_encoder(texts, style, text_lengths, m) + m = m.unsqueeze(1) + lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu') + x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False) + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device) + x_pad[:, :x.shape[1], :] = x + x = x_pad + duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False)) + en = (d.transpose(-1, -2) @ alignment) + return duration.squeeze(-1), en + + def F0Ntrain(self, x, s): + x, _ = self.shared(x.transpose(-1, -2)) + F0 = x.transpose(-1, -2) + for block in self.F0: + F0 = block(F0, s) + F0 = self.F0_proj(F0) + N = x.transpose(-1, -2) + for block in self.N: + N = block(N, s) + N = self.N_proj(N) + return F0.squeeze(1), N.squeeze(1) + + +class DurationEncoder(nn.Module): + def __init__(self, sty_dim, d_model, nlayers, dropout=0.1): + super().__init__() + self.lstms = nn.ModuleList() + for _ in range(nlayers): + self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout)) + self.lstms.append(AdaLayerNorm(sty_dim, d_model)) + self.dropout = dropout + self.d_model = d_model + self.sty_dim = sty_dim + + def forward(self, x, style, text_lengths, m): + masks = m + x = x.permute(2, 0, 1) + s = style.expand(x.shape[0], x.shape[1], -1) + x = torch.cat([x, s], axis=-1) + x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0) + x = x.transpose(0, 1) + x = x.transpose(-1, -2) + for block in self.lstms: + if isinstance(block, AdaLayerNorm): + x = block(x.transpose(-1, -2), style).transpose(-1, -2) + x = torch.cat([x, s.permute(1, 2, 0)], axis=1) + x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0) + else: + lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu') + x = x.transpose(-1, -2) + x = nn.utils.rnn.pack_padded_sequence( + x, lengths, batch_first=True, enforce_sorted=False) + block.flatten_parameters() + x, _ = block(x) + x, _ = nn.utils.rnn.pad_packed_sequence( + x, batch_first=True) + x = F.dropout(x, p=self.dropout, training=False) + x = x.transpose(-1, -2) + x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device) + x_pad[:, :, :x.shape[-1]] = x + x = x_pad + + return x.transpose(-1, -2) + + +# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py +class CustomAlbert(AlbertModel): + def forward(self, *args, **kwargs): + outputs = super().forward(*args, **kwargs) + return outputs.last_hidden_state diff --git a/wan/multitalk/kokoro/pipeline.py b/wan/multitalk/kokoro/pipeline.py new file mode 100644 index 0000000..098df8e --- /dev/null +++ b/wan/multitalk/kokoro/pipeline.py @@ -0,0 +1,445 @@ +from .model import KModel +from dataclasses import dataclass +from huggingface_hub import hf_hub_download +from loguru import logger +from misaki import en, espeak +from typing import Callable, Generator, List, Optional, Tuple, Union +import re +import torch +import os + +ALIASES = { + 'en-us': 'a', + 'en-gb': 'b', + 'es': 'e', + 'fr-fr': 'f', + 'hi': 'h', + 'it': 'i', + 'pt-br': 'p', + 'ja': 'j', + 'zh': 'z', +} + +LANG_CODES = dict( + # pip install misaki[en] + a='American English', + b='British English', + + # espeak-ng + e='es', + f='fr-fr', + h='hi', + i='it', + p='pt-br', + + # pip install misaki[ja] + j='Japanese', + + # pip install misaki[zh] + z='Mandarin Chinese', +) + +class KPipeline: + ''' + KPipeline is a language-aware support class with 2 main responsibilities: + 1. Perform language-specific G2P, mapping (and chunking) text -> phonemes + 2. Manage and store voices, lazily downloaded from HF if needed + + You are expected to have one KPipeline per language. If you have multiple + KPipelines, you should reuse one KModel instance across all of them. + + KPipeline is designed to work with a KModel, but this is not required. + There are 2 ways to pass an existing model into a pipeline: + 1. On init: us_pipeline = KPipeline(lang_code='a', model=model) + 2. On call: us_pipeline(text, voice, model=model) + + By default, KPipeline will automatically initialize its own KModel. To + suppress this, construct a "quiet" KPipeline with model=False. + + A "quiet" KPipeline yields (graphemes, phonemes, None) without generating + any audio. You can use this to phonemize and chunk your text in advance. + + A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio). + ''' + def __init__( + self, + lang_code: str, + repo_id: Optional[str] = None, + model: Union[KModel, bool] = True, + trf: bool = False, + en_callable: Optional[Callable[[str], str]] = None, + device: Optional[str] = None + ): + """Initialize a KPipeline. + + Args: + lang_code: Language code for G2P processing + model: KModel instance, True to create new model, False for no model + trf: Whether to use transformer-based G2P + device: Override default device selection ('cuda' or 'cpu', or None for auto) + If None, will auto-select cuda if available + If 'cuda' and not available, will explicitly raise an error + """ + if repo_id is None: + repo_id = 'hexgrad/Kokoro-82M' + print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.") + config=None + else: + config = os.path.join(repo_id, 'config.json') + self.repo_id = repo_id + lang_code = lang_code.lower() + lang_code = ALIASES.get(lang_code, lang_code) + assert lang_code in LANG_CODES, (lang_code, LANG_CODES) + self.lang_code = lang_code + self.model = None + if isinstance(model, KModel): + self.model = model + elif model: + if device == 'cuda' and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available") + if device == 'mps' and not torch.backends.mps.is_available(): + raise RuntimeError("MPS requested but not available") + if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1': + raise RuntimeError("MPS requested but fallback not enabled") + if device is None: + if torch.cuda.is_available(): + device = 'cuda' + elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available(): + device = 'mps' + else: + device = 'cpu' + try: + self.model = KModel(repo_id=repo_id, config=config).to(device).eval() + except RuntimeError as e: + if device == 'cuda': + raise RuntimeError(f"""Failed to initialize model on CUDA: {e}. + Try setting device='cpu' or check CUDA installation.""") + raise + self.voices = {} + if lang_code in 'ab': + try: + fallback = espeak.EspeakFallback(british=lang_code=='b') + except Exception as e: + logger.warning("EspeakFallback not Enabled: OOD words will be skipped") + logger.warning({str(e)}) + fallback = None + self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='') + elif lang_code == 'j': + try: + from misaki import ja + self.g2p = ja.JAG2P() + except ImportError: + logger.error("You need to `pip install misaki[ja]` to use lang_code='j'") + raise + elif lang_code == 'z': + try: + from misaki import zh + self.g2p = zh.ZHG2P( + version=None if repo_id.endswith('/Kokoro-82M') else '1.1', + en_callable=en_callable + ) + except ImportError: + logger.error("You need to `pip install misaki[zh]` to use lang_code='z'") + raise + else: + language = LANG_CODES[lang_code] + logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") + self.g2p = espeak.EspeakG2P(language=language) + + def load_single_voice(self, voice: str): + if voice in self.voices: + return self.voices[voice] + if voice.endswith('.pt'): + f = voice + else: + f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt') + if not voice.startswith(self.lang_code): + v = LANG_CODES.get(voice, voice) + p = LANG_CODES.get(self.lang_code, self.lang_code) + logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.') + pack = torch.load(f, weights_only=True) + self.voices[voice] = pack + return pack + + """ + load_voice is a helper function that lazily downloads and loads a voice: + Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica'). + If multiple voices are requested, they are averaged. + Delimiter is optional and defaults to ','. + """ + def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor: + if isinstance(voice, torch.FloatTensor): + return voice + if voice in self.voices: + return self.voices[voice] + logger.debug(f"Loading voice: {voice}") + packs = [self.load_single_voice(v) for v in voice.split(delimiter)] + if len(packs) == 1: + return packs[0] + self.voices[voice] = torch.mean(torch.stack(packs), dim=0) + return self.voices[voice] + + @staticmethod + def tokens_to_ps(tokens: List[en.MToken]) -> str: + return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip() + + @staticmethod + def waterfall_last( + tokens: List[en.MToken], + next_count: int, + waterfall: List[str] = ['!.?…', ':;', ',—'], + bumps: List[str] = [')', '”'] + ) -> int: + for w in waterfall: + z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None) + if z is None: + continue + z += 1 + if z < len(tokens) and tokens[z].phonemes in bumps: + z += 1 + if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510: + return z + return len(tokens) + + @staticmethod + def tokens_to_text(tokens: List[en.MToken]) -> str: + return ''.join(t.text + t.whitespace for t in tokens).strip() + + def en_tokenize( + self, + tokens: List[en.MToken] + ) -> Generator[Tuple[str, str, List[en.MToken]], None, None]: + tks = [] + pcount = 0 + for t in tokens: + # American English: ɾ => T + t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T') + next_ps = t.phonemes + (' ' if t.whitespace else '') + next_pcount = pcount + len(next_ps.rstrip()) + if next_pcount > 510: + z = KPipeline.waterfall_last(tks, next_pcount) + text = KPipeline.tokens_to_text(tks[:z]) + logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'") + ps = KPipeline.tokens_to_ps(tks[:z]) + yield text, ps, tks[:z] + tks = tks[z:] + pcount = len(KPipeline.tokens_to_ps(tks)) + if not tks: + next_ps = next_ps.lstrip() + tks.append(t) + pcount += len(next_ps) + if tks: + text = KPipeline.tokens_to_text(tks) + ps = KPipeline.tokens_to_ps(tks) + yield ''.join(text).strip(), ''.join(ps).strip(), tks + + @staticmethod + def infer( + model: KModel, + ps: str, + pack: torch.FloatTensor, + speed: Union[float, Callable[[int], float]] = 1 + ) -> KModel.Output: + if callable(speed): + speed = speed(len(ps)) + return model(ps, pack[len(ps)-1], speed, return_output=True) + + def generate_from_tokens( + self, + tokens: Union[str, List[en.MToken]], + voice: str, + speed: float = 1, + model: Optional[KModel] = None + ) -> Generator['KPipeline.Result', None, None]: + """Generate audio from either raw phonemes or pre-processed tokens. + + Args: + tokens: Either a phoneme string or list of pre-processed MTokens + voice: The voice to use for synthesis + speed: Speech speed modifier (default: 1) + model: Optional KModel instance (uses pipeline's model if not provided) + + Yields: + KPipeline.Result containing the input tokens and generated audio + + Raises: + ValueError: If no voice is provided or token sequence exceeds model limits + """ + model = model or self.model + if model and voice is None: + raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")') + + pack = self.load_voice(voice).to(model.device) if model else None + + # Handle raw phoneme string + if isinstance(tokens, str): + logger.debug("Processing phonemes from raw string") + if len(tokens) > 510: + raise ValueError(f'Phoneme string too long: {len(tokens)} > 510') + output = KPipeline.infer(model, tokens, pack, speed) if model else None + yield self.Result(graphemes='', phonemes=tokens, output=output) + return + + logger.debug("Processing MTokens") + # Handle pre-processed tokens + for gs, ps, tks in self.en_tokenize(tokens): + if not ps: + continue + elif len(ps) > 510: + logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") + logger.warning("Truncating to 510 characters") + ps = ps[:510] + output = KPipeline.infer(model, ps, pack, speed) if model else None + if output is not None and output.pred_dur is not None: + KPipeline.join_timestamps(tks, output.pred_dur) + yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output) + + @staticmethod + def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor): + # Multiply by 600 to go from pred_dur frames to sample_rate 24000 + # Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds + # We will count nice round half-frames, so the divisor is 80 + MAGIC_DIVISOR = 80 + if not tokens or len(pred_dur) < 3: + # We expect at least 3: , token, + return + # We track 2 counts, measured in half-frames: (left, right) + # This way we can cut space characters in half + # TODO: Is -3 an appropriate offset? + left = right = 2 * max(0, pred_dur[0].item() - 3) + # Updates: + # left = right + (2 * token_dur) + space_dur + # right = left + space_dur + i = 1 + for t in tokens: + if i >= len(pred_dur)-1: + break + if not t.phonemes: + if t.whitespace: + i += 1 + left = right + pred_dur[i].item() + right = left + pred_dur[i].item() + i += 1 + continue + j = i + len(t.phonemes) + if j >= len(pred_dur): + break + t.start_ts = left / MAGIC_DIVISOR + token_dur = pred_dur[i: j].sum().item() + space_dur = pred_dur[j].item() if t.whitespace else 0 + left = right + (2 * token_dur) + space_dur + t.end_ts = left / MAGIC_DIVISOR + right = left + space_dur + i = j + (1 if t.whitespace else 0) + + @dataclass + class Result: + graphemes: str + phonemes: str + tokens: Optional[List[en.MToken]] = None + output: Optional[KModel.Output] = None + text_index: Optional[int] = None + + @property + def audio(self) -> Optional[torch.FloatTensor]: + return None if self.output is None else self.output.audio + + @property + def pred_dur(self) -> Optional[torch.LongTensor]: + return None if self.output is None else self.output.pred_dur + + ### MARK: BEGIN BACKWARD COMPAT ### + def __iter__(self): + yield self.graphemes + yield self.phonemes + yield self.audio + + def __getitem__(self, index): + return [self.graphemes, self.phonemes, self.audio][index] + + def __len__(self): + return 3 + #### MARK: END BACKWARD COMPAT #### + + def __call__( + self, + text: Union[str, List[str]], + voice: Optional[str] = None, + speed: Union[float, Callable[[int], float]] = 1, + split_pattern: Optional[str] = r'\n+', + model: Optional[KModel] = None + ) -> Generator['KPipeline.Result', None, None]: + model = model or self.model + if model and voice is None: + raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")') + pack = self.load_voice(voice).to(model.device) if model else None + + # Convert input to list of segments + if isinstance(text, str): + text = re.split(split_pattern, text.strip()) if split_pattern else [text] + + # Process each segment + for graphemes_index, graphemes in enumerate(text): + if not graphemes.strip(): # Skip empty segments + continue + + # English processing (unchanged) + if self.lang_code in 'ab': + logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}") + _, tokens = self.g2p(graphemes) + for gs, ps, tks in self.en_tokenize(tokens): + if not ps: + continue + elif len(ps) > 510: + logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") + ps = ps[:510] + output = KPipeline.infer(model, ps, pack, speed) if model else None + if output is not None and output.pred_dur is not None: + KPipeline.join_timestamps(tks, output.pred_dur) + yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index) + + # Non-English processing with chunking + else: + # Split long text into smaller chunks (roughly 400 characters each) + # Using sentence boundaries when possible + chunk_size = 400 + chunks = [] + + # Try to split on sentence boundaries first + sentences = re.split(r'([.!?]+)', graphemes) + current_chunk = "" + + for i in range(0, len(sentences), 2): + sentence = sentences[i] + # Add the punctuation back if it exists + if i + 1 < len(sentences): + sentence += sentences[i + 1] + + if len(current_chunk) + len(sentence) <= chunk_size: + current_chunk += sentence + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + + if current_chunk: + chunks.append(current_chunk.strip()) + + # If no chunks were created (no sentence boundaries), fall back to character-based chunking + if not chunks: + chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)] + + # Process each chunk + for chunk in chunks: + if not chunk.strip(): + continue + + ps, _ = self.g2p(chunk) + if not ps: + continue + elif len(ps) > 510: + logger.warning(f'Truncating len(ps) == {len(ps)} > 510') + ps = ps[:510] + + output = KPipeline.infer(model, ps, pack, speed) if model else None + yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index) diff --git a/wan/multitalk/multitalk.py b/wan/multitalk/multitalk.py new file mode 100644 index 0000000..e429371 --- /dev/null +++ b/wan/multitalk/multitalk.py @@ -0,0 +1,319 @@ +import random +import os +import torch +import torch.distributed as dist +from PIL import Image +import subprocess +import torchvision.transforms as transforms +import torch.nn.functional as F +import torch.nn as nn +import wan +from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from wan.utils.utils import cache_image, cache_video, str2bool +# from wan.utils.multitalk_utils import save_video_ffmpeg +# from .kokoro import KPipeline +from transformers import Wav2Vec2FeatureExtractor +from .wav2vec2 import Wav2Vec2Model + +import librosa +import pyloudnorm as pyln +import numpy as np +from einops import rearrange +import soundfile as sf +import re +import math + +def custom_init(device, wav2vec): + audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device) + audio_encoder.feature_extractor._freeze_parameters() + wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True) + return wav2vec_feature_extractor, audio_encoder + +def loudness_norm(audio_array, sr=16000, lufs=-23): + meter = pyln.Meter(sr) + loudness = meter.integrated_loudness(audio_array) + if abs(loudness) > 100: + return audio_array + normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs) + return normalized_audio + + +def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps = 25): + audio_duration = len(speech_array) / sr + video_length = audio_duration * fps + + # wav2vec_feature_extractor + audio_feature = np.squeeze( + wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values + ) + audio_feature = torch.from_numpy(audio_feature).float().to(device=device) + audio_feature = audio_feature.unsqueeze(0) + + # audio encoder + with torch.no_grad(): + embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True) + + if len(embeddings) == 0: + print("Fail to extract audio embedding") + return None + + audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) + audio_emb = rearrange(audio_emb, "b s d -> s b d") + + audio_emb = audio_emb.cpu().detach() + return audio_emb + +def audio_prepare_single(audio_path, sample_rate=16000, duration = 0): + ext = os.path.splitext(audio_path)[1].lower() + if ext in ['.mp4', '.mov', '.avi', '.mkv']: + human_speech_array = extract_audio_from_video(audio_path, sample_rate) + return human_speech_array + else: + human_speech_array, sr = librosa.load(audio_path, duration=duration, sr=sample_rate) + human_speech_array = loudness_norm(human_speech_array, sr) + return human_speech_array + + +def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0): + if not (left_path==None or right_path==None): + human_speech_array1 = audio_prepare_single(left_path, duration = duration) + human_speech_array2 = audio_prepare_single(right_path, duration = duration) + elif left_path==None: + human_speech_array2 = audio_prepare_single(right_path, duration = duration) + human_speech_array1 = np.zeros(human_speech_array2.shape[0]) + elif right_path==None: + human_speech_array1 = audio_prepare_single(left_path, duration = duration) + human_speech_array2 = np.zeros(human_speech_array1.shape[0]) + + if audio_type=='para': + new_human_speech1 = human_speech_array1 + new_human_speech2 = human_speech_array2 + elif audio_type=='add': + new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])]) + new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]]) + sum_human_speechs = new_human_speech1 + new_human_speech2 + return new_human_speech1, new_human_speech2, sum_human_speechs + +def process_tts_single(text, save_dir, voice1): + s1_sentences = [] + + pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M') + + voice_tensor = torch.load(voice1, weights_only=True) + generator = pipeline( + text, voice=voice_tensor, # <= change voice here + speed=1, split_pattern=r'\n+' + ) + audios = [] + for i, (gs, ps, audio) in enumerate(generator): + audios.append(audio) + audios = torch.concat(audios, dim=0) + s1_sentences.append(audios) + s1_sentences = torch.concat(s1_sentences, dim=0) + save_path1 =f'{save_dir}/s1.wav' + sf.write(save_path1, s1_sentences, 24000) # save each audio file + s1, _ = librosa.load(save_path1, sr=16000) + return s1, save_path1 + + + +def process_tts_multi(text, save_dir, voice1, voice2): + pattern = r'\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)' + matches = re.findall(pattern, text, re.DOTALL) + + s1_sentences = [] + s2_sentences = [] + + pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M') + for idx, (speaker, content) in enumerate(matches): + if speaker == '1': + voice_tensor = torch.load(voice1, weights_only=True) + generator = pipeline( + content, voice=voice_tensor, # <= change voice here + speed=1, split_pattern=r'\n+' + ) + audios = [] + for i, (gs, ps, audio) in enumerate(generator): + audios.append(audio) + audios = torch.concat(audios, dim=0) + s1_sentences.append(audios) + s2_sentences.append(torch.zeros_like(audios)) + elif speaker == '2': + voice_tensor = torch.load(voice2, weights_only=True) + generator = pipeline( + content, voice=voice_tensor, # <= change voice here + speed=1, split_pattern=r'\n+' + ) + audios = [] + for i, (gs, ps, audio) in enumerate(generator): + audios.append(audio) + audios = torch.concat(audios, dim=0) + s2_sentences.append(audios) + s1_sentences.append(torch.zeros_like(audios)) + + s1_sentences = torch.concat(s1_sentences, dim=0) + s2_sentences = torch.concat(s2_sentences, dim=0) + sum_sentences = s1_sentences + s2_sentences + save_path1 =f'{save_dir}/s1.wav' + save_path2 =f'{save_dir}/s2.wav' + save_path_sum = f'{save_dir}/sum.wav' + sf.write(save_path1, s1_sentences, 24000) # save each audio file + sf.write(save_path2, s2_sentences, 24000) + sf.write(save_path_sum, sum_sentences, 24000) + + s1, _ = librosa.load(save_path1, sr=16000) + s2, _ = librosa.load(save_path2, sr=16000) + # sum, _ = librosa.load(save_path_sum, sr=16000) + return s1, s2, save_path_sum + + +def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000): + wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base") + # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") + + new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps) + audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) + audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) + + full_audio_embs = [] + if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) + # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) + if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) + if audio_guide2 == None: sum_human_speechs = None + return full_audio_embs, sum_human_speechs + + +def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5): + HUMAN_NUMBER = len(full_audio_embs) + audio_end_idx = audio_start_idx + clip_length + indices = (torch.arange(2 * 2 + 1) - 2) * 1 + + audio_embs = [] + # split audio with window size + for human_idx in range(HUMAN_NUMBER): + center_indices = torch.arange( + audio_start_idx, + audio_end_idx, + 1 + ).unsqueeze( + 1 + ) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1).to(full_audio_embs[human_idx].device) + audio_emb = full_audio_embs[human_idx][center_indices][None,...] #.to(self.device) + audio_embs.append(audio_emb) + audio_embs = torch.concat(audio_embs, dim=0) #.to(self.param_dtype) + + # audio_cond = audio.to(device=x.device, dtype=x.dtype) + audio_cond = audio_embs + first_frame_audio_emb_s = audio_cond[:, :1, ...] + latter_frame_audio_emb = audio_cond[:, 1:, ...] + latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=vae_scale) + middle_index = audio_window // 2 + latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] + latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] + latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] + latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) + + return [first_frame_audio_emb_s, latter_frame_audio_emb_s] + +def resize_and_centercrop(cond_image, target_size): + """ + Resize image or tensor to the target size without padding. + """ + + # Get the original size + if isinstance(cond_image, torch.Tensor): + _, orig_h, orig_w = cond_image.shape + else: + orig_h, orig_w = cond_image.height, cond_image.width + + target_h, target_w = target_size + + # Calculate the scaling factor for resizing + scale_h = target_h / orig_h + scale_w = target_w / orig_w + + # Compute the final size + scale = max(scale_h, scale_w) + final_h = math.ceil(scale * orig_h) + final_w = math.ceil(scale * orig_w) + + # Resize + if isinstance(cond_image, torch.Tensor): + if len(cond_image.shape) == 3: + cond_image = cond_image[None] + resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() + # crop + cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) + cropped_tensor = cropped_tensor.squeeze(0) + else: + resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR) + resized_image = np.array(resized_image) + # tensor and crop + resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous() + cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) + cropped_tensor = cropped_tensor[:, :, None, :, :] + + return cropped_tensor + + +def timestep_transform( + t, + shift=5.0, + num_timesteps=1000, +): + t = t / num_timesteps + # shift the timestep based on ratio + new_t = shift * t / (1 + (shift - 1) * t) + new_t = new_t * num_timesteps + return new_t + + +# construct human mask +def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None): + human_masks = [] + if HUMAN_NUMBER==1: + background_mask = torch.ones([src_h, src_w]) + human_mask1 = torch.ones([src_h, src_w]) + human_mask2 = torch.ones([src_h, src_w]) + human_masks = [human_mask1, human_mask2, background_mask] + elif HUMAN_NUMBER==2: + if bbox != None: + assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio" + background_mask = torch.zeros([src_h, src_w]) + for _, person_bbox in bbox.items(): + x_min, y_min, x_max, y_max = person_bbox + human_mask = torch.zeros([src_h, src_w]) + human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1 + background_mask += human_mask + human_masks.append(human_mask) + else: + x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale)) + background_mask = torch.zeros([src_h, src_w]) + background_mask = torch.zeros([src_h, src_w]) + human_mask1 = torch.zeros([src_h, src_w]) + human_mask2 = torch.zeros([src_h, src_w]) + lefty_min, lefty_max = int((src_w//2) * face_scale), int((src_w//2) * (1 - face_scale)) + righty_min, righty_max = int((src_w//2) * face_scale + (src_w//2)), int((src_w//2) * (1 - face_scale) + (src_w//2)) + human_mask1[x_min:x_max, lefty_min:lefty_max] = 1 + human_mask2[x_min:x_max, righty_min:righty_max] = 1 + background_mask += human_mask1 + background_mask += human_mask2 + human_masks = [human_mask1, human_mask2] + background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1)) + human_masks.append(background_mask) + + ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device) + # resize and centercrop for ref_target_masks + # ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w)) + N_h, N_w = lat_h // 2, lat_w // 2 + token_ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(N_h, N_w), mode='nearest').squeeze() + token_ref_target_masks = (token_ref_target_masks > 0) + token_ref_target_masks = token_ref_target_masks.float() #.to(self.device) + + token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) + + return token_ref_target_masks \ No newline at end of file diff --git a/wan/multitalk/multitalk_model.py b/wan/multitalk/multitalk_model.py new file mode 100644 index 0000000..25af83c --- /dev/null +++ b/wan/multitalk/multitalk_model.py @@ -0,0 +1,799 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import numpy as np +import os +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange +from diffusers import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config + +from .attention import flash_attention, SingleStreamMutiAttention +from ..utils.multitalk_utils import get_attn_map_with_target + +__all__ = ['WanModel'] + + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast(enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( + s, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + freqs_i = freqs_i.to(device=x_i.device) + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + output.append(x_i) + return torch.stack(output).float() + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + out = F.layer_norm( + inputs.float(), + self.normalized_shape, + None if self.weight is None else self.weight.float(), + None if self.bias is None else self.bias.float() , + self.eps + ).to(origin_dtype) + return out + + +class WanSelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs, ref_target_masks=None): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + q, k, v = qkv_fn(x) + + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + + + x = flash_attention( + q=q, + k=k, + v=v, + k_lens=seq_lens, + window_size=self.window_size + ).type_as(x) + + # output + x = x.flatten(2) + x = self.o(x) + with torch.no_grad(): + x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0], + ref_target_masks=ref_target_masks) + + return x, x_ref_attn_map + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + img_x = flash_attention(q, k_img, v_img, k_lens=None) + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +class WanAttentionBlock(nn.Module): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + output_dim=768, + norm_input_visual=True, + class_range=24, + class_interval=4): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WanI2VCrossAttention(dim, + num_heads, + (-1, -1), + qk_norm, + eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + # init audio module + self.audio_cross_attn = SingleStreamMutiAttention( + dim=dim, + encoder_hidden_states_dim=output_dim, + num_heads=num_heads, + qk_norm=False, + qkv_bias=True, + eps=eps, + norm_layer=WanRMSNorm, + class_range=class_range, + class_interval=class_interval + ) + self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity() + + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + audio_embedding=None, + ref_target_masks=None, + human_num=None, + ): + + dtype = x.dtype + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation.to(e.device) + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y, x_ref_attn_map = self.self_attn( + (self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes, + freqs, ref_target_masks=ref_target_masks) + with amp.autocast(dtype=torch.float32): + x = x + y * e[2] + + x = x.to(dtype) + + # cross-attention of text + x = x + self.cross_attn(self.norm3(x), context, context_lens) + + # cross attn of audio + x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding, + shape=grid_sizes[0], x_ref_attn_map=x_ref_attn_map, human_num=human_num) + x = x + x_a + + y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype)) + with amp.autocast(dtype=torch.float32): + x = x + y * e[5] + + + x = x.to(dtype) + + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim)) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class AudioProjModel(ModelMixin, ConfigMixin): + def __init__( + self, + seq_len=5, + seq_len_vf=12, + blocks=12, + channels=768, + intermediate_dim=512, + output_dim=768, + context_tokens=32, + norm_output_audio=False, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels + self.input_dim_vf = seq_len_vf * blocks * channels + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.proj1 = nn.Linear(self.input_dim, intermediate_dim) + self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim) + self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) + self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) + self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity() + + def forward(self, audio_embeds, audio_embeds_vf): + video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] + B, _, _, S, C = audio_embeds.shape + + # process audio of first frame + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + # process audio of latter frame + audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") + batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape + audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) + + # first projection + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) + audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) + audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) + audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) + batch_size_c, N_t, C_a = audio_embeds_c.shape + audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) + + # second projection + audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) + + context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim) + + # normalization and reshape + context_tokens = self.norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class WanModel(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = [ + 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' + ] + _no_split_modules = ['WanAttentionBlock'] + + @register_to_config + def __init__(self, + model_type='i2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + # audio params + audio_window=5, + intermediate_dim=512, + output_dim=768, + context_tokens=32, + vae_scale=4, # vae timedownsample scale + + norm_input_visual=True, + norm_output_audio=True): + super().__init__() + + assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.' + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + + self.norm_output_audio = norm_output_audio + self.audio_window = audio_window + self.intermediate_dim = intermediate_dim + self.vae_scale = vae_scale + + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), + nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, + output_dim=output_dim, norm_input_visual=norm_input_visual) + for _ in range(num_layers) + ]) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + + if model_type == 'i2v': + self.img_emb = MLPProj(1280, dim) + else: + raise NotImplementedError('Not supported model type.') + + # init audio adapter + self.audio_proj = AudioProjModel( + seq_len=audio_window, + seq_len_vf=audio_window+vae_scale-1, + intermediate_dim=intermediate_dim, + output_dim=output_dim, + context_tokens=context_tokens, + norm_output_audio=norm_output_audio, + ) + + + # initialize weights + self.init_weights() + + def teacache_init( + self, + use_ret_steps=True, + teacache_thresh=0.2, + sample_steps=40, + model_scale='multitalk-480', + ): + print("teacache_init") + self.enable_teacache = True + + self.__class__.cnt = 0 + self.__class__.num_steps = sample_steps*3 + self.__class__.teacache_thresh = teacache_thresh + self.__class__.accumulated_rel_l1_distance_even = 0 + self.__class__.accumulated_rel_l1_distance_odd = 0 + self.__class__.previous_e0_even = None + self.__class__.previous_e0_odd = None + self.__class__.previous_residual_even = None + self.__class__.previous_residual_odd = None + self.__class__.use_ret_steps = use_ret_steps + + if use_ret_steps: + if model_scale == 'multitalk-480': + self.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01] + if model_scale == 'multitalk-720': + self.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02] + self.__class__.ret_steps = 5*3 + self.__class__.cutoff_steps = sample_steps*3 + else: + if model_scale == 'multitalk-480': + self.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] + + if model_scale == 'multitalk-720': + self.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + self.__class__.ret_steps = 1*3 + self.__class__.cutoff_steps = sample_steps*3 - 3 + print("teacache_init done") + + def disable_teacache(self): + self.enable_teacache = False + + def forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, + audio=None, + ref_target_masks=None, + ): + assert clip_fea is not None and y is not None + + _, T, H, W = x[0].shape + N_t = T // self.patch_size[0] + N_h = H // self.patch_size[1] + N_w = W // self.patch_size[2] + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + x[0] = x[0].to(context[0].dtype) + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # text embedding + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + # clip embedding + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) + context = torch.concat([context_clip, context], dim=1).to(x.dtype) + + + audio_cond = audio.to(device=x.device, dtype=x.dtype) + first_frame_audio_emb_s = audio_cond[:, :1, ...] + latter_frame_audio_emb = audio_cond[:, 1:, ...] + latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale) + middle_index = self.audio_window // 2 + latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] + latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] + latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] + latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) + audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) + human_num = len(audio_embedding) + audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype) + + + # convert ref_target_masks to token_ref_target_masks + if ref_target_masks is not None: + ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32) + token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest') + token_ref_target_masks = token_ref_target_masks.squeeze(0) + token_ref_target_masks = (token_ref_target_masks > 0) + token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) + token_ref_target_masks = token_ref_target_masks.to(x.dtype) + + # teacache + if self.enable_teacache: + modulated_inp = e0 if self.use_ret_steps else e + if self.cnt%3==0: # cond + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_cond = True + self.accumulated_rel_l1_distance_cond = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_cond < self.teacache_thresh: + should_calc_cond = False + else: + should_calc_cond = True + self.accumulated_rel_l1_distance_cond = 0 + self.previous_e0_cond = modulated_inp.clone() + elif self.cnt%3==1: # drop_text + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_drop_text = True + self.accumulated_rel_l1_distance_drop_text = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh: + should_calc_drop_text = False + else: + should_calc_drop_text = True + self.accumulated_rel_l1_distance_drop_text = 0 + self.previous_e0_drop_text = modulated_inp.clone() + else: # uncond + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_uncond = True + self.accumulated_rel_l1_distance_uncond = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh: + should_calc_uncond = False + else: + should_calc_uncond = True + self.accumulated_rel_l1_distance_uncond = 0 + self.previous_e0_uncond = modulated_inp.clone() + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + audio_embedding=audio_embedding, + ref_target_masks=token_ref_target_masks, + human_num=human_num, + ) + if self.enable_teacache: + if self.cnt%3==0: + if not should_calc_cond: + x += self.previous_residual_cond + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_cond = x - ori_x + elif self.cnt%3==1: + if not should_calc_drop_text: + x += self.previous_residual_drop_text + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_drop_text = x - ori_x + else: + if not should_calc_uncond: + x += self.previous_residual_uncond + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_uncond = x - ori_x + else: + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + if self.enable_teacache: + self.cnt += 1 + if self.cnt >= self.num_steps: + self.cnt = 0 + + return torch.stack(x).float() + + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) \ No newline at end of file diff --git a/wan/multitalk/multitalk_utils.py b/wan/multitalk/multitalk_utils.py new file mode 100644 index 0000000..4054361 --- /dev/null +++ b/wan/multitalk/multitalk_utils.py @@ -0,0 +1,353 @@ +import os +from einops import rearrange + +import torch +import torch.nn as nn + +from einops import rearrange, repeat +from functools import lru_cache +import imageio +import uuid +from tqdm import tqdm +import numpy as np +import subprocess +import soundfile as sf +import torchvision +import binascii +import os.path as osp + + +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") +ASPECT_RATIO_627 = { + '0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1), + '0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1), + '1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1), + '3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)} + + +ASPECT_RATIO_960 = { + '0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1), + '0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1), + '1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1), + '1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1), + '2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1), + '3.75': ([1920, 512], 1)} + + + +def torch_gc(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + + +def split_token_counts_and_frame_ids(T, token_frame, world_size, rank): + + S = T * token_frame + split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)] + start = sum(split_sizes[:rank]) + end = start + split_sizes[rank] + counts = [0] * T + for idx in range(start, end): + t = idx // token_frame + counts[t] += 1 + + counts_filtered = [] + frame_ids = [] + for t, c in enumerate(counts): + if c > 0: + counts_filtered.append(c) + frame_ids.append(t) + return counts_filtered, frame_ids + + +def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): + + source_min, source_max = source_range + new_min, new_max = target_range + + normalized = (column - source_min) / (source_max - source_min + epsilon) + scaled = normalized * (new_max - new_min) + new_min + return scaled + + +# @torch.compile +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None): + + ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q * scale + visual_q = visual_q.transpose(1, 2) + ref_k = ref_k.transpose(1, 2) + attn = visual_q @ ref_k.transpose(-2, -1) + + if attn_bias is not None: attn += attn_bias + + x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + + x_ref_attn_maps = [] + ref_target_masks = ref_target_masks.to(visual_q.dtype) + x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) + + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask[None, None, None, ...] + x_ref_attnmap = x_ref_attn_map_source * ref_target_mask + x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens + x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H + + if mode == 'mean': + x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens + elif mode == 'max': + x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens + + x_ref_attn_maps.append(x_ref_attnmap) + + del attn + del x_ref_attn_map_source + + return torch.concat(x_ref_attn_maps, dim=0) + + +def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0): + """Args: + query (torch.tensor): B M H K + key (torch.tensor): B M H K + shape (tuple): (N_t, N_h, N_w) + ref_target_masks: [B, N_h * N_w] + """ + + N_t, N_h, N_w = shape + + x_seqlens = N_h * N_w + ref_k = ref_k[:, :x_seqlens] + if ref_images_count > 0 : + visual_q_shape = visual_q.shape + visual_q = visual_q.reshape(visual_q_shape[0], N_t, -1) + visual_q = visual_q[:, ref_images_count:] + visual_q = visual_q.reshape(visual_q_shape[0], -1, *visual_q_shape[-2:]) + + _, seq_lens, heads, _ = visual_q.shape + class_num, _ = ref_target_masks.shape + x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype) + + split_chunk = heads // split_num + + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) + x_ref_attn_maps += x_ref_attn_maps_perhead + + x_ref_attn_maps /= split_num + return x_ref_attn_maps + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class RotaryPositionalEmbedding1D(nn.Module): + + def __init__(self, + head_dim, + ): + super().__init__() + self.head_dim = head_dim + self.base = 10000 + + + @lru_cache(maxsize=32) + def precompute_freqs_cis_1d(self, pos_indices): + + freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) + freqs = freqs.to(pos_indices.device) + freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def forward(self, x, pos_indices): + """1D RoPE. + + Args: + query (torch.tensor): [B, head, seq, head_dim] + pos_indices (torch.tensor): [seq,] + Returns: + query with the same shape as input. + """ + freqs_cis = self.precompute_freqs_cis_1d(pos_indices) + + x_ = x.float() + + freqs_cis = freqs_cis.float().to(x.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + x_ = (x_ * cos) + (rotate_half(x_) * sin) + + return x_.type_as(x) + + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + +def cache_video(tensor, + save_file=None, + fps=30, + suffix='.mp4', + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + + # cache file + cache_file = osp.join('/tmp', rand_name( + suffix=suffix)) if save_file is None else save_file + + # save to cache + error = None + for _ in range(retry): + + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid( + u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"]) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + return cache_file + +def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False): + + def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer( + save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params + ) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + save_path_tmp = save_path + "-temp.mp4" + + if high_quality_save: + cache_video( + tensor=gen_video_samples.unsqueeze(0), + save_file=save_path_tmp, + fps=fps, + nrow=1, + normalize=True, + value_range=(-1, 1) + ) + else: + video_audio = (gen_video_samples+1)/2 # C T H W + video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy() + video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255] + save_video(video_audio, save_path_tmp, fps=fps, quality=quality) + + + # crop audio according to video length + _, T, _, _ = gen_video_samples.shape + duration = T / fps + save_path_crop_audio = save_path + "-cropaudio.wav" + final_command = [ + "ffmpeg", + "-i", + vocal_audio_list[0], + "-t", + f'{duration}', + save_path_crop_audio, + ] + subprocess.run(final_command, check=True) + + save_path = save_path + ".mp4" + if high_quality_save: + final_command = [ + "ffmpeg", + "-y", + "-i", save_path_tmp, + "-i", save_path_crop_audio, + "-c:v", "libx264", + "-crf", "0", + "-preset", "veryslow", + "-c:a", "aac", + "-shortest", + save_path, + ] + subprocess.run(final_command, check=True) + os.remove(save_path_tmp) + os.remove(save_path_crop_audio) + else: + final_command = [ + "ffmpeg", + "-y", + "-i", + save_path_tmp, + "-i", + save_path_crop_audio, + "-c:v", + "libx264", + "-c:a", + "aac", + "-shortest", + save_path, + ] + subprocess.run(final_command, check=True) + os.remove(save_path_tmp) + os.remove(save_path_crop_audio) + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + + +def project( + v0: torch.Tensor, # [B, C, T, H, W] + v1: torch.Tensor, # [B, C, T, H, W] + ): + dtype = v0.dtype + v0, v1 = v0.double(), v1.double() + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel.to(dtype), v0_orthogonal.to(dtype) + + +def adaptive_projected_guidance( + diff: torch.Tensor, # [B, C, T, H, W] + pred_cond: torch.Tensor, # [B, C, T, H, W] + momentum_buffer: MomentumBuffer = None, + eta: float = 0.0, + norm_threshold: float = 55, + ): + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True) + print(f"diff_norm: {diff_norm}") + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + diff_parallel, diff_orthogonal = project(diff, pred_cond) + normalized_update = diff_orthogonal + eta * diff_parallel + return normalized_update \ No newline at end of file diff --git a/wan/multitalk/torch_utils.py b/wan/multitalk/torch_utils.py new file mode 100644 index 0000000..caa40ea --- /dev/null +++ b/wan/multitalk/torch_utils.py @@ -0,0 +1,20 @@ +import torch +import torch.nn.functional as F + + +def get_mask_from_lengths(lengths, max_len=None): + lengths = lengths.to(torch.long) + if max_len is None: + max_len = torch.max(lengths).item() + + ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device) + mask = ids < lengths.unsqueeze(1).expand(-1, max_len) + + return mask + + +def linear_interpolation(features, seq_len): + features = features.transpose(1, 2) + output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) + diff --git a/wan/multitalk/wav2vec2.py b/wan/multitalk/wav2vec2.py new file mode 100644 index 0000000..5ec9c2b --- /dev/null +++ b/wan/multitalk/wav2vec2.py @@ -0,0 +1,125 @@ +from transformers import Wav2Vec2Config, Wav2Vec2Model +from transformers.modeling_outputs import BaseModelOutput + +from .torch_utils import linear_interpolation + +# the implementation of Wav2Vec2Model is borrowed from +# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py +# initialize our encoder with the pre-trained wav2vec 2.0 weights. +class Wav2Vec2Model(Wav2Vec2Model): + def __init__(self, config: Wav2Vec2Config): + super().__init__(config) + + def forward( + self, + input_values, + seq_len, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + self.config.output_attentions = True + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = linear_interpolation(extract_features, seq_len=seq_len) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, ) + encoder_outputs[1:] + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + + def feature_extract( + self, + input_values, + seq_len, + ): + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = linear_interpolation(extract_features, seq_len=seq_len) + + return extract_features + + def encode( + self, + extract_features, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + self.config.output_attentions = True + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, ) + encoder_outputs[1:] + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + )