Skip to content

Commit 052454e

Browse files
committed
Implement Conn
1 parent 9df3f43 commit 052454e

File tree

7 files changed

+386
-40
lines changed

7 files changed

+386
-40
lines changed

accept.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,22 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
116116

117117
w.WriteHeader(http.StatusSwitchingProtocols)
118118

119-
c, brw, err := hj.Hijack()
119+
netConn, brw, err := hj.Hijack()
120120
if err != nil {
121121
err = xerrors.Errorf("websocket: failed to hijack connection: %v", err)
122122
http.Error(w, err.Error(), http.StatusInternalServerError)
123123
return nil, err
124124
}
125125

126-
_ = c
127-
_ = brw
126+
c := &Conn{
127+
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
128+
br: brw.Reader,
129+
bw: brw.Writer,
130+
closer: netConn,
131+
}
132+
c.init()
128133

129-
return nil, nil
134+
return c, nil
130135
}
131136

132137
func selectSubprotocol(w http.ResponseWriter, r *http.Request, subprotocols []string) {

dial.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package websocket
22

33
import (
4+
"bufio"
45
"bytes"
56
"context"
67
"encoding/base64"
@@ -132,5 +133,14 @@ func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.R
132133
return nil, resp, xerrors.Errorf("websocket: body is not a read write closer but should be: %T", rwc)
133134
}
134135

135-
return nil, resp, nil
136+
c := &Conn{
137+
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
138+
br: bufio.NewReader(rwc),
139+
bw: bufio.NewWriter(rwc),
140+
closer: rwc,
141+
client: true,
142+
}
143+
c.init()
144+
145+
return c, resp, nil
136146
}

header.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package websocket
22

33
import (
44
"encoding/binary"
5+
"fmt"
56
"io"
67

78
"golang.org/x/xerrors"
@@ -16,7 +17,6 @@ const maxHeaderSize = 1 + 1 + 8 + 4
1617

1718
// header represents a WebSocket frame header.
1819
// See https://tools.ietf.org/html/rfc6455#section-5.2
19-
// The fields are exported for easy printing for debugging.
2020
type header struct {
2121
fin bool
2222
rsv1 bool
@@ -34,7 +34,7 @@ type header struct {
3434

3535
// bytes returns the bytes of the header.
3636
// See https://tools.ietf.org/html/rfc6455#section-5.2
37-
func marshalHeader(h header) ([]byte, error) {
37+
func marshalHeader(h header) []byte {
3838
b := make([]byte, 2, maxHeaderSize)
3939

4040
if h.fin {
@@ -54,7 +54,7 @@ func marshalHeader(h header) ([]byte, error) {
5454

5555
switch {
5656
case h.payloadLength < 0:
57-
return nil, xerrors.Errorf("invalid header: negative length: %v", h.payloadLength)
57+
panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength))
5858
case h.payloadLength <= 125:
5959
b[1] = byte(h.payloadLength)
6060
case h.payloadLength <= 1<<16:
@@ -73,7 +73,7 @@ func marshalHeader(h header) ([]byte, error) {
7373
copy(b[len(b)-4:], h.maskKey[:])
7474
}
7575

76-
return b, nil
76+
return b
7777
}
7878

7979
// readHeader reads a header from the reader.
@@ -130,6 +130,9 @@ func readHeader(r io.Reader) (header, error) {
130130
b = b[2:]
131131
case payloadLength == 127:
132132
h.payloadLength = int64(binary.BigEndian.Uint64(b))
133+
if h.payloadLength < 0 {
134+
return header{}, xerrors.Errorf("websocket: header has negative payload length: %v", h.payloadLength)
135+
}
133136
b = b[8:]
134137
}
135138

header_test.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ func TestHeader(t *testing.T) {
3838

3939
t.Logf("header: %#v", h)
4040

41-
b, err := marshalHeader(h)
42-
if err != nil {
43-
t.Fatalf("failed to marshal header: %v", err)
44-
}
45-
41+
b := marshalHeader(h)
4642
t.Logf("bytes: %b", b)
4743

4844
r := bytes.NewReader(b)

opcode.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,11 @@ const (
1515
opPong
1616
// 11-16 are reserved for further control frames.
1717
)
18+
19+
func (o opcode) controlOp() bool {
20+
switch o {
21+
case opClose, opPing, opPong:
22+
return true
23+
}
24+
return false
25+
}

statuscode.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"errors"
66
"fmt"
77
"math/bits"
8+
9+
"golang.org/x/xerrors"
810
)
911

1012
// StatusCode represents a WebSocket status code.
@@ -43,18 +45,23 @@ func (e CloseError) Error() string {
4345
return fmt.Sprintf("WebSocket closed with status = %v and reason = %q", e.Code, e.Reason)
4446
}
4547

46-
func parseClosePayload(p []byte) (code StatusCode, reason []byte, err error) {
48+
func parseClosePayload(p []byte) (code StatusCode, reason string, err error) {
4749
if len(p) < 2 {
48-
return 0, nil, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
50+
return 0, "", fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
4951
}
5052

5153
code = StatusCode(binary.BigEndian.Uint16(p))
52-
reason = p[2:]
54+
reason = string(p[2:])
5355

5456
return code, reason, nil
5557
}
5658

57-
func closePayload(code StatusCode, reason []byte) ([]byte, error) {
59+
const maxControlFramePayload = 125
60+
61+
func closePayload(code StatusCode, reason string) ([]byte, error) {
62+
if len(reason) > maxControlFramePayload-2 {
63+
return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, reason, len(reason))
64+
}
5865
if bits.Len(uint(code)) > 16 {
5966
return nil, errors.New("status code is larger than 2 bytes")
6067
}

0 commit comments

Comments
 (0)