diff --git a/generate.py b/generate.py index 87ed37b..7f0ed21 100644 --- a/generate.py +++ b/generate.py @@ -212,13 +212,25 @@ def _init_logging(rank): def generate(args): - # Set device based on args or availability - if args.device: - device = torch.device(args.device) + # Handle both distributed and single-device scenarios + if "RANK" in os.environ: + # Distributed setup + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + init_logging(rank) else: - device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") - - _init_logging(0) # Use rank 0 logging for single-device + # Single-device setup with MPS fallback + rank = 0 + world_size = 1 + if args.device: + device = torch.device(args.device) + else: + device = torch.device("cuda:0" if torch.cuda.is_available() else + "mps" if torch.backends.mps.is_available() else + "cpu") + init_logging(rank) # Ensure all torch operations use this device torch.set_default_device(device)