diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c3ec86c4f..0e93263bd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/PyCQA/autoflake - rev: v2.2.1 + rev: v2.3.3 hooks: - id: autoflake name: Remove unused variables and imports @@ -18,7 +18,7 @@ repos: files: \.py$ - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + rev: 8.0.1 hooks: - id: isort name: (isort) Sorting import statements @@ -27,8 +27,8 @@ repos: types: [python] files: \.py$ - - repo: https://github.com/psf/black - rev: 23.12.1 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 26.3.0 hooks: - id: black name: (black) Format Python code @@ -43,7 +43,7 @@ repos: types: [jupyter] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.1.11" + rev: "v0.15.5" hooks: - id: ruff args: ["--config", "pyproject.toml", "--fix", "./sheeprl"] diff --git a/notebooks/dreamer_v3_imagination.ipynb b/notebooks/dreamer_v3_imagination.ipynb index 5545c47a5..40ebbd745 100644 --- a/notebooks/dreamer_v3_imagination.ipynb +++ b/notebooks/dreamer_v3_imagination.ipynb @@ -60,7 +60,6 @@ "import torchvision\n", "from lightning.fabric import Fabric\n", "from omegaconf import OmegaConf\n", - "from PIL import Image\n", "\n", "from sheeprl.algos.dreamer_v3.agent import build_agent\n", "from sheeprl.data.buffers import SequentialReplayBuffer\n", diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 02a40a8a5..2a5064204 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -83,9 +83,7 @@ def __init__( layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity}, activation=activation, norm_layer=[layer_norm_cls] * stages, - norm_args=[ - {**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages) - ], + norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)], ), nn.Flatten(-3, -1), ) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index d825ac24a..482d16a78 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -129,7 +129,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - (world_model, _, actor_task, critic_task, target_critic_task, actor_exploration, _, player) = build_agent( + world_model, _, actor_task, critic_task, target_critic_task, actor_exploration, _, player = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/sac/loss.py b/sheeprl/algos/sac/loss.py index eaf0c73ce..bca0a5449 100644 --- a/sheeprl/algos/sac/loss.py +++ b/sheeprl/algos/sac/loss.py @@ -1,5 +1,4 @@ -"""Based on "Soft Actor-Critic Algorithms and Applications": https://arxiv.org/abs/1812.05905 -""" +"""Based on "Soft Actor-Critic Algorithms and Applications": https://arxiv.org/abs/1812.05905""" from numbers import Number diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 8c51c9b19..e51afadd7 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -135,12 +135,10 @@ def to_tensor( return buf @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None: """Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten. @@ -617,12 +615,10 @@ def __len__(self) -> int: return self.buffer_size @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add( self, @@ -860,8 +856,9 @@ def __len__(self) -> int: return self._cum_lengths[-1] if len(self._buf) > 0 else 0 @typing.overload - def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None: - ... + def add( + self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False + ) -> None: ... @typing.overload def add( @@ -869,8 +866,7 @@ def add( data: Dict[str, np.ndarray], env_idxes: Sequence[int] | None = None, validate_args: bool = False, - ) -> None: - ... + ) -> None: ... def add( self, diff --git a/sheeprl/optim/rmsprop_tf.py b/sheeprl/optim/rmsprop_tf.py index 063ef6b47..754693067 100644 --- a/sheeprl/optim/rmsprop_tf.py +++ b/sheeprl/optim/rmsprop_tf.py @@ -1,4 +1,4 @@ -""" RMSProp modified to behave like Tensorflow impl +"""RMSProp modified to behave like Tensorflow impl Originally cut & paste from PyTorch RMSProp https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py