1616from mpisppy import MPI
1717from mpisppy .cylinders .spcommunicator import RecvArray , SendArray , SPCommunicator
1818from math import inf
19- from mpisppy .cylinders .spoke import ConvergerSpokeType
2019
2120from mpisppy import global_toc
2221
@@ -51,6 +50,8 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
5150
5251 self .extension_recv = set ()
5352
53+ self .initialize_bound_values ()
54+
5455 return
5556
5657 @abc .abstractmethod
@@ -233,14 +234,12 @@ def receive_innerbounds(self):
233234 (but should be harmless to call if there are none)
234235 """
235236 logging .debug ("Hub is trying to receive from InnerBounds" )
236- for idx in self .innerbound_spoke_indices :
237- key = self ._make_key (Field .OBJECTIVE_INNER_BOUND , idx )
238- recv_buf = self .receive_buffers [key ]
237+ for idx , cls , recv_buf in self .receive_field_spcomms [Field .OBJECTIVE_INNER_BOUND ]:
239238 is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_INNER_BOUND )
240239 if is_new :
241240 bound = recv_buf [0 ]
242241 logging .debug ("!! new InnerBound to opt {}" .format (bound ))
243- self .BestInnerBound = self .InnerBoundUpdate (bound , idx )
242+ self .BestInnerBound = self .InnerBoundUpdate (bound , cls , idx )
244243 logging .debug ("ph back from InnerBounds" )
245244
246245 def receive_outerbounds (self ):
@@ -249,37 +248,35 @@ def receive_outerbounds(self):
249248 (but should be harmless to call if there are none)
250249 """
251250 logging .debug ("Hub is trying to receive from OuterBounds" )
252- for idx in self .outerbound_spoke_indices :
253- key = self ._make_key (Field .OBJECTIVE_OUTER_BOUND , idx )
254- recv_buf = self .receive_buffers [key ]
251+ for idx , cls , recv_buf in self .receive_field_spcomms [Field .OBJECTIVE_OUTER_BOUND ]:
255252 is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_OUTER_BOUND )
256253 if is_new :
257254 bound = recv_buf [0 ]
258255 logging .debug ("!! new OuterBound to opt {}" .format (bound ))
259- self .BestOuterBound = self .OuterBoundUpdate (bound , idx )
256+ self .BestOuterBound = self .OuterBoundUpdate (bound , cls , idx )
260257 logging .debug ("ph back from OuterBounds" )
261258
262- def OuterBoundUpdate (self , new_bound , idx = None , char = '*' ):
259+ def OuterBoundUpdate (self , new_bound , cls = None , idx = None , char = '*' ):
263260 current_bound = self .BestOuterBound
264261 if self ._outer_bound_update (new_bound , current_bound ):
265- if idx is None :
262+ if cls is None :
266263 self .latest_ob_char = char
267264 self .last_ob_idx = 0
268265 else :
269- self .latest_ob_char = self . outerbound_spoke_chars [ idx ]
266+ self .latest_ib_char = cls . converger_spoke_char
270267 self .last_ob_idx = idx
271268 return new_bound
272269 else :
273270 return current_bound
274271
275- def InnerBoundUpdate (self , new_bound , idx = None , char = '*' ):
272+ def InnerBoundUpdate (self , new_bound , cls = None , idx = None , char = '*' ):
276273 current_bound = self .BestInnerBound
277274 if self ._inner_bound_update (new_bound , current_bound ):
278- if idx is None :
275+ if cls is None :
279276 self .latest_ib_char = char
280277 self .last_ib_idx = 0
281278 else :
282- self .latest_ib_char = self . innerbound_spoke_chars [ idx ]
279+ self .latest_ib_char = cls . converger_spoke_char
283280 self .last_ib_idx = idx
284281 return new_bound
285282 else :
@@ -297,28 +294,6 @@ def initialize_bound_values(self):
297294 self ._inner_bound_update = lambda new , old : (new > old )
298295 self ._outer_bound_update = lambda new , old : (new < old )
299296
300- def initialize_outer_bound_buffers (self ):
301- """ Initialize outer bound receive buffers
302- """
303- self .outerbound_receive_buffers = dict ()
304- for idx in self .outerbound_spoke_indices :
305- self .outerbound_receive_buffers [idx ] = self .register_recv_field (
306- Field .OBJECTIVE_OUTER_BOUND , idx , 1 ,
307- )
308- ## End for
309- return
310-
311- def initialize_inner_bound_buffers (self ):
312- """ Initialize inner bound receive buffers
313- """
314- self .innerbound_receive_buffers = dict ()
315- for idx in self .innerbound_spoke_indices :
316- self .innerbound_receive_buffers [idx ] = self .register_recv_field (
317- Field .OBJECTIVE_INNER_BOUND , idx , 1
318- )
319- ## End for
320- return
321-
322297 def _populate_boundsout_cache (self , buf ):
323298 """ Populate a given buffer with the current bounds
324299 """
@@ -327,62 +302,26 @@ def _populate_boundsout_cache(self, buf):
327302
328303 def send_boundsout (self ):
329304 """ Send bounds to the appropriate spokes
330- This is called only for spokes which are bounds only.
331- w and nonant spokes are passed bounds through the w and nonant buffers
332305 """
333306 my_bounds = self .send_buffers [Field .BEST_OBJECTIVE_BOUNDS ]
334307 self ._populate_boundsout_cache (my_bounds .array ())
335308 logging .debug ("hub is sending bounds={}" .format (my_bounds ))
336309 self .hub_to_spoke (my_bounds , Field .BEST_OBJECTIVE_BOUNDS )
337310 return
338311
339- def initialize_spoke_indices (self ):
312+ def register_receive_fields (self ):
340313 """ Figure out what types of spokes we have,
341314 and sort them into the appropriate classes.
342315
343316 Note:
344317 Some spokes may be multiple types (e.g. outerbound and nonant),
345318 though not all combinations are supported.
346319 """
347- self .outerbound_spoke_indices = set ()
348- self .innerbound_spoke_indices = set ()
349- self .nonant_spoke_indices = set ()
350- self .w_spoke_indices = set ()
351-
352- self .outerbound_spoke_chars = dict ()
353- self .innerbound_spoke_chars = dict ()
354-
355- for (i , spoke ) in enumerate (self .communicators ):
356- if i == self .strata_rank :
357- continue
358- spoke_class = spoke ["spcomm_class" ]
359- if hasattr (spoke_class , "converger_spoke_types" ):
360- for cst in spoke_class .converger_spoke_types :
361- if cst == ConvergerSpokeType .OUTER_BOUND :
362- self .outerbound_spoke_indices .add (i )
363- self .outerbound_spoke_chars [i ] = spoke_class .converger_spoke_char
364- elif cst == ConvergerSpokeType .INNER_BOUND :
365- self .innerbound_spoke_indices .add (i )
366- self .innerbound_spoke_chars [i ] = spoke_class .converger_spoke_char
367- elif cst == ConvergerSpokeType .W_GETTER :
368- self .w_spoke_indices .add (i )
369- elif cst == ConvergerSpokeType .NONANT_GETTER :
370- self .nonant_spoke_indices .add (i )
371- else :
372- raise RuntimeError (f"Unrecognized converger_spoke_type { cst } " )
373-
374- else : ##this isn't necessarily wrong, i.e., cut generators
375- logger .debug (f"Spoke class { spoke_class } not recognized by hub" )
376-
377- # all _BoundSpoke spokes get hub bounds so we determine which spokes
378- # are "bounds only"
379- self .bounds_only_indices = \
380- (self .outerbound_spoke_indices | self .innerbound_spoke_indices ) - \
381- (self .w_spoke_indices | self .nonant_spoke_indices )
320+ super ().register_receive_fields ()
382321
383322 # Not all opt classes may have extensions
384323 if getattr (self .opt , "extensions" , None ) is not None :
385- self .opt .extobject .initialize_spoke_indices ()
324+ self .opt .extobject .register_receive_fields ()
386325
387326 return
388327
@@ -511,31 +450,14 @@ def setup_hub(self):
511450 "Cannot call setup_hub before memory windows are constructed"
512451 )
513452
514- self .initialize_spoke_indices ()
515- self .initialize_bound_values ()
516-
517- self .initialize_outer_bound_buffers ()
518- self .initialize_inner_bound_buffers ()
519-
520- ## Do some checking for things we currently don't support
521- if len (self .outerbound_spoke_indices & self .innerbound_spoke_indices ) > 0 :
522- raise RuntimeError (
523- "A Spoke providing both inner and outer "
524- "bounds is currently unsupported"
525- )
526- if len (self .w_spoke_indices & self .nonant_spoke_indices ) > 0 :
527- raise RuntimeError (
528- "A Spoke needing both Ws and nonants is currently unsupported"
529- )
530-
531453 ## Generate some warnings if nothing is giving bounds
532- if not self .outerbound_spoke_indices :
454+ if not self .receive_field_spcomms [ Field . OBJECTIVE_OUTER_BOUND ] :
533455 logger .warn (
534456 "No OuterBound Spokes defined, this converger "
535457 "will not cause the hub to terminate"
536458 )
537459
538- if not self .innerbound_spoke_indices :
460+ if not self .receive_field_spcomms [ Field . OBJECTIVE_INNER_BOUND ] :
539461 logger .warn (
540462 "No InnerBound Spokes defined, this converger "
541463 "will not cause the hub to terminate"
@@ -578,7 +500,7 @@ def is_converged(self):
578500 if self .opt .best_bound_obj_val is not None :
579501 self .BestOuterBound = self .OuterBoundUpdate (self .opt .best_bound_obj_val )
580502
581- if not self .innerbound_spoke_indices :
503+ if not self .receive_field_spcomms [ Field . OBJECTIVE_INNER_BOUND ] :
582504 if self .opt ._PHIter == 1 :
583505 logger .warning (
584506 "PHHub cannot compute convergence without "
@@ -591,7 +513,7 @@ def is_converged(self):
591513
592514 return False
593515
594- if not self .outerbound_spoke_indices :
516+ if not self .receive_field_spcomms [ Field . OBJECTIVE_OUTER_BOUND ] :
595517 if self .opt ._PHIter == 1 and not self ._hub_algo_best_bound_provider :
596518 global_toc (
597519 "Without outer bound spokes, no progress "
@@ -660,24 +582,8 @@ def setup_hub(self):
660582 "Cannot call setup_hub before memory windows are constructed"
661583 )
662584
663- self .initialize_spoke_indices ()
664- self .initialize_bound_values ()
665-
666- self .initialize_outer_bound_buffers ()
667- self .initialize_inner_bound_buffers ()
668-
669- ## Do some checking for things we currently
670- ## do not support
671- if self .w_spoke_indices :
672- raise RuntimeError ("LShaped hub does not compute dual weights (Ws)" )
673- if len (self .outerbound_spoke_indices & self .innerbound_spoke_indices ) > 0 :
674- raise RuntimeError (
675- "A Spoke providing both inner and outer "
676- "bounds is currently unsupported"
677- )
678-
679585 ## Generate some warnings if nothing is giving bounds
680- if not self .innerbound_spoke_indices :
586+ if not self .receive_field_spcomms [ Field . OBJECTIVE_INNER_BOUND ] :
681587 logger .warn (
682588 "No InnerBound Spokes defined, this converger "
683589 "will not cause the hub to terminate"
0 commit comments