-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Add NeuronxDistributedInference support, Speculative Decoding, Dynamic on-device sampling #16357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 40 commits
6531f92
626850c
7fcf9b2
3a8e7a5
bf9a4c6
a2671aa
0b075b1
703efd9
6c04502
31203d7
ae29889
466cd01
ac90709
f10f17b
d123747
547709a
e91857a
70ad6a9
2df94fe
da8d1cf
498a918
8bc6537
4907e3e
248b708
4d3c6ae
51f403f
543f55d
d15c20b
49d7558
87eeb3c
39ad22c
53e821a
aaa9f17
5f4eb2f
d04c552
ea4e9c7
7c86b26
d34883c
d762abd
9d49902
a31715c
fa055e5
8092f4b
e298c41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import argparse | ||
| import json | ||
| import os | ||
| import subprocess | ||
| import sys | ||
|
|
||
| from vllm.logger import init_logger | ||
|
|
||
| logger = init_logger("vllm.neuron.multi-node") | ||
|
|
||
| NEURON_RT_ROOT_COMM_ID_PORT = 63423 | ||
|
|
||
|
|
||
| def error_exit(message: str) -> None: | ||
| logger.error(message) | ||
| sys.exit(1) | ||
|
|
||
|
|
||
| def arg_parser(): | ||
| parser = argparse.ArgumentParser(description="vLLM multi-node launcher") | ||
| parser.add_argument("--model", | ||
| type=str, | ||
| required=True, | ||
| help="Model or model path") | ||
| parser.add_argument("--world-size", | ||
| type=int, | ||
| required=True, | ||
| help="World size for distributed inference") | ||
| parser.add_argument("--max-num-seqs", | ||
| type=int, | ||
| required=True, | ||
| help="Maximum number of sequences (or batch size)") | ||
| parser.add_argument("--max-model-len", | ||
| type=int, | ||
| default=8192, | ||
| help="Maximum sequence length") | ||
| parser.add_argument("--max-context-length", | ||
| type=int, | ||
| help="Maximum context length") | ||
| parser.add_argument("--compiled-model-path", | ||
| help="Path to the compiled model. If not present, " | ||
| "model artifacts will be created in local-models " | ||
| "folder") | ||
| parser.add_argument("--local-ranks-size", | ||
| type=int, | ||
| default=32, | ||
| help="Local ranks size") | ||
| parser.add_argument("--on-device-sampling-config", | ||
| type=json.loads, | ||
| help="On-device sampling configuration") | ||
| parser.add_argument("--quantized", | ||
| type=bool, | ||
| default=False, | ||
| help="Enable quantized mode (default: False)") | ||
| parser.add_argument("--quantized-checkpoints-path", | ||
| type=str, | ||
| help="Path to quantized checkpoints " | ||
| "(required if --quantized is True)") | ||
| parser.add_argument("--port", | ||
| type=int, | ||
| default=8080, | ||
| help="Port for the API server") | ||
|
|
||
| args = parser.parse_args() | ||
| if args.quantized and not args.quantized_checkpoints_path: | ||
| parser.error("--quantized-checkpoints-path is required when " | ||
| "--quantized is enabled.") | ||
| return args | ||
|
|
||
|
|
||
| def make_override_config(args, rank): | ||
| if rank < 0: | ||
| error_exit("rank must be a non-negative integer") | ||
| start_rank_id = rank * args.local_ranks_size | ||
| override_config = { | ||
| "world_size": args.world_size, | ||
| "tp_degree": args.local_ranks_size, | ||
| "local_ranks_size": args.local_ranks_size, | ||
| "start_rank_id": start_rank_id, | ||
| } | ||
|
|
||
| if args.max_context_length: | ||
| override_config["max_context_length"] = args.max_context_length | ||
| if args.on_device_sampling_config: | ||
| override_config[ | ||
| "on_device_sampling_config"] = args.on_device_sampling_config | ||
| if args.quantized: | ||
| override_config[ | ||
| "quantized_checkpoints_path"] = args.quantized_checkpoints_path | ||
| override_config["quantized"] = args.quantized | ||
|
|
||
| return override_config | ||
|
|
||
|
|
||
| def main() -> None: | ||
| args = arg_parser() | ||
|
|
||
| rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) | ||
| mpi_world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) | ||
| master_addr = os.environ.get("MASTER_ADDR") | ||
| # TODO: this script can be extended to support TnX | ||
| os.environ["VLLM_NEURON_FRAMEWORK"] = "neuronx-distributed-inference" | ||
| if args.compiled_model_path: | ||
| os.environ["NEURON_COMPILED_ARTIFACTS"] = args.compiled_model_path | ||
| os.environ.update({ | ||
| "ENABLE_NEURON_MULTI_NODE": "true", | ||
| "WORLD_SIZE": str(mpi_world_size), | ||
| "NEURON_RT_ROOT_COMM_ID": | ||
| f"{master_addr}:{NEURON_RT_ROOT_COMM_ID_PORT}", | ||
| "NEURON_LOCAL_TP": str(args.local_ranks_size), | ||
| "NEURON_RANK_ID": str(rank) | ||
| }) | ||
|
|
||
| override_config = make_override_config(args, rank) | ||
| if rank == 0: | ||
| logger.info("Starting vLLM API server on rank 0...") | ||
| cmd = [ | ||
| "python", "-m", "vllm.entrypoints.api_server", | ||
| f"--model={args.model}", f"--port={args.port}", "--device=neuron", | ||
| f"--max-num-seqs={args.max_num_seqs}", | ||
| f"--max-model-len={args.max_model_len}", | ||
| f"--override-neuron-config={json.dumps(override_config)}" | ||
| ] | ||
| logger.debug("Command ran: %s", cmd) | ||
| try: | ||
| subprocess.run(cmd, check=True) | ||
| except subprocess.CalledProcessError: | ||
| error_exit(f"Failed to start vLLM API server on rank {rank}") | ||
| else: | ||
| logger.info("Starting worker on rank: %s", rank) | ||
| current_script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| worker_file_path = os.path.join(current_script_dir, "worker.py") | ||
| cmd = [ | ||
| "python", worker_file_path, f"--model={args.model}", | ||
| "--device=neuron", f"--max-num-seqs={args.max_num_seqs}", | ||
| f"--max-model-len={args.max_model_len}", | ||
| f"--override-neuron-config={json.dumps(override_config)}" | ||
| ] | ||
| logger.debug("Command ran: %s", cmd) | ||
| try: | ||
| subprocess.run(cmd, check=True) | ||
| except subprocess.CalledProcessError: | ||
| error_exit(f"Failed to start worker on rank {rank}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| #!/bin/bash -ex | ||
|
|
||
| HOSTFILE="" | ||
| MASTER_ADDR="" | ||
| MASTER_PORT="" | ||
|
|
||
| usage() { | ||
| echo "Usage: $0 -h <hostfile> -a <master_address> -p <master_port> <python_command>" | ||
| exit 1 | ||
| } | ||
|
|
||
| while getopts "h:a:p:" opt; do | ||
| case "$opt" in | ||
| h) HOSTFILE=$OPTARG ;; | ||
| a) MASTER_ADDR=$OPTARG ;; | ||
| p) MASTER_PORT=$OPTARG ;; | ||
| *) usage ;; | ||
| esac | ||
| done | ||
|
|
||
| shift $((OPTIND - 1)) | ||
|
|
||
| if [ -z "$HOSTFILE" ] || [ -z "$MASTER_ADDR" ] || [ -z "$MASTER_PORT" ]; then | ||
| echo "Error: Missing required arguments." | ||
| usage | ||
| fi | ||
|
|
||
| echo "Using hostfile: $HOSTFILE" | ||
| echo "Using address: $MASTER_ADDR" | ||
| echo "Using port: $MASTER_PORT" | ||
| echo "Python command:" | ||
| echo "$@" | ||
|
|
||
| # Use mpirun to trigger inference on head/worker nodes | ||
|
|
||
| /opt/amazon/openmpi/bin/mpirun \ | ||
| --mca mtl ^ofi --mca btl tcp,self --bind-to none \ | ||
| -np 2 \ | ||
| --hostfile "$HOSTFILE"\ | ||
| --prefix /opt/amazon/openmpi \ | ||
| -x FI_PROVIDER=efa \ | ||
| -x FI_EFA_USE_DEVICE_RDMA=1 \ | ||
| -x FI_EFA_FORK_SAFE=1 \ | ||
| -x PATH=/opt/amazon/openmpi/bin:$PATH \ | ||
| -x PYTHONPATH=$PYTHONPATH \ | ||
| -x LD_LIBRARY_PATH=/opt/aws/neuron/lib:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:$LD_LIBRARY_PATH \ | ||
| -x MASTER_ADDR="$MASTER_ADDR" -x MASTER_PORT="$MASTER_PORT" \ | ||
| "$@" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import argparse | ||
| import os | ||
|
|
||
| from vllm.engine.arg_utils import AsyncEngineArgs | ||
| from vllm.engine.async_llm_engine import AsyncLLMEngine | ||
| from vllm.logger import init_logger | ||
| from vllm.usage.usage_lib import UsageContext | ||
|
|
||
| logger = init_logger("vllm.neuron.multi-node.worker") | ||
|
|
||
|
|
||
| def initialize_worker(): | ||
| parser = argparse.ArgumentParser() | ||
| parser = AsyncEngineArgs.add_cli_args(parser) | ||
| args = parser.parse_args() | ||
|
|
||
| engine_args = AsyncEngineArgs.from_cli_args(args) | ||
| engine = AsyncLLMEngine.from_engine_args( | ||
| engine_args, usage_context=UsageContext.API_SERVER) | ||
| return args, engine | ||
|
|
||
|
|
||
| def start_worker(): | ||
| rank_id = int(os.getenv("NEURON_RANK_ID")) | ||
| if rank_id == 0: | ||
| logger.error("Worker must have rank > 0") | ||
| args, engine = initialize_worker() | ||
| worker = engine.engine.model_executor.driver_worker | ||
| while True: | ||
| worker.execute_model() | ||
|
|
||
|
|
||
| def main(): | ||
| try: | ||
| start_worker() | ||
| except Exception as e: | ||
| logger.error("Failed starting worker %s", e) | ||
| exit(1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """ | ||
| This example shows how to run offline inference with an EAGLE speculative | ||
| decoding model on neuron. To use EAGLE speculative decoding, you must use | ||
| a draft model that is specifically fine-tuned for EAGLE speculation. | ||
| Additionally, to use EAGLE with NxD Inference, the draft model must include | ||
| the LM head weights from the target model. These weights are shared between | ||
| the draft and target model. | ||
| """ | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
||
| # Configurations | ||
| TARGET_MODEL_PATH = "/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct" | ||
| DRAFT_MODEL_PATH = "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft" | ||
| BATCH_SIZE = 4 | ||
| SEQ_LEN = 2048 | ||
| TENSOR_PARALLEL_SIZE = 32 | ||
| SPECULATION_LENGTH = 5 | ||
|
|
||
| # Sample prompts. | ||
| prompts = [ | ||
| "What is annapurna labs?", | ||
| ] | ||
|
|
||
| # Create a sampling params object. | ||
| sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) | ||
|
|
||
| # Create an LLM. | ||
| llm = LLM( | ||
| model=TARGET_MODEL_PATH, | ||
|
||
| speculative_model=DRAFT_MODEL_PATH, | ||
| max_num_seqs=BATCH_SIZE, | ||
| # The max_model_len and block_size arguments are required to be same as | ||
| # max sequence length when targeting neuron device. | ||
| # Currently, this is a known limitation in continuous batching support | ||
| # in neuronx-distributed-inference. | ||
| max_model_len=SEQ_LEN, | ||
| block_size=SEQ_LEN, | ||
| speculative_max_model_len=SEQ_LEN, | ||
| # The device can be automatically detected when AWS Neuron SDK is installed. | ||
| # The device argument can be either unspecified for automated detection, | ||
| # or explicitly assigned. | ||
| device="neuron", | ||
| tensor_parallel_size=TENSOR_PARALLEL_SIZE, | ||
| num_speculative_tokens=SPECULATION_LENGTH, | ||
| override_neuron_config={ | ||
| "enable_eagle_speculation": True, | ||
| "enable_fused_speculatuon": True | ||
| }, | ||
| ) | ||
|
|
||
| # Generate texts from the prompts. The output is a list of RequestOutput objects | ||
| # that contain the prompt, generated text, and other information. | ||
| outputs = llm.generate(prompts, sampling_params) | ||
| # Print the outputs. | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """ | ||
| This example shows how to run offline inference with a speculative | ||
| decoding model on neuron. | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
||
| # Sample prompts. | ||
| prompts = [ | ||
| "Hello, I am a language model and I can help", | ||
| "The president of the United States is", | ||
| "The capital of France is", | ||
| ] | ||
|
|
||
|
|
||
| def config_buckets(): | ||
| """Configure context length and token gen buckets.""" | ||
| # creates XLA hlo graphs for all the context length buckets. | ||
| os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" | ||
| # creates XLA hlo graphs for all the token gen buckets. | ||
| os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" | ||
|
|
||
|
|
||
| def initialize_model(): | ||
| """Create an LLM with speculative decoding.""" | ||
| return LLM( | ||
| model="openlm-research/open_llama_7b", | ||
| speculative_model='openlm-research/open_llama_3b', | ||
| num_speculative_tokens=4, | ||
| max_num_seqs=4, | ||
| max_model_len=2048, | ||
| block_size=2048, | ||
| speculative_max_model_len=2048, | ||
| use_v2_block_manager=True, | ||
| device="neuron", | ||
| tensor_parallel_size=32, | ||
| ) | ||
|
|
||
|
|
||
| def process_requests(model: LLM, sampling_params: SamplingParams): | ||
| """Generate texts from prompts and print them.""" | ||
| outputs = model.generate(prompts, sampling_params) | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
|
||
|
|
||
| def main(): | ||
| """Main function that sets up the model and processes prompts.""" | ||
| config_buckets() | ||
| model = initialize_model() | ||
| # Create a sampling params object. | ||
| sampling_params = SamplingParams(max_tokens=100, top_k=1) | ||
| process_requests(model, sampling_params) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to #8692
i would propose to be consistent with eco-system and leverage interfaces described in https://docs.vllm.ai/en/latest/serving/distributed_serving.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for linking the PR and docs. I checked them out and read the associated comments. In summary:
Both these require re-working the feature significantly. The previous PR was closed out as the above pending items were not addressed. Is this a good understanding of the situation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that's a good understanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discussed with @mrinalks and we decided to not include multi-node support in this PR as this needs to be heavily reworked. I'll remove all the multi-node specific code in the following revision.