1111from mpisppy import MPI
1212from mpisppy .utils .lshaped_cuts import LShapedCutGenerator
1313from mpisppy .cylinders .spwindow import Field
14+ from mpisppy .cylinders .spoke import Spoke
1415
1516import numpy as np
1617import pyomo .environ as pyo
17- import mpisppy .cylinders .spoke as spoke
1818
19- class CrossScenarioCutSpoke (spoke .Spoke ):
20- def __init__ (self , spbase_object , fullcomm , strata_comm , cylinder_comm , options = None ):
21- super ().__init__ (spbase_object , fullcomm , strata_comm , cylinder_comm , options = options )
19+ class CrossScenarioCutSpoke (Spoke ):
20+
21+ send_fields = (* Spoke .send_fields , Field .CROSS_SCENARIO_CUT )
22+ receive_fields = (* Spoke .receive_fields , Field .NONANT , Field .CROSS_SCENARIO_COST )
2223
2324 def register_send_fields (self ) -> None :
2425
@@ -37,18 +38,28 @@ def register_send_fields(self) -> None:
3738 (self .nonant_per_scen , remainder ) = divmod (vbuflen , local_scen_count )
3839 assert (remainder == 0 )
3940
40- ## the _locals will also have the kill signal
4141 self .all_nonant_len = vbuflen
4242 self .all_eta_len = nscen * local_scen_count
4343
44- self .all_nonants = self .register_recv_field (Field .NONANT , 0 , vbuflen )
45- self .all_etas = self .register_recv_field (Field .CROSS_SCENARIO_COST , 0 , nscen * nscen )
4644
47- self .all_coefs = self .register_send_field (Field .CROSS_SCENARIO_CUT ,
48- nscen * (self .nonant_per_scen + 1 + 1 ))
45+ self .all_coefs = self .send_buffers [Field .CROSS_SCENARIO_CUT ]
4946
5047 return
5148
49+ def register_receive_fields (self ):
50+ super ().register_receive_fields ()
51+
52+ nonant_ranks = self .opt .spcomm .fields_to_ranks [Field .NONANT ]
53+ cs_cost_ranks = self .opt .spcomm .fields_to_ranks [Field .CROSS_SCENARIO_COST ]
54+
55+ assert len (nonant_ranks ) == 1
56+ assert len (cs_cost_ranks ) == 1
57+ assert nonant_ranks [0 ] == cs_cost_ranks [0 ]
58+ source_rank = nonant_ranks [0 ]
59+
60+ self .all_nonants = self .register_recv_field (Field .NONANT , source_rank )
61+ self .all_etas = self .register_recv_field (Field .CROSS_SCENARIO_COST , source_rank )
62+
5263 def prep_cs_cuts (self ):
5364 # create a map scenario -> index, this index is used for various lists containing scenario dependent info.
5465 self .scenario_to_index = { scen : indx for indx , scen in enumerate (self .opt .all_scenario_names ) }
@@ -135,7 +146,7 @@ def make_eta_lb_cut(self):
135146 ## this cut -- [ LB, -1, *0s ], i.e., -1*\eta + LB <= 0
136147 all_coefs [row_len * idx ] = self ._eta_lb_array [idx ]
137148 all_coefs [row_len * idx + 1 ] = - 1
138- self .spoke_to_hub (all_coefs , Field .CROSS_SCENARIO_CUT )
149+ self .put_send_buffer (all_coefs , Field .CROSS_SCENARIO_CUT )
139150
140151 def make_cut (self ):
141152
@@ -293,7 +304,7 @@ def make_cut(self):
293304 all_coefs [row_len * idx :row_len * (idx + 1 )] = coef_dict [k ]
294305 elif feas_cuts :
295306 all_coefs [row_len * idx :row_len * (idx + 1 )] = feas_cuts .pop ()
296- self .spoke_to_hub (all_coefs , Field .CROSS_SCENARIO_CUT )
307+ self .put_send_buffer (all_coefs , Field .CROSS_SCENARIO_CUT )
297308
298309 def main (self ):
299310 # call main cut generation routine
@@ -303,7 +314,6 @@ def main(self):
303314
304315 # main loop
305316 while not (self .got_kill_signal ()):
306- # if self._new_locals:
307317 if self .all_nonants .is_new () and self .all_etas .is_new ():
308318 self .make_cut ()
309319 ## End if
0 commit comments