Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions src/any_llm/providers/sagemaker/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from typing import Any

from pydantic import BaseModel

from any_llm.any_llm import AnyLLM
from any_llm.config import ClientConfig
from any_llm.exceptions import MissingApiKeyError
from any_llm.exceptions import MissingApiKeyError, UnsupportedParameterError
from any_llm.logging import logger
from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse
from any_llm.types.model import Model
from any_llm.utils.instructor import _convert_instructor_response

MISSING_PACKAGES_ERROR = None
try:
Expand Down Expand Up @@ -137,28 +134,8 @@ def _completion(
completion_kwargs = self._convert_completion_params(params, **kwargs)

if params.response_format:
if params.stream:
msg = "stream is not supported for response_format"
raise ValueError(msg)

if not isinstance(params.response_format, type) or not issubclass(params.response_format, BaseModel):
msg = "response_format must be a pydantic model"
raise ValueError(msg)

response = self.client.invoke_endpoint(
EndpointName=params.model_id,
Body=json.dumps(completion_kwargs),
ContentType="application/json",
)

response_body = json.loads(response["Body"].read())

try:
structured_response = params.response_format.model_validate(response_body)
return _convert_instructor_response(structured_response, params.model_id, "aws")
except (ValueError, TypeError) as e:
logger.warning("Failed to parse structured response: %s", e)
return self._convert_completion_response({"model": params.model_id, **response_body})
param = "response_format"
raise UnsupportedParameterError(param, "sagemaker")

if params.stream:
response = self.client.invoke_endpoint_with_response_stream(
Expand Down