Skip to content

Commit a998093

Browse files
Fixed progress meters with jax.grad
1 parent ac28ce4 commit a998093

File tree

2 files changed

+194
-130
lines changed

2 files changed

+194
-130
lines changed

diffrax/_progress_meter.py

Lines changed: 130 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import abc
22
import importlib.util
33
import threading
4-
from typing import Generic, TypeVar
4+
from collections.abc import Callable
5+
from typing import Any, cast, Generic, TypeVar, Union
56

67
import equinox as eqx
78
import equinox.internal as eqxi
@@ -73,10 +74,11 @@ def init(self) -> None:
7374
return None
7475

7576
def step(self, state, progress: FloatScalarLike) -> None:
77+
del progress
7678
return state
7779

7880
def close(self, state):
79-
pass
81+
del state
8082

8183

8284
NoProgressMeter.__init__.__doc__ = """**Arguments:**
@@ -91,18 +93,7 @@ def _unvmap_min(x): # No `eqxi.unvmap_min` at the moment.
9193

9294
class _TextProgressMeterState(eqx.Module):
9395
progress: FloatScalarLike
94-
95-
96-
def _print_percent_callback(progress):
97-
print(f"{100 * progress.item():.2f}%")
98-
99-
100-
def _print_percent(progress):
101-
# `io_callback` would be preferable here, to indicate that it provides an output,
102-
# but that's not supported in vmap-of-while.
103-
progress = eqxi.nonbatchable(progress) # check we'll only call the callback once.
104-
jax.debug.callback(_print_percent_callback, progress, ordered=True)
105-
return progress
96+
meter_idx: IntScalarLike
10697

10798

10899
class TextProgressMeter(AbstractProgressMeter):
@@ -118,9 +109,24 @@ class TextProgressMeter(AbstractProgressMeter):
118109

119110
minimum_increase: RealScalarLike = 0.02
120111

112+
@staticmethod
113+
def _init_bar() -> list[float]:
114+
print("0.00%")
115+
return [0.0]
116+
121117
def init(self) -> _TextProgressMeterState:
122-
_print_percent(0.0)
123-
return _TextProgressMeterState(progress=jnp.array(0.0))
118+
meter_idx = _progress_meter_manager.init(self._init_bar)
119+
return _TextProgressMeterState(meter_idx=meter_idx, progress=jnp.array(0.0))
120+
121+
@staticmethod
122+
def _step_bar(bar: list[float], progress: FloatScalarLike) -> None:
123+
if eqx.is_array(progress):
124+
# May not be an array when called with `JAX_DISABLE_JIT=1`
125+
progress = cast(Union[Array, np.ndarray], progress)
126+
progress = progress.item()
127+
progress = cast(float, progress)
128+
bar[0] = progress
129+
print(f"{100 * progress:.2f}%")
124130

125131
def step(
126132
self, state: _TextProgressMeterState, progress: FloatScalarLike
@@ -129,33 +135,31 @@ def step(
129135
# `state.progress` and `progress` will pick up a batch tracer.
130136
# (For the former, because the condition for the while-loop-over-steps becomes
131137
# batched, so necessarily everything in the body of the loop is as well.)
132-
#
133-
# We take a `min` over `progress` and a `max` over `state.progress`, as we want
134-
# to report the progress made over the worst batch element.
135-
state_progress = eqxi.unvmap_max(state.progress)
136-
del state
137-
progress = _unvmap_min(progress)
138-
pred = eqxi.nonbatchable(progress - state_progress > self.minimum_increase)
138+
pred = eqxi.unvmap_all(
139+
(progress - state.progress > self.minimum_increase) | (progress == 1)
140+
)
139141

140142
# We only print if the progress has increased by at least `minimum_increase` to
141143
# avoid flooding the user with too many updates.
142-
next_progress = jax.lax.cond(
143-
pred,
144-
_print_percent,
145-
lambda _: state_progress,
146-
progress,
144+
next_progress, meter_idx = jax.lax.cond(
145+
eqxi.nonbatchable(pred),
146+
lambda _idx: (
147+
progress,
148+
_progress_meter_manager.step(self._step_bar, progress, _idx),
149+
),
150+
lambda _idx: (state.progress, _idx),
151+
state.meter_idx,
147152
)
148153

149-
return _TextProgressMeterState(progress=next_progress)
154+
return _TextProgressMeterState(progress=next_progress, meter_idx=meter_idx)
155+
156+
@staticmethod
157+
def _close_bar(bar: list[float]):
158+
if bar[0] != 1:
159+
print("100.00%")
150160

151161
def close(self, state: _TextProgressMeterState):
152-
# As in `step`, we `unvmap` to handle batched state.
153-
# This means we only call the callback once.
154-
progress = _unvmap_min(state.progress)
155-
# Consumes `progress` without using it, to get the order of callbacks correct.
156-
progress = jax.debug.callback(
157-
lambda _: print("100.00%"), progress, ordered=True
158-
)
162+
_progress_meter_manager.close(self._close_bar, state.meter_idx)
159163

160164

161165
TextProgressMeter.__init__.__doc__ = """**Arguments:**
@@ -168,7 +172,7 @@ def close(self, state: _TextProgressMeterState):
168172

