mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
133 lines
4.3 KiB
Python
133 lines
4.3 KiB
Python
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import open_clip
|
|
import pandas as pd
|
|
import torch
|
|
import torchaudio
|
|
from torch.utils.data.dataset import Dataset
|
|
|
|
log = logging.getLogger()
|
|
|
|
|
|
class WavTextClipsDataset(Dataset):
|
|
|
|
def __init__(
|
|
self,
|
|
root: Union[str, Path],
|
|
*,
|
|
captions_tsv: Union[str, Path],
|
|
clips_tsv: Union[str, Path],
|
|
sample_rate: int,
|
|
num_samples: int,
|
|
normalize_audio: bool = False,
|
|
reject_silent: bool = False,
|
|
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
|
):
|
|
self.root = Path(root)
|
|
self.sample_rate = sample_rate
|
|
self.num_samples = num_samples
|
|
self.normalize_audio = normalize_audio
|
|
self.reject_silent = reject_silent
|
|
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
|
|
|
audios = sorted(os.listdir(self.root))
|
|
audios = set([
|
|
Path(audio).stem for audio in audios
|
|
if audio.endswith('.wav') or audio.endswith('.flac')
|
|
])
|
|
self.captions = {}
|
|
|
|
# read the caption tsv
|
|
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
|
for record in df_list:
|
|
id = record['id']
|
|
caption = record['caption']
|
|
self.captions[id] = caption
|
|
|
|
# read the clip tsv
|
|
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
|
'id': str,
|
|
'name': str
|
|
}).to_dict('records')
|
|
self.clips = []
|
|
for record in df_list:
|
|
record['id'] = record['id']
|
|
record['name'] = record['name']
|
|
id = record['id']
|
|
name = record['name']
|
|
if name not in self.captions:
|
|
log.warning(f'Audio {name} not found in {captions_tsv}')
|
|
continue
|
|
record['caption'] = self.captions[name]
|
|
self.clips.append(record)
|
|
|
|
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
|
|
|
self.resampler = {}
|
|
|
|
def __getitem__(self, idx: int) -> torch.Tensor:
|
|
try:
|
|
clip = self.clips[idx]
|
|
audio_name = clip['name']
|
|
audio_id = clip['id']
|
|
caption = clip['caption']
|
|
start_sample = clip['start_sample']
|
|
end_sample = clip['end_sample']
|
|
|
|
audio_path = self.root / f'{audio_name}.flac'
|
|
if not audio_path.exists():
|
|
audio_path = self.root / f'{audio_name}.wav'
|
|
assert audio_path.exists()
|
|
|
|
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
|
audio_chunk = audio_chunk.mean(dim=0) # mono
|
|
abs_max = audio_chunk.abs().max()
|
|
if self.normalize_audio:
|
|
audio_chunk = audio_chunk / abs_max * 0.95
|
|
|
|
if self.reject_silent and abs_max < 1e-6:
|
|
log.warning(f'Rejecting silent audio')
|
|
return None
|
|
|
|
audio_chunk = audio_chunk[start_sample:end_sample]
|
|
|
|
# resample
|
|
if sample_rate == self.sample_rate:
|
|
audio_chunk = audio_chunk
|
|
else:
|
|
if sample_rate not in self.resampler:
|
|
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
|
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
|
sample_rate,
|
|
self.sample_rate,
|
|
lowpass_filter_width=64,
|
|
rolloff=0.9475937167399596,
|
|
resampling_method='sinc_interp_kaiser',
|
|
beta=14.769656459379492,
|
|
)
|
|
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
|
|
|
if audio_chunk.shape[0] < self.num_samples:
|
|
raise ValueError('Audio is too short')
|
|
audio_chunk = audio_chunk[:self.num_samples]
|
|
|
|
tokens = self.tokenizer([caption])[0]
|
|
|
|
output = {
|
|
'waveform': audio_chunk,
|
|
'id': audio_id,
|
|
'caption': caption,
|
|
'tokens': tokens,
|
|
}
|
|
|
|
return output
|
|
except Exception as e:
|
|
log.error(f'Error reading {audio_path}: {e}')
|
|
return None
|
|
|
|
def __len__(self):
|
|
return len(self.clips)
|