From d15467d3d8639acda3092e17179877524dfddcbf Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Tue, 13 Feb 2024 10:08:28 +0000 Subject: [PATCH] feat: add unflatten as torch frontend Tensor method and the test tests pass. --- ivy/functional/frontends/torch/tensor.py | 3 ++ .../test_frontends/test_torch/test_tensor.py | 49 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index f470a9860420b..54c6d552d6268 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -823,6 +823,9 @@ def positive(self): def pow(self, exponent): return torch_frontend.pow(self, exponent) + def unflatten(self, dim, sizes): + return torch_frontend.unflatten(self, dim, sizes) + @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def pow_(self, exponent): self.ivy_array = self.pow(exponent).ivy_array diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index b5490c5ac5f5e..d6338be9e6ad6 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -21,6 +21,7 @@ # local import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers.hypothesis_helpers.general_helpers import sizes_ from ivy_tests.test_ivy.test_frontends.test_torch.test_blas_and_lapack_ops import ( _get_dtype_and_3dbatch_matrices, _get_dtype_input_and_matrices, @@ -13800,6 +13801,54 @@ def test_torch_unbind( ) +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="unflatten", + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + shape_key="shape", + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_int=True, + ), +) +def test_torch_unflatten( + *, + dtype_and_values, + on_device, + frontend, + backend_fw, + shape, + axis, + frontend_method_data, + init_flags, + method_flags, +): + dtype, x = dtype_and_values + sizes = sizes_(shape, axis) + helpers.test_frontend_method( + init_input_dtypes=dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={ + "dim": axis, + "sizes": sizes, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # unfold @handle_frontend_method( class_tree=CLASS_TREE,