Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from nemoguardrails.logging.callbacks import logging_callbacks
from nemoguardrails.logging.explain import LLMCallInfo

log = logging.getLogger(__name__)


class LLMCallException(Exception):
"""A wrapper around the LLM call invocation exception.
Expand Down Expand Up @@ -113,7 +115,7 @@ def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]:
return _infer_provider_from_module(llm)


def _infer_model_name(llm: BaseLanguageModel):
def _infer_model_name(llm: Union[BaseLanguageModel, Runnable]) -> str:
"""Helper to infer the model name based from an LLM instance.

Because not all models implement correctly _identifying_params from LangChain, we have to
Expand Down
19 changes: 19 additions & 0 deletions nemoguardrails/logging/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.logging.processing_log import processing_log_var
from nemoguardrails.logging.stats import LLMStats
from nemoguardrails.logging.utils import extract_model_name_and_base_url
from nemoguardrails.utils import new_uuid

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,6 +65,15 @@ async def on_llm_start(
if explain_info:
explain_info.llm_calls.append(llm_call_info)

# Log model name and base URL
model_name, base_url = extract_model_name_and_base_url(serialized)
if base_url:
log.info(f"Invoking LLM: model={model_name}, url={base_url}")
elif model_name:
log.info(f"Invoking LLM: model={model_name}")
else:
log.info("Invoking LLM")
Comment on lines +68 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The exact same logging logic appears in both on_llm_start (lines 68-75) and on_chat_model_start (lines 118-125). Extract to a helper method to reduce duplication and improve maintainability.


log.info("Invocation Params :: %s", kwargs.get("invocation_params", {}))
log.info(
"Prompt :: %s",
Expand Down Expand Up @@ -105,6 +115,15 @@ async def on_chat_model_start(
if explain_info:
explain_info.llm_calls.append(llm_call_info)

# Log model name and base URL
model_name, base_url = extract_model_name_and_base_url(serialized)
if base_url:
log.info(f"Invoking LLM: model={model_name}, url={base_url}")
elif model_name:
log.info(f"Invoking LLM: model={model_name}")
else:
log.info("Invoking LLM")

type_map = {
"human": "User",
"ai": "Bot",
Expand Down
77 changes: 77 additions & 0 deletions nemoguardrails/logging/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from typing import Any, Dict, Optional

log = logging.getLogger(__name__)


def extract_model_name_and_base_url(
serialized: Dict[str, Any]
) -> tuple[Optional[str], Optional[str]]:
"""Extract model name and base URL from serialized LLM parameters.

Args:
serialized: The serialized LLM configuration

Returns:
A tuple of (model_name, base_url). Either value can be None if not found
"""
model_name = None
base_url = None

# Case 1: Try to extract from kwargs (we expect kwargs to be populated for the `ChatOpenAI` class).
if "kwargs" in serialized:
kwargs = serialized["kwargs"]

# Check for model_name in kwargs (ChatOpenAI attribute)
if "model_name" in kwargs and kwargs["model_name"]:
model_name = str(kwargs["model_name"])

# Check for openai_api_base in kwargs (ChatOpenAI attribute)
if "openai_api_base" in kwargs and kwargs["openai_api_base"]:
base_url = str(kwargs["openai_api_base"])

# Case 2: For other providers, parse `repr`, a string representation of the provider class. We don't have
# a reference to the actual class, so we need to parse the string representation.
if "repr" in serialized and isinstance(serialized["repr"], str):
repr_str = serialized["repr"]

# Extract model name. We expect the property to be formatted like model='...' or model_name='...'
if not model_name:
match = re.search(r"model(?:_name)?=['\"]([^'\"]+)['\"]", repr_str)
if match:
model_name = match.group(1)

# Extract base URL. The propety name may vary between providers, so try common attribute patterns.
if not base_url:
url_attrs = [
"api_base",
"api_host",
"azure_endpoint",
"base_url",
"endpoint",
"endpoint_url",
"openai_api_base",
]
for attr in url_attrs:
match = re.search(rf"{attr}=['\"]([^'\"]+)['\"]", repr_str)
if match:
base_url = match.group(1)
break

return model_name, base_url
120 changes: 120 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nemoguardrails.logging.callbacks import LoggingCallbackHandler
from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo
from nemoguardrails.logging.stats import LLMStats
from nemoguardrails.logging.utils import extract_model_name_and_base_url


@pytest.mark.asyncio
Expand Down Expand Up @@ -261,3 +262,122 @@ def __init__(self, content, msg_type):
assert logged_prompt is not None
assert "[cyan]Custom[/]" in logged_prompt
assert "[cyan]Function[/]" in logged_prompt


def test_extract_model_and_url_from_kwargs():
"""Test extracting model_name and openai_api_base from kwargs (ChatOpenAI case)."""
serialized = {
"kwargs": {
"model_name": "gpt-4",
"openai_api_base": "https://api.openai.com/v1",
"temperature": 0.7,
}
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "gpt-4"
assert base_url == "https://api.openai.com/v1"


def test_extract_model_and_url_from_repr():
"""Test extracting from repr string (ChatNIM case)."""
# Property values in single-quotes
serialized = {
"kwargs": {"temperature": 0.1},
"repr": "ChatNIM(model='meta/llama-3.3-70b-instruct', client=<openai.OpenAI object at 0x10d8e4e90>, endpoint_url='https://nim.int.aire.nvidia.com/v1')",
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "meta/llama-3.3-70b-instruct"
assert base_url == "https://nim.int.aire.nvidia.com/v1"

# Property values in double-quotes
serialized = {
"repr": 'ChatOpenAI(model="gpt-3.5-turbo", base_url="https://custom.api.com/v1")'
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "gpt-3.5-turbo"
assert base_url == "https://custom.api.com/v1"

# Model is stored in the `model_name` property
serialized = {
"repr": "SomeProvider(model_name='custom-model-v2', api_base='https://example.com')"
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "custom-model-v2"
assert base_url == "https://example.com"


def test_extract_model_and_url_from_various_url_properties():
"""Test extracting various URL property names."""
test_cases = [
("api_base='https://api1.com'", "https://api1.com"),
("api_host='https://api2.com'", "https://api2.com"),
("azure_endpoint='https://azure.com'", "https://azure.com"),
("endpoint='https://endpoint.com'", "https://endpoint.com"),
("openai_api_base='https://openai.com'", "https://openai.com"),
]

for url_pattern, expected_url in test_cases:
serialized = {"repr": f"Provider(model='test-model', {url_pattern})"}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert base_url == expected_url, f"Failed for pattern: {url_pattern}"


def test_extract_model_and_url_kwargs_priority_over_repr():
"""Test that kwargs values, if present, take priority over repr values."""
serialized = {
"kwargs": {
"model_name": "gpt-4-from-kwargs",
"openai_api_base": "https://kwargs.api.com",
},
"repr": "ChatOpenAI(model='gpt-3.5-from-repr', base_url='https://repr.api.com')",
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "gpt-4-from-kwargs"
assert base_url == "https://kwargs.api.com"


def test_extract_model_and_url_with_missing_values():
"""Test extraction when values are missing."""
# No model or URL
serialized = {"kwargs": {"temperature": 0.7}}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name is None
assert base_url is None

# Only model, no URL
serialized = {"kwargs": {"model_name": "gpt-4"}}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name == "gpt-4"
assert base_url is None

# Only URL, no model
serialized = {"repr": "Provider(endpoint_url='https://example.com')"}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name is None
assert base_url == "https://example.com"


def test_extract_model_and_url_with_empty_values():
"""Test extraction when values are empty strings."""
serialized = {"kwargs": {"model_name": "", "openai_api_base": ""}}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name is None
assert base_url is None


def test_extract_model_and_url_with_empty_serialized_data():
"""Test extraction with empty or minimal serialized dict."""
serialized = {}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name is None
assert base_url is None