# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import binascii import os import os.path as osp import torchvision.transforms.functional as TF import torch.nn.functional as F import imageio import torch import decord import torchvision from PIL import Image import numpy as np from rembg import remove, new_session __all__ = ['cache_video', 'cache_image', 'str2bool'] from PIL import Image def resample(video_fps, video_frames_count, max_frames, target_fps): import math video_frame_duration = 1 /video_fps target_frame_duration = 1 / target_fps cur_time = 0 target_time = 0 frame_no = 0 frame_ids =[] while True: if max_frames != 0 and len(frame_ids) >= max_frames: break add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration ) frame_no += add_frames_count if frame_no >= video_frames_count: break frame_ids.append(frame_no) cur_time += add_frames_count * video_frame_duration target_time += target_frame_duration return frame_ids def get_video_frame(file_name, frame_no): decord.bridge.set_bridge('torch') reader = decord.VideoReader(file_name) frame = reader.get_batch([frame_no]).squeeze(0) img = Image.fromarray(frame.numpy().astype(np.uint8)) return img def resize_lanczos(img, h, w): img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = img.resize((w,h), resample=Image.Resampling.LANCZOS) return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) def remove_background(img, session=None): if session ==None: session = new_session() img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ): if rm_background: session = new_session() output_list =[] for img in img_list: width, height = img.size white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 ) scale = min(canvas_height / height, canvas_width / width) new_height = int(height * scale) new_width = int(width * scale) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) if rm_background: resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') top = (canvas_height - new_height) // 2 left = (canvas_width - new_width) // 2 white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image) img = Image.fromarray(white_canvas) output_list.append(img) return output_list def rand_name(length=8, suffix=''): name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') if suffix: if not suffix.startswith('.'): suffix = '.' + suffix name += suffix return name def cache_video(tensor, save_file=None, fps=30, suffix='.mp4', nrow=8, normalize=True, value_range=(-1, 1), retry=5): # cache file cache_file = osp.join('/tmp', rand_name( suffix=suffix)) if save_file is None else save_file # save to cache error = None for _ in range(retry): try: # preprocess tensor = tensor.clamp(min(value_range), max(value_range)) tensor = torch.stack([ torchvision.utils.make_grid( u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2) ], dim=1).permute(1, 2, 3, 0) tensor = (tensor * 255).type(torch.uint8).cpu() # write video writer = imageio.get_writer( cache_file, fps=fps, codec='libx264', quality=8) for frame in tensor.numpy(): writer.append_data(frame) writer.close() return cache_file except Exception as e: error = e continue else: print(f'cache_video failed, error: {error}', flush=True) return None def cache_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1), retry=5): # cache file suffix = osp.splitext(save_file)[1] if suffix.lower() not in [ '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' ]: suffix = '.png' # save to cache error = None for _ in range(retry): try: tensor = tensor.clamp(min(value_range), max(value_range)) torchvision.utils.save_image( tensor, save_file, nrow=nrow, normalize=normalize, value_range=value_range) return save_file except Exception as e: error = e continue def str2bool(v): """ Convert a string to a boolean. Supported true values: 'yes', 'true', 't', 'y', '1' Supported false values: 'no', 'false', 'f', 'n', '0' Args: v (str): String to convert. Returns: bool: Converted boolean value. Raises: argparse.ArgumentTypeError: If the value cannot be converted to boolean. """ if isinstance(v, bool): return v v_lower = v.lower() if v_lower in ('yes', 'true', 't', 'y', '1'): return True elif v_lower in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected (True/False)')