Skip to content

Commit 92b9950

Browse files
committed
feat: determine device type automatically except npu device.
1 parent 4d97206 commit 92b9950

File tree

3 files changed

+100
-50
lines changed

3 files changed

+100
-50
lines changed

docs/en/getting_started/compile.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pip install --upgrade setuptools wheel
4848
```
4949

5050
## Compilation
51-
Execute the compilation to generate the executable file `build/xllm/core/server/xllm` under `build/`. The default device is A2, for A3, add `--device a3`, for mlu, add `--device mlu`:
51+
Execute the compilation to generate the executable file `build/xllm/core/server/xllm` under `build/`. For A3, add `--device a3`, and no need to add `--device` for other devices:
5252
```bash
5353
python setup.py build
5454
```

docs/zh/getting_started/compile.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pip install -r cibuild/requirements-dev.txt -i https://mirrors.tuna.tsinghua.edu
4949
pip install --upgrade setuptools wheel
5050
```
5151
## 编译
52-
执行编译,在`build/`下生成可执行文件`build/xllm/core/server/xllm`默认为A2,A3请加 `--device a3`MLU请加 `--device mlu`
52+
执行编译,在`build/`下生成可执行文件`build/xllm/core/server/xllm`如果是A3请加`--device a3`其他设备无需加`--device`
5353
```bash
5454
python setup.py build
5555
```

setup.py

Lines changed: 98 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pathlib import Path
1212
from typing import List
1313
from jinja2 import Template
14+
import argparse
1415

1516
from distutils.core import Command
1617
from setuptools import Extension, setup, find_packages
@@ -30,6 +31,31 @@ def get_cpu_arch():
3031
else:
3132
raise ValueError(f"Unsupported architecture: {arch}")
3233

34+
# get device type
35+
def get_device_type():
36+
import torch
37+
38+
if torch.cuda.is_available():
39+
return "cuda"
40+
41+
try:
42+
import torch_mlu
43+
if torch.mlu.is_available():
44+
return "mlu"
45+
except ImportError:
46+
pass
47+
48+
try:
49+
import torch_npu
50+
if torch.npu.is_available():
51+
return "a2"
52+
except ImportError:
53+
pass
54+
55+
print("Unsupported device type, please install torch, torch_mlu or torch_npu")
56+
exit(1)
57+
58+
3359
def get_cxx_abi():
3460
try:
3561
import torch
@@ -227,8 +253,6 @@ def set_cuda_envs():
227253
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
228254
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
229255
os.environ["CUDA_TOOLKIT_ROOT_DIR"] = "/usr/local/cuda"
230-
os.environ["NCCL_ROOT"] = get_nccl_root_path()
231-
os.environ["NCCL_VERSION"] = "2"
232256

233257
class CMakeExtension(Extension):
234258
def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
@@ -562,54 +586,80 @@ def pre_build():
562586
if not run_shell_command("sh third_party/dependencies.sh", cwd=script_path):
563587
print("❌ Failed to reset changes!")
564588
exit(0)
589+
590+
def parse_arguments():
591+
parser = argparse.ArgumentParser(
592+
description='Setup helper for building xllm',
593+
epilog='Example: python setup.py build --device a3',
594+
usage='%(prog)s [COMMAND] [OPTIONS]'
595+
)
596+
597+
parser.add_argument(
598+
'setup_args',
599+
nargs='*',
600+
metavar='argparse.REMAINDER',
601+
help='setup command (build, test, bdist_wheel, etc.)'
602+
)
603+
604+
parser.add_argument(
605+
'--device',
606+
type=str.lower,
607+
choices=['auto', 'a2', 'a3', 'mlu', 'cuda'],
608+
default='auto',
609+
help='Device type: a2, a3, mlu, or cuda (case-insensitive)'
610+
)
611+
612+
parser.add_argument(
613+
'--dry-run',
614+
action='store_true',
615+
help='Dry run mode (do not execute pre_build)'
616+
)
617+
618+
parser.add_argument(
619+
'--install-xllm-kernels',
620+
type=str.lower,
621+
choices=['true', 'false', '1', '0', 'yes', 'no', 'y', 'n', 'on', 'off'],
622+
default='true',
623+
help='Whether to install xllm kernels'
624+
)
625+
626+
parser.add_argument(
627+
'--generate-so',
628+
type=str.lower,
629+
choices=['true', 'false', '1', '0', 'yes', 'no', 'y', 'n', 'on', 'off'],
630+
default='false',
631+
help='Whether to generate so or binary'
632+
)
633+
634+
args = parser.parse_args()
635+
636+
sys.argv = [sys.argv[0]] + args.setup_args
637+
638+
install_kernels = args.install_xllm_kernels.lower() in ('true', '1', 'yes', 'y', 'on')
639+
generate_so = args.generate_so.lower() in ('true', '1', 'yes', 'y', 'on')
640+
641+
return {
642+
'device': args.device,
643+
'dry_run': args.dry_run,
644+
'install_xllm_kernels': install_kernels,
645+
'generate_so': generate_so,
646+
}
647+
565648

