Skip to content

Commit

Permalink
fix: check for empty tensor in mean for paddle backend
Browse files Browse the repository at this point in the history
  • Loading branch information
jacksondm33 committed Feb 22, 2024
1 parent f6b0966 commit eeecaf7
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions ivy/functional/backends/paddle/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ def max(
return ret.astype(ret_dtype)


def _calculate_reduced_shape(x, axis, keepdims):
if axis is None:
axis = tuple(range(len(x.shape)))
elif type(axis) not in (tuple, list):
axis = (axis,)
if keepdims:
return [1 if i in axis else x.shape[i] for i in range(len(x.shape))]
return [x.shape[i] for i in range(len(x.shape)) if i not in axis]


@with_supported_dtypes(
{"2.6.0 and below": ("bool", "complex", "float32", "float64")}, backend_version
)
Expand All @@ -116,6 +126,9 @@ def mean(
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
ret_dtype = x.dtype
if 0 in x.shape:
shape = _calculate_reduced_shape(x, axis, keepdims)
ret = paddle.empty(shape)
if paddle.is_complex(x):
ret = paddle.complex(
paddle.mean(x.real(), axis=axis, keepdim=keepdims),
Expand Down

0 comments on commit eeecaf7

Please sign in to comment.