Skip to content

Commit ee2fad2

Browse files
committed
updating bnb algorithm, asynchronous in a state of flux
1 parent 1036ba9 commit ee2fad2

File tree

1 file changed

+119
-71
lines changed

1 file changed

+119
-71
lines changed

src/hiopbbpy/opt/bnbalgorithm.py

Lines changed: 119 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -358,24 +358,9 @@ def compute_acqf_bounds(self, l, u):
358358
x_points[i, j] = l[j] + ((u[j] - l[j])/2.) * np.floor(i / (3**j)).astype(int) % 3
359359
acqf_eval = self.acqf.evaluate(x_points)
360360
acqf_U = min(acqf_eval.flatten())
361+
if acqf_U < acqf_bounds[0]:
362+
print("ERROR in bound computations U < L")
361363

362-
#acqf_callback = {'obj' : self.acqf.scalar_evaluate}
363-
#if self.acqf.has_gradient:
364-
# acqf_callback['grad'] = self.acqf.scalar_eval_g
365-
#minimizer_method = "SLSQP"
366-
#minimizer_options = {"maxiter" : 100}
367-
#minimizer_constraints = ()
368-
#acqf_minimizer = minimizer_wrapper(acqf_callback, minimizer_method, self.gpsurrogate.xlimits, minimizer_constraints, minimizer_options)
369-
#x0_pts = np.array([[uniform(b[0], b[1]) for b in self.gpsurrogate.xlimits] for _ in range(1)])
370-
371-
#opt_evaluator = Evaluator()
372-
#opt_output = opt_evaluator.run(acqf_minimizer.minimizer_callback, x0_pts)[0]
373-
#assert opt_output[2], f"local optimizer failed"
374-
375-
376-
#return acqf_bounds[0], opt_output[1]#acqf_U
377-
378-
379364
return acqf_bounds[0], acqf_U
380365
def _prune_queue(self, queue, lub, eps):
381366
"""Keep only nodes whose lower-bound is not greater or equal least upper-bound + eps; then re-heapify."""
@@ -452,69 +437,155 @@ def bnboptimize(self, l_init, u_init):
452437
max_bbs_node_size = 0
453438
max_bfs_node_size = 0
454439
start_time = time.time()
455-
while self.num_branches < self.max_bnbiter:
440+
while self.num_branches < self.max_bnbiter: # iteration limit
441+
if time.time() - start_time > self.max_bnbtime: # time limit
442+
print("maximum time has elapsed")
443+
break
444+
445+
# -- retrieve submitted tasks --
446+
# asynchronously retrieve results from Evaluator that have been processed
447+
self.bbsevaluator.sync()
448+
bbschildren = self.bbsevaluator.retrieve_results()
449+
450+
# not all children are return, hence children is a ragged array
451+
# need to flatten this ragged list
452+
bbschildren = [item for sublist in bbschildren for item in sublist]
453+
454+
self.bfsevaluator.sync()
455+
bfschildren = self.bfsevaluator.retrieve_results()
456+
bfschildren = [item for sublist in bfschildren for item in sublist]
457+
458+
children = bbschildren + bfschildren # join child lists
459+
460+
if len(children) == 0:
461+
if len(self.queue) == 0 and len(all_bfsnodes) == 0:
462+
print("no children retrieved and no nodes in bfs/bbs node lists")
463+
exit()
464+
if len(children) > 0:
465+
self.num_branches += len(children)
466+
print(f"{len(children)} children evaluated")
467+
print(f"elapsed time: {time.time() - start_time}")
468+
# update best_node via children
469+
updated_best_node = False
470+
for child in children:
471+
if child.aq_U < child.aq_L:
472+
print("ERROR: child upper bound < child lower bound")
473+
exit()
474+
if child.aq_U <= self.LUB:
475+
self.best_node = child
476+
self.LUB = self.best_node.aq_U
477+
updated_best_node = True
478+
children_lower_bounds = [child.aq_L for child in children]
479+
args = np.argwhere(np.array(children_lower_bounds) < self.LUB + self.epsilon_prune).flatten()
480+
print(f"{len(args)} children to be appended to bbs/bfs lists")
481+
children = [children[arg] for arg in args]
482+
483+
# now move pruned children to data structs for (potential) future evaluation
484+
children_lower_bounds = [child.aq_L for child in children]
485+
# sort the children in order of increasing acqf lower-bounds
486+
args = np.argsort(children_lower_bounds)
487+
children = [children[arg] for arg in args]
488+
for child in children:
489+
if len(self.queue) < self.max_queue_size:
490+
heapq.heappush(self.queue, (child.aq_L, next(self._ctr), child))
491+
else:
492+
all_bfsnodes.append(child)
493+
max_bbs_node_size = max(max_bbs_node_size, len(self.queue))
494+
max_bfs_node_size = max(max_bfs_node_size, len(all_bfsnodes))
495+
496+
# reprune
497+
print(f"|bbs nodes| = {len(self.queue)}, |bfs nodes| = {len(all_bfsnodes)} (prior to pruning)")
498+
self.queue = self._prune_queue(self.queue, self.LUB, self.epsilon_prune)
499+
all_bfsnodes = self._prune_node_list(all_bfsnodes, self.LUB, self.epsilon_prune)
500+
print(f"|bbs nodes| = {len(self.queue)}, |bfs nodes| = {len(all_bfsnodes)} (after pruning)")
501+
if updated_best_node:
502+
print("best node not yet submitted to evaluator")
503+
#if self.best_node not in all_bfsnodes and self.best_node not in
504+
505+
506+
# BnB opt progress report
507+
gap = self.best_node.aq_U - self.best_node.aq_L
508+
print(f"\n--- Total number branches {self.num_branches} ---")
509+
print(f"Best node bounds: l={self.best_node.l}, u={self.best_node.u}")
510+
print(f"Node acquisition bounds: L={self.best_node.aq_L}, U={self.best_node.aq_U}")
511+
print(f"Current best feasible value (LUB): {self.LUB}")
512+
print(f"gap = {gap}")
513+
print(f"size of bbs queue = {len(self.queue)}")
514+
print(f"size of bfs node list = {len(all_bfsnodes)}")
515+
print(f"number of submitted jobs (bbs): {self.bbsevaluator.num_submitted_tasks()}")
516+
print(f"number of submitted jobs (bfs): {self.bfsevaluator.num_submitted_tasks()}")
517+
print(f"--- ---\n")
518+
519+
520+
if updated_best_node:
521+
if gap < self.epsilon_gap:
522+
print(f"STOP: optimality gap = {gap} < {self.epsilon_gap}")
523+
break
524+
525+
526+
# -- submit new tasks --
527+
456528
# collect nodes to be branched on in list structure
457529
bbsnodes = []
458-
num_submitted_nodes = 0
459530

