mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
398 lines
17 KiB
Python
398 lines
17 KiB
Python
import random
|
|
import os
|
|
import torch
|
|
import torch.distributed as dist
|
|
from PIL import Image
|
|
import subprocess
|
|
import torchvision.transforms as transforms
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
# from shared.utils.multitalk_utils import save_video_ffmpeg
|
|
# from .kokoro import KPipeline
|
|
from transformers import Wav2Vec2FeatureExtractor
|
|
from .wav2vec2 import Wav2Vec2Model
|
|
|
|
import librosa
|
|
import pyloudnorm as pyln
|
|
import numpy as np
|
|
from einops import rearrange
|
|
import soundfile as sf
|
|
import re
|
|
import math
|
|
|
|
def custom_init(device, wav2vec):
|
|
audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device)
|
|
audio_encoder.feature_extractor._freeze_parameters()
|
|
wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
|
|
return wav2vec_feature_extractor, audio_encoder
|
|
|
|
def loudness_norm(audio_array, sr=16000, lufs=-23):
|
|
meter = pyln.Meter(sr)
|
|
loudness = meter.integrated_loudness(audio_array)
|
|
if abs(loudness) > 100:
|
|
return audio_array
|
|
normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
|
|
return normalized_audio
|
|
|
|
|
|
def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps = 25):
|
|
audio_duration = len(speech_array) / sr
|
|
video_length = audio_duration * fps
|
|
|
|
# wav2vec_feature_extractor
|
|
audio_feature = np.squeeze(
|
|
wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
|
|
)
|
|
audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
|
|
audio_feature = audio_feature.unsqueeze(0)
|
|
|
|
# audio encoder
|
|
with torch.no_grad():
|
|
embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
|
|
|
|
if len(embeddings) == 0:
|
|
print("Fail to extract audio embedding")
|
|
return None
|
|
|
|
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
|
|
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
|
|
|
audio_emb = audio_emb.cpu().detach()
|
|
return audio_emb
|
|
|
|
def extract_audio_from_video(filename, sample_rate):
|
|
raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav'
|
|
ffmpeg_command = [
|
|
"ffmpeg",
|
|
"-y",
|
|
"-i",
|
|
str(filename),
|
|
"-vn",
|
|
"-acodec",
|
|
"pcm_s16le",
|
|
"-ar",
|
|
"16000",
|
|
"-ac",
|
|
"2",
|
|
str(raw_audio_path),
|
|
]
|
|
subprocess.run(ffmpeg_command, check=True)
|
|
human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate)
|
|
human_speech_array = loudness_norm(human_speech_array, sr)
|
|
os.remove(raw_audio_path)
|
|
|
|
return human_speech_array
|
|
|
|
def audio_prepare_single(audio_path, sample_rate=16000, duration = 0):
|
|
ext = os.path.splitext(audio_path)[1].lower()
|
|
if ext in ['.mp4', '.mov', '.avi', '.mkv']:
|
|
human_speech_array = extract_audio_from_video(audio_path, sample_rate)
|
|
return human_speech_array
|
|
else:
|
|
human_speech_array, sr = librosa.load(audio_path, duration=duration, sr=sample_rate)
|
|
human_speech_array = loudness_norm(human_speech_array, sr)
|
|
return human_speech_array
|
|
|
|
|
|
def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0, pad = 0, min_audio_duration = 0):
|
|
if not (left_path==None or right_path==None):
|
|
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
|
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
|
else:
|
|
audio_type='para'
|
|
if left_path==None:
|
|
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
|
human_speech_array1 = np.zeros(human_speech_array2.shape[0])
|
|
elif right_path==None:
|
|
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
|
human_speech_array2 = np.zeros(human_speech_array1.shape[0])
|
|
|
|
if audio_type=='para':
|
|
new_human_speech1 = human_speech_array1
|
|
new_human_speech2 = human_speech_array2
|
|
if len(new_human_speech1) != len(new_human_speech2):
|
|
if len(new_human_speech1) < len(new_human_speech2):
|
|
new_human_speech1 = np.pad(new_human_speech1, (0, len(new_human_speech2) - len(new_human_speech1)))
|
|
else:
|
|
new_human_speech2 = np.pad(new_human_speech2, (0, len(new_human_speech1) - len(new_human_speech2)))
|
|
elif audio_type=='add':
|
|
new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])])
|
|
new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]])
|
|
|
|
|
|
duration_changed = False
|
|
if min_audio_duration > 0:
|
|
min_samples = math.ceil( min_audio_duration * sample_rate)
|
|
if len(new_human_speech1) < min_samples:
|
|
new_human_speech1 = np.concatenate([new_human_speech1, np.zeros(min_samples -len(new_human_speech1)) ])
|
|
duration_changed = True
|
|
if len(new_human_speech2) < min_samples:
|
|
new_human_speech2 = np.concatenate([new_human_speech2, np.zeros(min_samples -len(new_human_speech2)) ])
|
|
duration_changed = True
|
|
|
|
#dont include the padding on the summed audio which is used to build the output audio track
|
|
sum_human_speechs = new_human_speech1 + new_human_speech2
|
|
|
|
if pad > 0:
|
|
duration_changed = True
|
|
new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1])
|
|
new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2])
|
|
|
|
return new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed
|
|
|
|
|
|
def process_tts_single(text, save_dir, voice1):
|
|
s1_sentences = []
|
|
|
|
pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
|
|
|
|
voice_tensor = torch.load(voice1, weights_only=True)
|
|
generator = pipeline(
|
|
text, voice=voice_tensor, # <= change voice here
|
|
speed=1, split_pattern=r'\n+'
|
|
)
|
|
audios = []
|
|
for i, (gs, ps, audio) in enumerate(generator):
|
|
audios.append(audio)
|
|
audios = torch.concat(audios, dim=0)
|
|
s1_sentences.append(audios)
|
|
s1_sentences = torch.concat(s1_sentences, dim=0)
|
|
save_path1 =f'{save_dir}/s1.wav'
|
|
sf.write(save_path1, s1_sentences, 24000) # save each audio file
|
|
s1, _ = librosa.load(save_path1, sr=16000)
|
|
return s1, save_path1
|
|
|
|
|
|
|
|
def process_tts_multi(text, save_dir, voice1, voice2):
|
|
pattern = r'\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)'
|
|
matches = re.findall(pattern, text, re.DOTALL)
|
|
|
|
s1_sentences = []
|
|
s2_sentences = []
|
|
|
|
pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M')
|
|
for idx, (speaker, content) in enumerate(matches):
|
|
if speaker == '1':
|
|
voice_tensor = torch.load(voice1, weights_only=True)
|
|
generator = pipeline(
|
|
content, voice=voice_tensor, # <= change voice here
|
|
speed=1, split_pattern=r'\n+'
|
|
)
|
|
audios = []
|
|
for i, (gs, ps, audio) in enumerate(generator):
|
|
audios.append(audio)
|
|
audios = torch.concat(audios, dim=0)
|
|
s1_sentences.append(audios)
|
|
s2_sentences.append(torch.zeros_like(audios))
|
|
elif speaker == '2':
|
|
voice_tensor = torch.load(voice2, weights_only=True)
|
|
generator = pipeline(
|
|
content, voice=voice_tensor, # <= change voice here
|
|
speed=1, split_pattern=r'\n+'
|
|
)
|
|
audios = []
|
|
for i, (gs, ps, audio) in enumerate(generator):
|
|
audios.append(audio)
|
|
audios = torch.concat(audios, dim=0)
|
|
s2_sentences.append(audios)
|
|
s1_sentences.append(torch.zeros_like(audios))
|
|
|
|
s1_sentences = torch.concat(s1_sentences, dim=0)
|
|
s2_sentences = torch.concat(s2_sentences, dim=0)
|
|
sum_sentences = s1_sentences + s2_sentences
|
|
save_path1 =f'{save_dir}/s1.wav'
|
|
save_path2 =f'{save_dir}/s2.wav'
|
|
save_path_sum = f'{save_dir}/sum.wav'
|
|
sf.write(save_path1, s1_sentences, 24000) # save each audio file
|
|
sf.write(save_path2, s2_sentences, 24000)
|
|
sf.write(save_path_sum, sum_sentences, 24000)
|
|
|
|
s1, _ = librosa.load(save_path1, sr=16000)
|
|
s2, _ = librosa.load(save_path2, sr=16000)
|
|
# sum, _ = librosa.load(save_path_sum, sr=16000)
|
|
return s1, s2, save_path_sum
|
|
|
|
|
|
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False):
|
|
wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
|
|
# wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
|
|
pad = int(padded_frames_for_embeddings/ fps * sr)
|
|
new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration )
|
|
if return_sum_only:
|
|
full_audio_embs = None
|
|
else:
|
|
audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
|
|
audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
|
|
full_audio_embs = []
|
|
if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
|
|
if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
|
|
if audio_guide2 == None and not duration_changed: sum_human_speechs = None
|
|
return full_audio_embs, sum_human_speechs
|
|
|
|
|
|
def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5):
|
|
if full_audio_embs == None: return None
|
|
HUMAN_NUMBER = len(full_audio_embs)
|
|
audio_end_idx = audio_start_idx + clip_length
|
|
indices = (torch.arange(2 * 2 + 1) - 2) * 1
|
|
|
|
audio_embs = []
|
|
# split audio with window size
|
|
for human_idx in range(HUMAN_NUMBER):
|
|
center_indices = torch.arange(
|
|
audio_start_idx,
|
|
audio_end_idx,
|
|
1
|
|
).unsqueeze(
|
|
1
|
|
) + indices.unsqueeze(0)
|
|
center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1).to(full_audio_embs[human_idx].device)
|
|
audio_emb = full_audio_embs[human_idx][center_indices][None,...] #.to(self.device)
|
|
audio_embs.append(audio_emb)
|
|
audio_embs = torch.concat(audio_embs, dim=0) #.to(self.param_dtype)
|
|
|
|
# audio_cond = audio.to(device=x.device, dtype=x.dtype)
|
|
audio_cond = audio_embs
|
|
first_frame_audio_emb_s = audio_cond[:, :1, ...]
|
|
latter_frame_audio_emb = audio_cond[:, 1:, ...]
|
|
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=vae_scale)
|
|
middle_index = audio_window // 2
|
|
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
|
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
|
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
|
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
|
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
|
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
|
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
|
|
|
return [first_frame_audio_emb_s, latter_frame_audio_emb_s]
|
|
|
|
def resize_and_centercrop(cond_image, target_size):
|
|
"""
|
|
Resize image or tensor to the target size without padding.
|
|
"""
|
|
|
|
# Get the original size
|
|
if isinstance(cond_image, torch.Tensor):
|
|
_, orig_h, orig_w = cond_image.shape
|
|
else:
|
|
orig_h, orig_w = cond_image.height, cond_image.width
|
|
|
|
target_h, target_w = target_size
|
|
|
|
# Calculate the scaling factor for resizing
|
|
scale_h = target_h / orig_h
|
|
scale_w = target_w / orig_w
|
|
|
|
# Compute the final size
|
|
scale = max(scale_h, scale_w)
|
|
final_h = math.ceil(scale * orig_h)
|
|
final_w = math.ceil(scale * orig_w)
|
|
|
|
# Resize
|
|
if isinstance(cond_image, torch.Tensor):
|
|
if len(cond_image.shape) == 3:
|
|
cond_image = cond_image[None]
|
|
resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous()
|
|
# crop
|
|
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
|
|
cropped_tensor = cropped_tensor.squeeze(0)
|
|
else:
|
|
resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
|
|
resized_image = np.array(resized_image)
|
|
# tensor and crop
|
|
resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
|
|
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
|
|
cropped_tensor = cropped_tensor[:, :, None, :, :]
|
|
|
|
return cropped_tensor
|
|
|
|
|
|
def timestep_transform(
|
|
t,
|
|
shift=5.0,
|
|
num_timesteps=1000,
|
|
):
|
|
t = t / num_timesteps
|
|
# shift the timestep based on ratio
|
|
new_t = shift * t / (1 + (shift - 1) * t)
|
|
new_t = new_t * num_timesteps
|
|
return new_t
|
|
|
|
def parse_speakers_locations(speakers_locations):
|
|
bbox = {}
|
|
if speakers_locations is None or len(speakers_locations) == 0:
|
|
return None, ""
|
|
speakers = speakers_locations.split(" ")
|
|
if len(speakers) !=2:
|
|
error= "Two speakers locations should be defined"
|
|
return "", error
|
|
|
|
for i, speaker in enumerate(speakers):
|
|
location = speaker.strip().split(":")
|
|
if len(location) not in (2,4):
|
|
error = f"Invalid Speaker Location '{location}'. A Speaker Location should be defined in the format Left:Right or usuing a BBox Left:Top:Right:Bottom"
|
|
return "", error
|
|
try:
|
|
good = False
|
|
location_float = [ float(val) for val in location]
|
|
good = all( 0 <= val <= 100 for val in location_float)
|
|
except:
|
|
pass
|
|
if not good:
|
|
error = f"Invalid Speaker Location '{location}'. Each number should be between 0 and 100."
|
|
return "", error
|
|
if len(location_float) == 2:
|
|
location_float = [location_float[0], 0, location_float[1], 100]
|
|
bbox[f"human{i}"] = location_float
|
|
return bbox, ""
|
|
|
|
|
|
# construct human mask
|
|
def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None):
|
|
human_masks = []
|
|
if HUMAN_NUMBER==1:
|
|
background_mask = torch.ones([src_h, src_w])
|
|
human_mask1 = torch.ones([src_h, src_w])
|
|
human_mask2 = torch.ones([src_h, src_w])
|
|
human_masks = [human_mask1, human_mask2, background_mask]
|
|
elif HUMAN_NUMBER==2:
|
|
if bbox != None:
|
|
assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio"
|
|
background_mask = torch.zeros([src_h, src_w])
|
|
for _, person_bbox in bbox.items():
|
|
y_min, x_min, y_max, x_max = person_bbox
|
|
x_min, y_min, x_max, y_max = max(x_min,5), max(y_min, 5), min(x_max,95), min(y_max,95)
|
|
x_min, y_min, x_max, y_max = int(src_h * x_min / 100), int(src_w * y_min / 100), int(src_h * x_max / 100), int(src_w * y_max / 100)
|
|
human_mask = torch.zeros([src_h, src_w])
|
|
human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
|
|
background_mask += human_mask
|
|
human_masks.append(human_mask)
|
|
else:
|
|
x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale))
|
|
background_mask = torch.zeros([src_h, src_w])
|
|
background_mask = torch.zeros([src_h, src_w])
|
|
human_mask1 = torch.zeros([src_h, src_w])
|
|
human_mask2 = torch.zeros([src_h, src_w])
|
|
lefty_min, lefty_max = int((src_w//2) * face_scale), int((src_w//2) * (1 - face_scale))
|
|
righty_min, righty_max = int((src_w//2) * face_scale + (src_w//2)), int((src_w//2) * (1 - face_scale) + (src_w//2))
|
|
human_mask1[x_min:x_max, lefty_min:lefty_max] = 1
|
|
human_mask2[x_min:x_max, righty_min:righty_max] = 1
|
|
background_mask += human_mask1
|
|
background_mask += human_mask2
|
|
human_masks = [human_mask1, human_mask2]
|
|
background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1))
|
|
human_masks.append(background_mask)
|
|
# toto = Image.fromarray(human_masks[2].mul_(255).unsqueeze(-1).repeat(1,1,3).to(torch.uint8).cpu().numpy())
|
|
ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device)
|
|
# resize and centercrop for ref_target_masks
|
|
# ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w))
|
|
N_h, N_w = lat_h // 2, lat_w // 2
|
|
token_ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(N_h, N_w), mode='nearest').squeeze()
|
|
token_ref_target_masks = (token_ref_target_masks > 0)
|
|
token_ref_target_masks = token_ref_target_masks.float() #.to(self.device)
|
|
|
|
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
|
|
|
|
return token_ref_target_masks |