Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ impl<S: Storage> Database<S> {
/// Limit(1)
/// Project(a,b)
let source_plan = binder.bind(&stmts[0])?;
// println!("source_plan plan: {:#?}", source_plan);
//println!("source_plan plan: {:#?}", source_plan);

let best_plan = Self::default_optimizer(source_plan).find_best()?;
// println!("best_plan plan: {:#?}", best_plan);
//println!("best_plan plan: {:#?}", best_plan);

let transaction = RefCell::new(transaction);
let mut stream = build(best_plan, &transaction);
Expand All @@ -78,10 +78,14 @@ impl<S: Storage> Database<S> {
.batch(
"Simplify Filter".to_string(),
HepBatchStrategy::fix_point_topdown(10),
vec![RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation],
vec![
RuleImpl::LikeRewrite,
RuleImpl::SimplifyFilter,
RuleImpl::ConstantCalculation,
],
)
.batch(
"Predicate Pushdown".to_string(),
"Predicate Pushown".to_string(),
HepBatchStrategy::fix_point_topdown(10),
vec![
RuleImpl::PushPredicateThroughJoin,
Expand Down Expand Up @@ -206,6 +210,12 @@ mod test {
let _ = kipsql
.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)")
.await?;
let _ = kipsql
.run("create table t4 (a int primary key, b varchar(100))")
.await?;
let _ = kipsql
.run("insert into t4 (a, b) values (1, 'abc'), (2, 'abdc'), (3, 'abcd'), (4, 'ddabc')")
.await?;

println!("show tables:");
let tuples_show_tables = kipsql.run("show tables").await?;
Expand Down Expand Up @@ -371,6 +381,10 @@ mod test {
let tuples_decimal = kipsql.run("select * from t3").await?;
println!("{}", create_table(&tuples_decimal));

println!("like rewrite:");
let tuples_like_rewrite = kipsql.run("select * from t4 where b like 'abc%'").await?;
println!("{}", create_table(&tuples_like_rewrite));

Ok(())
}
}
5 changes: 4 additions & 1 deletion src/optimizer/rule/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::optimizer::rule::pushdown_limit::{
};
use crate::optimizer::rule::pushdown_predicates::PushPredicateIntoScan;
use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin;
use crate::optimizer::rule::simplification::ConstantCalculation;
use crate::optimizer::rule::simplification::SimplifyFilter;
use crate::optimizer::rule::simplification::{ConstantCalculation, LikeRewrite};
use crate::optimizer::OptimizerError;

mod column_pruning;
Expand All @@ -37,6 +37,7 @@ pub enum RuleImpl {
// Simplification
SimplifyFilter,
ConstantCalculation,
LikeRewrite,
}

impl Rule for RuleImpl {
Expand All @@ -53,6 +54,7 @@ impl Rule for RuleImpl {
RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.pattern(),
RuleImpl::SimplifyFilter => SimplifyFilter.pattern(),
RuleImpl::ConstantCalculation => ConstantCalculation.pattern(),
RuleImpl::LikeRewrite => LikeRewrite.pattern(),
}
}

Expand All @@ -69,6 +71,7 @@ impl Rule for RuleImpl {
RuleImpl::SimplifyFilter => SimplifyFilter.apply(node_id, graph),
RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(node_id, graph),
RuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph),
RuleImpl::LikeRewrite => LikeRewrite.apply(node_id, graph),
}
}
}
Expand Down
117 changes: 109 additions & 8 deletions src/optimizer/rule/simplification.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
use crate::expression::{BinaryOperator, ScalarExpression};
use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate};
use crate::optimizer::core::rule::Rule;
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::optimizer::OptimizerError;
use crate::planner::operator::join::JoinCondition;
use crate::planner::operator::Operator;
use crate::types::value::{DataValue, ValueRef};
use lazy_static::lazy_static;
lazy_static! {
static ref LIKE_REWRITE_RULE: Pattern = {
Pattern {
predicate: |op| matches!(op, Operator::Filter(_)),
children: PatternChildrenPredicate::None,
}
};
static ref CONSTANT_CALCULATION_RULE: Pattern = {
Pattern {
predicate: |_| true,
Expand Down Expand Up @@ -109,6 +117,91 @@ impl Rule for SimplifyFilter {
}
}

pub struct LikeRewrite;

impl Rule for LikeRewrite {
fn pattern(&self) -> &Pattern {
&LIKE_REWRITE_RULE
}

fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
if let Operator::Filter(mut filter_op) = graph.operator(node_id).clone() {
Copy link
Member

@KKould KKould Dec 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the operator_mut method to modify directly instead of replace

// if is like expression
if let ScalarExpression::Binary {
op: BinaryOperator::Like,
left_expr,
right_expr,
ty,
} = &mut filter_op.predicate
{
// if left is column and right is constant
if let ScalarExpression::ColumnRef(_) = left_expr.as_ref() {
if let ScalarExpression::Constant(value) = right_expr.as_ref() {
Copy link
Member

@KKould KKould Nov 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce unnecessary nesting and matching

if let ScalarExpression::Constant(DataValue::Utf8(mut val)) = right_expr.as_ref() {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

knock knock

match value.as_ref() {
DataValue::Utf8(val_str) => {
let mut value = val_str.clone().unwrap_or_else(|| "".to_string());

if value.ends_with('%') {
value.pop(); // remove '%'
if let Some(last_char) = value.clone().pop() {
if let Some(next_char) = increment_char(last_char) {
let mut new_value = value.clone();
new_value.pop();
new_value.push(next_char);

let new_expr = ScalarExpression::Binary {
op: BinaryOperator::And,
left_expr: Box::new(ScalarExpression::Binary {
op: BinaryOperator::GtEq,
left_expr: left_expr.clone(),
right_expr: Box::new(
ScalarExpression::Constant(ValueRef::from(
DataValue::Utf8(Some(value)),
)),
),
ty: ty.clone(),
}),
right_expr: Box::new(ScalarExpression::Binary {
op: BinaryOperator::Lt,
left_expr: left_expr.clone(),
right_expr: Box::new(
ScalarExpression::Constant(ValueRef::from(
DataValue::Utf8(Some(new_value)),
)),
),
ty: ty.clone(),
}),
ty: ty.clone(),
};
filter_op.predicate = new_expr;
}
}
}
}
_ => {
graph.version += 1;
return Ok(());
}
}
}
}
}
graph.replace_node(node_id, Operator::Filter(filter_op))
}
// mark changed to skip this rule batch
graph.version += 1;
Ok(())
}
}

fn increment_char(v: char) -> Option<char> {
match v {
'z' => None,
'Z' => None,
_ => std::char::from_u32(v as u32 + 1),
}
}

#[cfg(test)]
mod test {
use crate::binder::test::select_sql_run;
Expand All @@ -118,6 +211,7 @@ mod test {
use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator};
use crate::optimizer::heuristic::batch::HepBatchStrategy;
use crate::optimizer::heuristic::optimizer::HepOptimizer;
use crate::optimizer::rule::simplification::increment_char;
use crate::optimizer::rule::RuleImpl;
use crate::planner::operator::filter::FilterOperator;
use crate::planner::operator::Operator;
Expand All @@ -127,6 +221,13 @@ mod test {
use std::collections::Bound;
use std::sync::Arc;

#[test]
fn test_increment_char() {
assert_eq!(increment_char('a'), Some('b'));
assert_eq!(increment_char('z'), None);
assert_eq!(increment_char('A'), Some('B'));
}

#[tokio::test]
async fn test_constant_calculation_omitted() -> Result<(), DatabaseError> {
// (2 + (-1)) < -(c1 + 1)
Expand Down Expand Up @@ -343,7 +444,7 @@ mod test {
cb_1_c1,
Some(ConstantBinary::Scope {
min: Bound::Unbounded,
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2))))
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))),
})
);

