Skip to content

Commit

Permalink
Add timeout support when waiting on servables to load.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635590139
  • Loading branch information
tensorflower-gardener authored and tensorflow-copybara committed May 20, 2024
1 parent 31ec013 commit 093d841
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 8 deletions.
1 change: 1 addition & 0 deletions tensorflow_serving/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ cc_library(
":servable_id",
":servable_state",
"//tensorflow_serving/util:event_bus",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@org_tensorflow//tensorflow/core:lib",
],
Expand Down
18 changes: 14 additions & 4 deletions tensorflow_serving/core/servable_state_monitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/time/time.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow_serving/core/servable_state.h"
Expand Down Expand Up @@ -234,11 +235,11 @@ void ServableStateMonitor::Notify(const NotifyFn& notify_fn) {
notify_fns_.push_back(notify_fn);
}

bool ServableStateMonitor::WaitUntilServablesReachState(
bool ServableStateMonitor::WaitUntilServablesReachStateWithTimeout(
const std::vector<ServableRequest>& servables,
const ServableState::ManagerState goal_state,
const ServableState::ManagerState goal_state, absl::Duration timeout,
std::map<ServableId, ServableState::ManagerState>* const states_reached) {
bool reached_goal_state;
bool reached_goal_state = false;
Notification notified;
NotifyWhenServablesReachState(
servables, goal_state,
Expand All @@ -251,10 +252,19 @@ bool ServableStateMonitor::WaitUntilServablesReachState(
reached_goal_state = incoming_reached_goal_state;
notified.Notify();
});
notified.WaitForNotification();
notified.WaitForNotificationWithTimeout(timeout);
return reached_goal_state;
}

bool ServableStateMonitor::WaitUntilServablesReachState(
const std::vector<ServableRequest>& servables,
const ServableState::ManagerState goal_state,
std::map<ServableId, ServableState::ManagerState>* const states_reached) {
return WaitUntilServablesReachStateWithTimeout(
servables, goal_state,
/*timeout=*/absl::InfiniteDuration(), states_reached);
}

void ServableStateMonitor::PreHandleEvent(
const EventBus<ServableState>::EventAndTime& state_and_time) {}

Expand Down
11 changes: 10 additions & 1 deletion tensorflow_serving/core/servable_state_monitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <functional>
#include <map>

#include "absl/time/time.h"
#include "absl/types/optional.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
Expand Down Expand Up @@ -156,11 +157,19 @@ class ServableStateMonitor {
///
/// To understand the return value and the return parameter 'states_reached',
/// please read the documentation on NotifyWhenServablesReachState(...).
/// WaitUntilServablesReachStateWithTimeout and WaitUntilServablesReachState
/// perform the same function, but the former has a timeout while the latter
/// waits indefinitely.
bool WaitUntilServablesReachStateWithTimeout(
const std::vector<ServableRequest>& servables,
ServableState::ManagerState goal_state, absl::Duration timeout,
std::map<ServableId, ServableState::ManagerState>* states_reached =
nullptr) TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
bool WaitUntilServablesReachState(
const std::vector<ServableRequest>& servables,
ServableState::ManagerState goal_state,
std::map<ServableId, ServableState::ManagerState>* states_reached =
nullptr) TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
nullptr) TF_MUST_USE_RESULT;

// Subscribes to all servable state changes hitting this monitor. This is
// called after the monitor updates its own state based on the event.
Expand Down
1 change: 1 addition & 0 deletions tensorflow_serving/model_servers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ cc_library(
"//tensorflow_serving/util:event_bus",
"//tensorflow_serving/util:unique_ptr_with_deps",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@com_google_protobuf//:cc_wkt_protos",
"@org_tensorflow//tensorflow/core:lib",
Expand Down
9 changes: 6 additions & 3 deletions tensorflow_serving/model_servers/server_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow_serving/config/file_system_storage_path_source.pb.h"
#include "tensorflow_serving/core/load_servables_fast.h"
#include "tensorflow_serving/core/servable_state_monitor.h"
#include "tensorflow_serving/model_servers/model_platform_types.h"
#include "tensorflow_serving/resources/resource_values.h"
#include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.h"
Expand Down Expand Up @@ -296,9 +297,10 @@ Status ServerCore::WaitUntilModelsAvailable(const std::set<string>& models,
awaited_servables.push_back(ServableRequest::Latest(model));
}
std::map<ServableId, ServableState::ManagerState> states_reached;
const bool all_models_available = monitor->WaitUntilServablesReachState(
awaited_servables, ServableState::ManagerState::kAvailable,
&states_reached);
const bool all_models_available =
monitor->WaitUntilServablesReachStateWithTimeout(
awaited_servables, ServableState::ManagerState::kAvailable,
options_.servable_state_waiter_timeout, &states_reached);
if (!all_models_available) {
const int num_unavailable_models = std::count_if(
states_reached.begin(), states_reached.end(),
Expand Down Expand Up @@ -367,6 +369,7 @@ Status ServerCore::AddModelsViaModelConfigList() {
} else {
// Create a fresh servable state monitor, to avoid getting confused if we're
// re-loading a model-version that has previously been unloaded.

ServableStateMonitor fresh_servable_state_monitor(
servable_event_bus_.get());

Expand Down
3 changes: 3 additions & 0 deletions tensorflow_serving/model_servers/server_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.

#include "google/protobuf/any.pb.h"
#include "absl/base/macros.h"
#include "absl/time/time.h"
#include "absl/types/optional.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cpu_info.h"
Expand Down Expand Up @@ -207,6 +208,8 @@ class ServerCore : public Manager {
// If true, propagate current context to children threads (periodic
// functions) in AspiredVersionsManager.
bool with_current_context = false;

absl::Duration servable_state_waiter_timeout = absl::InfiniteDuration();
};

virtual ~ServerCore() = default;
Expand Down

0 comments on commit 093d841

Please sign in to comment.