diff --git a/d3rlpy/dataset/episode_generator.py b/d3rlpy/dataset/episode_generator.py index 5857bc15..b01cf5aa 100644 --- a/d3rlpy/dataset/episode_generator.py +++ b/d3rlpy/dataset/episode_generator.py @@ -58,9 +58,9 @@ def __init__( if timeouts is None: timeouts = np.zeros_like(terminals) - assert ( - np.sum(np.logical_and(terminals, timeouts)) == 0 - ), "terminals and timeouts never become True at the same time" + if np.sum(np.logical_and(terminals, timeouts)) != 0: + # In case of overlap, 'terminal' is the important end condition + timeouts[terminals.astype(np.bool_)] = 0.0 assert (np.sum(terminals) + np.sum(timeouts)) > 0, ( "No episode termination was found. Either terminals" " or timeouts must include non-zero values." diff --git a/reproductions/offline/qdt.py b/reproductions/offline/qdt.py index 0886f34e..67da1b27 100644 --- a/reproductions/offline/qdt.py +++ b/reproductions/offline/qdt.py @@ -120,8 +120,9 @@ def relabel_dataset_rtg( sampled_actions = q_algo.sample_action(episode.observations) v = q_algo.predict_value(episode.observations, sampled_actions) values.append( - v if q_algo.reward_scaler is None - else q_algo.reward_scaler.reverse_transform(v) + v + if q_algo.reward_scaler is None + else q_algo.reward_scaler.reverse_transform(v) ) value = np.array(values).mean(axis=0) rewards = np.squeeze(episode.rewards, axis=1) diff --git a/tests/dataset/test_episode_generator.py b/tests/dataset/test_episode_generator.py index 79ea7eea..417adc11 100644 --- a/tests/dataset/test_episode_generator.py +++ b/tests/dataset/test_episode_generator.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest @@ -10,9 +12,14 @@ @pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("length", [1000]) -@pytest.mark.parametrize("terminal", [False, True]) +@pytest.mark.parametrize( + "episode_end_type", ["terminal", "truncated", "overlap"] +) def test_episode_generator( - observation_shape: Shape, action_size: int, length: int, terminal: bool + observation_shape: Shape, + action_size: int, + length: int, + episode_end_type: str, ) -> None: observations = create_observations(observation_shape, length) actions = np.random.random((length, action_size)) @@ -20,9 +27,12 @@ def test_episode_generator( terminals: Float32NDArray = np.zeros(length, dtype=np.float32) timeouts: Float32NDArray = np.zeros(length, dtype=np.float32) for i in range(length // 100): - if terminal: + if episode_end_type == "terminal": terminals[(i + 1) * 100 - 1] = 1.0 + terminal = True else: + terminal = False + if episode_end_type == "truncated" or episode_end_type == "overlap": timeouts[(i + 1) * 100 - 1] = 1.0 episode_generator = EpisodeGenerator( @@ -48,3 +58,25 @@ def test_episode_generator( assert episode.actions.shape == (100, action_size) assert episode.rewards.shape == (100, 1) assert episode.terminated == terminal + + +def test_episode_generator_raises_on_no_termination() -> None: + observations = create_observations((4,), 100) + actions = np.zeros((100, 2)) + rewards: Float32NDArray = np.zeros((100, 1), dtype=np.float32) + terminals = np.zeros(100, dtype=np.float32) + timeouts = np.zeros(100, dtype=np.float32) + + expected_msg = ( + "No episode termination was found. " + "Either terminals or timeouts must include non-zero values." + ) + + with pytest.raises(AssertionError, match=re.escape(expected_msg)): + EpisodeGenerator( + observations=observations, + actions=actions, + rewards=rewards, + terminals=terminals, + timeouts=timeouts, + )