diff --git a/.gitignore b/.gitignore index 049c3f06..652def38 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ dist /.idea/ *.egg-info /.vscode/ +venv diff --git a/d3rlpy/__init__.py b/d3rlpy/__init__.py index 6a73196a..40df87bb 100644 --- a/d3rlpy/__init__.py +++ b/d3rlpy/__init__.py @@ -26,6 +26,10 @@ from .constants import ActionSpace, LoggingStrategy, PositionEncodingType from .healthcheck import run_healthcheck from .torch_utility import Modules, TorchMiniBatch +from .transformation import ( + TransformationProtocol, + register_transformation_callable, +) __all__ = [ "algos", @@ -50,6 +54,8 @@ "Modules", "TorchMiniBatch", "seed", + "TransformationProtocol", + "register_transformation_callable", ] diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index b56f8082..7ca3c7e8 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -45,7 +45,13 @@ sync_optimizer_state, train_api, ) -from ...types import GymEnv, NDArray, Observation, TorchObservation +from ...types import ( + Float32NDArray, + GymEnv, + NDArray, + Observation, + TorchObservation, +) from ..utility import ( assert_action_space_with_dataset, assert_action_space_with_env, @@ -74,11 +80,15 @@ def inner_update( pass @eval_api - def predict_best_action(self, x: TorchObservation) -> torch.Tensor: - return self.inner_predict_best_action(x) + def predict_best_action( + self, x: TorchObservation, embedding: Optional[torch.Tensor] + ) -> torch.Tensor: + return self.inner_predict_best_action(x, embedding) @abstractmethod - def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: + def inner_predict_best_action( + self, x: TorchObservation, embedding: Optional[torch.Tensor] + ) -> torch.Tensor: pass @eval_api @@ -222,7 +232,7 @@ def _func(*x: Sequence[torch.Tensor]) -> torch.Tensor: observation ) - action = self._impl.predict_best_action(observation) + action = self._impl.predict_best_action(observation, None) if self._config.action_scaler: action = self._config.action_scaler.reverse_transform(action) @@ -253,7 +263,9 @@ def _func(*x: Sequence[torch.Tensor]) -> torch.Tensor: # workaround until version 1.6 self._impl.modules.unfreeze() - def predict(self, x: Observation) -> NDArray: + def predict( + self, x: Observation, embedding: Optional[Float32NDArray] + ) -> NDArray: """Returns greedy actions. .. code-block:: python @@ -275,12 +287,17 @@ def predict(self, x: Observation) -> NDArray: assert check_non_1d_array(x), "Input must have batch dimension." torch_x = convert_to_torch_recursively(x, self._device) + torch_embedding = ( + None + if embedding is None + else convert_to_torch_recursively(embedding, self._device) + ) with torch.no_grad(): if self._config.observation_scaler: torch_x = self._config.observation_scaler.transform(torch_x) - action = self._impl.predict_best_action(torch_x) + action = self._impl.predict_best_action(torch_x, torch_embedding) if self._config.action_scaler: action = self._config.action_scaler.reverse_transform(action) @@ -508,7 +525,7 @@ def fitter( # setup logger if experiment_name is None: experiment_name = self.__class__.__name__ - logger = D3RLPyLogger( + self.logger = D3RLPyLogger( algo=self, adapter_factory=logger_adapter, experiment_name=experiment_name, @@ -517,7 +534,7 @@ def fitter( ) # save hyperparameters - save_config(self, logger) + save_config(self, self.logger) # training loop n_epochs = n_steps // n_steps_per_epoch @@ -533,20 +550,20 @@ def fitter( ) for itr in range_gen: - with logger.measure_time("step"): + with self.logger.measure_time("step"): # pick transitions - with logger.measure_time("sample_batch"): + with self.logger.measure_time("sample_batch"): batch = dataset.sample_transition_batch( self._config.batch_size ) # update parameters - with logger.measure_time("algorithm_update"): + with self.logger.measure_time("algorithm_update"): loss = self.update(batch) # record metrics for name, val in loss.items(): - logger.add_metric(name, val) + self.logger.add_metric(name, val) epoch_loss[name].append(val) # update progress postfix with losses @@ -562,7 +579,7 @@ def fitter( logging_strategy == LoggingStrategy.STEPS and total_step % logging_steps == 0 ): - metrics = logger.commit(epoch, total_step) + metrics = self.logger.commit(epoch, total_step) # call callback if given if callback: @@ -575,19 +592,18 @@ def fitter( if evaluators: for name, evaluator in evaluators.items(): test_score = evaluator(self, dataset) - logger.add_metric(name, test_score) + self.logger.add_metric(name, test_score) # save metrics - if logging_strategy == LoggingStrategy.EPOCH: - metrics = logger.commit(epoch, total_step) + metrics = self.logger.commit(epoch, total_step) # save model parameters if epoch % save_interval == 0: - logger.save_model(total_step, self) + self.logger.save_model(total_step, self) yield epoch, metrics - logger.close() + self.logger.close() def fit_online( self, @@ -866,16 +882,23 @@ def update(self, batch: TransitionMiniBatch) -> dict[str, float]: Dictionary of metrics. """ assert self._impl, IMPL_NOT_INITIALIZED_ERROR - torch_batch = TorchMiniBatch.from_batch( - batch=batch, - gamma=self._config.gamma, - compute_returns_to_go=self.need_returns_to_go, - device=self._device, - observation_scaler=self._config.observation_scaler, - action_scaler=self._config.action_scaler, - reward_scaler=self._config.reward_scaler, - ) - loss = self._impl.update(torch_batch, self._grad_step) + with self.logger.measure_time("algorithm_update_mini_batch"): + torch_batch = TorchMiniBatch.from_batch( + batch=batch, + gamma=self._config.gamma, + compute_returns_to_go=self.need_returns_to_go, + device=self._device, + observation_scaler=self._config.observation_scaler, + action_scaler=self._config.action_scaler, + reward_scaler=self._config.reward_scaler, + ) + + with self.logger.measure_time("algorithm_update_augmentation"): + if self._config.transform: + torch_batch = self._config.transform(torch_batch) + + with self.logger.measure_time("algorithm_update_update"): + loss = self._impl.update(torch_batch, self._grad_step) self._grad_step += 1 return loss diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index da9cc275..ea9f2c0f 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -1,4 +1,5 @@ import dataclasses +from typing import Optional from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace @@ -149,6 +150,11 @@ class DiscreteBCConfig(LearnableConfig): optim_factory: OptimizerFactory = make_optimizer_field() encoder_factory: EncoderFactory = make_encoder_field() beta: float = 0.5 + entropy_beta: float = 0.0 + embedding_size: Optional[int] = None + automatic_mixed_precision: bool = False + scheduler_on_train_step: bool = True + label_smoothing: float = 0.0 def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -170,6 +176,7 @@ def inner_create_impl( self._config.encoder_factory, device=self._device, enable_ddp=self._enable_ddp, + embedding_size=self._config.embedding_size, ) optim = self._config.optim_factory.create( @@ -185,8 +192,12 @@ def inner_create_impl( action_size=action_size, modules=modules, beta=self._config.beta, + entropy_beta=self._config.entropy_beta, compiled=self.compiled, device=self._device, + automatic_mixed_precision=self._config.automatic_mixed_precision, + scheduler_on_train_step=self._config.scheduler_on_train_step, + label_smoothing=self._config.label_smoothing, ) def get_action_type(self) -> ActionSpace: diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 319e2ba8..0dd54b2c 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -1,9 +1,10 @@ import dataclasses from abc import ABCMeta, abstractmethod -from typing import Callable, Union +from typing import Callable, Optional, Union import torch from torch.optim import Optimizer +import torch.nn.functional as F from ....dataclass_utils import asdict_as_float from ....models.torch import ( @@ -56,7 +57,9 @@ def __init__( def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss: self._modules.optim.zero_grad() - loss = self.compute_loss(batch.observations, batch.actions) + loss = self.compute_loss( + batch.observations, batch.actions, batch.embeddings + ) loss.loss.backward() return loss @@ -67,12 +70,17 @@ def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]: @abstractmethod def compute_loss( - self, obs_t: TorchObservation, act_t: torch.Tensor + self, + obs_t: TorchObservation, + act_t: torch.Tensor, + embedding_t: Optional[torch.Tensor], ) -> ImitationLoss: pass - def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: - return self.inner_predict_best_action(x) + def inner_sample_action( + self, x: TorchObservation, embedding: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return self.inner_predict_best_action(x, embedding) def inner_predict_value( self, x: TorchObservation, action: torch.Tensor @@ -112,11 +120,16 @@ def __init__( ) self._policy_type = policy_type - def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: - return self._modules.imitator(x).squashed_mu + def inner_predict_best_action( + self, x: TorchObservation, embedding: Optional[torch.Tensor] + ) -> torch.Tensor: + return self._modules.imitator(x, embedding).squashed_mu def compute_loss( - self, obs_t: TorchObservation, act_t: torch.Tensor + self, + obs_t: TorchObservation, + act_t: torch.Tensor, + embedding_t: Optional[torch.Tensor], ) -> ImitationLoss: if self._policy_type == "deterministic": assert isinstance(self._modules.imitator, DeterministicPolicy) @@ -148,6 +161,7 @@ class DiscreteBCModules(BCBaseModules): class DiscreteBCImpl(BCBaseImpl): _modules: DiscreteBCModules _beta: float + _entropy_beta: float def __init__( self, @@ -155,8 +169,12 @@ def __init__( action_size: int, modules: DiscreteBCModules, beta: float, + entropy_beta: float, compiled: bool, device: str, + automatic_mixed_precision: bool, + scheduler_on_train_step: bool, + label_smoothing: float ): super().__init__( observation_shape=observation_shape, @@ -166,16 +184,60 @@ def __init__( device=device, ) self._beta = beta + self._entropy_beta = entropy_beta + self.automatic_mixed_precision = automatic_mixed_precision + self.scheduler_on_train_step = scheduler_on_train_step + self.label_smoothing = label_smoothing + self.grad_scaler = torch.amp.GradScaler(enabled=self.automatic_mixed_precision) + + def inner_predict_best_action( + self, x: TorchObservation, embedding: Optional[torch.Tensor] + ) -> torch.Tensor: + return self._modules.imitator(x, embedding).logits.argmax(dim=1) + + def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss: + self._modules.optim.zero_grad() + with torch.autocast(device_type=self.device, enabled=self.automatic_mixed_precision): + dist = self._modules.imitator(batch.observations, batch.embeddings) + imitation_loss = F.cross_entropy(dist.logits, batch.actions.view(-1).long(), label_smoothing=self.label_smoothing) + + penalty = (dist.logits ** 2).mean() + regularization_loss = self._beta * penalty + + entropy_loss = -self._entropy_beta * dist.entropy().mean() - def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: - return self._modules.imitator(x).logits.argmax(dim=1) + loss = imitation_loss + regularization_loss + entropy_loss + + self.grad_scaler.scale(loss).backward() + + if self._modules.optim._clip_grad_norm: + self.grad_scaler.unscale_(self._modules.optim.optim) + torch.nn.utils.clip_grad_norm_(self._modules.imitator.parameters(), max_norm=self._modules.optim._clip_grad_norm) + + self.grad_scaler.step(self._modules.optim.optim) + + self.grad_scaler.update() + + if self._modules.optim._lr_scheduler and self.scheduler_on_train_step: + self._modules.optim._lr_scheduler.step() + + return DiscreteImitationLoss(loss=loss, imitation_loss=imitation_loss, regularization_loss=regularization_loss, entropy_loss=entropy_loss) + + def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]: + loss = self._compute_imitator_grad(batch) + return asdict_as_float(loss) def compute_loss( - self, obs_t: TorchObservation, act_t: torch.Tensor + self, + obs_t: TorchObservation, + act_t: torch.Tensor, + embedding_t: Optional[torch.Tensor], ) -> DiscreteImitationLoss: return compute_discrete_imitation_loss( policy=self._modules.imitator, x=obs_t, + embedding=embedding_t, action=act_t.long(), beta=self._beta, + entropy_beta=self._entropy_beta, ) diff --git a/d3rlpy/algos/transformer/base.py b/d3rlpy/algos/transformer/base.py index 0bae1a59..9bbd85ec 100644 --- a/d3rlpy/algos/transformer/base.py +++ b/d3rlpy/algos/transformer/base.py @@ -16,7 +16,11 @@ from typing_extensions import Self from ...base import ImplBase, LearnableBase, LearnableConfig, save_config -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ( + IMPL_NOT_INITIALIZED_ERROR, + ActionSpace, + LoggingStrategy, +) from ...dataset import ReplayBuffer, TrajectoryMiniBatch, is_tuple_shape from ...logging import ( LOG, @@ -24,9 +28,15 @@ FileAdapterFactory, LoggerAdapterFactory, ) -from ...metrics import evaluate_transformer_with_environment +from ...metrics import EvaluatorProtocol, evaluate_transformer_with_environment from ...torch_utility import TorchTrajectoryMiniBatch, eval_api, train_api -from ...types import GymEnv, NDArray, Observation, TorchObservation +from ...types import ( + Float32NDArray, + GymEnv, + NDArray, + Observation, + TorchObservation, +) from ..utility import ( assert_action_space_with_dataset, build_scalers_with_trajectory_slicer, @@ -134,13 +144,16 @@ def __init__( context_size = algo.config.context_size self._observations = deque([], maxlen=context_size) + self._embeddings = deque([], maxlen=context_size) self._actions = deque([self._get_pad_action()], maxlen=context_size) self._rewards = deque([], maxlen=context_size) self._returns_to_go = deque([], maxlen=context_size) self._timesteps = deque([], maxlen=context_size) self._timestep = 1 - def predict(self, x: Observation, reward: float) -> Union[NDArray, int]: + def predict( + self, x: Observation, reward: float, embedding: Float32NDArray + ) -> Union[NDArray, int]: r"""Returns action. Args: @@ -151,6 +164,7 @@ def predict(self, x: Observation, reward: float) -> Union[NDArray, int]: Action. """ self._observations.append(x) + self._embeddings.append(embedding) self._rewards.append(reward) self._returns_to_go.append(self._return_rest - reward) self._timesteps.append(self._timestep) @@ -170,6 +184,9 @@ def predict(self, x: Observation, reward: float) -> Union[NDArray, int]: rewards=np.array(self._rewards).reshape((-1, 1)), returns_to_go=np.array(self._returns_to_go).reshape((-1, 1)), timesteps=np.array(self._timesteps), + embeddings=( + None if embedding is None else np.array(self._embeddings) + ), ) action = self._action_sampler(self._algo.predict(inpt)) self._actions[-1] = action @@ -181,6 +198,7 @@ def predict(self, x: Observation, reward: float) -> Union[NDArray, int]: def reset(self) -> None: """Clears stateful information.""" self._observations.clear() + self._embeddings.clear() self._actions.clear() self._rewards.clear() self._returns_to_go.clear() @@ -378,6 +396,8 @@ def fit( dataset: ReplayBuffer, n_steps: int, n_steps_per_epoch: int = 10000, + logging_steps: int = 500, + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, experiment_name: Optional[str] = None, with_timestamp: bool = True, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), @@ -386,7 +406,9 @@ def fit( eval_target_return: Optional[float] = None, eval_action_sampler: Optional[TransformerActionSampler] = None, save_interval: int = 1, + evaluators: Optional[dict[str, EvaluatorProtocol]] = None, callback: Optional[Callable[[Self, int, int], None]] = None, + epoch_callback: Optional[Callable[[Self, int, int], None]] = None, ) -> None: """Trains with given dataset. @@ -433,7 +455,7 @@ def fit( # setup logger if experiment_name is None: experiment_name = self.__class__.__name__ - logger = D3RLPyLogger( + self.logger = D3RLPyLogger( algo=self, adapter_factory=logger_adapter, experiment_name=experiment_name, @@ -442,7 +464,7 @@ def fit( ) # save hyperparameters - save_config(self, logger) + save_config(self, self.logger) # training loop n_epochs = n_steps // n_steps_per_epoch @@ -458,21 +480,21 @@ def fit( ) for itr in range_gen: - with logger.measure_time("step"): + with self.logger.measure_time("step"): # pick transitions - with logger.measure_time("sample_batch"): + with self.logger.measure_time("sample_batch"): batch = dataset.sample_trajectory_batch( self._config.batch_size, length=self._config.context_size, ) # update parameters - with logger.measure_time("algorithm_update"): + with self.logger.measure_time("algorithm_update"): loss = self.update(batch) # record metrics for name, val in loss.items(): - logger.add_metric(name, val) + self.logger.add_metric(name, val) epoch_loss[name].append(val) # update progress postfix with losses @@ -484,10 +506,25 @@ def fit( total_step += 1 + if ( + logging_strategy == LoggingStrategy.STEPS + and total_step % logging_steps == 0 + ): + self.logger.commit(epoch, total_step) + # call callback if given if callback: callback(self, epoch, total_step) + # call epoch_callback if given + if epoch_callback: + epoch_callback(self, epoch, total_step) + + if evaluators: + for name, evaluator in evaluators.items(): + test_score = evaluator(self, dataset) + self.logger.add_metric(name, test_score) + if eval_env: assert eval_target_return is not None eval_score = evaluate_transformer_with_environment( @@ -497,16 +534,17 @@ def fit( ), env=eval_env, ) - logger.add_metric("environment", eval_score) + self.logger.add_metric("environment", eval_score) # save metrics - logger.commit(epoch, total_step) + if logging_strategy == LoggingStrategy.EPOCH: + self.logger.commit(epoch, total_step) # save model parameters if epoch % save_interval == 0: - logger.save_model(total_step, self) + self.logger.save_model(total_step, self) - logger.close() + self.logger.close() def update(self, batch: TrajectoryMiniBatch) -> dict[str, float]: """Update parameters with mini-batch of data. @@ -525,6 +563,10 @@ def update(self, batch: TrajectoryMiniBatch) -> dict[str, float]: action_scaler=self._config.action_scaler, reward_scaler=self._config.reward_scaler, ) + + if self._config.transform: + torch_batch = self._config.transform(torch_batch) + loss = self._impl.update(torch_batch, self._grad_step) self._grad_step += 1 return loss diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index 8a822e4e..f2f42c34 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -1,4 +1,5 @@ import dataclasses +from typing import Optional from ...base import DeviceArg, register_learnable from ...constants import ActionSpace, PositionEncodingType @@ -71,6 +72,7 @@ class DecisionTransformerConfig(TransformerConfig): embed_dropout: float = 0.1 activation_type: str = "relu" position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE + embedding_size: Optional[int] = None compile_graph: bool = False def create( @@ -180,6 +182,7 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): position_encoding_type: PositionEncodingType = PositionEncodingType.GLOBAL warmup_tokens: int = 10240 final_tokens: int = 30000000 + embedding_size: Optional[int] = None compile_graph: bool = False def create( @@ -216,6 +219,7 @@ def inner_create_impl( position_encoding_type=self._config.position_encoding_type, device=self._device, enable_ddp=self._enable_ddp, + embedding_size=self._config.embedding_size, ) optim = self._config.optim_factory.create( transformer.named_modules(), diff --git a/d3rlpy/algos/transformer/inputs.py b/d3rlpy/algos/transformer/inputs.py index 191a7d80..faccc864 100644 --- a/d3rlpy/algos/transformer/inputs.py +++ b/d3rlpy/algos/transformer/inputs.py @@ -30,6 +30,7 @@ class TransformerInput: rewards: Float32NDArray # (L, 1) returns_to_go: Float32NDArray # (L, 1) timesteps: Int32NDArray # (L,) + embeddings: Optional[Float32NDArray] def __post_init__(self) -> None: # check sequence size @@ -53,6 +54,7 @@ class TorchTransformerInput: timesteps: torch.Tensor # (1, L) masks: torch.Tensor # (1, L) length: int + embeddings: Optional[torch.Tensor] = None @classmethod def from_numpy( @@ -74,6 +76,11 @@ def from_numpy( returns_to_go = inpt.returns_to_go[-context_size:] timesteps = inpt.timesteps[-context_size:] masks = np.ones(context_size, dtype=np.float32) + embeddings = ( + inpt.embeddings[-context_size:] + if inpt.embeddings is not None + else None + ) else: pad_size = context_size - inpt.length observations = batch_pad_observations(inpt.observations, pad_size) @@ -84,6 +91,11 @@ def from_numpy( masks = batch_pad_array( np.ones(inpt.length, dtype=np.float32), pad_size ) + embeddings = ( + batch_pad_array(inpt.embeddings, pad_size) + if inpt.embeddings is not None + else None + ) # convert numpy array to torch tensor observations_pt = convert_to_torch_recursively(observations, device) @@ -92,6 +104,9 @@ def from_numpy( returns_to_go_pt = convert_to_torch(returns_to_go, device) timesteps_pt = convert_to_torch(timesteps, device).long() masks_pt = convert_to_torch(masks, device) + embeddings_pt = ( + None if embeddings is None else convert_to_torch(embeddings, device) + ) # apply scaler if observation_scaler: @@ -115,4 +130,7 @@ def from_numpy( timesteps=timesteps_pt.unsqueeze(0), masks=masks_pt.unsqueeze(0), length=context_size, + embeddings=( + None if embeddings_pt is None else embeddings_pt.unsqueeze(0) + ), ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index e486a14e..03482d48 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -60,6 +60,7 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: inpt.returns_to_go, inpt.timesteps, 1 - inpt.masks, + inpt.embeddings, ) # (1, T, A) -> (A,) return action[0][-1] @@ -86,6 +87,7 @@ def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor: batch.returns_to_go, batch.timesteps, 1 - batch.masks, + batch.embeddings, ) # (B, T, A) -> (B, T) loss = ((action - batch.actions) ** 2).sum(dim=-1) @@ -142,6 +144,7 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: inpt.returns_to_go, inpt.timesteps, 1 - inpt.masks, + inpt.embeddings, ) # (1, T, A) -> (A,) return logits[0][-1] @@ -187,6 +190,7 @@ def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor: batch.returns_to_go, batch.timesteps, 1 - batch.masks, + batch.embeddings, ) loss = F.cross_entropy( logits.view(-1, self._action_size), diff --git a/d3rlpy/base.py b/d3rlpy/base.py index 45fcdd23..4643dc5f 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -2,8 +2,9 @@ import io import pickle from abc import ABCMeta, abstractmethod -from typing import BinaryIO, Generic, Optional, TypeVar, Union +from typing import BinaryIO, Callable, Generic, Optional, TypeVar, Union +import torch from gym.spaces import Box from gymnasium.spaces import Box as GymnasiumBox from typing_extensions import Self @@ -20,8 +21,18 @@ make_observation_scaler_field, make_reward_scaler_field, ) -from .serializable_config import DynamicConfig, generate_config_registration -from .torch_utility import Checkpointer, Modules +from .serializable_config import ( + DynamicConfig, + generate_config_registration, +) +from .torch_utility import ( + Checkpointer, + Modules, + TorchMiniBatch, + TorchTrajectoryMiniBatch, + map_location +) +from .transformation import make_transformation_callable_field from .types import GymEnv, Shape __all__ = [ @@ -96,6 +107,12 @@ class LearnableConfig(DynamicConfig): ) action_scaler: Optional[ActionScaler] = make_action_scaler_field() reward_scaler: Optional[RewardScaler] = make_reward_scaler_field() + transform: Optional[ + Callable[ + [TorchMiniBatch | TorchTrajectoryMiniBatch], + TorchMiniBatch | TorchTrajectoryMiniBatch, + ] + ] = make_transformation_callable_field() compile_graph: bool = False def create( @@ -183,6 +200,12 @@ def dump_learnable( "config": config.serialize(), "version": __version__, } + + if hasattr(algo.impl, "grad_scaler"): + scaler_bytes = io.BytesIO() + torch.save(algo.impl.grad_scaler.state_dict(), scaler_bytes) + obj["scaler"] = scaler_bytes.getvalue() + pickle.dump(obj, f) @@ -201,6 +224,11 @@ def load_learnable( algo = config.create(device=device, enable_ddp=enable_ddp) assert algo.impl algo.impl.load_model(io.BytesIO(obj["torch"])) + + if "scaler" in obj: + scaler_state_dict = torch.load(io.BytesIO(obj["scaler"]), map_location=map_location(algo._device)) + algo.impl.grad_scaler.load_state_dict(scaler_state_dict) + return algo diff --git a/d3rlpy/dataset/components.py b/d3rlpy/dataset/components.py index b8e50f76..d848ba72 100644 --- a/d3rlpy/dataset/components.py +++ b/d3rlpy/dataset/components.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Protocol, Sequence +from typing import Any, Optional, Protocol, Sequence import numpy as np @@ -79,6 +79,7 @@ class Transition: terminal: float interval: int rewards_to_go: Float32NDArray # (L, 1) + embedding: Float32NDArray @property def observation_signature(self) -> Signature: @@ -142,6 +143,7 @@ class PartialTrajectory: timesteps: Int32NDArray # (L,) masks: Float32NDArray # (L,) length: int + embeddings: Float32NDArray # (L) @property def observation_signature(self) -> Signature: @@ -238,6 +240,10 @@ def terminated(self) -> bool: """ raise NotImplementedError + @property + def embeddings(self) -> Float32NDArray: + raise NotImplementedError + @property def observation_signature(self) -> Signature: r"""Returns observation signature. @@ -331,6 +337,7 @@ class Episode: actions: NDArray rewards: Float32NDArray terminated: bool + embeddings: Optional[Float32NDArray] = None @property def observation_signature(self) -> Signature: @@ -367,6 +374,7 @@ def serialize(self) -> dict[str, Any]: "actions": self.actions, "rewards": self.rewards, "terminated": self.terminated, + "embeddings": self.embeddings, } @classmethod @@ -376,6 +384,7 @@ def deserialize(cls, serializedData: dict[str, Any]) -> "Episode": actions=serializedData["actions"], rewards=serializedData["rewards"], terminated=serializedData["terminated"], + embeddings=serializedData["embeddings"], ) def __len__(self) -> int: diff --git a/d3rlpy/dataset/episode_generator.py b/d3rlpy/dataset/episode_generator.py index 5857bc15..8abe4e0f 100644 --- a/d3rlpy/dataset/episode_generator.py +++ b/d3rlpy/dataset/episode_generator.py @@ -37,6 +37,8 @@ class EpisodeGenerator(EpisodeGeneratorProtocol): _rewards: Float32NDArray _terminals: Float32NDArray _timeouts: Float32NDArray + # Hold extra data for observations. Concatenate with observations to process + _embeddings: Float32NDArray def __init__( self, @@ -45,6 +47,7 @@ def __init__( rewards: Float32NDArray, terminals: Float32NDArray, timeouts: Optional[Float32NDArray] = None, + embeddings: Optional[Float32NDArray] = None, ): if actions.ndim == 1: actions = np.reshape(actions, [-1, 1]) @@ -71,6 +74,7 @@ def __init__( self._rewards = rewards self._terminals = terminals self._timeouts = timeouts + self._embeddings = embeddings def __call__(self) -> Sequence[Episode]: start = 0 @@ -85,6 +89,11 @@ def __call__(self) -> Sequence[Episode]: actions=self._actions[start:end], rewards=self._rewards[start:end], terminated=bool(self._terminals[i]), + embeddings=( + None + if self._embeddings is None + else self._embeddings[start:end] + ), ) episodes.append(episode) start = end diff --git a/d3rlpy/dataset/mini_batch.py b/d3rlpy/dataset/mini_batch.py index 7449aeff..c6f97f3c 100644 --- a/d3rlpy/dataset/mini_batch.py +++ b/d3rlpy/dataset/mini_batch.py @@ -42,6 +42,7 @@ class TransitionMiniBatch: terminals: Float32NDArray # (B, 1) intervals: Float32NDArray # (B, 1) transitions: Sequence[Transition] + embeddings: Float32NDArray # (B, 1) def __post_init__(self) -> None: assert check_non_1d_array(self.observations) @@ -94,6 +95,10 @@ def from_transitions( np.array([transition.interval for transition in transitions]), [-1, 1], ) + embeddings = np.stack( + [transition.embedding for transition in transitions], axis=0 + ) + return TransitionMiniBatch( observations=cast_recursively(observations, np.float32), actions=cast_recursively(actions, np.float32), @@ -103,6 +108,7 @@ def from_transitions( terminals=cast_recursively(terminals, np.float32), intervals=cast_recursively(intervals, np.float32), transitions=transitions, + embeddings=cast_recursively(embeddings, np.float32), ) @property @@ -159,6 +165,7 @@ class TrajectoryMiniBatch: timesteps: Float32NDArray # (B, L) masks: Float32NDArray # (B, L) length: int + embeddings: Float32NDArray def __post_init__(self) -> None: assert check_dtype(self.observations, np.float32) @@ -192,6 +199,9 @@ def from_partial_trajectories( terminals = np.stack([traj.terminals for traj in trajectories], axis=0) timesteps = np.stack([traj.timesteps for traj in trajectories], axis=0) masks = np.stack([traj.masks for traj in trajectories], axis=0) + embeddings = np.stack( + [traj.embeddings for traj in trajectories], axis=0 + ) return TrajectoryMiniBatch( observations=cast_recursively(observations, np.float32), actions=cast_recursively(actions, np.float32), @@ -201,6 +211,7 @@ def from_partial_trajectories( timesteps=cast_recursively(timesteps, np.float32), masks=cast_recursively(masks, np.float32), length=trajectories[0].length, + embeddings=cast_recursively(embeddings, np.float32), ) @property diff --git a/d3rlpy/dataset/trajectory_slicers.py b/d3rlpy/dataset/trajectory_slicers.py index 01853008..8f617917 100644 --- a/d3rlpy/dataset/trajectory_slicers.py +++ b/d3rlpy/dataset/trajectory_slicers.py @@ -74,6 +74,12 @@ def __call__( # compute backward padding size pad_size = size - actual_size + embeddings = ( + episode.embeddings[start:end] + if episode.embeddings is not None + else None + ) + if pad_size == 0: return PartialTrajectory( observations=observations, @@ -84,6 +90,7 @@ def __call__( timesteps=timesteps, masks=masks, length=size, + embeddings=embeddings, ) return PartialTrajectory( @@ -95,6 +102,11 @@ def __call__( timesteps=batch_pad_array(timesteps, pad_size), masks=batch_pad_array(masks, pad_size), length=size, + embeddings=( + None + if embeddings is None + else batch_pad_array(embeddings, pad_size) + ), ) diff --git a/d3rlpy/dataset/transition_pickers.py b/d3rlpy/dataset/transition_pickers.py index 0b059a10..ad9da9e5 100644 --- a/d3rlpy/dataset/transition_pickers.py +++ b/d3rlpy/dataset/transition_pickers.py @@ -69,6 +69,11 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition: terminal=float(is_terminal), interval=1, rewards_to_go=episode.rewards[index:], + embedding=( + episode.embeddings[index] + if episode.embeddings is not None + else None + ), ) diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index 74564013..9ef93540 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -1,4 +1,4 @@ -from typing import Sequence, cast +from typing import Optional, Sequence, cast import torch from torch import nn @@ -204,9 +204,12 @@ def create_categorical_policy( encoder_factory: EncoderFactory, device: str, enable_ddp: bool, + embedding_size: Optional[int] = None, ) -> CategoricalPolicy: encoder = encoder_factory.create(observation_shape) - hidden_size = compute_output_size([observation_shape], encoder) + hidden_size = compute_output_size( + [observation_shape], encoder, embedding_size + ) policy = CategoricalPolicy( encoder=encoder, hidden_size=hidden_size, action_size=action_size ) @@ -372,9 +375,12 @@ def create_discrete_decision_transformer( position_encoding_type: PositionEncodingType, device: str, enable_ddp: bool, + embedding_size: Optional[int] = None, ) -> DiscreteDecisionTransformer: encoder = encoder_factory.create(observation_shape) - hidden_size = compute_output_size([observation_shape], encoder) + hidden_size = compute_output_size( + [observation_shape], encoder, embedding_size + ) position_encoding = _create_position_encoding( position_encoding_type=position_encoding_type, diff --git a/d3rlpy/models/torch/encoders.py b/d3rlpy/models/torch/encoders.py index cd912d30..397252d0 100644 --- a/d3rlpy/models/torch/encoders.py +++ b/d3rlpy/models/torch/encoders.py @@ -26,7 +26,9 @@ class Encoder(nn.Module, metaclass=ABCMeta): # type: ignore def forward(self, x: TorchObservation) -> torch.Tensor: pass - def __call__(self, x: TorchObservation) -> torch.Tensor: + def __call__( + self, x: TorchObservation, embedding: Optional[torch.Tensor] + ) -> torch.Tensor: return super().__call__(x) @@ -367,7 +369,9 @@ def forward( def compute_output_size( - input_shapes: Sequence[Shape], encoder: nn.Module + input_shapes: Sequence[Shape], + encoder: nn.Module, + embedding_size: Optional[int] = None, ) -> int: device = next(encoder.parameters()).device with torch.no_grad(): @@ -377,5 +381,20 @@ def compute_output_size( inputs.append([torch.rand(2, *s, device=device) for s in shape]) else: inputs.append(torch.rand(2, *shape, device=device)) - y = encoder(*inputs) + if embedding_size is None: + y = encoder(*inputs, None) + else: + if inputs[0].ndim == 4: + # B x C x H X W + y = encoder( + *inputs, torch.rand(2, embedding_size, device=device) + ) + else: + # B x T x C x H x W + y = encoder( + *inputs, + torch.rand( + 2, inputs[0].shape[1], embedding_size, device=device + ) + ) return int(y.shape[1]) diff --git a/d3rlpy/models/torch/imitators.py b/d3rlpy/models/torch/imitators.py index ba91dde8..71ca6644 100644 --- a/d3rlpy/models/torch/imitators.py +++ b/d3rlpy/models/torch/imitators.py @@ -1,5 +1,5 @@ import dataclasses -from typing import cast +from typing import Optional, cast import torch import torch.nn.functional as F @@ -171,23 +171,28 @@ class ImitationLoss: class DiscreteImitationLoss(ImitationLoss): imitation_loss: torch.Tensor regularization_loss: torch.Tensor + entropy_loss: torch.Tensor def compute_discrete_imitation_loss( policy: CategoricalPolicy, x: TorchObservation, + embedding: Optional[torch.Tensor], action: torch.Tensor, beta: float, + entropy_beta: float, ) -> DiscreteImitationLoss: - dist = policy(x) + dist = policy(x, embedding) penalty = (dist.logits**2).mean() log_probs = F.log_softmax(dist.logits, dim=1) imitation_loss = F.nll_loss(log_probs, action.view(-1)) regularization_loss = beta * penalty + entropy_loss = -entropy_beta * dist.entropy().mean() return DiscreteImitationLoss( - loss=imitation_loss + regularization_loss, + loss=imitation_loss + regularization_loss + entropy_loss, imitation_loss=imitation_loss, regularization_loss=regularization_loss, + entropy_loss=entropy_loss, ) diff --git a/d3rlpy/models/torch/policies.py b/d3rlpy/models/torch/policies.py index a0815f56..6c00d6c8 100644 --- a/d3rlpy/models/torch/policies.py +++ b/d3rlpy/models/torch/policies.py @@ -67,7 +67,7 @@ def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): self._fc = nn.Linear(hidden_size, action_size) def forward(self, x: TorchObservation, *args: Any) -> ActionOutput: - h = self._encoder(x) + h = self._encoder(x, None) mu = self._fc(h) return ActionOutput(mu, torch.tanh(mu), logstd=None) @@ -129,7 +129,7 @@ def __init__( self._logstd = nn.Linear(hidden_size, action_size) def forward(self, x: TorchObservation, *args: Any) -> ActionOutput: - h = self._encoder(x) + h = self._encoder(x, None) mu = self._mu(h) if self._use_std_parameter: @@ -154,8 +154,12 @@ def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): self._encoder = encoder self._fc = nn.Linear(hidden_size, action_size) - def forward(self, x: TorchObservation) -> Categorical: - return Categorical(logits=self._fc(self._encoder(x))) + def forward( + self, x: TorchObservation, embedding: torch.Tensor + ) -> Categorical: + return Categorical(logits=self._fc(self._encoder(x, embedding))) - def __call__(self, x: TorchObservation) -> Categorical: - return super().__call__(x) + def __call__( + self, x: TorchObservation, embedding: torch.Tensor + ) -> Categorical: + return super().__call__(x, embedding) diff --git a/d3rlpy/models/torch/transformers.py b/d3rlpy/models/torch/transformers.py index ae1f5f54..f9129a2a 100644 --- a/d3rlpy/models/torch/transformers.py +++ b/d3rlpy/models/torch/transformers.py @@ -1,5 +1,6 @@ import math from abc import ABCMeta, abstractmethod +from typing import Optional import torch import torch.nn.functional as F @@ -412,6 +413,7 @@ def forward( return_to_go: torch.Tensor, timesteps: torch.Tensor, attention_mask: torch.Tensor, + embedding: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: batch_size, context_size, _ = return_to_go.shape position_embedding = self._position_encoding(timesteps) @@ -420,7 +422,18 @@ def forward( flat_x = x.reshape(-1, *x.shape[2:]) else: flat_x = [_x.reshape(-1, *_x.shape[2:]) for _x in x] - flat_state_embedding = self._encoder(flat_x) + if embedding is None: + flat_embedding = None + else: + if isinstance(embedding, torch.Tensor): + flat_embedding = embedding.reshape(-1, *embedding.shape[2:]) + else: + flat_embedding = [ + _embedding.reshape(-1, *_embedding.shape[2:]) + for _embedding in embedding + ] + + flat_state_embedding = self._encoder(flat_x, flat_embedding) state_embedding = flat_state_embedding.view( batch_size, context_size, -1 ) diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index ccd22c0c..461a0306 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -85,8 +85,7 @@ def sync_optimizer_state(targ_optim: Optimizer, optim: Optimizer) -> None: def map_location(device: str) -> Any: if "cuda" in device: - _, index = device.split(":") - return lambda storage, loc: storage.cuda(int(index)) + return "cuda" if "cpu" in device: return "cpu" raise ValueError(f"invalid device={device}") @@ -212,6 +211,7 @@ class TorchMiniBatch: intervals: torch.Tensor device: str numpy_batch: Optional[TransitionMiniBatch] = None + embeddings: Optional[torch.Tensor] = None @classmethod def from_batch( @@ -273,6 +273,7 @@ def from_batch( intervals=intervals, device=device, numpy_batch=batch, + embeddings=convert_to_torch_recursively(batch.embeddings, device), ) def copy_(self, src: Self) -> None: @@ -298,6 +299,7 @@ class TorchTrajectoryMiniBatch: masks: torch.Tensor # (B, L) device: str numpy_batch: Optional[TrajectoryMiniBatch] = None + embeddings: Optional[torch.Tensor] = None @classmethod def from_batch( @@ -337,6 +339,7 @@ def from_batch( masks=masks, device=device, numpy_batch=batch, + embeddings=convert_to_torch_recursively(batch.embeddings, device), ) def copy_(self, src: Self) -> None: diff --git a/d3rlpy/transformation.py b/d3rlpy/transformation.py new file mode 100644 index 00000000..8404bffe --- /dev/null +++ b/d3rlpy/transformation.py @@ -0,0 +1,13 @@ +from .serializable_config import ( + DynamicConfig, + generate_optional_config_generation, +) + + +class TransformationProtocol(DynamicConfig): ... + + +( + register_transformation_callable, + make_transformation_callable_field, +) = generate_optional_config_generation(TransformationProtocol) diff --git a/reproductions/offline/qdt.py b/reproductions/offline/qdt.py index 0886f34e..67da1b27 100644 --- a/reproductions/offline/qdt.py +++ b/reproductions/offline/qdt.py @@ -120,8 +120,9 @@ def relabel_dataset_rtg( sampled_actions = q_algo.sample_action(episode.observations) v = q_algo.predict_value(episode.observations, sampled_actions) values.append( - v if q_algo.reward_scaler is None - else q_algo.reward_scaler.reverse_transform(v) + v + if q_algo.reward_scaler is None + else q_algo.reward_scaler.reverse_transform(v) ) value = np.array(values).mean(axis=0) rewards = np.squeeze(episode.rewards, axis=1) diff --git a/tests/algos/qlearning/algo_test.py b/tests/algos/qlearning/algo_test.py index 8463041e..01db4022 100644 --- a/tests/algos/qlearning/algo_test.py +++ b/tests/algos/qlearning/algo_test.py @@ -150,7 +150,7 @@ def predict_tester( ) -> None: algo.create_impl(observation_shape, action_size) x = create_observations(observation_shape, 100) - y = algo.predict(x) + y = algo.predict(x, None) if algo.get_action_type() == ActionSpace.DISCRETE: assert y.shape == (100,) else: @@ -258,8 +258,8 @@ def save_and_load_tester( if deterministic_best_action: observations = create_observations(observation_shape, 100) - action1 = algo.predict(observations) - action2 = algo2.predict(observations) + action1 = algo.predict(observations, None) + action2 = algo2.predict(observations, None) assert np.all(action1 == action2) except NotImplementedError: # check interface at least @@ -299,6 +299,7 @@ def update_tester( rewards_to_go=rewards_to_go, terminal=terminal, interval=1, + embedding=None, ) transitions.append(transition) @@ -370,7 +371,7 @@ def save_policy_tester( if deterministic_best_action: action = action.detach().numpy() observations = convert_to_numpy_recursively(torch_observations) - assert np.allclose(action, algo.predict(observations), atol=1e-5) + assert np.allclose(action, algo.predict(observations, None), atol=1e-5) # check save_policy as ONNX algo.save_policy(os.path.join("test_data", "model.onnx")) @@ -393,4 +394,4 @@ def save_policy_tester( # TODO: check probablistic policy # https://github.com/pytorch/pytorch/pull/25753 if deterministic_best_action: - assert np.allclose(action, algo.predict(observations), atol=1e-5) + assert np.allclose(action, algo.predict(observations, None), atol=1e-5) diff --git a/tests/algos/transformer/algo_test.py b/tests/algos/transformer/algo_test.py index a5137ff1..5e19c2c2 100644 --- a/tests/algos/transformer/algo_test.py +++ b/tests/algos/transformer/algo_test.py @@ -122,6 +122,7 @@ def predict_tester( rewards=np.random.random((context_size, 1)).astype(np.float32), returns_to_go=np.random.random((context_size, 1)).astype(np.float32), timesteps=np.arange(context_size), + embeddings=None, ) y = algo.predict(inpt) if algo.get_action_type() == ActionSpace.DISCRETE: @@ -152,8 +153,8 @@ def save_and_load_tester( actor2 = algo2.as_stateful_wrapper(0, action_sampler) observation = create_observation(observation_shape) - action1 = actor1.predict(observation, 0) - action2 = actor2.predict(observation, 0) + action1 = actor1.predict(observation, 0, None) + action2 = actor2.predict(observation, 0, None) assert np.all(action1 == action2) @@ -184,6 +185,7 @@ def update_tester( timesteps=np.arange(context_size), masks=np.zeros(context_size, dtype=np.float32), length=context_size, + embeddings=None, ) trajectories.append(trajectory) @@ -220,7 +222,7 @@ def stateful_wrapper_tester( # check predict for _ in range(10): observation, reward = create_observation(observation_shape), 0.0 - action = wrapper.predict(observation, reward) + action = wrapper.predict(observation, reward, None) if algo.get_action_type() == ActionSpace.DISCRETE: assert isinstance(action, int) else: @@ -230,14 +232,14 @@ def stateful_wrapper_tester( # check reset observation1, reward1 = create_observation(observation_shape), 0.0 - action1 = wrapper.predict(observation1, reward1) + action1 = wrapper.predict(observation1, reward1, None) observation, reward = create_observation(observation_shape), 0.0 - action2 = wrapper.predict(observation, reward) + action2 = wrapper.predict(observation, reward, None) # in discrete case, there is high chance that action is the same. if algo.get_action_type() == ActionSpace.CONTINUOUS: assert np.all(action1 != action2) wrapper.reset() - action3 = wrapper.predict(observation1, reward1) + action3 = wrapper.predict(observation1, reward1, None) assert np.all(action1 == action3) @@ -291,6 +293,7 @@ def save_policy_tester( rewards=inputs[num_observations + 1].numpy(), returns_to_go=inputs[num_observations + 1].numpy(), timesteps=inputs[num_observations + 2].numpy(), + embeddings=None, ) if algo.get_action_type() == ActionSpace.DISCRETE: assert action == algo.predict(inpt).argmax()