diff --git a/jetstream/pull.go b/jetstream/pull.go index bb5479aa0..3196ad519 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -32,7 +32,7 @@ type ( // It is returned by [Consumer.Messages] method. MessagesContext interface { // Next retrieves next message on a stream. It will block until the next - // message is available. If the context is cancelled, Next will return + // message is available. If the context is canceled, Next will return // ErrMsgIteratorClosed error. Next() (Msg, error) @@ -531,7 +531,7 @@ var ( ) // Next retrieves next message on a stream. It will block until the next -// message is available. If the context is cancelled, Next will return +// message is available. If the context is canceled, Next will return // ErrMsgIteratorClosed error. func (s *pullSubscription) Next() (Msg, error) { s.Lock() @@ -1081,7 +1081,7 @@ type backoffOpts struct { // for all subsequent retries after reaching the limit customBackoff []time.Duration // cancel channel - // if set, retry will be cancelled when this channel is closed + // if set, retry will be canceled when this channel is closed cancel <-chan struct{} } diff --git a/js.go b/js.go index 462fea17e..97ce71d26 100644 --- a/js.go +++ b/js.go @@ -2861,7 +2861,13 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { } var hbTimer *time.Timer var hbErr error - if err == nil && len(msgs) < batch { + sub.mu.Lock() + subClosed := sub.closed || sub.draining + sub.mu.Unlock() + if subClosed { + err = errors.Join(ErrBadSubscription, ErrSubscriptionClosed) + } + if err == nil && len(msgs) < batch && !subClosed { // For batch real size of 1, it does not make sense to set no_wait in // the request. noWait := batch-len(msgs) > 1 @@ -3129,8 +3135,14 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e result.msgs <- msg } } - if len(result.msgs) == batch || result.err != nil { + sub.mu.Lock() + subClosed := sub.closed || sub.draining + sub.mu.Unlock() + if len(result.msgs) == batch || result.err != nil || subClosed { close(result.msgs) + if subClosed && len(result.msgs) == 0 { + return nil, errors.Join(ErrBadSubscription, ErrSubscriptionClosed) + } result.done <- struct{}{} return result, nil } diff --git a/jserrors.go b/jserrors.go index b5c968465..2a160405c 100644 --- a/jserrors.go +++ b/jserrors.go @@ -141,6 +141,9 @@ var ( // ErrNoHeartbeat is returned when no heartbeat is received from server when sending requests with pull consumer. ErrNoHeartbeat JetStreamError = &jsError{message: "no heartbeat received"} + // ErrSubscriptionClosed is returned when attempting to send pull request to a closed subscription + ErrSubscriptionClosed JetStreamError = &jsError{message: "subscription closed"} + // DEPRECATED: ErrInvalidDurableName is no longer returned and will be removed in future releases. // Use ErrInvalidConsumerName instead. ErrInvalidDurableName = errors.New("nats: invalid durable name") diff --git a/test/js_test.go b/test/js_test.go index 540ae41c9..279736c16 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -1239,6 +1239,64 @@ func TestPullSubscribeFetchWithHeartbeat(t *testing.T) { } } +func TestPullSubscribeFetchDrain(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + defer js.PurgeStream("TEST") + sub, err := js.PullSubscribe("foo", "") + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for i := 0; i < 100; i++ { + if _, err := js.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + } + // fill buffer with messages + cinfo, err := sub.ConsumerInfo() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + nextSubject := fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.TEST.%s", cinfo.Name) + replySubject := strings.Replace(sub.Subject, "*", "abc", 1) + payload := `{"batch":10,"no_wait":true}` + if err := nc.PublishRequest(nextSubject, replySubject, []byte(payload)); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + time.Sleep(100 * time.Millisecond) + + // now drain the subscription, messages should be in the buffer + sub.Drain() + msgs, err := sub.Fetch(100) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for _, msg := range msgs { + msg.Ack() + } + if len(msgs) != 10 { + t.Fatalf("Expected %d messages; got: %d", 10, len(msgs)) + } + + // subsequent fetch should return error, subscription is already drained + _, err = sub.Fetch(10, nats.MaxWait(100*time.Millisecond)) + if !errors.Is(err, nats.ErrSubscriptionClosed) { + t.Fatalf("Expected error: %s; got: %s", nats.ErrSubscriptionClosed, err) + } +} + func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) { s := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, s) @@ -1761,6 +1819,55 @@ func TestPullSubscribeFetchBatch(t *testing.T) { t.Errorf("Expected error: %s; got: %s", nats.ErrNoDeadlineContext, err) } }) + + t.Run("close subscription", func(t *testing.T) { + defer js.PurgeStream("TEST") + sub, err := js.PullSubscribe("foo", "") + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for i := 0; i < 100; i++ { + if _, err := js.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + } + // fill buffer with messages + cinfo, err := sub.ConsumerInfo() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + nextSubject := fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.TEST.%s", cinfo.Name) + replySubject := strings.Replace(sub.Subject, "*", "abc", 1) + payload := `{"batch":10,"no_wait":true}` + if err := nc.PublishRequest(nextSubject, replySubject, []byte(payload)); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + time.Sleep(100 * time.Millisecond) + + // now drain the subscription, messages should be in the buffer + sub.Drain() + res, err := sub.FetchBatch(100) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + msgs := make([]*nats.Msg, 0) + for msg := range res.Messages() { + msgs = append(msgs, msg) + msg.Ack() + } + if res.Error() != nil { + t.Fatalf("Unexpected error: %s", res.Error()) + } + if len(msgs) != 10 { + t.Fatalf("Expected %d messages; got: %d", 10, len(msgs)) + } + + // subsequent fetch should return error, subscription is already drained + _, err = sub.FetchBatch(10, nats.MaxWait(100*time.Millisecond)) + if !errors.Is(err, nats.ErrSubscriptionClosed) { + t.Fatalf("Expected error: %s; got: %s", nats.ErrSubscriptionClosed, err) + } + }) } func TestPullSubscribeConsumerDeleted(t *testing.T) { @@ -7646,7 +7753,7 @@ func testJetStreamFetchOptions(t *testing.T, srvs ...*jsServer) { if err == nil { t.Fatal("Unexpected success") } - if err != nats.ErrBadSubscription { + if !errors.Is(err, nats.ErrBadSubscription) { t.Fatalf("Unexpected error: %v", err) } })