Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .cspell/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@
"Norouzi",
"gzopen",
"turbulences",
"Decompressor"
"Decompressor",
"MADDPG"
],
"ignoreWords": [],
"minWordLength": 5,
Expand All @@ -143,4 +144,4 @@
"\\{%.*%\\}", // liquid syntax
"/^\\s*```[\\s\\S]*?^\\s*```/gm" // Another attempt at markdown code blocks. https://github.com/streetsidesoftware/vscode-spell-checker/issues/202#issuecomment-377477473
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# ---
# title: JuliaRL\_MADDPG\_KuhnPoker
# cover: assets/JuliaRL_MADDPG_KuhnPoker.png
# description: MADDPG applied to KuhnPoker
# date: 2021-08-09
# author: "[Peter Chen](https://github.com/peterchen96)"
# ---

#+ tangle=true
using ReinforcementLearning
using StableRNGs
using Flux
using IntervalSets

mutable struct ResultNEpisode <: AbstractHook
eval_freq::Int
episode_counter::Int
episode::Vector{Int}
results::Vector{Float64}
end

function (hook::ResultNEpisode)(::PostEpisodeStage, policy, env)
hook.episode_counter += 1
if hook.episode_counter % hook.eval_freq == 0
push!(hook.episode, hook.episode_counter)
push!(hook.results, reward(env, 1))
end
end

function RL.Experiment(
::Val{:JuliaRL},
::Val{:MADDPG},
::Val{:KuhnPoker},
::Nothing;
seed=123,
)
rng = StableRNG(seed)
env = KuhnPokerEnv()
wrapped_env = ActionTransformedEnv(
StateTransformedEnv(
env;
state_mapping = s -> [findfirst(==(s), state_space(env))],
state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)]
),
## add a dummy action for the other agent.
action_mapping = x -> length(x) == 1 ? x : Int(x[current_player(env)] + 1),
)
ns, na = 1, 1
n_players = 2

init = glorot_uniform(rng)

create_actor() = Chain(
Dense(ns, 64, relu; init = init),
Dense(64, 64, relu; init = init),
Dense(64, na, tanh; init = init),
)

create_critic() = Chain(
Dense(n_players * ns + n_players * na, 64, relu; init = init),
Dense(64, 64, relu; init = init),
Dense(64, 1; init = init),
)

agent = Agent(
policy = DDPGPolicy(
behavior_actor = NeuralNetworkApproximator(
model = create_actor(),
optimizer = ADAM(),
),
behavior_critic = NeuralNetworkApproximator(
model = create_critic(),
optimizer = ADAM(),
),
target_actor = NeuralNetworkApproximator(
model = create_actor(),
optimizer = ADAM(),
),
target_critic = NeuralNetworkApproximator(
model = create_critic(),
optimizer = ADAM(),
),
γ = 0.99f0,
ρ = 0.995f0,
na = na,
start_steps = 1000,
start_policy = RandomPolicy(-0.9..0.9; rng = rng),
update_after = 1000,
act_limit = 0.9,
act_noise = 0.1,
rng = rng,
),
trajectory = CircularArraySARTTrajectory(
capacity = 10000, # replay buffer capacity
state = Vector{Int} => (ns, ),
action = Float32 => (na, ),
),
)

agents = MADDPGManager(
Dict((player, deepcopy(agent))
for player in players(env) if player != chance_player(env)),
128, # batch_size
128, # update_freq
0, # step_counter
rng
)

stop_condition = StopAfterEpisode(10_000, is_show_progress=!haskey(ENV, "CI"))
hook = ResultNEpisode(1000, 0, [], [])
Experiment(agents, wrapped_env, stop_condition, hook, "# run MADDPG on KuhnPokerEnv")
end

#+ tangle=false
using Plots
ex = E`JuliaRL_MADDPG_KuhnPoker`
run(ex)
scatter(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="reward of player 1")

savefig("assets/JuliaRL_MADDPG_KuhnPoker.png") #hide

# ![](assets/JuliaRL_MADDPG_KuhnPoker.png)
1 change: 1 addition & 0 deletions docs/experiments/experiments/Policy Gradient/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"JuliaRL_A2C_CartPole.jl",
"JuliaRL_A2CGAE_CartPole.jl",
"JuliaRL_DDPG_Pendulum.jl",
"JuliaRL_MADDPG_KuhnPoker.jl",
"JuliaRL_MAC_CartPole.jl",
"JuliaRL_PPO_CartPole.jl",
"JuliaRL_PPO_Pendulum.jl",
Expand Down
171 changes: 171 additions & 0 deletions src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
export MADDPGManager

"""
MADDPGManager(; agents::Dict{<:Any, <:Agent}, args...)
Multi-agent Deep Deterministic Policy Gradient(MADDPG) implemented in Julia. Here only works for simultaneous games whose action space is discrete.
See the paper https://arxiv.org/abs/1706.02275 for more details.

# Keyword arguments
- `agents::Dict{<:Any, <:Agent{<:DDPGPolicy, <:AbstractTrajectory}}`, here each agent collects its own information. While updating the policy, each `critic` will assemble all agents' trajectory to update its own network.
- `batch_size::Int`
- `update_freq::Int`
- `update_step::Int`, count the step.
- `rng::AbstractRNG`.
"""
mutable struct MADDPGManager{P<:DDPGPolicy, T<:AbstractTrajectory} <: AbstractPolicy
agents::Dict{<:Any, <:Agent{<:P, <:T}}
batch_size::Int
update_freq::Int
update_step::Int
rng::AbstractRNG
end

