mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
170 lines
6.8 KiB
Python
170 lines
6.8 KiB
Python
import os
|
||
import cv2
|
||
import math
|
||
import json
|
||
import torch
|
||
import random
|
||
import librosa
|
||
import traceback
|
||
import torchvision
|
||
import numpy as np
|
||
import pandas as pd
|
||
from PIL import Image
|
||
from einops import rearrange
|
||
from torch.utils.data import Dataset
|
||
from decord import VideoReader, cpu
|
||
from transformers import CLIPImageProcessor
|
||
import torchvision.transforms as transforms
|
||
from torchvision.transforms import ToPILImage
|
||
|
||
|
||
|
||
def get_audio_feature(feature_extractor, audio_path):
|
||
audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
|
||
assert sampling_rate == 16000
|
||
|
||
audio_features = []
|
||
window = 750*640
|
||
for i in range(0, len(audio_input), window):
|
||
audio_feature = feature_extractor(audio_input[i:i+window],
|
||
sampling_rate=sampling_rate,
|
||
return_tensors="pt",
|
||
).input_features
|
||
audio_features.append(audio_feature)
|
||
|
||
audio_features = torch.cat(audio_features, dim=-1)
|
||
return audio_features, len(audio_input) // 640
|
||
|
||
|
||
class VideoAudioTextLoaderVal(Dataset):
|
||
def __init__(
|
||
self,
|
||
image_size: int,
|
||
meta_file: str,
|
||
**kwargs,
|
||
):
|
||
super().__init__()
|
||
self.meta_file = meta_file
|
||
self.image_size = image_size
|
||
self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder
|
||
self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder
|
||
self.feature_extractor = kwargs.get("feature_extractor", None)
|
||
self.meta_files = []
|
||
|
||
csv_data = pd.read_csv(meta_file)
|
||
for idx in range(len(csv_data)):
|
||
self.meta_files.append(
|
||
{
|
||
"videoid": str(csv_data["videoid"][idx]),
|
||
"image_path": str(csv_data["image"][idx]),
|
||
"audio_path": str(csv_data["audio"][idx]),
|
||
"prompt": str(csv_data["prompt"][idx]),
|
||
"fps": float(csv_data["fps"][idx])
|
||
}
|
||
)
|
||
|
||
self.llava_transform = transforms.Compose(
|
||
[
|
||
transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
|
||
]
|
||
)
|
||
self.clip_image_processor = CLIPImageProcessor()
|
||
|
||
self.device = torch.device("cuda")
|
||
self.weight_dtype = torch.float16
|
||
|
||
|
||
def __len__(self):
|
||
return len(self.meta_files)
|
||
|
||
@staticmethod
|
||
def get_text_tokens(text_encoder, description, dtype_encode="video"):
|
||
text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
|
||
text_ids = text_inputs["input_ids"].squeeze(0)
|
||
text_mask = text_inputs["attention_mask"].squeeze(0)
|
||
return text_ids, text_mask
|
||
|
||
def get_batch_data(self, idx):
|
||
meta_file = self.meta_files[idx]
|
||
videoid = meta_file["videoid"]
|
||
image_path = meta_file["image_path"]
|
||
audio_path = meta_file["audio_path"]
|
||
prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
|
||
fps = meta_file["fps"]
|
||
|
||
img_size = self.image_size
|
||
ref_image = Image.open(image_path).convert('RGB')
|
||
|
||
# Resize reference image
|
||
w, h = ref_image.size
|
||
scale = img_size / min(w, h)
|
||
new_w = round(w * scale / 64) * 64
|
||
new_h = round(h * scale / 64) * 64
|
||
|
||
if img_size == 704:
|
||
img_size_long = 1216
|
||
if new_w * new_h > img_size * img_size_long:
|
||
import math
|
||
scale = math.sqrt(img_size * img_size_long / w / h)
|
||
new_w = round(w * scale / 64) * 64
|
||
new_h = round(h * scale / 64) * 64
|
||
|
||
ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
|
||
|
||
ref_image = np.array(ref_image)
|
||
ref_image = torch.from_numpy(ref_image)
|
||
|
||
audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
|
||
audio_prompts = audio_input[0]
|
||
|
||
motion_bucket_id_heads = np.array([25] * 4)
|
||
motion_bucket_id_exps = np.array([30] * 4)
|
||
motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
|
||
motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
|
||
fps = torch.from_numpy(np.array(fps))
|
||
|
||
to_pil = ToPILImage()
|
||
pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
|
||
|
||
pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
|
||
pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
|
||
pixel_value_ref_clip = self.clip_image_processor(
|
||
images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)),
|
||
return_tensors="pt"
|
||
).pixel_values[0]
|
||
pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
|
||
|
||
# Encode text prompts
|
||
|
||
text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
|
||
text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
|
||
|
||
# Output batch
|
||
batch = {
|
||
"text_prompt": prompt, #
|
||
"videoid": videoid,
|
||
"pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255)
|
||
"pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围
|
||
"pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围
|
||
"audio_prompts": audio_prompts.to(dtype=torch.float16),
|
||
"motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype),
|
||
"motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype),
|
||
"fps": fps.to(dtype=torch.float16),
|
||
"text_ids": text_ids.clone(), # 对应llava_text_encoder
|
||
"text_mask": text_mask.clone(), # 对应llava_text_encoder
|
||
"text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder
|
||
"text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder
|
||
"audio_len": audio_len,
|
||
"image_path": image_path,
|
||
"audio_path": audio_path,
|
||
}
|
||
return batch
|
||
|
||
def __getitem__(self, idx):
|
||
return self.get_batch_data(idx)
|
||
|
||
|
||
|
||
|