Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
Changelog
==========

Pre-Release 0.9.0a2 (WIP)
Pre-Release 0.9.0 (2020-10-03)
------------------------------

**Bug fixes, get/set parameters and improved docs**

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed ``device`` keyword argument of policies; use ``policy.to(device)`` instead. (@qxcv)
Expand Down Expand Up @@ -50,6 +52,7 @@ Others:
- Clarified docstrings on what is saved and loaded to/from files
- Simplified ``save_to_zip_file`` function by removing duplicate code
- Store library version along with the saved models
- DQN loss is now logged

Documentation:
^^^^^^^^^^^^^^
Expand Down
3 changes: 3 additions & 0 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)

losses = []
for gradient_step in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
Expand All @@ -169,6 +170,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:

# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q, target_q)
losses.append(loss.item())

# Optimize the policy
self.policy.optimizer.zero_grad()
Expand All @@ -181,6 +183,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
self._n_updates += gradient_steps

logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/loss", np.mean(losses))

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.0a2
0.9.0