diff --git a/arm/c2farm/launch_utils.py b/arm/c2farm/launch_utils.py index 1ecd086..17ccc4c 100644 --- a/arm/c2farm/launch_utils.py +++ b/arm/c2farm/launch_utils.py @@ -24,6 +24,23 @@ def create_replay(batch_size: int, timesteps: int, prioritisation: bool, save_dir: str, cameras: list, env: Env, voxel_sizes, replay_size=1e5): + ''' + create replay buffer. To be more specific, we first append more `Element` + + Input: + - batch_size: default batch size for sampling + - timesteps: how many timestep to be stacked together + - prioritisation: whether to use priority replay + - save_dir + - cameras: list of camera name + - env: RLBench environment + - voxel_sizes + - replay_size: replay buffer capacity + + Output: + - replay_buffer: return the replay buffer + ''' + trans_indicies_size = 3 * len(voxel_sizes) rot_and_grip_indicies_size = (3 + 1) @@ -74,6 +91,23 @@ def _get_action( bounds_offset: List[float], rotation_resolution: int, crop_augmentation: bool): + ''' + extract the gripper pose from the observation + + Input: + - obs_tp1: + - rlbench_scene_bounds: + - voxel_sizes: + - bounds_offset: + - rotation_resolution: + - crop_augmentation: + + Output: + - trans_indicies + - rot_and_grip_indicies + - action + ''' + quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) if quat[-1] < 0: quat = -quat @@ -115,6 +149,23 @@ def _add_keypoints_to_replay( bounds_offset: List[float], rotation_resolution: int, crop_augmentation: bool): + ''' + + + Input: + - replay: the replay buffer we would like to add keypoints + - inital_obs: the initial observation of a (sub) trajectory + - demo: demo trajectory + - env: RLBench environment + - episode_keypoints: + - cameras: + - rlbench_scene_bounds: + - voxel_sizes: + - bounds_offset: + - rotation_resolution: + - crop_augmentation: + ''' + prev_action = None obs = inital_obs for k, keypoint in enumerate(episode_keypoints): @@ -170,6 +221,19 @@ def fill_replay(replay: ReplayBuffer, bounds_offset: List[float], rotation_resolution: int, crop_augmentation: bool): + ''' + load the demo trajectories and add them to the replay + + Input: + - replay: replay buffer + - task (str): unique id for RLBench environment + - env: RLBench environment + - num_demos: number of demo trajectories + - demo_augmentation: + - cameras (list): list of camera name + - rlbench_scene_bounds: + ''' + logging.info('Filling replay with demos...') for d_idx in range(num_demos): diff --git a/arm/c2farm/networks.py b/arm/c2farm/networks.py index e2ff9e1..ecefdb4 100644 --- a/arm/c2farm/networks.py +++ b/arm/c2farm/networks.py @@ -18,6 +18,24 @@ def __init__(self, activation: str = 'relu', dense_feats: int = 32, include_prev_layer = False,): + ''' + q-netowrk for a volume of voxels + + Input: + - in_channels: int, + - out_channels: int, + - out_dense: int, + - voxel_size: int, + - low_dim_size: int, + - kernels: int, + - norm: str = None, + - activation: str = 'relu', + - dense_feats: int = 32, + - include_prev_layer=False + Output: + + ''' + super(Qattention3DNet, self).__init__() self._in_channels = in_channels self._out_channels = out_channels @@ -127,6 +145,20 @@ def build(self): self._dense_feats, self._out_dense, None, None) def forward(self, ins, proprio, prev_layer_voxel_grid): + ''' + apply lots of 3D convolution to calculate the q value for each voxel + + Input: + - ins (torch.Tensor): the input voxel grid with shape (batch_size, + voxel_feat, voxel_size, voxel_size, voxel_size) + - proprio (torch.Tensor): proprioceptive data of robot arm + ([1, 3]) or ([128, 3]) + + Output: + - trans + - rot_and_grip_out + ''' + b, _, d, h, w = ins.shape x = self._input_preprocess(ins) diff --git a/arm/c2farm/qattention_agent.py b/arm/c2farm/qattention_agent.py index 7f9fab3..b864df1 100644 --- a/arm/c2farm/qattention_agent.py +++ b/arm/c2farm/qattention_agent.py @@ -36,12 +36,42 @@ def __init__(self, self._qnet.build() def _argmax_3d(self, tensor_orig): + ''' + calculate the index of x-axis, y-axis, z-axis with the maximum + q value of the voxel grid + + Input: + - tensor_orig: the q value voxel grid with shape (batch_size, + channel_size, voxel_size, voxel_size, voxel_size) + + Output: + - indices: index of the voxel with highest q-value in the voxel grid + with shape (batch_size, 3) + ''' + b, c, d, h, w = tensor_orig.shape # c will be one idxs = tensor_orig.view(b, c, -1).argmax(-1) indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1) return indices def choose_highest_action(self, q_trans, q_rot_grip): + ''' + choose the voxel with the highest q value\\ + If `q_rot_grip` is `None`, `rot_and_grip_indicies` will be `None` as well + + Input: + - q_trans: the q value voxel grid with shape (batch_size, + channel_size, voxel_size, voxel_size, voxel_size) + - q_rot_grip: (batch_size, 360//rotation_resolution*3) + + Output: + - coords: position (index) voxel field with highest q-value with + shape (batch_size, 3) + - rot_and_grip_indicies: rotation index for eular (xyz) and + whether the gripper is open (0 or 1) with shape + (batch_size, 4) + ''' + coords = self._argmax_3d(q_trans) rot_and_grip_indicies = None if q_rot_grip is not None: @@ -58,6 +88,23 @@ def choose_highest_action(self, q_trans, q_rot_grip): def forward(self, x, proprio, pcd, bounds=None, latent=None): + ''' + Input: + - x (list): [rgb, pcd] + - proprio: the state for the robot arm + - pcd: point cloud + - bound + - latent + + Output: + - q_trans: the q value voxel grid with shape (batch_size, + channel_size, voxel_size, voxel_size, voxel_size) + - rot_and_grip_q: (batch_size, 360//rotation_resolution*3) + - voxel_grid: the voxel grid formed by observation (rgb and + depth image) with shape (batch_size, voxel_feat, + voxel_size, voxel_size, voxel_size) + ''' + # x will be list of list (list of [rgb, pcd]) b = x[0][0].shape[0] pcd_flat = torch.cat( @@ -136,6 +183,14 @@ def __init__(self, self._name = NAME + '_layer' + str(self._layer) def build(self, training: bool, device: torch.device = None): + ''' + build the network + + Input: + - training: whether the built network is in the training mode + - device: the device (gpu/cpu) used for the agent + ''' + if device is None: device = torch.device('cpu') @@ -184,6 +239,17 @@ def build(self, training: bool, device: torch.device = None): self._device = device def _extract_crop(self, pixel_action, observation): + ''' + use the `pixel_action` as anchor + + Input: + - pixel_action: (batch_size, 1, 2) + - observation: (batch_size, 1, 3, img_h, img_w) + + Output: + - crop: + ''' + # Pixel action will now be (B, 2) observation = stack_on_channel(observation) h = observation.shape[-1] @@ -200,6 +266,22 @@ def _extract_crop(self, pixel_action, observation): return crop def _preprocess_inputs(self, replay_sample): + ''' + pack the inputs + + If the layer > 0, we will crop the rgb/depth image + + Input: + - replay_sample: the sampled transitions from the replay buffer + + Output: + - obs: rgb and depth image + - obs_tp1: rgb and depth image for next timestep + - pcds: depth image + - pcds_tp1: depth image for next timestep + ''' + + obs, obs_tp1 = [], [] pcds, pcds_tp1 = [], [] self._crop_summary, self._crop_summary_tp1 = [], [] @@ -229,6 +311,28 @@ def _preprocess_inputs(self, replay_sample): return obs, obs_tp1, pcds, pcds_tp1 def _act_preprocess_inputs(self, observation): + ''' + pack the observation for each camera + + Input: + - obervation: + - front_rgb: + - front_point_cloud: + - low_dim_state: + - front_camera_extrinsics: + - front_camera_intrinsics: + + If the depth > 0 + + - attention_coordinate: + - prev_layer_voxel_grid: + - front_pixel_coord: + + Output: + - obs (list): [[rgb, pcd], [rgb, pcd], ......] + - pcds: list of point cloud (with shape (batch_size, 3, h, w)) + ''' + obs, pcds = [], [] for n in self._camera_names: if self._layer > 0 and 'wrist' not in n: @@ -243,6 +347,19 @@ def _act_preprocess_inputs(self, observation): return obs, pcds def _get_value_from_voxel_index(self, q, voxel_idx): + ''' + extract the feature from the voxel grid feature with index + + Input: + - q: a voxel-grid of the q value (batch_size, channel_size, + voxel_size, voxel_size, voxel_size) + - voxel_idx: (batch_size, 3) + + Output: + - chosen_voxel_values: (batch_size, channel_size) + ''' + + b, c, d, h, w = q.shape q_flat = q.view(b, c, d * h * w) flat_indicies = (voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2])[:, None].long() @@ -251,6 +368,17 @@ def _get_value_from_voxel_index(self, q, voxel_idx): return chosen_voxel_values def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx): + ''' + + + Input: + - rot_grip_q: (batch_size, 360//rotation_resolution*3 + 2) + - rot_and_grip_idx: (batch_size, 4) + + Output: + - rot_and_grip_values: (batch_size, 4) + ''' + q_rot = torch.stack(torch.split( rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1), dim=1) # B, 3, 72 @@ -263,6 +391,16 @@ def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx): return rot_and_grip_values def update(self, step: int, replay_sample: dict) -> dict: + ''' + update the policy parameters + + NOTE: 'tp1' means next state + + Input: + - step: + - replay_sample (dict): contains the sampled transitions + ''' + action_trans = replay_sample['trans_action_indicies'][:, -1, self._layer * 3:self._layer * 3 + 3] @@ -374,6 +512,38 @@ def update(self, step: int, replay_sample: dict) -> dict: def act(self, step: int, observation: dict, deterministic=False) -> ActResult: deterministic = True # TODO: Don't explicitly explore. + ''' + take the observation as input and reture the action + + Input: + - step: dummy, please neglect + - observation (dict): + - front_rgb: rgb image of front camera + - front_point_cloud: depth image of front camera + - low_dim_state: the state of the robot arm + - front_camera_extrinsics + - front_camera_intrinsics + + (when layer>=1, the following information is included) + - attention_coordinate + - prev_layer_voxel_grid + - front_pixel_coord + + Output: + - act_result: + - action (tuple): it contain `coords` and `rot_grip_action` + - coords: the gripper position in the voxel grid index + - rot_grip_action: + - observation_elements (dict): + - attention_coordinate: + - prev_layer_voxel_grid: + - info (dict): + - voxel_grid_depth: + - q_depth: + - voxel_idx_depth: + ''' + + bounds = self._coordinate_bounds if self._layer > 0: diff --git a/arm/c2farm/qattention_stack_agent.py b/arm/c2farm/qattention_stack_agent.py index 15fa59a..52f4e04 100644 --- a/arm/c2farm/qattention_stack_agent.py +++ b/arm/c2farm/qattention_stack_agent.py @@ -30,6 +30,14 @@ def __init__(self, self._rotation_prediction_depth = rotation_prediction_depth def build(self, training: bool, device=None) -> None: + ''' + iteratively build the network for each agent + + Input: + - training: whether the agent is in the training mode + - device: + ''' + self._device = device if self._device is None: self._device = torch.device('cpu') diff --git a/arm/c2farm/voxel_grid.py b/arm/c2farm/voxel_grid.py index f3f1aaa..b81dbd4 100644 --- a/arm/c2farm/voxel_grid.py +++ b/arm/c2farm/voxel_grid.py @@ -17,6 +17,17 @@ def __init__(self, batch_size, feature_size, # e.g. rgb or image features max_num_coords: int,): + ''' + module that transform point cloud to voxel + + - coord_bounds: the bound of 3D volume in world coordinate + - voxel_size: number of voxels in + - device: gpu or cpu + - batch_size: + - feature_size: + - max_num_coords: the number of pixel from the images + ''' + super(VoxelGrid, self).__init__() self._device = device self._voxel_size = voxel_size @@ -95,6 +106,17 @@ def _broadcast(self, src: torch.Tensor, other: torch.Tensor, dim: int): def _scatter_mean(self, src: torch.Tensor, index: torch.Tensor, out: torch.Tensor, dim: int = -1): + ''' + average the value from `src` with specified induce `index` and dimension `dim` + , then, save to another tensor `out` + + Input: + - src: + - index: + - out: + - dim: + ''' + out = out.scatter_add_(dim, index, src) index_dim = dim @@ -115,6 +137,21 @@ def _scatter_mean(self, src: torch.Tensor, index: torch.Tensor, out: torch.Tenso return out def _scatter_nd(self, indices, updates): + ''' + contructure the voxel grid with the point feature + and index of the point in the voxel grid + + Input: + - indices: the index with (batch_size*num_coords, 4), and the index + includes batch_idx, x_idx, y_idx, z_idx on the last + dimention + - updates: the feature of each position/point with + shape (batch_size*num_coords, feature_size) + + Output: + - voxel_grid: a volume of voxels + ''' + indices_shape = indices.shape num_index_dims = indices_shape[-1] flat_updates = updates.view((-1,)) @@ -137,6 +174,20 @@ def _scatter_nd(self, indices, updates): def coords_to_bounding_voxel_grid(self, coords, coord_features=None, coord_bounds=None): + ''' + form the voxel grid with bound `coord_bounds`. + + Input: + - coords: 3D position, i.e., point cloud with shape + (self._batch_size, self._num_coords, 3) + - coord_features: feature for each point/position (e.g. RGB) with + (self._batch_size, self._num_coords, 3) + + Output: + - voxel_grid: (self._batch_size, voxel_size, voxel_size, + voxel_size, voxel_feat) + ''' + voxel_indicy_denmominator = self._voxel_indicy_denmominator res, bb_mins = self._res, self._bb_mins if coord_bounds is not None: diff --git a/arm/custom_rlbench_env.py b/arm/custom_rlbench_env.py index a9bbfc0..d4da61f 100644 --- a/arm/custom_rlbench_env.py +++ b/arm/custom_rlbench_env.py @@ -51,6 +51,15 @@ def observation_elements(self) -> List[ObservationElement]: return obs_elems def extract_obs(self, obs: Observation, t=None, prev_action=None): + ''' + + Input: + - obs: observation in `rlbench.backend.observation.Observation` type + - t: + - prev_action: the action executed in previous step + ''' + + obs.joint_velocities = None grip_mat = obs.gripper_matrix grip_pose = obs.gripper_pose diff --git a/arm/demo_loading_utils.py b/arm/demo_loading_utils.py index 2643d39..0ef6142 100644 --- a/arm/demo_loading_utils.py +++ b/arm/demo_loading_utils.py @@ -19,6 +19,18 @@ def _is_stopped(demo, i, obs, stopped_buffer, delta=0.1): def keypoint_discovery(demo: Demo, stopping_delta=0.1) -> List[int]: + ''' + get the index of the observation in a demo trajectory + + Input: + - demo: a trajectory that contains a sequence of observation + - stopping_delta: + + Output: + - episode_keypoints (list): indices for the keypoint + ''' + + episode_keypoints = [] prev_gripper_open = demo[0].gripper_open stopped_buffer = 0 diff --git a/arm/preprocess_agent.py b/arm/preprocess_agent.py index 7e9afab..447f0f3 100644 --- a/arm/preprocess_agent.py +++ b/arm/preprocess_agent.py @@ -29,6 +29,15 @@ def update(self, step: int, replay_sample: dict) -> dict: def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + ''' + observation (dict): + 'front_rgb': (1, 1, 3, 128, 128) + 'front_point_cloud': (1, 1, 3, 128, 128) + 'low_dim_state': (1, 1, 3) + 'front_camera_extrinsics': (1, 1, 4, 4) + 'front_camera_intrinsics': (1, 1, 3, 3) + ''' + # observation = {k: torch.tensor(v) for k, v in observation.items()} for k, v in observation.items(): if 'rgb' in k: diff --git a/arm/utils.py b/arm/utils.py index 33edfdd..626b298 100644 --- a/arm/utils.py +++ b/arm/utils.py @@ -20,6 +20,16 @@ def loss_weights(replay_sample, beta=1.0): def soft_updates(net, target_net, tau): + ''' + apply soft update from `net` to `target_net` with weight `tau`. \\ + `target_net = target_net * (1-tau) + net * tau` + + Input: + - net: source network + - target_net: updated network + - tau: the update weight + ''' + for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( tau * param.data + (1 - tau) * target_param.data @@ -44,6 +54,17 @@ def quaternion_to_discrete_euler(quaternion, resolution): def discrete_euler_to_quaternion(discrete_euler, resolution): + ''' + transform the descrete eular to quaternion + + Input: + - discrete_euler: index for xyz euluer + - resolution: the minimum unit in degree + + Output: + - rotation: quaternion + ''' + euluer = (discrete_euler * resolution) - 180 return Rotation.from_euler('xyz', euluer, degrees=True).as_quat() diff --git a/yarr/__init__.py b/yarr/__init__.py new file mode 100755 index 0000000..11d27f8 --- /dev/null +++ b/yarr/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1' diff --git a/yarr/agents/__init__.py b/yarr/agents/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/yarr/agents/agent.py b/yarr/agents/agent.py new file mode 100755 index 0000000..d394fe6 --- /dev/null +++ b/yarr/agents/agent.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from typing import Any, List + + +class Summary(object): + def __init__(self, name: str, value: Any): + self.name = name + self.value = value + + +class ScalarSummary(Summary): + pass + + +class HistogramSummary(Summary): + pass + + +class ImageSummary(Summary): + pass + + +class VideoSummary(Summary): + def __init__(self, name: str, value: Any, fps: int = 30): + super(VideoSummary, self).__init__(name, value) + self.fps = fps + + +class ActResult(object): + def __init__(self, + action: Any, + observation_elements: dict = None, + replay_elements: dict = None, + info: dict = None): + ''' + pack the action and related results + + Input: + - action: the action predicted by the agent + - observation_elements (dict): the processed observation used in + downstream framework + - replay_elements (dict) + - info (dict): remaining information of a transition + ''' + + self.action = action + self.observation_elements = observation_elements or {} + self.replay_elements = replay_elements or {} + self.info = info or {} + + +class Agent(ABC): + @abstractmethod + def build(self, training: bool, device=None) -> None: + pass + + @abstractmethod + def update(self, step: int, replay_sample: dict) -> dict: + pass + + @abstractmethod + def act(self, step: int, observation: dict, + deterministic: bool) -> ActResult: + # returns dict of values that get put in the replay. + # One of these must be 'action'. + pass + + def reset(self) -> None: + pass + + @abstractmethod + def update_summaries(self) -> List[Summary]: + pass + + @abstractmethod + def act_summaries(self) -> List[Summary]: + pass + + @abstractmethod + def load_weights(self, savedir: str) -> None: + pass + + @abstractmethod + def save_weights(self, savedir: str) -> None: + pass diff --git a/yarr/envs/__init__.py b/yarr/envs/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/yarr/envs/env.py b/yarr/envs/env.py new file mode 100755 index 0000000..bc9fff3 --- /dev/null +++ b/yarr/envs/env.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Any, List + +import numpy as np + +from yarr.utils.observation_type import ObservationElement +from yarr.utils.transition import Transition + + +class Env(ABC): + + def __init__(self): + self._eval_env = False + + @property + def eval(self): + return self._eval_env + + @eval.setter + def eval(self, eval): + self._eval_env = eval + + @abstractmethod + def launch(self) -> None: + pass + + def shutdown(self) -> None: + pass + + @abstractmethod + def reset(self) -> dict: + pass + + @abstractmethod + def step(self, action: np.ndarray) -> Transition: + pass + + @property + @abstractmethod + def observation_elements(self) -> List[ObservationElement]: + pass + + @property + @abstractmethod + def action_shape(self) -> tuple: + pass + + @property + @abstractmethod + def env(self) -> Any: + pass diff --git a/yarr/envs/rlbench_env.py b/yarr/envs/rlbench_env.py new file mode 100755 index 0000000..ff6ba50 --- /dev/null +++ b/yarr/envs/rlbench_env.py @@ -0,0 +1,147 @@ +from typing import Type, List + +import numpy as np +try: + from rlbench import ObservationConfig, Environment, CameraConfig +except (ModuleNotFoundError, ImportError) as e: + print("You need to install RLBench: 'https://github.com/stepjam/RLBench'") + raise e +from rlbench.action_modes import ActionMode +from rlbench.backend.observation import Observation +from rlbench.backend.task import Task + +from yarr.envs.env import Env +from yarr.utils.observation_type import ObservationElement +from yarr.utils.transition import Transition + + +class RLBenchEnv(Env): + + ROBOT_STATE_KEYS = [ + 'joint_velocities', 'joint_positions', 'joint_forces', 'gripper_open', + 'gripper_pose', 'gripper_joint_positions', 'gripper_touch_forces', + 'task_low_dim_state', 'misc' + ] + + def __init__(self, + task_class: Type[Task], + observation_config: ObservationConfig, + action_mode: ActionMode, + dataset_root: str = '', + channels_last=False, + headless=True): + super(RLBenchEnv, self).__init__() + self._task_class = task_class + self._observation_config = observation_config + self._channels_last = channels_last + self._rlbench_env = Environment(action_mode=action_mode, + obs_config=observation_config, + dataset_root=dataset_root, + headless=headless) + self._task = None + + def extract_obs(self, obs: Observation): + obs_dict = vars(obs) + obs_dict = {k: v for k, v in obs_dict.items() if v is not None} + robot_state = obs.get_low_dim_data() + # Remove all of the individual state elements + obs_dict = { + k: v + for k, v in obs_dict.items() + if k not in RLBenchEnv.ROBOT_STATE_KEYS + } + if not self._channels_last: + # Swap channels from last dim to 1st dim + obs_dict = { + k: np.transpose(v, [2, 0, 1]) + if v.ndim == 3 else np.expand_dims(v, 0) + for k, v in obs_dict.items() + } + else: + # Add extra dim to depth data + obs_dict = { + k: v if v.ndim == 3 else np.expand_dims(v, -1) + for k, v in obs_dict.items() + } + obs_dict['low_dim_state'] = np.array(robot_state, dtype=np.float32) + return obs_dict + + def launch(self): + self._rlbench_env.launch() + self._task = self._rlbench_env.get_task(self._task_class) + + def shutdown(self): + self._rlbench_env.shutdown() + + def reset(self) -> dict: + descriptions, obs = self._task.reset() + return self.extract_obs(obs) + + def step(self, action: np.ndarray) -> Transition: + obs, reward, terminal = self._task.step(action) + obs = self.extract_obs(obs) + return Transition(obs, reward, terminal) + + def _get_cam_observation_elements(self, camera: CameraConfig, prefix: str): + elements = [] + if camera.rgb: + shape = (camera.image_size + (3, ) if self._channels_last else + (3, ) + camera.image_size) + elements.append( + ObservationElement('%s_rgb' % prefix, shape, np.uint8)) + if camera.depth: + shape = (camera.image_size + (1, ) if self._channels_last else + (1, ) + camera.image_size) + elements.append( + ObservationElement('%s_depth' % prefix, shape, np.float32)) + if camera.mask: + raise NotImplementedError() + return elements + + @property + def observation_elements(self) -> List[ObservationElement]: + elements = [] + robot_state_len = 0 + if self._observation_config.joint_velocities: + robot_state_len += 7 + if self._observation_config.joint_positions: + robot_state_len += 7 + if self._observation_config.joint_forces: + robot_state_len += 7 + if self._observation_config.gripper_open: + robot_state_len += 1 + if self._observation_config.gripper_pose: + robot_state_len += 7 + if self._observation_config.gripper_joint_positions: + robot_state_len += 2 + if self._observation_config.gripper_touch_forces: + robot_state_len += 2 + if self._observation_config.task_low_dim_state: + raise NotImplementedError() + if robot_state_len > 0: + elements.append( + ObservationElement('low_dim_state', (robot_state_len, ), + np.float32)) + elements.extend( + self._get_cam_observation_elements( + self._observation_config.left_shoulder_camera, + 'left_shoulder')) + elements.extend( + self._get_cam_observation_elements( + self._observation_config.right_shoulder_camera, + 'right_shoulder')) + elements.extend( + self._get_cam_observation_elements( + self._observation_config.front_camera, 'front')) + elements.extend( + self._get_cam_observation_elements( + self._observation_config.wrist_camera, 'wrist')) + return elements + + @property + def action_shape(self): + return (self._rlbench_env.action_size, ) + + @property + def env(self) -> Environment: + return self._rlbench_env diff --git a/yarr/replay_buffer/__init__.py b/yarr/replay_buffer/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/yarr/replay_buffer/prioritized_replay_buffer.py b/yarr/replay_buffer/prioritized_replay_buffer.py new file mode 100755 index 0000000..4e8eb3c --- /dev/null +++ b/yarr/replay_buffer/prioritized_replay_buffer.py @@ -0,0 +1,237 @@ +''' +An implementation of Prioritized Experience Replay (PER). + +This implementation is based on the paper "Prioritized Experience Replay" +by Tom Schaul et al. (2015). +''' +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .uniform_replay_buffer import * +from .sum_tree import * +import numpy as np + +PRIORITY = 'priority' + + +class PrioritizedReplayBuffer(UniformReplayBuffer): + ''' + An out-of-graph Replay Buffer for Prioritized Experience Replay. + + See uniform_replay_buffer.py for details. + ''' + def __init__(self, *args, **kwargs): + ''' + Initializes OutOfGraphPrioritizedReplayBuffer. + ''' + + super(PrioritizedReplayBuffer, self).__init__(*args, **kwargs) + self._sum_tree = SumTree(self._replay_capacity) + + def get_storage_signature( + self) -> Tuple[List[ReplayElement], List[ReplayElement]]: + ''' + Returns a default list of elements to be stored in this replay memory. + + Note - Derived classes may return a different signature. + + Returns: + - dict of ReplayElements defining the type of the contents stored. + ''' + + storage_elements, obs_elements = super(PrioritizedReplayBuffer, + self).get_storage_signature() + storage_elements.append(ReplayElement(PRIORITY, (), np.float32), ) + + return storage_elements, obs_elements + + def add(self, action, reward, terminal, timeout, priority=None, **kwargs): + kwargs['priority'] = priority + super(PrioritizedReplayBuffer, self).add(action, reward, terminal, + timeout, **kwargs) + + def _add(self, kwargs: dict): + ''' + Internal add method to add to the storage arrays. + + Args: + - kwargs: All the elements in a transition. + ''' + + with self._lock: + cursor = self.cursor() + priority = kwargs[PRIORITY] + if priority is None: + priority = self._sum_tree.max_recorded_priority + + if self._disk_saving: + self._store[TERMINAL][cursor] = kwargs[TERMINAL] + with open(join(self._save_dir, '%d.replay' % cursor), + 'wb') as f: + pickle.dump(kwargs, f) + # If first add, then pad for correct wrapping + if self.add_count == 0: + self._add_initial_to_disk(kwargs) + else: + for name, data in kwargs.items(): + self._store[name][cursor] = data + + self._sum_tree.set(self.cursor(), priority) + self.add_count += 1 + self.invalid_range = invalid_range(self.cursor(), + self._replay_capacity, + self._timesteps, + self._update_horizon) + + def add_final(self, **kwargs): + ''' + Adds a transition to the replay memory. + + Args: + - **kwargs: The remaining args + ''' + + if self.is_empty() or self._store['terminal'][self.cursor() - 1] != 1: + raise ValueError('The previous transition was not terminal.') + self._check_add_types(kwargs, self._obs_signature) + transition = self._final_transition(kwargs) + for element_type in self._storage_signature: + # 0 priority for final observation. + if element_type.name == PRIORITY: + transition[element_type.name] = 0.0 + self._add(transition) + + def sample_index_batch(self, batch_size): + ''' + Returns a batch of valid indices sampled as in Schaul et al. (2015). + + Args: + - batch_size: int, number of indices returned. + + Returns: + - list of ints, a batch of valid indices sampled uniformly. + + Raises: + - Exception: If the batch was not constructed after maximum number of tries. + ''' + + # Sample stratified indices. Some of them might be invalid. + indices = self._sum_tree.stratified_sample(batch_size) + allowed_attempts = self._max_sample_attempts + for i in range(len(indices)): + if not self.is_valid_transition(indices[i]): + if allowed_attempts == 0: + raise RuntimeError( + 'Max sample attempts: Tried {} times but only sampled {}' + ' valid indices. Batch size is {}'.format( + self._max_sample_attempts, i, batch_size)) + index = indices[i] + while not self.is_valid_transition( + index) and allowed_attempts > 0: + # If index i is not valid keep sampling others. Note that this + # is not stratified. + index = self._sum_tree.sample() + allowed_attempts -= 1 + indices[i] = index + return indices + + def sample_transition_batch(self, + batch_size=None, + indices=None, + pack_in_dict=True): + ''' + Returns a batch of transitions with extra storage and the priorities. + + The extra storage are defined through the extra_storage_types constructor + argument. + + When the transition is terminal next_state_batch has undefined contents. + + Args: + - batch_size: int, number of transitions returned. If None, the default + batch_size will be used. + - indices: None or list of ints, the indices of every transition in the + batch. If None, sample the indices uniformly. + + Returns: + - transition_batch: tuple of np.arrays with the shape and type as in + get_transition_elements(). + ''' + + transition = super(PrioritizedReplayBuffer, + self).sample_transition_batch(batch_size, + indices, + pack_in_dict=False) + + transition_elements = self.get_transition_elements(batch_size) + transition_names = [e.name for e in transition_elements] + probabilities_index = transition_names.index('sampling_probabilities') + indices_index = transition_names.index('indices') + indices = transition[indices_index] + + # print('transition_names', transition_names) + + # The parent returned an empty array for the probabilities. Fill it with the + # contents of the sum tree. + transition[probabilities_index][:] = self.get_priority(indices) + batch_arrays = transition + if pack_in_dict: + batch_arrays = self.unpack_transition(transition, + transition_elements) + return batch_arrays + + def set_priority(self, indices, priorities): + ''' + Sets the priority of the given elements according to Schaul et al. + + Args: + - indices: np.array with dtype int32, of indices in range + [0, replay_capacity). + - priorities: float, the corresponding priorities. + ''' + + assert indices.dtype == np.int32, ('Indices must be integers, ' + 'given: {}'.format(indices.dtype)) + for index, priority in zip(indices, priorities): + self._sum_tree.set(index, priority) + + def get_priority(self, indices): + ''' + Fetches the priorities correspond to a batch of memory indices. + + For any memory location not yet used, the corresponding priority is 0. + + Args: + - indices: np.array with dtype int32, of indices in range + [0, replay_capacity). + + Returns: + - priorities: float, the corresponding priorities. + ''' + + assert indices.shape, 'Indices must be an array.' + assert indices.dtype == np.int32, ('Indices must be int32s, ' + 'given: {}'.format(indices.dtype)) + batch_size = len(indices) + priority_batch = np.empty((batch_size), dtype=np.float32) + for i, memory_index in enumerate(indices): + priority_batch[i] = self._sum_tree.get(memory_index) + return priority_batch + + def get_transition_elements(self, batch_size=None): + '''Returns a 'type signature' for sample_transition_batch. + + Args: + - batch_size: int, number of transitions returned. If None, the default + batch_size will be used. + Returns: + - signature: A namedtuple describing the method's return type signature. + ''' + + parent_transition_type = (super( + PrioritizedReplayBuffer, self).get_transition_elements(batch_size)) + probablilities_type = [ + ReplayElement('sampling_probabilities', (batch_size, ), np.float32) + ] + return parent_transition_type + probablilities_type diff --git a/yarr/replay_buffer/replay_buffer.py b/yarr/replay_buffer/replay_buffer.py new file mode 100755 index 0000000..ddc4dee --- /dev/null +++ b/yarr/replay_buffer/replay_buffer.py @@ -0,0 +1,70 @@ +from abc import ABC +from typing import Tuple, List + + +class ReplayElement(object): + def __init__(self, name, shape, type, is_observation=False): + self.name = name + self.shape = shape + self.type = type + self.is_observation = is_observation + + +class ReplayBuffer(ABC): + def replay_capacity(self): + pass + + def batch_size(self): + pass + + def get_storage_signature( + self) -> Tuple[List[ReplayElement], List[ReplayElement]]: + pass + + def add(self, action, reward, terminal, timeout, **kwargs): + pass + + def add_final(self, **kwargs): + pass + + def is_empty(self): + pass + + def is_full(self): + pass + + def cursor(self): + pass + + def get_range(self, array, start_index, end_index): + pass + + def get_range_stack(self, array, start_index, end_index, terminals=None): + pass + + def get_terminal_stack(self, index): + pass + + def is_valid_transition(self, index): + pass + + def sample_index_batch(self, batch_size): + pass + + def unpack_transition(self, transition_tensors, transition_type): + pass + + def sample_transition_batch(self, + batch_size=None, + indices=None, + pack_in_dict=True): + pass + + def get_transition_elements(self, batch_size=None): + pass + + def shutdown(self): + pass + + def using_disk(self): + pass diff --git a/yarr/replay_buffer/sum_tree.py b/yarr/replay_buffer/sum_tree.py new file mode 100755 index 0000000..df3e944 --- /dev/null +++ b/yarr/replay_buffer/sum_tree.py @@ -0,0 +1,193 @@ +"""A sum tree data structure. + +Used for prioritized experience replay. See prioritized_replay_buffer.py +and Schaul et al. (2015). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import random + +import numpy as np + + +class SumTree(object): + """A sum tree data structure for storing replay priorities. + + A sum tree is a complete binary tree whose leaves contain values called + priorities. Internal nodes maintain the sum of the priorities of all leaf + nodes in their subtree. + + For capacity = 4, the tree may look like this: + + +---+ + |2.5| + +-+-+ + | + +-------+--------+ + | | + +-+-+ +-+-+ + |1.5| |1.0| + +-+-+ +-+-+ + | | + +----+----+ +----+----+ + | | | | + +-+-+ +-+-+ +-+-+ +-+-+ + |0.5| |1.0| |0.5| |0.5| + +---+ +---+ +---+ +---+ + + This is stored in a list of numpy arrays: + self.nodes = [ [2.5], [1.5, 1], [0.5, 1, 0.5, 0.5] ] + + For conciseness, we allocate arrays as powers of two, and pad the excess + elements with zero values. + + This is similar to the usual array-based representation of a complete binary + tree, but is a little more user-friendly. + """ + + def __init__(self, capacity): + """Creates the sum tree data structure for the given replay capacity. + + Args: + capacity: int, the maximum number of elements that can be stored in this + data structure. + + Raises: + ValueError: If requested capacity is not positive. + """ + assert isinstance(capacity, int) + if capacity <= 0: + raise ValueError('Sum tree capacity should be positive. Got: {}'. + format(capacity)) + + self.nodes = [] + tree_depth = int(math.ceil(np.log2(capacity))) + level_size = 1 + for _ in range(tree_depth + 1): + nodes_at_this_depth = np.zeros(level_size) + self.nodes.append(nodes_at_this_depth) + + level_size *= 2 + + self.max_recorded_priority = 1.0 + + def _total_priority(self): + """Returns the sum of all priorities stored in this sum tree. + + Returns: + float, sum of priorities stored in this sum tree. + """ + return self.nodes[0][0] + + def sample(self, query_value=None): + """Samples an element from the sum tree. + + Each element has probability p_i / sum_j p_j of being picked, where p_i is + the (positive) value associated with node i (possibly unnormalized). + + Args: + query_value: float in [0, 1], used as the random value to select a + sample. If None, will select one randomly in [0, 1). + + Returns: + int, a random element from the sum tree. + + Raises: + Exception: If the sum tree is empty (i.e. its node values sum to 0), or if + the supplied query_value is larger than the total sum. + """ + if self._total_priority() == 0.0: + raise Exception('Cannot sample from an empty sum tree.') + + if query_value and (query_value < 0. or query_value > 1.): + raise ValueError('query_value must be in [0, 1].') + + # Sample a value in range [0, R), where R is the value stored at the root. + query_value = random.random() if query_value is None else query_value + query_value *= self._total_priority() + + # Now traverse the sum tree. + node_index = 0 + for nodes_at_this_depth in self.nodes[1:]: + # Compute children of previous depth's node. + left_child = node_index * 2 + + left_sum = nodes_at_this_depth[left_child] + # Each subtree describes a range [0, a), where a is its value. + if query_value < left_sum: # Recurse into left subtree. + node_index = left_child + else: # Recurse into right subtree. + node_index = left_child + 1 + # Adjust query to be relative to right subtree. + query_value -= left_sum + + return node_index + + def stratified_sample(self, batch_size): + """Performs stratified sampling using the sum tree. + + Let R be the value at the root (total value of sum tree). This method will + divide [0, R) into batch_size segments, pick a random number from each of + those segments, and use that random number to sample from the sum_tree. This + is as specified in Schaul et al. (2015). + + Args: + batch_size: int, the number of strata to use. + Returns: + list of batch_size elements sampled from the sum tree. + + Raises: + Exception: If the sum tree is empty (i.e. its node values sum to 0). + """ + if self._total_priority() == 0.0: + raise Exception('Cannot sample from an empty sum tree.') + + bounds = np.linspace(0., 1., batch_size + 1) + assert len(bounds) == batch_size + 1 + segments = [(bounds[i], bounds[i + 1]) for i in range(batch_size)] + # TODO removed for now + # query_values = [random.uniform(x[0], x[1]) for x in segments] + query_values = [random.uniform(0, 1) for x in segments] + return [self.sample(query_value=x) for x in query_values] + + def get(self, node_index): + """Returns the value of the leaf node corresponding to the index. + + Args: + node_index: The index of the leaf node. + Returns: + The value of the leaf node. + """ + return self.nodes[-1][node_index] + + def set(self, node_index, value): + """Sets the value of a leaf node and updates internal nodes accordingly. + + This operation takes O(log(capacity)). + Args: + node_index: int, the index of the leaf node to be updated. + value: float, the value which we assign to the node. This value must be + nonnegative. Setting value = 0 will cause the element to never be + sampled. + + Raises: + ValueError: If the given value is negative. + """ + if value < 0.0: + raise ValueError('Sum tree values should be nonnegative. Got {}'. + format(value)) + self.max_recorded_priority = max(value, self.max_recorded_priority) + + delta_value = value - self.nodes[-1][node_index] + + # Now traverse back the tree, adjusting all sums along the way. + for nodes_at_this_depth in reversed(self.nodes): + # Note: Adding a delta leads to some tolerable numerical inaccuracies. + nodes_at_this_depth[node_index] += delta_value + node_index //= 2 + + assert node_index == 0, ('Sum tree traversal failed, final node index ' + 'is not 0.') diff --git a/yarr/replay_buffer/uniform_replay_buffer.py b/yarr/replay_buffer/uniform_replay_buffer.py new file mode 100755 index 0000000..46ae6e4 --- /dev/null +++ b/yarr/replay_buffer/uniform_replay_buffer.py @@ -0,0 +1,826 @@ +'''The standard DQN replay memory. + +This implementation is an out-of-graph replay memory + in-graph wrapper. It +supports vanilla n-step updates of the form typically found in the literature, +i.e. where rewards are accumulated for n steps and the intermediate trajectory +is not exposed to the agent. This does not allow, for example, performing +off-policy corrections. +''' +import collections +import concurrent.futures +import os +from os.path import join +import pickle +from typing import List, Tuple, Type +import time +import math +# from threading import Lock +from multiprocessing import Lock +import numpy as np +import logging + +from natsort import natsort + +from yarr.replay_buffer.replay_buffer import ReplayBuffer, ReplayElement +from yarr.utils.observation_type import ObservationElement + +# Defines a type describing part of the tuple returned by the replay +# memory. Each element of the tuple is a tensor of shape [batch, ...] where +# ... is defined the 'shape' field of ReplayElement. The tensor type is +# given by the 'type' field. The 'name' field is for convenience and ease of +# debugging. + +# String constants for storage +ACTION = 'action' +REWARD = 'reward' +TERMINAL = 'terminal' +TIMEOUT = 'timeout' +INDICES = 'indices' + + +def invalid_range(cursor, replay_capacity, stack_size, update_horizon): + ''' + Returns a array with the indices of cursor-related invalid transitions. + + There are update_horizon + stack_size invalid indices: + - The update_horizon indices before the cursor, because we do not have a + valid N-step transition (including the next state). + - The stack_size indices on or immediately after the cursor. + If N = update_horizon, K = stack_size, and the cursor is at c, invalid + indices are: + c - N, c - N + 1, ..., c, c + 1, ..., c + K - 1. + + It handles special cases in a circular buffer in the beginning and the end. + + Args: + - cursor: int, the position of the cursor. + - replay_capacity: int, the size of the replay memory. + - stack_size: int, the size of the stacks returned by the replay memory. + - update_horizon: int, the agent's update horizon. + + Returns: + - np.array of size stack_size with the invalid indices. + ''' + assert cursor < replay_capacity + return np.array([(cursor - update_horizon + i) % replay_capacity + for i in range(stack_size + update_horizon)]) + + +class UniformReplayBuffer(ReplayBuffer): + ''' + A simple out-of-graph Replay Buffer. + + Stores transitions, state, action, reward, next_state, terminal (and any + extra contents specified) in a circular buffer and provides a uniform + transition sampling function. + + When the states consist of stacks of observations storing the states is + inefficient. This class writes observations and constructs the stacked states + at sample time. + + Attributes: + - add_count: int, counter of how many transitions have been added + (including the blank ones at the beginning of an episode). + - invalid_range: np.array, an array with the indices of cursor-related + invalid transitions + ''' + def __init__(self, + batch_size: int = 32, + timesteps: int = 1, + replay_capacity: int = int(1e6), + update_horizon: int = 1, + gamma: float = 0.99, + max_sample_attempts: int = 10000, + action_shape: tuple = (), + action_dtype: Type[np.dtype] = np.float32, + reward_shape: tuple = (), + reward_dtype: Type[np.dtype] = np.float32, + observation_elements: List[ObservationElement] = None, + extra_replay_elements: List[ReplayElement] = None, + save_dir: str = None, + purge_replay_on_shutdown: bool = True): + ''' + Initializes OutOfGraphReplayBuffer. + + Args: + - batch_size: int. + - timesteps: int, number of frames to use in state stack. + - replay_capacity: int, number of transitions to keep in memory. + - update_horizon: int, length of update ('n' in n-step update). + - gamma: int, the discount factor. + - max_sample_attempts: int, the maximum number of attempts allowed to + get a sample. + - action_shape: tuple of ints, the shape for the action vector. + Empty tuple means the action is a scalar. + - action_dtype: np.dtype, type of elements in the action. + - reward_shape: tuple of ints, the shape of the reward vector. + Empty tuple means the reward is a scalar. + - reward_dtype: np.dtype, type of elements in the reward. + - observation_elements: list of ObservationElement defining the type of + the extra contents that will be stored and returned. + - extra_storage_elements: list of ReplayElement defining the type of + the extra contents that will be stored and returned. + + Raises: + - ValueError: If replay_capacity is too small to hold at least one + transition. + ''' + + if observation_elements is None: + observation_elements = [] + if extra_replay_elements is None: + extra_replay_elements = [] + + if replay_capacity < update_horizon + timesteps: + raise ValueError('There is not enough capacity to cover ' + 'update_horizon and stack_size.') + + logging.info( + 'Creating a %s replay memory with the following parameters:', + self.__class__.__name__) + logging.info('\t timesteps: %d', timesteps) + logging.info('\t replay_capacity: %d', replay_capacity) + logging.info('\t batch_size: %d', batch_size) + logging.info('\t update_horizon: %d', update_horizon) + logging.info('\t gamma: %f', gamma) + + self._disk_saving = save_dir is not None + self._save_dir = save_dir + self._purge_replay_on_shutdown = purge_replay_on_shutdown + if self._disk_saving: + logging.info('\t saving to disk: %s', self._save_dir) + os.makedirs(save_dir, exist_ok=True) + else: + logging.info('\t saving to RAM') + + self._action_shape, self._action_dtype = action_shape, action_dtype + self._reward_shape, self._reward_dtype = reward_shape, reward_dtype + self._timesteps = timesteps + self._replay_capacity = replay_capacity + self._batch_size = batch_size + self._update_horizon = update_horizon + self._gamma = gamma + self._max_sample_attempts = max_sample_attempts + + self._observation_elements = observation_elements + self._extra_replay_elements = extra_replay_elements + + self._storage_signature, self._obs_signature = self.get_storage_signature( + ) + self._create_storage() + + self._lock = Lock() + self.add_count = np.array(0) + + self._replay_capacity = replay_capacity + + self.invalid_range = np.zeros((self._timesteps)) + + # When the horizon is > 1, we compute the sum of discounted rewards as a dot + # product using the precomputed vector . + self._cumulative_discount_vector = np.array( + [math.pow(self._gamma, n) for n in range(update_horizon)], + dtype=np.float32) + return + + @property + def timesteps(self): + return self._timesteps + + @property + def replay_capacity(self): + return self._replay_capacity + + @property + def batch_size(self): + return self._batch_size + + def _create_storage(self): + ''' + Creates the numpy arrays used to store transitions. + ''' + + self._store = {} + for storage_element in self._storage_signature: + array_shape = [self._replay_capacity] + list(storage_element.shape) + if storage_element.name == TERMINAL: + self._store[storage_element.name] = np.full( + array_shape, -1, dtype=storage_element.type) + elif not self._disk_saving: + # If saving to disk, we don't need to store anything else. + self._store[storage_element.name] = np.empty( + array_shape, dtype=storage_element.type) + + def get_storage_signature( + self) -> Tuple[List[ReplayElement], List[ReplayElement]]: + ''' + Returns a default list of elements to be stored in this replay memory. + + Note - Derived classes may return a different signature. + + Returns: + - dict of ReplayElements defining the type of the contents stored. + ''' + + storage_elements = [ + ReplayElement(ACTION, self._action_shape, self._action_dtype), + ReplayElement(REWARD, self._reward_shape, self._reward_dtype), + ReplayElement(TERMINAL, (), np.int8), + ReplayElement(TIMEOUT, (), np.bool), + ] + + obs_elements = [] + for obs_element in self._observation_elements: + obs_elements.append( + ReplayElement(obs_element.name, obs_element.shape, + obs_element.type)) + storage_elements.extend(obs_elements) + + for extra_replay_element in self._extra_replay_elements: + storage_elements.append(extra_replay_element) + + return storage_elements, obs_elements + + def add(self, action, reward, terminal, timeout, **kwargs): + ''' + Adds a transition to the replay memory. + + WE ONLY STORE THE TPS1s on the final frame + + This function checks the types and handles the padding at the beginning of + an episode. Then it calls the _add function. + + Since the next_observation in the transition will be the observation added + next there is no need to pass it. + + If the replay memory is at capacity the oldest transition will be discarded. + + Args: + - action: int, the action in the transition. + - reward: float, the reward received in the transition. + - terminal: A uint8 acting as a boolean indicating whether + the transition was terminal (1) or not (0). + - **kwargs: The remaining args + ''' + + # If previous transition was a terminal, then add_final wasn't called + if not self.is_empty() and self._store['terminal'][self.cursor() - + 1] == 1: + raise ValueError('The previous transition was a terminal, ' + 'but add_final was not called.') + + kwargs[ACTION] = action + kwargs[REWARD] = reward + kwargs[TERMINAL] = terminal + kwargs[TIMEOUT] = timeout + self._check_add_types(kwargs, self._storage_signature) + self._add(kwargs) + + def add_final(self, **kwargs): + ''' + Adds a transition to the replay memory. + + Args: + - **kwargs: The remaining args + ''' + if self.is_empty() or self._store['terminal'][self.cursor() - 1] != 1: + raise ValueError('The previous transition was not terminal.') + self._check_add_types(kwargs, self._obs_signature) + transition = self._final_transition(kwargs) + self._add(transition) + + def _final_transition(self, kwargs): + transition = {} + for element_type in self._storage_signature: + if element_type.name in kwargs: + transition[element_type.name] = kwargs[element_type.name] + elif element_type.name == TERMINAL: + # Used to check that user is correctly adding transitions + transition[element_type.name] = -1 + else: + transition[element_type.name] = np.empty( + element_type.shape, dtype=element_type.type) + return transition + + def _add_initial_to_disk(self, kwargs: dict): + for i in range(self._timesteps - 1): + with open( + join(self._save_dir, + '%d.replay' % (self._replay_capacity - 1 - i)), + 'wb') as f: + pickle.dump(kwargs, f) + + def _add(self, kwargs: dict): + ''' + Internal add method to add to the storage arrays. + + Args: + - kwargs: All the elements in a transition. + ''' + + with self._lock: + cursor = self.cursor() + + if self._disk_saving: + self._store[TERMINAL][cursor] = kwargs[TERMINAL] + with open(join(self._save_dir, '%d.replay' % cursor), + 'wb') as f: + pickle.dump(kwargs, f) + # If first add, then pad for correct wrapping + if self.add_count == 0: + self._add_initial_to_disk(kwargs) + else: + for name, data in kwargs.items(): + self._store[name][cursor] = data + + self.add_count += 1 + self.invalid_range = invalid_range(self.cursor(), + self._replay_capacity, + self._timesteps, + self._update_horizon) + + def _get_from_disk(self, start_index, end_index): + ''' + Returns the range of array at the index handling wraparound if necessary. + + Args: + - start_index: int, index to the start of the range to be returned. Range + will wraparound if start_index is smaller than 0. + - end_index: int, exclusive end index. Range will wraparound if end_index + exceeds replay_capacity. + + Returns: + - np.array, with shape [end_index - start_index, array.shape[1:]]. + ''' + assert end_index > start_index, 'end_index must be larger than start_index' + assert end_index >= 0 + assert start_index < self._replay_capacity + if not self.is_full(): + assert end_index <= self.cursor(), ( + 'Index {} has not been added.'.format(start_index)) + + # Here we fake a mini store (buffer) + store = { + store_element.name: {} + for store_element in self._storage_signature + } + if start_index % self._replay_capacity < end_index % self._replay_capacity: + for i in range(start_index, end_index): + with open(join(self._save_dir, '%d.replay' % i), 'rb') as f: + d = pickle.load(f) + for k, v in d.items(): + store[k][i] = v + else: + for i in range(end_index - start_index): + idx = (start_index + i) % self._replay_capacity + with open(join(self._save_dir, '%d.replay' % idx), 'rb') as f: + d = pickle.load(f) + for k, v in d.items(): + store[k][idx] = v + return store + + def _check_add_types(self, kwargs, signature): + ''' + Checks if args passed to the add method match those of the storage. + + Args: + - *args: Args whose types need to be validated. + + Raises: + - ValueError: If args have wrong shape or dtype. + ''' + + if (len(kwargs)) != len(signature): + expected = str(natsort.natsorted([e.name for e in signature])) + actual = str(natsort.natsorted(list(kwargs.keys()))) + error_list = '\nList of expected:\n{}\nList of actual:\n{}'.format( + expected, actual) + raise ValueError('Add expects {} elements, received {}.'.format( + len(signature), len(kwargs)) + error_list) + + for store_element in signature: + arg_element = kwargs[store_element.name] + if isinstance(arg_element, np.ndarray): + arg_shape = arg_element.shape + elif isinstance(arg_element, tuple) or isinstance( + arg_element, list): + # TODO: This is not efficient when arg_element is a list. + arg_shape = np.array(arg_element).shape + else: + # Assume it is scalar. + arg_shape = tuple() + store_element_shape = tuple(store_element.shape) + if arg_shape != store_element_shape: + raise ValueError('arg has shape {}, expected {}'.format( + arg_shape, store_element_shape)) + + def is_empty(self): + ''' + Is the Replay Buffer empty? + ''' + return self.add_count == 0 + + def is_full(self): + ''' + Is the Replay Buffer full? + ''' + return self.add_count >= self._replay_capacity + + def cursor(self): + ''' + Index to the location where the next transition will be written. + ''' + return self.add_count % self._replay_capacity + + def get_range(self, array, start_index, end_index): + ''' + Returns the range of array at the index handling wraparound if + necessary. + + Input: + - array (np.array): the array to get the stack from. + - start_index (int): index to the start of the range to be + returned. Range will wraparound if start_index + is smaller than 0. + - end_index (int): exclusive end index. Range will wraparound if + end_index exceeds replay_capacity. + + Output: + - return_array (np.array): with shape [end_index - start_index, + array.shape[1:]]. + ''' + + # check whether the index is valid -------------------------- + assert end_index > start_index, 'end_index must be larger than start_index' + assert end_index >= 0 + assert start_index < self._replay_capacity + + if not self.is_full(): + assert end_index <= self.cursor(), ( + 'Index {} has not been added.'.format(start_index)) + # ----------------------------------------------------------- + + # Fast slice read when there is no wraparound. + if start_index % self._replay_capacity < end_index % self._replay_capacity: + return_array = np.array( + [array[i] for i in range(start_index, end_index)]) + # Slow list read. + else: + indices = [(start_index + i) % self._replay_capacity + for i in range(end_index - start_index)] + return_array = np.array([array[i] for i in indices]) + + return return_array + + def get_range_stack(self, array, start_index, end_index, terminals=None): + ''' + Returns the range of array at the index handling wraparound if + necessary. + + Input: + - array (np.array): the array to get the stack from. + - start_index (int): index to the start of the range to be + returned. Range will wraparound if start_index + is smaller than 0. + - end_index (int): exclusive end index. Range will wraparound if + end_index exceeds replay_capacity. + + Output: + - return_array (np.array): with shape [end_index - start_index, + array.shape[1:]]. + ''' + + return_array = np.array(self.get_range(array, start_index, end_index)) + if terminals is None: + terminals = self.get_range(self._store[TERMINAL], start_index, + end_index) + + terminals = terminals[:-1] + + # Here we now check if we need to pad the front episodes + # If any have a terminal of -1, then we have spilled over + # into the the previous transition + if np.any(terminals == -1): + padding_item = return_array[-1] + _array = list(return_array)[:-1] + arr_len = len(_array) + pad_from_now = False + for i, (ar, term) in enumerate( + zip(reversed(_array), reversed(terminals))): + if term == -1 or pad_from_now: + # The first time we see a -1 term, means we have hit the + # beginning of this episode, so pad from now. + # pad_from_now needed because the next transition (reverse) + # will not be a -1 terminal. + pad_from_now = True + return_array[arr_len - 1 - i] = padding_item + else: + # After we hit out first -1 terminal, we never reassign. + padding_item = ar + + return return_array + + def _get_element_stack(self, array, index, terminals=None): + state = self.get_range_stack(array, + index - self._timesteps + 1, + index + 1, + terminals=terminals) + return state + + def get_terminal_stack(self, index): + return self.get_range(self._store[TERMINAL], + index - self._timesteps + 1, index + 1) + + def is_valid_transition(self, index): + ''' + Checks if the index contains a valid transition. + + Checks for collisions with the end of episodes and the current position + of the cursor. + + Args: + - index: int, the index to the state in the transition. + + Returns: + - Is the index valid: Boolean. + ''' + + # Check the index is in the valid range + if index < 0 or index >= self._replay_capacity: + return False + if not self.is_full(): + # The indices and next_indices must be smaller than the cursor. + if index >= self.cursor() - self._update_horizon: + return False + + # Skip transitions that straddle the cursor. + if index in set(self.invalid_range): + return False + + term_stack = self.get_terminal_stack(index) + if term_stack[-1] == -1: + return False + + return True + + def _create_batch_arrays(self, batch_size): + ''' + Create a tuple of arrays with the type of get_transition_elements. + + When using the WrappedReplayBuffer with staging enabled it is important + to create new arrays every sample because StaginArea keeps a pointer to + the returned arrays. + + Args: + - batch_size: (int) number of transitions returned. If None the default + batch_size will be used. + + Returns: + - Tuple of np.arrays with the shape and type of get_transition_elements. + ''' + + transition_elements = self.get_transition_elements(batch_size) + batch_arrays = [] + for element in transition_elements: + batch_arrays.append(np.empty(element.shape, dtype=element.type)) + return tuple(batch_arrays) + + def sample_index_batch(self, batch_size): + ''' + Returns a batch of valid indices sampled uniformly. + + Args: + - batch_size: int, number of indices returned. + + Returns: + - list of ints, a batch of valid indices sampled uniformly. + + Raises: + - RuntimeError: If the batch was not constructed after maximum number of + tries. + ''' + + if self.is_full(): + # add_count >= self._replay_capacity > self._stack_size + min_id = (self.cursor() - self._replay_capacity + self._timesteps - + 1) + max_id = self.cursor() - self._update_horizon + else: + min_id = 0 + max_id = self.cursor() - self._update_horizon + if max_id <= min_id: + raise RuntimeError( + 'Cannot sample a batch with fewer than stack size ' + '({}) + update_horizon ({}) transitions.'.format( + self._timesteps, self._update_horizon)) + + indices = [] + attempt_count = 0 + while (len(indices) < batch_size + and attempt_count < self._max_sample_attempts): + index = np.random.randint(min_id, max_id) % self._replay_capacity + if self.is_valid_transition(index): + indices.append(index) + else: + attempt_count += 1 + if len(indices) != batch_size: + raise RuntimeError( + 'Max sample attempts: Tried {} times but only sampled {}' + ' valid indices. Batch size is {}'.format( + self._max_sample_attempts, len(indices), batch_size)) + + return indices + + def unpack_transition(self, transition_tensors, transition_type): + ''' + Unpacks the given transition into member variables. To be more, + specific, we reform the transition into `collections.OrderedDict` + + Args: + - transition_tensors: tuple of tf.Tensors. + - transition_type: tuple of ReplayElements matching transition_tensors. + ''' + self.transition = collections.OrderedDict() + for element, element_type in zip(transition_tensors, transition_type): + self.transition[element_type.name] = element + return self.transition + + def sample_transition_batch(self, + batch_size=None, + indices=None, + pack_in_dict=True): + ''' + Returns a batch of transitions (including any extra contents). + + If get_transition_elements has been overridden and defines elements not + stored in self._store, an empty array will be returned and it will be + left to the child class to fill it. For example, for the child class + OutOfGraphPrioritizedReplayBuffer, the contents of the + sampling_probabilities are stored separately in a sum tree. + + When the transition is terminal next_state_batch has undefined contents. + + NOTE: This transition contains the indices of the sampled elements. + These are only valid during the call to sample_transition_batch, + i.e. they may be used by subclasses of this replay buffer but may + point to different data as soon as sampling is done. + + Args: + - batch_size: int, number of transitions returned. If None, the default + batch_size will be used. + - indices: None or list of ints, the indices of every transition in the + batch. If None, sample the indices uniformly. + + Returns: + - transition_batch: tuple of np.arrays with the shape and type as in + get_transition_elements(). + + Raises: + - ValueError: If an element to be sampled is missing from the + replay buffer. + ''' + + if batch_size is None: + batch_size = self._batch_size + with self._lock: + if indices is None: + indices = self.sample_index_batch(batch_size) + assert len(indices) == batch_size + + # get the name of the transition elements + # print('replay_buffer batch_size', batch_size) + transition_elements = self.get_transition_elements(batch_size) + + # create empty array + batch_arrays = self._create_batch_arrays(batch_size) + + for batch_element, state_index in enumerate(indices): + + if not self.is_valid_transition(state_index): + raise ValueError('Invalid index %d.' % state_index) + + # calculate the index of the next state according to + # self._update_horizon and termination of the trajectory + trajectory_indices = [(state_index + j) % self._replay_capacity + for j in range(self._update_horizon)] + trajectory_terminals = self._store['terminal'][ + trajectory_indices] + is_terminal_transition = trajectory_terminals.any() + if not is_terminal_transition: + trajectory_length = self._update_horizon + else: + # np.argmax of a bool array returns index of the first True. + trajectory_length = np.argmax( + trajectory_terminals.astype(np.bool), 0) + 1 + + next_state_index = state_index + trajectory_length + # --------------------------------------------------- + + store = self._store + if self._disk_saving: + store = self._get_from_disk( + state_index - (self._timesteps - 1), + next_state_index + 1) + + trajectory_discount_vector = ( + self._cumulative_discount_vector[:trajectory_length]) + trajectory_rewards = self.get_range(store['reward'], + state_index, + next_state_index) + + terminal_stack = self.get_terminal_stack(state_index) + terminal_stack_tp1 = self.get_terminal_stack( + next_state_index % self._replay_capacity) + + # Fill the contents of each array in the sampled batch. + assert len(transition_elements) == len(batch_arrays) + for element_array, element in zip(batch_arrays, + transition_elements): + if element.is_observation: + + # print('original:', element.name) + if element.name.endswith('tp1'): + element_array[ + batch_element] = self._get_element_stack( + store[element.name[:-4]], + next_state_index % self._replay_capacity, + terminal_stack_tp1) + # print(element.name[:-4]) + else: + element_array[ + batch_element] = self._get_element_stack( + store[element.name], state_index, + terminal_stack) + # print(element.name) + elif element.name == REWARD: + # compute discounted sum of rewards in the trajectory. + element_array[batch_element] = np.sum( + trajectory_discount_vector * trajectory_rewards, + axis=0) + elif element.name == TERMINAL: + element_array[batch_element] = is_terminal_transition + elif element.name == INDICES: + element_array[batch_element] = state_index + elif element.name in store.keys(): + element_array[batch_element] = ( + store[element.name][state_index]) + # print('-'*80) + + if pack_in_dict: + batch_arrays = self.unpack_transition(batch_arrays, + transition_elements) + return batch_arrays + + def get_transition_elements(self, batch_size=None): + ''' + Returns a 'type signature' for sample_transition_batch. + + To be more specific, we first pack basic transition information, then, + we pack `self._observation_elements` (with '_tp1') and + `self._extra_replay_elements`. + + NOTE: 'tp1' means next state + + Args: + - batch_size: int, number of transitions returned. If None, the default + batch_size will be used. + Returns: + - signature: A namedtuple describing the method's return type signature. + ''' + batch_size = self._batch_size if batch_size is None else batch_size + + transition_elements = [ + ReplayElement(ACTION, (batch_size, ) + self._action_shape, + self._action_dtype), + ReplayElement(REWARD, (batch_size, ) + self._reward_shape, + self._reward_dtype), + ReplayElement(TERMINAL, (batch_size, ), np.int8), + ReplayElement(TIMEOUT, (batch_size, ), np.bool), + ReplayElement(INDICES, (batch_size, ), np.int32) + ] + + for element in self._observation_elements: + transition_elements.append( + ReplayElement(element.name, (batch_size, self._timesteps) + + tuple(element.shape), element.type, True)) + transition_elements.append( + ReplayElement(element.name + '_tp1', + (batch_size, self._timesteps) + + tuple(element.shape), element.type, True)) + + for element in self._extra_replay_elements: + transition_elements.append( + ReplayElement(element.name, + (batch_size, ) + tuple(element.shape), + element.type)) + return transition_elements + + def shutdown(self): + if self._purge_replay_on_shutdown: + # Safely delete replay + logging.info('Clearing disk replay buffer.') + for f in [f for f in os.listdir(self._save_dir) if '.replay' in f]: + os.remove(join(self._save_dir, f)) + + def using_disk(self): + return self._disk_saving diff --git a/yarr/replay_buffer/wrappers/__init__.py b/yarr/replay_buffer/wrappers/__init__.py new file mode 100755 index 0000000..7ce7ddb --- /dev/null +++ b/yarr/replay_buffer/wrappers/__init__.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from typing import Any + +from yarr.replay_buffer.replay_buffer import ReplayBuffer + + +class WrappedReplayBuffer(ABC): + + def __init__(self, replay_buffer: ReplayBuffer): + """Initializes WrappedReplayBuffer. + + Raises: + ValueError: If update_horizon is not positive. + ValueError: If discount factor is not in [0, 1]. + """ + self._replay_buffer = replay_buffer + + @property + def replay_buffer(self): + return self._replay_buffer + + @abstractmethod + def dataset(self) -> Any: + pass \ No newline at end of file diff --git a/yarr/replay_buffer/wrappers/pytorch_replay_buffer.py b/yarr/replay_buffer/wrappers/pytorch_replay_buffer.py new file mode 100755 index 0000000..0917660 --- /dev/null +++ b/yarr/replay_buffer/wrappers/pytorch_replay_buffer.py @@ -0,0 +1,42 @@ +import time +from threading import Lock, Thread + +from torch.utils.data import IterableDataset, DataLoader + +from yarr.replay_buffer.replay_buffer import ReplayBuffer +from yarr.replay_buffer.wrappers import WrappedReplayBuffer + + +class PyTorchIterableReplayDataset(IterableDataset): + def __init__(self, replay_buffer: ReplayBuffer): + self._replay_buffer = replay_buffer + + def _generator(self): + while True: + yield self._replay_buffer.sample_transition_batch( + pack_in_dict=True) + + def __iter__(self): + return iter(self._generator()) + + +class PyTorchReplayBuffer(WrappedReplayBuffer): + """Wrapper of OutOfGraphReplayBuffer with an in graph sampling mechanism. + + Usage: + To add a transition: call the `add` function. + + To sample a batch: Construct operations that depend on any of the \\ + tensors is the transition dictionary. Every \\ + sess.run that requires any of these tensors will \\ + sample a new transition. + """ + def __init__(self, replay_buffer: ReplayBuffer, num_workers: int = 2): + super(PyTorchReplayBuffer, self).__init__(replay_buffer) + self._num_workers = num_workers + + def dataset(self) -> DataLoader: + # d = PyTorchIterableReplayDataset(self._replay_buffer, self._num_workers) + d = PyTorchIterableReplayDataset(self._replay_buffer) + # Batch size None disables automatic batching + return DataLoader(d, batch_size=None, pin_memory=True) diff --git a/yarr/replay_buffer/wrappers/pytorch_replay_buffer_backup.py b/yarr/replay_buffer/wrappers/pytorch_replay_buffer_backup.py new file mode 100755 index 0000000..5a4af85 --- /dev/null +++ b/yarr/replay_buffer/wrappers/pytorch_replay_buffer_backup.py @@ -0,0 +1,79 @@ +import time +from threading import Lock, Thread + +from torch.utils.data import IterableDataset, DataLoader + +from yarr.replay_buffer.replay_buffer import ReplayBuffer +from yarr.replay_buffer.wrappers import WrappedReplayBuffer + + +class PyTorchIterableReplayDataset(IterableDataset): + + def __init__(self, replay_buffer: ReplayBuffer): + self._replay_buffer = replay_buffer + + def _generator(self): + while True: + yield self._replay_buffer.sample_transition_batch(pack_in_dict=True) + + def __iter__(self): + return iter(self._generator()) + +# class PyTorchIterableReplayDataset(IterableDataset): +# +# BUFFER = 4 +# +# def __init__(self, replay_buffer: ReplayBuffer, num_workers: int): +# self._replay_buffer = replay_buffer +# self._num_wokers = num_workers +# self._samples = [] +# self._lock = Lock() +# +# def _run(self): +# while True: +# # Check if replay buffer is ig enough to be sampled +# while self._replay_buffer.add_count < self._replay_buffer.batch_size: +# time.sleep(1.) +# s = self._replay_buffer.sample_transition_batch(pack_in_dict=True) +# while len(self._samples) >= PyTorchIterableReplayDataset.BUFFER: +# time.sleep(0.25) +# with self._lock: +# self._samples.append(s) +# +# def _generator(self): +# ts = [Thread( +# target=self._run, args=()) for _ in range(self._num_wokers)] +# [t.start() for t in ts] +# while True: +# while len(self._samples) == 0: +# time.sleep(0.1) +# with self._lock: +# s = self._samples.pop(0) +# yield s +# +# def __iter__(self): +# i = iter(self._generator()) +# return i + + +class PyTorchReplayBuffer(WrappedReplayBuffer): + """Wrapper of OutOfGraphReplayBuffer with an in graph sampling mechanism. + + Usage: + To add a transition: call the `add` function. + + To sample a batch: Construct operations that depend on any of the \\ + tensors is the transition dictionary. Every \\ + sess.run that requires any of these tensors will \\ + sample a new transition. + """ + + def __init__(self, replay_buffer: ReplayBuffer, num_workers: int = 2): + super(PyTorchReplayBuffer, self).__init__(replay_buffer) + self._num_workers = num_workers + + def dataset(self) -> DataLoader: + # d = PyTorchIterableReplayDataset(self._replay_buffer, self._num_workers) + d = PyTorchIterableReplayDataset(self._replay_buffer) + # Batch size None disables automatic batching + return DataLoader(d, batch_size=None, pin_memory=True) diff --git a/yarr/runners/__init__.py b/yarr/runners/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/yarr/runners/_env_runner.py b/yarr/runners/_env_runner.py new file mode 100755 index 0000000..a218dc3 --- /dev/null +++ b/yarr/runners/_env_runner.py @@ -0,0 +1,189 @@ +import copy +import logging +import os +import time +from multiprocessing import Process, Manager +from typing import Any + +import numpy as np +from yarr.agents.agent import Agent +from yarr.envs.env import Env +from yarr.utils.rollout_generator import RolloutGenerator + +# try: +# if get_start_method() != 'spawn': +# set_start_method('spawn', force=True) +# except RuntimeError: +# pass + + +class _EnvRunner(object): + def __init__(self, + train_env: Env, + eval_env: Env, + agent: Agent, + timesteps: int, + train_envs: int, + eval_envs: int, + episodes: int, + episode_length: int, + kill_signal: Any, + step_signal: Any, + rollout_generator: RolloutGenerator, + save_load_lock, + current_replay_ratio, + target_replay_ratio, + weightsdir: str = None): + self._train_env, self._eval_env = train_env, eval_env + self._agent = agent + self._train_envs, self._eval_envs = train_envs, eval_envs + self._episodes, self._episode_length = episodes, episode_length + self._rollout_generator = rollout_generator + self._weightsdir = weightsdir + self._previous_loaded_weight_folder = '' + + self._timesteps = timesteps + + self._p_args = {} + self.p_failures = {} + manager = Manager() + self.write_lock = manager.Lock() + self.stored_transitions = manager.list() + self.agent_summaries = manager.list() + self._kill_signal = kill_signal + self._step_signal = step_signal + self._save_load_lock = save_load_lock + self._current_replay_ratio = current_replay_ratio + self._target_replay_ratio = target_replay_ratio + + def restart_process(self, name: str): + ''' + restart a process to run the environment + + Input: + - name: the nickname/tag for this environment + ''' + + p = Process(target=self._run_env, args=self._p_args[name], name=name) + p.start() + return p + + def spin_up_envs(self, name: str, num_envs: int, eval: bool): + ''' + create subprocess for running the environment + + Input: + - name: the nickname/tag for this environment + - num_envs: number of environment we would like to initiate + - eval: whether the environment is in `eval` mode + + Output: + - ps: the subprocesses + ''' + + ps = [] + for i in range(num_envs): + n = name + str(i) + self._p_args[n] = (n, eval) + self.p_failures[n] = 0 + p = Process(target=self._run_env, args=self._p_args[n], name=n) + p.start() + ps.append(p) + return ps + + def _load_save(self): + if self._weightsdir is None: + logging.info("'weightsdir' was None, so not loading weights.") + return + while True: + weight_folders = [] + with self._save_load_lock: + if os.path.exists(self._weightsdir): + weight_folders = os.listdir(self._weightsdir) + if len(weight_folders) > 0: + weight_folders = sorted(map(int, weight_folders)) + # Only load if there has been a new weight saving + if self._previous_loaded_weight_folder != weight_folders[ + -1]: + self._previous_loaded_weight_folder = weight_folders[ + -1] + d = os.path.join(self._weightsdir, + str(weight_folders[-1])) + try: + self._agent.load_weights(d) + except FileNotFoundError: + # Rare case when agent hasn't finished writing. + time.sleep(1) + self._agent.load_weights(d) + logging.info('Agent %s: Loaded weights: %s' % + (self._name, d)) + break + logging.info('Waiting for weights to become available.') + time.sleep(1) + + def _get_type(self, x): + if x.dtype == np.float64: + return np.float32 + return x.dtype + + def _run_env(self, name: str, eval: bool): + + self._name = name + self._agent = copy.deepcopy(self._agent) + self._agent.build(training=False) + + logging.info('%s: Launching env.' % name) + np.random.seed() + + logging.info('Agent information:') + logging.info(self._agent) + + env = self._train_env + if eval: + env = self._eval_env + env.eval = eval + env.launch() + for ep in range(self._episodes): + self._load_save() + logging.debug('%s: Starting episode %d.' % (name, ep)) + episode_rollout = [] + generator = self._rollout_generator.generator( + self._step_signal, env, self._agent, self._episode_length, + self._timesteps, eval) + try: + for replay_transition in generator: + while True: + if self._kill_signal.value: + env.shutdown() + return + if (eval or self._target_replay_ratio is None + or self._step_signal.value <= 0 + or (self._current_replay_ratio.value > + self._target_replay_ratio)): + break + time.sleep(1) + logging.debug( + 'Agent. Waiting for replay_ratio %f to be more than %f' + % (self._current_replay_ratio.value, + self._target_replay_ratio)) + + with self.write_lock: + if len(self.agent_summaries) == 0: + # Only store new summaries if the previous ones + # have been popped by the main env runner. + for s in self._agent.act_summaries(): + self.agent_summaries.append(s) + episode_rollout.append(replay_transition) + except StopIteration as e: + continue + except Exception as e: + env.shutdown() + raise e + + with self.write_lock: + for transition in episode_rollout: + self.stored_transitions.append((name, transition, eval)) + env.shutdown() + + def kill(self): + self._kill_signal.value = True diff --git a/yarr/runners/env_runner.py b/yarr/runners/env_runner.py new file mode 100755 index 0000000..e5cce63 --- /dev/null +++ b/yarr/runners/env_runner.py @@ -0,0 +1,169 @@ +import collections +import logging +import os +import signal +import time +from multiprocessing import Value +from threading import Thread +from typing import List +from typing import Union + +import numpy as np + +from yarr.agents.agent import Agent +from yarr.agents.agent import ScalarSummary +from yarr.agents.agent import Summary +from yarr.envs.env import Env +from yarr.replay_buffer.replay_buffer import ReplayBuffer +from yarr.runners._env_runner import _EnvRunner +from yarr.utils.rollout_generator import RolloutGenerator +from yarr.utils.stat_accumulator import StatAccumulator + + +class EnvRunner(object): + def __init__(self, + env: Env, + agent: Agent, + replay_buffer: ReplayBuffer, + train_envs: int, + eval_envs: int, + episodes: int, + episode_length: int, + stat_accumulator: Union[StatAccumulator, None] = None, + rollout_generator: RolloutGenerator = None, + weightsdir: str = None, + max_fails: int = 10): + self._env, self._agent = env, agent + self._train_envs, self._eval_envs = train_envs, eval_envs + self._replay_buffer = replay_buffer + self._episodes = episodes + self._episode_length = episode_length + self._stat_accumulator = stat_accumulator + self._rollout_generator = (RolloutGenerator() + if rollout_generator is None else + rollout_generator) + self._weightsdir = weightsdir + self._max_fails = max_fails + self._previous_loaded_weight_folder = '' + self._p = None + self._kill_signal = Value('b', 0) + self._step_signal = Value('i', -1) + self._new_transitions = {'train_envs': 0, 'eval_envs': 0} + self._total_transitions = {'train_envs': 0, 'eval_envs': 0} + self.log_freq = 1000 # Will get overridden later + self.target_replay_ratio = None # Will get overridden later + self.current_replay_ratio = Value('f', -1) + + def summaries(self) -> List[Summary]: + summaries = [] + if self._stat_accumulator is not None: + summaries.extend(self._stat_accumulator.pop()) + for key, value in self._new_transitions.items(): + summaries.append(ScalarSummary('%s/new_transitions' % key, value)) + for key, value in self._total_transitions.items(): + summaries.append(ScalarSummary('%s/total_transitions' % key, + value)) + self._new_transitions = {'train_envs': 0, 'eval_envs': 0} + summaries.extend(self._agent_summaries) + return summaries + + def _update(self): + ''' + Move the stored transitions to the replay and accumulate statistics. + ''' + + new_transitions = collections.defaultdict(int) + with self._internal_env_runner.write_lock: + self._agent_summaries = list( + self._internal_env_runner.agent_summaries) + if self._step_signal.value % self.log_freq == 0 and self._step_signal.value > 0: + self._internal_env_runner.agent_summaries[:] = [] + for name, transition, eval in self._internal_env_runner.stored_transitions: + if not eval: + kwargs = dict(transition.observation) + self._replay_buffer.add(np.array(transition.action), + transition.reward, + transition.terminal, + transition.timeout, **kwargs) + if transition.terminal: + self._replay_buffer.add_final( + **transition.final_observation) + new_transitions[name] += 1 + self._new_transitions[ + 'eval_envs' if eval else 'train_envs'] += 1 + self._total_transitions[ + 'eval_envs' if eval else 'train_envs'] += 1 + if self._stat_accumulator is not None: + self._stat_accumulator.step(transition, eval) + self._internal_env_runner.stored_transitions[:] = [] # Clear list + return new_transitions + + def _run(self, save_load_lock): + self._internal_env_runner = _EnvRunner( + self._env, self._env, self._agent, self._replay_buffer.timesteps, + self._train_envs, self._eval_envs, self._episodes, + self._episode_length, self._kill_signal, self._step_signal, + self._rollout_generator, save_load_lock, self.current_replay_ratio, + self.target_replay_ratio, self._weightsdir) + training_envs = self._internal_env_runner.spin_up_envs( + 'train_env', self._train_envs, False) + eval_envs = self._internal_env_runner.spin_up_envs( + 'eval_env', self._eval_envs, True) + envs = training_envs + eval_envs + no_transitions = {env.name: 0 for env in envs} + while True: + for p in envs: + if p.exitcode is not None: + envs.remove(p) + if p.exitcode != 0: + self._internal_env_runner.p_failures[p.name] += 1 + n_failures = self._internal_env_runner.p_failures[ + p.name] + if n_failures > self._max_fails: + logging.error( + 'Env %s failed too many times (%d times > %d)' + % (p.name, n_failures, self._max_fails)) + raise RuntimeError('Too many process failures.') + logging.warning( + 'Env %s failed (%d times <= %d). restarting' % + (p.name, n_failures, self._max_fails)) + p = self._internal_env_runner.restart_process(p.name) + envs.append(p) + + if not self._kill_signal.value: + new_transitions = self._update() + for p in envs: + if new_transitions[p.name] == 0: + no_transitions[p.name] += 1 + else: + no_transitions[p.name] = 0 + if no_transitions[p.name] > 600: # 5min + logging.warning("Env %s hangs, so restarting" % p.name) + envs.remove(p) + os.kill(p.pid, signal.SIGTERM) + p = self._internal_env_runner.restart_process(p.name) + envs.append(p) + no_transitions[p.name] = 0 + + if len(envs) == 0: + break + time.sleep(1) + + def start(self, save_load_lock): + self._p = Thread(target=self._run, + args=(save_load_lock, ), + daemon=True) + self._p.name = 'EnvRunnerThread' + self._p.start() + + def wait(self): + if self._p.is_alive(): + self._p.join() + + def stop(self): + if self._p.is_alive(): + self._kill_signal.value = True + self._p.join() + + def set_step(self, step): + self._step_signal.value = step diff --git a/yarr/runners/pytorch_train_runner.py b/yarr/runners/pytorch_train_runner.py new file mode 100755 index 0000000..f5b347e --- /dev/null +++ b/yarr/runners/pytorch_train_runner.py @@ -0,0 +1,264 @@ +import copy +import logging +import os +import shutil +import signal +import sys +import threading +import time +from multiprocessing import Lock +from typing import Optional, List +from typing import Union + +import numpy as np +import psutil +import torch +from yarr.agents.agent import Agent +from yarr.replay_buffer.wrappers.pytorch_replay_buffer import \ + PyTorchReplayBuffer +from yarr.runners.env_runner import EnvRunner +from yarr.runners.train_runner import TrainRunner +from yarr.utils.log_writer import LogWriter +from yarr.utils.stat_accumulator import StatAccumulator + +NUM_WEIGHTS_TO_KEEP = 10 + + +class PyTorchTrainRunner(TrainRunner): + def __init__(self, + agent: Agent, + env_runner: EnvRunner, + wrapped_replay_buffer: Union[PyTorchReplayBuffer, + List[PyTorchReplayBuffer]], + train_device: torch.device, + replay_buffer_sample_rates: List[float] = None, + stat_accumulator: Union[StatAccumulator, None] = None, + iterations: int = int(1e6), + logdir: str = '/tmp/yarr/logs', + log_freq: int = 10, + transitions_before_train: int = 1000, + weightsdir: str = '/tmp/yarr/weights', + save_freq: int = 100, + replay_ratio: Optional[float] = None, + tensorboard_logging: bool = True, + csv_logging: bool = False): + super(PyTorchTrainRunner, + self).__init__(agent, env_runner, wrapped_replay_buffer, + stat_accumulator, iterations, logdir, log_freq, + transitions_before_train, weightsdir, save_freq) + + env_runner.log_freq = log_freq + env_runner.target_replay_ratio = replay_ratio + self._wrapped_buffer = wrapped_replay_buffer if isinstance( + wrapped_replay_buffer, list) else [wrapped_replay_buffer] + self._replay_buffer_sample_rates = ([1.0] if + replay_buffer_sample_rates is None + else replay_buffer_sample_rates) + if len(self._replay_buffer_sample_rates) != len(wrapped_replay_buffer): + raise ValueError( + 'Numbers of replay buffers differs from sampling rates.') + if sum(self._replay_buffer_sample_rates) != 1: + raise ValueError('Sum of sampling rates should be 1.') + + self._train_device = train_device + self._tensorboard_logging = tensorboard_logging + self._csv_logging = csv_logging + + if replay_ratio is not None and replay_ratio < 0: + raise ValueError("max_replay_ratio must be positive.") + self._target_replay_ratio = replay_ratio + + self._writer = None + if logdir is None: + logging.info("'logdir' was None. No logging will take place.") + else: + self._writer = LogWriter(self._logdir, tensorboard_logging, + csv_logging) + if weightsdir is None: + logging.info( + "'weightsdir' was None. No weight saving will take place.") + else: + os.makedirs(self._weightsdir, exist_ok=True) + + def _save_model(self, i): + ''' + save the weight of the agent/policy + + Input: + - i: the step index for weight saving + ''' + + with self._save_load_lock: + d = os.path.join(self._weightsdir, str(i)) + os.makedirs(d, exist_ok=True) + self._agent.save_weights(d) + # Remove oldest save + prev_dir = os.path.join( + self._weightsdir, + str(i - self._save_freq * NUM_WEIGHTS_TO_KEEP)) + if os.path.exists(prev_dir): + shutil.rmtree(prev_dir) + + def _step(self, i, sampled_batch): + ''' + update step for the agent with the sample batch transition + + Input: + - i: + - sampled_batch: + ''' + + update_dict = self._agent.update(i, sampled_batch) + acc_bs = 0 + for wb in self._wrapped_buffer: + bs = wb.replay_buffer.batch_size + if 'priority' in update_dict: + wb.replay_buffer.set_priority( + sampled_batch['indices'][acc_bs:acc_bs + + bs].cpu().detach().numpy(), + update_dict['priority'][acc_bs:acc_bs + bs]) + acc_bs += bs + + def _signal_handler(self, sig, frame): + if threading.current_thread().name != 'MainThread': + return + logging.info('SIGINT captured. Shutting down.' + 'This may take a few seconds.') + self._env_runner.stop() + [r.replay_buffer.shutdown() for r in self._wrapped_buffer] + sys.exit(0) + + def _get_add_counts(self): + return np.array( + [r.replay_buffer.add_count for r in self._wrapped_buffer]) + + def _get_sum_add_counts(self): + return sum([r.replay_buffer.add_count for r in self._wrapped_buffer]) + + def start(self): + ''' + start the RL learning process + ''' + + signal.signal(signal.SIGINT, self._signal_handler) + + self._save_load_lock = Lock() + + # Kick off the environments + self._env_runner.start(self._save_load_lock) + + self._agent = copy.deepcopy(self._agent) + self._agent.build(training=True, device=self._train_device) + + if self._weightsdir is not None: + self._save_model(0) # Save weights so workers can load. + + while (np.any( + self._get_add_counts() < self._transitions_before_train)): + time.sleep(1) + logging.info( + 'Waiting for %d samples before training. Currently have %s.' % + (self._transitions_before_train, str(self._get_add_counts()))) + + datasets = [r.dataset() for r in self._wrapped_buffer] + data_iter = [iter(d) for d in datasets] + + init_replay_size = self._get_sum_add_counts().astype(float) + batch_size = sum( + [r.replay_buffer.batch_size for r in self._wrapped_buffer]) + process = psutil.Process(os.getpid()) + num_cpu = psutil.cpu_count() + + for i in range(self._iterations): + self._env_runner.set_step(i) + + log_iteration = i % self._log_freq == 0 and i > 0 + + if log_iteration: + process.cpu_percent(interval=None) + + def get_replay_ratio(): + size_used = batch_size * i + size_added = (self._get_sum_add_counts() - init_replay_size) + replay_ratio = size_used / (size_added + 1e-6) + return replay_ratio + + if self._target_replay_ratio is not None: + # wait for env_runner collecting enough samples + while True: + replay_ratio = get_replay_ratio() + self._env_runner.current_replay_ratio.value = replay_ratio + if replay_ratio < self._target_replay_ratio: + break + time.sleep(1) + logging.debug( + 'Waiting for replay_ratio %f to be less than %f.' % + (replay_ratio, self._target_replay_ratio)) + del replay_ratio + + t = time.time() + sampled_batch = [next(di) for di in data_iter] + + if len(sampled_batch) > 1: + result = {} + for key in sampled_batch[0]: + result[key] = torch.cat([d[key] for d in sampled_batch], 0) + sampled_batch = result + else: + sampled_batch = sampled_batch[0] + + sample_time = time.time() - t + batch = { + k: v.to(self._train_device) + for k, v in sampled_batch.items() + } + t = time.time() + self._step(i, batch) + step_time = time.time() - t + + if log_iteration and self._writer is not None: + replay_ratio = get_replay_ratio() + logging.info( + 'Step %d. Sample time: %s. Step time: %s. Replay ratio: %s.' + % (i, sample_time, step_time, replay_ratio)) + agent_summaries = self._agent.update_summaries() + env_summaries = self._env_runner.summaries() + self._writer.add_summaries(i, agent_summaries + env_summaries) + + for r_i, wrapped_buffer in enumerate(self._wrapped_buffer): + self._writer.add_scalar( + i, 'replay%d/add_count' % r_i, + wrapped_buffer.replay_buffer.add_count) + self._writer.add_scalar( + i, 'replay%d/size' % r_i, + wrapped_buffer.replay_buffer.replay_capacity + if wrapped_buffer.replay_buffer.is_full() else + wrapped_buffer.replay_buffer.add_count) + + self._writer.add_scalar(i, 'replay/replay_ratio', replay_ratio) + self._writer.add_scalar( + i, 'replay/update_to_insert_ratio', + float(i) / float(self._get_sum_add_counts() - + init_replay_size + 1e-6)) + + self._writer.add_scalar(i, 'monitoring/sample_time_per_item', + sample_time / batch_size) + self._writer.add_scalar(i, 'monitoring/train_time_per_item', + step_time / batch_size) + self._writer.add_scalar(i, 'monitoring/memory_gb', + process.memory_info().rss * 1e-9) + self._writer.add_scalar( + i, 'monitoring/cpu_percent', + process.cpu_percent(interval=None) / num_cpu) + + self._writer.end_iteration() + + if i % self._save_freq == 0 and self._weightsdir is not None: + self._save_model(i) + + if self._writer is not None: + self._writer.close() + + logging.info('Stopping envs ...') + self._env_runner.stop() + [r.replay_buffer.shutdown() for r in self._wrapped_buffer] diff --git a/yarr/runners/train_runner.py b/yarr/runners/train_runner.py new file mode 100755 index 0000000..41d4e6a --- /dev/null +++ b/yarr/runners/train_runner.py @@ -0,0 +1,34 @@ +from abc import abstractmethod, ABC +from typing import Union, List + +from yarr.agents.agent import Agent +from yarr.replay_buffer.wrappers import WrappedReplayBuffer +from yarr.runners.env_runner import EnvRunner +from yarr.utils.stat_accumulator import StatAccumulator + + +class TrainRunner(ABC): + + def __init__(self, + agent: Agent, + env_runner: EnvRunner, + wrapped_replay_buffer: WrappedReplayBuffer, + stat_accumulator: Union[StatAccumulator, None] = None, + iterations: int = int(1e6), + logdir: str = '/tmp/yarr/logs', + log_freq: int = 500, + transitions_before_train: int = 1000, + weightsdir: str = '/tmp/yarr/weights', + save_freq: int = 100, + ): + self._agent, self._env_runner = agent, env_runner + self._wrapped_buffer = wrapped_replay_buffer + self._stat_accumulator = stat_accumulator + self._iterations = iterations + self._logdir, self._log_freq = logdir, log_freq + self._transitions_before_train = transitions_before_train + self._weightsdir, self._save_freq = weightsdir, save_freq + + @abstractmethod + def start(self): + pass diff --git a/yarr/utils/__init__.py b/yarr/utils/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/yarr/utils/log_writer.py b/yarr/utils/log_writer.py new file mode 100755 index 0000000..f9ccff1 --- /dev/null +++ b/yarr/utils/log_writer.py @@ -0,0 +1,84 @@ +import csv +import logging +import os +from collections import OrderedDict + +import numpy as np +import torch +from yarr.agents.agent import ScalarSummary, HistogramSummary, ImageSummary, \ + VideoSummary +from torch.utils.tensorboard import SummaryWriter + + +class LogWriter(object): + + def __init__(self, + logdir: str, + tensorboard_logging: bool, + csv_logging: bool): + self._tensorboard_logging = tensorboard_logging + self._csv_logging = csv_logging + os.makedirs(logdir, exist_ok=True) + if tensorboard_logging: + self._tf_writer = SummaryWriter(logdir) + if csv_logging: + self._prev_row_data = self._row_data = OrderedDict() + self._csv_file = os.path.join(logdir, 'data.csv') + self._field_names = None + + def add_scalar(self, i, name, value): + if self._tensorboard_logging: + self._tf_writer.add_scalar(name, value, i) + if self._csv_logging: + if len(self._row_data) == 0: + self._row_data['step'] = i + self._row_data[name] = value.item() if isinstance( + value, torch.Tensor) else value + + def add_summaries(self, i, summaries): + for summary in summaries: + try: + if isinstance(summary, ScalarSummary): + self.add_scalar(i, summary.name, summary.value) + elif self._tensorboard_logging: + if isinstance(summary, HistogramSummary): + self._tf_writer.add_histogram( + summary.name, summary.value, i) + elif isinstance(summary, ImageSummary): + # Only grab first item in batch + v = (summary.value if summary.value.ndim == 3 else + summary.value[0]) + self._tf_writer.add_image(summary.name, v, i) + elif isinstance(summary, VideoSummary): + # Only grab first item in batch + v = (summary.value if summary.value.ndim == 5 else + np.array([summary.value])) + self._tf_writer.add_video( + summary.name, v, i, fps=summary.fps) + except Exception as e: + logging.error('Error on summary: %s' % summary.name) + raise e + + def end_iteration(self): + if self._csv_logging and len(self._row_data) > 0: + with open(self._csv_file, mode='a+') as csv_f: + names = self._field_names or self._row_data.keys() + writer = csv.DictWriter(csv_f, fieldnames=names) + if self._field_names is None: + writer.writeheader() + else: + if not np.array_equal(self._field_names, self._row_data.keys()): + # Special case when we are logging faster than new + # summaries are coming in. + missing_keys = list(set(self._field_names) - set( + self._row_data.keys())) + for mk in missing_keys: + self._row_data[mk] = self._prev_row_data[mk] + self._field_names = names + writer.writerow(self._row_data) + self._prev_row_data = self._row_data + self._row_data = OrderedDict() + + def close(self): + if self._tensorboard_logging: + self._tf_writer.close() diff --git a/yarr/utils/multi_task_rollout_generator.py b/yarr/utils/multi_task_rollout_generator.py new file mode 100755 index 0000000..9c66549 --- /dev/null +++ b/yarr/utils/multi_task_rollout_generator.py @@ -0,0 +1,65 @@ +from multiprocessing import Value + +import numpy as np + +from yarr.agents.agent import Agent +from yarr.envs.env import Env +from yarr.envs.multi_task_env import MultiTaskEnv +from yarr.utils.transition import ReplayTransition + + +class RolloutGenerator(object): + + def _get_type(self, x): + if x.dtype == np.float64: + return np.float32 + return x.dtype + + def generator(self, step_signal: Value, env: MultiTaskEnv, agent: Agent, + episode_length: int, timesteps: int, eval: bool): + obs = env.reset() + agent.reset() + obs_history = {k: [np.array(v, dtype=self._get_type(v))] * timesteps for k, v in obs.items()} + for step in range(episode_length): + + prepped_data = {k: np.array([v]) for k, v in obs_history.items()} + + act_result = agent.act(step_signal.value, prepped_data, + deterministic=eval) + + # Convert to np if not already + agent_obs_elems = {k: np.array(v) for k, v in + act_result.observation_elements.items()} + agent_extra_elems = {k: np.array(v) for k, v in + act_result.replay_elements.items()} + + transition = env.step(act_result) + timeout = False + if step == episode_length - 1: + # If last transition, and not terminal, then we timed out + timeout = not transition.terminal + if timeout: + transition.terminal = True + if "needs_reset" in transition.info: + transition.info["needs_reset"] = True + + obs.update(agent_obs_elems) + obs_tp1 = dict(transition.observation) + + for k in obs_history.keys(): + obs_history[k].append(transition.observation[k]) + obs_history[k].pop(0) + + transition.info["active_task_id"] = env.active_task_id + + replay_transition = ReplayTransition( + obs, act_result.action, transition.reward, + transition.terminal, + timeout, obs_tp1, agent_extra_elems, + transition.info) + + obs = transition.observation + yield replay_transition + + if transition.info.get("needs_reset", transition.terminal): + return diff --git a/yarr/utils/observation_type.py b/yarr/utils/observation_type.py new file mode 100755 index 0000000..bc2dd5a --- /dev/null +++ b/yarr/utils/observation_type.py @@ -0,0 +1,10 @@ +from typing import Type +import numpy as np + + +class ObservationElement(object): + + def __init__(self, name: str, shape: tuple, type: Type[np.dtype]): + self.name = name + self.shape = shape + self.type = type diff --git a/yarr/utils/rollout_generator.py b/yarr/utils/rollout_generator.py new file mode 100755 index 0000000..9725e08 --- /dev/null +++ b/yarr/utils/rollout_generator.py @@ -0,0 +1,91 @@ +from multiprocessing import Value + +import numpy as np +from yarr.agents.agent import Agent +from yarr.envs.env import Env +from yarr.utils.transition import ReplayTransition + + +class RolloutGenerator(object): + def _get_type(self, x): + if x.dtype == np.float64: + return np.float32 + return x.dtype + + def generator(self, step_signal: Value, env: Env, agent: Agent, + episode_length: int, timesteps: int, eval: bool): + obs = env.reset() + agent.reset() + obs_history = { + k: [np.array(v, dtype=self._get_type(v))] * timesteps + for k, v in obs.items() + } + for step in range(episode_length): + + prepped_data = {k: np.array([v]) for k, v in obs_history.items()} + + act_result = agent.act(step_signal.value, + prepped_data, + deterministic=eval) + + # Convert to np if not already + agent_obs_elems = { + k: np.array(v) + for k, v in act_result.observation_elements.items() + } + extra_replay_elements = { + k: np.array(v) + for k, v in act_result.replay_elements.items() + } + + transition = env.step(act_result) + obs_tp1 = dict(transition.observation) + timeout = False + if step == episode_length - 1: + # If last transition, and not terminal, then we timed out + timeout = not transition.terminal + if timeout: + transition.terminal = True + if "needs_reset" in transition.info: + transition.info["needs_reset"] = True + + obs_and_replay_elems = {} + obs_and_replay_elems.update(obs) + obs_and_replay_elems.update(agent_obs_elems) + obs_and_replay_elems.update(extra_replay_elements) + + for k in obs_history.keys(): + obs_history[k].append(transition.observation[k]) + obs_history[k].pop(0) + + replay_transition = ReplayTransition( + obs_and_replay_elems, + act_result.action, + transition.reward, + transition.terminal, + timeout, + summaries=transition.summaries) + + if transition.terminal or timeout: + # If the agent gives us observations then we need to call act + # one last time (i.e. acting in the terminal state). + if len(act_result.observation_elements) > 0: + prepped_data = { + k: np.array([v]) + for k, v in obs_history.items() + } + act_result = agent.act(step_signal.value, + prepped_data, + deterministic=eval) + agent_obs_elems_tp1 = { + k: np.array(v) + for k, v in act_result.observation_elements.items() + } + obs_tp1.update(agent_obs_elems_tp1) + replay_transition.final_observation = obs_tp1 + + obs = dict(transition.observation) + yield replay_transition + + if transition.info.get("needs_reset", transition.terminal): + return diff --git a/yarr/utils/stat_accumulator.py b/yarr/utils/stat_accumulator.py new file mode 100755 index 0000000..395e39d --- /dev/null +++ b/yarr/utils/stat_accumulator.py @@ -0,0 +1,193 @@ +from multiprocessing import Lock +from typing import List + +import numpy as np +from yarr.agents.agent import Summary, ScalarSummary +from yarr.utils.transition import ReplayTransition + + +class StatAccumulator(object): + + def step(self, transition: ReplayTransition, eval: bool): + pass + + def pop(self) -> List[Summary]: + pass + + def peak(self) -> List[Summary]: + pass + + def reset(self) -> None: + pass + + +class Metric(object): + + def __init__(self): + self._previous = [] + self._current = 0 + + def update(self, value): + self._current += value + + def next(self): + self._previous.append(self._current) + self._current = 0 + + def reset(self): + self._previous.clear() + + def min(self): + return np.min(self._previous) + + def max(self): + return np.max(self._previous) + + def mean(self): + return np.mean(self._previous) + + def median(self): + return np.median(self._previous) + + def std(self): + return np.std(self._previous) + + def __len__(self): + return len(self._previous) + + def __getitem__(self, i): + return self._previous[i] + + +class _SimpleAccumulator(StatAccumulator): + + def __init__(self, prefix, eval_video_fps: int = 30, + mean_only: bool = True): + self._prefix = prefix + self._eval_video_fps = eval_video_fps + self._mean_only = mean_only + self._lock = Lock() + self._episode_returns = Metric() + self._episode_lengths = Metric() + self._summaries = [] + self._transitions = 0 + + def _reset_data(self): + with self._lock: + self._episode_returns.reset() + self._episode_lengths.reset() + self._summaries.clear() + + def step(self, transition: ReplayTransition, eval: bool): + with self._lock: + self._transitions += 1 + self._episode_returns.update(transition.reward) + self._episode_lengths.update(1) + if transition.terminal: + self._episode_returns.next() + self._episode_lengths.next() + self._summaries.extend(list(transition.summaries)) + + def _get(self) -> List[Summary]: + sums = [] + + if self._mean_only: + stat_keys = ["mean"] + else: + stat_keys = ["min", "max", "mean", "median", "std"] + names = ["return", "length"] + metrics = [self._episode_returns, self._episode_lengths] + for name, metric in zip(names, metrics): + for stat_key in stat_keys: + if self._mean_only: + assert stat_key == "mean" + sum_name = '%s/%s' % (self._prefix, name) + else: + sum_name = '%s/%s/%s' % (self._prefix, name, stat_key) + sums.append( + ScalarSummary(sum_name, getattr(metric, stat_key)())) + sums.append(ScalarSummary( + '%s/total_transitions' % self._prefix, self._transitions)) + sums.extend(self._summaries) + return sums + + def pop(self) -> List[Summary]: + data = [] + if len(self._episode_returns) > 1: + data = self._get() + self._reset_data() + return data + + def peak(self) -> List[Summary]: + return self._get() + + def reset(self): + self._transitions = 0 + self._reset_data() + + +class SimpleAccumulator(StatAccumulator): + + def __init__(self, eval_video_fps: int = 30, mean_only: bool = True): + self._train_acc = _SimpleAccumulator( + 'train_envs', eval_video_fps, mean_only=mean_only) + self._eval_acc = _SimpleAccumulator( + 'eval_envs', eval_video_fps, mean_only=mean_only) + + def step(self, transition: ReplayTransition, eval: bool): + if eval: + self._eval_acc.step(transition, eval) + else: + self._train_acc.step(transition, eval) + + def pop(self) -> List[Summary]: + return self._train_acc.pop() + self._eval_acc.pop() + + def peak(self) -> List[Summary]: + return self._train_acc.peak() + self._eval_acc.peak() + + def reset(self) -> None: + self._train_acc.reset() + self._eval_acc.reset() + + +class MultiTaskAccumulator(StatAccumulator): + + def __init__(self, num_tasks, + eval_video_fps: int = 30, mean_only: bool = True, + train_prefix: str = 'train_task', + eval_prefix: str = 'eval_task'): + self._train_accs = [_SimpleAccumulator( + '%s%d/envs' % (train_prefix, i), eval_video_fps, mean_only=mean_only) + for i in range(num_tasks)] + self._eval_accs = [_SimpleAccumulator( + '%s%d/envs' % (eval_prefix, i), eval_video_fps, mean_only=mean_only) + for i in range(num_tasks)] + self._train_accs_mean = _SimpleAccumulator( + '%s_summary/envs' % train_prefix, eval_video_fps, + mean_only=mean_only) + + def step(self, transition: ReplayTransition, eval: bool): + raise NotImplementedError() + # replay_index = transition.extra_replay_elements["active_task_id"] + if eval: + self._eval_accs[replay_index].step(transition, eval) + else: + self._train_accs[replay_index].step(transition, eval) + self._train_accs_mean.step(transition, eval) + + def pop(self) -> List[Summary]: + combined = self._train_accs_mean.pop() + for acc in self._train_accs + self._eval_accs: + combined.extend(acc.pop()) + return combined + + def peak(self) -> List[Summary]: + combined = self._train_accs_mean.peak() + for acc in self._train_accs + self._eval_accs: + combined.extend(acc.peak()) + return combined + + def reset(self) -> None: + self._train_accs_mean.reset() + [acc.reset() for acc in self._train_accs + self._eval_accs] diff --git a/yarr/utils/transition.py b/yarr/utils/transition.py new file mode 100755 index 0000000..d7d5d3a --- /dev/null +++ b/yarr/utils/transition.py @@ -0,0 +1,39 @@ +from typing import List + +import numpy as np +from yarr.agents.agent import Summary + + +class Transition(object): + def __init__(self, + observation: dict, + reward: float, + terminal: bool, + info: dict = None, + summaries: List[Summary] = None): + self.observation = observation + self.reward = reward + self.terminal = terminal + self.info = info or {} + self.summaries = summaries or [] + + +class ReplayTransition(object): + def __init__(self, + observation: dict, + action: np.ndarray, + reward: float, + terminal: bool, + timeout: bool, + final_observation: dict = None, + summaries: List[Summary] = None, + info: dict = None): + self.observation = observation + self.action = action + self.reward = reward + self.terminal = terminal + self.timeout = timeout + # final only populated on last timestep + self.final_observation = final_observation + self.summaries = summaries or [] + self.info = info