@@ -25,10 +25,13 @@ import (
2525// This is example of experimental on-demand setup of a CDI devices.
2626// This code is not currently shipping with BuildKit and will probably change.
2727
28- const (
29- cdiKind = "nvidia.com/gpu"
30- defaultVersion = "570.0"
31- )
28+ const cdiKind = "nvidia.com/gpu"
29+
30+ // https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43
31+ var libcudaGlobs = []string {
32+ "/usr/lib/*-linux-gnu/libcuda.so*" ,
33+ "/usr/lib/wsl/drivers/*/libcuda.so*" ,
34+ }
3235
3336func init () {
3437 cdidevices .Register (cdiKind , & setup {})
@@ -39,8 +42,7 @@ type setup struct{}
3942var _ cdidevices.Setup = & setup {}
4043
4144func (s * setup ) Validate () error {
42- _ , err := readVersion ()
43- if err == nil {
45+ if _ , err := readVersion (); err == nil {
4446 return nil
4547 }
4648 b , err := hasNvidiaDevices ()
@@ -93,55 +95,94 @@ func (s *setup) Run(ctx context.Context) (err error) {
9395 return errors .Errorf ("NVIDIA setup is currently only supported on Debian/Ubuntu" )
9496 }
9597
96- var needsDriver bool
97-
98- if _ , err := os .Stat ("/proc/driver/nvidia" ); err != nil {
99- needsDriver = true
98+ needsDriver := true
99+ if _ , err := os .Stat ("/proc/driver/nvidia" ); err == nil {
100+ needsDriver = false
101+ } else if nvidiaSmi , err := exec .LookPath ("nvidia-smi" ); err == nil && nvidiaSmi != "" {
102+ if err := run (ctx , []string {nvidiaSmi , "-L" }, pw , dgst ); err == nil {
103+ needsDriver = false
104+ }
105+ }
106+ if needsDriver {
107+ if hasWSLGPU () {
108+ return errors .Errorf ("NVIDIA drivers are required for WSL with non PCI-based GPUs" )
109+ }
110+ return errors .Errorf ("NVIDIA drivers are required. Try loading NVIDIA kernel module with \" modprobe nvidia\" command" )
100111 }
101112
102- var arch string
103- switch runtime .GOARCH {
104- case "amd64" :
105- arch = "x86_64"
106- case "arm64" :
107- arch = "sbsa"
108- // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
113+ var dv string
114+ if ! hasLibsInstalled () && ! hasWSLGPU () {
115+ version , err := readVersion ()
116+ if err != nil {
117+ return errors .Wrapf (err , "failed to read NVIDIA driver version" )
118+ }
119+ var ok bool
120+ dv , _ , ok = strings .Cut (version , "." )
121+ if ! ok {
122+ return errors .Errorf ("failed to parse NVIDIA driver version %q" , version )
123+ }
109124 }
110125
111- if arch == "" {
112- return errors . Errorf ( "unsupported architecture: %s" , runtime . GOARCH )
126+ if err := run ( ctx , [] string { "apt-get" , "update" }, pw , dgst ); err != nil {
127+ return err
113128 }
114129
115- if needsDriver {
116- pw .Write (identity .NewID (), client.VertexWarning {
117- Vertex : dgst ,
118- Short : []byte ("NVIDIA Drivers not found. Installing prebuilt drivers is not recommended" ),
119- })
130+ if err := run (ctx , []string {"apt-get" , "install" , "-y" , "gpg" }, pw , dgst ); err != nil {
131+ return err
120132 }
121133
122- version , err := readVersion ()
123- if err != nil && ! needsDriver {
124- return errors .Wrapf (err , "failed to read NVIDIA driver version" )
134+ if err := installPackages (ctx , dv , pw , dgst ); err != nil {
135+ return err
125136 }
126- if version == "" {
127- version = defaultVersion
137+
138+ if err := os .MkdirAll ("/etc/cdi" , 0700 ); err != nil {
139+ return errors .Wrapf (err , "failed to create /etc/cdi" )
128140 }
129- v1 , _ , ok := strings .Cut (version , "." )
130- if ! ok {
131- return errors .Errorf ("failed to parse NVIDIA driver version %q" , version )
141+
142+ buf := & bytes.Buffer {}
143+
144+ cmd := exec .CommandContext (ctx , "nvidia-ctk" , "cdi" , "generate" )
145+ cmd .Stdout = buf
146+ cmd .Stderr = newStream (pw , 2 , dgst )
147+ if err := cmd .Run (); err != nil {
148+ return errors .Wrapf (err , "failed to generate CDI spec" )
132149 }
133150
134- if err := run ( ctx , [] string { "apt-get" , "update" }, pw , dgst ); err != nil {
135- return err
151+ if len ( buf . Bytes ()) == 0 {
152+ return errors . Errorf ( "nvidia-ctk output is empty" )
136153 }
137154
138- if err := run ( ctx , [] string { "apt-get " , "install" , "-y" , "gpg" }, pw , dgst ); err != nil {
139- return err
155+ if err := os . WriteFile ( "/etc/cdi/nvidia.yaml " , buf . Bytes (), 0644 ); err != nil {
156+ return errors . Wrapf ( err , "failed to write /etc/cdi/nvidia.yaml" )
140157 }
141158
159+ return nil
160+ }
161+
162+ func run (ctx context.Context , args []string , pw progress.Writer , dgst digest.Digest ) error {
163+ fmt .Fprintf (newStream (pw , 2 , dgst ), "> %s\n " , strings .Join (args , " " ))
164+ cmd := exec .CommandContext (ctx , args [0 ], args [1 :]... ) //nolint:gosec
165+ cmd .Stderr = newStream (pw , 2 , dgst )
166+ cmd .Stdout = newStream (pw , 1 , dgst )
167+ return cmd .Run ()
168+ }
169+
170+ func installPackages (ctx context.Context , dv string , pw progress.Writer , dgst digest.Digest ) error {
142171 const aptDistro = "ubuntu2404"
143- aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
144172
173+ var arch string
174+ switch runtime .GOARCH {
175+ case "amd64" :
176+ arch = "x86_64"
177+ case "arm64" :
178+ arch = "sbsa"
179+ // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb
180+ }
181+ if arch == "" {
182+ return errors .Errorf ("unsupported architecture: %s" , runtime .GOARCH )
183+ }
184+
185+ aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/"
145186 keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg"
146187
147188 if _ , err := os .Stat (keyTarget ); err != nil {
@@ -174,59 +215,17 @@ func (s *setup) Run(ctx context.Context) (err error) {
174215 return err
175216 }
176217
177- if needsDriver {
178- // this pretty much never works, is it even worth having?
179- // better approach could be to try to create another chroot/container that is built with same kernel packages as the host
180- // could nvidia-headless-no-dkms- be reusable
181- if err := run (ctx , []string {"apt-get" , "install" , "-y" , "nvidia-driver-" + v1 }, pw , dgst ); err != nil {
182- return err
183- }
184- _ , err := os .Stat ("/proc/driver/nvidia" )
185- if err != nil {
186- return errors .Wrapf (err , "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually" )
187- }
188- }
189-
190- if err := run (ctx , []string {"apt-get" , "install" , "-y" , "--no-install-recommends" ,
191- "libnvidia-compute-" + v1 ,
192- "libnvidia-extra-" + v1 ,
193- "libnvidia-gl-" + v1 ,
194- "nvidia-utils-" + v1 ,
195- "nvidia-container-toolkit-base" ,
196- }, pw , dgst ); err != nil {
197- return err
198- }
199-
200- if err := os .MkdirAll ("/etc/cdi" , 0700 ); err != nil {
201- return errors .Wrapf (err , "failed to create /etc/cdi" )
218+ pkgs := []string {"nvidia-container-toolkit-base" }
219+ if dv != "" {
220+ pkgs = append (pkgs , []string {
221+ "libnvidia-compute-" + dv ,
222+ "libnvidia-extra-" + dv ,
223+ "libnvidia-gl-" + dv ,
224+ "nvidia-utils-" + dv ,
225+ }... )
202226 }
203227
204- buf := & bytes.Buffer {}
205-
206- cmd := exec .CommandContext (ctx , "nvidia-ctk" , "cdi" , "generate" )
207- cmd .Stdout = buf
208- cmd .Stderr = newStream (pw , 2 , dgst )
209- if err := cmd .Run (); err != nil {
210- return errors .Wrapf (err , "failed to generate CDI spec" )
211- }
212-
213- if len (buf .Bytes ()) == 0 {
214- return errors .Errorf ("nvidia-ctk output is empty" )
215- }
216-
217- if err := os .WriteFile ("/etc/cdi/nvidia.yaml" , buf .Bytes (), 0644 ); err != nil {
218- return errors .Wrapf (err , "failed to write /etc/cdi/nvidia.yaml" )
219- }
220-
221- return nil
222- }
223-
224- func run (ctx context.Context , args []string , pw progress.Writer , dgst digest.Digest ) error {
225- fmt .Fprintf (newStream (pw , 2 , dgst ), "> %s\n " , strings .Join (args , " " ))
226- cmd := exec .CommandContext (ctx , args [0 ], args [1 :]... ) //nolint:gosec
227- cmd .Stderr = newStream (pw , 2 , dgst )
228- cmd .Stdout = newStream (pw , 1 , dgst )
229- return cmd .Run ()
228+ return run (ctx , append ([]string {"apt-get" , "install" , "-y" , "--no-install-recommends" }, pkgs ... ), pw , dgst )
230229}
231230
232231func readVersion () (string , error ) {
@@ -268,6 +267,10 @@ func hasNvidiaDevices() (bool, error) {
268267 }
269268 }
270269
270+ if ! found {
271+ found = hasWSLGPU ()
272+ }
273+
271274 return found , nil
272275}
273276
@@ -302,3 +305,19 @@ func isDebianOrUbuntu() (bool, error) {
302305
303306 return id == "debian" || id == "ubuntu" , nil
304307}
308+
309+ func hasWSLGPU () bool {
310+ // WSL-specific GPU mapping that doesn't expose PCI info.
311+ _ , err := os .Stat ("/dev/dxg" )
312+ return err == nil
313+ }
314+
315+ func hasLibsInstalled () bool {
316+ // Check for libcuda in the standard locations to confirm NVIDIA GPU drivers
317+ for _ , p := range libcudaGlobs {
318+ if matches , err := filepath .Glob (p ); err == nil && len (matches ) > 0 {
319+ return true
320+ }
321+ }
322+ return false
323+ }
0 commit comments