11import abc
22import importlib .util
33import threading
4- from typing import Generic , TypeVar
4+ from collections .abc import Callable
5+ from typing import Any , cast , Generic , TypeVar , Union
56
67import equinox as eqx
78import equinox .internal as eqxi
@@ -73,10 +74,11 @@ def init(self) -> None:
7374 return None
7475
7576 def step (self , state , progress : FloatScalarLike ) -> None :
77+ del progress
7678 return state
7779
7880 def close (self , state ):
79- pass
81+ del state
8082
8183
8284NoProgressMeter .__init__ .__doc__ = """**Arguments:**
@@ -91,18 +93,7 @@ def _unvmap_min(x): # No `eqxi.unvmap_min` at the moment.
9193
9294class _TextProgressMeterState (eqx .Module ):
9395 progress : FloatScalarLike
94-
95-
96- def _print_percent_callback (progress ):
97- print (f"{ 100 * progress .item ():.2f} %" )
98-
99-
100- def _print_percent (progress ):
101- # `io_callback` would be preferable here, to indicate that it provides an output,
102- # but that's not supported in vmap-of-while.
103- progress = eqxi .nonbatchable (progress ) # check we'll only call the callback once.
104- jax .debug .callback (_print_percent_callback , progress , ordered = True )
105- return progress
96+ meter_idx : IntScalarLike
10697
10798
10899class TextProgressMeter (AbstractProgressMeter ):
@@ -118,9 +109,24 @@ class TextProgressMeter(AbstractProgressMeter):
118109
119110 minimum_increase : RealScalarLike = 0.02
120111
112+ @staticmethod
113+ def _init_bar () -> list [float ]:
114+ print ("0.00%" )
115+ return [0.0 ]
116+
121117 def init (self ) -> _TextProgressMeterState :
122- _print_percent (0.0 )
123- return _TextProgressMeterState (progress = jnp .array (0.0 ))
118+ meter_idx = _progress_meter_manager .init (self ._init_bar )
119+ return _TextProgressMeterState (meter_idx = meter_idx , progress = jnp .array (0.0 ))
120+
121+ @staticmethod
122+ def _step_bar (bar : list [float ], progress : FloatScalarLike ) -> None :
123+ if eqx .is_array (progress ):
124+ # May not be an array when called with `JAX_DISABLE_JIT=1`
125+ progress = cast (Union [Array , np .ndarray ], progress )
126+ progress = progress .item ()
127+ progress = cast (float , progress )
128+ bar [0 ] = progress
129+ print (f"{ 100 * progress :.2f} %" )
124130
125131 def step (
126132 self , state : _TextProgressMeterState , progress : FloatScalarLike
@@ -129,33 +135,31 @@ def step(
129135 # `state.progress` and `progress` will pick up a batch tracer.
130136 # (For the former, because the condition for the while-loop-over-steps becomes
131137 # batched, so necessarily everything in the body of the loop is as well.)
132- #
133- # We take a `min` over `progress` and a `max` over `state.progress`, as we want
134- # to report the progress made over the worst batch element.
135- state_progress = eqxi .unvmap_max (state .progress )
136- del state
137- progress = _unvmap_min (progress )
138- pred = eqxi .nonbatchable (progress - state_progress > self .minimum_increase )
138+ pred = eqxi .unvmap_all (
139+ (progress - state .progress > self .minimum_increase ) | (progress == 1 )
140+ )
139141
140142 # We only print if the progress has increased by at least `minimum_increase` to
141143 # avoid flooding the user with too many updates.
142- next_progress = jax .lax .cond (
143- pred ,
144- _print_percent ,
145- lambda _ : state_progress ,
146- progress ,
144+ next_progress , meter_idx = jax .lax .cond (
145+ eqxi .nonbatchable (pred ),
146+ lambda _idx : (
147+ progress ,
148+ _progress_meter_manager .step (self ._step_bar , progress , _idx ),
149+ ),
150+ lambda _idx : (state .progress , _idx ),
151+ state .meter_idx ,
147152 )
148153
149- return _TextProgressMeterState (progress = next_progress )
154+ return _TextProgressMeterState (progress = next_progress , meter_idx = meter_idx )
155+
156+ @staticmethod
157+ def _close_bar (bar : list [float ]):
158+ if bar [0 ] != 1 :
159+ print ("100.00%" )
150160
151161 def close (self , state : _TextProgressMeterState ):
152- # As in `step`, we `unvmap` to handle batched state.
153- # This means we only call the callback once.
154- progress = _unvmap_min (state .progress )
155- # Consumes `progress` without using it, to get the order of callbacks correct.
156- progress = jax .debug .callback (
157- lambda _ : print ("100.00%" ), progress , ordered = True
158- )
162+ _progress_meter_manager .close (self ._close_bar , state .meter_idx )
159163
160164
161165TextProgressMeter .__init__ .__doc__ = """**Arguments:**
@@ -168,7 +172,7 @@ def close(self, state: _TextProgressMeterState):
168172
169173
170174class _TqdmProgressMeterState (eqx .Module ):
171- progress_meter_id : IntScalarLike
175+ meter_idx : IntScalarLike
172176 step : IntScalarLike
173177
174178
@@ -184,73 +188,56 @@ def __check_init__(self):
184188 "Install it via `pip install tqdm`."
185189 )
186190
187- def init ( self ) -> _TqdmProgressMeterState :
188- # Not `pure_callback` because it's not a deterministic function of its input
189- # arguments.
190- # Not `debug.callback` because it has a return value.
191- progress_meter_id = io_callback (
192- _progress_meter_manager . init , jax . ShapeDtypeStruct ((), jnp . int32 )
191+ @ staticmethod
192+ def _init_bar () -> "tqdm.tqdm" : # pyright: ignore # noqa: F821
193+ import tqdm # pyright: ignore
194+
195+ bar_format = (
196+ "{percentage:.2f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
193197 )
194- progress_meter_id = eqxi .nonbatchable (progress_meter_id )
195- return _TqdmProgressMeterState (
196- progress_meter_id = progress_meter_id , step = jnp .array (0 )
198+ return tqdm .tqdm (
199+ total = 100 ,
200+ unit = "%" ,
201+ bar_format = bar_format ,
197202 )
198203
204+ def init (self ) -> _TqdmProgressMeterState :
205+ meter_idx = _progress_meter_manager .init (self ._init_bar )
206+ return _TqdmProgressMeterState (meter_idx = meter_idx , step = jnp .array (0 ))
207+
208+ @staticmethod
209+ def _step_bar (bar : "tqdm.tqdm" , progress : FloatScalarLike ) -> None : # pyright: ignore # noqa: F821
210+ bar .n = round (100 * float (progress ), 2 )
211+ bar .update (n = 0 )
212+ bar .refresh ()
213+
199214 def step (
200215 self ,
201216 state : _TqdmProgressMeterState ,
202217 progress : FloatScalarLike ,
203218 ) -> _TqdmProgressMeterState :
204- # As in `TextProgressMeter`, then `state` may pick up a batch tracer from a
205- # batched condition, so we need to handle that.
206- #
207- # In practice it should always be the case that this remains constant over the
208- # solve, so we can just do a max to extract the single value we want.
209- progress_meter_id = eqxi .unvmap_max (state .progress_meter_id )
210- # What happens here is that all batch values for `state.step` start off in sync,
219+ # Here we update every `refresh_rate` steps in order to limit expensive
220+ # callbacks.
221+ # The `unvmap_max` is because batch values for `state.step` start off in sync,
211222 # and then eventually will freeze their values as that batch element finishes
212223 # its solve. So take a `max` to get the true number of overall solve steps for
213224 # the batched system.
214- step = eqxi .unvmap_max (state .step )
215- del state
216- # Track the slowest batch element.
217- progress = _unvmap_min (progress )
218-
219- def update_progress_bar ():
220- # `io_callback` would be preferable here (to indicate the side-effect), but
221- # that's not supported in vmap-of-while. (Even when none of the inputs to
222- # the callback are batched.)
223- jax .debug .callback (
224- _progress_meter_manager .step , progress , progress_meter_id , ordered = True
225- )
226-
227- # Here we update every `refresh_rate` steps in order to limit expensive
228- # callbacks.
229- jax .lax .cond (
230- eqxi .nonbatchable (step % self .refresh_steps == 0 ),
231- update_progress_bar ,
232- lambda : None ,
225+ meter_idx = jax .lax .cond (
226+ eqxi .nonbatchable (eqxi .unvmap_max (state .step ) % self .refresh_steps == 0 ),
227+ lambda _idx : _progress_meter_manager .step (self ._step_bar , progress , _idx ),
228+ lambda _idx : _idx ,
229+ state .meter_idx ,
233230 )
231+ return _TqdmProgressMeterState (meter_idx = meter_idx , step = state .step + 1 )
234232
235- return _TqdmProgressMeterState (
236- progress_meter_id = progress_meter_id , step = step + 1
237- )
233+ @staticmethod
234+ def _close_bar (bar : "tqdm.tqdm" ): # pyright: ignore # noqa: F821
235+ bar .n = 100.0
236+ bar .update (n = 0 )
237+ bar .close ()
238238
239239 def close (self , state : _TqdmProgressMeterState ):
240- # `unvmap_max` as in `step`.
241- progress_meter_id = eqxi .unvmap_max (state .progress_meter_id )
242- # Pass in `step` to thread the order correctly. (`ordered=True` seems sketchy.
243- # At the very least it doesn't also hold the order wrt
244- # `jax.debug.callback(..., ordered=True)`.)
245- # In addition, unvmap it to be sure the callback is only called once.
246- step = eqxi .unvmap_max (state .step )
247- del state
248- io_callback (
249- lambda idx , _ : _progress_meter_manager .close (idx ),
250- None ,
251- progress_meter_id ,
252- step ,
253- )
240+ _progress_meter_manager .close (self ._close_bar , state .meter_idx )
254241
255242
256243TqdmProgressMeter .__init__ .__doc__ = """**Arguments:**
@@ -261,45 +248,65 @@ def close(self, state: _TqdmProgressMeterState):
261248"""
262249
263250
264- class _TqdmProgressMeterManager :
265- """Host-side progress meter manager for TqdmProgressMeter ."""
251+ class _ProgressMeterManager :
252+ """Host-side progress meter manager."""
266253
267254 def __init__ (self ):
268255 self .idx = 0
269256 self .bars = {}
270257 # Not sure how important a lock really is, but included just in case.
271258 self .lock = threading .Lock ()
272259
273- def init (self ) -> IntScalarLike :
274- with self .lock :
275- import tqdm # pyright: ignore
260+ def init (self , init_bar : Callable [[], Any ]) -> IntScalarLike :
261+ def _init () -> IntScalarLike :
262+ with self .lock :
263+ bar = init_bar ()
264+ self .idx += 1
265+ self .bars [self .idx ] = bar
266+ return np .array (self .idx , dtype = jnp .int32 )
276267
277- bar_format = (
278- "{percentage:.2f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
279- )
280- bar = tqdm .tqdm (
281- total = 100 ,
282- unit = "%" ,
283- bar_format = bar_format ,
284- )
285- self .idx += 1
286- self .bars [self .idx ] = bar
287- return np .array (self .idx , dtype = jnp .int32 )
288-
289- def step (self , progress : FloatScalarLike , idx : IntScalarLike ):
290- with self .lock :
291- bar = self .bars [int (idx )]
292- bar .n = round (100 * float (progress ), 2 )
293- bar .update (n = 0 )
294-
295- def close (self , idx : IntScalarLike ):
296- with self .lock :
297- idx = int (idx )
298- bar = self .bars [idx ]
299- bar .n = 100.0
300- bar .update (n = 0 )
301- bar .close ()
302- del self .bars [idx ]
303-
304-
305- _progress_meter_manager = _TqdmProgressMeterManager ()
268+ # Not `pure_callback` because it's not a deterministic function of its input
269+ # arguments.
270+ # Not `debug.callback` because it has a return value.
271+ meter_idx = io_callback (_init , jax .ShapeDtypeStruct ((), jnp .int32 ))
272+ return eqxi .nonbatchable (meter_idx )
273+
274+ def step (
275+ self ,
276+ step_bar : Callable [[Any , FloatScalarLike ], None ],
277+ progress : FloatScalarLike ,
278+ idx : IntScalarLike ,
279+ ) -> IntScalarLike :
280+ # Track the slowest batch element.
281+ progress = _unvmap_min (progress )
282+
283+ def _step (_progress , _idx ):
284+ with self .lock :
285+ try :
286+ # This may pick up a spurious batch tracer from a batched condition,
287+ # so we need to handle that. We do this by using an `np.unique`.
288+ # It should always be the case that `_idx` has precisely one value!
289+ bar = self .bars [np .unique (_idx ).item ()]
290+ except KeyError :
291+ pass # E.g. the backward pass after a forward pass.
292+ else :
293+ step_bar (bar , _progress )
294+ # Return the idx to thread the callbacks in the correct order.
295+ return _idx
296+
297+ return jax .pure_callback (_step , idx , progress , idx , vectorized = True ) # pyright: ignore
298+
299+ def close (self , close_bar : Callable [[Any ], None ], idx : IntScalarLike ):
300+ def _close (_idx ):
301+ with self .lock :
302+ _idx = _idx .item ()
303+ bar = self .bars [_idx ]
304+ close_bar (bar )
305+ del self .bars [_idx ]
306+
307+ # Unlike in `step`, we do the `unvmap_max` here. For mysterious reasons this
308+ # callback does not trigger at all otherwise.
309+ io_callback (_close , None , eqxi .unvmap_max (idx ))
310+
311+
312+ _progress_meter_manager = _ProgressMeterManager ()
0 commit comments