Skip to content

Commit

Permalink
Delay restoring thread context until after all client calls (#2798)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <widdis@gmail.com>
  • Loading branch information
dbwiddis authored Aug 4, 2024
1 parent e94acc6 commit 6173a93
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ private void handleConnectorAccessValidationFailure(String connectorId, Exceptio

private void checkForModelsUsingConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> actionListener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> restoringListener = ActionListener.runBefore(actionListener, context::restore);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId));
if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
Expand All @@ -133,26 +134,25 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A
sdkClient
.searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((sr, st) -> {
context.restore();
if (sr != null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser());
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
deleteConnector(connectorId, actionListener);
deleteConnector(connectorId, restoringListener);
} else {
handleModelsUsingConnector(searchHits, connectorId, actionListener);
handleModelsUsingConnector(searchHits, connectorId, restoringListener);
}
} catch (Exception e) {
log.error("Failed to parse search response", e);
actionListener
restoringListener
.onFailure(
new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR)
);
}
} else {
Exception cause = SdkClientUtils.unwrapAndConvertToException(st);
handleSearchFailure(connectorId, cause, actionListener);
handleSearchFailure(connectorId, cause, restoringListener);
}
});
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(sdkClient, client, context, getDataObjectRequest, connectorId, ActionListener.wrap(connector -> {
// context is already restored here
if (TenantAwareHelper
.validateTenantResource(
mlFeatureEnabledSetting,
Expand All @@ -123,7 +124,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
listener
)) {
boolean hasPermission = connectorAccessControlHelper.validateConnectorAccess(client, connector);
if (Boolean.TRUE.equals(hasPermission)) {
if (hasPermission) {
connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt);
connector.validateConnectorURL(trustedConnectorEndpointsRegex);
UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest
Expand All @@ -132,7 +133,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
.id(connectorId)
.dataObject(connector)
.build();
updateUndeployedConnector(connectorId, updateDataObjectRequest, listener, context);
try (ThreadContext.StoredContext innerContext = client.threadPool().getThreadContext().stashContext()) {
updateUndeployedConnector(
connectorId,
updateDataObjectRequest,
ActionListener.runBefore(listener, innerContext::restore)
);
}
} else {
listener
.onFailure(
Expand All @@ -155,8 +162,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
private void updateUndeployedConnector(
String connectorId,
UpdateDataObjectRequest updateDataObjectRequest,
ActionListener<UpdateResponse> listener,
ThreadContext.StoredContext context
ActionListener<UpdateResponse> listener
) {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
Expand All @@ -180,11 +186,7 @@ private void updateUndeployedConnector(
sdkClient
.updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
handleUpdateDataObjectCompletionStage(
r,
throwable,
getUpdateResponseListener(connectorId, listener, context)
);
handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener));
});
} else {
log.error(searchHits.length + " models are still using this connector, please undeploy the models first!");
Expand Down Expand Up @@ -214,11 +216,7 @@ private void updateUndeployedConnector(
sdkClient
.updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
handleUpdateDataObjectCompletionStage(
r,
throwable,
getUpdateResponseListener(connectorId, listener, context)
);
handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener));
});
return;
} else {
Expand Down Expand Up @@ -246,12 +244,8 @@ private void handleUpdateDataObjectCompletionStage(
}
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String connectorId,
ActionListener<UpdateResponse> actionListener,
ThreadContext.StoredContext context
) {
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> {
private ActionListener<UpdateResponse> getUpdateResponseListener(String connectorId, ActionListener<UpdateResponse> actionListener) {
return ActionListener.wrap(updateResponse -> {
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
log.error("Failed to update the connector with ID: {}", connectorId);
actionListener.onResponse(updateResponse);
Expand All @@ -262,6 +256,6 @@ private ActionListener<UpdateResponse> getUpdateResponseListener(
}, exception -> {
log.error("Failed to update ML connector with ID {}. Details: {}", connectorId, exception);
actionListener.onFailure(exception);
}), context::restore);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ public void getConnector(Client client, String connectorId, ActionListener<Conne
}));
}

/**
* Gets a connector with the provided clients.
* @param sdkClient The SDKClient
* @param client The OpenSearch client for thread pool management
* @param context The Stored Context. Executing this method will restore this context.
* @param getDataObjectRequest The get request
* @param connectorId The connector Id
* @param listener the action listener to complete with the GetResponse or Exception
*/
public void getConnector(
SdkClient sdkClient,
Client client,
Expand Down

0 comments on commit 6173a93

Please sign in to comment.