diff --git a/channel.go b/channel.go index feafd9a6b..50c917adc 100644 --- a/channel.go +++ b/channel.go @@ -23,14 +23,13 @@ import ( "io" "net" "sync" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) const ( - messageHeaderLength = 10 - messageLengthMax = 4 << 20 + messageHeaderLength = 10 + MinMessageLengthLimit = 4 << 10 + MaxMessageLengthLimit = 4 << 22 + DefaultMessageLengthLimit = 4 << 20 ) type messageType uint8 @@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error { var buffers sync.Pool type channel struct { - conn net.Conn - bw *bufio.Writer - br *bufio.Reader - hrbuf [messageHeaderLength]byte // avoid alloc when reading header - hwbuf [messageHeaderLength]byte + conn net.Conn + bw *bufio.Writer + br *bufio.Reader + hrbuf [messageHeaderLength]byte // avoid alloc when reading header + hwbuf [messageHeaderLength]byte + maxMsgLen int } -func newChannel(conn net.Conn) *channel { +func newChannel(conn net.Conn, maxMsgLen int) *channel { + if maxMsgLen == 0 { + maxMsgLen = DefaultMessageLengthLimit + } return &channel{ - conn: conn, - bw: bufio.NewWriter(conn), - br: bufio.NewReader(conn), + conn: conn, + bw: bufio.NewWriter(conn), + br: bufio.NewReader(conn), + maxMsgLen: maxMsgLen, } } @@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) { return messageHeader{}, nil, err } - if mh.Length > uint32(messageLengthMax) { + if maxMsgLen := ch.maxMsgLimit(true); mh.Length > uint32(maxMsgLen) { if _, err := ch.br.Discard(int(mh.Length)); err != nil { return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err) } - return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) + return mh, nil, OversizedMessageError(int(mh.Length), maxMsgLen) } var p []byte @@ -143,10 +147,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error { - // TODO: Error on send rather than on recv - //if len(p) > messageLengthMax { - // return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax) - //} + if maxMsgLen := ch.maxMsgLimit(false); maxMsgLen != 0 { + if len(p) > maxMsgLen { + return OversizedMessageError(len(p), maxMsgLen) + } + } + if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil { return err } @@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte { func (ch *channel) putmbuf(p []byte) { buffers.Put(&p) } + +func (ch *channel) maxMsgLimit(recv bool) int { + if ch.maxMsgLen == 0 && recv { + return DefaultMessageLengthLimit + } + return ch.maxMsgLen +} + +func clampWireMessageLimit(maxMsgLen int) int { + switch { + case maxMsgLen == 0: + return 0 + case maxMsgLen < MinMessageLengthLimit: + return MinMessageLengthLimit + case maxMsgLen > MaxMessageLengthLimit: + return MaxMessageLengthLimit + } + return maxMsgLen +} diff --git a/channel_test.go b/channel_test.go index de8b66d38..30263730d 100644 --- a/channel_test.go +++ b/channel_test.go @@ -31,8 +31,8 @@ import ( func TestReadWriteMessage(t *testing.T) { var ( w, r = net.Pipe() - ch = newChannel(w) - rch = newChannel(r) + ch = newChannel(w, 0) + rch = newChannel(r, 0) messages = [][]byte{ []byte("hello"), []byte("this is a test"), @@ -89,21 +89,19 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( - w, r = net.Pipe() - wch, rch = newChannel(w), newChannel(r) - msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) - errs = make(chan error, 1) + w, _ = net.Pipe() + wch = newChannel(w, 0) + msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) + errs = make(chan error, 1) ) go func() { - if err := wch.send(1, 1, 0, msg); err != nil { - errs <- err - } + errs <- wch.send(1, 1, 0, msg) }() - _, _, err := rch.recv() + err := <-errs if err == nil { - t.Fatalf("error expected reading with small buffer") + t.Fatalf("sending oversized message expected to fail") } status, ok := status.FromError(err) @@ -114,12 +112,4 @@ func TestMessageOversize(t *testing.T) { if status.Code() != codes.ResourceExhausted { t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted) } - - select { - case err := <-errs: - if err != nil { - t.Fatal(err) - } - default: - } } diff --git a/client.go b/client.go index b1bc7a3fc..09cfb2897 100644 --- a/client.go +++ b/client.go @@ -35,9 +35,10 @@ import ( // Client for a ttrpc server type Client struct { - codec codec - conn net.Conn - channel *channel + codec codec + conn net.Conn + channel *channel + maxMsgLen int streamLock sync.RWMutex streams map[streamID]*stream @@ -107,14 +108,20 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker } } +// WithClientWireMessageLimit sets the maximum allowed message length on the wire for the client. +func WithClientWireMessageLimit(maxMsgLen int) ClientOpts { + maxMsgLen = clampWireMessageLimit(maxMsgLen) + return func(c *Client) { + c.maxMsgLen = maxMsgLen + } +} + // NewClient creates a new ttrpc client using the given connection func NewClient(conn net.Conn, opts ...ClientOpts) *Client { ctx, cancel := context.WithCancel(context.Background()) - channel := newChannel(conn) c := &Client{ codec: codec{}, conn: conn, - channel: channel, streams: make(map[streamID]*stream), nextStreamID: 1, closed: cancel, @@ -127,6 +134,8 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client { o(c) } + c.channel = newChannel(conn, c.maxMsgLen) + if c.interceptor == nil { c.interceptor = defaultClientInterceptor } diff --git a/config.go b/config.go index f401f67be..5995b9a8b 100644 --- a/config.go +++ b/config.go @@ -24,6 +24,7 @@ import ( type serverConfig struct { handshaker Handshaker interceptor UnaryServerInterceptor + maxMsgLen int } // ServerOpt for configuring a ttrpc server @@ -84,3 +85,12 @@ func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, intercep chainUnaryServerInterceptors(info, method, interceptors[1:])) } } + +// WithServerWireMessageLimit sets the maximum allowed message length on the wire for the server. +func WithServerWireMessageLimit(maxMsgLen int) ServerOpt { + maxMsgLen = clampWireMessageLimit(maxMsgLen) + return func(c *serverConfig) error { + c.maxMsgLen = maxMsgLen + return nil + } +} diff --git a/errors.go b/errors.go index ec14b7952..5f81bc491 100644 --- a/errors.go +++ b/errors.go @@ -16,7 +16,13 @@ package ttrpc -import "errors" +import ( + "errors" + "fmt" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) var ( // ErrProtocol is a general error in the handling the protocol. @@ -32,3 +38,83 @@ var ( // ErrStreamClosed is when the streaming connection is closed. ErrStreamClosed = errors.New("ttrpc: stream closed") ) + +// OversizedMessageErr is used to indicate refusal to send an oversized message. +// It wraps a ResourceExhausted grpc Status together with the offending message +// length. +type OversizedMessageErr struct { + messageLength int + maxLength int + err error +} + +var ( + oversizedMsgFmt = "message length %d exceeds maximum message size of %d" + oversizedMsgScanFmt = fmt.Sprintf("%v", status.New(codes.ResourceExhausted, oversizedMsgFmt)) +) + +// OversizedMessageError returns an OversizedMessageErr error for the given message +// length if it exceeds the allowed maximum. Otherwise a nil error is returned. +func OversizedMessageError(messageLength, maxLength int) error { + if messageLength <= maxLength { + return nil + } + + return &OversizedMessageErr{ + messageLength: messageLength, + maxLength: maxLength, + err: OversizedMessageStatus(messageLength, maxLength).Err(), + } +} + +// OversizedMessageStatus returns a Status for an oversized message error. +func OversizedMessageStatus(messageLength, maxLength int) *status.Status { + return status.Newf(codes.ResourceExhausted, oversizedMsgFmt, messageLength, maxLength) +} + +// OversizedMessageFromError reconstructs an OversizedMessageErr from a Status. +func OversizedMessageFromError(err error) (*OversizedMessageErr, bool) { + var ( + messageLength int + maxLength int + ) + + st, ok := status.FromError(err) + if !ok || st.Code() != codes.ResourceExhausted { + return nil, false + } + + // TODO(klihub): might be too ugly to recover an error this way... An + // alternative would be to define our custom status detail proto type, + // then use status.WithDetails() and status.Details(). + + n, _ := fmt.Sscanf(st.Message(), oversizedMsgScanFmt, &messageLength, &maxLength) + if n != 2 { + n, _ = fmt.Sscanf(st.Message(), oversizedMsgFmt, &messageLength, &maxLength) + } + if n != 2 { + return nil, false + } + + return OversizedMessageError(messageLength, maxLength).(*OversizedMessageErr), true +} + +// Error returns the error message for the corresponding grpc Status for the error. +func (e *OversizedMessageErr) Error() string { + return e.err.Error() +} + +// Unwrap returns the corresponding error with our grpc status code. +func (e *OversizedMessageErr) Unwrap() error { + return e.err +} + +// RejectedLength retrieves the rejected message length which triggered the error. +func (e *OversizedMessageErr) RejectedLength() int { + return e.messageLength +} + +// MaximumLength retrieves the maximum allowed message length that triggered the error. +func (e *OversizedMessageErr) MaximumLength() int { + return e.maxLength +} diff --git a/server.go b/server.go index 26419831d..f30d5c924 100644 --- a/server.go +++ b/server.go @@ -74,9 +74,18 @@ func (s *Server) RegisterService(name string, desc *ServiceDesc) { } func (s *Server) Serve(ctx context.Context, l net.Listener) error { - s.addListener(l) + s.mu.Lock() + s.addListenerLocked(l) defer s.closeListener(l) + select { + case <-s.done: + s.mu.Unlock() + return ErrServerClosed + default: + } + s.mu.Unlock() + var ( backoff time.Duration handshaker = s.config.handshaker @@ -188,9 +197,7 @@ func (s *Server) Close() error { return err } -func (s *Server) addListener(l net.Listener) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *Server) addListenerLocked(l net.Listener) { s.listeners[l] = struct{}{} } @@ -332,7 +339,7 @@ func (c *serverConn) run(sctx context.Context) { ) var ( - ch = newChannel(c.conn) + ch = newChannel(c.conn, c.server.config.maxMsgLen) ctx, cancel = context.WithCancel(sctx) state connState = connStateIdle responses = make(chan response) @@ -366,6 +373,14 @@ func (c *serverConn) run(sctx context.Context) { } } + isResourceExhaustedError := func(err error) (*status.Status, bool) { + st, ok := status.FromError(err) + if !ok || st.Code() != codes.ResourceExhausted { + return nil, false + } + return st, true + } + go func(recvErr chan error) { defer close(recvErr) for { @@ -518,6 +533,17 @@ func (c *serverConn) run(sctx context.Context) { } if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil { + if st, ok := isResourceExhaustedError(err); ok { + p, err = c.server.codec.Marshal(&Response{ + Status: st.Proto(), + }) + if err != nil { + log.G(ctx).WithError(err).Error("failed marshaling error response") + return + } + ch.send(response.id, messageTypeResponse, 0, p) + return + } log.G(ctx).WithError(err).Error("failed sending message on channel") return } diff --git a/server_test.go b/server_test.go index cf34986d6..76a6de131 100644 --- a/server_test.go +++ b/server_test.go @@ -19,6 +19,7 @@ package ttrpc import ( "bytes" "context" + "crypto/md5" "errors" "fmt" "net" @@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (* } // testingServer is what would be implemented by the user of this package. -type testingServer struct{} +type testingServer struct { + echoOnce bool +} func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) { - tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)} + tp := &internal.TestPayload{} + if s.echoOnce { + tp.Foo = req.Foo + } else { + tp.Foo = strings.Repeat(req.Foo, 2) + } if dl, ok := ctx.Deadline(); ok { tp.Deadline = dl.UnixNano() } @@ -298,39 +306,270 @@ func TestServerClose(t *testing.T) { checkServerShutdown(t, server) } +func TestImmediateServerShutdown(t *testing.T) { + for i := 0; i < 1024; i++ { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + _, cleanup = newTestClient(t, addr) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(ctx, listener) + }() + + registerTestingService(server, &testingServer{}) + + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + select { + case err := <-errs: + if err != ErrServerClosed { + t.Fatal(err) + } + case <-time.After(time.Second): + t.Fatal("retreiving error from server.Shutdown() timed out") + } + } +} + func TestOversizeCall(t *testing.T) { - var ( - ctx = context.Background() - server = mustServer(t)(NewServer()) - addr, listener = newTestListener(t) - errs = make(chan error, 1) - client, cleanup = newTestClient(t, addr) - ) - defer cleanup() - defer listener.Close() - go func() { - errs <- server.Serve(ctx, listener) - }() + type testCase struct { + name string + echoOnce bool + clientLimit int + serverLimit int + requestSize int + clientFail bool + sendFail bool + serverFail bool + } + + overhead := getWireMessageOverhead(t) + + clientOpts := func(tc *testCase) []ClientOpts { + if tc.clientLimit == 0 { + return nil + } + return []ClientOpts{WithClientWireMessageLimit(tc.clientLimit)} + } + serverOpts := func(tc *testCase) []ServerOpt { + if tc.serverLimit == 0 { + return nil + } + return []ServerOpt{WithServerWireMessageLimit(tc.serverLimit)} + } + + runTest := func(t *testing.T, tc *testCase) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer(serverOpts(tc)...)) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr, clientOpts(tc)...) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(ctx, listener) + }() + + registerTestingService(server, &testingServer{echoOnce: tc.echoOnce}) + + req := &internal.TestPayload{ + Foo: strings.Repeat("a", tc.requestSize), + } + rsp := &internal.TestPayload{} + + err := client.Call(ctx, serviceName, "Test", req, rsp) + if tc.clientFail { + if err == nil { + t.Fatalf("expected error from oversized message") + } else if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + if tc.sendFail { + var msgLenErr *OversizedMessageErr + if !errors.As(err, &msgLenErr) { + t.Fatalf("failed to retrieve client send OversizedMessageErr") + } + rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength() + if rejLen == 0 { + t.Fatalf("zero rejected length in client send oversized message error") + } + if maxLen == 0 { + t.Fatalf("zero maximum length in client send oversized message error") + } + if rejLen <= maxLen { + t.Fatalf("client send oversized message error rejected < max. length (%d < %d)", + rejLen, maxLen) + } + } + } else if tc.serverFail { + if err == nil { + t.Fatalf("expected error from server-side oversized message") + } else { + if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + if msgLenErr, ok := OversizedMessageFromError(err); !ok { + t.Fatalf("failed to retrieve oversized message error") + } else { + rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength() + if rejLen == 0 { + t.Fatalf("zero rejected length in oversized message error") + } + if maxLen == 0 { + t.Fatalf("zero maximum length in oversized message error") + } + if rejLen <= maxLen { + t.Fatalf("oversized message error rejected < max. length (%d < %d)", + rejLen, maxLen) + } + } + } + } else { + if err != nil { + t.Fatalf("expected success, got error %v", err) + } + } - registerTestingService(server, &testingServer{}) + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + if err := <-errs; err != ErrServerClosed { + t.Fatal(err) + } + } - tp := &internal.TestPayload{ - Foo: strings.Repeat("a", 1+messageLengthMax), + for _, tc := range []*testCase{ + { + name: "default limits, fitting request and response", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit - overhead, + }, + { + name: "default limits, only recv side check", + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit - overhead, + serverFail: true, + }, + + { + name: "default limits, oversized request", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit, + clientFail: true, + }, + { + name: "default limits, oversized response", + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit / 2, + serverFail: true, + }, + { + name: "8K limits, 4K request and response", + echoOnce: true, + clientLimit: 8 * 1024, + serverLimit: 8 * 1024, + requestSize: 4 * 1024, + }, + { + name: "4K limits, barely fitting cc. 4K request and response", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4*1024 - overhead, + }, + { + name: "4K limits, oversized request on client side", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + clientFail: true, + sendFail: true, + }, + { + name: "4K limits, oversized request on server side", + echoOnce: true, + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + serverFail: true, + }, + { + name: "4K limits, oversized response on client side", + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 8*1024 + overhead, + clientFail: true, + }, + { + name: "4K limits, oversized response on server side", + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + serverFail: true, + }, + { + name: "too small limits, adjusted to minimum accepted limit", + echoOnce: true, + clientLimit: 4, + serverLimit: 4, + requestSize: 4*1024 - overhead, + }, + { + name: "maximum allowed protocol limit", + echoOnce: true, + clientLimit: MaxMessageLengthLimit, + serverLimit: MaxMessageLengthLimit, + requestSize: MaxMessageLengthLimit - overhead, + }, + } { + t.Run(tc.name, func(t *testing.T) { + runTest(t, tc) + }) } - if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { - t.Fatalf("expected error from oversized message") - } else if status, ok := status.FromError(err); !ok { - t.Fatalf("expected status present in error: %v", err) - } else if status.Code() != codes.ResourceExhausted { - t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) +} + +func getWireMessageOverhead(t *testing.T) int { + emptyReq, err := codec{}.Marshal(&Request{ + Service: serviceName, + Method: "Test", + }) + if err != nil { + t.Fatalf("failed to marshal empty request: %v", err) } - if err := server.Shutdown(ctx); err != nil { - t.Fatal(err) + emptyRsp, err := codec{}.Marshal(&Response{ + Status: status.New(codes.OK, "").Proto(), + }) + if err != nil { + t.Fatalf("failed to marshal empty response: %v", err) } - if err := <-errs; err != ErrServerClosed { - t.Fatal(err) + + reqLen := len(emptyReq) + rspLen := len(emptyRsp) + if reqLen > rspLen { + return reqLen + messageHeaderLength } + + return rspLen + messageHeaderLength } func TestClientEOF(t *testing.T) { @@ -551,13 +790,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func } func newTestListener(t testing.TB) (string, net.Listener) { - var prefix string + var ( + name = t.Name() + prefix string + ) // Abstracts sockets are only available on Linux. if runtime.GOOS == "linux" { prefix = "\x00" + } else { + if split := strings.SplitN(name, "/", 2); len(split) == 2 { + name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1]))) + } } - addr := prefix + t.Name() + addr := prefix + name listener, err := net.Listen("unix", addr) if err != nil { t.Fatal(err)