Skip to content

Commit 9556bcd

Browse files
tlm365Tai Le Manh
andauthored
[datafusion-spark] Implement factorical function (#16125)
* Implement spark-compatible factorical function Signed-off-by: Tai Le Manh <[email protected]> * Remove clippy warning * Add unit tests --------- Signed-off-by: Tai Le Manh <[email protected]> Co-authored-by: Tai Le Manh <[email protected]>
1 parent cdaaef7 commit 9556bcd

File tree

3 files changed

+244
-3
lines changed

3 files changed

+244
-3
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{Array, Int64Array};
22+
use arrow::datatypes::DataType;
23+
use arrow::datatypes::DataType::{Int32, Int64};
24+
use datafusion_common::cast::as_int32_array;
25+
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
26+
use datafusion_expr::Signature;
27+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
28+
29+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#factorial>
30+
#[derive(Debug)]
31+
pub struct SparkFactorial {
32+
signature: Signature,
33+
aliases: Vec<String>,
34+
}
35+
36+
impl Default for SparkFactorial {
37+
fn default() -> Self {
38+
Self::new()
39+
}
40+
}
41+
42+
impl SparkFactorial {
43+
pub fn new() -> Self {
44+
Self {
45+
signature: Signature::exact(vec![Int32], Volatility::Immutable),
46+
aliases: vec![],
47+
}
48+
}
49+
}
50+
51+
impl ScalarUDFImpl for SparkFactorial {
52+
fn as_any(&self) -> &dyn Any {
53+
self
54+
}
55+
56+
fn name(&self) -> &str {
57+
"factorial"
58+
}
59+
60+
fn signature(&self) -> &Signature {
61+
&self.signature
62+
}
63+
64+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
65+
Ok(Int64)
66+
}
67+
68+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69+
spark_factorial(&args.args)
70+
}
71+
72+
fn aliases(&self) -> &[String] {
73+
&self.aliases
74+
}
75+
}
76+
77+
const FACTORIALS: [i64; 21] = [
78+
1,
79+
1,
80+
2,
81+
6,
82+
24,
83+
120,
84+
720,
85+
5040,
86+
40320,
87+
362880,
88+
3628800,
89+
39916800,
90+
479001600,
91+
6227020800,
92+
87178291200,
93+
1307674368000,
94+
20922789888000,
95+
355687428096000,
96+
6402373705728000,
97+
121645100408832000,
98+
2432902008176640000,
99+
];
100+
101+
pub fn spark_factorial(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
102+
if args.len() != 1 {
103+
return Err(DataFusionError::Internal(
104+
"`factorial` expects exactly one argument".to_string(),
105+
));
106+
}
107+
108+
match &args[0] {
109+
ColumnarValue::Scalar(ScalarValue::Int32(value)) => {
110+
let result = compute_factorial(*value);
111+
Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)))
112+
}
113+
ColumnarValue::Scalar(other) => {
114+
exec_err!("`factorial` got an unexpected scalar type: {:?}", other)
115+
}
116+
ColumnarValue::Array(array) => match array.data_type() {
117+
Int32 => {
118+
let array = as_int32_array(array)?;
119+
120+
let result: Int64Array = array.iter().map(compute_factorial).collect();
121+
122+
Ok(ColumnarValue::Array(Arc::new(result)))
123+
}
124+
other => {
125+
exec_err!("`factorial` got an unexpected argument type: {:?}", other)
126+
}
127+
},
128+
}
129+
}
130+
131+
#[inline]
132+
fn compute_factorial(num: Option<i32>) -> Option<i64> {
133+
num.filter(|&v| (0..=20).contains(&v))
134+
.map(|v| FACTORIALS[v as usize])
135+
}
136+
137+
#[cfg(test)]
138+
mod test {
139+
use crate::function::math::factorial::spark_factorial;
140+
use arrow::array::{Int32Array, Int64Array};
141+
use datafusion_common::cast::as_int64_array;
142+
use datafusion_common::ScalarValue;
143+
use datafusion_expr::ColumnarValue;
144+
use std::sync::Arc;
145+
146+
#[test]
147+
fn test_spark_factorial_array() {
148+
let input = Int32Array::from(vec![
149+
Some(-1),
150+
Some(0),
151+
Some(1),
152+
Some(2),
153+
Some(4),
154+
Some(20),
155+
Some(21),
156+
None,
157+
]);
158+
159+
let args = ColumnarValue::Array(Arc::new(input));
160+
let result = spark_factorial(&[args]).unwrap();
161+
let result = match result {
162+
ColumnarValue::Array(array) => array,
163+
_ => panic!("Expected array"),
164+
};
165+
166+
let actual = as_int64_array(&result).unwrap();
167+
let expected = Int64Array::from(vec![
168+
None,
169+
Some(1),
170+
Some(1),
171+
Some(2),
172+
Some(24),
173+
Some(2432902008176640000),
174+
None,
175+
None,
176+
]);
177+
178+
assert_eq!(actual, &expected);
179+
}
180+
181+
#[test]
182+
fn test_spark_factorial_scalar() {
183+
let input = ScalarValue::Int32(Some(5));
184+
185+
let args = ColumnarValue::Scalar(input);
186+
let result = spark_factorial(&[args]).unwrap();
187+
let result = match result {
188+
ColumnarValue::Scalar(ScalarValue::Int64(val)) => val,
189+
_ => panic!("Expected scalar"),
190+
};
191+
let actual = result.unwrap();
192+
let expected = 120_i64;
193+
194+
assert_eq!(actual, expected);
195+
}
196+
}

