diff --git a/docs/src/How_to_implement_a_new_algorithm.md b/docs/src/How_to_implement_a_new_algorithm.md index ab262ff29..6e723f7c2 100644 --- a/docs/src/How_to_implement_a_new_algorithm.md +++ b/docs/src/How_to_implement_a_new_algorithm.md @@ -1,6 +1,6 @@ # How to implement a new algorithm -All algorithms in ReinforcementLearning.jl are based on a common `run` function defined in [run.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/core/run.jl) that will be dispatched based on the type of its arguments. As you can see, the run function first performs a check and then calls a "private" `_run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)`, this is the main function we are interested in. It consists of an outer and an inner loop that will repeateadly call `policy(stage, env)`. +All algorithms in ReinforcementLearning.jl are based on a common `run` function defined in [run.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/core/run.jl) that will be dispatched based on the type of its arguments. As you can see, the run function first performs a check and then calls a "private" `_run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)`, this is the main function we are interested in. It consists of an outer and an inner loop that will repeateadly call `optimise!(policy, stage, env)`. Let's look at it closer in this simplified version (hooks are discussed [here](./How_to_use_hooks.md)): @@ -16,36 +16,40 @@ function _run(policy::AbstractPolicy, while !is_stop reset!(env) push!(policy, PreEpisodeStage(), env) + optimise!(policy, PreEpisodeStage()) while !reset_condition(policy, env) # one episode push!(policy, PreActStage(), env) + optimise!(policy, PreActStage()) action = RLBase.plan!(policy, env) act!(env, action) - optimise!(policy) - push!(policy, PostActStage(), env) + optimise!(policy, PostActStage()) if check_stop(stop_condition, policy, env) is_stop = true push!(policy, PreActStage(), env) + optimise!(policy, PreActStage()) RLBase.plan!(policy, env) # let the policy see the last observation break end end # end of an episode push!(policy, PostEpisodeStage(), env) # let the policy see the last observation + optimise!(policy, PostEpisodeStage()) + end push!(policy, PostExperimentStage(), env) end ``` -Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` subtype, its action sampling method `Base.push!(policy::PolicyType, env)` and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty of pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm). +Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` subtype, its action sampling method (by overloading `Base.push!(policy::YourPolicyType, env)`) and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty of pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm). ## Using Agents -A better way is to use the policy wrapper `Agent`. An agent is an AbstractPolicy that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `push!(agent, stage, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are +The recommended way is to use the policy wrapper `Agent`. An agent is itself an `AbstractPolicy` that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `push!(agent, stage, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are ```julia function RLBase.plan!(agent::Agent{P,T,C}, env::AbstractEnv) where {P,T,C} @@ -63,18 +67,58 @@ function Base.push!(agent::Agent{P,T,C}, ::PostActStage, env::E) where {P,T,C,E< end ``` -The default behavior at other stages is a no-op. The first function, `RLBase.plan!(agent::Agent, env::AbstractEnv)`, is called at the `action = RLBase.plan!(policy, env)` line. It gets an action from the policy (since you implemented the `RLBase.plan!(your_new_policy, env)` function), then it pushes its `cache` and the action to the trajectory of the agent. Finally, it empties the cache and returns the action (which is immediately applied to env after). At the `PreActStage()` and the `PostActStage()`, the agent simply records the current state of the environment, the returned reward and the terminal state signal to its cache (to be pushed to the trajectory by the first function). + The first function, `RLBase.plan!(agent::Agent, env::AbstractEnv)`, is called at the `action = RLBase.plan!(policy, env)` line. It gets an action from the policy (since you implemented the `RLBase.plan!(your_new_policy, env)` function), then it pushes its `cache` and the action to the trajectory of the agent. Finally, it empties the cache and returns the action (which is immediately applied to env after). At the `PreActStage()` and the `PostActStage()`, the agent simply records the current state of the environment, the returned reward and the terminal state signal to its cache (to be pushed to the trajectory by the first function). Notice that the `PreActStage` push and plan are called at the end of the episode loop to store the very last state and actions in the trajectory. -If you need a different behavior at some stages, then you can overload the `Base.push!(Agent{<:YourPolicyType}, [stage,] env)` or `Base.push!(Agent{<:Any, <: YourTrajectoryType}, [stage,] env)`, depending on whether you have a custom policy or just a custom trajectory. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the `PreActStage()`. +If you need a different behavior at some stages, then you can overload the `Base.push!(Agent{<:YourPolicyType}, [stage,] env)` or `Base.push!(Agent{<:Any, <: YourTrajectoryType}, [stage,] env)`, or `Base.plan!`, depending on whether you have a custom policy or just a custom trajectory. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the `PreActStage()`. ## Updating the policy -Finally, you need to implement the learning function by implementing `RLBase.optimise!(p::YourPolicyType, batch::NamedTuple)` (see that it is called by `optimise!(agent)` then `RLBase.optimise!(p::YourPolicyType, b::Trajectory)`). -In principle you can do the update at other stages by overload the `push!(agent::Agent)` but this is not recommended because the trajectory may not be consistent and samples could be incorrect. If you choose to do it, make sure to know what you are doing. +Finally, you need to implement the learning function by implementing `RLBase.optimise!(::YourPolicyType, ::Stage, ::Trajectory)`. By default this does nothing at all stages. Overload it on the stage where you wish to optimise (most often, at `PreActStage` or `PostEpisodeStage`). This function should loop the trajectory to sample batches. Inside the loop, put whatever is required. For example: + +```julia +function RLBase.optimise!(p::YourPolicyType, ::PostEpisodeStage, traj::Trajectory) + for batch in traj + optimise!(p, batch) + end +end + +``` +where `optimise!(p, batch)` is a function that will typically compute the gradient and update a neural network, or update tabular policy. What is inside the loop is free to be whatever you need. This is further discussed in the next section on `Trajectory`s. ## ReinforcementLearningTrajectories -Trajectories are handled in a stand-alone package called [ReinforcementLearningTrajectories](https://github.com/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl). Refer to its documentation (in progress) to learn how to use it. +Trajectories are handled in a stand-alone package called [ReinforcementLearningTrajectories](https://github.com/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl). However, it is core to the implementation of your algorithm as it controls many aspects of it, such as the batch size, the sampling frequency, or the replay buffer length. +A `Trajectory` is composed of three elements: a `container`, a `controller`, and a `sampler`. + +### Container + +The container is typically an `AbstractTraces`, an object that store a set of `Trace` in a structured manner. You can either define your own (and contribute it to the package if it is likely to be usable for other algorithms), or use a predefined one if it exists. + +The most common `AbstractTraces` object is the `CircularArraySARTTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which toghether are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace. + +```julia +function (capacity, state_size, state_eltype, action_size, action_eltype, reward_eltype) + MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + + Traces( + reward=CircularArrayBuffer{reward_eltype}(1, capacity), + terminal=CircularArrayBuffer{Bool}(1, capacity), + ) +end +``` +We can see it is composed (with the `+` operator) of two `MultiplexTraces` and a `Traces`. +- `MultiplexTraces` is a special Trace that stores two names in one container. In this case, the two names of the first one are `SS′ = (:state, :next_state)`. When sampled for the `:next_state` at index `i`, it will return the state stored at `i+1`. This way, states and next states are managed together seamlessly (notice however that these must have +1 in their capacity). +- `Traces` is for simpler traces, simply define a name (reward and terminal here) for each and assign them to a container. + +The containers used here are `CircularArrayBuffers`. These are preallocated arrays that, once full, will overwrite the oldest element in storage, as if it was circular. It takes as arguments the size of each of its dimensions, where the last one is the capacity of the buffer. For example, if a state is a 256 x 256 image, `state_size` would be a tuple `(256,256)`. For vector states use `(256,)` and for scalars `1` or `()`. + +### Controller + +ReinforcementLearningTrajectories' design aims to eventually support distributed experience collection, hence the somewhat involved design of trajectories and the presence of a controller. The controller is an object that will decide when the trajectory is ready to be sampled. Let us see with an example of the only controller so far: `InsertSampleRatioController(ratio, threshold)`. Despite its name, it is quite simple: this controller records the number of insertions (`ins`) in the trajectory and the number of batches sampled (`sam`); if `sam/ins > ratio` then the controller will stop the batch sample loop. For example, a ratio of 1/1000 means that one batch will be sampled every 1000 insertions in the trajectory. `threshold` is simply a minimum number of insertions required before the the controller starts sampling. + +### Sampler + +The sampler is the object that will fetch data in your trajectory to create the `batch` in the optimise for loop. The simplest one is the `BatchSampler{names}(batchsize, rng)`.`batchsize` is the number of elements to sample and `rng` is an optional argument that you may set to a custom rng for reproducibility. `names` is the set of traces the sampler must query. For example a `BatchSampler{(:state, :action, :next_state)}(32)` will sample a named tuple `(state = [32 states], action=[32 actions], next_state=[32 states that are one-off with respect that in state])`. ## Using resources from RLCore diff --git a/src/ReinforcementLearningBase/src/interface.jl b/src/ReinforcementLearningBase/src/interface.jl index b3e15d370..47f57d781 100644 --- a/src/ReinforcementLearningBase/src/interface.jl +++ b/src/ReinforcementLearningBase/src/interface.jl @@ -39,7 +39,7 @@ action is determined by a `plan!` which takes an environment and policy and retu @api plan!(π::AbstractPolicy, env) """ - optimise!(π::AbstractPolicy, experience) + RLBase.optimise!(π::AbstractPolicy, experience) Optimise the policy `π` with online/offline experience or parameters. """ diff --git a/src/ReinforcementLearningCore/src/policies/agent/base.jl b/src/ReinforcementLearningCore/src/policies/agent/base.jl index dc7798dc2..7654de4d6 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/base.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/base.jl @@ -39,18 +39,15 @@ end Agent(;policy, trajectory, cache = SRT()) = Agent(policy, trajectory, cache) -RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = optimise!(TrajectoryStyle(agent.trajectory), agent, stage) +RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} =RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage) RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = - optimise!(agent.policy, stage, agent.trajectory) + RLBase.optimise!(agent.policy, stage, agent.trajectory) # already spawn a task to optimise inner policy when initializing the agent RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing -function RLBase.optimise!(policy::AbstractPolicy, stage::S, trajectory::Trajectory) where {S<:AbstractStage} - for batch in trajectory - optimise!(policy, stage, batch) - end -end +#by default, optimise does nothing at all stage +function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end @functor Agent (policy,) diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl index cb3376038..2c19f9b92 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl @@ -37,4 +37,8 @@ end RLBase.prob(p::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} = prob(p.explorer, forward(p.learner, env), legal_action_space_mask(env)) -RLBase.optimise!(p::QBasedPolicy{L,Ex}, stage::S, x::NamedTuple) where {L<:AbstractLearner,Ex<:AbstractExplorer, S<:AbstractStage} = optimise!(p.learner, x) +function RLBase.optimise!(p::QBasedPolicy{L,Ex}, ::PostActStage, trajectory::Trajectory) where {L<:AbstractLearner,Ex<:AbstractExplorer} + for batch in trajectory + RLBase.optimise!(p.learner, batch) + end +end \ No newline at end of file diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl index afe28701d..5a43c8c01 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl @@ -7,7 +7,7 @@ # --- #+ tangle=true -using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo, ReinforcementLearningZoo +using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo using ReinforcementLearningEnvironments using Flux using Flux: glorot_uniform @@ -33,7 +33,7 @@ function RLCore.Experiment( Dense(ns, 128, relu; init=glorot_uniform(rng)), Dense(128, 128, relu; init=glorot_uniform(rng)), Dense(128, na; init=glorot_uniform(rng)), - ) |> gpu, + ), optimiser=Adam(), ), loss_func=huber_loss, diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl index 234a7506c..f3f1046e8 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl @@ -25,12 +25,12 @@ function RLCore.Experiment( agent = Agent( policy = QBasedPolicy( learner = BasicDQNLearner( - approximator = NeuralNetworkApproximator( + approximator = Approximator( model = Chain( Dense(ns, 64, relu; init = glorot_uniform(rng)), Dense(64, 64, relu; init = glorot_uniform(rng)), Dense(64, na; init = glorot_uniform(rng)), - ) |> gpu, + ), optimizer = Adam(), ), batch_size = 32, diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_PendulumDiscrete.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_PendulumDiscrete.jl index ee3ceaa56..ecd2c24c4 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_PendulumDiscrete.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_PendulumDiscrete.jl @@ -30,7 +30,7 @@ function RLCore.Experiment( Dense(ns, 64, relu; init = glorot_uniform(rng)), Dense(64, 64, relu; init = glorot_uniform(rng)), Dense(64, na; init = glorot_uniform(rng)), - ) |> gpu, + ), optimizer = Adam(), ), batch_size = 32, diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl index e6c7c2877..ca17bc056 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl @@ -37,7 +37,7 @@ function RLCore.Experiment( Dense(ns, 128, relu; init=glorot_uniform(rng)), Dense(128, 128, relu; init=glorot_uniform(rng)), Dense(128, na; init=glorot_uniform(rng)), - ) |> gpu, + ), optimizer=Adam(), ), batch_size=32, diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl index 974f852b9..3f481eff4 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl @@ -41,7 +41,7 @@ function RLCore.Experiment( sync_freq=100 ), optimiser=Adam(), - ) |> gpu, + ), n=n, γ=γ, is_enable_double_DQN=is_enable_double_DQN, diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl index 5b0f181e7..1e2e752b4 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl @@ -26,20 +26,20 @@ function RLCore.Experiment( agent = Agent( policy = QBasedPolicy( learner = DQNLearner( - approximator = NeuralNetworkApproximator( + approximator = Approximator( model = Chain( Dense(ns, 64, relu; init = glorot_uniform(rng)), Dense(64, 64, relu; init = glorot_uniform(rng)), Dense(64, na; init = glorot_uniform(rng)), - ) |> gpu, + ), optimizer = Adam(), ), - target_approximator = NeuralNetworkApproximator( + target_approximator = Approximator( model = Chain( Dense(ns, 64, relu; init = glorot_uniform(rng)), Dense(64, 64, relu; init = glorot_uniform(rng)), Dense(64, na; init = glorot_uniform(rng)), - ) |> gpu, + ), optimizer = Adam(), ), loss_func = huber_loss, diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl index a1a36d318..0bc2832ff 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl @@ -34,7 +34,7 @@ function RLCore.Experiment( ψ=Dense(ns, n_hidden, relu; init=init), ϕ=Dense(Nₑₘ, n_hidden, relu; init=init), header=Dense(n_hidden, na; init=init), - ) |> gpu + ) agent = Agent( policy=QBasedPolicy( diff --git a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl index debc20fa4..9362db4c1 100644 --- a/src/ReinforcementLearningExperiments/deps/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl +++ b/src/ReinforcementLearningExperiments/deps/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl @@ -82,7 +82,7 @@ function RLCore.Experiment( for i in 1:300 _, batch = s(hook.records) - optimise!(bc, batch) + RLBase.optimise!(bc, batch) end hook = TotalRewardPerEpisode() diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/basic_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/basic_dqn.jl index a77679c8b..49e9e8f5f 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/basic_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/basic_dqn.jl @@ -23,9 +23,9 @@ own customized algorithm. - `loss_func=huber_loss`: the loss function to use. - `γ::Float32=0.99f0`: discount rate. """ -Base.@kwdef mutable struct BasicDQNLearner{Q} <: AbstractLearner +Base.@kwdef mutable struct BasicDQNLearner{Q, F} <: AbstractLearner approximator::Q - loss_func::Any = huber_loss + loss_func::F = huber_loss γ::Float32 = 0.99f0 # for debugging loss::Float32 = 0.0f0 @@ -37,13 +37,13 @@ RLCore.forward(L::BasicDQNLearner, s::AbstractArray) = RLCore.forward(L.approxim function RLCore.optimise!( learner::BasicDQNLearner, - batch::NamedTuple{SS′ART}, + batch::NamedTuple ) Q = learner.approximator γ = learner.γ loss_func = learner.loss_func - + s, s′, a, r, t = send_to_device(device(Q), batch) a = CartesianIndex.(a, 1:length(a)) @@ -58,5 +58,5 @@ function RLCore.optimise!( loss end - optimise!(Q, gs) + RLBase.optimise!(Q, gs) end diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl index fd3757066..cdae7ab6f 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl @@ -1,15 +1,15 @@ export DQNLearner -using Random: AbstractRNG, GLOBAL_RNG +using Random: AbstractRNG using Functors: @functor -Base.@kwdef mutable struct DQNLearner{A<:Approximator{<:TwinNetwork}} <: AbstractLearner +Base.@kwdef mutable struct DQNLearner{A<:Approximator{<:TwinNetwork}, F, R} <: AbstractLearner approximator::A - loss_func::Any + loss_func::F n::Int = 1 γ::Float32 = 0.99f0 is_enable_double_DQN::Bool = true - rng::AbstractRNG = GLOBAL_RNG + rng::R = Random.default_rng() # for logging loss::Float32 = 0.0f0 end @@ -18,7 +18,7 @@ RLCore.forward(L::DQNLearner, s::A) where {A<:AbstractArray} = RLCore.forward(L @functor DQNLearner (approximator,) -function RLBase.optimise!(learner::DQNLearner, batch::Union{NamedTuple{SS′ART},NamedTuple{SS′L′ART}}) +function RLBase.optimise!(learner::DQNLearner, batch::NamedTuple) A = learner.approximator Q = A.model.source Qₜ = A.model.target @@ -48,5 +48,5 @@ function RLBase.optimise!(learner::DQNLearner, batch::Union{NamedTuple{SS′ART} loss end - optimise!(A, gs) + RLBase.optimise!(A, gs) end diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/iqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/iqn.jl index 485a2751e..33a6302c8 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/iqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/iqn.jl @@ -2,7 +2,7 @@ export IQNLearner, ImplicitQuantileNet using Functors: @functor using Flux: params, unsqueeze, gradient -using Random: AbstractRNG, GLOBAL_RNG +import Random using StatsBase: mean using ChainRulesCore: ignore_derivatives @@ -36,7 +36,7 @@ function (net::ImplicitQuantileNet)(s, emb) reshape(quantiles, :, size(merged, 2), size(merged, 3)) # (n_action, N, batch_size) end -Base.@kwdef mutable struct IQNLearner{A<:Approximator{<:TwinNetwork}} <: AbstractLearner +Base.@kwdef mutable struct IQNLearner{A<:Approximator{<:TwinNetwork}, R1, R2} <: AbstractLearner approximator::A γ::Float32 = 0.99f0 κ::Float32 = 1.0f0 @@ -44,8 +44,8 @@ Base.@kwdef mutable struct IQNLearner{A<:Approximator{<:TwinNetwork}} <: Abstrac N′::Int = 32 Nₑₘ::Int = 64 K::Int = 32 - rng::AbstractRNG = GLOBAL_RNG - device_rng::AbstractRNG = rng + rng::R1 = default_rng() + device_rng::R2 = rng # for logging loss::Float32 = 0.0f0 end @@ -77,7 +77,7 @@ function RLBase.optimise!(learner::IQNLearner, batch::NamedTuple) N′ = learner.N′ Nₑₘ = learner.Nₑₘ κ = learner.κ - + s, s′, a, r, t = map(x -> batch[x], SS′ART) batch_size = length(t) τ′ = rand(learner.device_rng, Float32, N′, batch_size) # TODO: support β distribution @@ -126,5 +126,5 @@ function RLBase.optimise!(learner::IQNLearner, batch::NamedTuple) loss end - optimise!(A, gs) + RLBase.optimise!(A, gs) end diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl index 5a5f4a7b8..407808a1a 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl @@ -1,19 +1,19 @@ export PrioritizedDQNLearner -using Random: AbstractRNG, GLOBAL_RNG +import Random using Functors: @functor using LinearAlgebra: dot using Flux using Flux: gradient, params -Base.@kwdef mutable struct PrioritizedDQNLearner{A<:Approximator{<:TwinNetwork}} <: AbstractLearner +Base.@kwdef mutable struct PrioritizedDQNLearner{A<:Approximator{<:TwinNetwork}, R} <: AbstractLearner approximator::A loss_func::Any # !!! here the loss func must return the loss before reducing over the batch dimension n::Int = 1 γ::Float32 = 0.99f0 β_priority::Float32 = 0.5f0 is_enable_double_DQN::Bool = true - rng::AbstractRNG = GLOBAL_RNG + rng::R = Random.default_rng() # for logging loss::Float32 = 0.0f0 end @@ -67,13 +67,13 @@ function RLBase.optimise!( loss end - optimise!(A, gs) - k => p′ + RLBase.optimise!(A, gs) + k, p′ end -function RLBase.optimise!(policy::QBasedPolicy{<:PrioritizedDQNLearner}, ::PostActStage, trajectory::Trajectory) +function RLBase.optimise!(policy::QBasedPolicy{L, Ex}, ::PostActStage, trajectory::Trajectory) where {L<:PrioritizedDQNLearner, Ex<:AbstractExplorer} for batch in trajectory - k, p = optimise!(policy, PostActStage(), batch) |> send_to_host + k, p = RLBase.optimise!(policy.learner, batch) |> send_to_host trajectory[:priority, k] = p end end diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl index 4102c1c4a..03f8e893a 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl @@ -1,7 +1,7 @@ export QRDQNLearner, quantile_huber_loss using ChainRulesCore: ignore_derivatives -using Random: GLOBAL_RNG, AbstractRNG +import Random using StatsBase: mean using Functors: @functor using Flux @@ -22,12 +22,12 @@ function quantile_huber_loss(ŷ, y; κ=1.0f0) mean(sum(loss; dims=1)) end -Base.@kwdef mutable struct QRDQNLearner{A<:Approximator{<:TwinNetwork}} <: AbstractLearner +Base.@kwdef mutable struct QRDQNLearner{A<:Approximator{<:TwinNetwork}, F, R} <: AbstractLearner approximator::A n_quantile::Int - loss_func::Any = quantile_huber_loss + loss_func::F = quantile_huber_loss γ::Float32 = 0.99f0 - rng::AbstractRNG = GLOBAL_RNG + rng::R = Random.default_rng() # for recording loss::Float32 = 0.0f0 end @@ -66,5 +66,5 @@ function RLBase.optimise!(learner::QRDQNLearner, batch::NamedTuple) loss end - optimise!(A, gs) + RLBase.optimise!(A, gs) end diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl index a10a86416..f3942d1c0 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl @@ -1,11 +1,11 @@ export RainbowLearner -using Random: AbstractRNG, GLOBAL_RNG +import Random using Flux: params, unsqueeze, softmax, gradient using Flux.Losses: logitcrossentropy using Functors: @functor -Base.@kwdef mutable struct RainbowLearner{A<:Approximator{<:TwinNetwork}} <: AbstractLearner +Base.@kwdef mutable struct RainbowLearner{A<:Approximator{<:TwinNetwork}, F, R} <: AbstractLearner approximator::A Vₘₐₓ::Float32 Vₘᵢₙ::Float32 @@ -17,8 +17,8 @@ Base.@kwdef mutable struct RainbowLearner{A<:Approximator{<:TwinNetwork}} <: Abs delta_z::Float32 = convert(Float32, support.step) default_priority::Float32 = 1.0f2 β_priority::Float32 = 0.5f0 - loss_func::Any = (ŷ, y) -> logitcrossentropy(ŷ, y; agg=identity) - rng::AbstractRNG = GLOBAL_RNG + loss_func::F = (ŷ, y) -> logitcrossentropy(ŷ, y; agg=identity) + rng::R = Random.default_rng() # for logging loss::Float32 = 0.0f0 end @@ -109,7 +109,7 @@ function RLBase.optimise!(learner::RainbowLearner, batch::NamedTuple) loss end - optimise!(A, gs) + RLBase.optimise!(A, gs) is_use_PER ? batch.key => updated_priorities : nothing end @@ -139,9 +139,9 @@ function project_distribution(supports, weights, target_support, delta_z, vmin, reshape(sum(projection, dims=1), n_atoms, batch_size) end -function RLBase.optimise!(policy::QBasedPolicy{<:RainbowLearner}, ::PostActStage, trajectory::Trajectory) +function RLBase.optimise!(policy::QBasedPolicy{L, Ex}, ::PostActStage, trajectory::Trajectory) where {L<:RainbowLearner, Ex<:AbstractExplorer} for batch in trajectory - res = optimise!(policy, PostActStage(), batch) |> send_to_host + res = RLBase.optimise!(policy, PostActStage(), batch) |> send_to_host if !isnothing(res) k, p = res trajectory[:priority, k] = p diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl index bb75551e9..e79fafdd5 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl @@ -70,6 +70,6 @@ function RLBase.optimise!(learner::REMDQNLearner, batch::NamedTuple) loss end - optimise!(A, gs) + RLBase.optimise!(A, gs) end diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl index 1f9aa7370..f878fb2f9 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl @@ -28,6 +28,11 @@ mutable struct MPOPolicy{P<:Approximator,Q<:Approximator,R} <: AbstractPolicy logs::Dict{Symbol, Vector{Float32}} end +function check(agent::Agent{<:MPOPolicy}, env) + error_string = "MPO requires a trajectory sampler that is a `MetaSampler` composed of two `MultiBatchSampler`. The first must be named `:actor` and sample `(:state,)`, the second must be named `:critic` and sample `SS′ART`" + @assert agent.trajectory.sampler isa MetaSampler{(:actor, :critic), Tuple{MultiBatchSampler{BatchSampler{(:state,)}}, MultiBatchSampler{BatchSampler{(:state, :next_state, :action, :reward, :terminal)}}}} error_string +end + """ MPOPolicy(; @@ -91,19 +96,16 @@ function RLBase.plan!(p::MPOPolicy, env; testmode = false) send_to_host(action) end + function RLBase.optimise!( p::MPOPolicy, - ::PostActStage, - batches::NamedTuple{ - (:actor, :critic), - <: Tuple{ - <: Vector{<: NamedTuple{(:state,)}}, - <: Vector{<: NamedTuple{SS′ART}} - } - } -) - update_critic!(p, batches[:critic]) - update_actor!(p, batches[:actor]) + ::PostActStage, + trajectory::Trajectory) + + for batches in trajectory + update_critic!(p, batches[:critic]) + update_actor!(p, batches[:actor]) + end end #Here we apply the TD3 Q network approach. The original MPO paper uses retrace. diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl index 572abd962..a57599be2 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl @@ -35,7 +35,7 @@ end function Base.push!(p::Agent{<:TRPO}, ::PostEpisodeStage, env::AbstractEnv) p.trajectory.container[] = true - optimise!(p.policy, p.trajectory.container) + RLBase.optimise!(p.policy, p.trajectory.container) empty!(p.trajectory.container) end @@ -44,7 +44,7 @@ RLBase.optimise!(::Agent{<:TRPO}, ::PostActStage) = nothing function RLBase.optimise!(π::TRPO, ::PostActStage, episode::Episode) gain = discount_rewards(episode[:reward][:], π.γ) for inds in Iterators.partition(shuffle(π.rng, 1:length(episode)), π.batch_size) - optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds])) + RLBase.optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds])) end end @@ -66,7 +66,7 @@ function RLBase.optimise!(p::TRPO, ::PostActStage, batch::NamedTuple{(:state, :a end loss end - optimise!(B, gs) + RLBase.optimise!(B, gs) end # store logits as intermediate value diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/util.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/util.jl index d126648ae..bcd9da67a 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/util.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/util.jl @@ -77,7 +77,7 @@ policy_gradient_estimate(policy::AbstractPolicy, states, actions, advantage) = function policy_gradient_estimate(::IsPolicyGradient, policy, states, actions, advantage) gs = gradient(params(policy.approximator)) do - action_logits = action_distribution(policy.dist, policy.approximator(states)) + action_logits = action_distribution(policy.dist, RLCore.forward(policy.approximator,states)) total_loss = logpdf.(action_logits, actions) .* advantage loss = -mean(total_loss) loss diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl index d9a3ec0c7..eeb864c07 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl @@ -1,6 +1,6 @@ export VPG -using Random: GLOBAL_RNG, shuffle +using Random: Random, shuffle using Distributions: ContinuousDistribution, DiscreteDistribution, logpdf using Functors: @functor using Flux: params, softmax, gradient, logsoftmax @@ -10,7 +10,7 @@ using ChainRulesCore: ignore_derivatives """ Vanilla Policy Gradient """ -Base.@kwdef struct VPG{A,B,D} <: AbstractPolicy +Base.@kwdef struct VPG{A,B,D, R} <: AbstractPolicy "For discrete actions, logits before softmax is expected. For continuous actions, a `Tuple` of arguments are expected to initialize `dist`" approximator::A baseline::B = nothing @@ -19,7 +19,7 @@ Base.@kwdef struct VPG{A,B,D} <: AbstractPolicy "discount ratio" γ::Float32 = 0.99f0 batch_size::Int = 1024 - rng::AbstractRNG = GLOBAL_RNG + rng::R = Random.default_rng() end IsPolicyGradient(::Type{<:VPG}) = IsPolicyGradient() @@ -30,22 +30,22 @@ function RLBase.plan!(π::VPG, env::AbstractEnv) rand(π.rng, action_distribution(π.dist, res)[1]) end -function update!(p::Agent{<:VPG}, ::PostEpisodeStage, env::AbstractEnv) - p.trajectory.container[] = true - optimise!(p.policy, p.trajectory.container) - empty!(p.trajectory.container) +function RLBase.optimise!(p::VPG, ::PostEpisodeStage, trajectory::Trajectory) + trajectory.container[] = true + for batch in trajectory + RLBase.optimise!(p, batch) + end + empty!(trajectory.container) end -RLBase.optimise!(::Agent{<:VPG}, ::PostActStage) = nothing - -function RLBase.optimise!(π::VPG, ::PostActStage, episode::Episode) +function RLBase.optimise!(π::VPG, episode::Episode) gain = discount_rewards(episode[:reward][:], π.γ) for inds in Iterators.partition(shuffle(π.rng, 1:length(episode)), π.batch_size) - optimise!(π, PostActStage(), (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds])) + RLBase.optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds])) end end -function RLBase.optimise!(p::VPG, ::PostActStage, batch::NamedTuple{(:state, :action, :gain)}) +function RLBase.optimise!(p::VPG, batch::NamedTuple{(:state, :action, :gain)}) A = p.approximator B = p.baseline s, a, g = map(Array, batch) # !!! FIXME @@ -56,16 +56,16 @@ function RLBase.optimise!(p::VPG, ::PostActStage, batch::NamedTuple{(:state, :ac loss = 0 else gs = gradient(params(B)) do - δ = g - vec(B(s)) + δ = g - vec(RLCore.forward(B, s)) loss = mean(δ .^ 2) ignore_derivatives() do # @info "VPG/baseline" loss = loss δ end loss end - optimise!(B, gs) + RLBase.optimise!(B, gs) end gs = policy_gradient_estimate(p, s, a, δ) - optimise!(A, gs) + RLBase.optimise!(A, gs) end