Skip to content

Commit e2ef5e5

Browse files
committed
contrib: check if nvidia drivers are already installed
Signed-off-by: CrazyMax <[email protected]>
1 parent 7f1278d commit e2ef5e5

File tree

1 file changed

+81
-79
lines changed

1 file changed

+81
-79
lines changed

contrib/cdisetup/nvidia/nvidia.go

Lines changed: 81 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ import (
2525
// This is example of experimental on-demand setup of a CDI devices.
2626
// This code is not currently shipping with BuildKit and will probably change.
2727

28-
const (
29-
cdiKind = "nvidia.com/gpu"
30-
defaultVersion = "570.0"
31-
)
28+
const cdiKind = "nvidia.com/gpu"
29+
30+
// https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
31+
var libcudaGlobs = []string{
32+
"/usr/lib/*-linux-gnu/libcuda.so*",
33+
"/usr/lib/wsl/drivers/*/libcuda.so*",
34+
}
3235

3336
func init() {
3437
cdidevices.Register(cdiKind, &setup{})
@@ -93,25 +96,15 @@ func (s *setup) Run(ctx context.Context) (err error) {
9396
}
9497

9598
var needsDriver bool
96-
if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" {
97-
if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err != nil {
99+
if !hasDriversInstalled() {
100+
// Check for nvidia-smi or /proc/driver/nvidia to determine if NVIDIA drivers are installed
101+
if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" {
102+
if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err != nil {
103+
needsDriver = true
104+
}
105+
} else if _, err := os.Stat("/proc/driver/nvidia"); err != nil {
98106
needsDriver = true
99107
}
100-
} else if _, err := os.Stat("/proc/driver/nvidia"); err != nil {
101-
needsDriver = true
102-
}
103-
104-
var arch string
105-
switch runtime.GOARCH {
106-
case "amd64":
107-
arch = "x86_64"
108-
case "arm64":
109-
arch = "sbsa"
110-
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
111-
}
112-
113-
if arch == "" {
114-
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
115108
}
116109

117110
if needsDriver {
@@ -122,21 +115,19 @@ func (s *setup) Run(ctx context.Context) (err error) {
122115
}
123116

124117
var dv string
125-
if !hasWSLGPU() {
118+
if needsDriver {
119+
if hasWSLGPU() {
120+
return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs")
121+
}
126122
version, err := readVersion()
127-
if err != nil && !needsDriver {
123+
if err != nil {
128124
return errors.Wrapf(err, "failed to read NVIDIA driver version")
129125
}
130-
if version == "" {
131-
version = defaultVersion
132-
}
133126
var ok bool
134127
dv, _, ok = strings.Cut(version, ".")
135128
if !ok {
136129
return errors.Errorf("failed to parse NVIDIA driver version %q", version)
137130
}
138-
} else if needsDriver {
139-
return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs")
140131
}
141132

142133
if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
@@ -147,9 +138,58 @@ func (s *setup) Run(ctx context.Context) (err error) {
147138
return err
148139
}
149140

141+
if err := installPackages(ctx, dv, pw, dgst); err != nil {
142+
return err
143+
}
144+
145+
if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
146+
return errors.Wrapf(err, "failed to create /etc/cdi")
147+
}
148+
149+
buf := &bytes.Buffer{}
150+
151+
cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
152+
cmd.Stdout = buf
153+
cmd.Stderr = newStream(pw, 2, dgst)
154+
if err := cmd.Run(); err != nil {
155+
return errors.Wrapf(err, "failed to generate CDI spec")
156+
}
157+
158+
if len(buf.Bytes()) == 0 {
159+
return errors.Errorf("nvidia-ctk output is empty")
160+
}
161+
162+
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
163+
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
164+
}
165+
166+
return nil
167+
}
168+
169+
func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
170+
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
171+
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
172+
cmd.Stderr = newStream(pw, 2, dgst)
173+
cmd.Stdout = newStream(pw, 1, dgst)
174+
return cmd.Run()
175+
}
176+
177+
func installPackages(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error {
150178
const aptDistro = "ubuntu2404"
151-
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
152179

180+
var arch string
181+
switch runtime.GOARCH {
182+
case "amd64":
183+
arch = "x86_64"
184+
case "arm64":
185+
arch = "sbsa"
186+
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
187+
}
188+
if arch == "" {
189+
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
190+
}
191+
192+
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
153193
keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"
154194

155195
if _, err := os.Stat(keyTarget); err != nil {
@@ -182,22 +222,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
182222
return err
183223
}
184224

185-
if needsDriver && dv != "" {
186-
// this pretty much never works, is it even worth having?
187-
// better approach could be to try to create another chroot/container that is built with same kernel packages as the host
188-
// could nvidia-headless-no-dkms- be reusable
189-
if err := run(ctx, []string{"apt-get", "install", "-y", "nvidia-driver-" + dv}, pw, dgst); err != nil {
190-
return err
191-
}
192-
_, err := os.Stat("/proc/driver/nvidia")
193-
if err != nil {
194-
return errors.Wrapf(err, "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually")
195-
}
196-
}
197-
198-
pkgs := []string{
199-
"nvidia-container-toolkit-base",
200-
}
225+
pkgs := []string{"nvidia-container-toolkit-base"}
201226
if dv != "" {
202227
pkgs = append(pkgs, []string{
203228
"libnvidia-compute-" + dv,
@@ -207,40 +232,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
207232
}...)
208233
}
209234

210-
if err := run(ctx, append([]string{"apt-get", "install", "-y", "--no-install-recommends"}, pkgs...), pw, dgst); err != nil {
211-
return err
212-
}
213-
214-
if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
215-
return errors.Wrapf(err, "failed to create /etc/cdi")
216-
}
217-
218-
buf := &bytes.Buffer{}
219-
220-
cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
221-
cmd.Stdout = buf
222-
cmd.Stderr = newStream(pw, 2, dgst)
223-
if err := cmd.Run(); err != nil {
224-
return errors.Wrapf(err, "failed to generate CDI spec")
225-
}
226-
227-
if len(buf.Bytes()) == 0 {
228-
return errors.Errorf("nvidia-ctk output is empty")
229-
}
230-
231-
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
232-
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
233-
}
234-
235-
return nil
236-
}
237-
238-
func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
239-
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
240-
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
241-
cmd.Stderr = newStream(pw, 2, dgst)
242-
cmd.Stdout = newStream(pw, 1, dgst)
243-
return cmd.Run()
235+
return run(ctx, append([]string{"apt-get", "install", "-y", "--no-install-recommends"}, pkgs...), pw, dgst)
244236
}
245237

246238
func readVersion() (string, error) {
@@ -326,3 +318,13 @@ func hasWSLGPU() bool {
326318
_, err := os.Stat("/dev/dxg")
327319
return err == nil
328320
}
321+
322+
func hasDriversInstalled() bool {
323+
// Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
324+
for _, p := range libcudaGlobs {
325+
if matches, err := filepath.Glob(p); err == nil && len(matches) > 0 {
326+
return true
327+
}
328+
}
329+
return false
330+
}

0 commit comments

Comments
 (0)