Expand All @@ -353,7 +454,7 @@ mod test {
cb_1_c2,
Some(ConstantBinary::Scope {
min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))),
max: Bound::Unbounded
max: Bound::Unbounded,
})
);

Expand All @@ -363,7 +464,7 @@ mod test {
cb_2_c1,
Some(ConstantBinary::Scope {
min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))),
max: Bound::Unbounded
max: Bound::Unbounded,
})
);

Expand All @@ -373,7 +474,7 @@ mod test {
cb_1_c1,
Some(ConstantBinary::Scope {
min: Bound::Unbounded,
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2))))
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))),
})
);

Expand All @@ -383,7 +484,7 @@ mod test {
cb_3_c1,
Some(ConstantBinary::Scope {
min: Bound::Unbounded,
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1))))
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))),
})
);

Expand All @@ -393,7 +494,7 @@ mod test {
cb_3_c2,
Some(ConstantBinary::Scope {
min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))),
max: Bound::Unbounded
max: Bound::Unbounded,
})
);

Expand All @@ -403,7 +504,7 @@ mod test {
cb_4_c1,
Some(ConstantBinary::Scope {
min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))),
max: Bound::Unbounded
max: Bound::Unbounded,
})
);

Expand All @@ -413,7 +514,7 @@ mod test {
cb_4_c2,
Some(ConstantBinary::Scope {
min: Bound::Unbounded,
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1))))
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))),
})
);

Expand Down