Skip to content

Commit 780bda4

Browse files
committed
Fix race with c.readerShouldLock
Closes #168
1 parent e36318f commit 780bda4

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

conn.go

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,10 @@ type Conn struct {
7878
readLock chan struct{}
7979

8080
// messageReader state.
81-
readerMsgCtx context.Context
82-
readerMsgHeader header
83-
readerFrameEOF bool
84-
readerMaskPos int
85-
readerShouldLock bool
81+
readerMsgCtx context.Context
82+
readerMsgHeader header
83+
readerFrameEOF bool
84+
readerMaskPos int
8685

8786
setReadTimeout chan context.Context
8887
setWriteTimeout chan context.Context
@@ -445,7 +444,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
445444
c.readerFrameEOF = false
446445
c.readerMaskPos = 0
447446
c.readMsgLeft = c.msgReadLimit.Load()
448-
c.readerShouldLock = lock
449447

450448
r := &messageReader{
451449
c: c,
@@ -465,7 +463,11 @@ func (r *messageReader) eof() bool {
465463

466464
// Read reads as many bytes as possible into p.
467465
func (r *messageReader) Read(p []byte) (int, error) {
468-
n, err := r.read(p)
466+
return r.exportedRead(p, true)
467+
}
468+
469+
func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
470+
n, err := r.read(p, lock)
469471
if err != nil {
470472
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
471473
// isn't used widely yet.
@@ -477,17 +479,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
477479
return n, nil
478480
}
479481

480-
func (r *messageReader) read(p []byte) (int, error) {
481-
if r.c.readerShouldLock {
482-
err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock)
483-
if err != nil {
484-
return 0, err
482+
func (r *messageReader) readUnlocked(p []byte) (int, error) {
483+
return r.exportedRead(p, false)
484+
}
485+
486+
func (r *messageReader) read(p []byte, lock bool) (int, error) {
487+
if lock {
488+
// If we cannot acquire the read lock, then
489+
// there is either a concurrent read or the close handshake
490+
// is proceeding.
491+
select {
492+
case r.c.readLock <- struct{}{}:
493+
defer r.c.releaseLock(r.c.readLock)
494+
default:
495+
if r.c.closing.Load() == 1 {
496+
<-r.c.closed
497+
return 0, r.c.closeErr
498+
}
499+
return 0, errors.New("concurrent read detected")
485500
}
486-
defer r.c.releaseLock(r.c.readLock)
487501
}
488502

489503
if r.eof() {
490-
return 0, fmt.Errorf("cannot use EOFed reader")
504+
return 0, errors.New("cannot use EOFed reader")
491505
}
492506

493507
if r.c.readMsgLeft <= 0 {
@@ -950,8 +964,6 @@ func (c *Conn) waitClose() error {
950964
return c.closeReceived
951965
}
952966

953-
c.readerShouldLock = false
954-
955967
b := bpool.Get()
956968
buf := b.Bytes()
957969
buf = buf[:cap(buf)]
@@ -965,7 +977,8 @@ func (c *Conn) waitClose() error {
965977
}
966978
}
967979

968-
_, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf)
980+
r := readerFunc(c.activeReader.readUnlocked)
981+
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
969982
if err != nil {
970983
return err
971984
}
@@ -1019,6 +1032,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
10191032
}
10201033
}
10211034

1035+
type readerFunc func(p []byte) (int, error)
1036+
1037+
func (f readerFunc) Read(p []byte) (int, error) {
1038+
return f(p)
1039+
}
1040+
10221041
type writerFunc func(p []byte) (int, error)
10231042

10241043
func (f writerFunc) Write(p []byte) (int, error) {

0 commit comments

Comments
 (0)