Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

channel: reject oversized messages on the sender side(, too). #171

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
// TODO: Error on send rather than on recv
//if len(p) > messageLengthMax {
// return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax)
//}
if len(p) > messageLengthMax {
return OversizedMessageError(len(p))
}

if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
return err
}
Expand Down
24 changes: 7 additions & 17 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,19 @@ func TestReadWriteMessage(t *testing.T) {

func TestMessageOversize(t *testing.T) {
var (
w, r = net.Pipe()
wch, rch = newChannel(w), newChannel(r)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
w, _ = net.Pipe()
wch = newChannel(w)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
)

go func() {
if err := wch.send(1, 1, 0, msg); err != nil {
errs <- err
}
errs <- wch.send(1, 1, 0, msg)
}()

_, _, err := rch.recv()
err := <-errs
if err == nil {
t.Fatalf("error expected reading with small buffer")
t.Fatalf("sending oversized message expected to fail")
}

status, ok := status.FromError(err)
Expand All @@ -114,12 +112,4 @@ func TestMessageOversize(t *testing.T) {
if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
}

select {
case err := <-errs:
if err != nil {
t.Fatal(err)
}
default:
}
}
48 changes: 47 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

package ttrpc

import "errors"
import (
"errors"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
// ErrProtocol is a general error in the handling the protocol.
Expand All @@ -32,3 +37,44 @@ var (
// ErrStreamClosed is when the streaming connection is closed.
ErrStreamClosed = errors.New("ttrpc: stream closed")
)

// OversizedMessageErr is used to indicate refusal to send an oversized message.
// It wraps a ResourceExhausted grpc Status together with the offending message
// length.
type OversizedMessageErr struct {
messageLength int
err error
}

// OversizedMessageError returns an OversizedMessageErr error for the given message
// length if it exceeds the allowed maximum. Otherwise a nil error is returned.
func OversizedMessageError(messageLength int) error {
if messageLength <= messageLengthMax {
return nil
}

return &OversizedMessageErr{
messageLength: messageLength,
err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax),
}
}

// Error returns the error message for the corresponding grpc Status for the error.
func (e *OversizedMessageErr) Error() string {
return e.err.Error()
}

// Unwrap returns the corresponding error with our grpc status code.
func (e *OversizedMessageErr) Unwrap() error {
return e.err
}

// RejectedLength retrieves the rejected message length which triggered the error.
func (e *OversizedMessageErr) RejectedLength() int {
return e.messageLength
}

// MaximumLength retrieves the maximum allowed message length that triggered the error.
func (*OversizedMessageErr) MaximumLength() int {
return messageLengthMax
}
Loading