Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving some internal error-handling #846

Merged
merged 5 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kafka

import (
"bufio"
"errors"
"io"
"sync"
"time"
Expand Down Expand Up @@ -82,7 +83,7 @@ func (batch *Batch) close() (err error) {
batch.msgs.discard()
}

if err = batch.err; err == io.EOF {
if err = batch.err; errors.Is(batch.err, io.EOF) {
err = nil
}

Expand All @@ -93,7 +94,8 @@ func (batch *Batch) close() (err error) {
conn.mutex.Unlock()

if err != nil {
if _, ok := err.(Error); !ok && err != io.ErrShortBuffer {
var kafkaError Error
if !errors.As(err, &kafkaError) && !errors.Is(err, io.ErrShortBuffer) {
conn.Close()
}
}
Expand Down Expand Up @@ -238,11 +240,11 @@ func (batch *Batch) readMessage(

var lastOffset int64
offset, lastOffset, timestamp, headers, err = batch.msgs.readMessage(batch.offset, key, val)
switch err {
case nil:
switch {
case err == nil:
batch.offset = offset + 1
batch.lastOffset = lastOffset
case errShortRead:
case errors.Is(err, errShortRead):
// As an "optimization" kafka truncates the returned response after
// producing MaxBytes, which could then cause the code to return
// errShortRead.
Expand Down Expand Up @@ -272,7 +274,7 @@ func (batch *Batch) readMessage(
// to MaxBytes truncation
// - `batch.lastOffset` to ensure that the message format contains
// `lastOffset`
if batch.err == io.EOF && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 {
if errors.Is(batch.err, io.EOF) && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 {
// Log compaction can create batches that end with compacted
// records so the normal strategy that increments the "next"
// offset as records are read doesn't work as the compacted
Expand Down
5 changes: 3 additions & 2 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kafka

import (
"context"
"errors"
"io"
"net"
"strconv"
Expand Down Expand Up @@ -30,11 +31,11 @@ func TestBatchDontExpectEOF(t *testing.T) {

batch := conn.ReadBatch(1024, 8192)

if _, err := batch.ReadMessage(); err != io.ErrUnexpectedEOF {
if _, err := batch.ReadMessage(); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("bad error when reading message:", err)
}

if err := batch.Close(); err != io.ErrUnexpectedEOF {
if err := batch.Close(); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("bad error when closing the batch:", err)
}
}
5 changes: 3 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package kafka
import (
"context"
"errors"
"fmt"
"net"
"time"

Expand Down Expand Up @@ -67,7 +68,7 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int
})

if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get topic metadata :%w", err)
}

topic := metadata.Topics[0]
Expand All @@ -85,7 +86,7 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int
})

if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get offsets: %w", err)
}

topicOffsets := offsets.Topics[topic.Name]
Expand Down
7 changes: 4 additions & 3 deletions compress/snappy/xerial.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package snappy
import (
"bytes"
"encoding/binary"
"errors"
"io"

"github.com/klauspost/compress/snappy"
Expand Down Expand Up @@ -64,7 +65,7 @@ func (x *xerialReader) WriteTo(w io.Writer) (int64, error) {
}

if _, err := x.readChunk(nil); err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
return wn, err
Expand Down Expand Up @@ -128,7 +129,7 @@ func (x *xerialReader) readChunk(dst []byte) (int, error) {
n, err := x.read(x.input[len(x.input):cap(x.input)])
x.input = x.input[:len(x.input)+n]
if err != nil {
if err == io.EOF && len(x.input) > 0 {
if errors.Is(err, io.EOF) && len(x.input) > 0 {
break
}
return 0, err
Expand Down Expand Up @@ -212,7 +213,7 @@ func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) {
}

if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
return wn, err
Expand Down
5 changes: 3 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
default:
throttle, highWaterMark, remain, err = readFetchResponseHeaderV2(&c.rbuf, size)
}
if err == errShortRead {
if errors.Is(err, errShortRead) {
err = checkTimeoutErr(adjustedDeadline)
}

Expand All @@ -865,9 +865,10 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
msgs, err = newMessageSetReader(&c.rbuf, remain)
}
}
if err == errShortRead {
if errors.Is(err, errShortRead) {
err = checkTimeoutErr(adjustedDeadline)
}

return &Batch{
conn: c,
msgs: msgs,
Expand Down
26 changes: 15 additions & 11 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package kafka
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -640,10 +641,13 @@ func testConnReadBatchWithMaxWait(t *testing.T, conn *Conn) {
conn.Seek(0, SeekAbsolute)
conn.SetDeadline(time.Now().Add(50 * time.Millisecond))
batch = conn.ReadBatchWith(cfg)
var netErr net.Error
if err := batch.Err(); err == nil {
t.Fatal("should have timed out, but got no error")
} else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
t.Fatalf("should have timed out, but got: %v", err)
} else if errors.As(err, &netErr) {
if !netErr.Timeout() {
t.Fatalf("should have timed out, but got: %v", err)
}
}
}

Expand Down Expand Up @@ -761,7 +765,7 @@ func testConnFindCoordinator(t *testing.T, conn *Conn) {

func testConnJoinGroupInvalidGroupID(t *testing.T, conn *Conn) {
_, err := conn.joinGroup(joinGroupRequestV1{})
if err != InvalidGroupId && err != NotCoordinatorForGroup {
if !errors.Is(err, InvalidGroupId) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", InvalidGroupId, NotCoordinatorForGroup, err)
}
}
Expand All @@ -773,7 +777,7 @@ func testConnJoinGroupInvalidSessionTimeout(t *testing.T, conn *Conn) {
_, err := conn.joinGroup(joinGroupRequestV1{
GroupID: groupID,
})
if err != InvalidSessionTimeout && err != NotCoordinatorForGroup {
if !errors.Is(err, InvalidSessionTimeout) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", InvalidSessionTimeout, NotCoordinatorForGroup, err)
}
}
Expand All @@ -786,7 +790,7 @@ func testConnJoinGroupInvalidRefreshTimeout(t *testing.T, conn *Conn) {
GroupID: groupID,
SessionTimeout: int32(3 * time.Second / time.Millisecond),
})
if err != InvalidSessionTimeout && err != NotCoordinatorForGroup {
if !errors.Is(err, InvalidSessionTimeout) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", InvalidSessionTimeout, NotCoordinatorForGroup, err)
}
}
Expand All @@ -798,7 +802,7 @@ func testConnHeartbeatErr(t *testing.T, conn *Conn) {
_, err := conn.syncGroup(syncGroupRequestV0{
GroupID: groupID,
})
if err != UnknownMemberId && err != NotCoordinatorForGroup {
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
}
}
Expand All @@ -810,7 +814,7 @@ func testConnLeaveGroupErr(t *testing.T, conn *Conn) {
_, err := conn.leaveGroup(leaveGroupRequestV0{
GroupID: groupID,
})
if err != UnknownMemberId && err != NotCoordinatorForGroup {
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
}
}
Expand All @@ -822,7 +826,7 @@ func testConnSyncGroupErr(t *testing.T, conn *Conn) {
_, err := conn.syncGroup(syncGroupRequestV0{
GroupID: groupID,
})
if err != UnknownMemberId && err != NotCoordinatorForGroup {
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
}
}
Expand Down Expand Up @@ -985,7 +989,7 @@ func testConnReadShortBuffer(t *testing.T, conn *Conn) {
b[3] = 0

n, err := conn.Read(b)
if err != io.ErrShortBuffer {
if !errors.Is(err, io.ErrShortBuffer) {
t.Error("bad error:", i, err)
}
if n != 4 {
Expand Down Expand Up @@ -1061,7 +1065,7 @@ func testDeleteTopicsInvalidTopic(t *testing.T, conn *Conn) {
}
conn.SetDeadline(time.Now().Add(5 * time.Second))
err = conn.DeleteTopics("invalid-topic", topic)
if err != UnknownTopicOrPartition {
if !errors.Is(err, UnknownTopicOrPartition) {
t.Fatalf("expected UnknownTopicOrPartition error, but got %v", err)
}
partitions, err := conn.ReadPartitions(topic)
Expand Down Expand Up @@ -1154,7 +1158,7 @@ func TestUnsupportedSASLMechanism(t *testing.T) {
}
defer conn.Close()

if err := conn.saslHandshake("FOO"); err != UnsupportedSASLMechanism {
if err := conn.saslHandshake("FOO"); !errors.Is(err, UnsupportedSASLMechanism) {
t.Errorf("Expected UnsupportedSASLMechanism but got %v", err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion consumergroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroup
// assignments for the topic. this matches the behavior of the official
// clients: java, python, and librdkafka.
// a topic watcher can trigger a rebalance when the topic comes into being.
if err != nil && err != UnknownTopicOrPartition {
if err != nil && !errors.Is(err, UnknownTopicOrPartition) {
return nil, err
}

Expand Down
12 changes: 6 additions & 6 deletions consumergroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func TestConsumerGroup(t *testing.T) {
if gen != nil {
t.Errorf("expected generation to be nil")
}
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Errorf("expected context.Canceled, but got %+v", err)
}
},
Expand All @@ -301,7 +301,7 @@ func TestConsumerGroup(t *testing.T) {
if gen != nil {
t.Errorf("expected generation to be nil")
}
if err != ErrGroupClosed {
if !errors.Is(err, ErrGroupClosed) {
t.Errorf("expected ErrGroupClosed, but got %+v", err)
}
},
Expand Down Expand Up @@ -398,7 +398,7 @@ func TestConsumerGroupErrors(t *testing.T) {
gen, err := group.Next(ctx)
if err == nil {
t.Errorf("expected an error")
} else if err != NotCoordinatorForGroup {
} else if !errors.Is(err, NotCoordinatorForGroup) {
t.Errorf("got wrong error: %+v", err)
}
if gen != nil {
Expand Down Expand Up @@ -460,7 +460,7 @@ func TestConsumerGroupErrors(t *testing.T) {
gen, err := group.Next(ctx)
if err == nil {
t.Errorf("expected an error")
} else if err != InvalidTopic {
} else if !errors.Is(err, InvalidTopic) {
t.Errorf("got wrong error: %+v", err)
}
if gen != nil {
Expand Down Expand Up @@ -540,7 +540,7 @@ func TestConsumerGroupErrors(t *testing.T) {
gen, err := group.Next(ctx)
if err == nil {
t.Errorf("expected an error")
} else if err != InvalidTopic {
} else if !errors.Is(err, InvalidTopic) {
t.Errorf("got wrong error: %+v", err)
}
if gen != nil {
Expand Down Expand Up @@ -672,7 +672,7 @@ func TestGenerationStartsFunctionAfterClosed(t *testing.T) {
case <-time.After(time.Second):
t.Fatal("timed out waiting for func to run")
case err := <-ch:
if err != ErrGenerationEnded {
if !errors.Is(err, ErrGenerationEnded) {
t.Fatalf("expected %v but got %v", ErrGenerationEnded, err)
}
}
Expand Down
Loading