Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 20 additions & 58 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@
GenerationOptions,
GenerationResponse,
)
from nemoguardrails.rails.llm.utils import get_history_cache_key
from nemoguardrails.rails.llm.utils import (
get_action_details_from_flow_id,
get_history_cache_key,
)
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
from nemoguardrails.utils import (
extract_error_json,
Expand Down Expand Up @@ -520,7 +523,21 @@ def _create_isolated_llms_for_actions(self):
)

created_count = 0
for action_name in actions_needing_llms:
# out of flows defined in rails config we get the actions
get_action_details = partial(
get_action_details_from_flow_id, flows=self.config.flows
)
configured_actions_names = []
for flow_id in self.config.rails.input.flows:
action_name, _ = get_action_details(flow_id)
configured_actions_names.append(action_name)
for flow_id in self.config.rails.output.flows:
action_name, _ = get_action_details(flow_id)
configured_actions_names.append(action_name)

for action_name in configured_actions_names:
if action_name not in actions_needing_llms:
continue
if f"{action_name}_llm" not in self.runtime.registered_action_params:
isolated_llm = self._create_action_llm_copy(self.llm, action_name)
if isolated_llm:
Expand Down Expand Up @@ -1590,7 +1607,7 @@ def _prepare_params(
output_rails_flows_id = self.config.rails.output.flows
stream_first = stream_first or output_rails_streaming_config.stream_first
get_action_details = partial(
_get_action_details_from_flow_id, flows=self.config.flows
get_action_details_from_flow_id, flows=self.config.flows
)

parallel_mode = getattr(self.config.rails.output, "parallel", False)
Expand Down Expand Up @@ -1746,58 +1763,3 @@ def _prepare_params(
# yield the individual chunks directly from the buffer strategy
for chunk in user_output_chunks:
yield chunk


def _get_action_details_from_flow_id(
flow_id: str,
flows: List[Union[Dict, Any]],
prefixes: Optional[List[str]] = None,
) -> Tuple[str, Any]:
"""Get the action name and parameters from the flow id.

First, try to find an exact match.
If not found, then if the provided flow_id starts with one of the special prefixes,
return the first flow whose id starts with that same prefix.
"""

supported_prefixes = [
"content safety check output",
"topic safety check output",
]
if prefixes:
supported_prefixes.extend(prefixes)

candidate_flow = None

for flow in flows:
# If exact match, use it
if flow["id"] == flow_id:
candidate_flow = flow
break

# If no exact match, check if both the provided flow_id and this flow's id share a special prefix
for prefix in supported_prefixes:
if flow_id.startswith(prefix) and flow["id"].startswith(prefix):
candidate_flow = flow
# We don't break immediately here because an exact match would have been preferred,
# but since we're in the else branch it's fine to choose the first matching candidate.
# TODO:we should avoid having multiple matchin prefixes
break

if candidate_flow is not None:
break

if candidate_flow is None:
raise ValueError(f"No action found for flow_id: {flow_id}")

# we have identified a candidate, look for the run_action element.
for element in candidate_flow["elements"]:
if (
element["_type"] == "run_action"
and element["_source_mapping"]["filename"].endswith(".co")
and "execute" in element["_source_mapping"]["line_text"]
and "action_name" in element
):
return element["action_name"], element["action_params"]

raise ValueError(f"No run_action element found for flow_id: {flow_id}")
56 changes: 55 additions & 1 deletion nemoguardrails/rails/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import List
from typing import Any, Dict, List, Optional, Tuple, Union


def get_history_cache_key(messages: List[dict]) -> str:
Expand Down Expand Up @@ -56,3 +56,57 @@ def get_history_cache_key(messages: List[dict]) -> str:
history_cache_key = ":".join(key_items)

return history_cache_key


def get_action_details_from_flow_id(
flow_id: str,
flows: List[Union[Dict, Any]],
prefixes: Optional[List[str]] = None,
) -> Tuple[str, Any]:
"""Get the action name and parameters from the flow id.

First, try to find an exact match.
If not found, then if the provided flow_id starts with one of the special prefixes,
return the first flow whose id starts with that same prefix.
"""
supported_prefixes = [
"content safety check output",
"topic safety check output",
]
if prefixes:
supported_prefixes.extend(prefixes)

candidate_flow = None

for flow in flows:
# If exact match, use it
if flow["id"] == flow_id:
candidate_flow = flow
break

# If no exact match, check if both the provided flow_id and this flow's id share a special prefix
for prefix in supported_prefixes:
if flow_id.startswith(prefix) and flow["id"].startswith(prefix):
candidate_flow = flow
# We don't break immediately here because an exact match would have been preferred,
# but since we're in the else branch it's fine to choose the first matching candidate.
# TODO:we should avoid having multiple matchin prefixes
break

if candidate_flow is not None:
break

if candidate_flow is None:
raise ValueError(f"No action found for flow_id: {flow_id}")

# we have identified a candidate, look for the run_action element.
for element in candidate_flow["elements"]:
if (
element["_type"] == "run_action"
and element["_source_mapping"]["filename"].endswith(".co")
and "execute" in element["_source_mapping"]["line_text"]
and "action_name" in element
):
return element["action_name"], element["action_params"]

raise ValueError(f"No run_action element found for flow_id: {flow_id}")
153 changes: 139 additions & 14 deletions tests/test_llm_isolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,48 +256,89 @@ def test_create_isolated_llms_for_actions_integration(self, rails_with_mock_llm)
"""Test the full isolated LLM creation process."""
rails = rails_with_mock_llm

# Mock rails configuration with flows
rails.config.rails = Mock()
rails.config.rails.input = Mock()
rails.config.rails.output = Mock()
rails.config.rails.input.flows = ["input_flow_1", "input_flow_2"]
rails.config.rails.output.flows = ["output_flow_1"]

rails.runtime = Mock()
rails.runtime.action_dispatcher = MockActionDispatcher()
rails.runtime.registered_action_params = {}
rails.runtime.register_action_param = Mock()

rails._create_isolated_llms_for_actions()

expected_calls = [
# Mock get_action_details_from_flow_id to return actions that need LLMs
def mock_get_action_details(flow_id, flows):
mapping = {
"input_flow_1": ("action_with_llm", {}),
"input_flow_2": ("generate_user_intent", {}),
"output_flow_1": ("self_check_output", {}),
}
return mapping.get(flow_id, ("unknown_action", {}))

with patch(
"nemoguardrails.rails.llm.llmrails.get_action_details_from_flow_id",
side_effect=mock_get_action_details,
):
rails._create_isolated_llms_for_actions()

expected_llm_params = [
"action_with_llm_llm",
"generate_user_intent_llm",
"self_check_output_llm",
]

actual_calls = [
registered_llm_params = [
call[0][0] for call in rails.runtime.register_action_param.call_args_list
]

for expected_call in expected_calls:
assert expected_call in actual_calls
for expected_param in expected_llm_params:
assert expected_param in registered_llm_params

def test_create_isolated_llms_skips_existing_specialized_llms(
self, rails_with_mock_llm
):
"""Test that existing specialized LLMs are not overridden."""
rails = rails_with_mock_llm

# Mock rails configuration with flows
rails.config.rails = Mock()
rails.config.rails.input = Mock()
rails.config.rails.output = Mock()
rails.config.rails.input.flows = ["input_flow_1", "input_flow_2"]
rails.config.rails.output.flows = ["output_flow_1"]

rails.runtime = Mock()
rails.runtime.action_dispatcher = MockActionDispatcher()
rails.runtime.registered_action_params = {"self_check_output_llm": Mock()}
rails.runtime.register_action_param = Mock()

rails._create_isolated_llms_for_actions()

# verify self_check_output_llm was NOT re-registered
actual_calls = [
# Mock get_action_details_from_flow_id to return actions that need LLMs
def mock_get_action_details(flow_id, flows):
mapping = {
"input_flow_1": ("action_with_llm", {}),
"input_flow_2": ("generate_user_intent", {}),
"output_flow_1": (
"self_check_output",
{},
), # This one already has an LLM
}
return mapping.get(flow_id, ("unknown_action", {}))

with patch(
"nemoguardrails.rails.llm.llmrails.get_action_details_from_flow_id",
side_effect=mock_get_action_details,
):
rails._create_isolated_llms_for_actions()

registered_llm_params = [
call[0][0] for call in rails.runtime.register_action_param.call_args_list
]
assert "self_check_output_llm" not in actual_calls

# but other actions should still get isolated LLMs
assert "action_with_llm_llm" in actual_calls
assert "generate_user_intent_llm" in actual_calls
assert "self_check_output_llm" not in registered_llm_params
assert "action_with_llm_llm" in registered_llm_params
assert "generate_user_intent_llm" in registered_llm_params

def test_create_isolated_llms_handles_no_main_llm(self, mock_config):
"""Test graceful handling when no main LLM is available."""
Expand Down Expand Up @@ -412,3 +453,87 @@ def test_action_detection_parametrized(
assert action_name in actions_needing_llms
else:
assert action_name not in actions_needing_llms

def test_create_isolated_llms_for_configured_actions_only(
self, rails_with_mock_llm
):
"""Test that isolated LLMs are created only for actions configured in rails flows."""
rails = rails_with_mock_llm

rails.config.rails = Mock()
rails.config.rails.input = Mock()
rails.config.rails.output = Mock()
rails.config.rails.input.flows = [
"input_flow_1",
"input_flow_2",
"input_flow_3",
]
rails.config.rails.output.flows = ["output_flow_1", "output_flow_2"]

rails.runtime = Mock()
rails.runtime.action_dispatcher = MockActionDispatcher()
rails.runtime.registered_action_params = {}
rails.runtime.register_action_param = Mock()

def mock_get_action_details(flow_id, flows):
mapping = {
"input_flow_1": ("action_with_llm", {}),
"input_flow_2": ("action_without_llm", {}),
"input_flow_3": ("self_check_output", {}),
"output_flow_1": ("generate_user_intent", {}),
"output_flow_2": ("non_configured_action", {}),
}
return mapping.get(flow_id, ("unknown_action", {}))

with patch(
"nemoguardrails.rails.llm.llmrails.get_action_details_from_flow_id",
side_effect=mock_get_action_details,
):
rails._create_isolated_llms_for_actions()

registered_llm_params = [
call[0][0] for call in rails.runtime.register_action_param.call_args_list
]

expected_isolated_llm_params = [
"action_with_llm_llm",
"generate_user_intent_llm",
"self_check_output_llm",
]

for expected_param in expected_isolated_llm_params:
assert (
expected_param in registered_llm_params
), f"Expected {expected_param} to be registered as action param"

assert "action_without_llm_llm" not in registered_llm_params
assert "non_configured_action_llm" not in registered_llm_params

assert len(registered_llm_params) == 3, (
f"Should only create isolated LLMs for actions from config flows that need LLMs. "
f"Got {registered_llm_params}"
)

def test_create_isolated_llms_handles_empty_rails_config(self, rails_with_mock_llm):
"""Test that the method handles empty rails configuration gracefully."""
rails = rails_with_mock_llm

rails.config.rails = Mock()
rails.config.rails.input = Mock()
rails.config.rails.output = Mock()
rails.config.rails.input.flows = []
rails.config.rails.output.flows = []

rails.runtime = Mock()
rails.runtime.action_dispatcher = MockActionDispatcher()
rails.runtime.registered_action_params = {}
rails.runtime.register_action_param = Mock()

with patch(
"nemoguardrails.rails.llm.llmrails.get_action_details_from_flow_id"
) as mock_get_action:
rails._create_isolated_llms_for_actions()

mock_get_action.assert_not_called()

rails.runtime.register_action_param.assert_not_called()
Loading
Loading