Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ async def async_request_openai_completions(
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe instead of deleting it we add another condition for start_profile?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the naming of this function should be changed instead.

"completions"
), "OpenAI Completions API URL must end with 'completions'."
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
Expand Down
43 changes: 43 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,15 @@ def calculate_metrics(
async def benchmark(
backend: str,
api_url: str,
base_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
best_of: int,
use_beam_search: bool,
request_rate: float,
disable_tqdm: bool,
profile: bool,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
Expand All @@ -326,6 +328,22 @@ async def benchmark(
f"are correctly specified. Error: {test_output.error}")
else:
print("Initial test run completed. Starting main benchmark run...")

if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/start_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")

print(f"Traffic request rate: {request_rate}")

pbar = None if disable_tqdm else tqdm(total=len(input_requests))
Expand All @@ -349,6 +367,21 @@ async def benchmark(
pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

if profile:
print("Stopping profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/stop_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler stopped")

if pbar is not None:
pbar.close()

Expand Down Expand Up @@ -433,8 +466,10 @@ def main(args: argparse.Namespace):

if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
base_url = f"{args.base_url}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"

tokenizer = get_tokenizer(tokenizer_id,
trust_remote_code=args.trust_remote_code)
Expand Down Expand Up @@ -506,13 +541,15 @@ def main(args: argparse.Namespace):
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
best_of=args.best_of,
use_beam_search=args.use_beam_search,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
))

# Save config and results to json
Expand Down Expand Up @@ -693,6 +730,12 @@ def main(args: argparse.Namespace):
action="store_true",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--save-result",
action="store_true",
Expand Down
33 changes: 33 additions & 0 deletions docs/source/dev/profiling/profiling_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Profiling vLLM
=================================

We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/``
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to use a more common path as an example, such as $HOME/traces/ or /tmp/traces


The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest cover offline batching as well. It should be even easier by setting the environment variable before creating the engine.


When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag.

.. warning::

Only enable profiling in a development environment.


Traces can be visualized using https://ui.perfetto.dev/.

.. tip::

Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.

Example commands:

OpenAI Server:

.. code-block:: bash

VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B

benchmark_serving.py:

.. code-block:: bash

python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Documentation
dev/input_processing/model_inputs_index
dev/multimodal/multimodal_index
dev/dockerfile/dockerfile
dev/profiling/profiling_index

.. toctree::
:maxdepth: 1
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,3 +1266,9 @@ def remove_logger(self, logger_name: str) -> None:
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)

async def start_profile(self) -> None:
self.engine.model_executor._run_workers("start_profile")

async def stop_profile(self) -> None:
self.engine.model_executor._run_workers("stop_profile")
8 changes: 8 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,11 @@ async def do_log_stats(
async def check_health(self) -> None:
"""Raise if unhealthy"""
...

async def start_profile(self) -> None:
"""Start profiling the engine"""
...

async def stop_profile(self) -> None:
"""Start profiling the engine"""
...
20 changes: 20 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)


if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")

@router.post("/start_profile")
async def start_profile():
logger.info("Starting profiler...")
await async_engine_client.start_profile()
logger.info("Profiler started.")
return Response(status_code=200)

@router.post("/stop_profile")
async def stop_profile():
logger.info("Stopping profiler...")
await async_engine_client.stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)


def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum):
DO_LOG_STATS = 7
IS_SERVER_HEALTHY = 8
IS_TRACING_ENABLED = 9
START_PROFILE = 10
STOP_PROFILE = 11


RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
Expand Down
14 changes: 14 additions & 0 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,17 @@ async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")

async def start_profile(self) -> None:
"""Start profiling the engine"""

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.START_PROFILE,
error_message="RPCRequest START_PROFILE failed.")

async def stop_profile(self) -> None:
"""Stop profiling the engine"""

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")
24 changes: 24 additions & 0 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,26 @@ async def check_health(self, identity):
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])

async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")

await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])

async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")

await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])

def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
Expand Down Expand Up @@ -153,6 +173,10 @@ def _make_handler_coro(self, identity,
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
elif request == RPCUtilityRequest.START_PROFILE:
return self.start_profile(identity)
elif request == RPCUtilityRequest.STOP_PROFILE:
return self.stop_profile(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")

Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None


def get_default_cache_root():
Expand Down Expand Up @@ -384,6 +385,12 @@ def get_default_config_root():
"VLLM_PLUGINS":
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","),

# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
}

# end-env-vars-definition
Expand Down
31 changes: 31 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import torch
import torch.distributed

import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
Expand All @@ -27,6 +29,8 @@
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput

logger = init_logger(__name__)


class Worker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a GPU.
Expand Down Expand Up @@ -113,6 +117,33 @@ def __init__(
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}

# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None

def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()

def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()

def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model

Expand Down