@@ -25,10 +25,13 @@ import (
2525// This is example of experimental on-demand setup of a CDI devices.
2626// This code is not currently shipping with BuildKit and will probably change.
2727
28- const (
29- cdiKind = "nvidia.com/gpu"
30- defaultVersion = "570.0"
31- )
28+ const cdiKind = "nvidia.com/gpu"
29+
30+ // https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
31+ var libcudaGlobs = []string {
32+ "/usr/lib/*-linux-gnu/libcuda.so*" ,
33+ "/usr/lib/wsl/drivers/*/libcuda.so*" ,
34+ }
3235
3336func init () {
3437 cdidevices .Register (cdiKind , & setup {})
@@ -93,25 +96,15 @@ func (s *setup) Run(ctx context.Context) (err error) {
9396 }
9497
9598 var needsDriver bool
96- if nvidiaSmi , err := exec .LookPath ("nvidia-smi" ); err == nil && nvidiaSmi != "" {
97- if err := run (ctx , []string {nvidiaSmi , "-L" }, pw , dgst ); err != nil {
99+ if ! hasDriversInstalled () {
100+ // Check for nvidia-smi or /proc/driver/nvidia to determine if NVIDIA drivers are installed
101+ if nvidiaSmi , err := exec .LookPath ("nvidia-smi" ); err == nil && nvidiaSmi != "" {
102+ if err := run (ctx , []string {nvidiaSmi , "-L" }, pw , dgst ); err != nil {
103+ needsDriver = true
104+ }
105+ } else if _ , err := os .Stat ("/proc/driver/nvidia" ); err != nil {
98106 needsDriver = true
99107 }
100- } else if _ , err := os .Stat ("/proc/driver/nvidia" ); err != nil {
101- needsDriver = true
102- }
103-
104- var arch string
105- switch runtime .GOARCH {
106- case "amd64" :
107- arch = "x86_64"
108- case "arm64" :
109- arch = "sbsa"
110- // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
111- }
112-
113- if arch == "" {
114- return errors .Errorf ("unsupported architecture: %s" , runtime .GOARCH )
115108 }
116109
117110 if needsDriver {
@@ -122,21 +115,19 @@ func (s *setup) Run(ctx context.Context) (err error) {
122115 }
123116
124117 var dv string
125- if ! hasWSLGPU () {
118+ if needsDriver {
119+ if hasWSLGPU () {
120+ return errors .Errorf ("NVIDIA drivers are required for WSL with non PCI-based GPUs" )
121+ }
126122 version , err := readVersion ()
127- if err != nil && ! needsDriver {
123+ if err != nil {
128124 return errors .Wrapf (err , "failed to read NVIDIA driver version" )
129125 }
130- if version == "" {
131- version = defaultVersion
132- }
133126 var ok bool
134127 dv , _ , ok = strings .Cut (version , "." )
135128 if ! ok {
136129 return errors .Errorf ("failed to parse NVIDIA driver version %q" , version )
137130 }
138- } else if needsDriver {
139- return errors .Errorf ("NVIDIA drivers are required for WSL with non PCI-based GPUs" )
140131 }
141132
142133 if err := run (ctx , []string {"apt-get" , "update" }, pw , dgst ); err != nil {
@@ -147,100 +138,138 @@ func (s *setup) Run(ctx context.Context) (err error) {
147138 return err
148139 }
149140
141+ if needsDriver {
142+ if err := installDrivers (ctx , dv , pw , dgst ); err != nil {
143+ return err
144+ }
145+ }
146+ if err := installContainerToolkit (ctx , pw , dgst ); err != nil {
147+ return err
148+ }
149+
150+ if err := os .MkdirAll ("/etc/cdi" , 0700 ); err != nil {
151+ return errors .Wrapf (err , "failed to create /etc/cdi" )
152+ }
153+
154+ buf := & bytes.Buffer {}
155+
156+ cmd := exec .CommandContext (ctx , "nvidia-ctk" , "cdi" , "generate" )
157+ cmd .Stdout = buf
158+ cmd .Stderr = newStream (pw , 2 , dgst )
159+ if err := cmd .Run (); err != nil {
160+ return errors .Wrapf (err , "failed to generate CDI spec" )
161+ }
162+
163+ if len (buf .Bytes ()) == 0 {
164+ return errors .Errorf ("nvidia-ctk output is empty" )
165+ }
166+
167+ if err := os .WriteFile ("/etc/cdi/nvidia.yaml" , buf .Bytes (), 0644 ); err != nil {
168+ return errors .Wrapf (err , "failed to write /etc/cdi/nvidia.yaml" )
169+ }
170+
171+ return nil
172+ }
173+
174+ func run (ctx context.Context , args []string , pw progress.Writer , dgst digest.Digest ) error {
175+ fmt .Fprintf (newStream (pw , 2 , dgst ), "> %s\n " , strings .Join (args , " " ))
176+ cmd := exec .CommandContext (ctx , args [0 ], args [1 :]... ) //nolint:gosec
177+ cmd .Stderr = newStream (pw , 2 , dgst )
178+ cmd .Stdout = newStream (pw , 1 , dgst )
179+ return cmd .Run ()
180+ }
181+
182+ func installDrivers (ctx context.Context , dv string , pw progress.Writer , dgst digest.Digest ) error {
150183 const aptDistro = "ubuntu2404"
151- aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
152184
185+ var arch string
186+ switch runtime .GOARCH {
187+ case "amd64" :
188+ arch = "x86_64"
189+ case "arm64" :
190+ arch = "sbsa"
191+ // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
192+ }
193+ if arch == "" {
194+ return errors .Errorf ("unsupported architecture: %s" , runtime .GOARCH )
195+ }
196+
197+ aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
153198 keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"
154199
155200 if _ , err := os .Stat (keyTarget ); err != nil {
156- fmt .Fprintf (newStream (pw , 2 , dgst ), "Downloading NVIDIA GPG key\n " )
201+ fmt .Fprintf (newStream (pw , 2 , dgst ), "Downloading NVIDIA cuda GPG key\n " )
157202
158203 req , err := http .NewRequestWithContext (ctx , http .MethodGet , aptURL + "3bf863cc.pub" , nil )
159204 if err != nil {
160- return errors .Wrapf (err , "failed to create request for NVIDIA GPG key" )
205+ return errors .Wrapf (err , "failed to create request for NVIDIA cuda GPG key" )
161206 }
162207
163208 resp , err := http .DefaultClient .Do (req )
164209 if err != nil {
165- return errors .Wrapf (err , "failed to download NVIDIA GPG key" )
210+ return errors .Wrapf (err , "failed to download NVIDIA cuda GPG key" )
166211 }
167212
168213 cmd := exec .CommandContext (ctx , "gpg" , "--dearmor" , "-o" , keyTarget )
169214 cmd .Stdin = resp .Body
170215 cmd .Stderr = newStream (pw , 2 , dgst )
171216 if err := cmd .Run (); err != nil {
172- return errors .Wrapf (err , "failed to install NVIDIA GPG key" )
217+ return errors .Wrapf (err , "failed to install NVIDIA cuda GPG key" )
173218 }
174219 resp .Body .Close ()
175220 }
176221
177222 if err := os .WriteFile ("/etc/apt/sources.list.d/nvidia-cuda.list" , []byte ("deb [signed-by=" + keyTarget + "] " + aptURL + " /" ), 0644 ); err != nil {
178- return errors .Wrapf (err , "failed to add NVIDIA apt repo" )
223+ return errors .Wrapf (err , "failed to add NVIDIA cuda apt repo" )
179224 }
180225
181226 if err := run (ctx , []string {"apt-get" , "update" }, pw , dgst ); err != nil {
182227 return err
183228 }
184229
185- if needsDriver && dv != "" {
186- // this pretty much never works, is it even worth having?
187- // better approach could be to try to create another chroot/container that is built with same kernel packages as the host
188- // could nvidia-headless-no-dkms- be reusable
189- if err := run (ctx , []string {"apt-get" , "install" , "-y" , "nvidia-driver-" + dv }, pw , dgst ); err != nil {
190- return err
191- }
192- _ , err := os .Stat ("/proc/driver/nvidia" )
193- if err != nil {
194- return errors .Wrapf (err , "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually" )
195- }
196- }
230+ return run (ctx , []string {"apt-get" , "install" , "-y" , "--no-install-recommends" ,
231+ "libnvidia-compute-" + dv ,
232+ "libnvidia-extra-" + dv ,
233+ "libnvidia-gl-" + dv ,
234+ "nvidia-utils-" + dv ,
235+ }, pw , dgst )
236+ }
197237
198- pkgs := []string {
199- "nvidia-container-toolkit-base" ,
200- }
201- if dv != "" {
202- pkgs = append (pkgs , []string {
203- "libnvidia-compute-" + dv ,
204- "libnvidia-extra-" + dv ,
205- "libnvidia-gl-" + dv ,
206- "nvidia-utils-" + dv ,
207- }... )
208- }
238+ func installContainerToolkit (ctx context.Context , pw progress.Writer , dgst digest.Digest ) error {
239+ aptURL := "https://nvidia.github.io/libnvidia-container/"
240+ keyTarget := "/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg"
209241
210- if err := run (ctx , append ([]string {"apt-get" , "install" , "-y" , "--no-install-recommends" }, pkgs ... ), pw , dgst ); err != nil {
211- return err
212- }
242+ if _ , err := os .Stat (keyTarget ); err != nil {
243+ fmt .Fprintf (newStream (pw , 2 , dgst ), "Downloading NVIDIA container toolkit GPG key\n " )
213244
214- if err := os .MkdirAll ("/etc/cdi" , 0700 ); err != nil {
215- return errors .Wrapf (err , "failed to create /etc/cdi" )
216- }
245+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , aptURL + "gpgkey" , nil )
246+ if err != nil {
247+ return errors .Wrapf (err , "failed to create request for NVIDIA container toolkit GPG key" )
248+ }
217249
218- buf := & bytes.Buffer {}
250+ resp , err := http .DefaultClient .Do (req )
251+ if err != nil {
252+ return errors .Wrapf (err , "failed to download NVIDIA container toolkit GPG key" )
253+ }
219254
220- cmd := exec .CommandContext (ctx , "nvidia-ctk" , "cdi" , "generate" )
221- cmd .Stdout = buf
222- cmd .Stderr = newStream (pw , 2 , dgst )
223- if err := cmd .Run (); err != nil {
224- return errors .Wrapf (err , "failed to generate CDI spec" )
255+ cmd := exec .CommandContext (ctx , "gpg" , "--dearmor" , "-o" , keyTarget )
256+ cmd .Stdin = resp .Body
257+ cmd .Stderr = newStream (pw , 2 , dgst )
258+ if err := cmd .Run (); err != nil {
259+ return errors .Wrapf (err , "failed to install NVIDIA container toolkit GPG key" )
260+ }
261+ resp .Body .Close ()
225262 }
226263
227- if len ( buf . Bytes ()) == 0 {
228- return errors .Errorf ( "nvidia-ctk output is empty " )
264+ if err := os . WriteFile ( "/etc/apt/sources.list.d/nvidia-container-toolkit.list" , [] byte ( "deb [signed-by=" + keyTarget + "] " + aptURL + "stable/deb/$(ARCH) /" ), 0644 ); err != nil {
265+ return errors .Wrapf ( err , "failed to add NVIDIA container toolkit apt repo " )
229266 }
230267
231- if err := os . WriteFile ( "/etc/cdi/nvidia.yaml " , buf . Bytes (), 0644 ); err != nil {
232- return errors . Wrapf ( err , "failed to write /etc/cdi/nvidia.yaml" )
268+ if err := run ( ctx , [] string { "apt-get " , "update" }, pw , dgst ); err != nil {
269+ return err
233270 }
234271
235- return nil
236- }
237-
238- func run (ctx context.Context , args []string , pw progress.Writer , dgst digest.Digest ) error {
239- fmt .Fprintf (newStream (pw , 2 , dgst ), "> %s\n " , strings .Join (args , " " ))
240- cmd := exec .CommandContext (ctx , args [0 ], args [1 :]... ) //nolint:gosec
241- cmd .Stderr = newStream (pw , 2 , dgst )
242- cmd .Stdout = newStream (pw , 1 , dgst )
243- return cmd .Run ()
272+ return run (ctx , []string {"apt-get" , "install" , "-y" , "nvidia-container-toolkit-base" }, pw , dgst )
244273}
245274
246275func readVersion () (string , error ) {
@@ -326,3 +355,13 @@ func hasWSLGPU() bool {
326355 _ , err := os .Stat ("/dev/dxg" )
327356 return err == nil
328357}
358+
359+ func hasDriversInstalled () bool {
360+ // Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
361+ for _ , p := range libcudaGlobs {
362+ if matches , err := filepath .Glob (p ); err == nil && len (matches ) > 0 {
363+ return true
364+ }
365+ }
366+ return false
367+ }
0 commit comments