diff --git a/include/triton/core/tritonbackend.h b/include/triton/core/tritonbackend.h index ad71c79a5..8e58c34cf 100644 --- a/include/triton/core/tritonbackend.h +++ b/include/triton/core/tritonbackend.h @@ -94,7 +94,7 @@ struct TRITONBACKEND_Batcher; /// } /// #define TRITONBACKEND_API_VERSION_MAJOR 1 -#define TRITONBACKEND_API_VERSION_MINOR 13 +#define TRITONBACKEND_API_VERSION_MINOR 14 /// Get the TRITONBACKEND API version supported by Triton. This value /// can be compared against the TRITONBACKEND_API_VERSION_MAJOR and @@ -1480,6 +1480,26 @@ TRITONBACKEND_BackendAttributeAddPreferredInstanceGroup( const TRITONSERVER_InstanceGroupKind kind, const uint64_t count, const uint64_t* device_ids, const uint64_t id_count); +/// Sets whether or not the backend supports concurrently loading multiple +/// TRITONBACKEND_ModelInstances in a thread-safe manner. +/// +/// Most backends are thread-safe for parallel execution of model instances as +/// that is the primary use of concurrency in backends. However, not all +/// backends are thread-safe when initializing or finalizing model instances. In +/// order for Triton to know that it can safely load instances concurrently, the +/// backend needs to opt-in by setting this backend attribute to true. By +/// default, this attribute is false and calls to the +/// TRITONBACKEND_ModelInstanceInitialize function will be made serially. If +/// this attribute is set to true, then Triton will make calls to +/// TRITONBACKEND_ModelInstanceInitialize concurrently. +/// +/// \param backend_attributes The backend attributes object. +/// \param enabled Whether or not the backend supports loading model instances +/// in parallel. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendAttributeSetParallelModelInstanceLoading( + TRITONBACKEND_BackendAttribute* backend_attributes, bool enabled); + /// TRITONBACKEND Batching /// /// API to add custom batching strategy diff --git a/src/backend_manager.cc b/src/backend_manager.cc index 5327f1ee8..d2b95d066 100644 --- a/src/backend_manager.cc +++ b/src/backend_manager.cc @@ -120,6 +120,7 @@ TritonBackend::UpdateAttributes() if (!latest.preferred_groups_.empty()) { attributes_.preferred_groups_ = latest.preferred_groups_; } + attributes_.parallel_instance_loading_ = latest.parallel_instance_loading_; return Status::Success; } diff --git a/src/backend_manager.h b/src/backend_manager.h index 0aba05802..a1aa68def 100644 --- a/src/backend_manager.h +++ b/src/backend_manager.h @@ -47,9 +47,15 @@ namespace triton { namespace core { class TritonBackend { public: struct Attribute { - Attribute() : exec_policy_(TRITONBACKEND_EXECUTION_BLOCKING) {} + Attribute() + : exec_policy_(TRITONBACKEND_EXECUTION_BLOCKING), + parallel_instance_loading_(false) + { + } TRITONBACKEND_ExecutionPolicy exec_policy_; std::vector preferred_groups_; + // Whether the backend supports loading model instances in parallel + bool parallel_instance_loading_; }; typedef TRITONSERVER_Error* (*TritonModelInitFn_t)( TRITONBACKEND_Model* model); diff --git a/src/backend_model.cc b/src/backend_model.cc index 34cb694f5..dec657680 100644 --- a/src/backend_model.cc +++ b/src/backend_model.cc @@ -406,6 +406,35 @@ TritonModel::PrepareInstances( std::vector>> existing_instances = IndexInstances(); + + std::vector> creation_results; + // Used to protect shared states for parallel instance loading + std::mutex instance_mu; + + // Deferred will be lazily evaluated when the result is requested. Since the + // creation_results are requested serially below, this is equivalent to making + // the calls serially. + auto launch_policy = std::launch::deferred; + + // Override for testing/debugging purposes + bool parallel = backend_->BackendAttributes().parallel_instance_loading_; + const char* env = std::getenv("TRITON_PARALLEL_INSTANCE_LOADING"); + if (env != nullptr) { + std::string s_env = std::string(env); + if (!s_env.empty()) { + parallel = (s_env == "1") ? true : false; + LOG_VERBOSE(1) + << "Using TRITON_PARALLEL_INSTANCE_LOADING environment variable " + "override: " + << parallel; + } + } + + // If the backend supports it, std::launch::async will allow concurrent calls + if (parallel) { + launch_policy = std::launch::async; + } + // Iterates over all the requested instances on the model config, and decides // if each requested instance can reuse an existing instance or a new instance // is needed. @@ -485,17 +514,28 @@ TritonModel::PrepareInstances( } } - // The requested instance did not match an existing instance. Create a - // new instance. - std::shared_ptr new_instance; - LOG_VERBOSE(2) << "Creating model instance named '" << instance_name - << "' with device id '" << is.device_id_ << "'"; - RETURN_IF_ERROR(TritonModelInstance::CreateInstance( - this, instance_name, signature, is.kind_, is.device_id_, - profile_names, passive, is.policy_name_, *is.rate_limiter_config_, - secondary_devices, &new_instance)); - added_instances->push_back(new_instance); - RegisterBackgroundInstance(std::move(new_instance), passive); + // Note that the local variables should be captured by value + creation_results.emplace_back( + std::async(launch_policy, [=, &instance_mu]() { + // The requested instance did not match an existing instance. + // Create a new instance. + std::shared_ptr new_instance; + RETURN_IF_ERROR(TritonModelInstance::CreateInstance( + this, instance_name, signature, is.kind_, is.device_id_, + profile_names, passive, is.policy_name_, + *is.rate_limiter_config_, secondary_devices, &new_instance)); + { + std::lock_guard lk(instance_mu); + added_instances->push_back(new_instance); + RegisterBackgroundInstance(std::move(new_instance), passive); + } + // Keep logging to a single stream operator to avoid interweaving + const auto msg = "Created model instance named '" + + instance_name + "' with device id '" + + std::to_string(is.device_id_) + "'"; + LOG_VERBOSE(2) << msg; + return Status::Success; + })); } } } @@ -506,7 +546,16 @@ TritonModel::PrepareInstances( removed_instances->end(), pair.second.begin(), pair.second.end()); } - return Status::Success; + auto status = Status::Success; + for (auto& cr : creation_results) { + auto lstatus = cr.get(); + if (!lstatus.IsOk()) { + LOG_ERROR << "ERROR: Failed to create instance: " << lstatus.Message(); + status = lstatus; + } + } + + return status; } void @@ -1675,6 +1724,16 @@ TRITONBACKEND_BackendAttributeAddPreferredInstanceGroup( return nullptr; } + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendAttributeSetParallelModelInstanceLoading( + TRITONBACKEND_BackendAttribute* backend_attributes, bool enabled) +{ + auto ba = reinterpret_cast(backend_attributes); + ba->parallel_instance_loading_ = enabled; + return nullptr; +} + } // extern C }} // namespace triton::core diff --git a/src/rate_limiter.cc b/src/rate_limiter.cc index ede9982b5..2c0f92397 100644 --- a/src/rate_limiter.cc +++ b/src/rate_limiter.cc @@ -389,12 +389,18 @@ RateLimiter::InitializePayloadQueues(const TritonModelInstance* instance) } payload_queue = payload_queues_[instance->Model()].get(); } - if (payload_queue->specific_queues_.find(instance) == - payload_queue->specific_queues_.end()) { - payload_queue->specific_queues_.emplace( - instance, - new InstanceQueue( - config.max_batch_size(), max_queue_delay_microseconds * 1000)); + { + // NOTE: payload_queue can have a data race because instance->Model() + // is the same for multiple instances of same model, so protect it when + // creating model instances in parallel. + std::lock_guard lk(payload_queue->mu_); + if (payload_queue->specific_queues_.find(instance) == + payload_queue->specific_queues_.end()) { + payload_queue->specific_queues_.emplace( + instance, + new InstanceQueue( + config.max_batch_size(), max_queue_delay_microseconds * 1000)); + } } } diff --git a/src/tritonserver_stub.cc b/src/tritonserver_stub.cc index 6087b8021..8ef16a0c4 100644 --- a/src/tritonserver_stub.cc +++ b/src/tritonserver_stub.cc @@ -1015,6 +1015,11 @@ TRITONBACKEND_BackendAttributeAddPreferredInstanceGroup() { } +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendAttributeSetParallelModelInstanceLoading() +{ +} + TRITONAPI_DECLSPEC void TRITONCACHE_ApiVersion() {