Skip to content

Commit c21770b

Browse files
mscolnickLight2Darkpre-commit-ci[bot]
authored
improvement: print matplotlib Figures/Axes in rich table output (#6904)
Fixes #6893 Print matplotlib Figures/Axes in rich table output. Updated `narwhals_table.py` to handle of Pillow images and Matplotlib figures <img width="1038" height="894" alt="image" src="https://github.com/user-attachments/assets/d261b09b-4aee-4ecf-b135-c04bfcefaadb" /> --------- Co-authored-by: Shahmir Varqha <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 884ec62 commit c21770b

File tree

4 files changed

+310
-3
lines changed

4 files changed

+310
-3
lines changed

marimo/_plugins/ui/_impl/tables/narwhals_table.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,9 +715,37 @@ def _sanitize_table_value(self, value: Any) -> Any:
715715

716716
# Handle Pillow images
717717
if DependencyManager.pillow.imported():
718-
from PIL import Image
718+
try:
719+
from PIL import Image
720+
721+
if isinstance(value, Image.Image):
722+
return io_to_data_url(value, "image/png")
723+
except Exception:
724+
LOGGER.debug(
725+
"Unable to convert image to data URL", exc_info=True
726+
)
719727

720-
if isinstance(value, Image.Image):
721-
return io_to_data_url(value, "image/png")
728+
# Handle Matplotlib figures
729+
if DependencyManager.matplotlib.imported():
730+
try:
731+
import matplotlib.figure
732+
from matplotlib.axes import Axes
733+
734+
from marimo._output.formatting import as_html
735+
from marimo._plugins.stateless.flex import vstack
736+
737+
if isinstance(value, matplotlib.figure.Figure):
738+
html = as_html(vstack([str(value), value]))
739+
mimetype, data = html._mime_()
740+
741+
if isinstance(value, Axes):
742+
html = as_html(vstack([str(value), value]))
743+
mimetype, data = html._mime_()
744+
return {"mimetype": mimetype, "data": data}
745+
except Exception:
746+
LOGGER.debug(
747+
"Error converting matplotlib figures to HTML",
748+
exc_info=True,
749+
)
722750

723751
return value
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# /// script
2+
# requires-python = ">=3.11"
3+
# dependencies = [
4+
# "pandas",
5+
# "matplotlib",
6+
# ]
7+
# ///
8+
9+
import marimo
10+
11+
__generated_with = "0.17.0"
12+
app = marimo.App(width="medium")
13+
14+
15+
@app.cell
16+
def _():
17+
import marimo as mo
18+
return (mo,)
19+
20+
21+
@app.cell(hide_code=True)
22+
def _(mo):
23+
mo.md(
24+
"""
25+
# Issue #6893: Pandas Subplots Not Displaying
26+
27+
This smoke test reproduces the issue where pandas DataFrame box plots
28+
with `subplots=True` don't render properly in marimo.
29+
30+
**Expected behavior**: Box plots should display as images
31+
32+
**Actual behavior**: Only textual representation appears
33+
34+
The issue occurs because `df.plot.box(subplots=True)` returns a numpy
35+
ndarray of matplotlib Axes objects, which marimo doesn't currently format.
36+
"""
37+
)
38+
return
39+
40+
41+
@app.cell
42+
def _():
43+
import pandas as pd
44+
import matplotlib
45+
46+
# Load NYC taxi data from stable GitHub URL
47+
taxi_url = (
48+
"https://raw.githubusercontent.com/mwaskom/seaborn-data/master/taxis.csv"
49+
)
50+
df = pd.read_csv(taxi_url)
51+
df
52+
return (df,)
53+
54+
55+
@app.cell(hide_code=True)
56+
def _(mo):
57+
mo.md("""## Test Case 1: Box plot WITHOUT subplots (works correctly)""")
58+
return
59+
60+
61+
@app.cell
62+
def _(df):
63+
# This should work - returns a single Axes object
64+
df[["distance", "total"]].plot.box()
65+
return
66+
67+
68+
@app.cell(hide_code=True)
69+
def _(mo):
70+
mo.md(
71+
"""
72+
## Test Case 2: Box plot WITH subplots (reproduces issue #6893)
73+
74+
This is the exact scenario from the bug report.
75+
"""
76+
)
77+
return
78+
79+
80+
@app.cell
81+
def _(df):
82+
# This reproduces the issue - returns ndarray of Axes
83+
df[["distance", "total"]].plot.box(subplots=True)
84+
return
85+
86+
87+
@app.cell(hide_code=True)
88+
def _(mo):
89+
mo.md("""## Test Case 3: Multiple subplots with custom layout""")
90+
return
91+
92+
93+
@app.cell
94+
def _(df):
95+
# Test with 2x2 layout - also returns ndarray of Axes
96+
df[["distance", "total", "fare", "tip"]].plot.box(
97+
subplots=True, layout=(2, 2), figsize=(10, 8)
98+
)
99+
return
100+
101+
102+
@app.cell(hide_code=True)
103+
def _(mo):
104+
mo.md(
105+
"""
106+
## Test Case 4: 1D array of subplots
107+
108+
Test with a single row of subplots.
109+
"""
110+
)
111+
return
112+
113+
114+
@app.cell
115+
def _(df):
116+
# Returns 1D ndarray of Axes
117+
df[["fare", "tip", "tolls"]].plot.box(subplots=True, layout=(1, 3))
118+
return
119+
120+
121+
if __name__ == "__main__":
122+
app.run()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import marimo
2+
3+
__generated_with = "0.17.0"
4+
app = marimo.App(width="medium", sql_output="native")
5+
6+
7+
@app.cell
8+
def _():
9+
import marimo as mo
10+
import pandas as pd
11+
import numpy as np
12+
import sqlglot
13+
import pyarrow
14+
return np, pd
15+
16+
17+
@app.cell
18+
def _(np, pd):
19+
example_df = pd.DataFrame(
20+
[
21+
[np.random.random(size=[2, 2]) for _col in range(2)]
22+
for _row in range(3)
23+
],
24+
columns=["A", "B"],
25+
)
26+
return (example_df,)
27+
28+
29+
@app.cell
30+
def _(example_df):
31+
import duckdb
32+
33+
res = duckdb.sql(
34+
f"""
35+
SELECT COUNT(*) FROM example_df
36+
""",
37+
)
38+
print(res)
39+
return
40+
41+
42+
if __name__ == "__main__":
43+
app.run()

tests/_plugins/ui/_impl/tables/test_narwhals.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,3 +1600,117 @@ def test_calculate_top_k_rows_cache_invalidation(df: Any) -> None:
16001600

16011601
# Verify the actual results are different
16021602
assert result1 != result2
1603+
1604+
1605+
@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed")
1606+
class TestSanitizeTableValue:
1607+
"""Tests for the _sanitize_table_value method."""
1608+
1609+
def setUp(self) -> None:
1610+
import polars as pl
1611+
1612+
self.data = pl.DataFrame({"A": [1, 2, 3]})
1613+
self.manager = NarwhalsTableManager.from_dataframe(self.data)
1614+
1615+
def test_sanitize_none(self) -> None:
1616+
"""Test that None values are returned as-is."""
1617+
manager = self._get_manager()
1618+
assert manager._sanitize_table_value(None) is None
1619+
1620+
def test_sanitize_primitive_values(self) -> None:
1621+
"""Test that primitive values are returned unchanged."""
1622+
manager = self._get_manager()
1623+
assert manager._sanitize_table_value(42) == 42
1624+
assert manager._sanitize_table_value("hello") == "hello"
1625+
assert manager._sanitize_table_value(3.14) == 3.14
1626+
assert manager._sanitize_table_value(True) is True
1627+
1628+
@pytest.mark.skipif(
1629+
not DependencyManager.pillow.has(),
1630+
reason="Pillow not installed",
1631+
)
1632+
def test_sanitize_pillow_image(self) -> None:
1633+
"""Test that Pillow images are converted to data URLs."""
1634+
from PIL import Image
1635+
1636+
manager = self._get_manager()
1637+
1638+
# Create a simple test image
1639+
img = Image.new("RGB", (10, 10), color="red")
1640+
1641+
result = manager._sanitize_table_value(img)
1642+
1643+
# Verify it returns a data URL string
1644+
assert isinstance(result, str)
1645+
assert result.startswith("data:image/png;base64,")
1646+
1647+
@pytest.mark.skipif(
1648+
not DependencyManager.matplotlib.has(),
1649+
reason="Matplotlib not installed",
1650+
)
1651+
def test_sanitize_matplotlib_figure(self) -> None:
1652+
"""Test that Matplotlib figures are returned unchanged (no conversion)."""
1653+
import matplotlib.pyplot as plt
1654+
1655+
manager = self._get_manager()
1656+
1657+
# Create a simple figure
1658+
fig, ax = plt.subplots()
1659+
ax.plot([1, 2, 3], [1, 2, 3])
1660+
1661+
result = manager._sanitize_table_value(fig)
1662+
1663+
# Figure is currently returned unchanged because there's no return statement for figures
1664+
# (only for axes)
1665+
assert result == fig
1666+
1667+
plt.close(fig)
1668+
1669+
@pytest.mark.skipif(
1670+
not DependencyManager.matplotlib.has(),
1671+
reason="Matplotlib not installed",
1672+
)
1673+
def test_sanitize_matplotlib_axes(self) -> None:
1674+
"""Test that Matplotlib axes are converted to HTML."""
1675+
import matplotlib.pyplot as plt
1676+
1677+
manager = self._get_manager()
1678+
1679+
# Create a simple axes
1680+
fig, ax = plt.subplots()
1681+
ax.plot([1, 2, 3], [1, 2, 3])
1682+
1683+
result = manager._sanitize_table_value(ax)
1684+
1685+
# Verify it returns a dict with mimetype and data
1686+
assert isinstance(result, dict)
1687+
assert "mimetype" in result
1688+
assert "data" in result
1689+
1690+
plt.close(fig)
1691+
1692+
def test_sanitize_unsupported_types(self) -> None:
1693+
"""Test that unsupported types are returned unchanged."""
1694+
manager = self._get_manager()
1695+
1696+
# Test various unsupported types
1697+
class CustomClass:
1698+
pass
1699+
1700+
obj = CustomClass()
1701+
assert manager._sanitize_table_value(obj) == obj
1702+
1703+
# Test dict
1704+
d = {"key": "value"}
1705+
assert manager._sanitize_table_value(d) == d
1706+
1707+
# Test list
1708+
lst = [1, 2, 3]
1709+
assert manager._sanitize_table_value(lst) == lst
1710+
1711+
def _get_manager(self) -> NarwhalsTableManager[Any]:
1712+
"""Helper method to create a manager."""
1713+
import polars as pl
1714+
1715+
data = pl.DataFrame({"A": [1, 2, 3]})
1716+
return NarwhalsTableManager.from_dataframe(data)

0 commit comments

Comments
 (0)