566649
if __name__ == "__main__":
567-
device = 'a2' # default
650+
config = parse_arguments()
651+
568652
arch = get_cpu_arch()
569-
install_kernels = True
570-
generate_so = False
571-
if '--device' in sys.argv:
572-
idx = sys.argv.index('--device')
573-
if idx + 1 < len(sys.argv):
574-
device = sys.argv[idx+1].lower()
575-
if device not in ('a2', 'a3', 'mlu', 'cuda'):
576-
print("Error: --device must be a2 or a3 or mlu (case-insensitive)")
577-
sys.exit(1)
578-
# Remove the arguments so setup() doesn't see them
579-
del sys.argv[idx]
580-
del sys.argv[idx]
581-
if '--dry_run' not in sys.argv:
582-
pre_build()
583-
else:
584-
sys.argv.remove("--dry_run")
653+
device = config['device']
654+
if device == 'auto':
655+
device = get_device_type()
656+
print(f"🚀 Build xllm with CPU arch: {arch} and target device: {device}")
585657

586-
if '--install-xllm-kernels' in sys.argv:
587-
idx = sys.argv.index('--install-xllm-kernels')
588-
if idx + 1 < len(sys.argv):
589-
install_kernels = sys.argv[idx+1].lower()
590-
if install_kernels in ('true', '1', 'yes', 'y', 'on'):
591-
install_kernels = True
592-
elif install_kernels in ('false', '0', 'no', 'n', 'off'):
593-
install_kernels = False
594-
else:
595-
print("Error: --install-xllm-kernels must be true or false")
596-
sys.exit(1)
597-
sys.argv.pop(idx)
598-
sys.argv.pop(idx)
599-
600-
if '--generate-so' in sys.argv:
601-
idx = sys.argv.index('--generate-so')
602-
if idx + 1 < len(sys.argv):
603-
generate_so_val = sys.argv[idx+1].lower()
604-
if generate_so_val in ('true', '1', 'yes', 'y', 'on'):
605-
generate_so = True
606-
elif generate_so_val in ('false', '0', 'no', 'n', 'off'):
607-
generate_so = False
608-
else:
609-
print("Error: --generate-so must be true or false")
610-
sys.exit(1)
611-
sys.argv.pop(idx)
612-
sys.argv.pop(idx)
658+
if not config['dry_run']:
659+
pre_build()
660+
661+
install_kernels = config['install_xllm_kernels']
662+
generate_so = config['generate_so']
613663

614664
if "SKIP_TEST" in os.environ:
615665
BUILD_TEST_FILE = False
@@ -631,7 +681,7 @@ def pre_build():
631681
long_description=read_readme(),
632682
long_description_content_type="text/markdown",
633683
url="https://github.com/jd-opensource/xllm",
634-
project_url={
684+
project_urls={
635685
"Homepage": "https://xllm.readthedocs.io/zh-cn/latest/",
636686
"Documentation": "https://xllm.readthedocs.io/zh-cn/latest/",
637687
},
@@ -658,7 +708,7 @@ def pre_build():
658708
options={'build_ext': {
659709
'device': device,
660710
'arch': arch,
661-
'install_xllm_kernels': install_kernels if install_kernels is not None else "false",
711+
'install_xllm_kernels': install_kernels,
662712
'generate_so': generate_so
663713
},
664714
'bdist_wheel': {

0 commit comments

Comments
 (0)