Skip to content

add comments #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions arm/c2farm/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@
def create_replay(batch_size: int, timesteps: int, prioritisation: bool,
save_dir: str, cameras: list, env: Env,
voxel_sizes, replay_size=1e5):
'''
create replay buffer. To be more specific, we first append more `Element`

Input:
- batch_size: default batch size for sampling
- timesteps: how many timestep to be stacked together
- prioritisation: whether to use priority replay
- save_dir
- cameras: list of camera name
- env: RLBench environment
- voxel_sizes
- replay_size: replay buffer capacity

Output:
- replay_buffer: return the replay buffer
'''


trans_indicies_size = 3 * len(voxel_sizes)
rot_and_grip_indicies_size = (3 + 1)
Expand Down Expand Up @@ -74,6 +91,23 @@ def _get_action(
bounds_offset: List[float],
rotation_resolution: int,
crop_augmentation: bool):
'''
extract the gripper pose from the observation

Input:
- obs_tp1:
- rlbench_scene_bounds:
- voxel_sizes:
- bounds_offset:
- rotation_resolution:
- crop_augmentation:

Output:
- trans_indicies
- rot_and_grip_indicies
- action
'''

quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
if quat[-1] < 0:
quat = -quat
Expand Down Expand Up @@ -115,6 +149,23 @@ def _add_keypoints_to_replay(
bounds_offset: List[float],
rotation_resolution: int,
crop_augmentation: bool):
'''


Input:
- replay: the replay buffer we would like to add keypoints
- inital_obs: the initial observation of a (sub) trajectory
- demo: demo trajectory
- env: RLBench environment
- episode_keypoints:
- cameras:
- rlbench_scene_bounds:
- voxel_sizes:
- bounds_offset:
- rotation_resolution:
- crop_augmentation:
'''

prev_action = None
obs = inital_obs
for k, keypoint in enumerate(episode_keypoints):
Expand Down Expand Up @@ -170,6 +221,19 @@ def fill_replay(replay: ReplayBuffer,
bounds_offset: List[float],
rotation_resolution: int,
crop_augmentation: bool):
'''
load the demo trajectories and add them to the replay

Input:
- replay: replay buffer
- task (str): unique id for RLBench environment
- env: RLBench environment
- num_demos: number of demo trajectories
- demo_augmentation:
- cameras (list): list of camera name
- rlbench_scene_bounds:
'''


logging.info('Filling replay with demos...')
for d_idx in range(num_demos):
Expand Down
32 changes: 32 additions & 0 deletions arm/c2farm/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ def __init__(self,
activation: str = 'relu',
dense_feats: int = 32,
include_prev_layer = False,):
'''
q-netowrk for a volume of voxels

Input:
- in_channels: int,
- out_channels: int,
- out_dense: int,
- voxel_size: int,
- low_dim_size: int,
- kernels: int,
- norm: str = None,
- activation: str = 'relu',
- dense_feats: int = 32,
- include_prev_layer=False
Output:

'''

super(Qattention3DNet, self).__init__()
self._in_channels = in_channels
self._out_channels = out_channels
Expand Down Expand Up @@ -127,6 +145,20 @@ def build(self):
self._dense_feats, self._out_dense, None, None)

def forward(self, ins, proprio, prev_layer_voxel_grid):
'''
apply lots of 3D convolution to calculate the q value for each voxel

Input:
- ins (torch.Tensor): the input voxel grid with shape (batch_size,
voxel_feat, voxel_size, voxel_size, voxel_size)
- proprio (torch.Tensor): proprioceptive data of robot arm
([1, 3]) or ([128, 3])

Output:
- trans
- rot_and_grip_out
'''

b, _, d, h, w = ins.shape
x = self._input_preprocess(ins)

