1111from pathlib import Path
1212from typing import List
1313from jinja2 import Template
14+ import argparse
1415
1516from distutils .core import Command
1617from setuptools import Extension , setup , find_packages
@@ -30,6 +31,31 @@ def get_cpu_arch():
3031 else :
3132 raise ValueError (f"Unsupported architecture: { arch } " )
3233
34+ # get device type
35+ def get_device_type ():
36+ import torch
37+
38+ if torch .cuda .is_available ():
39+ return "cuda"
40+
41+ try :
42+ import torch_mlu
43+ if torch .mlu .is_available ():
44+ return "mlu"
45+ except ImportError :
46+ pass
47+
48+ try :
49+ import torch_npu
50+ if torch .npu .is_available ():
51+ return "a2"
52+ except ImportError :
53+ pass
54+
55+ print ("Unsupported device type, please install torch, torch_mlu or torch_npu" )
56+ exit (1 )
57+
58+
3359def get_cxx_abi ():
3460 try :
3561 import torch
@@ -227,8 +253,6 @@ def set_cuda_envs():
227253 os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
228254 os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
229255 os .environ ["CUDA_TOOLKIT_ROOT_DIR" ] = "/usr/local/cuda"
230- os .environ ["NCCL_ROOT" ] = get_nccl_root_path ()
231- os .environ ["NCCL_VERSION" ] = "2"
232256
233257class CMakeExtension (Extension ):
234258 def __init__ (self , name : str , path : str , sourcedir : str = "" ) -> None :
@@ -562,54 +586,80 @@ def pre_build():
562586 if not run_shell_command ("sh third_party/dependencies.sh" , cwd = script_path ):
563587 print ("❌ Failed to reset changes!" )
564588 exit (0 )
589+
590+ def parse_arguments ():
591+ parser = argparse .ArgumentParser (
592+ description = 'Setup helper for building xllm' ,
593+ epilog = 'Example: python setup.py build --device a3' ,
594+ usage = '%(prog)s [COMMAND] [OPTIONS]'
595+ )
596+
597+ parser .add_argument (
598+ 'setup_args' ,
599+ nargs = '*' ,
600+ metavar = 'argparse.REMAINDER' ,
601+ help = 'setup command (build, test, bdist_wheel, etc.)'
602+ )
603+
604+ parser .add_argument (
605+ '--device' ,
606+ type = str .lower ,
607+ choices = ['auto' , 'a2' , 'a3' , 'mlu' , 'cuda' ],
608+ default = 'auto' ,
609+ help = 'Device type: a2, a3, mlu, or cuda (case-insensitive)'
610+ )
611+
612+ parser .add_argument (
613+ '--dry-run' ,
614+ action = 'store_true' ,
615+ help = 'Dry run mode (do not execute pre_build)'
616+ )
617+
618+ parser .add_argument (
619+ '--install-xllm-kernels' ,
620+ type = str .lower ,
621+ choices = ['true' , 'false' , '1' , '0' , 'yes' , 'no' , 'y' , 'n' , 'on' , 'off' ],
622+ default = 'true' ,
623+ help = 'Whether to install xllm kernels'
624+ )
625+
626+ parser .add_argument (
627+ '--generate-so' ,
628+ type = str .lower ,
629+ choices = ['true' , 'false' , '1' , '0' , 'yes' , 'no' , 'y' , 'n' , 'on' , 'off' ],
630+ default = 'false' ,
631+ help = 'Whether to generate so or binary'
632+ )
633+
634+ args = parser .parse_args ()
635+
636+ sys .argv = [sys .argv [0 ]] + args .setup_args
637+
638+ install_kernels = args .install_xllm_kernels .lower () in ('true' , '1' , 'yes' , 'y' , 'on' )
639+ generate_so = args .generate_so .lower () in ('true' , '1' , 'yes' , 'y' , 'on' )
640+
641+ return {
642+ 'device' : args .device ,
643+ 'dry_run' : args .dry_run ,
644+ 'install_xllm_kernels' : install_kernels ,
645+ 'generate_so' : generate_so ,
646+ }
647+
565648
566649if __name__ == "__main__" :
567- device = 'a2' # default
650+ config = parse_arguments ()
651+
568652 arch = get_cpu_arch ()
569- install_kernels = True
570- generate_so = False
571- if '--device' in sys .argv :
572- idx = sys .argv .index ('--device' )
573- if idx + 1 < len (sys .argv ):
574- device = sys .argv [idx + 1 ].lower ()
575- if device not in ('a2' , 'a3' , 'mlu' , 'cuda' ):
576- print ("Error: --device must be a2 or a3 or mlu (case-insensitive)" )
577- sys .exit (1 )
578- # Remove the arguments so setup() doesn't see them
579- del sys .argv [idx ]
580- del sys .argv [idx ]
581- if '--dry_run' not in sys .argv :
582- pre_build ()
583- else :
584- sys .argv .remove ("--dry_run" )
653+ device = config ['device' ]
654+ if device == 'auto' :
655+ device = get_device_type ()
656+ print (f"🚀 Build xllm with CPU arch: { arch } and target device: { device } " )
585657
586- if '--install-xllm-kernels' in sys .argv :
587- idx = sys .argv .index ('--install-xllm-kernels' )
588- if idx + 1 < len (sys .argv ):
589- install_kernels = sys .argv [idx + 1 ].lower ()
590- if install_kernels in ('true' , '1' , 'yes' , 'y' , 'on' ):
591- install_kernels = True
592- elif install_kernels in ('false' , '0' , 'no' , 'n' , 'off' ):
593- install_kernels = False
594- else :
595- print ("Error: --install-xllm-kernels must be true or false" )
596- sys .exit (1 )
597- sys .argv .pop (idx )
598- sys .argv .pop (idx )
599-
600- if '--generate-so' in sys .argv :
601- idx = sys .argv .index ('--generate-so' )
602- if idx + 1 < len (sys .argv ):
603- generate_so_val = sys .argv [idx + 1 ].lower ()
604- if generate_so_val in ('true' , '1' , 'yes' , 'y' , 'on' ):
605- generate_so = True
606- elif generate_so_val in ('false' , '0' , 'no' , 'n' , 'off' ):
607- generate_so = False
608- else :
609- print ("Error: --generate-so must be true or false" )
610- sys .exit (1 )
611- sys .argv .pop (idx )
612- sys .argv .pop (idx )
658+ if not config ['dry_run' ]:
659+ pre_build ()
660+
661+ install_kernels = config ['install_xllm_kernels' ]
662+ generate_so = config ['generate_so' ]
613663
614664 if "SKIP_TEST" in os .environ :
615665 BUILD_TEST_FILE = False
@@ -631,7 +681,7 @@ def pre_build():
631681 long_description = read_readme (),
632682 long_description_content_type = "text/markdown" ,
633683 url = "https://github.com/jd-opensource/xllm" ,
634- project_url = {
684+ project_urls = {
635685 "Homepage" : "https://xllm.readthedocs.io/zh-cn/latest/" ,
636686 "Documentation" : "https://xllm.readthedocs.io/zh-cn/latest/" ,
637687 },
@@ -658,7 +708,7 @@ def pre_build():
658708 options = {'build_ext' : {
659709 'device' : device ,
660710 'arch' : arch ,
661- 'install_xllm_kernels' : install_kernels if install_kernels is not None else "false" ,
711+ 'install_xllm_kernels' : install_kernels ,
662712 'generate_so' : generate_so
663713 },
664714 'bdist_wheel' : {
0 commit comments