diff --git a/gpu-switch.sh b/gpu-switch.sh index 31f686c..22238e8 100644 --- a/gpu-switch.sh +++ b/gpu-switch.sh @@ -35,21 +35,24 @@ convert_to_pci_address() { local device="$1" local gpu_address="" - if [[ "$device" =~ ^[0-9]+$ ]]; then - # Convert GPU index to PCI address - gpu_address=$(nvidia-smi --id=$device --query-gpu=gpu_bus_id --format=csv,noheader 2>/dev/null | tr -d '[:space:]') - elif [[ "$device" =~ ^GPU-.*$ ]]; then - # Handle UUID - gpu_address=$(nvidia-smi --id=$device --query-gpu=gpu_bus_id --format=csv,noheader 2>/dev/null | tr -d '[:space:]') + if [[ "$device" =~ ^[0-9]+$ || "$device" =~ ^GPU-.*$ ]]; then + # Convert GPU index or UUID to PCI address + gpu_address=$(nvidia-smi --id="$device" --query-gpu=gpu_bus_id --format=csv,noheader 2>/dev/null | tr -d '[:space:]') else # Direct PCI address provided - gpu_address=$device + gpu_address="$device" + fi + + # Check for valid output + if [ -z "$gpu_address" ]; then + error_exit "Failed to get PCI address for device: $device" fi # Standardize format - echo "$gpu_address" | sed 's/0000://' | sed 's/\./:/g' + echo "$gpu_address" | sed -e 's/0000://' -e 's/\./:/g' } + get_gpu_addresses() { # Split devices by comma IFS=',' read -ra DEVICES <<< "$NVIDIA_VISIBLE_DEVICES"