Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! [ascent] datalog implementation of analysis.

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use ascent::Lattice;
use ascent::lattice::BoundedLattice;
Expand Down Expand Up @@ -90,12 +90,13 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
}

/// Run the analysis (iterate until a lattice fixpoint is reached).
/// As a shortcut, for Hugrs whose root is a [`FuncDefn`](OpType::FuncDefn),
/// [CFG](OpType::CFG), [DFG](OpType::DFG), [Conditional](OpType::Conditional)
/// or [`TailLoop`] only (that is: *not* [Module](OpType::Module),
/// As a shortcut, for Hugrs whose [HugrView::entrypoint] is a
/// [`FuncDefn`](OpType::FuncDefn), [CFG](OpType::CFG), [DFG](OpType::DFG),
/// [Conditional](OpType::Conditional) or [`TailLoop`](OpType::TailLoop) only
/// (that is: *not* [Module](OpType::Module),
/// [`DataflowBlock`](OpType::DataflowBlock) or [Case](OpType::Case)),
/// `in_values` may provide initial values for the root-node inputs,
/// equivalent to calling `prepopulate_inputs` with the root node.
/// `in_values` may provide initial values for the entrypoint-node inputs,
/// equivalent to calling `prepopulate_inputs` with the entrypoint node.
///
/// The context passed in allows interpretation of leaf operations.
///
Expand All @@ -107,26 +108,24 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
context: impl DFContext<V, Node = H::Node>,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V, H::Node>)>,
) -> AnalysisResults<V, H> {
let root = self.0.entrypoint();
if self.0.get_optype(root).is_module() {
if self.0.entrypoint_optype().is_module() {
assert!(
in_values.into_iter().next().is_none(),
"No inputs possible for Module"
);
} else {
let ep = self.0.entrypoint();
let mut p = in_values.into_iter().peekable();
// We must provide some inputs to the root so that they are Top rather than Bottom.
// (However, this test will fail for DataflowBlock or Case roots, i.e. if no
// inputs have been provided they will still see Bottom. We could store the "input"
// values for even these nodes in self.1 and then convert to actual Wire values
// (outputs from the Input node) before we run_datalog, but we would need to have
// a separate store of output-wire values in self to keep prepopulate_wire working.)
if p.peek().is_some() || !self.1.contains_key(&root) {
self.prepopulate_inputs(root, p).unwrap();
if p.peek().is_some() || !self.1.contains_key(&ep) {
self.prepopulate_inputs(ep, p).unwrap();
}
}
// Note/TODO, if analysis is running on a subregion then we should do similar
// for any nonlocal edges providing values from outside the region.
run_datalog(
context,
self.0,
Expand Down Expand Up @@ -164,7 +163,13 @@ pub(super) fn run_datalog<V: AbstractValue, H: HugrView>(
lattice in_wire_value(H::Node, IncomingPort, PV<V, H::Node>); // <Node> receives, on <IncomingPort>, the value <PV>
lattice node_in_value_row(H::Node, ValueRow<V, H::Node>); // <Node>'s inputs are <ValueRow>

node(n) <-- for n in hugr.entry_descendants();
// Analyse all nodes as this will compute the most accurate results for the desired nodes
// (i.e. the entry_descendants). Moreover, this is the only sound policy until we correctly
// mark incoming edges as `Top`, see https://github.com/CQCL/hugr/issues/2254), so is a
// workaround for that.
// When that issue is solved, we can consider a flag to restrict analysis to the subregion
// (for efficiency - will still decrease accuracy of solutions, but will at least be safe).
node(n) <-- for n in hugr.nodes();

in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only
out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise)
Expand Down Expand Up @@ -359,17 +364,31 @@ pub(super) fn run_datalog<V: AbstractValue, H: HugrView>(
if matches!(v, PartialValue::Top | PartialValue::Value(_)),
for p in ci.signature().output_ports();
};
let entry_descs = hugr.entry_descendants().collect::<HashSet<_>>();
let out_wire_values = all_results
.out_wire_value
.iter()
.filter(|(n, _, _)| entry_descs.contains(n))
.map(|(n, p, v)| (Wire::new(*n, *p), v.clone()))
.collect();
AnalysisResults {
hugr,
out_wire_values,
in_wire_value: all_results.in_wire_value,
case_reachable: all_results.case_reachable,
bb_reachable: all_results.bb_reachable,
in_wire_value: all_results
.in_wire_value
.into_iter()
.filter(|(n, _, _)| entry_descs.contains(n))
.collect(),
case_reachable: all_results
.case_reachable
.into_iter()
.filter(|(_, n)| entry_descs.contains(n))
.collect(),
bb_reachable: all_results
.bb_reachable
.into_iter()
.filter(|(_, n)| entry_descs.contains(n))
.collect(),
}
}

Expand Down
55 changes: 21 additions & 34 deletions hugr-passes/src/dataflow/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,47 +442,34 @@ fn test_region() {
let mut builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap();
let [in_w] = builder.input_wires_arr();
let cst_w = builder.add_load_const(Value::false_val());
// Create a nested DFG which gets in_w passed as an input, but has a nonlocal edge
// from the LoadConstant
let nested = builder
.dfg_builder(Signature::new_endo(vec![bool_t(); 2]), [in_w, cst_w])
.dfg_builder(Signature::new(bool_t(), vec![bool_t(); 2]), [in_w])
.unwrap();
let nested_ins = nested.input_wires();
let nested = nested.finish_with_outputs(nested_ins).unwrap();
let [nested_in] = nested.input_wires_arr();
let nested = nested.finish_with_outputs([nested_in, cst_w]).unwrap();
let hugr = builder.finish_hugr_with_outputs(nested.outputs()).unwrap();
let [nested_input, _] = hugr.get_io(nested.node()).unwrap();
let whole_hugr_results = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]);
assert_eq!(
whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)),
Some(pv_true())
);
assert_eq!(
whole_hugr_results.read_out_wire(Wire::new(nested_input, 1)),
Some(pv_false())
);
assert_eq!(
whole_hugr_results.read_out_wire(Wire::new(hugr.entrypoint(), 0)),
Some(pv_true())
);
assert_eq!(
whole_hugr_results.read_out_wire(Wire::new(hugr.entrypoint(), 1)),
Some(pv_false())
);

