@@ -232,8 +232,24 @@ func (g *FastGenerator) generateMonobase(lines []string, tmpDir string) ([]strin
232232 "ENV R8_COG_VERSION=" + CogletVersionFromEnvironment (),
233233 }... )
234234
235+ torchVersion , ok := g .Config .TorchVersion ()
236+ if ok {
237+ if ! CheckMajorMinorPatch (torchVersion ) {
238+ return nil , fmt .Errorf ("Torch version must be <major>.<minor>.<patch>: %s" , strings .Join (g .matrix .TorchVersions , ", " ))
239+ }
240+ envs = append (envs , []string {
241+ "ENV R8_TORCH_VERSION=" + torchVersion ,
242+ }... )
243+ }
244+
245+ console .Infof ("OK: %v" , ok )
246+ console .Infof ("Torch Version: %s" , torchVersion )
247+
235248 if g .Config .Build .GPU {
236249 cudaVersion := g .Config .Build .CUDA
250+ if cudaVersion == "" && ok {
251+ cudaVersion = g .matrix .DefaultCUDAVersion (torchVersion )
252+ }
237253 cudnnVersion := g .Config .Build .CuDNN
238254 if cudnnVersion == "" {
239255 cudnnVersion = g .matrix .DefaultCudnnVersion ()
@@ -261,16 +277,6 @@ func (g *FastGenerator) generateMonobase(lines []string, tmpDir string) ([]strin
261277 "ENV R8_PYTHON_VERSION=" + g .Config .Build .PythonVersion ,
262278 }... )
263279
264- torchVersion , ok := g .Config .TorchVersion ()
265- if ok {
266- if ! CheckMajorMinorPatch (torchVersion ) {
267- return nil , fmt .Errorf ("Torch version must be <major>.<minor>.<patch>: %s" , strings .Join (g .matrix .TorchVersions , ", " ))
268- }
269- envs = append (envs , []string {
270- "ENV R8_TORCH_VERSION=" + torchVersion ,
271- }... )
272- }
273-
274280 if ! g .matrix .IsSupported (g .Config .Build .PythonVersion , torchVersion , g .Config .Build .CUDA ) {
275281 return nil , fmt .Errorf (
276282 "Unsupported version combination: Python=%s, Torch=%s, CUDA=%s" ,
0 commit comments