From ece812041590b5131412fa05ededc3530fda2b4a Mon Sep 17 00:00:00 2001 From: Bakhtiyor Sulaymonov Date: Mon, 17 Mar 2025 12:42:17 +0500 Subject: [PATCH] Update generate.py: Fix hardcoding for one device --- generate.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/generate.py b/generate.py index 87ed37b..7f0ed21 100644 --- a/generate.py +++ b/generate.py @@ -212,13 +212,25 @@ def _init_logging(rank): def generate(args): - # Set device based on args or availability - if args.device: - device = torch.device(args.device) + # Handle both distributed and single-device scenarios + if "RANK" in os.environ: + # Distributed setup + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + init_logging(rank) else: - device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") - - _init_logging(0) # Use rank 0 logging for single-device + # Single-device setup with MPS fallback + rank = 0 + world_size = 1 + if args.device: + device = torch.device(args.device) + else: + device = torch.device("cuda:0" if torch.cuda.is_available() else + "mps" if torch.backends.mps.is_available() else + "cpu") + init_logging(rank) # Ensure all torch operations use this device torch.set_default_device(device)