mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-03 22:04:21 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			119 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			119 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						|
import argparse
 | 
						|
import binascii
 | 
						|
import os
 | 
						|
import os.path as osp
 | 
						|
 | 
						|
import imageio
 | 
						|
import torch
 | 
						|
import torchvision
 | 
						|
 | 
						|
__all__ = ['cache_video', 'cache_image', 'str2bool']
 | 
						|
 | 
						|
 | 
						|
def rand_name(length=8, suffix=''):
 | 
						|
    name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
 | 
						|
    if suffix:
 | 
						|
        if not suffix.startswith('.'):
 | 
						|
            suffix = '.' + suffix
 | 
						|
        name += suffix
 | 
						|
    return name
 | 
						|
 | 
						|
 | 
						|
def cache_video(tensor,
 | 
						|
                save_file=None,
 | 
						|
                fps=30,
 | 
						|
                suffix='.mp4',
 | 
						|
                nrow=8,
 | 
						|
                normalize=True,
 | 
						|
                value_range=(-1, 1),
 | 
						|
                retry=5):
 | 
						|
    # cache file
 | 
						|
    cache_file = osp.join('/tmp', rand_name(
 | 
						|
        suffix=suffix)) if save_file is None else save_file
 | 
						|
 | 
						|
    # save to cache
 | 
						|
    error = None
 | 
						|
    for _ in range(retry):
 | 
						|
        try:
 | 
						|
            # preprocess
 | 
						|
            tensor = tensor.clamp(min(value_range), max(value_range))
 | 
						|
            tensor = torch.stack([
 | 
						|
                torchvision.utils.make_grid(
 | 
						|
                    u, nrow=nrow, normalize=normalize, value_range=value_range)
 | 
						|
                for u in tensor.unbind(2)
 | 
						|
            ],
 | 
						|
                                 dim=1).permute(1, 2, 3, 0)
 | 
						|
            tensor = (tensor * 255).type(torch.uint8).cpu()
 | 
						|
 | 
						|
            # write video
 | 
						|
            writer = imageio.get_writer(
 | 
						|
                cache_file, fps=fps, codec='libx264', quality=8)
 | 
						|
            for frame in tensor.numpy():
 | 
						|
                writer.append_data(frame)
 | 
						|
            writer.close()
 | 
						|
            return cache_file
 | 
						|
        except Exception as e:
 | 
						|
            error = e
 | 
						|
            continue
 | 
						|
    else:
 | 
						|
        print(f'cache_video failed, error: {error}', flush=True)
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
def cache_image(tensor,
 | 
						|
                save_file,
 | 
						|
                nrow=8,
 | 
						|
                normalize=True,
 | 
						|
                value_range=(-1, 1),
 | 
						|
                retry=5):
 | 
						|
    # cache file
 | 
						|
    suffix = osp.splitext(save_file)[1]
 | 
						|
    if suffix.lower() not in [
 | 
						|
            '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
 | 
						|
    ]:
 | 
						|
        suffix = '.png'
 | 
						|
 | 
						|
    # save to cache
 | 
						|
    error = None
 | 
						|
    for _ in range(retry):
 | 
						|
        try:
 | 
						|
            tensor = tensor.clamp(min(value_range), max(value_range))
 | 
						|
            torchvision.utils.save_image(
 | 
						|
                tensor,
 | 
						|
                save_file,
 | 
						|
                nrow=nrow,
 | 
						|
                normalize=normalize,
 | 
						|
                value_range=value_range)
 | 
						|
            return save_file
 | 
						|
        except Exception as e:
 | 
						|
            error = e
 | 
						|
            continue
 | 
						|
 | 
						|
 | 
						|
def str2bool(v):
 | 
						|
    """
 | 
						|
    Convert a string to a boolean.
 | 
						|
 | 
						|
    Supported true values: 'yes', 'true', 't', 'y', '1'
 | 
						|
    Supported false values: 'no', 'false', 'f', 'n', '0'
 | 
						|
 | 
						|
    Args:
 | 
						|
        v (str): String to convert.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        bool: Converted boolean value.
 | 
						|
 | 
						|
    Raises:
 | 
						|
        argparse.ArgumentTypeError: If the value cannot be converted to boolean.
 | 
						|
    """
 | 
						|
    if isinstance(v, bool):
 | 
						|
        return v
 | 
						|
    v_lower = v.lower()
 | 
						|
    if v_lower in ('yes', 'true', 't', 'y', '1'):
 | 
						|
        return True
 | 
						|
    elif v_lower in ('no', 'false', 'f', 'n', '0'):
 | 
						|
        return False
 | 
						|
    else:
 | 
						|
        raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
 |