Expand Down
170 changes: 170 additions & 0 deletions arm/c2farm/qattention_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,42 @@ def __init__(self,
self._qnet.build()

def _argmax_3d(self, tensor_orig):
'''
calculate the index of x-axis, y-axis, z-axis with the maximum
q value of the voxel grid

Input:
- tensor_orig: the q value voxel grid with shape (batch_size,
channel_size, voxel_size, voxel_size, voxel_size)

Output:
- indices: index of the voxel with highest q-value in the voxel grid
with shape (batch_size, 3)
'''

b, c, d, h, w = tensor_orig.shape # c will be one
idxs = tensor_orig.view(b, c, -1).argmax(-1)
indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
return indices

def choose_highest_action(self, q_trans, q_rot_grip):
'''
choose the voxel with the highest q value\\
If `q_rot_grip` is `None`, `rot_and_grip_indicies` will be `None` as well

Input:
- q_trans: the q value voxel grid with shape (batch_size,
channel_size, voxel_size, voxel_size, voxel_size)
- q_rot_grip: (batch_size, 360//rotation_resolution*3)

Output:
- coords: position (index) voxel field with highest q-value with
shape (batch_size, 3)
- rot_and_grip_indicies: rotation index for eular (xyz) and
whether the gripper is open (0 or 1) with shape
(batch_size, 4)
'''

coords = self._argmax_3d(q_trans)
rot_and_grip_indicies = None
if q_rot_grip is not None:
Expand All @@ -58,6 +88,23 @@ def choose_highest_action(self, q_trans, q_rot_grip):

def forward(self, x, proprio, pcd,
bounds=None, latent=None):
'''
Input:
- x (list): [rgb, pcd]
- proprio: the state for the robot arm
- pcd: point cloud
- bound
- latent

Output:
- q_trans: the q value voxel grid with shape (batch_size,
channel_size, voxel_size, voxel_size, voxel_size)
- rot_and_grip_q: (batch_size, 360//rotation_resolution*3)
- voxel_grid: the voxel grid formed by observation (rgb and
depth image) with shape (batch_size, voxel_feat,
voxel_size, voxel_size, voxel_size)
'''

# x will be list of list (list of [rgb, pcd])
b = x[0][0].shape[0]
pcd_flat = torch.cat(
Expand Down Expand Up @@ -136,6 +183,14 @@ def __init__(self,
self._name = NAME + '_layer' + str(self._layer)

def build(self, training: bool, device: torch.device = None):
'''
build the network

Input:
- training: whether the built network is in the training mode
- device: the device (gpu/cpu) used for the agent
'''

if device is None:
device = torch.device('cpu')

Expand Down Expand Up @@ -184,6 +239,17 @@ def build(self, training: bool, device: torch.device = None):
self._device = device

def _extract_crop(self, pixel_action, observation):
'''
use the `pixel_action` as anchor

Input:
- pixel_action: (batch_size, 1, 2)
- observation: (batch_size, 1, 3, img_h, img_w)

Output:
- crop:
'''

# Pixel action will now be (B, 2)
observation = stack_on_channel(observation)
h = observation.shape[-1]
Expand All @@ -200,6 +266,22 @@ def _extract_crop(self, pixel_action, observation):
return crop

def _preprocess_inputs(self, replay_sample):
'''
pack the inputs

If the layer > 0, we will crop the rgb/depth image

Input:
- replay_sample: the sampled transitions from the replay buffer

Output:
- obs: rgb and depth image
- obs_tp1: rgb and depth image for next timestep
- pcds: depth image
- pcds_tp1: depth image for next timestep
'''


obs, obs_tp1 = [], []
pcds, pcds_tp1 = [], []
self._crop_summary, self._crop_summary_tp1 = [], []
Expand Down Expand Up @@ -229,6 +311,28 @@ def _preprocess_inputs(self, replay_sample):
return obs, obs_tp1, pcds, pcds_tp1

def _act_preprocess_inputs(self, observation):
'''
pack the observation for each camera

Input:
- obervation:
- front_rgb:
- front_point_cloud:
- low_dim_state:
- front_camera_extrinsics:
- front_camera_intrinsics:

If the depth > 0

- attention_coordinate:
- prev_layer_voxel_grid:
- front_pixel_coord:

Output:
- obs (list): [[rgb, pcd], [rgb, pcd], ......]
- pcds: list of point cloud (with shape (batch_size, 3, h, w))
'''

obs, pcds = [], []
for n in self._camera_names:
if self._layer > 0 and 'wrist' not in n:
Expand All @@ -243,6 +347,19 @@ def _act_preprocess_inputs(self, observation):
return obs, pcds

def _get_value_from_voxel_index(self, q, voxel_idx):
'''
extract the feature from the voxel grid feature with index

Input:
- q: a voxel-grid of the q value (batch_size, channel_size,
voxel_size, voxel_size, voxel_size)
- voxel_idx: (batch_size, 3)

Output:
- chosen_voxel_values: (batch_size, channel_size)
'''


b, c, d, h, w = q.shape
q_flat = q.view(b, c, d * h * w)
flat_indicies = (voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2])[:, None].long()
Expand All @@ -251,6 +368,17 @@ def _get_value_from_voxel_index(self, q, voxel_idx):
return chosen_voxel_values

def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
'''


Input:
- rot_grip_q: (batch_size, 360//rotation_resolution*3 + 2)
- rot_and_grip_idx: (batch_size, 4)

Output:
- rot_and_grip_values: (batch_size, 4)
'''

q_rot = torch.stack(torch.split(
rot_grip_q[:, :-2], int(360 // self._rotation_resolution),
dim=1), dim=1) # B, 3, 72
Expand All @@ -263,6 +391,16 @@ def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
return rot_and_grip_values

def update(self, step: int, replay_sample: dict) -> dict:
'''
update the policy parameters

NOTE: 'tp1' means next state

Input:
- step:
- replay_sample (dict): contains the sampled transitions
'''


action_trans = replay_sample['trans_action_indicies'][:, -1,
self._layer * 3:self._layer * 3 + 3]
Expand Down Expand Up @@ -374,6 +512,38 @@ def update(self, step: int, replay_sample: dict) -> dict:
def act(self, step: int, observation: dict,
deterministic=False) -> ActResult:
deterministic = True # TODO: Don't explicitly explore.
'''
take the observation as input and reture the action

Input:
- step: dummy, please neglect
- observation (dict):
- front_rgb: rgb image of front camera
- front_point_cloud: depth image of front camera
- low_dim_state: the state of the robot arm
- front_camera_extrinsics
- front_camera_intrinsics

(when layer>=1, the following information is included)
- attention_coordinate
- prev_layer_voxel_grid
- front_pixel_coord

Output:
- act_result:
- action (tuple): it contain `coords` and `rot_grip_action`
- coords: the gripper position in the voxel grid index
- rot_grip_action:
- observation_elements (dict):
- attention_coordinate:
- prev_layer_voxel_grid:
- info (dict):
- voxel_grid_depth:
- q_depth:
- voxel_idx_depth:
'''


bounds = self._coordinate_bounds

if self._layer > 0:
Expand Down
8 changes: 8 additions & 0 deletions arm/c2farm/qattention_stack_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def __init__(self,
self._rotation_prediction_depth = rotation_prediction_depth

def build(self, training: bool, device=None) -> None:
'''
iteratively build the network for each agent

Input:
- training: whether the agent is in the training mode
- device:
'''

self._device = device
if self._device is None:
self._device = torch.device('cpu')
Expand Down
Loading