From de273bf7511de4710934b92415a00d471a6118cb Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Wed, 21 Aug 2024 14:04:00 +0300 Subject: [PATCH 1/7] channel: reject oversized messages on the sender side. Reject oversized messages on the sender side, keeping the receiver side rejection intact. This should provide minimal low-level plumbing for clients to attempt application level corrective actions on the requestor side, if the high-level protocol is designed with this in mind. Co-authored-by: Alessio Cantillo Co-authored-by: Qian Zhang Signed-off-by: Krisztian Litkey --- channel.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/channel.go b/channel.go index feafd9a6b..1af2a8408 100644 --- a/channel.go +++ b/channel.go @@ -143,10 +143,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error { - // TODO: Error on send rather than on recv - //if len(p) > messageLengthMax { - // return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax) - //} + if len(p) > messageLengthMax { + return status.Errorf(codes.ResourceExhausted, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax) + } + if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil { return err } From d8c00dfec306c305efef44aa526f2acf8ebd165b Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Wed, 21 Aug 2024 14:20:27 +0300 Subject: [PATCH 2/7] channel_test: update oversize message test. Co-authored-by: Alessio Cantillo Co-authored-by: Qian Zhang Signed-off-by: Krisztian Litkey --- channel_test.go | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/channel_test.go b/channel_test.go index de8b66d38..9eab63148 100644 --- a/channel_test.go +++ b/channel_test.go @@ -89,21 +89,19 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( - w, r = net.Pipe() - wch, rch = newChannel(w), newChannel(r) - msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) - errs = make(chan error, 1) + w, _ = net.Pipe() + wch = newChannel(w) + msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) + errs = make(chan error, 1) ) go func() { - if err := wch.send(1, 1, 0, msg); err != nil { - errs <- err - } + errs <- wch.send(1, 1, 0, msg) }() - _, _, err := rch.recv() + err := <-errs if err == nil { - t.Fatalf("error expected reading with small buffer") + t.Fatalf("sending oversized message expected to fail") } status, ok := status.FromError(err) @@ -114,12 +112,4 @@ func TestMessageOversize(t *testing.T) { if status.Code() != codes.ResourceExhausted { t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted) } - - select { - case err := <-errs: - if err != nil { - t.Fatal(err) - } - default: - } } From b5cd6e4b32878158dc44b7854a7d14b454f75daf Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Thu, 22 Aug 2024 13:22:06 +0300 Subject: [PATCH 3/7] channel: allow discovery of overflown message size. Use a dedicated, grpc Status-compatible error to wrap the unique grpc status code, the size of the rejected message and the maximum allowed size when a message is rejected due to size limitations by the sending side. Signed-off-by: Krisztian Litkey --- channel.go | 2 +- errors.go | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/channel.go b/channel.go index 1af2a8408..872261e6d 100644 --- a/channel.go +++ b/channel.go @@ -144,7 +144,7 @@ func (ch *channel) recv() (messageHeader, []byte, error) { func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error { if len(p) > messageLengthMax { - return status.Errorf(codes.ResourceExhausted, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax) + return OversizedMessageError(len(p)) } if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil { diff --git a/errors.go b/errors.go index ec14b7952..632dbe8bd 100644 --- a/errors.go +++ b/errors.go @@ -16,7 +16,12 @@ package ttrpc -import "errors" +import ( + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) var ( // ErrProtocol is a general error in the handling the protocol. @@ -32,3 +37,44 @@ var ( // ErrStreamClosed is when the streaming connection is closed. ErrStreamClosed = errors.New("ttrpc: stream closed") ) + +// OversizedMessageErr is used to indicate refusal to send an oversized message. +// It wraps a ResourceExhausted grpc Status together with the offending message +// length. +type OversizedMessageErr struct { + messageLength int + err error +} + +// OversizedMessageError returns an OversizedMessageErr error for the given message +// length if it exceeds the allowed maximum. Otherwise a nil error is returned. +func OversizedMessageError(messageLength int) error { + if messageLength <= messageLengthMax { + return nil + } + + return &OversizedMessageErr{ + messageLength: messageLength, + err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax), + } +} + +// Error returns the error message for the corresponding grpc Status for the error. +func (e *OversizedMessageErr) Error() string { + return e.err.Error() +} + +// Unwrap returns the corresponding error with our grpc status code. +func (e *OversizedMessageErr) Unwrap() error { + return e.err +} + +// RejectedLength retrieves the rejected message length which triggered the error. +func (e *OversizedMessageErr) RejectedLength() int { + return e.messageLength +} + +// MaximumLength retrieves the maximum allowed message length that triggered the error. +func (*OversizedMessageErr) MaximumLength() int { + return messageLengthMax +} From dd32ddee42956ba18f74fc4eec3281058fe928c7 Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Fri, 13 Sep 2024 16:02:32 +0300 Subject: [PATCH 4/7] server_test: add Serve()/Shutdown() race test. Signed-off-by: Krisztian Litkey --- server_test.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/server_test.go b/server_test.go index cf34986d6..8ee3fe3a9 100644 --- a/server_test.go +++ b/server_test.go @@ -298,6 +298,37 @@ func TestServerClose(t *testing.T) { checkServerShutdown(t, server) } +func TestImmediateServerShutdown(t *testing.T) { + for i := 0; i < 1024; i++ { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + _, cleanup = newTestClient(t, addr) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(ctx, listener) + }() + + registerTestingService(server, &testingServer{}) + + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + select { + case err := <-errs: + if err != ErrServerClosed { + t.Fatal(err) + } + case <-time.After(time.Second): + t.Fatal("retreiving error from server.Shutdown() timed out") + } + } +} + func TestOversizeCall(t *testing.T) { var ( ctx = context.Background() From 525ddcea260866db151e6c1862e1e828e5fd914b Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Fri, 13 Sep 2024 15:06:23 +0300 Subject: [PATCH 5/7] server: fix Serve() vs. immediate Shutdown() race. Fix a race where an asynchronous server.Serve() invoked in a a goroutine races with an almost immediate server.Shutdown(). If Shutdown() finishes its locked closing of listeners before Serve() gets around to add the new one, Serve will sit stuck forever in l.Accept(), unless the caller closes the listener in addition to Shutdown(). This is probably almost impossible to trigger in real life, but some of the unit tests, which run the server and client in the same process, occasionally do trigger this. Then, if the test tries to verify a final ErrServerClosed error from Serve() after Shutdown() it gets stuck forever. Signed-off-by: Krisztian Litkey --- server.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/server.go b/server.go index 26419831d..bb71de677 100644 --- a/server.go +++ b/server.go @@ -74,9 +74,18 @@ func (s *Server) RegisterService(name string, desc *ServiceDesc) { } func (s *Server) Serve(ctx context.Context, l net.Listener) error { - s.addListener(l) + s.mu.Lock() + s.addListenerLocked(l) defer s.closeListener(l) + select { + case <-s.done: + s.mu.Unlock() + return ErrServerClosed + default: + } + s.mu.Unlock() + var ( backoff time.Duration handshaker = s.config.handshaker @@ -188,9 +197,7 @@ func (s *Server) Close() error { return err } -func (s *Server) addListener(l net.Listener) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *Server) addListenerLocked(l net.Listener) { s.listeners[l] = struct{}{} } From e1f03b3b07edfe3c395faafbbf65432bbcf78d8b Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Mon, 26 Aug 2024 17:03:42 +0300 Subject: [PATCH 6/7] client,server: configurable wire message size limits. Implement configurable limits for the maximum accepted message size of the wire protocol. The default limit can be overridden using the WithClientWireMessageLimit() option for clients and using the WithServerWireMessageLimit() option for servers. Add exported constants for the minimum, maximum and default limits. Signed-off-by: Krisztian Litkey --- channel.go | 61 ++++++++++++++++++++++++++++++++++++++---------------- client.go | 19 ++++++++++++----- config.go | 10 +++++++++ errors.go | 50 +++++++++++++++++++++++++++++++++++++++----- server.go | 21 ++++++++++++++++++- 5 files changed, 132 insertions(+), 29 deletions(-) diff --git a/channel.go b/channel.go index 872261e6d..50c917adc 100644 --- a/channel.go +++ b/channel.go @@ -23,14 +23,13 @@ import ( "io" "net" "sync" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) const ( - messageHeaderLength = 10 - messageLengthMax = 4 << 20 + messageHeaderLength = 10 + MinMessageLengthLimit = 4 << 10 + MaxMessageLengthLimit = 4 << 22 + DefaultMessageLengthLimit = 4 << 20 ) type messageType uint8 @@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error { var buffers sync.Pool type channel struct { - conn net.Conn - bw *bufio.Writer - br *bufio.Reader - hrbuf [messageHeaderLength]byte // avoid alloc when reading header - hwbuf [messageHeaderLength]byte + conn net.Conn + bw *bufio.Writer + br *bufio.Reader + hrbuf [messageHeaderLength]byte // avoid alloc when reading header + hwbuf [messageHeaderLength]byte + maxMsgLen int } -func newChannel(conn net.Conn) *channel { +func newChannel(conn net.Conn, maxMsgLen int) *channel { + if maxMsgLen == 0 { + maxMsgLen = DefaultMessageLengthLimit + } return &channel{ - conn: conn, - bw: bufio.NewWriter(conn), - br: bufio.NewReader(conn), + conn: conn, + bw: bufio.NewWriter(conn), + br: bufio.NewReader(conn), + maxMsgLen: maxMsgLen, } } @@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) { return messageHeader{}, nil, err } - if mh.Length > uint32(messageLengthMax) { + if maxMsgLen := ch.maxMsgLimit(true); mh.Length > uint32(maxMsgLen) { if _, err := ch.br.Discard(int(mh.Length)); err != nil { return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err) } - return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) + return mh, nil, OversizedMessageError(int(mh.Length), maxMsgLen) } var p []byte @@ -143,8 +147,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error { - if len(p) > messageLengthMax { - return OversizedMessageError(len(p)) + if maxMsgLen := ch.maxMsgLimit(false); maxMsgLen != 0 { + if len(p) > maxMsgLen { + return OversizedMessageError(len(p), maxMsgLen) + } } if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil { @@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte { func (ch *channel) putmbuf(p []byte) { buffers.Put(&p) } + +func (ch *channel) maxMsgLimit(recv bool) int { + if ch.maxMsgLen == 0 && recv { + return DefaultMessageLengthLimit + } + return ch.maxMsgLen +} + +func clampWireMessageLimit(maxMsgLen int) int { + switch { + case maxMsgLen == 0: + return 0 + case maxMsgLen < MinMessageLengthLimit: + return MinMessageLengthLimit + case maxMsgLen > MaxMessageLengthLimit: + return MaxMessageLengthLimit + } + return maxMsgLen +} diff --git a/client.go b/client.go index b1bc7a3fc..09cfb2897 100644 --- a/client.go +++ b/client.go @@ -35,9 +35,10 @@ import ( // Client for a ttrpc server type Client struct { - codec codec - conn net.Conn - channel *channel + codec codec + conn net.Conn + channel *channel + maxMsgLen int streamLock sync.RWMutex streams map[streamID]*stream @@ -107,14 +108,20 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker } } +// WithClientWireMessageLimit sets the maximum allowed message length on the wire for the client. +func WithClientWireMessageLimit(maxMsgLen int) ClientOpts { + maxMsgLen = clampWireMessageLimit(maxMsgLen) + return func(c *Client) { + c.maxMsgLen = maxMsgLen + } +} + // NewClient creates a new ttrpc client using the given connection func NewClient(conn net.Conn, opts ...ClientOpts) *Client { ctx, cancel := context.WithCancel(context.Background()) - channel := newChannel(conn) c := &Client{ codec: codec{}, conn: conn, - channel: channel, streams: make(map[streamID]*stream), nextStreamID: 1, closed: cancel, @@ -127,6 +134,8 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client { o(c) } + c.channel = newChannel(conn, c.maxMsgLen) + if c.interceptor == nil { c.interceptor = defaultClientInterceptor } diff --git a/config.go b/config.go index f401f67be..5995b9a8b 100644 --- a/config.go +++ b/config.go @@ -24,6 +24,7 @@ import ( type serverConfig struct { handshaker Handshaker interceptor UnaryServerInterceptor + maxMsgLen int } // ServerOpt for configuring a ttrpc server @@ -84,3 +85,12 @@ func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, intercep chainUnaryServerInterceptors(info, method, interceptors[1:])) } } + +// WithServerWireMessageLimit sets the maximum allowed message length on the wire for the server. +func WithServerWireMessageLimit(maxMsgLen int) ServerOpt { + maxMsgLen = clampWireMessageLimit(maxMsgLen) + return func(c *serverConfig) error { + c.maxMsgLen = maxMsgLen + return nil + } +} diff --git a/errors.go b/errors.go index 632dbe8bd..5f81bc491 100644 --- a/errors.go +++ b/errors.go @@ -18,6 +18,7 @@ package ttrpc import ( "errors" + "fmt" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -43,20 +44,59 @@ var ( // length. type OversizedMessageErr struct { messageLength int + maxLength int err error } +var ( + oversizedMsgFmt = "message length %d exceeds maximum message size of %d" + oversizedMsgScanFmt = fmt.Sprintf("%v", status.New(codes.ResourceExhausted, oversizedMsgFmt)) +) + // OversizedMessageError returns an OversizedMessageErr error for the given message // length if it exceeds the allowed maximum. Otherwise a nil error is returned. -func OversizedMessageError(messageLength int) error { - if messageLength <= messageLengthMax { +func OversizedMessageError(messageLength, maxLength int) error { + if messageLength <= maxLength { return nil } return &OversizedMessageErr{ messageLength: messageLength, - err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax), + maxLength: maxLength, + err: OversizedMessageStatus(messageLength, maxLength).Err(), + } +} + +// OversizedMessageStatus returns a Status for an oversized message error. +func OversizedMessageStatus(messageLength, maxLength int) *status.Status { + return status.Newf(codes.ResourceExhausted, oversizedMsgFmt, messageLength, maxLength) +} + +// OversizedMessageFromError reconstructs an OversizedMessageErr from a Status. +func OversizedMessageFromError(err error) (*OversizedMessageErr, bool) { + var ( + messageLength int + maxLength int + ) + + st, ok := status.FromError(err) + if !ok || st.Code() != codes.ResourceExhausted { + return nil, false } + + // TODO(klihub): might be too ugly to recover an error this way... An + // alternative would be to define our custom status detail proto type, + // then use status.WithDetails() and status.Details(). + + n, _ := fmt.Sscanf(st.Message(), oversizedMsgScanFmt, &messageLength, &maxLength) + if n != 2 { + n, _ = fmt.Sscanf(st.Message(), oversizedMsgFmt, &messageLength, &maxLength) + } + if n != 2 { + return nil, false + } + + return OversizedMessageError(messageLength, maxLength).(*OversizedMessageErr), true } // Error returns the error message for the corresponding grpc Status for the error. @@ -75,6 +115,6 @@ func (e *OversizedMessageErr) RejectedLength() int { } // MaximumLength retrieves the maximum allowed message length that triggered the error. -func (*OversizedMessageErr) MaximumLength() int { - return messageLengthMax +func (e *OversizedMessageErr) MaximumLength() int { + return e.maxLength } diff --git a/server.go b/server.go index bb71de677..f30d5c924 100644 --- a/server.go +++ b/server.go @@ -339,7 +339,7 @@ func (c *serverConn) run(sctx context.Context) { ) var ( - ch = newChannel(c.conn) + ch = newChannel(c.conn, c.server.config.maxMsgLen) ctx, cancel = context.WithCancel(sctx) state connState = connStateIdle responses = make(chan response) @@ -373,6 +373,14 @@ func (c *serverConn) run(sctx context.Context) { } } + isResourceExhaustedError := func(err error) (*status.Status, bool) { + st, ok := status.FromError(err) + if !ok || st.Code() != codes.ResourceExhausted { + return nil, false + } + return st, true + } + go func(recvErr chan error) { defer close(recvErr) for { @@ -525,6 +533,17 @@ func (c *serverConn) run(sctx context.Context) { } if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil { + if st, ok := isResourceExhaustedError(err); ok { + p, err = c.server.codec.Marshal(&Response{ + Status: st.Proto(), + }) + if err != nil { + log.G(ctx).WithError(err).Error("failed marshaling error response") + return + } + ch.send(response.id, messageTypeResponse, 0, p) + return + } log.G(ctx).WithError(err).Error("failed sending message on channel") return } From 9e19b2d241f80abd6212daf24e38697f26f95a36 Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Mon, 26 Aug 2024 17:20:35 +0300 Subject: [PATCH 7/7] {channel,server}_test: add tests for message limits. Adjust unit test to accomodate for altered internal interfaces. Add unit tests to exercise the new message size limit options. Signed-off-by: Krisztian Litkey --- channel_test.go | 6 +- server_test.go | 273 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 247 insertions(+), 32 deletions(-) diff --git a/channel_test.go b/channel_test.go index 9eab63148..30263730d 100644 --- a/channel_test.go +++ b/channel_test.go @@ -31,8 +31,8 @@ import ( func TestReadWriteMessage(t *testing.T) { var ( w, r = net.Pipe() - ch = newChannel(w) - rch = newChannel(r) + ch = newChannel(w, 0) + rch = newChannel(r, 0) messages = [][]byte{ []byte("hello"), []byte("this is a test"), @@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( w, _ = net.Pipe() - wch = newChannel(w) + wch = newChannel(w, 0) msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) errs = make(chan error, 1) ) diff --git a/server_test.go b/server_test.go index 8ee3fe3a9..76a6de131 100644 --- a/server_test.go +++ b/server_test.go @@ -19,6 +19,7 @@ package ttrpc import ( "bytes" "context" + "crypto/md5" "errors" "fmt" "net" @@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (* } // testingServer is what would be implemented by the user of this package. -type testingServer struct{} +type testingServer struct { + echoOnce bool +} func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) { - tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)} + tp := &internal.TestPayload{} + if s.echoOnce { + tp.Foo = req.Foo + } else { + tp.Foo = strings.Repeat(req.Foo, 2) + } if dl, ok := ctx.Deadline(); ok { tp.Deadline = dl.UnixNano() } @@ -330,38 +338,238 @@ func TestImmediateServerShutdown(t *testing.T) { } func TestOversizeCall(t *testing.T) { - var ( - ctx = context.Background() - server = mustServer(t)(NewServer()) - addr, listener = newTestListener(t) - errs = make(chan error, 1) - client, cleanup = newTestClient(t, addr) - ) - defer cleanup() - defer listener.Close() - go func() { - errs <- server.Serve(ctx, listener) - }() + type testCase struct { + name string + echoOnce bool + clientLimit int + serverLimit int + requestSize int + clientFail bool + sendFail bool + serverFail bool + } + + overhead := getWireMessageOverhead(t) + + clientOpts := func(tc *testCase) []ClientOpts { + if tc.clientLimit == 0 { + return nil + } + return []ClientOpts{WithClientWireMessageLimit(tc.clientLimit)} + } + serverOpts := func(tc *testCase) []ServerOpt { + if tc.serverLimit == 0 { + return nil + } + return []ServerOpt{WithServerWireMessageLimit(tc.serverLimit)} + } - registerTestingService(server, &testingServer{}) + runTest := func(t *testing.T, tc *testCase) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer(serverOpts(tc)...)) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr, clientOpts(tc)...) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(ctx, listener) + }() + + registerTestingService(server, &testingServer{echoOnce: tc.echoOnce}) - tp := &internal.TestPayload{ - Foo: strings.Repeat("a", 1+messageLengthMax), + req := &internal.TestPayload{ + Foo: strings.Repeat("a", tc.requestSize), + } + rsp := &internal.TestPayload{} + + err := client.Call(ctx, serviceName, "Test", req, rsp) + if tc.clientFail { + if err == nil { + t.Fatalf("expected error from oversized message") + } else if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + if tc.sendFail { + var msgLenErr *OversizedMessageErr + if !errors.As(err, &msgLenErr) { + t.Fatalf("failed to retrieve client send OversizedMessageErr") + } + rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength() + if rejLen == 0 { + t.Fatalf("zero rejected length in client send oversized message error") + } + if maxLen == 0 { + t.Fatalf("zero maximum length in client send oversized message error") + } + if rejLen <= maxLen { + t.Fatalf("client send oversized message error rejected < max. length (%d < %d)", + rejLen, maxLen) + } + } + } else if tc.serverFail { + if err == nil { + t.Fatalf("expected error from server-side oversized message") + } else { + if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + if msgLenErr, ok := OversizedMessageFromError(err); !ok { + t.Fatalf("failed to retrieve oversized message error") + } else { + rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength() + if rejLen == 0 { + t.Fatalf("zero rejected length in oversized message error") + } + if maxLen == 0 { + t.Fatalf("zero maximum length in oversized message error") + } + if rejLen <= maxLen { + t.Fatalf("oversized message error rejected < max. length (%d < %d)", + rejLen, maxLen) + } + } + } + } else { + if err != nil { + t.Fatalf("expected success, got error %v", err) + } + } + + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + if err := <-errs; err != ErrServerClosed { + t.Fatal(err) + } } - if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { - t.Fatalf("expected error from oversized message") - } else if status, ok := status.FromError(err); !ok { - t.Fatalf("expected status present in error: %v", err) - } else if status.Code() != codes.ResourceExhausted { - t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + + for _, tc := range []*testCase{ + { + name: "default limits, fitting request and response", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit - overhead, + }, + { + name: "default limits, only recv side check", + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit - overhead, + serverFail: true, + }, + + { + name: "default limits, oversized request", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit, + clientFail: true, + }, + { + name: "default limits, oversized response", + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit / 2, + serverFail: true, + }, + { + name: "8K limits, 4K request and response", + echoOnce: true, + clientLimit: 8 * 1024, + serverLimit: 8 * 1024, + requestSize: 4 * 1024, + }, + { + name: "4K limits, barely fitting cc. 4K request and response", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4*1024 - overhead, + }, + { + name: "4K limits, oversized request on client side", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + clientFail: true, + sendFail: true, + }, + { + name: "4K limits, oversized request on server side", + echoOnce: true, + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + serverFail: true, + }, + { + name: "4K limits, oversized response on client side", + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 8*1024 + overhead, + clientFail: true, + }, + { + name: "4K limits, oversized response on server side", + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + serverFail: true, + }, + { + name: "too small limits, adjusted to minimum accepted limit", + echoOnce: true, + clientLimit: 4, + serverLimit: 4, + requestSize: 4*1024 - overhead, + }, + { + name: "maximum allowed protocol limit", + echoOnce: true, + clientLimit: MaxMessageLengthLimit, + serverLimit: MaxMessageLengthLimit, + requestSize: MaxMessageLengthLimit - overhead, + }, + } { + t.Run(tc.name, func(t *testing.T) { + runTest(t, tc) + }) } +} - if err := server.Shutdown(ctx); err != nil { - t.Fatal(err) +func getWireMessageOverhead(t *testing.T) int { + emptyReq, err := codec{}.Marshal(&Request{ + Service: serviceName, + Method: "Test", + }) + if err != nil { + t.Fatalf("failed to marshal empty request: %v", err) } - if err := <-errs; err != ErrServerClosed { - t.Fatal(err) + + emptyRsp, err := codec{}.Marshal(&Response{ + Status: status.New(codes.OK, "").Proto(), + }) + if err != nil { + t.Fatalf("failed to marshal empty response: %v", err) + } + + reqLen := len(emptyReq) + rspLen := len(emptyRsp) + if reqLen > rspLen { + return reqLen + messageHeaderLength } + + return rspLen + messageHeaderLength } func TestClientEOF(t *testing.T) { @@ -582,13 +790,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func } func newTestListener(t testing.TB) (string, net.Listener) { - var prefix string + var ( + name = t.Name() + prefix string + ) // Abstracts sockets are only available on Linux. if runtime.GOOS == "linux" { prefix = "\x00" + } else { + if split := strings.SplitN(name, "/", 2); len(split) == 2 { + name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1]))) + } } - addr := prefix + t.Name() + addr := prefix + name listener, err := net.Listen("unix", addr) if err != nil { t.Fatal(err)