Skip to content

Commit

Permalink
Fix typos and finalize ABA
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 12, 2023
1 parent e148da1 commit 3c2de13
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 17 deletions.
12 changes: 12 additions & 0 deletions src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ def horzcat(*x) -> "CasadiLike":
y = [xi.array if isinstance(xi, CasadiLike) else xi for xi in x]
return CasadiLike(cs.horzcat(*y))

@staticmethod
def solve(A: "CasadiLike", b: "CasadiLike") -> "CasadiLike":
"""
Args:
A (CasadiLike): matrix
b (CasadiLike): vector
Returns:
CasadiLike: solution of A*x=b
"""
return CasadiLike(cs.solve(A.array, b.array))


if __name__ == "__main__":
math = SpatialMath()
Expand Down
54 changes: 39 additions & 15 deletions src/adam/core/rbd_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def aba(
base_transform: npt.ArrayLike,
joint_positions: npt.ArrayLike,
joint_velocities: npt.ArrayLike,
joint_torques: npt.ArrayLike,
tau: npt.ArrayLike,
g: npt.ArrayLike,
) -> npt.ArrayLike:
"""Implementation of Articulated Body Algorithm
Expand All @@ -483,24 +483,27 @@ def aba(
c = self.math.factory.zeros(self.model.N, 6, 1)
pA = self.math.factory.zeros(self.model.N, 6, 1)
IA = self.math.factory.zeros(self.model.N, 6, 6)
U = self.math.factory.zeros(self.model.N, 6, 6)
D = self.math.factory.zeros(self.model.N, 6, 6)
u = self.math.factory.zeros(self.model.N, 6, 1)
U = self.math.factory.zeros(self.model.N, 6, 1)
D = self.math.factory.zeros(self.model.N, 1, 1)
u = self.math.factory.zeros(self.model.N, 1, 1)
a = self.math.factory.zeros(self.model.N, 6, 1)
f = self.math.factory.zeros(self.model.N, 6, 1)
sdd = self.math.factory.zeros(self.model.N, 1, 1)
B_X_W = self.math.adjoint_mixed_inverse(base_transform)

# Pass 1
for i, node in enumerate(self.model.tree):
link_i, joint_i, link_pi = node.get_elements()

if link_i.name == self.root_link:
continue
q = joint_positions[joint_i.idx] if joint_i.idx is not None else 0.0
q_dot = joint_velocities[joint_i.idx] if joint_i.idx is not None else 0.0

pi = self.model.tree.get_idx_from_name(link_pi.name)

# Parent-child transform
i_X_pi[i] = joint_i.spatial_transform(joint_positions[i])
v_J = joint_i.motion_subspace() * joint_velocities[i]
i_X_pi[i] = joint_i.spatial_transform(q)
v_J = joint_i.motion_subspace() * q_dot

v[i] = i_X_pi[i] @ v[pi] + v_J
c[i] = i_X_pi[i] @ c[pi] + self.math.spatial_skew(v[i]) @ v_J
Expand All @@ -519,26 +522,47 @@ def aba(
continue

pi = self.model.tree.get_idx_from_name(link_pi.name)
tau_i = tau[joint_i.idx] if joint_i.idx is not None else 0.0

U[i] = IA[i] @ node.joint.motion_subspace()
D[i] = node.joint.motion_subspace().T @ U[i]
u[i] = tau[i] - node.joint.motion_subspace().T @ pA[i]
U[i] = IA[i] @ joint_i.motion_subspace()
D[i] = joint_i.motion_subspace().T @ U[i]
u[i] = self.math.vertcat(tau_i) - joint_i.motion_subspace().T @ pA[i]

Ia = IA[i] - U[i] / D[i] @ U[i].T
pa = pA[i] + Ia @ c[i] + U[i] * u[i] / D[i]

a[0] = B_X_W @ g if self.model.floating_base else self.math.solve(-IA[0], pA[0])

# Pass 3
for i, node in enumerate(self.model.tree):
link_i, joint_i, link_pi = node.get_elements()

if link_i.name == self.root_link:
IA[pi] += i_X_pi[i].T @ Ia @ i_X_pi[i]
pA[pi] += i_X_pi[i].T @ pa
continue

pi = self.model.tree.get_idx_from_name(link_pi.name)

sdd = (u[i] - U[i].T @ a[i]) / D[i]
sdd[i - 1] = (u[i] - U[i].T @ a[i]) / D[i]

a[i] += i_X_pi[i].T @ a[pi] + joint_i.motion_subspace() * sdd[i - 1] + c[i]

# Filter sdd to remove NaNs generate with lumped joints
s_ddot = self.math.vertcat(
*[sdd[i] for i in range(self.model.N) if sdd[i] == sdd[i]]
)

a[i] = i_X_pi[i].T @ a[pi] + node.joint.motion_subspace() * sdd + c[i]
if (
self.frame_velocity_representation
== Representations.BODY_FIXED_REPRESENTATION
):
return a[0], s_ddot

return a, sdd
elif self.frame_velocity_representation == Representations.MIXED_REPRESENTATION:
return (
self.math.vertcat(
self.math.solve(B_X_W, a[0]) + g
if self.model.floating_base
else self.math.zeros(6, 1),
),
s_ddot,
)
6 changes: 4 additions & 2 deletions src/adam/jax/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,12 @@ def forward_dynamics(
base_acceleration (jnp.array): The base acceleration in mixed representation
joint_accelerations (jnp.array): The joints acceleration
"""
return self.rbdalgos.aba(
base_acceleration, joint_accelerations = self.rbdalgos.aba(
base_transform,
joint_positions,
joint_velocities,
joint_torques,
self.g,
).array.squeeze()
)

return base_acceleration.array.squeeze(), joint_accelerations.array.squeeze()

0 comments on commit 3c2de13

Please sign in to comment.