diff --git a/python/ray/llm/_internal/serve/configs/openai_api_models.py b/python/ray/llm/_internal/serve/configs/openai_api_models.py index 0936abb9589b..98c3b9d491ce 100644 --- a/python/ray/llm/_internal/serve/configs/openai_api_models.py +++ b/python/ray/llm/_internal/serve/configs/openai_api_models.py @@ -1,719 +1,70 @@ -""" -Note (genesu): majority of this file is adapted from -- https://github.com/vllm-project/vllm/blob/5095e966069b9e65b7c4c63427e06cebacaad0a0/vllm/entrypoints/openai/protocol.py -- https://github.com/vllm-project/vllm/blob/5095e966069b9e65b7c4c63427e06cebacaad0a0/vllm/entrypoints/chat_utils.py -- https://github.com/openai/openai-python/tree/2e56c8da6f163db00a4ca362020148bb391edca9/src/openai/types/chat - -We patched `ErrorResponse` and `ResponseFormat` to be slightly different from the -original source. -""" - - -import time -from argparse import Namespace -from typing import ( - Any, - AsyncGenerator, - Dict, - Iterable, - List, - Literal, - Optional, - TypeVar, - Union, -) +from typing import Union, AsyncGenerator, Optional, Dict, Any, List from pydantic import ( BaseModel, - Field, - model_validator, -) -from typing_extensions import Annotated, Required, TypeAlias, TypedDict - -from ray.llm._internal.serve.configs.openai_api_models_patch import ( - ErrorResponse, - ResponseFormatType as ResponseFormat, -) -from ray.llm._internal.serve.configs.server_models import ( - LLMConfig, - LLMRawResponse, - ModelData, -) -from ray.serve._private.utils import ( - generate_request_id, + ConfigDict, ) -# openai.types.chat aliases. -# We use aliases becasuse openai.types.chat is not installed in the docs build. -# This is a hack to make the docs build pass. -ChatCompletionContentPartInputAudioParam = TypeVar( - "ChatCompletionContentPartInputAudioParam", bound=Any +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest as vLLMChatCompletionRequest, + ChatCompletionResponse as vLLMChatCompletionResponse, + ChatCompletionStreamResponse as vLLMChatCompletionStreamResponse, + ErrorResponse as vLLMErrorResponse, + CompletionRequest as vLLMCompletionRequest, + CompletionResponse as vLLMCompletionResponse, + CompletionStreamResponse as vLLMCompletionStreamResponse, + EmbeddingCompletionRequest as vLLMEmbeddingCompletionRequest, + EmbeddingChatRequest as vLLMEmbeddingChatRequest, + EmbeddingResponse as vLLMEmbeddingResponse, ) -ChatCompletionContentPartRefusalParam = TypeVar( - "ChatCompletionContentPartRefusalParam", bound=Any -) -ChatCompletionMessageToolCallParam = TypeVar( - "ChatCompletionMessageToolCallParam", bound=Any -) -OpenAIChatCompletionContentPartParam = TypeVar( - "OpenAIChatCompletionContentPartParam", bound=Any -) - -_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) - - -class AudioURL(TypedDict, total=False): - url: Required[str] - """ - Either a URL of the audio or a data URL with base64 encoded audio data. - """ - - -class ChatCompletionContentPartAudioParam(TypedDict, total=False): - audio_url: Required[AudioURL] - - type: Required[Literal["audio_url"]] - """The type of the content part.""" - - -class VideoURL(TypedDict, total=False): - url: Required[str] - """ - Either a URL of the video or a data URL with base64 encoded video data. - """ - - -class ChatCompletionContentPartVideoParam(TypedDict, total=False): - video_url: Required[VideoURL] - - type: Required[Literal["video_url"]] - """The type of the content part.""" - - -class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): - """A simpler version of the param that only accepts a plain image_url. - This is supported by OpenAI API, although it is not documented. - - Example: - { - "image_url": "https://example.com/image.jpg" - } - """ - - image_url: Required[str] - - -class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): - """A simpler version of the param that only accepts a plain audio_url. - - Example: - { - "audio_url": "https://example.com/audio.mp3" - } - """ - - audio_url: Required[str] - - -class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): - """A simpler version of the param that only accepts a plain audio_url. - - Example: - { - "video_url": "https://example.com/video.mp4" - } - """ - - video_url: Required[str] - - -# Ref: https://huggingface.co/mistral-community/pixtral-12b -# -# Community version of pixtral uses the key `content` instead of `text` in the content. -# This is to support the "content" content type in the prompt format, as opposite of -# the "text" content from the above which most other model uses. -class ChatCompletionContentPartContentParam(TypedDict, total=False): - content: Required[str] - """The content content.""" - - type: Required[Literal["text"]] - """The type of the content part.""" - - -ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, - ChatCompletionContentPartAudioParam, - ChatCompletionContentPartInputAudioParam, - ChatCompletionContentPartVideoParam, - ChatCompletionContentPartRefusalParam, - CustomChatCompletionContentSimpleImageParam, - CustomChatCompletionContentSimpleAudioParam, - CustomChatCompletionContentSimpleVideoParam, - str, -] - - -class ChatCompletionMessageParam(TypedDict, total=False): - """Enables custom roles in the Chat Completion API.""" - - role: Required[str] - """The role of the message's author.""" - - content: Union[str, List[ChatCompletionContentPartParam]] - """The contents of the message.""" - - name: str - """An optional name for the participant. - - Provides the model information to differentiate between participants of the - same role. - """ - - tool_call_id: Optional[str] - """Tool call that this message is responding to.""" - - tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] - """The tool calls generated by the model, such as function calls.""" - - -class StreamOptions(BaseModel): - include_usage: Optional[bool] = True - continuous_usage_stats: Optional[bool] = False - - -class FunctionDefinition(BaseModel): - name: str - description: Optional[str] = None - parameters: Optional[Dict[str, Any]] = None - - -class ChatCompletionToolsParam(BaseModel): - type: Literal["function"] = "function" - function: FunctionDefinition - - -class ChatCompletionNamedFunction(BaseModel): - name: str - -class ChatCompletionNamedToolChoiceParam(BaseModel): - function: ChatCompletionNamedFunction - type: Literal["function"] = "function" +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ray.llm._internal.serve.configs.server_models import LLMConfig -class LogitsProcessorConstructor(BaseModel): - qualname: str - args: Optional[List[Any]] = None - kwargs: Optional[Dict[str, Any]] = None +class ChatCompletionRequest(vLLMChatCompletionRequest): + pass -LogitsProcessors = List[Union[str, LogitsProcessorConstructor]] +class ChatCompletionResponse(vLLMChatCompletionResponse): + pass -class ChatCompletionRequest(BaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/chat/create - messages: Annotated[List[ChatCompletionMessageParam], Field(min_length=1)] - model: str - frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[Dict[str, float]] = None - logprobs: Optional[bool] = False - top_logprobs: Optional[int] = 0 - # TODO(#9845): remove max_tokens when field is removed from OpenAI API - max_tokens: Optional[int] = Field( - default=None, - deprecated="max_tokens is deprecated in favor of the max_completion_tokens field", - ) - max_completion_tokens: Optional[int] = None - n: Optional[int] = 1 - presence_penalty: Optional[float] = 0.0 - response_format: Optional[ResponseFormat] = None - seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, List[str]]] = Field(default_factory=list) - stream: Optional[bool] = False - stream_options: Optional[StreamOptions] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - tools: Optional[List[ChatCompletionToolsParam]] = None - tool_choice: Optional[ - Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam] - ] = "none" - - # NOTE this will be ignored by vLLM -- the model determines the behavior - parallel_tool_calls: Optional[bool] = False - user: Optional[str] = None - - # doc: begin-chat-completion-sampling-params - best_of: Optional[int] = None - use_beam_search: bool = False - top_k: Optional[int] = None - min_p: Optional[float] = None - repetition_penalty: Optional[float] = None - length_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = Field(default_factory=list) - include_stop_str_in_output: bool = False - ignore_eos: bool = False - min_tokens: int = 0 - skip_special_tokens: bool = True - spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - prompt_logprobs: Optional[int] = None - # doc: end-chat-completion-sampling-params - - # doc: begin-chat-completion-extra-params - echo: bool = Field( - default=False, - description=( - "If true, the new message will be prepended with the last message " - "if they belong to the same role." - ), - ) - add_generation_prompt: bool = Field( - default=True, - description=( - "If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model." - ), - ) - continue_final_message: bool = Field( - default=False, - description=( - "If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - 'This allows you to "prefill" part of the model\'s response for it. ' - "Cannot be used at the same time as `add_generation_prompt`." - ), - ) - add_special_tokens: bool = Field( - default=False, - description=( - "If true, special tokens (e.g. BOS) will be added to the prompt " - "on top of what is added by the chat template. " - "For most models, the chat template takes care of adding the " - "special tokens so this should be set to false (as is the " - "default)." - ), - ) - documents: Optional[List[Dict[str, str]]] = Field( - default=None, - description=( - "A list of dicts representing documents that will be accessible to " - "the model if it is performing RAG (retrieval-augmented generation)." - " If the template does not support RAG, this argument will have no " - "effect. We recommend that each document should be a dict containing " - '"title" and "text" keys.' - ), - ) - chat_template: Optional[str] = Field( - default=None, - description=( - "A Jinja template to use for this conversion. " - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ), - ) - chat_template_kwargs: Optional[Dict[str, Any]] = Field( - default=None, - description=( - "Additional kwargs to pass to the template renderer. " - "Will be accessible by the chat template." - ), - ) - guided_json: Optional[Union[str, dict, BaseModel]] = Field( - default=None, - description=("If specified, the output will follow the JSON schema."), - ) - guided_regex: Optional[str] = Field( - default=None, - description=("If specified, the output will follow the regex pattern."), - ) - guided_choice: Optional[List[str]] = Field( - default=None, - description=("If specified, the output will be exactly one of the choices."), - ) - guided_grammar: Optional[str] = Field( - default=None, - description=("If specified, the output will follow the context free grammar."), - ) - guided_decoding_backend: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be either " - "'outlines' / 'lm-format-enforcer'" - ), - ) - guided_whitespace_pattern: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding." - ), - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - request_id: str = Field( - default_factory=lambda: f"{generate_request_id()}", - description=( - "The request_id related to this request. If the caller does " - "not set it, a generate_request_id will be generated. This id is used " - "through out the inference process and return in response." - ), - ) - logits_processors: Optional[LogitsProcessors] = Field( - default=None, - description=( - "A list of either qualified names of logits processors, or " - "constructor objects, to apply when sampling. A constructor is " - "a JSON object with a required 'qualname' field specifying the " - "qualified name of the processor class/factory, and optional " - "'args' and 'kwargs' fields containing positional and keyword " - "arguments. For example: {'qualname': " - "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}." - ), - ) - - # doc: end-chat-completion-extra-params - - -class CompletionRequest(BaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/completions/create - model: str - prompt: Union[List[int], List[List[int]], str, List[str]] - best_of: Optional[int] = None - echo: Optional[bool] = False - frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[Dict[str, float]] = None - logprobs: Optional[int] = None - max_tokens: Optional[int] = 16 - n: int = 1 - presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, List[str]]] = Field(default_factory=list) - stream: Optional[bool] = False - stream_options: Optional[StreamOptions] = None - suffix: Optional[str] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - user: Optional[str] = None - - # doc: begin-completion-sampling-params - use_beam_search: bool = False - top_k: Optional[int] = None - min_p: Optional[float] = None - repetition_penalty: Optional[float] = None - length_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = Field(default_factory=list) - include_stop_str_in_output: bool = False - ignore_eos: bool = False - min_tokens: int = 0 - skip_special_tokens: bool = True - spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - allowed_token_ids: Optional[List[int]] = None - prompt_logprobs: Optional[int] = None - # doc: end-completion-sampling-params - - # doc: begin-completion-extra-params - add_special_tokens: bool = Field( - default=True, - description=( - "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt." - ), - ) - response_format: Optional[ResponseFormat] = Field( - default=None, - description=( - "Similar to chat completion, this parameter specifies the format of " - "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or " - "{'type': 'text' } is supported." - ), - ) - guided_json: Optional[Union[str, dict, BaseModel]] = Field( - default=None, - description="If specified, the output will follow the JSON schema.", - ) - guided_regex: Optional[str] = Field( - default=None, - description=("If specified, the output will follow the regex pattern."), - ) - guided_choice: Optional[List[str]] = Field( - default=None, - description=("If specified, the output will be exactly one of the choices."), - ) - guided_grammar: Optional[str] = Field( - default=None, - description=("If specified, the output will follow the context free grammar."), - ) - guided_decoding_backend: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be one of " - "'outlines' / 'lm-format-enforcer'" - ), - ) - guided_whitespace_pattern: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding." - ), - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - logits_processors: Optional[LogitsProcessors] = Field( - default=None, - description=( - "A list of either qualified names of logits processors, or " - "constructor objects, to apply when sampling. A constructor is " - "a JSON object with a required 'qualname' field specifying the " - "qualified name of the processor class/factory, and optional " - "'args' and 'kwargs' fields containing positional and keyword " - "arguments. For example: {'qualname': " - "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}." - ), - ) - - # doc: end-completion-extra-params - - -class FunctionCall(BaseModel): - name: str - arguments: str - - -class ToolCall(BaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-tool-{generate_request_id()}") - type: Literal["function"] = "function" - function: FunctionCall +class ChatCompletionStreamResponse(vLLMChatCompletionStreamResponse): + pass -class ChatMessage(BaseModel): - role: str - reasoning_content: Optional[str] = None - content: Optional[str] = None - tool_calls: List[ToolCall] = Field(default_factory=list) +class ErrorResponse(vLLMErrorResponse): + pass -class ChatCompletionLogProb(BaseModel): - token: str - logprob: float = -9999.0 - bytes: Optional[List[int]] = None +class CompletionRequest(vLLMCompletionRequest): + pass -class ChatCompletionLogProbsContent(ChatCompletionLogProb): - top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) +class CompletionResponse(vLLMCompletionResponse): + pass -class ChatCompletionLogProbs(BaseModel): - content: Optional[List[ChatCompletionLogProbsContent]] = None +class CompletionStreamResponse(vLLMCompletionStreamResponse): + pass -class ChatCompletionResponseChoice(BaseModel): - index: int - message: ChatMessage - logprobs: Optional[ChatCompletionLogProbs] = None - # per OpenAI spec this is the default - finish_reason: Optional[str] = "stop" - # not part of the OpenAI spec but included in vLLM for legacy reasons - stop_reason: Optional[Union[int, str]] = None +class EmbeddingCompletionRequest(vLLMEmbeddingCompletionRequest): + pass -class DeltaFunctionCall(BaseModel): - name: Optional[str] = None - arguments: Optional[str] = None +class EmbeddingChatRequest(vLLMEmbeddingChatRequest): + pass -class DeltaToolCall(BaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-tool-{generate_request_id()}") - type: Literal["function"] = "function" - index: int - function: Optional[DeltaFunctionCall] = None +class EmbeddingResponse(vLLMEmbeddingResponse): + pass -class DeltaMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None - reasoning_content: Optional[str] = None - tool_calls: List[DeltaToolCall] = Field(default_factory=list) - - @model_validator(mode="after") - def _non_null_content(self): - self.content = self.content or "" - return self - - -class ChatCompletionResponseStreamChoice(BaseModel): - index: int - delta: DeltaMessage - logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None - - -class PromptTokenUsageInfo(BaseModel): - cached_tokens: Optional[int] = None - - -class UsageInfo(BaseModel): - prompt_tokens: int = 0 - total_tokens: int = 0 - completion_tokens: Optional[int] = 0 - prompt_tokens_details: Optional[PromptTokenUsageInfo] = None - - -class Logprob(BaseModel): - """Infos for supporting OpenAI compatible logprobs and token ranks. - - Attributes: - logprob: The logprob of chosen token - rank: The vocab rank of chosen token (>=1) - decoded_token: The decoded chosen token index - """ - - logprob: float - rank: Optional[int] = None - decoded_token: Optional[str] = None - - -class ChatCompletionStreamResponse(BaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-{generate_request_id()}") - object: Literal["chat.completion.chunk"] = "chat.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[ChatCompletionResponseStreamChoice] - usage: Optional[UsageInfo] = Field(default=None) - - -class ChatCompletionResponse(BaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-{generate_request_id()}") - object: Literal["chat.completion"] = "chat.completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[ChatCompletionResponseChoice] - usage: UsageInfo - prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None - - -class CompletionLogProbs(BaseModel): - text_offset: List[int] = Field(default_factory=list) - token_logprobs: List[Optional[float]] = Field(default_factory=list) - tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) - - -class CompletionResponseChoice(BaseModel): - index: int - text: str - logprobs: Optional[CompletionLogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = Field( - default=None, - description=( - "The stop string or token id that caused the completion " - "to stop, None if the completion finished for some other reason " - "including encountering the EOS token" - ), - ) - prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None - - -class CompletionResponse(BaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{generate_request_id()}") - object: str = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[CompletionResponseChoice] - usage: UsageInfo - - -class CompletionResponseStreamChoice(BaseModel): - index: int - text: str - logprobs: Optional[CompletionLogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = Field( - default=None, - description=( - "The stop string or token id that caused the completion " - "to stop, None if the completion finished for some other reason " - "including encountering the EOS token" - ), - ) - - -class CompletionStreamResponse(BaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{generate_request_id()}") - object: str = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[CompletionResponseStreamChoice] - usage: Optional[UsageInfo] = Field(default=None) - - -class EmbeddingCompletionRequest(BaseModel): - model: Optional[str] = None - input: Union[List[int], List[List[int]], str, List[str]] - encoding_format: Literal["float", "base64"] = "float" - dimensions: Optional[int] = None - user: Optional[str] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - - additional_data: Optional[Any] = None - add_special_tokens: bool = Field( - default=True, - description=( - "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt." - ), - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - - -EmbeddingRequest = EmbeddingCompletionRequest - - -class EmbeddingResponseData(BaseModel): - index: int - object: str = "embedding" - embedding: Union[List[float], str] - - -class EmbeddingResponse(BaseModel): - id: str = Field(default_factory=lambda: f"embd-{generate_request_id()}") - object: str = "list" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - data: List[EmbeddingResponseData] - usage: UsageInfo +EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] LLMEmbeddingsResponse = Union[ AsyncGenerator[Union[EmbeddingResponse, ErrorResponse], None], @@ -731,7 +82,7 @@ class EmbeddingResponse(BaseModel): ], ] - +# TODO: remove this class class OpenAIHTTPException(Exception): def __init__( self, @@ -745,21 +96,33 @@ def __init__( self.type = type self.internal_message = internal_message - @classmethod - def from_model_response(cls, response: LLMRawResponse) -> "OpenAIHTTPException": - return cls( - status_code=response.error.code, - message=response.error.message, - type=response.error.type, - internal_message=response.error.internal_message, - ) + +# TODO: upstream metadata for ModelData +# Compared to vLLM this has a metadata field. +class ModelCard(BaseModel): + model_config = ConfigDict(protected_namespaces=tuple()) + + id: str + object: str + owned_by: str + permission: List[str] + metadata: Dict[str, Any] + + @property + def model_type(self) -> str: + return self.metadata["engine_config"]["model_type"] + + +class ModelList(BaseModel): + data: List[ModelCard] + object: str = "list" def to_model_metadata( model_id: str, - model_config: LLMConfig, + model_config: "LLMConfig", overrides: Optional[Dict[str, Any]] = None, -): +) -> ModelCard: """Creates an OpenAI-compatible ModelData object. Args: @@ -779,10 +142,10 @@ def to_model_metadata( if overrides: metadata.update(overrides) - return ModelData( + return ModelCard( id=model_id, - rayllm_metadata=metadata, object="model", owned_by="organization-owner", permission=[], + metadata=metadata, ) diff --git a/python/ray/llm/_internal/serve/configs/prompt_formats.py b/python/ray/llm/_internal/serve/configs/prompt_formats.py deleted file mode 100644 index fe1bdf47527b..000000000000 --- a/python/ray/llm/_internal/serve/configs/prompt_formats.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import ( - Any, - Dict, - List, - Literal, - Optional, - Union, -) - -from pydantic import ( - BaseModel, - field_validator, - model_validator, -) - -from ray.llm._internal.common.utils.import_utils import try_import - -transformers = try_import("transformers") - - -class Text(BaseModel): - type: str = "text" - text: str - - -# Ref: https://huggingface.co/mistral-community/pixtral-12b -# -# Community version of pixtral uses the key `content` instead of `text` in the content. -# This is to support the "content" content type in the prompt format, as opposite of -# the "text" content from the above which most other model uses. -class Content(BaseModel): - type: str = "text" - content: str - - -class Image(BaseModel): - type: str = "image_url" - image_url: Dict - - @field_validator("image_url") - @classmethod - def check_image_url(cls, value): - """Checks if the image_url is a dict with a 'url' key. - Example: - image_url = { - "url": "https://example.com/image.png" - } - """ - if "url" not in value or not value["url"] or not isinstance(value["url"], str): - raise ValueError( - # TODO(xwjiang): Link to doc. - "Expecting 'url' string to be provided under 'image_url' dict." - ) - return value - - -ContentList = List[Union[Image, Text, Content]] - - -class Message(BaseModel): - role: Literal["system", "assistant", "user"] - content: Optional[Union[str, ContentList]] = None - - def __str__(self): - return self.model_dump_json() - - @model_validator(mode="after") - def check_fields(self): - if self.role == "system": - if not isinstance(self.content, str): - raise ValueError("System content must be a string") - if self.role == "user" and self.content is None: - raise ValueError("User content must not be None.") - if self.role == "assistant": - # passing a regular assistant message - if self.content is not None and not isinstance(self.content, str): - raise ValueError("content must be a string or None") - return self - - -class Prompt(BaseModel): - prompt: Union[str, List[Message]] - use_prompt_format: bool = True - parameters: Optional[Dict[str, Any]] = None - - @field_validator("parameters", mode="before") - @classmethod - def parse_parameters(cls, value): - if isinstance(value, BaseModel): - # Use exclude_unset so that we can distinguish unset values from default values - return value.model_dump(exclude_unset=True) - return value - - @field_validator("prompt") - @classmethod - def check_prompt(cls, value): - if isinstance(value, list) and not value: - raise ValueError("Messages cannot be an empty list.") - return value - - def to_unformatted_string(self) -> str: - if isinstance(self.prompt, list): - return ", ".join(str(message.content) for message in self.prompt) - return self.prompt - - -class ImageInput(BaseModel): - """Prompt output that contains image info.""" - - image_url: str - - -class EngineInput(BaseModel): - """Input to the engine. - - Which is also output from `PromptFormat.generate_prompt()`.""" - - text: str - image: Optional[List[ImageInput]] = None diff --git a/python/ray/llm/_internal/serve/configs/server_models.py b/python/ray/llm/_internal/serve/configs/server_models.py index 72857df44069..ccb67b260b5e 100644 --- a/python/ray/llm/_internal/serve/configs/server_models.py +++ b/python/ray/llm/_internal/serve/configs/server_models.py @@ -7,9 +7,7 @@ List, Optional, Sequence, - Set, Tuple, - Type, TypeVar, Union, ) @@ -37,17 +35,9 @@ DEFAULT_MULTIPLEX_DOWNLOAD_TIMEOUT_S, DEFAULT_MULTIPLEX_DOWNLOAD_TRIES, ENABLE_WORKER_PROCESS_SETUP_HOOK, - MAX_NUM_STOPPING_SEQUENCES, MODEL_RESPONSE_BATCH_TIMEOUT_MS, ) -from ray.llm._internal.serve.configs.error_handling import TooManyStoppingSequences -from ray.llm._internal.serve.configs.openai_api_models_patch import ( - ErrorResponse, - ResponseFormatType, -) -from ray.llm._internal.serve.configs.prompt_formats import ( - Prompt, -) +from ray.llm._internal.serve.configs.openai_api_models import ErrorResponse from ray.llm._internal.serve.observability.logging import get_logger from ray.serve._private.config import DeploymentConfig @@ -572,32 +562,6 @@ def parse_args(self) -> "LLMServingArgs": return LLMServingArgs(llm_configs=llm_configs) -TModel = TypeVar("TModel", bound="Model") - - -class ModelData(BaseModel): - model_config = ConfigDict(protected_namespaces=tuple()) - - id: str - object: str - owned_by: str - permission: List[str] - rayllm_metadata: Dict[str, Any] - - @property - def model_type(self) -> str: - return self.rayllm_metadata["engine_config"]["model_type"] - - -class Model(BaseModel): - data: List[ModelData] - object: str = "list" - - @classmethod - def list(cls) -> TModel: - pass - - class FinishReason(str, Enum): LENGTH = "length" STOP = "stop" @@ -866,100 +830,3 @@ def merge_dicts(base: Dict, overwrite: Dict) -> Dict: else: base[key] = overwrite[key] return base - - -class SamplingParams(BaseModelExtended): - """Parameters for controlling text generation sampling. - - Args: - max_tokens: The maximum number of tokens to generate. Defaults to inf. - temperature: What sampling temperature to use. - top_p: An alternative to sampling with temperature, called nucleus sampling. - n: How many completions to generate for each prompt. - logprobs: Include the log probabilities on the `logprobs` most likely - tokens, as well the chosen tokens. - top_logprobs: The number of logprobs to return. Defaults to 1. `logprobs` - must be set to `True` in order to use top_logprobs. - stop: Up to 4 sequences where the API will stop generating further tokens. - The returned text will not contain the stop sequence. - stop_tokens: Tokens to stop on (applied before detokenization). - presence_penalty: Number between -2.0 and 2.0. - Positive values penalize new tokens based on whether they appear in - the text so far, increasing the model's likelihood to talk about - new topics. - frequency_penalty: Number between -2.0 and 2.0. Positive values penalize - new tokens based on their existing frequency in the text so far, - decreasing the model's likelihood to repeat the same line verbatim. - best_of: Generates `best_of` completions server-side and returns the "best". - logit_bias: Modify the likelihood of specified tokens appearing in - the completion. - response_format: Format to return the final response in. Can be for ex: - response_format={"type": "json", "schema": "{...}"} - """ - - _ignored_fields: Set[str] = set() - - max_tokens: Optional[int] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - n: int = 1 - logprobs: Optional[bool] = None - top_logprobs: Optional[int] = None - logit_bias: Optional[Dict[str, float]] = None - stop: Optional[List[str]] = None - stop_tokens: Optional[List[int]] = None - ignore_eos: Optional[bool] = None - presence_penalty: Optional[float] = None - frequency_penalty: Optional[float] = None - best_of: int = 1 - response_format: Optional[ResponseFormatType] = None - - def model_dump(self, **kwargs) -> Dict[str, Any]: - if kwargs.get("exclude", None) is None: - kwargs["exclude"] = self._ignored_fields - return super().model_dump(**kwargs) - - @field_validator("stop", mode="before") - @classmethod - def validate_stopping_sequences(cls, values): - if not values: - return values - - unique_val = sorted(set(values)) - - if len(unique_val) > MAX_NUM_STOPPING_SEQUENCES: - TooManyStoppingSequences( - len(unique_val), MAX_NUM_STOPPING_SEQUENCES - ).raise_exception() - - return list(unique_val) - - @field_validator("stop_tokens", mode="before") - @classmethod - def validate_stop_tokens(cls, values): - if not values: - return values - return sorted(set(values)) - - @classmethod - def _get_model_validate_kwargs(cls: Type[ModelT], prompt: Prompt) -> Dict[str, Any]: - generate_kwargs = prompt.parameters or {} - if not isinstance(generate_kwargs, dict): - generate_kwargs = generate_kwargs.model_dump(exclude_unset=True) - - return generate_kwargs - - @classmethod - def from_prompt(cls: Type[ModelT], prompt: Prompt) -> ModelT: - # Extract parameters object from prompt - generate_kwargs = cls._get_model_validate_kwargs(prompt) - return cls.model_validate(generate_kwargs) - - -class GenerationRequest(BaseModelExtended): - prompt: Union[str, List[int], List[str]] - prompt_token_ids: Optional[List[int]] = None - request_id: Union[str, List[str]] - sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None - stream: bool = False - metadata: Optional[Dict[str, Any]] = None diff --git a/python/ray/llm/_internal/serve/deployments/llm/llm_engine.py b/python/ray/llm/_internal/serve/deployments/llm/llm_engine.py index cd32c4640005..f0d0637990e0 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/llm_engine.py +++ b/python/ray/llm/_internal/serve/deployments/llm/llm_engine.py @@ -1,17 +1,14 @@ import abc -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, Any from ray.llm._internal.serve.configs.server_models import ( DiskMultiplexConfig, - GenerationRequest, LLMConfig, - LLMRawResponse, - Prompt, ) class LLMEngine(abc.ABC): - """Base class for all LLM engines""" + """Base protocal class for all LLM engines""" @abc.abstractmethod def __init__(self, llm_config: LLMConfig): @@ -24,22 +21,23 @@ async def start(self): pass @abc.abstractmethod - async def prepare_request( - self, - request_id: str, - prompt: Prompt, - stream: bool, - disk_lora_model: Optional[DiskMultiplexConfig] = None, - **kwargs, - ) -> GenerationRequest: - """Prepare a GenerationRequest for the engine""" + async def resolve_lora(self, lora_model: DiskMultiplexConfig): + """Resolve the lora model""" pass @abc.abstractmethod - async def generate( - self, request: GenerationRequest - ) -> AsyncGenerator[LLMRawResponse, None]: - """Generate an LLMRawResponse stream based on the GenerationRequest""" + async def chat(self, request) -> AsyncGenerator[Any, None]: + """Chat with the engine""" + pass + + @abc.abstractmethod + async def completions(self, request) -> AsyncGenerator[Any, None]: + """Completion with the engine""" + pass + + @abc.abstractmethod + async def embeddings(self, request) -> AsyncGenerator[Any, None]: + """Embed with the engine""" pass async def check_health(self) -> None: diff --git a/python/ray/llm/_internal/serve/deployments/llm/llm_server.py b/python/ray/llm/_internal/serve/deployments/llm/llm_server.py index 9573d25f42a3..023ceec971fa 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/llm_server.py +++ b/python/ray/llm/_internal/serve/deployments/llm/llm_server.py @@ -1,13 +1,11 @@ import asyncio import os from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union, AsyncGenerator, List -# Third-party imports from ray import serve from ray._common.utils import import_attr -# Local imports from ray.llm._internal.serve.configs.constants import ( DEFAULT_HEALTH_CHECK_PERIOD_S, DEFAULT_HEALTH_CHECK_TIMEOUT_S, @@ -16,57 +14,34 @@ RAYLLM_VLLM_ENGINE_CLS_ENV, ) from ray.llm._internal.serve.configs.openai_api_models import ( - ChatCompletionLogProb, - ChatCompletionLogProbs, - ChatCompletionLogProbsContent, ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, CompletionRequest, CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - DeltaMessage, EmbeddingRequest, EmbeddingResponse, - EmbeddingResponseData, LLMChatResponse, LLMCompletionsResponse, - LLMEmbeddingsResponse, - UsageInfo, ) -from ray.llm._internal.serve.configs.prompt_formats import Message, Prompt +from ray.llm._internal.serve.deployments.llm.multiplex.lora_model_loader import ( + LoraModelLoader, +) from ray.llm._internal.serve.configs.server_models import ( - DiskMultiplexConfig, LLMConfig, - LLMRawResponse, ) from ray.llm._internal.serve.deployments.llm.llm_engine import LLMEngine -from ray.llm._internal.serve.deployments.llm.multiplex.lora_model_loader import ( - LoraModelLoader, -) from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import VLLMEngine -from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( - VLLMEmbeddingRequest, -) from ray.llm._internal.serve.deployments.utils.batcher import OpenAIResponseBatcher -from ray.llm._internal.serve.deployments.utils.error_handling_utils import ( - StreamingErrorHandler, -) from ray.llm._internal.serve.deployments.utils.server_utils import ( - get_model_request_id, - get_response_for_error, get_serve_request_id, ) +from ray.llm._internal.serve.configs.server_models import DiskMultiplexConfig from ray.llm._internal.serve.observability.logging import get_logger from ray.llm._internal.serve.observability.usage_telemetry.usage import ( push_telemetry_report_for_all_models, ) + logger = get_logger(__name__) @@ -77,11 +52,10 @@ class _LLMServerBase(ABC): """ # TODO (Kourosh): I don't know why this is an async init. Need to fix. - async def __init__(self, llm_config: LLMConfig): + async def __init__(self): """ Constructor takes in an LLMConfig object and start the underlying engine. """ - self._llm_config = llm_config @abstractmethod async def chat(self, request: ChatCompletionRequest) -> LLMChatResponse: @@ -105,308 +79,22 @@ async def check_health(self) -> None: """ ... - async def llm_config(self) -> LLMConfig: - return self._llm_config - + # TODO (Kourosh): This does not belong here. + async def llm_config(self) -> Optional[LLMConfig]: + return None -class ResponsePostprocessor: - """Processes raw LLM responses into OpenAI-compatible formats. - This class handles: - 1. Error handling for the response stream - 2. Converting LLMRawResponse to Chat/Completion API formats - 3. Supporting both streaming and non-streaming responses - """ +class LLMServer(_LLMServerBase): + """This is a shm layer to decouple the LLM engine from the ingress deployment. - def __init__(self): - self.metrics_wrapper = StreamingErrorHandler() - - async def handle_failure( - self, model: str, gen: AsyncGenerator[LLMRawResponse, None] - ) -> AsyncGenerator[LLMRawResponse, None]: - async for llm_response in self.metrics_wrapper.handle_failure(model, gen): - yield llm_response - - @staticmethod - async def merge_stream( - response_stream: AsyncGenerator[LLMRawResponse, None] - ) -> LLMRawResponse: - responses = [resp async for resp in response_stream] - return LLMRawResponse.merge_stream(*responses) - - async def process_chat( - self, model: str, gen: AsyncGenerator[LLMRawResponse, None], stream: bool - ) -> LLMChatResponse: - """Process raw LLM responses into chat completion format.""" - gen = self.handle_failure(model=model, gen=gen) - request_id = get_serve_request_id() - completion_id = get_model_request_id(model) - - if stream: - # Stream processing - preserve batching from generator - yielded_role = False - all_results = [] - try: - async for batched_results in gen: - - for result in batched_results.unpack(): - all_results.append(result) - - # Handle errors - if result.error: - logger.error(f"{result.error}") - # Drop finish reason as OpenAI doesn't expect it for errors - result.finish_reason = None - all_results.pop() - yield result.error - return - - finish_reason = result.finish_reason - - # Send role message first - if not yielded_role: - yield ChatCompletionStreamResponse( - id=completion_id, - model=model, - choices=[ - ChatCompletionResponseStreamChoice( - delta=DeltaMessage(role="assistant"), - index=0, - finish_reason=None, - logprobs=ChatCompletionLogProbs(content=[]), - ) - ], - usage=None, - ) - yielded_role = True - - # Process logprobs if present - logprobs = None - if result.logprobs: - logprobs = ChatCompletionLogProbs( - content=[ - ChatCompletionLogProbsContent( - token=logprobs.token, - logprob=logprobs.logprob, - bytes=logprobs.bytes, - top_logprobs=[ - ChatCompletionLogProb( - token=logprob.token, - logprob=logprob.logprob, - bytes=logprob.bytes, - ) - for logprob in logprobs.top_logprobs - ], - ) - for logprobs in result.logprobs - ] - ) - - yield ChatCompletionStreamResponse( - id=completion_id, - model=model, - choices=[ - ChatCompletionResponseStreamChoice( - delta=DeltaMessage( - content=result.generated_text or "" - ), - index=0, - finish_reason=None, - logprobs=logprobs, - ) - ], - usage=None, - ) - - # Send final message with finish_reason if there were any results - # TODO (Kourosh): Doing this much for the last token - # (usage token) might add extra overhead to ITL of the last token. - # We should find a better way to do this. - if all_results: - merged_results = LLMRawResponse.merge_stream(*all_results) - finish_reason = merged_results.finish_reason - usage = UsageInfo( - prompt_tokens=merged_results.num_input_tokens or 0, - completion_tokens=merged_results.num_generated_tokens or 0, - total_tokens=(merged_results.num_input_tokens or 0) - + (merged_results.num_generated_tokens or 0), - ) - - yield ChatCompletionStreamResponse( - id=completion_id, - model=model, - choices=[ - ChatCompletionResponseStreamChoice( - delta=DeltaMessage(), - index=0, - finish_reason=finish_reason, - ) - ], - usage=usage, - ) - except Exception as e: - logger.error( - f"Failed while handling chat-completions for request ({request_id}): {repr(e)}", - exc_info=e, - ) - yield get_response_for_error(e, request_id).error - else: - # Non-streaming processing - merge and return a single response - try: - results: LLMRawResponse = await self.merge_stream(gen) - if results.error: - yield results.error - return - - logprobs = None - if results.logprobs: - logprobs = ChatCompletionLogProbs( - content=[ - ChatCompletionLogProbsContent( - token=logprobs.token, - logprob=logprobs.logprob, - bytes=logprobs.bytes, - top_logprobs=[ - ChatCompletionLogProb( - token=logprob.token, - logprob=logprob.logprob, - bytes=logprob.bytes, - ) - for logprob in logprobs.top_logprobs - ], - ) - for logprobs in results.logprobs - ] - ) - - yield ChatCompletionResponse( - id=completion_id, - model=model, - choices=[ - ChatCompletionResponseChoice( - message=ChatMessage( - role="assistant", - content=results.generated_text or "", - ), - index=0, - finish_reason=results.finish_reason, - logprobs=logprobs, - ) - ], - usage=UsageInfo( - prompt_tokens=results.num_input_tokens or 0, - completion_tokens=results.num_generated_tokens or 0, - total_tokens=(results.num_input_tokens or 0) - + (results.num_generated_tokens or 0), - ), - ) - except Exception as e: - logger.error( - f"Failed while handling chat-completions for request ({request_id}): {repr(e)}", - exc_info=e, - ) - yield get_response_for_error(e, request_id).error - - async def process_completions( - self, model: str, gen: AsyncGenerator[LLMRawResponse, None], stream: bool - ) -> LLMCompletionsResponse: - """Process raw LLM responses into completions format.""" - gen = self.handle_failure(model=model, gen=gen) - request_id = get_serve_request_id() - completion_id = get_model_request_id(model) - - if stream: - # Stream processing - preserve batching from generator - all_results = [] - try: - async for batched_results in gen: - - for result in batched_results.unpack(): - all_results.append(result) - - # Handle errors - if result.error: - # Drop finish reason as OpenAI doesn't expect it for errors - result.finish_reason = None - logger.error( - f"Reporting back an error: {result.error}", - extra={ - "ray_serve_extra_fields": {"response": str(result)} - }, - ) - all_results.pop() - yield result.error - return - - # Calculate usage if finished - usage = None - if result.finish_reason: - merged_results = LLMRawResponse.merge_stream(*all_results) - usage = UsageInfo( - prompt_tokens=merged_results.num_input_tokens or 0, - completion_tokens=merged_results.num_generated_tokens - or 0, - total_tokens=(merged_results.num_input_tokens or 0) - + (merged_results.num_generated_tokens or 0), - ) - - chunk = CompletionStreamResponse( - id=completion_id, - model=model, - choices=[ - CompletionResponseStreamChoice( - text=result.generated_text or "", - index=0, - logprobs={}, - finish_reason=result.finish_reason, - ) - ], - usage=usage, - ) - - yield chunk - - except Exception as e: - logger.error( - f"Failed while handling completions for request ({request_id}): {repr(e)}", - exc_info=e, - ) - yield get_response_for_error(e, request_id).error - else: - # Non-streaming processing - merge and return a single response - try: - results: LLMRawResponse = await self.merge_stream(gen) - if results.error: - yield results.error - return - - yield CompletionResponse( - id=completion_id, - model=model, - choices=[ - CompletionResponseChoice( - text=results.generated_text or "", - index=0, - logprobs={}, - finish_reason=results.finish_reason, - ) - ], - usage=UsageInfo( - prompt_tokens=results.num_input_tokens or 0, - completion_tokens=results.num_generated_tokens or 0, - total_tokens=(results.num_input_tokens or 0) - + (results.num_generated_tokens or 0), - ), - ) - except Exception as e: - logger.error( - f"Failed while handling completions for request ({request_id}): {repr(e)}", - exc_info=e, - ) - yield get_response_for_error(e, request_id).error + It has a very similar API as the engine. Almost all of the abstractions are implemented by the engine. This class just a little bit more logic on top: + 1. Logic for serve multiplexing (e.g. LoRA loading). + 2. Request id handing from serve context. + 3. Batching in case of streaming (only for chat and completions). + 4. Telemetry reporting. + """ -class LLMServer(_LLMServerBase): _default_engine_cls = VLLMEngine async def __init__( @@ -414,7 +102,7 @@ async def __init__( llm_config: LLMConfig, *, engine_cls: Optional[Type[LLMEngine]] = None, - model_downloader: Optional[LoraModelLoader] = None, + model_downloader: Optional[Type[LoraModelLoader]] = None, ): """Constructor of LLMServer. @@ -425,10 +113,11 @@ async def __init__( llm_config: LLMConfig for the model. engine_cls: Dependency injection for the vllm engine class. Defaults to `VLLMEngine`. - model_downloader: Dependency injection for the model downloader - object. Defaults to be initialized with `LoraModelLoader`. + model_downloader: Dependency injection for the model downloader. + Defaults to `LoraModelLoader`. """ - await super().__init__(llm_config) + await super().__init__() + self._llm_config = llm_config self._engine_cls = engine_cls or self._get_default_engine_class() self.engine: Optional[LLMEngine] = None @@ -436,24 +125,37 @@ async def __init__( self.engine = self._engine_cls(self._llm_config) await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S) - multiplex_config = self._llm_config.multiplex_config() - if model_downloader: - self.model_downloader = model_downloader - elif multiplex_config: - self.model_downloader = LoraModelLoader( - download_timeout_s=multiplex_config.download_timeout_s, - max_tries=multiplex_config.max_download_tries, + self._init_multiplex_loader(model_downloader) + + def _init_multiplex_loader( + self, model_downloader_cls: Optional[Type[LoraModelLoader]] = None + ): + """Initialize the multiplex loader.""" + + model_downloader_cls = model_downloader_cls or LoraModelLoader + mx_config = self._llm_config.multiplex_config() + + if mx_config is not None: + model_downloader = model_downloader_cls( + download_timeout_s=mx_config.download_timeout_s, + max_tries=mx_config.max_download_tries, ) + + async def _load_model(lora_model_id: str) -> DiskMultiplexConfig: + return await model_downloader.load_model( + lora_model_id=lora_model_id, + llm_config=self._llm_config, + ) + + self._load_model = serve.multiplexed( + max_num_models_per_replica=mx_config.max_num_models_per_replica + )(_load_model) else: - self.model_downloader = LoraModelLoader() - # Hack that lets us set max_num_models_per_replica from the llm_config - if multiplex_config: - self.load_model = serve.multiplexed( - max_num_models_per_replica=multiplex_config.max_num_models_per_replica - )(lambda lora_model_id: self._load_model(lora_model_id)) + async def _load_model(lora_model_id: str) -> DiskMultiplexConfig: + raise ValueError("LoRA config is not set in the LLMConfig") - self.response_postprocessor = ResponsePostprocessor() + self._load_model = _load_model def _get_default_engine_class(self) -> Type[LLMEngine]: """Helper to load the engine class from the environment variable. @@ -474,40 +176,6 @@ async def _start_engine(self): # Push telemetry reports for the model in the current deployment. push_telemetry_report_for_all_models(all_models=[self._llm_config]) - async def _predict( - self, - request_id: str, - prompt: Prompt, - stream: bool, - ) -> AsyncGenerator[LLMRawResponse, None]: - """A thin wrapper around VLLMEngine.generate(). - - 1. Load the model to disk - 2. Format parameters correctly - 3. Forward request to VLLMEngine.generate() - """ - - logger.info(f"Received streaming request {request_id}") - multiplexed_model_id = serve.get_multiplexed_model_id() - - if multiplexed_model_id: - assert ( - self._llm_config.lora_config is not None - ), "Must setup lora config for multiplexed requests." - disk_lora_model = await self._disk_lora_model(multiplexed_model_id) - else: - disk_lora_model = None - - llm_request = await self.engine.prepare_request( - request_id=request_id, - prompt=prompt, - stream=stream, - disk_lora_model=disk_lora_model, - ) - - async for llm_response in self.engine.generate(llm_request): - yield llm_response - def _get_batch_interval_ms(self, stream: bool = True) -> int: """Calculate the batching interval for responses.""" stream_batching_interval_ms = self._llm_config.experimental_configs.get( @@ -517,80 +185,106 @@ def _get_batch_interval_ms(self, stream: bool = True) -> int: stream_batching_interval_ms = MODEL_RESPONSE_BATCH_TIMEOUT_MS return stream_batching_interval_ms if stream else None - def _process_llm_request( - self, request: Union[ChatCompletionRequest, CompletionRequest], is_chat: bool - ) -> Union[LLMChatResponse, LLMCompletionsResponse]: - """Common processing pipeline for both chat and completions APIs. + async def _maybe_add_request_id_to_request( + self, request: Union[ChatCompletionRequest, CompletionRequest, EmbeddingRequest] + ): + """Add the request id to the request.""" + request_id = get_serve_request_id() + if request_id: + request.request_id = request_id + + async def _maybe_resolve_lora_from_multiplex(self) -> None: + """Handle the lora model for the request.""" + multiplexed_model_id = serve.get_multiplexed_model_id() + if multiplexed_model_id: + if self._llm_config.lora_config is None: + raise ValueError("Must setup lora config for multiplexed requests.") + disk_lora_model = await self._load_model(multiplexed_model_id) + await self.engine.resolve_lora(disk_lora_model) + + def _batch_output_stream(self, generator): + return OpenAIResponseBatcher( + generator, + interval_ms=self._get_batch_interval_ms(), + ).stream() + + async def _run_request( + self, + request: Union[ChatCompletionRequest, CompletionRequest, EmbeddingRequest], + *, + engine_method: str, + batch_output_stream: bool = False, + ) -> AsyncGenerator[Any, None]: + """Run the engine method on the request + perform batching when stream=True. Args: - request: Either a ChatCompletionRequest or CompletionRequest object - is_chat: Whether this is a chat request (True) or completions request (False) + request: The request to run. + engine_method: The method to call on the engine. + batch_output_stream: Whether to batch the output stream. Returns: - A generator of response objects (either chat completion or text completion) + An AsyncGenerator of the response. If stream is True and batching is enabled, then the generator will yield a list of streaming responses (strings of the format data: {response_json}\n\n). Otherwise, it will yield the non-streaming response from engine directly. """ - request_id = get_serve_request_id() + await self._maybe_add_request_id_to_request(request) + await self._maybe_resolve_lora_from_multiplex() - # 1. Construct the appropriate prompt based on request type - if is_chat: - prompt = Prompt( - prompt=[ - Message.model_validate(message) for message in request.messages - ], - parameters=request, + is_stream = hasattr(request, "stream") and request.stream + if is_stream and batch_output_stream: + stream = self._batch_output_stream( + getattr(self.engine, engine_method)(request) ) else: - prompt = Prompt( - prompt=request.prompt, - parameters=request, - use_prompt_format=False, - ) - - # 2. Predict using the engine - gen = self._predict(request_id=request_id, prompt=prompt, stream=request.stream) + stream = getattr(self.engine, engine_method)(request) - # 3. Convert raw LLM responses to OpenAI format - processor_method = ( - self.response_postprocessor.process_chat - if is_chat - else self.response_postprocessor.process_completions - ) - openai_resp_generator = processor_method( - model=self._llm_config.model_id, gen=gen, stream=request.stream - ) + return stream - if request.stream: - # 4. Apply batching with appropriate interval in case of streaming - batched_openai_response_stream = OpenAIResponseBatcher( - openai_resp_generator, - interval_ms=self._get_batch_interval_ms(), - ) - - return batched_openai_response_stream.stream() - - return openai_resp_generator - - async def chat(self, request: ChatCompletionRequest) -> LLMChatResponse: + async def chat( + self, request: ChatCompletionRequest + ) -> AsyncGenerator[Union[List[str], ChatCompletionResponse], None]: """Runs a chat request to the LLM engine and returns the response. Args: request: A ChatCompletionRequest object. Returns: - A LLMChatResponse object. + An AsyncGenerator of the response. If stream is True and batching is enabled, then the generator will yield a list of chat streaming responses (strings of the format data: {response_json}\n\n). Otherwise, it will yield the ChatCompletionResponse object directly. """ - return self._process_llm_request(request, is_chat=True) + return await self._run_request( + request, engine_method="chat", batch_output_stream=True + ) - async def completions(self, request: CompletionRequest) -> LLMCompletionsResponse: + async def completions( + self, request: CompletionRequest + ) -> AsyncGenerator[Union[List[str], CompletionResponse], None]: """Runs a completion request to the LLM engine and returns the response. Args: request: A CompletionRequest object. Returns: - A LLMCompletionsResponse object. + An AsyncGenerator of the response. If stream is True and batching is enabled, then the generator will yield a list of completion streaming responses (strings of the format data: {response_json}\n\n). Otherwise, it will yield the CompletionResponse object directly. + """ + return await self._run_request( + request, engine_method="completions", batch_output_stream=True + ) + + async def embeddings( + self, request: EmbeddingRequest + ) -> AsyncGenerator[EmbeddingResponse, None]: + """Runs an embeddings request to the engine and returns the response. + + Returns an AsyncGenerator over the EmbeddingResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, and embeddings. + + Args: + request: An EmbeddingRequest object. + + Returns: + An AsyncGenerator over the EmbeddingResponse object. """ - return self._process_llm_request(request, is_chat=False) + # NOTE: Embeddings does not need batching. + return await self._run_request( + request, engine_method="embeddings", batch_output_stream=False + ) async def check_health(self) -> None: """ @@ -605,68 +299,12 @@ async def check_health(self) -> None: logger.error("Engine health check failed in LLMServer.check_health: %s", e) raise e - async def embeddings(self, request: EmbeddingRequest) -> LLMEmbeddingsResponse: - """Runs an embeddings request to the vllm engine, and return the response. - - Args: - request: An EmbeddingRequest object. - - Returns: - A LLMEmbeddingsResponse object. - """ - request_id = get_serve_request_id() - try: - multiplexed_model_id = serve.get_multiplexed_model_id() - - if multiplexed_model_id: - assert ( - self._llm_config.lora_config is not None - ), "Must setup lora config for multiplexed requests." - disk_lora_model = await self._disk_lora_model(multiplexed_model_id) - else: - disk_lora_model = None - - request_params = { - "request_id": request_id, - "prompt": request.input, - "encoding_format": request.encoding_format, - "disk_multiplex_config": disk_lora_model, - "serve_request_context": serve.context._serve_request_context.get(), - } - vllm_request = VLLMEmbeddingRequest(**request_params) - embedding_data, total_tokens = await self.engine.embed(vllm_request) - - data = [ - EmbeddingResponseData( - object="embedding", index=index, embedding=embedding - ) - for index, embedding in enumerate(embedding_data) - ] - - usage = UsageInfo(prompt_tokens=total_tokens, total_tokens=total_tokens) - - yield EmbeddingResponse( - model=self._llm_config.model_id, data=data, usage=usage, object="list" - ) - except Exception as e: - logger.error( - f"Failed while handling embeddings for request ({request_id}): {repr(e)}", - exc_info=e, - ) - - async def _load_model(self, lora_model_id: str) -> DiskMultiplexConfig: - return await self.model_downloader.load_model( - lora_model_id=lora_model_id, - llm_config=self._llm_config, - ) - - async def _disk_lora_model(self, lora_model_id: str) -> DiskMultiplexConfig: - disk_lora_model: DiskMultiplexConfig = await self.load_model(lora_model_id) - return disk_lora_model + async def llm_config(self) -> Optional[LLMConfig]: + return self._llm_config @classmethod def as_deployment( - cls, deployment_options: Dict[str, Any] = None + cls, deployment_options: Optional[Dict[str, Any]] = None ) -> serve.Deployment: """Convert the LLMServer to a Ray Serve deployment. diff --git a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py index c0ef0cff357e..ec8c9500f3de 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py +++ b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py @@ -1,82 +1,56 @@ import os -import re -import time import uuid -from concurrent.futures.thread import ThreadPoolExecutor -from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Tuple +import argparse +from starlette.datastructures import State + +from typing import TYPE_CHECKING, AsyncGenerator, Tuple, Union from transformers.dynamic_module_utils import init_hf_modules import ray from ray.llm._internal.common.utils.import_utils import try_import -from ray.llm._internal.serve.configs.constants import ( - MAX_NUM_TOPLOGPROBS_ALLOWED, - MIN_NUM_TOPLOGPROBS_ALLOWED, - RAYLLM_ENABLE_REQUEST_PROMPT_LOGS, - RAYLLM_GUIDED_DECODING_BACKEND, -) -from ray.llm._internal.serve.configs.error_handling import ( - InputTooLong, - ValidationError, +from ray.llm._internal.serve.configs.openai_api_models import ( + CompletionRequest, + CompletionResponse, + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, ) + from ray.llm._internal.serve.configs.server_models import ( DiskMultiplexConfig, - FinishReason, - GenerationRequest, LLMConfig, - LLMRawResponse, - LogProb, - LogProbs, - Prompt, ) +from transformers.dynamic_module_utils import init_hf_modules + from ray.llm._internal.serve.deployments.llm.llm_engine import LLMEngine from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine_stats import ( - ArgUsage, VLLMEngineStatTracker, - usage_counters, ) from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( - KV_TRANSFER_PARAMS_KEY, - VLLMEmbeddingRequest, VLLMEngineConfig, - VLLMGenerationRequest, - VLLMSamplingParams, ) from ray.llm._internal.serve.deployments.utils.node_initialization_utils import ( InitializeNodeOutput, initialize_node, ) -from ray.llm._internal.serve.deployments.utils.server_utils import floats_to_base64 from ray.llm._internal.serve.observability.logging import get_logger -from ray.llm._internal.serve.observability.metrics.utils import ( - LONG_RANGE_LATENCY_HISTOGRAM_BUCKETS_MS, - ClockUnit, - MsClock, -) -from ray.util import metrics from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from vllm.entrypoints.openai.cli_args import FrontendArgs +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.openai.protocol import ErrorResponse as VLLMErrorResponse + if TYPE_CHECKING: - from vllm import SamplingParams as VLLMInternalSamplingParams - from vllm.config import ModelConfig, VllmConfig - from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.config import VllmConfig from vllm.engine.protocol import EngineClient - from vllm.outputs import PoolingRequestOutput, RequestOutput vllm = try_import("vllm") logger = get_logger(__name__) -time_in_queue_histogram = metrics.Histogram( - "vllm_engine_stats_time_in_queue_ms", - "Time a request spends in the queue first forward pass not included (ms).", - boundaries=LONG_RANGE_LATENCY_HISTOGRAM_BUCKETS_MS, -) - -V1_TOO_LONG_PATTERN = re.compile( - r".* (\d+).* is longer than the maximum model length of (\d+).*" -) - def _get_vllm_engine_config( llm_config: LLMConfig, @@ -179,86 +153,37 @@ def __init__( port = vllm_envs.VLLM_NIXL_SIDE_CHANNEL_PORT kv_transfer_config.engine_id = "-".join([engine_id, host, str(port)]) - assert isinstance( - llm_config, LLMConfig - ), f"Got invalid config {llm_config} of type {type(llm_config)}" - self.llm_config = llm_config - + # TODO (Kourosh): What do we do with this stats tracker? self._stats = VLLMEngineStatTracker() self._running = False - self.model_config: "ModelConfig" = None - self._engine_client = None - self.vllm_config: "VllmConfig" = None - - # Chat template content format (openai or string) - self._resolved_content_format = None - # Also need local instance of the tokenizer to manage prompt formatting. - self._tokenizer = None - - self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) - self._atokenize = vllm_utils.make_async( - self._tokenize, executor=self._tokenizer_executor - ) - def _tokenize( - self, prompt_text: str, add_special_tokens: bool = False - ) -> List[int]: - encoded = self._tokenizer(prompt_text, add_special_tokens=add_special_tokens) - return encoded.input_ids + # vLLM Integration points. Will be set through .start() + self._engine_client = None + self._oai_models = None + self._oai_serving_chat = None + self._oai_serving_completion = None + self._oai_serving_embedding = None async def start(self) -> None: """Start the vLLM engine. If the engine is already running, do nothing. """ - from vllm.entrypoints.chat_utils import ( - resolve_chat_template_content_format as _resolve_chat_template_content_format, - ) if self._running: # The engine is already running! logger.info("Skipping engine restart because the engine is already running") return - self._engine_client = await self._start_engine() - self._running = True - self.model_config = await self._engine_client.get_model_config() - - self._tokenizer = await self._engine_client.get_tokenizer() - - def resolve_chat_template_content_format(model_config, **kwargs): - try: - return _resolve_chat_template_content_format( - model_config=model_config, **kwargs - ) - except TypeError: - # Legacy API before vLLM 0.9.0. - # TODO(#52975): Remove this try-except once vLLM <0.9.0 is no longer supported. - return _resolve_chat_template_content_format( - trust_remote_code=model_config.trust_remote_code, **kwargs - ) - - self._resolved_content_format = resolve_chat_template_content_format( - model_config=self.model_config, - # Use HF to get the chat template so set it to None here. - chat_template=None, - # Default to None, change when it's needed. - # vLLM does not have a high level API to support all of this. - tools=None, - # Let vLLM decide the content format. - given_format="auto", - tokenizer=self._tokenizer, - ) - - logger.info("Started vLLM engine.") + from vllm.entrypoints.openai.api_server import init_app_state - async def _start_engine(self) -> "EngineClient": - # Initialize node and return all configurations node_initialization = await initialize_node(self.llm_config) - vllm_engine_args, vllm_engine_config = await self._prepare_engine_config( - node_initialization - ) + ( + vllm_engine_args, + vllm_frontend_args, + vllm_engine_config, + ) = self._prepare_engine_config(node_initialization) # Apply checkpoint info to the llm_config. # This is needed for capturing model capabilities @@ -269,22 +194,59 @@ async def _start_engine(self) -> "EngineClient": trust_remote_code=config.trust_remote_code, ) - return self._start_async_llm_engine( + self._engine_client = self._start_async_llm_engine( vllm_engine_args, vllm_engine_config, node_initialization.placement_group, ) - async def _prepare_engine_config(self, node_initialization: InitializeNodeOutput): - """Prepare the engine config to start the engine. + state = State() + args = argparse.Namespace( + **vllm_frontend_args.__dict__, + **vllm_engine_args.__dict__, + ) - Args: - node_initialization: The node initialization. + await init_app_state( + engine_client=self._engine_client, + vllm_config=vllm_engine_config, + state=state, + args=args, + ) + + self._oai_models = state.openai_serving_models + self._oai_serving_chat = state.openai_serving_chat + self._oai_serving_completion = state.openai_serving_completion + self._oai_serving_embedding = state.openai_serving_embedding + + self._validate_openai_serving_models() + + self._running = True + + logger.info("Started vLLM engine.") + + def _validate_openai_serving_models(self): + if not hasattr(self._oai_models, "lora_requests"): + raise ValueError("oai_models must have a lora_requests attribute") + + if not hasattr(self._oai_models, "load_lora_adapter"): + raise ValueError("oai_models must have a load_lora_adapter attribute") + + def _validate_openai_serving_chat(self): + if not hasattr(self._oai_serving_chat, "create_chat_completion"): + raise ValueError( + "oai_serving_chat must have a create_chat_completion attribute" + ) + + def _prepare_engine_config(self, node_initialization: InitializeNodeOutput): + """Prepare the engine config to start the engine. Returns: engine_args: The vLLM's internal engine arguments that is flattened. + frontend_args: The vLLM's internal frontend arguments that is + flattened. engine_config: The vLLM's internal engine config that is nested. """ + engine_config: VLLMEngineConfig = self.llm_config.get_engine_config() if engine_config.use_gpu: @@ -310,29 +272,27 @@ async def _prepare_engine_config(self, node_initialization: InitializeNodeOutput self.llm_config ) - # Note (genesu): vllm_config is used to extract the scheduler config for - # computing the correct prompt limit. - self.vllm_config = vllm_engine_config - return vllm_engine_args, vllm_engine_config + vllm_frontend_args = FrontendArgs(**engine_config.frontend_kwargs) + return vllm_engine_args, vllm_frontend_args, vllm_engine_config def _start_async_llm_engine_v0( self, - vllm_engine_args: "AsyncEngineArgs", - vllm_engine_config: "VllmConfig", + engine_args: "AsyncEngineArgs", + vllm_config: "VllmConfig", placement_group: PlacementGroup, ) -> "EngineClient": - from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.executor.ray_distributed_executor import RayDistributedExecutor + from vllm.engine.async_llm_engine import AsyncLLMEngine - vllm_engine_config.parallel_config.placement_group = placement_group + vllm_config.parallel_config.placement_group = placement_group _clear_current_platform_cache() engine_client = AsyncLLMEngine( - vllm_config=vllm_engine_config, + vllm_config=vllm_config, executor_class=RayDistributedExecutor, - log_stats=not vllm_engine_args.disable_log_stats, + log_stats=not engine_args.disable_log_stats, ) return engine_client @@ -352,8 +312,8 @@ def _start_async_llm_engine( vllm_engine_args, vllm_engine_config, placement_group ) - from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.executor.abstract import Executor + from vllm.v1.engine.async_llm import AsyncLLM vllm_engine_config.parallel_config.placement_group = placement_group @@ -380,301 +340,126 @@ def _start_async_llm_engine( return engine_client - async def prepare_request( - self, - request_id: str, - prompt: Prompt, - stream: bool, - disk_lora_model: Optional[DiskMultiplexConfig] = None, - ) -> GenerationRequest: - from vllm.entrypoints.chat_utils import ( - apply_hf_chat_template as _apply_hf_chat_template, - parse_chat_messages_futures, - ) + async def resolve_lora(self, disk_lora_model: DiskMultiplexConfig): + from vllm.entrypoints.openai.protocol import LoadLoRAAdapterRequest - model_config = self.model_config - mm_data = None - - if isinstance(prompt.prompt, list): - messages = [m.model_dump() for m in prompt.prompt] - conversation, mm_futures = parse_chat_messages_futures( - messages=messages, - model_config=model_config, - tokenizer=self._tokenizer, - content_format=self._resolved_content_format, - ) - mm_data = await mm_futures - - def apply_hf_chat_template(model_config, **kwargs): - try: - return _apply_hf_chat_template(model_config=model_config, **kwargs) - except TypeError: - # Legacy API before vLLM 0.9.0. - # TODO(#52975): Remove above once vLLM <0.9.0 is no longer supported. - return _apply_hf_chat_template( - trust_remote_code=model_config.trust_remote_code, **kwargs - ) + if disk_lora_model.model_id in self._oai_models.lora_requests: + # Lora is already loaded, return + return - prompt_text = apply_hf_chat_template( - model_config=model_config, - tokenizer=self._tokenizer, - conversation=conversation, - chat_template=None, - tools=None, - tokenize=False, - # **kwargs for tokenizer.apply_chat_template - trust_remote_code=model_config.trust_remote_code, - add_generation_prompt=True, - continue_final_message=False, + lora_request = await self._oai_models.load_lora_adapter( + request=LoadLoRAAdapterRequest( + lora_name=disk_lora_model.model_id, + lora_path=disk_lora_model.local_path, ) - else: - prompt_text = prompt.prompt - - prompt_token_ids = await self._atokenize(prompt_text) + ) - request_params = { - "prompt": prompt_text, - "prompt_token_ids": prompt_token_ids, - "request_id": request_id, - "sampling_params": VLLMSamplingParams.from_prompt(prompt), - "disk_multiplex_config": disk_lora_model, - "stream": stream, - } - if mm_data: - request_params["multi_modal_data"] = mm_data + if isinstance(lora_request, VLLMErrorResponse): + raise ValueError(f"Failed to load lora model: {lora_request.message}") - vllm_request = VLLMGenerationRequest(**request_params) - return vllm_request + async def chat( + self, request: ChatCompletionRequest + ) -> AsyncGenerator[Union[str, ChatCompletionResponse, ErrorResponse], None]: + """ - async def generate( - self, request: GenerationRequest - ) -> AsyncGenerator[LLMRawResponse, None]: - """Generate an LLMRawResponse stream + input: Take a genric free form input type and cast it to the target engine request type inside the engine. - The vLLM generation request will be passed into vLLM, and the resulting output - will be wrapped in an LLMRawResponse and yielded back to the user. + output: + - stream: True --> for each chunk, yield astring representing data: \n\n + - stream: False --> yield only one string representing the response - Error handling: + Error: + option A: + when request hits an error, raise an HTTPException(msg, code, type) + option B: + yield a HTTPException object + """ - We schedule a finalizer that will abort the request on the engine. + self._validate_openai_serving_chat() - If an exception is raised in this function or vllm, the finalizer guarantees that the request is aborted. - If an exception is raised in the caller, when this generator is gced, it will run the finalizer and abort the request. + chat_response = await self._oai_serving_chat.create_chat_completion(request) - This should also handle the case where the caller is cancelled (raises asyncio.CancelledError) - """ - if RAYLLM_ENABLE_REQUEST_PROMPT_LOGS: - logger.info( - f"Request {request.request_id} started. " f"Prompt: {request.prompt}" - ) - - if request.prompt_token_ids is not None: - prompt = vllm.TokensPrompt( - prompt_token_ids=request.prompt_token_ids, - multi_modal_data=request.multi_modal_data, - ) + if isinstance(chat_response, AsyncGenerator): + async for response in chat_response: + if not isinstance(response, str): + raise ValueError( + f"Expected create_chat_completion to return a stream of strings, got and item with type {type(response)}" + ) + yield response else: - prompt = vllm.TextPrompt( - prompt=request.prompt, - multi_modal_data=request.multi_modal_data, + logger.info( + f"[Kourosh] non streaming response received, type: {type(chat_response)}, chat_response: {chat_response}" ) + if isinstance(chat_response, VLLMErrorResponse): + yield ErrorResponse(**chat_response.model_dump()) + yield ChatCompletionResponse(**chat_response.model_dump()) - # Construct a results generator from vLLM - results_generator: AsyncGenerator[ - "RequestOutput", None - ] = self._engine_client.generate( - prompt=prompt, - sampling_params=self._parse_sampling_params(request.sampling_params), - request_id=request.request_id, - lora_request=request.lora_request, # type: ignore - ) + async def completions( + self, request: CompletionRequest + ) -> AsyncGenerator[Union[str, CompletionResponse, ErrorResponse], None]: + """ - # Loop over the results - num_text_returned = 0 - all_tokens_collected = 0 - clock = MsClock(unit=ClockUnit.s) - log_probs_idx = 0 - finish_reason = None - num_input_tokens = 0 - try: - start = time.perf_counter() - request_output = None - async for request_output in self._stats.auto_track(results_generator): - # TODO(tchordia): handle more than one output - assert ( - len(request_output.outputs) == 1 - ), "Received more than 1 output from vllm, aborting" - - output = request_output.outputs[0] - text_output = output.text[num_text_returned:] - num_text_returned += len(text_output) - num_input_tokens = len(request_output.prompt_token_ids) - tokens_collected = len(output.token_ids) - all_tokens_collected - all_tokens_collected += tokens_collected - finish_reason = FinishReason.from_vllm_finish_reason( - output.finish_reason - ) + input: Take a generic free form input type and cast it to the target engine request type inside the engine. - self._handle_input_too_long(request_output, finish_reason) + output: + - stream: True --> for each chunk, yield a string representing data: \n\n + - stream: False --> yield only one string representing the response - log_probs, log_probs_idx = self._extract_logprobs( - output, - log_probs_idx, - request.sampling_params.top_logprobs, - ) - internal_metadata = {} - if getattr(request_output, "kv_transfer_params", None) is not None: - internal_metadata[ - KV_TRANSFER_PARAMS_KEY - ] = request_output.kv_transfer_params - yield LLMRawResponse( - generated_text=text_output, - num_generated_tokens=tokens_collected, - logprobs=log_probs, - num_generated_tokens_batch=tokens_collected, - num_input_tokens=num_input_tokens, - num_input_tokens_batch=num_input_tokens, - preprocessing_time=0, - generation_time=clock.reset_interval(), - finish_reason=finish_reason, - metadata=internal_metadata, - ) + Error: + option A: + when request hits an error, raise an HTTPException(msg, code, type) + option B: + yield a HTTPException object + """ - if request_output is not None: - total_request_time = time.perf_counter() - start - if request_output.metrics is None: - # vLLM V1 metrics are not included in the request output yet. - queue_time = "N/A" - generation_time_str = "N/A" - tokens_s = "N/A" - generated_tokens_s = "N/A" - else: - time_in_queue_histogram.observe( - request_output.metrics.time_in_queue - ) - queue_time = f"{request_output.metrics.time_in_queue}s" - generation_time = ( - total_request_time - request_output.metrics.time_in_queue - ) - generation_time_str = f"{generation_time}s" - tokens_s = ( - num_input_tokens + all_tokens_collected - ) / generation_time - generated_tokens_s = all_tokens_collected / generation_time - - logger.info( - f"Request {request.request_id} finished ({finish_reason}). " - f"Total time: {total_request_time}s, " - f"Queue time: {queue_time}, " - f"Generation+async time: {generation_time_str}, " - f"Input tokens: {num_input_tokens}, " - f"Generated tokens: {all_tokens_collected}, " - f"tokens/s: {tokens_s}, " - f"generated tokens/s: {generated_tokens_s}." - ) - else: - logger.warning( - f"Request {request.request_id} " - "finished without any output. " - f"Input tokens: {num_input_tokens}." - ) - except ValueError as e: - error_args = e.args - if len(error_args) == 3 and "Input too long." == error_args[0]: - _, input_length, max_input_length = error_args - raise InputTooLong(input_length, max_input_length).exception from None - elif len(error_args) == 1 and V1_TOO_LONG_PATTERN.match(error_args[0]): - parsed_error = V1_TOO_LONG_PATTERN.match(error_args[0]) - raise InputTooLong( - int(parsed_error[1]), int(parsed_error[2]) - ).exception from None - else: - raise e from None - finally: - # Ensure that we cancel on the engine once we have exited the streaming - # phase - await self._engine_client.abort(request.request_id) + if self._oai_serving_completion is None: + raise RuntimeError( + "Completion service is not available. Make sure the engine is started and supports completions." + ) - def _get_prompt_limit(self) -> int: - """Helper to get the prompt limit from scheduler config + completion_response = await self._oai_serving_completion.create_completion( + request + ) - Port from https://github.com/vllm-project/vllm/blob/7b5ecf79bd94aab0d782c70126d0dcc37c16bc60/vllm/core/scheduler.py#L939 - """ - scheduler_config = self.vllm_config.scheduler_config - if ( - scheduler_config.chunked_prefill_enabled - and not scheduler_config.is_multi_step - ): - prompt_limit = scheduler_config.max_model_len + if isinstance(completion_response, AsyncGenerator): + async for response in completion_response: + if not isinstance(response, str): + raise ValueError( + f"Expected create_completion to return a stream of strings, got and item with type {type(response)}" + ) + yield response else: - prompt_limit = min( - scheduler_config.max_model_len, - scheduler_config.max_num_batched_tokens, - ) - return prompt_limit - - def _handle_input_too_long( - self, request_output: "RequestOutput", finish_reason: Optional[FinishReason] - ): - if ( - finish_reason - and finish_reason == FinishReason.LENGTH - and hasattr(request_output.metrics, "first_token_time") - and request_output.metrics.first_token_time is None - ): - # This means that the prompt was too long and we did not generate anything. - raise InputTooLong( - len(request_output.prompt_token_ids), self._get_prompt_limit() - ).exception - - async def embed( - self, vllm_embedding_request: VLLMEmbeddingRequest - ) -> Tuple[List[List[float]], int]: - """Return (embeddings, num_prompt_tokens)""" - - num_prompts = len(vllm_embedding_request.prompt) - if RAYLLM_ENABLE_REQUEST_PROMPT_LOGS: logger.info( - f"Encoding request {vllm_embedding_request.request_id} started. " - f"Num prompts: {num_prompts}" + f"[Kourosh] non streaming response received, type: {type(completion_response)}, completion_response: {completion_response}" ) + if isinstance(completion_response, VLLMErrorResponse): + yield ErrorResponse(**completion_response.model_dump()) + else: + yield CompletionResponse(**completion_response.model_dump()) - generators: List[AsyncGenerator["PoolingRequestOutput", None]] = [] - - prompts = vllm_embedding_request.prompt - if isinstance(prompts, str): - prompts = [prompts] - - for i, prompt in enumerate(prompts): - request_id = f"{vllm_embedding_request.request_id}-{i}" - gen: AsyncGenerator[ - "PoolingRequestOutput", None - ] = self._engine_client.encode( - prompt=vllm.TextPrompt( - prompt=prompt, - ), - pooling_params=vllm.PoolingParams(), - request_id=request_id, - lora_request=vllm_embedding_request.lora_request, # type: ignore - ) - generators.append(gen) + async def embeddings( + self, request: EmbeddingRequest + ) -> AsyncGenerator[Union[EmbeddingResponse, ErrorResponse], None]: + """Generate embeddings using vLLM's OpenAI-compatible API. + + Args: + request: An EmbeddingRequest object. - embedding_data = [] - total_prompt_tokens = 0 + Yields: + An EmbeddingResponse or ErrorResponse object. + """ - for gen in generators: - async for result in gen: - if hasattr(result.outputs, "embedding"): - embedding = result.outputs.embedding - else: - embedding = result.outputs.data.tolist() - if vllm_embedding_request.encoding_format == "base64": - embedding = floats_to_base64(embedding) + if self._oai_serving_embedding is None: + raise RuntimeError( + "Embedding service is not available. Make sure the engine is started and supports embeddings." + ) - embedding_data.append(embedding) - total_prompt_tokens += len(result.prompt_token_ids) + embedding_response = await self._oai_serving_embedding.create_embedding(request) - return embedding_data, total_prompt_tokens + if isinstance(embedding_response, VLLMErrorResponse): + yield ErrorResponse(**embedding_response.model_dump()) + else: + yield EmbeddingResponse(**embedding_response.model_dump()) async def check_health(self) -> None: if not hasattr(self._engine_client, "check_health"): @@ -687,163 +472,3 @@ async def check_health(self) -> None: except BaseException as e: logger.error("Healthcheck failed. The replica will be restarted") raise e from None - - @staticmethod - def _collect_usage_metrics(sampling_params: VLLMSamplingParams) -> None: - if sampling_params.best_of is not None: - usage_counters[ArgUsage.BEST_OF].inc() - - if sampling_params.presence_penalty is not None: - usage_counters[ArgUsage.PRESENCE_PENALTY].inc() - - if sampling_params.frequency_penalty is not None: - usage_counters[ArgUsage.FREQUENCY_PENALTY].inc() - - if ( - sampling_params.presence_penalty is not None - and sampling_params.frequency_penalty is not None - ): - usage_counters[ArgUsage.PRESENCE_AND_FREQUENCY_PENALTY].inc() - - if sampling_params.temperature is not None: - usage_counters[ArgUsage.TEMPERATURE].inc() - - if sampling_params.top_p is not None: - usage_counters[ArgUsage.TOP_P].inc() - - if sampling_params.top_k is not None: - usage_counters[ArgUsage.TOP_K].inc() - - if sampling_params.stop is not None: - usage_counters[ArgUsage.STOP].inc() - - if sampling_params.max_tokens is not None: - usage_counters[ArgUsage.MAX_TOKENS].inc() - - if sampling_params.logprobs is not None: - usage_counters[ArgUsage.LOGPROBS].inc() - - def _parse_sampling_params( - self, sampling_params: VLLMSamplingParams - ) -> "VLLMInternalSamplingParams": - """Parse the vllm sampling parameters from the prompt. - This function is used to parse the sampling parameters from the prompt. - It also collects the usage metrics for the sampling parameters. - Args: - sampling_params: The sampling parameters defined in ray.serve.llm. - Returns: - vllm.SamplingParams, The parsed sampling parameters. - """ - self._collect_usage_metrics(sampling_params) - try: - if self.model_config is None: - raise RuntimeError( - "VLLMEngine.model_config not set. Maybe VLLMEngine.start() was not called?" - ) - - log_probs = None - if sampling_params.logprobs: - max_logprobs = getattr(self.model_config, "max_logprobs", 0) - max_logprobs = min(MAX_NUM_TOPLOGPROBS_ALLOWED, max_logprobs) - if max_logprobs == 0: - raise ValueError("This model doesn't support outputting logprobs.") - if sampling_params.top_logprobs: - if not ( - MIN_NUM_TOPLOGPROBS_ALLOWED - <= sampling_params.top_logprobs - <= max_logprobs - ): - raise ValueError( - f"top_logprobs must be between {MIN_NUM_TOPLOGPROBS_ALLOWED} " - f"and {max_logprobs}. Got {sampling_params.top_logprobs}." - ) - log_probs = sampling_params.top_logprobs - else: - log_probs = 1 - else: - if sampling_params.top_logprobs: - raise ValueError( - "if top_logprobs is specified, logprobs must be set to `True`" - ) - - kwargs = dict( - n=1, - best_of=sampling_params.best_of, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - temperature=1.0, - top_p=1.0, - top_k=-1, - stop=sampling_params.stop, - stop_token_ids=sampling_params.stop_tokens, - ignore_eos=False, - # vLLM will cancel internally if input+output>max_tokens - max_tokens=self.model_config.max_model_len, - logprobs=log_probs, - ) - if sampling_params.presence_penalty is not None: - kwargs["presence_penalty"] = sampling_params.presence_penalty - if sampling_params.frequency_penalty is not None: - kwargs["frequency_penalty"] = sampling_params.frequency_penalty - if sampling_params.repetition_penalty is not None: - kwargs["repetition_penalty"] = sampling_params.repetition_penalty - if sampling_params.temperature is not None: - kwargs["temperature"] = sampling_params.temperature - if sampling_params.top_p is not None: - kwargs["top_p"] = sampling_params.top_p - if sampling_params.top_k is not None: - kwargs["top_k"] = sampling_params.top_k - if sampling_params.ignore_eos is not None: - kwargs["ignore_eos"] = sampling_params.ignore_eos - if sampling_params.max_tokens is not None: - kwargs["max_tokens"] = sampling_params.max_tokens - # If we set it to None, vLLM will throw an exception - # as that is not the default value. Omitting it - # will allow vLLM to generate a new seed internally, - # as expected. - if sampling_params.seed is not None: - kwargs["seed"] = sampling_params.seed - if sampling_params.response_format is not None: - kwargs[ - "guided_decoding" - ] = sampling_params.response_format.to_guided_decoding_params( - backend=RAYLLM_GUIDED_DECODING_BACKEND - ) - if sampling_params.kv_transfer_params is not None: - kwargs["extra_args"] = { - KV_TRANSFER_PARAMS_KEY: sampling_params.kv_transfer_params - } - - return vllm.SamplingParams(**kwargs) - except Exception as e: - # Wrap the error in ValidationError so the status code - # returned to the user is correct. - raise ValidationError(str(e)) from e - - @staticmethod - def _extract_logprobs( - output: "RequestOutput", - log_probs_idx: int, - top_logprobs: Optional[int] = None, - ) -> Tuple[List[LogProbs], int]: - all_log_probs = output.logprobs[log_probs_idx:] if output.logprobs else None - return_log_probs = [] - if all_log_probs: - for log_probs in all_log_probs: - log_probs_for_n_sampled = [ - LogProb( - logprob=log_prob.logprob, - token=log_prob.decoded_token, - bytes=list(log_prob.decoded_token.encode()), - ) - for log_prob in log_probs.values() - if log_prob.decoded_token is not None - ] - if log_probs_for_n_sampled: - return_log_probs += [ - LogProbs.create( - logprobs=log_probs_for_n_sampled, top_logprobs=top_logprobs - ) - ] - return return_log_probs, log_probs_idx + len(return_log_probs) diff --git a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py index f0c79e636e23..9dac86c7f7ee 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py +++ b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py @@ -1,8 +1,9 @@ import dataclasses import os -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional +import dataclasses -from pydantic import ConfigDict, Field, ValidationError, field_validator +from pydantic import ConfigDict, Field from vllm.engine.arg_utils import AsyncEngineArgs from ray.llm._internal.common.base_pydantic import BaseModelExtended @@ -13,13 +14,9 @@ ENV_VARS_TO_PROPAGATE, RAYLLM_GUIDED_DECODING_BACKEND, ) -from ray.llm._internal.serve.configs.prompt_formats import Prompt from ray.llm._internal.serve.configs.server_models import ( - DiskMultiplexConfig, - GenerationRequest, GPUType, LLMConfig, - SamplingParams, ) from ray.llm._internal.serve.observability.logging import get_logger from ray.util.placement_group import ( @@ -29,6 +26,9 @@ placement_group_table, ) +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.openai.cli_args import FrontendArgs + # The key for the kv_transfer_params in the internal metadata. KV_TRANSFER_PARAMS_KEY = "kv_transfer_params" @@ -77,10 +77,6 @@ def actual_hf_model_id(self) -> str: def trust_remote_code(self) -> bool: return self.engine_kwargs.get("trust_remote_code", False) - @property - def sampling_params_model(self): - return VLLMSamplingParams - def get_initialization_kwargs(self) -> dict: """ Get kwargs that will be actually passed to the LLMInitializer @@ -112,9 +108,6 @@ def get_initialization_kwargs(self) -> dict: ) engine_kwargs["disable_log_stats"] = False - if "guided_decoding_backend" not in engine_kwargs: - engine_kwargs["guided_decoding_backend"] = RAYLLM_GUIDED_DECODING_BACKEND - return engine_kwargs def get_runtime_env_with_local_env_vars(self) -> dict: @@ -145,17 +138,20 @@ def from_llm_config(cls, llm_config: LLMConfig) -> "VLLMEngineConfig": frontend_kwargs = {} # Get field names from dataclasses + frontend_field_names = { + field.name for field in dataclasses.fields(FrontendArgs) + } async_engine_field_names = { field.name for field in dataclasses.fields(AsyncEngineArgs) } for key, value in all_engine_kwargs.items(): - if key in async_engine_field_names: + if key in frontend_field_names: + frontend_kwargs[key] = value + elif key in async_engine_field_names: engine_kwargs[key] = value else: - # Assume anything that is not an engine argument is a frontend - # argument. - frontend_kwargs[key] = value + raise ValueError(f"Unknown engine argument: {key}") return VLLMEngineConfig( model_id=llm_config.model_id, @@ -257,92 +253,3 @@ def get_or_create_pg(self) -> PlacementGroup: logger.info(f"Using new placement group {pg}. {placement_group_table(pg)}") return pg - - -class VLLMSamplingParams(SamplingParams): - """Sampling parameters specific to vLLM engine. - - Args: - top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. - seed: Seed for deterministic sampling with temperature>0. - repetition_penalty: Float that penalizes new tokens based on whether they - appear in the prompt and the generated text so far. Values > 1 encourage - the model to use new tokens, while values < 1 encourage the model to repeat - tokens. - """ - - _ignored_fields = {"best_of", "n", "logit_bias"} - - top_k: Optional[int] = None - repetition_penalty: Optional[float] = None - seed: Optional[int] = None - kv_transfer_params: Optional[Dict[str, Any]] = None - - @field_validator("n", mode="before") - @classmethod - def validate_n(cls, values): - if values != 1: - raise ValidationError("n>1 is not supported yet in rayllm.") - return values - - @classmethod - def _get_model_validate_kwargs(cls, prompt: Prompt) -> Dict[str, Any]: - """ - Extend the base class's `_get_model_validate_kwargs` to include vllm-specific parameters. - """ - generate_kwargs = super()._get_model_validate_kwargs(prompt) - if ( - prompt.parameters is not None - and KV_TRANSFER_PARAMS_KEY in prompt.parameters - ): - generate_kwargs[KV_TRANSFER_PARAMS_KEY] = prompt.parameters[ - KV_TRANSFER_PARAMS_KEY - ] - return generate_kwargs - - -class VLLMGenerationRequest(GenerationRequest): - model_config = ConfigDict(arbitrary_types_allowed=True) - - # Intentionally override the base class's `sampling_params` field. - sampling_params: Optional[ - Union[ - VLLMSamplingParams, - List[VLLMSamplingParams], - ] - ] = None - multi_modal_data: Optional[Dict[str, Any]] = None - disk_multiplex_config: Optional[DiskMultiplexConfig] = None - - @property - def lora_request(self) -> "LoRARequest": - disk_vllm_config = self.disk_multiplex_config - if not disk_vllm_config: - return None - else: - return vllm.lora.request.LoRARequest( - lora_name=disk_vllm_config.model_id, - lora_int_id=disk_vllm_config.lora_assigned_int_id, - lora_local_path=disk_vllm_config.local_path, - long_lora_max_len=disk_vllm_config.max_total_tokens, - ) - - -class VLLMEmbeddingRequest(GenerationRequest): - model_config = ConfigDict(arbitrary_types_allowed=True) - encoding_format: Optional[Literal["float", "base64"]] = "float" - dimensions: Optional[int] = None - disk_multiplex_config: Optional[DiskMultiplexConfig] = None - - @property - def lora_request(self) -> "LoRARequest": - disk_vllm_config = self.disk_multiplex_config - if not disk_vllm_config: - return None - else: - return vllm.lora.request.LoRARequest( - lora_name=disk_vllm_config.model_id, - lora_int_id=disk_vllm_config.lora_assigned_int_id, - lora_local_path=disk_vllm_config.local_path, - long_lora_max_len=disk_vllm_config.max_total_tokens, - ) diff --git a/python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py b/python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py index 399ddbba584b..25579d284f23 100644 --- a/python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py +++ b/python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py @@ -8,12 +8,9 @@ from vllm.config import KVTransferConfig from ray import serve -from ray.llm._internal.serve.configs.prompt_formats import Prompt from ray.llm._internal.serve.configs.server_models import ( - LLMRawResponse, parse_args as parse_llm_configs, ) -from ray.llm._internal.serve.deployments.llm.llm_server import ResponsePostprocessor from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( KV_TRANSFER_PARAMS_KEY, ) @@ -26,8 +23,18 @@ ModelLoadingConfig, build_llm_deployment, ) +from ray.llm._internal.serve.configs.openai_api_models import ( + ChatCompletionRequest, + CompletionRequest, + ChatCompletionResponse, + CompletionResponse, + ErrorResponse, + EmbeddingRequest, + EmbeddingResponse, +) logger = logging.getLogger(__name__) +RequestType = Union[ChatCompletionRequest, CompletionRequest] class PDServingArgs(BaseModel): @@ -92,27 +99,20 @@ async def __init__( llm_config, ) - self.prefill_server = prefill_server - self.decode_server = decode_server + self.prefill_server = prefill_server.options(stream=True) + self.decode_server = decode_server.options(stream=True) - async def _predict( - self, - request_id: str, - prompt: Prompt, - stream: bool, - ) -> AsyncGenerator[LLMRawResponse, None]: - """ - Disaggregate the P/D requests: - 1. Send the request to the prefill server. - 2. Parse the response and forward necessary fields to the decode server. - 3. Return the response from the decode server. - """ + async def embeddings( + self, request: EmbeddingRequest + ) -> AsyncGenerator[EmbeddingResponse, None]: + raise NotImplementedError("Embedding is not supported for P/D disaggregation") + def _prepare_prefill_request(self, request: RequestType) -> RequestType: assert ( - prompt.parameters.get(KV_TRANSFER_PARAMS_KEY, None) is None - ), f"{KV_TRANSFER_PARAMS_KEY} should be empty before proxy" - prefill_prompt = prompt.model_copy(deep=True) - prefill_prompt.parameters[KV_TRANSFER_PARAMS_KEY] = { + getattr(request, "kv_transfer_params", None) is None + ), f"kv_transfer_params should be empty before proxy" + prefill_request = request.model_copy(deep=True) + prefill_request.kv_transfer_params = { "do_remote_decode": True, "do_remote_prefill": False, "remote_engine_id": None, @@ -120,37 +120,61 @@ async def _predict( "remote_host": None, "remote_port": None, } - prefill_prompt.parameters["max_tokens"] = 1 - - prefill_response_gen: AsyncGenerator[ - LLMRawResponse, None - ] = self.prefill_server.options( - # _predict returns generator, we have to set stream=True - stream=True - )._predict.remote( - request_id=request_id, prompt=prefill_prompt, stream=False - ) + prefill_request.max_tokens = 1 + prefill_request.stream = False - prefill_response = await ResponsePostprocessor.merge_stream( - prefill_response_gen - ) + return prefill_request + + def _prepare_decode_request( + self, + request: RequestType, + prefill_chunk: Union[ChatCompletionResponse, CompletionResponse], + ) -> RequestType: + decode_request = request.model_copy(deep=True) + decode_request.kv_transfer_params = prefill_chunk.kv_transfer_params - if prefill_response.error: - logger.error(f"Prefill server returned error: {prefill_response.error}") - yield prefill_response + return decode_request + + async def _handle_request( + self, + request: RequestType, + ) -> AsyncGenerator[ + Union[str, ChatCompletionResponse, CompletionResponse, ErrorResponse], None + ]: + + if isinstance(request, ChatCompletionRequest): + method = "chat" + elif isinstance(request, CompletionRequest): + method = "completions" + else: + raise ValueError(f"Unsupported request type: {type(request)}") + + prefill_request = self._prepare_prefill_request(request) + prefill_gen = getattr(self.prefill_server, method).remote(prefill_request) + + prefill_chunk = await anext(prefill_gen) + + if isinstance(prefill_chunk, ErrorResponse): + logger.error(f"Prefill returned error: {prefill_chunk.error}") + yield prefill_chunk return - kv_transfer_params = prefill_response.metadata[KV_TRANSFER_PARAMS_KEY] - logger.debug( - f"Prefill metadata[{KV_TRANSFER_PARAMS_KEY}]: {kv_transfer_params}" - ) - prompt.parameters[KV_TRANSFER_PARAMS_KEY] = kv_transfer_params + decode_request = self._prepare_decode_request(request, prefill_chunk) + decode_gen = self.decode_server.chat.remote(decode_request) - async for chunk in self.decode_server.options(stream=True)._predict.remote( - request_id=request_id, prompt=prompt, stream=stream - ): + async for chunk in decode_gen: yield chunk + async def chat( + self, request: ChatCompletionRequest + ) -> AsyncGenerator[Union[str, ChatCompletionResponse, ErrorResponse], None]: + return self._handle_request(request) + + async def completions( + self, request: CompletionRequest + ) -> AsyncGenerator[Union[str, CompletionResponse, ErrorResponse], None]: + return self._handle_request(request) + @classmethod def as_deployment(cls) -> serve.Deployment: """Turns PDProxyServer into a Ray Serve deployment.""" diff --git a/python/ray/llm/_internal/serve/deployments/routers/middleware.py b/python/ray/llm/_internal/serve/deployments/routers/middleware.py index d2c2a7a2abde..961e199332ff 100644 --- a/python/ray/llm/_internal/serve/deployments/routers/middleware.py +++ b/python/ray/llm/_internal/serve/deployments/routers/middleware.py @@ -70,7 +70,7 @@ def _uncaught_exception_handler(request: Request, e: Exception): response_payload = get_response_for_error(e, request_id) return JSONResponse( - content=response_payload.model_dump(), status_code=response_payload.error.code + content=response_payload.model_dump(), status_code=response_payload.code ) @@ -115,7 +115,7 @@ async def _handle_application_exceptions( return JSONResponse( content=response_payload.model_dump(), - status_code=response_payload.error.code, + status_code=response_payload.code, ) # This adds last-resort uncaught exception handler into Starlette diff --git a/python/ray/llm/_internal/serve/deployments/routers/router.py b/python/ray/llm/_internal/serve/deployments/routers/router.py index e488f269605c..9782940dc9e5 100644 --- a/python/ray/llm/_internal/serve/deployments/routers/router.py +++ b/python/ray/llm/_internal/serve/deployments/routers/router.py @@ -43,15 +43,12 @@ LLMEmbeddingsResponse, OpenAIHTTPException, to_model_metadata, -) -from ray.llm._internal.serve.configs.openai_api_models_patch import ( ErrorResponse, + ModelCard, + ModelList, ) -from ray.llm._internal.serve.configs.server_models import ( - LLMConfig, - Model, - ModelData, -) + +from ray.llm._internal.serve.configs.server_models import LLMConfig from ray.llm._internal.serve.deployments.llm.multiplex.utils import ( get_base_model_id, get_lora_model_ids, @@ -136,10 +133,15 @@ def _apply_openai_json_format( data: \n\ndata: \n\n... """ if isinstance(response, list): + first_response = next(iter(response)) + if isinstance(first_response, str): + return "".join(response) return "".join(f"data: {r.model_dump_json()}\n\n" for r in response) if hasattr(response, "model_dump_json"): return f"data: {response.model_dump_json()}\n\n" - raise ValueError(f"Unexpected response type: {type(response)}") + if isinstance(response, str): + return response + raise ValueError(f"Unexpected response type: {type(response)}, {response=}") async def _peek_at_generator( @@ -296,7 +298,7 @@ async def _get_response( async for response in getattr(model_handle, call_method).remote(body): yield response - async def model(self, model_id: str) -> Optional[ModelData]: + async def model(self, model_id: str) -> Optional[ModelCard]: if model_id in self._llm_configs: return to_model_metadata(model_id, self._llm_configs[model_id]) @@ -322,8 +324,8 @@ async def model(self, model_id: str) -> Optional[ModelData]: "Check that adapter config file exists in cloud bucket." ) - @fastapi_router_app.get("/v1/models", response_model=Model) - async def models(self) -> Model: + @fastapi_router_app.get("/v1/models", response_model=ModelList) + async def models(self) -> ModelList: """OpenAI API-compliant endpoint to get all rayllm models.""" all_models = dict() for base_model_id, llm_config in self._llm_configs.items(): @@ -341,11 +343,11 @@ async def models(self) -> Model: if model_data is not None: all_models[lora_id] = model_data - return Model(data=list(all_models.values())) + return ModelList(data=list(all_models.values())) # :path allows us to have slashes in the model name - @fastapi_router_app.get("/v1/models/{model:path}", response_model=ModelData) - async def model_data(self, model: str) -> ModelData: + @fastapi_router_app.get("/v1/models/{model:path}", response_model=ModelCard) + async def model_data(self, model: str) -> ModelCard: """OpenAI API-compliant endpoint to get one rayllm model. :param model: The model ID (e.g. "amazon/LightGPT") diff --git a/python/ray/llm/_internal/serve/deployments/utils/node_initialization_utils.py b/python/ray/llm/_internal/serve/deployments/utils/node_initialization_utils.py index c1ba2edb005f..af1650fbe996 100644 --- a/python/ray/llm/_internal/serve/deployments/utils/node_initialization_utils.py +++ b/python/ray/llm/_internal/serve/deployments/utils/node_initialization_utils.py @@ -143,7 +143,8 @@ def _initialize_local_node( if not isinstance(local_path, str) or not os.path.exists(local_path): logger.info(f"Downloading the tokenizer for {engine_config.actual_hf_model_id}") - _ = transformers.AutoTokenizer.from_pretrained( - engine_config.actual_hf_model_id, - trust_remote_code=engine_config.trust_remote_code, - ) + # TODO: NEEDED for Mistral models that don't support tekken + # _ = transformers.AutoTokenizer.from_pretrained( + # engine_config.actual_hf_model_id, + # trust_remote_code=engine_config.trust_remote_code, + # ) diff --git a/python/ray/llm/_internal/serve/deployments/utils/server_utils.py b/python/ray/llm/_internal/serve/deployments/utils/server_utils.py index b54b4cb6d5b5..e6628e266e38 100644 --- a/python/ray/llm/_internal/serve/deployments/utils/server_utils.py +++ b/python/ray/llm/_internal/serve/deployments/utils/server_utils.py @@ -11,12 +11,7 @@ from ray import serve from ray.llm._internal.serve.configs.openai_api_models import OpenAIHTTPException -from ray.llm._internal.serve.configs.openai_api_models_patch import ( - ErrorResponse, -) -from ray.llm._internal.serve.configs.server_models import ( - LLMRawResponse, -) +from ray.llm._internal.serve.configs.openai_api_models import ErrorResponse from ray.llm._internal.serve.observability.logging import get_logger logger = get_logger(__name__) @@ -78,7 +73,7 @@ def _extract_message(e): def get_response_for_error( e: Exception, request_id: str, -) -> LLMRawResponse: +) -> ErrorResponse: if isinstance(e, HTTPException): status_code = e.status_code elif isinstance(e, OpenAIHTTPException): @@ -116,13 +111,11 @@ def get_response_for_error( internal_message += f" (Request ID: {request_id})" error_response = ErrorResponse( - message=message, + message=f"Message: {message}, Internal exception: {internal_message}, original exception: {str(e)}", code=status_code, - internal_message=internal_message, type=exc_type, - original_exception=e, ) - return LLMRawResponse(error=error_response) + return error_response def get_serve_request_id() -> str: @@ -140,10 +133,3 @@ def get_model_request_id(model: str): def replace_prefix(model: str) -> str: """Replace -- with / in model name to handle slashes within the URL path segment""" return model.replace("--", "/") - - -def floats_to_base64(float_list: List[float]) -> str: - """Encode a list of floats as base64 as needed for the embedding API response.""" - binary = struct.pack(f"{len(float_list)}f", *float_list) - encoded = base64.b64encode(binary).decode("utf-8") - return encoded diff --git a/python/ray/llm/tests/serve/conftest.py b/python/ray/llm/tests/serve/conftest.py index 4ca469db2bea..4b6c5a38390e 100644 --- a/python/ray/llm/tests/serve/conftest.py +++ b/python/ray/llm/tests/serve/conftest.py @@ -14,6 +14,11 @@ from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( VLLMEngineConfig, ) +from ray.llm._internal.serve.configs.openai_api_models import ( + ChatCompletionRequest, + CompletionRequest, + EmbeddingCompletionRequest, +) from ray.serve.llm import ( LLMConfig, LLMServer, @@ -62,6 +67,50 @@ def llm_config(model_pixtral_12b, disable_placement_bundles): ) +@pytest.fixture +def mock_llm_config(): + """LLM config for mock engine testing.""" + return LLMConfig( + model_loading_config=ModelLoadingConfig(model_id="mock-model"), + runtime_env={}, + log_engine_metrics=False, + ) + + +@pytest.fixture +def mock_chat_request(stream, max_tokens): + """Fixture for creating chat completion requests for mock testing.""" + return ChatCompletionRequest( + model="mock-model", + messages=[{"role": "user", "content": "Hello, world!"}], + max_tokens=max_tokens, + stream=stream, + ) + + +@pytest.fixture +def mock_completion_request(stream, max_tokens): + """Fixture for creating text completion requests for mock testing.""" + return CompletionRequest( + model="mock-model", + prompt="Complete this text:", + max_tokens=max_tokens, + stream=stream, + ) + + +@pytest.fixture +def mock_embedding_request(dimensions): + """Fixture for creating embedding requests for mock testing.""" + request = EmbeddingCompletionRequest( + model="mock-model", + input="Text to embed", + ) + if dimensions: + request.dimensions = dimensions + return request + + def get_test_model_path(yaml_file: str) -> pathlib.Path: current_file_dir = pathlib.Path(__file__).absolute().parent test_model_path = current_file_dir / yaml_file diff --git a/python/ray/llm/tests/serve/cpu/config_generator/test_text_completion.py b/python/ray/llm/tests/serve/cpu/config_generator/test_text_completion.py index d8464402fd19..306594caad43 100644 --- a/python/ray/llm/tests/serve/cpu/config_generator/test_text_completion.py +++ b/python/ray/llm/tests/serve/cpu/config_generator/test_text_completion.py @@ -101,7 +101,7 @@ def test_populate_custom_model( model_config = populate_text_completion_model_config(input_model_config) self._assert_models(model_config, input_model_config) - serve_config = get_serve_config(input_model_config, "./file.yaml") + serve_config = get_serve_config("./file.yaml") assert len(serve_config["applications"][0]["args"]["llm_configs"]) == 1 def _assert_models( diff --git a/python/ray/llm/tests/serve/cpu/configs/test_openai_api_models.py b/python/ray/llm/tests/serve/cpu/configs/test_openai_api_models.py deleted file mode 100644 index ff92ecea0a7b..000000000000 --- a/python/ray/llm/tests/serve/cpu/configs/test_openai_api_models.py +++ /dev/null @@ -1,29 +0,0 @@ -from ray.llm._internal.serve.configs.openai_api_models import DeltaMessage - - -def test_delta_message_null_content(): - """Test that the DeltaMessage class is correctly constructed. - - When the content is passed as None, it should be set to an empty string. - """ - role = "user" - delta_message_implicitly_null_content = DeltaMessage( - role=role, - ) - - delta_message_explicitly_null_content = DeltaMessage( - role=role, - content=None, - ) - - delta_message_empty_string_content = DeltaMessage( - role=role, - content="", - ) - - assert delta_message_implicitly_null_content.role == role - assert delta_message_explicitly_null_content.role == role - assert delta_message_empty_string_content.role == role - assert delta_message_implicitly_null_content.content == "" - assert delta_message_explicitly_null_content.content == "" - assert delta_message_empty_string_content.content == "" diff --git a/python/ray/llm/tests/serve/cpu/configs/test_prompt_formats.py b/python/ray/llm/tests/serve/cpu/configs/test_prompt_formats.py deleted file mode 100644 index e120d7c1f5f5..000000000000 --- a/python/ray/llm/tests/serve/cpu/configs/test_prompt_formats.py +++ /dev/null @@ -1,83 +0,0 @@ -import sys - -import pytest -from pydantic import ValidationError - -from ray.llm._internal.serve.configs.prompt_formats import ( - Image, - Message, - Prompt, - Text, -) - - -def test_validation_message(): - # check that message with assistant role can have content that - # is a string or none, but nothing else - Message.model_validate({"role": "assistant", "content": "Hello, World!"}) - - Message.model_validate({"role": "assistant", "content": ""}) - - Message.model_validate({"role": "assistant", "content": None}) - - with pytest.raises(ValueError): - Message.model_validate( - { - "role": "assistant", - "content": { - "NOT_VALID", - }, - } - ) - - # Test system and user roles - for role in ["system", "user"]: - # this should pass - Message.model_validate({"role": role, "content": "Hello, World!"}) - - Message.model_validate({"role": role, "content": ""}) - - # a non string content should raise an error - - with pytest.raises(ValueError): - Message.model_validate( - { - "role": role, - "content": { - "NOT_VALID", - }, - } - ) - - with pytest.raises(ValueError): - Message.model_validate({"role": role, "content": None}) - - # test message with image. - Message( - role="user", - content=[ - Text(type="text", text="This is a test."), - Image(type="image_url", image_url={"url": "foo"}), - ], - ) - - -def test_prompt_validation(): - # Test valid prompt creation - Prompt(prompt="This is a test message.") - - Prompt( - prompt=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="Hello!"), - ] - ) - - # Test invalid prompt creation - with pytest.raises(ValidationError): - # Empty list should raise error - Prompt(prompt=[]) - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/configs/test_server_models.py b/python/ray/llm/tests/serve/cpu/configs/test_server_models.py deleted file mode 100644 index a885a88e2b11..000000000000 --- a/python/ray/llm/tests/serve/cpu/configs/test_server_models.py +++ /dev/null @@ -1,96 +0,0 @@ -import sys - -import pytest - -from ray.llm._internal.serve.configs.prompt_formats import Prompt -from ray.llm._internal.serve.configs.server_models import SamplingParams - - -class TestSamplingParams: - def test_default_initialization(self): - """Test that SamplingParams can be initialized with default values.""" - params = SamplingParams() - - assert params.max_tokens is None - assert params.temperature is None - assert params.top_p is None - assert params.n == 1 - assert params.logprobs is None - assert params.top_logprobs is None - assert params.logit_bias is None - assert params.stop is None - assert params.stop_tokens is None - assert params.ignore_eos is None - assert params.presence_penalty is None - assert params.frequency_penalty is None - assert params.best_of == 1 - assert params.response_format is None - - def test_initialization_with_values(self): - """Test that SamplingParams can be initialized with specific values.""" - params = SamplingParams( - max_tokens=100, - temperature=0.7, - top_p=0.9, - n=2, - logprobs=True, - top_logprobs=5, - stop=["END", "STOP"], - stop_tokens=[1, 2, 3], - presence_penalty=0.5, - frequency_penalty=0.3, - best_of=3, - ) - - assert params.max_tokens == 100 - assert params.temperature == 0.7 - assert params.top_p == 0.9 - assert params.n == 2 - assert params.logprobs is True - assert params.top_logprobs == 5 - assert params.stop == ["END", "STOP"] - assert params.stop_tokens == [1, 2, 3] - assert params.presence_penalty == 0.5 - assert params.frequency_penalty == 0.3 - assert params.best_of == 3 - - def test_stop_valid_sequences(self): - """Test that valid stop sequences are processed correctly.""" - stop_sequences = ["END", "STOP", "FINISH", "END"] - params = SamplingParams(stop=stop_sequences) - assert params.stop == ["END", "FINISH", "STOP"] # Should be unique - - def test_idempotency(self): - params = SamplingParams() - new_params = SamplingParams.model_validate(params.model_dump()) - assert params.model_dump() == new_params.model_dump() - - @pytest.mark.parametrize( - "stop, stop_tokens", - [ - (["B-END", "A-End"], None), - (["B-END", "A-End"], []), - (None, [100, 50]), - (None, None), - ], - ) - def test_from_prompt_with_dict_parameters(self, stop, stop_tokens): - """Test from_prompt method with dictionary parameters.""" - prompt = Prompt( - prompt="Test prompt", - parameters={ - "stop": stop, - "stop_tokens": stop_tokens, - }, - ) - - params = SamplingParams.from_prompt(prompt) - - assert params.stop == (sorted(stop) if stop is not None else None) - assert params.stop_tokens == ( - sorted(stop_tokens) if stop_tokens is not None else None - ) - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_lora_deployment_base_client.py b/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_lora_deployment_base_client.py index 282130cefa20..7c806cade746 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_lora_deployment_base_client.py +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_lora_deployment_base_client.py @@ -6,9 +6,9 @@ from fastapi import HTTPException from ray import serve -from ray.llm._internal.serve.configs.server_models import ModelData +from ray.llm._internal.serve.configs.openai_api_models import ModelCard from ray.llm._internal.serve.deployments.llm.llm_server import LLMDeployment -from ray.llm.tests.serve.mocks.mock_vllm_engine import MockEchoVLLMEngine +from ray.llm.tests.serve.mocks.mock_vllm_engine import MockVLLMEngine from ray.serve.handle import DeploymentHandle from ray.serve.llm import LLMConfig, LLMRouter, LoraConfig @@ -57,7 +57,7 @@ def get_mocked_llm_deployments(llm_configs) -> List[DeploymentHandle]: llm_deployments.append( deployment.bind( llm_config=llm_config, - engine_cls=MockEchoVLLMEngine, + engine_cls=MockVLLMEngine, ) ) return llm_deployments @@ -97,10 +97,10 @@ async def test_lora_get_model(shutdown_ray_and_serve, disable_placement_bundles) # Case 2: Model has only the base model config. base_model_config = await router_handle.model.remote(base_model_id) - assert isinstance(base_model_config, ModelData) + assert isinstance(base_model_config, ModelCard) base_model_data = base_model_config.model_dump() assert base_model_data["id"] == base_model_id - base_model_config = base_model_data["rayllm_metadata"] + base_model_config = base_model_data["metadata"] # Case 3: model has a multiplex config in the cloud. llm_config = VLLM_APP.model_copy(deep=True) @@ -122,10 +122,10 @@ async def fake_get_lora_model_metadata(*args, **kwargs): router_handle = serve.run(router_deployment) lora_model_config = await router_handle.model.remote(lora_model) - assert isinstance(lora_model_config, ModelData) + assert isinstance(lora_model_config, ModelCard) lora_model_data = lora_model_config.model_dump() assert lora_model_data["id"] == lora_model - lora_metadata = lora_model_data["rayllm_metadata"] + lora_metadata = lora_model_data["metadata"] assert lora_metadata["model_id"] == lora_model assert lora_metadata["base_model_id"] == base_model_id assert lora_metadata["max_request_context_length"] == 4096 diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_multiplex_deployment.py b/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_multiplex_deployment.py deleted file mode 100644 index 4680ad8b273f..000000000000 --- a/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_multiplex_deployment.py +++ /dev/null @@ -1,83 +0,0 @@ -import sys - -import pytest - -from ray import serve -from ray.llm._internal.serve.configs.prompt_formats import ( - Prompt, -) -from ray.llm._internal.serve.configs.server_models import ( - LLMConfig, -) -from ray.llm._internal.serve.deployments.llm.llm_server import LLMDeployment -from ray.llm.tests.serve.mocks.mock_vllm_engine import ( - FakeLoraModelLoader, - MockMultiplexEngine, -) - - -@pytest.fixture(name="handle") -def handle(shutdown_ray_and_serve): - - llm_config = LLMConfig( - model_loading_config={ - "model_id": "meta-llama/Llama-2-7b-hf", - }, - lora_config={ - "max_num_adapters_per_replica": 16, - "dynamic_lora_loading_path": "s3://my/s3/path_here", - }, - ) - - handle = serve.run( - LLMDeployment.options(placement_group_bundles=[{"CPU": 1}],).bind( - llm_config, - engine_cls=MockMultiplexEngine, - model_downloader=FakeLoraModelLoader(), - ), - ) - - return handle - - -@pytest.mark.asyncio -@pytest.mark.parametrize("stream_tokens", [True, False]) -@pytest.mark.parametrize("multiplexed_model_id", ["test_model", None]) -async def test_multiplex_deployment( - handle, - stream_tokens: bool, - multiplexed_model_id: str, -): - - gen = handle.options( - stream=True, multiplexed_model_id=multiplexed_model_id - )._predict.remote( - "req_id", - Prompt(prompt="Generate some sql please.", use_prompt_format=False), - stream=stream_tokens, - ) - - # gen is an async generator - # we need to convert it to a list of outputs in one line - outputs = [] - async for x in gen: - outputs.append(x) - - assert len(outputs) == 1 - output = outputs[0] - - assert output.stream == stream_tokens - - if multiplexed_model_id is None: - assert output.disk_multiplex_config is None - else: - assert output.disk_multiplex_config.model_dump() == { - "model_id": multiplexed_model_id, - "max_total_tokens": None, - "local_path": "/local/path", - "lora_assigned_int_id": 1, - } - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine.py b/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine.py new file mode 100644 index 000000000000..a7253dde1dec --- /dev/null +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine.py @@ -0,0 +1,82 @@ +"""This tests the LLM engine by testing the mocked implementations directly. + +This implicitly tests the consistency of the engine API through time. +Also tests that our Mock is behaving as expected to ensure that the downstream tests using Mocks are correct from Mock implementation perspective. + + +We have the following Mock: + +- An engine that returns a string of form "test_i" for i in range(max_tokens) +""" + +from ray.llm.tests.serve.mocks.mock_vllm_engine import MockVLLMEngine +from ray.llm.tests.serve.utils.testing_utils import LLMResponseValidator + +import pytest +from typing import Optional + + +class TestMockLLMEngine: + @pytest.mark.parametrize("api_type", ["chat", "completion"]) + @pytest.mark.parametrize("stream", [False, True]) + @pytest.mark.parametrize("max_tokens", [5]) + @pytest.mark.asyncio + async def test_unified_llm_engine( + self, + mock_llm_config, + mock_chat_request, + mock_completion_request, + api_type: str, + stream: bool, + max_tokens: int, + ): + """Unified test for both chat and completion APIs, streaming and non-streaming.""" + # Create and start the engine + engine = MockVLLMEngine(mock_llm_config) + await engine.start() + + # Create request based on API type + if api_type == "chat": + request = mock_chat_request + response_generator = engine.chat(request) + elif api_type == "completion": + request = mock_completion_request + response_generator = engine.completions(request) + + print( + f"\n\n_____ {api_type.upper()} ({'STREAMING' if stream else 'NON-STREAMING'}) max_tokens={max_tokens} _____\n\n" + ) + + if stream: + # Collect streaming chunks + chunks = [] + async for chunk in response_generator: + assert isinstance(chunk, str) + chunks.append(chunk) + + # Validate streaming response + LLMResponseValidator.validate_streaming_chunks(chunks, api_type, max_tokens) + else: + # Validate non-streaming response + async for response in response_generator: + LLMResponseValidator.validate_non_streaming_response( + response, api_type, max_tokens + ) + + @pytest.mark.parametrize("dimensions", [None, 512]) + @pytest.mark.asyncio + async def test_embedding_mock_engine( + self, mock_llm_config, mock_embedding_request, dimensions: Optional[int] + ): + """Test embedding API with different dimensions.""" + # Create and start the engine + engine = MockVLLMEngine(mock_llm_config) + await engine.start() + + # Create embedding request + request = mock_embedding_request + + print(f"\n\n_____ EMBEDDING dimensions={dimensions} _____\n\n") + + async for response in engine.embeddings(request): + LLMResponseValidator.validate_embedding_response(response, dimensions) diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_server.py b/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_server.py index 146aa7f96d8e..dd16a4f094f2 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_server.py +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_server.py @@ -1,427 +1,267 @@ import sys -from unittest.mock import AsyncMock +from typing import Optional import pytest +from unittest.mock import patch -from ray.llm._internal.serve.configs.constants import MODEL_RESPONSE_BATCH_TIMEOUT_MS -from ray.llm._internal.serve.configs.openai_api_models import ( - ChatCompletionRequest, - CompletionRequest, - ErrorResponse, +from ray.llm.tests.serve.mocks.mock_vllm_engine import ( + MockVLLMEngine, + FakeLoraModelLoader, ) -from ray.llm._internal.serve.configs.server_models import ( - FinishReason, - LLMConfig, - LLMRawResponse, - ModelLoadingConfig, -) -from ray.llm._internal.serve.deployments.llm.llm_server import ( - ResponsePostprocessor, -) -from ray.llm.tests.serve.mocks.mock_vllm_engine import MockVLLMEngine - - -async def stream_generator(): - yield LLMRawResponse( - generated_text="Hello", - num_generated_tokens=1, - num_generated_tokens_batch=1, - num_input_tokens=5, - finish_reason=None, +from ray.llm.tests.serve.utils.testing_utils import LLMResponseValidator +from ray import serve +from ray.llm._internal.serve.deployments.llm.llm_server import LLMServer +from ray.llm._internal.serve.configs.server_models import LoraConfig + + +@pytest.fixture +def serve_handle(mock_llm_config, stream_batching_interval_ms=0): + mock_llm_config.experimental_configs = { + "stream_batching_interval_ms": stream_batching_interval_ms, + } + + app = serve.deployment(LLMServer).bind(mock_llm_config, engine_cls=MockVLLMEngine) + handle = serve.run(app) + # We set stream=True because the interfaces are async generators regardless + # of the stream flag on request. + handle = handle.options(stream=True) + yield handle + serve.shutdown() + + +@pytest.fixture +def multiplexed_serve_handle(mock_llm_config, stream_batching_interval_ms=0): + mock_llm_config.experimental_configs = { + "stream_batching_interval_ms": stream_batching_interval_ms, + } + mock_llm_config.lora_config = LoraConfig( + dynamic_lora_loading_path="s3://my/s3/path_here", + download_timeout_s=60, + max_download_tries=3, ) - yield LLMRawResponse( - generated_text=" world", - num_generated_tokens=1, - num_generated_tokens_batch=1, - num_input_tokens=5, - finish_reason=FinishReason.STOP, + app = serve.deployment(LLMServer).bind( + mock_llm_config, + engine_cls=MockVLLMEngine, + model_downloader=FakeLoraModelLoader, ) - - -class TestResponsePostprocessor: - @pytest.mark.asyncio - async def test_process_chat_streaming(self): - """Test processing streaming chat responses.""" - postprocessor = ResponsePostprocessor() - model = "test_model" - - # Process the generator as a streaming chat response - response_gen = postprocessor.process_chat( - model, stream_generator(), stream=True - ) - - # Collect all responses - responses = [resp async for resp in response_gen] - - # Verify we got the expected responses - assert len(responses) >= 3 # Role message + content chunks + final message - assert ( - responses[0].choices[0].delta.role == "assistant" - ) # First message has role - assert ( - responses[1].choices[0].delta.content == "Hello" - ) # Second has first chunk - assert ( - responses[-1].choices[0].finish_reason == "stop" - ) # Last has finish reason - - @pytest.mark.asyncio - async def test_process_chat_non_streaming(self): - """Test processing non-streaming chat responses.""" - postprocessor = ResponsePostprocessor() - model = "test_model" - - # Process the generator as a non-streaming chat response - response_gen = postprocessor.process_chat( - model, stream_generator(), stream=False - ) - - # Collect the single response - responses = [resp async for resp in response_gen] - assert len(responses) == 1 - - # Verify the content of the response - response = responses[0] - assert response.choices[0].message.role == "assistant" - assert response.choices[0].message.content == "Hello world" - assert response.choices[0].finish_reason == "stop" - assert response.usage.prompt_tokens == 5 - assert response.usage.completion_tokens == 2 - assert response.usage.total_tokens == 7 - - @pytest.mark.asyncio - async def test_process_completions_streaming(self): - """Test processing streaming completion responses.""" - postprocessor = ResponsePostprocessor() - model = "test_model" - - # Process the generator as a streaming completion response - response_gen = postprocessor.process_completions( - model, stream_generator(), stream=True - ) - - # Collect all responses - responses = [resp async for resp in response_gen] - - # Verify we got the expected responses - assert len(responses) == 2 - assert responses[0].choices[0].text == "Hello" - assert responses[0].choices[0].finish_reason is None - assert responses[1].choices[0].text == " world" - assert responses[1].choices[0].finish_reason == "stop" - - @pytest.mark.asyncio - async def test_process_completions_non_streaming(self): - """Test processing non-streaming completion responses.""" - postprocessor = ResponsePostprocessor() - model = "test_model" - - # Process the generator as a non-streaming completion response - response_gen = postprocessor.process_completions( - model, stream_generator(), stream=False - ) - - # Collect the single response - responses = [resp async for resp in response_gen] - assert len(responses) == 1 - - # Verify the content of the response - response = responses[0] - assert response.choices[0].text == "Hello world" - assert response.choices[0].finish_reason == "stop" - assert response.usage.prompt_tokens == 5 - assert response.usage.completion_tokens == 2 - assert response.usage.total_tokens == 7 - - @pytest.mark.asyncio - async def test_error_handling(self): - """Test error handling in response streams.""" - postprocessor = ResponsePostprocessor() - model = "test_model" - - # Create a generator that raises an exception - - error_response = ErrorResponse( - message="Test error", - code=500, - internal_message="Test error", - type="Test error", - original_exception=Exception("Test error"), - ) - - async def gen(): - yield LLMRawResponse( - error=error_response, - ) - yield LLMRawResponse( - generated_text="Hello", - num_generated_tokens=1, - num_generated_tokens_batch=1, - num_input_tokens=5, - finish_reason=None, - ) - - # Process the generator as a non-streaming chat response - response_gen = postprocessor.process_chat(model, gen(), stream=False) - - # Collect the responses, should contain the error - responses = [resp async for resp in response_gen] - assert len(responses) == 1 - assert responses[0] == error_response + handle = serve.run(app) + handle = handle.options(stream=True, multiplexed_model_id="test_model_id") + yield handle + serve.shutdown() class TestLLMServer: + @pytest.mark.parametrize("api_type", ["chat", "completion"]) + @pytest.mark.parametrize("stream", [False, True]) + @pytest.mark.parametrize("max_tokens", [5]) + @pytest.mark.parametrize("stream_batching_interval_ms", [0, 10000]) @pytest.mark.asyncio - async def test_get_batch_interval_ms(self, create_server): - """Test that the batch interval is set correctly in the config.""" - - # Test with a no stream_batching_interval_ms. - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="llm_model_id", - ), - ) - server = await create_server(llm_config, engine_cls=MockVLLMEngine) - - assert server._get_batch_interval_ms() == MODEL_RESPONSE_BATCH_TIMEOUT_MS - - # Test with a non-zero stream_batching_interval_ms. - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="llm_model_id", - ), - experimental_configs={ - "stream_batching_interval_ms": 13, - }, - ) - server = await create_server(llm_config, engine_cls=MockVLLMEngine) - assert server._get_batch_interval_ms() == 13 - - # Test with zero stream_batching_interval_ms. - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="llm_model_id", - ), - experimental_configs={ - "stream_batching_interval_ms": 0, - }, - ) - server = await create_server(llm_config, engine_cls=MockVLLMEngine) - assert server._get_batch_interval_ms() == 0 - - @pytest.mark.asyncio - async def test_chat_streaming(self, create_server): - """Test chat completion in streaming mode.""" - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="test_model", - ), - experimental_configs={ - # Maximum batching - "stream_batching_interval_ms": 10000, - }, + async def test_unified_llm_server( + self, + serve_handle, + mock_llm_config, + mock_chat_request, + mock_completion_request, + api_type: str, + stream: bool, + max_tokens: int, + stream_batching_interval_ms: int, + ): + """Unified test for both chat and completion APIs, streaming and non-streaming.""" + + # Create request based on API type + if api_type == "chat": + request = mock_chat_request + batched_chunks = serve_handle.chat.remote(request) + elif api_type == "completion": + request = mock_completion_request + batched_chunks = serve_handle.completions.remote(request) + + print( + f"\n\n_____ {api_type.upper()} ({'STREAMING' if stream else 'NON-STREAMING'}) max_tokens={max_tokens} batching_interval_ms={stream_batching_interval_ms} _____\n\n" ) - server = await create_server(llm_config, engine_cls=MockVLLMEngine) - - # Create a chat completion request - request = ChatCompletionRequest( - model="test_model", - messages=[dict(role="user", content="Hello")], - stream=True, - max_tokens=5, - ) - - # Get the response stream - response_stream = await server.chat(request) - - # Collect responses from the stream - responses = [] - async for response in response_stream: - responses.append(response) - - # Each response should be an iterator over ChatCompletionStreamResponse - # Check that we got responses - assert len(responses) > 0 - - text = "" - role = None - for response in responses: - assert isinstance(response, list) - for chunk in response: - if chunk.choices[0].delta.role is not None and role is None: - role = chunk.choices[0].delta.role - - text += chunk.choices[0].delta.content - - assert role == "assistant" - # What mock vllm engine returns - assert text == "test_0 test_1 test_2 test_3 test_4 " + if stream: + # Collect responses from the stream + chunks = [] + async for batch in batched_chunks: + chunks.extend(batch) + + # Check that we got responses + assert len(chunks) > 0 + + # Validate streaming response + LLMResponseValidator.validate_streaming_chunks(chunks, api_type, max_tokens) + else: + # Collect non-streaming response + chunks = [] + async for batch in batched_chunks: + chunks.append(batch) + + # Check that we got one response + assert len(chunks) == 1 + + # Validate non-streaming response + LLMResponseValidator.validate_non_streaming_response( + chunks[0], api_type, max_tokens + ) + @pytest.mark.parametrize("dimensions", [None, 512]) @pytest.mark.asyncio - async def test_chat_non_streaming(self, create_server): - """Test non-streaming chat completion.""" - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="test_model", - ), - ) + async def test_embedding_llm_server( + self, + serve_handle, + mock_llm_config, + mock_embedding_request, + dimensions: Optional[int], + ): + """Test embedding API from LLMServer perspective.""" - server = await create_server(llm_config, engine_cls=MockVLLMEngine) + # Create embedding request + request = mock_embedding_request - # Create a chat completion request - request = ChatCompletionRequest( - model="test_model", - messages=[dict(role="user", content="Hello")], - stream=False, - max_tokens=5, - ) + print(f"\n\n_____ EMBEDDING SERVER dimensions={dimensions} _____\n\n") # Get the response - response_stream = await server.chat(request) + batched_chunks = serve_handle.embeddings.remote(request) # Collect responses (should be just one) - responses = [] - async for response in response_stream: - responses.append(response) + chunks = [] + async for batch in batched_chunks: + chunks.append(batch) # Check that we got one response - assert len(responses) == 1 - assert responses[0].choices[0].message.role == "assistant" - assert ( - responses[0].choices[0].message.content - == "test_0 test_1 test_2 test_3 test_4 " - ) - assert responses[0].choices[0].finish_reason == "stop" - - @pytest.mark.asyncio - async def test_completions_streaming(self, create_server): - """Test streaming text completion.""" - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="test_model", - ), - experimental_configs={ - # Maximum batching - "stream_batching_interval_ms": 10000, - }, - ) - - server = await create_server(llm_config, engine_cls=MockVLLMEngine) - - # Create a completion request - request = CompletionRequest( - model="test_model", - prompt="Hello", - stream=True, - max_tokens=5, - ) - - # Get the response stream - response_stream = await server.completions(request) - - # Collect responses from the stream - responses = [] - async for response in response_stream: - responses.append(response) + assert len(chunks) == 1 - # Check that we got responses - assert len(responses) > 0 - - text = "" - for response in responses: - assert isinstance(response, list) - for chunk in response: - text += chunk.choices[0].text - - assert text == "test_0 test_1 test_2 test_3 test_4 " + # Validate embedding response + LLMResponseValidator.validate_embedding_response(chunks[0], dimensions) @pytest.mark.asyncio - async def test_completions_non_streaming(self, create_server): - """Test non-streaming text completion.""" - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="test_model", - ), - ) - - server = await create_server(llm_config, engine_cls=MockVLLMEngine) - - # Create a completion request - request = CompletionRequest( - model="test_model", - prompt="Hello", - stream=False, - max_tokens=5, - ) - - # Get the response - response_stream = await server.completions(request) + async def test_check_health(self, create_server, mock_llm_config): + """Test health check functionality.""" - # Collect responses (should be just one) - responses = [] - async for response in response_stream: - responses.append(response) + # Mock the engine's check_health method + class LocalMockEngine(MockVLLMEngine): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.check_health_called = False - # Check that we got one response - assert len(responses) == 1 - assert responses[0].choices[0].text == "test_0 test_1 test_2 test_3 test_4 " - assert responses[0].choices[0].finish_reason == "stop" - - @pytest.mark.asyncio - async def test_check_health(self, create_server): - """Test health check functionality.""" - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="test_model", - ), - ) + async def check_health(self): + self.check_health_called = True # Create a server with a mocked engine - server = await create_server(llm_config, engine_cls=MockVLLMEngine) - - # Mock the engine's check_health method - server.engine.check_health = AsyncMock(return_value=None) + server = await create_server(mock_llm_config, engine_cls=LocalMockEngine) # Perform the health check, no exceptions should be raised await server.check_health() - server.engine.check_health.assert_called_once() - - @pytest.mark.asyncio - async def test_error_handling(self, create_server): - """Test error handling in the server.""" - llm_config = LLMConfig( - model_loading_config=ModelLoadingConfig( - model_id="test_model", - ), - ) - server = await create_server(llm_config, engine_cls=MockVLLMEngine) + # Check that the health check method was called + assert server.engine.check_health_called - # Mock the _predict method to raise an exception - server._predict = AsyncMock(side_effect=Exception("Test error")) + @pytest.mark.asyncio + async def test_llm_config_property(self, create_server, mock_llm_config): + """Test the llm_config property.""" + server = await create_server(mock_llm_config, engine_cls=MockVLLMEngine) + llm_config = await server.llm_config() + assert isinstance(llm_config, type(mock_llm_config)) + + @pytest.mark.parametrize("stream", [False]) + @pytest.mark.parametrize("max_tokens", [5]) + @pytest.mark.asyncio + async def test_request_id_handling( + self, + serve_handle, + mock_llm_config, + mock_chat_request, + stream: bool, + max_tokens: int, + ): + """Test that the request id is handled correctly.""" # Create a chat completion request - request = ChatCompletionRequest( - model="test_model", - messages=[dict(role="user", content="Hello")], - stream=False, + # We should patch get_server_request_id to return a test_request_id + serve.context._serve_request_context.set( + serve.context._RequestContext(**{"request_id": "test_request_id"}) ) - # Get the response - response_stream = await server.chat(request) + chunks = [] + async for chunk in serve_handle.chat.remote(mock_chat_request): + chunks.append(chunk) - # Collect responses (should contain an error) - responses = [] - async for response in response_stream: - responses.append(response) + assert len(chunks) == 1 + assert chunks[0].id == "test_request_id" - # Check that we got an error response - assert len(responses) > 0 - assert isinstance(responses[0], ErrorResponse) + @pytest.mark.parametrize("api_type", ["chat", "completion"]) + @pytest.mark.parametrize("stream", [False, True]) + @pytest.mark.parametrize("max_tokens", [5]) + @pytest.mark.parametrize("stream_batching_interval_ms", [0, 10000]) + @pytest.mark.asyncio + async def test_multiplexed_request_handling( + self, + multiplexed_serve_handle, + mock_chat_request, + mock_completion_request, + api_type: str, + stream: bool, + max_tokens: int, + stream_batching_interval_ms: int, + ): + """Unified test for multiplexed (LoRA) requests - both chat and completion APIs, streaming and non-streaming.""" + + # Create request based on API type and set model ID for multiplexing + if api_type == "chat": + request = mock_chat_request + batched_chunks = multiplexed_serve_handle.chat.remote(request) + elif api_type == "completion": + request = mock_completion_request + batched_chunks = multiplexed_serve_handle.completions.remote(request) + + request.model = "test_model_id" + print( + f"\n\n_____ MULTIPLEXED {api_type.upper()} ({'STREAMING' if stream else 'NON-STREAMING'}) max_tokens={max_tokens} batching_interval_ms={stream_batching_interval_ms} _____\n\n" + ) - # Internal server error - assert responses[0].code == 500 + if stream: + # Collect responses from the stream + chunks = [] + async for batch in batched_chunks: + if isinstance(batch, list): + chunks.extend(batch) + else: + chunks.append(batch) + + # Check that we got responses + assert len(chunks) > 0 + + # Validate streaming response with LoRA model ID + LLMResponseValidator.validate_streaming_chunks( + chunks, api_type, max_tokens, lora_model_id=request.model + ) + else: + # Collect non-streaming response + chunks = [] + async for batch in batched_chunks: + if isinstance(batch, list): + chunks.extend(batch) + else: + chunks.append(batch) + + # Check that we got one response + assert len(chunks) == 1 + + # Validate non-streaming response with LoRA model ID + LLMResponseValidator.validate_non_streaming_response( + chunks[0], api_type, max_tokens, lora_model_id=request.model + ) + + @pytest.mark.asyncio + async def test_push_telemetry(self, create_server, mock_llm_config): + """Test that the telemetry push is called properly.""" + with patch( + "ray.llm._internal.serve.deployments.llm.llm_server.push_telemetry_report_for_all_models" + ) as mock_push_telemetry: + await create_server(mock_llm_config, engine_cls=MockVLLMEngine) + mock_push_telemetry.assert_called_once() if __name__ == "__main__": diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/test_vllm_engine.py b/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/test_vllm_engine.py deleted file mode 100644 index 57cd19f20283..000000000000 --- a/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/test_vllm_engine.py +++ /dev/null @@ -1,199 +0,0 @@ -import asyncio -import json -import sys -from types import SimpleNamespace -from typing import List -from unittest.mock import Mock - -import pytest - -from ray.llm._internal.serve.configs.server_models import ( - FinishReason, - LLMConfig, -) -from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import ( - VLLMEngine, -) -from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( - VLLMGenerationRequest, - VLLMSamplingParams, -) - - -class FakeVLLMEngine: - def __init__(self, mock: Mock, output=None): - self._engine_client = mock - - self._output = output or [] - self.num_generated = 0 - - async def generate(self, *args, **kwargs): - # Record the call - self._engine_client.generate(*args, **kwargs) - - for x in self._output: - await asyncio.sleep(0.01) - self.num_generated += 1 - yield x - - async def abort(self, request_id: str): - # Record the call - self._engine_client.abort(request_id) - - def _abort(self, request_id: str, **kwargs): - # Record the call - self._engine_client.abort(request_id) - - -def get_fake_responses(*tokens: List[str]): - total = "" - output = [] - - for token in tokens: - total += token - # For some reason vLLM appears to return the full text on each iteration - # We should fix this in vllm - output.append( - SimpleNamespace( - outputs=[ - SimpleNamespace( - text=total, - finish_reason="stop", # for some reason, vllm returns a finish reason on all tokens. We should fix this too. - token_ids=[0], - logprobs=[], - ) - ], - prompt_token_ids=[0], - metrics=SimpleNamespace(time_in_queue=0.01), - ) - ) - - return output - - -def get_fake_engine_and_request(llm_config: LLMConfig, expected_out: List[str]): - vllm_engine = VLLMEngine(llm_config) - # We normally set the model config when calling VLLMEngine.start() - vllm_engine.model_config = Mock() - vllm_engine.model_config.max_model_len = 1 - - engine_mock = Mock() - vllm_engine._engine_client = FakeVLLMEngine( - engine_mock, get_fake_responses(*expected_out) - ) - - req = VLLMGenerationRequest( - prompt="prompt", - request_id="req_id", - sampling_params=VLLMSamplingParams(), - disk_multiplex_config=None, - stream=True, - ) - return vllm_engine, req, engine_mock - - -class TestVLLMEngine: - """Test the VLLMEngine.""" - - @pytest.mark.asyncio - async def test_generate(self, llm_config): - expected_out = ["hi ", "i ", "am ", "vllm."] - vllm_engine, req, engine_mock = get_fake_engine_and_request( - llm_config, expected_out - ) - - cur_idx = 0 - async for x in vllm_engine.generate(req): - if cur_idx < len(expected_out): - assert x.generated_text == expected_out[cur_idx] - cur_idx += 1 - assert x.generation_time == pytest.approx( - 0.01, abs=0.01 - ), "We are sleeping for this long before returning tokens in the fake" - assert ( - x.num_input_tokens == 1 - ), "We are setting the num input tokens to len 1 in the fake output" - else: - assert x.finish_reason == FinishReason.STOP - - await asyncio.sleep(0.02) # wait for asyncio task scheduling - - # Abort should be called - engine_mock.abort.assert_called_once_with("req_id") - - @pytest.mark.asyncio - async def test_vllm_engine_error_in_caller(self, llm_config): - expected_out = ["hi ", "i ", "am ", "vllm."] - vllm_engine, req, engine_mock = get_fake_engine_and_request( - llm_config, expected_out - ) - - with pytest.raises(RuntimeError): - async for _x in vllm_engine.generate(req): - raise RuntimeError() - - await asyncio.sleep(0.02) # wait for asyncio task scheduling - # Abort should be called - engine_mock.abort.assert_called_once_with("req_id") - - @pytest.mark.asyncio - async def test_vllm_engine_caller_cancellation(self, llm_config): - expected_out = ["hi ", "i ", "am ", "vllm.", "and more"] * 10 # many tokens - vllm_engine, req, engine_mock = get_fake_engine_and_request( - llm_config, expected_out - ) - - async def run(): - async for x in vllm_engine.generate(req): - print(x) - - task = asyncio.create_task(run()) - await asyncio.sleep(0.02) # wait for some tokens to be returned - - # Cancel the task - task.cancel() - - await asyncio.sleep(0.02) # wait for asyncio task scheduling - # Abort should be called - engine_mock.abort.assert_called_once_with("req_id") - assert ( - vllm_engine._engine_client.num_generated <= 4 - ), "We should have generated not more than 4 tokens" - - @pytest.mark.parametrize("enable_json_mode", [True, False]) - def test_parse_sampling_params_json_mode( - self, llm_config: LLMConfig, enable_json_mode: bool - ): - # Make a deep copy to avoid modifying the session-scoped fixture - llm_config = llm_config.model_copy(deep=True) - vllm_engine = VLLMEngine(llm_config) - - # Mock model_config to avoid None errors - vllm_engine.model_config = Mock() - vllm_engine.model_config.max_model_len = 1000 - - # Create sampling params with response format - sampling_params = VLLMSamplingParams( - response_format={ - "type": "json_object", - "schema": { - "type": "object", - "properties": {"name": {"type": "string"}}, - }, - } - ) - - # Parse the sampling params - parsed_params = vllm_engine._parse_sampling_params(sampling_params) - - # For both cases we should now have guided decoding since we are using oss vllm. - # When json_mode is disabled, guided_decoding should be used instead - assert hasattr(parsed_params, "guided_decoding") - # Parse the JSON string from guided_decoding into a dict - guided_json = json.loads(parsed_params.guided_decoding.json) - assert guided_json == sampling_params.response_format.json_schema - assert getattr(parsed_params, "response_format", None) is None - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py b/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py index c73e8d3cfa6f..c6cd17b3f66a 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py +++ b/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py @@ -1,18 +1,12 @@ import sys -from unittest.mock import patch import pytest -from vllm.config import KVTransferConfig -from vllm.platforms.interface import UnspecifiedPlatform -from ray.llm._internal.serve.configs.prompt_formats import Prompt -from ray.llm._internal.serve.configs.server_models import LLMRawResponse +from ray.serve.llm import LLMConfig + from ray.llm._internal.serve.deployments.prefill_decode_disagg.prefill_decode_disagg import ( build_app, ) -from ray.llm.tests.serve.mocks.mock_vllm_engine import MockPDDisaggVLLMEngine -from ray.serve.llm import LLMConfig, ModelLoadingConfig -from ray.serve.llm.openai_api_models import ChatCompletionRequest class TestServingArgsParsing: @@ -55,127 +49,5 @@ def test_parse_dict(self): assert app is not None -class FakePlatform(UnspecifiedPlatform): - """ - vllm UnspecifiedPlatform has some interfaces that's left unimplemented, which - could trigger exception in following tests. So we implement needed interfaces - and patch. - """ - - def is_async_output_supported(self, enforce_eager: bool) -> bool: - return True - - -class TestPDDisaggLLMServer: - """Test PD-disaggregated LLM server. - - A real P/D disaggregation use case will spawn multiple LLM servers, - so this test suite just does smoke test and verifies certain expected - parameters exist in responses. - """ - - @pytest.mark.asyncio - @patch("vllm.platforms.current_platform", FakePlatform()) - async def test_chat_non_streaming( - self, - create_server, - # model_pixtral_12b is a fixture that only contains config files without weights - model_pixtral_12b, - ): - """This is smoke testing that normal chat completion works.""" - llm_config = LLMConfig( - # Here we - # 1. want to skip GPU placement in cpu test cases (https://github.com/ray-project/ray/blob/945b9d5dd55c9215d0aeb94a66cfda3b71c2fd43/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py#L330) - # 2. cannot set it to None, otherwise it defaults to use_gpu=True (https://github.com/ray-project/ray/blob/c7e07328c9efbd0d67bf2da4fa098d6492478ef4/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py#L159) - # 3. cannot use "CPU" or anything random, which violates the check (https://github.com/ray-project/ray/blob/945b9d5dd55c9215d0aeb94a66cfda3b71c2fd43/python/ray/llm/_internal/serve/configs/server_models.py#L325) - # so we select a non-NVIDIA type here: Intel-GAUDI. - accelerator_type="Intel-GAUDI", - model_loading_config=ModelLoadingConfig( - model_id=model_pixtral_12b, - ), - engine_kwargs={ - "kv_transfer_config": KVTransferConfig( - kv_connector="NixlConnector", - kv_role="kv_both", - ), - }, - ) - - server = await create_server(llm_config, engine_cls=MockPDDisaggVLLMEngine) - - # Create a chat completion request - request = ChatCompletionRequest( - model="test_model", - messages=[dict(role="user", content="Hello")], - stream=False, - max_tokens=5, - ) - - # Get the response - response_stream = await server.chat(request) - - # Collect responses (should be just one) - responses = [r async for r in response_stream] - - # Check that we got one response - assert len(responses) == 1 - assert responses[0].choices[0].message.role == "assistant" - assert ( - responses[0].choices[0].message.content - == "mock_pd_client_response_0 mock_pd_client_response_1 mock_pd_client_response_2 mock_pd_client_response_3 mock_pd_client_response_4 " - ) - - @pytest.mark.asyncio - @patch("vllm.platforms.current_platform", FakePlatform()) - async def test_predict_non_streaming( - self, - create_server, - # model_pixtral_12b is a fixture that only contains config files without weights - model_pixtral_12b, - ): - """Test non-streaming predict.""" - llm_config = LLMConfig( - # Here we - # 1. want to skip GPU placement in cpu test cases (https://github.com/ray-project/ray/blob/945b9d5dd55c9215d0aeb94a66cfda3b71c2fd43/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py#L330) - # 2. cannot set it to None, otherwise it defaults to use_gpu=True (https://github.com/ray-project/ray/blob/c7e07328c9efbd0d67bf2da4fa098d6492478ef4/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py#L159) - # 3. cannot use "CPU" or anything random, which violates the check (https://github.com/ray-project/ray/blob/945b9d5dd55c9215d0aeb94a66cfda3b71c2fd43/python/ray/llm/_internal/serve/configs/server_models.py#L325) - # so we select a non-NVIDIA type here: Intel-GAUDI. - accelerator_type="Intel-GAUDI", - model_loading_config=ModelLoadingConfig( - model_id=model_pixtral_12b, - ), - engine_kwargs={ - "kv_transfer_config": KVTransferConfig( - kv_connector="NixlConnector", - kv_role="kv_both", - ), - }, - ) - - server = await create_server(llm_config, engine_cls=MockPDDisaggVLLMEngine) - - # Create a predict request - request = Prompt( - prompt="test prompt", - parameters=dict( - max_tokens=1, - stream=False, - kv_transfer_params=dict(field_that_does_not_matter="1"), - ), - ) - - # Get the response - responses: list[LLMRawResponse] = [] - async for response in server._predict( - request_id="test_request_id", prompt=request, stream=False - ): - responses.append(response) - - # Collect responses (should be just one) - assert len(responses) == 1 - assert responses[0].generated_text == "mock_pd_client_response_0 " - assert responses[0].metadata is not None - - if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py b/python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py index 5ba14036df08..4204231fd069 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py +++ b/python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py @@ -86,7 +86,7 @@ async def test_chat(self, stream_batching_interval_ms, client, stream): role = response.choices[0].message.role assert role == "assistant" - assert text == "".join([f"test_{i} " for i in range(n_tokens)]) + assert text.strip() == " ".join([f"test_{i}" for i in range(n_tokens)]) @pytest.mark.asyncio @pytest.mark.parametrize("stream_batching_interval_ms", [None, 0, 10000]) @@ -112,8 +112,8 @@ async def test_completion(self, stream_batching_interval_ms, client, stream): text = response.choices[0].text # The mock engine produces "test_0 test_1 test_2 ..." pattern - expected_text = "".join([f"test_{i} " for i in range(n_tokens)]) - assert text == expected_text + expected_text = " ".join([f"test_{i}" for i in range(n_tokens)]) + assert text.strip() == expected_text def test_router_with_num_router_replicas_config(self): """Test the router with num_router_replicas config.""" diff --git a/python/ray/llm/tests/serve/gpu/deployments/llm/vllm/test_vllm_engine_gpu.py b/python/ray/llm/tests/serve/gpu/deployments/llm/vllm/test_vllm_engine_gpu.py deleted file mode 100644 index 0607bd59951d..000000000000 --- a/python/ray/llm/tests/serve/gpu/deployments/llm/vllm/test_vllm_engine_gpu.py +++ /dev/null @@ -1,73 +0,0 @@ -import sys - -import pytest - -from ray.llm._internal.serve.configs.server_models import ( - LLMConfig, -) -from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import ( - VLLMEngine, - _get_vllm_engine_config, -) - - -class TestVLLMEngine: - """Test the VLLMEngine.""" - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "engine_kwargs, expected_prompt_limit", - [ - ({"enable_chunked_prefill": True}, 1024000), - ( - { - "enable_chunked_prefill": True, - "max_model_len": 999, - }, - 999, - ), - ( - { - "enable_chunked_prefill": True, - "max_num_batched_tokens": 888, - }, - 1024000, - ), - ( - { - "enable_chunked_prefill": True, - "max_model_len": 999, - "max_num_batched_tokens": 888, - "enforce_eager": True, - }, - 999, - ), - ({"enable_chunked_prefill": False}, 1024000), - ( - { - "enable_chunked_prefill": False, - "max_model_len": 999, - }, - 999, - ), - ], - ) - async def test_get_prompt_limit( - # llm_config is a fixture defined in serve.tests.conftest.py - self, - llm_config: LLMConfig, - engine_kwargs: dict, - expected_prompt_limit: int, - ): - llm_config = llm_config.model_copy(deep=True) - vllm_engine = VLLMEngine(llm_config) - - # Test with default engine kwargs - llm_config.engine_kwargs = engine_kwargs - _, vllm_config = _get_vllm_engine_config(llm_config) - vllm_engine.vllm_config = vllm_config - assert vllm_engine._get_prompt_limit() == expected_prompt_limit - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/gpu/integration/test_openai_compatibility.py b/python/ray/llm/tests/serve/gpu/integration/test_openai_compatibility.py index a5405cbded72..e1a4f02b8c22 100644 --- a/python/ray/llm/tests/serve/gpu/integration/test_openai_compatibility.py +++ b/python/ray/llm/tests/serve/gpu/integration/test_openai_compatibility.py @@ -3,11 +3,6 @@ import openai import pytest -from ray.llm._internal.serve.configs.constants import ( - MAX_NUM_TOPLOGPROBS_ALLOWED, - MIN_NUM_TOPLOGPROBS_ALLOWED, -) - class TestOpenAICompatibility: """Test that the rayllm are compatible with the OpenAI API""" @@ -17,7 +12,7 @@ def test_models(self, testing_model): # noqa: F811 models = client.models.list() assert len(models.data) == 1, "Only the test model should be returned" assert models.data[0].id == model, "The test model id should match" - assert models.data[0].rayllm_metadata["input_modality"] == "text" + assert models.data[0].metadata["input_modality"] == "text" def test_completions(self, testing_model): # noqa: F811 client, model = testing_model @@ -28,7 +23,7 @@ def test_completions(self, testing_model): # noqa: F811 ) assert completion.model == model assert completion.model - assert completion.choices[0].text == "test_0 test_1 " + assert completion.choices[0].text == "test_0 test_1" def test_chat(self, testing_model): # noqa: F811 client, model = testing_model @@ -43,97 +38,6 @@ def test_chat(self, testing_model): # noqa: F811 assert isinstance(chat_completion.choices, list) assert chat_completion.choices[0].message.content - def test_chat_logprobs(self, testing_model): - client, model = testing_model - num_tokens = 5 - # test logprobs for non-streaming chat completions - for top_logprobs in range(5): - chat_completion = client.chat.completions.create( - model=model, - max_tokens=num_tokens, - messages=[{"role": "user", "content": "Hello world"}], - logprobs=True, - top_logprobs=top_logprobs, - ) - logprobs = chat_completion.choices[0].logprobs.content - assert logprobs, "Logprobs should be not be None or Empty" - assert len(logprobs) == num_tokens - assert all( - len(logprob.top_logprobs) == top_logprobs for logprob in logprobs - ) - text_from_logprobs = [] - for logprob in logprobs: - text_from_logprobs.append(logprob.token) - if logprob.top_logprobs: - assert logprob.token == logprob.top_logprobs[0].token - text_from_logprobs = "".join(text_from_logprobs) - assert ( - text_from_logprobs == chat_completion.choices[0].message.content - ), "Text from logprobs should match text from completion" - - for num_top_logprobs in range(5): - chat_completion = client.chat.completions.create( - model=model, - max_tokens=num_tokens, - messages=[{"role": "user", "content": "Hello world"}], - logprobs=True, - top_logprobs=num_top_logprobs, - stream=True, - ) - - for c in chat_completion: - choice_logprobs = c.choices[0].logprobs - if choice_logprobs and choice_logprobs.content: - for chat_completion_token_logprob in choice_logprobs.content: - top_logprobs_res = chat_completion_token_logprob.top_logprobs - assert len(top_logprobs_res) == num_top_logprobs - if top_logprobs_res: - assert ( - top_logprobs_res[0].token - == chat_completion_token_logprob.token - ) - - # try to send logprobs request with invalid number of toplogprobs - with pytest.raises(openai.BadRequestError): - for top_logprobs in [ - MAX_NUM_TOPLOGPROBS_ALLOWED + 1, - MIN_NUM_TOPLOGPROBS_ALLOWED - 1, - ]: - client.chat.completions.create( - model=model, - max_tokens=num_tokens, - messages=[{"role": "user", "content": "Hello world"}], - logprobs=True, - top_logprobs=top_logprobs, - ) - - def test_completions_bad_request(self, testing_model): # noqa: F811 - client, model = testing_model - with pytest.raises(openai.BadRequestError) as exc_info: - client.completions.create( - model=model, - prompt="Hello world", - temperature=-0.1, - ) - assert "temperature" in str(exc_info.value) - - def test_chat_bad_request(self, testing_model): # noqa: F811 - client, model = testing_model - with pytest.raises(openai.BadRequestError) as exc_info: - client.chat.completions.create( - model=model, - messages=[{"role": "user", "content": "Hello world"}], - temperature=-0.1, - ) - assert "temperature" in str(exc_info.value) - - with pytest.raises(openai.BadRequestError) as exc_info: - client.chat.completions.create( - model=model, - messages=[], - ) - assert "least 1 item" in str(exc_info.value) - def test_completions_missing_model(self, testing_model): # noqa: F811 client, _ = testing_model with pytest.raises(openai.NotFoundError) as exc_info: @@ -174,8 +78,12 @@ def test_chat_stream(self, testing_model): # noqa: F811 model=model, messages=[{"role": "user", "content": "Hello world"}], stream=True, + stream_options=dict( + include_usage=True, + ), temperature=0.4, frequency_penalty=0.02, + max_tokens=5, ): if i == 0: assert chat_completion @@ -190,45 +98,6 @@ def test_chat_stream(self, testing_model): # noqa: F811 chat_completion.choices[0].delta, "content" ) i += 1 - assert chat_completion - assert chat_completion.id - assert isinstance(chat_completion.choices, list) - assert not chat_completion.choices[0].delta.content - assert chat_completion.choices[0].finish_reason - assert i > 4 - - def test_completions_stream_bad_request(self, testing_model): # noqa: F811 - client, model = testing_model - with pytest.raises(openai.BadRequestError) as exc_info: - for _ in client.completions.create( - model=model, - prompt="Hello world", - stream=True, - temperature=-0.1, - ): - pass - assert "temperature" in str(exc_info.value) - - def test_chat_stream_bad_request(self, testing_model): # noqa: F811 - client, model = testing_model - with pytest.raises(openai.BadRequestError) as exc_info: - for _chat_completion in client.chat.completions.create( - model=model, - messages=[{"role": "user", "content": "Hello world"}], - stream=True, - temperature=-0.1, - ): - pass - assert "temperature" in str(exc_info.value) - - with pytest.raises(openai.BadRequestError) as exc_info: - for _chat_completion in client.chat.completions.create( - model=model, - messages=[], - stream=True, - ): - pass - assert "least 1 item" in str(exc_info.value) def test_completions_stream_missing_model(self, testing_model): # noqa: F811 client, _ = testing_model diff --git a/python/ray/llm/tests/serve/gpu/integration/test_openai_compatibility_no_accelerator_type.py b/python/ray/llm/tests/serve/gpu/integration/test_openai_compatibility_no_accelerator_type.py index 549f655da85b..1142700b34ed 100644 --- a/python/ray/llm/tests/serve/gpu/integration/test_openai_compatibility_no_accelerator_type.py +++ b/python/ray/llm/tests/serve/gpu/integration/test_openai_compatibility_no_accelerator_type.py @@ -27,7 +27,7 @@ def test_completions_no_accelerator_type( ) assert completion.model == model assert completion.model - assert completion.choices[0].text == "test_0 test_1 " + assert completion.choices[0].text == "test_0 test_1" def test_chat_no_accelerator_type(self, testing_model_no_accelerator): # noqa: F811 """Check chat completions without accelerator_type""" diff --git a/python/ray/llm/tests/serve/mocks/mock_vllm_engine.py b/python/ray/llm/tests/serve/mocks/mock_vllm_engine.py index 356e1d3b3313..7f33b93d0ce6 100644 --- a/python/ray/llm/tests/serve/mocks/mock_vllm_engine.py +++ b/python/ray/llm/tests/serve/mocks/mock_vllm_engine.py @@ -2,628 +2,278 @@ import json import random from random import randint -from typing import AsyncGenerator, Dict, Optional - -from PIL import Image -from transformers import AutoTokenizer -from vllm import CompletionOutput, PromptType, RequestOutput -from vllm.config import DeviceConfig, KVTransferConfig, ModelConfig, VllmConfig -from vllm.engine.protocol import EngineClient -from vllm.sampling_params import SamplingParams as VLLMInternalSamplingParams - -from ray.llm._internal.serve.configs.error_handling import ValidationError -from ray.llm._internal.serve.configs.openai_api_models_patch import ( - ResponseFormatJsonObject, +from typing import AsyncGenerator, Dict, Optional, Any, List, Union + +from ray.llm._internal.serve.configs.openai_api_models import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, ) from ray.llm._internal.serve.configs.server_models import ( DiskMultiplexConfig, - FinishReason, LLMConfig, - LLMRawResponse, - LogProb, - LogProbs, - Prompt, ) from ray.llm._internal.serve.deployments.llm.llm_engine import LLMEngine -from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import VLLMEngine -from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine_stats import ( - VLLMEngineStats, - VLLMEngineStatTracker, -) -from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( - KV_TRANSFER_PARAMS_KEY, - VLLMGenerationRequest, - VLLMSamplingParams, +from ray.llm._internal.serve.deployments.llm.multiplex.lora_model_loader import ( + LoraModelLoader, ) class MockVLLMEngine(LLMEngine): + """Mock vLLM Engine that generates fake text responses. + + - In case of LoRA it generates a prefix with the model name in the text part of the response. + """ + def __init__(self, llm_config: LLMConfig): - """Create a vLLM Engine class + """Create a mock vLLM Engine. Args: llm_config: The llm configuration for this engine """ - assert isinstance( - llm_config, LLMConfig - ), f"Got invalid config {llm_config} of type {type(llm_config)}" self.llm_config = llm_config - - self._stats = VLLMEngineStatTracker() - - async def start(self): - """No-Op""" - return - - @staticmethod - async def async_range(count): - for i in range(count): - yield i - await asyncio.sleep(0.0) - - async def prepare_request( - self, request_id: str, prompt: Prompt, stream: bool, **kwargs - ) -> VLLMGenerationRequest: - - if isinstance(prompt.prompt, list): - # Simplification: Assume prompt is a list of messages with one user message - assert len(prompt.prompt) == 1 - assert hasattr(prompt.prompt[0], "content") - prompt_text = prompt.prompt[0].content - else: - prompt_text = prompt.prompt - - return VLLMGenerationRequest( - request_id=request_id, - prompt=prompt_text, - stream=stream, - sampling_params=VLLMSamplingParams.from_prompt(prompt), - ) - - async def generate(self, vllm_engine_request: VLLMGenerationRequest): - sampling_params = self._parse_sampling_params( - vllm_engine_request.sampling_params - ) - max_tokens = sampling_params.max_tokens - if not max_tokens: - max_tokens = randint(1, 10) - prompt = vllm_engine_request.prompt - prompt_len = ( - len(prompt.split()) if isinstance(prompt, str) else len(prompt.prompt) - ) - generation_time = 0.001 - - async for i in self.async_range(max_tokens): - if i == max_tokens - 1: - finish_reason = FinishReason.STOP - else: - finish_reason = None - llm_response = LLMRawResponse( - generated_text=f"test_{i} ", - num_input_tokens=prompt_len, - num_input_tokens_batch=prompt_len, - num_generated_tokens=1, - preprocessing_time=0, - generation_time=generation_time, - finish_reason=finish_reason, - logprobs=self.get_logprobs(i, vllm_engine_request, sampling_params), - ) - yield llm_response - await asyncio.sleep(generation_time) - - async def check_health(self) -> None: - return - - def stats(self) -> VLLMEngineStats: - return self._stats.to_stats() - - def shutdown(self, shutdown_pg: bool = True): - raise NotImplementedError() - - def _parse_sampling_params( - self, sampling_params: VLLMSamplingParams - ) -> VLLMInternalSamplingParams: - try: - if sampling_params.n != 1: - raise ValueError("n>1 is not supported yet in rayllm") - if sampling_params.logprobs: - if sampling_params.top_logprobs: - if not (0 <= sampling_params.top_logprobs <= 5): - raise ValueError("top_logprobs must be between 0 and 5") - log_probs = sampling_params.top_logprobs - else: - log_probs = 1 - else: - if sampling_params.top_logprobs: - raise ValueError( - "if top_logprobs is specified, logprobs must be set to `True`" - ) - log_probs = None - - return VLLMInternalSamplingParams( - n=1, - best_of=sampling_params.best_of, - presence_penalty=sampling_params.presence_penalty - if sampling_params.presence_penalty is not None - else 0.0, - frequency_penalty=sampling_params.frequency_penalty - if sampling_params.frequency_penalty is not None - else 0.0, - repetition_penalty=sampling_params.repetition_penalty - if sampling_params.repetition_penalty is not None - else 1.0, - temperature=sampling_params.temperature - if sampling_params.temperature is not None - else 1.0, - top_p=sampling_params.top_p - if sampling_params.top_p is not None - else 1.0, - top_k=sampling_params.top_k - if sampling_params.top_k is not None - else -1, - stop=sampling_params.stop, - stop_token_ids=sampling_params.stop_tokens, - ignore_eos=False, - # vLLM will cancel internally if input+output>max_tokens - max_tokens=sampling_params.max_tokens - or self.llm_config.max_request_context_length, - logprobs=log_probs, - ) - except Exception as e: - # Wrap the error in ValidationError so the status code - # returned to the user is correct. - raise ValidationError(str(e)) from e - - def get_logprobs( - self, - i: int, - vllm_engine_request: VLLMGenerationRequest, - sampling_params: VLLMSamplingParams, - ): - """Helper function for generating LLMRawResponse logprobs""" - num_logprobs = sampling_params.logprobs - top_logprobs = vllm_engine_request.sampling_params.top_logprobs - if num_logprobs: - log_probs = [ - LogProbs.create( - logprobs=[ - LogProb( - logprob=0.0, - token=( - f"test_{i} " if idx == 0 else f"candidate_token_{idx}" - ), - bytes=[], - ) - for idx in range(num_logprobs) - ], - top_logprobs=top_logprobs, - ) - ] - else: - log_probs = None - - return log_probs - - -class MockEchoVLLMEngine(MockVLLMEngine): - """ - Mock engine that responds with information about the request sent to it. Useful - for testing the contents of VLLMGenerationRequests created in RayLLM code up to - the vLLM boundary. - """ - - def _convert_to_json(self, vllm_engine_request: VLLMGenerationRequest) -> Dict: - """Converts request to json. - - If the request contains an image, this method removes the image - from `vllm_engine_request` and sets `has_image: true` in the - output dictionary. - This is because `Image.Image` is not json serializable. - """ - mm_data = vllm_engine_request.multi_modal_data - if isinstance(mm_data, dict) and "image" in mm_data: - assert isinstance(mm_data["image"], Image.Image) or ( - isinstance(mm_data["image"], list) - and all( - [ - isinstance(image, Image.Image) - for image in vllm_engine_request.multi_modal_data["image"] - ] - ) - ), "Image must be of type Image.Image or a list of Image.Image" - mm_data["image"] = None - has_image = True - else: - has_image = False - res = vllm_engine_request.model_dump() - res.update({"has_image": has_image}) - return json.dumps(res) - - async def generate(self, vllm_engine_request: VLLMGenerationRequest): - yield LLMRawResponse( - generated_text=self._convert_to_json(vllm_engine_request), - num_input_tokens=0, - num_input_tokens_batch=0, - num_generated_tokens=1, - preprocessing_time=0, - generation_time=0.01, - finish_reason=FinishReason.STOP, - logprobs=None, - ) - - -class MockMultiplexEngine(LLMEngine): - def __init__(self, *args, **kwargs): self.started = False - - async def prepare_request( - self, - request_id: str, - prompt: Prompt, - stream: bool, - disk_lora_model: Optional[DiskMultiplexConfig] = None, - ) -> VLLMGenerationRequest: - - if isinstance(prompt.prompt, list): - # Simplification: Assume prompt is a list of messages with one user message - assert len(prompt.prompt) == 1 - assert hasattr(prompt.prompt[0], "content") - prompt_text = prompt.prompt[0].content - else: - prompt_text = prompt.prompt - - output = VLLMGenerationRequest( - request_id=request_id, - prompt=prompt_text, - stream=stream, - sampling_params=VLLMSamplingParams.from_prompt(prompt), - disk_multiplex_config=disk_lora_model, - ) - return output + self._current_lora_model: Dict[str, DiskMultiplexConfig] = {} async def start(self): + """Start the mock engine.""" self.started = True - async def generate(self, arg): - assert self.started, "Engine was not started" - yield arg - - async def check_health(self): - return True + async def resolve_lora(self, lora_model: DiskMultiplexConfig): + """Resolve/load a LoRA model.""" + self._current_lora_model[lora_model.model_id] = lora_model + async def check_health(self) -> None: + """Check the health of the mock engine.""" + if not self.started: + raise RuntimeError("Engine not started") + + async def chat( + self, request: ChatCompletionRequest + ) -> AsyncGenerator[Union[str, ChatCompletionResponse, ErrorResponse], None]: + """Mock chat completion.""" + if not self.started: + raise RuntimeError("Engine not started") + + # Extract prompt text from messages + prompt_text = "" + if request.messages: + for message in request.messages: + if hasattr(message, "content") and message.content: + prompt_text += str(message.content) + " " + + max_tokens = getattr(request, "max_tokens", None) or randint(1, 10) + + # Generate streaming response + async for response in self._generate_chat_response( + request=request, prompt_text=prompt_text.strip(), max_tokens=max_tokens + ): + yield response -class FakeLoraModelLoader: - async def load_model( - self, lora_model_id: str, llm_config: LLMConfig - ) -> DiskMultiplexConfig: - return DiskMultiplexConfig.model_validate( - { - "model_id": lora_model_id, - "max_total_tokens": llm_config.max_request_context_length, - "local_path": "/local/path", - "lora_assigned_int_id": 1, - } - ) + async def completions( + self, request: CompletionRequest + ) -> AsyncGenerator[Union[str, CompletionResponse, ErrorResponse], None]: + """Mock text completion.""" + if not self.started: + raise RuntimeError("Engine not started") + prompt_text = str(request.prompt) if request.prompt else "" + max_tokens = getattr(request, "max_tokens", None) or randint(5, 20) -class MockJSONModeVLLMEngine(MockVLLMEngine): - async def generate_text(self, max_tokens, prompt_len): - generation_time = 0.001 - async for i in self.async_range(max_tokens): - if i == max_tokens - 1: - finish_reason = FinishReason.STOP - else: - finish_reason = None - llm_response = LLMRawResponse( - generated_text=f"test_{i} ", - num_input_tokens=prompt_len, - num_input_tokens_batch=prompt_len, - num_generated_tokens=1, - preprocessing_time=0, - generation_time=generation_time, - finish_reason=finish_reason, - ) - yield llm_response - await asyncio.sleep(generation_time) - - async def generate_json(self, json_schema, max_tokens, prompt_len): - random_valid_json = str(generate_from_schema(json_schema)) - # the json has double quotes where single quotes should be and single quotes where double quotes should be: - random_valid_json = random_valid_json.replace("'", '"') - - tokens = split_string_into_chunks(random_valid_json, max_tokens) - - generation_time = 0.001 - async for i in self.async_range(max_tokens): - finish_reason = None - if i == max_tokens - 1: - finish_reason = FinishReason.STOP - - generated_text = tokens[i] - llm_response = LLMRawResponse( - generated_text=generated_text, - num_input_tokens=prompt_len, - num_input_tokens_batch=prompt_len, - num_generated_tokens=1, - preprocessing_time=0, - generation_time=generation_time, - finish_reason=finish_reason, + # Generate streaming response + async for response in self._generate_completion_response( + request=request, prompt_text=prompt_text, max_tokens=max_tokens + ): + yield response + + async def embeddings( + self, request: EmbeddingRequest + ) -> AsyncGenerator[Union[str, EmbeddingResponse, ErrorResponse], None]: + """Mock embeddings generation.""" + if not self.started: + raise RuntimeError("Engine not started") + + # Generate a mock embedding response + embedding_data = [] + inputs = request.input if isinstance(request.input, list) else [request.input] + + for i, text in enumerate(inputs): + # Generate random embedding vector + dimensions = getattr(request, "dimensions", None) or 1536 + embedding = [random.uniform(-1, 1) for _ in range(dimensions)] + + embedding_data.append( + {"object": "embedding", "embedding": embedding, "index": i} ) - yield llm_response - await asyncio.sleep(generation_time) - async def generate(self, vllm_engine_request: VLLMGenerationRequest): - sampling_params = self._parse_sampling_params( - vllm_engine_request.sampling_params + response = EmbeddingResponse( + object="list", + data=embedding_data, + model=getattr(request, "model", "mock-model"), + usage={ + "prompt_tokens": len(str(request.input).split()), + "total_tokens": len(str(request.input).split()), + }, ) - max_tokens = sampling_params.max_tokens - if not max_tokens: - max_tokens = randint(1, 10) - prompt = vllm_engine_request.prompt - prompt_len = get_prompt_length(prompt) - response_format = sampling_params.response_format - if response_format and isinstance(response_format, ResponseFormatJsonObject): - response_format = sampling_params.response_format - generator = self.generate_json( - response_format.json_schema, - max_tokens=max_tokens, - prompt_len=prompt_len, - ) + yield response + + async def _generate_chat_response( + self, request: ChatCompletionRequest, prompt_text: str, max_tokens: int + ) -> AsyncGenerator[Union[str, ChatCompletionResponse], None]: + """Generate mock chat completion response.""" + + request_id = request.request_id or f"chatcmpl-{random.randint(1000, 9999)}" + lora_prefix = ( + "" + if request.model not in self._current_lora_model + else f"[lora_model] {request.model}: " + ) + if request.stream: + # Streaming response - return SSE formatted strings + created_time = int(asyncio.get_event_loop().time()) + model_name = getattr(request, "model", "mock-model") + + for i in range(max_tokens): + if i == 0: + token = f"{lora_prefix}test_{i} " + else: + token = f"test_{i} " + if i == max_tokens - 1: + # no space for the last token + token = f"test_{i}" + + # Create streaming chunk + choice = { + "index": 0, + "delta": { + "content": token, + "role": "assistant" if i == 0 else None, + }, + "finish_reason": "stop" if i == max_tokens - 1 else None, + } + + chunk_data = { + "id": request_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model_name, + "choices": [choice], + } + + # Format as SSE + yield f"data: {json.dumps(chunk_data)}\n\n" + await asyncio.sleep(0.01) # Simulate processing time + + # Send final [DONE] message + yield "data: [DONE]\n\n" else: - generator = self.generate_text(max_tokens=max_tokens, prompt_len=prompt_len) - async for x in generator: - yield x - - def _parse_sampling_params( - self, sampling_params: VLLMSamplingParams - ) -> VLLMInternalSamplingParams: - new_sampling_params = super()._parse_sampling_params(sampling_params) - new_sampling_params.response_format = sampling_params.response_format - return new_sampling_params - - -class MockPDDisaggVLLMEngineClient(EngineClient): - """ - Mock vllm EngineClient that supports PD Disaggregation. - """ - - def __init__(self, vllm_config: VllmConfig): - self._llm_config = vllm_config - self._model_config = vllm_config.model_config - - @property - def kv_transfer_config(self): - # https://github.com/vllm-project/vllm/blob/980a172474fa0f32433dda87ae1fa4aadba24c51/vllm/config.py#L4061 - kv_transfer_config = self._llm_config.kv_transfer_config - if kv_transfer_config is not None: - assert isinstance(kv_transfer_config, KVTransferConfig) - return kv_transfer_config - - @staticmethod - async def async_range(count): - for i in range(count): - yield i - await asyncio.sleep(0.0) - - def is_running(self) -> bool: - return True - - @property - def is_stopped(self) -> bool: - return False - - @property - def errored(self) -> bool: - return False - - @property - def dead_error(self) -> BaseException: - return None - - def generate( - self, - prompt: PromptType, - sampling_params: VLLMInternalSamplingParams, - request_id: str, - **kwargs, - ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request.""" - max_tokens = sampling_params.max_tokens or randint(1, 10) - - # vLLM uses `extra_args` to pass in `kv_transfer_params`: - # https://github.com/vllm-project/vllm/blob/980a172474fa0f32433dda87ae1fa4aadba24c51/vllm/v1/request.py#L65 - kv_transfer_params = None - if ( - self.kv_transfer_config is not None - and KV_TRANSFER_PARAMS_KEY in sampling_params.extra_args - ): - # For now we don't test the items in request/response, so just pass empty dict. - kv_transfer_params = {} # noqa: F841 - - async def generate_response(): - # vLLM EngineClient spits accumulated output in the response. - # ray serve's engine spits output in chunk. - accumulated_output = "" - async for i in self.async_range(max_tokens): - accumulated_output += f"mock_pd_client_response_{i} " - yield RequestOutput( - finished=(i == max_tokens - 1), - request_id=request_id, - prompt=prompt, - prompt_token_ids=[i], - prompt_logprobs=[0.0], - outputs=[ - CompletionOutput( - index=i, - text=accumulated_output, - token_ids=[i], - cumulative_logprob=None, - logprobs=None, - ) - ], - kv_transfer_params=kv_transfer_params, - ) - - return generate_response() - - def encode( - self, - prompt: PromptType, - request_id: str, - **kwargs, - ) -> AsyncGenerator: - """Generate outputs for a request from a pooling model.""" - raise NotImplementedError("Not expected to be reached") - - async def abort(self, request_id: str) -> None: - """Abort a request. - - Args: - request_id: The unique id of the request. - """ - return - - async def get_vllm_config(self): - """Get the vllm configuration of the vLLM engine.""" - return self._llm_config - - async def get_model_config(self): - """Get the model configuration of the vLLM engine.""" - return self._model_config - - async def get_decoding_config(self): - """Get the decoding configuration of the vLLM engine.""" - raise NotImplementedError("Not expected to be reached") - - async def get_input_preprocessor(self): - """Get the input processor of the vLLM engine.""" - raise NotImplementedError("Not expected to be reached") - - async def get_tokenizer( - self, - lora_request=None, - ) -> any: - """Get the appropriate tokenizer for the request""" - return AutoTokenizer.from_pretrained(self._model_config.model) - - async def is_tracing_enabled(self) -> bool: - """Check if tracing is enabled""" - raise NotImplementedError("Not expected to be reached") - - async def do_log_stats( - self, - scheduler_outputs=None, - model_output=None, - ) -> None: - raise NotImplementedError("Not expected to be reached") + # Non-streaming response - return response object + generated_text = " ".join([f"test_{i}" for i in range(max_tokens)]) + generated_text = f"{lora_prefix}{generated_text}" + + choice = { + "index": 0, + "message": {"role": "assistant", "content": generated_text}, + "finish_reason": "stop", + } - async def check_health(self) -> None: - """Raise if unhealthy""" - return - - async def start_profile(self) -> None: - """Start profiling the engine""" - raise NotImplementedError("Not expected to be reached") - - async def stop_profile(self) -> None: - """Start profiling the engine""" - raise NotImplementedError("Not expected to be reached") - - async def reset_prefix_cache(self, device=None) -> None: - """Reset the prefix cache""" - raise NotImplementedError("Not expected to be reached") - - async def sleep(self, level: int = 1) -> None: - """Sleep the engine""" - raise NotImplementedError("Not expected to be reached") - - async def wake_up(self, tags: Optional[list[str]] = None) -> None: - """Wake up the engine""" - raise NotImplementedError("Not expected to be reached") - - async def is_sleeping(self) -> bool: - """Check whether the engine is sleeping""" - raise NotImplementedError("Not expected to be reached") - - async def add_lora(self, lora_request) -> None: - """Load a new LoRA adapter into the engine for future requests.""" - raise NotImplementedError("Not expected to be reached") - - async def reset_mm_cache(self) -> None: - """Reset the multi-modal cache""" - raise NotImplementedError("Not expected to be reached") - - -class MockPDDisaggVLLMEngine(VLLMEngine): - async def _start_engine(self) -> EngineClient: - return MockPDDisaggVLLMEngineClient( - VllmConfig( - model_config=ModelConfig( - model=self.llm_config.model_loading_config.model_id, - task="auto", - tokenizer=self.llm_config.model_loading_config.model_id, - tokenizer_mode="auto", - trust_remote_code=False, - dtype="auto", - seed=0, - ), - device_config=DeviceConfig( - device="cpu", - ), + response = ChatCompletionResponse( + id=request_id, + object="chat.completion", + created=int(asyncio.get_event_loop().time()), + model=getattr(request, "model", "mock-model"), + choices=[choice], + usage={ + "prompt_tokens": len(prompt_text.split()), + "completion_tokens": max_tokens, + "total_tokens": len(prompt_text.split()) + max_tokens, + }, ) - ) - - -def generate_from_schema(schema): - if "type" not in schema: - raise ValueError("Schema must have a 'type' property") - - # Check for enum and return a random value from it - if "enum" in schema: - return schema["enum"][0] - - if schema["type"] == "object": - obj = {} - for prop, prop_schema in schema.get("properties", {}).items(): - obj[prop] = generate_from_schema(prop_schema) - return obj - - elif schema["type"] == "array": - item_schema = schema.get("items", {}) - return [generate_from_schema(item_schema) for _ in range(random.randint(1, 3))] - - elif schema["type"] == "string": - return "sample_string" - - elif schema["type"] == "integer": - return random.randint(0, 100) - - elif schema["type"] == "number": - return random.uniform(0, 100) - - elif schema["type"] == "boolean": - return random.choice([True, False]) - else: - raise ValueError(f"Unsupported type: {schema['type']}") + yield response + async def _generate_completion_response( + self, request: CompletionRequest, prompt_text: str, max_tokens: int + ) -> AsyncGenerator[Union[str, CompletionResponse], None]: + """Generate mock completion response.""" -def split_string_into_chunks(s, n): - if n <= 0: - raise ValueError("Number of chunks must be greater than 0") - - chunk_size = len(s) // n - remainder = len(s) % n + request_id = request.request_id or f"cmpl-{random.randint(1000, 9999)}" + lora_prefix = ( + "" + if request.model not in self._current_lora_model + else f"[lora_model] {request.model}: " + ) + if request.stream: + # Streaming response - return SSE formatted strings + created_time = int(asyncio.get_event_loop().time()) + model_name = getattr(request, "model", "mock-model") + + for i in range(max_tokens): + if i == 0: + token = f"{lora_prefix}test_{i} " + else: + token = f"test_{i} " + if i == max_tokens - 1: + # no space for the last token + token = f"test_{i}" + + choice = { + "index": 0, + "text": token, + "finish_reason": "stop" if i == max_tokens - 1 else None, + } + + chunk_data = { + "id": request_id, + "object": "text_completion", + "created": created_time, + "model": model_name, + "choices": [choice], + } + + # Format as SSE + yield f"data: {json.dumps(chunk_data)}\n\n" + await asyncio.sleep(0.01) + + # Send final [DONE] message + yield "data: [DONE]\n\n" + else: + # Non-streaming response - return response object + generated_text = " ".join([f"test_{i}" for i in range(max_tokens)]) + generated_text = f"{lora_prefix}{generated_text}" + + choice = {"index": 0, "text": generated_text, "finish_reason": "stop"} + + response = CompletionResponse( + id=request_id, + object="text_completion", + created=int(asyncio.get_event_loop().time()), + model=getattr(request, "model", "mock-model"), + choices=[choice], + usage={ + "prompt_tokens": len(prompt_text.split()), + "completion_tokens": max_tokens, + "total_tokens": len(prompt_text.split()) + max_tokens, + }, + ) - chunks = [] - start = 0 - for i in range(n): - end = start + chunk_size + (1 if i < remainder else 0) - chunks.append(s[start:end]) - start = end + yield response - return chunks +class FakeLoraModelLoader(LoraModelLoader): + """Fake LoRA model loader for testing.""" -def get_prompt_length(prompt): - return len(prompt.split()) if isinstance(prompt, str) else len(prompt) + async def load_model( + self, lora_model_id: str, llm_config: LLMConfig + ) -> DiskMultiplexConfig: + """Load a fake LoRA model.""" + return DiskMultiplexConfig( + model_id=lora_model_id, + max_total_tokens=llm_config.max_request_context_length, + local_path="/fake/local/path", + lora_assigned_int_id=random.randint(1, 100), + ) diff --git a/python/ray/llm/tests/serve/utils/__init__.py b/python/ray/llm/tests/serve/utils/__init__.py new file mode 100644 index 000000000000..e356527468b2 --- /dev/null +++ b/python/ray/llm/tests/serve/utils/__init__.py @@ -0,0 +1 @@ +# Testing utilities for Ray LLM serve tests diff --git a/python/ray/llm/tests/serve/utils/testing_utils.py b/python/ray/llm/tests/serve/utils/testing_utils.py new file mode 100644 index 000000000000..1cdab168418b --- /dev/null +++ b/python/ray/llm/tests/serve/utils/testing_utils.py @@ -0,0 +1,96 @@ +"""Shared testing utilities for Ray LLM serve tests. + +This is written with assumptions around how mocks for testing are expected to behave. +""" + +import json +import re +from typing import Union, List, Optional + +from ray.llm._internal.serve.configs.openai_api_models import ( + ChatCompletionResponse, + CompletionResponse, + EmbeddingResponse, +) + + +class LLMResponseValidator: + """Reusable validation logic for LLM responses.""" + + @staticmethod + def get_expected_content( + api_type: str, max_tokens: int, lora_model_id: str = "" + ) -> str: + """Get expected content based on API type.""" + expected_content = " ".join(f"test_{i}" for i in range(max_tokens)) + if lora_model_id: + expected_content = f"[lora_model] {lora_model_id}: {expected_content}" + return expected_content + + @staticmethod + def validate_non_streaming_response( + response: Union[ChatCompletionResponse, CompletionResponse], + api_type: str, + max_tokens: int, + lora_model_id: str = "", + ): + """Validate non-streaming responses.""" + expected_content = LLMResponseValidator.get_expected_content( + api_type, max_tokens, lora_model_id + ) + + if api_type == "chat": + assert isinstance(response, ChatCompletionResponse) + assert response.choices[0].message.content == expected_content + elif api_type == "completion": + assert isinstance(response, CompletionResponse) + assert response.choices[0].text == expected_content + + @staticmethod + def validate_streaming_chunks( + chunks: List[str], api_type: str, max_tokens: int, lora_model_id: str = "" + ): + """Validate streaming response chunks.""" + # Should have max_tokens + 1 chunks (tokens + [DONE]) + assert len(chunks) == max_tokens + 1 + + # Validate each chunk except the last [DONE] chunk + for chunk_iter, chunk in enumerate(chunks[:-1]): + pattern = r"data: (.*)\n\n" + match = re.match(pattern, chunk) + assert match is not None + chunk_data = json.loads(match.group(1)) + + expected_chunk = f"test_{chunk_iter}" + if lora_model_id and chunk_iter == 0: + expected_chunk = f"[lora_model] {lora_model_id}: {expected_chunk}" + + if api_type == "chat": + delta = chunk_data["choices"][0]["delta"] + if chunk_iter == 0: + assert delta["role"] == "assistant" + else: + assert delta["role"] is None + assert delta["content"].strip() == expected_chunk + elif api_type == "completion": + text = chunk_data["choices"][0]["text"] + assert text.strip() == expected_chunk + + @staticmethod + def validate_embedding_response( + response: EmbeddingResponse, expected_dimensions: Optional[int] = None + ): + """Validate embedding responses.""" + assert isinstance(response, EmbeddingResponse) + assert response.object == "list" + assert len(response.data) == 1 + assert response.data[0].object == "embedding" + assert isinstance(response.data[0].embedding, list) + assert ( + len(response.data[0].embedding) > 0 + ) # Should have some embedding dimensions + assert response.data[0].index == 0 + + # Check dimensions if specified + if expected_dimensions: + assert len(response.data[0].embedding) == expected_dimensions diff --git a/python/ray/serve/llm/openai_api_models.py b/python/ray/serve/llm/openai_api_models.py index 210984cc1bd0..496cf794ac4b 100644 --- a/python/ray/serve/llm/openai_api_models.py +++ b/python/ray/serve/llm/openai_api_models.py @@ -72,14 +72,7 @@ class CompletionResponse(_CompletionResponse): pass -@PublicAPI(stability="alpha") -class EmbeddingRequest(_EmbeddingRequest): - """EmbeddingRequest is the request body for the embedding API. - - This model is compatible with vLLM's OpenAI API models. - """ - - pass +EmbeddingRequest = _EmbeddingRequest @PublicAPI(stability="alpha") diff --git a/release/llm_tests/serve/probes/models.py b/release/llm_tests/serve/probes/models.py index 5c067515df26..f0714c209ad9 100644 --- a/release/llm_tests/serve/probes/models.py +++ b/release/llm_tests/serve/probes/models.py @@ -97,15 +97,11 @@ def is_release_test_model(model: "openai.types.model.Model") -> bool: def is_finetuned_model(model: "openai.types.model.Model") -> bool: # If base_model_id is set, this is a finetuned model - return ( - model.model_dump().get("rayllm_metadata", {}).get("base_model_id") is not None - ) + return model.model_dump().get("metadata", {}).get("base_model_id") is not None def is_vision_language_model(model: "openai.types.model.Model") -> bool: - return ( - model.model_dump().get("rayllm_metadata", {}).get("input_modality") == "image" - ) + return model.model_dump().get("metadata", {}).get("input_modality") == "image" def is_rate_liming_test_model(model: "openai.types.model.Model") -> bool: @@ -134,7 +130,7 @@ def is_completions_only_model(model: "openai.types.model.Model") -> bool: def supports_function_calling_via_prompt(model: "openai.types.model.Model") -> bool: # True if tool template is specified in the generation config - gen_config = model.model_dump().get("rayllm_metadata", {}).get("generation", False) + gen_config = model.model_dump().get("metadata", {}).get("generation", False) if not gen_config: return False diff --git a/release/llm_tests/serve/probes/query_utils.py b/release/llm_tests/serve/probes/query_utils.py index e76d2338e3fc..1026e303f19f 100644 --- a/release/llm_tests/serve/probes/query_utils.py +++ b/release/llm_tests/serve/probes/query_utils.py @@ -42,7 +42,12 @@ def _apply_delta(base, delta): # in order to merge them, not recursively merge them. if key == "logprobs": if delta[key]: - base[key]["content"].extend(delta[key]["content"]) + cur_val = (base[key] or {}).get("content", []) or [] + cur_val.extend(delta[key]["content"]) + if base[key]: + base[key]["content"] = cur_val + else: + base[key] = {"content": cur_val} continue if isinstance(base[key], dict): @@ -97,6 +102,8 @@ def messages(self): """In case of streamed response, what are the individual chunked messages? that contain the content we care about?""" vals = [] for r in self.response: + if len(r.choices) == 0: + continue v = r.choices[0].model_dump() if "message" in v and "content" in v["message"]: vals.append(v["message"]["content"] or "") @@ -128,7 +135,11 @@ def num_completion_tokens(self): def finish_reason(self): # This should be set on the last response. - return self.response[-1].choices[0].finish_reason + for chunk in self.response: + if len(chunk.choices) > 0: + if chunk.choices[0].finish_reason: + return chunk.choices[0].finish_reason + return None class BaseProbe: @@ -171,6 +182,11 @@ async def query( "stream": stream, **chat_args, } + + if stream: + args["stream_options"] = { + "include_usage": True, + } if chat: method = self.client.chat.completions.create else: diff --git a/release/llm_tests/serve/probes/test_basic.py b/release/llm_tests/serve/probes/test_basic.py index ca42f934095e..3b35289bb5da 100755 --- a/release/llm_tests/serve/probes/test_basic.py +++ b/release/llm_tests/serve/probes/test_basic.py @@ -160,7 +160,7 @@ async def test_too_long_completion_request( ) # XXX: AE-686 hack, should read model data instead - length = 20000 + length = 200000 if "8x22" in model: length = 70000 diff --git a/release/llm_tests/serve/probes/test_json_mode.py b/release/llm_tests/serve/probes/test_json_mode.py index a971be59c49a..1dc2eb51af0e 100644 --- a/release/llm_tests/serve/probes/test_json_mode.py +++ b/release/llm_tests/serve/probes/test_json_mode.py @@ -101,8 +101,11 @@ def get_params_and_expected_type(response_type: str, test_id: str): params.update( { "response_format": { - "type": "json_object", - "schema": expected_type.schema_json(), + "type": "json_schema", + "json_schema": { + "name": "expected_schema", + "schema": expected_type.model_json_schema(), + }, } } ) @@ -118,7 +121,7 @@ def get_response_formats(): {"type": "json_object", "schema": json.dumps({})}, {"type": "json_object", "schema": json.loads(BasicResponse.schema_json())}, {"type": "json_object", "schema": BasicResponse.schema_json()}, - {"type": "grammar", "grammar": JSON_GRAMMAR_EBNF_STR}, + # {"type": "grammar", "grammar": JSON_GRAMMAR_EBNF_STR}, ] @@ -201,8 +204,11 @@ async def test_response_format_options( async def test_invalid_schema(model: str, openai_async_client): querier = TextGenerationProbeQuerier(openai_async_client, {"temperature": 0.0}) response_format = { - "type": "json_object", - "schema": {"type": "object", "properties": {"name": {"type": "str"}}}, + "type": "json_schema", + "json_schema": { + "name": "expected_schema", + "schema": {"type": "object", "properties": {"name": {"type": "str"}}}, + }, } params = { diff --git a/release/llm_tests/serve/probes/test_models.py b/release/llm_tests/serve/probes/test_models.py index 84d1207da673..f2ecc4a076a6 100644 --- a/release/llm_tests/serve/probes/test_models.py +++ b/release/llm_tests/serve/probes/test_models.py @@ -8,4 +8,4 @@ def test_get_model(model: str): model_description = openai_client.models.retrieve(model) assert model_description.id == model - assert "rayllm_metadata" in model_description.model_dump() + assert "metadata" in model_description.model_dump()