Skip to content

Commit b2f0170

Browse files
euf_completion with AC: add first cut of AC matching for top-level, add plugins and fix shared expression rewriting in ac-plugin
1 parent bc31276 commit b2f0170

File tree

8 files changed

+139
-26
lines changed

8 files changed

+139
-26
lines changed

src/ast/euf/euf_ac_plugin.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ namespace euf {
9393
return;
9494
for (auto arg : enode_args(n))
9595
if (is_op(arg))
96-
register_shared(arg); // TODO optimization to avoid registering shared terms twice
96+
register_shared(arg);
9797
}
9898

9999
void ac_plugin::register_shared(enode* n) {
@@ -180,7 +180,7 @@ namespace euf {
180180
std::ostream& ac_plugin::display_monomial(std::ostream& out, ptr_vector<node> const& m) const {
181181
for (auto n : m) {
182182
if (n->n->num_args() == 0)
183-
out << mk_pp(n->n->get_expr(), g.get_manager()) << " ";
183+
out << n->n->get_expr_id() << ": " << mk_pp(n->n->get_expr(), g.get_manager()) << " ";
184184
else
185185
out << g.bpp(n->n) << " ";
186186
}
@@ -244,6 +244,7 @@ namespace euf {
244244
if (l == r)
245245
return;
246246
auto j = justification::equality(l, r);
247+
TRACE(plugin, tout << g.bpp(l) << " == " << g.bpp(r) << " " << is_op(l) << " " << is_op(r) << "\n");
247248
if (!is_op(l) && !is_op(r))
248249
merge(mk_node(l), mk_node(r), j);
249250
else
@@ -263,6 +264,7 @@ namespace euf {
263264
void ac_plugin::init_equation(eq const& e) {
264265
m_eqs.push_back(e);
265266
auto& eq = m_eqs.back();
267+
TRACE(plugin, display_equation(tout, e) << "\n");
266268
if (orient_equation(eq)) {
267269

268270
unsigned eq_id = m_eqs.size() - 1;
@@ -273,6 +275,8 @@ namespace euf {
273275
n->root->n->mark1();
274276
push_undo(is_add_eq_index);
275277
m_node_trail.push_back(n->root);
278+
for (auto s : n->root->shared)
279+
m_shared_todo.insert(s);
276280
}
277281
}
278282

@@ -282,6 +286,8 @@ namespace euf {
282286
n->root->n->mark1();
283287
push_undo(is_add_eq_index);
284288
m_node_trail.push_back(n->root);
289+
for (auto s : n->root->shared)
290+
m_shared_todo.insert(s);
285291
}
286292
}
287293

@@ -291,6 +297,7 @@ namespace euf {
291297
for (auto n : monomial(eq.r))
292298
n->root->n->unmark1();
293299

300+
TRACE(plugin, display_equation(tout, e) << "\n");
294301
m_to_simplify_todo.insert(eq_id);
295302
}
296303
else
@@ -368,6 +375,7 @@ namespace euf {
368375
}
369376

370377
void ac_plugin::merge(node* root, node* other, justification j) {
378+
TRACE(plugin, tout << root << " == " << other << " num shared " << other->shared.size() << "\n");
371379
for (auto n : equiv(other))
372380
n->root = root;
373381
m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() });
@@ -394,22 +402,34 @@ namespace euf {
394402
ptr_vector<node> m;
395403
ns.push_back(n);
396404
for (unsigned i = 0; i < ns.size(); ++i) {
397-
n = ns[i];
398-
if (is_op(n))
399-
ns.append(n->num_args(), n->args());
405+
auto k = ns[i];
406+
if (is_op(k))
407+
ns.append(k->num_args(), k->args());
400408
else
401-
m.push_back(mk_node(n));
409+
m.push_back(mk_node(k));
402410
}
403411
return to_monomial(n, m);
404412
}
405413

406414
unsigned ac_plugin::to_monomial(enode* e, ptr_vector<node> const& ms) {
407415
unsigned id = m_monomials.size();
408-
m_monomials.push_back({ ms, bloom() });
416+
m_monomials.push_back({ ms, bloom(), e });
409417
push_undo(is_add_monomial);
410418
return id;
411419
}
412420

421+
enode* ac_plugin::from_monomial(ptr_vector<node> const& mon) {
422+
auto& m = g.get_manager();
423+
ptr_buffer<expr> args;
424+
enode_vector nodes;
425+
for (auto arg : mon) {
426+
nodes.push_back(arg->root->n);
427+
args.push_back(arg->root->n->get_expr());
428+
}
429+
auto n = m.mk_app(m_fid, m_op, args.size(), args.data());
430+
return g.mk(n, 0, nodes.size(), nodes.data());
431+
}
432+
413433
ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) {
414434
auto* mem = r.allocate(sizeof(node));
415435
node* res = new (mem) node();
@@ -427,6 +447,9 @@ namespace euf {
427447
push_undo(is_add_node);
428448
m_nodes.setx(id, r, nullptr);
429449
m_node_trail.push_back(r);
450+
if (is_op(n)) {
451+
// extract shared sub-expressions
452+
}
430453
return r;
431454
}
432455

@@ -983,6 +1006,7 @@ namespace euf {
9831006
//
9841007

9851008
void ac_plugin::propagate_shared() {
1009+
TRACE(plugin, tout << "num shared todo " << m_shared_todo.size() << "\n");
9861010
if (m_shared_todo.empty())
9871011
return;
9881012
while (!m_shared_todo.empty()) {
@@ -1007,12 +1031,15 @@ namespace euf {
10071031
void ac_plugin::simplify_shared(unsigned idx, shared s) {
10081032
auto j = s.j;
10091033
auto old_m = s.m;
1034+
auto old_n = monomial(old_m).m_src;
10101035
ptr_vector<node> m1(monomial(old_m).m_nodes);
1011-
TRACE(plugin, tout << "simplify " << m_pp(*this, monomial(old_m)) << "\n");
1036+
TRACE(plugin, tout << "simplify " << g.bpp(old_n) << ": " << m_pp(*this, monomial(old_m)) << "\n");
10121037
if (!reduce(m1, j))
10131038
return;
10141039

1015-
auto new_m = to_monomial(m1);
1040+
1041+
auto new_n = from_monomial(m1);
1042+
auto new_m = to_monomial(new_n, m1);
10161043
// update shared occurrences for members of the new monomial that are not already in the old monomial.
10171044
for (auto n : monomial(old_m))
10181045
n->root->n->mark1();
@@ -1029,6 +1056,10 @@ namespace euf {
10291056
push_undo(is_update_shared);
10301057
m_shared[idx].m = new_m;
10311058
m_shared[idx].j = j;
1059+
1060+
TRACE(plugin, tout << "shared simplified to " << m_pp(*this, monomial(new_m)) << "\n");
1061+
1062+
push_merge(old_n, new_n, j);
10321063
}
10331064

10341065
justification ac_plugin::justify_rewrite(unsigned eq1, unsigned eq2) {

src/ast/euf/euf_ac_plugin.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ namespace euf {
9797
struct monomial_t {
9898
ptr_vector<node> m_nodes;
9999
bloom m_bloom;
100+
enode* m_src = nullptr;
100101
node* operator[](unsigned i) const { return m_nodes[i]; }
101102
unsigned size() const { return m_nodes.size(); }
102103
void set(ptr_vector<node> const& ns) { m_nodes.reset(); m_nodes.append(ns); m_bloom.m_tick = 0; }
@@ -187,6 +188,7 @@ namespace euf {
187188
unsigned to_monomial(enode* n);
188189
unsigned to_monomial(enode* n, ptr_vector<node> const& ms);
189190
unsigned to_monomial(ptr_vector<node> const& ms) { return to_monomial(nullptr, ms); }
191+
enode* from_monomial(ptr_vector<node> const& m);
190192
monomial_t const& monomial(unsigned i) const { return m_monomials[i]; }
191193
monomial_t& monomial(unsigned i) { return m_monomials[i]; }
192194
void sort(monomial_t& monomial);

src/ast/euf/euf_arith_plugin.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ namespace euf {
3333
}
3434

3535
void arith_plugin::register_node(enode* n) {
36-
// no-op
36+
TRACE(plugin, tout << g.bpp(n) << "\n");
37+
m_add.register_node(n);
38+
m_mul.register_node(n);
3739
}
3840

3941
void arith_plugin::merge_eh(enode* n1, enode* n2) {
42+
TRACE(plugin, tout << g.bpp(n1) << " == " << g.bpp(n2) << "\n");
4043
m_add.merge_eh(n1, n2);
4144
m_mul.merge_eh(n1, n2);
4245
}

src/ast/euf/euf_egraph.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ namespace euf {
310310
}
311311
}
312312

313+
void egraph::register_shared(enode* n, theory_id id) {
314+
force_push();
315+
auto* p = get_plugin(id);
316+
if (p)
317+
p->register_node(n);
318+
}
319+
313320
void egraph::undo_add_th_var(enode* n, theory_id tid) {
314321
theory_var v = n->get_th_var(tid);
315322
SASSERT(v != null_theory_var);

src/ast/euf/euf_egraph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ namespace euf {
318318

319319

320320
void add_th_var(enode* n, theory_var v, theory_id id);
321+
void register_shared(enode* n, theory_id id);
321322
void set_th_propagates_diseqs(theory_id id);
322323
void set_cgc_enabled(enode* n, bool enable_cgc);
323324
void set_merge_tf_enabled(enode* n, bool enable_merge_tf);

src/ast/euf/euf_mam.cpp

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ namespace euf {
649649
}
650650

651651
bool is_ac(func_decl* f) const {
652-
return false && f->is_associative() && f->is_commutative();
652+
return f->is_associative() && f->is_commutative();
653653
}
654654

655655
instruction * mk_init(func_decl* f, unsigned n) {
@@ -1777,6 +1777,10 @@ namespace euf {
17771777
m_use_filters(use_filters) {
17781778
}
17791779

1780+
bool is_ac(func_decl* f) const {
1781+
return f->is_associative() && f->is_commutative();
1782+
}
1783+
17801784
/**
17811785
\brief Create a new code tree for the given quantifier.
17821786
@@ -1791,6 +1795,8 @@ namespace euf {
17911795
code_tree * r = m_ct_manager.mk_code_tree(p->get_decl(), num_args, filter_candidates);
17921796
init(r, qa, mp, first_idx);
17931797
linearise(r->m_root, first_idx);
1798+
if (is_ac(p->get_decl()))
1799+
++m_num_choices;
17941800
r->m_num_choices = m_num_choices;
17951801
TRACE(mam_compiler, tout << "new tree for:\n" << mk_pp(mp, m) << "\n" << *r;);
17961802
return r;
@@ -1861,9 +1867,6 @@ namespace euf {
18611867
unsigned m_old_max_generation;
18621868
union {
18631869
enode * m_curr;
1864-
struct {
1865-
unsigned m_next_pattern;
1866-
};
18671870
struct {
18681871
enode_vector * m_to_recycle;
18691872
enode * const * m_it;
@@ -2009,7 +2012,7 @@ namespace euf {
20092012

20102013
void display_pc_info(std::ostream & out);
20112014

2012-
bool match_ac(initn const* pc);
2015+
bool next_ac_match(initn const* pc);
20132016

20142017
#define INIT_ARGS_SIZE 16
20152018

@@ -2291,9 +2294,53 @@ namespace euf {
22912294
// Established: use Diophantine equations to capture matchability.
22922295
//
22932296

2294-
bool interpreter::match_ac(initn const* pc) {
2297+
bool interpreter::next_ac_match(initn const* pc) {
22952298
unsigned f_args = pc->m_num_args;
22962299
SASSERT(f_args <= m_acargs.size());
2300+
for (unsigned i = f_args; i-- > 0;) {
2301+
unsigned j = m_acpatarg[i];
2302+
m_acbitset[j] = false;
2303+
next_j:
2304+
++j;
2305+
for (; j < m_acargs.size(); ++j) {
2306+
if (m_acbitset[j])
2307+
continue;
2308+
m_registers[i + 1] = m_acargs[j];
2309+
m_acbitset[j] = true;
2310+
m_acpatarg[i] = j;
2311+
break;
2312+
}
2313+
if (j == m_acargs.size())
2314+
continue;
2315+
2316+
for (unsigned ii = i + 1; ii < f_args; ++ii) {
2317+
unsigned k = 0;
2318+
// populate arguments after i
2319+
for (; k < m_acargs.size(); ++k) {
2320+
if (!m_acbitset[k]) {
2321+
m_registers[ii + 1] = m_acargs[k];
2322+
m_acbitset[k] = true;
2323+
m_acpatarg[ii] = k;
2324+
break;
2325+
}
2326+
}
2327+
if (k == m_acargs.size()) {
2328+
--ii;
2329+
// clean up
2330+
for (; ii >= i; --ii) {
2331+
k = m_acpatarg[ii];
2332+
m_acbitset[k] = false;
2333+
}
2334+
goto next_j;
2335+
}
2336+
}
2337+
IF_VERBOSE(2,
2338+
verbose_stream() << "next ac: ";
2339+
for (unsigned j = 0; j < f_args; ++j)
2340+
verbose_stream() << m_acpatarg[j] << " ";
2341+
verbose_stream() << "\n";);
2342+
return true;
2343+
}
22972344
return false;
22982345
}
22992346

@@ -2412,6 +2459,7 @@ namespace euf {
24122459
m_acargs.reset();
24132460
m_acargs.push_back(m_app);
24142461
auto* f = m_app->get_decl();
2462+
auto num_pat_args = static_cast<const initn*>(m_pc)->m_num_args;
24152463
for (unsigned i = 0; i < m_acargs.size(); ++i) {
24162464
auto* arg = m_acargs[i];
24172465
if (is_app(arg->get_expr()) && f == arg->get_decl()) {
@@ -2421,19 +2469,20 @@ namespace euf {
24212469
--i;
24222470
}
24232471
}
2424-
if (static_cast<const initn*>(m_pc)->m_num_args > m_acargs.size())
2472+
if (num_pat_args > m_acargs.size())
24252473
goto backtrack;
24262474
m_acbitset.reset();
24272475
m_acbitset.reserve(m_acargs.size(), false);
24282476
m_acpatarg.reset();
24292477
m_acpatarg.reserve(m_acargs.size(), 0);
24302478
m_backtrack_stack[m_top].m_instr = m_pc;
24312479
m_backtrack_stack[m_top].m_old_max_generation = m_curr_max_generation;
2432-
m_backtrack_stack[m_top].m_next_pattern = 0;
2433-
++m_top;
2434-
// perform the match relative index
2435-
if (!match_ac(static_cast<initn const*>(m_pc)))
2436-
goto backtrack;
2480+
++m_top;
2481+
for (unsigned i = 0; i < num_pat_args; ++i) {
2482+
m_acpatarg[i] = i;
2483+
m_acbitset[i] = true;
2484+
m_registers[i + 1] = m_acargs[i];
2485+
}
24372486
m_pc = m_pc->m_next;
24382487
goto main_loop;
24392488
}
@@ -2499,7 +2548,7 @@ namespace euf {
24992548
m_app = get_first_f_app(static_cast<const bind *>(m_pc)->m_label, static_cast<const bind *>(m_pc)->m_num_args, m_n1); \
25002549
if (!m_app) \
25012550
goto backtrack; \
2502-
TRACE(mam_int, tout << "bind candidate: " << mk_pp(m_app->get_expr(), m) << "\n";); \
2551+
TRACE(mam_int, tout << "bind candidate: " << mk_pp(m_app->get_expr(), m) << " " << m_top << " " << m_backtrack_stack.size() << "\n";); \
25032552
m_backtrack_stack[m_top].m_instr = m_pc; \
25042553
m_backtrack_stack[m_top].m_old_max_generation = m_curr_max_generation; \
25052554
m_backtrack_stack[m_top].m_curr = m_app; \
@@ -2832,7 +2881,11 @@ namespace euf {
28322881

28332882
case INITAC:
28342883
// this is a backtracking point.
2835-
NOT_IMPLEMENTED_YET();
2884+
if (!next_ac_match(static_cast<initn const*>(bp.m_instr))) {
2885+
--m_top;
2886+
goto backtrack;
2887+
}
2888+
m_pc = bp.m_instr->m_next;
28362889
goto main_loop;
28372890

28382891
case CONTINUE:

0 commit comments

Comments
 (0)