Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/any_llm/providers/cohere/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
MISSING_PACKAGES_ERROR = None
try:
import cohere
from cohere import V2ChatResponse

from .utils import (
_convert_models_list,
Expand Down Expand Up @@ -37,7 +38,7 @@ class CohereProvider(AnyLLM):
SUPPORTS_COMPLETION_STREAMING = True
SUPPORTS_COMPLETION = True
SUPPORTS_RESPONSES = False
SUPPORTS_COMPLETION_REASONING = False
SUPPORTS_COMPLETION_REASONING = True
SUPPORTS_COMPLETION_IMAGE = False
SUPPORTS_COMPLETION_PDF = False
SUPPORTS_EMBEDDING = False
Expand All @@ -60,10 +61,10 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[
return converted_params

@staticmethod
def _convert_completion_response(response: Any, **kwargs: Any) -> ChatCompletion:
def _convert_completion_response(response: V2ChatResponse, **kwargs: Any) -> ChatCompletion:
"""Convert Cohere response to OpenAI format."""
# We need the model parameter for conversion
model = kwargs.get("model", getattr(response, "model", "cohere-model"))
model = kwargs.get("model", "unknown")
return _convert_response(response, model)

@staticmethod
Expand Down
32 changes: 23 additions & 9 deletions src/any_llm/providers/cohere/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Sequence
from typing import Any

from cohere import V2ChatResponse
from cohere.types import ListModelsResponse as CohereListModelsResponse

from any_llm.types.completion import (
Expand All @@ -11,6 +12,7 @@
Choice,
CompletionUsage,
Function,
Reasoning,
)
from any_llm.types.model import Model

Expand Down Expand Up @@ -61,9 +63,12 @@ def _create_openai_chunk_from_cohere_chunk(chunk: Any) -> ChatCompletionChunk:
and chunk.delta.message
and hasattr(chunk.delta.message, "content")
and chunk.delta.message.content
and hasattr(chunk.delta.message.content, "text")
):
delta["content"] = chunk.delta.message.content.text
content_obj = chunk.delta.message.content
if hasattr(content_obj, "text") and content_obj.text:
delta["content"] = content_obj.text
elif hasattr(content_obj, "thinking") and content_obj.thinking:
delta["reasoning"] = {"content": content_obj.thinking}

elif chunk_type == "tool-call-start":
if (
Expand Down Expand Up @@ -144,7 +149,7 @@ def _create_openai_chunk_from_cohere_chunk(chunk: Any) -> ChatCompletionChunk:
return ChatCompletionChunk.model_validate(chunk_dict)


def _convert_response(response: Any, model: str) -> ChatCompletion:
def _convert_response(response: V2ChatResponse, model: str) -> ChatCompletion:
"""Convert Cohere response to OpenAI ChatCompletion format directly."""
prompt_tokens = 0
completion_tokens = 0
Expand All @@ -166,11 +171,13 @@ def _convert_response(response: Any, model: str) -> ChatCompletion:
content=response.message.tool_plan,
tool_calls=[
ChatCompletionMessageFunctionToolCall(
id=tool_call.id,
id=tool_call.id or "",
type="function",
function=Function(
name=tool_call.function.name if tool_call.function else "",
arguments=tool_call.function.arguments if tool_call.function else "",
name=tool_call.function.name if tool_call.function and tool_call.function.name else "",
arguments=tool_call.function.arguments
if tool_call.function and tool_call.function.arguments
else "",
),
)
],
Expand All @@ -185,10 +192,17 @@ def _convert_response(response: Any, model: str) -> ChatCompletion:
usage=usage,
)
content = ""
if response.message.content and len(response.message.content) > 0:
content = response.message.content[0].text
reasoning_content = None

message = ChatCompletionMessage(role="assistant", content=content, tool_calls=None)
if response.message.content and len(response.message.content) > 0:
for item in response.message.content:
if hasattr(item, "type"):
if item.type == "text" and hasattr(item, "text"):
content += item.text
elif item.type == "thinking" and hasattr(item, "thinking"):
reasoning_content = Reasoning(content=item.thinking)

message = ChatCompletionMessage(role="assistant", content=content, tool_calls=None, reasoning=reasoning_content)
choice = Choice(
index=0,
finish_reason="stop",
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def provider_reasoning_model_map() -> dict[LLMProvider, str]:
LLMProvider.LLAMACPP: "N/A",
LLMProvider.LMSTUDIO: "openai/gpt-oss-20b", # You must have LM Studio running and the server enabled
LLMProvider.AZUREOPENAI: "azure/<your_deployment_name>",
LLMProvider.COHERE: "command-a-reasoning-08-2025",
}


Expand Down