From c3b48747fab998df01a518de0c8a3569ed05cd11 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Wed, 2 Apr 2025 12:44:41 -0600 Subject: [PATCH 01/19] adding fields for best xhat and recent xhats --- mpisppy/cylinders/spoke.py | 4 ++-- mpisppy/cylinders/spwindow.py | 41 ++++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index 69485c373..56d7d6be5 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -154,7 +154,7 @@ class InnerBoundSpoke(_BoundSpoke): Hub, and do not need information from the main PH OPT hub. """ - send_fields = (*_BoundSpoke.send_fields, Field.OBJECTIVE_INNER_BOUND, ) + send_fields = (*_BoundSpoke.send_fields, Field.OBJECTIVE_INNER_BOUND, Field.BEST_XHAT, Field.RECENT_XHATS, ) receive_fields = (*_BoundSpoke.receive_fields, ) converger_spoke_char = 'I' @@ -246,7 +246,7 @@ class InnerBoundNonantSpoke(_BoundNonantSpoke): and restoring results """ - send_fields = (*_BoundNonantSpoke.send_fields, Field.OBJECTIVE_INNER_BOUND, ) + send_fields = (*_BoundNonantSpoke.send_fields, Field.OBJECTIVE_INNER_BOUND, Field.BEST_XHAT, Field.RECENT_XHATS, ) receive_fields = (*_BoundNonantSpoke.receive_fields, Field.NONANT) converger_spoke_char = 'I' diff --git a/mpisppy/cylinders/spwindow.py b/mpisppy/cylinders/spwindow.py index 6dbfd1cdc..d65b11c3d 100644 --- a/mpisppy/cylinders/spwindow.py +++ b/mpisppy/cylinders/spwindow.py @@ -30,14 +30,19 @@ class Field(enum.IntEnum): CROSS_SCENARIO_COST=400 NONANT_LOWER_BOUNDS=500 NONANT_UPPER_BOUNDS=501 + BEST_XHAT=600 # buffer having the best xhat and its total cost per scenario + RECENT_XHATS=601 # buffer having some recent xhats and their total cost per scenario WHOLE=1_000_000 -_field_length_components = pyo.ConcreteModel() -_field_length_components.local_nonant_length = pyo.Param(mutable=True) -_field_length_components.local_scenario_length = pyo.Param(mutable=True) -_field_length_components.total_number_nonants = pyo.Param(mutable=True) -_field_length_components.total_number_scenarios = pyo.Param(mutable=True) +field_length_components = pyo.ConcreteModel() +field_length_components._local_nonant_length = pyo.Param(mutable=True) +field_length_components._local_scenario_length = pyo.Param(mutable=True) +field_length_components._total_number_nonants = pyo.Param(mutable=True) +field_length_components._total_number_scenarios = pyo.Param(mutable=True) + +# these could be modified by the user... +field_length_components.total_number_recent_xhats = pyo.Param(mutable=True, initialize=10, within=pyo.NonNegativeIntegers) _field_lengths = { Field.SHUTDOWN : 1, @@ -47,12 +52,14 @@ class Field(enum.IntEnum): Field.BEST_OBJECTIVE_BOUNDS : 2, Field.OBJECTIVE_INNER_BOUND : 1, Field.OBJECTIVE_OUTER_BOUND : 1, - Field.EXPECTED_REDUCED_COST : _field_length_components.total_number_nonants, - Field.SCENARIO_REDUCED_COST : _field_length_components.local_nonant_length, - Field.CROSS_SCENARIO_CUT : _field_length_components.total_number_scenarios * (_field_length_components.total_number_nonants + 1 + 1), - Field.CROSS_SCENARIO_COST : _field_length_components.total_number_scenarios * _field_length_components.total_number_scenarios, - Field.NONANT_LOWER_BOUNDS : _field_length_components.total_number_nonants, - Field.NONANT_UPPER_BOUNDS : _field_length_components.total_number_nonants, + Field.EXPECTED_REDUCED_COST : field_length_components._total_number_nonants, + Field.SCENARIO_REDUCED_COST : field_length_components._local_nonant_length, + Field.CROSS_SCENARIO_CUT : field_length_components._total_number_scenarios * (field_length_components._total_number_nonants + 1 + 1), + Field.CROSS_SCENARIO_COST : field_length_components._total_number_scenarios * field_length_components._total_number_scenarios, + Field.NONANT_LOWER_BOUNDS : field_length_components._total_number_nonants, + Field.NONANT_UPPER_BOUNDS : field_length_components._total_number_nonants, + Field.BEST_XHAT : field_length_components._local_nonant_length + field_length_components._local_scenario_length, + Field.RECENT_XHATS : field_length_components.total_number_recent_xhats * (field_length_components._local_nonant_length + field_length_components._local_scenario_length), } @@ -65,15 +72,15 @@ def __init__(self, opt): ) ) - _field_length_components.local_nonant_length.value = number_nonants - _field_length_components.local_scenario_length.value = len(opt.local_scenarios) - _field_length_components.total_number_nonants.value = opt.nonant_length - _field_length_components.total_number_scenarios.value = len(opt.local_scenarios) + field_length_components._local_nonant_length.value = number_nonants + field_length_components._local_scenario_length.value = len(opt.local_scenarios) + field_length_components._total_number_nonants.value = opt.nonant_length + field_length_components._total_number_scenarios.value = len(opt.local_scenarios) self._field_lengths = {k : pyo.value(v) for k, v in _field_lengths.items()} - # reset the _field_length_components - for p in _field_length_components.component_data_objects(): + # reset the field_length_components + for p in field_length_components.component_data_objects(): p.clear() def __getitem__(self, field: Field): From e24620293e3c1a1357c616efe6d9f3c42a0965a9 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Mon, 16 Jun 2025 11:49:03 -0600 Subject: [PATCH 02/19] properly resolve conflict in spwindow --- mpisppy/cylinders/spwindow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mpisppy/cylinders/spwindow.py b/mpisppy/cylinders/spwindow.py index d65b11c3d..6b9bb41a0 100644 --- a/mpisppy/cylinders/spwindow.py +++ b/mpisppy/cylinders/spwindow.py @@ -46,9 +46,9 @@ class Field(enum.IntEnum): _field_lengths = { Field.SHUTDOWN : 1, - Field.NONANT : _field_length_components.local_nonant_length, - Field.DUALS : _field_length_components.local_nonant_length, - Field.RELAXED_NONANT : _field_length_components.local_nonant_length, + Field.NONANT : field_length_components.local_nonant_length, + Field.DUALS : field_length_components.local_nonant_length, + Field.RELAXED_NONANT : field_length_components.local_nonant_length, Field.BEST_OBJECTIVE_BOUNDS : 2, Field.OBJECTIVE_INNER_BOUND : 1, Field.OBJECTIVE_OUTER_BOUND : 1, From c20381706a5b75f1ca20c18e302fe48a288a0582 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Mon, 16 Jun 2025 11:57:46 -0600 Subject: [PATCH 03/19] resolve merge conflicts part 2 --- mpisppy/cylinders/spwindow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mpisppy/cylinders/spwindow.py b/mpisppy/cylinders/spwindow.py index 6b9bb41a0..1145a7761 100644 --- a/mpisppy/cylinders/spwindow.py +++ b/mpisppy/cylinders/spwindow.py @@ -46,9 +46,9 @@ class Field(enum.IntEnum): _field_lengths = { Field.SHUTDOWN : 1, - Field.NONANT : field_length_components.local_nonant_length, - Field.DUALS : field_length_components.local_nonant_length, - Field.RELAXED_NONANT : field_length_components.local_nonant_length, + Field.NONANT : field_length_components._local_nonant_length, + Field.DUALS : field_length_components._local_nonant_length, + Field.RELAXED_NONANT : field_length_components._local_nonant_length, Field.BEST_OBJECTIVE_BOUNDS : 2, Field.OBJECTIVE_INNER_BOUND : 1, Field.OBJECTIVE_OUTER_BOUND : 1, From 3d3ce2aceff9cf92dfacc9722a449c50d107f351 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 09:28:30 -0600 Subject: [PATCH 04/19] fix for EF bundles --- mpisppy/extensions/xhatbase.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mpisppy/extensions/xhatbase.py b/mpisppy/extensions/xhatbase.py index 83adb4183..454940167 100644 --- a/mpisppy/extensions/xhatbase.py +++ b/mpisppy/extensions/xhatbase.py @@ -202,6 +202,10 @@ def _try_one(self, snamedict, solver_options=None, verbose=False, self.opt.local_scenarios[sname].pprint() # get the global obj obj = self.opt.Eobjective(verbose=verbose) + # set the scenario objective value for communication + for k,s in self.opt.local_scenarios.items(): + objfct = self.opt.saved_objectives[k] + s._mpisppy_data.inner_bound = pyo.value(objfct) self.opt.update_best_solution_if_improving(obj) if restore_nonants: self.opt._restore_nonants() From 828e84964d65a2cffb144638c4fa8fb176df41c2 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 09:36:00 -0600 Subject: [PATCH 05/19] only reset private parameters --- mpisppy/cylinders/spwindow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mpisppy/cylinders/spwindow.py b/mpisppy/cylinders/spwindow.py index 1145a7761..3e2b915fe 100644 --- a/mpisppy/cylinders/spwindow.py +++ b/mpisppy/cylinders/spwindow.py @@ -81,7 +81,10 @@ def __init__(self, opt): # reset the field_length_components for p in field_length_components.component_data_objects(): - p.clear() + # leave user-set parameter alone, just clear the + # "private" parameters + if p.name[0] == "_": + p.clear() def __getitem__(self, field: Field): return self._field_lengths[field] From 7f18b37699cfda93a494cf765f97d659ee51b276 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 09:44:55 -0600 Subject: [PATCH 06/19] standarize agnostic to use inner_bound --- mpisppy/agnostic/ampl_guest.py | 6 +++--- mpisppy/agnostic/examples/farmer_gurobipy_model.py | 2 +- mpisppy/agnostic/gams_guest.py | 4 ++-- mpisppy/agnostic/pyomo_guest.py | 2 +- mpisppy/spopt.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mpisppy/agnostic/ampl_guest.py b/mpisppy/agnostic/ampl_guest.py index 195185071..b5e1aa48a 100644 --- a/mpisppy/agnostic/ampl_guest.py +++ b/mpisppy/agnostic/ampl_guest.py @@ -270,7 +270,7 @@ def solve_one(self, Ag, s, solve_keyword_args, gripe, tee=False, need_solution=T if gripe: print (f"Solve failed for scenario {s.name} on rank {global_rank}") print(f"{gs.solve_result =}") - s._mpisppy_data._obj_from_agnostic = None + s._mpisppy_data.inner_bound = None return else: @@ -289,7 +289,7 @@ def solve_one(self, Ag, s, solve_keyword_args, gripe, tee=False, need_solution=T if gd["sense"] == pyo.minimize: s._mpisppy_data.outer_bound = objval - mipgap else: - s._mpisppy_data.inner_bound = objval + mipgap + s._mpisppy_data.outer_bound = objval + mipgap # copy the nonant x values from gs to s so mpisppy can use them in s # in general, we need more checks (see the pyomo agnostic guest example) @@ -311,7 +311,7 @@ def solve_one(self, Ag, s, solve_keyword_args, gripe, tee=False, need_solution=T s._mpisppy_data.nonant_indices[ndn_i]._value = gxvar.value() - s._mpisppy_data._obj_from_agnostic = objval + s._mpisppy_data.inner_bound = objval # local helper diff --git a/mpisppy/agnostic/examples/farmer_gurobipy_model.py b/mpisppy/agnostic/examples/farmer_gurobipy_model.py index e42cd5bc7..8f8bb3fd2 100644 --- a/mpisppy/agnostic/examples/farmer_gurobipy_model.py +++ b/mpisppy/agnostic/examples/farmer_gurobipy_model.py @@ -200,7 +200,7 @@ def solve_one(Ag, s, solve_keyword_args, gripe, tee, need_solution=True): s._mpisppy_data.nonant_indices[ndn_i]._value = grb_var.X # Store the objective function value in the host scenario - s._mpisppy_data._obj_from_agnostic = objval + s._mpisppy_data.inner_bound = objval # Additional checks and operations for bundling if needed (depending on the problem) # ... diff --git a/mpisppy/agnostic/gams_guest.py b/mpisppy/agnostic/gams_guest.py index 829d6ad77..e33e1f72b 100644 --- a/mpisppy/agnostic/gams_guest.py +++ b/mpisppy/agnostic/gams_guest.py @@ -204,7 +204,7 @@ def solve_one(self, Ag, s, solve_keyword_args, gripe, tee, need_solution=True): if gripe: print (f"Solve failed for scenario {s.name} on rank {global_rank}") print(f"{gs.model_status =}") - s._mpisppy_data._obj_from_agnostic = None + s._mpisppy_data.inner_bound = None return if solver_exception is not None and need_solution: @@ -248,7 +248,7 @@ def solve_one(self, Ag, s, solve_keyword_args, gripe, tee, need_solution=True): s._mpisppy_data.outer_bound = objval # the next line ignores bundling - s._mpisppy_data._obj_from_agnostic = objval + s._mpisppy_data.inner_bound = objval # TBD: deal with other aspects of bundling (see solve_one in spopt.py) #print(f"For {s.name} in {global_rank=}: {objval=}") diff --git a/mpisppy/agnostic/pyomo_guest.py b/mpisppy/agnostic/pyomo_guest.py index 3dae79487..7233e5544 100644 --- a/mpisppy/agnostic/pyomo_guest.py +++ b/mpisppy/agnostic/pyomo_guest.py @@ -262,7 +262,7 @@ def solve_one(self, Ag, s, solve_keyword_args, gripe, tee=False, need_solution=T s._mpisppy_data.nonant_indices[ndn_i]._value = gxvar._value # the next line ignore bundles (other than proper bundles) - s._mpisppy_data._obj_from_agnostic = pyo.value(sputils.get_objs(gs)[0]) + s._mpisppy_data.inner_bound = pyo.value(sputils.get_objs(gs)[0]) # local helper diff --git a/mpisppy/spopt.py b/mpisppy/spopt.py index 3d53276aa..d2ff30116 100644 --- a/mpisppy/spopt.py +++ b/mpisppy/spopt.py @@ -421,7 +421,7 @@ def Eobjective(self, verbose=False): local_Eobjs.append(s._mpisppy_probability * pyo.value(objfct)) else: # Agnostic will have attached the objective (and doesn't bundle as of Aug 2023) - local_Eobjs.append(s._mpisppy_probability * s._mpisppy_data._obj_from_agnostic) + local_Eobjs.append(s._mpisppy_probability * s._mpisppy_data.inner_bound) if verbose: print ("caller", inspect.stack()[1][3]) print ("E_Obj Scenario {}, prob={}, Obj={}, ObjExpr={}"\ From 8b117ed5f1ff8caa5e9c96321d77140f871ebf12 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 10:37:44 -0600 Subject: [PATCH 07/19] _obj_from_agnostic -> inner_bound in examples --- examples/farmer/agnostic/farmer_ampl_agnostic.py | 2 +- examples/farmer/agnostic/farmer_gurobipy_agnostic.py | 2 +- examples/farmer/agnostic/farmer_pyomo_agnostic.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/farmer/agnostic/farmer_ampl_agnostic.py b/examples/farmer/agnostic/farmer_ampl_agnostic.py index a4022b73a..c6dc3377a 100644 --- a/examples/farmer/agnostic/farmer_ampl_agnostic.py +++ b/examples/farmer/agnostic/farmer_ampl_agnostic.py @@ -326,7 +326,7 @@ def solve_one(Ag, s, solve_keyword_args, gripe, tee, need_solution=True): s._mpisppy_data.nonant_indices[ndn_i]._value = gxvar.value() # the next line ignores bundling - s._mpisppy_data._obj_from_agnostic = objval + s._mpisppy_data.inner_bound = objval # TBD: deal with other aspects of bundling (see solve_one in spopt.py) diff --git a/examples/farmer/agnostic/farmer_gurobipy_agnostic.py b/examples/farmer/agnostic/farmer_gurobipy_agnostic.py index 835a93b15..103f4defe 100644 --- a/examples/farmer/agnostic/farmer_gurobipy_agnostic.py +++ b/examples/farmer/agnostic/farmer_gurobipy_agnostic.py @@ -285,7 +285,7 @@ def solve_one(Ag, s, solve_keyword_args, gripe, tee, need_solution=True): s._mpisppy_data.nonant_indices[ndn_i]._value = grb_var.X # Store the objective function value in the host scenario - s._mpisppy_data._obj_from_agnostic = objval + s._mpisppy_data.inner_bound = objval # Additional checks and operations for bundling if needed (depending on the problem) # ... diff --git a/examples/farmer/agnostic/farmer_pyomo_agnostic.py b/examples/farmer/agnostic/farmer_pyomo_agnostic.py index 1a4232855..99e3f9fd0 100644 --- a/examples/farmer/agnostic/farmer_pyomo_agnostic.py +++ b/examples/farmer/agnostic/farmer_pyomo_agnostic.py @@ -241,7 +241,7 @@ def solve_one(Ag, s, solve_keyword_args, gripe, tee=False, need_solution=True): s._mpisppy_data.nonant_indices[ndn_i]._value = gxvar._value # the next line ignore bundling - s._mpisppy_data._obj_from_agnostic = pyo.value(gs.Total_Cost_Objective) + s._mpisppy_data.inner_bound = pyo.value(gs.Total_Cost_Objective) # TBD: deal with other aspects of bundling (see solve_one in spopt.py) From 86b87681d3ef6f30ea6a309daffa171dff96e186 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 10:49:09 -0600 Subject: [PATCH 08/19] put best xhat in buffer --- mpisppy/cylinders/spoke.py | 83 +++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index 56d7d6be5..a392fb6fd 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -19,7 +19,7 @@ class Spoke(SPCommunicator): send_fields = (*SPCommunicator.send_fields, ) - receive_fields = (*SPCommunicator.receive_fields, Field.SHUTDOWN, ) + receive_fields = (*SPCommunicator.receive_fields, Field.SHUTDOWN, Field.BEST_OBJECTIVE_BOUNDS, ) def got_kill_signal(self): """ Spoke should call this method at least every iteration @@ -50,9 +50,6 @@ class _BoundSpoke(Spoke): """ A base class for bound spokes """ - send_fields = (*Spoke.send_fields, ) - receive_fields = (*Spoke.receive_fields, Field.BEST_OBJECTIVE_BOUNDS) - def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options=None): super().__init__(spbase_object, fullcomm, strata_comm, cylinder_comm, options) if self.cylinder_rank == 0 and \ @@ -159,6 +156,47 @@ class InnerBoundSpoke(_BoundSpoke): converger_spoke_char = 'I' + def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options=None): + super().__init__(spbase_object, fullcomm, strata_comm, cylinder_comm, options) + self.is_minimizing = self.opt.is_minimizing + self.best_inner_bound = math.inf if self.is_minimizing else -math.inf + self.solver_options = None # can be overwritten by derived classes + + def update_if_improving(self, candidate_inner_bound, update_best_solution_cache=True): + if update_best_solution_cache: + update = self.opt.update_best_solution_if_improving(candidate_inner_bound) + else: + update = ( (candidate_inner_bound < self.best_inner_bound) + if self.is_minimizing else + (self.best_inner_bound < candidate_inner_bound) + ) + if update: + self.best_inner_bound = candidate_inner_bound + # send to hub + self.send_bound(candidate_inner_bound) + self.send_best_xhat() + return True + return False + + def send_best_xhat(self): + best_xhat_buf = self.send_buffers[Field.BEST_XHAT] + # NOTE: this does not work with "loose" bundles + ci = 0 + for s in self.opt.local_scenarios.values(): + for ndn_var in s._mpisppy_data.nonant_indices.values(): + best_xhat_buf[ci] = s._mpisppy_data.best_solution_cache[ndn_var] + ci += 1 + best_xhat_buf[ci] = s._mpisppy_data.inner_bound + ci += 1 + # print(f"{self.cylinder_rank=} sending {best_xhat_buf.value_array()=}") + self.put_send_buffer(best_xhat_buf, Field.BEST_XHAT) + + def finalize(self): + if self.opt.load_best_solution(): + self.final_bound = self.bound + return self.final_bound + return None + def bound_type(self) -> Field: return Field.OBJECTIVE_INNER_BOUND @@ -236,7 +274,7 @@ def update_nonants(self) -> bool: return self._update_nonant_len_buffer() -class InnerBoundNonantSpoke(_BoundNonantSpoke): +class InnerBoundNonantSpoke(_BoundNonantSpoke, InnerBoundSpoke): """ For Spokes that provide an inner (incumbent) bound through self.send_bound to the Hub, and receive the nonants from @@ -246,42 +284,11 @@ class InnerBoundNonantSpoke(_BoundNonantSpoke): and restoring results """ - send_fields = (*_BoundNonantSpoke.send_fields, Field.OBJECTIVE_INNER_BOUND, Field.BEST_XHAT, Field.RECENT_XHATS, ) - receive_fields = (*_BoundNonantSpoke.receive_fields, Field.NONANT) + send_fields = (*InnerBoundSpoke.send_fields, ) + receive_fields = (*InnerBoundSpoke.receive_fields, Field.NONANT) converger_spoke_char = 'I' - def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options=None): - super().__init__(spbase_object, fullcomm, strata_comm, cylinder_comm, options) - self.is_minimizing = self.opt.is_minimizing - self.best_inner_bound = math.inf if self.is_minimizing else -math.inf - self.solver_options = None # can be overwritten by derived classes - - def update_if_improving(self, candidate_inner_bound, update_best_solution_cache=True): - if update_best_solution_cache: - update = self.opt.update_best_solution_if_improving(candidate_inner_bound) - else: - update = ( (candidate_inner_bound < self.best_inner_bound) - if self.is_minimizing else - (self.best_inner_bound < candidate_inner_bound) - ) - if update: - self.best_inner_bound = candidate_inner_bound - # send to hub - self.send_bound(candidate_inner_bound) - return True - return False - - def finalize(self): - if self.opt.load_best_solution(): - self.final_bound = self.bound - return self.final_bound - return None - - def bound_type(self) -> Field: - return Field.OBJECTIVE_INNER_BOUND - - class OuterBoundNonantSpoke(_BoundNonantSpoke): """ For Spokes that provide an outer From 0e938490dc815715789e0df3fa7047a1f9585ec9 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 10:52:45 -0600 Subject: [PATCH 09/19] putting all recent solutions in buffer --- mpisppy/cylinders/spcommunicator.py | 40 +++++++++++++++++++++++++++++ mpisppy/cylinders/spoke.py | 23 ++++++++++++++++- mpisppy/spbase.py | 8 +++++- 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/mpisppy/cylinders/spcommunicator.py b/mpisppy/cylinders/spcommunicator.py index b22cd2099..d6bea9448 100644 --- a/mpisppy/cylinders/spcommunicator.py +++ b/mpisppy/cylinders/spcommunicator.py @@ -114,6 +114,46 @@ def _pull_id(self) -> int: return self._id +class _CircularBuffer: + + def __init__(self, data: FieldArray, field_length: int, buffer_size: int): + # last byte is the "write pointer" + assert len(data.value_array()) == field_length * buffer_size + self.data = data + self._field_length = field_length + self._buffer_size = buffer_size + + def _get_value_array(self, read_write_index): + position = read_write_index % self._buffer_size + return self.data._array[(position*self._field_length):((position+1)*self._field_length)] + + +class SendCircularBuffer(_CircularBuffer): + + def __init__(self, data: SendArray, field_length: int, buffer_size: int): + super().__init__(data, field_length, buffer_size) + + def next_value_array(self): + return self._get_value_array(self.data.id()) + + +class RecvCircularBuffer(_CircularBuffer): + + def __init__(self, data: RecvArray, field_length: int, buffer_size: int): + super().__init__(data, field_length, buffer_size) + self._read_id = 0 + + def next_value_arrays(self): + # if the writes have already "wrapped around" the buffer, + # we need to fast-forward the read index so we don't read + # the same data multiple times + while self.data.id() > self._read_id + self._buffer_size: + self._read_id += 1 + while self._read_id < self.data.id(): + yield self._get_value_array(self._read_id) + self._read_id += 1 + + class SPCommunicator: """ Base class for communicator objects. Each communicator object should register as a class attribute what Field attributes it provides in its buffer diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index a392fb6fd..26eb85657 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -12,7 +12,7 @@ import os import math -from mpisppy.cylinders.spcommunicator import SPCommunicator +from mpisppy.cylinders.spcommunicator import SPCommunicator, SendCircularBuffer from mpisppy.cylinders.spwindow import Field @@ -162,6 +162,14 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options= self.best_inner_bound = math.inf if self.is_minimizing else -math.inf self.solver_options = None # can be overwritten by derived classes + def register_send_fields(self): + super().register_send_fields() + self._recent_xhat_send_circular_buffer = SendCircularBuffer( + self.send_buffers[Field.RECENT_XHATS], + self._field_lengths[Field.BEST_XHAT], + self._field_lengths[Field.RECENT_XHATS] // self._field_lengths[Field.BEST_XHAT], + ) + def update_if_improving(self, candidate_inner_bound, update_best_solution_cache=True): if update_best_solution_cache: update = self.opt.update_best_solution_if_improving(candidate_inner_bound) @@ -170,6 +178,7 @@ def update_if_improving(self, candidate_inner_bound, update_best_solution_cache= if self.is_minimizing else (self.best_inner_bound < candidate_inner_bound) ) + self.send_latest_xhat() if update: self.best_inner_bound = candidate_inner_bound # send to hub @@ -191,6 +200,18 @@ def send_best_xhat(self): # print(f"{self.cylinder_rank=} sending {best_xhat_buf.value_array()=}") self.put_send_buffer(best_xhat_buf, Field.BEST_XHAT) + def send_latest_xhat(self): + recent_xhat_buf = self._recent_xhat_send_circular_buffer.next_value_array() + ci = 0 + for s in self.opt.local_scenarios.values(): + for ndn_var in s._mpisppy_data.nonant_indices.values(): + recent_xhat_buf[ci] = s._mpisppy_data.latest_solution_cache[ndn_var] + ci += 1 + recent_xhat_buf[ci] = s._mpisppy_data.inner_bound + ci += 1 + # print(f"{self.cylinder_rank=} sending {recent_xhat_buf=}") + self.put_send_buffer(self._recent_xhat_send_circular_buffer.data, Field.RECENT_XHATS) + def finalize(self): if self.opt.load_best_solution(): self.final_bound = self.bound diff --git a/mpisppy/spbase.py b/mpisppy/spbase.py index c08feb411..f6c23da78 100644 --- a/mpisppy/spbase.py +++ b/mpisppy/spbase.py @@ -571,6 +571,8 @@ def update_best_solution_if_improving(self, obj_val): update = (obj_val < self.best_solution_obj_val) else: update = (self.best_solution_obj_val < obj_val) + if obj_val is not None: + self._cache_latest_solution() if update: self.best_solution_obj_val = obj_val self._cache_best_solution() @@ -578,11 +580,15 @@ def update_best_solution_if_improving(self, obj_val): return False def _cache_best_solution(self): + for k,s in self.local_scenarios.items(): + s._mpisppy_data.best_solution_cache = s._mpisppy_data.latest_solution_cache + + def _cache_latest_solution(self): for k,s in self.local_scenarios.items(): scenario_cache = pyo.ComponentMap() for var in s.component_data_objects(pyo.Var): scenario_cache[var] = var.value - s._mpisppy_data.best_solution_cache = scenario_cache + s._mpisppy_data.latest_solution_cache = scenario_cache def _get_cylinder_name(self): if self.spcomm: From 50050e8b3835308fad7b07d7fb3ad2cc710524bb Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 12:31:24 -0600 Subject: [PATCH 10/19] trying to fix performance issue --- mpisppy/cylinders/spoke.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index 26eb85657..d1168149e 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -193,7 +193,7 @@ def send_best_xhat(self): ci = 0 for s in self.opt.local_scenarios.values(): for ndn_var in s._mpisppy_data.nonant_indices.values(): - best_xhat_buf[ci] = s._mpisppy_data.best_solution_cache[ndn_var] + best_xhat_buf[ci] = ndn_var.value ci += 1 best_xhat_buf[ci] = s._mpisppy_data.inner_bound ci += 1 @@ -205,7 +205,7 @@ def send_latest_xhat(self): ci = 0 for s in self.opt.local_scenarios.values(): for ndn_var in s._mpisppy_data.nonant_indices.values(): - recent_xhat_buf[ci] = s._mpisppy_data.latest_solution_cache[ndn_var] + recent_xhat_buf[ci] = ndn_var.value ci += 1 recent_xhat_buf[ci] = s._mpisppy_data.inner_bound ci += 1 From 2cf70b2567e9d72071bc3053031e7002558e9373 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 14:05:10 -0600 Subject: [PATCH 11/19] try to resolve performance, for real --- mpisppy/cylinders/spoke.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index d1168149e..ca48d357c 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -192,8 +192,9 @@ def send_best_xhat(self): # NOTE: this does not work with "loose" bundles ci = 0 for s in self.opt.local_scenarios.values(): + solution_cache = s._mpisppy_data.best_solution_cache for ndn_var in s._mpisppy_data.nonant_indices.values(): - best_xhat_buf[ci] = ndn_var.value + best_xhat_buf[ci] = solution_cache._dict[id(ndn_var)][1] ci += 1 best_xhat_buf[ci] = s._mpisppy_data.inner_bound ci += 1 @@ -204,8 +205,9 @@ def send_latest_xhat(self): recent_xhat_buf = self._recent_xhat_send_circular_buffer.next_value_array() ci = 0 for s in self.opt.local_scenarios.values(): + solution_cache = s._mpisppy_data.latest_solution_cache for ndn_var in s._mpisppy_data.nonant_indices.values(): - recent_xhat_buf[ci] = ndn_var.value + recent_xhat_buf[ci] = solution_cache._dict[id(ndn_var)][1] ci += 1 recent_xhat_buf[ci] = s._mpisppy_data.inner_bound ci += 1 From 59593adc672de1bfba88fdadd039d6c9b25b76b2 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 14:18:04 -0600 Subject: [PATCH 12/19] since we are misusing ComponentMap, might as well be efficient --- mpisppy/cylinders/spoke.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index ca48d357c..5ff5d2a30 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -192,9 +192,9 @@ def send_best_xhat(self): # NOTE: this does not work with "loose" bundles ci = 0 for s in self.opt.local_scenarios.values(): - solution_cache = s._mpisppy_data.best_solution_cache - for ndn_var in s._mpisppy_data.nonant_indices.values(): - best_xhat_buf[ci] = solution_cache._dict[id(ndn_var)][1] + solution_cache = s._mpisppy_data.best_solution_cache._dict + for ndn_varid in s._mpisppy_data.varid_to_nonant_index: + best_xhat_buf[ci] = solution_cache[ndn_varid][1] ci += 1 best_xhat_buf[ci] = s._mpisppy_data.inner_bound ci += 1 @@ -205,9 +205,9 @@ def send_latest_xhat(self): recent_xhat_buf = self._recent_xhat_send_circular_buffer.next_value_array() ci = 0 for s in self.opt.local_scenarios.values(): - solution_cache = s._mpisppy_data.latest_solution_cache - for ndn_var in s._mpisppy_data.nonant_indices.values(): - recent_xhat_buf[ci] = solution_cache._dict[id(ndn_var)][1] + solution_cache = s._mpisppy_data.latest_solution_cache._dict + for ndn_varid in s._mpisppy_data.varid_to_nonant_index: + recent_xhat_buf[ci] = solution_cache[ndn_varid][1] ci += 1 recent_xhat_buf[ci] = s._mpisppy_data.inner_bound ci += 1 From 55bbcb01aa65e920ae62050fb97310ed52e9ef0b Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 15:08:41 -0600 Subject: [PATCH 13/19] faster componentmap writes --- .github/workflows/test_pr_and_main.yml | 4 ++++ mpisppy/spbase.py | 8 +++++-- mpisppy/tests/test_component_map_usage.py | 28 +++++++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 mpisppy/tests/test_component_map_usage.py diff --git a/.github/workflows/test_pr_and_main.yml b/.github/workflows/test_pr_and_main.yml index 18fa3798d..3b07d9012 100644 --- a/.github/workflows/test_pr_and_main.yml +++ b/.github/workflows/test_pr_and_main.yml @@ -86,6 +86,10 @@ jobs: cd examples python afew.py xpress_persistent + - name: Test ComponentMap + run: | + python mpisppy/tests/test_component_map_usage.py + - name: Test docs run: | cd ./doc/src/ diff --git a/mpisppy/spbase.py b/mpisppy/spbase.py index f6c23da78..a22e20b4d 100644 --- a/mpisppy/spbase.py +++ b/mpisppy/spbase.py @@ -586,8 +586,7 @@ def _cache_best_solution(self): def _cache_latest_solution(self): for k,s in self.local_scenarios.items(): scenario_cache = pyo.ComponentMap() - for var in s.component_data_objects(pyo.Var): - scenario_cache[var] = var.value + _put_var_vals_in_component_map_dict(scenario_cache._dict, s) s._mpisppy_data.latest_solution_cache = scenario_cache def _get_cylinder_name(self): @@ -747,3 +746,8 @@ def write_tree_solution(self, directory_name, self.mpicomm.Barrier() for scenario_name, scenario in self.local_scenarios.items(): scenario_tree_solution_writer(directory_name, scenario_name, scenario, self.bundling) + + +def _put_var_vals_in_component_map_dict(sn_cache_dict, model): + for var in model.component_data_objects(pyo.Var): + sn_cache_dict[id(var)] = (var, var.value) diff --git a/mpisppy/tests/test_component_map_usage.py b/mpisppy/tests/test_component_map_usage.py new file mode 100644 index 000000000..8ef67e247 --- /dev/null +++ b/mpisppy/tests/test_component_map_usage.py @@ -0,0 +1,28 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### + +import pyomo.environ as pyo + +from mpisppy.spbase import _put_var_vals_in_component_map_dict + +def test_component_map_usage(): + m = pyo.ConcreteModel() + m.x = pyo.Var([1,2], initialize=2) + m.y = pyo.Var(["a", "b"], initialize=5) + m.z = pyo.Var(initialize=42) + + cmap = pyo.ComponentMap() + + _put_var_vals_in_component_map_dict(cmap._dict, m) + + assert cmap[m.x[1]] == 2 + assert cmap[m.x[2]] == 2 + assert cmap[m.y["a"]] == 5 + assert cmap[m.y["b"]] == 5 + assert cmap[m.z] == 42 From a5d030f2ac8c91efb3c8534590d648ad3e9595f3 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 17 Jun 2025 15:18:21 -0600 Subject: [PATCH 14/19] actually test ... --- .github/workflows/test_pr_and_main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_pr_and_main.yml b/.github/workflows/test_pr_and_main.yml index 3b07d9012..469e758c3 100644 --- a/.github/workflows/test_pr_and_main.yml +++ b/.github/workflows/test_pr_and_main.yml @@ -88,7 +88,7 @@ jobs: - name: Test ComponentMap run: | - python mpisppy/tests/test_component_map_usage.py + pytest mpisppy/tests/test_component_map_usage.py - name: Test docs run: | From f8d59026c75a26a2ade75be489b9afdc7a0a4f71 Mon Sep 17 00:00:00 2001 From: bknueven <30801372+bknueven@users.noreply.github.com> Date: Wed, 18 Jun 2025 08:37:02 -0600 Subject: [PATCH 15/19] Use startswith Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mpisppy/cylinders/spwindow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mpisppy/cylinders/spwindow.py b/mpisppy/cylinders/spwindow.py index 3e2b915fe..4881649f3 100644 --- a/mpisppy/cylinders/spwindow.py +++ b/mpisppy/cylinders/spwindow.py @@ -83,7 +83,7 @@ def __init__(self, opt): for p in field_length_components.component_data_objects(): # leave user-set parameter alone, just clear the # "private" parameters - if p.name[0] == "_": + if p.name.startswith("_"): p.clear() def __getitem__(self, field: Field): From 53ffc5f6b7e1b3ed7ff6a547901fd7ccb6d72de2 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Fri, 20 Jun 2025 15:25:57 -0600 Subject: [PATCH 16/19] next_value_array -> next_value_array_reference --- mpisppy/cylinders/spcommunicator.py | 2 +- mpisppy/cylinders/spoke.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mpisppy/cylinders/spcommunicator.py b/mpisppy/cylinders/spcommunicator.py index d6bea9448..a29f7c89f 100644 --- a/mpisppy/cylinders/spcommunicator.py +++ b/mpisppy/cylinders/spcommunicator.py @@ -133,7 +133,7 @@ class SendCircularBuffer(_CircularBuffer): def __init__(self, data: SendArray, field_length: int, buffer_size: int): super().__init__(data, field_length, buffer_size) - def next_value_array(self): + def next_value_array_reference(self): return self._get_value_array(self.data.id()) diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index 5ff5d2a30..91801a542 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -202,7 +202,7 @@ def send_best_xhat(self): self.put_send_buffer(best_xhat_buf, Field.BEST_XHAT) def send_latest_xhat(self): - recent_xhat_buf = self._recent_xhat_send_circular_buffer.next_value_array() + recent_xhat_buf = self._recent_xhat_send_circular_buffer.next_value_array_reference() ci = 0 for s in self.opt.local_scenarios.values(): solution_cache = s._mpisppy_data.latest_solution_cache._dict From 9e607d069a71c4efc20e52e5ffb09ee6ee3b8e2d Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Fri, 20 Jun 2025 15:40:29 -0600 Subject: [PATCH 17/19] remove unneeded __init__, add comments / documentation --- mpisppy/cylinders/spcommunicator.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/mpisppy/cylinders/spcommunicator.py b/mpisppy/cylinders/spcommunicator.py index a29f7c89f..affde2f17 100644 --- a/mpisppy/cylinders/spcommunicator.py +++ b/mpisppy/cylinders/spcommunicator.py @@ -115,6 +115,19 @@ def _pull_id(self) -> int: class _CircularBuffer: + """ + The circular buffer is meant for holding several versions of a Field + (defined by the `buffer_size`). The `data` object is an instance of + `FieldArray`. + + To know where in the buffer we are, we use the FieldArray._id. The layout + looks like this for a `buffer_size` of 4: + + |--0--|--1--|--2--|--3--|id| + + The id % buffer_size tells us which data point is the most recent, such + that individual ids are not needed for each instance. + """ def __init__(self, data: FieldArray, field_length: int, buffer_size: int): # last byte is the "write pointer" @@ -130,14 +143,20 @@ def _get_value_array(self, read_write_index): class SendCircularBuffer(_CircularBuffer): - def __init__(self, data: SendArray, field_length: int, buffer_size: int): - super().__init__(data, field_length, buffer_size) - def next_value_array_reference(self): + # NOTE: The id gets incremented in the call + # to `put_send_buffer`, which is necessarily + # called *after* this method. Therefore + # we start at 0 and go up, and when sent + # will be the id of the next *open* position return self._get_value_array(self.data.id()) class RecvCircularBuffer(_CircularBuffer): + # The _read_id tells us where we last read from, and the + # (data.id % buffer_size) - 1 + # has the last place written to. Therefore, we know which + # items in the buffer are new based on their difference. def __init__(self, data: RecvArray, field_length: int, buffer_size: int): super().__init__(data, field_length, buffer_size) From 3ddaa13bd1e9dbca0779782b143fe1d5a3ba15ac Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Fri, 20 Jun 2025 15:41:09 -0600 Subject: [PATCH 18/19] next_value_arrays -> most_recent_value_arrays --- mpisppy/cylinders/spcommunicator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mpisppy/cylinders/spcommunicator.py b/mpisppy/cylinders/spcommunicator.py index affde2f17..081d54577 100644 --- a/mpisppy/cylinders/spcommunicator.py +++ b/mpisppy/cylinders/spcommunicator.py @@ -162,7 +162,7 @@ def __init__(self, data: RecvArray, field_length: int, buffer_size: int): super().__init__(data, field_length, buffer_size) self._read_id = 0 - def next_value_arrays(self): + def most_recent_value_arrays(self): # if the writes have already "wrapped around" the buffer, # we need to fast-forward the read index so we don't read # the same data multiple times From d7552f2e4ff5f21a327b7f8052ffe3cab612073d Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Fri, 20 Jun 2025 15:54:06 -0600 Subject: [PATCH 19/19] only store nonants for each solution --- mpisppy/cylinders/spoke.py | 8 ++++---- mpisppy/spbase.py | 25 ++++++++++++++--------- mpisppy/tests/test_component_map_usage.py | 2 +- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index 91801a542..5a83b0b0f 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -205,10 +205,10 @@ def send_latest_xhat(self): recent_xhat_buf = self._recent_xhat_send_circular_buffer.next_value_array_reference() ci = 0 for s in self.opt.local_scenarios.values(): - solution_cache = s._mpisppy_data.latest_solution_cache._dict - for ndn_varid in s._mpisppy_data.varid_to_nonant_index: - recent_xhat_buf[ci] = solution_cache[ndn_varid][1] - ci += 1 + solution_cache = s._mpisppy_data.latest_nonant_solution_cache + len_nonants = len(s._mpisppy_data.nonant_indices) + recent_xhat_buf[ci:ci+len_nonants] = solution_cache[:] + ci += len_nonants recent_xhat_buf[ci] = s._mpisppy_data.inner_bound ci += 1 # print(f"{self.cylinder_rank=} sending {recent_xhat_buf=}") diff --git a/mpisppy/spbase.py b/mpisppy/spbase.py index a22e20b4d..3dcabce95 100644 --- a/mpisppy/spbase.py +++ b/mpisppy/spbase.py @@ -130,7 +130,7 @@ def __init__( self._verify_nonant_lengths() self._set_sense() self._use_variable_probability_setter() - self._set_best_solution_cache() + self._set_solution_cache() ## SPCommunicator object self._spcomm = None @@ -553,10 +553,11 @@ def _options_check(self, required_options, given_options): if missing: raise ValueError(f"Missing the following required options: {', '.join(missing)}") - def _set_best_solution_cache(self): + def _set_solution_cache(self): # set up best solution cache for k,s in self.local_scenarios.items(): s._mpisppy_data.best_solution_cache = None + s._mpisppy_data.latest_nonant_solution_cache = np.full(len(s._mpisppy_data.nonant_indices), np.nan) def update_best_solution_if_improving(self, obj_val): """ Call if the variable values have a nonanticipative solution @@ -572,7 +573,7 @@ def update_best_solution_if_improving(self, obj_val): else: update = (self.best_solution_obj_val < obj_val) if obj_val is not None: - self._cache_latest_solution() + self._cache_latest_solution_nonants() if update: self.best_solution_obj_val = obj_val self._cache_best_solution() @@ -581,13 +582,17 @@ def update_best_solution_if_improving(self, obj_val): def _cache_best_solution(self): for k,s in self.local_scenarios.items(): - s._mpisppy_data.best_solution_cache = s._mpisppy_data.latest_solution_cache + scenario_cache = pyo.ComponentMap() + _put_var_vals_in_component_map_dict( + scenario_cache._dict, + s.component_data_objects(pyo.Var) + ) + s._mpisppy_data.best_solution_cache = scenario_cache - def _cache_latest_solution(self): + def _cache_latest_solution_nonants(self): for k,s in self.local_scenarios.items(): - scenario_cache = pyo.ComponentMap() - _put_var_vals_in_component_map_dict(scenario_cache._dict, s) - s._mpisppy_data.latest_solution_cache = scenario_cache + for idx, v in enumerate(s._mpisppy_data.nonant_indices.values()): + s._mpisppy_data.latest_nonant_solution_cache[idx] = v.value def _get_cylinder_name(self): if self.spcomm: @@ -748,6 +753,6 @@ def write_tree_solution(self, directory_name, scenario_tree_solution_writer(directory_name, scenario_name, scenario, self.bundling) -def _put_var_vals_in_component_map_dict(sn_cache_dict, model): - for var in model.component_data_objects(pyo.Var): +def _put_var_vals_in_component_map_dict(sn_cache_dict, var_iter): + for var in var_iter: sn_cache_dict[id(var)] = (var, var.value) diff --git a/mpisppy/tests/test_component_map_usage.py b/mpisppy/tests/test_component_map_usage.py index 8ef67e247..d361b8503 100644 --- a/mpisppy/tests/test_component_map_usage.py +++ b/mpisppy/tests/test_component_map_usage.py @@ -19,7 +19,7 @@ def test_component_map_usage(): cmap = pyo.ComponentMap() - _put_var_vals_in_component_map_dict(cmap._dict, m) + _put_var_vals_in_component_map_dict(cmap._dict, m.component_data_objects(pyo.Var)) assert cmap[m.x[1]] == 2 assert cmap[m.x[2]] == 2