Skip to content

Commit f932d48

Browse files
use propagation queues and hash-tables to schedule bindings
1 parent 7b432ae commit f932d48

File tree

2 files changed

+158
-13
lines changed

2 files changed

+158
-13
lines changed

src/ast/simplifiers/euf_completion.cpp

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ namespace euf {
6666
m_canonical(m),
6767
m_eargs(m),
6868
m_canonical_proofs(m),
69+
m_infer_patterns(m, m_smt_params),
6970
m_deps(m),
7071
m_rewriter(m) {
7172
m_tt = m_egraph.mk(m.mk_true(), 0, 0, nullptr);
@@ -196,9 +197,10 @@ namespace euf {
196197
m_should_propagate = false;
197198
m_egraph.propagate();
198199
m_mam->propagate();
200+
flush_binding_queue();
199201
propagate_rules();
200202
IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n");
201-
if (!m_should_propagate)
203+
if (!m_should_propagate && !should_stop())
202204
propagate_all_rules();
203205
}
204206
}
@@ -229,7 +231,8 @@ namespace euf {
229231
}
230232
else if (m.is_not(f, f)) {
231233
enode* n = mk_enode(f);
232-
m_egraph.merge(n, m_ff, to_ptr(push_pr_dep(pr, d)));
234+
auto j = to_ptr(push_pr_dep(pr, d));
235+
m_egraph.new_diseq(n, j);
233236
add_children(n);
234237
}
235238
else {
@@ -238,6 +241,12 @@ namespace euf {
238241
add_children(n);
239242
if (is_forall(f)) {
240243
quantifier* q = to_quantifier(f);
244+
if (q->get_num_patterns() == 0) {
245+
expr_ref tmp(m);
246+
m_infer_patterns(q, tmp);
247+
m_egraph.mk(tmp, 0, 0, nullptr); // ensure tmp is pinned within this scope.
248+
q = to_quantifier(tmp);
249+
}
241250
ptr_vector<app> ground;
242251
for (unsigned i = 0; i < q->get_num_patterns(); ++i) {
243252
auto p = to_app(q->get_pattern(i));
@@ -396,33 +405,100 @@ namespace euf {
396405
}
397406
}
398407

408+
binding* completion::tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding) {
409+
if (q->get_num_decls() > m_tmp_binding_capacity) {
410+
void* mem = memory::allocate(sizeof(binding) + q->get_num_decls() * sizeof(euf::enode*));
411+
m_tmp_binding = new (mem) binding(q, pat, 0, 0, 0);
412+
m_tmp_binding_capacity = q->get_num_decls();
413+
}
414+
415+
for (unsigned i = q->get_num_decls(); i-- > 0; )
416+
m_tmp_binding->m_nodes[i] = _binding[i];
417+
m_tmp_binding->m_pattern = pat;
418+
m_tmp_binding->m_q = q;
419+
return m_tmp_binding.get();
420+
}
421+
422+
binding* completion::alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top) {
423+
binding* b = tmp_binding(q, pat, _binding);
424+
425+
if (m_bindings.contains(b))
426+
return nullptr;
427+
428+
for (unsigned i = q->get_num_decls(); i-- > 0; )
429+
b->m_nodes[i] = b->m_nodes[i]->get_root();
430+
431+
if (m_bindings.contains(b))
432+
return nullptr;
433+
434+
unsigned n = q->get_num_decls();
435+
unsigned sz = sizeof(binding) + sizeof(euf::enode* const*) * n;
436+
void* mem = get_region().allocate(sz);
437+
b = new (mem) binding(q, pat, max_generation, min_top, max_top);
438+
b->init(b);
439+
for (unsigned i = 0; i < n; ++i)
440+
b->m_nodes[i] = _binding[i];
441+
442+
m_bindings.insert(b);
443+
get_trail().push(insert_map<bindings, binding*>(m_bindings, b));
444+
return b;
445+
}
446+
399447
// callback when mam finds a binding
400-
void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) {
448+
void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned max_global, unsigned min_top, unsigned max_top) {
449+
if (should_stop())
450+
return;
451+
auto* b = alloc_binding(q, pat, binding, max_global, min_top, max_top);
452+
if (!b)
453+
return;
454+
insert_binding(b);
455+
}
456+
457+
void completion::insert_binding(binding* b) {
458+
m_queue.reserve(b->m_max_top_generation + 1);
459+
m_queue[b->m_max_top_generation].push_back(b);
460+
}
461+
462+
void completion::flush_binding_queue() {
463+
TRACE(euf_completion,
464+
tout << "flush-queue\n";
465+
for (unsigned i = 0; i < m_queue.size(); ++i)
466+
tout << i << ": " << m_queue[i].size() << "\n";);
467+
IF_VERBOSE(10,
468+
verbose_stream() << "flush-queue\n";
469+
for (unsigned i = 0; i < m_queue.size(); ++i)
470+
verbose_stream() << i << ": " << m_queue[i].size() << "\n");
471+
472+
for (auto& g : m_queue) {
473+
for (auto b : g)
474+
apply_binding(*b);
475+
g.reset();
476+
}
477+
}
478+
479+
void completion::apply_binding(binding& b) {
401480
if (should_stop())
402481
return;
403482
var_subst subst(m);
404483
expr_ref_vector _binding(m);
405-
unsigned max_generation = 0;
406-
for (unsigned i = 0; i < q->get_num_decls(); ++i) {
407-
_binding.push_back(binding[i]->get_expr());
408-
max_generation = std::max(max_generation, binding[i]->generation());
409-
}
484+
quantifier* q = b.m_q;
485+
for (unsigned i = 0; i < q->get_num_decls(); ++i)
486+
_binding.push_back(b.m_nodes[i]->get_expr());
487+
410488
expr_ref r = subst(q->get_expr(), _binding);
411-
IF_VERBOSE(12, verbose_stream() << "add " << r << "\n");
412-
IF_VERBOSE(10, verbose_stream() << max_generation << "\n");
413-
scoped_generation sg(*this, max_generation + 1);
489+
490+
scoped_generation sg(*this, b.m_max_top_generation + 1);
414491
auto [pr, d] = get_dependency(q);
415492
if (pr)
416493
pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), _binding.size(), _binding.data());
417494
add_constraint(r, pr, d);
418495
propagate_rules();
496+
m_egraph.propagate();
419497
m_should_propagate = true;
420498
++m_stats.m_num_instances;
421499
}
422500

