@@ -42,6 +42,7 @@ Algorithm for extracting canonical form from an E-graph:
4242#include  " ast/ast_pp.h" 
4343#include  " ast/ast_util.h" 
4444#include  " ast/euf/euf_egraph.h" 
45+ #include  " ast/rewriter/var_subst.h" 
4546#include  " ast/simplifiers/euf_completion.h" 
4647#include  " ast/shared_occs.h" 
4748
@@ -50,6 +51,7 @@ namespace euf {
5051    completion::completion (ast_manager& m, dependent_expr_state& fmls):
5152        dependent_expr_simplifier (m, fmls),
5253        m_egraph (m),
54+         m_mam (mam::mk(*this , *this )),
5355        m_canonical (m),
5456        m_eargs (m),
5557        m_deps (m),
@@ -58,6 +60,19 @@ namespace euf {
5860        m_ff = m_egraph.mk (m.mk_false (), 0 , 0 , nullptr );
5961        m_rewriter.set_order_eq (true );
6062        m_rewriter.set_flat_and_or (false );
63+ 
64+         std::function<void (euf::enode*, euf::enode*)> _on_merge = 
65+             [&](euf::enode* root, euf::enode* other) { 
66+                 m_mam->on_merge (root, other); 
67+         };
68+         
69+         std::function<void (euf::enode*)> _on_make = 
70+             [&](euf::enode* n) {
71+             m_mam->add_node (n, false );
72+         };
73+         
74+         m_egraph.set_on_merge (_on_merge);
75+         m_egraph.set_on_make (_on_make);
6176    }
6277
6378    void  completion::reduce () {
@@ -75,33 +90,67 @@ namespace euf {
7590    void  completion::add_egraph () {
7691        m_nodes_to_canonize.reset ();
7792        unsigned  sz = qtail ();
93+ 
94+         for  (unsigned  i = qhead (); i < sz; ++i) {
95+             auto  [f, p, d] = m_fmls[i]();
96+             add_constraint (f, d);
97+         }
98+         m_should_propagate = true ;
99+         while  (m_should_propagate) {
100+             m_should_propagate = false ;
101+             m_egraph.propagate ();
102+             m_mam->propagate ();
103+         }
104+     }
105+ 
106+     void  completion::add_constraint (expr* f, expr_dependency* d) {
78107        auto  add_children = [&](enode* n) {                
79108            for  (auto * ch : enode_args (n))
80109                m_nodes_to_canonize.push_back (ch);
81110        };
82- 
83-         for  (unsigned  i = qhead (); i < sz; ++i) {
84-             expr* x, * y;
85-             auto  [f, p, d] = m_fmls[i]();
86-             if  (m.is_eq (f, x, y)) {
87-                 enode* a = mk_enode (x);
88-                 enode* b = mk_enode (y);
89-                 m_egraph.merge (a, b, d);
90-                 add_children (a);
91-                 add_children (b);
92-             }
93-             else  if  (m.is_not (f, f)) {
94-                 enode* n = mk_enode (f);
95-                 m_egraph.merge (n, m_ff, d);
96-                 add_children (n);
97-             }
98-             else  {
99-                 enode* n = mk_enode (f);
100-                 m_egraph.merge (n, m_tt, d);
101-                 add_children (n);
111+         expr* x, * y;
112+         if  (m.is_eq (f, x, y)) {
113+             enode* a = mk_enode (x);
114+             enode* b = mk_enode (y);
115+             m_egraph.merge (a, b, d);
116+             add_children (a);
117+             add_children (b);
118+         }
119+         else  if  (m.is_not (f, f)) {
120+             enode* n = mk_enode (f);
121+             m_egraph.merge (n, m_ff, d);
122+             add_children (n);
123+         }
124+         else  {
125+             enode* n = mk_enode (f);
126+             m_egraph.merge (n, m_tt, d);
127+             add_children (n);
128+             if  (is_forall (f)) {
129+                 quantifier* q = to_quantifier (f);
130+                 ptr_vector<app> ground;
131+                 for  (unsigned  i = 0 ; i < q->get_num_patterns (); ++i) {
132+                     auto  p = to_app (q->get_pattern (i));
133+                     mam::ground_subterms (p, ground);
134+                     for  (expr* g : ground)
135+                         mk_enode (g);
136+                     m_mam->add_pattern (q, p);
137+                 }
138+                 if  (!get_dependency (q)) {
139+                     m_q2dep.insert (q, d);
140+                     get_trail ().push (insert_obj_map (m_q2dep, q));
141+                 }                    
102142            }
103143        }
104-         m_egraph.propagate ();
144+     }
145+ 
146+     void  completion::on_binding (quantifier* q, app* pat, enode* const * binding, unsigned  mg, unsigned  ming, unsigned  mx) {
147+         var_subst subst (m);
148+         expr_ref_vector _binding (m);
149+         for  (unsigned  i = 0 ; i < q->get_num_decls (); ++i)
150+             _binding.push_back (binding[i]->get_expr ());
151+         expr_ref r = subst (q->get_expr (), _binding);
152+         add_constraint (r, get_dependency (q));
153+         m_should_propagate = true ;
105154    }
106155
107156    void  completion::read_egraph () {
0 commit comments