Skip to content

Commit a8d3a5e

Browse files
timsaucerclaude
andcommitted
refactor: split cast_to_type into cast_to_type and try_cast_to_type
Replace the try_cast bool flag with separate cast_to_type and try_cast_to_type functions, matching upstream DataFusion and the arrow_cast / arrow_try_cast pair. Also drop the redundant data_type parametrization on test_arrow_try_cast_null_on_failure, since the str-vs-pyarrow distinction is already covered by test_arrow_cast_variants. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 25cb8e2 commit a8d3a5e

3 files changed

Lines changed: 33 additions & 26 deletions

File tree

crates/core/src/functions.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -609,15 +609,8 @@ expr_fn!(arrow_typeof, arg_1);
609609
expr_fn!(arrow_cast, arg_1 datatype);
610610
expr_fn!(arrow_try_cast, arg_1 datatype);
611611
expr_fn!(arrow_field, arg_1);
612-
#[pyfunction]
613-
#[pyo3(signature = (arg_1, reference, *, try_cast = false))]
614-
fn cast_to_type(arg_1: PyExpr, reference: PyExpr, try_cast: bool) -> PyExpr {
615-
if try_cast {
616-
functions::expr_fn::try_cast_to_type(arg_1.into(), reference.into()).into()
617-
} else {
618-
functions::expr_fn::cast_to_type(arg_1.into(), reference.into()).into()
619-
}
620-
}
612+
expr_fn!(cast_to_type, arg_1 reference);
613+
expr_fn!(try_cast_to_type, arg_1 reference);
621614
expr_fn_vec!(arrow_metadata);
622615
expr_fn_vec!(with_metadata);
623616
expr_fn!(union_tag, arg1);
@@ -977,6 +970,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
977970
m.add_wrapped(wrap_pyfunction!(arrow_try_cast))?;
978971
m.add_wrapped(wrap_pyfunction!(arrow_field))?;
979972
m.add_wrapped(wrap_pyfunction!(cast_to_type))?;
973+
m.add_wrapped(wrap_pyfunction!(try_cast_to_type))?;
980974
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
981975
m.add_wrapped(wrap_pyfunction!(with_metadata))?;
982976
m.add_wrapped(wrap_pyfunction!(ascii))?;

python/datafusion/functions.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@
360360
"translate",
361361
"trim",
362362
"trunc",
363+
"try_cast_to_type",
363364
"union_extract",
364365
"union_tag",
365366
"upper",
@@ -2988,13 +2989,13 @@ def arrow_field(expr: Expr) -> Expr:
29882989
return Expr(f.arrow_field(expr.expr))
29892990

29902991

2991-
def cast_to_type(value: Expr, type_ref: Expr, *, try_cast: bool = False) -> Expr:
2992+
def cast_to_type(value: Expr, type_ref: Expr) -> Expr:
29922993
"""Casts ``value`` to the data type of ``type_ref``.
29932994
29942995
Only the *type* of ``type_ref`` is used; its value is ignored. This is
29952996
useful when the target type comes from another column or expression
2996-
rather than being known up-front. When ``try_cast=True``, casts that
2997-
fail produce NULL instead of erroring.
2997+
rather than being known up-front. Casts that fail produce an error; use
2998+
:py:func:`try_cast_to_type` for the NULL-on-failure variant.
29982999
29993000
If the target type is known statically, prefer :py:func:`arrow_cast`
30003001
(or :py:func:`arrow_try_cast` for the NULL-on-failure variant) and
@@ -3010,17 +3011,32 @@ def cast_to_type(value: Expr, type_ref: Expr, *, try_cast: bool = False) -> Expr
30103011
... )
30113012
>>> result.collect_column("c")[0].as_py()
30123013
1.0
3014+
"""
3015+
return Expr(f.cast_to_type(value.expr, type_ref.expr))
3016+
3017+
3018+
def try_cast_to_type(value: Expr, type_ref: Expr) -> Expr:
3019+
"""Casts ``value`` to the data type of ``type_ref``, NULL on failure.
3020+
3021+
Like :py:func:`cast_to_type`, but casts that fail produce NULL instead
3022+
of erroring. Only the *type* of ``type_ref`` is used; its value is
3023+
ignored.
3024+
3025+
If the target type is known statically, prefer :py:func:`arrow_try_cast`
3026+
and pass a type string or ``pyarrow.DataType`` directly.
30133027
3028+
Examples:
3029+
>>> ctx = dfn.SessionContext()
30143030
>>> df = ctx.from_pydict({"a": ["oops"], "b": [1.0]})
30153031
>>> result = df.select(
3016-
... dfn.functions.cast_to_type(
3017-
... dfn.col("a"), dfn.col("b"), try_cast=True
3032+
... dfn.functions.try_cast_to_type(
3033+
... dfn.col("a"), dfn.col("b")
30183034
... ).alias("c")
30193035
... )
30203036
>>> result.collect_column("c")[0].as_py() is None
30213037
True
30223038
"""
3023-
return Expr(f.cast_to_type(value.expr, type_ref.expr, try_cast=try_cast))
3039+
return Expr(f.try_cast_to_type(value.expr, type_ref.expr))
30243040

