1+ ###############################################################################
2+ # mpi-sppy: MPI-based Stochastic Programming in PYthon
3+ #
4+ # Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for
5+ # Sustainable Energy, LLC, The Regents of the University of California, et al.
6+ # All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for
7+ # full copyright and license information.
8+ ###############################################################################
9+
10+ import mpisppy .extensions .dyn_rho_base
11+ import numpy as np
12+ from pyomo .core .expr .calculus .derivatives import differentiate
13+ from pyomo .core .expr .calculus .derivatives import Modes
14+ import pyomo .environ as pyo
15+ import mpisppy .MPI as MPI
16+ from mpisppy import global_toc
17+ import mpisppy .utils .sputils as sputils
18+ from mpisppy .cylinders .spwindow import Field
19+
20+ class GradRho (mpisppy .extensions .dyn_rho_base .Dyn_Rho_extension_base ):
21+ """
22+ Gradient-based rho from
23+ Gradient-based rho Parameter for Progressive Hedging
24+ U. Naepels, David L. Woodruff, 2023
25+
26+ Includes modifications to extend scenario-based denominators
27+ to multi-stage problems and calculation of gradients from objective
28+ expressions on the fly.
29+ A. Asger, B. Knueven, 2025
30+ """
31+
32+ def __init__ (self , opt ):
33+ cfg = opt .options ["grad_rho_options" ]["cfg" ]
34+ super ().__init__ (opt , cfg )
35+ self .opt = opt
36+ self .alpha = cfg .grad_order_stat
37+ assert (self .alpha >= 0 and self .alpha <= 1 ), f"For grad_order_stat 0 is the min, 0.5 the average, 1 the max; { self .alpha = } is invalid."
38+ self .multiplier = 1.0
39+
40+ if (
41+ cfg .grad_rho_multiplier
42+ ):
43+ self .multiplier = cfg .grad_rho_multiplier
44+
45+ self .eval_at_xhat = cfg .eval_at_xhat
46+ self .indep_denom = cfg .indep_denom
47+
48+ def _scen_dep_denom (self , s ):
49+ """ Computes scenario dependent denominator for grad rho calculation.
50+
51+ Args:
52+ s (Pyomo Concrete Model): scenario
53+
54+ Returns:
55+ scen_dep_denom (numpy array): denominator
56+
57+ """
58+
59+ scen_dep_denom = {}
60+
61+ xbars = s ._mpisppy_model .xbars
62+
63+ for ndn_i , v in s ._mpisppy_data .nonant_indices .items ():
64+ scen_dep_denom [ndn_i ] = abs (v ._value - xbars [ndn_i ]._value )
65+
66+ denom_max = max (scen_dep_denom .values ())
67+
68+ for ndn_i , v in s ._mpisppy_data .nonant_indices .items ():
69+ if scen_dep_denom [ndn_i ] <= self .opt .E1_tolerance :
70+ scen_dep_denom [ndn_i ] = max (denom_max , self .opt .E1_tolerance )
71+
72+ return scen_dep_denom
73+
74+ def _scen_indep_denom (self ):
75+ """ Computes scenario independent denominator for grad rho calculation.
76+
77+ Returns:
78+ scen_indep_denom (numpy array): denominator
79+
80+ """
81+ opt = self .opt
82+ local_nodenames = []
83+ local_denoms = {}
84+ global_denoms = {}
85+
86+ for k , s in opt .local_scenarios .items ():
87+ nlens = s ._mpisppy_data .nlens
88+ for node in s ._mpisppy_node_list :
89+ if node .name not in local_nodenames :
90+ ndn = node .name
91+ local_nodenames .append (ndn )
92+ nlen = nlens [ndn ]
93+
94+ local_denoms [ndn ] = np .zeros (nlen , dtype = "d" )
95+ global_denoms [ndn ] = np .zeros (nlen , dtype = "d" )
96+
97+ for k , s in opt .local_scenarios .items ():
98+ nlens = s ._mpisppy_data .nlens
99+ xbars = s ._mpisppy_model .xbars
100+ for node in s ._mpisppy_node_list :
101+ ndn = node .name
102+ denoms = local_denoms [ndn ]
103+
104+ unweighted_denoms = np .fromiter (
105+ (
106+ abs (v ._value - xbars [ndn , i ]._value )
107+ for i , v in enumerate (node .nonant_vardata_list )
108+ ),
109+ dtype = "d" ,
110+ count = nlens [ndn ],
111+ )
112+ denoms += s ._mpisppy_data .prob_coeff [ndn ] * unweighted_denoms
113+
114+ for nodename in local_nodenames :
115+ opt .comms [nodename ].Allreduce (
116+ [local_denoms [nodename ], MPI .DOUBLE ],
117+ [global_denoms [nodename ], MPI .DOUBLE ],
118+ op = MPI .SUM ,
119+ )
120+
121+ scen_indep_denom = {}
122+ for ndn , global_denom in global_denoms .items ():
123+ for i , v in enumerate (global_denom ):
124+ scen_indep_denom [ndn , i ] = v
125+
126+ return scen_indep_denom
127+
128+ def _get_grad_exprs (self ):
129+ """ Grabs and caches the gradient expressions for each scenario's objective (without proximal term). """
130+
131+ self .grad_exprs = dict ()
132+
133+ for s in self .opt .local_scenarios .values ():
134+ self .grad_exprs [s ] = differentiate (sputils .find_active_objective (s ),
135+ wrt_list = s ._mpisppy_data .nonant_indices .values (),
136+ mode = Modes .reverse_symbolic ,
137+ )
138+
139+ self .grad_exprs [s ] = {ndn_i : self .grad_exprs [s ][i ] for i , ndn_i in enumerate (s ._mpisppy_data .nonant_indices )}
140+
141+ return
142+
143+ def _eval_grad_exprs (self , s , xhat ):
144+ """ Evaluates the gradient expressions of the objectives for scenario s at xhat (if available) or the current values. """
145+
146+ ci = 0
147+ grads = {}
148+
149+ if self .eval_at_xhat :
150+ if True not in np .isnan (self .best_xhat_buf .value_array ()):
151+ for ndn_i , var in s ._mpisppy_data .nonant_indices .items ():
152+ var .value = xhat [ci ]
153+ ci += 1
154+
155+ for ndn_i , var in s ._mpisppy_data .nonant_indices .items ():
156+ grads [ndn_i ] = pyo .value (self .grad_exprs [s ][ndn_i ])
157+
158+ return grads
159+
160+ def _compute_and_update_rho (self ):
161+ """ Computes and sets rhos for each scenario and each variable based on scenario dependence of
162+ the denominator in rho calculation.
163+ """
164+
165+ opt = self .opt
166+ local_scens = opt .local_scenarios .values ()
167+
168+ if self .indep_denom :
169+ grad_denom = self ._scen_indep_denom ()
170+ loc_denom = {s : grad_denom for s in local_scens }
171+ else :
172+ loc_denom = {s : self ._scen_dep_denom (s )
173+ for s in opt .local_scenarios .values ()}
174+
175+ costs = {s : self ._eval_grad_exprs (s , self .best_xhat_buf .value_array ())
176+ for s in opt .local_scenarios .values ()}
177+
178+ local_nodenames = []
179+ local_rhos = {}
180+ local_rho_mins = {}
181+ local_rho_maxes = {}
182+ local_rho_means = {}
183+ global_rho_mins = {}
184+ global_rho_maxes = {}
185+ global_rho_means = {}
186+
187+ for k , s in opt .local_scenarios .items ():
188+ nlens = s ._mpisppy_data .nlens
189+ for node in s ._mpisppy_node_list :
190+ if node .name not in local_nodenames :
191+ ndn = node .name
192+ local_nodenames .append (ndn )
193+ nlen = nlens [ndn ]
194+
195+ local_rhos [ndn ] = np .zeros (nlen , dtype = "d" )
196+ local_rho_mins [ndn ] = np .zeros (nlen , dtype = "d" )
197+ local_rho_maxes [ndn ] = np .zeros (nlen , dtype = "d" )
198+ local_rho_means [ndn ] = np .zeros (nlen , dtype = "d" )
199+ global_rho_mins [ndn ] = np .zeros (nlen , dtype = "d" )
200+ global_rho_maxes [ndn ] = np .zeros (nlen , dtype = "d" )
201+ global_rho_means [ndn ] = np .zeros (nlen , dtype = "d" )
202+
203+ for k , s in opt .local_scenarios .items ():
204+ nlens = s ._mpisppy_data .nlens
205+ for node in s ._mpisppy_node_list :
206+ ndn = node .name
207+ rhos = local_rhos [ndn ]
208+ rho_mins = local_rho_mins [ndn ]
209+ rho_maxes = local_rho_maxes [ndn ]
210+ rho_means = local_rho_means [ndn ]
211+
212+ rhos = np .fromiter (
213+ (
214+ abs (costs [s ][ndn , i ]/ loc_denom [s ][ndn , i ])
215+ for i , v in enumerate (node .nonant_vardata_list )
216+ ),
217+ dtype = "d" ,
218+ count = nlens [ndn ],
219+ )
220+
221+ np .minimum (rho_mins , rhos , out = rho_mins , where = (s ._mpisppy_data .prob_coeff [ndn ] > 0 ))
222+ np .maximum (rho_maxes , rhos , out = rho_maxes , where = (s ._mpisppy_data .prob_coeff [ndn ] > 0 ))
223+ rho_means += s ._mpisppy_data .prob_coeff [ndn ] * rhos
224+
225+ for nodename in local_nodenames :
226+ opt .comms [nodename ].Allreduce (
227+ [local_rho_mins [nodename ], MPI .DOUBLE ],
228+ [global_rho_mins [nodename ], MPI .DOUBLE ],
229+ op = MPI .MIN ,
230+ )
231+
232+ opt .comms [nodename ].Allreduce (
233+ [local_rho_maxes [nodename ], MPI .DOUBLE ],
234+ [global_rho_maxes [nodename ], MPI .DOUBLE ],
235+ op = MPI .MAX ,
236+ )
237+
238+ opt .comms [nodename ].Allreduce (
239+ [local_rho_means [nodename ], MPI .DOUBLE ],
240+ [global_rho_means [nodename ], MPI .DOUBLE ],
241+ op = MPI .SUM ,
242+ )
243+
244+ if self .alpha == 0.5 :
245+ rhos = {(ndn , i ): float (v ) for ndn , rho_mean in global_rho_means .items () for i , v in enumerate (rho_mean )}
246+ elif self .alpha == 0.0 :
247+ rhos = {(ndn , i ): float (v ) for ndn , rho_min in global_rho_mins .items () for i , v in enumerate (rho_min )}
248+ elif self .alpha == 1.0 :
249+ rhos = {(ndn , i ): float (v ) for ndn , rho_max in global_rho_maxes .items () for i , v in enumerate (rho_max )}
250+ elif self .alpha < 0.5 :
251+ rhos = {(ndn , i ): float (min_v + self .alpha * 2 * (mean_v - min_v ))
252+ for ndn in global_rho_mins .keys ()
253+ for i , (min_v , mean_v ) in enumerate (zip (global_rho_mins [ndn ], global_rho_means [ndn ]))}
254+ elif self .alpha > 0.5 :
255+ rhos = {(ndn , i ): float (2 * mean_v - max_v + self .alpha * 2 * (max_v - mean_v ))
256+ for ndn in global_rho_maxes .keys ()
257+ for i , (max_v , mean_v ) in enumerate (zip (global_rho_maxes [ndn ], global_rho_means [ndn ]))}
258+ else :
259+ raise RuntimeError ("Coding error." )
260+
261+ for s in opt .local_scenarios .values ():
262+ for ndn_i , rho in s ._mpisppy_model .rho .items ():
263+ if rhos [ndn_i ] != 0 :
264+ rho ._value = self .multiplier * rhos [ndn_i ]
265+
266+ def compute_and_update_rho (self ):
267+ self ._compute_and_update_rho ()
268+ sum_rho = 0.0
269+ num_rhos = 0 # could be computed...
270+ for sname , s in self .opt .local_scenarios .items ():
271+ for ndn_i , nonant in s ._mpisppy_data .nonant_indices .items ():
272+ sum_rho += s ._mpisppy_model .rho [ndn_i ]._value
273+ num_rhos += 1
274+ rho_avg = sum_rho / num_rhos
275+ global_toc (f"Rho values recomputed - average rank 0 rho={ rho_avg } " )
276+
277+ def pre_iter0 (self ):
278+ pass
279+
280+ def iter0_post_solver_creation (self ):
281+ pass
282+
283+ def post_iter0 (self ):
284+ global_toc ("Using grad-rho rho setter" )
285+ self .update_caches ()
286+ self ._get_grad_exprs ()
287+ self .compute_and_update_rho ()
288+
289+ def miditer (self ):
290+ self .update_caches ()
291+ self .opt .spcomm .get_receive_buffer (
292+ self .best_xhat_buf ,
293+ Field .BEST_XHAT ,
294+ self .best_xhat_spoke_index ,
295+ )
296+ if self ._update_recommended ():
297+ self .compute_and_update_rho ()
298+
299+ def enditer (self ):
300+ pass
301+
302+ def post_everything (self ):
303+ pass
304+
305+ def register_receive_fields (self ):
306+ spcomm = self .opt .spcomm
307+ best_xhat_ranks = spcomm .fields_to_ranks [Field .BEST_XHAT ]
308+ assert len (best_xhat_ranks ) == 1
309+ index = best_xhat_ranks [0 ]
310+
311+ self .best_xhat_spoke_index = index
312+
313+ self .best_xhat_buf = spcomm .register_recv_field (
314+ Field .BEST_XHAT ,
315+ self .best_xhat_spoke_index ,
316+ )
317+
318+ return
0 commit comments