diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 155ba27fe5..8511a9ca38 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -34,7 +34,7 @@ DEFAULT_MARSHALLER = YamlMarshaller() -# We use a generic type to annotate the return value of classmethods, +# We use a generic type to annotate the return value of class methods, # so that static analyzers won't be confused when derived classes # use those methods. T = TypeVar("T", bound="PipelineBase") @@ -619,31 +619,76 @@ def outputs(self, include_components_with_connected_outputs: bool = False) -> Di } return outputs - def show(self) -> None: + def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None: """ - If running in a Jupyter notebook, display an image representing this `Pipeline`. + Display an image representing this `Pipeline` in a Jupyter notebook. + This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in + the notebook. + + :param server_url: + The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink'). + See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more + info on how to set up your own Mermaid server. + + :param params: + Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details + Supported keys: + - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'. + - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'. + - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'. + - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white'). + - width: Width of the output image (integer). + - height: Height of the output image (integer). + - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified. + - fit: Whether to fit the diagram size to the page (PDF only, boolean). + - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true. + - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true. + + :raises PipelineDrawingError: + If the function is called outside of a Jupyter notebook or if there is an issue with rendering. """ if is_in_jupyter(): from IPython.display import Image, display # type: ignore - image_data = _to_mermaid_image(self.graph) - + image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params) display(Image(image_data)) else: msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally." raise PipelineDrawingError(msg) - def draw(self, path: Path) -> None: + def draw(self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None: """ - Save an image representing this `Pipeline` to `path`. + Save an image representing this `Pipeline` to the specified file path. + + This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path. :param path: - The path to save the image to. + The file path where the generated image will be saved. + :param server_url: + The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink'). + See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more + info on how to set up your own Mermaid server. + :param params: + Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details + Supported keys: + - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'. + - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'. + - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'. + - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white'). + - width: Width of the output image (integer). + - height: Height of the output image (integer). + - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified. + - fit: Whether to fit the diagram size to the page (PDF only, boolean). + - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true. + - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true. + + :raises PipelineDrawingError: + If there is an issue with rendering or saving the image. """ # Before drawing we edit a bit the graph, to avoid modifying the original that is # used for running the pipeline we copy it. - image_data = _to_mermaid_image(self.graph) + image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params) Path(path).write_bytes(image_data) def walk(self) -> Iterator[Tuple[str, Component]]: diff --git a/haystack/core/pipeline/draw.py b/haystack/core/pipeline/draw.py index b367696d84..51cec9f6bc 100644 --- a/haystack/core/pipeline/draw.py +++ b/haystack/core/pipeline/draw.py @@ -5,6 +5,7 @@ import base64 import json import zlib +from typing import Any, Dict, Optional import networkx # type:ignore import requests @@ -54,7 +55,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph: ARROWHEAD_MANDATORY = "-->" ARROWHEAD_OPTIONAL = ".->" MERMAID_STYLED_TEMPLATE = """ -%%{{ init: {{'theme': 'neutral' }} }}%% +%%{{ init: {params} }}%% graph TD; @@ -64,27 +65,133 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph: """ -def _to_mermaid_image(graph: networkx.MultiDiGraph): +def _validate_mermaid_params(params: Dict[str, Any]) -> None: """ - Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access. + Validates and sets default values for Mermaid parameters. + + :param params: + Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details. + Supported keys: + - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'. + - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'. + - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'. + - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white'). + - width: Width of the output image (integer). + - height: Height of the output image (integer). + - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified. + - fit: Whether to fit the diagram size to the page (PDF only, boolean). + - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true. + - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true. + + :raises ValueError: + If any parameter is invalid or does not match the expected format. + """ + valid_img_types = {"jpeg", "png", "webp"} + valid_themes = {"default", "neutral", "dark", "forest"} + valid_formats = {"img", "svg", "pdf"} + + params.setdefault("format", "img") + params.setdefault("type", "png") + params.setdefault("theme", "neutral") + + if params["format"] not in valid_formats: + raise ValueError(f"Invalid image format: {params['format']}. Valid options are: {valid_formats}.") + + if params["format"] == "img" and params["type"] not in valid_img_types: + raise ValueError(f"Invalid image type: {params['type']}. Valid options are: {valid_img_types}.") + + if params["theme"] not in valid_themes: + raise ValueError(f"Invalid theme: {params['theme']}. Valid options are: {valid_themes}.") + + if "width" in params and not isinstance(params["width"], int): + raise ValueError("Width must be an integer.") + if "height" in params and not isinstance(params["height"], int): + raise ValueError("Height must be an integer.") + + if "scale" in params and not 1 <= params["scale"] <= 3: + raise ValueError("Scale must be a number between 1 and 3.") + if "scale" in params and not ("width" in params or "height" in params): + raise ValueError("Scale is only allowed when width or height is set.") + + if "bgColor" in params and not isinstance(params["bgColor"], str): + raise ValueError("Background color must be a string.") + + # PDF specific parameters + if params["format"] == "pdf": + if "fit" in params and not isinstance(params["fit"], bool): + raise ValueError("Fit must be a boolean.") + if "paper" in params and not isinstance(params["paper"], str): + raise ValueError("Paper size must be a string (e.g., 'a4', 'a3').") + if "landscape" in params and not isinstance(params["landscape"], bool): + raise ValueError("Landscape must be a boolean.") + if "fit" in params and ("paper" in params or "landscape" in params): + logger.warning("`fit` overrides `paper` and `landscape` for PDFs. Ignoring `paper` and `landscape`.") + + +def _to_mermaid_image( + graph: networkx.MultiDiGraph, server_url: str = "https://mermaid.ink", params: Optional[dict] = None +) -> bytes: + """ + Renders a pipeline using a Mermaid server. + + :param graph: + The graph to render as a Mermaid pipeline. + :param server_url: + Base URL of the Mermaid server (default: 'https://mermaid.ink'). + :param params: + Dictionary of customization parameters. See `validate_mermaid_params` for valid keys. + :returns: + The image, SVG, or PDF data returned by the Mermaid server as bytes. + :raises ValueError: + If any parameter is invalid or does not match the expected format. + :raises PipelineDrawingError: + If there is an issue connecting to the Mermaid server or the server returns an error. """ + + if params is None: + params = {} + + _validate_mermaid_params(params) + + theme = params.get("theme") + init_params = json.dumps({"theme": theme}) + # Copy the graph to avoid modifying the original - graph_styled = _to_mermaid_text(graph.copy()) + graph_styled = _to_mermaid_text(graph.copy(), init_params) json_string = json.dumps({"code": graph_styled}) - # Uses the DEFLATE algorithm at the highest level for smallest size - compressor = zlib.compressobj(level=9) + # Compress the JSON string with zlib (RFC 1950) + compressor = zlib.compressobj(level=9, wbits=15) compressed_data = compressor.compress(json_string.encode("utf-8")) + compressor.flush() compressed_url_safe_base64 = base64.urlsafe_b64encode(compressed_data).decode("utf-8").strip() - url = f"https://mermaid.ink/img/pako:{compressed_url_safe_base64}?type=png" + # Determine the correct endpoint + endpoint_format = params.get("format", "img") # Default to /img endpoint + if endpoint_format not in {"img", "svg", "pdf"}: + raise ValueError(f"Invalid format: {endpoint_format}. Valid options are 'img', 'svg', or 'pdf'.") + + # Construct the URL without query parameters + url = f"{server_url}/{endpoint_format}/pako:{compressed_url_safe_base64}" + + # Add query parameters adhering to mermaid.ink documentation + query_params = [] + for key, value in params.items(): + if key not in {"theme", "format"}: # Exclude theme (handled in init_params) and format (endpoint-specific) + if value is True: + query_params.append(f"{key}") + else: + query_params.append(f"{key}={value}") + + if query_params: + url += "?" + "&".join(query_params) logger.debug("Rendering graph at {url}", url=url) try: resp = requests.get(url, timeout=10) if resp.status_code >= 400: logger.warning( - "Failed to draw the pipeline: https://mermaid.ink/img/ returned status {status_code}", + "Failed to draw the pipeline: {server_url} returned status {status_code}", + server_url=server_url, status_code=resp.status_code, ) logger.info("Exact URL requested: {url}", url=url) @@ -93,18 +200,16 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph): except Exception as exc: # pylint: disable=broad-except logger.warning( - "Failed to draw the pipeline: could not connect to https://mermaid.ink/img/ ({error})", error=exc + "Failed to draw the pipeline: could not connect to {server_url} ({error})", server_url=server_url, error=exc ) logger.info("Exact URL requested: {url}", url=url) logger.warning("No pipeline diagram will be saved.") - raise PipelineDrawingError( - "There was an issue with https://mermaid.ink/, see the stacktrace for details." - ) from exc + raise PipelineDrawingError(f"There was an issue with {server_url}, see the stacktrace for details.") from exc return resp.content -def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: +def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str: """ Converts a Networkx graph into Mermaid syntax. @@ -153,7 +258,7 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: ] connections = "\n".join(connections_list + input_connections + output_connections) - graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections) + graph_styled = MERMAID_STYLED_TEMPLATE.format(params=init_params, connections=connections) logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled) return graph_styled diff --git a/releasenotes/notes/custom-mermaid-server-and-params-b88ca837375c3e0f.yaml b/releasenotes/notes/custom-mermaid-server-and-params-b88ca837375c3e0f.yaml new file mode 100644 index 0000000000..436de3e1b8 --- /dev/null +++ b/releasenotes/notes/custom-mermaid-server-and-params-b88ca837375c3e0f.yaml @@ -0,0 +1,5 @@ +--- + +features: + - | + Drawing pipelines, i.e.: calls to draw() or show(), can now be done using a custom Mermaid server and additional parameters. This allows for more flexibility in how pipelines are rendered. See Mermaid.ink's [documentation](https://github.com/jihchi/mermaid.ink) for more information on how to set up a custom server. diff --git a/test/core/pipeline/test_draw.py b/test/core/pipeline/test_draw.py index f687f6c587..a54bb761bd 100644 --- a/test/core/pipeline/test_draw.py +++ b/test/core/pipeline/test_draw.py @@ -57,7 +57,7 @@ def raise_for_status(self): mock_response.raise_for_status = raise_for_status mock_get.return_value = mock_response - with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink/"): + with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink"): _to_mermaid_image(pipe.graph) @@ -68,11 +68,12 @@ def test_to_mermaid_text(): pipe.connect("comp1.result", "comp2.value") pipe.connect("comp2.value", "comp1.value") - text = _to_mermaid_text(pipe.graph) + init_params = {"theme": "neutral"} + text = _to_mermaid_text(pipe.graph, init_params) assert ( text == """ -%%{ init: {'theme': 'neutral' } }%% +%%{ init: {'theme': 'neutral'} }%% graph TD; @@ -92,5 +93,108 @@ def test_to_mermaid_text_does_not_edit_graph(): pipe.connect("comp2.value", "comp1.value") expected_pipe = pipe.to_dict() - _to_mermaid_text(pipe.graph) + init_params = {"theme": "neutral"} + _to_mermaid_text(pipe.graph, init_params) assert expected_pipe == pipe.to_dict() + + +@pytest.mark.integration +@pytest.mark.parametrize( + "params", + [ + {"format": "img", "type": "png", "theme": "dark"}, + {"format": "svg", "theme": "forest"}, + {"format": "pdf", "fit": True, "theme": "neutral"}, + ], +) +def test_to_mermaid_image_valid_formats(params): + # Test valid formats + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + image_data = _to_mermaid_image(pipe.graph, params=params) + assert image_data # Ensure some data is returned + + +def test_to_mermaid_image_invalid_format(): + # Test invalid format + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + with pytest.raises(ValueError, match="Invalid image format:"): + _to_mermaid_image(pipe.graph, params={"format": "invalid_format"}) + + +@pytest.mark.integration +def test_to_mermaid_image_missing_theme(): + # Test default theme (neutral) + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + params = {"format": "img"} + image_data = _to_mermaid_image(pipe.graph, params=params) + + assert image_data # Ensure some data is returned + + +def test_to_mermaid_image_invalid_scale(): + # Test invalid scale + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + with pytest.raises(ValueError, match="Scale must be a number between 1 and 3."): + _to_mermaid_image(pipe.graph, params={"format": "img", "scale": 5}) + + +def test_to_mermaid_image_scale_without_dimensions(): + # Test scale without width/height + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + with pytest.raises(ValueError, match="Scale is only allowed when width or height is set."): + _to_mermaid_image(pipe.graph, params={"format": "img", "scale": 2}) + + +@patch("haystack.core.pipeline.draw.requests.get") +def test_to_mermaid_image_server_error(mock_get): + # Test server failure + pipe = Pipeline() + pipe.add_component("comp1", Double()) + pipe.add_component("comp2", Double()) + pipe.connect("comp1", "comp2") + + def raise_for_status(self): + raise requests.HTTPError() + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.content = '{"error": "server error"}' + mock_response.raise_for_status = raise_for_status + mock_get.return_value = mock_response + + with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink"): + _to_mermaid_image(pipe.graph) + + +def test_to_mermaid_image_invalid_server_url(): + # Test invalid server URL + pipe = Pipeline() + pipe.add_component("comp1", AddFixedValue(add=3)) + pipe.add_component("comp2", Double()) + pipe.connect("comp1.result", "comp2.value") + pipe.connect("comp2.value", "comp1.value") + + server_url = "https://invalid.server" + + with pytest.raises(PipelineDrawingError, match=f"There was an issue with {server_url}"): + _to_mermaid_image(pipe.graph, server_url=server_url)