Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
ac543ba
callback: use logging instance
hasan-yaman Dec 24, 2024
f464c04
Merge pull request #2 from peak/validation-loss
hasan-yaman Dec 24, 2024
25f80ce
Revert "callback: use logging instance"
hasan-yaman Dec 24, 2024
961bea0
Merge pull request #3 from peak/revert-2-validation-loss
hasan-yaman Dec 24, 2024
f287fe5
patch: logging update
hasan-yaman Dec 24, 2024
ca636c6
patch: transformation update (for data augmentation)
hasan-yaman Feb 4, 2025
4a9b1df
patch: update for import
hasan-yaman Feb 4, 2025
60904dc
transformer: update logging & add data augmentation option
hasan-yaman Feb 20, 2025
4d9c5ac
bc: add entropy loss option
hasan-yaman Mar 6, 2025
262de28
torch_utility: fix map_location
hasan-yaman Mar 9, 2025
e1872e9
fix logging commit order
hasan-yaman Mar 13, 2025
8f3cf55
add embedding input for bc
hasan-yaman Mar 21, 2025
28f74ae
add embedding size to bc config
hasan-yaman Mar 21, 2025
fbbc19b
pass embedding explicity
hasan-yaman Mar 21, 2025
4f31f9c
add embedding to CategoricalPolicy call
hasan-yaman Mar 21, 2025
3d75cd6
pass embedding to predict functions
hasan-yaman Mar 21, 2025
16fd821
compute output size for sequential model
hasan-yaman Mar 25, 2025
062ec88
compute output size update
hasan-yaman Mar 25, 2025
18a92ef
trajectory: add embeddings
hasan-yaman Mar 25, 2025
21e1bbb
update compute output size
hasan-yaman Mar 25, 2025
04f1b95
episode: handle null embeddings
hasan-yaman Mar 25, 2025
55a8796
fix null embeddings
hasan-yaman Mar 25, 2025
1f6593a
transformer embedding
hasan-yaman Mar 25, 2025
1ec4e7a
add embeddings to transformer
hasan-yaman Mar 25, 2025
0b532ce
transformer embedding flatten
hasan-yaman Mar 25, 2025
22a2800
stateful transformer: add embedding
hasan-yaman Apr 4, 2025
c7ccbe7
embedding convert fix
hasan-yaman Apr 4, 2025
362fd84
embedding fix
hasan-yaman Apr 4, 2025
c09bea3
Merge branch 'takuseno:master' into master
hasan-yaman Aug 20, 2025
0e6e9d8
Merge branch 'master' into patch-v2
hasan-yaman Aug 20, 2025
02eee74
Merge branch 'patch-v2' of github.com:peak/d3rlpy into patch-v2
hasan-yaman Aug 20, 2025
4b76693
fix typing import
hasan-yaman Aug 20, 2025
7e5e6dc
fix import
hasan-yaman Aug 20, 2025
91977ce
fix loss
hasan-yaman Aug 20, 2025
17807d2
remove unused imports
hasan-yaman Aug 21, 2025
c04352a
fix
hasan-yaman Aug 21, 2025
aa5c559
f
hasan-yaman Aug 21, 2025
a22d0e0
embeddings at the end
hasan-yaman Aug 21, 2025
d2dad91
format
hasan-yaman Aug 21, 2025
1fe4b18
format too
hasan-yaman Aug 21, 2025
6c21c83
Merge pull request #4 from peak/patch-v2
hasan-yaman Aug 21, 2025
e0e7ca1
fix call ordering
hasan-yaman Aug 21, 2025
dbd98b1
Merge pull request #5 from peak/fix-v2
hasan-yaman Aug 22, 2025
f16605d
bc: automatic mixed precision
hasan-yaman Aug 28, 2025
c2073cd
save grad scaler and more logs for measuring time spent
hasan-yaman Aug 29, 2025
b127274
f
hasan-yaman Aug 29, 2025
8015fe8
Merge pull request #6 from peak/mixed-precision-training
hasan-yaman Aug 29, 2025
bc3e885
bc: label smoothing
hasan-yaman Sep 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ dist
/.idea/
*.egg-info
/.vscode/
venv
6 changes: 6 additions & 0 deletions d3rlpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from .constants import ActionSpace, LoggingStrategy, PositionEncodingType
from .healthcheck import run_healthcheck
from .torch_utility import Modules, TorchMiniBatch
from .transformation import (
TransformationProtocol,
register_transformation_callable,
)

