mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import torch
|
|
from torch.utils.data.dataset import Dataset
|
|
from torchvision.transforms import v2
|
|
from torio.io import StreamingMediaDecoder
|
|
|
|
from ...utils.dist_utils import local_rank
|
|
|
|
log = logging.getLogger()
|
|
|
|
_CLIP_SIZE = 384
|
|
_CLIP_FPS = 8.0
|
|
|
|
_SYNC_SIZE = 224
|
|
_SYNC_FPS = 25.0
|
|
|
|
|
|
class MovieGenData(Dataset):
|
|
|
|
def __init__(
|
|
self,
|
|
video_root: Union[str, Path],
|
|
sync_root: Union[str, Path],
|
|
jsonl_root: Union[str, Path],
|
|
*,
|
|
duration_sec: float = 10.0,
|
|
read_clip: bool = True,
|
|
):
|
|
self.video_root = Path(video_root)
|
|
self.sync_root = Path(sync_root)
|
|
self.jsonl_root = Path(jsonl_root)
|
|
self.read_clip = read_clip
|
|
|
|
videos = sorted(os.listdir(self.video_root))
|
|
videos = [v[:-4] for v in videos] # remove extensions
|
|
self.captions = {}
|
|
|
|
for v in videos:
|
|
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
|
data = json.load(f)
|
|
self.captions[v] = data['audio_prompt']
|
|
|
|
if local_rank == 0:
|
|
log.info(f'{len(videos)} videos found in {video_root}')
|
|
|
|
self.duration_sec = duration_sec
|
|
|
|
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
|
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
|
|
|
self.clip_augment = v2.Compose([
|
|
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
v2.ToImage(),
|
|
v2.ToDtype(torch.float32, scale=True),
|
|
])
|
|
|
|
self.sync_augment = v2.Compose([
|
|
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
v2.CenterCrop(_SYNC_SIZE),
|
|
v2.ToImage(),
|
|
v2.ToDtype(torch.float32, scale=True),
|
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
])
|
|
|
|
self.videos = videos
|
|
|
|
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
|
video_id = self.videos[idx]
|
|
caption = self.captions[video_id]
|
|
|
|
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
|
reader.add_basic_video_stream(
|
|
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
|
frame_rate=_CLIP_FPS,
|
|
format='rgb24',
|
|
)
|
|
reader.add_basic_video_stream(
|
|
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
|
frame_rate=_SYNC_FPS,
|
|
format='rgb24',
|
|
)
|
|
|
|
reader.fill_buffer()
|
|
data_chunk = reader.pop_chunks()
|
|
|
|
clip_chunk = data_chunk[0]
|
|
sync_chunk = data_chunk[1]
|
|
if clip_chunk is None:
|
|
raise RuntimeError(f'CLIP video returned None {video_id}')
|
|
if clip_chunk.shape[0] < self.clip_expected_length:
|
|
raise RuntimeError(f'CLIP video too short {video_id}')
|
|
|
|
if sync_chunk is None:
|
|
raise RuntimeError(f'Sync video returned None {video_id}')
|
|
if sync_chunk.shape[0] < self.sync_expected_length:
|
|
raise RuntimeError(f'Sync video too short {video_id}')
|
|
|
|
# truncate the video
|
|
clip_chunk = clip_chunk[:self.clip_expected_length]
|
|
if clip_chunk.shape[0] != self.clip_expected_length:
|
|
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
|
f'expected {self.clip_expected_length}, '
|
|
f'got {clip_chunk.shape[0]}')
|
|
clip_chunk = self.clip_augment(clip_chunk)
|
|
|
|
sync_chunk = sync_chunk[:self.sync_expected_length]
|
|
if sync_chunk.shape[0] != self.sync_expected_length:
|
|
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
|
f'expected {self.sync_expected_length}, '
|
|
f'got {sync_chunk.shape[0]}')
|
|
sync_chunk = self.sync_augment(sync_chunk)
|
|
|
|
data = {
|
|
'name': video_id,
|
|
'caption': caption,
|
|
'clip_video': clip_chunk,
|
|
'sync_video': sync_chunk,
|
|
}
|
|
|
|
return data
|
|
|
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
return self.sample(idx)
|
|
|
|
def __len__(self):
|
|
return len(self.captions)
|