423501
void completion::read_egraph() {
424-
//m_egraph.display(verbose_stream());
425-
//exit(0);
426502
if (m_egraph.inconsistent()) {
427503
auto* d = explain_conflict();
428504
proof_ref pr(m);

src/ast/simplifiers/euf_completion.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ Module Name:
2121
#pragma once
2222

2323
#include "util/scoped_vector.h"
24+
#include "util/dlist.h"
2425
#include "ast/simplifiers/dependent_expr_state.h"
2526
#include "ast/euf/euf_egraph.h"
2627
#include "ast/euf/euf_mam.h"
2728
#include "ast/rewriter/th_rewriter.h"
29+
#include "ast/pattern/pattern_inference.h"
30+
#include "params/smt_params.h"
2831

2932
namespace euf {
3033

@@ -43,6 +46,60 @@ namespace euf {
4346
virtual void solve_for(vector<solution>& sol) = 0;
4447
};
4548

49+
struct binding : public dll_base<binding> {
50+
quantifier* m_q;
51+
app* m_pattern;
52+
unsigned m_max_generation;
53+
unsigned m_min_top_generation;
54+
unsigned m_max_top_generation;
55+
euf::enode* m_nodes[0];
56+
57+
binding(quantifier* q, app* pat, unsigned max_generation, unsigned min_top, unsigned max_top) :
58+
m_q(q),
59+
m_pattern(pat),
60+
m_max_generation(max_generation),
61+
m_min_top_generation(min_top),
62+
m_max_top_generation(max_top) {
63+
}
64+
65+
euf::enode* const* nodes() { return m_nodes; }
66+
67+
euf::enode* operator[](unsigned i) const { return m_nodes[i]; }
68+
69+
unsigned size() const { return m_q->get_num_decls(); }
70+
71+
quantifier* q() const { return m_q; }
72+
73+
bool eq(binding const& other) const {
74+
if (q() != other.q())
75+
return false;
76+
for (unsigned i = size(); i-- > 0; )
77+
if ((*this)[i] != other[i])
78+
return false;
79+
return true;
80+
}
81+
};
82+
83+
struct binding_khasher {
84+
unsigned operator()(binding const* f) const { return f->q()->get_id(); }
85+
};
86+
87+
struct binding_chasher {
88+
unsigned operator()(binding const* f, unsigned idx) const { return f->m_nodes[idx]->hash(); }
89+
};
90+
91+
struct binding_hash_proc {
92+
unsigned operator()(binding const* f) const {
93+
return get_composite_hash<binding*, binding_khasher, binding_chasher>(const_cast<binding*>(f), f->size());
94+
}
95+
};
96+
97+
struct binding_eq_proc {
98+
bool operator()(binding const* a, binding const* b) const { return a->eq(*b); }
99+
};
100+
101+
typedef ptr_hashtable<binding, binding_hash_proc, binding_eq_proc> bindings;
102+
46103
class completion : public dependent_expr_simplifier, public on_binding_callback, public mam_solver {
47104

48105
struct stats {
@@ -63,13 +120,18 @@ namespace euf {
63120
m_body(b), m_head(h), m_proofs(prs), m_dep(d) {}
64121
};
65122

123+
smt_params m_smt_params;
66124
egraph m_egraph;
67125
scoped_ptr<mam> m_mam;
68126
enode* m_tt, *m_ff;
69127
ptr_vector<expr> m_todo;
70128
enode_vector m_args, m_reps, m_nodes_to_canonize;
71129
expr_ref_vector m_canonical, m_eargs;
72130
proof_ref_vector m_canonical_proofs;
131+
pattern_inference_rw m_infer_patterns;
132+
bindings m_bindings;
133+
scoped_ptr<binding> m_tmp_binding;
134+
unsigned m_tmp_binding_capacity = 0;
73135
expr_dependency_ref_vector m_deps;
74136
obj_map<quantifier, std::pair<proof*, expr_dependency*>> m_q2dep;
75137
vector<std::pair<proof_ref, expr_dependency*>> m_pr_dep;
@@ -109,6 +171,13 @@ namespace euf {
109171
expr_dependency* explain_conflict();
110172
std::pair<proof*, expr_dependency*> get_dependency(quantifier* q) { return m_q2dep.contains(q) ? m_q2dep[q] : std::pair(nullptr, nullptr); }
111173

174+
binding* tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding);
175+
binding* alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top);
176+
void insert_binding(binding* b);
177+
void apply_binding(binding& b);
178+
void flush_binding_queue();
179+
vector<ptr_vector<binding>> m_queue;
180+
112181
lbool eval_cond(expr* f, proof_ref& pr, expr_dependency*& d);
113182

114183
bool should_stop();

0 commit comments

Comments
 (0)