Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client, server: implement configurable wire message size limits. #172

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ import (
"io"
"net"
"sync"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
messageHeaderLength = 10
messageLengthMax = 4 << 20
messageHeaderLength = 10
MinMessageLengthLimit = 4 << 10
MaxMessageLengthLimit = 4 << 22
DefaultMessageLengthLimit = 4 << 20
)

type messageType uint8
Expand Down Expand Up @@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
var buffers sync.Pool

type channel struct {
conn net.Conn
bw *bufio.Writer
br *bufio.Reader
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
hwbuf [messageHeaderLength]byte
conn net.Conn
bw *bufio.Writer
br *bufio.Reader
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
hwbuf [messageHeaderLength]byte
maxMsgLen int
}

func newChannel(conn net.Conn) *channel {
func newChannel(conn net.Conn, maxMsgLen int) *channel {
if maxMsgLen == 0 {
maxMsgLen = DefaultMessageLengthLimit
}
return &channel{
conn: conn,
bw: bufio.NewWriter(conn),
br: bufio.NewReader(conn),
conn: conn,
bw: bufio.NewWriter(conn),
br: bufio.NewReader(conn),
maxMsgLen: maxMsgLen,
}
}

Expand All @@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
return messageHeader{}, nil, err
}

if mh.Length > uint32(messageLengthMax) {
if maxMsgLen := ch.maxMsgLimit(true); mh.Length > uint32(maxMsgLen) {
if _, err := ch.br.Discard(int(mh.Length)); err != nil {
return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err)
}

return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
return mh, nil, OversizedMessageError(int(mh.Length), maxMsgLen)
}

var p []byte
Expand All @@ -143,10 +147,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
// TODO: Error on send rather than on recv
//if len(p) > messageLengthMax {
// return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax)
//}
if maxMsgLen := ch.maxMsgLimit(false); maxMsgLen != 0 {
if len(p) > maxMsgLen {
return OversizedMessageError(len(p), maxMsgLen)
}
}

