diff --git a/src/any_llm/providers/cohere/cohere.py b/src/any_llm/providers/cohere/cohere.py index cc924112..95c815b7 100644 --- a/src/any_llm/providers/cohere/cohere.py +++ b/src/any_llm/providers/cohere/cohere.py @@ -10,6 +10,7 @@ MISSING_PACKAGES_ERROR = None try: import cohere + from cohere import V2ChatResponse from .utils import ( _convert_models_list, @@ -37,7 +38,7 @@ class CohereProvider(AnyLLM): SUPPORTS_COMPLETION_STREAMING = True SUPPORTS_COMPLETION = True SUPPORTS_RESPONSES = False - SUPPORTS_COMPLETION_REASONING = False + SUPPORTS_COMPLETION_REASONING = True SUPPORTS_COMPLETION_IMAGE = False SUPPORTS_COMPLETION_PDF = False SUPPORTS_EMBEDDING = False @@ -60,10 +61,10 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ return converted_params @staticmethod - def _convert_completion_response(response: Any, **kwargs: Any) -> ChatCompletion: + def _convert_completion_response(response: V2ChatResponse, **kwargs: Any) -> ChatCompletion: """Convert Cohere response to OpenAI format.""" # We need the model parameter for conversion - model = kwargs.get("model", getattr(response, "model", "cohere-model")) + model = kwargs.get("model", "unknown") return _convert_response(response, model) @staticmethod diff --git a/src/any_llm/providers/cohere/utils.py b/src/any_llm/providers/cohere/utils.py index 1df8330f..ace7c245 100644 --- a/src/any_llm/providers/cohere/utils.py +++ b/src/any_llm/providers/cohere/utils.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from typing import Any +from cohere import V2ChatResponse from cohere.types import ListModelsResponse as CohereListModelsResponse from any_llm.types.completion import ( @@ -11,6 +12,7 @@ Choice, CompletionUsage, Function, + Reasoning, ) from any_llm.types.model import Model @@ -61,9 +63,12 @@ def _create_openai_chunk_from_cohere_chunk(chunk: Any) -> ChatCompletionChunk: and chunk.delta.message and hasattr(chunk.delta.message, "content") and chunk.delta.message.content - and hasattr(chunk.delta.message.content, "text") ): - delta["content"] = chunk.delta.message.content.text + content_obj = chunk.delta.message.content + if hasattr(content_obj, "text") and content_obj.text: + delta["content"] = content_obj.text + elif hasattr(content_obj, "thinking") and content_obj.thinking: + delta["reasoning"] = {"content": content_obj.thinking} elif chunk_type == "tool-call-start": if ( @@ -144,7 +149,7 @@ def _create_openai_chunk_from_cohere_chunk(chunk: Any) -> ChatCompletionChunk: return ChatCompletionChunk.model_validate(chunk_dict) -def _convert_response(response: Any, model: str) -> ChatCompletion: +def _convert_response(response: V2ChatResponse, model: str) -> ChatCompletion: """Convert Cohere response to OpenAI ChatCompletion format directly.""" prompt_tokens = 0 completion_tokens = 0 @@ -166,11 +171,13 @@ def _convert_response(response: Any, model: str) -> ChatCompletion: content=response.message.tool_plan, tool_calls=[ ChatCompletionMessageFunctionToolCall( - id=tool_call.id, + id=tool_call.id or "", type="function", function=Function( - name=tool_call.function.name if tool_call.function else "", - arguments=tool_call.function.arguments if tool_call.function else "", + name=tool_call.function.name if tool_call.function and tool_call.function.name else "", + arguments=tool_call.function.arguments + if tool_call.function and tool_call.function.arguments + else "", ), ) ], @@ -185,10 +192,17 @@ def _convert_response(response: Any, model: str) -> ChatCompletion: usage=usage, ) content = "" - if response.message.content and len(response.message.content) > 0: - content = response.message.content[0].text + reasoning_content = None - message = ChatCompletionMessage(role="assistant", content=content, tool_calls=None) + if response.message.content and len(response.message.content) > 0: + for item in response.message.content: + if hasattr(item, "type"): + if item.type == "text" and hasattr(item, "text"): + content += item.text + elif item.type == "thinking" and hasattr(item, "thinking"): + reasoning_content = Reasoning(content=item.thinking) + + message = ChatCompletionMessage(role="assistant", content=content, tool_calls=None, reasoning=reasoning_content) choice = Choice( index=0, finish_reason="stop", diff --git a/tests/conftest.py b/tests/conftest.py index efbc4291..595aaaa0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ def provider_reasoning_model_map() -> dict[LLMProvider, str]: LLMProvider.LLAMACPP: "N/A", LLMProvider.LMSTUDIO: "openai/gpt-oss-20b", # You must have LM Studio running and the server enabled LLMProvider.AZUREOPENAI: "azure/", + LLMProvider.COHERE: "command-a-reasoning-08-2025", }