Skip to content

Commit 679ddb8

Browse files
committed
Drastically improve non autobahn test coverage
Also simplified and refactored the Conn tests. More changes soon.
1 parent a3a891b commit 679ddb8

File tree

10 files changed

+761
-306
lines changed

10 files changed

+761
-306
lines changed

accept_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,39 @@ import (
66
"testing"
77
)
88

9+
func TestAccept(t *testing.T) {
10+
t.Parallel()
11+
12+
t.Run("badClientHandshake", func(t *testing.T) {
13+
t.Parallel()
14+
15+
w := httptest.NewRecorder()
16+
r := httptest.NewRequest("GET", "/", nil)
17+
18+
_, err := Accept(w, r, AcceptOptions{})
19+
if err == nil {
20+
t.Fatalf("unexpected error value: %v", err)
21+
}
22+
23+
})
24+
25+
t.Run("requireHttpHijacker", func(t *testing.T) {
26+
t.Parallel()
27+
28+
w := httptest.NewRecorder()
29+
r := httptest.NewRequest("GET", "/", nil)
30+
r.Header.Set("Connection", "Upgrade")
31+
r.Header.Set("Upgrade", "websocket")
32+
r.Header.Set("Sec-WebSocket-Version", "13")
33+
r.Header.Set("Sec-WebSocket-Key", "meow123")
34+
35+
_, err := Accept(w, r, AcceptOptions{})
36+
if err == nil || !strings.Contains(err.Error(), "http.Hijacker") {
37+
t.Fatalf("unexpected error value: %v", err)
38+
}
39+
})
40+
}
41+
942
func Test_verifyClientHandshake(t *testing.T) {
1043
t.Parallel()
1144

ci/test.sh

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,34 @@ set -euo pipefail
44
cd "$(dirname "${0}")"
55
cd "$(git rev-parse --show-toplevel)"
66

7-
mkdir -p ci/out/websocket
8-
testFlags=(
7+
argv=(
8+
go run gotest.tools/gotestsum
9+
# https://circleci.com/docs/2.0/collect-test-data/
10+
"--junitfile=ci/out/websocket/testReport.xml"
11+
"--format=short-verbose"
12+
--
913
-race
1014
"-vet=off"
11-
# "-bench=."
15+
"-bench=."
16+
)
17+
# Interactive usage probably does not want to enable benchmarks, race detection
18+
# turn off vet or use gotestsum by default.
19+
if [[ $# -gt 0 ]]; then
20+
argv=(go test "$@")
21+
fi
22+
23+
# We always want coverage.
24+
argv+=(
1225
"-coverprofile=ci/out/coverage.prof"
1326
"-coverpkg=./..."
1427
)
15-
# https://circleci.com/docs/2.0/collect-test-data/
16-
go run gotest.tools/gotestsum \
17-
--junitfile ci/out/websocket/testReport.xml \
18-
--format=short-verbose \
19-
-- "${testFlags[@]}"
28+
29+
mkdir -p ci/out/websocket
30+
"${argv[@]}"
31+
32+
# Removes coverage of generated files.
33+
grep -v _string.go < ci/out/coverage.prof > ci/out/coverage2.prof
34+
mv ci/out/coverage2.prof ci/out/coverage.prof
2035

2136
go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
2237
if [[ ${CI:-} ]]; then

dial_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,21 @@ func TestBadDials(t *testing.T) {
3333
},
3434
},
3535
},
36+
{
37+
name: "badTLS",
38+
url: "wss://totallyfake.nhooyr.io",
39+
},
3640
}
3741

3842
for _, tc := range testCases {
3943
tc := tc
4044
t.Run(tc.name, func(t *testing.T) {
4145
t.Parallel()
4246

43-
_, _, err := Dial(context.Background(), tc.url, tc.opts)
47+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
48+
defer cancel()
49+
50+
_, _, err := Dial(ctx, tc.url, tc.opts)
4451
if err == nil {
4552
t.Fatalf("expected non nil error: %+v", err)
4653
}

export_test.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
11
package websocket
22

3-
var Compute = handleSecWebSocketKey
3+
import (
4+
"context"
5+
)
6+
7+
type Addr = websocketAddr
8+
9+
type Header = header
10+
11+
func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
12+
return c.writeFrame(ctx, fin, opcode, p)
13+
}

header_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package websocket
22

33
import (
44
"bytes"
5+
"io"
56
"math/rand"
67
"strconv"
78
"testing"
@@ -21,6 +22,36 @@ func randBool() bool {
2122
func TestHeader(t *testing.T) {
2223
t.Parallel()
2324

25+
t.Run("eof", func(t *testing.T) {
26+
t.Parallel()
27+
28+
testCases := []struct {
29+
name string
30+
bytes []byte
31+
}{
32+
{
33+
"start",
34+
[]byte{0xff},
35+
},
36+
{
37+
"middle",
38+
[]byte{0xff, 0xff, 0xff},
39+
},
40+
}
41+
for _, tc := range testCases {
42+
tc := tc
43+
t.Run(tc.name, func(t *testing.T) {
44+
t.Parallel()
45+
46+
b := bytes.NewBuffer(tc.bytes)
47+
_, err := readHeader(nil, b)
48+
if io.ErrUnexpectedEOF != err {
49+
t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err)
50+
}
51+
})
52+
}
53+
})
54+
2455
t.Run("writeNegativeLength", func(t *testing.T) {
2556
t.Parallel()
2657

netconn.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ func (c *netConn) Read(p []byte) (int, error) {
101101
return 0, err
102102
}
103103
if typ != c.msgType {
104-
c.c.Close(StatusUnsupportedData, fmt.Sprintf("can only accept %v messages", c.msgType))
105-
return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", c.msgType, typ)
104+
c.c.Close(StatusUnsupportedData, fmt.Sprintf("unexpected frame type read (expected %v): %v", c.msgType, typ))
105+
return 0, c.c.closeErr
106106
}
107107
c.reader = r
108108
}

statuscode.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ const (
3535
StatusTryAgainLater
3636
StatusBadGateway
3737
// statusTLSHandshake is unexported because we just return
38-
// handshake error in dial. We do not return a conn
38+
// the handshake error in dial. We do not return a conn
3939
// so there is nothing to use this on. At least until WASM.
4040
statusTLSHandshake
4141
)

statuscode_test.go

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@ import (
44
"math"
55
"strings"
66
"testing"
7+
8+
"github.com/google/go-cmp/cmp"
79
)
810

911
func TestCloseError(t *testing.T) {
1012
t.Parallel()
1113

12-
// Other parts of close error are tested by websocket_test.go right now
13-
// with the autobahn tests.
14-
1514
testCases := []struct {
1615
name string
1716
ce CloseError
@@ -50,7 +49,108 @@ func TestCloseError(t *testing.T) {
5049

5150
_, err := tc.ce.bytes()
5251
if (err == nil) != tc.success {
53-
t.Fatalf("unexpected error value: %v", err)
52+
t.Fatalf("unexpected error value: %+v", err)
53+
}
54+
})
55+
}
56+
}
57+
58+
func Test_parseClosePayload(t *testing.T) {
59+
t.Parallel()
60+
61+
testCases := []struct {
62+
name string
63+
p []byte
64+
success bool
65+
ce CloseError
66+
}{
67+
{
68+
name: "normal",
69+
p: append([]byte{0x3, 0xE8}, []byte("hello")...),
70+
success: true,
71+
ce: CloseError{
72+
Code: StatusNormalClosure,
73+
Reason: "hello",
74+
},
75+
},
76+
{
77+
name: "nothing",
78+
success: true,
79+
ce: CloseError{
80+
Code: StatusNoStatusRcvd,
81+
},
82+
},
83+
{
84+
name: "oneByte",
85+
p: []byte{0},
86+
success: false,
87+
},
88+
{
89+
name: "badStatusCode",
90+
p: []byte{0x17, 0x70},
91+
success: false,
92+
},
93+
}
94+
95+
for _, tc := range testCases {
96+
tc := tc
97+
t.Run(tc.name, func(t *testing.T) {
98+
t.Parallel()
99+
100+
ce, err := parseClosePayload(tc.p)
101+
if (err == nil) != tc.success {
102+
t.Fatalf("unexpected expected error value: %+v", err)
103+
}
104+
105+
if tc.success && tc.ce != ce {
106+
t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce))
107+
}
108+
})
109+
}
110+
}
111+
112+
func Test_validWireCloseCode(t *testing.T) {
113+
t.Parallel()
114+
115+
testCases := []struct {
116+
name string
117+
code StatusCode
118+
valid bool
119+
}{
120+
{
121+
name: "normal",
122+
code: StatusNormalClosure,
123+
valid: true,
124+
},
125+
{
126+
name: "noStatus",
127+
code: StatusNoStatusRcvd,
128+
valid: false,
129+
},
130+
{
131+
name: "3000",
132+
code: 3000,
133+
valid: true,
134+
},
135+
{
136+
name: "4999",
137+
code: 4999,
138+
valid: true,
139+
},
140+
{
141+
name: "unknown",
142+
code: 5000,
143+
valid: false,
144+
},
145+
}
146+
147+
for _, tc := range testCases {
148+
tc := tc
149+
t.Run(tc.name, func(t *testing.T) {
150+
t.Parallel()
151+
152+
if valid := validWireCloseCode(tc.code); tc.valid != valid {
153+
t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid)
54154
}
55155
})
56156
}

0 commit comments

Comments
 (0)