Skip to content

Commit 8045155

Browse files
[BugFix] Fix meltingpot buffer transforms (#148)
* fixes * empty
1 parent 9813807 commit 8045155

File tree

3 files changed

+8
-13
lines changed

3 files changed

+8
-13
lines changed

benchmarl/environments/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,14 @@ def get_env_transforms(self, env: EnvBase) -> List[Transform]:
282282
"""
283283
return []
284284

285-
def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]:
285+
def get_replay_buffer_transforms(self, env: EnvBase, group: str) -> List[Transform]:
286286
"""
287-
Returns a list of :class:`torchrl.envs.Transform` to be applied to the :class:`torchrl.data.ReplayBuffer`.
287+
Returns a list of :class:`torchrl.envs.Transform` to be applied to the :class:`torchrl.data.ReplayBuffer`
288+
of the specified group.
288289
289290
Args:
290291
env (EnvBase): An environment created via self.get_env_fun
291-
292+
group (str): The agent group using the replay buffer
292293
293294
"""
294295
return []

benchmarl/environments/meltingpot/common.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,16 @@ def get_env_transforms(self, env: EnvBase) -> List[Transform]:
125125
else []
126126
)
127127

128-
def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]:
128+
def get_replay_buffer_transforms(self, env: EnvBase, group: str) -> List[Transform]:
129129
return [
130130
DTypeCastTransform(
131131
dtype_in=torch.uint8,
132132
dtype_out=torch.float,
133133
in_keys=[
134134
"RGB",
135-
*[
136-
(group, "observation", "RGB")
137-
for group in self.group_map(env).keys()
138-
],
135+
(group, "observation", "RGB"),
139136
("next", "RGB"),
140-
*[
141-
("next", group, "observation", "RGB")
142-
for group in self.group_map(env).keys()
143-
],
137+
("next", group, "observation", "RGB"),
144138
],
145139
in_keys_inv=[],
146140
)

benchmarl/experiment/experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def _setup_algorithm(self):
466466
self.replay_buffers = {
467467
group: self.algorithm.get_replay_buffer(
468468
group=group,
469-
transforms=self.task.get_replay_buffer_transforms(self.test_env),
469+
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
470470
)
471471
for group in self.group_map.keys()
472472
}

0 commit comments

Comments
 (0)