Skip to content

Commit 8cc8085

Browse files
authored
Merge pull request #543 from aasgerr/grad-rho
Reimplementation of grad rho
2 parents 900d058 + 1bc0d1d commit 8cc8085

File tree

2 files changed

+331
-0
lines changed

2 files changed

+331
-0
lines changed

mpisppy/extensions/grad_rho.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
###############################################################################
2+
# mpi-sppy: MPI-based Stochastic Programming in PYthon
3+
#
4+
# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for
5+
# Sustainable Energy, LLC, The Regents of the University of California, et al.
6+
# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for
7+
# full copyright and license information.
8+
###############################################################################
9+
10+
import mpisppy.extensions.dyn_rho_base
11+
import numpy as np
12+
from pyomo.core.expr.calculus.derivatives import differentiate
13+
from pyomo.core.expr.calculus.derivatives import Modes
14+
import pyomo.environ as pyo
15+
import mpisppy.MPI as MPI
16+
from mpisppy import global_toc
17+
import mpisppy.utils.sputils as sputils
18+
from mpisppy.cylinders.spwindow import Field
19+
20+
class GradRho(mpisppy.extensions.dyn_rho_base.Dyn_Rho_extension_base):
21+
"""
22+
Gradient-based rho from
23+
Gradient-based rho Parameter for Progressive Hedging
24+
U. Naepels, David L. Woodruff, 2023
25+
26+
Includes modifications to extend scenario-based denominators
27+
to multi-stage problems and calculation of gradients from objective
28+
expressions on the fly.
29+
A. Asger, B. Knueven, 2025
30+
"""
31+
32+
def __init__(self, opt):
33+
cfg = opt.options["grad_rho_options"]["cfg"]
34+
super().__init__(opt, cfg)
35+
self.opt = opt
36+
self.alpha = cfg.grad_order_stat
37+
assert (self.alpha >= 0 and self.alpha <= 1), f"For grad_order_stat 0 is the min, 0.5 the average, 1 the max; {self.alpha=} is invalid."
38+
self.multiplier = 1.0
39+
40+
if (
41+
cfg.grad_rho_multiplier
42+
):
43+
self.multiplier = cfg.grad_rho_multiplier
44+
45+
self.eval_at_xhat = cfg.eval_at_xhat
46+
self.indep_denom = cfg.indep_denom
47+
48+
def _scen_dep_denom(self, s):
49+
""" Computes scenario dependent denominator for grad rho calculation.
50+
51+
Args:
52+
s (Pyomo Concrete Model): scenario
53+
54+
Returns:
55+
scen_dep_denom (numpy array): denominator
56+
57+
"""
58+
59+
scen_dep_denom = {}
60+
61+
xbars = s._mpisppy_model.xbars
62+
63+
for ndn_i, v in s._mpisppy_data.nonant_indices.items():
64+
scen_dep_denom[ndn_i] = abs(v._value - xbars[ndn_i]._value)
65+
66+
denom_max = max(scen_dep_denom.values())
67+
68+
for ndn_i, v in s._mpisppy_data.nonant_indices.items():
69+
if scen_dep_denom[ndn_i] <= self.opt.E1_tolerance:
70+
scen_dep_denom[ndn_i] = max(denom_max, self.opt.E1_tolerance)
71+
72+
return scen_dep_denom
73+
74+
def _scen_indep_denom(self):
75+
""" Computes scenario independent denominator for grad rho calculation.
76+
77+
Returns:
78+
scen_indep_denom (numpy array): denominator
79+
80+
"""
81+
opt = self.opt
82+
local_nodenames = []
83+
local_denoms = {}
84+
global_denoms = {}
85+
86+
for k, s in opt.local_scenarios.items():
87+
nlens = s._mpisppy_data.nlens
88+
for node in s._mpisppy_node_list:
89+
if node.name not in local_nodenames:
90+
ndn = node.name
91+
local_nodenames.append(ndn)
92+
nlen = nlens[ndn]
93+
94+
local_denoms[ndn] = np.zeros(nlen, dtype="d")
95+
global_denoms[ndn] = np.zeros(nlen, dtype="d")
96+
97+
for k, s in opt.local_scenarios.items():
98+
nlens = s._mpisppy_data.nlens
99+
xbars = s._mpisppy_model.xbars
100+
for node in s._mpisppy_node_list:
101+
ndn = node.name
102+
denoms = local_denoms[ndn]
103+
104+
unweighted_denoms = np.fromiter(
105+
(
106+
abs(v._value - xbars[ndn, i]._value)
107+
for i, v in enumerate(node.nonant_vardata_list)
108+
),
109+
dtype="d",
110+
count=nlens[ndn],
111+
)
112+
denoms += s._mpisppy_data.prob_coeff[ndn] * unweighted_denoms
113+
114+
for nodename in local_nodenames:
115+
opt.comms[nodename].Allreduce(
116+
[local_denoms[nodename], MPI.DOUBLE],
117+
[global_denoms[nodename], MPI.DOUBLE],
118+
op=MPI.SUM,
119+
)
120+
121+
scen_indep_denom = {}
122+
for ndn, global_denom in global_denoms.items():
123+
for i, v in enumerate(global_denom):
124+
scen_indep_denom[ndn, i] = v
125+
126+
return scen_indep_denom
127+
128+
def _get_grad_exprs(self):
129+
""" Grabs and caches the gradient expressions for each scenario's objective (without proximal term). """
130+
131+
self.grad_exprs = dict()
132+
133+
for s in self.opt.local_scenarios.values():
134+
self.grad_exprs[s] = differentiate(sputils.find_active_objective(s),
135+
wrt_list=s._mpisppy_data.nonant_indices.values(),
136+
mode=Modes.reverse_symbolic,
137+
)
138+
139+
self.grad_exprs[s] = {ndn_i : self.grad_exprs[s][i] for i, ndn_i in enumerate(s._mpisppy_data.nonant_indices)}
140+
141+
return
142+
143+
def _eval_grad_exprs(self, s, xhat):
144+
""" Evaluates the gradient expressions of the objectives for scenario s at xhat (if available) or the current values. """
145+
146+
ci = 0
147+
grads = {}
148+
149+
if self.eval_at_xhat:
150+
if True not in np.isnan(self.best_xhat_buf.value_array()):
151+
for ndn_i, var in s._mpisppy_data.nonant_indices.items():
152+
var.value = xhat[ci]
153+
ci += 1
154+
155+
for ndn_i, var in s._mpisppy_data.nonant_indices.items():
156+
grads[ndn_i] = pyo.value(self.grad_exprs[s][ndn_i])
157+
158+
return grads
159+
160+
def _compute_and_update_rho(self):
161+
""" Computes and sets rhos for each scenario and each variable based on scenario dependence of
162+
the denominator in rho calculation.
163+
"""
164+
165+
opt = self.opt
166+
local_scens = opt.local_scenarios.values()
167+
168+
if self.indep_denom:
169+
grad_denom = self._scen_indep_denom()
170+
loc_denom = {s: grad_denom for s in local_scens}
171+
else:
172+
loc_denom = {s: self._scen_dep_denom(s)
173+
for s in opt.local_scenarios.values()}
174+
175+
costs = {s: self._eval_grad_exprs(s, self.best_xhat_buf.value_array())
176+
for s in opt.local_scenarios.values()}
177+
178+
local_nodenames = []
179+
local_rhos = {}
180+
local_rho_mins = {}
181+
local_rho_maxes = {}
182+
local_rho_means = {}
183+
global_rho_mins = {}
184+
global_rho_maxes = {}
185+
global_rho_means = {}
186+
187+
for k, s in opt.local_scenarios.items():
188+
nlens = s._mpisppy_data.nlens
189+
for node in s._mpisppy_node_list:
190+
if node.name not in local_nodenames:
191+
ndn = node.name
192+
local_nodenames.append(ndn)
193+
nlen = nlens[ndn]
194+
195+
local_rhos[ndn] = np.zeros(nlen, dtype="d")
196+
local_rho_mins[ndn] = np.zeros(nlen, dtype="d")
197+
local_rho_maxes[ndn] = np.zeros(nlen, dtype="d")
198+
local_rho_means[ndn] = np.zeros(nlen, dtype="d")
199+
global_rho_mins[ndn] = np.zeros(nlen, dtype="d")
200+
global_rho_maxes[ndn] = np.zeros(nlen, dtype="d")
201+
global_rho_means[ndn] = np.zeros(nlen, dtype="d")
202+
203+
for k, s in opt.local_scenarios.items():
204+
nlens = s._mpisppy_data.nlens
205+
for node in s._mpisppy_node_list:
206+
ndn = node.name
207+
rhos = local_rhos[ndn]
208+
rho_mins = local_rho_mins[ndn]
209+
rho_maxes = local_rho_maxes[ndn]
210+
rho_means = local_rho_means[ndn]
211+
212+
rhos = np.fromiter(
213+
(
214+
abs(costs[s][ndn, i]/loc_denom[s][ndn, i])
215+
for i, v in enumerate(node.nonant_vardata_list)
216+
),
217+
dtype="d",
218+
count=nlens[ndn],
219+
)
220+
221+
np.minimum(rho_mins, rhos, out=rho_mins, where=(s._mpisppy_data.prob_coeff[ndn] > 0))
222+
np.maximum(rho_maxes, rhos, out=rho_maxes, where=(s._mpisppy_data.prob_coeff[ndn] > 0))
223+
rho_means += s._mpisppy_data.prob_coeff[ndn] * rhos
224+
225+
for nodename in local_nodenames:
226+
opt.comms[nodename].Allreduce(
227+
[local_rho_mins[nodename], MPI.DOUBLE],
228+
[global_rho_mins[nodename], MPI.DOUBLE],
229+
op=MPI.MIN,
230+
)
231+
232+
opt.comms[nodename].Allreduce(
233+
[local_rho_maxes[nodename], MPI.DOUBLE],
234+
[global_rho_maxes[nodename], MPI.DOUBLE],
235+
op=MPI.MAX,
236+
)
237+
238+
opt.comms[nodename].Allreduce(
239+
[local_rho_means[nodename], MPI.DOUBLE],
240+
[global_rho_means[nodename], MPI.DOUBLE],
241+
op=MPI.SUM,
242+
)
243+
244+
if self.alpha == 0.5:
245+
rhos = {(ndn, i): float(v) for ndn, rho_mean in global_rho_means.items() for i, v in enumerate(rho_mean)}
246+
elif self.alpha == 0.0:
247+
rhos = {(ndn, i): float(v) for ndn, rho_min in global_rho_mins.items() for i, v in enumerate(rho_min)}
248+
elif self.alpha == 1.0:
249+
rhos = {(ndn, i): float(v) for ndn, rho_max in global_rho_maxes.items() for i, v in enumerate(rho_max)}
250+
elif self.alpha < 0.5:
251+
rhos = {(ndn, i): float(min_v + self.alpha * 2 * (mean_v - min_v))
252+
for ndn in global_rho_mins.keys()
253+
for i, (min_v, mean_v) in enumerate(zip(global_rho_mins[ndn], global_rho_means[ndn]))}
254+
elif self.alpha > 0.5:
255+
rhos = {(ndn, i): float(2 * mean_v - max_v + self.alpha * 2 * (max_v - mean_v))
256+
for ndn in global_rho_maxes.keys()
257+
for i, (max_v, mean_v) in enumerate(zip(global_rho_maxes[ndn], global_rho_means[ndn]))}
258+
else:
259+
raise RuntimeError("Coding error.")
260+
261+
for s in opt.local_scenarios.values():
262+
for ndn_i, rho in s._mpisppy_model.rho.items():
263+
if rhos[ndn_i] != 0:
264+
rho._value = self.multiplier*rhos[ndn_i]
265+
266+
def compute_and_update_rho(self):
267+
self._compute_and_update_rho()
268+
sum_rho = 0.0
269+
num_rhos = 0 # could be computed...
270+
for sname, s in self.opt.local_scenarios.items():
271+
for ndn_i, nonant in s._mpisppy_data.nonant_indices.items():
272+
sum_rho += s._mpisppy_model.rho[ndn_i]._value
273+
num_rhos += 1
274+
rho_avg = sum_rho / num_rhos
275+
global_toc(f"Rho values recomputed - average rank 0 rho={rho_avg}")
276+
277+
def pre_iter0(self):
278+
pass
279+
280+
def iter0_post_solver_creation(self):
281+
pass
282+
283+
def post_iter0(self):
284+
global_toc("Using grad-rho rho setter")
285+
self.update_caches()
286+
self._get_grad_exprs()
287+
self.compute_and_update_rho()
288+
289+
def miditer(self):
290+
self.update_caches()
291+
self.opt.spcomm.get_receive_buffer(
292+
self.best_xhat_buf,
293+
Field.BEST_XHAT,
294+
self.best_xhat_spoke_index,
295+
)
296+
if self._update_recommended():
297+
self.compute_and_update_rho()
298+
299+
def enditer(self):
300+
pass
301+
302+
def post_everything(self):
303+
pass
304+
305+
def register_receive_fields(self):
306+
spcomm = self.opt.spcomm
307+
best_xhat_ranks = spcomm.fields_to_ranks[Field.BEST_XHAT]
308+
assert len(best_xhat_ranks) == 1
309+
index = best_xhat_ranks[0]
310+
311+
self.best_xhat_spoke_index = index
312+
313+
self.best_xhat_buf = spcomm.register_recv_field(
314+
Field.BEST_XHAT,
315+
self.best_xhat_spoke_index,
316+
)
317+
318+
return

mpisppy/utils/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,19 @@ def gradient_args(self):
955955
description="display rho during gradient calcs (default True)",
956956
domain=bool,
957957
default=True)
958+
self.add_to_config("grad_rho_multiplier",
959+
description="multiplier for GradRho (default 1.0)",
960+
domain=float,
961+
default=1.0)
962+
self.add_to_config("eval_at_xhat",
963+
description="evaluate the gradient at xhat whenever available (default False)",
964+
domain=bool,
965+
default=False)
966+
967+
self.add_to_config("indep_denom",
968+
description="evaluate rho using scenario independent denominator (default False)",
969+
domain=bool,
970+
default=False)
958971
# likely unused presently
959972
# self.add_to_config("grad_pd_thresh",
960973
# description="threshold for dual/primal during gradient calcs",

0 commit comments

Comments
 (0)