460531
# if the number of submitted jobs is too large then wait for some jobs to be processed
461532
if self.bbsevaluator.num_submitted_tasks() + self.bfsevaluator.num_submitted_tasks() > 10 * (self.num_bbs_workers + self.num_bfs_workers):
462-
if time.time() - start_time > self.max_bnbtime:
463-
print("maximum time has elapsed")
464-
break
465-
else:
466-
print("num submitted bbs tasks = ", self.bbsevaluator.num_submitted_tasks())
467-
print("num submitted bfs tasks = ", self.bfsevaluator.num_submitted_tasks())
468-
time.sleep(1.0) # give time for Evaluators to process jobs
469-
continue
533+
print("num submitted bbs tasks = ", self.bbsevaluator.num_submitted_tasks())
534+
print("num submitted bfs tasks = ", self.bfsevaluator.num_submitted_tasks())
535+
time.sleep(1.e-6) # give time for Evaluators to process jobs
536+
continue
470537

471538
# only submit additional tasks if there aren't too many in the Evaluators queue
472-
if self.bbsevaluator.num_submitted_tasks() < 10 * self.num_bbs_workers:
539+
if True:#self.bbsevaluator.num_submitted_tasks() < 10 * self.num_bbs_workers:
473540
for i in range(self.nodes_per_batch):
474541
if (not self.queue):
475542
break # no more nodes available to send to evaluator for branching/bound computations
476543
_, _, node = heapq.heappop(self.queue)
477544
bbsnodes.append(node)
478-
num_submitted_nodes += 1
479545

480546
# parallel branching and upper/lower bound node compuatations
481547
brancher = branching_wrapper(self.acqf, LUB = self.LUB, epsilon_prune=self.epsilon_prune)
482548
bbsnodes = np.array(bbsnodes)
483-
self.bbsevaluator.submit_tasks(brancher.callback, bbsnodes)
484-
549+
if len(bbsnodes) > 0:
550+
self.bbsevaluator.submit_tasks(brancher.callback, bbsnodes)
485551
bfsnodes = []
486552
# only submit additional tasks if there aren't too many in the Evaluators queue
487-
if self.bfsevaluator.num_submitted_tasks() < 10 * self.num_bfs_workers:
553+
if True:#self.bfsevaluator.num_submitted_tasks() < 10 * self.num_bfs_workers:
488554
for i in range(self.nodes_per_batch):
489555
if len(all_bfsnodes) == 0:
490556
break # no more nodes available to send to evaluator for branching/bound computations
491557
node = all_bfsnodes.pop(0)
492558
bfsnodes.append(node)
493559
bfsnodes = np.array(bfsnodes)
494-
self.bfsevaluator.submit_tasks(brancher.callback, bfsnodes)
560+
if len(bfsnodes) > 0:
561+
self.bfsevaluator.submit_tasks(brancher.callback, bfsnodes)
495562

