3030
3131class Hub (SPCommunicator ):
3232
33+ send_fields = (* SPCommunicator .send_fields , Field .SHUTDOWN , Field .BEST_OBJECTIVE_BOUNDS ,)
34+ receive_fields = (* SPCommunicator .receive_fields , )
35+ optional_receive_fields = (* SPCommunicator .optional_receive_fields , Field .OBJECTIVE_INNER_BOUND , Field .OBJECTIVE_OUTER_BOUND , )
36+
3337 _hub_algo_best_bound_provider = False
3438
3539 def __init__ (self , spbase_object , fullcomm , strata_comm , cylinder_comm , communicators , options = None ):
@@ -85,7 +89,7 @@ def register_extension_recv_field(self, field: Field, strata_rank: int, buf_len:
8589 to the extension sync_with_spokes function.
8690 """
8791 key = self ._make_key (field , strata_rank )
88- if key not in self ._locals :
92+ if key not in self .receive_buffers :
8993 # if it is not already registered, we need to update the local buffer
9094 self .extension_recv .add (key )
9195 ## End if
@@ -103,7 +107,7 @@ def register_extension_send_field(self, field: Field, buf_len: int) -> SendArray
103107 return self .register_send_field (field , buf_len )
104108
105109 def is_send_field_registered (self , field : Field ) -> bool :
106- return field in self ._sends
110+ return field in self .send_buffers
107111
108112 def extension_send_field (self , field : Field , buf : SendArray ):
109113 """
@@ -117,7 +121,7 @@ def sync_extension_fields(self):
117121 Update all registered extension fields. Safe to call even when there are no extension fields.
118122 """
119123 for key in self .extension_recv :
120- ext_buf = self ._locals [key ]
124+ ext_buf = self .receive_buffers [key ]
121125 (field , srank ) = self ._split_key (key )
122126 ext_buf ._is_new = self .hub_from_spoke (ext_buf , srank , field )
123127 ## End for
@@ -233,7 +237,7 @@ def receive_innerbounds(self):
233237 logging .debug ("Hub is trying to receive from InnerBounds" )
234238 for idx in self .innerbound_spoke_indices :
235239 key = self ._make_key (Field .OBJECTIVE_INNER_BOUND , idx )
236- recv_buf = self ._locals [key ]
240+ recv_buf = self .receive_buffers [key ]
237241 is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_INNER_BOUND )
238242 if is_new :
239243 bound = recv_buf [0 ]
@@ -249,7 +253,7 @@ def receive_outerbounds(self):
249253 logging .debug ("Hub is trying to receive from OuterBounds" )
250254 for idx in self .outerbound_spoke_indices :
251255 key = self ._make_key (Field .OBJECTIVE_OUTER_BOUND , idx )
252- recv_buf = self ._locals [key ]
256+ recv_buf = self .receive_buffers [key ]
253257 is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_OUTER_BOUND )
254258 if is_new :
255259 bound = recv_buf [0 ]
@@ -320,18 +324,18 @@ def initialize_inner_bound_buffers(self):
320324 def _populate_boundsout_cache (self , buf ):
321325 """ Populate a given buffer with the current bounds
322326 """
323- buf [- 3 ] = self .BestOuterBound
324- buf [- 2 ] = self .BestInnerBound
327+ buf [0 ] = self .BestOuterBound
328+ buf [1 ] = self .BestInnerBound
325329
326330 def send_boundsout (self ):
327331 """ Send bounds to the appropriate spokes
328332 This is called only for spokes which are bounds only.
329333 w and nonant spokes are passed bounds through the w and nonant buffers
330334 """
331- my_bounds = self .boundsout_send_buffer
335+ my_bounds = self .send_buffers [ Field . BEST_OBJECTIVE_BOUNDS ]
332336 self ._populate_boundsout_cache (my_bounds .array ())
333337 logging .debug ("hub is sending bounds={}" .format (my_bounds ))
334- self .hub_to_spoke (my_bounds , Field .OBJECTIVE_BOUNDS )
338+ self .hub_to_spoke (my_bounds , Field .BEST_OBJECTIVE_BOUNDS )
335339 return
336340
337341 def initialize_spoke_indices (self ):
@@ -392,45 +396,7 @@ def initialize_spoke_indices(self):
392396
393397
394398 def register_send_fields (self ):
395-
396- self .shutdown = self .register_send_field (Field .SHUTDOWN , 1 )
397-
398- required_fields = set ()
399- for i , spoke in enumerate (self .communicators ):
400- if i == self .strata_rank :
401- continue
402- spoke_class = spoke ["spcomm_class" ]
403- if hasattr (spoke_class , "converger_spoke_types" ):
404- for cst in spoke_class .converger_spoke_types :
405- if cst == ConvergerSpokeType .W_GETTER :
406- required_fields .add (Field .DUALS )
407- elif cst == ConvergerSpokeType .NONANT_GETTER :
408- required_fields .add (Field .NONANT )
409- elif cst == ConvergerSpokeType .INNER_BOUND or cst == ConvergerSpokeType .OUTER_BOUND :
410- required_fields .add (Field .OBJECTIVE_BOUNDS )
411- else :
412- pass # Intentional no-op
413- ## End if
414- ## End for
415- else :
416- # Intentional no-op. Non-converger spokes need to register any needed
417- # fields separately. See the functions `register_extension_recv_field`
418- # and `register_extension_send_field`.
419- pass
420- ## End if
421- ## End for
422-
423- n_nonants = 0
424- for s in self .opt .local_scenarios .values ():
425- n_nonants += len (s ._mpisppy_data .nonant_indices )
426- ## End for
427-
428- if Field .DUALS in required_fields :
429- self .w_send_buffer = self .register_send_field (Field .DUALS , n_nonants )
430- if Field .NONANT in required_fields :
431- self .nonant_send_buffer = self .register_send_field (Field .NONANT , n_nonants )
432- if Field .OBJECTIVE_BOUNDS in required_fields :
433- self .boundsout_send_buffer = self .register_send_field (Field .OBJECTIVE_BOUNDS , 2 )
399+ super ().register_send_fields ()
434400
435401 # Not all opt classes may have extensions
436402 if getattr (self .opt , "extensions" , None ) is not None :
@@ -439,7 +405,6 @@ def register_send_fields(self):
439405 return
440406
441407
442-
443408 def hub_to_spoke (self , buf : SendArray , field : Field ):
444409 """ Put the specified values into the specified locally-owned buffer
445410 for the spoke to pick up.
@@ -534,13 +499,17 @@ def send_terminate(self):
534499 buffer, so every spoke will see it simultaneously.
535500 processes (don't need to call them one at a time).
536501 """
537- shutdown = self .shutdown
538- shutdown [0 ] = 1.0
539- self .hub_to_spoke (shutdown , Field .SHUTDOWN )
502+ self .send_buffers [Field .SHUTDOWN ][0 ] = 1.0
503+ self .hub_to_spoke (self .send_buffers [Field .SHUTDOWN ], Field .SHUTDOWN )
540504 return
541505
542506
543507class PHHub (Hub ):
508+
509+ send_fields = (* Hub .send_fields , Field .NONANT , Field .DUALS )
510+ receive_fields = (* Hub .receive_fields ,)
511+ optional_receive_fields = (* Hub .optional_receive_fields ,)
512+
544513 def setup_hub (self ):
545514 """ Must be called after make_windows(), so that
546515 the hub knows the sizes of all the spokes windows
@@ -673,8 +642,7 @@ def send_nonants(self):
673642 """
674643 self .opt ._save_nonants ()
675644 ci = 0 ## index to self.nonant_send_buffer
676- # my_nonants = self._sends[Field.NONANT]
677- nonant_send_buffer = self .nonant_send_buffer
645+ nonant_send_buffer = self .send_buffers [Field .NONANT ]
678646 for k , s in self .opt .local_scenarios .items ():
679647 for xvar in s ._mpisppy_data .nonant_indices .values ():
680648 nonant_send_buffer [ci ] = xvar ._value
@@ -690,7 +658,7 @@ def send_ws(self):
690658 """ Send dual weights to the appropriate spokes
691659 """
692660 # NOTE: my_ws.array() and self.w_send_buffer should be the same array.
693- my_ws = self ._sends [Field .DUALS ]
661+ my_ws = self .send_buffers [Field .DUALS ]
694662 self .opt ._populate_W_cache (my_ws .array (), padding = 1 )
695663 logging .debug ("hub is sending Ws={}" .format (my_ws .array ()))
696664
@@ -701,6 +669,10 @@ def send_ws(self):
701669
702670class LShapedHub (Hub ):
703671
672+ send_fields = (* Hub .send_fields , Field .NONANT ,)
673+ receive_fields = (* Hub .receive_fields ,)
674+ optional_receive_fields = (* Hub .optional_receive_fields ,)
675+
704676 def setup_hub (self ):
705677 """ Must be called after make_windows(), so that
706678 the hub knows the sizes of all the spokes windows
@@ -781,7 +753,7 @@ def send_nonants(self):
781753 TODO: Will likely fail with bundling
782754 """
783755 ci = 0 ## index to self.nonant_send_buffer
784- nonant_send_buffer = self .nonant_send_buffer
756+ nonant_send_buffer = self .send_buffers [ Field . NONANT ]
785757 for k , s in self .opt .local_scenarios .items ():
786758 nonant_to_root_var_map = s ._mpisppy_model .subproblem_to_root_vars_map
787759 for xvar in s ._mpisppy_data .nonant_indices .values ():
@@ -797,6 +769,8 @@ def send_nonants(self):
797769
798770class SubgradientHub (PHHub ):
799771
772+ # send / receive fields are same as PHHub
773+
800774 _hub_algo_best_bound_provider = True
801775
802776 def main (self ):
@@ -806,6 +780,8 @@ def main(self):
806780
807781class APHHub (PHHub ):
808782
783+ # send / receive fields are same as PHHub
784+
809785 def main (self ):
810786 """ SPComm gets attached by self.__init___; holding APH harmless """
811787 logger .critical ("aph debug main in hub.py" )
0 commit comments