diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 9bd8c27b3655..5922750e81cd 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -352,6 +352,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ + ct: t, done: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, @@ -1191,6 +1192,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { // If headerChan hasn't been closed yet if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { + s.headerValid = true if !endStream { // HEADERS frame block carries a Response-Headers. isHeader = true diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 1c1d106709ac..965c76f18fa1 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -233,6 +233,7 @@ const ( type Stream struct { id uint32 st ServerTransport // nil for client side Stream + ct *http2Client // nil for server side Stream ctx context.Context // the associated context of the stream cancel context.CancelFunc // always nil for client side Stream done chan struct{} // closed at the end of stream to unblock writers. On the client side. @@ -251,6 +252,10 @@ type Stream struct { headerChan chan struct{} // closed to indicate the end of header metadata. headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. + // headerValid indicates whether a valid header was received. Only + // meaningful after headerChan is closed (always call waitOnHeader() before + // reading its value). + headerValid bool // hdrMu protects header and trailer metadata on the server-side. hdrMu sync.Mutex @@ -303,34 +308,26 @@ func (s *Stream) getState() streamState { return streamState(atomic.LoadUint32((*uint32)(&s.state))) } -func (s *Stream) waitOnHeader() error { +func (s *Stream) waitOnHeader() { if s.headerChan == nil { // On the server headerChan is always nil since a stream originates // only after having received headers. - return nil + return } select { case <-s.ctx.Done(): - // We prefer success over failure when reading messages because we delay - // context error in stream.Read(). To keep behavior consistent, we also - // prefer success here. - select { - case <-s.headerChan: - return nil - default: - } - return ContextErr(s.ctx.Err()) + // Close the stream to prevent headers/trailers from changing after + // this function returns. + err := ContextErr(s.ctx.Err()) + s.ct.closeStream(s, err, false, 0, status.Convert(err), nil, false) case <-s.headerChan: - return nil } } // RecvCompress returns the compression algorithm applied to the inbound // message. It is empty string if there is no compression applied. func (s *Stream) RecvCompress() string { - if err := s.waitOnHeader(); err != nil { - return "" - } + s.waitOnHeader() return s.recvCompress } @@ -358,29 +355,19 @@ func (s *Stream) Header() (metadata.MD, error) { // header after t.WriteHeader is called. return s.header.Copy(), nil } - err := s.waitOnHeader() - // Even if the stream is closed, header is returned if available. - select { - case <-s.headerChan: - if s.header == nil { - return nil, nil - } - return s.header.Copy(), nil - default: + s.waitOnHeader() + if !s.headerValid { + return nil, s.status.Err() } - return nil, err + return s.header.Copy(), nil } // TrailersOnly blocks until a header or trailers-only frame is received and // then returns true if the stream was trailers-only. If the stream ends -// before headers are received, returns true, nil. If a context error happens -// first, returns it as a status error. Client-side only. -func (s *Stream) TrailersOnly() (bool, error) { - err := s.waitOnHeader() - if err != nil { - return false, err - } - return s.noHeaders, nil +// before headers are received, returns true, nil. Client-side only. +func (s *Stream) TrailersOnly() bool { + s.waitOnHeader() + return s.noHeaders } // Trailer returns the cached trailer metedata. Note that if it is not called diff --git a/stream.go b/stream.go index 134a624a15df..bb99940e36fe 100644 --- a/stream.go +++ b/stream.go @@ -488,7 +488,7 @@ func (cs *clientStream) shouldRetry(err error) error { pushback := 0 hasPushback := false if cs.attempt.s != nil { - if to, toErr := cs.attempt.s.TrailersOnly(); toErr != nil || !to { + if !cs.attempt.s.TrailersOnly() { return err } diff --git a/test/context_canceled_test.go b/test/context_canceled_test.go index 9715b5c203e2..781f63f0c04e 100644 --- a/test/context_canceled_test.go +++ b/test/context_canceled_test.go @@ -139,13 +139,12 @@ func (s) TestCancelWhileRecvingWithCompression(t *testing.T) { for i := 0; i < 10; i++ { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() s, err := ss.client.FullDuplexCall(ctx, grpc.UseCompressor(gzip.Name)) if err != nil { t.Fatalf("failed to start bidi streaming RPC: %v", err) } // Cancel the stream while receiving to trigger the internal error. - time.AfterFunc(time.Millisecond*10, cancel) + time.AfterFunc(time.Millisecond, cancel) for { _, err := s.Recv() if err != nil {