169173

170174
class _TqdmProgressMeterState(eqx.Module):
171-
progress_meter_id: IntScalarLike
175+
meter_idx: IntScalarLike
172176
step: IntScalarLike
173177

174178

@@ -184,73 +188,56 @@ def __check_init__(self):
184188
"Install it via `pip install tqdm`."
185189
)
186190

187-
def init(self) -> _TqdmProgressMeterState:
188-
# Not `pure_callback` because it's not a deterministic function of its input
189-
# arguments.
190-
# Not `debug.callback` because it has a return value.
191-
progress_meter_id = io_callback(
192-
_progress_meter_manager.init, jax.ShapeDtypeStruct((), jnp.int32)
191+
@staticmethod
192+
def _init_bar() -> "tqdm.tqdm": # pyright: ignore # noqa: F821
193+
import tqdm # pyright: ignore
194+
195+
bar_format = (
196+
"{percentage:.2f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
193197
)
194-
progress_meter_id = eqxi.nonbatchable(progress_meter_id)
195-
return _TqdmProgressMeterState(
196-
progress_meter_id=progress_meter_id, step=jnp.array(0)
198+
return tqdm.tqdm(
199+
total=100,
200+
unit="%",
201+
bar_format=bar_format,
197202
)
198203

204+
def init(self) -> _TqdmProgressMeterState:
205+
meter_idx = _progress_meter_manager.init(self._init_bar)
206+
return _TqdmProgressMeterState(meter_idx=meter_idx, step=jnp.array(0))
207+
208+
@staticmethod
209+
def _step_bar(bar: "tqdm.tqdm", progress: FloatScalarLike) -> None: # pyright: ignore # noqa: F821
210+
bar.n = round(100 * float(progress), 2)
211+
bar.update(n=0)
212+
bar.refresh()
213+
199214
def step(
200215
self,
201216
state: _TqdmProgressMeterState,
202217
progress: FloatScalarLike,
203218
) -> _TqdmProgressMeterState:
204-
# As in `TextProgressMeter`, then `state` may pick up a batch tracer from a
205-
# batched condition, so we need to handle that.
206-
#
207-
# In practice it should always be the case that this remains constant over the
208-
# solve, so we can just do a max to extract the single value we want.
209-
progress_meter_id = eqxi.unvmap_max(state.progress_meter_id)
210-
# What happens here is that all batch values for `state.step` start off in sync,
219+
# Here we update every `refresh_rate` steps in order to limit expensive
220+
# callbacks.
221+
# The `unvmap_max` is because batch values for `state.step` start off in sync,
211222
# and then eventually will freeze their values as that batch element finishes
212223
# its solve. So take a `max` to get the true number of overall solve steps for
213224
# the batched system.
214-
step = eqxi.unvmap_max(state.step)
215-
del state
216-
# Track the slowest batch element.
217-
progress = _unvmap_min(progress)
218-
219-
def update_progress_bar():
220-
# `io_callback` would be preferable here (to indicate the side-effect), but
221-
# that's not supported in vmap-of-while. (Even when none of the inputs to
222-
# the callback are batched.)
223-
jax.debug.callback(
224-
_progress_meter_manager.step, progress, progress_meter_id, ordered=True
225-
)
226-
227-
# Here we update every `refresh_rate` steps in order to limit expensive
228-
# callbacks.
229-
jax.lax.cond(
230-
eqxi.nonbatchable(step % self.refresh_steps == 0),
231-
update_progress_bar,
232-
lambda: None,
225+
meter_idx = jax.lax.cond(
226+
eqxi.nonbatchable(eqxi.unvmap_max(state.step) % self.refresh_steps == 0),
227+
lambda _idx: _progress_meter_manager.step(self._step_bar, progress, _idx),
228+
lambda _idx: _idx,
229+
state.meter_idx,
233230
)
231+
return _TqdmProgressMeterState(meter_idx=meter_idx, step=state.step + 1)
234232

235-
return _TqdmProgressMeterState(
236-
progress_meter_id=progress_meter_id, step=step + 1
237-
)
233+
@staticmethod
234+
def _close_bar(bar: "tqdm.tqdm"): # pyright: ignore # noqa: F821
235+
bar.n = 100.0
236+
bar.update(n=0)
237+
bar.close()
238238

239239
def close(self, state: _TqdmProgressMeterState):
240-
# `unvmap_max` as in `step`.
241-
progress_meter_id = eqxi.unvmap_max(state.progress_meter_id)
242-
# Pass in `step` to thread the order correctly. (`ordered=True` seems sketchy.
243-
# At the very least it doesn't also hold the order wrt
244-
# `jax.debug.callback(..., ordered=True)`.)
245-
# In addition, unvmap it to be sure the callback is only called once.
246-
step = eqxi.unvmap_max(state.step)
247-
del state
248-
io_callback(
249-
lambda idx, _: _progress_meter_manager.close(idx),
250-
None,
251-
progress_meter_id,
252-
step,
253-
)
240+
_progress_meter_manager.close(self._close_bar, state.meter_idx)
254241

255242

256243
TqdmProgressMeter.__init__.__doc__ = """**Arguments:**
@@ -261,45 +248,65 @@ def close(self, state: _TqdmProgressMeterState):
261248
"""
262249

263250

264-
class _TqdmProgressMeterManager:
265-
"""Host-side progress meter manager for TqdmProgressMeter."""
251+
class _ProgressMeterManager:
252+
"""Host-side progress meter manager."""
266253

267254
def __init__(self):
268255
self.idx = 0
269256
self.bars = {}
270257
# Not sure how important a lock really is, but included just in case.
271258
self.lock = threading.Lock()
272259

273-
def init(self) -> IntScalarLike:
274-
with self.lock:
275-
import tqdm # pyright: ignore
260+
def init(self, init_bar: Callable[[], Any]) -> IntScalarLike:
261+
def _init() -> IntScalarLike:
262+
with self.lock:
263+
bar = init_bar()
264+
self.idx += 1
265+
self.bars[self.idx] = bar
266+
return np.array(self.idx, dtype=jnp.int32)
276267

277-
bar_format = (
278-
"{percentage:.2f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
279-
)
280-
bar = tqdm.tqdm(
281-
total=100,
282-
unit="%",
283-
bar_format=bar_format,
284-
)
285-
self.idx += 1
286-
self.bars[self.idx] = bar
287-
return np.array(self.idx, dtype=jnp.int32)
288-
289-
def step(self, progress: FloatScalarLike, idx: IntScalarLike):
290-
with self.lock:
291-
bar = self.bars[int(idx)]
292-
bar.n = round(100 * float(progress), 2)
293-
bar.update(n=0)
294-
295-
def close(self, idx: IntScalarLike):
296-
with self.lock:
297-
idx = int(idx)
298-
bar = self.bars[idx]
299-
bar.n = 100.0
300-
bar.update(n=0)
301-
bar.close()
302-
del self.bars[idx]
303-
304-
305-
_progress_meter_manager = _TqdmProgressMeterManager()
268+
# Not `pure_callback` because it's not a deterministic function of its input
269+
# arguments.
270+
# Not `debug.callback` because it has a return value.
271+
meter_idx = io_callback(_init, jax.ShapeDtypeStruct((), jnp.int32))
272+
return eqxi.nonbatchable(meter_idx)
273+
274+
def step(
275+
self,
276+
step_bar: Callable[[Any, FloatScalarLike], None],
277+
progress: FloatScalarLike,
278+
idx: IntScalarLike,
279+
) -> IntScalarLike:
280+
# Track the slowest batch element.
281+
progress = _unvmap_min(progress)
282+
283+
def _step(_progress, _idx):
284+
with self.lock:
285+
try:
286+
# This may pick up a spurious batch tracer from a batched condition,
287+
# so we need to handle that. We do this by using an `np.unique`.
288+
# It should always be the case that `_idx` has precisely one value!
289+
bar = self.bars[np.unique(_idx).item()]
290+
except KeyError:
291+
pass # E.g. the backward pass after a forward pass.
292+
else:
293+
step_bar(bar, _progress)
294+
# Return the idx to thread the callbacks in the correct order.
295+
return _idx
296+
297+
return jax.pure_callback(_step, idx, progress, idx, vectorized=True) # pyright: ignore
298+
299+
def close(self, close_bar: Callable[[Any], None], idx: IntScalarLike):
300+
def _close(_idx):
301+
with self.lock:
302+
_idx = _idx.item()
303+
bar = self.bars[_idx]
304+
close_bar(bar)
305+
del self.bars[_idx]
306+
307+
# Unlike in `step`, we do the `unvmap_max` here. For mysterious reasons this
308+
# callback does not trigger at all otherwise.
309+
io_callback(_close, None, eqxi.unvmap_max(idx))
310+
311+
312+
_progress_meter_manager = _ProgressMeterManager()

0 commit comments

Comments
 (0)