Skip to content

Commit f8c1909

Browse files
authored
Merge pull request #5880 from crazy-max/contrib-nvidia-nopci
contrib: check nvidia drivers and support non PCI-based GPUs for WSL
2 parents e2bf281 + f85a66c commit f8c1909

File tree

1 file changed

+107
-88
lines changed

1 file changed

+107
-88
lines changed

contrib/cdisetup/nvidia/nvidia.go

Lines changed: 107 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ import (
2525
// This is example of experimental on-demand setup of a CDI devices.
2626
// This code is not currently shipping with BuildKit and will probably change.
2727

28-
const (
29-
cdiKind = "nvidia.com/gpu"
30-
defaultVersion = "570.0"
31-
)
28+
const cdiKind = "nvidia.com/gpu"
29+
30+
// https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
31+
var libcudaGlobs = []string{
32+
"/usr/lib/*-linux-gnu/libcuda.so*",
33+
"/usr/lib/wsl/drivers/*/libcuda.so*",
34+
}
3235

3336
func init() {
3437
cdidevices.Register(cdiKind, &setup{})
@@ -39,8 +42,7 @@ type setup struct{}
3942
var _ cdidevices.Setup = &setup{}
4043

4144
func (s *setup) Validate() error {
42-
_, err := readVersion()
43-
if err == nil {
45+
if _, err := readVersion(); err == nil {
4446
return nil
4547
}
4648
b, err := hasNvidiaDevices()
@@ -93,55 +95,94 @@ func (s *setup) Run(ctx context.Context) (err error) {
9395
return errors.Errorf("NVIDIA setup is currently only supported on Debian/Ubuntu")
9496
}
9597

96-
var needsDriver bool
97-
98-
if _, err := os.Stat("/proc/driver/nvidia"); err != nil {
99-
needsDriver = true
98+
needsDriver := true
99+
if _, err := os.Stat("/proc/driver/nvidia"); err == nil {
100+
needsDriver = false
101+
} else if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" {
102+
if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err == nil {
103+
needsDriver = false
104+
}
105+
}
106+
if needsDriver {
107+
if hasWSLGPU() {
108+
return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs")
109+
}
110+
return errors.Errorf("NVIDIA drivers are required. Try loading NVIDIA kernel module with \"modprobe nvidia\" command")
100111
}
101112

102-
var arch string
103-
switch runtime.GOARCH {
104-
case "amd64":
105-
arch = "x86_64"
106-
case "arm64":
107-
arch = "sbsa"
108-
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
113+
var dv string
114+
if !hasLibsInstalled() && !hasWSLGPU() {
115+
version, err := readVersion()
116+
if err != nil {
117+
return errors.Wrapf(err, "failed to read NVIDIA driver version")
118+
}
119+
var ok bool
120+
dv, _, ok = strings.Cut(version, ".")
121+
if !ok {
122+
return errors.Errorf("failed to parse NVIDIA driver version %q", version)
123+
}
109124
}
110125

111-
if arch == "" {
112-
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
126+
if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
127+
return err
113128
}
114129

115-
if needsDriver {
116-
pw.Write(identity.NewID(), client.VertexWarning{
117-
Vertex: dgst,
118-
Short: []byte("NVIDIA Drivers not found. Installing prebuilt drivers is not recommended"),
119-
})
130+
if err := run(ctx, []string{"apt-get", "install", "-y", "gpg"}, pw, dgst); err != nil {
131+
return err
120132
}
121133

122-
version, err := readVersion()
123-
if err != nil && !needsDriver {
124-
return errors.Wrapf(err, "failed to read NVIDIA driver version")
134+
if err := installPackages(ctx, dv, pw, dgst); err != nil {
135+
return err
125136
}
126-
if version == "" {
127-
version = defaultVersion
137+
138+
if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
139+
return errors.Wrapf(err, "failed to create /etc/cdi")
128140
}
129-
v1, _, ok := strings.Cut(version, ".")
130-
if !ok {
131-
return errors.Errorf("failed to parse NVIDIA driver version %q", version)
141+
142+
buf := &bytes.Buffer{}
143+
144+
cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
145+
cmd.Stdout = buf
146+
cmd.Stderr = newStream(pw, 2, dgst)
147+
if err := cmd.Run(); err != nil {
148+
return errors.Wrapf(err, "failed to generate CDI spec")
132149
}
133150

134-
if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
135-
return err
151+
if len(buf.Bytes()) == 0 {
152+
return errors.Errorf("nvidia-ctk output is empty")
136153
}
137154

138-
if err := run(ctx, []string{"apt-get", "install", "-y", "gpg"}, pw, dgst); err != nil {
139-
return err
155+
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
156+
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
140157
}
141158

159+
return nil
160+
}
161+
162+
func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
163+
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
164+
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
165+
cmd.Stderr = newStream(pw, 2, dgst)
166+
cmd.Stdout = newStream(pw, 1, dgst)
167+
return cmd.Run()
168+
}
169+
170+
func installPackages(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error {
142171
const aptDistro = "ubuntu2404"
143-
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
144172

173+
var arch string
174+
switch runtime.GOARCH {
175+
case "amd64":
176+
arch = "x86_64"
177+
case "arm64":
178+
arch = "sbsa"
179+
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
180+
}
181+
if arch == "" {
182+
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
183+
}
184+
185+
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
145186
keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"
146187

147188
if _, err := os.Stat(keyTarget); err != nil {
@@ -174,59 +215,17 @@ func (s *setup) Run(ctx context.Context) (err error) {
174215
return err
175216
}
176217

177-
if needsDriver {
178-
// this pretty much never works, is it even worth having?
179-
// better approach could be to try to create another chroot/container that is built with same kernel packages as the host
180-
// could nvidia-headless-no-dkms- be reusable
181-
if err := run(ctx, []string{"apt-get", "install", "-y", "nvidia-driver-" + v1}, pw, dgst); err != nil {
182-
return err
183-
}
184-
_, err := os.Stat("/proc/driver/nvidia")
185-
if err != nil {
186-
return errors.Wrapf(err, "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually")
187-
}
188-
}
189-
190-
if err := run(ctx, []string{"apt-get", "install", "-y", "--no-install-recommends",
191-
"libnvidia-compute-" + v1,
192-
"libnvidia-extra-" + v1,
193-
"libnvidia-gl-" + v1,
194-
"nvidia-utils-" + v1,
195-
"nvidia-container-toolkit-base",
196-
}, pw, dgst); err != nil {
197-
return err
198-
}
199-
200-
if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
201-
return errors.Wrapf(err, "failed to create /etc/cdi")
218+
pkgs := []string{"nvidia-container-toolkit-base"}
219+
if dv != "" {
220+
pkgs = append(pkgs, []string{
221+
"libnvidia-compute-" + dv,
222+
"libnvidia-extra-" + dv,
223+
"libnvidia-gl-" + dv,
224+
"nvidia-utils-" + dv,
225+
}...)
202226
}
203227

204-
buf := &bytes.Buffer{}
205-
206-
cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
207-
cmd.Stdout = buf
208-
cmd.Stderr = newStream(pw, 2, dgst)
209-
if err := cmd.Run(); err != nil {
210-
return errors.Wrapf(err, "failed to generate CDI spec")
211-
}
212-
213-
if len(buf.Bytes()) == 0 {
214-
return errors.Errorf("nvidia-ctk output is empty")
215-
}
216-
217-
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
218-
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
219-
}
220-
221-
return nil
222-
}
223-
224-
func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
225-
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
226-
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
227-
cmd.Stderr = newStream(pw, 2, dgst)
228-
cmd.Stdout = newStream(pw, 1, dgst)
229-
return cmd.Run()
228+
return run(ctx, append([]string{"apt-get", "install", "-y", "--no-install-recommends"}, pkgs...), pw, dgst)
230229
}
231230

232231
func readVersion() (string, error) {
@@ -268,6 +267,10 @@ func hasNvidiaDevices() (bool, error) {
268267
}
269268
}
270269

270+
if !found {
271+
found = hasWSLGPU()
272+
}
273+
271274
return found, nil
272275
}
273276

@@ -302,3 +305,19 @@ func isDebianOrUbuntu() (bool, error) {
302305

303306
return id == "debian" || id == "ubuntu", nil
304307
}
308+
309+
func hasWSLGPU() bool {
310+
// WSL-specific GPU mapping that doesn't expose PCI info.
311+
_, err := os.Stat("/dev/dxg")
312+
return err == nil
313+
}
314+
315+
func hasLibsInstalled() bool {
316+
// Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
317+
for _, p := range libcudaGlobs {
318+
if matches, err := filepath.Glob(p); err == nil && len(matches) > 0 {
319+
return true
320+
}
321+
}
322+
return false
323+
}

0 commit comments

Comments
 (0)