Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
from torchrl.data.replay_buffers import (
PrioritizedSampler,
RandomSampler,
SamplerWithoutReplacement,
)
from torchrl.envs import Compose, EnvBase, Transform
from torchrl.objectives import LossModule
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater
Expand Down Expand Up @@ -158,7 +162,17 @@ def get_replay_buffer(
memory_size = -(-memory_size // sequence_length)
sampling_size = -(-sampling_size // sequence_length)

sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler()
if self.on_policy:
sampler = SamplerWithoutReplacement()
elif self.experiment_config.off_policy_use_prioritized_replay_buffer:
sampler = PrioritizedSampler(
memory_size,
self.experiment_config.off_policy_prb_alpha,
self.experiment_config.off_policy_prb_beta,
)
else:
sampler = RandomSampler()

return TensorDictReplayBuffer(
storage=LazyTensorStorage(
memory_size,
Expand Down
17 changes: 12 additions & 5 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ share_policy_params: True
prefer_continuous_actions: True
# If False collection is done using a collector (under no grad). If True, collection is done with gradients.
collect_with_grad: False
# In case of non-vectorized environments, weather to run collection of multiple processes
# In case of non-vectorized environments, whether to run collection of multiple processes
# If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker frames each
parallel_collection: False

Expand All @@ -34,7 +34,7 @@ clip_grad_val: 5
soft_target_update: True
# If soft_target_update is True, this is its polyak_tau
polyak_tau: 0.005
# If soft_target_update is False, this is the frequency of the hard trarget updates in terms of n_optimizer_steps
# If soft_target_update is False, this is the frequency of the hard target updates in terms of n_optimizer_steps
hard_target_update_frequency: 5

# When an exploration wrapper is used. This is its initial epsilon for annealing
Expand All @@ -54,7 +54,7 @@ max_n_frames: 3_000_000
on_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection.
# Otherwise batching will be simulated and each env will be run sequentially or parallelly depending on parallel_collection.
on_policy_n_envs_per_worker: 10
# This is the number of times collected_frames_per_batch will be split into minibatches and trained
on_policy_n_minibatch_iters: 45
Expand All @@ -66,7 +66,7 @@ on_policy_minibatch_size: 400
off_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection.
# Otherwise batching will be simulated and each env will be run sequentially or parallelly depending on parallel_collection.
off_policy_n_envs_per_worker: 10
# This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over.
off_policy_n_optimizer_steps: 1000
Expand All @@ -76,6 +76,13 @@ off_policy_train_batch_size: 128
off_policy_memory_size: 1_000_000
# Number of random action frames to prefill the replay buffer with
off_policy_init_random_frames: 0
# whether to use priorities while sampling from the replay buffer
off_policy_use_prioritized_replay_buffer: False
# exponent that determines how much prioritization is used when off_policy_use_prioritized_replay_buffer = True
# PRB reduces to random sampling when alpha=0
off_policy_prb_alpha: 0.6
# importance sampling negative exponent when off_policy_use_prioritized_replay_buffer = True
off_policy_prb_beta: 0.4


evaluation: True
Expand Down Expand Up @@ -108,7 +115,7 @@ restore_map_location: null
# Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch).
# Set it to 0 to disable checkpointing
checkpoint_interval: 0
# Wether to checkpoint when the experiment is done
# Whether to checkpoint when the experiment is done
checkpoint_at_end: False
# How many checkpoints to keep. As new checkpoints are taken, temporally older checkpoints are deleted to keep this number of
# checkpoints. The checkpoint at the end is included in this number. Set to `null` to keep all checkpoints.
Expand Down
3 changes: 3 additions & 0 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class ExperimentConfig:
off_policy_train_batch_size: int = MISSING
off_policy_memory_size: int = MISSING
off_policy_init_random_frames: int = MISSING
off_policy_use_prioritized_replay_buffer: bool = MISSING
off_policy_prb_alpha: float = MISSING
off_policy_prb_beta: float = MISSING

evaluation: bool = MISSING
render: bool = MISSING
Expand Down
Loading