// Do not provide a value on the second input (constant false in the whole hugr, above)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not testing what was intended: not passing the second input, just makes Machine::run / Machine::prepopulate_inputs provide Top on the missing input, so there is no connection with (or reliance on the edge from) the graph outside the entrypoint. Hence, replacing the second input with a nonlocal edge.

let sub_hugr_results =
Machine::new(hugr.with_entrypoint(nested.node())).run(TestContext, [(0.into(), pv_true())]);
assert_eq!(
sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)),
Some(pv_true())
);
assert_eq!(
sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)),
Some(PartialValue::Top)
);
for w in [0, 1] {
assert_eq!(
sub_hugr_results.read_out_wire(Wire::new(hugr.entrypoint(), w)),
None
);
for (wire, val) in [
(Wire::new(nested_input, 0), Some(pv_true())),
(Wire::new(nested.node(), 0), Some(pv_true())),
(Wire::new(nested.node(), 1), Some(pv_false())),
] {
assert_eq!(whole_hugr_results.read_out_wire(wire), val);
assert_eq!(sub_hugr_results.read_out_wire(wire), val);
}

for (wire, val) in [
(cst_w, pv_false()),
(Wire::new(hugr.entrypoint(), 0), pv_true()),
(Wire::new(hugr.entrypoint(), 1), pv_false()),
] {
assert_eq!(whole_hugr_results.read_out_wire(wire), Some(val));
assert_eq!(sub_hugr_results.read_out_wire(wire), None);
}
}

Expand Down
Loading