33# install_prerequisites.py
44import argparse
55import glob
6+ import json
67import os
78import subprocess
89import sys
10+ import urllib .request
911
1012# --- Configuration ---
1113WHEELS_CACHE_HOME = os .environ .get ("WHEELS_CACHE_HOME" , "/tmp/wheels_cache" )
1820
1921
2022# --- Helper Functions ---
23+ def get_latest_nixl_version ():
24+ """Helper function to get latest release version of NIXL"""
25+ try :
26+ nixl_release_url = "https://api.github.com/repos/ai-dynamo/nixl/releases/latest"
27+ with urllib .request .urlopen (nixl_release_url ) as response :
28+ data = json .load (response )
29+ return data .get ("tag_name" , "0.7.0" )
30+ except Exception :
31+ return "0.7.0"
32+
33+
34+ NIXL_VERSION = os .environ .get ("NIXL_VERSION" , get_latest_nixl_version ())
35+
36+
2137def run_command (command , cwd = "." , env = None ):
2238 """Helper function to run a shell command and check for errors."""
2339 print (f"--> Running command: { ' ' .join (command )} in '{ cwd } '" , flush = True )
@@ -37,7 +53,7 @@ def is_pip_package_installed(package_name):
3753def find_nixl_wheel_in_cache (cache_dir ):
3854 """Finds a nixl wheel file in the specified cache directory."""
3955 # The repaired wheel will have a 'manylinux' tag, but this glob still works.
40- search_pattern = os .path .join (cache_dir , "nixl*.whl" )
56+ search_pattern = os .path .join (cache_dir , f "nixl* { NIXL_VERSION } *.whl" )
4157 wheels = glob .glob (search_pattern )
4258 if wheels :
4359 # Sort to get the most recent/highest version if multiple exist
@@ -146,6 +162,10 @@ def build_and_install_prerequisites(args):
146162 print ("\n [2/3] Building NIXL wheel from source..." , flush = True )
147163 if not os .path .exists (NIXL_DIR ):
148164 run_command (["git" , "clone" , NIXL_REPO_URL , NIXL_DIR ])
165+ else :
166+ run_command (["git" , "fetch" , "--tags" ], cwd = NIXL_DIR )
167+ run_command (["git" , "checkout" , NIXL_VERSION ], cwd = NIXL_DIR )
168+ print (f"--> Checked out NIXL version: { NIXL_VERSION } " , flush = True )
149169
150170 build_env = os .environ .copy ()
151171 build_env ["PKG_CONFIG_PATH" ] = os .path .join (ucx_install_path , "lib" , "pkgconfig" )
@@ -203,7 +223,14 @@ def build_and_install_prerequisites(args):
203223 { os .path .basename (newly_built_wheel )} . Now installing..." ,
204224 flush = True ,
205225 )
206- install_command = [sys .executable , "-m" , "pip" , "install" , newly_built_wheel ]
226+ install_command = [
227+ sys .executable ,
228+ "-m" ,
229+ "pip" ,
230+ "install" ,
231+ "--no-deps" , # w/o "no-deps", it will install cuda-torch
232+ newly_built_wheel ,
233+ ]
207234 if args .force_reinstall :
208235 install_command .insert (- 1 , "--force-reinstall" )
209236
0 commit comments