Skip to content
Merged
181 changes: 163 additions & 18 deletions datafusion/physical-expr/src/utils/guarantee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,6 @@ impl LiteralGuarantee {
.filter_map(|expr| ColOpLit::try_new(expr))
.collect::<Vec<_>>();

if terms.is_empty() {
return builder;
}

// if not all terms are of the form (col <op> literal),
// can't infer any guarantees
if terms.len() != disjunctions.len() {
return builder;
}

// if all terms are 'col <op> literal' with the same column
// and operation we can infer any guarantees
//
Expand All @@ -203,18 +193,68 @@ impl LiteralGuarantee {
// foo is required for the expression to be true.
// So we can only create a multi value guarantee for `=`
// (or a single value). (e.g. ignore `a != foo OR a != bar`)
let first_term = &terms[0];
if terms.iter().all(|term| {
term.col.name() == first_term.col.name()
&& term.guarantee == Guarantee::In
}) {
let first_term = terms.first();
if !terms.is_empty()
&& terms.len() == disjunctions.len()
&& terms.iter().all(|term| {
term.col.name() == first_term.unwrap().col.name()
&& term.guarantee == Guarantee::In
})
{
builder.aggregate_multi_conjunct(
first_term.col,
first_term.unwrap().col,
Guarantee::In,
terms.iter().map(|term| term.lit.value()),
)
} else {
// can't infer anything
// Handle disjunctions with conjunctions like (a = 1 AND b = 2) OR (a = 2 AND b = 3)
// Extract termsets from each disjunction
// if in each termset, they have same column, and the guarantee is In,
// we can infer a guarantee for the column
// e.g. (a = 1 AND b = 2) OR (a = 2 AND b = 3) is `a IN (1, 2) AND b IN (2, 3)`
// otherwise, we can't infer a guarantee
let termsets: Vec<Vec<ColOpLit>> = disjunctions
.iter()
.map(|expr| {
split_conjunction(expr)
.into_iter()
.filter_map(ColOpLit::try_new)
.filter(|term| term.guarantee == Guarantee::In)
.collect()
})
.collect();

// Early return if any termset is empty (can't infer guarantees)
if termsets.iter().any(|terms| terms.is_empty()) {
return builder;
}

// Find columns that appear in all termsets
let common_cols = find_common_columns(&termsets);
if common_cols.is_empty() {
return builder;
}

// Build guarantees for common columns
let mut builder = builder;
for col in common_cols {
let literals: Vec<_> = termsets
.iter()
.filter_map(|terms| {
terms
.iter()
.find(|term| term.col == col)
.map(|term| term.lit.value())
})
.collect();

builder = builder.aggregate_multi_conjunct(
col,
Guarantee::In,
literals.into_iter(),
);
}

builder
}
}
Expand Down Expand Up @@ -410,6 +450,36 @@ impl<'a> ColOpLit<'a> {
}
}

/// Find columns that appear in all termsets
fn find_common_columns<'a>(
termsets: &[Vec<ColOpLit<'a>>],
) -> Vec<&'a crate::expressions::Column> {
if termsets.is_empty() {
return Vec::new();
}

// Start with columns from the first termset
let mut common_cols: HashSet<_> = termsets[0].iter().map(|term| term.col).collect();

// check if any common_col in one termset occur many times
// e.g. (a = 1 AND a = 2) OR (a = 2 AND b = 3), should not infer a guarantee
// TODO: for above case, we can infer a IN (2) AND b IN (3)
if common_cols.len() != termsets[0].len() {
return Vec::new();
}

// Intersect with columns from remaining termsets
for termset in termsets.iter().skip(1) {
let termset_cols: HashSet<_> = termset.iter().map(|term| term.col).collect();
if termset_cols.len() != termset.len() {
return Vec::new();
}
common_cols = common_cols.intersection(&termset_cols).cloned().collect();
}

common_cols.into_iter().collect()
}

