Skip to content

Commit 9c37f52

Browse files
DrTodd13Todd A. Anderson
andauthored
Add support for running str.match in arrow compute. (#865)
Co-authored-by: Todd A. Anderson <[email protected]>
1 parent 8f2c217 commit 9c37f52

6 files changed

Lines changed: 61 additions & 11 deletions

File tree

bodo/pandas/physical/expression.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ std::shared_ptr<array_info> do_arrow_compute_binary(
166166
}
167167

168168
std::shared_ptr<array_info> do_arrow_compute_unary(
169-
std::shared_ptr<ExprResult> left_res, const std::string& comparator) {
169+
std::shared_ptr<ExprResult> left_res, const std::string& comparator,
170+
const arrow::compute::FunctionOptions* func_options) {
170171
// Try to convert the results of our children into array
171172
// or scalar results to see which one they are.
172173
std::shared_ptr<ArrayExprResult> left_as_array =
@@ -187,7 +188,7 @@ std::shared_ptr<array_info> do_arrow_compute_unary(
187188
}
188189

189190
arrow::Result<arrow::Datum> cmp_res =
190-
arrow::compute::CallFunction(comparator, {src1});
191+
arrow::compute::CallFunction(comparator, {src1}, func_options);
191192
if (!cmp_res.ok()) [[unlikely]] {
192193
throw std::runtime_error(
193194
"do_array_compute_unary: Error in Arrow compute: " +
@@ -665,6 +666,31 @@ std::shared_ptr<ExprResult> PhysicalArrowExpression::ProcessBatch(
665666
// which returns a struct. To match the output dtype of Pandas, we Cast
666667
// to Date32 instead.
667668
result = do_arrow_compute_cast(res, duckdb::LogicalType::DATE);
669+
} else if (scalar_func_data.arrow_func_name == "match_substring_regex") {
670+
if (!PyTuple_Check(scalar_func_data.args) ||
671+
PyTuple_Size(scalar_func_data.args) != 1) {
672+
throw std::runtime_error(
673+
"match_substring_regex args not a 1-element tuple.");
674+
}
675+
676+
// Get the first element (borrowed reference)
677+
PyObject* py_str = PyTuple_GetItem(scalar_func_data.args, 0);
678+
679+
if (!PyUnicode_Check(py_str)) {
680+
throw std::runtime_error(
681+
"match_substring_regex args element is not a Python string.");
682+
}
683+
684+
// Convert to UTF‑8 C string
685+
const char* c_str = PyUnicode_AsUTF8(py_str);
686+
if (!c_str) {
687+
throw std::runtime_error(
688+
"match_substring_regex error extracting Python string.");
689+
}
690+
691+
arrow::compute::MatchSubstringOptions opts(c_str);
692+
result = do_arrow_compute_unary(res, scalar_func_data.arrow_func_name,
693+
&opts);
668694
} else {
669695
result = do_arrow_compute_unary(res, scalar_func_data.arrow_func_name);
670696
}

bodo/pandas/physical/expression.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ extern std::function<bool(int)> less_equal_test;
137137
*
138138
*/
139139
std::shared_ptr<array_info> do_arrow_compute_unary(
140-
std::shared_ptr<ExprResult> left_res, const std::string &comparator);
140+
std::shared_ptr<ExprResult> left_res, const std::string &comparator,
141+
const arrow::compute::FunctionOptions *func_options = nullptr);
141142

142143
/**
143144
* @brief Convert two ExprResults to arrow and run compute operation on them.

bodo/pandas/plan.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,11 @@ def function_name(self):
739739
"""Return the function name."""
740740
return self.args[2]
741741

742+
@property
743+
def function_args(self):
744+
"""Return the function args."""
745+
return self.args[3]
746+
742747
def update_func_expr_source(self, new_source_plan: LazyPlan, col_index_offset: int):
743748
"""Update the source and column index of the function expression."""
744749
if self.source != new_source_plan:
@@ -763,6 +768,7 @@ def update_func_expr_source(self, new_source_plan: LazyPlan, col_index_offset: i
763768
new_source_plan,
764769
(in_col_ind + col_index_offset,) + index_cols,
765770
self.function_name,
771+
self.function_args,
766772
)
767773
expr.is_series = self.is_series
768774
return expr

bodo/pandas/plan_optimizer.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,11 +676,12 @@ cdef class ArrowScalarFuncExpression(Expression):
676676
object out_schema,
677677
LogicalOperator source,
678678
vector[int] input_column_indices,
679-
str function_name):
679+
str function_name,
680+
object args):
680681

681682
self.out_schema = out_schema
682683
self.c_expression = make_scalar_func_expr(
683-
source.c_logical_operator, out_schema, None, input_column_indices, False, False, function_name.encode())
684+
source.c_logical_operator, out_schema, args, input_column_indices, False, False, function_name.encode())
684685

685686
def __str__(self):
686687
return f"ArrowScalarFuncExpression({self.function_name})"

bodo/pandas/series.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2554,10 +2554,7 @@ def make_expr(expr, plan, first, schema, index_cols, side="right"):
25542554
idx = get_new_idx(idx, first, side)
25552555
empty_data = arrow_to_empty_df(pa.schema([expr.pa_schema[0]]))
25562556
return ArrowScalarFuncExpression(
2557-
empty_data,
2558-
plan,
2559-
(idx,) + tuple(index_cols),
2560-
expr.function_name,
2557+
empty_data, plan, (idx,) + tuple(index_cols), expr.function_name, ()
25612558
)
25622559
elif is_arith_expr(expr):
25632560
# TODO: recursively traverse arithmetic expr tree to update col idx.
@@ -2729,6 +2726,7 @@ def _get_series_func_plan(
27292726
"str.swapcase",
27302727
"str.title",
27312728
"str.reverse",
2729+
"str.match",
27322730
)
27332731

27342732
def get_arrow_func(name):
@@ -2740,20 +2738,23 @@ def get_arrow_func(name):
27402738
if name.startswith("str.is"):
27412739
body = name.split(".")[1]
27422740
return "utf8_" + body[:2] + "_" + body[2:]
2741+
if name == "str.match":
2742+
return "match_substring_regex"
27432743
if name.startswith("str."):
27442744
return "utf8_" + name.split(".")[1]
27452745
return name.split(".")[1]
27462746

2747-
if func in arrow_compute_list:
2747+
if func in arrow_compute_list and len(kwargs) == 0:
27482748
func_name = get_arrow_func(func)
2749-
func_args = () # TODO: expand this to enable arrow compute calls with args
2749+
func_args = tuple(args)
27502750
is_cfunc = False
27512751
has_state = False
27522752
expr = ArrowScalarFuncExpression(
27532753
empty_data,
27542754
source_data,
27552755
(col_index,) + tuple(index_cols),
27562756
func_name,
2757+
func_args,
27572758
)
27582759
else:
27592760
# Empty func_name separates Python calls from Arrow calls.

bodo/tests/test_df_lib/test_end_to_end.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3763,3 +3763,18 @@ def test_join_non_equi_key_not_in_output():
37633763
reset_index=True,
37643764
sort_output=True,
37653765
)
3766+
3767+
3768+
def test_series_str_match():
3769+
s = pd.Series(["abc", "a1c", "zzz", None], dtype="string")
3770+
bs = bd.Series(s)
3771+
3772+
# Match strings that start with 'a' and end with 'c'
3773+
pmask = s.str.match(r"a.*c")
3774+
bmask = bs.str.match(r"a.*c")
3775+
3776+
_test_equal(
3777+
bmask,
3778+
pmask,
3779+
check_pandas_types=False,
3780+
)

0 commit comments

Comments
 (0)