@@ -299,7 +299,8 @@ def __init__(self, acqf, options = {}):
299299 num_bfs_workers = 1
300300 self .bbsevaluator = MPIEvaluator (function_mode = False , max_workers = num_bbs_workers )
301301 self .bfsevaluator = MPIEvaluator (function_mode = False , max_workers = num_bfs_workers )
302-
302+ self .num_bbs_workers = num_bbs_workers
303+ self .num_bfs_workers = num_bfs_workers
303304
304305 # For minimization, we find a feasible function value as the upper bound on the minimum value of the acquisition function.
305306 def compute_acqf_upper_bound (self , l , u ):
@@ -334,11 +335,15 @@ def compute_acqf_bounds(self, l, u):
334335 acqf_U = self .acqf .evaluate (x_midpoint ).flatten ()[0 ]
335336 return acqf_bounds [0 ], acqf_U
336337 def _prune_queue (self , queue , lub , eps ):
337- """Keep only nodes that can beat current GUB within tolerance ; then re-heapify."""
338+ """Keep only nodes whose lower-bound is not greater or equal least upper-bound + eps ; then re-heapify."""
338339 # queue items are (L, counter, node)
339- pruned = [(L , c , n ) for (L , c , n ) in queue if L <= lub + eps ]
340+ pruned = [(L , c , n ) for (L , c , n ) in queue if L < lub + eps ]
340341 heapq .heapify (pruned )
341342 return pruned
343+ def _prune_node_list (self , node_list , lub , eps ):
344+ """Keep only nodes whose lower-bound is not greater or equal least upper-bound + eps; then re-heapify."""
345+ pruned_node_list = [node for node in node_list if node .aq_L < lub + eps ]
346+ return pruned_node_list
342347 def optimize (self ):
343348 opt = self .bnboptimize (self .gpsurrogate .xlimits [:,0 ], self .gpsurrogate .xlimits [:,1 ])
344349 lopt = opt [0 ]
@@ -393,49 +398,78 @@ def bnboptimize(self, l_init, u_init):
393398
394399 heapq .heapify (self .queue )
395400
396-
401+ all_bfsnodes = []
397402
398403 # stopping criterion should be on the total maximum number of branched nodes
399404 num_branches = 0
400405 while num_branches < self .max_bnbiter :
401406
402407 # collect nodes to be branched on in list structure
403- nodes = []
404-
408+ bbsnodes = []
405409 for i in range (self .nodes_per_batch - 1 ):
406410 if not self .queue :
407411 break # no more nodes in queue to branch on
408412 _ , _ , node = heapq .heappop (self .queue )
409- nodes .append (node )
413+ bbsnodes .append (node )
410414
411415 # parallel branching and upper/lower bound node compuatations
412416 brancher = branching_wrapper (self .acqf , LUB = self .LUB , epsilon_prune = self .epsilon_prune )
413- nodes = np .array (nodes )
414- self .evaluator .submit_tasks (brancher .callback , nodes )
417+ bbsnodes = np .array (bbsnodes )
418+ self .bbsevaluator .submit_tasks (brancher .callback , bbsnodes )
419+
420+ # TODO: different nodes_per_batch for bbs/bfs sets
421+ bfsnodes = []
422+ for i in range (self .nodes_per_batch - 1 ):
423+ if len (all_bfsnodes ) == 0 :
424+ break
425+ node = all_bfsnodes .pop (0 )
426+ bfsnodes .append (node )
427+ bfsnodes = np .array (bfsnodes )
428+ self .bfsevaluator .submit_tasks (brancher .callback , bfsnodes )
415429
416- # self.evaluator.sync()
417430 # asynchronously retrieve results from Evaluator that have been processed
418- children = self .evaluator .retrieve_results ()
431+ bbschildren = self .bbsevaluator .retrieve_results ()
419432
420433 # not all children are return, hence children is a ragged array
421434 # need to flatten this ragged list
422- children = [item for sublist in children for item in sublist ]
423- num_branches += len (children )
424-
435+ bbschildren = [item for sublist in bbschildren for item in sublist ]
436+ num_branches += len (bbschildren )
437+
438+ bfschildren = self .bfsevaluator .retrieve_results ()
439+ bfschildren = [item for sublist in bfschildren for item in sublist ]
440+ num_branches += len (bfschildren )
441+
442+ children = bbschildren + bfschildren # join child lists
443+ if len (children ) == 0 :
444+ continue
425445 # update best_node and queue via children
426446 updated_best_node = False
427447 for child in children :
428448 if child .aq_U < self .LUB :
429449 self .best_node = child
430450 self .LUB = self .best_node .aq_U
431451 updated_best_node = True
432- # if gap of child with LUB is small enough exit here
433- if child . aq_L < self .LUB + self .epsilon_prune :
434- heapq . heappush ( self . queue , ( child . aq_L , next ( self . _ctr ), child ))
452+ children_lower_bounds = [ child . aq_L for child in children ]
453+ args = np . argwhere ( np . array ( children_lower_bounds ) < self .LUB + self .epsilon_prune ). flatten ()
454+ children = [ children [ arg ] for arg in args ]
435455
456+ # TODO: criteria for moving children to all_bfsnodes or queue
457+ # now move pruned children to data structs for (potential) future evaluation
458+ children_lower_bounds = [child .aq_L for child in children ]
459+ # sort the children in order of increasing acqf lower-bounds
460+ args = np .argsort (children_lower_bounds )
461+ children = [children [arg ] for arg in args ]
462+ for child in children :
463+ if len (self .queue ) < self .num_bbs_workers :
464+ heapq .heappush (self .queue , (child .aq_L , next (self ._ctr ), child ))
465+ else :
466+ all_bfsnodes .append (child )
467+
468+
469+
436470 # BnB opt progress report
437- # only report if a child was returned in most recent
438- # retrieve_results() Evaluator call
471+ # only report if one or more children were returned in most recent
472+ # retrieve_results() MPIEvaluator calls
439473 # only check optimality gap and reprune if one or more children were returned
440474 if len (children ) > 0 :
441475 gap = self .best_node .aq_U - self .best_node .aq_L
@@ -446,6 +480,7 @@ def bnboptimize(self, l_init, u_init):
446480 print (f"gap = { gap } " )
447481 # prune queue based on potentially updated least upper bound
448482 self .queue = self ._prune_queue (self .queue , self .LUB , self .epsilon_prune )
483+ all_bfsnodes = self ._prune_node_list (all_bfsnodes , self .LUB , self .epsilon_prune )
449484
450485 if updated_best_node :
451486 if gap < self .epsilon_gap :
0 commit comments