mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
Update generate.py: Fix hardcoding for one device
This commit is contained in:
parent
b3e69433d0
commit
ece8120415
24
generate.py
24
generate.py
@ -212,13 +212,25 @@ def _init_logging(rank):
|
||||
|
||||
|
||||
def generate(args):
|
||||
# Set device based on args or availability
|
||||
if args.device:
|
||||
device = torch.device(args.device)
|
||||
# Handle both distributed and single-device scenarios
|
||||
if "RANK" in os.environ:
|
||||
# Distributed setup
|
||||
rank = int(os.getenv("RANK", 0))
|
||||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
||||
init_logging(rank)
|
||||
else:
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
|
||||
_init_logging(0) # Use rank 0 logging for single-device
|
||||
# Single-device setup with MPS fallback
|
||||
rank = 0
|
||||
world_size = 1
|
||||
if args.device:
|
||||
device = torch.device(args.device)
|
||||
else:
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else
|
||||
"mps" if torch.backends.mps.is_available() else
|
||||
"cpu")
|
||||
init_logging(rank)
|
||||
|
||||
# Ensure all torch operations use this device
|
||||
torch.set_default_device(device)
|
||||
|
Loading…
Reference in New Issue
Block a user