mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	fix pytorch version extraction
This commit is contained in:
		
							parent
							
								
									93e876f06a
								
							
						
					
					
						commit
						aef624bd5b
					
				@ -52,9 +52,14 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
 | 
			
		||||
    initialized_logger[logger_name] = True
 | 
			
		||||
    return logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
 | 
			
		||||
    torch.__version__)[0][:3])] >= [1, 12, 0]
 | 
			
		||||
match = re.match(r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__)
 | 
			
		||||
if match:
 | 
			
		||||
    version_tuple = match.groups()
 | 
			
		||||
    IS_HIGH_VERSION = [int(v) for v in version_tuple] >= [1, 12, 0]
 | 
			
		||||
else:
 | 
			
		||||
    logger = get_root_logger()
 | 
			
		||||
    logger.warning(f"Could not parse torch version '{torch.__version__}'. Assuming it's not a high version >= 1.12.0.")
 | 
			
		||||
    IS_HIGH_VERSION = False
 | 
			
		||||
 | 
			
		||||
def gpu_is_available():
 | 
			
		||||
    if IS_HIGH_VERSION:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user