Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/uv-resolver/src/resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
.options
.torch_backend
.as_ref()
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Auto { .. }))
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Cuda { .. }))
.and_then(|_| pins.get(name, version).and_then(ResolvedDist::index))
.map(IndexUrl::url)
.and_then(SystemDependency::from_index)
Expand Down
6 changes: 5 additions & 1 deletion crates/uv-static/src/env_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,14 @@ impl EnvVars {
/// This is a quasi-standard variable, described, e.g., in `ncurses(3x)`.
pub const COLUMNS: &'static str = "COLUMNS";

/// The CUDA driver version to assume when inferring the PyTorch backend.
/// The CUDA driver version to assume when inferring the PyTorch backend (e.g., `550.144.03`).
#[attr_hidden]
pub const UV_CUDA_DRIVER_VERSION: &'static str = "UV_CUDA_DRIVER_VERSION";

/// The AMD GPU architecture to assume when inferring the PyTorch backend (e.g., `gfx1100`).
#[attr_hidden]
pub const UV_AMD_GPU_ARCHITECTURE: &'static str = "UV_AMD_GPU_ARCHITECTURE";

/// Equivalent to the `--torch-backend` command-line argument (e.g., `cpu`, `cu126`, or `auto`).
pub const UV_TORCH_BACKEND: &'static str = "UV_TORCH_BACKEND";

Expand Down
110 changes: 109 additions & 1 deletion crates/uv-torch/src/accelerator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,30 @@ pub enum AcceleratorError {
Version(#[from] uv_pep440::VersionParseError),
#[error(transparent)]
Utf8(#[from] std::string::FromUtf8Error),
#[error("Unknown AMD GPU architecture: {0}")]
UnknownAmdGpuArchitecture(String),
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Accelerator {
/// The CUDA driver version (e.g., `550.144.03`).
///
/// This is in contrast to the CUDA toolkit version (e.g., `12.8.0`).
Cuda { driver_version: Version },
/// The AMD GPU architecture (e.g., `gfx906`).
///
/// This is in contrast to the user-space ROCm version (e.g., `6.4.0-47`) or the kernel-mode
/// driver version (e.g., `6.12.12`).
Amd {
gpu_architecture: AmdGpuArchitecture,
},
}

impl std::fmt::Display for Accelerator {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"),
}
}
}
Expand All @@ -33,9 +46,11 @@ impl Accelerator {
///
/// Query, in order:
/// 1. The `UV_CUDA_DRIVER_VERSION` environment variable.
/// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
/// 2. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
/// 3. `/proc/driver/nvidia/version`, which contains the driver version among other information.
/// 4. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
/// 5. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
pub fn detect() -> Result<Option<Self>, AcceleratorError> {
// Read from `UV_CUDA_DRIVER_VERSION`.
if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) {
Expand All @@ -44,6 +59,15 @@ impl Accelerator {
return Ok(Some(Self::Cuda { driver_version }));
}

// Read from `UV_AMD_GPU_ARCHITECTURE`.
if let Ok(gpu_architecture) = std::env::var(EnvVars::UV_AMD_GPU_ARCHITECTURE) {
let gpu_architecture = AmdGpuArchitecture::from_str(&gpu_architecture)?;
debug!(
"Detected AMD GPU architecture from `UV_AMD_GPU_ARCHITECTURE`: {gpu_architecture}"
);
return Ok(Some(Self::Amd { gpu_architecture }));
}

// Read from `/sys/module/nvidia/version`.
match fs_err::read_to_string("/sys/module/nvidia/version") {
Ok(content) => {
Expand Down Expand Up @@ -100,7 +124,34 @@ impl Accelerator {
);
}

debug!("Failed to detect CUDA driver version");
// Query `rocm_agent_enumerator` to detect the AMD GPU architecture.
//
// See: https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html
if let Ok(output) = std::process::Command::new("rocm_agent_enumerator").output() {
if output.status.success() {
let stdout = String::from_utf8(output.stdout)?;
if let Some(gpu_architecture) = stdout
.lines()
.map(str::trim)
.filter_map(|line| AmdGpuArchitecture::from_str(line).ok())
.min()
{
debug!(
"Detected AMD GPU architecture from `rocm_agent_enumerator`: {gpu_architecture}"
);
return Ok(Some(Self::Amd { gpu_architecture }));
}
} else {
debug!(
"Failed to query AMD GPU architecture with `rocm_agent_enumerator` with status `{}`: {}",
output.status,
String::from_utf8_lossy(&output.stderr)
);
}
}

debug!("Failed to detect GPU driver version");

Ok(None)
}
}
Expand Down Expand Up @@ -129,6 +180,63 @@ fn parse_proc_driver_nvidia_version(content: &str) -> Result<Option<Version>, Ac
Ok(Some(driver_version))
}

/// A GPU architecture for AMD GPUs.
///
/// See: <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html>
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub enum AmdGpuArchitecture {
Gfx900,
Gfx906,
Gfx908,
Gfx90a,
Gfx942,
Gfx1030,
Gfx1100,
Gfx1101,
Gfx1102,
Gfx1200,
Gfx1201,
}

impl FromStr for AmdGpuArchitecture {
type Err = AcceleratorError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gfx900" => Ok(Self::Gfx900),
"gfx906" => Ok(Self::Gfx906),
"gfx908" => Ok(Self::Gfx908),
"gfx90a" => Ok(Self::Gfx90a),
"gfx942" => Ok(Self::Gfx942),
"gfx1030" => Ok(Self::Gfx1030),
"gfx1100" => Ok(Self::Gfx1100),
"gfx1101" => Ok(Self::Gfx1101),
"gfx1102" => Ok(Self::Gfx1102),
"gfx1200" => Ok(Self::Gfx1200),
"gfx1201" => Ok(Self::Gfx1201),
_ => Err(AcceleratorError::UnknownAmdGpuArchitecture(s.to_string())),
}
}
}

impl std::fmt::Display for AmdGpuArchitecture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Gfx900 => write!(f, "gfx900"),
Self::Gfx906 => write!(f, "gfx906"),
Self::Gfx908 => write!(f, "gfx908"),
Self::Gfx90a => write!(f, "gfx90a"),
Self::Gfx942 => write!(f, "gfx942"),
Self::Gfx1030 => write!(f, "gfx1030"),
Self::Gfx1100 => write!(f, "gfx1100"),
Self::Gfx1101 => write!(f, "gfx1101"),
Self::Gfx1102 => write!(f, "gfx1102"),
Self::Gfx1200 => write!(f, "gfx1200"),
Self::Gfx1201 => write!(f, "gfx1201"),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading
Loading