From 6dbd1b9d5a3a0a8d9c64f161e23b7f355a56e588 Mon Sep 17 00:00:00 2001 From: John Cormie Date: Wed, 14 Aug 2024 09:11:49 -0700 Subject: [PATCH] Add newAttachMetadataServerInterceptor() MetadataUtil (#11458) --- .../main/java/io/grpc/stub/MetadataUtils.java | 64 +++++++ .../java/io/grpc/stub/MetadataUtilsTest.java | 175 ++++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java diff --git a/stub/src/main/java/io/grpc/stub/MetadataUtils.java b/stub/src/main/java/io/grpc/stub/MetadataUtils.java index addf54c0f81..4208d3ca652 100644 --- a/stub/src/main/java/io/grpc/stub/MetadataUtils.java +++ b/stub/src/main/java/io/grpc/stub/MetadataUtils.java @@ -22,10 +22,15 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; +import io.grpc.ExperimentalApi; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.Status; import java.util.concurrent.atomic.AtomicReference; @@ -143,4 +148,63 @@ public void onClose(Status status, Metadata trailers) { } } } + + /** + * Returns a ServerInterceptor that adds the specified Metadata to every response stream, one way + * or another. + * + *

If, absent this interceptor, a stream would have headers, 'extras' will be added to those + * headers. Otherwise, 'extras' will be sent as trailers. This pattern is useful when you have + * some fixed information, server identity say, that should be included no matter how the call + * turns out. The fallback to trailers avoids artificially committing clients to error responses + * that could otherwise be retried (see https://grpc.io/docs/guides/retry/ for more). + * + *

For correct operation, be sure to arrange for this interceptor to run *before* any others + * that might add headers. + * + * @param extras the Metadata to be added to each stream. Caller gives up ownership. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11462") + public static ServerInterceptor newAttachMetadataServerInterceptor(Metadata extras) { + return new MetadataAttachingServerInterceptor(extras); + } + + private static final class MetadataAttachingServerInterceptor implements ServerInterceptor { + + private final Metadata extras; + + MetadataAttachingServerInterceptor(Metadata extras) { + this.extras = extras; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + return next.startCall(new MetadataAttachingServerCall<>(call), headers); + } + + final class MetadataAttachingServerCall + extends SimpleForwardingServerCall { + boolean headersSent; + + MetadataAttachingServerCall(ServerCall delegate) { + super(delegate); + } + + @Override + public void sendHeaders(Metadata headers) { + headers.merge(extras); + headersSent = true; + super.sendHeaders(headers); + } + + @Override + public void close(Status status, Metadata trailers) { + if (!headersSent) { + trailers.merge(extras); + } + super.close(status, trailers); + } + } + } } diff --git a/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java b/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java new file mode 100644 index 00000000000..f9890ac0433 --- /dev/null +++ b/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java @@ -0,0 +1,175 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.stub.MetadataUtils.newAttachMetadataServerInterceptor; +import static io.grpc.stub.MetadataUtils.newCaptureMetadataInterceptor; +import static org.junit.Assert.fail; + +import com.google.common.collect.ImmutableList; +import io.grpc.CallOptions; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptors; +import io.grpc.ServerMethodDefinition; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; +import io.grpc.StringMarshaller; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.testing.GrpcCleanupRule; +import java.io.IOException; +import java.util.Iterator; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MetadataUtilsTest { + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + private static final String SERVER_NAME = "test"; + private static final Metadata.Key FOO_KEY = + Metadata.Key.of("foo-key", Metadata.ASCII_STRING_MARSHALLER); + + private final MethodDescriptor echoMethod = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/echo") + .setType(MethodDescriptor.MethodType.UNARY) + .build(); + + private final ServerCallHandler echoCallHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + respObserver.onNext(req); + respObserver.onCompleted(); + }); + + MethodDescriptor echoServerStreamingMethod = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/echoStream") + .setType(MethodDescriptor.MethodType.SERVER_STREAMING) + .build(); + + private final AtomicReference trailersCapture = new AtomicReference<>(); + private final AtomicReference headersCapture = new AtomicReference<>(); + + @Test + public void shouldAttachHeadersToResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test").addMethod(echoMethod, echoCallHandler).build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + + String response = + ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello"); + assertThat(response).isEqualTo("hello"); + assertThat(trailersCapture.get() == null || !trailersCapture.get().containsKey(FOO_KEY)) + .isTrue(); + assertThat(headersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + @Test + public void shouldAttachTrailersWhenNoResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test") + .addMethod( + ServerMethodDefinition.create( + echoServerStreamingMethod, + ServerCalls.asyncUnaryCall( + (req, respObserver) -> respObserver.onCompleted()))) + .build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + + Iterator response = + ClientCalls.blockingServerStreamingCall( + channel, echoServerStreamingMethod, CallOptions.DEFAULT, "hello"); + assertThat(response.hasNext()).isFalse(); + assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue(); + assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + @Test + public void shouldAttachTrailersToErrorResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test") + .addMethod( + echoMethod, + ServerCalls.asyncUnaryCall( + (req, respObserver) -> + respObserver.onError(Status.INVALID_ARGUMENT.asRuntimeException()))) + .build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + try { + ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello"); + fail(); + } catch (StatusRuntimeException e) { + assertThat(e.getStatus()).isNotNull(); + assertThat(e.getStatus().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + } + assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue(); + assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + private static InProcessServerBuilder newInProcessServerBuilder() { + return InProcessServerBuilder.forName(SERVER_NAME).directExecutor(); + } + + private static InProcessChannelBuilder newInProcessChannelBuilder() { + return InProcessChannelBuilder.forName(SERVER_NAME).directExecutor(); + } +}