mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-20 14:12:04 +00:00
35 lines
884 B
Python
35 lines
884 B
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
try:
|
|
import torch_musa
|
|
except ModuleNotFoundError:
|
|
pass
|
|
|
|
|
|
def _is_musa():
|
|
try:
|
|
if torch.musa.is_available():
|
|
return True
|
|
except ModuleNotFoundError:
|
|
return False
|
|
|
|
|
|
def get_device(local_rank:Optional[int]=None) -> torch.device:
|
|
if torch.cuda.is_available():
|
|
return torch.cuda.current_device() if local_rank is None else torch.device("cuda", local_rank)
|
|
elif _is_musa():
|
|
return torch.musa.current_device() if local_rank is None else torch.device("musa", local_rank)
|
|
else:
|
|
return torch.device("cpu")
|
|
|
|
|
|
def get_torch_distributed_backend() -> str:
|
|
if torch.cuda.is_available():
|
|
return "nccl"
|
|
elif _is_musa():
|
|
return "mccl"
|
|
else:
|
|
raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available")
|