30253041

30263042
def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:

python/tests/test_functions.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,13 +1316,12 @@ def test_arrow_cast_variants(df, cast_fn, data_type, expected):
13161316
assert result.column(0) == expected
13171317

13181318

1319-
@pytest.mark.parametrize("data_type", ["Float64", pa.float64()])
1320-
def test_arrow_try_cast_null_on_failure(data_type):
1319+
def test_arrow_try_cast_null_on_failure():
13211320
ctx = SessionContext()
13221321
batch = pa.RecordBatch.from_arrays([pa.array(["1.5", "oops", "3"])], names=["s"])
13231322
df = ctx.create_dataframe([[batch]])
13241323

1325-
result = df.select(f.arrow_try_cast(column("s"), data_type).alias("c")).collect()[0]
1324+
result = df.select(f.arrow_try_cast(column("s"), "Float64").alias("c")).collect()[0]
13261325

13271326
assert result.column(0).to_pylist() == [1.5, None, 3.0]
13281327

@@ -1348,23 +1347,21 @@ def test_arrow_field():
13481347

13491348

13501349
@pytest.mark.parametrize(
1351-
("values", "try_cast", "expected"),
1350+
("cast_fn", "values", "expected"),
13521351
[
1353-
(pa.array([4, 5, 6]), False, [4.0, 5.0, 6.0]),
1354-
(pa.array(["oops", "2", "3"]), True, [None, 2.0, 3.0]),
1352+
(f.cast_to_type, pa.array([4, 5, 6]), [4.0, 5.0, 6.0]),
1353+
(f.try_cast_to_type, pa.array(["oops", "2", "3"]), [None, 2.0, 3.0]),
13551354
],
13561355
)
1357-
def test_cast_to_type(values, try_cast, expected):
1358-
"""cast_to_type takes target type from ``type_ref``; try_cast nullifies failures."""
1356+
def test_cast_to_type(cast_fn, values, expected):
1357+
"""cast_to_type / try_cast_to_type take target type from ``type_ref``."""
13591358
ctx = SessionContext()
13601359
batch = pa.RecordBatch.from_arrays(
13611360
[values, pa.array([1.0, 2.0, 3.0])], names=["v", "fl"]
13621361
)
13631362
df = ctx.create_dataframe([[batch]])
13641363

1365-
result = df.select(
1366-
f.cast_to_type(column("v"), column("fl"), try_cast=try_cast).alias("c")
1367-
).collect()[0]
1364+
result = df.select(cast_fn(column("v"), column("fl")).alias("c")).collect()[0]
13681365

13691366
assert result.column(0).to_pylist() == expected
13701367
assert result.column(0).type == pa.float64()

0 commit comments

Comments
 (0)