Skip to content

Commit

Permalink
Changed numpy frontend's default dtype setting to only happen in case…
Browse files Browse the repository at this point in the history
…s dtype is not defined and all args are not arrays to prevent major issues with jax x64 flag
  • Loading branch information
fspyridakos committed Mar 22, 2023
1 parent cb815ff commit bc95858
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions ivy/functional/frontends/numpy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,15 @@ def new_fn(*args, order="K", **kwargs):
# handle order and call unmodified function
# ToDo: Remove this default dtype setting
# once frontend specific backend setting is added
ivy.set_default_int_dtype(
"int64"
) if platform.system() != "Windows" else ivy.set_default_int_dtype("int32")
ivy.set_default_float_dtype("float64")
set_default_dtype = False
if not ("dtype" in kwargs and ivy.exists(kwargs["dtype"])) and all(
[not (ivy.is_array(i) or hasattr(i, "ivy_array")) for i in args]
):
ivy.set_default_int_dtype(
"int64"
) if platform.system() != "Windows" else ivy.set_default_int_dtype("int32")
ivy.set_default_float_dtype("float64")
set_default_dtype = True
if contains_order:
if len(args) >= (order_pos + 1):
order = args[order_pos]
Expand All @@ -372,14 +377,16 @@ def new_fn(*args, order="K", **kwargs):
try:
ret = fn(*args, order=order, **kwargs)
finally:
ivy.unset_default_int_dtype()
ivy.unset_default_float_dtype()
if set_default_dtype:
ivy.unset_default_int_dtype()
ivy.unset_default_float_dtype()
else:
try:
ret = fn(*args, **kwargs)
finally:
ivy.unset_default_int_dtype()
ivy.unset_default_float_dtype()
if set_default_dtype:
ivy.unset_default_int_dtype()
ivy.unset_default_float_dtype()
if not ivy.get_array_mode():
return ret
# convert all returned arrays to `ndarray` instances
Expand Down

0 comments on commit bc95858

Please sign in to comment.