From aef624bd5bbf8655c2fcb680100b607f83012d5d Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 14 Apr 2025 04:53:21 +1000 Subject: [PATCH] fix pytorch version extraction --- preprocessing/matanyone/tools/misc.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/preprocessing/matanyone/tools/misc.py b/preprocessing/matanyone/tools/misc.py index 43b8499..868639c 100644 --- a/preprocessing/matanyone/tools/misc.py +++ b/preprocessing/matanyone/tools/misc.py @@ -52,9 +52,14 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None initialized_logger[logger_name] = True return logger - -IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ - torch.__version__)[0][:3])] >= [1, 12, 0] +match = re.match(r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__) +if match: + version_tuple = match.groups() + IS_HIGH_VERSION = [int(v) for v in version_tuple] >= [1, 12, 0] +else: + logger = get_root_logger() + logger.warning(f"Could not parse torch version '{torch.__version__}'. Assuming it's not a high version >= 1.12.0.") + IS_HIGH_VERSION = False def gpu_is_available(): if IS_HIGH_VERSION: