fix pytorch version extraction

This commit is contained in:
Chris Malone 2025-04-14 04:53:21 +10:00
parent 93e876f06a
commit aef624bd5b

View File

@ -52,9 +52,14 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
initialized_logger[logger_name] = True
return logger
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
match = re.match(r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__)
if match:
version_tuple = match.groups()
IS_HIGH_VERSION = [int(v) for v in version_tuple] >= [1, 12, 0]
else:
logger = get_root_logger()
logger.warning(f"Could not parse torch version '{torch.__version__}'. Assuming it's not a high version >= 1.12.0.")
IS_HIGH_VERSION = False
def gpu_is_available():
if IS_HIGH_VERSION: