Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ def __init__(
# We also register the kb as a parameter that can be passed to actions.
self.runtime.register_action_param("kb", self.kb)

# detect actions that need isolated LLM instances and create them
self._create_isolated_llms_for_actions()
# Reference to the general ExplainInfo object.
self.explain_info = None

Expand Down Expand Up @@ -507,9 +509,6 @@ def _init_llms(self):

self.runtime.register_action_param("llms", llms)

# detect actions that need isolated LLM instances and create them
self._create_isolated_llms_for_actions()

def _create_isolated_llms_for_actions(self):
"""Create isolated LLM copies for all actions that accept 'llm' parameter."""
if not self.llm:
Expand All @@ -525,17 +524,38 @@ def _create_isolated_llms_for_actions(self):
)

created_count = 0
# Get the actions from flows defined in rails config
get_action_details = partial(
get_action_details_from_flow_id, flows=self.config.flows
)

configured_actions_names = []
for flow_id in self.config.rails.input.flows:
action_name, _ = get_action_details(flow_id)
configured_actions_names.append(action_name)
for flow_id in self.config.rails.output.flows:
action_name, _ = get_action_details(flow_id)
configured_actions_names.append(action_name)
try:
if self.config.flows:
get_action_details = partial(
get_action_details_from_flow_id, flows=self.config.flows
)
for flow_id in self.config.rails.input.flows:
action_name, _ = get_action_details(flow_id)
configured_actions_names.append(action_name)
for flow_id in self.config.rails.output.flows:
action_name, _ = get_action_details(flow_id)
configured_actions_names.append(action_name)
else:
# for configurations without flow definitions, use all actions that need LLMs
print(
"No flow definitions found, creating isolated LLMs for all actions requiring them"
)
configured_actions_names = list(actions_needing_llms)
except Exception as e:
# if flow matching fails, fall back to all actions that need LLMs
log.info(
"No flow definitions found, creating isolated LLMs for all actions requiring them"
)
configured_actions_names = list(actions_needing_llms)
except Exception as e:
# if flow matching fails, fall back to all actions that need LLMs
log.warning(
"Flow matching failed (%s), creating isolated LLMs for all actions requiring them",
e,
)
configured_actions_names = list(actions_needing_llms)

for action_name in configured_actions_names:
if action_name not in actions_needing_llms:
Expand Down
27 changes: 8 additions & 19 deletions nemoguardrails/rails/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id


def get_history_cache_key(messages: List[dict]) -> str:
Expand Down Expand Up @@ -61,37 +63,24 @@ def get_history_cache_key(messages: List[dict]) -> str:
def get_action_details_from_flow_id(
flow_id: str,
flows: List[Union[Dict, Any]],
prefixes: Optional[List[str]] = None,
) -> Tuple[str, Any]:
"""Get the action name and parameters from the flow id.

First, try to find an exact match.
If not found, then if the provided flow_id starts with one of the special prefixes,
return the first flow whose id starts with that same prefix.
"""
supported_prefixes = [
"content safety check output",
"topic safety check output",
]
if prefixes:
supported_prefixes.extend(prefixes)

candidate_flow = None

for flow in flows:
# If exact match, use it
if flow["id"] == flow_id:
flow_id = _normalize_flow_id(flow_id)
normalized_flow_id = _normalize_flow_id(flow_id)
for flow in flows:
# If exact match, use it
if flow["id"] == normalized_flow_id:
candidate_flow = flow
break

# If no exact match, check if both the provided flow_id and this flow's id share a special prefix
for prefix in supported_prefixes:
if flow_id.startswith(prefix) and flow["id"].startswith(prefix):
candidate_flow = flow
# We don't break immediately here because an exact match would have been preferred,
# but since we're in the else branch it's fine to choose the first matching candidate.
# TODO:we should avoid having multiple matchin prefixes
break

