Skip to content

Commit

Permalink
fix workerthread hang due to grpc client cancel connection (pytorch#1854
Browse files Browse the repository at this point in the history
)

* check grpc client cancel connection

* grpc setOnCancelHandler

Co-authored-by: Aaqib <maaquib@gmail.com>
  • Loading branch information
2 people authored and jagadeeshi2i committed Nov 1, 2022
1 parent b7b1060 commit 944c48e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
import io.grpc.Status;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.net.HttpURLConnection;
import java.util.Map;
Expand All @@ -25,11 +26,23 @@
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.pytorch.serve.wlm.ModelManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InferenceImpl extends InferenceAPIsServiceImplBase {
private static final Logger logger = LoggerFactory.getLogger(InferenceImpl.class);

@Override
public void ping(Empty request, StreamObserver<TorchServeHealthResponse> responseObserver) {
((ServerCallStreamObserver<TorchServeHealthResponse>) responseObserver)
.setOnCancelHandler(
() -> {
logger.warn("grpc client call already cancelled");
responseObserver.onError(
io.grpc.Status.CANCELLED
.withDescription("call already cancelled")
.asRuntimeException());
});
Runnable r =
() -> {
String response = ApiUtils.getWorkerStatus();
Expand All @@ -49,6 +62,15 @@ public void ping(Empty request, StreamObserver<TorchServeHealthResponse> respons
@Override
public void predictions(
PredictionsRequest request, StreamObserver<PredictionResponse> responseObserver) {
((ServerCallStreamObserver<PredictionResponse>) responseObserver)
.setOnCancelHandler(
() -> {
logger.warn("grpc client call already cancelled");
responseObserver.onError(
io.grpc.Status.CANCELLED
.withDescription("call already cancelled")
.asRuntimeException());
});
String modelName = request.getModelName();
String modelVersion = request.getModelVersion();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.pytorch.serve.grpcimpl;

import io.grpc.Status;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
Expand All @@ -26,12 +27,24 @@
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.wlm.ModelManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ManagementImpl extends ManagementAPIsServiceImplBase {
private static final Logger logger = LoggerFactory.getLogger(ManagementImpl.class);

@Override
public void describeModel(
DescribeModelRequest request, StreamObserver<ManagementResponse> responseObserver) {
((ServerCallStreamObserver<ManagementResponse>) responseObserver)
.setOnCancelHandler(
() -> {
logger.warn("grpc client call already cancelled");
responseObserver.onError(
io.grpc.Status.CANCELLED
.withDescription("call already cancelled")
.asRuntimeException());
});
String requestId = UUID.randomUUID().toString();
RequestInput input = new RequestInput(requestId);
String modelName = request.getModelName();
Expand Down Expand Up @@ -67,6 +80,15 @@ public void describeModel(
@Override
public void listModels(
ListModelsRequest request, StreamObserver<ManagementResponse> responseObserver) {
((ServerCallStreamObserver<ManagementResponse>) responseObserver)
.setOnCancelHandler(
() -> {
logger.warn("grpc client call already cancelled");
responseObserver.onError(
io.grpc.Status.CANCELLED
.withDescription("call already cancelled")
.asRuntimeException());
});
int limit = request.getLimit();
int pageToken = request.getNextPageToken();

Expand All @@ -77,6 +99,15 @@ public void listModels(
@Override
public void registerModel(
RegisterModelRequest request, StreamObserver<ManagementResponse> responseObserver) {
((ServerCallStreamObserver<ManagementResponse>) responseObserver)
.setOnCancelHandler(
() -> {
logger.warn("grpc client call already cancelled");
responseObserver.onError(
io.grpc.Status.CANCELLED
.withDescription("call already cancelled")
.asRuntimeException());
});
org.pytorch.serve.http.messages.RegisterModelRequest registerModelRequest =
new org.pytorch.serve.http.messages.RegisterModelRequest(request);

Expand All @@ -98,6 +129,15 @@ public void registerModel(
@Override
public void scaleWorker(
ScaleWorkerRequest request, StreamObserver<ManagementResponse> responseObserver) {
((ServerCallStreamObserver<ManagementResponse>) responseObserver)
.setOnCancelHandler(
() -> {
logger.warn("grpc client call already cancelled");
responseObserver.onError(
io.grpc.Status.CANCELLED
.withDescription("call already cancelled")
.asRuntimeException());
});
int minWorkers = GRPCUtils.getRegisterParam(request.getMinWorker(), 1);
int maxWorkers = GRPCUtils.getRegisterParam(request.getMaxWorker(), minWorkers);
String modelName = GRPCUtils.getRegisterParam(request.getModelName(), null);
Expand Down Expand Up @@ -128,6 +168,15 @@ public void scaleWorker(
@Override
public void setDefault(
SetDefaultRequest request, StreamObserver<ManagementResponse> responseObserver) {
((ServerCallStreamObserver<ManagementResponse>) responseObserver)
.setOnCancelHandler(
() -> {
logger.warn("grpc client call already cancelled");
responseObserver.onError(
io.grpc.Status.CANCELLED
.withDescription("call already cancelled")
.asRuntimeException());
});
String modelName = request.getModelName();
String newModelVersion = request.getModelVersion();

Expand All @@ -142,6 +191,15 @@ public void setDefault(
@Override
public void unregisterModel(
UnregisterModelRequest request, StreamObserver<ManagementResponse> responseObserver) {
((ServerCallStreamObserver<ManagementResponse>) responseObserver)
.setOnCancelHandler(
() -> {
logger.warn("grpc client call already cancelled");
responseObserver.onError(
io.grpc.Status.CANCELLED
.withDescription("call already cancelled")
.asRuntimeException());
});
try {
String modelName = request.getModelName();
if (modelName == null || ("").equals(modelName)) {
Expand Down

0 comments on commit 944c48e

Please sign in to comment.