We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 828f9bf commit f2e6c21Copy full SHA for f2e6c21
bsuite/baselines/jax/boot_dqn/run.py
@@ -28,8 +28,8 @@
28
29
import haiku as hk
30
from jax import lax
31
-from jax.experimental import optix
32
import jax.numpy as jnp
+import optax
33
34
# Internal imports.
35
@@ -70,7 +70,7 @@ def network(inputs: jnp.ndarray) -> jnp.ndarray:
70
x = hk.Flatten()(inputs)
71
return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
72
73
- optimizer = optix.adam(learning_rate=1e-3)
+ optimizer = optax.adam(learning_rate=1e-3)
74
75
agent = boot_dqn.BootstrappedDqn(
76
obs_spec=env.observation_spec(),
0 commit comments