From 9aa97f9cb41be5555bb5d255eca0eacd1fc00422 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Wed, 10 Jun 2020 18:00:24 -0700 Subject: [PATCH] stream: fix calloption.After() race in finish (#3672) --- rpc_util.go | 45 ++++++++++++++------------------- stream.go | 20 +++++++-------- test/end2end_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 37 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index cf9dbe7fd97d..8644b8a7d0da 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -155,7 +155,6 @@ func (d *gzipDecompressor) Type() string { type callInfo struct { compressorType string failFast bool - stream ClientStream maxReceiveMessageSize *int maxSendMessageSize *int creds credentials.PerRPCCredentials @@ -180,7 +179,7 @@ type CallOption interface { // after is called after the call has completed. after cannot return an // error, so any failures should be reported via output parameters. - after(*callInfo) + after(*callInfo, *csAttempt) } // EmptyCallOption does not alter the Call configuration. @@ -188,8 +187,8 @@ type CallOption interface { // by interceptors. type EmptyCallOption struct{} -func (EmptyCallOption) before(*callInfo) error { return nil } -func (EmptyCallOption) after(*callInfo) {} +func (EmptyCallOption) before(*callInfo) error { return nil } +func (EmptyCallOption) after(*callInfo, *csAttempt) {} // Header returns a CallOptions that retrieves the header metadata // for a unary RPC. @@ -205,10 +204,8 @@ type HeaderCallOption struct { } func (o HeaderCallOption) before(c *callInfo) error { return nil } -func (o HeaderCallOption) after(c *callInfo) { - if c.stream != nil { - *o.HeaderAddr, _ = c.stream.Header() - } +func (o HeaderCallOption) after(c *callInfo, attempt *csAttempt) { + *o.HeaderAddr, _ = attempt.s.Header() } // Trailer returns a CallOptions that retrieves the trailer metadata @@ -225,10 +222,8 @@ type TrailerCallOption struct { } func (o TrailerCallOption) before(c *callInfo) error { return nil } -func (o TrailerCallOption) after(c *callInfo) { - if c.stream != nil { - *o.TrailerAddr = c.stream.Trailer() - } +func (o TrailerCallOption) after(c *callInfo, attempt *csAttempt) { + *o.TrailerAddr = attempt.s.Trailer() } // Peer returns a CallOption that retrieves peer information for a unary RPC. @@ -245,11 +240,9 @@ type PeerCallOption struct { } func (o PeerCallOption) before(c *callInfo) error { return nil } -func (o PeerCallOption) after(c *callInfo) { - if c.stream != nil { - if x, ok := peer.FromContext(c.stream.Context()); ok { - *o.PeerAddr = *x - } +func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) { + if x, ok := peer.FromContext(attempt.s.Context()); ok { + *o.PeerAddr = *x } } @@ -285,7 +278,7 @@ func (o FailFastCallOption) before(c *callInfo) error { c.failFast = o.FailFast return nil } -func (o FailFastCallOption) after(c *callInfo) {} +func (o FailFastCallOption) after(c *callInfo, attempt *csAttempt) {} // MaxCallRecvMsgSize returns a CallOption which sets the maximum message size // in bytes the client can receive. @@ -304,7 +297,7 @@ func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error { c.maxReceiveMessageSize = &o.MaxRecvMsgSize return nil } -func (o MaxRecvMsgSizeCallOption) after(c *callInfo) {} +func (o MaxRecvMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {} // MaxCallSendMsgSize returns a CallOption which sets the maximum message size // in bytes the client can send. @@ -323,7 +316,7 @@ func (o MaxSendMsgSizeCallOption) before(c *callInfo) error { c.maxSendMessageSize = &o.MaxSendMsgSize return nil } -func (o MaxSendMsgSizeCallOption) after(c *callInfo) {} +func (o MaxSendMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {} // PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials // for a call. @@ -342,7 +335,7 @@ func (o PerRPCCredsCallOption) before(c *callInfo) error { c.creds = o.Creds return nil } -func (o PerRPCCredsCallOption) after(c *callInfo) {} +func (o PerRPCCredsCallOption) after(c *callInfo, attempt *csAttempt) {} // UseCompressor returns a CallOption which sets the compressor used when // sending the request. If WithCompressor is also set, UseCompressor has @@ -363,7 +356,7 @@ func (o CompressorCallOption) before(c *callInfo) error { c.compressorType = o.CompressorType return nil } -func (o CompressorCallOption) after(c *callInfo) {} +func (o CompressorCallOption) after(c *callInfo, attempt *csAttempt) {} // CallContentSubtype returns a CallOption that will set the content-subtype // for a call. For example, if content-subtype is "json", the Content-Type over @@ -396,7 +389,7 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error { c.contentSubtype = o.ContentSubtype return nil } -func (o ContentSubtypeCallOption) after(c *callInfo) {} +func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {} // ForceCodec returns a CallOption that will set the given Codec to be // used for all request and response messages for a call. The result of calling @@ -428,7 +421,7 @@ func (o ForceCodecCallOption) before(c *callInfo) error { c.codec = o.Codec return nil } -func (o ForceCodecCallOption) after(c *callInfo) {} +func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {} // CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of // an encoding.Codec. @@ -450,7 +443,7 @@ func (o CustomCodecCallOption) before(c *callInfo) error { c.codec = o.Codec return nil } -func (o CustomCodecCallOption) after(c *callInfo) {} +func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {} // MaxRetryRPCBufferSize returns a CallOption that limits the amount of memory // used for buffering this RPC's requests for retry purposes. @@ -471,7 +464,7 @@ func (o MaxRetryRPCBufferSizeCallOption) before(c *callInfo) error { c.maxRetryRPCBufferSize = o.MaxRetryRPCBufferSize return nil } -func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo) {} +func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo, attempt *csAttempt) {} // The format of the payload: compressed or not? type payloadFormat uint8 diff --git a/stream.go b/stream.go index 934ef68321cc..62d51334488d 100644 --- a/stream.go +++ b/stream.go @@ -277,7 +277,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } cs.binlog = binarylog.GetMethodLogger(method) - cs.callInfo.stream = cs // Only this initial attempt has stats/tracing. // TODO(dfawley): move to newAttempt when per-attempt stats are implemented. if err := cs.newAttemptLocked(sh, trInfo); err != nil { @@ -799,6 +798,15 @@ func (cs *clientStream) finish(err error) { } cs.finished = true cs.commitAttemptLocked() + if cs.attempt != nil { + cs.attempt.finish(err) + // after functions all rely upon having a stream. + if cs.attempt.s != nil { + for _, o := range cs.opts { + o.after(cs.callInfo, cs.attempt) + } + } + } cs.mu.Unlock() // For binary logging. only log cancel in finish (could be caused by RPC ctx // canceled or ClientConn closed). Trailer will be logged in RecvMsg. @@ -820,15 +828,6 @@ func (cs *clientStream) finish(err error) { cs.cc.incrCallsSucceeded() } } - if cs.attempt != nil { - cs.attempt.finish(err) - // after functions all rely upon having a stream. - if cs.attempt.s != nil { - for _, o := range cs.opts { - o.after(cs.callInfo) - } - } - } cs.cancel() } @@ -1066,7 +1065,6 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin t: t, } - as.callInfo.stream = as s, err := as.t.NewStream(as.ctx, as.callHdr) if err != nil { err = toRPCErr(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index f3a60de5a96f..82150fe11be9 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7128,3 +7128,62 @@ func (s) TestGzipBadChecksum(t *testing.T) { t.Errorf("ss.client.UnaryCall(_) = _, %v\n\twant: _, status(codes.Internal, contains %q)", err, gzip.ErrChecksum) } } + +// When an RPC is canceled, it's possible that the last Recv() returns before +// all call options' after are executed. +func (s) TestCanceledRPCCallOptionRace(t *testing.T) { + ss := &stubServer{ + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + err := stream.Send(&testpb.StreamingOutputCallResponse{}) + if err != nil { + return err + } + <-stream.Context().Done() + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + const count = 1000 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var wg sync.WaitGroup + wg.Add(count) + for i := 0; i < count; i++ { + go func() { + defer wg.Done() + var p peer.Peer + ctx, cancel := context.WithCancel(ctx) + defer cancel() + stream, err := ss.client.FullDuplexCall(ctx, grpc.Peer(&p)) + if err != nil { + t.Errorf("_.FullDuplexCall(_) = _, %v", err) + return + } + if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Errorf("_ has error %v while sending", err) + return + } + if _, err := stream.Recv(); err != nil { + t.Errorf("%v.Recv() = %v", stream, err) + return + } + cancel() + if _, err := stream.Recv(); status.Code(err) != codes.Canceled { + t.Errorf("%v compleled with error %v, want %s", stream, err, codes.Canceled) + return + } + // If recv returns before call options are executed, peer.Addr is not set, + // fail the test. + if p.Addr == nil { + t.Errorf("peer.Addr is nil, want non-nil") + return + } + }() + } + wg.Wait() +}