datafusion/spark/src/function/math/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,29 @@
1616
// under the License.
1717

1818
pub mod expm1;
19+
pub mod factorial;
1920
pub mod hex;
2021

2122
use datafusion_expr::ScalarUDF;
2223
use datafusion_functions::make_udf_function;
2324
use std::sync::Arc;
2425

2526
make_udf_function!(expm1::SparkExpm1, expm1);
27+
make_udf_function!(factorial::SparkFactorial, factorial);
2628
make_udf_function!(hex::SparkHex, hex);
2729

2830
pub mod expr_fn {
2931
use datafusion_functions::export_functions;
3032

3133
export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1));
34+
export_functions!((
35+
factorial,
36+
"Returns the factorial of expr. expr is [0..20]. Otherwise, null.",
37+
arg1
38+
));
3239
export_functions!((hex, "Computes hex value of the given column.", arg1));
3340
}
3441

3542
pub fn functions() -> Vec<Arc<ScalarUDF>> {
36-
vec![expm1(), hex()]
43+
vec![expm1(), factorial(), hex()]
3744
}

datafusion/sqllogictest/test_files/spark/math/factorial.slt

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,44 @@
2323

2424
## Original Query: SELECT factorial(5);
2525
## PySpark 3.5.5 Result: {'factorial(5)': 120, 'typeof(factorial(5))': 'bigint', 'typeof(5)': 'int'}
26-
#query
27-
#SELECT factorial(5::int);
26+
query I
27+
SELECT factorial(5::INT);
28+
----
29+
120
2830

31+
query I
32+
SELECT factorial(a)
33+
FROM VALUES
34+
(-1::INT),
35+
(0::INT), (1::INT), (2::INT), (3::INT), (4::INT), (5::INT), (6::INT), (7::INT), (8::INT), (9::INT), (10::INT),
36+
(11::INT), (12::INT), (13::INT), (14::INT), (15::INT), (16::INT), (17::INT), (18::INT), (19::INT), (20::INT),
37+
(21::INT),
38+
(NULL) AS t(a);
39+
----
40+
NULL
41+
1
42+
1
43+
2
44+
6
45+
24
46+
120
47+
720
48+
5040
49+
40320
50+
362880
51+
3628800
52+
39916800
53+
479001600
54+
6227020800
55+
87178291200
56+
1307674368000
57+
20922789888000
58+
355687428096000
59+
6402373705728000
60+
121645100408832000
61+
2432902008176640000
62+
NULL
63+
NULL
64+
65+
query error Error during planning: Failed to coerce arguments to satisfy a call to 'factorial' function
66+
SELECT factorial(5::BIGINT);

0 commit comments

Comments
 (0)