Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub mod overlay;
pub mod planner;
pub mod r#struct;
pub mod union_extract;
pub mod union_tag;
pub mod version;

// create UDFs
Expand All @@ -52,6 +53,7 @@ make_udf_function!(coalesce::CoalesceFunc, coalesce);
make_udf_function!(greatest::GreatestFunc, greatest);
make_udf_function!(least::LeastFunc, least);
make_udf_function!(union_extract::UnionExtractFun, union_extract);
make_udf_function!(union_tag::UnionTagFunc, union_tag);
make_udf_function!(version::VersionFunc, version);

pub mod expr_fn {
Expand Down Expand Up @@ -101,6 +103,10 @@ pub mod expr_fn {
least,
"Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL",
args,
),(
union_tag,
"Returns the name of the currently selected field in the union",
arg1
));

#[doc = "Returns the value of the field with the given name from the struct"]
Expand Down Expand Up @@ -136,6 +142,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
greatest(),
least(),
union_extract(),
union_tag(),
version(),
r#struct(),
]
Expand Down
223 changes: 223 additions & 0 deletions datafusion/functions/src/core/union_tag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, AsArray, DictionaryArray, Int8Array, StringArray};
use arrow::datatypes::DataType;
use datafusion_common::utils::take_function_args;
use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
use datafusion_doc::Documentation;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use std::sync::Arc;

#[user_doc(
doc_section(label = "Union Functions"),
description = "Returns the name of the currently selected field in the union",
syntax_example = "union_tag(union_expression)",
sql_example = r#"```sql
❯ select union_column, union_tag(union_column) from table_with_union;
+--------------+-------------------------+
| union_column | union_tag(union_column) |
+--------------+-------------------------+
| {a=1} | a |
| {b=3.0} | b |
| {a=4} | a |
| {b=} | b |
| {a=} | a |
+--------------+-------------------------+
```"#,
standard_argument(name = "union", prefix = "Union")
)]
#[derive(Debug)]
pub struct UnionTagFunc {
signature: Signature,
}

impl Default for UnionTagFunc {
fn default() -> Self {
Self::new()
}
}

impl UnionTagFunc {
pub fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for UnionTagFunc {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"union_tag"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Utf8),
))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [union_] = take_function_args("union_tag", args.args)?;

match union_ {
ColumnarValue::Array(array)
if matches!(array.data_type(), DataType::Union(_, _)) =>
{
let union_array = array.as_union();

let keys = Int8Array::try_new(union_array.type_ids().clone(), None)?;

let fields = match union_array.data_type() {
DataType::Union(fields, _) => fields,
_ => unreachable!(),
};

// Union fields type IDs only constraints are being unique and in the 0..128 range:
// They may not start at 0, be sequential, or even contiguous.
// Therefore, we allocate a values vector with a length equal to the highest type ID plus one,
// ensuring that each field's name can be placed at the index corresponding to its type ID.
Copy link
Contributor Author

@gstvg gstvg Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The union column used on the sqllogictests contains a single field with type id 3, so this is put to the test

fn register_union_table(ctx: &SessionContext) {
let union = UnionArray::try_new(
UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]),
ScalarBuffer::from(vec![3, 3]),
None,
vec![Arc::new(Int32Array::from(vec![1, 2]))],
)
.unwrap();
let schema = Schema::new(vec![Field::new(
"union_column",
union.data_type().clone(),
false,
)]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union)]).unwrap();
ctx.register_batch("union_table", batch).unwrap();
}

"union_function.slt" => {
info!("Registering table with union column");
register_union_table(test_ctx.session_ctx())
}

let values_len = fields
.iter()
.map(|(type_id, _)| type_id + 1)
.max()
.unwrap_or_default() as usize;

let mut values = vec![""; values_len];

for (type_id, field) in fields.iter() {
values[type_id as usize] = field.name().as_str()
}

let values = Arc::new(StringArray::from(values));

// SAFETY: union type_ids are validated to not be smaller than zero.
// values len is the union biggest type id plus one.
// keys is built from the union type_ids, which contains only valid type ids
// therefore, `keys[i] >= values.len() || keys[i] < 0` never occurs
let dict = unsafe { DictionaryArray::new_unchecked(keys, values) };

Ok(ColumnarValue::Array(Arc::new(dict)))
}
ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => match value {
Some((value_type_id, _)) => fields
.iter()
.find(|(type_id, _)| value_type_id == *type_id)
.map(|(_, field)| {
ColumnarValue::Scalar(ScalarValue::Dictionary(
Box::new(DataType::Int8),
Box::new(field.name().as_str().into()),
))
})
.ok_or_else(|| {
exec_datafusion_err!(
"union_tag: union scalar with unknow type_id {value_type_id}"
)
}),
None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
args.return_type,
)?)),
},
v => exec_err!("union_tag only support unions, got {:?}", v.data_type()),
}
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

#[cfg(test)]
mod tests {
use super::UnionTagFunc;
use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
use std::sync::Arc;

// when it becomes possible to construct union scalars in SQL, this should go to sqllogictests
#[test]
fn union_scalar() {
let fields = [(0, Arc::new(Field::new("a", DataType::UInt32, false)))]
.into_iter()
.collect();

let scalar = ScalarValue::Union(
Some((0, Box::new(ScalarValue::UInt32(Some(0))))),
fields,
UnionMode::Dense,
);

let result = UnionTagFunc::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(scalar)],
number_rows: 1,
return_type: &DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Utf8),
),
})
.unwrap();

assert_scalar(
result,
ScalarValue::Dictionary(Box::new(DataType::Int8), Box::new("a".into())),
);
}

#[test]
fn union_scalar_empty() {
let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense);

let result = UnionTagFunc::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(scalar)],
number_rows: 1,
return_type: &DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Utf8),
),
})
.unwrap();

assert_scalar(
result,
ScalarValue::Dictionary(
Box::new(DataType::Int8),
Box::new(ScalarValue::Utf8(None)),
),
);
}

fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
match value {
ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
}
}
}
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/union_function.slt
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,18 @@ select union_extract(union_column, 1) from union_table;

query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3
select union_extract(union_column, 'a', 'b') from union_table;

query ?T
select union_column, union_tag(union_column) from union_table;
----
{int=1} int
{int=2} int

query error DataFusion error: Error during planning: 'union_tag' does not support zero arguments
select union_tag() from union_table;

query error DataFusion error: Error during planning: The function 'union_tag' expected 1 arguments but received 2
select union_tag(union_column, 'int') from union_table;

query error DataFusion error: Execution error: union_tag only support unions, got Utf8
select union_tag('int') from union_table;
28 changes: 28 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -4404,6 +4404,7 @@ sha512(expression)
Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator

- [union_extract](#union_extract)
- [union_tag](#union_tag)

### `union_extract`

Expand Down Expand Up @@ -4433,6 +4434,33 @@ union_extract(union, field_name)
+--------------+----------------------------------+----------------------------------+
```

### `union_tag`

Returns the name of the currently selected field in the union

```sql
union_tag(union_expression)
```

#### Arguments

- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators.

#### Example

```sql
❯ select union_column, union_tag(union_column) from table_with_union;
+--------------+-------------------------+
| union_column | union_tag(union_column) |
+--------------+-------------------------+
| {a=1} | a |
| {b=3.0} | b |
| {a=4} | a |
| {b=} | b |
| {a=} | a |
+--------------+-------------------------+
```

## Other Functions

- [arrow_cast](#arrow_cast)
Expand Down