Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ def _parse_response_candidate(
except (AttributeError, TypeError):
thought_sig = None

has_function_call = hasattr(part, "function_call") and part.function_call

if hasattr(part, "thought") and part.thought:
thinking_message = {
"type": "thinking",
Expand All @@ -797,7 +799,7 @@ def _parse_response_candidate(
if thought_sig:
thinking_message["signature"] = thought_sig
content = _append_to_content(content, thinking_message)
elif text is not None and text:
elif text is not None and text.strip() and not has_function_call:
# Check if this text Part has a signature attached
if thought_sig:
# Text with signature needs structured block to preserve signature
Expand Down Expand Up @@ -929,15 +931,25 @@ def _parse_response_candidate(
}
function_call_signatures.append(sig_block)

# Add function call signatures to content only if there's already other content
# This preserves backward compatibility where content is "" for
# function-only responses
if function_call_signatures and content is not None:
for sig_block in function_call_signatures:
content = _append_to_content(content, sig_block)
# Add function call signatures to content only if there's already other content
# This preserves backward compatibility where content is "" for
# function-only responses
if function_call_signatures and content is not None:
for sig_block in function_call_signatures:
content = _append_to_content(content, sig_block)

if content is None:
content = ""

if (
hasattr(response_candidate, "logprobs_result")
and response_candidate.logprobs_result
):
response_metadata["logprobs"] = MessageToDict(
response_candidate.logprobs_result._pb,
preserving_proto_field_name=True,
)

if isinstance(content, list) and any(
isinstance(item, dict) and "executable_code" in item for item in content
):
Expand Down Expand Up @@ -1825,6 +1837,9 @@ class Joke(BaseModel):
stop: Optional[List[str]] = None
"""Stop sequences for the model."""

logprobs: Optional[int] = None
"""The number of logprobs to return."""

streaming: Optional[bool] = None
"""Whether to stream responses from the model."""

Expand Down Expand Up @@ -1963,6 +1978,7 @@ def _identifying_params(self) -> Dict[str, Any]:
"media_resolution": self.media_resolution,
"thinking_budget": self.thinking_budget,
"include_thoughts": self.include_thoughts,
"logprobs": self.logprobs,
}

def invoke(
Expand Down Expand Up @@ -2037,6 +2053,7 @@ def _prepare_params(
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"logprobs": getattr(self, "logprobs", None),
"response_modalities": self.response_modalities,
"thinking_config": (
(
Expand All @@ -2058,6 +2075,8 @@ def _prepare_params(
}.items()
if v is not None
}
if getattr(self, "logprobs", None) is not None:
gen_config["response_logprobs"] = True
if generation_config:
gen_config = {**gen_config, **generation_config}

Expand Down
79 changes: 78 additions & 1 deletion libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,24 +141,98 @@ def test_initialization_inside_threadpool() -> None:
).result()


def test_client_transport() -> None:
def test_logprobs() -> None:
"""Test that logprobs parameter is set correctly and is in the response."""
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key=SecretStr("secret-api-key"),
logprobs=10,
)
assert llm.logprobs == 10

# Create proper mock response with logprobs_result
raw_response = {
"candidates": [
{
"content": {"parts": [{"text": "Test response"}]},
"finish_reason": 1,
"safety_ratings": [],
"logprobs_result": {
"top_candidates": [
{
"candidates": [
{"token": "Test", "log_probability": -0.1},
]
}
]
},
}
],
"prompt_feedback": {"block_reason": 0, "safety_ratings": []},
"usage_metadata": {
"prompt_token_count": 5,
"candidates_token_count": 2,
"total_token_count": 7,
},
}
response = GenerateContentResponse(raw_response)

with patch(
"langchain_google_genai.chat_models._chat_with_retry"
) as mock_chat_with_retry:
mock_chat_with_retry.return_value = response
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key="test-key",
logprobs=1,
)
result = llm.invoke("test")
assert "logprobs" in result.response_metadata
assert result.response_metadata["logprobs"] == {
"top_candidates": [
{
"candidates": [
{"token": "Test", "log_probability": -0.1},
]
}
]
}

mock_chat_with_retry.assert_called_once()
request = mock_chat_with_retry.call_args.kwargs["request"]
assert request.generation_config.logprobs == 1
assert request.generation_config.response_logprobs is True


@pytest.mark.enable_socket
@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient")
@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceClient")
def test_client_transport(mock_client: Mock, mock_async_client: Mock) -> None:
"""Test client transport configuration."""
mock_client.return_value.transport = Mock()
mock_client.return_value.transport.kind = "grpc"
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key")
assert model.client.transport.kind == "grpc"

mock_client.return_value.transport.kind = "rest"
model = ChatGoogleGenerativeAI(
model=MODEL_NAME, google_api_key="fake-key", transport="rest"
)
assert model.client.transport.kind == "rest"

async def check_async_client() -> None:
mock_async_client.return_value.transport = Mock()
mock_async_client.return_value.transport.kind = "grpc_asyncio"
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key")
_ = model.async_client
assert model.async_client.transport.kind == "grpc_asyncio"

# Test auto conversion of transport to "grpc_asyncio" from "rest"
model = ChatGoogleGenerativeAI(
model=MODEL_NAME, google_api_key="fake-key", transport="rest"
)
model.async_client_running = None
_ = model.async_client
assert model.async_client.transport.kind == "grpc_asyncio"

asyncio.run(check_async_client())
Expand All @@ -172,6 +246,7 @@ def test_initalization_without_async() -> None:
assert chat.async_client is None


@pytest.mark.enable_socket
def test_initialization_with_async() -> None:
async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI:
model = ChatGoogleGenerativeAI(
Expand Down Expand Up @@ -1288,6 +1363,7 @@ def test_grounding_metadata_multiple_parts() -> None:
assert grounding["grounding_supports"][0]["segment"]["part_index"] == 1


@pytest.mark.enable_socket
@pytest.mark.parametrize(
"is_async,mock_target,method_name",
[
Expand Down Expand Up @@ -1414,6 +1490,7 @@ def mock_stream() -> Iterator[GenerateContentResponse]:
assert "timeout" not in call_kwargs


@pytest.mark.enable_socket
@pytest.mark.parametrize(
"is_async,mock_target,method_name",
[
Expand Down