@@ -25,10 +25,13 @@ import (
2525// This is example of experimental on-demand setup of a CDI devices.
2626// This code is not currently shipping with BuildKit and will probably change.
2727
28- const (
29- cdiKind = "nvidia.com/gpu"
30- defaultVersion = "570.0"
31- )
28+ const cdiKind = "nvidia.com/gpu"
29+
30+ // https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
31+ var libcudaGlobs = []string {
32+ "/usr/lib/*-linux-gnu/libcuda.so*" ,
33+ "/usr/lib/wsl/drivers/*/libcuda.so*" ,
34+ }
3235
3336func init () {
3437 cdidevices .Register (cdiKind , & setup {})
@@ -92,51 +95,32 @@ func (s *setup) Run(ctx context.Context) (err error) {
9295 return errors .Errorf ("NVIDIA setup is currently only supported on Debian/Ubuntu" )
9396 }
9497
95- var needsDriver bool
96- if nvidiaSmi , err := exec .LookPath ("nvidia-smi" ); err == nil && nvidiaSmi != "" {
97- if err := run (ctx , []string {nvidiaSmi , "-L" }, pw , dgst ); err != nil {
98- needsDriver = true
98+ needsDriver := true
99+ if _ , err := os .Stat ("/proc/driver/nvidia" ); err == nil {
100+ needsDriver = false
101+ } else if nvidiaSmi , err := exec .LookPath ("nvidia-smi" ); err == nil && nvidiaSmi != "" {
102+ if err := run (ctx , []string {nvidiaSmi , "-L" }, pw , dgst ); err == nil {
103+ needsDriver = false
99104 }
100- } else if _ , err := os .Stat ("/proc/driver/nvidia" ); err != nil {
101- needsDriver = true
102- }
103-
104- var arch string
105- switch runtime .GOARCH {
106- case "amd64" :
107- arch = "x86_64"
108- case "arm64" :
109- arch = "sbsa"
110- // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
111- }
112-
113- if arch == "" {
114- return errors .Errorf ("unsupported architecture: %s" , runtime .GOARCH )
115105 }
116-
117106 if needsDriver {
118- pw . Write ( identity . NewID (), client. VertexWarning {
119- Vertex : dgst ,
120- Short : [] byte ( "NVIDIA Drivers not found. Installing prebuilt drivers is not recommended" ),
121- } )
107+ if hasWSLGPU () {
108+ return errors . Errorf ( "NVIDIA drivers are required for WSL with non PCI-based GPUs" )
109+ }
110+ return errors . Errorf ( "NVIDIA drivers are required. Try loading NVIDIA kernel module with \" modprobe nvidia \" command" )
122111 }
123112
124113 var dv string
125- if ! hasWSLGPU () {
114+ if ! hasLibsInstalled () && ! hasWSLGPU () {
126115 version , err := readVersion ()
127- if err != nil && ! needsDriver {
116+ if err != nil {
128117 return errors .Wrapf (err , "failed to read NVIDIA driver version" )
129118 }
130- if version == "" {
131- version = defaultVersion
132- }
133119 var ok bool
134120 dv , _ , ok = strings .Cut (version , "." )
135121 if ! ok {
136122 return errors .Errorf ("failed to parse NVIDIA driver version %q" , version )
137123 }
138- } else if needsDriver {
139- return errors .Errorf ("NVIDIA drivers are required for WSL with non PCI-based GPUs" )
140124 }
141125
142126 if err := run (ctx , []string {"apt-get" , "update" }, pw , dgst ); err != nil {
@@ -147,9 +131,58 @@ func (s *setup) Run(ctx context.Context) (err error) {
147131 return err
148132 }
149133
134+ if err := installPackages (ctx , dv , pw , dgst ); err != nil {
135+ return err
136+ }
137+
138+ if err := os .MkdirAll ("/etc/cdi" , 0700 ); err != nil {
139+ return errors .Wrapf (err , "failed to create /etc/cdi" )
140+ }
141+
142+ buf := & bytes.Buffer {}
143+
144+ cmd := exec .CommandContext (ctx , "nvidia-ctk" , "cdi" , "generate" )
145+ cmd .Stdout = buf
146+ cmd .Stderr = newStream (pw , 2 , dgst )
147+ if err := cmd .Run (); err != nil {
148+ return errors .Wrapf (err , "failed to generate CDI spec" )
149+ }
150+
151+ if len (buf .Bytes ()) == 0 {
152+ return errors .Errorf ("nvidia-ctk output is empty" )
153+ }
154+
155+ if err := os .WriteFile ("/etc/cdi/nvidia.yaml" , buf .Bytes (), 0644 ); err != nil {
156+ return errors .Wrapf (err , "failed to write /etc/cdi/nvidia.yaml" )
157+ }
158+
159+ return nil
160+ }
161+
162+ func run (ctx context.Context , args []string , pw progress.Writer , dgst digest.Digest ) error {
163+ fmt .Fprintf (newStream (pw , 2 , dgst ), "> %s\n " , strings .Join (args , " " ))
164+ cmd := exec .CommandContext (ctx , args [0 ], args [1 :]... ) //nolint:gosec
165+ cmd .Stderr = newStream (pw , 2 , dgst )
166+ cmd .Stdout = newStream (pw , 1 , dgst )
167+ return cmd .Run ()
168+ }
169+
170+ func installPackages (ctx context.Context , dv string , pw progress.Writer , dgst digest.Digest ) error {
150171 const aptDistro = "ubuntu2404"
151- aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
152172
173+ var arch string
174+ switch runtime .GOARCH {
175+ case "amd64" :
176+ arch = "x86_64"
177+ case "arm64" :
178+ arch = "sbsa"
179+ // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
180+ }
181+ if arch == "" {
182+ return errors .Errorf ("unsupported architecture: %s" , runtime .GOARCH )
183+ }
184+
185+ aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
153186 keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"
154187
155188 if _ , err := os .Stat (keyTarget ); err != nil {
@@ -182,22 +215,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
182215 return err
183216 }
184217
185- if needsDriver && dv != "" {
186- // this pretty much never works, is it even worth having?
187- // better approach could be to try to create another chroot/container that is built with same kernel packages as the host
188- // could nvidia-headless-no-dkms- be reusable
189- if err := run (ctx , []string {"apt-get" , "install" , "-y" , "nvidia-driver-" + dv }, pw , dgst ); err != nil {
190- return err
191- }
192- _ , err := os .Stat ("/proc/driver/nvidia" )
193- if err != nil {
194- return errors .Wrapf (err , "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually" )
195- }
196- }
197-
198- pkgs := []string {
199- "nvidia-container-toolkit-base" ,
200- }
218+ pkgs := []string {"nvidia-container-toolkit-base" }
201219 if dv != "" {
202220 pkgs = append (pkgs , []string {
203221 "libnvidia-compute-" + dv ,
@@ -207,40 +225,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
207225 }... )
208226 }
209227
210- if err := run (ctx , append ([]string {"apt-get" , "install" , "-y" , "--no-install-recommends" }, pkgs ... ), pw , dgst ); err != nil {
211- return err
212- }
213-
214- if err := os .MkdirAll ("/etc/cdi" , 0700 ); err != nil {
215- return errors .Wrapf (err , "failed to create /etc/cdi" )
216- }
217-
218- buf := & bytes.Buffer {}
219-
220- cmd := exec .CommandContext (ctx , "nvidia-ctk" , "cdi" , "generate" )
221- cmd .Stdout = buf
222- cmd .Stderr = newStream (pw , 2 , dgst )
223- if err := cmd .Run (); err != nil {
224- return errors .Wrapf (err , "failed to generate CDI spec" )
225- }
226-
227- if len (buf .Bytes ()) == 0 {
228- return errors .Errorf ("nvidia-ctk output is empty" )
229- }
230-
231- if err := os .WriteFile ("/etc/cdi/nvidia.yaml" , buf .Bytes (), 0644 ); err != nil {
232- return errors .Wrapf (err , "failed to write /etc/cdi/nvidia.yaml" )
233- }
234-
235- return nil
236- }
237-
238- func run (ctx context.Context , args []string , pw progress.Writer , dgst digest.Digest ) error {
239- fmt .Fprintf (newStream (pw , 2 , dgst ), "> %s\n " , strings .Join (args , " " ))
240- cmd := exec .CommandContext (ctx , args [0 ], args [1 :]... ) //nolint:gosec
241- cmd .Stderr = newStream (pw , 2 , dgst )
242- cmd .Stdout = newStream (pw , 1 , dgst )
243- return cmd .Run ()
228+ return run (ctx , append ([]string {"apt-get" , "install" , "-y" , "--no-install-recommends" }, pkgs ... ), pw , dgst )
244229}
245230
246231func readVersion () (string , error ) {
@@ -326,3 +311,13 @@ func hasWSLGPU() bool {
326311 _ , err := os .Stat ("/dev/dxg" )
327312 return err == nil
328313}
314+
315+ func hasLibsInstalled () bool {
316+ // Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
317+ for _ , p := range libcudaGlobs {
318+ if matches , err := filepath .Glob (p ); err == nil && len (matches ) > 0 {
319+ return true
320+ }
321+ }
322+ return false
323+ }
0 commit comments