if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
return err
}
Expand Down Expand Up @@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte {
func (ch *channel) putmbuf(p []byte) {
buffers.Put(&p)
}

func (ch *channel) maxMsgLimit(recv bool) int {
if ch.maxMsgLen == 0 && recv {
return DefaultMessageLengthLimit
}
return ch.maxMsgLen
}

func clampWireMessageLimit(maxMsgLen int) int {
switch {
case maxMsgLen == 0:
return 0
case maxMsgLen < MinMessageLengthLimit:
return MinMessageLengthLimit
case maxMsgLen > MaxMessageLengthLimit:
return MaxMessageLengthLimit
}
return maxMsgLen
}
28 changes: 9 additions & 19 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (
func TestReadWriteMessage(t *testing.T) {
var (
w, r = net.Pipe()
ch = newChannel(w)
rch = newChannel(r)
ch = newChannel(w, 0)
rch = newChannel(r, 0)
messages = [][]byte{
[]byte("hello"),
[]byte("this is a test"),
Expand Down Expand Up @@ -89,21 +89,19 @@ func TestReadWriteMessage(t *testing.T) {

func TestMessageOversize(t *testing.T) {
var (
w, r = net.Pipe()
wch, rch = newChannel(w), newChannel(r)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
w, _ = net.Pipe()
wch = newChannel(w, 0)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
)

go func() {
if err := wch.send(1, 1, 0, msg); err != nil {
errs <- err
}
errs <- wch.send(1, 1, 0, msg)
}()

_, _, err := rch.recv()
err := <-errs
if err == nil {
t.Fatalf("error expected reading with small buffer")
t.Fatalf("sending oversized message expected to fail")
}

status, ok := status.FromError(err)
Expand All @@ -114,12 +112,4 @@ func TestMessageOversize(t *testing.T) {
if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
}

select {
case err := <-errs:
if err != nil {
t.Fatal(err)
}
default:
}
}
19 changes: 14 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ import (

// Client for a ttrpc server
type Client struct {
codec codec
conn net.Conn
channel *channel
codec codec
conn net.Conn
channel *channel
maxMsgLen int

streamLock sync.RWMutex
streams map[streamID]*stream
Expand Down Expand Up @@ -107,14 +108,20 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker
}
}

// WithClientWireMessageLimit sets the maximum allowed message length on the wire for the client.
func WithClientWireMessageLimit(maxMsgLen int) ClientOpts {
maxMsgLen = clampWireMessageLimit(maxMsgLen)
return func(c *Client) {
c.maxMsgLen = maxMsgLen
}
}

// NewClient creates a new ttrpc client using the given connection
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
ctx, cancel := context.WithCancel(context.Background())
channel := newChannel(conn)
c := &Client{
codec: codec{},
conn: conn,
channel: channel,
streams: make(map[streamID]*stream),
nextStreamID: 1,
closed: cancel,
Expand All @@ -127,6 +134,8 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
o(c)
}

c.channel = newChannel(conn, c.maxMsgLen)

if c.interceptor == nil {
c.interceptor = defaultClientInterceptor
}
Expand Down
10 changes: 10 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
type serverConfig struct {
handshaker Handshaker
interceptor UnaryServerInterceptor
maxMsgLen int
}

// ServerOpt for configuring a ttrpc server
Expand Down Expand Up @@ -84,3 +85,12 @@ func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, intercep
chainUnaryServerInterceptors(info, method, interceptors[1:]))
}
}

// WithServerWireMessageLimit sets the maximum allowed message length on the wire for the server.
func WithServerWireMessageLimit(maxMsgLen int) ServerOpt {
maxMsgLen = clampWireMessageLimit(maxMsgLen)
return func(c *serverConfig) error {
c.maxMsgLen = maxMsgLen
return nil
}
}
88 changes: 87 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

package ttrpc

import "errors"
import (
"errors"
"fmt"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
// ErrProtocol is a general error in the handling the protocol.
Expand All @@ -32,3 +38,83 @@ var (
// ErrStreamClosed is when the streaming connection is closed.
ErrStreamClosed = errors.New("ttrpc: stream closed")
)

// OversizedMessageErr is used to indicate refusal to send an oversized message.
// It wraps a ResourceExhausted grpc Status together with the offending message
// length.
type OversizedMessageErr struct {
messageLength int
maxLength int
err error
}

var (
oversizedMsgFmt = "message length %d exceeds maximum message size of %d"
oversizedMsgScanFmt = fmt.Sprintf("%v", status.New(codes.ResourceExhausted, oversizedMsgFmt))
)

// OversizedMessageError returns an OversizedMessageErr error for the given message
// length if it exceeds the allowed maximum. Otherwise a nil error is returned.
func OversizedMessageError(messageLength, maxLength int) error {
if messageLength <= maxLength {
return nil
}

return &OversizedMessageErr{
messageLength: messageLength,
maxLength: maxLength,
err: OversizedMessageStatus(messageLength, maxLength).Err(),
}
}

// OversizedMessageStatus returns a Status for an oversized message error.
func OversizedMessageStatus(messageLength, maxLength int) *status.Status {
return status.Newf(codes.ResourceExhausted, oversizedMsgFmt, messageLength, maxLength)
}

// OversizedMessageFromError reconstructs an OversizedMessageErr from a Status.
func OversizedMessageFromError(err error) (*OversizedMessageErr, bool) {
var (
messageLength int
maxLength int
)

st, ok := status.FromError(err)
if !ok || st.Code() != codes.ResourceExhausted {
return nil, false
}

// TODO(klihub): might be too ugly to recover an error this way... An
// alternative would be to define our custom status detail proto type,
// then use status.WithDetails() and status.Details().

n, _ := fmt.Sscanf(st.Message(), oversizedMsgScanFmt, &messageLength, &maxLength)
if n != 2 {
n, _ = fmt.Sscanf(st.Message(), oversizedMsgFmt, &messageLength, &maxLength)
}
if n != 2 {
return nil, false
}

return OversizedMessageError(messageLength, maxLength).(*OversizedMessageErr), true
}

// Error returns the error message for the corresponding grpc Status for the error.
func (e *OversizedMessageErr) Error() string {
return e.err.Error()
}

// Unwrap returns the corresponding error with our grpc status code.
func (e *OversizedMessageErr) Unwrap() error {
return e.err
}

// RejectedLength retrieves the rejected message length which triggered the error.
func (e *OversizedMessageErr) RejectedLength() int {
return e.messageLength
}

// MaximumLength retrieves the maximum allowed message length that triggered the error.
func (e *OversizedMessageErr) MaximumLength() int {
return e.maxLength
}
36 changes: 31 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,18 @@ func (s *Server) RegisterService(name string, desc *ServiceDesc) {
}

func (s *Server) Serve(ctx context.Context, l net.Listener) error {
s.addListener(l)
s.mu.Lock()
s.addListenerLocked(l)
defer s.closeListener(l)

select {
case <-s.done:
s.mu.Unlock()
return ErrServerClosed
default:
}
s.mu.Unlock()

var (
backoff time.Duration
handshaker = s.config.handshaker
Expand Down Expand Up @@ -188,9 +197,7 @@ func (s *Server) Close() error {
return err
}

func (s *Server) addListener(l net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
func (s *Server) addListenerLocked(l net.Listener) {
s.listeners[l] = struct{}{}
}

Expand Down Expand Up @@ -332,7 +339,7 @@ func (c *serverConn) run(sctx context.Context) {
)

var (
ch = newChannel(c.conn)
ch = newChannel(c.conn, c.server.config.maxMsgLen)
ctx, cancel = context.WithCancel(sctx)
state connState = connStateIdle
responses = make(chan response)
Expand Down Expand Up @@ -366,6 +373,14 @@ func (c *serverConn) run(sctx context.Context) {
}
}

isResourceExhaustedError := func(err error) (*status.Status, bool) {
st, ok := status.FromError(err)
if !ok || st.Code() != codes.ResourceExhausted {
return nil, false
}
return st, true
}

go func(recvErr chan error) {
defer close(recvErr)
for {
Expand Down Expand Up @@ -518,6 +533,17 @@ func (c *serverConn) run(sctx context.Context) {
}

if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
if st, ok := isResourceExhaustedError(err); ok {
p, err = c.server.codec.Marshal(&Response{
Status: st.Proto(),
})
if err != nil {
log.G(ctx).WithError(err).Error("failed marshaling error response")
return
}
ch.send(response.id, messageTypeResponse, 0, p)
return
}
log.G(ctx).WithError(err).Error("failed sending message on channel")
return
}
Expand Down
Loading
Loading