# for simultaneous game with a discrete action space.
function (π::MADDPGManager)(env::AbstractEnv)
while current_player(env) == chance_player(env)
env |> legal_action_space |> rand |> env
end
Dict((player, ceil(agent.policy(env))) for (player, agent) in π.agents)
end

function (π::MADDPGManager)(::PreEpisodeStage, ::AbstractEnv)
for (_, agent) in π.agents
if length(agent.trajectory) > 0
pop!(agent.trajectory[:state])
pop!(agent.trajectory[:action])
if haskey(agent.trajectory, :legal_actions_mask)
pop!(agent.trajectory[:legal_actions_mask])
end
end
end
end

function (π::MADDPGManager)(::PreActStage, env::AbstractEnv, actions)
# update each agent's trajectory
for (player, agent) in π.agents
push!(agent.trajectory[:state], state(env, player))
push!(agent.trajectory[:action], actions[player])
if haskey(agent.trajectory, :legal_actions_mask)
lasm = legal_action_space_mask(env, player)
push!(agent.trajectory[:legal_actions_mask], lasm)
end
end

# update policy
update!(π)
end

function (π::MADDPGManager)(::PostActStage, env::AbstractEnv)
for (player, agent) in π.agents
push!(agent.trajectory[:reward], reward(env, player))
push!(agent.trajectory[:terminal], is_terminated(env))
end
end

function (π::MADDPGManager)(::PostEpisodeStage, env::AbstractEnv)
# collect state and dummy action to each agent's trajectory
for (player, agent) in π.agents
push!(agent.trajectory[:state], state(env, player))
push!(agent.trajectory[:action], rand(action_space(env)))
if haskey(agent.trajectory, :legal_actions_mask)
lasm = legal_action_space_mask(env, player)
push!(agent.trajectory[:legal_actions_mask], lasm)
end
end

# update policy
update!(π)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about dispatching to the inner agent's corresponding methods?

Like calling agent(stage, env, action) in the for loop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take a look at the NamedPolicy and see whether we can reuse existing code as much as possible? See also the MultiAgentManager


# update policy
function RLBase.update!(π::MADDPGManager)
π.update_step += 1
π.update_step % π.update_freq == 0 || return

for (_, agent) in π.agents
length(agent.trajectory) > agent.policy.update_after || return
length(agent.trajectory) > π.batch_size || return
end

# get training data
temp_player = rand(keys(π.agents))
t = π.agents[temp_player].trajectory
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply use the first agent?

inds = rand(π.rng, 1:length(t), π.batch_size)
batches = Dict((player, RLCore.fetch!(BatchSampler{SARTS}(π.batch_size), agent.trajectory, inds))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded SARTS will make the algorithm work only on environments of MINIMAL_ACTION_SET.

for (player, agent) in π.agents)

# get s, a, s′ for critic
s = vcat((batches[player][1] for (player, _) in π.agents)...)
a = vcat((batches[player][2] for (player, _) in π.agents)...)
s′ = vcat((batches[player][5] for (player, _) in π.agents)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vcat is not very efficient here. Try Flux.batch?


# for training behavior_actor
mu_actions = vcat(
((
batches[player][1] |> # get personal state information
x -> send_to_device(device(agent.policy.behavior_actor), x) |>
agent.policy.behavior_actor |> send_to_host
) for (player, agent) in π.agents)...
)
# for training behavior_critic
new_actions = vcat(
((
batches[player][5] |> # batch[5] get new_state information
x -> send_to_device(device(agent.policy.target_actor), x) |>
agent.policy.target_actor |> send_to_host
) for (player, agent) in π.agents)...
)

for (player, agent) in π.agents
p = agent.policy
A = p.behavior_actor
C = p.behavior_critic
Aₜ = p.target_actor
Cₜ = p.target_critic

γ = p.γ
ρ = p.ρ

_device(x) = send_to_device(device(A), x)

# Note that here default A, C, Aₜ, Cₜ on the same device.
s, a, s′ = _device((s, a, s′))
mu_actions = _device(mu_actions)
new_actions = _device(new_actions)
r = _device(batches[player][:reward])
t = _device(batches[player][:terminal])

qₜ = Cₜ(vcat(s′, new_actions)) |> vec
y = r .+ γ .* (1 .- t) .* qₜ

gs1 = gradient(Flux.params(C)) do
q = C(vcat(s, a)) |> vec
loss = mean((y .- q) .^ 2)
ignore() do
p.critic_loss = loss
end
loss
end

update!(C, gs1)

gs2 = gradient(Flux.params(A)) do
loss = -mean(C(vcat(s, mu_actions)))
ignore() do
p.actor_loss = loss
end
loss
end

update!(A, gs2)

# polyak averaging
for (dest, src) in zip(Flux.params([Aₜ, Cₜ]), Flux.params([A, C]))
dest .= ρ .* dest .+ (1 - ρ) .* src
end

s, a, s′ = send_to_host((s, a, s′))
mu_actions = send_to_host(mu_actions)
new_actions = send_to_host(new_actions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they required here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your kind reviews! I'll check and update my codes later today.

end
end
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ include("MAC.jl")
include("ddpg.jl")
include("td3.jl")
include("sac.jl")
include("maddpg.jl")