mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
446 lines
18 KiB
Python
446 lines
18 KiB
Python
from .model import KModel
|
|
from dataclasses import dataclass
|
|
from huggingface_hub import hf_hub_download
|
|
from loguru import logger
|
|
from misaki import en, espeak
|
|
from typing import Callable, Generator, List, Optional, Tuple, Union
|
|
import re
|
|
import torch
|
|
import os
|
|
|
|
ALIASES = {
|
|
'en-us': 'a',
|
|
'en-gb': 'b',
|
|
'es': 'e',
|
|
'fr-fr': 'f',
|
|
'hi': 'h',
|
|
'it': 'i',
|
|
'pt-br': 'p',
|
|
'ja': 'j',
|
|
'zh': 'z',
|
|
}
|
|
|
|
LANG_CODES = dict(
|
|
# pip install misaki[en]
|
|
a='American English',
|
|
b='British English',
|
|
|
|
# espeak-ng
|
|
e='es',
|
|
f='fr-fr',
|
|
h='hi',
|
|
i='it',
|
|
p='pt-br',
|
|
|
|
# pip install misaki[ja]
|
|
j='Japanese',
|
|
|
|
# pip install misaki[zh]
|
|
z='Mandarin Chinese',
|
|
)
|
|
|
|
class KPipeline:
|
|
'''
|
|
KPipeline is a language-aware support class with 2 main responsibilities:
|
|
1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
|
|
2. Manage and store voices, lazily downloaded from HF if needed
|
|
|
|
You are expected to have one KPipeline per language. If you have multiple
|
|
KPipelines, you should reuse one KModel instance across all of them.
|
|
|
|
KPipeline is designed to work with a KModel, but this is not required.
|
|
There are 2 ways to pass an existing model into a pipeline:
|
|
1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
|
|
2. On call: us_pipeline(text, voice, model=model)
|
|
|
|
By default, KPipeline will automatically initialize its own KModel. To
|
|
suppress this, construct a "quiet" KPipeline with model=False.
|
|
|
|
A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
|
|
any audio. You can use this to phonemize and chunk your text in advance.
|
|
|
|
A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
|
|
'''
|
|
def __init__(
|
|
self,
|
|
lang_code: str,
|
|
repo_id: Optional[str] = None,
|
|
model: Union[KModel, bool] = True,
|
|
trf: bool = False,
|
|
en_callable: Optional[Callable[[str], str]] = None,
|
|
device: Optional[str] = None
|
|
):
|
|
"""Initialize a KPipeline.
|
|
|
|
Args:
|
|
lang_code: Language code for G2P processing
|
|
model: KModel instance, True to create new model, False for no model
|
|
trf: Whether to use transformer-based G2P
|
|
device: Override default device selection ('cuda' or 'cpu', or None for auto)
|
|
If None, will auto-select cuda if available
|
|
If 'cuda' and not available, will explicitly raise an error
|
|
"""
|
|
if repo_id is None:
|
|
repo_id = 'hexgrad/Kokoro-82M'
|
|
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
|
config=None
|
|
else:
|
|
config = os.path.join(repo_id, 'config.json')
|
|
self.repo_id = repo_id
|
|
lang_code = lang_code.lower()
|
|
lang_code = ALIASES.get(lang_code, lang_code)
|
|
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
|
self.lang_code = lang_code
|
|
self.model = None
|
|
if isinstance(model, KModel):
|
|
self.model = model
|
|
elif model:
|
|
if device == 'cuda' and not torch.cuda.is_available():
|
|
raise RuntimeError("CUDA requested but not available")
|
|
if device == 'mps' and not torch.backends.mps.is_available():
|
|
raise RuntimeError("MPS requested but not available")
|
|
if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
|
|
raise RuntimeError("MPS requested but fallback not enabled")
|
|
if device is None:
|
|
if torch.cuda.is_available():
|
|
device = 'cuda'
|
|
elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
|
|
device = 'mps'
|
|
else:
|
|
device = 'cpu'
|
|
try:
|
|
self.model = KModel(repo_id=repo_id, config=config).to(device).eval()
|
|
except RuntimeError as e:
|
|
if device == 'cuda':
|
|
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
|
|
Try setting device='cpu' or check CUDA installation.""")
|
|
raise
|
|
self.voices = {}
|
|
if lang_code in 'ab':
|
|
try:
|
|
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
|
except Exception as e:
|
|
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
|
|
logger.warning({str(e)})
|
|
fallback = None
|
|
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
|
|
elif lang_code == 'j':
|
|
try:
|
|
from misaki import ja
|
|
self.g2p = ja.JAG2P()
|
|
except ImportError:
|
|
logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
|
|
raise
|
|
elif lang_code == 'z':
|
|
try:
|
|
from misaki import zh
|
|
self.g2p = zh.ZHG2P(
|
|
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
|
|
en_callable=en_callable
|
|
)
|
|
except ImportError:
|
|
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
|
|
raise
|
|
else:
|
|
language = LANG_CODES[lang_code]
|
|
logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
|
|
self.g2p = espeak.EspeakG2P(language=language)
|
|
|
|
def load_single_voice(self, voice: str):
|
|
if voice in self.voices:
|
|
return self.voices[voice]
|
|
if voice.endswith('.pt'):
|
|
f = voice
|
|
else:
|
|
f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
|
|
if not voice.startswith(self.lang_code):
|
|
v = LANG_CODES.get(voice, voice)
|
|
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
|
logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
|
|
pack = torch.load(f, weights_only=True)
|
|
self.voices[voice] = pack
|
|
return pack
|
|
|
|
"""
|
|
load_voice is a helper function that lazily downloads and loads a voice:
|
|
Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica').
|
|
If multiple voices are requested, they are averaged.
|
|
Delimiter is optional and defaults to ','.
|
|
"""
|
|
def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
|
|
if isinstance(voice, torch.FloatTensor):
|
|
return voice
|
|
if voice in self.voices:
|
|
return self.voices[voice]
|
|
logger.debug(f"Loading voice: {voice}")
|
|
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
|
|
if len(packs) == 1:
|
|
return packs[0]
|
|
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
|
return self.voices[voice]
|
|
|
|
@staticmethod
|
|
def tokens_to_ps(tokens: List[en.MToken]) -> str:
|
|
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
|
|
|
|
@staticmethod
|
|
def waterfall_last(
|
|
tokens: List[en.MToken],
|
|
next_count: int,
|
|
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
|
bumps: List[str] = [')', '”']
|
|
) -> int:
|
|
for w in waterfall:
|
|
z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
|
|
if z is None:
|
|
continue
|
|
z += 1
|
|
if z < len(tokens) and tokens[z].phonemes in bumps:
|
|
z += 1
|
|
if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
|
|
return z
|
|
return len(tokens)
|
|
|
|
@staticmethod
|
|
def tokens_to_text(tokens: List[en.MToken]) -> str:
|
|
return ''.join(t.text + t.whitespace for t in tokens).strip()
|
|
|
|
def en_tokenize(
|
|
self,
|
|
tokens: List[en.MToken]
|
|
) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
|
|
tks = []
|
|
pcount = 0
|
|
for t in tokens:
|
|
# American English: ɾ => T
|
|
t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T')
|
|
next_ps = t.phonemes + (' ' if t.whitespace else '')
|
|
next_pcount = pcount + len(next_ps.rstrip())
|
|
if next_pcount > 510:
|
|
z = KPipeline.waterfall_last(tks, next_pcount)
|
|
text = KPipeline.tokens_to_text(tks[:z])
|
|
logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
|
|
ps = KPipeline.tokens_to_ps(tks[:z])
|
|
yield text, ps, tks[:z]
|
|
tks = tks[z:]
|
|
pcount = len(KPipeline.tokens_to_ps(tks))
|
|
if not tks:
|
|
next_ps = next_ps.lstrip()
|
|
tks.append(t)
|
|
pcount += len(next_ps)
|
|
if tks:
|
|
text = KPipeline.tokens_to_text(tks)
|
|
ps = KPipeline.tokens_to_ps(tks)
|
|
yield ''.join(text).strip(), ''.join(ps).strip(), tks
|
|
|
|
@staticmethod
|
|
def infer(
|
|
model: KModel,
|
|
ps: str,
|
|
pack: torch.FloatTensor,
|
|
speed: Union[float, Callable[[int], float]] = 1
|
|
) -> KModel.Output:
|
|
if callable(speed):
|
|
speed = speed(len(ps))
|
|
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
|
|
|
def generate_from_tokens(
|
|
self,
|
|
tokens: Union[str, List[en.MToken]],
|
|
voice: str,
|
|
speed: float = 1,
|
|
model: Optional[KModel] = None
|
|
) -> Generator['KPipeline.Result', None, None]:
|
|
"""Generate audio from either raw phonemes or pre-processed tokens.
|
|
|
|
Args:
|
|
tokens: Either a phoneme string or list of pre-processed MTokens
|
|
voice: The voice to use for synthesis
|
|
speed: Speech speed modifier (default: 1)
|
|
model: Optional KModel instance (uses pipeline's model if not provided)
|
|
|
|
Yields:
|
|
KPipeline.Result containing the input tokens and generated audio
|
|
|
|
Raises:
|
|
ValueError: If no voice is provided or token sequence exceeds model limits
|
|
"""
|
|
model = model or self.model
|
|
if model and voice is None:
|
|
raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')
|
|
|
|
pack = self.load_voice(voice).to(model.device) if model else None
|
|
|
|
# Handle raw phoneme string
|
|
if isinstance(tokens, str):
|
|
logger.debug("Processing phonemes from raw string")
|
|
if len(tokens) > 510:
|
|
raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
|
|
output = KPipeline.infer(model, tokens, pack, speed) if model else None
|
|
yield self.Result(graphemes='', phonemes=tokens, output=output)
|
|
return
|
|
|
|
logger.debug("Processing MTokens")
|
|
# Handle pre-processed tokens
|
|
for gs, ps, tks in self.en_tokenize(tokens):
|
|
if not ps:
|
|
continue
|
|
elif len(ps) > 510:
|
|
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
|
logger.warning("Truncating to 510 characters")
|
|
ps = ps[:510]
|
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
|
if output is not None and output.pred_dur is not None:
|
|
KPipeline.join_timestamps(tks, output.pred_dur)
|
|
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
|
|
|
@staticmethod
|
|
def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
|
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
|
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
|
|
# We will count nice round half-frames, so the divisor is 80
|
|
MAGIC_DIVISOR = 80
|
|
if not tokens or len(pred_dur) < 3:
|
|
# We expect at least 3: <bos>, token, <eos>
|
|
return
|
|
# We track 2 counts, measured in half-frames: (left, right)
|
|
# This way we can cut space characters in half
|
|
# TODO: Is -3 an appropriate offset?
|
|
left = right = 2 * max(0, pred_dur[0].item() - 3)
|
|
# Updates:
|
|
# left = right + (2 * token_dur) + space_dur
|
|
# right = left + space_dur
|
|
i = 1
|
|
for t in tokens:
|
|
if i >= len(pred_dur)-1:
|
|
break
|
|
if not t.phonemes:
|
|
if t.whitespace:
|
|
i += 1
|
|
left = right + pred_dur[i].item()
|
|
right = left + pred_dur[i].item()
|
|
i += 1
|
|
continue
|
|
j = i + len(t.phonemes)
|
|
if j >= len(pred_dur):
|
|
break
|
|
t.start_ts = left / MAGIC_DIVISOR
|
|
token_dur = pred_dur[i: j].sum().item()
|
|
space_dur = pred_dur[j].item() if t.whitespace else 0
|
|
left = right + (2 * token_dur) + space_dur
|
|
t.end_ts = left / MAGIC_DIVISOR
|
|
right = left + space_dur
|
|
i = j + (1 if t.whitespace else 0)
|
|
|
|
@dataclass
|
|
class Result:
|
|
graphemes: str
|
|
phonemes: str
|
|
tokens: Optional[List[en.MToken]] = None
|
|
output: Optional[KModel.Output] = None
|
|
text_index: Optional[int] = None
|
|
|
|
@property
|
|
def audio(self) -> Optional[torch.FloatTensor]:
|
|
return None if self.output is None else self.output.audio
|
|
|
|
@property
|
|
def pred_dur(self) -> Optional[torch.LongTensor]:
|
|
return None if self.output is None else self.output.pred_dur
|
|
|
|
### MARK: BEGIN BACKWARD COMPAT ###
|
|
def __iter__(self):
|
|
yield self.graphemes
|
|
yield self.phonemes
|
|
yield self.audio
|
|
|
|
def __getitem__(self, index):
|
|
return [self.graphemes, self.phonemes, self.audio][index]
|
|
|
|
def __len__(self):
|
|
return 3
|
|
#### MARK: END BACKWARD COMPAT ####
|
|
|
|
def __call__(
|
|
self,
|
|
text: Union[str, List[str]],
|
|
voice: Optional[str] = None,
|
|
speed: Union[float, Callable[[int], float]] = 1,
|
|
split_pattern: Optional[str] = r'\n+',
|
|
model: Optional[KModel] = None
|
|
) -> Generator['KPipeline.Result', None, None]:
|
|
model = model or self.model
|
|
if model and voice is None:
|
|
raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
|
|
pack = self.load_voice(voice).to(model.device) if model else None
|
|
|
|
# Convert input to list of segments
|
|
if isinstance(text, str):
|
|
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
|
|
|
# Process each segment
|
|
for graphemes_index, graphemes in enumerate(text):
|
|
if not graphemes.strip(): # Skip empty segments
|
|
continue
|
|
|
|
# English processing (unchanged)
|
|
if self.lang_code in 'ab':
|
|
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
|
|
_, tokens = self.g2p(graphemes)
|
|
for gs, ps, tks in self.en_tokenize(tokens):
|
|
if not ps:
|
|
continue
|
|
elif len(ps) > 510:
|
|
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
|
ps = ps[:510]
|
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
|
if output is not None and output.pred_dur is not None:
|
|
KPipeline.join_timestamps(tks, output.pred_dur)
|
|
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index)
|
|
|
|
# Non-English processing with chunking
|
|
else:
|
|
# Split long text into smaller chunks (roughly 400 characters each)
|
|
# Using sentence boundaries when possible
|
|
chunk_size = 400
|
|
chunks = []
|
|
|
|
# Try to split on sentence boundaries first
|
|
sentences = re.split(r'([.!?]+)', graphemes)
|
|
current_chunk = ""
|
|
|
|
for i in range(0, len(sentences), 2):
|
|
sentence = sentences[i]
|
|
# Add the punctuation back if it exists
|
|
if i + 1 < len(sentences):
|
|
sentence += sentences[i + 1]
|
|
|
|
if len(current_chunk) + len(sentence) <= chunk_size:
|
|
current_chunk += sentence
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = sentence
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
# If no chunks were created (no sentence boundaries), fall back to character-based chunking
|
|
if not chunks:
|
|
chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)]
|
|
|
|
# Process each chunk
|
|
for chunk in chunks:
|
|
if not chunk.strip():
|
|
continue
|
|
|
|
ps, _ = self.g2p(chunk)
|
|
if not ps:
|
|
continue
|
|
elif len(ps) > 510:
|
|
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
|
|
ps = ps[:510]
|
|
|
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
|
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
|