mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
multitalk files
This commit is contained in:
parent
621687c12a
commit
3a8bd05c6e
382
wan/multitalk/attention.py
Normal file
382
wan/multitalk/attention.py
Normal file
@ -0,0 +1,382 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from .multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids
|
||||
from wan.modules.attention import pay_attention
|
||||
|
||||
import xformers.ops
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
FLASH_ATTN_2_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
import warnings
|
||||
|
||||
__all__ = [
|
||||
'flash_attention',
|
||||
'attention',
|
||||
]
|
||||
|
||||
|
||||
def flash_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
version=None,
|
||||
):
|
||||
"""
|
||||
q: [B, Lq, Nq, C1].
|
||||
k: [B, Lk, Nk, C1].
|
||||
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
||||
q_lens: [B].
|
||||
k_lens: [B].
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
||||
deterministic: bool. If True, slightly slower and uses more memory.
|
||||
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
||||
"""
|
||||
half_dtypes = (torch.float16, torch.bfloat16)
|
||||
assert dtype in half_dtypes
|
||||
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
||||
|
||||
# params
|
||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||
|
||||
def half(x):
|
||||
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||
|
||||
# preprocess query
|
||||
if q_lens is None:
|
||||
q = half(q.flatten(0, 1))
|
||||
q_lens = torch.tensor(
|
||||
[lq] * b, dtype=torch.int32).to(
|
||||
device=q.device, non_blocking=True)
|
||||
else:
|
||||
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
||||
|
||||
# preprocess key, value
|
||||
if k_lens is None:
|
||||
k = half(k.flatten(0, 1))
|
||||
v = half(v.flatten(0, 1))
|
||||
k_lens = torch.tensor(
|
||||
[lk] * b, dtype=torch.int32).to(
|
||||
device=k.device, non_blocking=True)
|
||||
else:
|
||||
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
||||
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
||||
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
|
||||
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
||||
warnings.warn(
|
||||
'Flash attention 3 is not available, use flash attention 2 instead.'
|
||||
)
|
||||
|
||||
# apply attention
|
||||
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
||||
# Note: dropout_p, window_size are not supported in FA3 now.
|
||||
x = flash_attn_interface.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
||||
else:
|
||||
assert FLASH_ATTN_2_AVAILABLE
|
||||
x = flash_attn.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic).unflatten(0, (b, lq))
|
||||
|
||||
# output
|
||||
return x.type(out_dtype)
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
fa_version=None,
|
||||
):
|
||||
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
||||
return flash_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
q_lens=q_lens,
|
||||
k_lens=k_lens,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
q_scale=q_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic,
|
||||
dtype=dtype,
|
||||
version=fa_version,
|
||||
)
|
||||
else:
|
||||
if q_lens is not None or k_lens is not None:
|
||||
warnings.warn(
|
||||
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
||||
)
|
||||
attn_mask = None
|
||||
|
||||
q = q.transpose(1, 2).to(dtype)
|
||||
k = k.transpose(1, 2).to(dtype)
|
||||
v = v.transpose(1, 2).to(dtype)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
||||
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
class SingleStreamAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
qk_norm: bool,
|
||||
norm_layer: nn.Module,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias)
|
||||
|
||||
self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||
# get q for hidden_state
|
||||
B, N, C = x.shape
|
||||
q = self.q_linear(x)
|
||||
q_shape = (B, N, self.num_heads, self.head_dim)
|
||||
q = q.view(q_shape).permute((0, 2, 1, 3))
|
||||
|
||||
if self.qk_norm:
|
||||
q = self.q_norm(q)
|
||||
|
||||
# get kv from encoder_hidden_states
|
||||
_, N_a, _ = encoder_hidden_states.shape
|
||||
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
|
||||
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
|
||||
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||
|
||||
if self.qk_norm:
|
||||
encoder_k = self.add_k_norm(encoder_k)
|
||||
|
||||
q = rearrange(q, "B H M K -> B M H K")
|
||||
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||
|
||||
attn_bias = None
|
||||
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
|
||||
qkv_list = [q, encoder_k, encoder_v]
|
||||
q = encoder_k = encoder_v = None
|
||||
x = pay_attention(qkv_list)
|
||||
x = rearrange(x, "B M H K -> B H M K")
|
||||
|
||||
# linear transform
|
||||
x_output_shape = (B, N, C)
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(x_output_shape)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
# reshape x to origin shape
|
||||
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||
|
||||
return x
|
||||
|
||||
class SingleStreamMutiAttention(SingleStreamAttention):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
qk_norm: bool,
|
||||
norm_layer: nn.Module,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
eps: float = 1e-6,
|
||||
class_range: int = 24,
|
||||
class_interval: int = 4,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
norm_layer=norm_layer,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
eps=eps,
|
||||
)
|
||||
self.class_interval = class_interval
|
||||
self.class_range = class_range
|
||||
self.rope_h1 = (0, self.class_interval)
|
||||
self.rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||||
self.rope_bak = int(self.class_range // 2)
|
||||
|
||||
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
shape=None,
|
||||
x_ref_attn_map=None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(0)
|
||||
if x_ref_attn_map == None:
|
||||
return super().forward(x, encoder_hidden_states, shape)
|
||||
|
||||
N_t, _, _ = shape
|
||||
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||
|
||||
# get q for hidden_state
|
||||
B, N, C = x.shape
|
||||
q = self.q_linear(x)
|
||||
q_shape = (B, N, self.num_heads, self.head_dim)
|
||||
q = q.view(q_shape).permute((0, 2, 1, 3))
|
||||
|
||||
if self.qk_norm:
|
||||
q = self.q_norm(q)
|
||||
|
||||
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||||
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||||
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||||
|
||||
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||||
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||||
|
||||
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
|
||||
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
|
||||
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
|
||||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
|
||||
|
||||
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
q = self.rope_1d(q, normalized_pos)
|
||||
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
_, N_a, _ = encoder_hidden_states.shape
|
||||
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
|
||||
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
|
||||
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||
|
||||
if self.qk_norm:
|
||||
encoder_k = self.add_k_norm(encoder_k)
|
||||
|
||||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device)
|
||||
per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
|
||||
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
|
||||
encoder_pos = torch.concat([per_frame]*N_t, dim=0)
|
||||
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||||
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
q = rearrange(q, "B H M K -> B M H K")
|
||||
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
|
||||
qkv_list = [q, encoder_k, encoder_v]
|
||||
q = encoder_k = encoder_v = None
|
||||
x = pay_attention(qkv_list)
|
||||
|
||||
x = rearrange(x, "B M H K -> B H M K")
|
||||
|
||||
# linear transform
|
||||
x_output_shape = (B, N, C)
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(x_output_shape)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
# reshape x to origin shape
|
||||
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||
|
||||
return x
|
||||
23
wan/multitalk/kokoro/__init__.py
Normal file
23
wan/multitalk/kokoro/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
__version__ = '0.9.4'
|
||||
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
|
||||
# Add custom handler with clean format including module and line number
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
format="<green>{time:HH:mm:ss}</green> | <cyan>{module:>16}:{line}</cyan> | <level>{level: >8}</level> | <level>{message}</level>",
|
||||
colorize=True,
|
||||
level="INFO" # "DEBUG" to enable logger.debug("message") and up prints
|
||||
# "ERROR" to enable only logger.error("message") prints
|
||||
# etc
|
||||
)
|
||||
|
||||
# Disable before release or as needed
|
||||
logger.disable("kokoro")
|
||||
|
||||
from .model import KModel
|
||||
from .pipeline import KPipeline
|
||||
148
wan/multitalk/kokoro/__main__.py
Normal file
148
wan/multitalk/kokoro/__main__.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Kokoro TTS CLI
|
||||
Example usage:
|
||||
python3 -m kokoro --text "The sky above the port was the color of television, tuned to a dead channel." -o file.wav --debug
|
||||
|
||||
echo "Bom dia mundo, como vão vocês" > text.txt
|
||||
python3 -m kokoro -i text.txt -l p --voice pm_alex > audio.wav
|
||||
|
||||
Common issues:
|
||||
pip not installed: `uv pip install pip`
|
||||
(Temporary workaround while https://github.com/explosion/spaCy/issues/13747 is not fixed)
|
||||
|
||||
espeak not installed: `apt-get install espeak-ng`
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Generator, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
languages = [
|
||||
"a", # American English
|
||||
"b", # British English
|
||||
"h", # Hindi
|
||||
"e", # Spanish
|
||||
"f", # French
|
||||
"i", # Italian
|
||||
"p", # Brazilian Portuguese
|
||||
"j", # Japanese
|
||||
"z", # Mandarin Chinese
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kokoro import KPipeline
|
||||
|
||||
|
||||
def generate_audio(
|
||||
text: str, kokoro_language: str, voice: str, speed=1
|
||||
) -> Generator["KPipeline.Result", None, None]:
|
||||
from kokoro import KPipeline
|
||||
|
||||
if not voice.startswith(kokoro_language):
|
||||
logger.warning(f"Voice {voice} is not made for language {kokoro_language}")
|
||||
pipeline = KPipeline(lang_code=kokoro_language)
|
||||
yield from pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")
|
||||
|
||||
|
||||
def generate_and_save_audio(
|
||||
output_file: Path, text: str, kokoro_language: str, voice: str, speed=1
|
||||
) -> None:
|
||||
with wave.open(str(output_file.resolve()), "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # Mono audio
|
||||
wav_file.setsampwidth(2) # 2 bytes per sample (16-bit audio)
|
||||
wav_file.setframerate(24000) # Sample rate
|
||||
|
||||
for result in generate_audio(
|
||||
text, kokoro_language=kokoro_language, voice=voice, speed=speed
|
||||
):
|
||||
logger.debug(result.phonemes)
|
||||
if result.audio is None:
|
||||
continue
|
||||
audio_bytes = (result.audio.numpy() * 32767).astype(np.int16).tobytes()
|
||||
wav_file.writeframes(audio_bytes)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--voice",
|
||||
default="af_heart",
|
||||
help="Voice to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--language",
|
||||
help="Language to use (defaults to the one corresponding to the voice)",
|
||||
choices=languages,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
"--output_file",
|
||||
type=Path,
|
||||
help="Path to output WAV file",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-file",
|
||||
"--input_file",
|
||||
type=Path,
|
||||
help="Path to input text file (default: stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--text",
|
||||
help="Text to use instead of reading from stdin",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--speed",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Speech speed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Print DEBUG messages to console",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.debug:
|
||||
logger.level("DEBUG")
|
||||
logger.debug(args)
|
||||
|
||||
lang = args.language or args.voice[0]
|
||||
|
||||
if args.text is not None and args.input_file is not None:
|
||||
raise Exception("You cannot specify both 'text' and 'input_file'")
|
||||
elif args.text:
|
||||
text = args.text
|
||||
elif args.input_file:
|
||||
file: Path = args.input_file
|
||||
text = file.read_text()
|
||||
else:
|
||||
import sys
|
||||
print("Press Ctrl+D to stop reading input and start generating", flush=True)
|
||||
text = '\n'.join(sys.stdin)
|
||||
|
||||
logger.debug(f"Input text: {text!r}")
|
||||
|
||||
out_file: Path = args.output_file
|
||||
if not out_file.suffix == ".wav":
|
||||
logger.warning("The output file name should end with .wav")
|
||||
generate_and_save_audio(
|
||||
output_file=out_file,
|
||||
text=text,
|
||||
kokoro_language=lang,
|
||||
voice=args.voice,
|
||||
speed=args.speed,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
197
wan/multitalk/kokoro/custom_stft.py
Normal file
197
wan/multitalk/kokoro/custom_stft.py
Normal file
@ -0,0 +1,197 @@
|
||||
from attr import attr
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class CustomSTFT(nn.Module):
|
||||
"""
|
||||
STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d.
|
||||
|
||||
- forward STFT => Real-part conv1d + Imag-part conv1d
|
||||
- inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum
|
||||
- avoids F.unfold, so easier to export to ONNX
|
||||
- uses replicate or constant padding for 'center=True' to approximate 'reflect'
|
||||
(reflect is not supported for dynamic shapes in ONNX)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length=800,
|
||||
hop_length=200,
|
||||
win_length=800,
|
||||
window="hann",
|
||||
center=True,
|
||||
pad_mode="replicate", # or 'constant'
|
||||
):
|
||||
super().__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.n_fft = filter_length
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
# Number of frequency bins for real-valued STFT with onesided=True
|
||||
self.freq_bins = self.n_fft // 2 + 1
|
||||
|
||||
# Build window
|
||||
assert window == 'hann', window
|
||||
window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
|
||||
if self.win_length < self.n_fft:
|
||||
# Zero-pad up to n_fft
|
||||
extra = self.n_fft - self.win_length
|
||||
window_tensor = F.pad(window_tensor, (0, extra))
|
||||
elif self.win_length > self.n_fft:
|
||||
window_tensor = window_tensor[: self.n_fft]
|
||||
self.register_buffer("window", window_tensor)
|
||||
|
||||
# Precompute forward DFT (real, imag)
|
||||
# PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...)
|
||||
n = np.arange(self.n_fft)
|
||||
k = np.arange(self.freq_bins)
|
||||
angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft)
|
||||
dft_real = np.cos(angle)
|
||||
dft_imag = -np.sin(angle) # note negative sign
|
||||
|
||||
# Combine window and dft => shape (freq_bins, filter_length)
|
||||
# We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length).
|
||||
forward_window = window_tensor.numpy() # shape (n_fft,)
|
||||
forward_real = dft_real * forward_window # (freq_bins, n_fft)
|
||||
forward_imag = dft_imag * forward_window
|
||||
|
||||
# Convert to PyTorch
|
||||
forward_real_torch = torch.from_numpy(forward_real).float()
|
||||
forward_imag_torch = torch.from_numpy(forward_imag).float()
|
||||
|
||||
# Register as Conv1d weight => (out_channels, in_channels, kernel_size)
|
||||
# out_channels = freq_bins, in_channels=1, kernel_size=n_fft
|
||||
self.register_buffer(
|
||||
"weight_forward_real", forward_real_torch.unsqueeze(1)
|
||||
)
|
||||
self.register_buffer(
|
||||
"weight_forward_imag", forward_imag_torch.unsqueeze(1)
|
||||
)
|
||||
|
||||
# Precompute inverse DFT
|
||||
# Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc.
|
||||
# For simplicity, we won't do the "DC/nyquist not doubled" approach here.
|
||||
# If you want perfect real iSTFT, you can add that logic.
|
||||
# This version just yields good approximate reconstruction with Hann + typical overlap.
|
||||
inv_scale = 1.0 / self.n_fft
|
||||
n = np.arange(self.n_fft)
|
||||
angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins)
|
||||
idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft)
|
||||
idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft)
|
||||
|
||||
# Multiply by window again for typical overlap-add
|
||||
# We also incorporate the scale factor 1/n_fft
|
||||
inv_window = window_tensor.numpy() * inv_scale
|
||||
backward_real = idft_cos * inv_window # (freq_bins, n_fft)
|
||||
backward_imag = idft_sin * inv_window
|
||||
|
||||
# We'll implement iSTFT as real+imag conv_transpose with stride=hop.
|
||||
self.register_buffer(
|
||||
"weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1)
|
||||
)
|
||||
self.register_buffer(
|
||||
"weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1)
|
||||
)
|
||||
|
||||
|
||||
|
||||
def transform(self, waveform: torch.Tensor):
|
||||
"""
|
||||
Forward STFT => returns magnitude, phase
|
||||
Output shape => (batch, freq_bins, frames)
|
||||
"""
|
||||
# waveform shape => (B, T). conv1d expects (B, 1, T).
|
||||
# Optional center pad
|
||||
if self.center:
|
||||
pad_len = self.n_fft // 2
|
||||
waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode)
|
||||
|
||||
x = waveform.unsqueeze(1) # => (B, 1, T)
|
||||
# Convolution to get real part => shape (B, freq_bins, frames)
|
||||
real_out = F.conv1d(
|
||||
x,
|
||||
self.weight_forward_real,
|
||||
bias=None,
|
||||
stride=self.hop_length,
|
||||
padding=0,
|
||||
)
|
||||
# Imag part
|
||||
imag_out = F.conv1d(
|
||||
x,
|
||||
self.weight_forward_imag,
|
||||
bias=None,
|
||||
stride=self.hop_length,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
# magnitude, phase
|
||||
magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14)
|
||||
phase = torch.atan2(imag_out, real_out)
|
||||
# Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch
|
||||
# In this case, PyTorch returns pi, ONNX returns -pi
|
||||
correction_mask = (imag_out == 0) & (real_out < 0)
|
||||
phase[correction_mask] = torch.pi
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None):
|
||||
"""
|
||||
Inverse STFT => returns waveform shape (B, T).
|
||||
"""
|
||||
# magnitude, phase => (B, freq_bins, frames)
|
||||
# Re-create real/imag => shape (B, freq_bins, frames)
|
||||
real_part = magnitude * torch.cos(phase)
|
||||
imag_part = magnitude * torch.sin(phase)
|
||||
|
||||
# conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension
|
||||
# so we do (B, freq_bins, frames) => (B, freq_bins, frames)
|
||||
# But PyTorch conv_transpose1d expects (B, in_channels, input_length)
|
||||
real_part = real_part # (B, freq_bins, frames)
|
||||
imag_part = imag_part
|
||||
|
||||
# real iSTFT => convolve with "backward_real", "backward_imag", and sum
|
||||
# We'll do 2 conv_transpose calls, each giving (B, 1, time),
|
||||
# then add them => (B, 1, time).
|
||||
real_rec = F.conv_transpose1d(
|
||||
real_part,
|
||||
self.weight_backward_real, # shape (freq_bins, 1, filter_length)
|
||||
bias=None,
|
||||
stride=self.hop_length,
|
||||
padding=0,
|
||||
)
|
||||
imag_rec = F.conv_transpose1d(
|
||||
imag_part,
|
||||
self.weight_backward_imag,
|
||||
bias=None,
|
||||
stride=self.hop_length,
|
||||
padding=0,
|
||||
)
|
||||
# sum => (B, 1, time)
|
||||
waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part
|
||||
|
||||
# If we used "center=True" in forward, we should remove pad
|
||||
if self.center:
|
||||
pad_len = self.n_fft // 2
|
||||
# Because of transposed convolution, total length might have extra samples
|
||||
# We remove `pad_len` from start & end if possible
|
||||
waveform = waveform[..., pad_len:-pad_len]
|
||||
|
||||
# If a specific length is desired, clamp
|
||||
if length is not None:
|
||||
waveform = waveform[..., :length]
|
||||
|
||||
# shape => (B, T)
|
||||
return waveform
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Full STFT -> iSTFT pass: returns time-domain reconstruction.
|
||||
Same interface as your original code.
|
||||
"""
|
||||
mag, phase = self.transform(x)
|
||||
return self.inverse(mag, phase, length=x.shape[-1])
|
||||
421
wan/multitalk/kokoro/istftnet.py
Normal file
421
wan/multitalk/kokoro/istftnet.py
Normal file
@ -0,0 +1,421 @@
|
||||
# ADAPTED from https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||
from .custom_stft import CustomSTFT
|
||||
from torch.nn.utils import weight_norm
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
class AdaIN1d(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
# affine should be False, however there's a bug in the old torch.onnx.export (not newer dynamo) that causes the channel dimension to be lost if affine=False. When affine is true, there's additional learnably parameters. This shouldn't really matter setting it to True, since we're in inference mode
|
||||
self.norm = nn.InstanceNorm1d(num_features, affine=True)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
|
||||
class AdaINResBlock1(nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
||||
super(AdaINResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
self.adain1 = nn.ModuleList([
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
])
|
||||
self.adain2 = nn.ModuleList([
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
])
|
||||
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
||||
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
||||
|
||||
def forward(self, x, s):
|
||||
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
||||
xt = n1(x, s)
|
||||
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
||||
xt = c1(xt)
|
||||
xt = n2(xt, s)
|
||||
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module):
|
||||
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
||||
super().__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
assert window == 'hann', window
|
||||
self.window = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
|
||||
|
||||
def transform(self, input_data):
|
||||
forward_transform = torch.stft(
|
||||
input_data,
|
||||
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
||||
return_complex=True)
|
||||
return torch.abs(forward_transform), torch.angle(forward_transform)
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
inverse_transform = torch.istft(
|
||||
magnitude * torch.exp(phase * 1j),
|
||||
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
||||
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
||||
|
||||
def forward(self, input_data):
|
||||
self.magnitude, self.phase = self.transform(input_data)
|
||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||
return reconstruction
|
||||
|
||||
|
||||
class SineGen(nn.Module):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(torch.pi) or cos(0)
|
||||
"""
|
||||
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0,
|
||||
flag_for_pulse=False):
|
||||
super(SineGen, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.dim = self.harmonic_num + 1
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
self.flag_for_pulse = flag_for_pulse
|
||||
self.upsample_scale = upsample_scale
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||
return uv
|
||||
|
||||
def _f02sine(self, f0_values):
|
||||
""" f0_values: (batchsize, length, dim)
|
||||
where dim indicates fundamental tone and overtones
|
||||
"""
|
||||
# convert to F0 in rad. The interger part n can be ignored
|
||||
# because 2 * torch.pi * n doesn't affect phase
|
||||
rad_values = (f0_values / self.sampling_rate) % 1
|
||||
# initial phase noise (no noise for fundamental component)
|
||||
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||
if not self.flag_for_pulse:
|
||||
rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
|
||||
phase = torch.cumsum(rad_values, dim=1) * 2 * torch.pi
|
||||
phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||
sines = torch.sin(phase)
|
||||
else:
|
||||
# If necessary, make sure that the first time step of every
|
||||
# voiced segments is sin(pi) or cos(0)
|
||||
# This is used for pulse-train generation
|
||||
# identify the last time step in unvoiced segments
|
||||
uv = self._f02uv(f0_values)
|
||||
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||
uv_1[:, -1, :] = 1
|
||||
u_loc = (uv < 1) * (uv_1 > 0)
|
||||
# get the instantanouse phase
|
||||
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||
# different batch needs to be processed differently
|
||||
for idx in range(f0_values.shape[0]):
|
||||
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||
# stores the accumulation of i.phase within
|
||||
# each voiced segments
|
||||
tmp_cumsum[idx, :, :] = 0
|
||||
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||
# within the previous voiced segment.
|
||||
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||
# get the sines
|
||||
sines = torch.cos(i_phase * 2 * torch.pi)
|
||||
return sines
|
||||
|
||||
def forward(self, f0):
|
||||
""" sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||
output uv: tensor(batchsize=1, length, 1)
|
||||
"""
|
||||
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
||||
# fundamental component
|
||||
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||
# generate sine waveforms
|
||||
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||
# generate uv signal
|
||||
# uv = torch.ones(f0.shape)
|
||||
# uv = uv * (f0 > self.voiced_threshold)
|
||||
uv = self._f02uv(f0)
|
||||
# noise: for unvoiced should be similar to sine_amp
|
||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||
# for voiced regions is self.noise_std
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
|
||||
class SourceModuleHnNSF(nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
# to produce sine waveforms
|
||||
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
||||
sine_amp, add_noise_std, voiced_threshod)
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = nn.Tanh()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
# source for harmonic branch
|
||||
with torch.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=24000,
|
||||
upsample_scale=math.prod(upsample_rates) * gen_istft_hop_size,
|
||||
harmonic_num=8, voiced_threshod=10)
|
||||
self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates) * gen_istft_hop_size)
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.noise_res = nn.ModuleList()
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(weight_norm(
|
||||
nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(k-u)//2)))
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel//(2**(i+1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
||||
self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = math.prod(upsample_rates[i + 1:])
|
||||
self.noise_convs.append(nn.Conv1d(
|
||||
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
||||
self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
|
||||
else:
|
||||
self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
||||
self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
|
||||
self.post_n_fft = gen_istft_n_fft
|
||||
self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||
self.stft = (
|
||||
CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||
if disable_complex
|
||||
else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||
)
|
||||
|
||||
def forward(self, x, s, f0):
|
||||
with torch.no_grad():
|
||||
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
har_source, noi_source, uv = self.m_source(f0)
|
||||
har_source = har_source.transpose(1, 2).squeeze(1)
|
||||
har_spec, har_phase = self.stft.transform(har_source)
|
||||
har = torch.cat([har_spec, har_phase], dim=1)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, negative_slope=0.1)
|
||||
x_source = self.noise_convs[i](har)
|
||||
x_source = self.noise_res[i](x_source, s)
|
||||
x = self.ups[i](x)
|
||||
if i == self.num_upsamples - 1:
|
||||
x = self.reflection_pad(x)
|
||||
x = x + x_source
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||
else:
|
||||
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||
return self.stft.inverse(spec, phase)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, layer_type):
|
||||
super().__init__()
|
||||
self.layer_type = layer_type
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_type == 'none':
|
||||
return x
|
||||
else:
|
||||
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
|
||||
|
||||
class AdainResBlk1d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample_type = upsample
|
||||
self.upsample = UpSample1d(upsample)
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
if upsample == 'none':
|
||||
self.pool = nn.Identity()
|
||||
else:
|
||||
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = self.upsample(x)
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv1(self.dropout(x))
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(self.dropout(x))
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) * torch.rsqrt(torch.tensor(2))
|
||||
return out
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, dim_in, style_dim, dim_out,
|
||||
resblock_kernel_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
resblock_dilation_sizes,
|
||||
upsample_kernel_sizes,
|
||||
gen_istft_n_fft, gen_istft_hop_size,
|
||||
disable_complex=False):
|
||||
super().__init__()
|
||||
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
||||
self.decode = nn.ModuleList()
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
||||
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||
self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
|
||||
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
||||
upsample_initial_channel, resblock_dilation_sizes,
|
||||
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex)
|
||||
|
||||
def forward(self, asr, F0_curve, N, s):
|
||||
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
||||
N = self.N_conv(N.unsqueeze(1))
|
||||
x = torch.cat([asr, F0, N], axis=1)
|
||||
x = self.encode(x, s)
|
||||
asr_res = self.asr_res(asr)
|
||||
res = True
|
||||
for block in self.decode:
|
||||
if res:
|
||||
x = torch.cat([x, asr_res, F0, N], axis=1)
|
||||
x = block(x, s)
|
||||
if block.upsample_type != "none":
|
||||
res = False
|
||||
x = self.generator(x, s, F0_curve)
|
||||
return x
|
||||
155
wan/multitalk/kokoro/model.py
Normal file
155
wan/multitalk/kokoro/model.py
Normal file
@ -0,0 +1,155 @@
|
||||
from .istftnet import Decoder
|
||||
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
|
||||
from dataclasses import dataclass
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
from transformers import AlbertConfig
|
||||
from typing import Dict, Optional, Union
|
||||
import json
|
||||
import torch
|
||||
import os
|
||||
|
||||
class KModel(torch.nn.Module):
|
||||
'''
|
||||
KModel is a torch.nn.Module with 2 main responsibilities:
|
||||
1. Init weights, downloading config.json + model.pth from HF if needed
|
||||
2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
|
||||
|
||||
You likely only need one KModel instance, and it can be reused across
|
||||
multiple KPipelines to avoid redundant memory allocation.
|
||||
|
||||
Unlike KPipeline, KModel is language-blind.
|
||||
|
||||
KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
|
||||
so there is no need to repeatedly download config.json outside of KModel.
|
||||
'''
|
||||
|
||||
MODEL_NAMES = {
|
||||
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
|
||||
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: Optional[str] = None,
|
||||
config: Union[Dict, str, None] = None,
|
||||
model: Optional[str] = None,
|
||||
disable_complex: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
if repo_id is None:
|
||||
repo_id = 'hexgrad/Kokoro-82M'
|
||||
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
||||
self.repo_id = repo_id
|
||||
if not isinstance(config, dict):
|
||||
if not config:
|
||||
logger.debug("No config provided, downloading from HF")
|
||||
config = hf_hub_download(repo_id=repo_id, filename='config.json')
|
||||
with open(config, 'r', encoding='utf-8') as r:
|
||||
config = json.load(r)
|
||||
logger.debug(f"Loaded config: {config}")
|
||||
self.vocab = config['vocab']
|
||||
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
|
||||
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
|
||||
self.context_length = self.bert.config.max_position_embeddings
|
||||
self.predictor = ProsodyPredictor(
|
||||
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
|
||||
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
|
||||
)
|
||||
self.text_encoder = TextEncoder(
|
||||
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
|
||||
depth=config['n_layer'], n_symbols=config['n_token']
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
|
||||
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
|
||||
)
|
||||
if not model:
|
||||
try:
|
||||
model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id])
|
||||
except:
|
||||
model = os.path.join(repo_id, 'kokoro-v1_0.pth')
|
||||
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
|
||||
assert hasattr(self, key), key
|
||||
try:
|
||||
getattr(self, key).load_state_dict(state_dict)
|
||||
except:
|
||||
logger.debug(f"Did not load {key} from state_dict")
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
getattr(self, key).load_state_dict(state_dict, strict=False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.bert.device
|
||||
|
||||
@dataclass
|
||||
class Output:
|
||||
audio: torch.FloatTensor
|
||||
pred_dur: Optional[torch.LongTensor] = None
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_tokens(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
ref_s: torch.FloatTensor,
|
||||
speed: float = 1
|
||||
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||
input_lengths = torch.full(
|
||||
(input_ids.shape[0],),
|
||||
input_ids.shape[-1],
|
||||
device=input_ids.device,
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
|
||||
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
|
||||
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
|
||||
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
s = ref_s[:, 128:]
|
||||
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
x, _ = self.predictor.lstm(d)
|
||||
duration = self.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
|
||||
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
|
||||
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
|
||||
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
||||
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
|
||||
return audio, pred_dur
|
||||
|
||||
def forward(
|
||||
self,
|
||||
phonemes: str,
|
||||
ref_s: torch.FloatTensor,
|
||||
speed: float = 1,
|
||||
return_output: bool = False
|
||||
) -> Union['KModel.Output', torch.FloatTensor]:
|
||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
|
||||
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
||||
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
|
||||
ref_s = ref_s.to(self.device)
|
||||
audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed)
|
||||
audio = audio.squeeze().cpu()
|
||||
pred_dur = pred_dur.cpu() if pred_dur is not None else None
|
||||
logger.debug(f"pred_dur: {pred_dur}")
|
||||
return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio
|
||||
|
||||
class KModelForONNX(torch.nn.Module):
|
||||
def __init__(self, kmodel: KModel):
|
||||
super().__init__()
|
||||
self.kmodel = kmodel
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
ref_s: torch.FloatTensor,
|
||||
speed: float = 1
|
||||
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
|
||||
return waveform, duration
|
||||
183
wan/multitalk/kokoro/modules.py
Normal file
183
wan/multitalk/kokoro/modules.py
Normal file
@ -0,0 +1,183 @@
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||
from .istftnet import AdainResBlk1d
|
||||
from torch.nn.utils import weight_norm
|
||||
from transformers import AlbertModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LinearNorm(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(LinearNorm, self).__init__()
|
||||
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
|
||||
nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(n_symbols, channels)
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.cnn = nn.ModuleList()
|
||||
for _ in range(depth):
|
||||
self.cnn.append(nn.Sequential(
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
||||
LayerNorm(channels),
|
||||
actv,
|
||||
nn.Dropout(0.2),
|
||||
))
|
||||
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, x, input_lengths, m):
|
||||
x = self.embedding(x) # [B, T, emb]
|
||||
x = x.transpose(1, 2) # [B, emb, T]
|
||||
m = m.unsqueeze(1)
|
||||
x.masked_fill_(m, 0.0)
|
||||
for c in self.cnn:
|
||||
x = c(x)
|
||||
x.masked_fill_(m, 0.0)
|
||||
x = x.transpose(1, 2) # [B, T, chn]
|
||||
lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu')
|
||||
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||
x = x.transpose(-1, -2)
|
||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
|
||||
x_pad[:, :, :x.shape[-1]] = x
|
||||
x = x_pad
|
||||
x.masked_fill_(m, 0.0)
|
||||
return x
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, style_dim, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
self.fc = nn.Linear(style_dim, channels*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
x = x.transpose(-1, -2)
|
||||
x = x.transpose(1, -1)
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
||||
x = (1 + gamma) * x + beta
|
||||
return x.transpose(1, -1).transpose(-1, -2)
|
||||
|
||||
|
||||
class ProsodyPredictor(nn.Module):
|
||||
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
||||
super().__init__()
|
||||
self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
|
||||
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||
self.duration_proj = LinearNorm(d_hid, max_dur)
|
||||
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||
self.F0 = nn.ModuleList()
|
||||
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||
self.N = nn.ModuleList()
|
||||
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||
|
||||
def forward(self, texts, style, text_lengths, alignment, m):
|
||||
d = self.text_encoder(texts, style, text_lengths, m)
|
||||
m = m.unsqueeze(1)
|
||||
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
|
||||
x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False)
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device)
|
||||
x_pad[:, :x.shape[1], :] = x
|
||||
x = x_pad
|
||||
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
|
||||
en = (d.transpose(-1, -2) @ alignment)
|
||||
return duration.squeeze(-1), en
|
||||
|
||||
def F0Ntrain(self, x, s):
|
||||
x, _ = self.shared(x.transpose(-1, -2))
|
||||
F0 = x.transpose(-1, -2)
|
||||
for block in self.F0:
|
||||
F0 = block(F0, s)
|
||||
F0 = self.F0_proj(F0)
|
||||
N = x.transpose(-1, -2)
|
||||
for block in self.N:
|
||||
N = block(N, s)
|
||||
N = self.N_proj(N)
|
||||
return F0.squeeze(1), N.squeeze(1)
|
||||
|
||||
|
||||
class DurationEncoder(nn.Module):
|
||||
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
||||
super().__init__()
|
||||
self.lstms = nn.ModuleList()
|
||||
for _ in range(nlayers):
|
||||
self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
|
||||
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
||||
self.dropout = dropout
|
||||
self.d_model = d_model
|
||||
self.sty_dim = sty_dim
|
||||
|
||||
def forward(self, x, style, text_lengths, m):
|
||||
masks = m
|
||||
x = x.permute(2, 0, 1)
|
||||
s = style.expand(x.shape[0], x.shape[1], -1)
|
||||
x = torch.cat([x, s], axis=-1)
|
||||
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
||||
x = x.transpose(0, 1)
|
||||
x = x.transpose(-1, -2)
|
||||
for block in self.lstms:
|
||||
if isinstance(block, AdaLayerNorm):
|
||||
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
||||
x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
|
||||
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
||||
else:
|
||||
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
|
||||
x = x.transpose(-1, -2)
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x, lengths, batch_first=True, enforce_sorted=False)
|
||||
block.flatten_parameters()
|
||||
x, _ = block(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
x, batch_first=True)
|
||||
x = F.dropout(x, p=self.dropout, training=False)
|
||||
x = x.transpose(-1, -2)
|
||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
|
||||
x_pad[:, :, :x.shape[-1]] = x
|
||||
x = x_pad
|
||||
|
||||
return x.transpose(-1, -2)
|
||||
|
||||
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||
class CustomAlbert(AlbertModel):
|
||||
def forward(self, *args, **kwargs):
|
||||
outputs = super().forward(*args, **kwargs)
|
||||
return outputs.last_hidden_state
|
||||
445
wan/multitalk/kokoro/pipeline.py
Normal file
445
wan/multitalk/kokoro/pipeline.py
Normal file
@ -0,0 +1,445 @@
|
||||
from .model import KModel
|
||||
from dataclasses import dataclass
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
from misaki import en, espeak
|
||||
from typing import Callable, Generator, List, Optional, Tuple, Union
|
||||
import re
|
||||
import torch
|
||||
import os
|
||||
|
||||
ALIASES = {
|
||||
'en-us': 'a',
|
||||
'en-gb': 'b',
|
||||
'es': 'e',
|
||||
'fr-fr': 'f',
|
||||
'hi': 'h',
|
||||
'it': 'i',
|
||||
'pt-br': 'p',
|
||||
'ja': 'j',
|
||||
'zh': 'z',
|
||||
}
|
||||
|
||||
LANG_CODES = dict(
|
||||
# pip install misaki[en]
|
||||
a='American English',
|
||||
b='British English',
|
||||
|
||||
# espeak-ng
|
||||
e='es',
|
||||
f='fr-fr',
|
||||
h='hi',
|
||||
i='it',
|
||||
p='pt-br',
|
||||
|
||||
# pip install misaki[ja]
|
||||
j='Japanese',
|
||||
|
||||
# pip install misaki[zh]
|
||||
z='Mandarin Chinese',
|
||||
)
|
||||
|
||||
class KPipeline:
|
||||
'''
|
||||
KPipeline is a language-aware support class with 2 main responsibilities:
|
||||
1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
|
||||
2. Manage and store voices, lazily downloaded from HF if needed
|
||||
|
||||
You are expected to have one KPipeline per language. If you have multiple
|
||||
KPipelines, you should reuse one KModel instance across all of them.
|
||||
|
||||
KPipeline is designed to work with a KModel, but this is not required.
|
||||
There are 2 ways to pass an existing model into a pipeline:
|
||||
1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
|
||||
2. On call: us_pipeline(text, voice, model=model)
|
||||
|
||||
By default, KPipeline will automatically initialize its own KModel. To
|
||||
suppress this, construct a "quiet" KPipeline with model=False.
|
||||
|
||||
A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
|
||||
any audio. You can use this to phonemize and chunk your text in advance.
|
||||
|
||||
A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
lang_code: str,
|
||||
repo_id: Optional[str] = None,
|
||||
model: Union[KModel, bool] = True,
|
||||
trf: bool = False,
|
||||
en_callable: Optional[Callable[[str], str]] = None,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""Initialize a KPipeline.
|
||||
|
||||
Args:
|
||||
lang_code: Language code for G2P processing
|
||||
model: KModel instance, True to create new model, False for no model
|
||||
trf: Whether to use transformer-based G2P
|
||||
device: Override default device selection ('cuda' or 'cpu', or None for auto)
|
||||
If None, will auto-select cuda if available
|
||||
If 'cuda' and not available, will explicitly raise an error
|
||||
"""
|
||||
if repo_id is None:
|
||||
repo_id = 'hexgrad/Kokoro-82M'
|
||||
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
||||
config=None
|
||||
else:
|
||||
config = os.path.join(repo_id, 'config.json')
|
||||
self.repo_id = repo_id
|
||||
lang_code = lang_code.lower()
|
||||
lang_code = ALIASES.get(lang_code, lang_code)
|
||||
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||
self.lang_code = lang_code
|
||||
self.model = None
|
||||
if isinstance(model, KModel):
|
||||
self.model = model
|
||||
elif model:
|
||||
if device == 'cuda' and not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA requested but not available")
|
||||
if device == 'mps' and not torch.backends.mps.is_available():
|
||||
raise RuntimeError("MPS requested but not available")
|
||||
if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
|
||||
raise RuntimeError("MPS requested but fallback not enabled")
|
||||
if device is None:
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
|
||||
device = 'mps'
|
||||
else:
|
||||
device = 'cpu'
|
||||
try:
|
||||
self.model = KModel(repo_id=repo_id, config=config).to(device).eval()
|
||||
except RuntimeError as e:
|
||||
if device == 'cuda':
|
||||
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
|
||||
Try setting device='cpu' or check CUDA installation.""")
|
||||
raise
|
||||
self.voices = {}
|
||||
if lang_code in 'ab':
|
||||
try:
|
||||
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
||||
except Exception as e:
|
||||
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
|
||||
logger.warning({str(e)})
|
||||
fallback = None
|
||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
|
||||
elif lang_code == 'j':
|
||||
try:
|
||||
from misaki import ja
|
||||
self.g2p = ja.JAG2P()
|
||||
except ImportError:
|
||||
logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
|
||||
raise
|
||||
elif lang_code == 'z':
|
||||
try:
|
||||
from misaki import zh
|
||||
self.g2p = zh.ZHG2P(
|
||||
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
|
||||
en_callable=en_callable
|
||||
)
|
||||
except ImportError:
|
||||
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
|
||||
raise
|
||||
else:
|
||||
language = LANG_CODES[lang_code]
|
||||
logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
|
||||
self.g2p = espeak.EspeakG2P(language=language)
|
||||
|
||||
def load_single_voice(self, voice: str):
|
||||
if voice in self.voices:
|
||||
return self.voices[voice]
|
||||
if voice.endswith('.pt'):
|
||||
f = voice
|
||||
else:
|
||||
f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
|
||||
if not voice.startswith(self.lang_code):
|
||||
v = LANG_CODES.get(voice, voice)
|
||||
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
||||
logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
|
||||
pack = torch.load(f, weights_only=True)
|
||||
self.voices[voice] = pack
|
||||
return pack
|
||||
|
||||
"""
|
||||
load_voice is a helper function that lazily downloads and loads a voice:
|
||||
Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica').
|
||||
If multiple voices are requested, they are averaged.
|
||||
Delimiter is optional and defaults to ','.
|
||||
"""
|
||||
def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
|
||||
if isinstance(voice, torch.FloatTensor):
|
||||
return voice
|
||||
if voice in self.voices:
|
||||
return self.voices[voice]
|
||||
logger.debug(f"Loading voice: {voice}")
|
||||
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
|
||||
if len(packs) == 1:
|
||||
return packs[0]
|
||||
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
||||
return self.voices[voice]
|
||||
|
||||
@staticmethod
|
||||
def tokens_to_ps(tokens: List[en.MToken]) -> str:
|
||||
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
|
||||
|
||||
@staticmethod
|
||||
def waterfall_last(
|
||||
tokens: List[en.MToken],
|
||||
next_count: int,
|
||||
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
||||
bumps: List[str] = [')', '”']
|
||||
) -> int:
|
||||
for w in waterfall:
|
||||
z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
|
||||
if z is None:
|
||||
continue
|
||||
z += 1
|
||||
if z < len(tokens) and tokens[z].phonemes in bumps:
|
||||
z += 1
|
||||
if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
|
||||
return z
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def tokens_to_text(tokens: List[en.MToken]) -> str:
|
||||
return ''.join(t.text + t.whitespace for t in tokens).strip()
|
||||
|
||||
def en_tokenize(
|
||||
self,
|
||||
tokens: List[en.MToken]
|
||||
) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
|
||||
tks = []
|
||||
pcount = 0
|
||||
for t in tokens:
|
||||
# American English: ɾ => T
|
||||
t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T')
|
||||
next_ps = t.phonemes + (' ' if t.whitespace else '')
|
||||
next_pcount = pcount + len(next_ps.rstrip())
|
||||
if next_pcount > 510:
|
||||
z = KPipeline.waterfall_last(tks, next_pcount)
|
||||
text = KPipeline.tokens_to_text(tks[:z])
|
||||
logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
|
||||
ps = KPipeline.tokens_to_ps(tks[:z])
|
||||
yield text, ps, tks[:z]
|
||||
tks = tks[z:]
|
||||
pcount = len(KPipeline.tokens_to_ps(tks))
|
||||
if not tks:
|
||||
next_ps = next_ps.lstrip()
|
||||
tks.append(t)
|
||||
pcount += len(next_ps)
|
||||
if tks:
|
||||
text = KPipeline.tokens_to_text(tks)
|
||||
ps = KPipeline.tokens_to_ps(tks)
|
||||
yield ''.join(text).strip(), ''.join(ps).strip(), tks
|
||||
|
||||
@staticmethod
|
||||
def infer(
|
||||
model: KModel,
|
||||
ps: str,
|
||||
pack: torch.FloatTensor,
|
||||
speed: Union[float, Callable[[int], float]] = 1
|
||||
) -> KModel.Output:
|
||||
if callable(speed):
|
||||
speed = speed(len(ps))
|
||||
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||
|
||||
def generate_from_tokens(
|
||||
self,
|
||||
tokens: Union[str, List[en.MToken]],
|
||||
voice: str,
|
||||
speed: float = 1,
|
||||
model: Optional[KModel] = None
|
||||
) -> Generator['KPipeline.Result', None, None]:
|
||||
"""Generate audio from either raw phonemes or pre-processed tokens.
|
||||
|
||||
Args:
|
||||
tokens: Either a phoneme string or list of pre-processed MTokens
|
||||
voice: The voice to use for synthesis
|
||||
speed: Speech speed modifier (default: 1)
|
||||
model: Optional KModel instance (uses pipeline's model if not provided)
|
||||
|
||||
Yields:
|
||||
KPipeline.Result containing the input tokens and generated audio
|
||||
|
||||
Raises:
|
||||
ValueError: If no voice is provided or token sequence exceeds model limits
|
||||
"""
|
||||
model = model or self.model
|
||||
if model and voice is None:
|
||||
raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')
|
||||
|
||||
pack = self.load_voice(voice).to(model.device) if model else None
|
||||
|
||||
# Handle raw phoneme string
|
||||
if isinstance(tokens, str):
|
||||
logger.debug("Processing phonemes from raw string")
|
||||
if len(tokens) > 510:
|
||||
raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
|
||||
output = KPipeline.infer(model, tokens, pack, speed) if model else None
|
||||
yield self.Result(graphemes='', phonemes=tokens, output=output)
|
||||
return
|
||||
|
||||
logger.debug("Processing MTokens")
|
||||
# Handle pre-processed tokens
|
||||
for gs, ps, tks in self.en_tokenize(tokens):
|
||||
if not ps:
|
||||
continue
|
||||
elif len(ps) > 510:
|
||||
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||
logger.warning("Truncating to 510 characters")
|
||||
ps = ps[:510]
|
||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||
if output is not None and output.pred_dur is not None:
|
||||
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
||||
|
||||
@staticmethod
|
||||
def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
||||
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
||||
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
|
||||
# We will count nice round half-frames, so the divisor is 80
|
||||
MAGIC_DIVISOR = 80
|
||||
if not tokens or len(pred_dur) < 3:
|
||||
# We expect at least 3: <bos>, token, <eos>
|
||||
return
|
||||
# We track 2 counts, measured in half-frames: (left, right)
|
||||
# This way we can cut space characters in half
|
||||
# TODO: Is -3 an appropriate offset?
|
||||
left = right = 2 * max(0, pred_dur[0].item() - 3)
|
||||
# Updates:
|
||||
# left = right + (2 * token_dur) + space_dur
|
||||
# right = left + space_dur
|
||||
i = 1
|
||||
for t in tokens:
|
||||
if i >= len(pred_dur)-1:
|
||||
break
|
||||
if not t.phonemes:
|
||||
if t.whitespace:
|
||||
i += 1
|
||||
left = right + pred_dur[i].item()
|
||||
right = left + pred_dur[i].item()
|
||||
i += 1
|
||||
continue
|
||||
j = i + len(t.phonemes)
|
||||
if j >= len(pred_dur):
|
||||
break
|
||||
t.start_ts = left / MAGIC_DIVISOR
|
||||
token_dur = pred_dur[i: j].sum().item()
|
||||
space_dur = pred_dur[j].item() if t.whitespace else 0
|
||||
left = right + (2 * token_dur) + space_dur
|
||||
t.end_ts = left / MAGIC_DIVISOR
|
||||
right = left + space_dur
|
||||
i = j + (1 if t.whitespace else 0)
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
graphemes: str
|
||||
phonemes: str
|
||||
tokens: Optional[List[en.MToken]] = None
|
||||
output: Optional[KModel.Output] = None
|
||||
text_index: Optional[int] = None
|
||||
|
||||
@property
|
||||
def audio(self) -> Optional[torch.FloatTensor]:
|
||||
return None if self.output is None else self.output.audio
|
||||
|
||||
@property
|
||||
def pred_dur(self) -> Optional[torch.LongTensor]:
|
||||
return None if self.output is None else self.output.pred_dur
|
||||
|
||||
### MARK: BEGIN BACKWARD COMPAT ###
|
||||
def __iter__(self):
|
||||
yield self.graphemes
|
||||
yield self.phonemes
|
||||
yield self.audio
|
||||
|
||||
def __getitem__(self, index):
|
||||
return [self.graphemes, self.phonemes, self.audio][index]
|
||||
|
||||
def __len__(self):
|
||||
return 3
|
||||
#### MARK: END BACKWARD COMPAT ####
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
voice: Optional[str] = None,
|
||||
speed: Union[float, Callable[[int], float]] = 1,
|
||||
split_pattern: Optional[str] = r'\n+',
|
||||
model: Optional[KModel] = None
|
||||
) -> Generator['KPipeline.Result', None, None]:
|
||||
model = model or self.model
|
||||
if model and voice is None:
|
||||
raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
|
||||
pack = self.load_voice(voice).to(model.device) if model else None
|
||||
|
||||
# Convert input to list of segments
|
||||
if isinstance(text, str):
|
||||
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
||||
|
||||
# Process each segment
|
||||
for graphemes_index, graphemes in enumerate(text):
|
||||
if not graphemes.strip(): # Skip empty segments
|
||||
continue
|
||||
|
||||
# English processing (unchanged)
|
||||
if self.lang_code in 'ab':
|
||||
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
|
||||
_, tokens = self.g2p(graphemes)
|
||||
for gs, ps, tks in self.en_tokenize(tokens):
|
||||
if not ps:
|
||||
continue
|
||||
elif len(ps) > 510:
|
||||
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||
ps = ps[:510]
|
||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||
if output is not None and output.pred_dur is not None:
|
||||
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index)
|
||||
|
||||
# Non-English processing with chunking
|
||||
else:
|
||||
# Split long text into smaller chunks (roughly 400 characters each)
|
||||
# Using sentence boundaries when possible
|
||||
chunk_size = 400
|
||||
chunks = []
|
||||
|
||||
# Try to split on sentence boundaries first
|
||||
sentences = re.split(r'([.!?]+)', graphemes)
|
||||
current_chunk = ""
|
||||
|
||||
for i in range(0, len(sentences), 2):
|
||||
sentence = sentences[i]
|
||||
# Add the punctuation back if it exists
|
||||
if i + 1 < len(sentences):
|
||||
sentence += sentences[i + 1]
|
||||
|
||||
if len(current_chunk) + len(sentence) <= chunk_size:
|
||||
current_chunk += sentence
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
# If no chunks were created (no sentence boundaries), fall back to character-based chunking
|
||||
if not chunks:
|
||||
chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)]
|
||||
|
||||
# Process each chunk
|
||||
for chunk in chunks:
|
||||
if not chunk.strip():
|
||||
continue
|
||||
|
||||
ps, _ = self.g2p(chunk)
|
||||
if not ps:
|
||||
continue
|
||||
elif len(ps) > 510:
|
||||
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
|
||||
ps = ps[:510]
|
||||
|
||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
|
||||
319
wan/multitalk/multitalk.py
Normal file
319
wan/multitalk/multitalk.py
Normal file
@ -0,0 +1,319 @@
|
||||
import random
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
import subprocess
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import wan
|
||||
from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||
# from wan.utils.multitalk_utils import save_video_ffmpeg
|
||||
# from .kokoro import KPipeline
|
||||
from transformers import Wav2Vec2FeatureExtractor
|
||||
from .wav2vec2 import Wav2Vec2Model
|
||||
|
||||
import librosa
|
||||
import pyloudnorm as pyln
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import soundfile as sf
|
||||
import re
|
||||
import math
|
||||
|
||||
def custom_init(device, wav2vec):
|
||||
audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device)
|
||||
audio_encoder.feature_extractor._freeze_parameters()
|
||||
wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
|
||||
return wav2vec_feature_extractor, audio_encoder
|
||||
|
||||
def loudness_norm(audio_array, sr=16000, lufs=-23):
|
||||
meter = pyln.Meter(sr)
|
||||
loudness = meter.integrated_loudness(audio_array)
|
||||
if abs(loudness) > 100:
|
||||
return audio_array
|
||||
normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
|
||||
return normalized_audio
|
||||
|
||||
|
||||
def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps = 25):
|
||||
audio_duration = len(speech_array) / sr
|
||||
video_length = audio_duration * fps
|
||||
|
||||
# wav2vec_feature_extractor
|
||||
audio_feature = np.squeeze(
|
||||
wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
|
||||
)
|
||||
audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
|
||||
audio_feature = audio_feature.unsqueeze(0)
|
||||
|
||||
# audio encoder
|
||||
with torch.no_grad():
|
||||
embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
|
||||
|
||||
if len(embeddings) == 0:
|
||||
print("Fail to extract audio embedding")
|
||||
return None
|
||||
|
||||
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
|
||||
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
||||
|
||||
audio_emb = audio_emb.cpu().detach()
|
||||
return audio_emb
|
||||
|
||||
def audio_prepare_single(audio_path, sample_rate=16000, duration = 0):
|
||||
ext = os.path.splitext(audio_path)[1].lower()
|
||||
if ext in ['.mp4', '.mov', '.avi', '.mkv']:
|
||||
human_speech_array = extract_audio_from_video(audio_path, sample_rate)
|
||||
return human_speech_array
|
||||
else:
|
||||
human_speech_array, sr = librosa.load(audio_path, duration=duration, sr=sample_rate)
|
||||
human_speech_array = loudness_norm(human_speech_array, sr)
|
||||
return human_speech_array
|
||||
|
||||
|
||||
def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0):
|
||||
if not (left_path==None or right_path==None):
|
||||
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
||||
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
||||
elif left_path==None:
|
||||
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
||||
human_speech_array1 = np.zeros(human_speech_array2.shape[0])
|
||||
elif right_path==None:
|
||||
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
||||
human_speech_array2 = np.zeros(human_speech_array1.shape[0])
|
||||
|
||||
if audio_type=='para':
|
||||
new_human_speech1 = human_speech_array1
|
||||
new_human_speech2 = human_speech_array2
|
||||
elif audio_type=='add':
|
||||
new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])])
|
||||
new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]])
|
||||
sum_human_speechs = new_human_speech1 + new_human_speech2
|
||||
return new_human_speech1, new_human_speech2, sum_human_speechs
|
||||
|
||||
def process_tts_single(text, save_dir, voice1):
|
||||
s1_sentences = []
|
||||
|
||||
pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
|
||||
|
||||
voice_tensor = torch.load(voice1, weights_only=True)
|
||||
generator = pipeline(
|
||||
text, voice=voice_tensor, # <= change voice here
|
||||
speed=1, split_pattern=r'\n+'
|
||||
)
|
||||
audios = []
|
||||
for i, (gs, ps, audio) in enumerate(generator):
|
||||
audios.append(audio)
|
||||
audios = torch.concat(audios, dim=0)
|
||||
s1_sentences.append(audios)
|
||||
s1_sentences = torch.concat(s1_sentences, dim=0)
|
||||
save_path1 =f'{save_dir}/s1.wav'
|
||||
sf.write(save_path1, s1_sentences, 24000) # save each audio file
|
||||
s1, _ = librosa.load(save_path1, sr=16000)
|
||||
return s1, save_path1
|
||||
|
||||
|
||||
|
||||
def process_tts_multi(text, save_dir, voice1, voice2):
|
||||
pattern = r'\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
s1_sentences = []
|
||||
s2_sentences = []
|
||||
|
||||
pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
|
||||
for idx, (speaker, content) in enumerate(matches):
|
||||
if speaker == '1':
|
||||
voice_tensor = torch.load(voice1, weights_only=True)
|
||||
generator = pipeline(
|
||||
content, voice=voice_tensor, # <= change voice here
|
||||
speed=1, split_pattern=r'\n+'
|
||||
)
|
||||
audios = []
|
||||
for i, (gs, ps, audio) in enumerate(generator):
|
||||
audios.append(audio)
|
||||
audios = torch.concat(audios, dim=0)
|
||||
s1_sentences.append(audios)
|
||||
s2_sentences.append(torch.zeros_like(audios))
|
||||
elif speaker == '2':
|
||||
voice_tensor = torch.load(voice2, weights_only=True)
|
||||
generator = pipeline(
|
||||
content, voice=voice_tensor, # <= change voice here
|
||||
speed=1, split_pattern=r'\n+'
|
||||
)
|
||||
audios = []
|
||||
for i, (gs, ps, audio) in enumerate(generator):
|
||||
audios.append(audio)
|
||||
audios = torch.concat(audios, dim=0)
|
||||
s2_sentences.append(audios)
|
||||
s1_sentences.append(torch.zeros_like(audios))
|
||||
|
||||
s1_sentences = torch.concat(s1_sentences, dim=0)
|
||||
s2_sentences = torch.concat(s2_sentences, dim=0)
|
||||
sum_sentences = s1_sentences + s2_sentences
|
||||
save_path1 =f'{save_dir}/s1.wav'
|
||||
save_path2 =f'{save_dir}/s2.wav'
|
||||
save_path_sum = f'{save_dir}/sum.wav'
|
||||
sf.write(save_path1, s1_sentences, 24000) # save each audio file
|
||||
sf.write(save_path2, s2_sentences, 24000)
|
||||
sf.write(save_path_sum, sum_sentences, 24000)
|
||||
|
||||
s1, _ = librosa.load(save_path1, sr=16000)
|
||||
s2, _ = librosa.load(save_path2, sr=16000)
|
||||
# sum, _ = librosa.load(save_path_sum, sr=16000)
|
||||
return s1, s2, save_path_sum
|
||||
|
||||
|
||||
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000):
|
||||
wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
|
||||
# wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
|
||||
|
||||
new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps)
|
||||
audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
|
||||
audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
|
||||
|
||||
full_audio_embs = []
|
||||
if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
|
||||
# if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
|
||||
if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
|
||||
if audio_guide2 == None: sum_human_speechs = None
|
||||
return full_audio_embs, sum_human_speechs
|
||||
|
||||
|
||||
def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5):
|
||||
HUMAN_NUMBER = len(full_audio_embs)
|
||||
audio_end_idx = audio_start_idx + clip_length
|
||||
indices = (torch.arange(2 * 2 + 1) - 2) * 1
|
||||
|
||||
audio_embs = []
|
||||
# split audio with window size
|
||||
for human_idx in range(HUMAN_NUMBER):
|
||||
center_indices = torch.arange(
|
||||
audio_start_idx,
|
||||
audio_end_idx,
|
||||
1
|
||||
).unsqueeze(
|
||||
1
|
||||
) + indices.unsqueeze(0)
|
||||
center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1).to(full_audio_embs[human_idx].device)
|
||||
audio_emb = full_audio_embs[human_idx][center_indices][None,...] #.to(self.device)
|
||||
audio_embs.append(audio_emb)
|
||||
audio_embs = torch.concat(audio_embs, dim=0) #.to(self.param_dtype)
|
||||
|
||||
# audio_cond = audio.to(device=x.device, dtype=x.dtype)
|
||||
audio_cond = audio_embs
|
||||
first_frame_audio_emb_s = audio_cond[:, :1, ...]
|
||||
latter_frame_audio_emb = audio_cond[:, 1:, ...]
|
||||
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=vae_scale)
|
||||
middle_index = audio_window // 2
|
||||
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||
|
||||
return [first_frame_audio_emb_s, latter_frame_audio_emb_s]
|
||||
|
||||
def resize_and_centercrop(cond_image, target_size):
|
||||
"""
|
||||
Resize image or tensor to the target size without padding.
|
||||
"""
|
||||
|
||||
# Get the original size
|
||||
if isinstance(cond_image, torch.Tensor):
|
||||
_, orig_h, orig_w = cond_image.shape
|
||||
else:
|
||||
orig_h, orig_w = cond_image.height, cond_image.width
|
||||
|
||||
target_h, target_w = target_size
|
||||
|
||||
# Calculate the scaling factor for resizing
|
||||
scale_h = target_h / orig_h
|
||||
scale_w = target_w / orig_w
|
||||
|
||||
# Compute the final size
|
||||
scale = max(scale_h, scale_w)
|
||||
final_h = math.ceil(scale * orig_h)
|
||||
final_w = math.ceil(scale * orig_w)
|
||||
|
||||
# Resize
|
||||
if isinstance(cond_image, torch.Tensor):
|
||||
if len(cond_image.shape) == 3:
|
||||
cond_image = cond_image[None]
|
||||
resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous()
|
||||
# crop
|
||||
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
|
||||
cropped_tensor = cropped_tensor.squeeze(0)
|
||||
else:
|
||||
resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
|
||||
resized_image = np.array(resized_image)
|
||||
# tensor and crop
|
||||
resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
|
||||
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
|
||||
cropped_tensor = cropped_tensor[:, :, None, :, :]
|
||||
|
||||
return cropped_tensor
|
||||
|
||||
|
||||
def timestep_transform(
|
||||
t,
|
||||
shift=5.0,
|
||||
num_timesteps=1000,
|
||||
):
|
||||
t = t / num_timesteps
|
||||
# shift the timestep based on ratio
|
||||
new_t = shift * t / (1 + (shift - 1) * t)
|
||||
new_t = new_t * num_timesteps
|
||||
return new_t
|
||||
|
||||
|
||||
# construct human mask
|
||||
def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None):
|
||||
human_masks = []
|
||||
if HUMAN_NUMBER==1:
|
||||
background_mask = torch.ones([src_h, src_w])
|
||||
human_mask1 = torch.ones([src_h, src_w])
|
||||
human_mask2 = torch.ones([src_h, src_w])
|
||||
human_masks = [human_mask1, human_mask2, background_mask]
|
||||
elif HUMAN_NUMBER==2:
|
||||
if bbox != None:
|
||||
assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio"
|
||||
background_mask = torch.zeros([src_h, src_w])
|
||||
for _, person_bbox in bbox.items():
|
||||
x_min, y_min, x_max, y_max = person_bbox
|
||||
human_mask = torch.zeros([src_h, src_w])
|
||||
human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
|
||||
background_mask += human_mask
|
||||
human_masks.append(human_mask)
|
||||
else:
|
||||
x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale))
|
||||
background_mask = torch.zeros([src_h, src_w])
|
||||
background_mask = torch.zeros([src_h, src_w])
|
||||
human_mask1 = torch.zeros([src_h, src_w])
|
||||
human_mask2 = torch.zeros([src_h, src_w])
|
||||
lefty_min, lefty_max = int((src_w//2) * face_scale), int((src_w//2) * (1 - face_scale))
|
||||
righty_min, righty_max = int((src_w//2) * face_scale + (src_w//2)), int((src_w//2) * (1 - face_scale) + (src_w//2))
|
||||
human_mask1[x_min:x_max, lefty_min:lefty_max] = 1
|
||||
human_mask2[x_min:x_max, righty_min:righty_max] = 1
|
||||
background_mask += human_mask1
|
||||
background_mask += human_mask2
|
||||
human_masks = [human_mask1, human_mask2]
|
||||
background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1))
|
||||
human_masks.append(background_mask)
|
||||
|
||||
ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device)
|
||||
# resize and centercrop for ref_target_masks
|
||||
# ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w))
|
||||
N_h, N_w = lat_h // 2, lat_w // 2
|
||||
token_ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(N_h, N_w), mode='nearest').squeeze()
|
||||
token_ref_target_masks = (token_ref_target_masks > 0)
|
||||
token_ref_target_masks = token_ref_target_masks.float() #.to(self.device)
|
||||
|
||||
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
|
||||
|
||||
return token_ref_target_masks
|
||||
799
wan/multitalk/multitalk_model.py
Normal file
799
wan/multitalk/multitalk_model.py
Normal file
@ -0,0 +1,799 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from diffusers import ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
|
||||
from .attention import flash_attention, SingleStreamMutiAttention
|
||||
from ..utils.multitalk_utils import get_attn_map_with_target
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.type(torch.float64)
|
||||
|
||||
# calculation
|
||||
sinusoid = torch.outer(
|
||||
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
@amp.autocast(enabled=False)
|
||||
def rope_params(max_seq_len, dim, theta=10000):
|
||||
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len),
|
||||
1.0 / torch.pow(theta,
|
||||
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
|
||||
@amp.autocast(enabled=False)
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
||||
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
||||
s, n, -1, 2))
|
||||
freqs_i = torch.cat([
|
||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
],
|
||||
dim=-1).reshape(seq_len, 1, -1)
|
||||
freqs_i = freqs_i.to(device=x_i.device)
|
||||
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
||||
x_i = torch.cat([x_i, x[i, seq_len:]])
|
||||
|
||||
output.append(x_i)
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
return self._norm(x.float()).type_as(x) * self.weight
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.LayerNorm):
|
||||
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
||||
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
origin_dtype = inputs.dtype
|
||||
out = F.layer_norm(
|
||||
inputs.float(),
|
||||
self.normalized_shape,
|
||||
None if self.weight is None else self.weight.float(),
|
||||
None if self.bias is None else self.bias.float() ,
|
||||
self.eps
|
||||
).to(origin_dtype)
|
||||
return out
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, seq_lens, grid_sizes, freqs, ref_target_masks=None):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
return q, k, v
|
||||
q, k, v = qkv_fn(x)
|
||||
|
||||
q = rope_apply(q, grid_sizes, freqs)
|
||||
k = rope_apply(k, grid_sizes, freqs)
|
||||
|
||||
|
||||
x = flash_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size
|
||||
).type_as(x)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
x = self.o(x)
|
||||
with torch.no_grad():
|
||||
x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0],
|
||||
ref_target_masks=ref_target_masks)
|
||||
|
||||
return x, x_ref_attn_map
|
||||
|
||||
|
||||
class WanI2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6):
|
||||
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
||||
|
||||
self.k_img = nn.Linear(dim, dim)
|
||||
self.v_img = nn.Linear(dim, dim)
|
||||
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context, context_lens):
|
||||
context_img = context[:, :257]
|
||||
context = context[:, 257:]
|
||||
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
||||
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
||||
v = self.v(context).view(b, -1, n, d)
|
||||
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
||||
v_img = self.v_img(context_img).view(b, -1, n, d)
|
||||
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
img_x = img_x.flatten(2)
|
||||
x = x + img_x
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6,
|
||||
output_dim=768,
|
||||
norm_input_visual=True,
|
||||
class_range=24,
|
||||
class_interval=4):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||
self.norm3 = WanLayerNorm(
|
||||
dim, eps,
|
||||
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WanI2VCrossAttention(dim,
|
||||
num_heads,
|
||||
(-1, -1),
|
||||
qk_norm,
|
||||
eps)
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
||||
nn.Linear(ffn_dim, dim))
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
# init audio module
|
||||
self.audio_cross_attn = SingleStreamMutiAttention(
|
||||
dim=dim,
|
||||
encoder_hidden_states_dim=output_dim,
|
||||
num_heads=num_heads,
|
||||
qk_norm=False,
|
||||
qkv_bias=True,
|
||||
eps=eps,
|
||||
norm_layer=WanRMSNorm,
|
||||
class_range=class_range,
|
||||
class_interval=class_interval
|
||||
)
|
||||
self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity()
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
context,
|
||||
context_lens,
|
||||
audio_embedding=None,
|
||||
ref_target_masks=None,
|
||||
human_num=None,
|
||||
):
|
||||
|
||||
dtype = x.dtype
|
||||
assert e.dtype == torch.float32
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
|
||||
assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y, x_ref_attn_map = self.self_attn(
|
||||
(self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes,
|
||||
freqs, ref_target_masks=ref_target_masks)
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
x = x + y * e[2]
|
||||
|
||||
x = x.to(dtype)
|
||||
|
||||
# cross-attention of text
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||
|
||||
# cross attn of audio
|
||||
x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding,
|
||||
shape=grid_sizes[0], x_ref_attn_map=x_ref_attn_map, human_num=human_num)
|
||||
x = x + x_a
|
||||
|
||||
y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype))
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
x = x + y * e[5]
|
||||
|
||||
|
||||
x = x.to(dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, e):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
e(Tensor): Shape [B, C]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
return x
|
||||
|
||||
|
||||
class MLPProj(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
||||
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
||||
torch.nn.LayerNorm(out_dim))
|
||||
|
||||
def forward(self, image_embeds):
|
||||
clip_extra_context_tokens = self.proj(image_embeds)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class AudioProjModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
seq_len=5,
|
||||
seq_len_vf=12,
|
||||
blocks=12,
|
||||
channels=768,
|
||||
intermediate_dim=512,
|
||||
output_dim=768,
|
||||
context_tokens=32,
|
||||
norm_output_audio=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.blocks = blocks
|
||||
self.channels = channels
|
||||
self.input_dim = seq_len * blocks * channels
|
||||
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.context_tokens = context_tokens
|
||||
self.output_dim = output_dim
|
||||
|
||||
# define multiple linear layers
|
||||
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
||||
self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
|
||||
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
||||
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
|
||||
|
||||
def forward(self, audio_embeds, audio_embeds_vf):
|
||||
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||
B, _, _, S, C = audio_embeds.shape
|
||||
|
||||
# process audio of first frame
|
||||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||
|
||||
# process audio of latter frame
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||
|
||||
# first projection
|
||||
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||
|
||||
# second projection
|
||||
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||
|
||||
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
|
||||
|
||||
# normalization and reshape
|
||||
context_tokens = self.norm(context_tokens)
|
||||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class WanModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
ignore_for_config = [
|
||||
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
||||
]
|
||||
_no_split_modules = ['WanAttentionBlock']
|
||||
|
||||
@register_to_config
|
||||
def __init__(self,
|
||||
model_type='i2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
# audio params
|
||||
audio_window=5,
|
||||
intermediate_dim=512,
|
||||
output_dim=768,
|
||||
context_tokens=32,
|
||||
vae_scale=4, # vae timedownsample scale
|
||||
|
||||
norm_input_visual=True,
|
||||
norm_output_audio=True):
|
||||
super().__init__()
|
||||
|
||||
assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.'
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
|
||||
self.norm_output_audio = norm_output_audio
|
||||
self.audio_window = audio_window
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.vae_scale = vae_scale
|
||||
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
# blocks
|
||||
cross_attn_type = 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps,
|
||||
output_dim=output_dim, norm_input_visual=norm_input_visual)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
self.freqs = torch.cat([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6))
|
||||
],
|
||||
dim=1)
|
||||
|
||||
if model_type == 'i2v':
|
||||
self.img_emb = MLPProj(1280, dim)
|
||||
else:
|
||||
raise NotImplementedError('Not supported model type.')
|
||||
|
||||
# init audio adapter
|
||||
self.audio_proj = AudioProjModel(
|
||||
seq_len=audio_window,
|
||||
seq_len_vf=audio_window+vae_scale-1,
|
||||
intermediate_dim=intermediate_dim,
|
||||
output_dim=output_dim,
|
||||
context_tokens=context_tokens,
|
||||
norm_output_audio=norm_output_audio,
|
||||
)
|
||||
|
||||
|
||||
# initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def teacache_init(
|
||||
self,
|
||||
use_ret_steps=True,
|
||||
teacache_thresh=0.2,
|
||||
sample_steps=40,
|
||||
model_scale='multitalk-480',
|
||||
):
|
||||
print("teacache_init")
|
||||
self.enable_teacache = True
|
||||
|
||||
self.__class__.cnt = 0
|
||||
self.__class__.num_steps = sample_steps*3
|
||||
self.__class__.teacache_thresh = teacache_thresh
|
||||
self.__class__.accumulated_rel_l1_distance_even = 0
|
||||
self.__class__.accumulated_rel_l1_distance_odd = 0
|
||||
self.__class__.previous_e0_even = None
|
||||
self.__class__.previous_e0_odd = None
|
||||
self.__class__.previous_residual_even = None
|
||||
self.__class__.previous_residual_odd = None
|
||||
self.__class__.use_ret_steps = use_ret_steps
|
||||
|
||||
if use_ret_steps:
|
||||
if model_scale == 'multitalk-480':
|
||||
self.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
||||
if model_scale == 'multitalk-720':
|
||||
self.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
||||
self.__class__.ret_steps = 5*3
|
||||
self.__class__.cutoff_steps = sample_steps*3
|
||||
else:
|
||||
if model_scale == 'multitalk-480':
|
||||
self.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||
|
||||
if model_scale == 'multitalk-720':
|
||||
self.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||
self.__class__.ret_steps = 1*3
|
||||
self.__class__.cutoff_steps = sample_steps*3 - 3
|
||||
print("teacache_init done")
|
||||
|
||||
def disable_teacache(self):
|
||||
self.enable_teacache = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
audio=None,
|
||||
ref_target_masks=None,
|
||||
):
|
||||
assert clip_fea is not None and y is not None
|
||||
|
||||
_, T, H, W = x[0].shape
|
||||
N_t = T // self.patch_size[0]
|
||||
N_h = H // self.patch_size[1]
|
||||
N_w = W // self.patch_size[2]
|
||||
|
||||
if y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
x[0] = x[0].to(context[0].dtype)
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||
grid_sizes = torch.stack(
|
||||
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
x = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in x
|
||||
])
|
||||
|
||||
# time embeddings
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# text embedding
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([
|
||||
torch.cat(
|
||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||
for u in context
|
||||
]))
|
||||
|
||||
# clip embedding
|
||||
if clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.concat([context_clip, context], dim=1).to(x.dtype)
|
||||
|
||||
|
||||
audio_cond = audio.to(device=x.device, dtype=x.dtype)
|
||||
first_frame_audio_emb_s = audio_cond[:, :1, ...]
|
||||
latter_frame_audio_emb = audio_cond[:, 1:, ...]
|
||||
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale)
|
||||
middle_index = self.audio_window // 2
|
||||
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||
audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
||||
human_num = len(audio_embedding)
|
||||
audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)
|
||||
|
||||
|
||||
# convert ref_target_masks to token_ref_target_masks
|
||||
if ref_target_masks is not None:
|
||||
ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32)
|
||||
token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest')
|
||||
token_ref_target_masks = token_ref_target_masks.squeeze(0)
|
||||
token_ref_target_masks = (token_ref_target_masks > 0)
|
||||
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
|
||||
token_ref_target_masks = token_ref_target_masks.to(x.dtype)
|
||||
|
||||
# teacache
|
||||
if self.enable_teacache:
|
||||
modulated_inp = e0 if self.use_ret_steps else e
|
||||
if self.cnt%3==0: # cond
|
||||
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||
should_calc_cond = True
|
||||
self.accumulated_rel_l1_distance_cond = 0
|
||||
else:
|
||||
rescale_func = np.poly1d(self.coefficients)
|
||||
self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance_cond < self.teacache_thresh:
|
||||
should_calc_cond = False
|
||||
else:
|
||||
should_calc_cond = True
|
||||
self.accumulated_rel_l1_distance_cond = 0
|
||||
self.previous_e0_cond = modulated_inp.clone()
|
||||
elif self.cnt%3==1: # drop_text
|
||||
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||
should_calc_drop_text = True
|
||||
self.accumulated_rel_l1_distance_drop_text = 0
|
||||
else:
|
||||
rescale_func = np.poly1d(self.coefficients)
|
||||
self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh:
|
||||
should_calc_drop_text = False
|
||||
else:
|
||||
should_calc_drop_text = True
|
||||
self.accumulated_rel_l1_distance_drop_text = 0
|
||||
self.previous_e0_drop_text = modulated_inp.clone()
|
||||
else: # uncond
|
||||
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||
should_calc_uncond = True
|
||||
self.accumulated_rel_l1_distance_uncond = 0
|
||||
else:
|
||||
rescale_func = np.poly1d(self.coefficients)
|
||||
self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh:
|
||||
should_calc_uncond = False
|
||||
else:
|
||||
should_calc_uncond = True
|
||||
self.accumulated_rel_l1_distance_uncond = 0
|
||||
self.previous_e0_uncond = modulated_inp.clone()
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context,
|
||||
context_lens=context_lens,
|
||||
audio_embedding=audio_embedding,
|
||||
ref_target_masks=token_ref_target_masks,
|
||||
human_num=human_num,
|
||||
)
|
||||
if self.enable_teacache:
|
||||
if self.cnt%3==0:
|
||||
if not should_calc_cond:
|
||||
x += self.previous_residual_cond
|
||||
else:
|
||||
ori_x = x.clone()
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
self.previous_residual_cond = x - ori_x
|
||||
elif self.cnt%3==1:
|
||||
if not should_calc_drop_text:
|
||||
x += self.previous_residual_drop_text
|
||||
else:
|
||||
ori_x = x.clone()
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
self.previous_residual_drop_text = x - ori_x
|
||||
else:
|
||||
if not should_calc_uncond:
|
||||
x += self.previous_residual_uncond
|
||||
else:
|
||||
ori_x = x.clone()
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
self.previous_residual_uncond = x - ori_x
|
||||
else:
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
if self.enable_teacache:
|
||||
self.cnt += 1
|
||||
if self.cnt >= self.num_steps:
|
||||
self.cnt = 0
|
||||
|
||||
return torch.stack(x).float()
|
||||
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
Reconstruct video tensors from patch embeddings.
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||||
grid_sizes (Tensor):
|
||||
Original spatial-temporal grid dimensions before patching,
|
||||
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
|
||||
c = self.out_dim
|
||||
out = []
|
||||
for u, v in zip(x, grid_sizes.tolist()):
|
||||
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
||||
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
||||
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def init_weights(self):
|
||||
r"""
|
||||
Initialize model parameters using Xavier initialization.
|
||||
"""
|
||||
|
||||
# basic init
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
# init embeddings
|
||||
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
||||
for m in self.text_embedding.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=.02)
|
||||
for m in self.time_embedding.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=.02)
|
||||
|
||||
# init output layer
|
||||
nn.init.zeros_(self.head.head.weight)
|
||||
353
wan/multitalk/multitalk_utils.py
Normal file
353
wan/multitalk/multitalk_utils.py
Normal file
@ -0,0 +1,353 @@
|
||||
import os
|
||||
from einops import rearrange
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from functools import lru_cache
|
||||
import imageio
|
||||
import uuid
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import soundfile as sf
|
||||
import torchvision
|
||||
import binascii
|
||||
import os.path as osp
|
||||
|
||||
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
ASPECT_RATIO_627 = {
|
||||
'0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1),
|
||||
'0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1),
|
||||
'1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1),
|
||||
'3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)}
|
||||
|
||||
|
||||
ASPECT_RATIO_960 = {
|
||||
'0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1),
|
||||
'0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1),
|
||||
'1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1),
|
||||
'1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1),
|
||||
'2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1),
|
||||
'3.75': ([1920, 512], 1)}
|
||||
|
||||
|
||||
|
||||
def torch_gc():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
|
||||
def split_token_counts_and_frame_ids(T, token_frame, world_size, rank):
|
||||
|
||||
S = T * token_frame
|
||||
split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)]
|
||||
start = sum(split_sizes[:rank])
|
||||
end = start + split_sizes[rank]
|
||||
counts = [0] * T
|
||||
for idx in range(start, end):
|
||||
t = idx // token_frame
|
||||
counts[t] += 1
|
||||
|
||||
counts_filtered = []
|
||||
frame_ids = []
|
||||
for t, c in enumerate(counts):
|
||||
if c > 0:
|
||||
counts_filtered.append(c)
|
||||
frame_ids.append(t)
|
||||
return counts_filtered, frame_ids
|
||||
|
||||
|
||||
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||||
|
||||
source_min, source_max = source_range
|
||||
new_min, new_max = target_range
|
||||
|
||||
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||||
scaled = normalized * (new_max - new_min) + new_min
|
||||
return scaled
|
||||
|
||||
|
||||
# @torch.compile
|
||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
|
||||
|
||||
ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
|
||||
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||
visual_q = visual_q * scale
|
||||
visual_q = visual_q.transpose(1, 2)
|
||||
ref_k = ref_k.transpose(1, 2)
|
||||
attn = visual_q @ ref_k.transpose(-2, -1)
|
||||
|
||||
if attn_bias is not None: attn += attn_bias
|
||||
|
||||
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
|
||||
|
||||
x_ref_attn_maps = []
|
||||
ref_target_masks = ref_target_masks.to(visual_q.dtype)
|
||||
x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
|
||||
|
||||
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||
ref_target_mask = ref_target_mask[None, None, None, ...]
|
||||
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
|
||||
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
|
||||
x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
|
||||
|
||||
if mode == 'mean':
|
||||
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
|
||||
elif mode == 'max':
|
||||
x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
|
||||
|
||||
x_ref_attn_maps.append(x_ref_attnmap)
|
||||
|
||||
del attn
|
||||
del x_ref_attn_map_source
|
||||
|
||||
return torch.concat(x_ref_attn_maps, dim=0)
|
||||
|
||||
|
||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
|
||||
"""Args:
|
||||
query (torch.tensor): B M H K
|
||||
key (torch.tensor): B M H K
|
||||
shape (tuple): (N_t, N_h, N_w)
|
||||
ref_target_masks: [B, N_h * N_w]
|
||||
"""
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_seqlens = N_h * N_w
|
||||
ref_k = ref_k[:, :x_seqlens]
|
||||
if ref_images_count > 0 :
|
||||
visual_q_shape = visual_q.shape
|
||||
visual_q = visual_q.reshape(visual_q_shape[0], N_t, -1)
|
||||
visual_q = visual_q[:, ref_images_count:]
|
||||
visual_q = visual_q.reshape(visual_q_shape[0], -1, *visual_q_shape[-2:])
|
||||
|
||||
_, seq_lens, heads, _ = visual_q.shape
|
||||
class_num, _ = ref_target_masks.shape
|
||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
|
||||
|
||||
split_chunk = heads // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
|
||||
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||||
|
||||
x_ref_attn_maps /= split_num
|
||||
return x_ref_attn_maps
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.base = 10000
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def precompute_freqs_cis_1d(self, pos_indices):
|
||||
|
||||
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||||
freqs = freqs.to(pos_indices.device)
|
||||
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||||
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, pos_indices):
|
||||
"""1D RoPE.
|
||||
|
||||
Args:
|
||||
query (torch.tensor): [B, head, seq, head_dim]
|
||||
pos_indices (torch.tensor): [seq,]
|
||||
Returns:
|
||||
query with the same shape as input.
|
||||
"""
|
||||
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||||
|
||||
x_ = x.float()
|
||||
|
||||
freqs_cis = freqs_cis.float().to(x.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
x_ = (x_ * cos) + (rotate_half(x_) * sin)
|
||||
|
||||
return x_.type_as(x)
|
||||
|
||||
|
||||
|
||||
def rand_name(length=8, suffix=''):
|
||||
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
||||
if suffix:
|
||||
if not suffix.startswith('.'):
|
||||
suffix = '.' + suffix
|
||||
name += suffix
|
||||
return name
|
||||
|
||||
def cache_video(tensor,
|
||||
save_file=None,
|
||||
fps=30,
|
||||
suffix='.mp4',
|
||||
nrow=8,
|
||||
normalize=True,
|
||||
value_range=(-1, 1),
|
||||
retry=5):
|
||||
|
||||
# cache file
|
||||
cache_file = osp.join('/tmp', rand_name(
|
||||
suffix=suffix)) if save_file is None else save_file
|
||||
|
||||
# save to cache
|
||||
error = None
|
||||
for _ in range(retry):
|
||||
|
||||
# preprocess
|
||||
tensor = tensor.clamp(min(value_range), max(value_range))
|
||||
tensor = torch.stack([
|
||||
torchvision.utils.make_grid(
|
||||
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||
for u in tensor.unbind(2)
|
||||
],
|
||||
dim=1).permute(1, 2, 3, 0)
|
||||
tensor = (tensor * 255).type(torch.uint8).cpu()
|
||||
|
||||
# write video
|
||||
writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"])
|
||||
for frame in tensor.numpy():
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
return cache_file
|
||||
|
||||
def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False):
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||
writer = imageio.get_writer(
|
||||
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
|
||||
)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
save_path_tmp = save_path + "-temp.mp4"
|
||||
|
||||
if high_quality_save:
|
||||
cache_video(
|
||||
tensor=gen_video_samples.unsqueeze(0),
|
||||
save_file=save_path_tmp,
|
||||
fps=fps,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1)
|
||||
)
|
||||
else:
|
||||
video_audio = (gen_video_samples+1)/2 # C T H W
|
||||
video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy()
|
||||
video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255]
|
||||
save_video(video_audio, save_path_tmp, fps=fps, quality=quality)
|
||||
|
||||
|
||||
# crop audio according to video length
|
||||
_, T, _, _ = gen_video_samples.shape
|
||||
duration = T / fps
|
||||
save_path_crop_audio = save_path + "-cropaudio.wav"
|
||||
final_command = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
vocal_audio_list[0],
|
||||
"-t",
|
||||
f'{duration}',
|
||||
save_path_crop_audio,
|
||||
]
|
||||
subprocess.run(final_command, check=True)
|
||||
|
||||
save_path = save_path + ".mp4"
|
||||
if high_quality_save:
|
||||
final_command = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i", save_path_tmp,
|
||||
"-i", save_path_crop_audio,
|
||||
"-c:v", "libx264",
|
||||
"-crf", "0",
|
||||
"-preset", "veryslow",
|
||||
"-c:a", "aac",
|
||||
"-shortest",
|
||||
save_path,
|
||||
]
|
||||
subprocess.run(final_command, check=True)
|
||||
os.remove(save_path_tmp)
|
||||
os.remove(save_path_crop_audio)
|
||||
else:
|
||||
final_command = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
save_path_tmp,
|
||||
"-i",
|
||||
save_path_crop_audio,
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-shortest",
|
||||
save_path,
|
||||
]
|
||||
subprocess.run(final_command, check=True)
|
||||
os.remove(save_path_tmp)
|
||||
os.remove(save_path_crop_audio)
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
|
||||
def project(
|
||||
v0: torch.Tensor, # [B, C, T, H, W]
|
||||
v1: torch.Tensor, # [B, C, T, H, W]
|
||||
):
|
||||
dtype = v0.dtype
|
||||
v0, v1 = v0.double(), v1.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4])
|
||||
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
||||
|
||||
|
||||
def adaptive_projected_guidance(
|
||||
diff: torch.Tensor, # [B, C, T, H, W]
|
||||
pred_cond: torch.Tensor, # [B, C, T, H, W]
|
||||
momentum_buffer: MomentumBuffer = None,
|
||||
eta: float = 0.0,
|
||||
norm_threshold: float = 55,
|
||||
):
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True)
|
||||
print(f"diff_norm: {diff_norm}")
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
diff_parallel, diff_orthogonal = project(diff, pred_cond)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
return normalized_update
|
||||
20
wan/multitalk/torch_utils.py
Normal file
20
wan/multitalk/torch_utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
lengths = lengths.to(torch.long)
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
|
||||
ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
|
||||
mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def linear_interpolation(features, seq_len):
|
||||
features = features.transpose(1, 2)
|
||||
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
||||
return output_features.transpose(1, 2)
|
||||
|
||||
125
wan/multitalk/wav2vec2.py
Normal file
125
wan/multitalk/wav2vec2.py
Normal file
@ -0,0 +1,125 @@
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2Model
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
|
||||
from .torch_utils import linear_interpolation
|
||||
|
||||
# the implementation of Wav2Vec2Model is borrowed from
|
||||
# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
||||
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
||||
class Wav2Vec2Model(Wav2Vec2Model):
|
||||
def __init__(self, config: Wav2Vec2Config):
|
||||
super().__init__(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
seq_len,
|
||||
attention_mask=None,
|
||||
mask_time_indices=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
self.config.output_attentions = True
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
extract_features = self.feature_extractor(input_values)
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(
|
||||
extract_features.shape[1], attention_mask, add_adapter=False
|
||||
)
|
||||
|
||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||
hidden_states = self._mask_hidden_states(
|
||||
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if self.adapter is not None:
|
||||
hidden_states = self.adapter(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states, ) + encoder_outputs[1:]
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def feature_extract(
|
||||
self,
|
||||
input_values,
|
||||
seq_len,
|
||||
):
|
||||
extract_features = self.feature_extractor(input_values)
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
||||
|
||||
return extract_features
|
||||
|
||||
def encode(
|
||||
self,
|
||||
extract_features,
|
||||
attention_mask=None,
|
||||
mask_time_indices=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
self.config.output_attentions = True
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(
|
||||
extract_features.shape[1], attention_mask, add_adapter=False
|
||||
)
|
||||
|
||||
|
||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||
hidden_states = self._mask_hidden_states(
|
||||
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if self.adapter is not None:
|
||||
hidden_states = self.adapter(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states, ) + encoder_outputs[1:]
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user