Skip to content

Commit c188ddc

Browse files
committed
clean & doc
Signed-off-by: Ruihang Xia <[email protected]>
1 parent 765809a commit c188ddc

File tree

1 file changed

+89
-92
lines changed

1 file changed

+89
-92
lines changed

datafusion/src/optimizer/common_subexpr_eliminate.rs

Lines changed: 89 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ use crate::logical_plan::{
3030
};
3131
use crate::optimizer::optimizer::OptimizerRule;
3232

33+
/// A map from expression's identifier to tuple including
34+
/// - the expression itself (cloned)
35+
/// - counter
36+
/// - DataType of this expression.
37+
type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;
38+
39+
/// Identifier type. Current implementation use describe of a expression (type String) as
40+
/// Identifier.
41+
///
42+
/// A Identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no
43+
/// collision (as low as possible)"
44+
///
45+
/// Since a identifier is likely to be copied many times, it is better that a identifier
46+
/// is small or "copy". otherwise some kinds of reference count is needed. String description
47+
/// here is not such a good choose.
48+
type Identifier = String;
49+
/// Perform Common Sub-expression Elimination optimization.
50+
///
51+
/// Currently only common sub-expressions within one logical plan will
52+
/// be eliminated.
3353
pub struct CommonSubexprEliminate {}
3454

