Update generate.py: Fix hardcoding for one device

This commit is contained in:
Bakhtiyor Sulaymonov 2025-03-17 12:42:17 +05:00 committed by GitHub
parent b3e69433d0
commit ece8120415
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -212,13 +212,25 @@ def _init_logging(rank):
def generate(args):
# Set device based on args or availability
if args.device:
device = torch.device(args.device)
# Handle both distributed and single-device scenarios
if "RANK" in os.environ:
# Distributed setup
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
init_logging(rank)
else:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
_init_logging(0) # Use rank 0 logging for single-device
# Single-device setup with MPS fallback
rank = 0
world_size = 1
if args.device:
device = torch.device(args.device)
else:
device = torch.device("cuda:0" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu")
init_logging(rank)
# Ensure all torch operations use this device
torch.set_default_device(device)