diff --git a/allenact_plugins/ithor_plugin/ithor_environment.py b/allenact_plugins/ithor_plugin/ithor_environment.py index 50c5db406..c77b08ada 100644 --- a/allenact_plugins/ithor_plugin/ithor_environment.py +++ b/allenact_plugins/ithor_plugin/ithor_environment.py @@ -15,6 +15,8 @@ from allenact.utils.system import get_logger from allenact_plugins.ithor_plugin.ithor_constants import VISIBILITY_DISTANCE, FOV from allenact_plugins.ithor_plugin.ithor_util import round_to_factor +from ai2thor.util import metrics +from allenact.utils.cache_utils import DynamicDistanceCache class IThorEnvironment(object): @@ -31,6 +33,7 @@ class IThorEnvironment(object): def __init__( self, + all_metadata_available: bool = True, x_display: Optional[str] = None, docker_enabled: bool = False, local_thor_build: Optional[str] = None, @@ -38,11 +41,16 @@ def __init__( fov: float = FOV, player_screen_width: int = 300, player_screen_height: int = 300, + grid_size: float = 0.25, + rotate_step_degrees: int = 90, quality: str = "Very Low", restrict_to_initially_reachable_points: bool = False, make_agents_visible: bool = True, object_open_speed: float = 1.0, simplify_physics: bool = False, + snap_to_grid: bool = True, + agent_count: int = 1, + **kwargs, ) -> None: """Initializer. @@ -81,11 +89,14 @@ def __init__( self.controller: Optional[Controller] = None self._started = False self._quality = quality + self._snap_to_grid = snap_to_grid + self.agent_count = agent_count self._initially_reachable_points: Optional[List[Dict]] = None self._initially_reachable_points_set: Optional[Set[Tuple[float, float]]] = None self._move_mag: Optional[float] = None - self._grid_size: Optional[float] = None + self._grid_size: Optional[float] = grid_size + self._rotate_step_degrees = rotate_step_degrees self._visibility_distance = visibility_distance self._fov = fov self.restrict_to_initially_reachable_points = ( @@ -95,10 +106,113 @@ def __init__( self.object_open_speed = object_open_speed self._always_return_visible_range = False self.simplify_physics = simplify_physics + self.all_metadata_available = all_metadata_available + + self.scene_to_reachable_positions: Optional[Dict[str, Any]] = None + self.distance_cache: Optional[DynamicDistanceCache] = None self.start(None) # noinspection PyTypeHints + if self.all_metadata_available: + self.scene_to_reachable_positions = { + self.scene_name: copy.deepcopy(self.currently_reachable_points) + } + assert len(self.scene_to_reachable_positions[self.scene_name]) > 10 + + self.distance_cache = DynamicDistanceCache(rounding=1) self.controller.docker_enabled = docker_enabled # type: ignore + self._extra_teleport_kwargs: Dict[ + str, Any + ] = {} # Used for backwards compatability with the teleport action + + def path_from_point_to_object_type( + self, point: Dict[str, float], object_type: str, allowed_error: float + ) -> Optional[List[Dict[str, float]]]: + event = self.controller.step( + action="GetShortestPath", + objectType=object_type, + position=point, + allowedError=allowed_error, + ) + if event.metadata["lastActionSuccess"]: + return event.metadata["actionReturn"]["corners"] + else: + get_logger().debug( + "Failed to find path for {} in {}. Start point {}, agent state {}.".format( + object_type, + self.controller.last_event.metadata["sceneName"], + point, + self.agent_state(), + ) + ) + return None + + def distance_from_point_to_object_type( + self, point: Dict[str, float], object_type: str, allowed_error: float + ) -> float: + """Minimal geodesic distance from a point to an object of the given + type. + It might return -1.0 for unreachable targets. + """ + path = self.path_from_point_to_object_type(point, object_type, allowed_error) + if path: + # Because `allowed_error != 0` means that the path returned above might not start + # at `point`, we explicitly add any offset there is. + s_dist = math.sqrt( + (point["x"] - path[0]["x"]) ** 2 + (point["z"] - path[0]["z"]) ** 2 + ) + return metrics.path_distance(path) + s_dist + return -1.0 + + def distance_to_object_type(self, object_type: str, agent_id: int = 0) -> float: + """Minimal geodesic distance to object of given type from agent's + current location. + It might return -1.0 for unreachable targets. + """ + assert 0 <= agent_id < self.agent_count + assert ( + self.all_metadata_available + ), "`distance_to_object_type` cannot be called when `self.all_metadata_available` is `False`." + + def retry_dist(position: Dict[str, float], object_type: str): + allowed_error = 0.05 + debug_log = "" + d = -1.0 + while allowed_error < 2.5: + d = self.distance_from_point_to_object_type( + position, object_type, allowed_error + ) + if d < 0: + debug_log = ( + f"In scene {self.scene_name}, could not find a path from {position} to {object_type} with" + f" {allowed_error} error tolerance. Increasing this tolerance to" + f" {2 * allowed_error} any trying again." + ) + allowed_error *= 2 + else: + break + if d < 0: + get_logger().warning( + f"In scene {self.scene_name}, could not find a path from {position} to {object_type}" + f" with {allowed_error} error tolerance. Returning a distance of -1." + ) + elif debug_log != "": + get_logger().debug(debug_log) + return d + + return self.distance_cache.find_distance( + self.scene_name, + self.controller.last_event.events[agent_id].metadata["agent"]["position"], + object_type, + retry_dist, + ) + + @property + def currently_reachable_points(self) -> List[Dict[str, float]]: + """List of {"x": x, "y": y, "z": z} locations in the scene that are + currently reachable.""" + self.step({"action": "GetReachablePositions"}) + return self.last_event.metadata["actionReturn"] # type:ignore @property def scene_name(self) -> str: @@ -164,10 +278,7 @@ def last_action_return(self, value: Any) -> None: self.controller.last_event.metadata["actionReturn"] = value def start( - self, - scene_name: Optional[str], - move_mag: float = 0.25, - **kwargs, + self, scene_name: Optional[str], move_mag: float = 0.25, **kwargs, ) -> None: """Starts the ai2thor controller if it was previously stopped. @@ -189,8 +300,12 @@ def start( width=self._start_player_screen_width, height=self._start_player_screen_height, local_executable_path=self._local_thor_build, + snapToGrid=self._snap_to_grid, quality=self._quality, server_class=ai2thor.fifo_server.FifoServer, + gridSize=self._grid_size, + rotateStepDegrees=self._rotate_step_degrees, + visibilityDistance=self._visibility_distance, ) if ( @@ -218,10 +333,7 @@ def stop(self) -> None: self._started = False def reset( - self, - scene_name: Optional[str], - move_mag: float = 0.25, - **kwargs, + self, scene_name: Optional[str], move_mag: float = 0.25, **kwargs, ): """Resets the ai2thor in a new scene. @@ -529,6 +641,17 @@ def currently_reachable_points(self) -> List[Dict[str, float]]: self.step({"action": "GetReachablePositions"}) return self.last_event.metadata["actionReturn"] # type:ignore + def agent_state(self, agent_id: int = 0) -> Dict: + """Return agent position, rotation and horizon.""" + assert 0 <= agent_id < self.agent_count + + agent_meta = self.last_event.events[agent_id].metadata["agent"] + return { + **{k: float(v) for k, v in agent_meta["position"].items()}, + "rotation": {k: float(v) for k, v in agent_meta["rotation"].items()}, + "horizon": round(float(agent_meta["cameraHorizon"]), 1), + } + def get_agent_location(self) -> Dict[str, Union[float, bool]]: """Gets agent's location.""" metadata = self.controller.last_event.metadata @@ -728,6 +851,36 @@ def step( return sr + def set_object_filter(self, object_ids: List[str]): + self.controller.step("SetObjectFilter", objectIds=object_ids, renderImage=False) + + def reset_object_filter(self): + self.controller.step("ResetObjectFilter", renderImage=False) + + def teleport( + self, + position: Dict[str, float], + rotation: Dict[str, float], + horizon: float = 0.0, + ): + try: + e = self.controller.step( + action="TeleportFull", + x=position["x"], + y=position["y"], + z=position["z"], + rotation=rotation, + horizon=horizon, + **self._extra_teleport_kwargs, + ) + except ValueError as e: + if len(self._extra_teleport_kwargs) == 0: + self._extra_teleport_kwargs["standing"] = True + else: + raise e + return self.teleport(position=position, rotation=rotation, horizon=horizon) + return e.metadata["lastActionSuccess"] + @staticmethod def position_dist( p0: Mapping[str, Any], diff --git a/allenact_plugins/ithor_plugin/ithor_task_samplers.py b/allenact_plugins/ithor_plugin/ithor_task_samplers.py index e43b699af..19d0f4e6a 100644 --- a/allenact_plugins/ithor_plugin/ithor_task_samplers.py +++ b/allenact_plugins/ithor_plugin/ithor_task_samplers.py @@ -1,6 +1,8 @@ import copy import random -from typing import List, Dict, Optional, Any, Union, cast +import gzip +import json +from typing import List, Optional, Union, Dict, Any, cast import gym @@ -198,3 +200,219 @@ def set_seed(self, seed: int): self.seed = seed if seed is not None: set_seed(seed) + + +class ObjectNaviThorDatasetTaskSampler(TaskSampler): + def __init__( + self, + scenes: List[str], + scene_directory: str, + sensors: List[Sensor], + max_steps: int, + env_args: Dict[str, Any], + action_space: gym.Space, + rewards_config: Dict, + seed: Optional[int] = None, + deterministic_cudnn: bool = False, + loop_dataset: bool = True, + allow_flipping=False, + env_class=IThorEnvironment, + **kwargs, + ) -> None: + self.rewards_config = rewards_config + self.env_args = env_args + self.scenes = scenes + self.episodes = { + scene: ObjectNaviThorDatasetTaskSampler.load_dataset(scene, scene_directory) + for scene in scenes + } + self.env_class = env_class + self.object_types = [ + ep["object_type"] for scene in self.episodes for ep in self.episodes[scene] + ] + self.env: Optional[IThorEnvironment] = None + self.sensors = sensors + self.max_steps = max_steps + self._action_space = action_space + self.allow_flipping = allow_flipping + self.scene_counter: Optional[int] = None + self.scene_order: Optional[List[str]] = None + self.scene_id: Optional[int] = None + # get the total number of tasks assigned to this process + if loop_dataset: + self.max_tasks = None + else: + self.max_tasks = sum(len(self.episodes[scene]) for scene in self.episodes) + self.reset_tasks = self.max_tasks + self.scene_index = 0 + self.episode_index = 0 + + self._last_sampled_task: Optional[ObjectNaviThorGridTask] = None + + self.seed: Optional[int] = None + self.set_seed(seed) + + if deterministic_cudnn: + set_deterministic_cudnn() + + self.reset() + + def _create_environment(self) -> IThorEnvironment: + env = self.env_class( + make_agents_visible=False, + object_open_speed=0.05, + restrict_to_initially_reachable_points=False, + **self.env_args, + ) + return env + + @staticmethod + def load_dataset(scene: str, base_directory: str) -> List[Dict]: + filename = ( + "/".join([base_directory, scene]) + if base_directory[-1] != "/" + else "".join([base_directory, scene]) + ) + filename += ".json.gz" + fin = gzip.GzipFile(filename, "r") + json_bytes = fin.read() + fin.close() + json_str = json_bytes.decode("utf-8") + data = json.loads(json_str) + random.shuffle(data) + return data + + @staticmethod + def load_distance_cache_from_file(scene: str, base_directory: str) -> Dict: + filename = ( + "/".join([base_directory, scene]) + if base_directory[-1] != "/" + else "".join([base_directory, scene]) + ) + filename += ".json.gz" + fin = gzip.GzipFile(filename, "r") + json_bytes = fin.read() + fin.close() + json_str = json_bytes.decode("utf-8") + data = json.loads(json_str) + return data + + @property + def __len__(self) -> Union[int, float]: + """Length. + + # Returns + + Number of total tasks remaining that can be sampled. Can be float('inf'). + """ + return float("inf") if self.max_tasks is None else self.max_tasks + + @property + def total_unique(self) -> Optional[Union[int, float]]: + return self.reset_tasks + + @property + def last_sampled_task(self) -> Optional[ObjectNaviThorGridTask]: + return self._last_sampled_task + + def close(self) -> None: + if self.env is not None: + self.env.stop() + + @property + def all_observation_spaces_equal(self) -> bool: + """Check if observation spaces equal. + + # Returns + + True if all Tasks that can be sampled by this sampler have the + same observation space. Otherwise False. + """ + return True + + @property + def length(self) -> Union[int, float]: + """Length. + + # Returns + + Number of total tasks remaining that can be sampled. Can be float('inf'). + """ + return float("inf") if self.max_tasks is None else self.max_tasks + + def next_task( + self, force_advance_scene: bool = False + ) -> Optional[ObjectNaviThorGridTask]: + if self.max_tasks is not None and self.max_tasks <= 0: + return None + + if self.episode_index >= len(self.episodes[self.scenes[self.scene_index]]): + self.scene_index = (self.scene_index + 1) % len(self.scenes) + # shuffle the new list of episodes to train on + random.shuffle(self.episodes[self.scenes[self.scene_index]]) + self.episode_index = 0 + scene = self.scenes[self.scene_index] + episode = self.episodes[scene][self.episode_index] + if self.env is None: + self.env = self._create_environment() + + if scene.replace("_physics", "") != self.env.scene_name.replace("_physics", ""): + self.env.reset(scene_name=scene) + else: + self.env.reset_object_filter() + + self.env.set_object_filter( + object_ids=[ + o["objectId"] + for o in self.env.last_event.metadata["objects"] + if o["objectType"] == episode["object_type"] + ] + ) + + task_info = {"scene": scene, "object_type": episode["object_type"]} + if len(task_info) == 0: + get_logger().warning( + "Scene {} does not contain any" + " objects of any of the types {}.".format(scene, self.object_types) + ) + task_info["initial_position"] = episode["initial_position"] + task_info["initial_orientation"] = episode["initial_orientation"] + task_info["initial_horizon"] = episode.get("initial_horizon", 0) + task_info["distance_to_target"] = episode.get("shortest_path_length") + task_info["path_to_target"] = episode.get("shortest_path") + task_info["object_type"] = episode["object_type"] + task_info["id"] = episode["id"] + if self.allow_flipping and random.random() > 0.5: + task_info["mirrored"] = True + else: + task_info["mirrored"] = False + + self.episode_index += 1 + if self.max_tasks is not None: + self.max_tasks -= 1 + if not self.env.teleport( + position=episode["initial_position"], + rotation=episode["initial_orientation"], + horizon=episode.get("initial_horizon", 0), + ): + return self.next_task() + self._last_sampled_task = ObjectNaviThorGridTask( + env=self.env, + sensors=self.sensors, + task_info=task_info, + max_steps=self.max_steps, + action_space=self._action_space, + reward_configs=self.rewards_config, + ) + + return self._last_sampled_task + + def reset(self): + self.episode_index = 0 + self.scene_index = 0 + self.max_tasks = self.reset_tasks + + def set_seed(self, seed: int): + self.seed = seed + if seed is not None: + set_seed(seed) diff --git a/allenact_plugins/ithor_plugin/ithor_tasks.py b/allenact_plugins/ithor_plugin/ithor_tasks.py index 75670db57..44664c686 100644 --- a/allenact_plugins/ithor_plugin/ithor_tasks.py +++ b/allenact_plugins/ithor_plugin/ithor_tasks.py @@ -3,6 +3,7 @@ import gym import numpy as np +import math from allenact.base_abstractions.misc import RLStepResult from allenact.base_abstractions.sensor import Sensor @@ -65,6 +66,7 @@ def __init__( sensors: List[Sensor], task_info: Dict[str, Any], max_steps: int, + reward_configs: Dict[str, Any], **kwargs, ) -> None: """Initializer. @@ -74,15 +76,30 @@ def __init__( super().__init__( env=env, sensors=sensors, task_info=task_info, max_steps=max_steps, **kwargs ) + + self._rewards: List[float] = [] + self.reward_configs = reward_configs self._took_end_action: bool = False self._success: Optional[bool] = False self._subsampled_locations_from_which_obj_visible: Optional[ List[Tuple[float, float, int, int]] ] = None + self._all_metadata_available = env.all_metadata_available + self.path: List = ( + [] + ) # the initial coordinate will be directly taken from the optimal path + self.travelled_distance = 0.0 self.task_info["followed_path"] = [self.env.get_agent_location()] self.task_info["action_names"] = self.class_action_names() + if self._all_metadata_available: + self.last_geodesic_distance = self.env.distance_to_object_type( + self.task_info["object_type"] + ) + self.optimal_distance = self.last_geodesic_distance + self.closest_geo_distance = self.last_geodesic_distance + @property def action_space(self): return gym.spaces.Discrete(len(self._actions)) @@ -116,13 +133,24 @@ def _step(self, action: Union[int, Sequence[int]]) -> RLStepResult: ) and self._CACHED_LOCATIONS_FROM_WHICH_OBJECT_IS_VISIBLE is not None: self.env.update_graph_with_failed_action(failed_action=action_str) - self.task_info["followed_path"].append(self.env.get_agent_location()) + pose = self.env.agent_state() + + self.path.append({k: pose[k] for k in ["x", "y", "z"]}) + self.task_info["followed_path"].append(pose) + if len(self.path) > 1: + self.travelled_distance += IThorEnvironment.position_dist( + p0=self.path[-1], p1=self.path[-2], ignore_y=True + ) + # self.task_info["followed_path"].append(self.env.get_agent_location()) step_result = RLStepResult( observation=self.get_observations(), reward=self.judge(), done=self.is_done(), - info={"last_action_success": self.last_action_success}, + info={ + "last_action_success": self.last_action_success, + "action": action_str, + }, ) return step_result @@ -137,15 +165,70 @@ def is_goal_object_visible(self) -> bool: for o in self.env.visible_objects() ) + def shaping(self) -> float: + rew = 0.0 + + if self.reward_configs["shaping_weight"] == 0.0: + return rew + + geodesic_distance = self.env.distance_to_object_type( + self.task_info["object_type"] + ) + + # Ensuring the reward magnitude is not greater than the total distance moved + max_reward_mag = 0.0 + if len(self.path) >= 2: + p0, p1 = self.path[-2:] + max_reward_mag = math.sqrt( + (p0["x"] - p1["x"]) ** 2 + (p0["z"] - p1["z"]) ** 2 + ) + + if self.reward_configs.get("positive_only_reward", False): + if geodesic_distance > 0.5: + rew = max(self.closest_geo_distance - geodesic_distance, 0) + else: + if ( + self.last_geodesic_distance > -0.5 and geodesic_distance > -0.5 + ): # (robothor limits) + rew += self.last_geodesic_distance - geodesic_distance + + self.last_geodesic_distance = geodesic_distance + self.closest_geo_distance = min(self.closest_geo_distance, geodesic_distance) + + return ( + max( + min(rew, max_reward_mag), + -max_reward_mag, + ) + * self.reward_configs["shaping_weight"] + ) + def judge(self) -> float: + """Judge the last event.""" + reward = self.reward_configs["step_penalty"] + + reward += self.shaping() + + if self._took_end_action: + if self._success: + reward += self.reward_configs["goal_success_reward"] + else: + reward += self.reward_configs["failed_stop_reward"] + elif self.num_steps_taken() + 1 >= self.max_steps: + reward += self.reward_configs.get("reached_max_steps_reward", 0.0) + + self._rewards.append(float(reward)) + return float(reward) + + def judge_old(self) -> float: """Compute the reward after having taken a step.""" reward = -0.01 if not self.last_action_success: - reward += -0.03 + reward += -0.00 if self._took_end_action: - reward += 1.0 if self._success else -1.0 + reward += 10.0 if self._success else -0.0 return float(reward) diff --git a/projects/tutorials/object_nav_ithor_ppo_baseline.py b/projects/tutorials/object_nav_ithor_ppo_baseline.py new file mode 100644 index 000000000..ed7a3d68c --- /dev/null +++ b/projects/tutorials/object_nav_ithor_ppo_baseline.py @@ -0,0 +1,383 @@ +from math import ceil +from typing import Dict, Any, List, Optional, Sequence +import glob +import os +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from torchvision import models + +from allenact.base_abstractions.preprocessor import SensorPreprocessorGraph + +from allenact.embodiedai.preprocessors.resnet import ResNetPreprocessor +from allenact.embodiedai.sensors.vision_sensors import RGBSensor, DepthSensor +from allenact.embodiedai.preprocessors.resnet import ResNetPreprocessor +from allenact.utils.experiment_utils import Builder +from allenact.utils.experiment_utils import evenly_distribute_count_into_bins +from allenact.algorithms.onpolicy_sync.losses import PPO +from allenact.algorithms.onpolicy_sync.losses.ppo import PPOConfig +from allenact.base_abstractions.experiment_config import ExperimentConfig, MachineParams +from allenact.base_abstractions.sensor import SensorSuite +from allenact.base_abstractions.task import TaskSampler +from allenact.utils.experiment_utils import ( + Builder, + PipelineStage, + TrainingPipeline, + LinearDecay, +) +from allenact_plugins.ithor_plugin.ithor_sensors import ( + RGBSensorThor, + GoalObjectTypeThorSensor, +) +from allenact_plugins.ithor_plugin.ithor_task_samplers import ( + ObjectNaviThorDatasetTaskSampler, +) +from allenact_plugins.ithor_plugin.ithor_tasks import ObjectNaviThorGridTask +from projects.objectnav_baselines.models.object_nav_models import ( + ResnetTensorObjectNavActorCritic, +) + + +class ObjectNavThorPPOExperimentConfig(ExperimentConfig): + """A simple object navigation experiment in THOR. + + Training with PPO. + """ + + # A simple setting, train/valid/test are all the same single scene + # and we're looking for a single object + OBJECT_TYPES = sorted( + [ + "AlarmClock", + "Apple", + "Book", + "Bowl", + "Box", + "Candle", + "GarbageCan", + "HousePlant", + "Laptop", + "SoapBottle", + "Television", + "Toaster", + ] + ) + train_path = os.path.join( + os.getcwd(), "datasets/ithor-objectnav/train/episodes", "*.json.gz" + ) + val_path = os.path.join( + os.getcwd(), "datasets/ithor-objectnav/val/episodes", "*.json.gz" + ) + test_path = os.path.join( + os.getcwd(), "datasets/ithor-objectnav/val/episodes", "*.json.gz" + ) + + TRAIN_SCENES = [ + os.path.basename(scene).split(".")[0] for scene in glob.glob(train_path) + ] + VALID_SCENES = [ + os.path.basename(scene).split(".")[0] for scene in glob.glob(val_path) + ] + TEST_SCENES = [ + os.path.basename(scene).split(".")[0] for scene in glob.glob(test_path) + ] + + # Setting up sensors and basic environment details + CAMERA_WIDTH = 400 + CAMERA_HEIGHT = 300 + SCREEN_SIZE = 224 + SENSORS = [ + RGBSensorThor( + height=SCREEN_SIZE, + width=SCREEN_SIZE, + use_resnet_normalization=True, + uuid="rgb_lowres", + ), + GoalObjectTypeThorSensor(object_types=OBJECT_TYPES), + ] + + PREPROCESSORS = [ + Builder( + ResNetPreprocessor, + { + "input_height": SCREEN_SIZE, + "input_width": SCREEN_SIZE, + "output_width": 7, + "output_height": 7, + "output_dims": 512, + "pool": False, + "torchvision_resnet_model": models.resnet18, + "input_uuids": ["rgb_lowres"], + "output_uuid": "rgb_resnet", + }, + ), + ] + + ENV_ARGS = { + "player_screen_height": CAMERA_HEIGHT, + "player_screen_width": CAMERA_WIDTH, + "quality": "Very Low", + "rotate_step_degrees": 30, + "visibility_distance": 1.0, + "grid_size": 0.25, + "snap_to_grid": False, + } + + MAX_STEPS = 500 + REWARD_CONFIG = { + "step_penalty": -0.01, + "goal_success_reward": 10.0, + "failed_stop_reward": 0.0, + "shaping_weight": 1.0, + } + ADVANCE_SCENE_ROLLOUT_PERIOD: Optional[int] = None + VALID_SAMPLES_IN_SCENE = 10 + TEST_SAMPLES_IN_SCENE = 100 + + DEFAULT_TRAIN_GPU_IDS = tuple(range(torch.cuda.device_count())) + DEFAULT_VALID_GPU_IDS = (torch.cuda.device_count() - 1,) + DEFAULT_TEST_GPU_IDS = (torch.cuda.device_count() - 1,) + + @classmethod + def tag(cls): + return "ObjectNaviThorPPOResnetGRU" + + @classmethod + def training_pipeline(cls, **kwargs): + ppo_steps = int(300000000) + lr = 3e-4 + num_mini_batch = 1 if not torch.cuda.is_available() else 6 + update_repeats = 4 + num_steps = 128 + metric_accumulate_interval = 10000 # Log every 10 max length tasks + save_interval = 5000000 + gamma = 0.99 + use_gae = True + gae_lambda = 0.95 + max_grad_norm = 0.5 + + return TrainingPipeline( + save_interval=save_interval, + metric_accumulate_interval=metric_accumulate_interval, + optimizer_builder=Builder(optim.Adam, dict(lr=lr)), + num_mini_batch=num_mini_batch, + update_repeats=update_repeats, + max_grad_norm=max_grad_norm, + num_steps=num_steps, + named_losses={ + "ppo_loss": PPO(clip_decay=LinearDecay(ppo_steps), **PPOConfig), + }, + gamma=gamma, + use_gae=use_gae, + gae_lambda=gae_lambda, + advance_scene_rollout_period=cls.ADVANCE_SCENE_ROLLOUT_PERIOD, + pipeline_stages=[ + PipelineStage(loss_names=["ppo_loss"], max_stage_steps=ppo_steps,), + ], + lr_scheduler_builder=Builder( + LambdaLR, {"lr_lambda": LinearDecay(steps=ppo_steps)} + ), + ) + + @classmethod + def machine_params(cls, mode="train", **kwargs): + + if mode == "train": + nprocesses = 40 + workers_per_device = 1 + gpu_ids = ( + [torch.device("cpu")] + if not torch.cuda.is_available() + else cls.DEFAULT_TRAIN_GPU_IDS * workers_per_device + ) + nprocesses = evenly_distribute_count_into_bins( + nprocesses, max(len(gpu_ids), 1) + ) + sampler_devices = cls.DEFAULT_TRAIN_GPU_IDS + elif mode == "valid": + nprocesses = 1 + gpu_ids = [] if not torch.cuda.is_available() else cls.DEFAULT_VALID_GPU_IDS + elif mode == "test": + nprocesses = 1 + gpu_ids = [] if not torch.cuda.is_available() else cls.DEFAULT_TEST_GPU_IDS + else: + raise NotImplementedError("mode must be 'train', 'valid', or 'test'.") + + sensor_preprocessor_graph = ( + SensorPreprocessorGraph( + source_observation_spaces=SensorSuite(cls.SENSORS).observation_spaces, + preprocessors=cls.PREPROCESSORS, + ) + if mode == "train" + or ( + (isinstance(nprocesses, int) and nprocesses > 0) + or (isinstance(nprocesses, Sequence) and sum(nprocesses) > 0) + ) + else None + ) + + return MachineParams( + nprocesses=nprocesses, + devices=gpu_ids, + sampler_devices=sampler_devices + if mode == "train" + else gpu_ids, # ignored with > 1 gpu_ids + sensor_preprocessor_graph=sensor_preprocessor_graph, + ) + + @classmethod + def create_model(cls, **kwargs) -> nn.Module: + has_rgb = any(isinstance(s, RGBSensor) for s in cls.SENSORS) + has_depth = any(isinstance(s, DepthSensor) for s in cls.SENSORS) + goal_sensor_uuid = next( + (s.uuid for s in cls.SENSORS if isinstance(s, GoalObjectTypeThorSensor)), + None, + ) + + return ResnetTensorObjectNavActorCritic( + action_space=gym.spaces.Discrete( + len(ObjectNaviThorGridTask.class_action_names()) + ), + observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces, + goal_sensor_uuid=goal_sensor_uuid, + rgb_resnet_preprocessor_uuid="rgb_resnet" if has_rgb else None, + depth_resnet_preprocessor_uuid="depth_resnet" if has_depth else None, + hidden_size=512, + goal_dims=32, + ) + + @classmethod + def make_sampler_fn(cls, **kwargs) -> TaskSampler: + return ObjectNaviThorDatasetTaskSampler(**kwargs) + + @staticmethod + def _partition_inds(n: int, num_parts: int): + return np.round(np.linspace(0, n, num_parts + 1, endpoint=True)).astype( + np.int32 + ) + + def _get_sampler_args_for_scene_split( + self, + scenes: List[str], + process_ind: int, + total_processes: int, + seeds: Optional[List[int]] = None, + deterministic_cudnn: bool = False, + ) -> Dict[str, Any]: + if total_processes > len(scenes): # oversample some scenes -> bias + if total_processes % len(scenes) != 0: + print( + "Warning: oversampling some of the scenes to feed all processes." + " You can avoid this by setting a number of workers divisible by the number of scenes" + ) + scenes = scenes * int(ceil(total_processes / len(scenes))) + scenes = scenes[: total_processes * (len(scenes) // total_processes)] + else: + if len(scenes) % total_processes != 0: + print( + "Warning: oversampling some of the scenes to feed all processes." + " You can avoid this by setting a number of workers divisor of the number of scenes" + ) + inds = self._partition_inds(len(scenes), total_processes) + + return { + "scenes": scenes[inds[process_ind] : inds[process_ind + 1]], + "object_types": self.OBJECT_TYPES, + "env_args": self.ENV_ARGS, + "max_steps": self.MAX_STEPS, + "sensors": self.SENSORS, + "action_space": gym.spaces.Discrete( + len(ObjectNaviThorGridTask.class_action_names()) + ), + "seed": seeds[process_ind] if seeds is not None else None, + "deterministic_cudnn": deterministic_cudnn, + "rewards_config": self.REWARD_CONFIG, + } + + def train_task_sampler_args( + self, + process_ind: int, + total_processes: int, + devices: Optional[List[int]] = None, + seeds: Optional[List[int]] = None, + deterministic_cudnn: bool = False, + ) -> Dict[str, Any]: + res = self._get_sampler_args_for_scene_split( + self.TRAIN_SCENES, + process_ind, + total_processes, + seeds=seeds, + deterministic_cudnn=deterministic_cudnn, + ) + res["scene_directory"] = os.path.join( + os.getcwd(), "datasets/ithor-objectnav/train/episodes" + ) + res["loop_dataset"] = True + res["scene_period"] = "manual" + res["env_args"] = {} + res["env_args"].update(self.ENV_ARGS) + res["env_args"]["x_display"] = ( + ("0.%d" % devices[process_ind % len(devices)]) + if devices is not None and len(devices) > 0 + else None + ) + return res + + def valid_task_sampler_args( + self, + process_ind: int, + total_processes: int, + devices: Optional[List[int]] = None, + seeds: Optional[List[int]] = None, + deterministic_cudnn: bool = False, + ) -> Dict[str, Any]: + res = self._get_sampler_args_for_scene_split( + self.VALID_SCENES, + process_ind, + total_processes, + seeds=seeds, + deterministic_cudnn=deterministic_cudnn, + ) + res["loop_dataset"] = False + res["scene_directory"] = os.path.join( + os.getcwd(), "datasets/ithor-objectnav/val/episodes" + ) + res["env_args"] = {} + res["env_args"].update(self.ENV_ARGS) + res["env_args"]["x_display"] = ( + ("0.%d" % devices[process_ind % len(devices)]) + if devices is not None and len(devices) > 0 + else None + ) + return res + + def test_task_sampler_args( + self, + process_ind: int, + total_processes: int, + devices: Optional[List[int]] = None, + seeds: Optional[List[int]] = None, + deterministic_cudnn: bool = False, + ) -> Dict[str, Any]: + res = self._get_sampler_args_for_scene_split( + self.TEST_SCENES, + process_ind, + total_processes, + seeds=seeds, + deterministic_cudnn=deterministic_cudnn, + ) + res["scene_directory"] = os.path.join( + os.getcwd(), "datasets/ithor-objectnav/val/episodes" + ) + res["loop_dataset"] = False + res["env_args"] = {} + res["env_args"].update(self.ENV_ARGS) + res["env_args"]["x_display"] = ( + ("0.%d" % devices[process_ind % len(devices)]) + if devices is not None and len(devices) > 0 + else None + ) + return res