Skip to content

Commit 6e07a83

Browse files
authored
improvement: better error handling in type-mismatch for mo.ui.altair_chart (#6986)
If the types mismatch, we throw a clearer error instead of a stack trace
1 parent 65e4429 commit 6e07a83

File tree

2 files changed

+152
-65
lines changed

2 files changed

+152
-65
lines changed

marimo/_plugins/ui/_impl/altair_chart.py

Lines changed: 111 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -185,59 +185,84 @@ def _filter_dataframe(
185185
dtype = schema[field]
186186
resolved_values = _resolve_values(values, dtype)
187187

188-
if is_point_selection and not is_binned:
189-
df = df.filter(nw.col(field).is_in(resolved_values))
190-
elif len(resolved_values) == 1:
191-
df = df.filter(nw.col(field) == resolved_values[0])
192-
# Range selection
193-
elif len(resolved_values) == 2 and _is_numeric_or_date(
194-
resolved_values[0]
188+
# Validate that resolved values have compatible types
189+
# If coercion failed, the values will still be strings when they should be dates/numbers
190+
if nw.Date == dtype or (
191+
nw.Datetime == dtype and isinstance(dtype, nw.Datetime)
195192
):
196-
left_value, right_value = resolved_values
197-
198-
# For binned fields, we need to check if this is the last bin
199-
# by comparing the right boundary to the maximum value in the dataset.
200-
# If they're equal (or right boundary >= max), use inclusive right boundary.
201-
if is_binned:
202-
# Get the maximum value in the dataset for this field
203-
max_value_df = df.select(nw.col(field).max())
204-
max_value_collected = (
205-
max_value_df.collect()
206-
if is_narwhals_lazyframe(max_value_df)
207-
else max_value_df
193+
# Check if any values are still strings (indicating failed coercion)
194+
if any(isinstance(v, str) for v in resolved_values):
195+
LOGGER.error(
196+
f"Type mismatch for field '{field}': Column has {dtype} type, "
197+
f"but values {resolved_values} could not be properly coerced. "
198+
"Skipping this filter condition."
208199
)
209-
max_value = max_value_collected[field][0]
200+
continue
210201

211-
# If right boundary >= max value, this is the last bin
212-
is_last_bin = right_value >= max_value
213-
214-
if is_last_bin:
215-
# Last bin: use inclusive right boundary
216-
df = df.filter(
217-
(nw.col(field) >= left_value)
218-
& (nw.col(field) <= right_value)
202+
try:
203+
if is_point_selection and not is_binned:
204+
df = df.filter(nw.col(field).is_in(resolved_values))
205+
elif len(resolved_values) == 1:
206+
df = df.filter(nw.col(field) == resolved_values[0])
207+
# Range selection
208+
elif len(resolved_values) == 2 and _is_numeric_or_date(
209+
resolved_values[0]
210+
):
211+
left_value, right_value = resolved_values
212+
213+
# For binned fields, we need to check if this is the last bin
214+
# by comparing the right boundary to the maximum value in the dataset.
215+
# If they're equal (or right boundary >= max), use inclusive right boundary.
216+
if is_binned:
217+
# Get the maximum value in the dataset for this field
218+
max_value_df = df.select(nw.col(field).max())
219+
max_value_collected = (
220+
max_value_df.collect()
221+
if is_narwhals_lazyframe(max_value_df)
222+
else max_value_df
219223
)
224+
max_value = max_value_collected[field][0]
225+
226+
# If right boundary >= max value, this is the last bin
227+
is_last_bin = right_value >= max_value
228+
229+
if is_last_bin:
230+
# Last bin: use inclusive right boundary
231+
df = df.filter(
232+
(nw.col(field) >= left_value)
233+
& (nw.col(field) <= right_value)
234+
)
235+
else:
236+
# Not last bin: use exclusive right boundary
237+
df = df.filter(
238+
(nw.col(field) >= left_value)
239+
& (nw.col(field) < right_value)
240+
)
220241
else:
221-
# Not last bin: use exclusive right boundary
242+
# Non-binned fields: use inclusive right boundary
222243
df = df.filter(
223244
(nw.col(field) >= left_value)
224-
& (nw.col(field) < right_value)
245+
& (nw.col(field) <= right_value)
225246
)
247+
# Multi-selection via range
248+
# This can happen when you use an interval selection
249+
# on categorical data
250+
elif len(resolved_values) > 1:
251+
df = df.filter(nw.col(field).is_in(resolved_values))
226252
else:
227-
# Non-binned fields: use inclusive right boundary
228-
df = df.filter(
229-
(nw.col(field) >= left_value)
230-
& (nw.col(field) <= right_value)
253+
raise ValueError(
254+
f"Invalid selection: {field}={resolved_values}"
231255
)
232-
# Multi-selection via range
233-
# This can happen when you use an interval selection
234-
# on categorical data
235-
elif len(resolved_values) > 1:
236-
df = df.filter(nw.col(field).is_in(resolved_values))
237-
else:
238-
raise ValueError(
239-
f"Invalid selection: {field}={resolved_values}"
256+
except (TypeError, ValueError, Exception) as e:
257+
# Handle type comparison errors and other database errors gracefully
258+
# (e.g., DuckDB BinderException, Polars errors, etc.)
259+
LOGGER.error(
260+
f"Error during filter comparison for field '{field}': {e}. "
261+
f"Attempted to compare {dtype} column with values {resolved_values}. "
262+
"Skipping this filter condition."
240263
)
264+
# Continue without this filter - don't break the entire operation
265+
continue
241266

242267
if not is_lazy and is_narwhals_lazyframe(df):
243268
# Undo the lazy
@@ -250,30 +275,51 @@ def _resolve_values(values: Any, dtype: Any) -> list[Any]:
250275
def _coerce_value(value: Any, dtype: Any) -> Any:
251276
import zoneinfo
252277

253-
if nw.Date == dtype:
254-
if isinstance(value, str):
255-
return datetime.date.fromisoformat(value)
256-
# Value is milliseconds since epoch
257-
# so we convert to seconds since epoch
258-
return datetime.date.fromtimestamp(value / 1000)
259-
if nw.Datetime == dtype and isinstance(dtype, nw.Datetime):
260-
if isinstance(value, str):
261-
res = datetime.datetime.fromisoformat(value)
262-
# If dtype has no timezone, but value has timezone, remove timezone without shifting
263-
if dtype.time_zone is None and res.tzinfo is not None:
264-
return res.replace(tzinfo=None)
265-
return res
266-
267-
# Value is milliseconds since epoch
268-
# so we convert to seconds since epoch
269-
return datetime.datetime.fromtimestamp(
270-
value / 1000,
271-
tz=(
272-
zoneinfo.ZoneInfo(dtype.time_zone)
273-
if dtype.time_zone
274-
else None
275-
),
278+
try:
279+
if nw.Date == dtype:
280+
if isinstance(value, str):
281+
return datetime.date.fromisoformat(value)
282+
# Value is milliseconds since epoch
283+
# so we convert to seconds since epoch
284+
if isinstance(value, (int, float)):
285+
return datetime.date.fromtimestamp(value / 1000)
286+
# If value is already a date or datetime, return as-is
287+
if isinstance(value, datetime.date):
288+
return value
289+
# Otherwise, try to convert to string then parse
290+
return datetime.date.fromisoformat(str(value))
291+
if nw.Datetime == dtype and isinstance(dtype, nw.Datetime):
292+
if isinstance(value, str):
293+
res = datetime.datetime.fromisoformat(value)
294+
# If dtype has no timezone, but value has timezone, remove timezone without shifting
295+
if dtype.time_zone is None and res.tzinfo is not None:
296+
return res.replace(tzinfo=None)
297+
return res
298+
299+
# Value is milliseconds since epoch
300+
# so we convert to seconds since epoch
301+
if isinstance(value, (int, float)):
302+
return datetime.datetime.fromtimestamp(
303+
value / 1000,
304+
tz=(
305+
zoneinfo.ZoneInfo(dtype.time_zone)
306+
if dtype.time_zone
307+
else None
308+
),
309+
)
310+
# If value is already a datetime, return as-is
311+
if isinstance(value, datetime.datetime):
312+
return value
313+
# Otherwise, try to convert to string then parse
314+
return datetime.datetime.fromisoformat(str(value))
315+
except (ValueError, TypeError, OSError) as e:
316+
# Log the error but return the original value
317+
# to avoid breaking the filter entirely
318+
LOGGER.warning(
319+
f"Failed to coerce value {value!r} to {dtype}: {e}. "
320+
"Using original value."
276321
)
322+
return value
277323
return value
278324

279325
if isinstance(values, list):

tests/_plugins/ui/_impl/test_altair_chart.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,47 @@ def test_filter_dataframe_with_dates(
244244
assert str(first) == "value1"
245245
assert str(second) == "value2"
246246

247+
@staticmethod
248+
@pytest.mark.parametrize(
249+
"df",
250+
create_dataframes(
251+
{
252+
"field": ["value1", "value2", "value3"],
253+
"date_column": [
254+
datetime.date(2020, 1, 1),
255+
datetime.date(2020, 1, 8),
256+
datetime.date(2020, 1, 10),
257+
],
258+
},
259+
),
260+
)
261+
def test_filter_dataframe_with_dates_graceful_error(
262+
df: ChartDataType,
263+
) -> None:
264+
"""Test that invalid date comparisons are handled gracefully."""
265+
# Try with invalid date strings that can't be parsed
266+
interval_selection: ChartSelection = {
267+
"signal_channel": {"date_column": ["invalid_date", "also_invalid"]}
268+
}
269+
# Should not raise an error, but skip the filter condition
270+
# and return the original dataframe
271+
filtered_df = _filter_dataframe(df, selection=interval_selection)
272+
# Since the filter failed gracefully, we should get the full dataframe
273+
assert get_len(filtered_df) == 3
274+
275+
# Try with mixed valid/invalid values - the coercion should handle it
276+
interval_selection = {
277+
"signal_channel": {
278+
"date_column": [
279+
datetime.date(2020, 1, 1).isoformat(),
280+
"not_a_valid_date",
281+
]
282+
}
283+
}
284+
# The filter should be skipped due to type error
285+
filtered_df = _filter_dataframe(df, selection=interval_selection)
286+
assert get_len(filtered_df) == 3
287+
247288
@staticmethod
248289
@pytest.mark.skipif(
249290
not HAS_DEPS, reason="optional dependencies not installed"

0 commit comments

Comments
 (0)