@@ -30,6 +30,26 @@ use crate::logical_plan::{
3030} ;
3131use crate :: optimizer:: optimizer:: OptimizerRule ;
3232
33+ /// A map from expression's identifier to tuple including
34+ /// - the expression itself (cloned)
35+ /// - counter
36+ /// - DataType of this expression.
37+ type ExprSet = HashMap < Identifier , ( Expr , usize , DataType ) > ;
38+
39+ /// Identifier type. Current implementation use describe of a expression (type String) as
40+ /// Identifier.
41+ ///
42+ /// A Identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no
43+ /// collision (as low as possible)"
44+ ///
45+ /// Since a identifier is likely to be copied many times, it is better that a identifier
46+ /// is small or "copy". otherwise some kinds of reference count is needed. String description
47+ /// here is not such a good choose.
48+ type Identifier = String ;
49+ /// Perform Common Sub-expression Elimination optimization.
50+ ///
51+ /// Currently only common sub-expressions within one logical plan will
52+ /// be eliminated.
3353pub struct CommonSubexprEliminate { }
3454
3555impl OptimizerRule for CommonSubexprEliminate {
@@ -50,26 +70,19 @@ impl CommonSubexprEliminate {}
5070
5171fn optimize ( plan : & LogicalPlan , execution_props : & ExecutionProps ) -> Result < LogicalPlan > {
5272 let mut expr_set = ExprSet :: new ( ) ;
53- let mut addr_map = ExprAddrToId :: new ( ) ;
5473 let mut affected_id = HashSet :: new ( ) ;
5574
5675 match plan {
5776 LogicalPlan :: Projection {
5877 expr,
5978 input,
60- schema,
79+ schema : _ ,
6180 } => {
6281 let mut arrays = vec ! [ ] ;
6382 for e in expr {
6483 let data_type = e. get_type ( input. schema ( ) ) ?;
6584 let mut id_array = vec ! [ ] ;
66- expr_to_identifier (
67- e,
68- & mut expr_set,
69- & mut addr_map,
70- & mut id_array,
71- data_type,
72- ) ?;
85+ expr_to_identifier ( e, & mut expr_set, & mut id_array, data_type) ?;
7386 arrays. push ( id_array) ;
7487 }
7588
@@ -78,32 +91,20 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
7891 LogicalPlan :: Filter { predicate, input } => {
7992 let data_type = predicate. get_type ( input. schema ( ) ) ?;
8093 let mut id_array = vec ! [ ] ;
81- expr_to_identifier (
82- predicate,
83- & mut expr_set,
84- & mut addr_map,
85- & mut id_array,
86- data_type,
87- ) ?;
94+ expr_to_identifier ( predicate, & mut expr_set, & mut id_array, data_type) ?;
8895
8996 return optimize ( input, execution_props) ;
9097 }
9198 LogicalPlan :: Window {
9299 input,
93100 window_expr,
94- schema,
101+ schema : _ ,
95102 } => {
96103 let mut arrays = vec ! [ ] ;
97104 for e in window_expr {
98105 let data_type = e. get_type ( input. schema ( ) ) ?;
99106 let mut id_array = vec ! [ ] ;
100- expr_to_identifier (
101- e,
102- & mut expr_set,
103- & mut addr_map,
104- & mut id_array,
105- data_type,
106- ) ?;
107+ expr_to_identifier ( e, & mut expr_set, & mut id_array, data_type) ?;
107108 arrays. push ( id_array) ;
108109 }
109110
@@ -120,26 +121,14 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
120121 for e in group_expr {
121122 let data_type = e. get_type ( input. schema ( ) ) ?;
122123 let mut id_array = vec ! [ ] ;
123- expr_to_identifier (
124- e,
125- & mut expr_set,
126- & mut addr_map,
127- & mut id_array,
128- data_type,
129- ) ?;
124+ expr_to_identifier ( e, & mut expr_set, & mut id_array, data_type) ?;
130125 group_arrays. push ( id_array) ;
131126 }
132127 let mut aggr_arrays = vec ! [ ] ;
133128 for e in aggr_expr {
134129 let data_type = e. get_type ( input. schema ( ) ) ?;
135130 let mut id_array = vec ! [ ] ;
136- expr_to_identifier (
137- e,
138- & mut expr_set,
139- & mut addr_map,
140- & mut id_array,
141- data_type,
142- ) ?;
131+ expr_to_identifier ( e, & mut expr_set, & mut id_array, data_type) ?;
143132 aggr_arrays. push ( id_array) ;
144133 }
145134
@@ -163,8 +152,7 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
163152
164153 let mut new_input = optimize ( input, execution_props) ?;
165154 if !affected_id. is_empty ( ) {
166- new_input =
167- build_project_plan ( new_input, affected_id, & expr_set, schema) ?;
155+ new_input = build_project_plan ( new_input, affected_id, & expr_set) ?;
168156 }
169157
170158 return Ok ( LogicalPlan :: Aggregate {
@@ -174,7 +162,7 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
174162 schema : schema. clone ( ) ,
175163 } ) ;
176164 }
177- LogicalPlan :: Sort { expr, input } => { }
165+ LogicalPlan :: Sort { expr : _ , input : _ } => { }
178166 LogicalPlan :: Join { .. }
179167 | LogicalPlan :: CrossJoin { .. }
180168 | LogicalPlan :: Repartition { .. }
@@ -193,13 +181,12 @@ fn build_project_plan(
193181 input : LogicalPlan ,
194182 affected_id : HashSet < Identifier > ,
195183 expr_set : & ExprSet ,
196- schema : & DFSchema ,
197184) -> Result < LogicalPlan > {
198185 let mut project_exprs = vec ! [ ] ;
199186 let mut fields = vec ! [ ] ;
200187
201188 for id in affected_id {
202- let ( expr, _, _ , _ , data_type) = expr_set. get ( & id) . unwrap ( ) ;
189+ let ( expr, _, data_type) = expr_set. get ( & id) . unwrap ( ) ;
203190 // todo: check `nullable`
204191 fields. push ( DFField :: new ( None , & id, data_type. clone ( ) , true ) ) ;
205192 project_exprs. push ( expr. clone ( ) ) ;
@@ -212,50 +199,44 @@ fn build_project_plan(
212199 } )
213200}
214201
215- // Helper struct & func
216-
217- /// A map from expression's identifier to tuple including
218- /// - the expression itself (cloned)
219- /// - a hash set contains all addresses with the same identifier.
220- /// - counter
221- /// - A alternative plan.
222- pub type ExprSet =
223- HashMap < Identifier , ( Expr , HashSet < * const Expr > , usize , Option < ( ) > , DataType ) > ;
224-
225- pub type ExprAddrToId = HashMap < * const Expr , Identifier > ;
226-
227- /// Identifier type. Current implementation use descript of a expression (type String) as
228- /// Identifier.
202+ /// Go through an expression tree and generate identifier.
229203///
230- /// A Identifier should (ideally) be "hashable", "accumulatable", "Eq", "collisionless"
231- /// (as low probability as possible).
204+ /// An identifier contains information of the expression itself and its sub-expression.
205+ /// This visitor implementation use a stack `visit_stack` to track traversal, which
206+ /// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called
207+ /// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack.
208+ /// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem`
209+ /// before the first `EnterMark` is considered to be sub-tree of the leaving node.
210+ ///
211+ /// This visitor also records identifier in `id_array`. Makes the following traverse
212+ /// pass can get the identifier of a node without recalculate it. We assign each node
213+ /// in the expr tree a series number, start from 1, maintained by `series_number`.
214+ /// Series number represents the order we left (`post_visit`) a node. Has the property
215+ /// that child node's series number always smaller than parent's. While `id_array` is
216+ /// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to
217+ /// get the index of `id_array` for each node.
232218///
233- /// Since a identifier is likely to be copied many times, it is better that a identifier
234- /// is small or "copy". otherwise some kinds of reference count is needed.
235- pub type Identifier = String ;
236-
237- /// Go through an expression tree and generate identifier.
238219/// `Expr` without sub-expr (column, literal etc.) will not have identifier
239220/// because they should not be recognized as common sub-expr.
240221struct ExprIdentifierVisitor < ' a > {
241222 // param
242223 expr_set : & ' a mut ExprSet ,
243- addr_map : & ' a mut ExprAddrToId ,
244224 /// series number (usize) and identifier.
245225 id_array : & ' a mut Vec < ( usize , Identifier ) > ,
246226 data_type : DataType ,
247227
248228 // inner states
249- visit_stack : Vec < Item > ,
229+ visit_stack : Vec < VisitRecord > ,
250230 /// increased in pre_visit, start from 0.
251231 node_count : usize ,
252232 /// increased in post_visit, start from 1.
253- post_visit_number : usize ,
233+ series_number : usize ,
254234}
255235
256- enum Item {
236+ /// Record item that used when traversing a expression tree.
237+ enum VisitRecord {
257238 /// `usize` is the monotone increasing series number assigned in pre_visit().
258- /// Starts from 0. Is used to index ithe dentifier array `id_array` in post_visit().
239+ /// Starts from 0. Is used to index the identifier array `id_array` in post_visit().
259240 EnterMark ( usize ) ,
260241 /// Accumulated identifier of sub expression.
261242 ExprItem ( Identifier ) ,
@@ -311,35 +292,38 @@ impl ExprIdentifierVisitor<'_> {
311292 desc
312293 }
313294
295+ /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
296+ /// before it.
314297 fn pop_enter_mark ( & mut self ) -> ( usize , Identifier ) {
315298 let mut desc = String :: new ( ) ;
316299
317300 while let Some ( item) = self . visit_stack . pop ( ) {
318301 match item {
319- Item :: EnterMark ( idx) => {
302+ VisitRecord :: EnterMark ( idx) => {
320303 return ( idx, desc) ;
321304 }
322- Item :: ExprItem ( s) => {
305+ VisitRecord :: ExprItem ( s) => {
323306 desc. push_str ( & s) ;
324307 }
325308 }
326309 }
327310
328- ( 0 , desc )
311+ unreachable ! ( "Enter mark should paired with node number" ) ;
329312 }
330313}
331314
332315impl ExpressionVisitor for ExprIdentifierVisitor < ' _ > {
333316 fn pre_visit ( mut self , _expr : & Expr ) -> Result < Recursion < Self > > {
334- self . visit_stack . push ( Item :: EnterMark ( self . node_count ) ) ;
317+ self . visit_stack
318+ . push ( VisitRecord :: EnterMark ( self . node_count ) ) ;
335319 self . node_count += 1 ;
336320 // put placeholder
337321 self . id_array . push ( ( 0 , "" . to_string ( ) ) ) ;
338322 Ok ( Recursion :: Continue ( self ) )
339323 }
340324
341325 fn post_visit ( mut self , expr : & Expr ) -> Result < Self > {
342- self . post_visit_number += 1 ;
326+ self . series_number += 1 ;
343327
344328 let ( idx, sub_expr_desc) = self . pop_enter_mark ( ) ;
345329 // skip exprs should not be recognize.
@@ -350,58 +334,61 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> {
350334 | Expr :: ScalarVariable ( ..)
351335 | Expr :: Wildcard
352336 ) {
353- self . id_array [ idx] . 0 = self . post_visit_number ;
337+ self . id_array [ idx] . 0 = self . series_number ;
354338 let desc = Self :: desc_expr ( expr) ;
355- self . visit_stack . push ( Item :: ExprItem ( desc) ) ;
339+ self . visit_stack . push ( VisitRecord :: ExprItem ( desc) ) ;
356340 return Ok ( self ) ;
357341 }
358342 let mut desc = Self :: desc_expr ( expr) ;
359343 desc. push_str ( & sub_expr_desc) ;
360344
361- self . id_array [ idx] = ( self . post_visit_number , desc. clone ( ) ) ;
362- self . visit_stack . push ( Item :: ExprItem ( desc. clone ( ) ) ) ;
345+ self . id_array [ idx] = ( self . series_number , desc. clone ( ) ) ;
346+ self . visit_stack . push ( VisitRecord :: ExprItem ( desc. clone ( ) ) ) ;
363347 let data_type = self . data_type . clone ( ) ;
364348 self . expr_set
365349 . entry ( desc. clone ( ) )
366- . or_insert_with ( || ( expr. clone ( ) , HashSet :: new ( ) , 0 , None , data_type) )
367- . 2 += 1 ;
368- self . addr_map . insert ( expr as * const Expr , desc) ;
350+ . or_insert_with ( || ( expr. clone ( ) , 0 , data_type) )
351+ . 1 += 1 ;
369352 Ok ( self )
370353 }
371354}
372355
373- /// Go through an expression tree and generate identity .
356+ /// Go through an expression tree and generate identifier for every node in this tree .
374357fn expr_to_identifier (
375358 expr : & Expr ,
376359 expr_set : & mut ExprSet ,
377- addr_map : & mut ExprAddrToId ,
378360 id_array : & mut Vec < ( usize , Identifier ) > ,
379361 data_type : DataType ,
380362) -> Result < ( ) > {
381363 expr. accept ( ExprIdentifierVisitor {
382364 expr_set,
383- addr_map,
384365 id_array,
385366 data_type,
386367 visit_stack : vec ! [ ] ,
387368 node_count : 0 ,
388- post_visit_number : 0 ,
369+ series_number : 0 ,
389370 } ) ?;
390371
391372 Ok ( ( ) )
392373}
393374
375+ /// Rewrite expression by replacing detected common sub-expression with
376+ /// the corresponding temporary column name. That column contains the
377+ /// evaluate result of replaced expression.
394378struct CommonSubexprRewriter < ' a > {
395379 expr_set : & ' a mut ExprSet ,
396380 id_array : & ' a Vec < ( usize , Identifier ) > ,
381+ /// Which identifier is replaced.
397382 affected_id : & ' a mut HashSet < Identifier > ,
398383
384+ /// the max series number we have rewritten. Other expression nodes
385+ /// with smaller series number is already replaced and shouldn't
386+ /// do anything with them.
399387 max_series_number : usize ,
388+ /// current node's information's index in `id_array`.
400389 curr_index : usize ,
401390}
402391
403- impl CommonSubexprRewriter < ' _ > { }
404-
405392impl ExprRewriter for CommonSubexprRewriter < ' _ > {
406393 fn pre_visit ( & mut self , _: & Expr ) -> Result < RewriteRecursion > {
407394 if self . curr_index >= self . id_array . len ( )
@@ -416,8 +403,7 @@ impl ExprRewriter for CommonSubexprRewriter<'_> {
416403 self . curr_index += 1 ;
417404 return Ok ( RewriteRecursion :: Continue ) ;
418405 }
419- let ( stored_expr, somewhat_set, counter, another_thing, _) =
420- self . expr_set . get ( curr_id) . unwrap ( ) ;
406+ let ( _, counter, _) = self . expr_set . get ( curr_id) . unwrap ( ) ;
421407 if * counter > 1 {
422408 self . affected_id . insert ( curr_id. clone ( ) ) ;
423409 Ok ( RewriteRecursion :: Mutate )
@@ -471,8 +457,16 @@ mod test {
471457 use crate :: logical_plan:: { binary_expr, col, lit, sum, LogicalPlanBuilder , Operator } ;
472458 use crate :: test:: * ;
473459
460+ fn assert_optimized_plan_eq ( plan : & LogicalPlan , expected : & str ) {
461+ let optimizer = CommonSubexprEliminate { } ;
462+ let optimized_plan = optimizer
463+ . optimize ( plan, & ExecutionProps :: new ( ) )
464+ . expect ( "failed to optimize plan" ) ;
465+ let formatted_plan = format ! ( "{:?}" , optimized_plan) ;
466+ assert_eq ! ( formatted_plan, expected) ;
467+ }
468+
474469 #[ test]
475- #[ ignore]
476470 fn dev_driver_tpch_q1_simplified ( ) -> Result < ( ) > {
477471 // SQL:
478472 // select
@@ -504,8 +498,11 @@ mod test {
504498 ) ?
505499 . build ( ) ?;
506500
507- let optimizer = CommonSubexprEliminate { } ;
508- let new_plan = optimizer. optimize ( & plan, & ExecutionProps :: new ( ) ) . unwrap ( ) ;
501+ let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a), SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a Multiply Int32(1) Plus #test.c)]]\
502+ \n Projection: #test.a Multiply Int32(1) Minus #test.b\
503+ \n TableScan: test projection=None";
504+
505+ assert_optimized_plan_eq ( & plan, expected) ;
509506
510507 Ok ( ( ) )
511508 }
0 commit comments