__all__ = [
"algos",
Expand All @@ -50,6 +54,8 @@
"Modules",
"TorchMiniBatch",
"seed",
"TransformationProtocol",
"register_transformation_callable",
]


Expand Down
81 changes: 52 additions & 29 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@
sync_optimizer_state,
train_api,
)
from ...types import GymEnv, NDArray, Observation, TorchObservation
from ...types import (
Float32NDArray,
GymEnv,
NDArray,
Observation,
TorchObservation,
)
from ..utility import (
assert_action_space_with_dataset,
assert_action_space_with_env,
Expand Down Expand Up @@ -74,11 +80,15 @@ def inner_update(
pass

@eval_api
def predict_best_action(self, x: TorchObservation) -> torch.Tensor:
return self.inner_predict_best_action(x)
def predict_best_action(
self, x: TorchObservation, embedding: Optional[torch.Tensor]
) -> torch.Tensor:
return self.inner_predict_best_action(x, embedding)

@abstractmethod
def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
def inner_predict_best_action(
self, x: TorchObservation, embedding: Optional[torch.Tensor]
) -> torch.Tensor:
pass

@eval_api
Expand Down Expand Up @@ -222,7 +232,7 @@ def _func(*x: Sequence[torch.Tensor]) -> torch.Tensor:
observation
)

action = self._impl.predict_best_action(observation)
action = self._impl.predict_best_action(observation, None)

if self._config.action_scaler:
action = self._config.action_scaler.reverse_transform(action)
Expand Down Expand Up @@ -253,7 +263,9 @@ def _func(*x: Sequence[torch.Tensor]) -> torch.Tensor:
# workaround until version 1.6
self._impl.modules.unfreeze()

def predict(self, x: Observation) -> NDArray:
def predict(
self, x: Observation, embedding: Optional[Float32NDArray]
) -> NDArray:
"""Returns greedy actions.

.. code-block:: python
Expand All @@ -275,12 +287,17 @@ def predict(self, x: Observation) -> NDArray:
assert check_non_1d_array(x), "Input must have batch dimension."

torch_x = convert_to_torch_recursively(x, self._device)
torch_embedding = (
None
if embedding is None
else convert_to_torch_recursively(embedding, self._device)
)

with torch.no_grad():
if self._config.observation_scaler:
torch_x = self._config.observation_scaler.transform(torch_x)

action = self._impl.predict_best_action(torch_x)
action = self._impl.predict_best_action(torch_x, torch_embedding)

if self._config.action_scaler:
action = self._config.action_scaler.reverse_transform(action)
Expand Down Expand Up @@ -508,7 +525,7 @@ def fitter(
# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__
logger = D3RLPyLogger(
self.logger = D3RLPyLogger(
algo=self,
adapter_factory=logger_adapter,
experiment_name=experiment_name,
Expand All @@ -517,7 +534,7 @@ def fitter(
)

# save hyperparameters
save_config(self, logger)
save_config(self, self.logger)

# training loop
n_epochs = n_steps // n_steps_per_epoch
Expand All @@ -533,20 +550,20 @@ def fitter(
)

for itr in range_gen:
with logger.measure_time("step"):
with self.logger.measure_time("step"):
# pick transitions
with logger.measure_time("sample_batch"):
with self.logger.measure_time("sample_batch"):
batch = dataset.sample_transition_batch(
self._config.batch_size
)

# update parameters
with logger.measure_time("algorithm_update"):
with self.logger.measure_time("algorithm_update"):
loss = self.update(batch)

# record metrics
for name, val in loss.items():
logger.add_metric(name, val)
self.logger.add_metric(name, val)
epoch_loss[name].append(val)

# update progress postfix with losses
Expand All @@ -562,7 +579,7 @@ def fitter(
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
):
metrics = logger.commit(epoch, total_step)
metrics = self.logger.commit(epoch, total_step)

# call callback if given
if callback:
Expand All @@ -575,19 +592,18 @@ def fitter(
if evaluators:
for name, evaluator in evaluators.items():
test_score = evaluator(self, dataset)
logger.add_metric(name, test_score)
self.logger.add_metric(name, test_score)

# save metrics
if logging_strategy == LoggingStrategy.EPOCH:
metrics = logger.commit(epoch, total_step)
metrics = self.logger.commit(epoch, total_step)

# save model parameters
if epoch % save_interval == 0:
logger.save_model(total_step, self)
self.logger.save_model(total_step, self)

yield epoch, metrics

logger.close()
self.logger.close()

def fit_online(
self,
Expand Down Expand Up @@ -866,16 +882,23 @@ def update(self, batch: TransitionMiniBatch) -> dict[str, float]:
Dictionary of metrics.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
torch_batch = TorchMiniBatch.from_batch(
batch=batch,
gamma=self._config.gamma,
compute_returns_to_go=self.need_returns_to_go,
device=self._device,
observation_scaler=self._config.observation_scaler,
action_scaler=self._config.action_scaler,
reward_scaler=self._config.reward_scaler,
)
loss = self._impl.update(torch_batch, self._grad_step)
with self.logger.measure_time("algorithm_update_mini_batch"):
torch_batch = TorchMiniBatch.from_batch(
batch=batch,
gamma=self._config.gamma,
compute_returns_to_go=self.need_returns_to_go,
device=self._device,
observation_scaler=self._config.observation_scaler,
action_scaler=self._config.action_scaler,
reward_scaler=self._config.reward_scaler,
)

with self.logger.measure_time("algorithm_update_augmentation"):
if self._config.transform:
torch_batch = self._config.transform(torch_batch)

with self.logger.measure_time("algorithm_update_update"):
loss = self._impl.update(torch_batch, self._grad_step)
self._grad_step += 1
return loss

Expand Down
11 changes: 11 additions & 0 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
from typing import Optional

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
Expand Down Expand Up @@ -149,6 +150,11 @@ class DiscreteBCConfig(LearnableConfig):
optim_factory: OptimizerFactory = make_optimizer_field()
encoder_factory: EncoderFactory = make_encoder_field()
beta: float = 0.5
entropy_beta: float = 0.0
embedding_size: Optional[int] = None
automatic_mixed_precision: bool = False
scheduler_on_train_step: bool = True
label_smoothing: float = 0.0

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand All @@ -170,6 +176,7 @@ def inner_create_impl(
self._config.encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
embedding_size=self._config.embedding_size,
)

optim = self._config.optim_factory.create(
Expand All @@ -185,8 +192,12 @@ def inner_create_impl(
action_size=action_size,
modules=modules,
beta=self._config.beta,
entropy_beta=self._config.entropy_beta,
compiled=self.compiled,
device=self._device,
automatic_mixed_precision=self._config.automatic_mixed_precision,
scheduler_on_train_step=self._config.scheduler_on_train_step,
label_smoothing=self._config.label_smoothing,
)

def get_action_type(self) -> ActionSpace:
Expand Down
84 changes: 73 additions & 11 deletions d3rlpy/algos/qlearning/torch/bc_impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Callable, Union
from typing import Callable, Optional, Union

import torch
from torch.optim import Optimizer
import torch.nn.functional as F

from ....dataclass_utils import asdict_as_float
from ....models.torch import (
Expand Down Expand Up @@ -56,7 +57,9 @@ def __init__(

def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss:
self._modules.optim.zero_grad()
loss = self.compute_loss(batch.observations, batch.actions)
loss = self.compute_loss(
batch.observations, batch.actions, batch.embeddings
)
loss.loss.backward()
return loss

Expand All @@ -67,12 +70,17 @@ def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:

@abstractmethod
def compute_loss(
self, obs_t: TorchObservation, act_t: torch.Tensor
self,
obs_t: TorchObservation,
act_t: torch.Tensor,
embedding_t: Optional[torch.Tensor],
) -> ImitationLoss:
pass

def inner_sample_action(self, x: TorchObservation) -> torch.Tensor:
return self.inner_predict_best_action(x)
def inner_sample_action(
self, x: TorchObservation, embedding: Optional[torch.Tensor] = None
) -> torch.Tensor:
return self.inner_predict_best_action(x, embedding)

def inner_predict_value(
self, x: TorchObservation, action: torch.Tensor
Expand Down Expand Up @@ -112,11 +120,16 @@ def __init__(
)
self._policy_type = policy_type

def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
return self._modules.imitator(x).squashed_mu
def inner_predict_best_action(
self, x: TorchObservation, embedding: Optional[torch.Tensor]
) -> torch.Tensor:
return self._modules.imitator(x, embedding).squashed_mu

def compute_loss(
self, obs_t: TorchObservation, act_t: torch.Tensor
self,
obs_t: TorchObservation,
act_t: torch.Tensor,
embedding_t: Optional[torch.Tensor],
) -> ImitationLoss:
if self._policy_type == "deterministic":
assert isinstance(self._modules.imitator, DeterministicPolicy)
Expand Down Expand Up @@ -148,15 +161,20 @@ class DiscreteBCModules(BCBaseModules):
class DiscreteBCImpl(BCBaseImpl):
_modules: DiscreteBCModules
_beta: float
_entropy_beta: float

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DiscreteBCModules,
beta: float,
entropy_beta: float,
compiled: bool,
device: str,
automatic_mixed_precision: bool,
scheduler_on_train_step: bool,
label_smoothing: float
):
super().__init__(
observation_shape=observation_shape,
Expand All @@ -166,16 +184,60 @@ def __init__(
device=device,
)
self._beta = beta
self._entropy_beta = entropy_beta
self.automatic_mixed_precision = automatic_mixed_precision
self.scheduler_on_train_step = scheduler_on_train_step
self.label_smoothing = label_smoothing
self.grad_scaler = torch.amp.GradScaler(enabled=self.automatic_mixed_precision)

def inner_predict_best_action(
self, x: TorchObservation, embedding: Optional[torch.Tensor]
) -> torch.Tensor:
return self._modules.imitator(x, embedding).logits.argmax(dim=1)

def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss:
self._modules.optim.zero_grad()
with torch.autocast(device_type=self.device, enabled=self.automatic_mixed_precision):
dist = self._modules.imitator(batch.observations, batch.embeddings)
imitation_loss = F.cross_entropy(dist.logits, batch.actions.view(-1).long(), label_smoothing=self.label_smoothing)

penalty = (dist.logits ** 2).mean()
regularization_loss = self._beta * penalty

entropy_loss = -self._entropy_beta * dist.entropy().mean()

def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
return self._modules.imitator(x).logits.argmax(dim=1)
loss = imitation_loss + regularization_loss + entropy_loss

self.grad_scaler.scale(loss).backward()

if self._modules.optim._clip_grad_norm:
self.grad_scaler.unscale_(self._modules.optim.optim)
torch.nn.utils.clip_grad_norm_(self._modules.imitator.parameters(), max_norm=self._modules.optim._clip_grad_norm)

self.grad_scaler.step(self._modules.optim.optim)

self.grad_scaler.update()

if self._modules.optim._lr_scheduler and self.scheduler_on_train_step:
self._modules.optim._lr_scheduler.step()

return DiscreteImitationLoss(loss=loss, imitation_loss=imitation_loss, regularization_loss=regularization_loss, entropy_loss=entropy_loss)

def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_imitator_grad(batch)
return asdict_as_float(loss)

def compute_loss(
self, obs_t: TorchObservation, act_t: torch.Tensor
self,
obs_t: TorchObservation,
act_t: torch.Tensor,
embedding_t: Optional[torch.Tensor],
) -> DiscreteImitationLoss:
return compute_discrete_imitation_loss(
policy=self._modules.imitator,
x=obs_t,
embedding=embedding_t,
action=act_t.long(),
beta=self._beta,
entropy_beta=self._entropy_beta,
)
Loading
Loading