#[cfg(test)]
mod test {
use std::sync::LazyLock;
Expand Down Expand Up @@ -824,13 +894,87 @@ mod test {
);
}

#[test]
fn test_disjunction_and_conjunction_multi_column() {
// (a = "foo" AND b = 1) OR (a = "bar" AND b = 2)
test_analyze(
(col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
.or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))),
vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1, 2])],
);
// (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) OR (b = 3)
test_analyze(
(col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
.or(col("a").eq(lit("bar")).and(col("b").eq(lit(2))))
.or(col("b").eq(lit(3))),
vec![in_guarantee("b", [1, 2, 3])],
);
// (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) OR (c = 3)
test_analyze(
(col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
.or(col("a").eq(lit("bar")).and(col("b").eq(lit(2))))
.or(col("c").eq(lit(3))),
vec![],
);
// (a = "foo" AND b = 1) OR (a != "bar" AND b = 2)
test_analyze(
(col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
.or(col("a").not_eq(lit("bar")).and(col("b").eq(lit(2)))),
vec![in_guarantee("b", [1, 2])],
);
// (a = "foo" AND b > 1) OR (a = "bar" AND b = 2)
test_analyze(
(col("a").eq(lit("foo")).and(col("b").gt(lit(1))))
.or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))),
vec![in_guarantee("a", ["foo", "bar"])],
);
// (a = "foo" AND b = 1) OR (b = 1 AND c = 2) OR (c = 3 AND a = "bar")
test_analyze(
(col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
.or(col("b").eq(lit(1)).and(col("c").eq(lit(2))))
.or(col("c").eq(lit(3)).and(col("a").eq(lit("bar")))),
vec![],
);
// (a = "foo" AND a = "bar") OR (a = "good" AND b = 1)
// TODO: this should be `a IN ("good") AND b IN (1)`
test_analyze(
(col("a").eq(lit("foo")).and(col("a").eq(lit("bar"))))
.or(col("a").eq(lit("good")).and(col("b").eq(lit(1)))),
vec![],
);
// (a = "foo" AND a = "foo") OR (a = "good" AND b = 1)
// TODO: this should be `a IN ("foo", "good")`
test_analyze(
(col("a").eq(lit("foo")).and(col("a").eq(lit("foo"))))
.or(col("a").eq(lit("good")).and(col("b").eq(lit(1)))),
vec![],
);
// (a = "foo" AND b = 3) OR (b = 4 AND b = 1) OR (b = 2 AND a = "bar")
test_analyze(
(col("a").eq(lit("foo")).and(col("b").eq(lit(3))))
.or(col("b").eq(lit(4)).and(col("b").eq(lit(1))))
.or(col("b").eq(lit(2)).and(col("a").eq(lit("bar")))),
vec![],
);
// (b = 1 AND b > 3) OR (a = "foo" AND b = 4)
test_analyze(
(col("b").eq(lit(1)).and(col("b").gt(lit(3))))
.or(col("a").eq(lit("foo")).and(col("b").eq(lit(4)))),
// if b isn't 1 or 4, it can not be true (though the expression actually can never be true)
vec![in_guarantee("b", [1, 4])],
);
}

/// Tests that analyzing expr results in the expected guarantees
fn test_analyze(expr: Expr, expected: Vec<LiteralGuarantee>) {
println!("Begin analyze of {expr}");
let schema = schema();
let physical_expr = logical2physical(&expr, &schema);

let actual = LiteralGuarantee::analyze(&physical_expr);
let actual = LiteralGuarantee::analyze(&physical_expr)
.into_iter()
.sorted_by_key(|g| g.column.name().to_string())
.collect::<Vec<_>>();
assert_eq!(
expected, actual,
"expr: {expr}\
Expand Down Expand Up @@ -867,6 +1011,7 @@ mod test {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]))
});
Arc::clone(&SCHEMA)
Expand Down
Loading