diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 3d1a2a5386..0eadbcdc10 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,6 +1,6 @@ //! [ascent] datalog implementation of analysis. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use ascent::Lattice; use ascent::lattice::BoundedLattice; @@ -90,12 +90,13 @@ impl Machine { } /// Run the analysis (iterate until a lattice fixpoint is reached). - /// As a shortcut, for Hugrs whose root is a [`FuncDefn`](OpType::FuncDefn), - /// [CFG](OpType::CFG), [DFG](OpType::DFG), [Conditional](OpType::Conditional) - /// or [`TailLoop`] only (that is: *not* [Module](OpType::Module), + /// As a shortcut, for Hugrs whose [HugrView::entrypoint] is a + /// [`FuncDefn`](OpType::FuncDefn), [CFG](OpType::CFG), [DFG](OpType::DFG), + /// [Conditional](OpType::Conditional) or [`TailLoop`](OpType::TailLoop) only + /// (that is: *not* [Module](OpType::Module), /// [`DataflowBlock`](OpType::DataflowBlock) or [Case](OpType::Case)), - /// `in_values` may provide initial values for the root-node inputs, - /// equivalent to calling `prepopulate_inputs` with the root node. + /// `in_values` may provide initial values for the entrypoint-node inputs, + /// equivalent to calling `prepopulate_inputs` with the entrypoint node. /// /// The context passed in allows interpretation of leaf operations. /// @@ -107,13 +108,13 @@ impl Machine { context: impl DFContext, in_values: impl IntoIterator)>, ) -> AnalysisResults { - let root = self.0.entrypoint(); - if self.0.get_optype(root).is_module() { + if self.0.entrypoint_optype().is_module() { assert!( in_values.into_iter().next().is_none(), "No inputs possible for Module" ); } else { + let ep = self.0.entrypoint(); let mut p = in_values.into_iter().peekable(); // We must provide some inputs to the root so that they are Top rather than Bottom. // (However, this test will fail for DataflowBlock or Case roots, i.e. if no @@ -121,12 +122,10 @@ impl Machine { // values for even these nodes in self.1 and then convert to actual Wire values // (outputs from the Input node) before we run_datalog, but we would need to have // a separate store of output-wire values in self to keep prepopulate_wire working.) - if p.peek().is_some() || !self.1.contains_key(&root) { - self.prepopulate_inputs(root, p).unwrap(); + if p.peek().is_some() || !self.1.contains_key(&ep) { + self.prepopulate_inputs(ep, p).unwrap(); } } - // Note/TODO, if analysis is running on a subregion then we should do similar - // for any nonlocal edges providing values from outside the region. run_datalog( context, self.0, @@ -164,7 +163,13 @@ pub(super) fn run_datalog( lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value lattice node_in_value_row(H::Node, ValueRow); // 's inputs are - node(n) <-- for n in hugr.entry_descendants(); + // Analyse all nodes as this will compute the most accurate results for the desired nodes + // (i.e. the entry_descendants). Moreover, this is the only sound policy until we correctly + // mark incoming edges as `Top`, see https://github.com/CQCL/hugr/issues/2254), so is a + // workaround for that. + // When that issue is solved, we can consider a flag to restrict analysis to the subregion + // (for efficiency - will still decrease accuracy of solutions, but will at least be safe). + node(n) <-- for n in hugr.nodes(); in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise) @@ -359,17 +364,31 @@ pub(super) fn run_datalog( if matches!(v, PartialValue::Top | PartialValue::Value(_)), for p in ci.signature().output_ports(); }; + let entry_descs = hugr.entry_descendants().collect::>(); let out_wire_values = all_results .out_wire_value .iter() + .filter(|(n, _, _)| entry_descs.contains(n)) .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { hugr, out_wire_values, - in_wire_value: all_results.in_wire_value, - case_reachable: all_results.case_reachable, - bb_reachable: all_results.bb_reachable, + in_wire_value: all_results + .in_wire_value + .into_iter() + .filter(|(n, _, _)| entry_descs.contains(n)) + .collect(), + case_reachable: all_results + .case_reachable + .into_iter() + .filter(|(_, n)| entry_descs.contains(n)) + .collect(), + bb_reachable: all_results + .bb_reachable + .into_iter() + .filter(|(_, n)| entry_descs.contains(n)) + .collect(), } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 1c0d6c6cea..205d9ba4fa 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -442,47 +442,34 @@ fn test_region() { let mut builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap(); let [in_w] = builder.input_wires_arr(); let cst_w = builder.add_load_const(Value::false_val()); + // Create a nested DFG which gets in_w passed as an input, but has a nonlocal edge + // from the LoadConstant let nested = builder - .dfg_builder(Signature::new_endo(vec![bool_t(); 2]), [in_w, cst_w]) + .dfg_builder(Signature::new(bool_t(), vec![bool_t(); 2]), [in_w]) .unwrap(); - let nested_ins = nested.input_wires(); - let nested = nested.finish_with_outputs(nested_ins).unwrap(); + let [nested_in] = nested.input_wires_arr(); + let nested = nested.finish_with_outputs([nested_in, cst_w]).unwrap(); let hugr = builder.finish_hugr_with_outputs(nested.outputs()).unwrap(); let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); let whole_hugr_results = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); - assert_eq!( - whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), - Some(pv_true()) - ); - assert_eq!( - whole_hugr_results.read_out_wire(Wire::new(nested_input, 1)), - Some(pv_false()) - ); - assert_eq!( - whole_hugr_results.read_out_wire(Wire::new(hugr.entrypoint(), 0)), - Some(pv_true()) - ); - assert_eq!( - whole_hugr_results.read_out_wire(Wire::new(hugr.entrypoint(), 1)), - Some(pv_false()) - ); - - // Do not provide a value on the second input (constant false in the whole hugr, above) let sub_hugr_results = Machine::new(hugr.with_entrypoint(nested.node())).run(TestContext, [(0.into(), pv_true())]); - assert_eq!( - sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), - Some(pv_true()) - ); - assert_eq!( - sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)), - Some(PartialValue::Top) - ); - for w in [0, 1] { - assert_eq!( - sub_hugr_results.read_out_wire(Wire::new(hugr.entrypoint(), w)), - None - ); + for (wire, val) in [ + (Wire::new(nested_input, 0), Some(pv_true())), + (Wire::new(nested.node(), 0), Some(pv_true())), + (Wire::new(nested.node(), 1), Some(pv_false())), + ] { + assert_eq!(whole_hugr_results.read_out_wire(wire), val); + assert_eq!(sub_hugr_results.read_out_wire(wire), val); + } + + for (wire, val) in [ + (cst_w, pv_false()), + (Wire::new(hugr.entrypoint(), 0), pv_true()), + (Wire::new(hugr.entrypoint(), 1), pv_false()), + ] { + assert_eq!(whole_hugr_results.read_out_wire(wire), Some(val)); + assert_eq!(sub_hugr_results.read_out_wire(wire), None); } }