Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 62 additions & 22 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,19 @@ def __init__( # pylint: disable=too-many-positional-arguments

for route in routes:
# extract inputs
route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]])
route_input_names = self._extract_variables(
self._env,
[route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]),
)
input_types.update(route_input_names)

# extract outputs
output_types.update({route["output_name"]: route["output_type"]})
output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
output_types_list = (
route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
)

output_types.update(dict(zip(output_names, output_types_list)))

# remove optional variables from mandatory input types
mandatory_input_types = input_types - set(self.optional_variables)
Expand Down Expand Up @@ -306,27 +314,45 @@ def run(self, **kwargs):
rendered = ast.literal_eval(rendered)
if not rendered:
continue
# We now evaluate the `output` expression to determine the route output
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output = ast.literal_eval(output)

# Handle multiple outputs
outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
output_types = (
route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
)
output_names = (
route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
)

result = {}
for output, output_type, output_name in zip(outputs, output_types, output_names):
# Evaluate output template
t_output = self._env.from_string(output)
output_value = t_output.render(**kwargs)

# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output_value = ast.literal_eval(output_value)

# Validate output type if needed
if self._validate_output_type and not self._output_matches_type(output_value, output_type):
raise ValueError(f"Route '{output_name}' type doesn't match expected type")

result[output_name] = output_value

return result

except Exception as e:
# If this was a type‐validation failure, let it propagate as a ValueError
if isinstance(e, ValueError):
raise
msg = f"Error evaluating condition for route '{route}': {e}"
raise RouteConditionException(msg) from e

if self._validate_output_type and not self._output_matches_type(output, route["output_type"]):
msg = f"""Route '{route["output_name"]}' type doesn't match expected type"""
raise ValueError(msg)

# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}

raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")

def _validate_routes(self, routes: List[Dict]):
Expand All @@ -347,9 +373,23 @@ def _validate_routes(self, routes: List[Dict]):
raise ValueError(
f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
)
for field in ["condition", "output"]:
if not self._validate_template(self._env, route[field]):
raise ValueError(f"Invalid template for field '{field}': {route[field]}")

# Validate outputs are consistent
outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
output_types = route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]

# Check lengths match
if not (len(outputs) == len(output_types) == len(output_names)):
raise ValueError(f"Route output, output_type and output_name must have same length: {route}")

# Validate templates
if not self._validate_template(self._env, route["condition"]):
raise ValueError(f"Invalid template for condition: {route['condition']}")

for output in outputs:
if not self._validate_template(self._env, output):
raise ValueError(f"Invalid template for output: {output}")

def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add support for multiple outputs in ConditionalRouter
31 changes: 31 additions & 0 deletions test/components/routers/test_conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,34 @@ def test_router_to_dict_does_not_mutate_routes(self):
assert new_router.routes == router.routes
assert new_router.routes[0]["output_type"] is str
assert new_router.routes[0]["output_type"] is original_output_type

def test_multiple_outputs_per_route(self):
"""Test that router handles multiple outputs per route correctly"""
routes = [
{
"condition": "{{streams|length >= 2}}",
"output": ["{{streams}}", "{{query}}"],
"output_type": [List[int], str],
"output_name": ["streams", "query"],
}
]
router = ConditionalRouter(routes)

# Test with valid input
result = router.run(streams=[1, 2, 3], query="test")
assert result == {"streams": [1, 2, 3], "query": "test"}

def test_multiple_outputs_validation(self):
"""Test validation of routes with multiple outputs"""
# Test mismatched lengths
with pytest.raises(ValueError, match="must have same length"):
ConditionalRouter(
[
{
"condition": "{{streams|length >= 2}}",
"output": ["{{streams}}", "{{query}}"],
"output_type": [List[int]],
"output_name": ["streams"],
}
]
)
Loading