diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index 8f94882b26..917b57ed71 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -203,11 +203,19 @@ def __init__( # pylint: disable=too-many-positional-arguments for route in routes: # extract inputs - route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]]) + route_input_names = self._extract_variables( + self._env, + [route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]), + ) input_types.update(route_input_names) # extract outputs - output_types.update({route["output_name"]: route["output_type"]}) + output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]] + output_types_list = ( + route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]] + ) + + output_types.update(dict(zip(output_names, output_types_list))) # remove optional variables from mandatory input types mandatory_input_types = input_types - set(self.optional_variables) @@ -306,27 +314,45 @@ def run(self, **kwargs): rendered = ast.literal_eval(rendered) if not rendered: continue - # We now evaluate the `output` expression to determine the route output - t_output = self._env.from_string(route["output"]) - output = t_output.render(**kwargs) - # We suppress the exception in case the output is already a string, otherwise - # we try to evaluate it and would fail. - # This must be done cause the output could be different literal structures. - # This doesn't support any user types. - with contextlib.suppress(Exception): - if not self._unsafe: - output = ast.literal_eval(output) + + # Handle multiple outputs + outputs = route["output"] if isinstance(route["output"], list) else [route["output"]] + output_types = ( + route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]] + ) + output_names = ( + route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]] + ) + + result = {} + for output, output_type, output_name in zip(outputs, output_types, output_names): + # Evaluate output template + t_output = self._env.from_string(output) + output_value = t_output.render(**kwargs) + + # We suppress the exception in case the output is already a string, otherwise + # we try to evaluate it and would fail. + # This must be done cause the output could be different literal structures. + # This doesn't support any user types. + with contextlib.suppress(Exception): + if not self._unsafe: + output_value = ast.literal_eval(output_value) + + # Validate output type if needed + if self._validate_output_type and not self._output_matches_type(output_value, output_type): + raise ValueError(f"Route '{output_name}' type doesn't match expected type") + + result[output_name] = output_value + + return result + except Exception as e: + # If this was a type‐validation failure, let it propagate as a ValueError + if isinstance(e, ValueError): + raise msg = f"Error evaluating condition for route '{route}': {e}" raise RouteConditionException(msg) from e - if self._validate_output_type and not self._output_matches_type(output, route["output_type"]): - msg = f"""Route '{route["output_name"]}' type doesn't match expected type""" - raise ValueError(msg) - - # and return the output as a dictionary under the output_name key - return {route["output_name"]: output} - raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}") def _validate_routes(self, routes: List[Dict]): @@ -347,9 +373,23 @@ def _validate_routes(self, routes: List[Dict]): raise ValueError( f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}" ) - for field in ["condition", "output"]: - if not self._validate_template(self._env, route[field]): - raise ValueError(f"Invalid template for field '{field}': {route[field]}") + + # Validate outputs are consistent + outputs = route["output"] if isinstance(route["output"], list) else [route["output"]] + output_types = route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]] + output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]] + + # Check lengths match + if not len(outputs) == len(output_types) == len(output_names): + raise ValueError(f"Route output, output_type and output_name must have same length: {route}") + + # Validate templates + if not self._validate_template(self._env, route["condition"]): + raise ValueError(f"Invalid template for condition: {route['condition']}") + + for output in outputs: + if not self._validate_template(self._env, output): + raise ValueError(f"Invalid template for output: {output}") def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]: """ diff --git a/releasenotes/notes/multiple-outputs-conditional-router-c2f0caad3d3f8ce5.yaml b/releasenotes/notes/multiple-outputs-conditional-router-c2f0caad3d3f8ce5.yaml new file mode 100644 index 0000000000..424a359ceb --- /dev/null +++ b/releasenotes/notes/multiple-outputs-conditional-router-c2f0caad3d3f8ce5.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add support for multiple outputs in ConditionalRouter diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index 0a86613454..df2321f6ad 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -574,3 +574,44 @@ def test_router_to_dict_does_not_mutate_routes(self): assert new_router.routes == router.routes assert new_router.routes[0]["output_type"] is str assert new_router.routes[0]["output_type"] is original_output_type + + def test_multiple_outputs_per_route(self): + """Test that router handles multiple outputs per route correctly""" + routes = [ + { + "condition": "{{streams|length >= 2}}", + "output": ["{{streams}}", "{{query}}"], + "output_type": [List[int], str], + "output_name": ["streams", "query"], + }, + { + "condition": "{{streams|length < 2}}", + "output": ["{{streams}}", "{{custom_error_message}}"], + "output_type": [List[int], str], + "output_name": ["streams", "custom_error_message"], + }, + ] + router = ConditionalRouter(routes) + + # Test with sufficient input streams + result = router.run(streams=[1, 2, 3], query="test_1", custom_error_message="Not enough streams") + assert result == {"streams": [1, 2, 3], "query": "test_1"} + + # Test with insufficient input streams + result = router.run(streams=[1], query="test_2", custom_error_message="Not enough streams") + assert result == {"streams": [1], "custom_error_message": "Not enough streams"} + + def test_multiple_outputs_validation(self): + """Test validation of routes with multiple outputs""" + # Test mismatched lengths + with pytest.raises(ValueError, match="must have same length"): + ConditionalRouter( + [ + { + "condition": "{{streams|length >= 2}}", + "output": ["{{streams}}", "{{query}}"], + "output_type": [List[int]], + "output_name": ["streams"], + } + ] + )