Skip to content

Commit

Permalink
feat: Add tensorflow GatherNd raw_ops (#27745)
Browse files Browse the repository at this point in the history
  • Loading branch information
TalhaKhalil authored Feb 7, 2024
1 parent 0c3a4a9 commit 252047c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ivy/functional/frontends/tensorflow/raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,11 @@ def Gather(*, params, indices, validate_indices=None, name="Gather"):
return ivy.gather(params, indices, axis=0, batch_dims=0)


@to_ivy_arrays_and_back
def GatherNd(*, params, indices, validate_indices=None, name="GatherNd"):
return ivy.gather_nd(params, indices, batch_dims=0)


@to_ivy_arrays_and_back
def Greater(*, x, y, name="Greater"):
x, y = check_tensorflow_casting(x, y)
Expand Down
36 changes: 36 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,42 @@ def test_tensorflow_Gather( # NOQA
)


# GatherNd
@handle_frontend_test(
fn_tree="tensorflow.raw_ops.GatherNd",
params_indices_axis_batch_dims=helpers.array_indices_axis(
array_dtypes=helpers.get_dtypes("valid"),
indices_dtypes=["int64"],
min_num_dims=5,
max_num_dims=10,
min_dim_size=1,
max_dim_size=5,
indices_same_dims=False,
),
test_with_out=st.just(False),
)
def test_tensorflow_GatherNd(
*,
params_indices_axis_batch_dims,
frontend,
test_flags,
fn_tree,
backend_fw,
on_device,
):
input_dtypes, params, indices, axis, batch_dims = params_indices_axis_batch_dims
helpers.test_frontend_function(
input_dtypes=input_dtypes,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
params=params,
indices=indices,
)


# Greater
@handle_frontend_test(
fn_tree="tensorflow.raw_ops.Greater",
Expand Down

0 comments on commit 252047c

Please sign in to comment.