Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Forward Dynamics computation via ABA #51

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f1c4fa6
Add first version of ABA implementation
flferretti Nov 30, 2023
e24ddb7
Update tests
flferretti Dec 1, 2023
6981667
Update variables name in ABA
flferretti Dec 1, 2023
3e79a86
Add `forward_dynamics` in computations
flferretti Dec 1, 2023
1259756
Update tests
flferretti Dec 1, 2023
eb9e6f6
Fix `ABA` in core.
flferretti Dec 1, 2023
111c83e
Fix computations
flferretti Dec 1, 2023
3088164
Add abstractions for solving linear systems
flferretti Dec 4, 2023
ab5023b
Add `floating_base` field
flferretti Dec 5, 2023
e148da1
Overwrite `__eq__` dunder
flferretti Dec 5, 2023
3c2de13
Fix typos and finalize ABA
flferretti Dec 5, 2023
80c7757
Add base velocity as input
flferretti Jan 4, 2024
8568e17
Reduce model tree
flferretti Jan 5, 2024
241750f
Use reduce model in ABA
flferretti Jan 12, 2024
eda6cd3
Add inverse function in spatial math
flferretti Jan 17, 2024
de550d9
Update tree transform and parent array calculation
flferretti Jan 17, 2024
e4a2c4a
Return SpatialMath object in ABA
flferretti Jan 17, 2024
a9d1b16
Update tests
flferretti Jan 17, 2024
874c16c
Update tree reduction
flferretti Jan 18, 2024
01533b6
Update model reduction logic
flferretti Jan 18, 2024
e0c75ff
Merge branch 'ami-iit:main' into aba
flferretti Jan 31, 2024
e053633
Add `vee` operator
flferretti Jan 19, 2024
e7d2f3e
Make StdJoint and StdLink hashable
flferretti Feb 20, 2024
71995ac
Add method to get the joint list
flferretti Feb 20, 2024
431687e
Tree merging logic and update type hints
Giulero Mar 6, 2024
6a983c6
Refactor model.py to use type hints and improve code readability
Giulero Mar 6, 2024
a1f4ba2
Fixing joint update
Giulero Mar 7, 2024
6191da9
Refactor chain calculation in Model class
Giulero Mar 7, 2024
924da1f
Cleanup commits
Giulero Mar 7, 2024
6c1b74d
Commented out code for lumping links
Giulero Mar 8, 2024
72a24b1
Fix target check in Model class
Giulero Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def __getitem__(self, idx) -> "CasadiLike":
"""Overrides get item operator"""
return CasadiLike(self.array[idx])

def __eq__(self, other: Union["CasadiLike", npt.ArrayLike]) -> bool:
"""Overrides == operator"""
if type(self) is not type(other):
return self.array == other
return self.array == other.array

@property
def T(self) -> "CasadiLike":
"""
Expand Down Expand Up @@ -149,6 +155,39 @@ def skew(x: Union["CasadiLike", npt.ArrayLike]) -> "CasadiLike":
else:
return CasadiLike(cs.skew(x))

@staticmethod
def vee(x: Union["CasadiLike", npt.ArrayLike]) -> "CasadiLike":
"""
Args:
x (Union[CasadiLike, npt.ArrayLike]): 3x3 matrix

Returns:
CasadiLike: the vector from the skew symmetric matrix x
"""
vee = lambda x: cs.vertcat(x[2, 1], x[0, 2], x[1, 0])
if isinstance(x, CasadiLike):
return CasadiLike(vee(x.array))
else:
return CasadiLike(vee(x))

@staticmethod
def inv(x: Union["CasadiLike", npt.ArrayLike]) -> "CasadiLike":
"""
Args:
x (Union[CasadiLike, npt.ArrayLike]): matrix

Returns:
CasadiLike: inverse of x
"""
if isinstance(x, CasadiLike):
return CasadiLike(cs.inv(x.array))
else:
return (
CasadiLike(cs.inv(x))
if x.size1() <= 5
else CasadiLike(self.solve(x, self.eye(x.size1())))
)

@staticmethod
def sin(x: npt.ArrayLike) -> "CasadiLike":
"""
Expand Down Expand Up @@ -206,6 +245,18 @@ def horzcat(*x) -> "CasadiLike":
y = [xi.array if isinstance(xi, CasadiLike) else xi for xi in x]
return CasadiLike(cs.horzcat(*y))

@staticmethod
def solve(A: "CasadiLike", b: "CasadiLike") -> "CasadiLike":
"""
Args:
A (CasadiLike): matrix
b (CasadiLike): vector

Returns:
CasadiLike: solution of A*x=b
"""
return CasadiLike(cs.solve(A.array, b.array))