3555
impl OptimizerRule for CommonSubexprEliminate {
@@ -50,26 +70,19 @@ impl CommonSubexprEliminate {}
5070

5171
fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<LogicalPlan> {
5272
let mut expr_set = ExprSet::new();
53-
let mut addr_map = ExprAddrToId::new();
5473
let mut affected_id = HashSet::new();
5574

5675
match plan {
5776
LogicalPlan::Projection {
5877
expr,
5978
input,
60-
schema,
79+
schema: _,
6180
} => {
6281
let mut arrays = vec![];
6382
for e in expr {
6483
let data_type = e.get_type(input.schema())?;
6584
let mut id_array = vec![];
66-
expr_to_identifier(
67-
e,
68-
&mut expr_set,
69-
&mut addr_map,
70-
&mut id_array,
71-
data_type,
72-
)?;
85+
expr_to_identifier(e, &mut expr_set, &mut id_array, data_type)?;
7386
arrays.push(id_array);
7487
}
7588

@@ -78,32 +91,20 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
7891
LogicalPlan::Filter { predicate, input } => {
7992
let data_type = predicate.get_type(input.schema())?;
8093
let mut id_array = vec![];
81-
expr_to_identifier(
82-
predicate,
83-
&mut expr_set,
84-
&mut addr_map,
85-
&mut id_array,
86-
data_type,
87-
)?;
94+
expr_to_identifier(predicate, &mut expr_set, &mut id_array, data_type)?;
8895

8996
return optimize(input, execution_props);
9097
}
9198
LogicalPlan::Window {
9299
input,
93100
window_expr,
94-
schema,
101+
schema: _,
95102
} => {
96103
let mut arrays = vec![];
97104
for e in window_expr {
98105
let data_type = e.get_type(input.schema())?;
99106
let mut id_array = vec![];
100-
expr_to_identifier(
101-
e,
102-
&mut expr_set,
103-
&mut addr_map,
104-
&mut id_array,
105-
data_type,
106-
)?;
107+
expr_to_identifier(e, &mut expr_set, &mut id_array, data_type)?;
107108
arrays.push(id_array);
108109
}
109110

@@ -120,26 +121,14 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
120121
for e in group_expr {
121122
let data_type = e.get_type(input.schema())?;
122123
let mut id_array = vec![];
123-
expr_to_identifier(
124-
e,
125-
&mut expr_set,
126-
&mut addr_map,
127-
&mut id_array,
128-
data_type,
129-
)?;
124+
expr_to_identifier(e, &mut expr_set, &mut id_array, data_type)?;
130125
group_arrays.push(id_array);
131126
}
132127
let mut aggr_arrays = vec![];
133128
for e in aggr_expr {
134129
let data_type = e.get_type(input.schema())?;
135130
let mut id_array = vec![];
136-
expr_to_identifier(
137-
e,
138-
&mut expr_set,
139-
&mut addr_map,
140-
&mut id_array,
141-
data_type,
142-
)?;
131+
expr_to_identifier(e, &mut expr_set, &mut id_array, data_type)?;
143132
aggr_arrays.push(id_array);
144133
}
145134

@@ -163,8 +152,7 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
163152

164153
let mut new_input = optimize(input, execution_props)?;
165154
if !affected_id.is_empty() {
166-
new_input =
167-
build_project_plan(new_input, affected_id, &expr_set, schema)?;
155+
new_input = build_project_plan(new_input, affected_id, &expr_set)?;
168156
}
169157

170158
return Ok(LogicalPlan::Aggregate {
@@ -174,7 +162,7 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
174162
schema: schema.clone(),
175163
});
176164
}
177-
LogicalPlan::Sort { expr, input } => {}
165+
LogicalPlan::Sort { expr: _, input: _ } => {}
178166
LogicalPlan::Join { .. }
179167
| LogicalPlan::CrossJoin { .. }
180168
| LogicalPlan::Repartition { .. }
@@ -193,13 +181,12 @@ fn build_project_plan(
193181
input: LogicalPlan,
194182
affected_id: HashSet<Identifier>,
195183
expr_set: &ExprSet,
196-
schema: &DFSchema,
197184
) -> Result<LogicalPlan> {
198185
let mut project_exprs = vec![];
199186
let mut fields = vec![];
200187

201188
for id in affected_id {
202-
let (expr, _, _, _, data_type) = expr_set.get(&id).unwrap();
189+
let (expr, _, data_type) = expr_set.get(&id).unwrap();
203190
// todo: check `nullable`
204191
fields.push(DFField::new(None, &id, data_type.clone(), true));
205192
project_exprs.push(expr.clone());
@@ -212,50 +199,44 @@ fn build_project_plan(
212199
})
213200
}
214201

215-
// Helper struct & func
216-
217-
/// A map from expression's identifier to tuple including
218-
/// - the expression itself (cloned)
219-
/// - a hash set contains all addresses with the same identifier.
220-
/// - counter
221-
/// - A alternative plan.
222-
pub type ExprSet =
223-
HashMap<Identifier, (Expr, HashSet<*const Expr>, usize, Option<()>, DataType)>;
224-
225-
pub type ExprAddrToId = HashMap<*const Expr, Identifier>;
226-
227-
/// Identifier type. Current implementation use descript of a expression (type String) as
228-
/// Identifier.
202+
/// Go through an expression tree and generate identifier.
229203
///
230-
/// A Identifier should (ideally) be "hashable", "accumulatable", "Eq", "collisionless"
231-
/// (as low probability as possible).
204+
/// An identifier contains information of the expression itself and its sub-expression.
205+
/// This visitor implementation use a stack `visit_stack` to track traversal, which
206+
/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called
207+
/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack.
208+
/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem`
209+
/// before the first `EnterMark` is considered to be sub-tree of the leaving node.
210+
///
211+
/// This visitor also records identifier in `id_array`. Makes the following traverse
212+
/// pass can get the identifier of a node without recalculate it. We assign each node
213+
/// in the expr tree a series number, start from 1, maintained by `series_number`.
214+
/// Series number represents the order we left (`post_visit`) a node. Has the property
215+
/// that child node's series number always smaller than parent's. While `id_array` is
216+
/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to
217+
/// get the index of `id_array` for each node.
232218
///
233-
/// Since a identifier is likely to be copied many times, it is better that a identifier
234-
/// is small or "copy". otherwise some kinds of reference count is needed.
235-
pub type Identifier = String;
236-
237-
/// Go through an expression tree and generate identifier.
238219
/// `Expr` without sub-expr (column, literal etc.) will not have identifier
239220
/// because they should not be recognized as common sub-expr.
240221
struct ExprIdentifierVisitor<'a> {
241222
// param
242223
expr_set: &'a mut ExprSet,
243-
addr_map: &'a mut ExprAddrToId,
244224
/// series number (usize) and identifier.
245225
id_array: &'a mut Vec<(usize, Identifier)>,
246226
data_type: DataType,
247227

248228
// inner states
249-
visit_stack: Vec<Item>,
229+
visit_stack: Vec<VisitRecord>,
250230
/// increased in pre_visit, start from 0.
251231
node_count: usize,
252232
/// increased in post_visit, start from 1.
253-
post_visit_number: usize,
233+
series_number: usize,
254234
}
255235

256-
enum Item {
236+
/// Record item that used when traversing a expression tree.
237+
enum VisitRecord {
257238
/// `usize` is the monotone increasing series number assigned in pre_visit().
258-
/// Starts from 0. Is used to index ithe dentifier array `id_array` in post_visit().
239+
/// Starts from 0. Is used to index the identifier array `id_array` in post_visit().
259240
EnterMark(usize),
260241
/// Accumulated identifier of sub expression.
261242
ExprItem(Identifier),
@@ -311,35 +292,38 @@ impl ExprIdentifierVisitor<'_> {
311292
desc
312293
}
313294

295+
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
296+
/// before it.
314297
fn pop_enter_mark(&mut self) -> (usize, Identifier) {
315298
let mut desc = String::new();
316299

317300
while let Some(item) = self.visit_stack.pop() {
318301
match item {
319-
Item::EnterMark(idx) => {
302+
VisitRecord::EnterMark(idx) => {
320303
return (idx, desc);
321304
}
322-
Item::ExprItem(s) => {
305+
VisitRecord::ExprItem(s) => {
323306
desc.push_str(&s);
324307
}
325308
}
326309
}
327310

328-
(0, desc)
311+
unreachable!("Enter mark should paired with node number");
329312
}
330313
}
331314

332315
impl ExpressionVisitor for ExprIdentifierVisitor<'_> {
333316
fn pre_visit(mut self, _expr: &Expr) -> Result<Recursion<Self>> {
334-
self.visit_stack.push(Item::EnterMark(self.node_count));
317+
self.visit_stack
318+
.push(VisitRecord::EnterMark(self.node_count));
335319
self.node_count += 1;
336320
// put placeholder
337321
self.id_array.push((0, "".to_string()));
338322
Ok(Recursion::Continue(self))
339323
}
340324

341325
fn post_visit(mut self, expr: &Expr) -> Result<Self> {
342-
self.post_visit_number += 1;
326+
self.series_number += 1;
343327

344328
let (idx, sub_expr_desc) = self.pop_enter_mark();
345329
// skip exprs should not be recognize.
@@ -350,58 +334,61 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> {
350334
| Expr::ScalarVariable(..)
351335
| Expr::Wildcard
352336
) {
353-
self.id_array[idx].0 = self.post_visit_number;
337+
self.id_array[idx].0 = self.series_number;
354338
let desc = Self::desc_expr(expr);
355-
self.visit_stack.push(Item::ExprItem(desc));
339+
self.visit_stack.push(VisitRecord::ExprItem(desc));
356340
return Ok(self);
357341
}
358342
let mut desc = Self::desc_expr(expr);
359343
desc.push_str(&sub_expr_desc);
360344

361-
self.id_array[idx] = (self.post_visit_number, desc.clone());
362-
self.visit_stack.push(Item::ExprItem(desc.clone()));
345+
self.id_array[idx] = (self.series_number, desc.clone());
346+
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
363347
let data_type = self.data_type.clone();
364348
self.expr_set
365349
.entry(desc.clone())
366-
.or_insert_with(|| (expr.clone(), HashSet::new(), 0, None, data_type))
367-
.2 += 1;
368-
self.addr_map.insert(expr as *const Expr, desc);
350+
.or_insert_with(|| (expr.clone(), 0, data_type))
351+
.1 += 1;
369352
Ok(self)
370353
}
371354
}
372355

373-
/// Go through an expression tree and generate identity.
356+
/// Go through an expression tree and generate identifier for every node in this tree.
374357
fn expr_to_identifier(
375358
expr: &Expr,
376359
expr_set: &mut ExprSet,
377-
addr_map: &mut ExprAddrToId,
378360
id_array: &mut Vec<(usize, Identifier)>,
379361
data_type: DataType,
380362
) -> Result<()> {
381363
expr.accept(ExprIdentifierVisitor {
382364
expr_set,
383-
addr_map,
384365
id_array,
385366
data_type,
386367
visit_stack: vec![],
387368
node_count: 0,
388-
post_visit_number: 0,
369+
series_number: 0,
389370
})?;
390371

391372
Ok(())
392373
}
393374

375+
/// Rewrite expression by replacing detected common sub-expression with
376+
/// the corresponding temporary column name. That column contains the
377+
/// evaluate result of replaced expression.
394378
struct CommonSubexprRewriter<'a> {
395379
expr_set: &'a mut ExprSet,
396380
id_array: &'a Vec<(usize, Identifier)>,
381+
/// Which identifier is replaced.
397382
affected_id: &'a mut HashSet<Identifier>,
398383

384+
/// the max series number we have rewritten. Other expression nodes
385+
/// with smaller series number is already replaced and shouldn't
386+
/// do anything with them.
399387
max_series_number: usize,
388+
/// current node's information's index in `id_array`.
400389
curr_index: usize,
401390
}
402391

403-
impl CommonSubexprRewriter<'_> {}
404-
405392
impl ExprRewriter for CommonSubexprRewriter<'_> {
406393
fn pre_visit(&mut self, _: &Expr) -> Result<RewriteRecursion> {
407394
if self.curr_index >= self.id_array.len()
@@ -416,8 +403,7 @@ impl ExprRewriter for CommonSubexprRewriter<'_> {
416403
self.curr_index += 1;
417404
return Ok(RewriteRecursion::Continue);
418405
}
419-
let (stored_expr, somewhat_set, counter, another_thing, _) =
420-
self.expr_set.get(curr_id).unwrap();
406+
let (_, counter, _) = self.expr_set.get(curr_id).unwrap();
421407
if *counter > 1 {
422408
self.affected_id.insert(curr_id.clone());
423409
Ok(RewriteRecursion::Mutate)
@@ -471,8 +457,16 @@ mod test {
471457
use crate::logical_plan::{binary_expr, col, lit, sum, LogicalPlanBuilder, Operator};
472458
use crate::test::*;
473459

460+
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
461+
let optimizer = CommonSubexprEliminate {};
462+
let optimized_plan = optimizer
463+
.optimize(plan, &ExecutionProps::new())
464+
.expect("failed to optimize plan");
465+
let formatted_plan = format!("{:?}", optimized_plan);
466+
assert_eq!(formatted_plan, expected);
467+
}
468+
474469
#[test]
475-
#[ignore]
476470
fn dev_driver_tpch_q1_simplified() -> Result<()> {
477471
// SQL:
478472
// select
@@ -504,8 +498,11 @@ mod test {
504498
)?
505499
.build()?;
506500

507-
let optimizer = CommonSubexprEliminate {};
508-
let new_plan = optimizer.optimize(&plan, &ExecutionProps::new()).unwrap();
501+
let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a), SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a Multiply Int32(1) Plus #test.c)]]\
502+
\n Projection: #test.a Multiply Int32(1) Minus #test.b\
503+
\n TableScan: test projection=None";
504+
505+
assert_optimized_plan_eq(&plan, expected);
509506

510507
Ok(())
511508
}

0 commit comments

Comments
 (0)