1+ import re
2+
13import numpy as np
24import pytest
35
1012@pytest .mark .parametrize ("observation_shape" , [(4 ,), ((4 ,), (8 ,))])
1113@pytest .mark .parametrize ("action_size" , [2 ])
1214@pytest .mark .parametrize ("length" , [1000 ])
13- @pytest .mark .parametrize ("terminal" , [False , True ])
15+ @pytest .mark .parametrize (
16+ "episode_end_type" , ["terminal" , "truncated" , "overlap" ]
17+ )
1418def test_episode_generator (
15- observation_shape : Shape , action_size : int , length : int , terminal : bool
19+ observation_shape : Shape ,
20+ action_size : int ,
21+ length : int ,
22+ episode_end_type : str ,
1623) -> None :
1724 observations = create_observations (observation_shape , length )
1825 actions = np .random .random ((length , action_size ))
1926 rewards : Float32NDArray = np .random .random ((length , 1 )).astype (np .float32 )
2027 terminals : Float32NDArray = np .zeros (length , dtype = np .float32 )
2128 timeouts : Float32NDArray = np .zeros (length , dtype = np .float32 )
2229 for i in range (length // 100 ):
23- if terminal :
30+ if episode_end_type == " terminal" :
2431 terminals [(i + 1 ) * 100 - 1 ] = 1.0
32+ terminal = True
2533 else :
34+ terminal = False
35+ if episode_end_type == "truncated" or episode_end_type == "overlap" :
2636 timeouts [(i + 1 ) * 100 - 1 ] = 1.0
2737
2838 episode_generator = EpisodeGenerator (
@@ -48,3 +58,25 @@ def test_episode_generator(
4858 assert episode .actions .shape == (100 , action_size )
4959 assert episode .rewards .shape == (100 , 1 )
5060 assert episode .terminated == terminal
61+
62+
63+ def test_episode_generator_raises_on_no_termination () -> None :
64+ observations = create_observations ((4 ,), 100 )
65+ actions = np .zeros ((100 , 2 ))
66+ rewards : Float32NDArray = np .zeros ((100 , 1 ), dtype = np .float32 )
67+ terminals = np .zeros (100 , dtype = np .float32 )
68+ timeouts = np .zeros (100 , dtype = np .float32 )
69+
70+ expected_msg = (
71+ "No episode termination was found. "
72+ "Either terminals or timeouts must include non-zero values."
73+ )
74+
75+ with pytest .raises (AssertionError , match = re .escape (expected_msg )):
76+ EpisodeGenerator (
77+ observations = observations ,
78+ actions = actions ,
79+ rewards = rewards ,
80+ terminals = terminals ,
81+ timeouts = timeouts ,
82+ )
0 commit comments