Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve model handling speed #31

Merged
merged 11 commits into from
May 17, 2024
6 changes: 3 additions & 3 deletions src/rod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def check_compatible_sdformat(specification_version: str) -> None:

if not GazeboHelper.has_gazebo():
return
else:
cmdline = GazeboHelper.get_gazebo_executable()
logging.info(f"Calling sdformat through '{cmdline} sdf'")

cmdline = GazeboHelper.get_gazebo_executable()
logging.info(f"Calling sdformat through '{cmdline} sdf'")

output_sdf_version = packaging.version.Version(
xmltodict.parse(
Expand Down
82 changes: 42 additions & 40 deletions src/rod/builder/primitive_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import abc
import dataclasses
from typing import Optional, Union
from typing import Optional

import numpy as np
import numpy.typing as npt
Expand All @@ -14,15 +16,15 @@ class PrimitiveBuilder(abc.ABC):
name: str
mass: float

element: Union[rod.Model, rod.Link, rod.Inertial, rod.Collision, rod.Visual] = (
element: rod.Model | rod.Link | rod.Inertial | rod.Collision | rod.Visual = (
dataclasses.field(
default=None, init=False, repr=False, hash=False, compare=False
)
)

def build(
self,
) -> Union[rod.Model, rod.Link, rod.Inertial, rod.Collision, rod.Visual]:
) -> rod.Model | rod.Link | rod.Inertial | rod.Collision | rod.Visual:
return self.element

# ================
Expand All @@ -43,9 +45,9 @@ def _geometry(self) -> rod.Geometry:

def build_model(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
) -> PrimitiveBuilder:
self._check_element()

self.element = self._model(name=name, pose=pose)
Expand All @@ -54,16 +56,16 @@ def build_model(

def build_link(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
) -> PrimitiveBuilder:
self._check_element()

self.element = self._link(name=name, pose=pose)

return self

def build_inertial(self, pose: Optional[rod.Pose] = None) -> "PrimitiveBuilder":
def build_inertial(self, pose: rod.Pose | None = None) -> PrimitiveBuilder:
self._check_element()

self.element = self._inertial(pose=pose)
Expand All @@ -72,9 +74,9 @@ def build_inertial(self, pose: Optional[rod.Pose] = None) -> "PrimitiveBuilder":

def build_visual(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
) -> PrimitiveBuilder:
self._check_element()

self.element = self._visual(name=name, pose=pose)
Expand All @@ -83,9 +85,9 @@ def build_visual(

def build_collision(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
) -> PrimitiveBuilder:
self._check_element()

self.element = self._collision(name=name, pose=pose)
Expand All @@ -98,10 +100,10 @@ def build_collision(

def add_link(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
link: Optional[rod.Link] = None,
) -> "PrimitiveBuilder":
name: str | None = None,
pose: rod.Pose | None = None,
link: rod.Link | None = None,
) -> PrimitiveBuilder:
if not isinstance(self.element, rod.Model):
raise ValueError(type(self.element))

Expand All @@ -116,9 +118,9 @@ def add_link(

def add_inertial(
self,
pose: Optional[rod.Pose] = None,
inertial: Optional[rod.Inertial] = None,
) -> "PrimitiveBuilder":
pose: rod.Pose | None = None,
inertial: rod.Inertial | None = None,
) -> PrimitiveBuilder:
if not isinstance(self.element, (rod.Model, rod.Link)):
raise ValueError(type(self.element))

Expand All @@ -144,11 +146,11 @@ def add_inertial(

def add_visual(
self,
name: Optional[str] = None,
name: str | None = None,
use_inertial_pose: bool = True,
pose: Optional[rod.Pose] = None,
visual: Optional[rod.Visual] = None,
) -> "PrimitiveBuilder":
pose: rod.Pose | None = None,
visual: rod.Visual | None = None,
) -> PrimitiveBuilder:
if not isinstance(self.element, (rod.Model, rod.Link)):
raise ValueError(type(self.element))

Expand Down Expand Up @@ -180,11 +182,11 @@ def add_visual(

def add_collision(
self,
name: Optional[str] = None,
name: str | None = None,
use_inertial_pose: bool = True,
pose: Optional[rod.Pose] = None,
collision: Optional[rod.Collision] = None,
) -> "PrimitiveBuilder":
pose: rod.Pose | None = None,
collision: rod.Collision | None = None,
) -> PrimitiveBuilder:
if not isinstance(self.element, (rod.Model, rod.Link)):
raise ValueError(type(self.element))

Expand Down Expand Up @@ -224,8 +226,8 @@ def add_collision(

def _model(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
name: str | None = None,
pose: rod.Pose | None = None,
) -> rod.Model:
name = name if name is not None else self.name
logging.debug(f"Building model '{name}'")
Expand All @@ -240,15 +242,15 @@ def _model(

def _link(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
name: str | None = None,
pose: rod.Pose | None = None,
) -> rod.Link:
return rod.Link(
name=name if name is not None else f"{self.name}_link",
pose=pose,
)

def _inertial(self, pose: Optional[rod.Pose] = None) -> rod.Inertial:
def _inertial(self, pose: rod.Pose | None = None) -> rod.Inertial:
return rod.Inertial(
pose=pose,
mass=self.mass,
Expand All @@ -257,8 +259,8 @@ def _inertial(self, pose: Optional[rod.Pose] = None) -> rod.Inertial:

def _visual(
self,
name: Optional[str] = None,
pose: Optional[rod.Pose] = None,
name: str | None = None,
pose: rod.Pose | None = None,
) -> rod.Visual:
name = name if name is not None else f"{self.name}_visual"

Expand All @@ -271,7 +273,7 @@ def _visual(
def _collision(
self,
name: Optional[str],
pose: Optional[rod.Pose] = None,
pose: rod.Pose | None = None,
) -> rod.Collision:
name = name if name is not None else f"{self.name}_collision"

Expand All @@ -297,7 +299,7 @@ def build_pose(
relative_to: str = None,
degrees: bool = None,
rotation_format: str = None,
) -> Optional[rod.Pose]:
) -> rod.Pose | None:
if pos is None and rpy is None:
return rod.Pose.from_transform(transform=np.eye(4), relative_to=relative_to)

Expand Down
3 changes: 1 addition & 2 deletions src/rod/builder/primitives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import pathlib
from typing import Union

import trimesh
from numpy.typing import NDArray
Expand Down Expand Up @@ -63,7 +62,7 @@ def _geometry(self) -> rod.Geometry:

@dataclasses.dataclass
class MeshBuilder(PrimitiveBuilder):
mesh_path: Union[str, pathlib.Path]
mesh_path: str | pathlib.Path
scale: NDArray

def __post_init__(self) -> None:
Expand Down
10 changes: 6 additions & 4 deletions src/rod/kinematics/kinematic_tree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import copy
import dataclasses
import functools
from typing import Dict, List, Sequence, Tuple, Union
from typing import Dict, List, Sequence, Tuple

import numpy as np

Expand All @@ -12,7 +14,7 @@

@dataclasses.dataclass(frozen=True)
class KinematicTree(DirectedTree):
model: "rod.Model"
model: rod.Model

joints: List[TreeEdge] = dataclasses.field(default_factory=list)
frames: List[TreeFrame] = dataclasses.field(default_factory=list)
Expand Down Expand Up @@ -46,7 +48,7 @@ def joint_names(self) -> List[str]:
return [joint.name() for joint in self.joints]

@staticmethod
def build(model: "rod.Model", is_top_level: bool = True) -> "KinematicTree":
def build(model: rod.Model, is_top_level: bool = True) -> KinematicTree:
logging.debug(msg=f"Building kinematic tree of model '{model.name}'")

if model.model is not None:
Expand Down Expand Up @@ -199,7 +201,7 @@ def build(model: "rod.Model", is_top_level: bool = True) -> "KinematicTree":
new_base_node, additional_frames = KinematicTree.remove_edge(
edge=world_to_base_edge, keep_parent=False
)
assert any([f.name() == TreeFrame.WORLD for f in additional_frames])
assert any(f.name() == TreeFrame.WORLD for f in additional_frames)

# Replace the former base node with the new base node
nodes_links_dict[new_base_node.name()] = new_base_node
Expand Down
Loading
Loading