Skip to content

Commit 284683c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e32d624 commit 284683c

3 files changed

Lines changed: 11 additions & 15 deletions

File tree

sheeprl/algos/dreamer_v3/agent.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,7 @@ def __init__(
8383
layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity},
8484
activation=activation,
8585
norm_layer=[layer_norm_cls] * stages,
86-
norm_args=[
87-
{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)
88-
],
86+
norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)],
8987
),
9088
nn.Flatten(-3, -1),
9189
)

sheeprl/data/buffers.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,10 @@ def to_tensor(
135135
return buf
136136

137137
@typing.overload
138-
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
139-
...
138+
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...
140139

141140
@typing.overload
142-
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
143-
...
141+
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...
144142

145143
def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None:
146144
"""Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten.
@@ -617,12 +615,10 @@ def __len__(self) -> int:
617615
return self.buffer_size
618616

619617
@typing.overload
620-
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
621-
...
618+
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...
622619

623620
@typing.overload
624-
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
625-
...
621+
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...
626622

627623
def add(
628624
self,
@@ -860,17 +856,17 @@ def __len__(self) -> int:
860856
return self._cum_lengths[-1] if len(self._buf) > 0 else 0
861857

862858
@typing.overload
863-
def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None:
864-
...
859+
def add(
860+
self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False
861+
) -> None: ...
865862

866863
@typing.overload
867864
def add(
868865
self,
869866
data: Dict[str, np.ndarray],
870867
env_idxes: Sequence[int] | None = None,
871868
validate_args: bool = False,
872-
) -> None:
873-
...
869+
) -> None: ...
874870

875871
def add(
876872
self,

sheeprl/utils/distribution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ class OneHotCategoricalValidateArgs(Distribution):
307307
probs (Tensor): event probabilities
308308
logits (Tensor): event log probabilities (unnormalized)
309309
"""
310+
310311
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
311312
support = constraints.one_hot
312313
has_enumerate_support = True
@@ -391,6 +392,7 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs
391392
[1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
392393
(Bengio et al, 2013)
393394
"""
395+
394396
has_rsample = True
395397

396398
def rsample(self, sample_shape=torch.Size()):

0 commit comments

Comments
 (0)