Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions docs/src/How_to_implement_a_new_algorithm.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# How to implement a new algorithm

All algorithms in ReinforcementLearning.jl are based on a common `run` function defined in [run.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/core/run.jl) that will be dispatched based on the type of its arguments. As you can see, the run function first performs a check and then calls a "private" `_run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)`, this is the main function we are interested in. It consists of an outer and an inner loop that will repeateadly call `policy(stage, env)`.
All algorithms in ReinforcementLearning.jl are based on a common `run` function defined in [run.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/core/run.jl) that will be dispatched based on the type of its arguments. As you can see, the run function first performs a check and then calls a "private" `_run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)`, this is the main function we are interested in. It consists of an outer and an inner loop that will repeateadly call `optimise!(policy, stage, env)`.

Let's look at it closer in this simplified version (hooks are discussed [here](./How_to_use_hooks.md)):

Expand All @@ -16,36 +16,40 @@ function _run(policy::AbstractPolicy,
while !is_stop
reset!(env)
push!(policy, PreEpisodeStage(), env)
optimise!(policy, PreEpisodeStage())

while !reset_condition(policy, env) # one episode
push!(policy, PreActStage(), env)
optimise!(policy, PreActStage())

action = RLBase.plan!(policy, env)
act!(env, action)

optimise!(policy)

push!(policy, PostActStage(), env)
optimise!(policy, PostActStage())

if check_stop(stop_condition, policy, env)
is_stop = true
push!(policy, PreActStage(), env)
optimise!(policy, PreActStage())
RLBase.plan!(policy, env) # let the policy see the last observation
break
end
end # end of an episode

push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
optimise!(policy, PostEpisodeStage())

end
push!(policy, PostExperimentStage(), env)
end

```

Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` subtype, its action sampling method `Base.push!(policy::PolicyType, env)` and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty of pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm).
Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` subtype, its action sampling method (by overloading `Base.push!(policy::YourPolicyType, env)`) and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty of pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm).

## Using Agents
A better way is to use the policy wrapper `Agent`. An agent is an AbstractPolicy that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `push!(agent, stage, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are
The recommended way is to use the policy wrapper `Agent`. An agent is itself an `AbstractPolicy` that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `push!(agent, stage, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are

```julia
function RLBase.plan!(agent::Agent{P,T,C}, env::AbstractEnv) where {P,T,C}
Expand All @@ -63,18 +67,58 @@ function Base.push!(agent::Agent{P,T,C}, ::PostActStage, env::E) where {P,T,C,E<
end
```

The default behavior at other stages is a no-op. The first function, `RLBase.plan!(agent::Agent, env::AbstractEnv)`, is called at the `action = RLBase.plan!(policy, env)` line. It gets an action from the policy (since you implemented the `RLBase.plan!(your_new_policy, env)` function), then it pushes its `cache` and the action to the trajectory of the agent. Finally, it empties the cache and returns the action (which is immediately applied to env after). At the `PreActStage()` and the `PostActStage()`, the agent simply records the current state of the environment, the returned reward and the terminal state signal to its cache (to be pushed to the trajectory by the first function).
The first function, `RLBase.plan!(agent::Agent, env::AbstractEnv)`, is called at the `action = RLBase.plan!(policy, env)` line. It gets an action from the policy (since you implemented the `RLBase.plan!(your_new_policy, env)` function), then it pushes its `cache` and the action to the trajectory of the agent. Finally, it empties the cache and returns the action (which is immediately applied to env after). At the `PreActStage()` and the `PostActStage()`, the agent simply records the current state of the environment, the returned reward and the terminal state signal to its cache (to be pushed to the trajectory by the first function). Notice that the `PreActStage` push and plan are called at the end of the episode loop to store the very last state and actions in the trajectory.

If you need a different behavior at some stages, then you can overload the `Base.push!(Agent{<:YourPolicyType}, [stage,] env)` or `Base.push!(Agent{<:Any, <: YourTrajectoryType}, [stage,] env)`, depending on whether you have a custom policy or just a custom trajectory. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the `PreActStage()`.
If you need a different behavior at some stages, then you can overload the `Base.push!(Agent{<:YourPolicyType}, [stage,] env)` or `Base.push!(Agent{<:Any, <: YourTrajectoryType}, [stage,] env)`, or `Base.plan!`, depending on whether you have a custom policy or just a custom trajectory. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the `PreActStage()`.

## Updating the policy

Finally, you need to implement the learning function by implementing `RLBase.optimise!(p::YourPolicyType, batch::NamedTuple)` (see that it is called by `optimise!(agent)` then `RLBase.optimise!(p::YourPolicyType, b::Trajectory)`).
In principle you can do the update at other stages by overload the `push!(agent::Agent)` but this is not recommended because the trajectory may not be consistent and samples could be incorrect. If you choose to do it, make sure to know what you are doing.
Finally, you need to implement the learning function by implementing `RLBase.optimise!(::YourPolicyType, ::Stage, ::Trajectory)`. By default this does nothing at all stages. Overload it on the stage where you wish to optimise (most often, at `PreActStage` or `PostEpisodeStage`). This function should loop the trajectory to sample batches. Inside the loop, put whatever is required. For example:

```julia
function RLBase.optimise!(p::YourPolicyType, ::PostEpisodeStage, traj::Trajectory)
for batch in traj
optimise!(p, batch)
end
end

```
where `optimise!(p, batch)` is a function that will typically compute the gradient and update a neural network, or update tabular policy. What is inside the loop is free to be whatever you need. This is further discussed in the next section on `Trajectory`s.

## ReinforcementLearningTrajectories

Trajectories are handled in a stand-alone package called [ReinforcementLearningTrajectories](https://github.com/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl). Refer to its documentation (in progress) to learn how to use it.
Trajectories are handled in a stand-alone package called [ReinforcementLearningTrajectories](https://github.com/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl). However, it is core to the implementation of your algorithm as it controls many aspects of it, such as the batch size, the sampling frequency, or the replay buffer length.
A `Trajectory` is composed of three elements: a `container`, a `controller`, and a `sampler`.

### Container

The container is typically an `AbstractTraces`, an object that store a set of `Trace` in a structured manner. You can either define your own (and contribute it to the package if it is likely to be usable for other algorithms), or use a predefined one if it exists.

The most common `AbstractTraces` object is the `CircularArraySARTTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which toghether are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace.

```julia
function (capacity, state_size, state_eltype, action_size, action_eltype, reward_eltype)
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(1, capacity),
terminal=CircularArrayBuffer{Bool}(1, capacity),
)
end
```
We can see it is composed (with the `+` operator) of two `MultiplexTraces` and a `Traces`.
- `MultiplexTraces` is a special Trace that stores two names in one container. In this case, the two names of the first one are `SS′ = (:state, :next_state)`. When sampled for the `:next_state` at index `i`, it will return the state stored at `i+1`. This way, states and next states are managed together seamlessly (notice however that these must have +1 in their capacity).
- `Traces` is for simpler traces, simply define a name (reward and terminal here) for each and assign them to a container.

The containers used here are `CircularArrayBuffers`. These are preallocated arrays that, once full, will overwrite the oldest element in storage, as if it was circular. It takes as arguments the size of each of its dimensions, where the last one is the capacity of the buffer. For example, if a state is a 256 x 256 image, `state_size` would be a tuple `(256,256)`. For vector states use `(256,)` and for scalars `1` or `()`.

### Controller

ReinforcementLearningTrajectories' design aims to eventually support distributed experience collection, hence the somewhat involved design of trajectories and the presence of a controller. The controller is an object that will decide when the trajectory is ready to be sampled. Let us see with an example of the only controller so far: `InsertSampleRatioController(ratio, threshold)`. Despite its name, it is quite simple: this controller records the number of insertions (`ins`) in the trajectory and the number of batches sampled (`sam`); if `sam/ins > ratio` then the controller will stop the batch sample loop. For example, a ratio of 1/1000 means that one batch will be sampled every 1000 insertions in the trajectory. `threshold` is simply a minimum number of insertions required before the the controller starts sampling.

### Sampler

The sampler is the object that will fetch data in your trajectory to create the `batch` in the optimise for loop. The simplest one is the `BatchSampler{names}(batchsize, rng)`.`batchsize` is the number of elements to sample and `rng` is an optional argument that you may set to a custom rng for reproducibility. `names` is the set of traces the sampler must query. For example a `BatchSampler{(:state, :action, :next_state)}(32)` will sample a named tuple `(state = [32 states], action=[32 actions], next_state=[32 states that are one-off with respect that in state])`.

## Using resources from RLCore

Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningBase/src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ action is determined by a `plan!` which takes an environment and policy and retu
@api plan!(π::AbstractPolicy, env)

"""
optimise!(π::AbstractPolicy, experience)
RLBase.optimise!(π::AbstractPolicy, experience)

Optimise the policy `π` with online/offline experience or parameters.
"""
Expand Down
11 changes: 4 additions & 7 deletions src/ReinforcementLearningCore/src/policies/agent/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,15 @@ end

Agent(;policy, trajectory, cache = SRT()) = Agent(policy, trajectory, cache)

RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} =RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} =
optimise!(agent.policy, stage, agent.trajectory)
RLBase.optimise!(agent.policy, stage, agent.trajectory)

