@@ -25,10 +25,13 @@ import (
2525// This is example of experimental on-demand setup of a CDI devices.
2626// This code is not currently shipping with BuildKit and will probably change.
2727
28- const (
29- cdiKind = "nvidia.com/gpu"
30- defaultVersion = "570.0"
31- )
28+ const cdiKind = "nvidia.com/gpu"
29+
30+ // https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
31+ var libcudaGlobs = []string {
32+ "/usr/lib/*-linux-gnu/libcuda.so*" ,
33+ "/usr/lib/wsl/drivers/*/libcuda.so*" ,
34+ }
3235
3336func init () {
3437 cdidevices .Register (cdiKind , & setup {})
@@ -93,25 +96,15 @@ func (s *setup) Run(ctx context.Context) (err error) {
9396 }
9497
9598 var needsDriver bool
96- if nvidiaSmi , err := exec .LookPath ("nvidia-smi" ); err == nil && nvidiaSmi != "" {
97- if err := run (ctx , []string {nvidiaSmi , "-L" }, pw , dgst ); err != nil {
99+ if ! hasDriversInstalled () {
100+ // Check for nvidia-smi or /proc/driver/nvidia to determine if NVIDIA drivers are installed
101+ if nvidiaSmi , err := exec .LookPath ("nvidia-smi" ); err == nil && nvidiaSmi != "" {
102+ if err := run (ctx , []string {nvidiaSmi , "-L" }, pw , dgst ); err != nil {
103+ needsDriver = true
104+ }
105+ } else if _ , err := os .Stat ("/proc/driver/nvidia" ); err != nil {
98106 needsDriver = true
99107 }
100- } else if _ , err := os .Stat ("/proc/driver/nvidia" ); err != nil {
101- needsDriver = true
102- }
103-
104- var arch string
105- switch runtime .GOARCH {
106- case "amd64" :
107- arch = "x86_64"
108- case "arm64" :
109- arch = "sbsa"
110- // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
111- }
112-
113- if arch == "" {
114- return errors .Errorf ("unsupported architecture: %s" , runtime .GOARCH )
115108 }
116109
117110 if needsDriver {
@@ -122,21 +115,19 @@ func (s *setup) Run(ctx context.Context) (err error) {
122115 }
123116
124117 var dv string
125- if ! hasWSLGPU () {
118+ if needsDriver {
119+ if hasWSLGPU () {
120+ return errors .Errorf ("NVIDIA drivers are required for WSL with non PCI-based GPUs" )
121+ }
126122 version , err := readVersion ()
127- if err != nil && ! needsDriver {
123+ if err != nil {
128124 return errors .Wrapf (err , "failed to read NVIDIA driver version" )
129125 }
130- if version == "" {
131- version = defaultVersion
132- }
133126 var ok bool
134127 dv , _ , ok = strings .Cut (version , "." )
135128 if ! ok {
136129 return errors .Errorf ("failed to parse NVIDIA driver version %q" , version )
137130 }
138- } else if needsDriver {
139- return errors .Errorf ("NVIDIA drivers are required for WSL with non PCI-based GPUs" )
140131 }
141132
142133 if err := run (ctx , []string {"apt-get" , "update" }, pw , dgst ); err != nil {
@@ -147,9 +138,58 @@ func (s *setup) Run(ctx context.Context) (err error) {
147138 return err
148139 }
149140
141+ if err := installPackages (ctx , dv , pw , dgst ); err != nil {
142+ return err
143+ }
144+
145+ if err := os .MkdirAll ("/etc/cdi" , 0700 ); err != nil {
146+ return errors .Wrapf (err , "failed to create /etc/cdi" )
147+ }
148+
149+ buf := & bytes.Buffer {}
150+
151+ cmd := exec .CommandContext (ctx , "nvidia-ctk" , "cdi" , "generate" )
152+ cmd .Stdout = buf
153+ cmd .Stderr = newStream (pw , 2 , dgst )
154+ if err := cmd .Run (); err != nil {
155+ return errors .Wrapf (err , "failed to generate CDI spec" )
156+ }
157+
158+ if len (buf .Bytes ()) == 0 {
159+ return errors .Errorf ("nvidia-ctk output is empty" )
160+ }
161+
162+ if err := os .WriteFile ("/etc/cdi/nvidia.yaml" , buf .Bytes (), 0644 ); err != nil {
163+ return errors .Wrapf (err , "failed to write /etc/cdi/nvidia.yaml" )
164+ }
165+
166+ return nil
167+ }
168+
169+ func run (ctx context.Context , args []string , pw progress.Writer , dgst digest.Digest ) error {
170+ fmt .Fprintf (newStream (pw , 2 , dgst ), "> %s\n " , strings .Join (args , " " ))
171+ cmd := exec .CommandContext (ctx , args [0 ], args [1 :]... ) //nolint:gosec
172+ cmd .Stderr = newStream (pw , 2 , dgst )
173+ cmd .Stdout = newStream (pw , 1 , dgst )
174+ return cmd .Run ()
175+ }
176+
177+ func installPackages (ctx context.Context , dv string , pw progress.Writer , dgst digest.Digest ) error {
150178 const aptDistro = "ubuntu2404"
151- aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
152179
180+ var arch string
181+ switch runtime .GOARCH {
182+ case "amd64" :
183+ arch = "x86_64"
184+ case "arm64" :
185+ arch = "sbsa"
186+ // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
187+ }
188+ if arch == "" {
189+ return errors .Errorf ("unsupported architecture: %s" , runtime .GOARCH )
190+ }
191+
192+ aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
153193 keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"
154194
155195 if _ , err := os .Stat (keyTarget ); err != nil {
@@ -182,22 +222,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
182222 return err
183223 }
184224
185- if needsDriver && dv != "" {
186- // this pretty much never works, is it even worth having?
187- // better approach could be to try to create another chroot/container that is built with same kernel packages as the host
188- // could nvidia-headless-no-dkms- be reusable
189- if err := run (ctx , []string {"apt-get" , "install" , "-y" , "nvidia-driver-" + dv }, pw , dgst ); err != nil {
190- return err
191- }
192- _ , err := os .Stat ("/proc/driver/nvidia" )
193- if err != nil {
194- return errors .Wrapf (err , "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually" )
195- }
196- }
197-
198- pkgs := []string {
199- "nvidia-container-toolkit-base" ,
200- }
225+ pkgs := []string {"nvidia-container-toolkit-base" }
201226 if dv != "" {
202227 pkgs = append (pkgs , []string {
203228 "libnvidia-compute-" + dv ,
@@ -207,40 +232,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
207232 }... )
208233 }
209234
210- if err := run (ctx , append ([]string {"apt-get" , "install" , "-y" , "--no-install-recommends" }, pkgs ... ), pw , dgst ); err != nil {
211- return err
212- }
213-
214- if err := os .MkdirAll ("/etc/cdi" , 0700 ); err != nil {
215- return errors .Wrapf (err , "failed to create /etc/cdi" )
216- }
217-
218- buf := & bytes.Buffer {}
219-
220- cmd := exec .CommandContext (ctx , "nvidia-ctk" , "cdi" , "generate" )
221- cmd .Stdout = buf
222- cmd .Stderr = newStream (pw , 2 , dgst )
223- if err := cmd .Run (); err != nil {
224- return errors .Wrapf (err , "failed to generate CDI spec" )
225- }
226-
227- if len (buf .Bytes ()) == 0 {
228- return errors .Errorf ("nvidia-ctk output is empty" )
229- }
230-
231- if err := os .WriteFile ("/etc/cdi/nvidia.yaml" , buf .Bytes (), 0644 ); err != nil {
232- return errors .Wrapf (err , "failed to write /etc/cdi/nvidia.yaml" )
233- }
234-
235- return nil
236- }
237-
238- func run (ctx context.Context , args []string , pw progress.Writer , dgst digest.Digest ) error {
239- fmt .Fprintf (newStream (pw , 2 , dgst ), "> %s\n " , strings .Join (args , " " ))
240- cmd := exec .CommandContext (ctx , args [0 ], args [1 :]... ) //nolint:gosec
241- cmd .Stderr = newStream (pw , 2 , dgst )
242- cmd .Stdout = newStream (pw , 1 , dgst )
243- return cmd .Run ()
235+ return run (ctx , append ([]string {"apt-get" , "install" , "-y" , "--no-install-recommends" }, pkgs ... ), pw , dgst )
244236}
245237
246238func readVersion () (string , error ) {
@@ -326,3 +318,13 @@ func hasWSLGPU() bool {
326318 _ , err := os .Stat ("/dev/dxg" )
327319 return err == nil
328320}
321+
322+ func hasDriversInstalled () bool {
323+ // Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
324+ for _ , p := range libcudaGlobs {
325+ if matches , err := filepath .Glob (p ); err == nil && len (matches ) > 0 {
326+ return true
327+ }
328+ }
329+ return false
330+ }
0 commit comments