Skip to content

Commit

Permalink
Clean up internal compile interface
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Sep 5, 2024
1 parent e9b3d73 commit 2e140b7
Showing 1 changed file with 14 additions and 45 deletions.
59 changes: 14 additions & 45 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,8 @@ defmodule EXLA.Defn do

@doc false
def __stream__(key, input, acc, vars, fun, [args], options) do
{debug?, options} = Keyword.pop(options, :debug, false)
{run_options, compile_options} = Keyword.pop(options, :run_options, [])

{client_name, compile_options} =
Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0)

client = EXLA.Client.fetch!(client_name)
debug? = Keyword.get(compile_options, :debug, false)
compile_options = Keyword.put(compile_options, :lazy_transfers, :never)

input_length = length(Nx.Defn.Composite.flatten_list([input]))
Expand All @@ -50,21 +45,10 @@ defmodule EXLA.Defn do
used_inputs = Enum.to_list(input_length..(input_length + acc_length - 1)//1)

comp_fun =
&to_stream_computation(client, input_length, acc_length, &1, &2, &3, &4, compile_options)
&to_stream_computation(input_length, acc_length, &1, &2, &3, &4, &5, compile_options)

{executable, {used_inputs, {output, acc_output}, outfeed, input_typespecs}} =
compile(
client,
key,
vars,
fun,
compile_options,
used_buffers,
used_inputs,
_stream = true,
debug?,
comp_fun
)
compile(key, vars, fun, compile_options, used_buffers, used_inputs, true, comp_fun)

# Now discard the infeed from used inputs, similar to how it is done to buffers.
# Note we discard all lazy transfers too, as they are not possible with streams.
Expand Down Expand Up @@ -136,13 +120,13 @@ defmodule EXLA.Defn do
end

defp to_stream_computation(
client,
input_length,
acc_length,
%Function{} = builder,
expr,
used_typespecs,
outfeed,
client,
options
) do
%{token: root_token, infeeds: []} = outfeed
Expand Down Expand Up @@ -237,18 +221,12 @@ defmodule EXLA.Defn do

@doc false
def __compile__(key, vars, fun, options) do
{debug?, options} = Keyword.pop(options, :debug, false)
{run_options, compile_options} = Keyword.pop(options, :run_options, [])

{client_name, compile_options} =
Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0)

client = EXLA.Client.fetch!(client_name)

callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))
debug? = Keyword.get(compile_options, :debug, false)
callback = &to_root_computation(&1, &2, &3, &4, &5, compile_options)

{executable, {used_inputs, outputs, outfeed, _input_typespecs?}} =
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback)
compile(key, vars, fun, compile_options, 0, [], _stream = false, callback)

fn [args] ->
{time, lock} =
Expand All @@ -270,14 +248,12 @@ defmodule EXLA.Defn do
end
end

defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do
defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do
params =
Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg ->
{pos, arg}
end)

client = Keyword.fetch!(options, :client)

unless client do
raise ArgumentError, "missing client"
end
Expand Down Expand Up @@ -342,22 +318,15 @@ defmodule EXLA.Defn do

## Compile

defp compile(
client,
key,
vars,
fun,
options,
used_buffers,
used_inputs,
stream?,
debug?,
to_computation
) do
defp compile(key, vars, fun, options, used_buffers, used_inputs, stream?, to_computation) do
{cache, options} = Keyword.pop(options, :cache, true)
{hooks, options} = Keyword.pop(options, :hooks, %{})
{debug?, options} = Keyword.pop(options, :debug, false)
{lazy_transfers, options} = Keyword.pop(options, :lazy_transfers, :opt_in)

{client_name, options} = Keyword.pop_lazy(options, :client, &EXLA.Client.default_name/0)
client = EXLA.Client.fetch!(client_name)

{args_key, reverse_args_identifiers} =
Enum.map_reduce(vars, [], fn var, acc ->
Nx.Defn.Composite.traverse(var, acc, fn
Expand Down Expand Up @@ -453,7 +422,7 @@ defmodule EXLA.Defn do
end

expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)
outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed)
outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed, client)

{xla_time, executable} =
:timer.tc(fn ->
Expand Down

0 comments on commit 2e140b7

Please sign in to comment.