Skip to content

Commit

Permalink
Support paging in TfrtSavedModelServable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638471838
  • Loading branch information
SiqiaoWu1993 authored and tensorflow-copybara committed May 30, 2024
1 parent eeac086 commit 993a53c
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 14 deletions.
3 changes: 3 additions & 0 deletions tensorflow_serving/servables/tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1712,14 +1712,17 @@ cc_library(
"//tensorflow_serving/apis:inference_cc_proto",
"//tensorflow_serving/apis:predict_cc_proto",
"//tensorflow_serving/apis:regression_cc_proto",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:threadpool_options",
"@org_tensorflow//tensorflow/cc/saved_model:signature_constants",
"@org_tensorflow//tensorflow/core:framework",
Expand Down
27 changes: 15 additions & 12 deletions tensorflow_serving/servables/tensorflow/tfrt_saved_model_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tensorflow_serving/batching/tfrt_saved_model_with_batching.h"
Expand Down Expand Up @@ -104,7 +104,7 @@ absl::Status WrapSavedModelForBatching(
std::unique_ptr<BatchScheduler<SavedModelBatchingTask>>* queue) {
TF_RETURN_IF_ERROR(batch_scheduler->AddQueue(
queue_options, process_batch_callback, queue));
return OkStatus();
return absl::OkStatus();
};
std::vector<FuncNameWithBatchingSchedulerCreator>
func_name_with_batching_scheduler_creator;
Expand Down Expand Up @@ -287,26 +287,29 @@ absl::Status TfrtSavedModelFactory::CreateTfrtSavedModelWithMetadata(

MaybePublishMLMDStreamz(path, metadata.servable_id.name,
metadata.servable_id.version);
TF_ASSIGN_OR_RETURN(auto saved_model_config,
LoadSavedModelConfigOrDefault(path));

*servable = std::make_unique<TfrtSavedModelServable>(
metadata.servable_id.name, metadata.servable_id.version, config_,
saved_model_config, std::move(saved_model), thread_pool_factory_.get(),
recorder_creator_);
TfrtSavedModelServable* tfrt_servable =
down_cast<TfrtSavedModelServable*>(servable->get());

if (config().enable_model_warmup()) {
auto* warmup_options = mutable_config().mutable_model_warmup_options();
warmup_options->set_model_name(metadata.servable_id.name);
warmup_options->set_model_version(metadata.servable_id.version);
TF_RETURN_IF_ERROR(RunSavedModelWarmup(
*warmup_options, path, config().lazy_init_threshold(),
config().skip_warmup_requests_if_initialized(), saved_model.get()));
config().skip_warmup_requests_if_initialized(),
&tfrt_servable->saved_model()));
if (config().freeze_after_init()) {
TF_RETURN_IF_ERROR(Freeze(*saved_model));
TF_RETURN_IF_ERROR(Freeze(tfrt_servable->saved_model()));
}
}

TF_ASSIGN_OR_RETURN(auto saved_model_config,
LoadSavedModelConfigOrDefault(path));

*servable = std::make_unique<TfrtSavedModelServable>(
metadata.servable_id.name, metadata.servable_id.version, config_,
saved_model_config, std::move(saved_model), thread_pool_factory_.get(),
recorder_creator_);

return absl::OkStatus();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class TfrtSavedModelFactory {
/// @param metadata Metadata to be associated with the saved_model.
/// @param path Path to the model.
/// @param servable Newly created Servable if the returned Status is OK.
absl::Status CreateTfrtSavedModelWithMetadata(
virtual absl::Status CreateTfrtSavedModelWithMetadata(
const Loader::Metadata& metadata, const string& path,
std::unique_ptr<Servable>* servable);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ message TfrtSavedModelConfig {
// Whether to enable core selector for TFRT+IFRT. It is only effective when
// tfrt_use_ifrt is true.
bool enable_tfrt_use_ifrt_core_selector = 2021;

// Whether to enable paging. This should only be true when using Pathways
// backend.
bool enable_paging = 2022;
}

// Config proto for TfrtSavedModelSourceAdapter.
Expand Down
34 changes: 33 additions & 1 deletion tensorflow_serving/servables/tensorflow/tfrt_servable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/tracing.h" // NOLINT
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -199,6 +199,38 @@ absl::Status TfrtSavedModelServable::MultiInference(
saved_model_.get(), request, response);
}

absl::Status TfrtSavedModelServable::Suspend() {
TRACELITERAL("TfrtSavedModelServable::Suspend");
absl::MutexLock lock(&paging_mu_);
if (!suspend_fn_) {
return absl::UnimplementedError("Suspend is not implemented");
}
if (suspended_) {
return absl::OkStatus();
}
absl::Status status = suspend_fn_(this);
if (status.ok()) {
suspended_ = true;
}
return status;
}

absl::Status TfrtSavedModelServable::Resume() {
TRACELITERAL("TfrtSavedModelServable::Resume");
absl::MutexLock lock(&paging_mu_);
if (!resume_fn_) {
return absl::UnimplementedError("Resume is not implemented");
}
if (!suspended_) {
return absl::OkStatus();
}
absl::Status status = resume_fn_(this);
if (!status.ok()) {
suspended_ = false;
}
return status;
}

namespace {

absl::Status ValidateGetModelMetadataRequest(
Expand Down
31 changes: 31 additions & 0 deletions tensorflow_serving/servables/tensorflow/tfrt_servable.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ limitations under the License.
#include <functional>
#include <memory>
#include <string>
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_set.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
#include "tensorflow_serving/apis/classification.pb.h"
#include "tensorflow_serving/apis/get_model_metadata.pb.h"
Expand Down Expand Up @@ -94,8 +97,26 @@ class TfrtSavedModelServable : public Servable {
absl::Status GetModelMetadata(const GetModelMetadataRequest& request,
GetModelMetadataResponse* response) override;

bool SupportsPaging() const override { return true; }

absl::Status Suspend() override;

absl::Status Resume() override;

tfrt_stub::SavedModel& saved_model() const { return *saved_model_; }

void set_resume_fn(
absl::AnyInvocable<absl::Status(TfrtSavedModelServable*)> resume_fn) {
absl::MutexLock lock(&paging_mu_);
resume_fn_ = std::move(resume_fn);
}

void set_suspend_fn(
absl::AnyInvocable<absl::Status(TfrtSavedModelServable*)> suspend_fn) {
absl::MutexLock lock(&paging_mu_);
suspend_fn_ = std::move(suspend_fn);
}

private:
tfrt_stub::SavedModel::RunOptions GetTFRTSavedModelRunOptions(
const Servable::RunOptions& run_options) const;
Expand All @@ -121,6 +142,16 @@ class TfrtSavedModelServable : public Servable {

std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
recorder_creator_ = [](TfrtSavedModelServable&) { return nullptr; };

absl::AnyInvocable<absl::Status(TfrtSavedModelServable*)> suspend_fn_
ABSL_GUARDED_BY(paging_mu_);

absl::AnyInvocable<absl::Status(TfrtSavedModelServable*)> resume_fn_
ABSL_GUARDED_BY(paging_mu_);

bool suspended_ ABSL_GUARDED_BY(paging_mu_) = false;

absl::Mutex paging_mu_;
};

// Creates a TfrtSavedModelServable from `saved_model_dir`.
Expand Down

0 comments on commit 993a53c

Please sign in to comment.