From 2d6de10ea272f71b00a73e72af78c56a7f9fd4a5 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 12:31:29 +0100 Subject: [PATCH 01/36] first push --- pax/env.py | 55 ------------------------------------------------------ 1 file changed, 55 deletions(-) diff --git a/pax/env.py b/pax/env.py index 01972542..2ef800d2 100644 --- a/pax/env.py +++ b/pax/env.py @@ -111,29 +111,6 @@ def runner_reset(ndims, rng): self.batch_step = jax.jit(jax.vmap(jax.vmap(_step))) self.runner_reset = runner_reset - def step(self, actions): - if self._reset_next_step: - return self.reset() - - output, self.state = self.runner_step(actions, self.state) - if (self.state.outer_t == self.outer_ep_length).all(): - self._reset_next_step = True - output = ( - TimeStep( - 2 * jnp.ones(self.num_envs, dtype=jnp.int8), - output[0].reward, - output[0].discount, - output[0].observation, - ), - TimeStep( - 2 * jnp.ones(self.num_envs, dtype=jnp.int8), - output[1].reward, - output[1].discount, - output[1].observation, - ), - ) - return output - def observation_spec(self) -> specs.DiscreteArray: """Returns the observation spec.""" return specs.DiscreteArray(num_values=5, name="previous turn") @@ -221,34 +198,6 @@ def _step(actions, state): jax.vmap(self.runner_step, (0, None), (0, None)) ) - def step( - self, actions: Tuple[jnp.ndarray, jnp.ndarray] - ) -> Tuple[TimeStep, TimeStep]: - """ - takes a tuple of batched policies and produce value functions from infinite game - policy of form [B, 5] - """ - if self._reset_next_step: - return self.reset() - - action_1, action_2 = actions - self._num_steps += 1 - assert action_1.shape == action_2.shape - assert action_1.shape == (self.num_envs, 5) - - outputs, self.state = self._jit_step(actions, self.state) - r1, r2, obs1, obs2, _ = outputs - r1, r2 = (1 - self.gamma) * r1, (1 - self.gamma) * r2 - - if self._num_steps == self.episode_length: - self._reset_next_step = True - return termination(reward=r1, observation=obs1), termination( - reward=r2, observation=obs2 - ) - return transition(reward=r1, observation=obs1), transition( - reward=r2, observation=obs2 - ) - def runner_step( self, actions: Tuple[jnp.ndarray, jnp.ndarray], @@ -728,10 +677,6 @@ def reset(self) -> Tuple[TimeStep, TimeStep]: output, self.state = self._reset(self.key) return output - def step(self, actions: Tuple[int, int]) -> Tuple[TimeStep, TimeStep]: - output, self.state = self.runner_step(actions, self.state) - return output - def observation_spec(self) -> specs.BoundedArray: """Returns the observation spec.""" if self.cnn: From 6e085aafd339d3be7abe2c7504d4e652a9e7916e Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 16:05:10 +0100 Subject: [PATCH 02/36] first push - add IteratedMatrixGame --- pax/env/iterated_matrix_game.py | 120 ++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 pax/env/iterated_matrix_game.py diff --git a/pax/env/iterated_matrix_game.py b/pax/env/iterated_matrix_game.py new file mode 100644 index 00000000..e40f1bf4 --- /dev/null +++ b/pax/env/iterated_matrix_game.py @@ -0,0 +1,120 @@ +import jax +from flax import struct +import jax.numpy as jnp +import chex +from gymnax.environments import environment, spaces +from typing import Tuple, Optional + + +class EnvState(struct.DataClass): + inner_t: int + outer_t: int + + +class EnvParams(struct.DataClass): + payoff_matrix: jnp.ndarray + num_inner_steps: int + num_outer_steps: int + num_players: int + num_actions: int + + +class IteratedMatrixGame(environment.Environment): + """ + JAX Compatible version of matrix game environment. Source: + """ + + def __init__(self): + super().__init__() + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[int, int], + params: EnvParams, + ): + inner_t, outer_t = state.inner_t, state.outer_t + a1, a2 = actions + inner_t += 1 + + cc_p1 = params.payoff[0][0] * (a1 - 1.0) * (a2 - 1.0) + cc_p2 = params.payoff[0][1] * (a1 - 1.0) * (a2 - 1.0) + cd_p1 = params.payoff[1][0] * (1.0 - a1) * a2 + cd_p2 = params.payoff[1][1] * (1.0 - a1) * a2 + dc_p1 = params.payoff[2][0] * a1 * (1.0 - a2) + dc_p2 = params.payoff[2][1] * a1 * (1.0 - a2) + dd_p1 = params.payoff[3][0] * a1 * a2 + dd_p2 = params.payoff[3][1] * a1 * a2 + + r1 = cc_p1 + dc_p1 + cd_p1 + dd_p1 + r2 = cc_p2 + dc_p2 + cd_p2 + dd_p2 + + s1 = ( + 0 * (1 - a1) * (1 - a2) + + 1 * (1 - a1) * a2 + + 2 * a1 * (1 - a2) + + 3 * (a1) * (a2) + ) + + s2 = ( + 0 * (1 - a1) * (1 - a2) + + 1 * a1 * (1 - a2) + + 2 * (1 - a1) * a2 + + 3 * (a1) * (a2) + ) + # if first step then return START state. + done = inner_t % params.num_inner_steps == 0 + s1 = jax.lax.select(done, jnp.int8(4), jnp.int8(s1)) + s2 = jax.lax.select(done, jnp.int8(4), jnp.int8(s2)) + obs1 = jax.nn.one_hot(s1, 5, dtype=jnp.int8) + obs2 = jax.nn.one_hot(s2, 5, dtype=jnp.int8) + + # out step keeping + reset_inner = inner_t == params.num_inner_steps + inner_t = jax.lax.select( + reset_inner, jnp.zeros_like(inner_t), inner_t + ) + outer_t_new = outer_t + 1 + outer_t = jax.lax.select(reset_inner, outer_t_new, outer_t) + discount = jnp.zeros((), dtype=jnp.int8) + state = EnvState(inner_t, outer_t) + return (obs1, obs2), state, (r1, r2), done, {"discount": discount} + + # for runner + # self.runner_step = jax.jit(jax.vmap(_step)) + # self.batch_step = jax.jit(jax.vmap(jax.vmap(_step))) + # self.runner_reset = runner_reset + + @property + def name(self) -> str: + """Environment name.""" + return "IteratedMatrixGame-v1" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return 2 + + def action_space( + self, params: Optional[EnvParams] = None + ) -> spaces.Discrete: + """Action space of the environment.""" + return spaces.Discrete(4) + + def observation_space(self, params: EnvParams) -> spaces.Box: + """Observation space of the environment.""" + return spaces.Discrete(5) + + def state_space(self, params: EnvParams) -> spaces.Dict: + """State space of the environment.""" + return spaces.Discrete(5) + + def reset_env( + self, key: chex.PRNGKey, params: EnvParams + ) -> Tuple[chex.Array, EnvState]: + state = EnvState( + jnp.zeros((), dtype=jnp.int8), + jnp.zeros((), dtype=jnp.int8), + ) + obs = jax.nn.one_hot(4 * jnp.ones(()), 5, dtype=jnp.int8) + return (obs, obs), state From 6e8f301b5527c3ced2c730fb1b763cfd158880b2 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 16:07:15 +0100 Subject: [PATCH 03/36] first push - add IteratedMatrixGame --- pax/env/iterated_matrix_game.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pax/env/iterated_matrix_game.py b/pax/env/iterated_matrix_game.py index e40f1bf4..8dd27eee 100644 --- a/pax/env/iterated_matrix_game.py +++ b/pax/env/iterated_matrix_game.py @@ -21,7 +21,7 @@ class EnvParams(struct.DataClass): class IteratedMatrixGame(environment.Environment): """ - JAX Compatible version of matrix game environment. Source: + JAX Compatible version of matrix game environment. """ def __init__(self): @@ -81,6 +81,7 @@ def _step( return (obs1, obs2), state, (r1, r2), done, {"discount": discount} # for runner + self.step = jax.jit(_step) # self.runner_step = jax.jit(jax.vmap(_step)) # self.batch_step = jax.jit(jax.vmap(jax.vmap(_step))) # self.runner_reset = runner_reset From d89f9d778ae90d946f12b3f0f89eee86eca85a7c Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 16:22:34 +0100 Subject: [PATCH 04/36] first push - add IteratedMatrixGame --- pax/envs/__init__.py | 0 pax/{env => envs}/iterated_matrix_game.py | 29 ++-- test/envs/test_iterated_matrix_game.py | 180 ++++++++++++++++++++++ 3 files changed, 195 insertions(+), 14 deletions(-) create mode 100644 pax/envs/__init__.py rename pax/{env => envs}/iterated_matrix_game.py (87%) create mode 100644 test/envs/test_iterated_matrix_game.py diff --git a/pax/envs/__init__.py b/pax/envs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pax/env/iterated_matrix_game.py b/pax/envs/iterated_matrix_game.py similarity index 87% rename from pax/env/iterated_matrix_game.py rename to pax/envs/iterated_matrix_game.py index 8dd27eee..fcf52343 100644 --- a/pax/env/iterated_matrix_game.py +++ b/pax/envs/iterated_matrix_game.py @@ -6,17 +6,17 @@ from typing import Tuple, Optional -class EnvState(struct.DataClass): +@struct.dataclass +class EnvState: inner_t: int outer_t: int -class EnvParams(struct.DataClass): +@struct.dataclass +class EnvParams: payoff_matrix: jnp.ndarray num_inner_steps: int num_outer_steps: int - num_players: int - num_actions: int class IteratedMatrixGame(environment.Environment): @@ -80,8 +80,19 @@ def _step( state = EnvState(inner_t, outer_t) return (obs1, obs2), state, (r1, r2), done, {"discount": discount} + def _reset_env( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[chex.Array, EnvState]: + state = EnvState( + jnp.zeros((), dtype=jnp.int8), + jnp.zeros((), dtype=jnp.int8), + ) + obs = jax.nn.one_hot(4 * jnp.ones(()), 5, dtype=jnp.int8) + return (obs, obs), state + # for runner self.step = jax.jit(_step) + self.reset_env = jax.jit(_reset_env) # self.runner_step = jax.jit(jax.vmap(_step)) # self.batch_step = jax.jit(jax.vmap(jax.vmap(_step))) # self.runner_reset = runner_reset @@ -109,13 +120,3 @@ def observation_space(self, params: EnvParams) -> spaces.Box: def state_space(self, params: EnvParams) -> spaces.Dict: """State space of the environment.""" return spaces.Discrete(5) - - def reset_env( - self, key: chex.PRNGKey, params: EnvParams - ) -> Tuple[chex.Array, EnvState]: - state = EnvState( - jnp.zeros((), dtype=jnp.int8), - jnp.zeros((), dtype=jnp.int8), - ) - obs = jax.nn.one_hot(4 * jnp.ones(()), 5, dtype=jnp.int8) - return (obs, obs), state diff --git a/test/envs/test_iterated_matrix_game.py b/test/envs/test_iterated_matrix_game.py new file mode 100644 index 00000000..a6eec790 --- /dev/null +++ b/test/envs/test_iterated_matrix_game.py @@ -0,0 +1,180 @@ +import jax.numpy as jnp +import jax +import pytest + +from pax.envs.iterated_matrix_game import IteratedMatrixGame +from pax.strategies import TitForTat + +# payoff matrices for four games +ipd = [[2, 2], [0, 3], [3, 0], [1, 1]] +stag = [[4, 4], [1, 3], [3, 1], [2, 2]] +sexes = [[3, 2], [0, 0], [0, 0], [2, 3]] +chicken = [[0, 0], [-1, 1], [1, -1], [-2, -2]] +test_payoffs = [ipd, stag, sexes, chicken] + + +@pytest.mark.parametrize("payoff", test_payoffs) +def test_single_batch_rewards(payoff) -> None: + num_envs = 1 + rng = jax.random.PRNGKey(0) + env = IteratedMatrixGame() + env_state = IteratedMatrixGame.State(payoff, 5, 5) + + obs, env_state = env.env_reset(env_state) + + action = jnp.ones((num_envs,), dtype=jnp.float32) + r_array = jnp.ones((num_envs,), dtype=jnp.float32) + + # payoffs + cc_p1, cc_p2 = payoff[0][0], payoff[0][1] + cd_p1, cd_p2 = payoff[1][0], payoff[1][1] + dc_p1, dc_p2 = payoff[2][0], payoff[2][1] + dd_p1, dd_p2 = payoff[3][0], payoff[3][1] + + # first step + tstep_0, tstep_1 = env.step(rng, (0 * action, 0 * action), env_state) + + tstep_0, tstep_1 = env.step((0 * action, 0 * action)) + assert jnp.array_equal(tstep_0.reward, cc_p1 * r_array) + assert jnp.array_equal(tstep_1.reward, cc_p2 * r_array) + + tstep_0, tstep_1 = env.step((1 * action, 0 * action)) + assert jnp.array_equal(tstep_0.reward, dc_p1 * r_array) + assert jnp.array_equal(tstep_1.reward, dc_p2 * r_array) + + tstep_0, tstep_1 = env.step((0 * action, 1 * action)) + assert jnp.array_equal(tstep_0.reward, cd_p1 * r_array) + assert jnp.array_equal(tstep_1.reward, cd_p2 * r_array) + + tstep_0, tstep_1 = env.step((1 * action, 1 * action)) + assert jnp.array_equal(tstep_0.reward, dd_p1 * r_array) + assert jnp.array_equal(tstep_1.reward, dd_p2 * r_array) + + +testdata = [ + ((0, 0), (2, 2), ipd), + ((1, 0), (3, 0), ipd), + ((0, 1), (0, 3), ipd), + ((1, 1), (1, 1), ipd), + ((0, 0), (4, 4), stag), + ((1, 0), (3, 1), stag), + ((0, 1), (1, 3), stag), + ((1, 1), (2, 2), stag), + ((0, 0), (3, 2), sexes), + ((1, 0), (0, 0), sexes), + ((0, 1), (0, 0), sexes), + ((1, 1), (2, 3), sexes), + ((0, 0), (0, 0), chicken), + ((1, 0), (1, -1), chicken), + ((0, 1), (-1, 1), chicken), + ((1, 1), (-2, -2), chicken), +] + + +@pytest.mark.parametrize("actions, expected_rewards, payoff", testdata) +def test_batch_outcomes(actions, expected_rewards, payoff) -> None: + num_envs = 3 + all_ones = jnp.ones((num_envs,)) + env = IteratedMatrixGame(num_envs, payoff, 5, 10) + env.reset() + + action_1, action_2 = actions + expected_r1, expected_r2 = expected_rewards + + tstep_0, tstep_1 = env.step((action_1 * all_ones, action_2 * all_ones)) + + assert jnp.array_equal(tstep_0.reward, expected_r1 * jnp.ones((num_envs,))) + assert jnp.array_equal(tstep_1.reward, expected_r2 * jnp.ones((num_envs,))) + # assert tstep_0.last() == False + # assert tstep_1.last() == False + + +def test_mixed_batched_outcomes() -> None: + pass + + +def test_tit_for_tat_match() -> None: + num_envs = 5 + payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] + env = IteratedMatrixGame(num_envs, payoff, 5, 10) + t_0, t_1 = env.reset() + + tit_for_tat = TitForTat(num_envs) + + action_0 = tit_for_tat.select_action(t_0) + action_1 = tit_for_tat.select_action(t_1) + assert jnp.array_equal(action_0, action_1) + + t_0, t_1 = env.step((action_0, action_1)) + assert jnp.array_equal(t_0.reward, t_1.reward) + + +def test_longer_game() -> None: + num_envs = 1 + payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] + num_steps = 50 + num_inner_steps = 2 + env = IteratedMatrixGame(num_envs, payoff, num_inner_steps, num_steps) + t_0, t_1 = env.reset() + + agent = TitForTat(num_envs) + action = agent.select_action(t_0) + + r1 = [] + r2 = [] + for _ in range(10): + action = agent.select_action(t_0) + t0, t1 = env.step((action, action)) + r1.append(t0.reward) + r2.append(t1.reward) + + assert jnp.array_equal(t_0.reward, t_1.reward) + assert jnp.mean(jnp.stack(r1)) == 2 + assert jnp.mean(jnp.stack(r2)) == 2 + + +def test_done(): + num_envs = 1 + payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] + env = IteratedMatrixGame(num_envs, payoff, 5, 5) + action = jnp.ones((num_envs,)) + + # check first + t_0, t_1 = env.step((0 * action, 0 * action)) + assert t_0.last() == False + assert t_1.last() == False + + for _ in range(4): + t_0, t_1 = env.step((0 * action, 0 * action)) + assert t_0.last() == False + assert t_1.last() == False + + # check final + t_0, t_1 = env.step((0 * action, 0 * action)) + assert t_0.last() == True + assert t_1.last() == True + + # check back at start + assert jnp.array_equal(t_0.observation.argmax(), 4) + assert jnp.array_equal(t_1.observation.argmax(), 4) + + +def test_reset(): + num_envs = 1 + payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] + env = IteratedMatrixGame(num_envs, payoff, 5, 10) + state = jnp.ones((num_envs,)) + + env.reset() + + for _ in range(4): + t_0, t_1 = env.step((0 * state, 0 * state)) + assert t_0.last().all() == False + assert t_1.last().all() == False + + env.reset() + + for _ in range(4): + t_0, t_1 = env.step((0 * state, 0 * state)) + assert t_0.last().all() == False + assert t_1.last().all() == False From 655e278a4bc7c074677bcc50fcc955b9b94c0ebe Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 16:43:00 +0100 Subject: [PATCH 05/36] first push - add IteratedMatrixGame --- pax/envs/iterated_matrix_game.py | 24 ++++++++++++------------ test/envs/__init__.py | 0 test/envs/test_iterated_matrix_game.py | 17 +++++++++++++---- 3 files changed, 25 insertions(+), 16 deletions(-) create mode 100644 test/envs/__init__.py diff --git a/pax/envs/iterated_matrix_game.py b/pax/envs/iterated_matrix_game.py index fcf52343..b4c24bb8 100644 --- a/pax/envs/iterated_matrix_game.py +++ b/pax/envs/iterated_matrix_game.py @@ -6,13 +6,13 @@ from typing import Tuple, Optional -@struct.dataclass +@chex.dataclass class EnvState: inner_t: int outer_t: int -@struct.dataclass +@chex.dataclass class EnvParams: payoff_matrix: jnp.ndarray num_inner_steps: int @@ -37,14 +37,14 @@ def _step( a1, a2 = actions inner_t += 1 - cc_p1 = params.payoff[0][0] * (a1 - 1.0) * (a2 - 1.0) - cc_p2 = params.payoff[0][1] * (a1 - 1.0) * (a2 - 1.0) - cd_p1 = params.payoff[1][0] * (1.0 - a1) * a2 - cd_p2 = params.payoff[1][1] * (1.0 - a1) * a2 - dc_p1 = params.payoff[2][0] * a1 * (1.0 - a2) - dc_p2 = params.payoff[2][1] * a1 * (1.0 - a2) - dd_p1 = params.payoff[3][0] * a1 * a2 - dd_p2 = params.payoff[3][1] * a1 * a2 + cc_p1 = params.payoff_matrix[0][0] * (a1 - 1.0) * (a2 - 1.0) + cc_p2 = params.payoff_matrix[0][1] * (a1 - 1.0) * (a2 - 1.0) + cd_p1 = params.payoff_matrix[1][0] * (1.0 - a1) * a2 + cd_p2 = params.payoff_matrix[1][1] * (1.0 - a1) * a2 + dc_p1 = params.payoff_matrix[2][0] * a1 * (1.0 - a2) + dc_p2 = params.payoff_matrix[2][1] * a1 * (1.0 - a2) + dd_p1 = params.payoff_matrix[3][0] * a1 * a2 + dd_p2 = params.payoff_matrix[3][1] * a1 * a2 r1 = cc_p1 + dc_p1 + cd_p1 + dd_p1 r2 = cc_p2 + dc_p2 + cd_p2 + dd_p2 @@ -84,8 +84,8 @@ def _reset_env( key: chex.PRNGKey, params: EnvParams ) -> Tuple[chex.Array, EnvState]: state = EnvState( - jnp.zeros((), dtype=jnp.int8), - jnp.zeros((), dtype=jnp.int8), + inner_t=jnp.zeros((), dtype=jnp.int8), + outer_t=jnp.zeros((), dtype=jnp.int8), ) obs = jax.nn.one_hot(4 * jnp.ones(()), 5, dtype=jnp.int8) return (obs, obs), state diff --git a/test/envs/__init__.py b/test/envs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/envs/test_iterated_matrix_game.py b/test/envs/test_iterated_matrix_game.py index a6eec790..fa649430 100644 --- a/test/envs/test_iterated_matrix_game.py +++ b/test/envs/test_iterated_matrix_game.py @@ -2,7 +2,8 @@ import jax import pytest -from pax.envs.iterated_matrix_game import IteratedMatrixGame +from pax.envs.iterated_matrix_game import IteratedMatrixGame, EnvParams + from pax.strategies import TitForTat # payoff matrices for four games @@ -11,6 +12,7 @@ sexes = [[3, 2], [0, 0], [0, 0], [2, 3]] chicken = [[0, 0], [-1, 1], [1, -1], [-2, -2]] test_payoffs = [ipd, stag, sexes, chicken] +test_payoffs = [ipd] @pytest.mark.parametrize("payoff", test_payoffs) @@ -18,9 +20,14 @@ def test_single_batch_rewards(payoff) -> None: num_envs = 1 rng = jax.random.PRNGKey(0) env = IteratedMatrixGame() - env_state = IteratedMatrixGame.State(payoff, 5, 5) + env_params = EnvParams( + payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=1 + ) + + env.reset_env = jax.vmap(env.reset_env, in_axes=(0, None)) + env.step = jax.vmap(env.step, in_axes=(None, None, 0, None)) - obs, env_state = env.env_reset(env_state) + obs, env_state = env.reset_env(rng, env_params) action = jnp.ones((num_envs,), dtype=jnp.float32) r_array = jnp.ones((num_envs,), dtype=jnp.float32) @@ -32,7 +39,9 @@ def test_single_batch_rewards(payoff) -> None: dd_p1, dd_p2 = payoff[3][0], payoff[3][1] # first step - tstep_0, tstep_1 = env.step(rng, (0 * action, 0 * action), env_state) + tstep_0, tstep_1 = env.step( + rng, env_state, (0 * action, 0 * action), env_params + ) tstep_0, tstep_1 = env.step((0 * action, 0 * action)) assert jnp.array_equal(tstep_0.reward, cc_p1 * r_array) From 83a8c886c7189c64204570a9afa9c3eaa861c144 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 18:28:50 +0100 Subject: [PATCH 06/36] first test passes! --- pax/envs/iterated_matrix_game.py | 15 ++++-- test/envs/test_iterated_matrix_game.py | 69 ++++++++++++++++---------- 2 files changed, 53 insertions(+), 31 deletions(-) diff --git a/pax/envs/iterated_matrix_game.py b/pax/envs/iterated_matrix_game.py index b4c24bb8..fa3ef174 100644 --- a/pax/envs/iterated_matrix_game.py +++ b/pax/envs/iterated_matrix_game.py @@ -64,6 +64,7 @@ def _step( ) # if first step then return START state. done = inner_t % params.num_inner_steps == 0 + print(done.shape, jnp.int8(4), jnp.int8(s1)) s1 = jax.lax.select(done, jnp.int8(4), jnp.int8(s1)) s2 = jax.lax.select(done, jnp.int8(4), jnp.int8(s2)) obs1 = jax.nn.one_hot(s1, 5, dtype=jnp.int8) @@ -76,9 +77,14 @@ def _step( ) outer_t_new = outer_t + 1 outer_t = jax.lax.select(reset_inner, outer_t_new, outer_t) - discount = jnp.zeros((), dtype=jnp.int8) - state = EnvState(inner_t, outer_t) - return (obs1, obs2), state, (r1, r2), done, {"discount": discount} + state = EnvState(inner_t=inner_t, outer_t=outer_t) + return ( + (obs1, obs2), + state, + (r1, r2), + done, + {"discount": jnp.zeros((), dtype=jnp.int8)}, + ) def _reset_env( key: chex.PRNGKey, params: EnvParams @@ -91,7 +97,8 @@ def _reset_env( return (obs, obs), state # for runner - self.step = jax.jit(_step) + # self.step = jax.jit(_step) + self.step = _step self.reset_env = jax.jit(_reset_env) # self.runner_step = jax.jit(jax.vmap(_step)) # self.batch_step = jax.jit(jax.vmap(jax.vmap(_step))) diff --git a/test/envs/test_iterated_matrix_game.py b/test/envs/test_iterated_matrix_game.py index fa649430..2248ab6f 100644 --- a/test/envs/test_iterated_matrix_game.py +++ b/test/envs/test_iterated_matrix_game.py @@ -12,26 +12,26 @@ sexes = [[3, 2], [0, 0], [0, 0], [2, 3]] chicken = [[0, 0], [-1, 1], [1, -1], [-2, -2]] test_payoffs = [ipd, stag, sexes, chicken] -test_payoffs = [ipd] @pytest.mark.parametrize("payoff", test_payoffs) def test_single_batch_rewards(payoff) -> None: - num_envs = 1 + num_envs = 5 rng = jax.random.PRNGKey(0) env = IteratedMatrixGame() env_params = EnvParams( payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=1 ) - env.reset_env = jax.vmap(env.reset_env, in_axes=(0, None)) - env.step = jax.vmap(env.step, in_axes=(None, None, 0, None)) - - obs, env_state = env.reset_env(rng, env_params) - action = jnp.ones((num_envs,), dtype=jnp.float32) r_array = jnp.ones((num_envs,), dtype=jnp.float32) + # we want to batch over envs purely by actions + env.step = jax.vmap( + env.step, in_axes=(None, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + obs, env_state = env.reset(rng, env_params) + # payoffs cc_p1, cc_p2 = payoff[0][0], payoff[0][1] cd_p1, cd_p2 = payoff[1][0], payoff[1][1] @@ -39,25 +39,29 @@ def test_single_batch_rewards(payoff) -> None: dd_p1, dd_p2 = payoff[3][0], payoff[3][1] # first step - tstep_0, tstep_1 = env.step( + obs, env_state, rewards, done, info = env.step( rng, env_state, (0 * action, 0 * action), env_params ) + assert jnp.array_equal(rewards[0], cc_p1 * r_array) + assert jnp.array_equal(rewards[1], cc_p2 * r_array) - tstep_0, tstep_1 = env.step((0 * action, 0 * action)) - assert jnp.array_equal(tstep_0.reward, cc_p1 * r_array) - assert jnp.array_equal(tstep_1.reward, cc_p2 * r_array) - - tstep_0, tstep_1 = env.step((1 * action, 0 * action)) - assert jnp.array_equal(tstep_0.reward, dc_p1 * r_array) - assert jnp.array_equal(tstep_1.reward, dc_p2 * r_array) + obs, env_state, rewards, done, info = env.step( + rng, env_state, (1 * action, 0 * action), env_params + ) + assert jnp.array_equal(rewards[0], dc_p1 * r_array) + assert jnp.array_equal(rewards[1], dc_p2 * r_array) - tstep_0, tstep_1 = env.step((0 * action, 1 * action)) - assert jnp.array_equal(tstep_0.reward, cd_p1 * r_array) - assert jnp.array_equal(tstep_1.reward, cd_p2 * r_array) + obs, env_state, rewards, done, info = env.step( + rng, env_state, (0 * action, 1 * action), env_params + ) + assert jnp.array_equal(rewards[0], cd_p1 * r_array) + assert jnp.array_equal(rewards[1], cd_p2 * r_array) - tstep_0, tstep_1 = env.step((1 * action, 1 * action)) - assert jnp.array_equal(tstep_0.reward, dd_p1 * r_array) - assert jnp.array_equal(tstep_1.reward, dd_p2 * r_array) + obs, env_state, rewards, done, info = env.step( + rng, env_state, (1 * action, 1 * action), env_params + ) + assert jnp.array_equal(rewards[0], dd_p1 * r_array) + assert jnp.array_equal(rewards[1], dd_p2 * r_array) testdata = [ @@ -84,16 +88,27 @@ def test_single_batch_rewards(payoff) -> None: def test_batch_outcomes(actions, expected_rewards, payoff) -> None: num_envs = 3 all_ones = jnp.ones((num_envs,)) - env = IteratedMatrixGame(num_envs, payoff, 5, 10) - env.reset() - action_1, action_2 = actions + rng = jax.random.PRNGKey(0) + + a1, a2 = actions expected_r1, expected_r2 = expected_rewards - tstep_0, tstep_1 = env.step((action_1 * all_ones, action_2 * all_ones)) + env = IteratedMatrixGame() + env_params = EnvParams( + payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=1 + ) + # we want to batch over envs purely by actions + env.step = jax.vmap( + env.step, in_axes=(None, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + obs, env_state = env.reset(rng, env_params) + obs, env_state, rewards, done, info = env.step( + rng, env_state, (a1 * all_ones, a2 * all_ones), env_params + ) - assert jnp.array_equal(tstep_0.reward, expected_r1 * jnp.ones((num_envs,))) - assert jnp.array_equal(tstep_1.reward, expected_r2 * jnp.ones((num_envs,))) + assert jnp.array_equal(rewards[0], expected_r1 * jnp.ones((num_envs,))) + assert jnp.array_equal(rewards[1], expected_r2 * jnp.ones((num_envs,))) # assert tstep_0.last() == False # assert tstep_1.last() == False From 9a32d116c1ee145cbbc07f23c310e82449784d28 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 18:54:00 +0100 Subject: [PATCH 07/36] changed everything up to tft --- test/envs/test_iterated_matrix_game.py | 80 ++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/test/envs/test_iterated_matrix_game.py b/test/envs/test_iterated_matrix_game.py index 2248ab6f..16e43d25 100644 --- a/test/envs/test_iterated_matrix_game.py +++ b/test/envs/test_iterated_matrix_game.py @@ -26,7 +26,7 @@ def test_single_batch_rewards(payoff) -> None: action = jnp.ones((num_envs,), dtype=jnp.float32) r_array = jnp.ones((num_envs,), dtype=jnp.float32) - # we want to batch over envs purely by actions + # we want to batch over actions env.step = jax.vmap( env.step, in_axes=(None, None, 0, None), out_axes=(0, None, 0, 0, 0) ) @@ -109,28 +109,86 @@ def test_batch_outcomes(actions, expected_rewards, payoff) -> None: assert jnp.array_equal(rewards[0], expected_r1 * jnp.ones((num_envs,))) assert jnp.array_equal(rewards[1], expected_r2 * jnp.ones((num_envs,))) - # assert tstep_0.last() == False - # assert tstep_1.last() == False + assert (done == False).all() + + +def test_batch_by_rngs() -> None: + # we don't use the rng in step but good test of batchings by rngs + num_envs = 2 + payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] + + rng = jnp.concatenate( + [jax.random.PRNGKey(0), jax.random.PRNGKey(0)] + ).reshape(num_envs, -1) + env = IteratedMatrixGame() + env_params = EnvParams( + payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=1 + ) + + action = jnp.ones((num_envs,), dtype=jnp.float32) + r_array = jnp.ones((num_envs,), dtype=jnp.float32) + + # we want to batch over envs purely by actions + env.step = jax.vmap( + env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + obs, env_state = env.reset(rng, env_params) + + # payoffs + cc_p1, cc_p2 = payoff[0][0], payoff[0][1] + cd_p1, cd_p2 = payoff[1][0], payoff[1][1] + dc_p1, dc_p2 = payoff[2][0], payoff[2][1] + dd_p1, dd_p2 = payoff[3][0], payoff[3][1] + + # first step + obs, env_state, rewards, done, info = env.step( + rng, env_state, (0 * action, 0 * action), env_params + ) + assert jnp.array_equal(rewards[0], cc_p1 * r_array) + assert jnp.array_equal(rewards[1], cc_p2 * r_array) + obs, env_state, rewards, done, info = env.step( + rng, env_state, (1 * action, 0 * action), env_params + ) + assert jnp.array_equal(rewards[0], dc_p1 * r_array) + assert jnp.array_equal(rewards[1], dc_p2 * r_array) -def test_mixed_batched_outcomes() -> None: - pass + obs, env_state, rewards, done, info = env.step( + rng, env_state, (0 * action, 1 * action), env_params + ) + assert jnp.array_equal(rewards[0], cd_p1 * r_array) + assert jnp.array_equal(rewards[1], cd_p2 * r_array) + + obs, env_state, rewards, done, info = env.step( + rng, env_state, (1 * action, 1 * action), env_params + ) + assert jnp.array_equal(rewards[0], dd_p1 * r_array) + assert jnp.array_equal(rewards[1], dd_p2 * r_array) def test_tit_for_tat_match() -> None: num_envs = 5 payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] - env = IteratedMatrixGame(num_envs, payoff, 5, 10) - t_0, t_1 = env.reset() + rng = jax.random.PRNGKey(0) + env = IteratedMatrixGame() + env_params = EnvParams( + payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=2 + ) + env.step = jax.vmap( + env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + obs, env_state = env.reset(rng, env_params) tit_for_tat = TitForTat(num_envs) - action_0 = tit_for_tat.select_action(t_0) - action_1 = tit_for_tat.select_action(t_1) + action_0 = tit_for_tat.select_action(obs[0]) + action_1 = tit_for_tat.select_action(obs[1]) assert jnp.array_equal(action_0, action_1) - t_0, t_1 = env.step((action_0, action_1)) - assert jnp.array_equal(t_0.reward, t_1.reward) + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action_0, action_1), env_params + ) + assert jnp.array_equal(rewards[0], rewards[1]) def test_longer_game() -> None: From 75c03f272737170cd57a779f26e4cb71769fa046 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 19:08:40 +0100 Subject: [PATCH 08/36] do not use gymnax.environments because it makes single agent assumptions --- pax/envs/iterated_matrix_game.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pax/envs/iterated_matrix_game.py b/pax/envs/iterated_matrix_game.py index fa3ef174..bd5dd9a6 100644 --- a/pax/envs/iterated_matrix_game.py +++ b/pax/envs/iterated_matrix_game.py @@ -64,7 +64,6 @@ def _step( ) # if first step then return START state. done = inner_t % params.num_inner_steps == 0 - print(done.shape, jnp.int8(4), jnp.int8(s1)) s1 = jax.lax.select(done, jnp.int8(4), jnp.int8(s1)) s2 = jax.lax.select(done, jnp.int8(4), jnp.int8(s2)) obs1 = jax.nn.one_hot(s1, 5, dtype=jnp.int8) @@ -86,7 +85,7 @@ def _step( {"discount": jnp.zeros((), dtype=jnp.int8)}, ) - def _reset_env( + def _reset( key: chex.PRNGKey, params: EnvParams ) -> Tuple[chex.Array, EnvState]: state = EnvState( @@ -96,13 +95,9 @@ def _reset_env( obs = jax.nn.one_hot(4 * jnp.ones(()), 5, dtype=jnp.int8) return (obs, obs), state - # for runner - # self.step = jax.jit(_step) - self.step = _step - self.reset_env = jax.jit(_reset_env) - # self.runner_step = jax.jit(jax.vmap(_step)) - # self.batch_step = jax.jit(jax.vmap(jax.vmap(_step))) - # self.runner_reset = runner_reset + # overwrite Gymnax as it makes single-agent assumptions + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) @property def name(self) -> str: From 5e48408c2c0069382f4f386f5fe06f9137f19ac0 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 19:22:35 +0100 Subject: [PATCH 09/36] titfortat now works --- test/envs/test_iterated_matrix_game.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/envs/test_iterated_matrix_game.py b/test/envs/test_iterated_matrix_game.py index 16e43d25..3af9860c 100644 --- a/test/envs/test_iterated_matrix_game.py +++ b/test/envs/test_iterated_matrix_game.py @@ -169,26 +169,31 @@ def test_batch_by_rngs() -> None: def test_tit_for_tat_match() -> None: num_envs = 5 payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] - rng = jax.random.PRNGKey(0) + rngs = jnp.concatenate(num_envs * [jax.random.PRNGKey(0)]).reshape( + num_envs, -1 + ) env = IteratedMatrixGame() env_params = EnvParams( payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=2 ) + env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) env.step = jax.vmap( env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) ) - obs, env_state = env.reset(rng, env_params) + + obs, env_state = env.reset(rngs, env_params) tit_for_tat = TitForTat(num_envs) action_0 = tit_for_tat.select_action(obs[0]) action_1 = tit_for_tat.select_action(obs[1]) assert jnp.array_equal(action_0, action_1) - obs, env_state, rewards, done, info = env.step( - rng, env_state, (action_0, action_1), env_params - ) - assert jnp.array_equal(rewards[0], rewards[1]) + for _ in range(10): + obs, env_state, rewards, done, info = env.step( + rngs, env_state, (action_0, action_1), env_params + ) + assert jnp.array_equal(rewards[0], rewards[1]) def test_longer_game() -> None: From c29f94c09d34f48fae0981193542ab9a06478395 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 21:58:22 +0100 Subject: [PATCH 10/36] all tests pass for iterated_matrix_game --- pax/strategies.py | 4 +- test/envs/test_iterated_matrix_game.py | 112 ++++++++++++++++--------- 2 files changed, 74 insertions(+), 42 deletions(-) diff --git a/pax/strategies.py b/pax/strategies.py index eaa1819b..53730bc6 100644 --- a/pax/strategies.py +++ b/pax/strategies.py @@ -391,11 +391,11 @@ def __init__(self, num_envs, *args): def select_action( self, - timestep: TimeStep, + obs: jnp.ndarray, ) -> jnp.ndarray: # state is [batch x time_step x num_players] # return [batch] - return self._reciprocity(timestep.observation) + return self._reciprocity(obs) def update(self, unused0, unused1, state, mem) -> None: return state, mem, {} diff --git a/test/envs/test_iterated_matrix_game.py b/test/envs/test_iterated_matrix_game.py index 3af9860c..cd373612 100644 --- a/test/envs/test_iterated_matrix_game.py +++ b/test/envs/test_iterated_matrix_game.py @@ -113,7 +113,7 @@ def test_batch_outcomes(actions, expected_rewards, payoff) -> None: def test_batch_by_rngs() -> None: - # we don't use the rng in step but good test of batchings by rngs + # we don't use the rng in step but good test of how runners wil use env num_envs = 2 payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] @@ -199,69 +199,101 @@ def test_tit_for_tat_match() -> None: def test_longer_game() -> None: num_envs = 1 payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] - num_steps = 50 + num_outer_steps = 25 num_inner_steps = 2 - env = IteratedMatrixGame(num_envs, payoff, num_inner_steps, num_steps) - t_0, t_1 = env.reset() + env = IteratedMatrixGame() - agent = TitForTat(num_envs) - action = agent.select_action(t_0) + # batch over actions and env_states + env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) + env.step = jax.vmap( + env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + + env_params = EnvParams( + payoff_matrix=payoff, + num_inner_steps=num_inner_steps, + num_outer_steps=num_outer_steps, + ) + + rngs = jnp.concatenate(num_envs * [jax.random.PRNGKey(0)]).reshape( + num_envs, -1 + ) + + obs, env_state = env.reset(rngs, env_params) + agent = TitForTat(num_envs) r1 = [] r2 = [] - for _ in range(10): - action = agent.select_action(t_0) - t0, t1 = env.step((action, action)) - r1.append(t0.reward) - r2.append(t1.reward) + for _ in range(num_outer_steps): + for _ in range(num_inner_steps): + action = agent.select_action(obs[0]) + obs, env_state, rewards, done, info = env.step( + rngs, env_state, (action, action), env_params + ) + r1.append(rewards[0]) + r2.append(rewards[1]) + assert jnp.array_equal(rewards[0], rewards[1]) + assert (done == True).all() - assert jnp.array_equal(t_0.reward, t_1.reward) assert jnp.mean(jnp.stack(r1)) == 2 assert jnp.mean(jnp.stack(r2)) == 2 def test_done(): - num_envs = 1 payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] - env = IteratedMatrixGame(num_envs, payoff, 5, 5) - action = jnp.ones((num_envs,)) - - # check first - t_0, t_1 = env.step((0 * action, 0 * action)) - assert t_0.last() == False - assert t_1.last() == False + num_outer_steps = 25 + num_inner_steps = 5 + env = IteratedMatrixGame() + env_params = EnvParams( + payoff_matrix=payoff, + num_inner_steps=num_inner_steps, + num_outer_steps=num_outer_steps, + ) + rng = jax.random.PRNGKey(0) + obs, env_state = env.reset(rng, env_params) + action = 0 for _ in range(4): - t_0, t_1 = env.step((0 * action, 0 * action)) - assert t_0.last() == False - assert t_1.last() == False + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action, action), env_params + ) + assert (done == False).all() # check final - t_0, t_1 = env.step((0 * action, 0 * action)) - assert t_0.last() == True - assert t_1.last() == True + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action, action), env_params + ) + assert (done == True).all() # check back at start - assert jnp.array_equal(t_0.observation.argmax(), 4) - assert jnp.array_equal(t_1.observation.argmax(), 4) + assert jnp.array_equal(obs[0].argmax(), 4) + assert jnp.array_equal(obs[1].argmax(), 4) def test_reset(): - num_envs = 1 payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] - env = IteratedMatrixGame(num_envs, payoff, 5, 10) - state = jnp.ones((num_envs,)) - - env.reset() + rng = jax.random.PRNGKey(0) + env = IteratedMatrixGame() + env_params = EnvParams( + payoff_matrix=payoff, + num_inner_steps=5, + num_outer_steps=2, + ) + action = 0 + obs, env_state = env.reset(rng, env_params) for _ in range(4): - t_0, t_1 = env.step((0 * state, 0 * state)) - assert t_0.last().all() == False - assert t_1.last().all() == False + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action, action), env_params + ) + assert done == False + assert done == False - env.reset() + obs, env_state = env.reset(rng, env_params) for _ in range(4): - t_0, t_1 = env.step((0 * state, 0 * state)) - assert t_0.last().all() == False - assert t_1.last().all() == False + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action, action), env_params + ) + assert done == False + assert done == False From 1f96f89e9ecac8307210e5d47591689ba142befc Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 22:18:50 +0100 Subject: [PATCH 11/36] first push for infinite matrix game --- pax/envs/infinite_matrix_game.py | 113 +++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 pax/envs/infinite_matrix_game.py diff --git a/pax/envs/infinite_matrix_game.py b/pax/envs/infinite_matrix_game.py new file mode 100644 index 00000000..914704f5 --- /dev/null +++ b/pax/envs/infinite_matrix_game.py @@ -0,0 +1,113 @@ +import jax +from flax import struct +import jax.numpy as jnp +import chex +from gymnax.environments import environment, spaces +from typing import Tuple, Optional + + +@chex.dataclass +class EnvState: + inner_t: int + outer_t: int + + +@chex.dataclass +class EnvParams: + payoff_matrix: jnp.ndarray + num_steps: int + gamma: float + + +class InfiniteMatrixGame(environment.Environment): + def __init__(self): + super().__init__() + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[int, int], + params: EnvParams, + ): + t = state.num_steps + key, _ = jax.random.split(key, 2) + payout_mat_1 = jnp.array([[r[0] for r in params.payoff_matrix]]) + payout_mat_2 = jnp.array([[r[1] for r in params.payoff_matrix]]) + + theta1, theta2 = actions + theta1, theta2 = jax.nn.sigmoid(theta1), jax.nn.sigmoid(theta2) + obs1 = jnp.concatenate([theta1, theta2]) + obs2 = jnp.concatenate([theta2, theta1]) + + _th2 = jnp.array( + [theta2[0], theta2[2], theta2[1], theta2[3], theta2[4]] + ) + + p_1_0 = theta1[4:5] + p_2_0 = _th2[4:5] + p = jnp.concatenate( + [ + p_1_0 * p_2_0, + p_1_0 * (1 - p_2_0), + (1 - p_1_0) * p_2_0, + (1 - p_1_0) * (1 - p_2_0), + ] + ) + p_1 = jnp.reshape(theta1[0:4], (4, 1)) + p_2 = jnp.reshape(_th2[0:4], (4, 1)) + P = jnp.concatenate( + [ + p_1 * p_2, + p_1 * (1 - p_2), + (1 - p_1) * p_2, + (1 - p_1) * (1 - p_2), + ], + axis=1, + ) + M = jnp.matmul(p, jnp.linalg.inv(jnp.eye(4) - params.gamma * P)) + L_1 = jnp.matmul(M, jnp.reshape(payout_mat_1, (4, 1))) + L_2 = jnp.matmul(M, jnp.reshape(payout_mat_2, (4, 1))) + r1 = (1 - params.gamma) * L_1.sum() + r2 = (1 - params.gamma) * L_2.sum() + done = t >= params.num_steps + return ( + (obs1, obs2), + EnvState(t + 1, 0), + (r1, r2), + done, + {"discount": jnp.zeros((), dtype=jnp.int8)}, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[chex.Array, EnvState]: + state = EnvState( + inner_t=jnp.zeros((), dtype=jnp.int8), + outer_t=jnp.zeros((), dtype=jnp.int8), + ) + obs = jax.nn.sigmoid(jax.random.uniform(key, (10,))) + return (obs, obs), state + + @property + def name(self) -> str: + """Environment name.""" + return "InfiniteMatrixGame-v1" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return 1 + + def action_space( + self, params: Optional[EnvParams] = None + ) -> spaces.Discrete: + """Action space of the environment.""" + return spaces.Box(low=0, high=1, shape=(5,)) + + def observation_space(self, params: EnvParams) -> spaces.Box: + """Observation space of the environment.""" + return spaces.Box(0, 1, (10,), dtype=jnp.float32) + + def state_space(self, params: EnvParams) -> spaces.Dict: + """State space of the environment.""" + return spaces.Box(0, 1, (10,), dtype=jnp.float32) From 0877855932300a994bab76520873a712777bd415 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 22:42:51 +0100 Subject: [PATCH 12/36] added tests for infinite matrix game --- test/envs/test_infinite_matrix_game.py | 132 +++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 test/envs/test_infinite_matrix_game.py diff --git a/test/envs/test_infinite_matrix_game.py b/test/envs/test_infinite_matrix_game.py new file mode 100644 index 00000000..c82bdffe --- /dev/null +++ b/test/envs/test_infinite_matrix_game.py @@ -0,0 +1,132 @@ +import jax.numpy as jnp +import jax +import pytest + +from pax.envs.infinite_matrix_game import InfiniteMatrixGame, EnvParams + +from pax.strategies import TitForTat + + +def test_single_infinite_game(): + rng = jax.random.PRNGKey(0) + payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] + alt_policy = 20 * jnp.ones((5)) + def_policy = -20 * jnp.ones((5)) + tft_policy = 20 * jnp.array([1.0, -1.0, 1.0, -1.0, 1.0]) + + # discount of 0.99 -> 1/(0.001) ~ 100 timestep + env = InfiniteMatrixGame() + env_params = EnvParams(payoff_matrix=payoff, num_steps=10, gamma=0.99) + + obs, env_state = env.reset(rng, env_params) + obs, env_state, rewards, done, info = env.step( + rng, env_state, (alt_policy, alt_policy), env_params + ) + + assert rewards[0] == rewards[1] + assert jnp.isclose(2, rewards[0], atol=0.01) + assert jnp.allclose(obs[0], jnp.ones((10)), atol=0.01) + assert jnp.allclose(obs[1], jnp.ones((10)), atol=0.01) + + obs, env_state, rewards, done, info = env.step( + rng, env_state, (def_policy, def_policy), env_params + ) + + assert rewards[0] == rewards[1] + assert jnp.isclose(1, rewards[0], atol=0.01) + assert jnp.allclose(obs[0], jnp.zeros((10))) + assert jnp.allclose(obs[1], jnp.zeros((10))) + + obs, env_state, rewards, done, info = env.step( + rng, env_state, (def_policy, alt_policy), env_params + ) + assert rewards[0] != rewards[1] + assert jnp.isclose(3, rewards[0]) + assert jnp.isclose(0.0, rewards[1], atol=0.0001) + assert jnp.allclose(obs[0], jnp.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])) + assert jnp.allclose(obs[1], jnp.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])) + + obs, env_state, rewards, done, info = env.step( + rng, env_state, (alt_policy, tft_policy), env_params + ) + assert jnp.isclose(2, rewards[0]) + assert jnp.isclose(2, rewards[1]) + assert jnp.allclose(obs[0], jnp.array([1, 1, 1, 1, 1, 1, 0, 1, 0, 1])) + assert jnp.allclose(obs[1], jnp.array([1, 0, 1, 0, 1, 1, 1, 1, 1, 1])) + + obs, env_state, rewards, done, info = env.step( + rng, env_state, (tft_policy, tft_policy), env_params + ) + assert jnp.isclose(2, rewards[0]) + assert jnp.isclose(2, rewards[1]) + assert jnp.allclose(obs[0], jnp.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + assert jnp.allclose(obs[1], jnp.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + + obs, env_state, rewards, done, info = env.step( + rng, env_state, (tft_policy, def_policy), env_params + ) + assert jnp.isclose(0.99, rewards[0]) + assert jnp.isclose(1.02, rewards[1]) + assert jnp.allclose(obs[0], jnp.array([1, 0, 1, 0, 1, 0, 0, 0, 0, 0])) + assert jnp.allclose(obs[1], jnp.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 1])) + + obs, env_state, rewards, done, info = env.step( + rng, env_state, (def_policy, tft_policy), env_params + ) + assert jnp.isclose(1.02, rewards[0]) + assert jnp.isclose(0.99, rewards[1]) + assert jnp.allclose(obs[0], jnp.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 1])) + assert jnp.allclose(obs[1], jnp.array([1, 0, 1, 0, 1, 0, 0, 0, 0, 0])) + + +def test_batch_infinite_game(): + payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] + rng = jax.random.PRNGKey(0) + alt_policy = 20 * jnp.ones((1, 5)) + def_policy = -20 * jnp.ones((1, 5)) + tft_policy = 20 * jnp.array([[1, -1, 1, -1, 1]]) + + env = InfiniteMatrixGame() + env_params = EnvParams(payoff_matrix=payoff, num_steps=10, gamma=0.99) + + env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) + env.step = jax.vmap( + env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + rngs = jnp.concatenate([rng, rng, rng], axis=0).reshape(3, -1) + + batched_alt = jnp.concatenate([alt_policy, alt_policy, alt_policy], axis=0) + batched_def = jnp.concatenate([def_policy, def_policy, def_policy], axis=0) + batch_mixed_1 = jnp.concatenate( + [alt_policy, def_policy, tft_policy], axis=0 + ) + batch_mixed_2 = jnp.concatenate( + [def_policy, tft_policy, tft_policy], axis=0 + ) + + # first step + obs, env_state = env.reset(rngs, env_params) + + obs, env_state, rewards, done, info = env.step( + rngs, env_state, (batched_alt, batched_alt), env_params + ) + assert jnp.allclose(rewards[0], rewards[1]) + assert jnp.allclose(jnp.array([2, 2, 2]), rewards[0], atol=0.01) + + obs, env_state, rewards, done, info = env.step( + rngs, env_state, (batched_alt, batched_def), env_params + ) + assert jnp.allclose(jnp.array([0, 0, 0]), rewards[0], atol=0.01) + assert jnp.allclose(jnp.array([3, 3, 3]), rewards[1], atol=0.01) + + obs, env_state, rewards, done, info = env.step( + rngs, env_state, (batch_mixed_1, batch_mixed_2), env_params + ) + assert jnp.allclose(jnp.array([0.0, 1.02, 2]), rewards[0]) + assert jnp.allclose(jnp.array([3, 0.99, 2]), rewards[1]) + + obs, env_state, rewards, done, info = env.step( + rngs, env_state, (batch_mixed_2, batch_mixed_1), env_params + ) + assert jnp.allclose(jnp.array([3, 0.99, 2]), rewards[0], atol=0.01) + assert jnp.allclose(jnp.array([0, 1.02, 2]), rewards[1], atol=0.01) From 1a739ab7cca8b779b730e5bf9c206a9a07f15758 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 22:43:48 +0100 Subject: [PATCH 13/36] added tests for infinite matrix game --- pax/envs/infinite_matrix_game.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pax/envs/infinite_matrix_game.py b/pax/envs/infinite_matrix_game.py index 914704f5..793c6f91 100644 --- a/pax/envs/infinite_matrix_game.py +++ b/pax/envs/infinite_matrix_game.py @@ -29,7 +29,7 @@ def _step( actions: Tuple[int, int], params: EnvParams, ): - t = state.num_steps + t = state.outer_t key, _ = jax.random.split(key, 2) payout_mat_1 = jnp.array([[r[0] for r in params.payoff_matrix]]) payout_mat_2 = jnp.array([[r[1] for r in params.payoff_matrix]]) @@ -70,9 +70,12 @@ def _step( r1 = (1 - params.gamma) * L_1.sum() r2 = (1 - params.gamma) * L_2.sum() done = t >= params.num_steps + state = EnvState( + inner_t=state.inner_t + 1, outer_t=state.outer_t + 1 + ) return ( (obs1, obs2), - EnvState(t + 1, 0), + state, (r1, r2), done, {"discount": jnp.zeros((), dtype=jnp.int8)}, @@ -88,6 +91,9 @@ def _reset( obs = jax.nn.sigmoid(jax.random.uniform(key, (10,))) return (obs, obs), state + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + @property def name(self) -> str: """Environment name.""" From 077fe48990f175520c577267d60f8286dea6b2ed Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 27 Oct 2022 22:58:37 +0100 Subject: [PATCH 14/36] shout out chris --- pax/envs/infinite_matrix_game.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pax/envs/infinite_matrix_game.py b/pax/envs/infinite_matrix_game.py index 793c6f91..1ea1e0fb 100644 --- a/pax/envs/infinite_matrix_game.py +++ b/pax/envs/infinite_matrix_game.py @@ -29,6 +29,7 @@ def _step( actions: Tuple[int, int], params: EnvParams, ): + # Thank you @luchris429 for this code! t = state.outer_t key, _ = jax.random.split(key, 2) payout_mat_1 = jnp.array([[r[0] for r in params.payoff_matrix]]) From fa9ca39a72a16f78977ec7f0b25799b93ff019a6 Mon Sep 17 00:00:00 2001 From: akbir Date: Fri, 28 Oct 2022 16:58:58 +0100 Subject: [PATCH 15/36] updated to include coingame --- pax/envs/coin_game.py | 555 ++++++++++++++++++++++++++++++++++++ test/envs/test_coin_game.py | 360 +++++++++++++++++++++++ 2 files changed, 915 insertions(+) create mode 100644 pax/envs/coin_game.py create mode 100644 test/envs/test_coin_game.py diff --git a/pax/envs/coin_game.py b/pax/envs/coin_game.py new file mode 100644 index 00000000..066d55fa --- /dev/null +++ b/pax/envs/coin_game.py @@ -0,0 +1,555 @@ +import jax +from dataclasses import replace +import jax.numpy as jnp +import chex +from gymnax.environments import environment, spaces +from typing import Tuple, Optional + + +@chex.dataclass +class EnvState: + red_pos: jnp.ndarray + blue_pos: jnp.ndarray + red_coin_pos: jnp.ndarray + blue_coin_pos: jnp.ndarray + inner_t: int + outer_t: int + # stats + red_coop: jnp.ndarray + red_defect: jnp.ndarray + blue_coop: jnp.ndarray + blue_defect: jnp.ndarray + counter: jnp.ndarray # 9 + coop1: jnp.ndarray # 9 + coop2: jnp.ndarray # 9 + last_state: jnp.ndarray # 2 + + +@chex.dataclass +class EnvParams: + payoff_matrix: chex.ArrayDevice + + +STATES = jnp.array( + [ + [0], # SS + [1], # CC + [2], # CD + [3], # DC + [4], # DD + [5], # SC + [6], # SD + [7], # CS + [8], # DS + ] +) +MOVES = jnp.array( + [ + [0, 1], # right + [0, -1], # left + [1, 0], # up + [-1, 0], # down + [0, 0], # stay + ] +) + + +class CoinGame(environment.Environment): + """ + JAX Compatible version of matrix game environment. + """ + + def __init__( + self, + num_inner_steps: int, + num_outer_steps: int, + cnn: bool, + egocentric: bool, + ): + + super().__init__() + + # helper functions + def _update_stats( + state: EnvState, + rr: jnp.ndarray, + rb: jnp.ndarray, + br: jnp.ndarray, + bb: jnp.ndarray, + ): + def state2idx(s: jnp.ndarray) -> int: + idx = 0 + idx = jnp.where((s == jnp.array([1, 1])).all(), 1, idx) + idx = jnp.where((s == jnp.array([1, 2])).all(), 2, idx) + idx = jnp.where((s == jnp.array([2, 1])).all(), 3, idx) + idx = jnp.where((s == jnp.array([2, 2])).all(), 4, idx) + idx = jnp.where((s == jnp.array([0, 1])).all(), 5, idx) + idx = jnp.where((s == jnp.array([0, 2])).all(), 6, idx) + idx = jnp.where((s == jnp.array([2, 0])).all(), 7, idx) + idx = jnp.where((s == jnp.array([1, 0])).all(), 8, idx) + return idx + + # actions are X, C, D + a1 = 0 + a1 = jnp.where(rr, 1, a1) + a1 = jnp.where(rb, 2, a1) + + a2 = 0 + a2 = jnp.where(bb, 1, a2) + a2 = jnp.where(br, 2, a2) + + # if we didn't get a coin this turn, use the last convention + convention_1 = jnp.where(a1 > 0, a1, state.last_state[0]) + convention_2 = jnp.where(a2 > 0, a2, state.last_state[1]) + + idx = state2idx(state.last_state) + counter = state.counter + jnp.zeros_like( + state.counter, dtype=jnp.int16 + ).at[idx].set(1) + coop1 = state.coop1 + jnp.zeros_like( + state.counter, dtype=jnp.int16 + ).at[idx].set(rr) + coop2 = state.coop2 + jnp.zeros_like( + state.counter, dtype=jnp.int16 + ).at[idx].set(bb) + convention = jnp.stack([convention_1, convention_2]).reshape(2) + return counter, coop1, coop2, convention + + def _abs_position(state: EnvState) -> jnp.ndarray: + obs1 = jnp.zeros((3, 3, 4), dtype=jnp.int8) + obs2 = jnp.zeros((3, 3, 4), dtype=jnp.int8) + + # obs channels are [red_player, blue_player, red_coin, blue_coin] + obs1 = obs1.at[state.red_pos[0], state.red_pos[1], 0].set(1) + obs1 = obs1.at[state.blue_pos[0], state.blue_pos[1], 1].set(1) + obs1 = obs1.at[ + state.red_coin_pos[0], state.red_coin_pos[1], 2 + ].set(1) + obs1 = obs1.at[ + state.blue_coin_pos[0], state.blue_coin_pos[1], 3 + ].set(1) + + # each agent has egotistic color (so thinks they are red) + obs2 = jnp.stack( + [obs1[:, :, 1], obs1[:, :, 0], obs1[:, :, 3], obs1[:, :, 2]], + axis=-1, + ) + return obs1, obs2 + + def _relative_position(state: EnvState) -> jnp.ndarray: + """Assume canonical agent is red player""" + # (x) redplayer at (2, 2) + # (y) redcoin at (0 ,0) + # + # o o x o o y + # o o o -> o x o + # y o o o o o + # + # redplayer goes to (1, 1) + # redcoing goes to (2, 2) + # offset = (-1, -1) + # new_redcoin = (0, 0) + (-1, -1) = (-1, -1) mod3 + # new_redcoin = (2, 2) + + agent_loc = jnp.array([state.red_pos[0], state.red_pos[1]]) + ego_offset = jnp.ones(2, dtype=jnp.int8) - agent_loc + + rel_other_player = (state.blue_pos + ego_offset) % 3 + rel_red_coin = (state.red_coin_pos + ego_offset) % 3 + rel_blue_coin = (state.blue_coin_pos + ego_offset) % 3 + + # create observation + obs = jnp.zeros((3, 3, 4), dtype=jnp.int8) + obs = obs.at[1, 1, 0].set(1) + obs = obs.at[rel_other_player[0], rel_other_player[1], 1].set(1) + obs = obs.at[rel_red_coin[0], rel_red_coin[1], 2].set(1) + obs = obs.at[rel_blue_coin[0], rel_blue_coin[1], 3].set(1) + return obs + + def _state_to_obs(state: EnvState) -> jnp.ndarray: + if egocentric: + print("Running Egocentric") + obs1 = _relative_position(state) + + # flip red and blue coins for second agent + obs2 = _relative_position( + EnvState( + red_pos=state.blue_pos, + blue_pos=state.red_pos, + red_coin_pos=state.blue_coin_pos, + blue_coin_pos=state.red_coin_pos, + inner_t=0, + outer_t=0, + red_coop=state.blue_coop, + red_defect=state.blue_defect, + blue_coop=state.red_coop, + blue_defect=state.red_defect, + last_state=state.last_state, + counter=state.counter, + coop1=state.coop1, + coop2=state.coop2, + ) + ) + else: + obs1, obs2 = _abs_position(state) + + if not cnn: + return obs1.flatten(), obs2.flatten() + return obs1, obs2 + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[int, int], + params: EnvParams, + ): + action_0, action_1 = actions + new_red_pos = (state.red_pos + MOVES[action_0]) % 3 + new_blue_pos = (state.blue_pos + MOVES[action_1]) % 3 + red_reward, blue_reward = 0, 0 + + red_red_matches = jnp.all( + new_red_pos == state.red_coin_pos, axis=-1 + ) + red_blue_matches = jnp.all( + new_red_pos == state.blue_coin_pos, axis=-1 + ) + + blue_red_matches = jnp.all( + new_blue_pos == state.red_coin_pos, axis=-1 + ) + blue_blue_matches = jnp.all( + new_blue_pos == state.blue_coin_pos, axis=-1 + ) + + ### [[1, -2],[1, -2] + _rr_reward = params.payoff_matrix[0][0] + _rb_reward = params.payoff_matrix[0][1] + _r_penalty = params.payoff_matrix[0][2] + _br_reward = params.payoff_matrix[1][0] + _bb_reward = params.payoff_matrix[1][1] + _b_penalty = params.payoff_matrix[1][2] + + red_reward = jnp.where( + red_red_matches, red_reward + _rr_reward, red_reward + ) + red_reward = jnp.where( + red_blue_matches, red_reward + _rb_reward, red_reward + ) + red_reward = jnp.where( + blue_red_matches, red_reward - _r_penalty, red_reward + ) + + blue_reward = jnp.where( + blue_red_matches, blue_reward + _br_reward, blue_reward + ) + blue_reward = jnp.where( + blue_blue_matches, blue_reward + _bb_reward, blue_reward + ) + blue_reward = jnp.where( + red_blue_matches, blue_reward - _b_penalty, blue_reward + ) + + (counter, coop1, coop2, last_state) = _update_stats( + state, + red_red_matches, + red_blue_matches, + blue_red_matches, + blue_blue_matches, + ) + + key, subkey = jax.random.split(key) + new_random_coin_poses = jax.random.randint( + subkey, shape=(2, 2), minval=0, maxval=3 + ) + new_red_coin_pos = jnp.where( + jnp.logical_or(red_red_matches, blue_red_matches), + new_random_coin_poses[0], + state.red_coin_pos, + ) + new_blue_coin_pos = jnp.where( + jnp.logical_or(red_blue_matches, blue_blue_matches), + new_random_coin_poses[1], + state.blue_coin_pos, + ) + + next_red_coop = state.red_coop + jnp.zeros( + num_outer_steps, dtype=jnp.int8 + ).at[state.outer_t].set(red_red_matches) + next_red_defect = state.red_defect + jnp.zeros( + num_outer_steps, dtype=jnp.int8 + ).at[state.outer_t].set(red_blue_matches) + next_blue_coop = state.blue_coop + jnp.zeros( + num_outer_steps, dtype=jnp.int8 + ).at[state.outer_t].set(blue_blue_matches) + next_blue_defect = state.blue_defect + jnp.zeros( + num_outer_steps, dtype=jnp.int8 + ).at[state.outer_t].set(blue_red_matches) + + next_state = EnvState( + red_pos=new_red_pos, + blue_pos=new_blue_pos, + red_coin_pos=new_red_coin_pos, + blue_coin_pos=new_blue_coin_pos, + inner_t=state.inner_t + 1, + outer_t=state.outer_t, + red_coop=next_red_coop, + red_defect=next_red_defect, + blue_coop=next_blue_coop, + blue_defect=next_blue_defect, + counter=counter, + coop1=coop1, + coop2=coop2, + last_state=last_state, + ) + + obs1, obs2 = _state_to_obs(next_state) + inner_t = next_state.inner_t + outer_t = next_state.outer_t + done = inner_t % num_inner_steps == 0 + + # if inner episode is done, return start state for next game + reset_obs, reset_state = _reset(key, params) + next_state = EnvState( + red_pos=jnp.where( + done, reset_state.red_pos, next_state.red_pos + ), + blue_pos=jnp.where( + done, reset_state.blue_pos, next_state.blue_pos + ), + red_coin_pos=jnp.where( + done, reset_state.red_coin_pos, next_state.red_coin_pos + ), + blue_coin_pos=jnp.where( + done, reset_state.blue_coin_pos, next_state.blue_coin_pos + ), + inner_t=jnp.where( + done, jnp.zeros_like(inner_t), next_state.inner_t + ), + outer_t=jnp.where(done, outer_t + 1, outer_t), + red_coop=next_state.red_coop, + red_defect=next_state.red_defect, + blue_coop=next_state.blue_coop, + blue_defect=next_state.blue_defect, + counter=counter, + coop1=coop1, + coop2=coop2, + last_state=jnp.where(done, jnp.zeros(2), last_state), + ) + + obs1 = jnp.where(done, reset_obs[0], obs1) + obs2 = jnp.where(done, reset_obs[1], obs2) + + blue_reward = jnp.where(done, 0, blue_reward) + red_reward = jnp.where(done, 0, red_reward) + return ( + (obs1, obs2), + next_state, + (red_reward, blue_reward), + done, + {"discount": jnp.zeros((), dtype=jnp.int8)}, + ) + + def _reset( + key: jnp.ndarray, params: EnvParams + ) -> Tuple[jnp.ndarray, EnvState]: + key, subkey = jax.random.split(key) + all_pos = jax.random.randint( + subkey, shape=(4, 2), minval=0, maxval=3 + ) + + empty_stats = jnp.zeros((num_outer_steps), dtype=jnp.int8) + state_stats = jnp.zeros(9) + + state = EnvState( + red_pos=all_pos[0, :], + blue_pos=all_pos[1, :], + red_coin_pos=all_pos[2, :], + blue_coin_pos=all_pos[3, :], + inner_t=0, + outer_t=0, + red_coop=empty_stats, + red_defect=empty_stats, + blue_coop=empty_stats, + blue_defect=empty_stats, + counter=state_stats, + coop1=state_stats, + coop2=state_stats, + last_state=jnp.zeros(2), + ) + obs1, obs2 = _state_to_obs(state) + return (obs1, obs2), state + + # overwrite Gymnax as it makes single-agent assumptions + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + self.cnn = cnn + + self.step = _step + self.reset = _reset + + @property + def name(self) -> str: + """Environment name.""" + return "CoinGame-v1" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return 5 + + def action_space( + self, params: Optional[EnvParams] = None + ) -> spaces.Discrete: + """Action space of the environment.""" + return spaces.Discrete(5) + + def observation_space(self, params: EnvParams) -> spaces.Box: + """Observation space of the environment.""" + _shape = (3, 3, 4) if self.cnn else (36,) + return spaces.Box(low=0, high=1, shape=_shape, dtype=jnp.uint8) + + def state_space(self, params: EnvParams) -> spaces.Dict: + """State space of the environment.""" + _shape = (3, 3, 4) if self.cnn else (36,) + return spaces.Box(low=0, high=1, shape=_shape, dtype=jnp.uint8) + + def render(self, state: EnvState): + from matplotlib.figure import Figure + from matplotlib.backends.backend_agg import ( + FigureCanvasAgg as FigureCanvas, + ) + from PIL import Image + import numpy as np + + """Small utility for plotting the agent's state.""" + fig = Figure((5, 2)) + canvas = FigureCanvas(fig) + ax = fig.add_subplot(121) + ax.imshow( + np.zeros((3, 3)), + cmap="Greys", + vmin=0, + vmax=1, + aspect="equal", + interpolation="none", + origin="lower", + extent=[0, 3, 0, 3], + ) + ax.set_aspect("equal") + + # ax.margins(0) + ax.set_xticks(jnp.arange(1, 4)) + ax.set_yticks(jnp.arange(1, 4)) + ax.grid() + red_pos = jnp.squeeze(state.red_pos) + blue_pos = jnp.squeeze(state.blue_pos) + red_coin_pos = jnp.squeeze(state.red_coin_pos) + blue_coin_pos = jnp.squeeze(state.blue_coin_pos) + ax.annotate( + "R", + fontsize=20, + color="red", + xy=(red_pos[0], red_pos[1]), + xycoords="data", + xytext=(red_pos[0] + 0.5, red_pos[1] + 0.5), + ) + ax.annotate( + "B", + fontsize=20, + color="blue", + xy=(blue_pos[0], blue_pos[1]), + xycoords="data", + xytext=(blue_pos[0] + 0.5, blue_pos[1] + 0.5), + ) + ax.annotate( + "Rc", + fontsize=20, + color="red", + xy=(red_coin_pos[0], red_coin_pos[1]), + xycoords="data", + xytext=(red_coin_pos[0] + 0.3, red_coin_pos[1] + 0.3), + ) + ax.annotate( + "Bc", + color="blue", + fontsize=20, + xy=(blue_coin_pos[0], blue_coin_pos[1]), + xycoords="data", + xytext=( + blue_coin_pos[0] + 0.3, + blue_coin_pos[1] + 0.3, + ), + ) + + ax2 = fig.add_subplot(122) + ax2.text(0.0, 0.95, "Timestep: %s" % (state.inner_t)) + ax2.text(0.0, 0.75, "Episode: %s" % (state.outer_t)) + ax2.text( + 0.0, 0.45, "Red Coop: %s" % (state.red_coop[state.outer_t].sum()) + ) + ax2.text( + 0.6, + 0.45, + "Red Defects : %s" % (state.red_defect[state.outer_t].sum()), + ) + ax2.text( + 0.0, 0.25, "Blue Coop: %s" % (state.blue_coop[state.outer_t].sum()) + ) + ax2.text( + 0.6, + 0.25, + "Blue Defects : %s" % (state.blue_defect[state.outer_t].sum()), + ) + ax2.text( + 0.0, + 0.05, + "Red Total: %s" + % ( + state.red_defect[state.outer_t].sum() + + state.red_coop[state.outer_t].sum() + ), + ) + ax2.text( + 0.6, + 0.05, + "Blue Total: %s" + % ( + state.blue_defect[state.outer_t].sum() + + state.blue_coop[state.outer_t].sum() + ), + ) + ax2.axis("off") + canvas.draw() + image = Image.frombytes( + "RGB", + fig.canvas.get_width_height(), + fig.canvas.tostring_rgb(), + ) + return image + + +if __name__ == "__main__": + bs = 1 + env = CoinGame(bs, 8, 16, 0, True, False) + action = jnp.ones(bs, dtype=int) + t1, t2 = env.reset() + rng = jax.random.PRNGKey(0) + pics = [] + + for _ in range(16): + rng, rng1, rng2 = jax.random.split(rng, 3) + a1 = jax.random.randint(rng1, (1,), minval=0, maxval=4) + a2 = jax.random.randint(rng2, (1,), minval=0, maxval=4) + t1, t2 = env.step((a1 * action, a2 * action)) + img = env.render(env.state) + pics.append(img) + + pics[0].save( + "test1.gif", + format="gif", + save_all=True, + append_images=pics[1:], + duration=300, + loop=0, + ) diff --git a/test/envs/test_coin_game.py b/test/envs/test_coin_game.py new file mode 100644 index 00000000..4a72f012 --- /dev/null +++ b/test/envs/test_coin_game.py @@ -0,0 +1,360 @@ +import jax.numpy as jnp +import jax + +from pax.envs.coin_game import CoinGame, EnvParams, EnvState + + +def test_coingame_shapes(): + rng = jax.random.PRNGKey(0) + env = CoinGame( + num_inner_steps=8, num_outer_steps=2, cnn=False, egocentric=True + ) + + params = EnvParams(payoff_matrix=[[1, 1, -2], [1, 1, -2]]) + + obs, state = env.reset(rng, params) + action = jnp.ones((), dtype=int) + + assert obs[0].shape == (36,) + assert obs[1].shape == (36,) + + obs, new_state, rewards, done, info = env.step( + rng, state, (action, action), params + ) + assert (rewards[0] == 0).all() + assert (rewards[1] == 0).all() + + assert (state.red_pos != new_state.red_pos).any() + assert (state.blue_pos != new_state.blue_pos).any() + + +def test_coingame_move(): + rng = jax.random.PRNGKey(0) + env = CoinGame(8, 2, True, True) + params = EnvParams(payoff_matrix=[[1, 1, -2], [1, 1, -2]]) + + action = 1 + _, state = env.reset(rng, params) + + state = EnvState( + red_pos=jnp.array([0, 0]), + blue_pos=jnp.array([1, 0]), + red_coin_pos=jnp.array([0, 2]), + blue_coin_pos=jnp.array([1, 2]), + inner_t=state.inner_t, + outer_t=state.outer_t, + red_coop=jnp.zeros(1), + red_defect=jnp.zeros(1), + blue_coop=jnp.zeros(1), + blue_defect=jnp.zeros(1), + counter=state.counter, + coop1=state.coop1, + coop2=state.coop2, + last_state=state.last_state, + ) + + _, state, rewards, _, _ = env.step(rng, state, (action, action), params) + assert rewards[0] == 1 + assert rewards[1] == 1 + assert (state.red_coop == jnp.array([1, 0])).all() + assert (state.red_defect == jnp.array([0, 0])).all() + assert (state.blue_coop == jnp.array([1, 0])).all() + assert (state.blue_defect == jnp.array([0, 0])).all() + assert state.inner_t == 1 + assert state.outer_t == 0 + + obs, state, rewards, done, info = env.step( + rng, state, (action, action), params + ) + assert rewards[0] == 0 + assert rewards[1] == 0 + assert (state.red_coop == jnp.array([1, 0])).all() + assert (state.red_defect == jnp.array([0, 0])).all() + assert (state.blue_coop == jnp.array([1, 0])).all() + assert (state.blue_defect == jnp.array([0, 0])).all() + assert state.inner_t == 2 + assert state.outer_t == 0 + + +def test_coingame_egocentric_colors(): + rng = jax.random.PRNGKey(0) + env = CoinGame(8, 2, True, True) + params = EnvParams(payoff_matrix=[[1, 1, -2], [1, 1, -2]]) + _, state = env.reset(rng, params) + action = 1 + + # importantly place agents on top of each other so that we can just check flips + state = EnvState( + red_pos=jnp.array([0, 0]), + blue_pos=jnp.array([0, 0]), + red_coin_pos=jnp.array([0, 2]), + blue_coin_pos=jnp.array([0, 2]), + inner_t=state.inner_t, + outer_t=state.outer_t, + red_coop=jnp.zeros(1), + red_defect=jnp.zeros(1), + blue_coop=jnp.zeros(1), + blue_defect=jnp.zeros(1), + counter=state.counter, + coop1=state.coop1, + coop2=state.coop2, + last_state=state.last_state, + ) + + for _ in range(7): + obs, state, rewards, done, info = env.step( + rng, state, (action, action), params + ) + obs1, obs2 = obs[0], obs[1] + # remove batch + assert (obs1[:, :, 0] == obs2[:, :, 1]).all() + assert (obs1[:, :, 1] == obs2[:, :, 0]).all() + assert (obs1[:, :, 2] == obs2[:, :, 3]).all() + assert (obs1[:, :, 3] == obs2[:, :, 2]).all() + + # here we reset the environment so expect these to break + obs, state, rewards, done, info = env.step( + rng, state, (action, action), params + ) + obs1, obs2 = obs[0], obs[1] + assert (obs1[:, :, 0] != obs2[:, :, 1]).any() + assert (obs1[:, :, 1] != obs2[:, :, 0]).any() + assert (obs1[:, :, 2] != obs2[:, :, 3]).any() + assert (obs1[:, :, 3] != obs2[:, :, 2]).any() + + +def test_coingame_egocentric_pos(): + rng = jax.random.PRNGKey(0) + env = CoinGame(8, 2, True, True) + params = EnvParams(payoff_matrix=[[1, 1, -2], [1, 1, -2]]) + _, state = env.reset(rng, params) + action = 1 + + state = EnvState( + red_pos=jnp.array([1, 2]), + blue_pos=jnp.array([2, 2]), + red_coin_pos=jnp.array([2, 2]), + blue_coin_pos=jnp.array([0, 0]), + inner_t=state.inner_t, + outer_t=state.outer_t, + red_coop=jnp.zeros(1), + red_defect=jnp.zeros(1), + blue_coop=jnp.zeros(1), + blue_defect=jnp.zeros(1), + counter=state.counter, + coop1=state.coop1, + coop2=state.coop2, + last_state=state.last_state, + ) + # it would be nice to have a stay action here lol + obs, state, rewards, done, info = env.step( + rng, state, (action, action), params + ) + + # takes left so red_pos = [1, 1], blue_pos = [2, 1] + obs1, obs2 = obs[0], obs[1] + + expected_obs1 = jnp.array( + [ + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], # agent + [[0, 0, 0], [0, 0, 0], [0, 1, 0]], # other agent + [[0, 0, 0], [0, 0, 0], [0, 0, 1]], # agent coin + [[1, 0, 0], [0, 0, 0], [0, 0, 0]], # other coin + ], + dtype=jnp.int8, + ) + # channel last + expected_obs1 = jnp.transpose(expected_obs1, (1, 2, 0)) + assert (expected_obs1 == obs1).all() + + expected_obs2 = jnp.array( + [ + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], # agent + [[0, 1, 0], [0, 0, 0], [0, 0, 0]], # other agent + [[0, 0, 0], [0, 0, 0], [1, 0, 0]], # agent coin + [[0, 0, 0], [0, 0, 1], [0, 0, 0]], # other coin + ], + dtype=jnp.int8, + ) + expected_obs2 = jnp.transpose(expected_obs2, (1, 2, 0)) + assert (expected_obs2 == obs2).all() + + +def test_coingame_stay(): + rng = jax.random.PRNGKey(0) + env = CoinGame(8, 2, True, True) + params = EnvParams(payoff_matrix=[[1, 1, -2], [1, 1, -2]]) + + action = 4 + _, state = env.reset(rng, params) + + _state = EnvState( + red_pos=jnp.array([0, 0]), + blue_pos=jnp.array([1, 0]), + red_coin_pos=jnp.array([0, 2]), + blue_coin_pos=jnp.array([1, 2]), + inner_t=state.inner_t, + outer_t=state.outer_t, + red_coop=jnp.zeros(1), + red_defect=jnp.zeros(1), + blue_coop=jnp.zeros(1), + blue_defect=jnp.zeros(1), + counter=state.counter, + coop1=state.coop1, + coop2=state.coop2, + last_state=state.last_state, + ) + + obs, state, rewards, done, info = env.step( + rng, _state, (action, action), params + ) + assert (state.red_pos == _state.red_pos).all() + assert (state.blue_pos == _state.blue_pos).all() + assert state.inner_t != _state.inner_t + + +def test_coingame_batch_egocentric(): + bs = 1 + action = jnp.ones(bs, dtype=int) + rngs = jnp.concatenate(bs * [jax.random.PRNGKey(0)]).reshape((bs, -1)) + env = CoinGame(8, 2, True, True) + params = EnvParams(payoff_matrix=[[1, 1, -2], [1, 1, -2]]) + + env.reset = jax.vmap(env.reset, in_axes=(0, None)) + env.step = jax.vmap( + env.step, + in_axes=(0, 0, 0, None), + ) + + obs, state = env.reset(rngs, params) + state = EnvState( + red_pos=jnp.array([[0, 0]]), + blue_pos=jnp.array([[0, 0]]), + red_coin_pos=jnp.array([[0, 2]]), + blue_coin_pos=jnp.array([[0, 2]]), + inner_t=state.inner_t, + outer_t=state.outer_t, + red_coop=jnp.zeros(1), + red_defect=jnp.zeros(1), + blue_coop=jnp.zeros(1), + blue_defect=jnp.zeros(1), + counter=state.counter, + coop1=state.coop1, + coop2=state.coop2, + last_state=state.last_state, + ) + obs1, obs2 = obs[0], obs[1] + + for _ in range(7): + obs, state, rewards, done, info = env.step( + rngs, state, (5 * action, 5 * action), params + ) + obs1, obs2 = obs[0][0], obs[1][0] + # remove batch + assert (obs1[:, :, 0] == obs2[:, :, 1]).all() + assert (obs1[:, :, 1] == obs2[:, :, 0]).all() + assert (obs1[:, :, 2] == obs2[:, :, 3]).all() + assert (obs1[:, :, 3] == obs2[:, :, 2]).all() + + # check negative case + env.state = EnvState( + red_pos=jnp.array([[0, 0]]), + blue_pos=jnp.array([[2, 2]]), + red_coin_pos=jnp.array([[0, 1]]), + blue_coin_pos=jnp.array([[0, 1]]), + inner_t=0 * state.inner_t, + outer_t=0 * state.outer_t, + red_coop=jnp.zeros(1), + red_defect=jnp.zeros(1), + blue_coop=jnp.zeros(1), + blue_defect=jnp.zeros(1), + counter=state.counter, + coop1=state.coop1, + coop2=state.coop2, + last_state=state.last_state, + ) + + obs, state, rewards, done, info = env.step( + rngs, state, (5 * action, 5 * action), params + ) + obs1, obs2 = obs[0][0], obs[1][0] + + assert (obs1[:, :, 0] != obs2[:, :, 1]).any() + assert (obs1[:, :, 1] != obs2[:, :, 0]).any() + assert (obs1[:, :, 2] != obs2[:, :, 3]).any() + assert (obs1[:, :, 3] != obs2[:, :, 2]).any() + + +def test_coingame_non_egocentric_batched(): + bs = 1 + action = jnp.ones(bs, dtype=int) + rngs = jnp.concatenate(bs * [jax.random.PRNGKey(0)]).reshape((bs, -1)) + env = CoinGame(8, 2, True, False) + params = EnvParams(payoff_matrix=[[1, 1, -2], [1, 1, -2]]) + + env.reset = jax.vmap(env.reset, in_axes=(0, None)) + env.step = jax.vmap( + env.step, + in_axes=(0, 0, 0, None), + ) + + _, state = env.reset(rngs, params) + state = EnvState( + red_pos=jnp.array([[0, 0]]), + blue_pos=jnp.array([[2, 2]]), + red_coin_pos=jnp.array([[1, 1]]), + blue_coin_pos=jnp.array([[1, 1]]), + inner_t=state.inner_t, + outer_t=state.outer_t, + red_coop=jnp.zeros(1), + red_defect=jnp.zeros(1), + blue_coop=jnp.zeros(1), + blue_defect=jnp.zeros(1), + counter=state.counter, + coop1=state.coop1, + coop2=state.coop2, + last_state=state.last_state, + ) + + obs, state, rewards, done, info = env.step( + rngs, state, (5 * action, 5 * action), params + ) + obs1, obs2 = obs[0][0], obs[1][0] + + expected_obs1 = jnp.array( + [ + [[1, 0, 0], [0, 0, 0], [0, 0, 0]], # agent + [[0, 0, 0], [0, 0, 0], [0, 0, 1]], # other agent + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], # agent coin + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], # other coin + ], + dtype=jnp.int8, + ) + # channel last + expected_obs1 = jnp.transpose(expected_obs1, (1, 2, 0)) + + assert (obs1[:, :, 0] == expected_obs1[:, :, 0]).all() + assert (expected_obs1 == obs1).all() + + expected_obs2 = jnp.array( + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 1]], # agent + [[1, 0, 0], [0, 0, 0], [0, 0, 0]], # other agent + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], # agent coin + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], # other coin + ], + dtype=jnp.int8, + ) + expected_obs2 = jnp.transpose(expected_obs2, (1, 2, 0)) + assert (expected_obs2 == obs2).all() + + for _ in range(6): + obs, state, rewards, done, info = env.step( + rngs, state, (action, action), params + ) + obs1, obs2 = obs[0][0], obs[1][0] + # remove batch + assert (obs1[:, :, 0] == obs2[:, :, 1]).all() + assert (obs1[:, :, 1] == obs2[:, :, 0]).all() + assert (obs1[:, :, 2] == obs2[:, :, 3]).all() + assert (obs1[:, :, 3] == obs2[:, :, 2]).all() From c087d8934ac5ee2428ef8752592abf523173c469 Mon Sep 17 00:00:00 2001 From: akbir Date: Fri, 28 Oct 2022 17:09:55 +0100 Subject: [PATCH 16/36] moved some arguments back to kwargs --- pax/envs/infinite_matrix_game.py | 5 ++-- pax/envs/iterated_matrix_game.py | 8 ++---- test/envs/test_infinite_matrix_game.py | 11 +++----- test/envs/test_iterated_matrix_game.py | 39 ++++++++------------------ 4 files changed, 21 insertions(+), 42 deletions(-) diff --git a/pax/envs/infinite_matrix_game.py b/pax/envs/infinite_matrix_game.py index 1ea1e0fb..12412f11 100644 --- a/pax/envs/infinite_matrix_game.py +++ b/pax/envs/infinite_matrix_game.py @@ -15,12 +15,11 @@ class EnvState: @chex.dataclass class EnvParams: payoff_matrix: jnp.ndarray - num_steps: int gamma: float class InfiniteMatrixGame(environment.Environment): - def __init__(self): + def __init__(self, num_steps: int): super().__init__() def _step( @@ -70,7 +69,7 @@ def _step( L_2 = jnp.matmul(M, jnp.reshape(payout_mat_2, (4, 1))) r1 = (1 - params.gamma) * L_1.sum() r2 = (1 - params.gamma) * L_2.sum() - done = t >= params.num_steps + done = t >= num_steps state = EnvState( inner_t=state.inner_t + 1, outer_t=state.outer_t + 1 ) diff --git a/pax/envs/iterated_matrix_game.py b/pax/envs/iterated_matrix_game.py index bd5dd9a6..a136ca46 100644 --- a/pax/envs/iterated_matrix_game.py +++ b/pax/envs/iterated_matrix_game.py @@ -15,8 +15,6 @@ class EnvState: @chex.dataclass class EnvParams: payoff_matrix: jnp.ndarray - num_inner_steps: int - num_outer_steps: int class IteratedMatrixGame(environment.Environment): @@ -24,7 +22,7 @@ class IteratedMatrixGame(environment.Environment): JAX Compatible version of matrix game environment. """ - def __init__(self): + def __init__(self, num_inner_steps: int): super().__init__() def _step( @@ -63,14 +61,14 @@ def _step( + 3 * (a1) * (a2) ) # if first step then return START state. - done = inner_t % params.num_inner_steps == 0 + done = inner_t % num_inner_steps == 0 s1 = jax.lax.select(done, jnp.int8(4), jnp.int8(s1)) s2 = jax.lax.select(done, jnp.int8(4), jnp.int8(s2)) obs1 = jax.nn.one_hot(s1, 5, dtype=jnp.int8) obs2 = jax.nn.one_hot(s2, 5, dtype=jnp.int8) # out step keeping - reset_inner = inner_t == params.num_inner_steps + reset_inner = inner_t == num_inner_steps inner_t = jax.lax.select( reset_inner, jnp.zeros_like(inner_t), inner_t ) diff --git a/test/envs/test_infinite_matrix_game.py b/test/envs/test_infinite_matrix_game.py index c82bdffe..1ca8316f 100644 --- a/test/envs/test_infinite_matrix_game.py +++ b/test/envs/test_infinite_matrix_game.py @@ -1,11 +1,8 @@ import jax.numpy as jnp import jax -import pytest from pax.envs.infinite_matrix_game import InfiniteMatrixGame, EnvParams -from pax.strategies import TitForTat - def test_single_infinite_game(): rng = jax.random.PRNGKey(0) @@ -15,8 +12,8 @@ def test_single_infinite_game(): tft_policy = 20 * jnp.array([1.0, -1.0, 1.0, -1.0, 1.0]) # discount of 0.99 -> 1/(0.001) ~ 100 timestep - env = InfiniteMatrixGame() - env_params = EnvParams(payoff_matrix=payoff, num_steps=10, gamma=0.99) + env = InfiniteMatrixGame(num_steps=10) + env_params = EnvParams(payoff_matrix=payoff, gamma=0.99) obs, env_state = env.reset(rng, env_params) obs, env_state, rewards, done, info = env.step( @@ -86,8 +83,8 @@ def test_batch_infinite_game(): def_policy = -20 * jnp.ones((1, 5)) tft_policy = 20 * jnp.array([[1, -1, 1, -1, 1]]) - env = InfiniteMatrixGame() - env_params = EnvParams(payoff_matrix=payoff, num_steps=10, gamma=0.99) + env = InfiniteMatrixGame(num_steps=10) + env_params = EnvParams(payoff_matrix=payoff, gamma=0.99) env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) env.step = jax.vmap( diff --git a/test/envs/test_iterated_matrix_game.py b/test/envs/test_iterated_matrix_game.py index cd373612..b393e7cc 100644 --- a/test/envs/test_iterated_matrix_game.py +++ b/test/envs/test_iterated_matrix_game.py @@ -18,10 +18,8 @@ def test_single_batch_rewards(payoff) -> None: num_envs = 5 rng = jax.random.PRNGKey(0) - env = IteratedMatrixGame() - env_params = EnvParams( - payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=1 - ) + env = IteratedMatrixGame(num_inner_steps=5) + env_params = EnvParams(payoff_matrix=payoff) action = jnp.ones((num_envs,), dtype=jnp.float32) r_array = jnp.ones((num_envs,), dtype=jnp.float32) @@ -94,10 +92,8 @@ def test_batch_outcomes(actions, expected_rewards, payoff) -> None: a1, a2 = actions expected_r1, expected_r2 = expected_rewards - env = IteratedMatrixGame() - env_params = EnvParams( - payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=1 - ) + env = IteratedMatrixGame(num_inner_steps=5) + env_params = EnvParams(payoff_matrix=payoff) # we want to batch over envs purely by actions env.step = jax.vmap( env.step, in_axes=(None, None, 0, None), out_axes=(0, None, 0, 0, 0) @@ -120,10 +116,8 @@ def test_batch_by_rngs() -> None: rng = jnp.concatenate( [jax.random.PRNGKey(0), jax.random.PRNGKey(0)] ).reshape(num_envs, -1) - env = IteratedMatrixGame() - env_params = EnvParams( - payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=1 - ) + env = IteratedMatrixGame(num_inner_steps=5) + env_params = EnvParams(payoff_matrix=payoff) action = jnp.ones((num_envs,), dtype=jnp.float32) r_array = jnp.ones((num_envs,), dtype=jnp.float32) @@ -172,10 +166,8 @@ def test_tit_for_tat_match() -> None: rngs = jnp.concatenate(num_envs * [jax.random.PRNGKey(0)]).reshape( num_envs, -1 ) - env = IteratedMatrixGame() - env_params = EnvParams( - payoff_matrix=payoff, num_inner_steps=5, num_outer_steps=2 - ) + env = IteratedMatrixGame(num_inner_steps=5) + env_params = EnvParams(payoff_matrix=payoff) env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) env.step = jax.vmap( @@ -201,7 +193,7 @@ def test_longer_game() -> None: payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] num_outer_steps = 25 num_inner_steps = 2 - env = IteratedMatrixGame() + env = IteratedMatrixGame(num_inner_steps=num_inner_steps) # batch over actions and env_states env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) @@ -211,8 +203,6 @@ def test_longer_game() -> None: env_params = EnvParams( payoff_matrix=payoff, - num_inner_steps=num_inner_steps, - num_outer_steps=num_outer_steps, ) rngs = jnp.concatenate(num_envs * [jax.random.PRNGKey(0)]).reshape( @@ -241,19 +231,16 @@ def test_longer_game() -> None: def test_done(): payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] - num_outer_steps = 25 num_inner_steps = 5 - env = IteratedMatrixGame() + env = IteratedMatrixGame(num_inner_steps=num_inner_steps) env_params = EnvParams( payoff_matrix=payoff, - num_inner_steps=num_inner_steps, - num_outer_steps=num_outer_steps, ) rng = jax.random.PRNGKey(0) obs, env_state = env.reset(rng, env_params) action = 0 - for _ in range(4): + for _ in range(num_inner_steps - 1): obs, env_state, rewards, done, info = env.step( rng, env_state, (action, action), env_params ) @@ -273,11 +260,9 @@ def test_done(): def test_reset(): payoff = [[2, 2], [0, 3], [3, 0], [1, 1]] rng = jax.random.PRNGKey(0) - env = IteratedMatrixGame() + env = IteratedMatrixGame(num_inner_steps=5) env_params = EnvParams( payoff_matrix=payoff, - num_inner_steps=5, - num_outer_steps=2, ) action = 0 From 6808c8c832014cb76261b1b8d7b0da25dafb7947 Mon Sep 17 00:00:00 2001 From: akbir Date: Fri, 28 Oct 2022 17:11:19 +0100 Subject: [PATCH 17/36] revert existing envs --- pax/env.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/pax/env.py b/pax/env.py index 2ef800d2..01972542 100644 --- a/pax/env.py +++ b/pax/env.py @@ -111,6 +111,29 @@ def runner_reset(ndims, rng): self.batch_step = jax.jit(jax.vmap(jax.vmap(_step))) self.runner_reset = runner_reset + def step(self, actions): + if self._reset_next_step: + return self.reset() + + output, self.state = self.runner_step(actions, self.state) + if (self.state.outer_t == self.outer_ep_length).all(): + self._reset_next_step = True + output = ( + TimeStep( + 2 * jnp.ones(self.num_envs, dtype=jnp.int8), + output[0].reward, + output[0].discount, + output[0].observation, + ), + TimeStep( + 2 * jnp.ones(self.num_envs, dtype=jnp.int8), + output[1].reward, + output[1].discount, + output[1].observation, + ), + ) + return output + def observation_spec(self) -> specs.DiscreteArray: """Returns the observation spec.""" return specs.DiscreteArray(num_values=5, name="previous turn") @@ -198,6 +221,34 @@ def _step(actions, state): jax.vmap(self.runner_step, (0, None), (0, None)) ) + def step( + self, actions: Tuple[jnp.ndarray, jnp.ndarray] + ) -> Tuple[TimeStep, TimeStep]: + """ + takes a tuple of batched policies and produce value functions from infinite game + policy of form [B, 5] + """ + if self._reset_next_step: + return self.reset() + + action_1, action_2 = actions + self._num_steps += 1 + assert action_1.shape == action_2.shape + assert action_1.shape == (self.num_envs, 5) + + outputs, self.state = self._jit_step(actions, self.state) + r1, r2, obs1, obs2, _ = outputs + r1, r2 = (1 - self.gamma) * r1, (1 - self.gamma) * r2 + + if self._num_steps == self.episode_length: + self._reset_next_step = True + return termination(reward=r1, observation=obs1), termination( + reward=r2, observation=obs2 + ) + return transition(reward=r1, observation=obs1), transition( + reward=r2, observation=obs2 + ) + def runner_step( self, actions: Tuple[jnp.ndarray, jnp.ndarray], @@ -677,6 +728,10 @@ def reset(self) -> Tuple[TimeStep, TimeStep]: output, self.state = self._reset(self.key) return output + def step(self, actions: Tuple[int, int]) -> Tuple[TimeStep, TimeStep]: + output, self.state = self.runner_step(actions, self.state) + return output + def observation_spec(self) -> specs.BoundedArray: """Returns the observation spec.""" if self.cnn: From c1d5ed232e70c078d0ce208bf15f556149850eb2 Mon Sep 17 00:00:00 2001 From: akbir Date: Fri, 28 Oct 2022 17:19:42 +0100 Subject: [PATCH 18/36] fixed render --- pax/envs/coin_game.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pax/envs/coin_game.py b/pax/envs/coin_game.py index 066d55fa..32c9d47e 100644 --- a/pax/envs/coin_game.py +++ b/pax/envs/coin_game.py @@ -530,19 +530,22 @@ def render(self, state: EnvState): if __name__ == "__main__": - bs = 1 - env = CoinGame(bs, 8, 16, 0, True, False) - action = jnp.ones(bs, dtype=int) - t1, t2 = env.reset() + action = 1 rng = jax.random.PRNGKey(0) + env = CoinGame(8, 16, True, False) + + params = EnvParams(payoff_matrix=[[1, 1, -2], [1, 1, -2]]) + obs, state = env.reset(rng, params) pics = [] for _ in range(16): rng, rng1, rng2 = jax.random.split(rng, 3) - a1 = jax.random.randint(rng1, (1,), minval=0, maxval=4) - a2 = jax.random.randint(rng2, (1,), minval=0, maxval=4) - t1, t2 = env.step((a1 * action, a2 * action)) - img = env.render(env.state) + a1 = jax.random.randint(rng1, (), minval=0, maxval=4) + a2 = jax.random.randint(rng2, (), minval=0, maxval=4) + obs, state, reward, done, info = env.step( + rng, state, (a1 * action, a2 * action), params + ) + img = env.render(state) pics.append(img) pics[0].save( From 01e3a132709a7398ed0f9c6027e2b30378bf2bc9 Mon Sep 17 00:00:00 2001 From: akbir Date: Sun, 30 Oct 2022 18:11:03 +0000 Subject: [PATCH 19/36] WIP: first push --- pax/envs/infinite_matrix_game.py | 2 +- pax/envs/iterated_matrix_game.py | 2 +- pax/experiment.py | 129 +++++++++++-------------------- pax/runner_rl.py | 57 ++++++++++---- 4 files changed, 87 insertions(+), 103 deletions(-) diff --git a/pax/envs/infinite_matrix_game.py b/pax/envs/infinite_matrix_game.py index 12412f11..4c8e05f7 100644 --- a/pax/envs/infinite_matrix_game.py +++ b/pax/envs/infinite_matrix_game.py @@ -14,7 +14,7 @@ class EnvState: @chex.dataclass class EnvParams: - payoff_matrix: jnp.ndarray + payoff_matrix: chex.ArrayDevice gamma: float diff --git a/pax/envs/iterated_matrix_game.py b/pax/envs/iterated_matrix_game.py index a136ca46..24a55f12 100644 --- a/pax/envs/iterated_matrix_game.py +++ b/pax/envs/iterated_matrix_game.py @@ -14,7 +14,7 @@ class EnvState: @chex.dataclass class EnvParams: - payoff_matrix: jnp.ndarray + payoff_matrix: chex.ArrayDevice class IteratedMatrixGame(environment.Environment): diff --git a/pax/experiment.py b/pax/experiment.py index b05036cc..17533a04 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -14,7 +14,17 @@ import wandb -from pax.env import CoinGame, IteratedMatrixGame, InfiniteMatrixGame +from pax.envs.coin_game import CoinGame, EnvParams as CoinGameParams +from pax.envs.iterated_matrix_game import ( + IteratedMatrixGame, + EnvParams as IteratedMatrixGameParams, +) +from pax.envs.infinite_matrix_game import ( + InfiniteMatrixGame, + EnvParams as InfiniteMatrixGameParams, +) +import jax.numpy as jnp + from pax.hyper.ppo import make_hyper from pax.learners import IndependentLearners, EvolutionaryLearners from pax.naive.naive import make_naive_pg @@ -24,7 +34,7 @@ from pax.mfos_ppo.ppo_gru import make_gru_agent as make_mfos_agent from pax.runner_eval import EvalRunner from pax.runner_evo import EvoRunner -from pax.runner_rl import Runner +from pax.runner_rl import RLRunner from pax.strategies import ( Altruistic, Defect, @@ -84,52 +94,23 @@ def env_setup(args, logger=None): """Set up env variables.""" if args.env_id == "ipd": + payoff = jnp.array(args.payoff) if args.env_type == "sequential": - train_env = IteratedMatrixGame( - args.num_envs, - args.payoff, - inner_ep_length=args.num_steps, - num_steps=args.num_steps, - ) - test_env = IteratedMatrixGame( - 1, - args.payoff, - inner_ep_length=args.num_steps, - num_steps=args.num_steps, - ) + env = IteratedMatrixGame(num_inner_steps=args.num_steps) + env_params = IteratedMatrixGameParams(payoff_matrix=payoff) elif args.env_type == "meta": - train_env = IteratedMatrixGame( - args.num_envs, - args.payoff, - inner_ep_length=args.num_inner_steps, - num_steps=args.num_steps, - ) - test_env = IteratedMatrixGame( - 1, - args.payoff, - inner_ep_length=args.num_inner_steps, - num_steps=args.num_steps, - ) + env = IteratedMatrixGame(num_inner_steps=args.num_inner_steps) + env_params = IteratedMatrixGameParams(payoff_matrix=payoff) if logger: logger.info( f"Env Type: Meta | Episode Length: {args.num_steps}" ) elif args.env_type == "infinite": - train_env = InfiniteMatrixGame( - args.num_envs, - args.payoff, - args.num_steps, - args.env_discount, - args.seed, - ) - test_env = InfiniteMatrixGame( - args.num_envs, - args.payoff, - args.num_steps, - args.env_discount, - args.seed + 1, + env = InfiniteMatrixGame(num_steps=args.num_steps) + env_params = InfiniteMatrixGameParams( + payoff_matrix=payoff, gamma=args.gamma ) if logger: logger.info( @@ -140,48 +121,30 @@ def env_setup(args, logger=None): elif args.env_id == "coin_game": if args.env_type == "sequential": - train_env = CoinGame( - args.num_envs, - inner_ep_length=args.num_steps, - num_steps=args.num_steps, - seed=args.seed, - cnn=args.ppo.with_cnn, + env = CoinGame( + num_inner_steps=args.num_steps, + num_outer_steps=args.num_steps, + cnn=args.cnn, egocentric=args.egocentric, ) - test_env = CoinGame( - 1, - inner_ep_length=args.num_steps, - num_steps=args.num_steps, - seed=args.seed, - cnn=args.ppo.with_cnn, - egocentric=args.egocentric, - ) - + env_params = CoinGameParams(args.payoff_matrix) else: - train_env = CoinGame( - args.num_envs, - inner_ep_length=args.num_inner_steps, - num_steps=args.num_steps, - seed=args.seed, - cnn=args.ppo.with_cnn, - egocentric=args.egocentric, - ) - test_env = CoinGame( - 1, - inner_ep_length=args.num_inner_steps, - num_steps=args.num_steps, - seed=args.seed, - cnn=args.ppo.with_cnn, + env = CoinGame( + num_inner_steps=args.num_inner_steps, + num_outer_steps=args.num_steps, + cnn=args.cnn, egocentric=args.egocentric, ) + env_params = CoinGameParams(args.payoff_matrix) + if logger: logger.info( f"Env Type: CoinGame | Episode Length: {args.num_steps}" ) - return train_env, test_env + return env, env_params -def runner_setup(args, agents, save_dir, logger): +def runner_setup(args, env, agents, save_dir, logger): if args.runner == "eval": logger.info("Evaluating with EvalRunner") return EvalRunner(args) @@ -272,21 +235,21 @@ def get_pgpe_strategy(agent): elif args.runner == "rl": logger.info("Training with RL Runner") - return Runner(args, save_dir) + return RLRunner(env, save_dir, args) else: raise ValueError(f"Unknown runner type {args.runner}") # flake8: noqa: C901 -def agent_setup(args, obs_spec, action_spec, logger): +def agent_setup(args, env, env_params, logger): """Set up agent variables.""" if args.env_id == "coin_game": - obs_shape = obs_spec.shape + obs_shape = env.observation_space(env_params).shape elif args.env_id == "ipd": - obs_shape = (obs_spec.num_values,) + obs_shape = (env.observation_space(env_params).n,) - num_actions = action_spec.num_values + num_actions = env.num_actions def get_PPO_memory_agent(seed, player_id): ppo_memory_agent = make_gru_agent( @@ -529,18 +492,16 @@ def main(args): save_dir = global_setup(args) with Section("Env setup", logger=logger): - train_env, test_env = env_setup(args, logger) + env, env_params = env_setup(args, logger) with Section("Agent setup", logger=logger): - agent_pair = agent_setup( - args, train_env.observation_spec(), train_env.action_spec(), logger - ) + agent_pair = agent_setup(args, env, env_params, logger) with Section("Watcher setup", logger=logger): watchers = watcher_setup(args, logger) with Section("Runner setup", logger=logger): - runner = runner_setup(args, agent_pair, save_dir, logger) + runner = runner_setup(args, env, agent_pair, save_dir, logger) if not args.wandb.log: watchers = False @@ -548,28 +509,28 @@ def main(args): if args.runner == "evo": num_iters = args.num_generations # number of generations print(f"Number of Generations: {num_iters}") - runner.run_loop(train_env, agent_pair, num_iters, watchers) + runner.run_loop(env, agent_pair, num_iters, watchers) elif args.runner == "rl": num_iters = int( args.total_timesteps / args.num_steps ) # number of episodes print(f"Number of Episodes: {num_iters}") - runner.run_loop(train_env, agent_pair, num_iters, watchers) + runner.run_loop(env, env_params, agent_pair, num_iters, watchers) elif args.runner == "eval": num_iters = int( args.total_timesteps / args.num_steps ) # number of episodes print(f"Number of Episodes: {num_iters}") - runner.run_loop(train_env, agent_pair, num_iters, watchers) + runner.run_loop(env, agent_pair, num_iters, watchers) elif args.runner == "eval": num_iters = int( args.total_timesteps / args.num_steps ) # number of episodes print(f"Number of Episodes: {num_iters}") - runner.run_loop(train_env, agent_pair, num_iters, watchers) + runner.run_loop(env, agent_pair, num_iters, watchers) if __name__ == "__main__": diff --git a/pax/runner_rl.py b/pax/runner_rl.py index d684ebb0..cae4bc02 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -50,10 +50,10 @@ def reduce_outer_traj(traj: Sample) -> Sample: ) -class Runner: +class RLRunner: """Holds the runner's state.""" - def __init__(self, args, save_dir): + def __init__(self, env, save_dir, args): self.train_steps = 0 self.eval_steps = 0 self.train_episodes = 0 @@ -76,11 +76,27 @@ def _reshape_opp_dim(x): self.ipd_stats = jax.jit(ipd_visitation) self.cg_stats = jax.jit(cg_visitation) - def run_loop(self, env, agents, num_episodes, watchers): + # we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + def run_loop(self, env, env_params, agents, num_episodes, watchers): def _inner_rollout(carry, unused): """Runner for inner episode""" - t1, t2, a1_state, a1_mem, a2_state, a2_mem, env_state = carry - + ( + t1, + t2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_param, + env_rng, + ) = carry + env_rng, _ = jax.random.split(rng) a1, a1_state, new_a1_mem = agent1.batch_policy( a1_state, t1.observation, @@ -91,9 +107,11 @@ def _inner_rollout(carry, unused): t2.observation, a2_mem, ) - (tprime_1, tprime_2), env_state = env.batch_step( - (a1, a2), + (tprime_1, tprime_2), env_state = env.step( + env_rng, env_state, + (a1, a2), + env_param, ) if self.args.agent1 == "MFOS": @@ -134,6 +152,8 @@ def _inner_rollout(carry, unused): a2_state, new_a2_mem, env_state, + env_param, + env_rng, ), ( traj1, traj2, @@ -147,7 +167,7 @@ def _outer_rollout(carry, unused): _inner_rollout, carry, None, - length=env.inner_episode_length, + length=self.args.num_inner_steps, ) # MFOS has to takes a meta-action for each episode @@ -178,21 +198,24 @@ def _outer_rollout(carry, unused): a1_state, a1_mem = agent1._state, agent1._mem a2_state, a2_mem = agent2._state, agent2._mem - num_iters = max(int(num_episodes / (env.num_envs * self.num_opps)), 1) + num_iters = max( + int(num_episodes / (self.args.num_envs * self.num_opps)), 1 + ) log_interval = max(num_iters / MAX_WANDB_CALLS, 5) print(f"Log Interval {log_interval}") # run actual loop for i in range(num_episodes): - rng, rng_run = jax.random.split(rng) - t_init, env_state = env.runner_reset( - (self.num_opps, env.num_envs), rng_run - ) + rngs = jnp.concatenate( + jax.random.split(rng, self.args.num_opps * self.args.num_envs) + ).reshape((self.args.num_opps, self.args.num_envs, -1)) + + obs, env_state = env.reset(rngs, env_params) if self.args.agent1 == "NaiveEx": - a1_state, a1_mem = agent1.batch_init(t_init[0]) + a1_state, a1_mem = agent1.batch_init(obs[0]) if self.args.agent2 == "NaiveEx": - a2_state, a2_mem = agent2.batch_init(t_init[1]) + a2_state, a2_mem = agent2.batch_init(obs[1]) elif self.args.env_type in ["meta", "infinite"]: # meta-experiments - init 2nd agent per trial @@ -202,9 +225,9 @@ def _outer_rollout(carry, unused): # run trials vals, stack = jax.lax.scan( _outer_rollout, - (*t_init, a1_state, a1_mem, a2_state, a2_mem, env_state), + (*obs, a1_state, a1_mem, a2_state, a2_mem, env_state), None, - length=env.outer_ep_length, + length=self.args.num_steps // self.args.num_inner_steps, ) t1, t2, a1_state, a1_mem, a2_state, a2_mem, env_state = vals From 3b6de1660e7753b97a352aefcce2a3571d7a7139 Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 12:58:45 +0000 Subject: [PATCH 20/36] changed inner/outer/loop --- pax/runner_rl.py | 127 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 94 insertions(+), 33 deletions(-) diff --git a/pax/runner_rl.py b/pax/runner_rl.py index cae4bc02..4a674219 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -76,38 +76,53 @@ def _reshape_opp_dim(x): self.ipd_stats = jax.jit(ipd_visitation) self.cg_stats = jax.jit(cg_visitation) - # we vmap over the rng but not params + # VMAP for num envs: we vmap over the rng but not params env.reset = jax.vmap(env.reset, (0, None), 0) env.step = jax.vmap( env.step, (0, 0, 0, None), 0 # rng, state, actions, params ) + # VMAP for num opps: we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) + def run_loop(self, env, env_params, agents, num_episodes, watchers): def _inner_rollout(carry, unused): """Runner for inner episode""" ( - t1, - t2, + rng, + obs1, + obs2, a1_state, a1_mem, a2_state, a2_mem, env_state, env_param, - env_rng, ) = carry - env_rng, _ = jax.random.split(rng) + + # unpack rngs + rng = self.split(rng, 4) + env_rng = rng[:, :, 0, :] + # a1_rng = rng[:, :, 1, :] + # a2_rng = rng[:, :, 2, :] + rng = rng[:, :, 3, :] + a1, a1_state, new_a1_mem = agent1.batch_policy( a1_state, - t1.observation, + obs1, a1_mem, ) a2, a2_state, new_a2_mem = agent2.batch_policy( a2_state, - t2.observation, + obs2, a2_mem, ) - (tprime_1, tprime_2), env_state = env.step( + (next_obs1, next_obs2), env_state, rewards, done, info = env.step( env_rng, env_state, (a1, a2), @@ -116,44 +131,44 @@ def _inner_rollout(carry, unused): if self.args.agent1 == "MFOS": traj1 = MFOSSample( - t1.observation, + obs1, a1, - tprime_1.reward, + rewards[0], new_a1_mem.extras["log_probs"], new_a1_mem.extras["values"], - tprime_1.last(), + done[0], a1_mem.hidden, a1_mem.th, ) else: traj1 = Sample( - t1.observation, + obs1, a1, - tprime_1.reward, + rewards[1], new_a1_mem.extras["log_probs"], new_a1_mem.extras["values"], - tprime_1.last(), + done, a1_mem.hidden, ) traj2 = Sample( - t2.observation, + obs2, a2, - tprime_2.reward, + rewards[1], new_a2_mem.extras["log_probs"], new_a2_mem.extras["values"], - tprime_2.last(), + done, a2_mem.hidden, ) return ( - tprime_1, - tprime_2, + rng, + next_obs1, + next_obs2, a1_state, new_a1_mem, a2_state, new_a2_mem, env_state, env_param, - env_rng, ), ( traj1, traj2, @@ -161,7 +176,18 @@ def _inner_rollout(carry, unused): def _outer_rollout(carry, unused): """Runner for trial""" - t1, t2, a1_state, a1_mem, a2_state, a2_memory, env_state = carry + ( + obs1, + obs2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_param, + env_rng, + ) = carry + # play episode of the game vals, trajectories = jax.lax.scan( _inner_rollout, @@ -170,24 +196,36 @@ def _outer_rollout(carry, unused): length=self.args.num_inner_steps, ) - # MFOS has to takes a meta-action for each episode + # MFOS has to take a meta-action for each episode if self.args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) # update second agent - t1, t2, a1_state, a1_mem, a2_state, a2_memory, env_state = vals - final_t2 = t2._replace(step_type=2 * jnp.ones_like(t2.step_type)) + ( + obs1, + obs2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_param, + env_rng, + ) = carry + a2_state, a2_memory, a2_metrics = agent2.batch_update( - trajectories[1], final_t2, a2_state, a2_memory + trajectories[1], obs2, a2_state, a2_mem ) return ( - t1, - t2, + obs1, + obs2, a1_state, a1_mem, a2_state, a2_memory, env_state, + env_param, + env_rng, ), (*trajectories, a2_metrics) """Run training of agents in environment""" @@ -205,8 +243,10 @@ def _outer_rollout(carry, unused): print(f"Log Interval {log_interval}") # run actual loop for i in range(num_episodes): + # RNG are the same for num_opps but different for num_envs rngs = jnp.concatenate( - jax.random.split(rng, self.args.num_opps * self.args.num_envs) + [jax.random.split(rng, self.args.num_opps)] + * self.args.num_envs ).reshape((self.args.num_opps, self.args.num_envs, -1)) obs, env_state = env.reset(rngs, env_params) @@ -222,21 +262,42 @@ def _outer_rollout(carry, unused): a2_state, a2_mem = agent2.batch_init( jax.random.split(rng, self.num_opps), a2_mem.hidden ) + # run trials vals, stack = jax.lax.scan( _outer_rollout, - (*obs, a1_state, a1_mem, a2_state, a2_mem, env_state), + ( + rngs, + *obs, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), None, length=self.args.num_steps // self.args.num_inner_steps, ) - t1, t2, a1_state, a1_mem, a2_state, a2_mem, env_state = vals + ( + rngs, + obs1, + obs2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals traj_1, traj_2, a2_metrics = stack + # update outer agent - final_t1 = t1._replace(step_type=2 * jnp.ones_like(t1.step_type)) + # final_t1 = t1._replace(step_type=2 * jnp.ones_like(t1.step_type)) a1_state, _, _ = agent1.update( reduce_outer_traj(traj_1), - self.reduce_opp_dim(final_t1), + self.reduce_opp_dim(obs1), a1_state, self.reduce_opp_dim(a1_mem), ) @@ -276,7 +337,7 @@ def _outer_rollout(carry, unused): self.ipd_stats( traj_1.observations, traj_1.actions, - final_t1.observation, + obs1, ), ) rewards_0 = traj_1.rewards.mean() From 28c03540e0ae41be275d0726282d2d6eba0af957 Mon Sep 17 00:00:00 2001 From: Timon Willi Date: Mon, 31 Oct 2022 13:08:37 +0000 Subject: [PATCH 21/36] getting rid of timestep for ppo --- pax/ppo/ppo.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index 71720adb..accc7a2a 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp import optax -from dm_env import TimeStep from pax import utils from pax.ppo.networks import ( @@ -369,24 +368,26 @@ def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState: @jax.jit def prepare_batch( - traj_batch: NamedTuple, t_prime: TimeStep, action_extras: dict + traj_batch: NamedTuple, reward: int, done: Any, action_extras: dict ): # Rollouts complete -> Training begins # Add an additional rollout step for advantage calculation _value = jax.lax.select( - t_prime.last(), + # t_prime.last(), + done, jnp.zeros_like(action_extras["values"]), action_extras["values"], ) _done = jax.lax.select( - t_prime.last(), + # t_prime.last(), + done, 2 * jnp.ones_like(_value), jnp.zeros_like(_value), ) _value = jax.lax.expand_dims(_value, [0]) - _reward = jax.lax.expand_dims(t_prime.reward, [0]) + _reward = jax.lax.expand_dims(reward, [0]) _done = jax.lax.expand_dims(_done, [0]) # need to add final value here traj_batch = traj_batch._replace( @@ -433,10 +434,10 @@ def prepare_batch( self._num_minibatches = num_minibatches # number of minibatches self._num_epochs = num_epochs # number of epochs to use sample - def select_action(self, t: TimeStep): + def select_action(self, obs: jnp.ndarray): """Selects action and updates info with PPO specific information""" actions, self._state, self._mem = self._policy( - self._state, t.observation, self._mem + self._state, obs, self._mem ) return utils.to_numpy(actions) @@ -450,11 +451,11 @@ def reset_memory(self, memory, eval=False) -> TrainingState: ) return memory - def update(self, traj_batch, t_prime, state, mem): + def update(self, traj_batch, reward: int, obs: jnp.ndarray, done: Any, state, mem): """Update the agent -> only called at the end of a trajectory""" - _, _, mem = self._policy(state, t_prime.observation, mem) + _, _, mem = self._policy(state, obs, mem) - traj_batch = self._prepare_batch(traj_batch, t_prime, mem.extras) + traj_batch = self._prepare_batch(traj_batch, reward, done, mem.extras) state, mem, metrics = self._sgd_step(state, traj_batch) self._logger.metrics["sgd_steps"] += ( self._num_minibatches * self._num_epochs From f616a6c6ba42d8b0a58aa28cce93b000a14b33b2 Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 13:20:58 +0000 Subject: [PATCH 22/36] updated runner --- pax/runner_rl.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pax/runner_rl.py b/pax/runner_rl.py index 4a674219..f7a8380e 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -97,6 +97,8 @@ def _inner_rollout(carry, unused): rng, obs1, obs2, + r1, + r2, a1_state, a1_mem, a2_state, @@ -163,6 +165,8 @@ def _inner_rollout(carry, unused): rng, next_obs1, next_obs2, + rewards[0], + rewards[1], a1_state, new_a1_mem, a2_state, @@ -179,6 +183,8 @@ def _outer_rollout(carry, unused): ( obs1, obs2, + r1, + r2, a1_state, a1_mem, a2_state, @@ -204,6 +210,8 @@ def _outer_rollout(carry, unused): ( obs1, obs2, + r1, + r2, a1_state, a1_mem, a2_state, @@ -219,6 +227,8 @@ def _outer_rollout(carry, unused): return ( obs1, obs2, + r1, + r2, a1_state, a1_mem, a2_state, @@ -250,6 +260,10 @@ def _outer_rollout(carry, unused): ).reshape((self.args.num_opps, self.args.num_envs, -1)) obs, env_state = env.reset(rngs, env_params) + rewards = [ + jnp.zeros((self.args.num_opps, self.args.num_envs)), + jnp.zeros((self.args.num_opps, self.args.num_envs)), + ] if self.args.agent1 == "NaiveEx": a1_state, a1_mem = agent1.batch_init(obs[0]) @@ -269,6 +283,7 @@ def _outer_rollout(carry, unused): ( rngs, *obs, + *rewards, a1_state, a1_mem, a2_state, @@ -284,6 +299,8 @@ def _outer_rollout(carry, unused): rngs, obs1, obs2, + r1, + r2, a1_state, a1_mem, a2_state, @@ -297,7 +314,9 @@ def _outer_rollout(carry, unused): # final_t1 = t1._replace(step_type=2 * jnp.ones_like(t1.step_type)) a1_state, _, _ = agent1.update( reduce_outer_traj(traj_1), + self.reduce_opp_dim(r1), self.reduce_opp_dim(obs1), + jnp.ones_like(self.reduce_opp_dim(r1), dtype=jnp.bool_), a1_state, self.reduce_opp_dim(a1_mem), ) From 468e8413c68294c5b05ac347456d99717480df3b Mon Sep 17 00:00:00 2001 From: Timon Willi Date: Mon, 31 Oct 2022 13:31:04 +0000 Subject: [PATCH 23/36] removing timestep from ppo, ppogru, and mfos --- pax/mfos_ppo/ppo_gru.py | 24 +++++++++++++----------- pax/ppo/ppo.py | 2 +- pax/ppo/ppo_gru.py | 28 +++++++++++++++------------- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/pax/mfos_ppo/ppo_gru.py b/pax/mfos_ppo/ppo_gru.py index 786f7e50..465f4afb 100644 --- a/pax/mfos_ppo/ppo_gru.py +++ b/pax/mfos_ppo/ppo_gru.py @@ -76,7 +76,7 @@ def __init__( ): @jax.jit def policy( - state: TrainingState, observation: TimeStep, mem: MemoryState + state: TrainingState, observation: jnp.ndarray, mem: MemoryState ): """Agent policy to select actions and calculate agent specific information""" key, subkey = jax.random.split(state.random_key) @@ -418,24 +418,24 @@ def make_initial_state( @jax.jit def prepare_batch( - traj_batch: NamedTuple, t_prime: TimeStep, action_extras: dict + traj_batch: NamedTuple, reward: jnp.ndarray, done: Any, t_prime: TimeStep, action_extras: dict ): # Rollouts complete -> Training begins # Add an additional rollout step for advantage calculation _value = jax.lax.select( - t_prime.last(), + done, jnp.zeros_like(action_extras["values"]), action_extras["values"], ) _done = jax.lax.select( - t_prime.last(), + done, 2 * jnp.ones_like(_value), jnp.zeros_like(_value), ) _value = jax.lax.expand_dims(_value, [0]) - _reward = jax.lax.expand_dims(t_prime.reward, [0]) + _reward = jax.lax.expand_dims(reward, [0]) _done = jax.lax.expand_dims(_done, [0]) # need to add final value here @@ -490,13 +490,13 @@ def prepare_batch( self._num_epochs = num_epochs # number of epochs to use sample self._gru_dim = gru_dim - def select_action(self, t: TimeStep): + def select_action(self, obs: jnp.ndarray): """Selects action and updates info with PPO specific information""" ( actions, self._state, self._mem, - ) = self._policy(self._state, t.observation, self._mem) + ) = self._policy(self._state, obs, self._mem) return utils.to_numpy(actions) def reset_memory(self, memory, eval=False) -> TrainingState: @@ -515,15 +515,17 @@ def reset_memory(self, memory, eval=False) -> TrainingState: def update( self, - traj_batch, - t_prime: TimeStep, + traj_batch: NamedTuple, + obs: jnp.ndarray, + reward: jnp.ndarray, + done: Any, state, mem, ): """Update the agent -> only called at the end of a trajectory""" - _, _, mem = self._policy(state, t_prime.observation, mem) - traj_batch = self.prepare_batch(traj_batch, t_prime, mem.extras) + _, _, mem = self._policy(state, obs, mem) + traj_batch = self.prepare_batch(traj_batch, reward, done, mem.extras) state, mem, metrics = self._sgd_step(state, traj_batch) # update logging diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index accc7a2a..19266853 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -451,7 +451,7 @@ def reset_memory(self, memory, eval=False) -> TrainingState: ) return memory - def update(self, traj_batch, reward: int, obs: jnp.ndarray, done: Any, state, mem): + def update(self, traj_batch, obs: jnp.ndarray, reward: int, done: Any, state, mem): """Update the agent -> only called at the end of a trajectory""" _, _, mem = self._policy(state, obs, mem) diff --git a/pax/ppo/ppo_gru.py b/pax/ppo/ppo_gru.py index 7f641d3a..177f8927 100644 --- a/pax/ppo/ppo_gru.py +++ b/pax/ppo/ppo_gru.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp import optax -from dm_env import TimeStep +# from dm_env import TimeStep from pax import utils from pax.ppo.networks import ( @@ -393,24 +393,24 @@ def make_initial_state( @jax.jit def prepare_batch( - traj_batch: NamedTuple, t_prime: TimeStep, action_extras: dict + traj_batch: NamedTuple, reward: jnp.ndarray, done: Any, action_extras: dict ): # Rollouts complete -> Training begins # Add an additional rollout step for advantage calculation _value = jax.lax.select( - t_prime.last(), + done, jnp.zeros_like(action_extras["values"]), action_extras["values"], ) _done = jax.lax.select( - t_prime.last(), + done, 2 * jnp.ones_like(_value), jnp.zeros_like(_value), ) _value = jax.lax.expand_dims(_value, [0]) - _reward = jax.lax.expand_dims(t_prime.reward, [0]) + _reward = jax.lax.expand_dims(reward, [0]) _done = jax.lax.expand_dims(_done, [0]) # need to add final value here @@ -465,13 +465,13 @@ def prepare_batch( self._num_epochs = num_epochs # number of epochs to use sample self._gru_dim = gru_dim - def select_action(self, t: TimeStep): + def select_action(self, obs: jnp.ndarray): """Selects action and updates info with PPO specific information""" ( actions, self._state, self._mem, - ) = self._policy(self._state, t.observation, self._mem) + ) = self._policy(self._state, obs, self._mem) return utils.to_numpy(actions) def reset_memory(self, memory, eval=False) -> TrainingState: @@ -488,16 +488,18 @@ def reset_memory(self, memory, eval=False) -> TrainingState: def update( self, - traj_batch, - t_prime: TimeStep, - state, - mem, + traj_batch: NamedTuple, + obs: jnp.ndarray, + reward: jnp.ndarray, + done: Any, + state: TrainingState, + mem: jnp.ndarray, ): """Update the agent -> only called at the end of a trajectory""" - _, _, mem = self._policy(state, t_prime.observation, mem) - traj_batch = self.prepare_batch(traj_batch, t_prime, mem.extras) + _, _, mem = self._policy(state, obs, mem) + traj_batch = self.prepare_batch(traj_batch, reward, done, mem.extras) state, mem, metrics = self._sgd_step(state, traj_batch) # update logging From e6c04643f10aaa527d7be79d01385c2df89e1d36 Mon Sep 17 00:00:00 2001 From: Timon Willi Date: Mon, 31 Oct 2022 14:50:54 +0000 Subject: [PATCH 24/36] change strategies update interface --- pax/strategies.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pax/strategies.py b/pax/strategies.py index 53730bc6..7b4f9a5f 100644 --- a/pax/strategies.py +++ b/pax/strategies.py @@ -129,7 +129,7 @@ def select_action( ) return action - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, mem, *args) -> MemoryState: @@ -233,7 +233,7 @@ def _policy( self._policy = _policy - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, mem, *args) -> MemoryState: @@ -332,7 +332,7 @@ def select_action( ) return action - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, mem, *args) -> MemoryState: @@ -355,7 +355,7 @@ def select_action( ) -> jnp.ndarray: return self._trigger(timestep.observation) - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, *args) -> TrainingState: @@ -397,7 +397,7 @@ def select_action( # return [batch] return self._reciprocity(obs) - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, mem, *args) -> MemoryState: @@ -442,7 +442,7 @@ def select_action( # return jnp.ones((batch_size, 1)) return jnp.ones((batch_size,)) - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, mem, *args) -> MemoryState: @@ -496,7 +496,7 @@ def _policy( batch_size = obs.shape[0] return jnp.zeros((batch_size,)), state, mem - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, mem, *args) -> MemoryState: @@ -563,7 +563,7 @@ def select_action( ) return action - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, mem, *args) -> MemoryState: @@ -607,7 +607,7 @@ def select_action( action = 5 * jnp.ones((batch_size,), dtype=int) return action - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, mem, *args) -> MemoryState: @@ -647,7 +647,7 @@ def _policy( action = jnp.tile(20 * jnp.ones((5,)), (batch_size, 1)) return action, state, mem - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, *args) -> TrainingState: @@ -685,7 +685,7 @@ def _policy( action = jnp.tile(-20 * jnp.ones((5,)), (batch_size, 1)) return action, state, mem - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, *args) -> TrainingState: @@ -730,7 +730,7 @@ def _policy( ) return action, state, mem - def update(self, unused0, unused1, state, mem) -> None: + def update(self, unused0, unused1, unused2, unused3, state, mem) -> None: return state, mem, {} def reset_memory(self, *args) -> TrainingState: From 58e12ba8d218b71583e83d85fc579cfdc1d87bb4 Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 14:59:44 +0000 Subject: [PATCH 25/36] updated runner --- pax/learners.py | 4 ++- pax/ppo/ppo.py | 12 ++++++--- pax/runner_rl.py | 63 +++++++++++++++++++++--------------------------- 3 files changed, 40 insertions(+), 39 deletions(-) diff --git a/pax/learners.py b/pax/learners.py index 6860e69d..16692e3a 100644 --- a/pax/learners.py +++ b/pax/learners.py @@ -45,7 +45,9 @@ def __init__(self, agents: list, args: dict): agent2.batch_reset = jax.jit( jax.vmap(agent2.reset_memory, (0, None), 0), static_argnums=1 ) - agent2.batch_update = jax.jit(jax.vmap(agent2.update, (1, 0, 0, 0), 0)) + agent2.batch_update = jax.jit( + jax.vmap(agent2.update, (1, 0, 0, 0, 0, 0), 0) + ) if args.agent1 != "NaiveEx": # NaiveEx requires env first step to init. diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index 19266853..5d19b919 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -374,14 +374,12 @@ def prepare_batch( # Add an additional rollout step for advantage calculation _value = jax.lax.select( - # t_prime.last(), done, jnp.zeros_like(action_extras["values"]), action_extras["values"], ) _done = jax.lax.select( - # t_prime.last(), done, 2 * jnp.ones_like(_value), jnp.zeros_like(_value), @@ -451,7 +449,15 @@ def reset_memory(self, memory, eval=False) -> TrainingState: ) return memory - def update(self, traj_batch, obs: jnp.ndarray, reward: int, done: Any, state, mem): + def update( + self, + traj_batch, + obs: jnp.ndarray, + reward: int, + done: Any, + state: TrainingState, + mem: MemoryState, + ): """Update the agent -> only called at the end of a trajectory""" _, _, mem = self._policy(state, obs, mem) diff --git a/pax/runner_rl.py b/pax/runner_rl.py index f7a8380e..48933178 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -83,9 +83,11 @@ def _reshape_opp_dim(x): ) # VMAP for num opps: we vmap over the rng but not params - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) ) self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) @@ -94,7 +96,7 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): def _inner_rollout(carry, unused): """Runner for inner episode""" ( - rng, + rngs, obs1, obs2, r1, @@ -108,11 +110,12 @@ def _inner_rollout(carry, unused): ) = carry # unpack rngs - rng = self.split(rng, 4) - env_rng = rng[:, :, 0, :] - # a1_rng = rng[:, :, 1, :] - # a2_rng = rng[:, :, 2, :] - rng = rng[:, :, 3, :] + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, 0, :] + # a1_rng = rngs[:, :, 1, :] + # a2_rng = rngs[:, :, 2, :] + rngs = rngs[:, :, 3, :] + print(rngs.shape) a1, a1_state, new_a1_mem = agent1.batch_policy( a1_state, @@ -162,7 +165,7 @@ def _inner_rollout(carry, unused): a2_mem.hidden, ) return ( - rng, + rngs, next_obs1, next_obs2, rewards[0], @@ -180,7 +183,15 @@ def _inner_rollout(carry, unused): def _outer_rollout(carry, unused): """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=self.args.num_inner_steps, + ) ( + rngs, obs1, obs2, r1, @@ -191,40 +202,23 @@ def _outer_rollout(carry, unused): a2_mem, env_state, env_param, - env_rng, - ) = carry - - # play episode of the game - vals, trajectories = jax.lax.scan( - _inner_rollout, - carry, - None, - length=self.args.num_inner_steps, - ) + ) = vals # MFOS has to take a meta-action for each episode if self.args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) # update second agent - ( - obs1, + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], obs2, - r1, r2, - a1_state, - a1_mem, + jnp.ones_like(r2, dtype=jnp.bool_), a2_state, a2_mem, - env_state, - env_param, - env_rng, - ) = carry - - a2_state, a2_memory, a2_metrics = agent2.batch_update( - trajectories[1], obs2, a2_state, a2_mem ) return ( + rngs, obs1, obs2, r1, @@ -232,10 +226,9 @@ def _outer_rollout(carry, unused): a1_state, a1_mem, a2_state, - a2_memory, + a2_mem, env_state, env_param, - env_rng, ), (*trajectories, a2_metrics) """Run training of agents in environment""" @@ -314,8 +307,8 @@ def _outer_rollout(carry, unused): # final_t1 = t1._replace(step_type=2 * jnp.ones_like(t1.step_type)) a1_state, _, _ = agent1.update( reduce_outer_traj(traj_1), - self.reduce_opp_dim(r1), self.reduce_opp_dim(obs1), + self.reduce_opp_dim(r1), jnp.ones_like(self.reduce_opp_dim(r1), dtype=jnp.bool_), a1_state, self.reduce_opp_dim(a1_mem), From 23ec8106da9bc19c4c5fb81d809401fb963f4e50 Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 15:01:13 +0000 Subject: [PATCH 26/36] updated runner --- pax/runner_rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pax/runner_rl.py b/pax/runner_rl.py index 48933178..3332d95c 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -115,7 +115,6 @@ def _inner_rollout(carry, unused): # a1_rng = rngs[:, :, 1, :] # a2_rng = rngs[:, :, 2, :] rngs = rngs[:, :, 3, :] - print(rngs.shape) a1, a1_state, new_a1_mem = agent1.batch_policy( a1_state, From 4c699f7754706994967236eb92855cdfd6907756 Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 15:11:49 +0000 Subject: [PATCH 27/36] fixed for mfos --- pax/conf/experiment/ipd/earl_v_ppo.yaml | 1 - pax/conf/experiment/ipd/mfos_v_tabular.yaml | 2 -- pax/conf/experiment/mp/earl_v_ppo.yaml | 1 - pax/conf/experiment/mp/earl_v_ppo_mem.yaml | 1 - pax/conf/experiment/mp/gs_v_ppo_mem.yaml | 4 +--- pax/conf/experiment/mp/gs_v_tabular.yaml | 4 +--- pax/conf/experiment/mp/mfos_v_tabular.yaml | 5 ++--- pax/experiment.py | 1 + pax/mfos_ppo/ppo_gru.py | 5 ++++- pax/runner_rl.py | 3 +-- 10 files changed, 10 insertions(+), 17 deletions(-) diff --git a/pax/conf/experiment/ipd/earl_v_ppo.yaml b/pax/conf/experiment/ipd/earl_v_ppo.yaml index 873e34ad..8616447a 100644 --- a/pax/conf/experiment/ipd/earl_v_ppo.yaml +++ b/pax/conf/experiment/ipd/earl_v_ppo.yaml @@ -26,7 +26,6 @@ num_generations: 5000 total_timesteps: 1e11 # Evaluation -num_seeds: 20 # # EARL vs. PPO trained on seed=0 # run_path: ucl-dark/ipd/13o3v95p # model_path: exp/EARL-PPO_memory-vs-PPO/run-seed-0-OpenES-pop-size-1000-num-opps-1/2022-09-15_00.15.31.908871/generation_2900 diff --git a/pax/conf/experiment/ipd/mfos_v_tabular.yaml b/pax/conf/experiment/ipd/mfos_v_tabular.yaml index d3e996a7..8d6e39e0 100644 --- a/pax/conf/experiment/ipd/mfos_v_tabular.yaml +++ b/pax/conf/experiment/ipd/mfos_v_tabular.yaml @@ -24,8 +24,6 @@ num_generations: 5000 total_timesteps: 1e11 num_devices: 1 -# Evaluation -num_seeds: 20 # MFOS vs. Tabular trained on seed = 0 run_path: ucl-dark/ipd/1r9txdso model_path: exp/GS-MFOS-vs-Tabular/run-seed-0-pop-size-1000/2022-09-25_20.32.20.821162/generation_4400 diff --git a/pax/conf/experiment/mp/earl_v_ppo.yaml b/pax/conf/experiment/mp/earl_v_ppo.yaml index d8678617..1cf595de 100644 --- a/pax/conf/experiment/mp/earl_v_ppo.yaml +++ b/pax/conf/experiment/mp/earl_v_ppo.yaml @@ -25,7 +25,6 @@ total_timesteps: 1e11 num_devices: 1 # Evaluation -num_seeds: 20 # # EARL vs. PPO trained on seed=0 # run_path: ucl-dark/ipd/13o3v95p # model_path: exp/EARL-PPO_memory-vs-PPO/run-seed-0-OpenES-pop-size-1000-num-opps-1/2022-09-15_00.15.31.908871/generation_2900 diff --git a/pax/conf/experiment/mp/earl_v_ppo_mem.yaml b/pax/conf/experiment/mp/earl_v_ppo_mem.yaml index 8986a21f..15286069 100644 --- a/pax/conf/experiment/mp/earl_v_ppo_mem.yaml +++ b/pax/conf/experiment/mp/earl_v_ppo_mem.yaml @@ -25,7 +25,6 @@ total_timesteps: 1e11 num_devices: 1 # Evaluation -num_seeds: 20 # # EARL vs. PPO trained on seed=0 # run_path: ucl-dark/ipd/13o3v95p # model_path: exp/EARL-PPO_memory-vs-PPO/run-seed-0-OpenES-pop-size-1000-num-opps-1/2022-09-15_00.15.31.908871/generation_2900 diff --git a/pax/conf/experiment/mp/gs_v_ppo_mem.yaml b/pax/conf/experiment/mp/gs_v_ppo_mem.yaml index aef49266..27c4e62d 100644 --- a/pax/conf/experiment/mp/gs_v_ppo_mem.yaml +++ b/pax/conf/experiment/mp/gs_v_ppo_mem.yaml @@ -24,9 +24,7 @@ num_generations: 5000 total_timesteps: 1e11 num_devices: 1 -# Evaluation -num_seeds: 20 -# # EARL vs. PPO trained on seed=0 +# EARL vs. PPO trained on seed=0 # run_path: ucl-dark/ipd/13o3v95p # model_path: exp/EARL-PPO_memory-vs-PPO/run-seed-0-OpenES-pop-size-1000-num-opps-1/2022-09-15_00.15.31.908871/generation_2900 # EARL vs. PPO trained on seed=1 diff --git a/pax/conf/experiment/mp/gs_v_tabular.yaml b/pax/conf/experiment/mp/gs_v_tabular.yaml index a6c6b84b..0ed3fa24 100644 --- a/pax/conf/experiment/mp/gs_v_tabular.yaml +++ b/pax/conf/experiment/mp/gs_v_tabular.yaml @@ -24,9 +24,7 @@ num_generations: 5000 total_timesteps: 1e11 num_devices: 1 -# Evaluation -num_seeds: 20 -# # EARL vs. PPO trained on seed=0 +# EARL vs. PPO trained on seed=0 # run_path: ucl-dark/ipd/13o3v95p # model_path: exp/EARL-PPO_memory-vs-PPO/run-seed-0-OpenES-pop-size-1000-num-opps-1/2022-09-15_00.15.31.908871/generation_2900 # EARL vs. PPO trained on seed=1 diff --git a/pax/conf/experiment/mp/mfos_v_tabular.yaml b/pax/conf/experiment/mp/mfos_v_tabular.yaml index 42222d67..35378b02 100644 --- a/pax/conf/experiment/mp/mfos_v_tabular.yaml +++ b/pax/conf/experiment/mp/mfos_v_tabular.yaml @@ -24,9 +24,8 @@ num_generations: 5000 total_timesteps: 1e11 num_devices: 1 -# Evaluation -num_seeds: 20 -# # EARL vs. PPO trained on seed=0 + +# EARL vs. PPO trained on seed=0 # run_path: ucl-dark/ipd/13o3v95p # model_path: exp/EARL-PPO_memory-vs-PPO/run-seed-0-OpenES-pop-size-1000-num-opps-1/2022-09-15_00.15.31.908871/generation_2900 # EARL vs. PPO trained on seed=1 diff --git a/pax/experiment.py b/pax/experiment.py index 17533a04..14691c46 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -462,6 +462,7 @@ def naive_pg_log(agent): "GoodGreedy": dumb_log, "EvilGreedy": dumb_log, "RandomGreedy": dumb_log, + "MFOS": dumb_log, "PPO": ppo_log, "PPO_memory": ppo_memory_log, "Naive": naive_pg_log, diff --git a/pax/mfos_ppo/ppo_gru.py b/pax/mfos_ppo/ppo_gru.py index 465f4afb..276dfc17 100644 --- a/pax/mfos_ppo/ppo_gru.py +++ b/pax/mfos_ppo/ppo_gru.py @@ -418,7 +418,10 @@ def make_initial_state( @jax.jit def prepare_batch( - traj_batch: NamedTuple, reward: jnp.ndarray, done: Any, t_prime: TimeStep, action_extras: dict + traj_batch: NamedTuple, + reward: jnp.ndarray, + done: Any, + action_extras: dict, ): # Rollouts complete -> Training begins # Add an additional rollout step for advantage calculation diff --git a/pax/runner_rl.py b/pax/runner_rl.py index 3332d95c..785a945f 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -140,7 +140,7 @@ def _inner_rollout(carry, unused): rewards[0], new_a1_mem.extras["log_probs"], new_a1_mem.extras["values"], - done[0], + done, a1_mem.hidden, a1_mem.th, ) @@ -303,7 +303,6 @@ def _outer_rollout(carry, unused): traj_1, traj_2, a2_metrics = stack # update outer agent - # final_t1 = t1._replace(step_type=2 * jnp.ones_like(t1.step_type)) a1_state, _, _ = agent1.update( reduce_outer_traj(traj_1), self.reduce_opp_dim(obs1), From 3d3dbc14fe91e85a8412e823d0bc92b25061cccd Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 15:20:36 +0000 Subject: [PATCH 28/36] fixed for mfos --- pax/conf/experiment/ipd/mfos_v_ppo.yaml | 1 - pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml | 3 +-- pax/runner_rl.py | 5 ++--- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pax/conf/experiment/ipd/mfos_v_ppo.yaml b/pax/conf/experiment/ipd/mfos_v_ppo.yaml index 2328315b..3e45fca3 100644 --- a/pax/conf/experiment/ipd/mfos_v_ppo.yaml +++ b/pax/conf/experiment/ipd/mfos_v_ppo.yaml @@ -13,7 +13,6 @@ payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]] # Runner runner: rl - # Training top_k: 5 popsize: 1000 diff --git a/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml b/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml index ef7c29ca..7516b559 100644 --- a/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml +++ b/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml @@ -11,8 +11,7 @@ env_discount: 0.96 payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]] # Runner -evo: True -eval: False +runner: rl # Training top_k: 5 diff --git a/pax/runner_rl.py b/pax/runner_rl.py index 785a945f..56cb9429 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -295,8 +295,8 @@ def _outer_rollout(carry, unused): r2, a1_state, a1_mem, - a2_state, - a2_mem, + _, + _, env_state, env_params, ) = vals @@ -314,7 +314,6 @@ def _outer_rollout(carry, unused): # update second agent a1_mem = agent1.batch_reset(a1_mem, False) - a2_mem = agent2.batch_reset(a2_mem, False) if self.args.save and i % self.args.save_interval == 0: log_savepath = os.path.join(self.save_dir, f"iteration_{i}") From 29260e1d1c5d1cc2c35c3f9f489b51b1a16563b5 Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 15:46:46 +0000 Subject: [PATCH 29/36] minor changes to mfos --- pax/mfos_ppo/ppo_gru.py | 9 ++++----- pax/ppo/ppo_gru.py | 10 +++++++--- pax/runner_rl.py | 2 +- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pax/mfos_ppo/ppo_gru.py b/pax/mfos_ppo/ppo_gru.py index 276dfc17..4f8971fd 100644 --- a/pax/mfos_ppo/ppo_gru.py +++ b/pax/mfos_ppo/ppo_gru.py @@ -265,11 +265,10 @@ def model_update_epoch( key, params, opt_state, timesteps, batch = carry key, subkey = jax.random.split(key) permutation = jax.random.permutation(subkey, batch_size) - shuffled_batch = jax.tree_map( + shuffled_batch = jax.tree_util.tree_map( lambda x: jnp.take(x, permutation, axis=0), batch ) - shuffled_batch = batch - minibatches = jax.tree_map( + minibatches = jax.tree_util.tree_map( lambda x: jnp.reshape( x, [num_minibatches, -1] + list(x.shape[1:]) ), @@ -522,8 +521,8 @@ def update( obs: jnp.ndarray, reward: jnp.ndarray, done: Any, - state, - mem, + state: TrainingState, + mem: MemoryState, ): """Update the agent -> only called at the end of a trajectory""" diff --git a/pax/ppo/ppo_gru.py b/pax/ppo/ppo_gru.py index 177f8927..7f216156 100644 --- a/pax/ppo/ppo_gru.py +++ b/pax/ppo/ppo_gru.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp import optax + # from dm_env import TimeStep from pax import utils @@ -393,7 +394,10 @@ def make_initial_state( @jax.jit def prepare_batch( - traj_batch: NamedTuple, reward: jnp.ndarray, done: Any, action_extras: dict + traj_batch: NamedTuple, + reward: jnp.ndarray, + done: Any, + action_extras: dict, ): # Rollouts complete -> Training begins # Add an additional rollout step for advantage calculation @@ -493,7 +497,7 @@ def update( reward: jnp.ndarray, done: Any, state: TrainingState, - mem: jnp.ndarray, + mem: MemoryState, ): """Update the agent -> only called at the end of a trajectory""" @@ -539,7 +543,7 @@ def make_gru_agent(args, obs_spec, action_spec, seed: int, player_id: int): ) # Optimizer - batch_size = int(args.num_envs * args.num_steps) + batch_size = int(args.num_envs * args.num_steps * args.num_opps) transition_steps = ( args.total_timesteps / batch_size diff --git a/pax/runner_rl.py b/pax/runner_rl.py index 56cb9429..d3a3431e 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -9,7 +9,7 @@ from pax.watchers import cg_visitation, ipd_visitation from pax.utils import save -MAX_WANDB_CALLS = 10000 +MAX_WANDB_CALLS = 1000000 class Sample(NamedTuple): From dfdf06eecbf4d38ca39591c65afbd8d9188bcdcb Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 16:10:27 +0000 Subject: [PATCH 30/36] rng fixes --- pax/conf/experiment/ipd/mfos_v_ppo.yaml | 2 +- pax/conf/experiment/ipd/ppo.yaml | 1 + pax/conf/experiment/ipd/ppo_memory.yaml | 1 + pax/runner_rl.py | 18 +++++++++--------- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pax/conf/experiment/ipd/mfos_v_ppo.yaml b/pax/conf/experiment/ipd/mfos_v_ppo.yaml index 3e45fca3..afe945a6 100644 --- a/pax/conf/experiment/ipd/mfos_v_ppo.yaml +++ b/pax/conf/experiment/ipd/mfos_v_ppo.yaml @@ -16,7 +16,7 @@ runner: rl # Training top_k: 5 popsize: 1000 -num_envs: 2 +num_envs: 100 num_opps: 1 num_steps: 10_000 num_inner_steps: 100 diff --git a/pax/conf/experiment/ipd/ppo.yaml b/pax/conf/experiment/ipd/ppo.yaml index 751278ec..d47d93a2 100644 --- a/pax/conf/experiment/ipd/ppo.yaml +++ b/pax/conf/experiment/ipd/ppo.yaml @@ -17,6 +17,7 @@ runner: rl num_envs: 100 num_opps: 1 num_steps: 150 # number of steps per episode +num_inner_steps: 150 total_timesteps: 1_000_000 # Evaluation diff --git a/pax/conf/experiment/ipd/ppo_memory.yaml b/pax/conf/experiment/ipd/ppo_memory.yaml index 3accbd10..b5153239 100644 --- a/pax/conf/experiment/ipd/ppo_memory.yaml +++ b/pax/conf/experiment/ipd/ppo_memory.yaml @@ -20,6 +20,7 @@ eval: False num_envs: 100 num_opps: 1 num_steps: 150 # number of steps per episode +num_inner_steps: 150 total_timesteps: 2e7 # Useful information diff --git a/pax/runner_rl.py b/pax/runner_rl.py index d3a3431e..4ead227d 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -148,7 +148,7 @@ def _inner_rollout(carry, unused): traj1 = Sample( obs1, a1, - rewards[1], + rewards[0], new_a1_mem.extras["log_probs"], new_a1_mem.extras["values"], done, @@ -243,14 +243,13 @@ def _outer_rollout(carry, unused): ) log_interval = max(num_iters / MAX_WANDB_CALLS, 5) print(f"Log Interval {log_interval}") + # RNG are the same for num_opps but different for num_envs + rngs = jnp.concatenate( + [jax.random.split(rng, self.args.num_envs)] * self.args.num_opps + ).reshape((self.args.num_opps, self.args.num_envs, -1)) + # run actual loop for i in range(num_episodes): - # RNG are the same for num_opps but different for num_envs - rngs = jnp.concatenate( - [jax.random.split(rng, self.args.num_opps)] - * self.args.num_envs - ).reshape((self.args.num_opps, self.args.num_envs, -1)) - obs, env_state = env.reset(rngs, env_params) rewards = [ jnp.zeros((self.args.num_opps, self.args.num_envs)), @@ -295,8 +294,8 @@ def _outer_rollout(carry, unused): r2, a1_state, a1_mem, - _, - _, + a2_state, + a2_mem, env_state, env_params, ) = vals @@ -314,6 +313,7 @@ def _outer_rollout(carry, unused): # update second agent a1_mem = agent1.batch_reset(a1_mem, False) + a2_mem = agent2.batch_reset(a2_mem, False) if self.args.save and i % self.args.save_interval == 0: log_savepath = os.path.join(self.save_dir, f"iteration_{i}") From 6a186eec015ce8efa545744b37f5b043b046e11d Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 16:50:13 +0000 Subject: [PATCH 31/36] evo runs --- pax/experiment.py | 8 +- pax/learners.py | 4 +- pax/runner_evo.py | 211 +++++++++++++++++++++++++++++++++------------- pax/runner_rl.py | 10 +-- 4 files changed, 163 insertions(+), 70 deletions(-) diff --git a/pax/experiment.py b/pax/experiment.py index 14691c46..25a18127 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -4,7 +4,7 @@ import os # uncomment to debug multi-devices on CPU -# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" +# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4" # from jax.config import config # config.update('jax_disable_jit', True) @@ -231,7 +231,9 @@ def get_pgpe_strategy(agent): logger.info(f"Evolution Strategy: {algo}") - return EvoRunner(args, strategy, es_params, param_reshaper, save_dir) + return EvoRunner( + env, strategy, es_params, param_reshaper, save_dir, args + ) elif args.runner == "rl": logger.info("Training with RL Runner") @@ -510,7 +512,7 @@ def main(args): if args.runner == "evo": num_iters = args.num_generations # number of generations print(f"Number of Generations: {num_iters}") - runner.run_loop(env, agent_pair, num_iters, watchers) + runner.run_loop(env, env_params, agent_pair, num_iters, watchers) elif args.runner == "rl": num_iters = int( diff --git a/pax/learners.py b/pax/learners.py index 16692e3a..fb0132f1 100644 --- a/pax/learners.py +++ b/pax/learners.py @@ -138,8 +138,8 @@ def __init__(self, agents: list, args: dict): agent2.batch_update = jax.jit( jax.vmap( - jax.vmap(agent2.update, (1, 0, 0, 0)), - (1, 0, 0, 0), + jax.vmap(agent2.update, (1, 0, 0, 0, 0, 0)), + (1, 0, 0, 0, 0, 0), ) ) diff --git a/pax/runner_evo.py b/pax/runner_evo.py index 107c179d..a497952a 100644 --- a/pax/runner_evo.py +++ b/pax/runner_evo.py @@ -1,12 +1,13 @@ from datetime import datetime import os import time -from typing import NamedTuple +from typing import Any, NamedTuple from evosax import FitnessShaper import jax import jax.numpy as jnp import wandb +import chex from pax.utils import save, TrainingState, MemoryState # TODO: import when evosax library is updated @@ -31,7 +32,9 @@ class Sample(NamedTuple): class EvoRunner: """Holds the runner's state.""" - def __init__(self, args, strategy, es_params, param_reshaper, save_dir): + def __init__( + self, env, strategy, es_params, param_reshaper, save_dir, args + ): self.algo = args.es.algo self.args = args self.es_params = es_params @@ -52,56 +55,106 @@ def __init__(self, args, strategy, es_params, param_reshaper, save_dir): self.ipd_stats = jax.jit(ipd_visitation) self.cg_stats = jax.jit(cg_visitation) - def run_loop(self, env, agents, num_generations, watchers): + # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) + # Evo Runner also has an additional pmap dim (num_devices, ...) + # For the env we vmap over the rng but not params + + # num envs + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + # num opps + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + # pop size + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + self.split = jax.vmap( + jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), + (0, None), + ) + + def run_loop(self, env, env_params, agents, num_generations, watchers): """Run training of agents in environment""" def _inner_rollout(carry, unused): """Runner for inner episode""" - t1, t2, a1_state, a1_mem, a2_state, a2_mem, env_state = carry + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] a1, a1_state, new_a1_mem = agent1.batch_policy( a1_state, - t1.observation, + obs1, a1_mem, ) - a2, a2_state, new_a2_mem = agent2.batch_policy( a2_state, - t2.observation, + obs2, a2_mem, ) - - (tprime_1, tprime_2), env_state = env.batch_step( - (a1, a2), + (next_obs1, next_obs2), env_state, rewards, done, info = env.step( + env_rng, env_state, + (a1, a2), + env_params, ) traj1 = Sample( - t1.observation, + obs1, a1, - tprime_1.reward, + rewards[0], new_a1_mem.extras["log_probs"], new_a1_mem.extras["values"], - tprime_1.last(), + done, a1_mem.hidden, ) traj2 = Sample( - t2.observation, + obs2, a2, - tprime_2.reward, + rewards[1], new_a2_mem.extras["log_probs"], new_a2_mem.extras["values"], - tprime_2.last(), + done, a2_mem.hidden, ) return ( - tprime_1, - tprime_2, + rngs, + next_obs1, + next_obs2, + rewards[0], + rewards[1], a1_state, new_a1_mem, a2_state, new_a2_mem, env_state, + env_params, ), ( traj1, traj2, @@ -109,57 +162,85 @@ def _inner_rollout(carry, unused): def _outer_rollout(carry, unused): """Runner for trial""" - t1, t2, a1_state, a1_mem, a2_state, a2_mem, env_state = carry # play episode of the game vals, trajectories = jax.lax.scan( _inner_rollout, carry, None, - length=env.inner_episode_length, + length=self.args.num_inner_steps, ) - - # update second agent - t1, t2, a1_state, a1_mem, a2_state, a2_mem, env_state = vals - - # MFOS has to takes a meta-action for each episode + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode if self.args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) - # do second agent update - final_t2 = t2._replace( - step_type=2 * jnp.ones_like(vals[1].step_type) - ) + # update second agent a2_state, a2_mem, a2_metrics = agent2.batch_update( - trajectories[1], final_t2, a2_state, a2_mem + trajectories[1], + obs2, + r2, + jnp.ones_like(r2, dtype=jnp.bool_), + a2_state, + a2_mem, ) - return ( - t1, - t2, + rngs, + obs1, + obs2, + r1, + r2, a1_state, a1_mem, a2_state, a2_mem, env_state, + env_params, ), (*trajectories, a2_metrics) def evo_rollout( params: jnp.ndarray, rng_run: jnp.ndarray, - rng_key: jnp.ndarray, a1_state: TrainingState, a1_mem: MemoryState, + env_params: Any, ): # env reset - t_init, env_state = env.runner_reset( - (popsize, num_opps, env.num_envs), rng_run + rngs = jnp.concatenate( + [jax.random.split(rng_run, self.args.num_envs)] + * self.args.num_opps + * self.args.popsize + ).reshape( + (self.args.popsize, self.args.num_opps, self.args.num_envs, -1) ) + + obs, env_state = env.reset(rngs, env_params) + rewards = [ + jnp.zeros( + (self.args.popsize, self.args.num_opps, self.args.num_envs) + ), + jnp.zeros( + (self.args.popsize, self.args.num_opps, self.args.num_envs) + ), + ] + # Player 1 a1_state = a1_state._replace(params=params) a1_mem = agent1.batch_reset(a1_mem, False) # Player 2 if self.args.agent2 == "NaiveEx": - a2_state, a2_mem = agent2.batch_init(t_init[1]) + a2_state, a2_mem = agent2.batch_init(obs[1]) else: # meta-experiments - init 2nd agent per trial @@ -170,15 +251,38 @@ def evo_rollout( agent2._mem.hidden, ) + # run trials vals, stack = jax.lax.scan( _outer_rollout, - (*t_init, a1_state, a1_mem, a2_state, a2_mem, env_state), + ( + rngs, + *obs, + *rewards, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), None, - length=env.outer_ep_length, + length=self.args.num_steps // self.args.num_inner_steps, ) + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals traj_1, traj_2, a2_metrics = stack - t1, t2, a1_state, a1_mem, a2_state, a2_mem, env_state = vals # Fitness fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) @@ -198,15 +302,13 @@ def evo_rollout( "meta", "sequential", ]: - final_t1 = t1._replace( - step_type=2 * jnp.ones_like(t1.step_type) - ) + env_stats = jax.tree_util.tree_map( lambda x: x.mean(), self.ipd_stats( traj_1.observations, traj_1.actions, - final_t1.observation, + obs1, ), ) rewards_0 = traj_1.rewards.mean() @@ -226,8 +328,8 @@ def evo_rollout( print(f"Number of Generations: {num_generations}") print(f"Number of Meta Episodes: {num_generations}") print(f"Population Size: {self.popsize}") - print(f"Number of Environments: {env.num_envs}") - print(f"Number of Opponent: {self.num_opps}") + print(f"Number of Environments: {self.args.num_envs}") + print(f"Number of Opponent: {self.args.num_opps}") print(f"Log Interval: {log_interval}") print("------------------------------") # Initialize agents and RNG @@ -252,19 +354,6 @@ def evo_rollout( log = es_logging.initialize() num_devices = self.args.num_devices - # Evolution specific: add pop size dimension - if self.args.env_type == "infinite" and self.args.env_id == "ipd": - env.batch_step = jax.jit( - jax.vmap(env.batch_step, (0, None), (0, None)) - ) - else: - env.batch_step = jax.jit( - jax.vmap(env.batch_step), - ) - - if self.args.env_id == "coin_game": - env.batch_reset = jax.jit(jax.vmap(env.batch_reset)) - # Reshape a single agent's params before vmapping init_hidden = jnp.tile( agent1._mem.hidden, @@ -280,8 +369,10 @@ def evo_rollout( evo_rollout, in_axes=(0, None, None, None, None), ) + for gen in range(num_gens): rng, rng_run, rng_gen, rng_key = jax.random.split(rng, 4) + # Ask x, evo_state = strategy.ask(rng_gen, evo_state, es_params) params = param_reshaper.reshape(x) @@ -297,7 +388,7 @@ def evo_rollout( rewards_0, rewards_1, a2_metrics, - ) = evo_rollout(params, rng_run, rng_key, a1_state, a1_mem) + ) = evo_rollout(params, rng_run, a1_state, a1_mem, env_params) # Reshape over devices fitness = jnp.reshape(fitness, popsize * num_devices) diff --git a/pax/runner_rl.py b/pax/runner_rl.py index 4ead227d..cb3fdfad 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -106,7 +106,7 @@ def _inner_rollout(carry, unused): a2_state, a2_mem, env_state, - env_param, + env_params, ) = carry # unpack rngs @@ -130,7 +130,7 @@ def _inner_rollout(carry, unused): env_rng, env_state, (a1, a2), - env_param, + env_params, ) if self.args.agent1 == "MFOS": @@ -174,7 +174,7 @@ def _inner_rollout(carry, unused): a2_state, new_a2_mem, env_state, - env_param, + env_params, ), ( traj1, traj2, @@ -200,7 +200,7 @@ def _outer_rollout(carry, unused): a2_state, a2_mem, env_state, - env_param, + env_params, ) = vals # MFOS has to take a meta-action for each episode @@ -227,7 +227,7 @@ def _outer_rollout(carry, unused): a2_state, a2_mem, env_state, - env_param, + env_params, ), (*trajectories, a2_metrics) """Run training of agents in environment""" From a5c816894d631f58e06074777a6fec4a49df042d Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 17:07:00 +0000 Subject: [PATCH 32/36] updated configs --- pax/conf/experiment/ipd/mfos_v_ppo.yaml | 4 ++-- pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml | 5 ++--- pax/conf/experiment/ipd/mfos_v_tabular.yaml | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pax/conf/experiment/ipd/mfos_v_ppo.yaml b/pax/conf/experiment/ipd/mfos_v_ppo.yaml index afe945a6..73b2b89c 100644 --- a/pax/conf/experiment/ipd/mfos_v_ppo.yaml +++ b/pax/conf/experiment/ipd/mfos_v_ppo.yaml @@ -11,7 +11,7 @@ env_discount: 0.96 payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]] # Runner -runner: rl +runner: evo # Training top_k: 5 @@ -22,7 +22,7 @@ num_steps: 10_000 num_inner_steps: 100 num_generations: 5000 total_timesteps: 1e11 -num_devices: 1 +num_devices: 2 # PPO agent parameters ppo: diff --git a/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml b/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml index 7516b559..4d97fce0 100644 --- a/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml +++ b/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml @@ -11,7 +11,7 @@ env_discount: 0.96 payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]] # Runner -runner: rl +runner: evo # Training top_k: 5 @@ -22,8 +22,7 @@ num_steps: 10_000 num_inner_steps: 100 num_generations: 5000 total_timesteps: 1e11 -num_devices: 1 -runner: rl +num_devices: 2 # PPO agent parameters ppo: diff --git a/pax/conf/experiment/ipd/mfos_v_tabular.yaml b/pax/conf/experiment/ipd/mfos_v_tabular.yaml index 8d6e39e0..b9e073d3 100644 --- a/pax/conf/experiment/ipd/mfos_v_tabular.yaml +++ b/pax/conf/experiment/ipd/mfos_v_tabular.yaml @@ -11,7 +11,7 @@ env_discount: 0.96 payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]] # Runner -runner: rl +runner: evo # Training top_k: 5 @@ -22,7 +22,7 @@ num_steps: 10_000 num_inner_steps: 100 num_generations: 5000 total_timesteps: 1e11 -num_devices: 1 +num_devices: 2 # MFOS vs. Tabular trained on seed = 0 run_path: ucl-dark/ipd/1r9txdso From c159bb8e027018041ff267ac808753ca45a3e8bf Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 18:57:35 +0000 Subject: [PATCH 33/36] updated eval runner --- pax/runner_eval.py | 168 +++++++++++++++++++++++++++++++++++---------- pax/runner_rl.py | 5 +- 2 files changed, 131 insertions(+), 42 deletions(-) diff --git a/pax/runner_eval.py b/pax/runner_eval.py index e35c9775..7d0ff428 100644 --- a/pax/runner_eval.py +++ b/pax/runner_eval.py @@ -27,7 +27,7 @@ class Sample(NamedTuple): class EvalRunner: """Evaluation runner""" - def __init__(self, args): + def __init__(self, env, args): self.train_steps = 0 self.eval_steps = 0 self.train_episodes = 0 @@ -40,53 +40,93 @@ def __init__(self, args): self.model_path = args.model_path self.ipd_stats = jax.jit(ipd_visitation) self.cg_stats = jax.jit(cg_visitation) + # VMAP for num envs: we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) - def run_loop(self, env, agents, num_episodes, watchers): + # VMAP for num opps: we vmap over the rng but not params + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + + self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) + + def run_loop(self, env, env_params, agents, num_episodes, watchers): def _inner_rollout(carry, unused): """Runner for inner episode""" - t1, t2, a1_state, a1_mem, a2_state, a2_mem, env_state = carry + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, 0, :] + # a1_rng = rngs[:, :, 1, :] + # a2_rng = rngs[:, :, 2, :] + rngs = rngs[:, :, 3, :] a1, a1_state, new_a1_mem = agent1.batch_policy( a1_state, - t1.observation, + obs1, a1_mem, ) a2, a2_state, new_a2_mem = agent2.batch_policy( a2_state, - t2.observation, + obs2, a2_mem, ) - (tprime_1, tprime_2), env_state = env.batch_step( - (a1, a2), + (next_obs1, next_obs2), env_state, rewards, done, info = env.step( + env_rng, env_state, + (a1, a2), + env_params, ) traj1 = Sample( - t1.observation, + obs1, a1, - tprime_1.reward, + rewards[0], new_a1_mem.extras["log_probs"], new_a1_mem.extras["values"], - tprime_1.last(), + done, a1_mem.hidden, ) traj2 = Sample( - t2.observation, + obs2, a2, - tprime_2.reward, + rewards[1], new_a2_mem.extras["log_probs"], new_a2_mem.extras["values"], - tprime_2.last(), + done, a2_mem.hidden, ) return ( - tprime_1, - tprime_2, + rngs, + next_obs1, + next_obs2, + rewards[0], + rewards[1], a1_state, new_a1_mem, a2_state, new_a2_mem, env_state, + env_params, ), ( traj1, traj2, @@ -94,34 +134,51 @@ def _inner_rollout(carry, unused): def _outer_rollout(carry, unused): """Runner for trial""" - t1, t2, a1_state, a1_mem, a2_state, a2_memory, env_state = carry # play episode of the game vals, trajectories = jax.lax.scan( _inner_rollout, carry, None, - length=env.inner_episode_length, + length=self.args.num_inner_steps, ) - - # update second agent - t1, t2, a1_state, a1_mem, a2_state, a2_memory, env_state = vals - - # MFOS has to takes a meta-action for each episode + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode if self.args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) - final_t2 = t2._replace(step_type=2 * jnp.ones_like(t2.step_type)) - a2_state, a2_memory, a2_metrics = agent2.batch_update( - trajectories[1], final_t2, a2_state, a2_memory + # update second agent + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], + obs2, + r2, + jnp.ones_like(r2, dtype=jnp.bool_), + a2_state, + a2_mem, ) return ( - t1, - t2, + rngs, + obs1, + obs2, + r1, + r2, a1_state, a1_mem, a2_state, - a2_memory, + a2_mem, env_state, + env_params, ), (*trajectories, a2_metrics) """Run training of agents in environment""" @@ -143,25 +200,58 @@ def _outer_rollout(carry, unused): num_iters = max(int(num_episodes / (env.num_envs * self.num_opps)), 1) log_interval = max(num_iters / MAX_WANDB_CALLS, 5) print(f"Log Interval {log_interval}") + + # RNG are the same for num_opps but different for num_envs + + rngs = jnp.concatenate( + [jax.random.split(rng, self.args.num_envs)] * self.args.num_opps + ).reshape((self.args.num_opps, self.args.num_envs, -1)) # run actual loop for i in range(num_episodes): - rng, rng_run = jax.random.split(rng) - t_init, env_state = env.runner_reset( - (self.num_opps, env.num_envs), rng_run - ) + obs, env_state = env.reset(rngs, env_params) + rewards = [ + jnp.zeros((self.args.num_opps, self.args.num_envs)), + jnp.zeros((self.args.num_opps, self.args.num_envs)), + ] if self.args.agent2 == "NaiveEx": - a2_state, a2_mem = agent2.batch_init(t_init[1]) - + a2_state, a2_mem = agent2.batch_init(obs[1]) + elif self.args.env_type in ["meta", "infinite"]: + # meta-experiments - init 2nd agent per trial + a2_state, a2_mem = agent2.batch_init( + jax.random.split(rng, self.num_opps), a2_mem.hidden + ) # run trials vals, stack = jax.lax.scan( _outer_rollout, - (*t_init, a1_state, a1_mem, a2_state, a2_mem, env_state), + ( + rngs, + *obs, + *rewards, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), None, - length=env.outer_ep_length, + length=self.args.num_steps // self.args.num_inner_steps, ) - t1, t2, _, a1_mem, a2_state, a2_mem, env_state = vals + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals traj_1, traj_2, a2_metrics = stack # reset second agent memory @@ -188,13 +278,15 @@ def _outer_rollout(carry, unused): self.ipd_stats( traj_1.observations, traj_1.actions, - t1.observation, + obs1, ), ) rewards_0 = traj_1.rewards.mean() rewards_1 = traj_2.rewards.mean() else: + rewards_0 = traj_1.rewards.mean() + rewards_1 = traj_2.rewards.mean() env_stats = {} print(f"Env Stats: {env_stats}") diff --git a/pax/runner_rl.py b/pax/runner_rl.py index cb3fdfad..e9af0824 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -1,6 +1,6 @@ import os import time -from typing import List, NamedTuple +from typing import NamedTuple import jax import jax.numpy as jnp @@ -202,7 +202,6 @@ def _outer_rollout(carry, unused): env_state, env_params, ) = vals - # MFOS has to take a meta-action for each episode if self.args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) @@ -267,7 +266,6 @@ def _outer_rollout(carry, unused): a2_state, a2_mem = agent2.batch_init( jax.random.split(rng, self.num_opps), a2_mem.hidden ) - # run trials vals, stack = jax.lax.scan( _outer_rollout, @@ -386,7 +384,6 @@ def _outer_rollout(carry, unused): | env_stats, ) - # update agents for eval loop exit agents.agents[0]._state = a1_state agents.agents[1]._state = a2_state return agents From 0299331749a63d8793a9b2178fb47df7b2d9acd9 Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 19:12:36 +0000 Subject: [PATCH 34/36] updated to work for cg --- pax/conf/experiment/cg/sanity.yaml | 3 ++- pax/envs/coin_game.py | 7 +++---- pax/experiment.py | 5 ++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pax/conf/experiment/cg/sanity.yaml b/pax/conf/experiment/cg/sanity.yaml index d4b6e495..e1cda6a5 100644 --- a/pax/conf/experiment/cg/sanity.yaml +++ b/pax/conf/experiment/cg/sanity.yaml @@ -8,8 +8,9 @@ agent2: 'PPO_memory' env_id: coin_game env_type: sequential egocentric: True +cnn: False env_discount: 0.96 -payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]] +payoff: [[1, 1, -2], [1, 1, -2]] runner: rl # Training hyperparameters diff --git a/pax/envs/coin_game.py b/pax/envs/coin_game.py index 32c9d47e..347df1b4 100644 --- a/pax/envs/coin_game.py +++ b/pax/envs/coin_game.py @@ -56,7 +56,7 @@ class EnvParams: class CoinGame(environment.Environment): """ - JAX Compatible version of matrix game environment. + JAX Compatible version of coin game environment. """ def __init__( @@ -168,7 +168,6 @@ def _relative_position(state: EnvState) -> jnp.ndarray: def _state_to_obs(state: EnvState) -> jnp.ndarray: if egocentric: - print("Running Egocentric") obs1 = _relative_position(state) # flip red and blue coins for second agent @@ -340,8 +339,8 @@ def _step( obs1 = jnp.where(done, reset_obs[0], obs1) obs2 = jnp.where(done, reset_obs[1], obs2) - blue_reward = jnp.where(done, 0, blue_reward) - red_reward = jnp.where(done, 0, red_reward) + blue_reward = jnp.where(done, 0.0, blue_reward) + red_reward = jnp.where(done, 0.0, red_reward) return ( (obs1, obs2), next_state, diff --git a/pax/experiment.py b/pax/experiment.py index 25a18127..48a1cc52 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -120,6 +120,8 @@ def env_setup(args, logger=None): raise ValueError(f"Unknown env type {args.env_type}") elif args.env_id == "coin_game": + payoff = jnp.array(args.payoff) + env_params = CoinGameParams(payoff_matrix=payoff) if args.env_type == "sequential": env = CoinGame( num_inner_steps=args.num_steps, @@ -127,7 +129,6 @@ def env_setup(args, logger=None): cnn=args.cnn, egocentric=args.egocentric, ) - env_params = CoinGameParams(args.payoff_matrix) else: env = CoinGame( num_inner_steps=args.num_inner_steps, @@ -135,8 +136,6 @@ def env_setup(args, logger=None): cnn=args.cnn, egocentric=args.egocentric, ) - env_params = CoinGameParams(args.payoff_matrix) - if logger: logger.info( f"Env Type: CoinGame | Episode Length: {args.num_steps}" From b3f8c501598d0f0fdae5a1b021ab43efa600269e Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 19:25:43 +0000 Subject: [PATCH 35/36] updated to work for cg --- pax/envs/coin_game.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pax/envs/coin_game.py b/pax/envs/coin_game.py index 347df1b4..a3693951 100644 --- a/pax/envs/coin_game.py +++ b/pax/envs/coin_game.py @@ -221,7 +221,7 @@ def _step( new_blue_pos == state.blue_coin_pos, axis=-1 ) - ### [[1, -2],[1, -2] + ### [[1, 1, -2],[1, 1, -2]] _rr_reward = params.payoff_matrix[0][0] _rb_reward = params.payoff_matrix[0][1] _r_penalty = params.payoff_matrix[0][2] @@ -236,7 +236,7 @@ def _step( red_blue_matches, red_reward + _rb_reward, red_reward ) red_reward = jnp.where( - blue_red_matches, red_reward - _r_penalty, red_reward + blue_red_matches, red_reward + _r_penalty, red_reward ) blue_reward = jnp.where( @@ -246,7 +246,7 @@ def _step( blue_blue_matches, blue_reward + _bb_reward, blue_reward ) blue_reward = jnp.where( - red_blue_matches, blue_reward - _b_penalty, blue_reward + red_blue_matches, blue_reward + _b_penalty, blue_reward ) (counter, coop1, coop2, last_state) = _update_stats( From e637fb95f9fce4a9067d8c909dbd02ce17085414 Mon Sep 17 00:00:00 2001 From: akbir Date: Mon, 31 Oct 2022 21:51:03 +0000 Subject: [PATCH 36/36] fixed sequential bug --- pax/envs/coin_game.py | 32 +++++++++++++++++++------------- pax/experiment.py | 2 +- pax/runner_rl.py | 1 - pax/watchers.py | 38 ++++++++++++++++++++------------------ 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/pax/envs/coin_game.py b/pax/envs/coin_game.py index a3693951..414760b0 100644 --- a/pax/envs/coin_game.py +++ b/pax/envs/coin_game.py @@ -303,29 +303,35 @@ def _step( ) obs1, obs2 = _state_to_obs(next_state) + + # now calculate if done for inner or outer episode inner_t = next_state.inner_t outer_t = next_state.outer_t - done = inner_t % num_inner_steps == 0 + reset_inner = inner_t == num_inner_steps # if inner episode is done, return start state for next game reset_obs, reset_state = _reset(key, params) next_state = EnvState( red_pos=jnp.where( - done, reset_state.red_pos, next_state.red_pos + reset_inner, reset_state.red_pos, next_state.red_pos ), blue_pos=jnp.where( - done, reset_state.blue_pos, next_state.blue_pos + reset_inner, reset_state.blue_pos, next_state.blue_pos ), red_coin_pos=jnp.where( - done, reset_state.red_coin_pos, next_state.red_coin_pos + reset_inner, + reset_state.red_coin_pos, + next_state.red_coin_pos, ), blue_coin_pos=jnp.where( - done, reset_state.blue_coin_pos, next_state.blue_coin_pos + reset_inner, + reset_state.blue_coin_pos, + next_state.blue_coin_pos, ), inner_t=jnp.where( - done, jnp.zeros_like(inner_t), next_state.inner_t + reset_inner, jnp.zeros_like(inner_t), next_state.inner_t ), - outer_t=jnp.where(done, outer_t + 1, outer_t), + outer_t=jnp.where(reset_inner, outer_t + 1, outer_t), red_coop=next_state.red_coop, red_defect=next_state.red_defect, blue_coop=next_state.blue_coop, @@ -333,19 +339,19 @@ def _step( counter=counter, coop1=coop1, coop2=coop2, - last_state=jnp.where(done, jnp.zeros(2), last_state), + last_state=jnp.where(reset_inner, jnp.zeros(2), last_state), ) - obs1 = jnp.where(done, reset_obs[0], obs1) - obs2 = jnp.where(done, reset_obs[1], obs2) + obs1 = jnp.where(reset_inner, reset_obs[0], obs1) + obs2 = jnp.where(reset_inner, reset_obs[1], obs2) - blue_reward = jnp.where(done, 0.0, blue_reward) - red_reward = jnp.where(done, 0.0, red_reward) + blue_reward = jnp.where(reset_inner, 0.0, blue_reward) + red_reward = jnp.where(reset_inner, 0.0, red_reward) return ( (obs1, obs2), next_state, (red_reward, blue_reward), - done, + reset_inner, {"discount": jnp.zeros((), dtype=jnp.int8)}, ) diff --git a/pax/experiment.py b/pax/experiment.py index 48a1cc52..8f289393 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -125,7 +125,7 @@ def env_setup(args, logger=None): if args.env_type == "sequential": env = CoinGame( num_inner_steps=args.num_steps, - num_outer_steps=args.num_steps, + num_outer_steps=1, cnn=args.cnn, egocentric=args.egocentric, ) diff --git a/pax/runner_rl.py b/pax/runner_rl.py index e9af0824..259b3d55 100644 --- a/pax/runner_rl.py +++ b/pax/runner_rl.py @@ -75,7 +75,6 @@ def _reshape_opp_dim(x): self.reduce_opp_dim = jax.jit(_reshape_opp_dim) self.ipd_stats = jax.jit(ipd_visitation) self.cg_stats = jax.jit(cg_visitation) - # VMAP for num envs: we vmap over the rng but not params env.reset = jax.vmap(env.reset, (0, None), 0) env.step = jax.vmap( diff --git a/pax/watchers.py b/pax/watchers.py index 0078bca6..e7b3b86d 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -425,30 +425,32 @@ def ipd_visitation( } -def cg_visitation(env_state: NamedTuple) -> dict: - # env state : [num_opps, num_envs, num_episodes] - env_state = jax.tree_util.tree_map( - lambda x: x.reshape(-1, x.shape[-1]), env_state +def cg_visitation(state: NamedTuple) -> dict: + # [num_opps, num_envs, num_outer_episodes] + total_1 = state.red_coop + state.red_defect + total_2 = state.blue_coop + state.blue_defect + avg_prob_1 = jnp.sum(state.red_coop, axis=-1) / jnp.sum(total_1, axis=-1) + avg_prob_2 = jnp.sum(state.blue_coop, axis=-1) / jnp.sum(total_2, axis=-1) + final_prob_1 = state.red_coop[:, :, -1] / total_1[:, :, -1] + final_prob_2 = state.blue_coop[:, :, -1] / total_2[:, :, -1] + + # [num_opps, num_envs, num_states] + prob_coop_1 = jnp.sum(state.coop1, axis=(0, 1)) / jnp.sum( + state.counter, axis=(0, 1) ) - - total_1 = env_state.red_coop + env_state.red_defect - total_2 = env_state.blue_coop + env_state.blue_defect - - prob_1 = env_state.red_coop / total_1 - prob_2 = env_state.blue_coop / total_2 - - prob_coop_1 = jnp.nanmean(env_state.coop1 / env_state.counter, axis=0) - prob_coop_2 = jnp.nanmean(env_state.coop2 / env_state.counter, axis=0) - count = jnp.nanmean(env_state.counter, axis=0) + prob_coop_2 = jnp.sum(state.coop2, axis=(0, 1)) / jnp.sum( + state.counter, axis=(0, 1) + ) + count = jnp.nanmean(state.counter, axis=0) return { - "prob_coop/1": jnp.nanmean(prob_1, axis=0), # [num_episodes] - "prob_coop/2": jnp.nanmean(prob_2, axis=0), # [num_episodes] + "prob_coop/1": jnp.nanmean(avg_prob_1), # [1] + "prob_coop/2": jnp.nanmean(avg_prob_2), # [1] + "final_prob_coop/1": jnp.nanmean(final_prob_1), # [1] + "final_prob_coop/2": jnp.nanmean(final_prob_2), # [1] "total_coins/1": total_1.sum(), # int "total_coins/2": total_2.sum(), # int "coins_per_episode/1": total_1.mean(axis=0), # [num_episodes] "coins_per_episode/2": total_2.mean(axis=0), # [num_episodes] - "final_prob_coop/1": jnp.nanmean(prob_1, axis=0)[-1], # [1] - "final_prob_coop/2": jnp.nanmean(prob_2, axis=0)[-1], # [1] "final_coin_total/1": total_1.mean(axis=0)[-1], # [1] "final_coin_total/2": total_2.mean(axis=0)[-1], # [1] "cooperation_probability/1/SS": prob_coop_1[0],