Wan2.1/wan/utils/platform.py
2025-08-06 20:05:57 +08:00

51 lines
1.1 KiB
Python

from typing import Optional
import torch
try:
import torch_musa
except ModuleNotFoundError:
torch_musa = None
def _is_musa():
if torch_musa is None:
return False
else:
return True
def get_device(local_rank: Optional[int] = None) -> torch.device:
if torch.cuda.is_available():
return (
torch.cuda.current_device()
if local_rank is None
else torch.device("cuda", local_rank)
)
elif _is_musa():
return (
torch.musa.current_device()
if local_rank is None
else torch.device("musa", local_rank)
)
else:
return torch.device("cpu")
def get_device_type() -> str:
if torch.cuda.is_available():
return "cuda"
elif _is_musa():
return "musa"
else:
return "cpu"
def get_torch_distributed_backend() -> str:
if torch.cuda.is_available():
return "nccl"
elif _is_musa():
return "mccl"
else:
raise NotImplementedError("No Accelerators(NV/MTT GPU) available")