mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			136 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			136 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import re
 | 
						|
import random
 | 
						|
import time
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
import logging
 | 
						|
import numpy as np
 | 
						|
from os import path as osp
 | 
						|
 | 
						|
def constant_init(module, val, bias=0):
 | 
						|
    if hasattr(module, 'weight') and module.weight is not None:
 | 
						|
        nn.init.constant_(module.weight, val)
 | 
						|
    if hasattr(module, 'bias') and module.bias is not None:
 | 
						|
        nn.init.constant_(module.bias, bias)
 | 
						|
 | 
						|
initialized_logger = {}
 | 
						|
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
 | 
						|
    """Get the root logger.
 | 
						|
    The logger will be initialized if it has not been initialized. By default a
 | 
						|
    StreamHandler will be added. If `log_file` is specified, a FileHandler will
 | 
						|
    also be added.
 | 
						|
    Args:
 | 
						|
        logger_name (str): root logger name. Default: 'basicsr'.
 | 
						|
        log_file (str | None): The log filename. If specified, a FileHandler
 | 
						|
            will be added to the root logger.
 | 
						|
        log_level (int): The root logger level. Note that only the process of
 | 
						|
            rank 0 is affected, while other processes will set the level to
 | 
						|
            "Error" and be silent most of the time.
 | 
						|
    Returns:
 | 
						|
        logging.Logger: The root logger.
 | 
						|
    """
 | 
						|
    logger = logging.getLogger(logger_name)
 | 
						|
    # if the logger has been initialized, just return it
 | 
						|
    if logger_name in initialized_logger:
 | 
						|
        return logger
 | 
						|
 | 
						|
    format_str = '%(asctime)s %(levelname)s: %(message)s'
 | 
						|
    stream_handler = logging.StreamHandler()
 | 
						|
    stream_handler.setFormatter(logging.Formatter(format_str))
 | 
						|
    logger.addHandler(stream_handler)
 | 
						|
    logger.propagate = False
 | 
						|
 | 
						|
    if log_file is not None:
 | 
						|
        logger.setLevel(log_level)
 | 
						|
        # add file handler
 | 
						|
        # file_handler = logging.FileHandler(log_file, 'w')
 | 
						|
        file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
 | 
						|
        file_handler.setFormatter(logging.Formatter(format_str))
 | 
						|
        file_handler.setLevel(log_level)
 | 
						|
        logger.addHandler(file_handler)
 | 
						|
    initialized_logger[logger_name] = True
 | 
						|
    return logger
 | 
						|
 | 
						|
match = re.match(r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__)
 | 
						|
if match:
 | 
						|
    version_tuple = match.groups()
 | 
						|
    IS_HIGH_VERSION = [int(v) for v in version_tuple] >= [1, 12, 0]
 | 
						|
else:
 | 
						|
    logger = get_root_logger()
 | 
						|
    logger.warning(f"Could not parse torch version '{torch.__version__}'. Assuming it's not a high version >= 1.12.0.")
 | 
						|
    IS_HIGH_VERSION = False
 | 
						|
 | 
						|
def gpu_is_available():
 | 
						|
    if IS_HIGH_VERSION:
 | 
						|
        if torch.backends.mps.is_available():
 | 
						|
            return True
 | 
						|
    return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
 | 
						|
 | 
						|
def get_device(gpu_id=None):
 | 
						|
    if gpu_id is None:
 | 
						|
        gpu_str = ''
 | 
						|
    elif isinstance(gpu_id, int):
 | 
						|
        gpu_str = f':{gpu_id}'
 | 
						|
    else:
 | 
						|
        raise TypeError('Input should be int value.')
 | 
						|
 | 
						|
    if IS_HIGH_VERSION:
 | 
						|
        if torch.backends.mps.is_available():
 | 
						|
            return torch.device('mps'+gpu_str)
 | 
						|
    return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
 | 
						|
 | 
						|
 | 
						|
def set_random_seed(seed):
 | 
						|
    """Set random seeds."""
 | 
						|
    random.seed(seed)
 | 
						|
    np.random.seed(seed)
 | 
						|
    torch.manual_seed(seed)
 | 
						|
    torch.cuda.manual_seed(seed)
 | 
						|
    torch.cuda.manual_seed_all(seed)
 | 
						|
 | 
						|
 | 
						|
def get_time_str():
 | 
						|
    return time.strftime('%Y%m%d_%H%M%S', time.localtime())
 | 
						|
 | 
						|
 | 
						|
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
 | 
						|
    """Scan a directory to find the interested files.
 | 
						|
 | 
						|
    Args:
 | 
						|
        dir_path (str): Path of the directory.
 | 
						|
        suffix (str | tuple(str), optional): File suffix that we are
 | 
						|
            interested in. Default: None.
 | 
						|
        recursive (bool, optional): If set to True, recursively scan the
 | 
						|
            directory. Default: False.
 | 
						|
        full_path (bool, optional): If set to True, include the dir_path.
 | 
						|
            Default: False.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        A generator for all the interested files with relative pathes.
 | 
						|
    """
 | 
						|
 | 
						|
    if (suffix is not None) and not isinstance(suffix, (str, tuple)):
 | 
						|
        raise TypeError('"suffix" must be a string or tuple of strings')
 | 
						|
 | 
						|
    root = dir_path
 | 
						|
 | 
						|
    def _scandir(dir_path, suffix, recursive):
 | 
						|
        for entry in os.scandir(dir_path):
 | 
						|
            if not entry.name.startswith('.') and entry.is_file():
 | 
						|
                if full_path:
 | 
						|
                    return_path = entry.path
 | 
						|
                else:
 | 
						|
                    return_path = osp.relpath(entry.path, root)
 | 
						|
 | 
						|
                if suffix is None:
 | 
						|
                    yield return_path
 | 
						|
                elif return_path.endswith(suffix):
 | 
						|
                    yield return_path
 | 
						|
            else:
 | 
						|
                if recursive:
 | 
						|
                    yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
 | 
						|
                else:
 | 
						|
                    continue
 | 
						|
 | 
						|
    return _scandir(dir_path, suffix=suffix, recursive=recursive) |