Skip to content

Commit

Permalink
fix: use shared mlir context thread pool (#1534)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Sep 14, 2024
1 parent 8a9c2b3 commit 116b124
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 8 deletions.
40 changes: 37 additions & 3 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"
#include "llvm/Support/ThreadPool.h"

// All of these are created with calls to `new` and subsequently
// passed to the VM as pointers-to-pointers so we balance it out
Expand Down Expand Up @@ -69,6 +70,10 @@ static int open_resources(ErlNifEnv* env) {
if (!exla::nif::open_resource<mlir::MLIRContext*>(env, mod, "MLIRContext")) {
return -1;
}

if (!exla::nif::open_resource<llvm::StdThreadPool*>(env, mod, "TheadPool")) {
return -1;
}
return 1;
}

Expand Down Expand Up @@ -150,12 +155,40 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::ok(env, exla::nif::make<exla::ExlaExecutable*>(env, executable));
}


ERL_NIF_TERM mlir_new_thread_pool(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
}

int concurrency;

if (!exla::nif::get(env, argv[0], &concurrency)) {
return exla::nif::error(env, "Unable to get concurrency.");
}

llvm::ThreadPoolStrategy strategy = llvm::hardware_concurrency(concurrency);
llvm::StdThreadPool* pool = new llvm::StdThreadPool(strategy);

auto ret = exla::nif::make<llvm::StdThreadPool*>(env, pool);
return exla::nif::ok(env, ret);
}

ERL_NIF_TERM mlir_new_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 0) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
}

mlir::MLIRContext* context = new mlir::MLIRContext();
llvm::StdThreadPool** thread_pool;

if (!exla::nif::get<llvm::StdThreadPool*>(env, argv[0], thread_pool)) {
return exla::nif::error(env, "Unable to get thread pool.");
}

mlir::MLIRContext* context = new mlir::MLIRContext(mlir::MLIRContext::Threading::DISABLED);

auto interface_ptr = reinterpret_cast<llvm::ThreadPoolInterface*>(*thread_pool);
context->setThreadPool(*interface_ptr);
context->getOrLoadDialect<mlir::func::FuncDialect>();
context->getOrLoadDialect<mlir::stablehlo::StablehloDialect>();
context->getOrLoadDialect<mlir::mhlo::MhloDialect>();
Expand Down Expand Up @@ -909,7 +942,8 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])

static ErlNifFunc exla_funcs[] = {
// MLIR Builder
{"mlir_new_context", 0, mlir_new_context},
{"mlir_new_thread_pool", 1, mlir_new_thread_pool},
{"mlir_new_context", 1, mlir_new_context},
{"mlir_new_module", 1, mlir_new_module},
{"mlir_create_function", 5, mlir_create_function},
{"mlir_get_function_arguments", 1, mlir_get_function_arguments},
Expand Down
6 changes: 4 additions & 2 deletions exla/lib/exla/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ defmodule EXLA.Application do
_ -> :os.set_signal(:sigchld, :default)
end

pool_size = System.schedulers_online()

children = [
EXLA.Logger,
{NimblePool,
worker: {EXLA.MLIR.ContextPool, :pool_state},
pool_size: System.schedulers_online(),
worker: {EXLA.MLIR.ContextPool, %{pool_size: pool_size}},
pool_size: pool_size,
name: EXLA.MLIR.ContextPool,
lazy: true},
EXLA.Client,
Expand Down
11 changes: 9 additions & 2 deletions exla/lib/exla/mlir/context_pool.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@ defmodule EXLA.MLIR.ContextPool do
end

@impl NimblePool
def init_worker(pool_state) do
{:ok, context} = EXLA.NIF.mlir_new_context()
def init_pool(%{pool_size: pool_size}) do
{:ok, thread_pool} = EXLA.NIF.mlir_new_thread_pool(pool_size)

{:ok, %{thread_pool: thread_pool}}
end

@impl NimblePool
def init_worker(%{thread_pool: thread_pool} = pool_state) do
{:ok, context} = EXLA.NIF.mlir_new_context(thread_pool)
{:ok, context, pool_state}
end

Expand Down
3 changes: 2 additions & 1 deletion exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ defmodule EXLA.NIF do
:erlang.load_nif(path, 0)
end

def mlir_new_context, do: :erlang.nif_error(:undef)
def mlir_new_thread_pool(_concurrency), do: :erlang.nif_error(:undef)
def mlir_new_context(_thread_pool_ref), do: :erlang.nif_error(:undef)

def mlir_new_module(_context), do: :erlang.nif_error(:undef)

Expand Down

0 comments on commit 116b124

Please sign in to comment.