mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
fix pytorch version extraction
This commit is contained in:
parent
93e876f06a
commit
aef624bd5b
@ -52,9 +52,14 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
|
|||||||
initialized_logger[logger_name] = True
|
initialized_logger[logger_name] = True
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
match = re.match(r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__)
|
||||||
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
if match:
|
||||||
torch.__version__)[0][:3])] >= [1, 12, 0]
|
version_tuple = match.groups()
|
||||||
|
IS_HIGH_VERSION = [int(v) for v in version_tuple] >= [1, 12, 0]
|
||||||
|
else:
|
||||||
|
logger = get_root_logger()
|
||||||
|
logger.warning(f"Could not parse torch version '{torch.__version__}'. Assuming it's not a high version >= 1.12.0.")
|
||||||
|
IS_HIGH_VERSION = False
|
||||||
|
|
||||||
def gpu_is_available():
|
def gpu_is_available():
|
||||||
if IS_HIGH_VERSION:
|
if IS_HIGH_VERSION:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user