Skip to content

Commit

Permalink
Default integers to 32-bit precision
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jun 4, 2024
1 parent 38bc042 commit 2f3c6ef
Show file tree
Hide file tree
Showing 23 changed files with 1,095 additions and 1,084 deletions.
10 changes: 5 additions & 5 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ defmodule EXLA do
iex> EXLA.jit(&Nx.add(&1, &1)).(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
s64[3]
s32[3]
[2, 4, 6]
>
Expand Down Expand Up @@ -249,7 +249,7 @@ defmodule EXLA do
iex> EXLA.jit_apply(&Nx.add(&1, &1), [Nx.tensor([1, 2, 3])])
#Nx.Tensor<
s64[3]
s32[3]
[2, 4, 6]
>
Expand All @@ -262,10 +262,10 @@ defmodule EXLA do
@doc """
A shortcut for `Nx.Defn.compile/3` with the EXLA compiler.
iex> fun = EXLA.compile(&Nx.add(&1, &1), [Nx.template({3}, {:s, 64})])
iex> fun = EXLA.compile(&Nx.add(&1, &1), [Nx.template({3}, {:s, 32})])
iex> fun.(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
s64[3]
s32[3]
[2, 4, 6]
>
Expand Down Expand Up @@ -327,7 +327,7 @@ defmodule EXLA do
Now let's invoke it:
stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 64}), 0])
stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 32}), 0])
for i <- 1..5 do
Nx.Stream.send(stream, i)
Expand Down
20 changes: 10 additions & 10 deletions exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ defmodule EXLA.BackendTest do

test "Nx.to_binary/1" do
t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
assert Nx.to_binary(t) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
assert Nx.to_binary(t, limit: 2) == <<1::64-native, 2::64-native>>
assert Nx.to_binary(t, limit: 6) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
assert Nx.to_binary(t) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
assert Nx.to_binary(t, limit: 2) == <<1::32-native, 2::32-native>>
assert Nx.to_binary(t, limit: 6) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
end

test "Nx.backend_transfer/1" do
Expand All @@ -44,7 +44,7 @@ defmodule EXLA.BackendTest do
assert %EXLA.Backend{buffer: %EXLA.DeviceBuffer{}} = et.data

nt = Nx.backend_transfer(et)
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>

assert_raise RuntimeError, ~r"called on deleted or donated buffer", fn ->
Nx.backend_transfer(et)
Expand All @@ -63,7 +63,7 @@ defmodule EXLA.BackendTest do
assert old_buffer == new_buffer

nt = Nx.backend_transfer(et)
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>

assert_raise RuntimeError, ~r"called on deleted or donated buffer", fn ->
Nx.backend_transfer(et)
Expand All @@ -83,10 +83,10 @@ defmodule EXLA.BackendTest do
assert old_buffer != new_buffer

nt = Nx.backend_copy(et)
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>

nt = Nx.backend_copy(et)
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
end

test "different clients" do
Expand All @@ -102,10 +102,10 @@ defmodule EXLA.BackendTest do
assert new_buffer.device_id == 0

nt = Nx.backend_copy(et)
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>

nt = Nx.backend_copy(et)
assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>>
assert Nx.to_binary(nt) == <<1::32-native, 2::32-native, 3::32-native, 4::32-native>>
end
end

Expand Down Expand Up @@ -151,7 +151,7 @@ defmodule EXLA.BackendTest do
assert inspect(t) ==
"""
#Nx.Tensor<
s64[4]
s32[4]
[1, 2, 3, 4]
>\
"""
Expand Down
36 changes: 18 additions & 18 deletions exla/test/exla/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ defmodule EXLA.Defn.ExprTest do
indices = Nx.tensor([[0]])
updates = Nx.tensor([1])

assert_equal(indexed_add(target, indices, updates), Nx.tensor([1], type: {:s, 64}))
assert_equal(indexed_add(target, indices, updates), Nx.tensor([1], type: {:s, 32}))

target = Nx.tensor([0])
indices = Nx.tensor([[0]])
Expand Down Expand Up @@ -1879,7 +1879,7 @@ defmodule EXLA.Defn.ExprTest do
indices = Nx.tensor([[0]])
updates = Nx.tensor([1])

assert_equal(indexed_put(target, indices, updates), Nx.tensor([1], type: {:s, 64}))
assert_equal(indexed_put(target, indices, updates), Nx.tensor([1], type: {:s, 32}))

target = Nx.tensor([0])
indices = Nx.tensor([[0]])
Expand Down Expand Up @@ -1963,7 +1963,7 @@ defmodule EXLA.Defn.ExprTest do
test "computes the sum across types" do
assert_equal(Nx.tensor([1, 2, 3]) |> sum(), Nx.tensor(6))
assert_equal(Nx.tensor([1, 2, 3], type: {:s, 8}) |> sum(), Nx.tensor(6))
assert_equal(Nx.tensor([1, 2, 3], type: {:u, 8}) |> sum(), Nx.tensor(6, type: {:u, 64}))
assert_equal(Nx.tensor([1, 2, 3], type: {:u, 8}) |> sum(), Nx.tensor(6, type: {:u, 32}))
assert_equal(Nx.tensor([1.0, 2.0, 3.0]) |> sum(), Nx.tensor(6.0))

assert_equal(
Expand All @@ -1986,9 +1986,9 @@ defmodule EXLA.Defn.ExprTest do
defn sum_equal(t), do: Nx.sum(Nx.equal(t, 1.0))

test "does not overflow" do
assert_equal(sum_equal(Nx.tensor(1)), Nx.tensor(1, type: {:u, 64}))
assert_equal(sum_equal(Nx.tensor([1, 1, 1])), Nx.tensor(3, type: {:u, 64}))
assert_equal(sum_equal(Nx.tensor([1, 2, 3])), Nx.tensor(1, type: {:u, 64}))
assert_equal(sum_equal(Nx.tensor(1)), Nx.tensor(1, type: {:u, 32}))
assert_equal(sum_equal(Nx.tensor([1, 1, 1])), Nx.tensor(3, type: {:u, 32}))
assert_equal(sum_equal(Nx.tensor([1, 2, 3])), Nx.tensor(1, type: {:u, 32}))
end

defn sum_keep(t), do: Nx.sum(t, keep_axes: true)
Expand All @@ -2011,7 +2011,7 @@ defmodule EXLA.Defn.ExprTest do
test "computes the product across types" do
assert_equal(Nx.tensor([1, 2, 3]) |> product(), Nx.tensor(6))
assert_equal(Nx.tensor([1, 2, 3], type: {:s, 8}) |> product(), Nx.tensor(6))
assert_equal(Nx.tensor([1, 2, 3], type: {:u, 8}) |> product(), Nx.tensor(6, type: {:u, 64}))
assert_equal(Nx.tensor([1, 2, 3], type: {:u, 8}) |> product(), Nx.tensor(6, type: {:u, 32}))
assert_equal(Nx.tensor([1.0, 2.0, 3.0]) |> product(), Nx.tensor(6.0))

assert_equal(
Expand All @@ -2034,9 +2034,9 @@ defmodule EXLA.Defn.ExprTest do
defn product_equal(t), do: Nx.product(Nx.equal(t, 1.0))

test "does not overflow" do
assert_equal(product_equal(Nx.tensor(1)), Nx.tensor(1, type: {:u, 64}))
assert_equal(product_equal(Nx.tensor([1, 1, 1])), Nx.tensor(1, type: {:u, 64}))
assert_equal(product_equal(Nx.tensor([1, 2, 3])), Nx.tensor(0, type: {:u, 64}))
assert_equal(product_equal(Nx.tensor(1)), Nx.tensor(1, type: {:u, 32}))
assert_equal(product_equal(Nx.tensor([1, 1, 1])), Nx.tensor(1, type: {:u, 32}))
assert_equal(product_equal(Nx.tensor([1, 2, 3])), Nx.tensor(0, type: {:u, 32}))
end

defn product_keep(t), do: Nx.product(t, keep_axes: true)
Expand Down Expand Up @@ -2416,12 +2416,12 @@ defmodule EXLA.Defn.ExprTest do
window_max2(Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])),
Nx.tensor([
[
[-9_223_372_036_854_775_808, -9_223_372_036_854_775_808],
[-9_223_372_036_854_775_808, 6]
[-2_147_483_648, -2_147_483_648],
[-2_147_483_648, 6]
],
[
[-9_223_372_036_854_775_808, -9_223_372_036_854_775_808],
[-9_223_372_036_854_775_808, 6]
[-2_147_483_648, -2_147_483_648],
[-2_147_483_648, 6]
]
])
)
Expand Down Expand Up @@ -2482,12 +2482,12 @@ defmodule EXLA.Defn.ExprTest do
window_min2(Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])),
Nx.tensor([
[
[9_223_372_036_854_775_807, 9_223_372_036_854_775_807],
[9_223_372_036_854_775_807, 3]
[2_147_483_647, 2_147_483_647],
[2_147_483_647, 3]
],
[
[9_223_372_036_854_775_807, 9_223_372_036_854_775_807],
[9_223_372_036_854_775_807, 3]
[2_147_483_647, 2_147_483_647],
[2_147_483_647, 3]
]
])
)
Expand Down
Loading

0 comments on commit 2f3c6ef

Please sign in to comment.