# already spawn a task to optimise inner policy when initializing the agent
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing

function RLBase.optimise!(policy::AbstractPolicy, stage::S, trajectory::Trajectory) where {S<:AbstractStage}
for batch in trajectory
optimise!(policy, stage, batch)
end
end
#by default, optimise does nothing at all stage
function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end

@functor Agent (policy,)

Expand Down
6 changes: 5 additions & 1 deletion src/ReinforcementLearningCore/src/policies/q_based_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ end
RLBase.prob(p::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
prob(p.explorer, forward(p.learner, env), legal_action_space_mask(env))

RLBase.optimise!(p::QBasedPolicy{L,Ex}, stage::S, x::NamedTuple) where {L<:AbstractLearner,Ex<:AbstractExplorer, S<:AbstractStage} = optimise!(p.learner, x)
function RLBase.optimise!(p::QBasedPolicy{L,Ex}, ::PostActStage, trajectory::Trajectory) where {L<:AbstractLearner,Ex<:AbstractExplorer}
for batch in trajectory
RLBase.optimise!(p.learner, batch)
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# ---

#+ tangle=true
using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo, ReinforcementLearningZoo
using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo
using ReinforcementLearningEnvironments
using Flux
using Flux: glorot_uniform
Expand All @@ -33,7 +33,7 @@ function RLCore.Experiment(
Dense(ns, 128, relu; init=glorot_uniform(rng)),
Dense(128, 128, relu; init=glorot_uniform(rng)),
Dense(128, na; init=glorot_uniform(rng)),
) |> gpu,
),
optimiser=Adam(),
),
loss_func=huber_loss,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ function RLCore.Experiment(
agent = Agent(
policy = QBasedPolicy(
learner = BasicDQNLearner(
approximator = NeuralNetworkApproximator(
approximator = Approximator(
model = Chain(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
Dense(64, na; init = glorot_uniform(rng)),
) |> gpu,
),
optimizer = Adam(),
),
batch_size = 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function RLCore.Experiment(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
Dense(64, na; init = glorot_uniform(rng)),
) |> gpu,
),
optimizer = Adam(),
),
batch_size = 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function RLCore.Experiment(
Dense(ns, 128, relu; init=glorot_uniform(rng)),
Dense(128, 128, relu; init=glorot_uniform(rng)),
Dense(128, na; init=glorot_uniform(rng)),
) |> gpu,
),
optimizer=Adam(),
),
batch_size=32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function RLCore.Experiment(
sync_freq=100
),
optimiser=Adam(),
) |> gpu,
),
n=n,
γ=γ,
is_enable_double_DQN=is_enable_double_DQN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ function RLCore.Experiment(
agent = Agent(
policy = QBasedPolicy(
learner = DQNLearner(
approximator = NeuralNetworkApproximator(
approximator = Approximator(
model = Chain(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
Dense(64, na; init = glorot_uniform(rng)),
) |> gpu,
),
optimizer = Adam(),
),
target_approximator = NeuralNetworkApproximator(
target_approximator = Approximator(
model = Chain(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
Dense(64, na; init = glorot_uniform(rng)),
) |> gpu,
),
optimizer = Adam(),
),
loss_func = huber_loss,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function RLCore.Experiment(
ψ=Dense(ns, n_hidden, relu; init=init),
ϕ=Dense(Nₑₘ, n_hidden, relu; init=init),
header=Dense(n_hidden, na; init=init),
) |> gpu
)

agent = Agent(
policy=QBasedPolicy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function RLCore.Experiment(

for i in 1:300
_, batch = s(hook.records)
optimise!(bc, batch)
RLBase.optimise!(bc, batch)
end

hook = TotalRewardPerEpisode()
Expand Down
10 changes: 5 additions & 5 deletions src/ReinforcementLearningZoo/src/algorithms/dqns/basic_dqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ own customized algorithm.
- `loss_func=huber_loss`: the loss function to use.
- `γ::Float32=0.99f0`: discount rate.
"""
Base.@kwdef mutable struct BasicDQNLearner{Q} <: AbstractLearner
Base.@kwdef mutable struct BasicDQNLearner{Q, F} <: AbstractLearner
approximator::Q
loss_func::Any = huber_loss
loss_func::F = huber_loss
γ::Float32 = 0.99f0
# for debugging
loss::Float32 = 0.0f0
Expand All @@ -37,13 +37,13 @@ RLCore.forward(L::BasicDQNLearner, s::AbstractArray) = RLCore.forward(L.approxim

function RLCore.optimise!(
learner::BasicDQNLearner,
batch::NamedTuple{SS′ART},
batch::NamedTuple
)

Q = learner.approximator
γ = learner.γ
loss_func = learner.loss_func

s, s′, a, r, t = send_to_device(device(Q), batch)
a = CartesianIndex.(a, 1:length(a))

Expand All @@ -58,5 +58,5 @@ function RLCore.optimise!(
loss
end

optimise!(Q, gs)
RLBase.optimise!(Q, gs)
end
Loading