496-
# asynchronously retrieve results from Evaluator that have been processed
497-
bbschildren = self.bbsevaluator.retrieve_results()
498563

499-
# not all children are return, hence children is a ragged array
500-
# need to flatten this ragged list
501-
bbschildren = [item for sublist in bbschildren for item in sublist]
564+
# retrieve all running jobs
565+
self.bbsevaluator.sync()
566+
bbschildren = self.bbsevaluator.retrieve_results()
502567

503-
bfschildren = self.bfsevaluator.retrieve_results()
504-
bfschildren = [item for sublist in bfschildren for item in sublist]
568+
# not all children are return, hence children is a ragged array
569+
# need to flatten this ragged list
570+
bbschildren = [item for sublist in bbschildren for item in sublist]
505571

506-
children = bbschildren + bfschildren # join child lists
507-
self.num_branches += len(children)
508-
if len(children) == 0:
509-
continue
572+
self.bfsevaluator.sync()
573+
bfschildren = self.bfsevaluator.retrieve_results()
574+
bfschildren = [item for sublist in bfschildren for item in sublist]
575+
576+
children = bbschildren + bfschildren # join child lists
510577

578+
if len(children) > 0:
579+
self.num_branches += len(children)
580+
print(f"{len(children)} children evaluated")
581+
print(f"elapsed time: {time.time() - start_time}")
511582
# update best_node via children
512583
updated_best_node = False
513584
for child in children:
514585
if child.aq_U <= self.LUB:
515586
self.best_node = child
516587
self.LUB = self.best_node.aq_U
517-
updated_best_node=True
588+
updated_best_node = True
518589
children_lower_bounds = [child.aq_L for child in children]
519590
args = np.argwhere(np.array(children_lower_bounds) < self.LUB + self.epsilon_prune).flatten()
520591
children = [children[arg] for arg in args]
@@ -535,24 +606,7 @@ def bnboptimize(self, l_init, u_init):
535606

536607
# BnB opt progress report
537608
gap = self.best_node.aq_U - self.best_node.aq_L
538-
print(f"\n--- Total number branches {self.num_branches} ---")
539-
print(f"Best node bounds: l={self.best_node.l}, u={self.best_node.u}")
540-
print(f"Node acquisition bounds: L={self.best_node.aq_L}, U={self.best_node.aq_U}")
541-
print(f"Current best feasible value (LUB): {self.LUB}")
542-
print(f"gap = {gap}")
543-
print(f"size of bbs queue = {len(self.queue)}")
544-
print(f"size of bfs node list = {len(all_bfsnodes)}")
545-
print(f"number of submitted jobs (bbs): {self.bbsevaluator.num_submitted_tasks()}")
546-
print(f"number of submitted jobs (bfs): {self.bfsevaluator.num_submitted_tasks()}")
547-
548-
# reprune
549-
self.queue = self._prune_queue(self.queue, self.LUB, self.epsilon_prune)
550-
all_bfsnodes = self._prune_node_list(all_bfsnodes, self.LUB, self.epsilon_prune)
551-
552-
if updated_best_node:
553-
if gap < self.epsilon_gap:
554-
print(f"STOP: optimality gap = {gap} < {self.epsilon_gap}")
555-
break
609+
556610

557611
print("\n=== Optimization Finished ===")
558612
print(f"Total number of branches: {self.num_branches}")
@@ -770,7 +824,8 @@ def compute_acqf_bounds(self, l, u):
770824
x_points[i, j] = l[j] + ((u[j] - l[j])/2.) * np.floor(i / (3**j)).astype(int) % 3
771825
acqf_eval = self.acqf.evaluate(x_points)
772826
acqf_U = min(acqf_eval.flatten())
773-
827+
if acqf_bounds[0] > acqf_U:
828+
print("ERROR in bound computations U < L")
774829

775830
#x_midpoint = np.atleast_2d(( l + u) / 2.)
776831
#acqf_U = self.acqf.evaluate(x_midpoint).flatten()[0]
@@ -795,13 +850,6 @@ def callback(self, nodes):
795850
for node in nodes.flatten():
796851
for child_l, child_u in branch(node.l, node.u):
797852
acqf_L, acqf_U = self.compute_acqf_bounds(child_l, child_u)
798-
# Child-level pre-prune
799-
#TODO: revisit this!
800-
# currently will be easier to track how many function
801-
# evaluations are in the queue for the MPIEvaluator
802-
# by removing this pre-pruning stage
803-
#if acqf_L >= self.LUB + self.epsilon_prune:
804-
# continue
805853
child = BnBNode(child_l, child_u, acqf_L, acqf_U)
806854
output.append(child)
807855
return [output]

0 commit comments

Comments
 (0)