Skip to content
50 changes: 34 additions & 16 deletions libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.ai import (
InputTokenDetails,
UsageMetadata,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
Expand Down Expand Up @@ -704,15 +708,7 @@ def _create_chat_result(
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
message.usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": token_usage.get(
"total_tokens", input_tokens + output_tokens
),
}
message.usage_metadata = _create_usage_metadata(token_usage)
generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
Expand Down Expand Up @@ -1303,13 +1299,7 @@ def _convert_chunk_to_message_chunk(
{k: executed_tool[k] for k in executed_tool if k != "output"}
)
if usage := (chunk.get("x_groq") or {}).get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
}
usage_metadata = _create_usage_metadata(usage)
else:
usage_metadata = None
return AIMessageChunk(
Expand Down Expand Up @@ -1409,3 +1399,31 @@ def _lc_invalid_tool_call_to_groq_tool_call(
"arguments": invalid_tool_call["args"],
},
}


def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata:
"""Create usage metadata from Groq token usage response.

Args:
groq_token_usage: Token usage dict from Groq API response.

Returns:
Usage metadata dict with input/output token details.
"""
input_tokens = groq_token_usage.get("prompt_tokens") or 0
output_tokens = groq_token_usage.get("completion_tokens") or 0
total_tokens = groq_token_usage.get("total_tokens") or input_tokens + output_tokens
input_token_details: dict = {
"cache_read": (groq_token_usage.get("prompt_tokens_details") or {}).get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find anywhere in Groq's docs where prompt_token_details are returned?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be input_tokens_details

Copy link
Contributor Author

@MshariAlaeena MshariAlaeena Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking, It’s actually documented here:
https://console.groq.com/docs/prompt-caching#tracking-cache-usage

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MshariAlaeena I think that might be wrong? See Slack

"cached_tokens"
),
}
usage_metadata: UsageMetadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}

if filtered_input := {k: v for k, v in input_token_details.items() if v}:
usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item]
return usage_metadata
276 changes: 275 additions & 1 deletion libs/partners/groq/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)

from langchain_groq.chat_models import ChatGroq, _convert_dict_to_message
from langchain_groq.chat_models import (
ChatGroq,
_convert_chunk_to_message_chunk,
_convert_dict_to_message,
_create_usage_metadata,
)

if "GROQ_API_KEY" not in os.environ:
os.environ["GROQ_API_KEY"] = "fake-key"
Expand Down Expand Up @@ -283,3 +289,271 @@ def test_groq_serialization() -> None:

# Ensure a None was preserved
assert llm.groq_api_base == llm2.groq_api_base


def test_create_usage_metadata_basic() -> None:
"""Test basic usage metadata creation without details."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
}

result = _create_usage_metadata(token_usage)

assert isinstance(result, dict)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result
assert "output_token_details" not in result


def test_create_usage_metadata_with_cached_tokens() -> None:
"""Test usage metadata with prompt caching."""
token_usage = {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {"cached_tokens": 1920},
}

result = _create_usage_metadata(token_usage)

assert isinstance(result, dict)
assert result["input_tokens"] == 2006
assert result["output_tokens"] == 300
assert result["total_tokens"] == 2306
assert "input_token_details" in result
assert isinstance(result["input_token_details"], dict)
assert result["input_token_details"]["cache_read"] == 1920
assert "output_token_details" not in result


def test_create_usage_metadata_with_all_details() -> None:
"""Test usage metadata with all available details."""
token_usage = {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {"cached_tokens": 1920},
}

result = _create_usage_metadata(token_usage)

assert isinstance(result, dict)
assert result["input_tokens"] == 2006
assert result["output_tokens"] == 300
assert result["total_tokens"] == 2306

assert "input_token_details" in result
assert isinstance(result["input_token_details"], dict)
assert result["input_token_details"]["cache_read"] == 1920

assert "output_token_details" not in result


def test_create_usage_metadata_missing_total_tokens() -> None:
"""Test that total_tokens is calculated when missing."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
}

result = _create_usage_metadata(token_usage)

assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150


def test_create_usage_metadata_empty_details() -> None:
"""Test that empty detail dicts don't create token detail objects."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_tokens_details": {},
}

result = _create_usage_metadata(token_usage)

assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result
assert "output_token_details" not in result


def test_create_usage_metadata_zero_cached_tokens() -> None:
"""Test that zero cached tokens are not included (falsy)."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_tokens_details": {"cached_tokens": 0},
}

result = _create_usage_metadata(token_usage)

assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result


def test_chat_result_with_usage_metadata() -> None:
"""Test that _create_chat_result properly includes usage metadata."""
llm = ChatGroq(model="test-model")

mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {"cached_tokens": 1920},
},
}

result = llm._create_chat_result(mock_response, {})

assert len(result.generations) == 1
message = result.generations[0].message
assert isinstance(message, AIMessage)
assert message.content == "Test response"

assert message.usage_metadata is not None
assert isinstance(message.usage_metadata, dict)
assert message.usage_metadata["input_tokens"] == 2006
assert message.usage_metadata["output_tokens"] == 300
assert message.usage_metadata["total_tokens"] == 2306

assert "input_token_details" in message.usage_metadata
assert message.usage_metadata["input_token_details"]["cache_read"] == 1920

assert "output_token_details" not in message.usage_metadata


def test_chat_result_backward_compatibility() -> None:
"""Test that responses without new fields still work."""
llm = ChatGroq(model="test-model")

mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
},
}

result = llm._create_chat_result(mock_response, {})

assert len(result.generations) == 1
message = result.generations[0].message
assert isinstance(message, AIMessage)

assert message.usage_metadata is not None
assert message.usage_metadata["input_tokens"] == 100
assert message.usage_metadata["output_tokens"] == 50
assert message.usage_metadata["total_tokens"] == 150

assert "input_token_details" not in message.usage_metadata
assert "output_token_details" not in message.usage_metadata


def test_streaming_with_usage_metadata() -> None:
"""Test that streaming properly includes usage metadata."""
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
"x_groq": {
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {"cached_tokens": 1920},
}
},
}

result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)

assert isinstance(result, AIMessageChunk)
assert result.content == "Hello"

assert result.usage_metadata is not None
assert isinstance(result.usage_metadata, dict)
assert result.usage_metadata["input_tokens"] == 2006
assert result.usage_metadata["output_tokens"] == 300
assert result.usage_metadata["total_tokens"] == 2306

assert "input_token_details" in result.usage_metadata
assert result.usage_metadata["input_token_details"]["cache_read"] == 1920

assert "output_token_details" not in result.usage_metadata


def test_streaming_without_usage_metadata() -> None:
"""Test that streaming works without usage metadata (backward compatibility)."""
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
}

result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)

assert isinstance(result, AIMessageChunk)
assert result.content == "Hello"
assert result.usage_metadata is None