@@ -78,11 +78,10 @@ type Conn struct {
78
78
readLock chan struct {}
79
79
80
80
// messageReader state.
81
- readerMsgCtx context.Context
82
- readerMsgHeader header
83
- readerFrameEOF bool
84
- readerMaskPos int
85
- readerShouldLock bool
81
+ readerMsgCtx context.Context
82
+ readerMsgHeader header
83
+ readerFrameEOF bool
84
+ readerMaskPos int
86
85
87
86
setReadTimeout chan context.Context
88
87
setWriteTimeout chan context.Context
@@ -445,7 +444,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
445
444
c .readerFrameEOF = false
446
445
c .readerMaskPos = 0
447
446
c .readMsgLeft = c .msgReadLimit .Load ()
448
- c .readerShouldLock = lock
449
447
450
448
r := & messageReader {
451
449
c : c ,
@@ -465,7 +463,11 @@ func (r *messageReader) eof() bool {
465
463
466
464
// Read reads as many bytes as possible into p.
467
465
func (r * messageReader ) Read (p []byte ) (int , error ) {
468
- n , err := r .read (p )
466
+ return r .exportedRead (p , true )
467
+ }
468
+
469
+ func (r * messageReader ) exportedRead (p []byte , lock bool ) (int , error ) {
470
+ n , err := r .read (p , lock )
469
471
if err != nil {
470
472
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
471
473
// isn't used widely yet.
@@ -477,17 +479,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
477
479
return n , nil
478
480
}
479
481
480
- func (r * messageReader ) read (p []byte ) (int , error ) {
481
- if r .c .readerShouldLock {
482
- err := r .c .acquireLock (r .c .readerMsgCtx , r .c .readLock )
483
- if err != nil {
484
- return 0 , err
482
+ func (r * messageReader ) readUnlocked (p []byte ) (int , error ) {
483
+ return r .exportedRead (p , false )
484
+ }
485
+
486
+ func (r * messageReader ) read (p []byte , lock bool ) (int , error ) {
487
+ if lock {
488
+ // If we cannot acquire the read lock, then
489
+ // there is either a concurrent read or the close handshake
490
+ // is proceeding.
491
+ select {
492
+ case r .c .readLock <- struct {}{}:
493
+ defer r .c .releaseLock (r .c .readLock )
494
+ default :
495
+ if r .c .closing .Load () == 1 {
496
+ <- r .c .closed
497
+ return 0 , r .c .closeErr
498
+ }
499
+ return 0 , errors .New ("concurrent read detected" )
485
500
}
486
- defer r .c .releaseLock (r .c .readLock )
487
501
}
488
502
489
503
if r .eof () {
490
- return 0 , fmt . Errorf ("cannot use EOFed reader" )
504
+ return 0 , errors . New ("cannot use EOFed reader" )
491
505
}
492
506
493
507
if r .c .readMsgLeft <= 0 {
@@ -950,8 +964,6 @@ func (c *Conn) waitClose() error {
950
964
return c .closeReceived
951
965
}
952
966
953
- c .readerShouldLock = false
954
-
955
967
b := bpool .Get ()
956
968
buf := b .Bytes ()
957
969
buf = buf [:cap (buf )]
@@ -965,7 +977,8 @@ func (c *Conn) waitClose() error {
965
977
}
966
978
}
967
979
968
- _ , err = io .CopyBuffer (ioutil .Discard , c .activeReader , buf )
980
+ r := readerFunc (c .activeReader .readUnlocked )
981
+ _ , err = io .CopyBuffer (ioutil .Discard , r , buf )
969
982
if err != nil {
970
983
return err
971
984
}
@@ -1019,6 +1032,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
1019
1032
}
1020
1033
}
1021
1034
1035
+ type readerFunc func (p []byte ) (int , error )
1036
+
1037
+ func (f readerFunc ) Read (p []byte ) (int , error ) {
1038
+ return f (p )
1039
+ }
1040
+
1022
1041
type writerFunc func (p []byte ) (int , error )
1023
1042
1024
1043
func (f writerFunc ) Write (p []byte ) (int , error ) {
0 commit comments