Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 54 additions & 9 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

DEFAULT_MARSHALLER = YamlMarshaller()

# We use a generic type to annotate the return value of classmethods,
# We use a generic type to annotate the return value of class methods,
# so that static analyzers won't be confused when derived classes
# use those methods.
T = TypeVar("T", bound="PipelineBase")
Expand Down Expand Up @@ -619,31 +619,76 @@ def outputs(self, include_components_with_connected_outputs: bool = False) -> Di
}
return outputs

def show(self) -> None:
def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
"""
If running in a Jupyter notebook, display an image representing this `Pipeline`.
Display an image representing this `Pipeline` in a Jupyter notebook.

This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in
the notebook.

:param server_url:
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
info on how to set up your own Mermaid server.

:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
Supported keys:
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
- width: Width of the output image (integer).
- height: Height of the output image (integer).
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.

:raises PipelineDrawingError:
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
"""
if is_in_jupyter():
from IPython.display import Image, display # type: ignore

image_data = _to_mermaid_image(self.graph)

image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
display(Image(image_data))
else:
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
raise PipelineDrawingError(msg)

def draw(self, path: Path) -> None:
def draw(self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
"""
Save an image representing this `Pipeline` to `path`.
Save an image representing this `Pipeline` to the specified file path.

This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path.

:param path:
The path to save the image to.
The file path where the generated image will be saved.
:param server_url:
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
info on how to set up your own Mermaid server.
:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
Supported keys:
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
- width: Width of the output image (integer).
- height: Height of the output image (integer).
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.

:raises PipelineDrawingError:
If there is an issue with rendering or saving the image.
"""
# Before drawing we edit a bit the graph, to avoid modifying the original that is
# used for running the pipeline we copy it.
image_data = _to_mermaid_image(self.graph)
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
Path(path).write_bytes(image_data)

def walk(self) -> Iterator[Tuple[str, Component]]:
Expand Down
133 changes: 119 additions & 14 deletions haystack/core/pipeline/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import base64
import json
import zlib
from typing import Any, Dict, Optional

import networkx # type:ignore
import requests
Expand Down Expand Up @@ -54,7 +55,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
ARROWHEAD_MANDATORY = "-->"
ARROWHEAD_OPTIONAL = ".->"
MERMAID_STYLED_TEMPLATE = """
%%{{ init: {{'theme': 'neutral' }} }}%%
%%{{ init: {params} }}%%

graph TD;

Expand All @@ -64,27 +65,133 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
"""


def _to_mermaid_image(graph: networkx.MultiDiGraph):
def _validate_mermaid_params(params: Dict[str, Any]) -> None:
"""
Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access.
Validates and sets default values for Mermaid parameters.

:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details.
Supported keys:
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
- width: Width of the output image (integer).
- height: Height of the output image (integer).
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.

:raises ValueError:
If any parameter is invalid or does not match the expected format.
"""
valid_img_types = {"jpeg", "png", "webp"}
valid_themes = {"default", "neutral", "dark", "forest"}
valid_formats = {"img", "svg", "pdf"}

params.setdefault("format", "img")
params.setdefault("type", "png")
params.setdefault("theme", "neutral")

if params["format"] not in valid_formats:
raise ValueError(f"Invalid image format: {params['format']}. Valid options are: {valid_formats}.")

if params["format"] == "img" and params["type"] not in valid_img_types:
raise ValueError(f"Invalid image type: {params['type']}. Valid options are: {valid_img_types}.")

if params["theme"] not in valid_themes:
raise ValueError(f"Invalid theme: {params['theme']}. Valid options are: {valid_themes}.")

if "width" in params and not isinstance(params["width"], int):
raise ValueError("Width must be an integer.")
if "height" in params and not isinstance(params["height"], int):
raise ValueError("Height must be an integer.")

if "scale" in params and not 1 <= params["scale"] <= 3:
raise ValueError("Scale must be a number between 1 and 3.")
if "scale" in params and not ("width" in params or "height" in params):
raise ValueError("Scale is only allowed when width or height is set.")

if "bgColor" in params and not isinstance(params["bgColor"], str):
raise ValueError("Background color must be a string.")

# PDF specific parameters
if params["format"] == "pdf":
if "fit" in params and not isinstance(params["fit"], bool):
raise ValueError("Fit must be a boolean.")
if "paper" in params and not isinstance(params["paper"], str):
raise ValueError("Paper size must be a string (e.g., 'a4', 'a3').")
if "landscape" in params and not isinstance(params["landscape"], bool):
raise ValueError("Landscape must be a boolean.")
if "fit" in params and ("paper" in params or "landscape" in params):
logger.warning("`fit` overrides `paper` and `landscape` for PDFs. Ignoring `paper` and `landscape`.")


def _to_mermaid_image(
graph: networkx.MultiDiGraph, server_url: str = "https://mermaid.ink", params: Optional[dict] = None
) -> bytes:
"""
Renders a pipeline using a Mermaid server.

:param graph:
The graph to render as a Mermaid pipeline.
:param server_url:
Base URL of the Mermaid server (default: 'https://mermaid.ink').
:param params:
Dictionary of customization parameters. See `validate_mermaid_params` for valid keys.
:returns:
The image, SVG, or PDF data returned by the Mermaid server as bytes.
:raises ValueError:
If any parameter is invalid or does not match the expected format.
:raises PipelineDrawingError:
If there is an issue connecting to the Mermaid server or the server returns an error.
"""

if params is None:
params = {}

_validate_mermaid_params(params)

theme = params.get("theme")
init_params = json.dumps({"theme": theme})

# Copy the graph to avoid modifying the original
graph_styled = _to_mermaid_text(graph.copy())
graph_styled = _to_mermaid_text(graph.copy(), init_params)
json_string = json.dumps({"code": graph_styled})

# Uses the DEFLATE algorithm at the highest level for smallest size
compressor = zlib.compressobj(level=9)
# Compress the JSON string with zlib (RFC 1950)
compressor = zlib.compressobj(level=9, wbits=15)
compressed_data = compressor.compress(json_string.encode("utf-8")) + compressor.flush()
compressed_url_safe_base64 = base64.urlsafe_b64encode(compressed_data).decode("utf-8").strip()

url = f"https://mermaid.ink/img/pako:{compressed_url_safe_base64}?type=png"
# Determine the correct endpoint
endpoint_format = params.get("format", "img") # Default to /img endpoint
if endpoint_format not in {"img", "svg", "pdf"}:
raise ValueError(f"Invalid format: {endpoint_format}. Valid options are 'img', 'svg', or 'pdf'.")

# Construct the URL without query parameters
url = f"{server_url}/{endpoint_format}/pako:{compressed_url_safe_base64}"

# Add query parameters adhering to mermaid.ink documentation
query_params = []
for key, value in params.items():
if key not in {"theme", "format"}: # Exclude theme (handled in init_params) and format (endpoint-specific)
if value is True:
query_params.append(f"{key}")
else:
query_params.append(f"{key}={value}")

if query_params:
url += "?" + "&".join(query_params)

logger.debug("Rendering graph at {url}", url=url)
try:
resp = requests.get(url, timeout=10)
if resp.status_code >= 400:
logger.warning(
"Failed to draw the pipeline: https://mermaid.ink/img/ returned status {status_code}",
"Failed to draw the pipeline: {server_url} returned status {status_code}",
server_url=server_url,
status_code=resp.status_code,
)
logger.info("Exact URL requested: {url}", url=url)
Expand All @@ -93,18 +200,16 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph):

