mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			170 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			170 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						||
import cv2
 | 
						||
import math
 | 
						||
import json
 | 
						||
import torch
 | 
						||
import random
 | 
						||
import librosa
 | 
						||
import traceback
 | 
						||
import torchvision
 | 
						||
import numpy as np
 | 
						||
import pandas as pd
 | 
						||
from PIL import Image
 | 
						||
from einops import rearrange
 | 
						||
from torch.utils.data import Dataset
 | 
						||
from decord import VideoReader, cpu
 | 
						||
from transformers import CLIPImageProcessor
 | 
						||
import torchvision.transforms as transforms
 | 
						||
from torchvision.transforms import ToPILImage
 | 
						||
 | 
						||
 | 
						||
 | 
						||
def get_audio_feature(feature_extractor, audio_path):
 | 
						||
    audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
 | 
						||
    assert sampling_rate == 16000
 | 
						||
 | 
						||
    audio_features = []
 | 
						||
    window = 750*640
 | 
						||
    for i in range(0, len(audio_input), window):
 | 
						||
        audio_feature = feature_extractor(audio_input[i:i+window], 
 | 
						||
                                        sampling_rate=sampling_rate, 
 | 
						||
                                        return_tensors="pt", 
 | 
						||
                                        ).input_features
 | 
						||
        audio_features.append(audio_feature)
 | 
						||
 | 
						||
    audio_features = torch.cat(audio_features, dim=-1)
 | 
						||
    return audio_features, len(audio_input) // 640
 | 
						||
 | 
						||
 | 
						||
class VideoAudioTextLoaderVal(Dataset):
 | 
						||
    def __init__(
 | 
						||
        self, 
 | 
						||
        image_size: int,
 | 
						||
        meta_file: str, 
 | 
						||
        **kwargs,
 | 
						||
    ):
 | 
						||
        super().__init__()
 | 
						||
        self.meta_file = meta_file
 | 
						||
        self.image_size = image_size
 | 
						||
        self.text_encoder = kwargs.get("text_encoder", None)            # llava_text_encoder
 | 
						||
        self.text_encoder_2 = kwargs.get("text_encoder_2", None)        # clipL_text_encoder
 | 
						||
        self.feature_extractor = kwargs.get("feature_extractor", None)
 | 
						||
        self.meta_files = []
 | 
						||
    
 | 
						||
        csv_data = pd.read_csv(meta_file)
 | 
						||
        for idx in range(len(csv_data)):
 | 
						||
            self.meta_files.append(
 | 
						||
                {
 | 
						||
                    "videoid": str(csv_data["videoid"][idx]),
 | 
						||
                    "image_path": str(csv_data["image"][idx]), 
 | 
						||
                    "audio_path": str(csv_data["audio"][idx]), 
 | 
						||
                    "prompt": str(csv_data["prompt"][idx]), 
 | 
						||
                    "fps": float(csv_data["fps"][idx])
 | 
						||
                }
 | 
						||
            )
 | 
						||
        
 | 
						||
        self.llava_transform = transforms.Compose(
 | 
						||
            [
 | 
						||
                transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR), 
 | 
						||
                transforms.ToTensor(), 
 | 
						||
                transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
 | 
						||
            ]
 | 
						||
        )
 | 
						||
        self.clip_image_processor = CLIPImageProcessor()
 | 
						||
        
 | 
						||
        self.device = torch.device("cuda")
 | 
						||
        self.weight_dtype = torch.float16
 | 
						||
 | 
						||
        
 | 
						||
    def __len__(self):
 | 
						||
        return len(self.meta_files)
 | 
						||
 | 
						||
    @staticmethod
 | 
						||
    def get_text_tokens(text_encoder, description, dtype_encode="video"):
 | 
						||
        text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
 | 
						||
        text_ids = text_inputs["input_ids"].squeeze(0)
 | 
						||
        text_mask = text_inputs["attention_mask"].squeeze(0)
 | 
						||
        return text_ids, text_mask
 | 
						||
    
 | 
						||
    def get_batch_data(self, idx):
 | 
						||
        meta_file = self.meta_files[idx]
 | 
						||
        videoid = meta_file["videoid"]
 | 
						||
        image_path = meta_file["image_path"]
 | 
						||
        audio_path = meta_file["audio_path"]
 | 
						||
        prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
 | 
						||
        fps = meta_file["fps"]
 | 
						||
        
 | 
						||
        img_size = self.image_size
 | 
						||
        ref_image = Image.open(image_path).convert('RGB')
 | 
						||
        
 | 
						||
        # Resize reference image
 | 
						||
        w, h = ref_image.size
 | 
						||
        scale = img_size / min(w, h)
 | 
						||
        new_w = round(w * scale / 64) * 64
 | 
						||
        new_h = round(h * scale / 64) * 64
 | 
						||
 | 
						||
        if img_size == 704:
 | 
						||
            img_size_long = 1216
 | 
						||
        if new_w * new_h > img_size * img_size_long:
 | 
						||
            import math
 | 
						||
            scale = math.sqrt(img_size * img_size_long / w / h)
 | 
						||
            new_w = round(w * scale / 64) * 64
 | 
						||
            new_h = round(h * scale / 64) * 64
 | 
						||
 | 
						||
        ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
 | 
						||
        
 | 
						||
        ref_image = np.array(ref_image)
 | 
						||
        ref_image = torch.from_numpy(ref_image)
 | 
						||
         
 | 
						||
        audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
 | 
						||
        audio_prompts = audio_input[0]
 | 
						||
        
 | 
						||
        motion_bucket_id_heads = np.array([25] * 4)
 | 
						||
        motion_bucket_id_exps = np.array([30] * 4)
 | 
						||
        motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
 | 
						||
        motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
 | 
						||
        fps = torch.from_numpy(np.array(fps))
 | 
						||
        
 | 
						||
        to_pil = ToPILImage()
 | 
						||
        pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w")   # (b c h w)
 | 
						||
        
 | 
						||
        pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
 | 
						||
        pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
 | 
						||
        pixel_value_ref_clip = self.clip_image_processor(
 | 
						||
            images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)), 
 | 
						||
            return_tensors="pt"
 | 
						||
        ).pixel_values[0]
 | 
						||
        pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
 | 
						||
        
 | 
						||
        # Encode text prompts
 | 
						||
   
 | 
						||
        text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
 | 
						||
        text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
 | 
						||
        
 | 
						||
        # Output batch
 | 
						||
        batch = {
 | 
						||
            "text_prompt": prompt,                         # 
 | 
						||
            "videoid": videoid, 
 | 
						||
            "pixel_value_ref": pixel_value_ref.to(dtype=torch.float16),                 # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255)
 | 
						||
            "pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16),     # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围
 | 
						||
            "pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16),       # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围
 | 
						||
            "audio_prompts": audio_prompts.to(dtype=torch.float16), 
 | 
						||
            "motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype), 
 | 
						||
            "motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype), 
 | 
						||
            "fps": fps.to(dtype=torch.float16), 
 | 
						||
            "text_ids": text_ids.clone(),                                               # 对应llava_text_encoder
 | 
						||
            "text_mask": text_mask.clone(),                                             # 对应llava_text_encoder
 | 
						||
            "text_ids_2": text_ids_2.clone(),                                           # 对应clip_text_encoder
 | 
						||
            "text_mask_2": text_mask_2.clone(),                                         # 对应clip_text_encoder
 | 
						||
            "audio_len": audio_len,
 | 
						||
            "image_path": image_path, 
 | 
						||
            "audio_path": audio_path, 
 | 
						||
        }
 | 
						||
        return batch
 | 
						||
    
 | 
						||
    def __getitem__(self, idx):
 | 
						||
        return self.get_batch_data(idx)
 | 
						||
    
 | 
						||
 | 
						||
        
 | 
						||
         |