Skip to content

Commit 9df3f43

Browse files
committed
Add header implementation with tests
1 parent eebb6d8 commit 9df3f43

File tree

2 files changed

+182
-14
lines changed

2 files changed

+182
-14
lines changed

header.go

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,141 @@
11
package websocket
22

33
import (
4+
"encoding/binary"
45
"io"
6+
7+
"golang.org/x/xerrors"
58
)
69

10+
// First byte contains fin, rsv1, rsv2, rsv3.
11+
// Second byte contains mask flag and payload len.
12+
// Next 8 bytes are the maximum extended payload length.
13+
// Last 4 bytes are the mask key.
14+
// https://tools.ietf.org/html/rfc6455#section-5.2
15+
const maxHeaderSize = 1 + 1 + 8 + 4
16+
717
// header represents a WebSocket frame header.
818
// See https://tools.ietf.org/html/rfc6455#section-5.2
919
// The fields are exported for easy printing for debugging.
1020
type header struct {
11-
Fin bool
12-
Rsv1 bool
13-
Rsv2 bool
14-
Rsv3 bool
15-
Opcode opcode
21+
fin bool
22+
rsv1 bool
23+
rsv2 bool
24+
rsv3 bool
25+
opcode opcode
1626

17-
PayloadLength int64
27+
payloadLength int64
1828

19-
Masked bool
20-
MaskKey [4]byte
29+
masked bool
30+
maskKey [4]byte
2131
}
2232

23-
// Bytes returns the bytes of the header.
24-
func (h header) Bytes() []byte {
25-
panic("TODO")
33+
// TODO bitwise helpers
34+
35+
// bytes returns the bytes of the header.
36+
// See https://tools.ietf.org/html/rfc6455#section-5.2
37+
func marshalHeader(h header) ([]byte, error) {
38+
b := make([]byte, 2, maxHeaderSize)
39+
40+
if h.fin {
41+
b[0] |= 1 << 7
42+
}
43+
if h.rsv1 {
44+
b[0] |= 1 << 6
45+
}
46+
if h.rsv2 {
47+
b[0] |= 1 << 5
48+
}
49+
if h.rsv3 {
50+
b[0] |= 1 << 4
51+
}
52+
53+
b[0] |= byte(h.opcode)
54+
55+
switch {
56+
case h.payloadLength < 0:
57+
return nil, xerrors.Errorf("invalid header: negative length: %v", h.payloadLength)
58+
case h.payloadLength <= 125:
59+
b[1] = byte(h.payloadLength)
60+
case h.payloadLength <= 1<<16:
61+
b[1] = 126
62+
b = b[:len(b)+2]
63+
binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength))
64+
default:
65+
b[1] = 127
66+
b = b[:len(b)+8]
67+
binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength))
68+
}
69+
70+
if h.masked {
71+
b[1] |= 1 << 7
72+
b = b[:len(b)+4]
73+
copy(b[len(b)-4:], h.maskKey[:])
74+
}
75+
76+
return b, nil
2677
}
2778

28-
// ReadHeader reads a header from the reader.
29-
func ReadHeader(r io.Reader) []byte {
30-
panic("TODO")
79+
// readHeader reads a header from the reader.
80+
// See https://tools.ietf.org/html/rfc6455#section-5.2
81+
func readHeader(r io.Reader) (header, error) {
82+
// We read the first two bytes directly so that we know
83+
// exactly how long the header is.
84+
b := make([]byte, 2, maxHeaderSize-2)
85+
_, err := io.ReadFull(r, b)
86+
if err != nil {
87+
return header{}, err
88+
}
89+
90+
var h header
91+
h.fin = b[0]&(1<<7) != 0
92+
h.rsv1 = b[0]&(1<<6) != 0
93+
h.rsv2 = b[0]&(1<<5) != 0
94+
h.rsv3 = b[0]&(1<<4) != 0
95+
96+
h.opcode = opcode(b[0] & 0xf)
97+
98+
var extra int
99+
100+
h.masked = b[1]&(1<<7) != 0
101+
if h.masked {
102+
extra += 4
103+
}
104+
105+
payloadLength := b[1] &^ (1 << 7)
106+
switch {
107+
case payloadLength < 126:
108+
h.payloadLength = int64(payloadLength)
109+
case payloadLength == 126:
110+
h.payloadLength = 126
111+
extra += 2
112+
case payloadLength == 127:
113+
h.payloadLength = 127
114+
extra += 8
115+
}
116+
117+
if extra == 0 {
118+
return h, nil
119+
}
120+
121+
b = b[:extra]
122+
_, err = io.ReadFull(r, b)
123+
if err != nil {
124+
return header{}, err
125+
}
126+
127+
switch {
128+
case payloadLength == 126:
129+
h.payloadLength = int64(binary.BigEndian.Uint16(b))
130+
b = b[2:]
131+
case payloadLength == 127:
132+
h.payloadLength = int64(binary.BigEndian.Uint64(b))
133+
b = b[8:]
134+
}
135+
136+
if h.masked {
137+
copy(h.maskKey[:], b)
138+
}
139+
140+
return h, nil
31141
}

header_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package websocket
2+
3+
import (
4+
"bytes"
5+
"math/rand"
6+
"testing"
7+
"time"
8+
9+
"github.com/google/go-cmp/cmp"
10+
)
11+
12+
func init() {
13+
rand.Seed(time.Now().UnixNano())
14+
}
15+
16+
func randBool() bool {
17+
return rand.Intn(1) == 0
18+
}
19+
20+
func TestHeader(t *testing.T) {
21+
t.Parallel()
22+
23+
for i := 0; i < 1000; i++ {
24+
h := header{
25+
fin: randBool(),
26+
rsv1: randBool(),
27+
rsv2: randBool(),
28+
rsv3: randBool(),
29+
opcode: opcode(rand.Intn(1 << 4)),
30+
31+
masked: randBool(),
32+
payloadLength: rand.Int63(),
33+
}
34+
35+
if h.masked {
36+
rand.Read(h.maskKey[:])
37+
}
38+
39+
t.Logf("header: %#v", h)
40+
41+
b, err := marshalHeader(h)
42+
if err != nil {
43+
t.Fatalf("failed to marshal header: %v", err)
44+
}
45+
46+
t.Logf("bytes: %b", b)
47+
48+
r := bytes.NewReader(b)
49+
h2, err := readHeader(r)
50+
if err != nil {
51+
t.Fatalf("failed to read header: %v", err)
52+
}
53+
54+
if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) {
55+
t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{})))
56+
}
57+
}
58+
}

0 commit comments

Comments
 (0)