Skip to content

Commit

Permalink
fix: add support for defn compilation in EXLA.to_mlir_module (#1530)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Sep 5, 2024
1 parent 2e140b7 commit ad28ea7
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 12 deletions.
37 changes: 25 additions & 12 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,18 @@ defmodule EXLA do
Takes in a function, the argument templates and the compilation
options and returns the textual representation of the MLIR module.
## Options
* `:within_defn_compiler` - a boolean that indicates whether
this function is being called from within a `defn` compiler.
Defaults to `false`.
## Examples
iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end
iex> args = [1.0, 2.0]
iex> EXLA.to_mlir_module(fun, args)
iex> %{mlir_module: mlir_module} = EXLA.to_mlir_module(fun, args)
iex> mlir_module
"""
module {
func.func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
Expand All @@ -377,20 +384,26 @@ defmodule EXLA do
"""
'''
def to_mlir_module(function, args, options \\ []) do
comp_fun = fn _key, callback ->
{:ok, {_xla_time, executable, _extra, _outfeed}} = callback.()
throw({:mlir_module, executable.ref})
end
{nested_compilation?, options} = Keyword.pop(options, :within_defn_compiler, false)

opts = [
{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}},
{:module_compilation, :to_mlir} | options
]
opts =
Keyword.merge(options,
module_compilation: :to_mlir,
compiler: EXLA
)

jit_apply(function, args, opts)
if nested_compilation? do
EXLA.Defn.__compile__(function, args, function, opts)
else
Nx.Defn.compile(function, args, opts)
end
catch
{:mlir_module, ref} ->
EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
{:mlir_module, ref, used_inputs, output_container} ->
%{
used_inputs: used_inputs,
output_container: output_container,
mlir_module: EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
}
end

@doc """
Expand Down
4 changes: 4 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ defmodule EXLA.Defn do
{executable, {used_inputs, outputs, outfeed, _input_typespecs?}} =
compile(key, vars, fun, compile_options, 0, [], _stream = false, callback)

if compile_options[:module_compilation] == :to_mlir do
throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs})
end

fn [args] ->
{time, lock} =
:timer.tc(fn ->
Expand Down
53 changes: 53 additions & 0 deletions exla/test/exla_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,57 @@ defmodule EXLATest do
end
end
end

defmodule ValidCompiler do
def __jit__(key, vars, fun, args_list, opts) do
__compile__(key, vars, fun, opts).(args_list)
end

def __compile__(_key, vars, fun, opts) do
result = EXLA.to_mlir_module(fun, vars, Keyword.put(opts, :within_defn_compiler, true))
throw({__MODULE__, result})
end
end

defmodule InvalidCompiler do
def __jit__(key, vars, fun, args_list, opts) do
__compile__(key, vars, fun, opts).(args_list)
end

def __compile__(_key, vars, fun, opts) do
# Keyword.delete to ensure default is false
EXLA.to_mlir_module(fun, vars, Keyword.delete(opts, :within_defn_compiler))
end
end

describe "to_mlir_module/3" do
test "fails if the compiler doesn't set the nested compilation flag" do
assert_raise BadArityError, fn ->
Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.InvalidCompiler)
end
end

test "works if the compiler sets the nested compilation flag" do
try do
Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.ValidCompiler)
catch
{__MODULE__.ValidCompiler, result} ->
assert %{mlir_module: module, output_container: container, used_inputs: used_inputs} =
result

assert module == """
module {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32>
return %0 : tensor<i32>
}
}
"""

assert Nx.compatible?(container, Nx.template({}, :s32))

assert MapSet.equal?(used_inputs, MapSet.new([0, 1]))
end
end
end
end

0 comments on commit ad28ea7

Please sign in to comment.