Skip to content

Commit 38c34b6

Browse files
authored
handle overlapping terminal and timeout flags by prioritizing terminals (#463)
1 parent 4f0956b commit 38c34b6

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

d3rlpy/dataset/episode_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def __init__(
5858
if timeouts is None:
5959
timeouts = np.zeros_like(terminals)
6060

61-
assert (
62-
np.sum(np.logical_and(terminals, timeouts)) == 0
63-
), "terminals and timeouts never become True at the same time"
61+
if np.sum(np.logical_and(terminals, timeouts)) != 0:
62+
# In case of overlap, 'terminal' is the important end condition
63+
timeouts[terminals.astype(np.bool_)] = 0.0
6464
assert (np.sum(terminals) + np.sum(timeouts)) > 0, (
6565
"No episode termination was found. Either terminals"
6666
" or timeouts must include non-zero values."

reproductions/offline/qdt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def relabel_dataset_rtg(
120120
sampled_actions = q_algo.sample_action(episode.observations)
121121
v = q_algo.predict_value(episode.observations, sampled_actions)
122122
values.append(
123-
v if q_algo.reward_scaler is None
124-
else q_algo.reward_scaler.reverse_transform(v)
123+
v
124+
if q_algo.reward_scaler is None
125+
else q_algo.reward_scaler.reverse_transform(v)
125126
)
126127
value = np.array(values).mean(axis=0)
127128
rewards = np.squeeze(episode.rewards, axis=1)

tests/dataset/test_episode_generator.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
import numpy as np
24
import pytest
35

@@ -10,19 +12,27 @@
1012
@pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))])
1113
@pytest.mark.parametrize("action_size", [2])
1214
@pytest.mark.parametrize("length", [1000])
13-
@pytest.mark.parametrize("terminal", [False, True])
15+
@pytest.mark.parametrize(
16+
"episode_end_type", ["terminal", "truncated", "overlap"]
17+
)
1418
def test_episode_generator(
15-
observation_shape: Shape, action_size: int, length: int, terminal: bool
19+
observation_shape: Shape,
20+
action_size: int,
21+
length: int,
22+
episode_end_type: str,
1623
) -> None:
1724
observations = create_observations(observation_shape, length)
1825
actions = np.random.random((length, action_size))
1926
rewards: Float32NDArray = np.random.random((length, 1)).astype(np.float32)
2027
terminals: Float32NDArray = np.zeros(length, dtype=np.float32)
2128
timeouts: Float32NDArray = np.zeros(length, dtype=np.float32)
2229
for i in range(length // 100):
23-
if terminal:
30+
if episode_end_type == "terminal":
2431
terminals[(i + 1) * 100 - 1] = 1.0
32+
terminal = True
2533
else:
34+
terminal = False
35+
if episode_end_type == "truncated" or episode_end_type == "overlap":
2636
timeouts[(i + 1) * 100 - 1] = 1.0
2737

2838
episode_generator = EpisodeGenerator(
@@ -48,3 +58,25 @@ def test_episode_generator(
4858
assert episode.actions.shape == (100, action_size)
4959
assert episode.rewards.shape == (100, 1)
5060
assert episode.terminated == terminal
61+
62+
63+
def test_episode_generator_raises_on_no_termination() -> None:
64+
observations = create_observations((4,), 100)
65+
actions = np.zeros((100, 2))
66+
rewards: Float32NDArray = np.zeros((100, 1), dtype=np.float32)
67+
terminals = np.zeros(100, dtype=np.float32)
68+
timeouts = np.zeros(100, dtype=np.float32)
69+
70+
expected_msg = (
71+
"No episode termination was found. "
72+
"Either terminals or timeouts must include non-zero values."
73+
)
74+
75+
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
76+
EpisodeGenerator(
77+
observations=observations,
78+
actions=actions,
79+
rewards=rewards,
80+
terminals=terminals,
81+
timeouts=timeouts,
82+
)

0 commit comments

Comments
 (0)