except Exception as exc: # pylint: disable=broad-except
logger.warning(
"Failed to draw the pipeline: could not connect to https://mermaid.ink/img/ ({error})", error=exc
"Failed to draw the pipeline: could not connect to {server_url} ({error})", server_url=server_url, error=exc
)
logger.info("Exact URL requested: {url}", url=url)
logger.warning("No pipeline diagram will be saved.")
raise PipelineDrawingError(
"There was an issue with https://mermaid.ink/, see the stacktrace for details."
) from exc
raise PipelineDrawingError(f"There was an issue with {server_url}, see the stacktrace for details.") from exc

return resp.content


def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
"""
Converts a Networkx graph into Mermaid syntax.

Expand Down Expand Up @@ -153,7 +258,7 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
]
connections = "\n".join(connections_list + input_connections + output_connections)

graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections)
graph_styled = MERMAID_STYLED_TEMPLATE.format(params=init_params, connections=connections)
logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled)

return graph_styled
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---

features:
- |
Drawing pipelines, i.e.: calls to draw() or show(), can now be done using a custom Mermaid server and additional parameters. This allows for more flexibility in how pipelines are rendered. See Mermaid.ink's [documentation](https://github.com/jihchi/mermaid.ink) for more information on how to set up a custom server.
Loading