@@ -383,7 +383,11 @@ func TestHandshake(t *testing.T) {
383
383
}
384
384
defer c .Close (websocket .StatusInternalError , "" )
385
385
386
- go c .Reader (r .Context ())
386
+ errc := make (chan error , 1 )
387
+ go func () {
388
+ _ , _ , err2 := c .Read (r .Context ())
389
+ errc <- err2
390
+ }()
387
391
388
392
err = c .Ping (r .Context ())
389
393
if err != nil {
@@ -395,8 +399,12 @@ func TestHandshake(t *testing.T) {
395
399
return err
396
400
}
397
401
398
- c .Close (websocket .StatusNormalClosure , "" )
399
- return nil
402
+ err = <- errc
403
+ var ce websocket.CloseError
404
+ if xerrors .As (err , & ce ) && ce .Code == websocket .StatusNormalClosure {
405
+ return nil
406
+ }
407
+ return xerrors .Errorf ("unexpected error: %w" , err )
400
408
},
401
409
client : func (ctx context.Context , u string ) error {
402
410
c , _ , err := websocket .Dial (ctx , u , websocket.DialOptions {})
@@ -405,19 +413,30 @@ func TestHandshake(t *testing.T) {
405
413
}
406
414
defer c .Close (websocket .StatusInternalError , "" )
407
415
408
- errc := make (chan error , 1 )
416
+ // We read a message from the connection and then keep reading until
417
+ // the Ping completes.
418
+ done := make (chan struct {})
409
419
go func () {
410
- errc <- c .Ping (ctx )
420
+ _ , _ , err := c .Read (ctx )
421
+ if err != nil {
422
+ c .Close (websocket .StatusInternalError , err .Error ())
423
+ return
424
+ }
425
+
426
+ close (done )
427
+
428
+ c .Read (ctx )
411
429
}()
412
430
413
- _ , _ , err = c .Read (ctx )
431
+ err = c .Ping (ctx )
414
432
if err != nil {
415
433
return err
416
434
}
417
435
418
- err = <- errc
436
+ <- done
437
+
419
438
c .Close (websocket .StatusNormalClosure , "" )
420
- return err
439
+ return nil
421
440
},
422
441
},
423
442
{
0 commit comments