mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
fix pytorch version extraction
This commit is contained in:
parent
93e876f06a
commit
aef624bd5b
@ -52,9 +52,14 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
|
||||
initialized_logger[logger_name] = True
|
||||
return logger
|
||||
|
||||
|
||||
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
||||
torch.__version__)[0][:3])] >= [1, 12, 0]
|
||||
match = re.match(r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__)
|
||||
if match:
|
||||
version_tuple = match.groups()
|
||||
IS_HIGH_VERSION = [int(v) for v in version_tuple] >= [1, 12, 0]
|
||||
else:
|
||||
logger = get_root_logger()
|
||||
logger.warning(f"Could not parse torch version '{torch.__version__}'. Assuming it's not a high version >= 1.12.0.")
|
||||
IS_HIGH_VERSION = False
|
||||
|
||||
def gpu_is_available():
|
||||
if IS_HIGH_VERSION:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user