mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-03 22:04:21 +00:00 
			
		
		
		
	* isort the code * format the code * Add yapf config file * Remove torch cuda memory profiler
		
			
				
	
	
		
			306 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			306 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
import torch.nn.functional as F
 | 
						|
import torchvision.transforms.functional as TF
 | 
						|
from PIL import Image
 | 
						|
 | 
						|
 | 
						|
class VaceImageProcessor(object):
 | 
						|
 | 
						|
    def __init__(self, downsample=None, seq_len=None):
 | 
						|
        self.downsample = downsample
 | 
						|
        self.seq_len = seq_len
 | 
						|
 | 
						|
    def _pillow_convert(self, image, cvt_type='RGB'):
 | 
						|
        if image.mode != cvt_type:
 | 
						|
            if image.mode == 'P':
 | 
						|
                image = image.convert(f'{cvt_type}A')
 | 
						|
            if image.mode == f'{cvt_type}A':
 | 
						|
                bg = Image.new(
 | 
						|
                    cvt_type,
 | 
						|
                    size=(image.width, image.height),
 | 
						|
                    color=(255, 255, 255))
 | 
						|
                bg.paste(image, (0, 0), mask=image)
 | 
						|
                image = bg
 | 
						|
            else:
 | 
						|
                image = image.convert(cvt_type)
 | 
						|
        return image
 | 
						|
 | 
						|
    def _load_image(self, img_path):
 | 
						|
        if img_path is None or img_path == '':
 | 
						|
            return None
 | 
						|
        img = Image.open(img_path)
 | 
						|
        img = self._pillow_convert(img)
 | 
						|
        return img
 | 
						|
 | 
						|
    def _resize_crop(self, img, oh, ow, normalize=True):
 | 
						|
        """
 | 
						|
        Resize, center crop, convert to tensor, and normalize.
 | 
						|
        """
 | 
						|
        # resize and crop
 | 
						|
        iw, ih = img.size
 | 
						|
        if iw != ow or ih != oh:
 | 
						|
            # resize
 | 
						|
            scale = max(ow / iw, oh / ih)
 | 
						|
            img = img.resize((round(scale * iw), round(scale * ih)),
 | 
						|
                             resample=Image.Resampling.LANCZOS)
 | 
						|
            assert img.width >= ow and img.height >= oh
 | 
						|
 | 
						|
            # center crop
 | 
						|
            x1 = (img.width - ow) // 2
 | 
						|
            y1 = (img.height - oh) // 2
 | 
						|
            img = img.crop((x1, y1, x1 + ow, y1 + oh))
 | 
						|
 | 
						|
        # normalize
 | 
						|
        if normalize:
 | 
						|
            img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
						|
        return img
 | 
						|
 | 
						|
    def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
 | 
						|
        return self._resize_crop(img, oh, ow, normalize)
 | 
						|
 | 
						|
    def load_image(self, data_key, **kwargs):
 | 
						|
        return self.load_image_batch(data_key, **kwargs)
 | 
						|
 | 
						|
    def load_image_pair(self, data_key, data_key2, **kwargs):
 | 
						|
        return self.load_image_batch(data_key, data_key2, **kwargs)
 | 
						|
 | 
						|
    def load_image_batch(self,
 | 
						|
                         *data_key_batch,
 | 
						|
                         normalize=True,
 | 
						|
                         seq_len=None,
 | 
						|
                         **kwargs):
 | 
						|
        seq_len = self.seq_len if seq_len is None else seq_len
 | 
						|
        imgs = []
 | 
						|
        for data_key in data_key_batch:
 | 
						|
            img = self._load_image(data_key)
 | 
						|
            imgs.append(img)
 | 
						|
        w, h = imgs[0].size
 | 
						|
        dh, dw = self.downsample[1:]
 | 
						|
 | 
						|
        # compute output size
 | 
						|
        scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
 | 
						|
        oh = int(h * scale) // dh * dh
 | 
						|
        ow = int(w * scale) // dw * dw
 | 
						|
        assert (oh // dh) * (ow // dw) <= seq_len
 | 
						|
        imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
 | 
						|
        return *imgs, (oh, ow)
 | 
						|
 | 
						|
 | 
						|
class VaceVideoProcessor(object):
 | 
						|
 | 
						|
    def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
 | 
						|
                 zero_start, seq_len, keep_last, **kwargs):
 | 
						|
        self.downsample = downsample
 | 
						|
        self.min_area = min_area
 | 
						|
        self.max_area = max_area
 | 
						|
        self.min_fps = min_fps
 | 
						|
        self.max_fps = max_fps
 | 
						|
        self.zero_start = zero_start
 | 
						|
        self.keep_last = keep_last
 | 
						|
        self.seq_len = seq_len
 | 
						|
        assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
 | 
						|
 | 
						|
    def set_area(self, area):
 | 
						|
        self.min_area = area
 | 
						|
        self.max_area = area
 | 
						|
 | 
						|
    def set_seq_len(self, seq_len):
 | 
						|
        self.seq_len = seq_len
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def resize_crop(video: torch.Tensor, oh: int, ow: int):
 | 
						|
        """
 | 
						|
        Resize, center crop and normalize for decord loaded video (torch.Tensor type)
 | 
						|
 | 
						|
        Parameters:
 | 
						|
          video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
 | 
						|
          oh - target height (int)
 | 
						|
          ow - target width (int)
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
 | 
						|
 | 
						|
        Raises:
 | 
						|
        """
 | 
						|
        # permute ([t, h, w, c] -> [t, c, h, w])
 | 
						|
        video = video.permute(0, 3, 1, 2)
 | 
						|
 | 
						|
        # resize and crop
 | 
						|
        ih, iw = video.shape[2:]
 | 
						|
        if ih != oh or iw != ow:
 | 
						|
            # resize
 | 
						|
            scale = max(ow / iw, oh / ih)
 | 
						|
            video = F.interpolate(
 | 
						|
                video,
 | 
						|
                size=(round(scale * ih), round(scale * iw)),
 | 
						|
                mode='bicubic',
 | 
						|
                antialias=True)
 | 
						|
            assert video.size(3) >= ow and video.size(2) >= oh
 | 
						|
 | 
						|
            # center crop
 | 
						|
            x1 = (video.size(3) - ow) // 2
 | 
						|
            y1 = (video.size(2) - oh) // 2
 | 
						|
            video = video[:, :, y1:y1 + oh, x1:x1 + ow]
 | 
						|
 | 
						|
        # permute ([t, c, h, w] -> [c, t, h, w]) and normalize
 | 
						|
        video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
 | 
						|
        return video
 | 
						|
 | 
						|
    def _video_preprocess(self, video, oh, ow):
 | 
						|
        return self.resize_crop(video, oh, ow)
 | 
						|
 | 
						|
    def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box,
 | 
						|
                                  rng):
 | 
						|
        target_fps = min(fps, self.max_fps)
 | 
						|
        duration = frame_timestamps[-1].mean()
 | 
						|
        x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
 | 
						|
        h, w = y2 - y1, x2 - x1
 | 
						|
        ratio = h / w
 | 
						|
        df, dh, dw = self.downsample
 | 
						|
 | 
						|
        area_z = min(self.seq_len, self.max_area / (dh * dw),
 | 
						|
                     (h // dh) * (w // dw))
 | 
						|
        of = min((int(duration * target_fps) - 1) // df + 1,
 | 
						|
                 int(self.seq_len / area_z))
 | 
						|
 | 
						|
        # deduce target shape of the [latent video]
 | 
						|
        target_area_z = min(area_z, int(self.seq_len / of))
 | 
						|
        oh = round(np.sqrt(target_area_z * ratio))
 | 
						|
        ow = int(target_area_z / oh)
 | 
						|
        of = (of - 1) * df + 1
 | 
						|
        oh *= dh
 | 
						|
        ow *= dw
 | 
						|
 | 
						|
        # sample frame ids
 | 
						|
        target_duration = of / target_fps
 | 
						|
        begin = 0. if self.zero_start else rng.uniform(
 | 
						|
            0, duration - target_duration)
 | 
						|
        timestamps = np.linspace(begin, begin + target_duration, of)
 | 
						|
        frame_ids = np.argmax(
 | 
						|
            np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
 | 
						|
                           timestamps[:, None] < frame_timestamps[None, :, 1]),
 | 
						|
            axis=1).tolist()
 | 
						|
        return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
 | 
						|
 | 
						|
    def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
 | 
						|
                                      crop_box, rng):
 | 
						|
        duration = frame_timestamps[-1].mean()
 | 
						|
        x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
 | 
						|
        h, w = y2 - y1, x2 - x1
 | 
						|
        ratio = h / w
 | 
						|
        df, dh, dw = self.downsample
 | 
						|
 | 
						|
        area_z = min(self.seq_len, self.max_area / (dh * dw),
 | 
						|
                     (h // dh) * (w // dw))
 | 
						|
        of = min((len(frame_timestamps) - 1) // df + 1,
 | 
						|
                 int(self.seq_len / area_z))
 | 
						|
 | 
						|
        # deduce target shape of the [latent video]
 | 
						|
        target_area_z = min(area_z, int(self.seq_len / of))
 | 
						|
        oh = round(np.sqrt(target_area_z * ratio))
 | 
						|
        ow = int(target_area_z / oh)
 | 
						|
        of = (of - 1) * df + 1
 | 
						|
        oh *= dh
 | 
						|
        ow *= dw
 | 
						|
 | 
						|
        # sample frame ids
 | 
						|
        target_duration = duration
 | 
						|
        target_fps = of / target_duration
 | 
						|
        timestamps = np.linspace(0., target_duration, of)
 | 
						|
        frame_ids = np.argmax(
 | 
						|
            np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
 | 
						|
                           timestamps[:, None] <= frame_timestamps[None, :, 1]),
 | 
						|
            axis=1).tolist()
 | 
						|
        # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
 | 
						|
        return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
 | 
						|
 | 
						|
    def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
 | 
						|
        if self.keep_last:
 | 
						|
            return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h,
 | 
						|
                                                      w, crop_box, rng)
 | 
						|
        else:
 | 
						|
            return self._get_frameid_bbox_default(fps, frame_timestamps, h, w,
 | 
						|
                                                  crop_box, rng)
 | 
						|
 | 
						|
    def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
 | 
						|
        return self.load_video_batch(
 | 
						|
            data_key, crop_box=crop_box, seed=seed, **kwargs)
 | 
						|
 | 
						|
    def load_video_pair(self,
 | 
						|
                        data_key,
 | 
						|
                        data_key2,
 | 
						|
                        crop_box=None,
 | 
						|
                        seed=2024,
 | 
						|
                        **kwargs):
 | 
						|
        return self.load_video_batch(
 | 
						|
            data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
 | 
						|
 | 
						|
    def load_video_batch(self,
 | 
						|
                         *data_key_batch,
 | 
						|
                         crop_box=None,
 | 
						|
                         seed=2024,
 | 
						|
                         **kwargs):
 | 
						|
        rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
 | 
						|
        # read video
 | 
						|
        import decord
 | 
						|
        decord.bridge.set_bridge('torch')
 | 
						|
        readers = []
 | 
						|
        for data_k in data_key_batch:
 | 
						|
            reader = decord.VideoReader(data_k)
 | 
						|
            readers.append(reader)
 | 
						|
 | 
						|
        fps = readers[0].get_avg_fps()
 | 
						|
        length = min([len(r) for r in readers])
 | 
						|
        frame_timestamps = [
 | 
						|
            readers[0].get_frame_timestamp(i) for i in range(length)
 | 
						|
        ]
 | 
						|
        frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
 | 
						|
        h, w = readers[0].next().shape[:2]
 | 
						|
        frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(
 | 
						|
            fps, frame_timestamps, h, w, crop_box, rng)
 | 
						|
 | 
						|
        # preprocess video
 | 
						|
        videos = [
 | 
						|
            reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :]
 | 
						|
            for reader in readers
 | 
						|
        ]
 | 
						|
        videos = [self._video_preprocess(video, oh, ow) for video in videos]
 | 
						|
        return *videos, frame_ids, (oh, ow), fps
 | 
						|
        # return videos if len(videos) > 1 else videos[0]
 | 
						|
 | 
						|
 | 
						|
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size,
 | 
						|
                   device):
 | 
						|
    for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
 | 
						|
        if sub_src_video is None and sub_src_mask is None:
 | 
						|
            src_video[i] = torch.zeros(
 | 
						|
                (3, num_frames, image_size[0], image_size[1]), device=device)
 | 
						|
            src_mask[i] = torch.ones(
 | 
						|
                (1, num_frames, image_size[0], image_size[1]), device=device)
 | 
						|
    for i, ref_images in enumerate(src_ref_images):
 | 
						|
        if ref_images is not None:
 | 
						|
            for j, ref_img in enumerate(ref_images):
 | 
						|
                if ref_img is not None and ref_img.shape[-2:] != image_size:
 | 
						|
                    canvas_height, canvas_width = image_size
 | 
						|
                    ref_height, ref_width = ref_img.shape[-2:]
 | 
						|
                    white_canvas = torch.ones(
 | 
						|
                        (3, 1, canvas_height, canvas_width),
 | 
						|
                        device=device)  # [-1, 1]
 | 
						|
                    scale = min(canvas_height / ref_height,
 | 
						|
                                canvas_width / ref_width)
 | 
						|
                    new_height = int(ref_height * scale)
 | 
						|
                    new_width = int(ref_width * scale)
 | 
						|
                    resized_image = F.interpolate(
 | 
						|
                        ref_img.squeeze(1).unsqueeze(0),
 | 
						|
                        size=(new_height, new_width),
 | 
						|
                        mode='bilinear',
 | 
						|
                        align_corners=False).squeeze(0).unsqueeze(1)
 | 
						|
                    top = (canvas_height - new_height) // 2
 | 
						|
                    left = (canvas_width - new_width) // 2
 | 
						|
                    white_canvas[:, :, top:top + new_height,
 | 
						|
                                 left:left + new_width] = resized_image
 | 
						|
                    src_ref_images[i][j] = white_canvas
 | 
						|
    return src_video, src_mask, src_ref_images
 |