Skip to content

Commit da31bd4

Browse files
committed
contrib: check if nvidia drivers are already installed
Signed-off-by: CrazyMax <[email protected]>
1 parent 7f1278d commit da31bd4

File tree

1 file changed

+121
-82
lines changed

1 file changed

+121
-82
lines changed

contrib/cdisetup/nvidia/nvidia.go

Lines changed: 121 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ import (
2525
// This is example of experimental on-demand setup of a CDI devices.
2626
// This code is not currently shipping with BuildKit and will probably change.
2727

28-
const (
29-
cdiKind = "nvidia.com/gpu"
30-
defaultVersion = "570.0"
31-
)
28+
const cdiKind = "nvidia.com/gpu"
29+
30+
// https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
31+
var libcudaGlobs = []string{
32+
"/usr/lib/*-linux-gnu/libcuda.so*",
33+
"/usr/lib/wsl/drivers/*/libcuda.so*",
34+
}
3235

3336
func init() {
3437
cdidevices.Register(cdiKind, &setup{})
@@ -93,25 +96,15 @@ func (s *setup) Run(ctx context.Context) (err error) {
9396
}
9497

9598
var needsDriver bool
96-
if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" {
97-
if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err != nil {
99+
if !hasDriversInstalled() {
100+
// Check for nvidia-smi or /proc/driver/nvidia to determine if NVIDIA drivers are installed
101+
if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" {
102+
if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err != nil {
103+
needsDriver = true
104+
}
105+
} else if _, err := os.Stat("/proc/driver/nvidia"); err != nil {
98106
needsDriver = true
99107
}
100-
} else if _, err := os.Stat("/proc/driver/nvidia"); err != nil {
101-
needsDriver = true
102-
}
103-
104-
var arch string
105-
switch runtime.GOARCH {
106-
case "amd64":
107-
arch = "x86_64"
108-
case "arm64":
109-
arch = "sbsa"
110-
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
111-
}
112-
113-
if arch == "" {
114-
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
115108
}
116109

117110
if needsDriver {
@@ -122,21 +115,19 @@ func (s *setup) Run(ctx context.Context) (err error) {
122115
}
123116

124117
var dv string
125-
if !hasWSLGPU() {
118+
if needsDriver {
119+
if hasWSLGPU() {
120+
return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs")
121+
}
126122
version, err := readVersion()
127-
if err != nil && !needsDriver {
123+
if err != nil {
128124
return errors.Wrapf(err, "failed to read NVIDIA driver version")
129125
}
130-
if version == "" {
131-
version = defaultVersion
132-
}
133126
var ok bool
134127
dv, _, ok = strings.Cut(version, ".")
135128
if !ok {
136129
return errors.Errorf("failed to parse NVIDIA driver version %q", version)
137130
}
138-
} else if needsDriver {
139-
return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs")
140131
}
141132

142133
if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
@@ -147,100 +138,138 @@ func (s *setup) Run(ctx context.Context) (err error) {
147138
return err
148139
}
149140

141+
if needsDriver {
142+
if err := installDrivers(ctx, dv, pw, dgst); err != nil {
143+
return err
144+
}
145+
}
146+
if err := installContainerToolkit(ctx, pw, dgst); err != nil {
147+
return err
148+
}
149+
150+
if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
151+
return errors.Wrapf(err, "failed to create /etc/cdi")
152+
}
153+
154+
buf := &bytes.Buffer{}
155+
156+
cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
157+
cmd.Stdout = buf
158+
cmd.Stderr = newStream(pw, 2, dgst)
159+
if err := cmd.Run(); err != nil {
160+
return errors.Wrapf(err, "failed to generate CDI spec")
161+
}
162+
163+
if len(buf.Bytes()) == 0 {
164+
return errors.Errorf("nvidia-ctk output is empty")
165+
}
166+
167+
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
168+
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
169+
}
170+
171+
return nil
172+
}
173+
174+
func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
175+
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
176+
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
177+
cmd.Stderr = newStream(pw, 2, dgst)
178+
cmd.Stdout = newStream(pw, 1, dgst)
179+
return cmd.Run()
180+
}
181+
182+
func installDrivers(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error {
150183
const aptDistro = "ubuntu2404"
151-
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
152184

185+
var arch string
186+
switch runtime.GOARCH {
187+
case "amd64":
188+
arch = "x86_64"
189+
case "arm64":
190+
arch = "sbsa"
191+
// for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
192+
}
193+
if arch == "" {
194+
return errors.Errorf("unsupported architecture: %s", runtime.GOARCH)
195+
}
196+
197+
aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
153198
keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"
154199

155200
if _, err := os.Stat(keyTarget); err != nil {
156-
fmt.Fprintf(newStream(pw, 2, dgst), "Downloading NVIDIA GPG key\n")
201+
fmt.Fprintf(newStream(pw, 2, dgst), "Downloading NVIDIA cuda GPG key\n")
157202

158203
req, err := http.NewRequestWithContext(ctx, http.MethodGet, aptURL+"3bf863cc.pub", nil)
159204
if err != nil {
160-
return errors.Wrapf(err, "failed to create request for NVIDIA GPG key")
205+
return errors.Wrapf(err, "failed to create request for NVIDIA cuda GPG key")
161206
}
162207

163208
resp, err := http.DefaultClient.Do(req)
164209
if err != nil {
165-
return errors.Wrapf(err, "failed to download NVIDIA GPG key")
210+
return errors.Wrapf(err, "failed to download NVIDIA cuda GPG key")
166211
}
167212

168213
cmd := exec.CommandContext(ctx, "gpg", "--dearmor", "-o", keyTarget)
169214
cmd.Stdin = resp.Body
170215
cmd.Stderr = newStream(pw, 2, dgst)
171216
if err := cmd.Run(); err != nil {
172-
return errors.Wrapf(err, "failed to install NVIDIA GPG key")
217+
return errors.Wrapf(err, "failed to install NVIDIA cuda GPG key")
173218
}
174219
resp.Body.Close()
175220
}
176221

177222
if err := os.WriteFile("/etc/apt/sources.list.d/nvidia-cuda.list", []byte("deb [signed-by="+keyTarget+"] "+aptURL+" /"), 0644); err != nil {
178-
return errors.Wrapf(err, "failed to add NVIDIA apt repo")
223+
return errors.Wrapf(err, "failed to add NVIDIA cuda apt repo")
179224
}
180225

181226
if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
182227
return err
183228
}
184229

185-
if needsDriver && dv != "" {
186-
// this pretty much never works, is it even worth having?
187-
// better approach could be to try to create another chroot/container that is built with same kernel packages as the host
188-
// could nvidia-headless-no-dkms- be reusable
189-
if err := run(ctx, []string{"apt-get", "install", "-y", "nvidia-driver-" + dv}, pw, dgst); err != nil {
190-
return err
191-
}
192-
_, err := os.Stat("/proc/driver/nvidia")
193-
if err != nil {
194-
return errors.Wrapf(err, "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually")
195-
}
196-
}
230+
return run(ctx, []string{"apt-get", "install", "-y", "--no-install-recommends",
231+
"libnvidia-compute-" + dv,
232+
"libnvidia-extra-" + dv,
233+
"libnvidia-gl-" + dv,
234+
"nvidia-utils-" + dv,
235+
}, pw, dgst)
236+
}
197237

198-
pkgs := []string{
199-
"nvidia-container-toolkit-base",
200-
}
201-
if dv != "" {
202-
pkgs = append(pkgs, []string{
203-
"libnvidia-compute-" + dv,
204-
"libnvidia-extra-" + dv,
205-
"libnvidia-gl-" + dv,
206-
"nvidia-utils-" + dv,
207-
}...)
208-
}
238+
func installContainerToolkit(ctx context.Context, pw progress.Writer, dgst digest.Digest) error {
239+
aptURL := "https://nvidia.github.io/libnvidia-container/"
240+
keyTarget := "/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg"
209241

210-
if err := run(ctx, append([]string{"apt-get", "install", "-y", "--no-install-recommends"}, pkgs...), pw, dgst); err != nil {
211-
return err
212-
}
242+
if _, err := os.Stat(keyTarget); err != nil {
243+
fmt.Fprintf(newStream(pw, 2, dgst), "Downloading NVIDIA container toolkit GPG key\n")
213244

214-
if err := os.MkdirAll("/etc/cdi", 0700); err != nil {
215-
return errors.Wrapf(err, "failed to create /etc/cdi")
216-
}
245+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, aptURL+"gpgkey", nil)
246+
if err != nil {
247+
return errors.Wrapf(err, "failed to create request for NVIDIA container toolkit GPG key")
248+
}
217249

218-
buf := &bytes.Buffer{}
250+
resp, err := http.DefaultClient.Do(req)
251+
if err != nil {
252+
return errors.Wrapf(err, "failed to download NVIDIA container toolkit GPG key")
253+
}
219254

220-
cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate")
221-
cmd.Stdout = buf
222-
cmd.Stderr = newStream(pw, 2, dgst)
223-
if err := cmd.Run(); err != nil {
224-
return errors.Wrapf(err, "failed to generate CDI spec")
255+
cmd := exec.CommandContext(ctx, "gpg", "--dearmor", "-o", keyTarget)
256+
cmd.Stdin = resp.Body
257+
cmd.Stderr = newStream(pw, 2, dgst)
258+
if err := cmd.Run(); err != nil {
259+
return errors.Wrapf(err, "failed to install NVIDIA container toolkit GPG key")
260+
}
261+
resp.Body.Close()
225262
}
226263

227-
if len(buf.Bytes()) == 0 {
228-
return errors.Errorf("nvidia-ctk output is empty")
264+
if err := os.WriteFile("/etc/apt/sources.list.d/nvidia-container-toolkit.list", []byte("deb [signed-by="+keyTarget+"] "+aptURL+"stable/deb/$(ARCH) /"), 0644); err != nil {
265+
return errors.Wrapf(err, "failed to add NVIDIA container toolkit apt repo")
229266
}
230267

231-
if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil {
232-
return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml")
268+
if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil {
269+
return err
233270
}
234271

235-
return nil
236-
}
237-
238-
func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error {
239-
fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " "))
240-
cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec
241-
cmd.Stderr = newStream(pw, 2, dgst)
242-
cmd.Stdout = newStream(pw, 1, dgst)
243-
return cmd.Run()
272+
return run(ctx, []string{"apt-get", "install", "-y", "nvidia-container-toolkit-base"}, pw, dgst)
244273
}
245274

246275
func readVersion() (string, error) {
@@ -326,3 +355,13 @@ func hasWSLGPU() bool {
326355
_, err := os.Stat("/dev/dxg")
327356
return err == nil
328357
}
358+
359+
func hasDriversInstalled() bool {
360+
// Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
361+
for _, p := range libcudaGlobs {
362+
if matches, err := filepath.Glob(p); err == nil && len(matches) > 0 {
363+
return true
364+
}
365+
}
366+
return false
367+
}

0 commit comments

Comments
 (0)