-
-
Notifications
You must be signed in to change notification settings - Fork 108
add MADDPG algorithm #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add MADDPG algorithm #444
Changes from 4 commits
92a6103
0b724ea
d5ed9ba
7ad23c7
2fc2ee0
4f60aac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| # --- | ||
| # title: JuliaRL\_MADDPG\_KuhnPoker | ||
| # cover: assets/JuliaRL_MADDPG_KuhnPoker.png | ||
| # description: MADDPG applied to KuhnPoker | ||
| # date: 2021-08-09 | ||
| # author: "[Peter Chen](https://github.com/peterchen96)" | ||
| # --- | ||
|
|
||
| #+ tangle=true | ||
| using ReinforcementLearning | ||
| using StableRNGs | ||
| using Flux | ||
| using IntervalSets | ||
|
|
||
| mutable struct ResultNEpisode <: AbstractHook | ||
| eval_freq::Int | ||
| episode_counter::Int | ||
| episode::Vector{Int} | ||
| results::Vector{Float64} | ||
| end | ||
|
|
||
| function (hook::ResultNEpisode)(::PostEpisodeStage, policy, env) | ||
| hook.episode_counter += 1 | ||
| if hook.episode_counter % hook.eval_freq == 0 | ||
| push!(hook.episode, hook.episode_counter) | ||
| push!(hook.results, reward(env, 1)) | ||
| end | ||
| end | ||
|
|
||
| function RL.Experiment( | ||
| ::Val{:JuliaRL}, | ||
| ::Val{:MADDPG}, | ||
| ::Val{:KuhnPoker}, | ||
| ::Nothing; | ||
| seed=123, | ||
| ) | ||
| rng = StableRNG(seed) | ||
| env = KuhnPokerEnv() | ||
| wrapped_env = ActionTransformedEnv( | ||
| StateTransformedEnv( | ||
| env; | ||
| state_mapping = s -> [findfirst(==(s), state_space(env))], | ||
| state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)] | ||
| ), | ||
| ## add a dummy action for the other agent. | ||
| action_mapping = x -> length(x) == 1 ? x : Int(x[current_player(env)] + 1), | ||
| ) | ||
| ns, na = 1, 1 | ||
| n_players = 2 | ||
|
|
||
| init = glorot_uniform(rng) | ||
|
|
||
| create_actor() = Chain( | ||
| Dense(ns, 64, relu; init = init), | ||
| Dense(64, 64, relu; init = init), | ||
| Dense(64, na, tanh; init = init), | ||
| ) | ||
|
|
||
| create_critic() = Chain( | ||
| Dense(n_players * ns + n_players * na, 64, relu; init = init), | ||
| Dense(64, 64, relu; init = init), | ||
| Dense(64, 1; init = init), | ||
| ) | ||
|
|
||
| agent = Agent( | ||
| policy = DDPGPolicy( | ||
| behavior_actor = NeuralNetworkApproximator( | ||
| model = create_actor(), | ||
| optimizer = ADAM(), | ||
| ), | ||
| behavior_critic = NeuralNetworkApproximator( | ||
| model = create_critic(), | ||
| optimizer = ADAM(), | ||
| ), | ||
| target_actor = NeuralNetworkApproximator( | ||
| model = create_actor(), | ||
| optimizer = ADAM(), | ||
| ), | ||
| target_critic = NeuralNetworkApproximator( | ||
| model = create_critic(), | ||
| optimizer = ADAM(), | ||
| ), | ||
| γ = 0.99f0, | ||
| ρ = 0.995f0, | ||
| na = na, | ||
| start_steps = 1000, | ||
| start_policy = RandomPolicy(-0.9..0.9; rng = rng), | ||
| update_after = 1000, | ||
| act_limit = 0.9, | ||
| act_noise = 0.1, | ||
| rng = rng, | ||
| ), | ||
| trajectory = CircularArraySARTTrajectory( | ||
| capacity = 10000, # replay buffer capacity | ||
| state = Vector{Int} => (ns, ), | ||
| action = Float32 => (na, ), | ||
| ), | ||
| ) | ||
|
|
||
| agents = MADDPGManager( | ||
| Dict((player, deepcopy(agent)) | ||
| for player in players(env) if player != chance_player(env)), | ||
| 128, # batch_size | ||
| 128, # update_freq | ||
| 0, # step_counter | ||
| rng | ||
| ) | ||
|
|
||
| stop_condition = StopAfterEpisode(10_000, is_show_progress=!haskey(ENV, "CI")) | ||
| hook = ResultNEpisode(1000, 0, [], []) | ||
| Experiment(agents, wrapped_env, stop_condition, hook, "# run MADDPG on KuhnPokerEnv") | ||
| end | ||
|
|
||
| #+ tangle=false | ||
| using Plots | ||
| ex = E`JuliaRL_MADDPG_KuhnPoker` | ||
| run(ex) | ||
| scatter(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="reward of player 1") | ||
|
|
||
| savefig("assets/JuliaRL_MADDPG_KuhnPoker.png") #hide | ||
|
|
||
| #  |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| export MADDPGManager | ||
|
|
||
| """ | ||
| MADDPGManager(; agents::Dict{<:Any, <:Agent}, args...) | ||
| Multi-agent Deep Deterministic Policy Gradient(MADDPG) implemented in Julia. Here only works for simultaneous games whose action space is discrete. | ||
| See the paper https://arxiv.org/abs/1706.02275 for more details. | ||
|
|
||
| # Keyword arguments | ||
| - `agents::Dict{<:Any, <:Agent{<:DDPGPolicy, <:AbstractTrajectory}}`, here each agent collects its own information. While updating the policy, each `critic` will assemble all agents' trajectory to update its own network. | ||
| - `batch_size::Int` | ||
| - `update_freq::Int` | ||
| - `update_step::Int`, count the step. | ||
| - `rng::AbstractRNG`. | ||
| """ | ||
| mutable struct MADDPGManager{P<:DDPGPolicy, T<:AbstractTrajectory} <: AbstractPolicy | ||
| agents::Dict{<:Any, <:Agent{<:P, <:T}} | ||
| batch_size::Int | ||
| update_freq::Int | ||
| update_step::Int | ||
| rng::AbstractRNG | ||
| end | ||
|
|
||
| # for simultaneous game with a discrete action space. | ||
| function (π::MADDPGManager)(env::AbstractEnv) | ||
| while current_player(env) == chance_player(env) | ||
| env |> legal_action_space |> rand |> env | ||
| end | ||
| Dict((player, ceil(agent.policy(env))) for (player, agent) in π.agents) | ||
| end | ||
|
|
||
| function (π::MADDPGManager)(::PreEpisodeStage, ::AbstractEnv) | ||
| for (_, agent) in π.agents | ||
| if length(agent.trajectory) > 0 | ||
| pop!(agent.trajectory[:state]) | ||
| pop!(agent.trajectory[:action]) | ||
| if haskey(agent.trajectory, :legal_actions_mask) | ||
| pop!(agent.trajectory[:legal_actions_mask]) | ||
| end | ||
| end | ||
| end | ||
| end | ||
|
|
||
| function (π::MADDPGManager)(::PreActStage, env::AbstractEnv, actions) | ||
| # update each agent's trajectory | ||
| for (player, agent) in π.agents | ||
| push!(agent.trajectory[:state], state(env, player)) | ||
| push!(agent.trajectory[:action], actions[player]) | ||
| if haskey(agent.trajectory, :legal_actions_mask) | ||
| lasm = legal_action_space_mask(env, player) | ||
| push!(agent.trajectory[:legal_actions_mask], lasm) | ||
| end | ||
| end | ||
|
|
||
| # update policy | ||
| update!(π) | ||
| end | ||
|
|
||
| function (π::MADDPGManager)(::PostActStage, env::AbstractEnv) | ||
| for (player, agent) in π.agents | ||
| push!(agent.trajectory[:reward], reward(env, player)) | ||
| push!(agent.trajectory[:terminal], is_terminated(env)) | ||
| end | ||
| end | ||
|
|
||
| function (π::MADDPGManager)(::PostEpisodeStage, env::AbstractEnv) | ||
| # collect state and dummy action to each agent's trajectory | ||
| for (player, agent) in π.agents | ||
| push!(agent.trajectory[:state], state(env, player)) | ||
| push!(agent.trajectory[:action], rand(action_space(env))) | ||
| if haskey(agent.trajectory, :legal_actions_mask) | ||
| lasm = legal_action_space_mask(env, player) | ||
| push!(agent.trajectory[:legal_actions_mask], lasm) | ||
| end | ||
| end | ||
|
|
||
| # update policy | ||
| update!(π) | ||
| end | ||
|
|
||
| # update policy | ||
| function RLBase.update!(π::MADDPGManager) | ||
| π.update_step += 1 | ||
| π.update_step % π.update_freq == 0 || return | ||
|
|
||
| for (_, agent) in π.agents | ||
| length(agent.trajectory) > agent.policy.update_after || return | ||
| length(agent.trajectory) > π.batch_size || return | ||
| end | ||
|
|
||
| # get training data | ||
| temp_player = rand(keys(π.agents)) | ||
| t = π.agents[temp_player].trajectory | ||
|
||
| inds = rand(π.rng, 1:length(t), π.batch_size) | ||
| batches = Dict((player, RLCore.fetch!(BatchSampler{SARTS}(π.batch_size), agent.trajectory, inds)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hardcoded |
||
| for (player, agent) in π.agents) | ||
|
|
||
| # get s, a, s′ for critic | ||
| s = vcat((batches[player][1] for (player, _) in π.agents)...) | ||
| a = vcat((batches[player][2] for (player, _) in π.agents)...) | ||
| s′ = vcat((batches[player][5] for (player, _) in π.agents)...) | ||
|
||
|
|
||
| # for training behavior_actor | ||
| mu_actions = vcat( | ||
| (( | ||
| batches[player][1] |> # get personal state information | ||
| x -> send_to_device(device(agent.policy.behavior_actor), x) |> | ||
| agent.policy.behavior_actor |> send_to_host | ||
| ) for (player, agent) in π.agents)... | ||
| ) | ||
| # for training behavior_critic | ||
| new_actions = vcat( | ||
| (( | ||
| batches[player][5] |> # batch[5] get new_state information | ||
| x -> send_to_device(device(agent.policy.target_actor), x) |> | ||
| agent.policy.target_actor |> send_to_host | ||
| ) for (player, agent) in π.agents)... | ||
| ) | ||
|
|
||
| for (player, agent) in π.agents | ||
| p = agent.policy | ||
| A = p.behavior_actor | ||
| C = p.behavior_critic | ||
| Aₜ = p.target_actor | ||
| Cₜ = p.target_critic | ||
|
|
||
| γ = p.γ | ||
| ρ = p.ρ | ||
|
|
||
| _device(x) = send_to_device(device(A), x) | ||
|
|
||
| # Note that here default A, C, Aₜ, Cₜ on the same device. | ||
| s, a, s′ = _device((s, a, s′)) | ||
| mu_actions = _device(mu_actions) | ||
| new_actions = _device(new_actions) | ||
| r = _device(batches[player][:reward]) | ||
| t = _device(batches[player][:terminal]) | ||
|
|
||
| qₜ = Cₜ(vcat(s′, new_actions)) |> vec | ||
| y = r .+ γ .* (1 .- t) .* qₜ | ||
|
|
||
| gs1 = gradient(Flux.params(C)) do | ||
| q = C(vcat(s, a)) |> vec | ||
| loss = mean((y .- q) .^ 2) | ||
| ignore() do | ||
| p.critic_loss = loss | ||
| end | ||
| loss | ||
| end | ||
|
|
||
| update!(C, gs1) | ||
|
|
||
| gs2 = gradient(Flux.params(A)) do | ||
| loss = -mean(C(vcat(s, mu_actions))) | ||
| ignore() do | ||
| p.actor_loss = loss | ||
| end | ||
| loss | ||
| end | ||
|
|
||
| update!(A, gs2) | ||
|
|
||
| # polyak averaging | ||
| for (dest, src) in zip(Flux.params([Aₜ, Cₜ]), Flux.params([A, C])) | ||
| dest .= ρ .* dest .+ (1 - ρ) .* src | ||
| end | ||
|
|
||
| s, a, s′ = send_to_host((s, a, s′)) | ||
| mu_actions = send_to_host(mu_actions) | ||
| new_actions = send_to_host(new_actions) | ||
|
||
| end | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,3 +7,4 @@ include("MAC.jl") | |
| include("ddpg.jl") | ||
| include("td3.jl") | ||
| include("sac.jl") | ||
| include("maddpg.jl") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about dispatching to the inner agent's corresponding methods?
Like calling
agent(stage, env, action)in theforloop.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you take a look at the
NamedPolicyand see whether we can reuse existing code as much as possible? See also theMultiAgentManager