Skip to content

Commit f85a66c

Browse files
committed
contrib: check if nvidia drivers are already installed
Signed-off-by: CrazyMax <[email protected]>
1 parent 7f1278d commit f85a66c

File tree

1 file changed

+81
-86
lines changed

1 file changed

+81
-86
lines changed

contrib/cdisetup/nvidia/nvidia.go

Lines changed: 81 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ import (
2525
// This is example of experimental on-demand setup of a CDI devices.
2626
// This code is not currently shipping with BuildKit and will probably change.
2727

28-
const (
29-
cdiKind = "nvidia.com/gpu"
30-
defaultVersion = "570.0"
31-
)
28+
const cdiKind = "nvidia.com/gpu"
29+
30+
// https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
31+
var libcudaGlobs = []string{
32+
"/usr/lib/*-linux-gnu/libcuda.so*",
33+
"/usr/lib/wsl/drivers/*/libcuda.so*",
34+
}
3235

3336
func init() {
3437
cdidevices.Register(cdiKind, &setup{})
@@ -92,51 +95,32 @@ func (s *setup) Run(ctx context.Context) (err error) {
9295
return errors.Errorf("NVIDIA setup is currently only supported on Debian/Ubuntu")
9396
}
9497

95-
var needsDriver bool
96-
if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" {
97-
if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err != nil {
98-
needsDriver = true
98+
needsDriver := true
99+
if _, err := os.Stat("/proc/driver/nvidia"); err == nil {
100+
needsDriver = false
101+
} else if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" {
102+
if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err == nil {
103+
needsDriver = false
99104
}
100-
} else if _, err := os.Stat("/proc/driver/nvidia"); err != nil {
101-
needsDriver = true
102-
}
103-
104-
var arch string
105-
switch runtime.GOARCH {
106-
case "amd64":
107-
arch = "x86_64"
108-
case "arm64":
109-
arch = "sbsa"
110-
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
111-
}
112-
113-
if arch == "" {
114-
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
115105
}
116-
117106
if needsDriver {
118-
pw.Write(identity.NewID(), client.VertexWarning{
119-
Vertex: dgst,
120-
Short: []byte("NVIDIA Drivers not found. Installing prebuilt drivers is not recommended"),
121-
})
107+
if hasWSLGPU() {
108+
return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs")
109+
}
110+
return errors.Errorf("NVIDIA drivers are required. Try loading NVIDIA kernel module with \"modprobe nvidia\" command")
122111
}
123112

124113
var dv string
125-
if !hasWSLGPU() {
114+
if !hasLibsInstalled() && !hasWSLGPU() {
126115
version, err := readVersion()
127-
if err != nil && !needsDriver {
116+
if err != nil {
128117
return errors.Wrapf(err, "failed to read NVIDIA driver version")
129118
}
130-
if version == "" {
131-
version = defaultVersion
132-
}
133119
var ok bool
134120
dv, _, ok = strings.Cut(version, ".")
135121
if !ok {
136122
return errors.Errorf("failed to parse NVIDIA driver version %q", version)
137123
}
138-
} else if needsDriver {
139-
return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs")
140124
}
141125

142126
if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
@@ -147,9 +131,58 @@ func (s *setup) Run(ctx context.Context) (err error) {
147131
return err
148132
}
149133

134+
if err := installPackages(ctx, dv, pw, dgst); err != nil {
135+
return err
136+
}
137+
138+
if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
139+
return errors.Wrapf(err, "failed to create /etc/cdi")
140+
}
141+
142+
buf := &bytes.Buffer{}
143+
144+
cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
145+
cmd.Stdout = buf
146+
cmd.Stderr = newStream(pw, 2, dgst)
147+
if err := cmd.Run(); err != nil {
148+
return errors.Wrapf(err, "failed to generate CDI spec")
149+
}
150+
151+
if len(buf.Bytes()) == 0 {
152+
return errors.Errorf("nvidia-ctk output is empty")
153+
}
154+
155+
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
156+
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
157+
}
158+
159+
return nil
160+
}
161+
162+
func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
163+
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
164+
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
165+
cmd.Stderr = newStream(pw, 2, dgst)
166+
cmd.Stdout = newStream(pw, 1, dgst)
167+
return cmd.Run()
168+
}
169+
170+
func installPackages(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error {
150171
const aptDistro = "ubuntu2404"
151-
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
152172

173+
var arch string
174+
switch runtime.GOARCH {
175+
case "amd64":
176+
arch = "x86_64"
177+
case "arm64":
178+
arch = "sbsa"
179+
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
180+
}
181+
if arch == "" {
182+
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
183+
}
184+
185+
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
153186
keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"
154187

155188
if _, err := os.Stat(keyTarget); err != nil {
@@ -182,22 +215,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
182215
return err
183216
}
184217

185-
if needsDriver && dv != "" {
186-
// this pretty much never works, is it even worth having?
187-
// better approach could be to try to create another chroot/container that is built with same kernel packages as the host
188-
// could nvidia-headless-no-dkms- be reusable
189-
if err := run(ctx, []string{"apt-get", "install", "-y", "nvidia-driver-" + dv}, pw, dgst); err != nil {
190-
return err
191-
}
192-
_, err := os.Stat("/proc/driver/nvidia")
193-
if err != nil {
194-
return errors.Wrapf(err, "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually")
195-
}
196-
}
197-
198-
pkgs := []string{
199-
"nvidia-container-toolkit-base",
200-
}
218+
pkgs := []string{"nvidia-container-toolkit-base"}
201219
if dv != "" {
202220
pkgs = append(pkgs, []string{
203221
"libnvidia-compute-" + dv,
@@ -207,40 +225,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
207225
}...)
208226
}
209227

210-
if err := run(ctx, append([]string{"apt-get", "install", "-y", "--no-install-recommends"}, pkgs...), pw, dgst); err != nil {
211-
return err
212-
}
213-
214-
if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
215-
return errors.Wrapf(err, "failed to create /etc/cdi")
216-
}
217-
218-
buf := &bytes.Buffer{}
219-
220-
cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
221-
cmd.Stdout = buf
222-
cmd.Stderr = newStream(pw, 2, dgst)
223-
if err := cmd.Run(); err != nil {
224-
return errors.Wrapf(err, "failed to generate CDI spec")
225-
}
226-
227-
if len(buf.Bytes()) == 0 {
228-
return errors.Errorf("nvidia-ctk output is empty")
229-
}
230-
231-
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
232-
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
233-
}
234-
235-
return nil
236-
}
237-
238-
func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
239-
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
240-
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
241-
cmd.Stderr = newStream(pw, 2, dgst)
242-
cmd.Stdout = newStream(pw, 1, dgst)
243-
return cmd.Run()
228+
return run(ctx, append([]string{"apt-get", "install", "-y", "--no-install-recommends"}, pkgs...), pw, dgst)
244229
}
245230

246231
func readVersion() (string, error) {
@@ -326,3 +311,13 @@ func hasWSLGPU() bool {
326311
_, err := os.Stat("/dev/dxg")
327312
return err == nil
328313
}
314+
315+
func hasLibsInstalled() bool {
316+
// Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
317+
for _, p := range libcudaGlobs {
318+
if matches, err := filepath.Glob(p); err == nil && len(matches) > 0 {
319+
return true
320+
}
321+
}
322+
return false
323+
}

0 commit comments

Comments
 (0)