if candidate_flow is not None:
break
Expand Down
41 changes: 41 additions & 0 deletions tests/test_llm_isolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,44 @@ def test_create_isolated_llms_handles_empty_rails_config(self, rails_with_mock_l
mock_get_action.assert_not_called()

rails.runtime.register_action_param.assert_not_called()

def test_llm_isolation_timing_with_empty_flows(self, rails_with_mock_llm, caplog):
"""Test that LLM isolation handles empty flows gracefully during initialization.

This test reproduces the timing issue where _create_isolated_llms_for_actions()
was called before flows were properly loaded. Before the fix, this would fail
when trying to resolve rail flow IDs against an empty flows list, causing
LLM isolation to fail silently with a warning log.
"""
rails = rails_with_mock_llm

rails.llm = MockLLM(model_kwargs={}, temperature=0.7)

# simulate the problematic scenario: rail flows defined but config.flows empty
rails.config.rails = Mock()
rails.config.rails.input = Mock()
rails.config.rails.output = Mock()
rails.config.rails.input.flows = [
"content safety check input $model=content_safety"
]
rails.config.rails.output.flows = [
"content safety check output $model=content_safety"
]
rails.config.flows = [] # Empty flows list (timing issue scenario)

rails.runtime = Mock()
rails.runtime.action_dispatcher = MockActionDispatcher()
rails.runtime.registered_action_params = {}
rails.runtime.register_action_param = Mock()

# before the fix, this would log a warning about failing to create isolated LLMs
# after the fix, it should handle empty flows gracefully without the warning
rails._create_isolated_llms_for_actions()

warning_messages = [
record.message for record in caplog.records if record.levelname == "WARNING"
]
assert not any(
"Failed to create isolated LLMs for actions" in msg
for msg in warning_messages
), f"Fix failed: Warning still logged: {warning_messages}"
64 changes: 8 additions & 56 deletions tests/test_rails_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def test_get_action_details_from_flow_id_exact_match():
assert action_params == {"param1": "value1"}


def test_get_action_details_from_flow_id_prefix_match():
"""Test get_action_details_from_flow_id with prefix matching."""
def test_get_action_details_from_flow_id_content_safety():
"""Test get_action_details_from_flow_id ."""
flows = [
{
"id": "content safety check output with model gpt-4",
"id": "content safety check output",
"elements": [
{
"_type": "run_action",
Expand All @@ -165,17 +165,17 @@ def test_get_action_details_from_flow_id_prefix_match():
]

action_name, action_params = get_action_details_from_flow_id(
"content safety check output", flows
"content safety check output $model=anothe_model_config", flows
)
assert action_name == "content_safety_check"
assert action_params == {"model": "gpt-4"}


def test_get_action_details_from_flow_id_topic_safety_prefix():
"""Test get_action_details_from_flow_id with topic safety prefix."""
def test_get_action_details_from_flow_id_topic_safety():
"""Test get_action_details_from_flow_id with topic safety."""
flows = [
{
"id": "topic safety check output with model claude",
"id": "topic safety check output",
"elements": [
{
"_type": "run_action",
Expand All @@ -191,38 +191,12 @@ def test_get_action_details_from_flow_id_topic_safety_prefix():
]

action_name, action_params = get_action_details_from_flow_id(
"topic safety check output", flows
"topic safety check output $model=claude_model", flows
)
assert action_name == "topic_safety_check"
assert action_params == {"model": "claude"}


def test_get_action_details_from_flow_id_custom_prefixes():
"""Test get_action_details_from_flow_id with custom prefixes."""
flows = [
{
"id": "custom prefix flow with params",
"elements": [
{
"_type": "run_action",
"_source_mapping": {
"filename": "custom.co",
"line_text": "execute custom_action",
},
"action_name": "custom_action",
"action_params": {"custom": "value"},
}
],
}
]

action_name, action_params = get_action_details_from_flow_id(
"custom prefix flow", flows, prefixes=["custom prefix"]
)
assert action_name == "custom_action"
assert action_params == {"custom": "value"}


def test_get_action_details_from_flow_id_no_match():
"""Test get_action_details_from_flow_id when no flow matches."""
flows = [
Expand Down Expand Up @@ -410,28 +384,6 @@ def test_get_action_details_exact_match_not_colang_2(dummy_flows):
assert "No run_action element found for flow_id" in str(exc_info.value)


def test_get_action_details_prefix_match(dummy_flows):
# For a flow_id that starts with the prefix "other_flow",
# we expect to retrieve the action details from the flow whose id starts with that prefix.
# we expect a result since we are passing the prefixes argument.
action_name, action_params = get_action_details_from_flow_id(
"other_flow", dummy_flows, prefixes=["other_flow"]
)
assert action_name == "other_action"
assert action_params == {"param2": "value2"}


def test_get_action_details_prefix_match_unsupported_prefix(dummy_flows):
# For a flow_id that starts with the prefix "other_flow",
# we expect to retrieve the action details from the flow whose id starts with that prefix.
# but as the prefix is not supported, we expect a ValueError.

with pytest.raises(ValueError) as exc_info:
get_action_details_from_flow_id("other_flow", dummy_flows)

assert "No action found for flow_id" in str(exc_info.value)


def test_get_action_details_no_match(dummy_flows):
# Tests that a non matching flow_id raises a ValueError
with pytest.raises(ValueError) as exc_info:
Expand Down