if __name__ == "__main__":
math = SpatialMath()
Expand Down
30 changes: 30 additions & 0 deletions src/adam/casadi/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,33 @@ def CoM_position(
CoM (Union[cs.SX, cs.DM]): The CoM position
"""
return self.rbdalgos.CoM_position(base_transform, joint_positions).array

def forward_dynamics(
self,
base_transform: Union[cs.SX, cs.DM],
base_velocity: Union[cs.SX, cs.DM],
joint_positions: Union[cs.SX, cs.DM],
joint_velocities: Union[cs.SX, cs.DM],
joint_torques: Union[cs.SX, cs.DM],
) -> Union[cs.SX, cs.DM]:
"""Returns base and joints accelerations of the floating-base dynamics equation

Args:
base_transform (Union[cs.SX, cs.DM]): The homogenous transform from base to world frame
base_velocity (Union[cs.SX, cs.DM]): The base velocity in mixed representation
joint_positions (Union[cs.SX, cs.DM]): The joints position
joint_velocities (Union[cs.SX, cs.DM]): The joints velocity
joint_torques (Union[cs.SX, cs.DM]): The joints torque

Returns:
base_acceleration (Union[cs.SX, cs.DM]): The base acceleration in mixed representation
joint_accelerations (Union[cs.SX, cs.DM]): The joints acceleration
"""
return self.rbdalgos.aba(
base_transform,
base_velocity,
joint_positions,
joint_velocities,
joint_torques,
self.g,
)
145 changes: 143 additions & 2 deletions src/adam/core/rbd_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,5 +457,146 @@ def rnea(
tau[:6] = B_X_BI.T @ tau[:6]
return tau

def aba(self):
raise NotImplementedError
def aba(
self,
base_transform: npt.ArrayLike,
base_velocity: npt.ArrayLike,
joint_positions: npt.ArrayLike,
joint_velocities: npt.ArrayLike,
tau: npt.ArrayLike,
g: npt.ArrayLike,
) -> npt.ArrayLike:
"""Implementation of Articulated Body Algorithm

Args:
base_transform (T): The homogenous transform from base to world frame
base_velocity (T): The base velocity in mixed representation
joint_positions (T): The joints position
joint_velocities (T): The joints velocity
tau (T): The generalized force variables
g (T): The 6D gravity acceleration

Returns:
base_acceleration (T): The base acceleration in mixed representation
joint_accelerations (T): The joints acceleration
"""
model = self.model.reduce(self.model.actuated_joints)
joints = list(model.joints.values())

NB = model.N

i_X_pi = self.math.factory.zeros(NB, 6, 6)
v = self.math.factory.zeros(NB, 6, 1)
c = self.math.factory.zeros(NB, 6, 1)
pA = self.math.factory.zeros(NB, 6, 1)
IA = self.math.factory.zeros(NB, 6, 6)
U = self.math.factory.zeros(NB, 6, 1)
D = self.math.factory.zeros(NB, 1, 1)
u = self.math.factory.zeros(NB, 1, 1)

a = self.math.factory.zeros(NB, 6, 1)
sdd = self.math.factory.zeros(NB, 1, 1)
B_X_W = self.math.adjoint_mixed(base_transform)

if model.floating_base:
IA[0] = model.tree.get_node_from_name(self.root_link).link.spatial_inertia()
v[0] = B_X_W @ base_velocity
pA[0] = (
self.math.spatial_skew_star(v[0]) @ IA[0] @ v[0]
) # - self.math.adjoint_inverse(B_X_W).T @ f_ext[0]

def get_tree_transform(self, joints) -> "Array":
"""returns the tree transform

Returns:
Array: the tree transform
"""
relative_transform = lambda j: self.math.inv(
model.tree.graph[j.child].parent_arc.spatial_transform(0)
) @ j.spatial_transform(0)

return self.math.vertcat(
[self.math.factory.eye(6).array]
+ list(
map(
lambda j: relative_transform(j).array
if j.parent != self.root_link
else self.math.factory.eye(6).array,
joints,
)
)
)

tree_transform = get_tree_transform(self, joints)
p = lambda i: list(model.tree.graph).index(joints[i].parent)

# Pass 1
for i, joint in enumerate(joints[1:], start=1):
q = joint_positions[i]
q_dot = joint_velocities[i]

# Parent-child transform
i_X_pi[i] = joint.spatial_transform(q) @ tree_transform[i]
v_J = joint.motion_subspace() * q_dot

v[i] = i_X_pi[i] @ v[p(i)] + v_J
c[i] = i_X_pi[i] @ c[p(i)] + self.math.spatial_skew(v[i]) @ v_J

IA[i] = model.tree.get_node_from_name(joint.parent).link.spatial_inertia()

pA[i] = IA[i] @ c[i] + self.math.spatial_skew_star(v[i]) @ IA[i] @ v[i]

# Pass 2
for i, joint in reversed(
list(
enumerate(
joints[1:],
start=1,
)
)
):
U[i] = IA[i] @ joint.motion_subspace()
D[i] = joint.motion_subspace().T @ U[i]
u[i] = (
self.math.vertcat(tau[joint.idx]) - joint.motion_subspace().T @ pA[i]
if joint.idx is not None
else 0.0
)

Ia = IA[i] - U[i] / D[i] @ U[i].T
pa = pA[i] + Ia @ c[i] + U[i] * u[i] / D[i]

if joint.parent != self.root_link or not model.floating_base:
IA[p(i)] += i_X_pi[i].T @ Ia @ i_X_pi[i]
pA[p(i)] += i_X_pi[i].T @ pa
continue

a[0] = B_X_W @ g if model.floating_base else self.math.solve(-IA[0], pA[0])

# Pass 3
for i, joint in enumerate(joints[1:], start=1):
if joint.parent == self.root_link:
continue

sdd[i - 1] = (u[i] - U[i].T @ a[i]) / D[i]

a[i] += i_X_pi[i].T @ a[p(i)] + joint.motion_subspace() * sdd[i - 1] + c[i]

# Squeeze sdd
s_ddot = self.math.vertcat(*[sdd[i] for i in range(sdd.shape[0])])

if (
self.frame_velocity_representation
== Representations.BODY_FIXED_REPRESENTATION
):
return self.math.horzcat(a[0], s_ddot)

elif self.frame_velocity_representation == Representations.MIXED_REPRESENTATION:
return self.math.horzcat(
self.math.vertcat(
self.math.solve(B_X_W, a[0]) + g
if model.floating_base
else self.math.zeros(6, 1),
),
s_ddot,
)
20 changes: 20 additions & 0 deletions src/adam/core/spatial_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ def cos(x: npt.ArrayLike) -> npt.ArrayLike:
def skew(x):
pass

@abc.abstractmethod
def vee(x):
pass

@abc.abstractmethod
def inv(x):
pass

@abc.abstractmethod
def solve(A: npt.ArrayLike, b: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
A (npt.ArrayLike): matrix
b (npt.ArrayLike): vector

Returns:
npt.ArrayLike: solution of the linear system Ax=b
"""
pass

def R_from_axis_angle(self, axis: npt.ArrayLike, q: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
Expand Down
32 changes: 32 additions & 0 deletions src/adam/jax/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,35 @@ def get_total_mass(self) -> float:
mass: The total mass
"""
return self.rbdalgos.get_total_mass()

def forward_dynamics(
self,
base_transform: jnp.array,
base_velocity: jnp.array,
joint_positions: jnp.array,
joint_velocities: jnp.array,
joint_torques: jnp.array,
) -> jnp.array:
"""Returns base and joints accelerations of the floating-base dynamics equation

Args:
base_transform (jnp.array): The homogenous transform from base to world frame
base_velocity (jnp.array): The base velocity in mixed representation
joint_positions (jnp.array): The joints position
joint_velocities (jnp.array): The joints velocity
joint_torques (jnp.array): The joints torques

Returns:
base_acceleration (jnp.array): The base acceleration in mixed representation
joint_accelerations (jnp.array): The joints acceleration
"""
base_acceleration, joint_accelerations = self.rbdalgos.aba(
base_transform,
base_velocity,
joint_positions,
joint_velocities,
joint_torques,
self.g,
)

return base_acceleration.array.squeeze(), joint_accelerations.array.squeeze()
43 changes: 43 additions & 0 deletions src/adam/jax/jax_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def __neg__(self) -> "JaxLike":
"""Overrides - operator"""
return JaxLike(-self.array)

def __eq__(self, other: Union["JaxLike", npt.ArrayLike]) -> bool:
"""Overrides == operator"""
if type(self) is not type(other):
return self.array.squeeze() == other.squeeze()
return self.array.squeeze() == other.array.squeeze()


class JaxLikeFactory(ArrayLikeFactory):
@staticmethod
Expand Down Expand Up @@ -188,6 +194,31 @@ def skew(x: Union["JaxLike", npt.ArrayLike]) -> "JaxLike":
x = x.array
return JaxLike(-jnp.cross(jnp.array(x), jnp.eye(3), axisa=0, axisb=0))

@staticmethod
def vee(x: Union["JaxLike", npt.ArrayLike]) -> "JaxLike":
"""
Args:
x (Union[JaxLike, npt.ArrayLike]): matrix

Returns:
JaxLike: the vee operator from x
"""
if not isinstance(x, JaxLike):
return JaxLike(jnp.array([x[2, 1], x[0, 2], x[1, 0]]))
x = x.array
return JaxLike(jnp.array([x[2, 1], x[0, 2], x[1, 0]]))

@staticmethod
def inv(x: "JaxLike") -> "JaxLike":
"""
Args:
x (JaxLike): Matrix

Returns:
JaxLike: Inverse of x
"""
return JaxLike(jnp.linalg.inv(x.array))

@staticmethod
def vertcat(*x) -> "JaxLike":
"""
Expand All @@ -211,3 +242,15 @@ def horzcat(*x) -> "JaxLike":
else:
v = jnp.hstack([x[i] for i in range(len(x))])
return JaxLike(v)

@staticmethod
def solve(A: "JaxLike", b: "JaxLike") -> "JaxLike":
"""
Args:
A (JaxLike): Matrix
b (JaxLike): Vector

Returns:
JaxLike: Solution of Ax=b
"""
return JaxLike(jnp.linalg.solve(A.array, b.array))
Loading
Loading