From 14f227fb833add8bccfb016771cc70e2ab9dd838 Mon Sep 17 00:00:00 2001 From: Carter Francis Date: Tue, 7 May 2024 11:40:45 -0500 Subject: [PATCH 1/3] BugFix: Allow get_kwargs for setting kwargs on __mul__ --- orix/quaternion/quaternion.py | 3 ++- orix/vector/vector3d.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/orix/quaternion/quaternion.py b/orix/quaternion/quaternion.py index e7fdf5e2..ed4d3be1 100644 --- a/orix/quaternion/quaternion.py +++ b/orix/quaternion/quaternion.py @@ -240,7 +240,8 @@ def __mul__(self, other: Union[Quaternion, Vector3d]): m.coordinate_format = other.coordinate_format return m else: - return other.__class__(v) + kwargs = other.get_kwargs(self) + return other.__class__(v, **kwargs) return NotImplemented def __neg__(self) -> Quaternion: diff --git a/orix/vector/vector3d.py b/orix/vector/vector3d.py index a565ba71..5e3fc3f6 100644 --- a/orix/vector/vector3d.py +++ b/orix/vector/vector3d.py @@ -731,6 +731,11 @@ def to_polar( polar = np.rad2deg(polar) return azimuth, polar, self.radial + def get_kwargs(self, *args, **kwargs) -> Dict[str, Any]: + """Return the dictionary of attributes to be used in the + constructor when applying an ufunc.""" + return {} + def in_fundamental_sector(self, symmetry: "Symmetry") -> Vector3d: """Project vectors to a symmetry's fundamental sector (inverse pole figure). From 77ef0daf9775f6eea44f1578db20173b772ee8ed Mon Sep 17 00:00:00 2001 From: Carter Francis Date: Tue, 7 May 2024 12:19:20 -0500 Subject: [PATCH 2/3] BugFix: Add get_kwargs to all dunder functions. --- orix/quaternion/quaternion.py | 4 +-- orix/vector/miller.py | 4 +++ orix/vector/vector3d.py | 50 ++++++++++++++++++++++------------- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/orix/quaternion/quaternion.py b/orix/quaternion/quaternion.py index ed4d3be1..3061b6fb 100644 --- a/orix/quaternion/quaternion.py +++ b/orix/quaternion/quaternion.py @@ -240,8 +240,8 @@ def __mul__(self, other: Union[Quaternion, Vector3d]): m.coordinate_format = other.coordinate_format return m else: - kwargs = other.get_kwargs(self) - return other.__class__(v, **kwargs) + kwargs = other.get_kwargs(v, self) + return other.__class__(**kwargs) return NotImplemented def __neg__(self) -> Quaternion: diff --git a/orix/vector/miller.py b/orix/vector/miller.py index f8a573f8..bd9e5f8f 100644 --- a/orix/vector/miller.py +++ b/orix/vector/miller.py @@ -161,6 +161,10 @@ def hkl(self) -> np.ndarray: """ return _transform_space(self.data, "c", "r", self.phase.structure.lattice) + def get_kwargs(self, new_data, *args, **kwargs) -> dict: + """Return the keyword arguments to create the instance.""" + return dict(xyz=new_data) + @hkl.setter def hkl(self, value: np.ndarray): """Set the reciprocal lattice vectors.""" diff --git a/orix/vector/vector3d.py b/orix/vector/vector3d.py index 5e3fc3f6..e1bc04be 100644 --- a/orix/vector/vector3d.py +++ b/orix/vector/vector3d.py @@ -219,31 +219,36 @@ def polar(self) -> np.ndarray: # ------------------------ Dunder methods ------------------------ # def __neg__(self) -> Vector3d: - return self.__class__(-self.data) + kwargs = self.get_kwargs(new_data=-self.data) + return self.__class__(**kwargs) def __add__( self, other: Union[int, float, List, Tuple, np.ndarray, Vector3d] ) -> Vector3d: if isinstance(other, Vector3d): - return self.__class__(self.data + other.data) + kwargs = self.get_kwargs(new_data=self.data + other.data) + return self.__class__(**kwargs) elif isinstance(other, (int, float)): - return self.__class__(self.data + other) + kwargs = self.get_kwargs(new_data=self.data + other) + return self.__class__(**kwargs) elif isinstance(other, (list, tuple)): other = np.array(other) - if isinstance(other, np.ndarray): - return self.__class__(self.data + other[..., np.newaxis]) + kwargs = self.get_kwargs(new_data=self.data + other[..., np.newaxis]) + return self.__class__(**kwargs) return NotImplemented def __radd__(self, other: Union[int, float, List, Tuple, np.ndarray]) -> Vector3d: if isinstance(other, (int, float)): - return self.__class__(other + self.data) + kwargs = self.get_kwargs(new_data=other + self.data) + return self.__class__(**kwargs) elif isinstance(other, (list, tuple)): other = np.array(other) if isinstance(other, np.ndarray): - return self.__class__(other[..., np.newaxis] + self.data) + kwargs = self.get_kwargs(new_data=other[..., np.newaxis] + self.data) + return self.__class__(**kwargs) return NotImplemented @@ -251,14 +256,17 @@ def __sub__( self, other: Union[int, float, List, Tuple, np.ndarray, Vector3d] ) -> Vector3d: if isinstance(other, Vector3d): - return self.__class__(self.data - other.data) + kwargs = self.get_kwargs(new_data=self.data - other.data) + return self.__class__(**kwargs) elif isinstance(other, (int, float)): - return self.__class__(self.data - other) + kwargs = self.get_kwargs(self.data - other) + return self.__class__(**kwargs) elif isinstance(other, (list, tuple)): other = np.array(other) if isinstance(other, np.ndarray): - return self.__class__(self.data - other[..., np.newaxis]) + kwargs = self.get_kwargs(self.data - other[..., np.newaxis]) + return self.__class__(**kwargs) return NotImplemented @@ -282,23 +290,27 @@ def __mul__( "Try `.dot` or `.cross` instead." ) elif isinstance(other, (int, float)): - return self.__class__(self.data * other) + kwargs = self.get_kwargs(new_data=self.data * other) + return self.__class__(**kwargs) elif isinstance(other, (list, tuple)): other = np.array(other) if isinstance(other, np.ndarray): - return self.__class__(self.data * other[..., np.newaxis]) + kwargs = self.get_kwargs(new_data=self.data * other[..., np.newaxis]) + return self.__class__(**kwargs) return NotImplemented def __rmul__(self, other: Union[int, float, List, Tuple, np.ndarray]) -> Vector3d: if isinstance(other, (int, float)): - return self.__class__(other * self.data) + kwargs = self.get_kwargs(new_data=other * self.data) + return self.__class__(**kwargs) elif isinstance(other, (list, tuple)): other = np.array(other) if isinstance(other, np.ndarray): - return self.__class__(other[..., np.newaxis] * self.data) + kwargs = self.get_kwargs(new_data=other[..., np.newaxis] * self.data) + return self.__class__(**kwargs) return NotImplemented @@ -308,12 +320,14 @@ def __truediv__( if isinstance(other, Vector3d): raise ValueError("Dividing vectors is undefined") elif isinstance(other, (int, float)): - return self.__class__(self.data / other) + kwargs = self.get_kwargs(new_data=self.data / other) + return self.__class__(**kwargs) elif isinstance(other, (list, tuple)): other = np.array(other) if isinstance(other, np.ndarray): - return self.__class__(self.data / other[..., np.newaxis]) + kwargs = self.get_kwargs(new_data=self.data / other[..., np.newaxis]) + return self.__class__(**kwargs) return NotImplemented @@ -731,10 +745,10 @@ def to_polar( polar = np.rad2deg(polar) return azimuth, polar, self.radial - def get_kwargs(self, *args, **kwargs) -> Dict[str, Any]: + def get_kwargs(self, new_data, *args, **kwargs) -> Dict[str, Any]: """Return the dictionary of attributes to be used in the constructor when applying an ufunc.""" - return {} + return dict(data=new_data) def in_fundamental_sector(self, symmetry: "Symmetry") -> Vector3d: """Project vectors to a symmetry's fundamental sector (inverse From 5dc50dfcd6e3c229fa1c3f7af31824de7e3a8822 Mon Sep 17 00:00:00 2001 From: Carter Francis Date: Tue, 7 May 2024 15:03:24 -0500 Subject: [PATCH 3/3] BugFix: pass rotation to ``get_kwargs`` --- orix/quaternion/quaternion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orix/quaternion/quaternion.py b/orix/quaternion/quaternion.py index 3061b6fb..28f355f3 100644 --- a/orix/quaternion/quaternion.py +++ b/orix/quaternion/quaternion.py @@ -240,7 +240,7 @@ def __mul__(self, other: Union[Quaternion, Vector3d]): m.coordinate_format = other.coordinate_format return m else: - kwargs = other.get_kwargs(v, self) + kwargs = other.get_kwargs(v, rotation=self) return other.__class__(**kwargs) return NotImplemented