Skip to content

Commit 1154513

Browse files
committed
working bbs/bfs evaluator bnb algorithm
1 parent 9707655 commit 1154513

1 file changed

Lines changed: 54 additions & 19 deletions

File tree

src/hiopbbpy/opt/bnbalgorithm.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ def __init__(self, acqf, options = {}):
299299
num_bfs_workers = 1
300300
self.bbsevaluator = MPIEvaluator(function_mode=False, max_workers = num_bbs_workers)
301301
self.bfsevaluator = MPIEvaluator(function_mode=False, max_workers = num_bfs_workers)
302-
302+
self.num_bbs_workers = num_bbs_workers
303+
self.num_bfs_workers = num_bfs_workers
303304

304305
# For minimization, we find a feasible function value as the upper bound on the minimum value of the acquisition function.
305306
def compute_acqf_upper_bound(self, l, u):
@@ -334,11 +335,15 @@ def compute_acqf_bounds(self, l, u):
334335
acqf_U = self.acqf.evaluate(x_midpoint).flatten()[0]
335336
return acqf_bounds[0], acqf_U
336337
def _prune_queue(self, queue, lub, eps):
337-
"""Keep only nodes that can beat current GUB within tolerance; then re-heapify."""
338+
"""Keep only nodes whose lower-bound is not greater or equal least upper-bound + eps; then re-heapify."""
338339
# queue items are (L, counter, node)
339-
pruned = [(L, c, n) for (L, c, n) in queue if L <= lub + eps]
340+
pruned = [(L, c, n) for (L, c, n) in queue if L < lub + eps]
340341
heapq.heapify(pruned)
341342
return pruned
343+
def _prune_node_list(self, node_list, lub, eps):
344+
"""Keep only nodes whose lower-bound is not greater or equal least upper-bound + eps; then re-heapify."""
345+
pruned_node_list = [node for node in node_list if node.aq_L < lub + eps]
346+
return pruned_node_list
342347
def optimize(self):
343348
opt = self.bnboptimize(self.gpsurrogate.xlimits[:,0], self.gpsurrogate.xlimits[:,1])
344349
lopt = opt[0]
@@ -393,49 +398,78 @@ def bnboptimize(self, l_init, u_init):
393398

394399
heapq.heapify(self.queue)
395400

396-
401+
all_bfsnodes = []
397402

398403
# stopping criterion should be on the total maximum number of branched nodes
399404
num_branches = 0
400405
while num_branches < self.max_bnbiter:
401406

402407
# collect nodes to be branched on in list structure
403-
nodes = []
404-
408+
bbsnodes = []
405409
for i in range(self.nodes_per_batch - 1):
406410
if not self.queue:
407411
break # no more nodes in queue to branch on
408412
_, _, node = heapq.heappop(self.queue)
409-
nodes.append(node)
413+
bbsnodes.append(node)
410414

411415
# parallel branching and upper/lower bound node compuatations
412416
brancher = branching_wrapper(self.acqf, LUB = self.LUB, epsilon_prune=self.epsilon_prune)
413-
nodes = np.array(nodes)
414-
self.evaluator.submit_tasks(brancher.callback, nodes)
417+
bbsnodes = np.array(bbsnodes)
418+
self.bbsevaluator.submit_tasks(brancher.callback, bbsnodes)
419+
420+
# TODO: different nodes_per_batch for bbs/bfs sets
421+
bfsnodes = []
422+
for i in range(self.nodes_per_batch - 1):
423+
if len(all_bfsnodes) == 0:
424+
break
425+
node = all_bfsnodes.pop(0)
426+
bfsnodes.append(node)
427+
bfsnodes = np.array(bfsnodes)
428+
self.bfsevaluator.submit_tasks(brancher.callback, bfsnodes)
415429

416-
# self.evaluator.sync()
417430
# asynchronously retrieve results from Evaluator that have been processed
418-
children = self.evaluator.retrieve_results()
431+
bbschildren = self.bbsevaluator.retrieve_results()
419432

420433
# not all children are return, hence children is a ragged array
421434
# need to flatten this ragged list
422-
children = [item for sublist in children for item in sublist]
423-
num_branches += len(children)
424-
435+
bbschildren = [item for sublist in bbschildren for item in sublist]
436+
num_branches += len(bbschildren)
437+
438+
bfschildren = self.bfsevaluator.retrieve_results()
439+
bfschildren = [item for sublist in bfschildren for item in sublist]
440+
num_branches += len(bfschildren)
441+
442+
children = bbschildren + bfschildren # join child lists
443+
if len(children) == 0:
444+
continue
425445
# update best_node and queue via children
426446
updated_best_node = False
427447
for child in children:
428448
if child.aq_U < self.LUB:
429449
self.best_node = child
430450
self.LUB = self.best_node.aq_U
431451
updated_best_node=True
432-
# if gap of child with LUB is small enough exit here
433-
if child.aq_L < self.LUB + self.epsilon_prune:
434-
heapq.heappush(self.queue, (child.aq_L, next(self._ctr), child))
452+
children_lower_bounds = [child.aq_L for child in children]
453+
args = np.argwhere(np.array(children_lower_bounds) < self.LUB + self.epsilon_prune).flatten()
454+
children = [children[arg] for arg in args]
435455

456+
# TODO: criteria for moving children to all_bfsnodes or queue
457+
# now move pruned children to data structs for (potential) future evaluation
458+
children_lower_bounds = [child.aq_L for child in children]
459+
# sort the children in order of increasing acqf lower-bounds
460+
args = np.argsort(children_lower_bounds)
461+
children = [children[arg] for arg in args]
462+
for child in children:
463+
if len(self.queue) < self.num_bbs_workers:
464+
heapq.heappush(self.queue, (child.aq_L, next(self._ctr), child))
465+
else:
466+
all_bfsnodes.append(child)
467+
468+
469+
436470
# BnB opt progress report
437-
# only report if a child was returned in most recent
438-
# retrieve_results() Evaluator call
471+
# only report if one or more children were returned in most recent
472+
# retrieve_results() MPIEvaluator calls
439473
# only check optimality gap and reprune if one or more children were returned
440474
if len(children) > 0:
441475
gap = self.best_node.aq_U - self.best_node.aq_L
@@ -446,6 +480,7 @@ def bnboptimize(self, l_init, u_init):
446480
print(f"gap = {gap}")
447481
# prune queue based on potentially updated least upper bound
448482
self.queue = self._prune_queue(self.queue, self.LUB, self.epsilon_prune)
483+
all_bfsnodes = self._prune_node_list(all_bfsnodes, self.LUB, self.epsilon_prune)
449484

450485
if updated_best_node:
451486
if gap < self.epsilon_gap:

0 commit comments

Comments
 (0)