From f8f1e48ab8356b19027baf0b7eff39f418e51adf Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 14 Sep 2024 11:48:55 -0300 Subject: [PATCH] fix: use shared mlir context thread pool --- exla/c_src/exla/exla.cc | 40 +++++++++++++++++++++++++++--- exla/lib/exla/application.ex | 6 +++-- exla/lib/exla/mlir/context_pool.ex | 11 ++++++-- exla/lib/exla/nif.ex | 3 ++- 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index c91b9fb621..13f7e39572 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -11,6 +11,7 @@ #include "stablehlo/dialect/StablehloOps.h" #include "xla/pjrt/pjrt_api.h" #include "xla/service/platform_util.h" +#include "llvm/Support/ThreadPool.h" // All of these are created with calls to `new` and subsequently // passed to the VM as pointers-to-pointers so we balance it out @@ -69,6 +70,10 @@ static int open_resources(ErlNifEnv* env) { if (!exla::nif::open_resource(env, mod, "MLIRContext")) { return -1; } + + if (!exla::nif::open_resource(env, mod, "TheadPool")) { + return -1; + } return 1; } @@ -150,12 +155,40 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::ok(env, exla::nif::make(env, executable)); } + +ERL_NIF_TERM mlir_new_thread_pool(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 1) { + return exla::nif::error(env, "Bad argument count."); + } + + int concurrency; + + if (!exla::nif::get(env, argv[0], &concurrency)) { + return exla::nif::error(env, "Unable to get concurrency."); + } + + llvm::ThreadPoolStrategy strategy = llvm::hardware_concurrency(concurrency); + llvm::StdThreadPool* pool = new llvm::StdThreadPool(strategy); + + auto ret = exla::nif::make(env, pool); + return exla::nif::ok(env, ret); +} + ERL_NIF_TERM mlir_new_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 0) { + if (argc != 1) { return exla::nif::error(env, "Bad argument count."); } - mlir::MLIRContext* context = new mlir::MLIRContext(); + llvm::StdThreadPool** thread_pool; + + if (!exla::nif::get(env, argv[0], thread_pool)) { + return exla::nif::error(env, "Unable to get thread pool."); + } + + mlir::MLIRContext* context = new mlir::MLIRContext(mlir::MLIRContext::Threading::DISABLED); + + auto interface_ptr = reinterpret_cast(*thread_pool); + context->setThreadPool(*interface_ptr); context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); @@ -909,7 +942,8 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) static ErlNifFunc exla_funcs[] = { // MLIR Builder - {"mlir_new_context", 0, mlir_new_context}, + {"mlir_new_thread_pool", 1, mlir_new_thread_pool}, + {"mlir_new_context", 1, mlir_new_context}, {"mlir_new_module", 1, mlir_new_module}, {"mlir_create_function", 5, mlir_create_function}, {"mlir_get_function_arguments", 1, mlir_get_function_arguments}, diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 3bdfa30d0c..9ec098a3e6 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -10,11 +10,13 @@ defmodule EXLA.Application do _ -> :os.set_signal(:sigchld, :default) end + pool_size = System.schedulers_online() + children = [ EXLA.Logger, {NimblePool, - worker: {EXLA.MLIR.ContextPool, :pool_state}, - pool_size: System.schedulers_online(), + worker: {EXLA.MLIR.ContextPool, %{pool_size: pool_size}}, + pool_size: pool_size, name: EXLA.MLIR.ContextPool, lazy: true}, EXLA.Client, diff --git a/exla/lib/exla/mlir/context_pool.ex b/exla/lib/exla/mlir/context_pool.ex index 14cf11429f..9e0d8155bb 100644 --- a/exla/lib/exla/mlir/context_pool.ex +++ b/exla/lib/exla/mlir/context_pool.ex @@ -13,8 +13,15 @@ defmodule EXLA.MLIR.ContextPool do end @impl NimblePool - def init_worker(pool_state) do - {:ok, context} = EXLA.NIF.mlir_new_context() + def init_pool(%{pool_size: pool_size}) do + {:ok, thread_pool} = EXLA.NIF.mlir_new_thread_pool(pool_size) + + {:ok, %{thread_pool: thread_pool}} + end + + @impl NimblePool + def init_worker(%{thread_pool: thread_pool} = pool_state) do + {:ok, context} = EXLA.NIF.mlir_new_context(thread_pool) {:ok, context, pool_state} end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 6830df726c..be0567cc0a 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -7,7 +7,8 @@ defmodule EXLA.NIF do :erlang.load_nif(path, 0) end - def mlir_new_context, do: :erlang.nif_error(:undef) + def mlir_new_thread_pool(_concurrency), do: :erlang.nif_error(:undef) + def mlir_new_context(_thread_pool_ref), do: :erlang.nif_error(:undef) def mlir_new_module(_context), do: :erlang.nif_error(:undef)