Skip to content

Commit d93d9c7

Browse files
committed
Implement handshake
1 parent 5050331 commit d93d9c7

File tree

12 files changed

+306
-23
lines changed

12 files changed

+306
-23
lines changed

.github/fmt/entrypoint.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ gen() {
88
go list ./... > /dev/null
99
go mod tidy
1010

11-
go install golang.org/x/tools/cmd/stringer
1211
go generate ./...
1312
}
1413

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ go get nhooyr.io/websocket@master
1717
## Features
1818

1919
- Full support of the WebSocket protocol
20+
- Only depends on the stdlib
2021
- Simple to use because of the minimal API
2122
- Uses the context package for cancellation
2223
- Uses net/http's Client to do WebSocket dials

accept.go

Lines changed: 138 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,48 @@
11
package websocket
22

33
import (
4-
"fmt"
4+
"crypto/sha1"
5+
"encoding/base64"
56
"net/http"
7+
"net/url"
8+
"strings"
9+
10+
"golang.org/x/net/http/httpguts"
11+
"golang.org/x/xerrors"
612
)
713

814
// AcceptOption is an option that can be passed to Accept.
15+
// The implementations of this interface are printable.
916
type AcceptOption interface {
1017
acceptOption()
11-
fmt.Stringer
1218
}
1319

20+
type acceptSubprotocols []string
21+
22+
func (o acceptSubprotocols) acceptOption() {}
23+
1424
// AcceptSubprotocols list the subprotocols that Accept will negotiate with a client.
1525
// The first protocol that a client supports will be negotiated.
16-
// Pass "" as a subprotocol if you would like to allow the default protocol.
26+
// Pass "" as a subprotocol if you would like to allow the default protocol along with
27+
// specific subprotocols.
1728
func AcceptSubprotocols(subprotocols ...string) AcceptOption {
18-
panic("TODO")
29+
return acceptSubprotocols(subprotocols)
1930
}
2031

32+
type acceptOrigins []string
33+
34+
func (o acceptOrigins) acceptOption() {}
35+
2136
// AcceptOrigins lists the origins that Accept will accept.
2237
// Accept will always accept r.Host as the origin so you do not need to
2338
// specify that with this option.
2439
//
2540
// Use this option with caution to avoid exposing your WebSocket
2641
// server to a CSRF attack.
2742
// See https://stackoverflow.com/a/37837709/4283659
28-
// You can use a * to specify wildcards.
43+
// You can use a * for wildcards.
2944
func AcceptOrigins(origins ...string) AcceptOption {
30-
panic("TODO")
45+
return AcceptOrigins(origins...)
3146
}
3247

3348
// Accept accepts a WebSocket handshake from a client and upgrades the
@@ -36,5 +51,121 @@ func AcceptOrigins(origins ...string) AcceptOption {
3651
// InsecureAcceptOrigin is passed.
3752
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
3853
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
39-
panic("TODO")
54+
var subprotocols []string
55+
origins := []string{r.Host}
56+
for _, opt := range opts {
57+
switch opt := opt.(type) {
58+
case acceptOrigins:
59+
origins = []string(opt)
60+
case acceptSubprotocols:
61+
subprotocols = []string(opt)
62+
}
63+
}
64+
65+
if !httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade") {
66+
err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection"))
67+
http.Error(w, err.Error(), http.StatusBadRequest)
68+
return nil, err
69+
}
70+
71+
if !httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") {
72+
err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade"))
73+
http.Error(w, err.Error(), http.StatusBadRequest)
74+
return nil, err
75+
}
76+
77+
if r.Method != "GET" {
78+
err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method)
79+
http.Error(w, err.Error(), http.StatusBadRequest)
80+
return nil, err
81+
}
82+
83+
if r.Header.Get("Sec-WebSocket-Version") != "13" {
84+
err := xerrors.Errorf("websocket: unsupported protocol version: %q", r.Header.Get("Sec-WebSocket-Version"))
85+
http.Error(w, err.Error(), http.StatusBadRequest)
86+
return nil, err
87+
}
88+
89+
if r.Header.Get("Sec-WebSocket-Key") == "" {
90+
err := xerrors.New("websocket: protocol violation: missing Sec-WebSocket-Key")
91+
http.Error(w, err.Error(), http.StatusBadRequest)
92+
return nil, err
93+
}
94+
95+
origins = append(origins, r.Host)
96+
97+
err := authenticateOrigin(r, origins)
98+
if err != nil {
99+
http.Error(w, err.Error(), http.StatusForbidden)
100+
return nil, err
101+
}
102+
103+
hj, ok := w.(http.Hijacker)
104+
if !ok {
105+
err = xerrors.New("websocket: response writer does not implement http.Hijacker")
106+
http.Error(w, err.Error(), http.StatusInternalServerError)
107+
return nil, err
108+
}
109+
110+
w.Header().Set("Upgrade", "websocket")
111+
w.Header().Set("Connection", "Upgrade")
112+
113+
handleKey(w, r)
114+
115+
selectSubprotocol(w, r, subprotocols)
116+
117+
w.WriteHeader(http.StatusSwitchingProtocols)
118+
119+
c, brw, err := hj.Hijack()
120+
if err != nil {
121+
err = xerrors.Errorf("websocket: failed to hijack connection: %v", err)
122+
http.Error(w, err.Error(), http.StatusInternalServerError)
123+
return nil, err
124+
}
125+
126+
_ = c
127+
_ = brw
128+
129+
return nil, nil
130+
}
131+
132+
func selectSubprotocol(w http.ResponseWriter, r *http.Request, subprotocols []string) {
133+
clientSubprotocols := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), "\n")
134+
for _, sp := range subprotocols {
135+
for _, cp := range clientSubprotocols {
136+
if sp == strings.TrimSpace(cp) {
137+
w.Header().Set("Sec-WebSocket-Protocol", sp)
138+
return
139+
}
140+
}
141+
}
142+
}
143+
144+
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
145+
146+
func handleKey(w http.ResponseWriter, r *http.Request) {
147+
key := r.Header.Get("Sec-WebSocket-Key")
148+
h := sha1.New()
149+
h.Write([]byte(key))
150+
h.Write(keyGUID)
151+
152+
responseKey := base64.StdEncoding.EncodeToString(h.Sum(nil))
153+
w.Header().Set("Sec-WebSocket-Accept", responseKey)
154+
}
155+
156+
func authenticateOrigin(r *http.Request, origins []string) error {
157+
origin := r.Header.Get("Origin")
158+
if origin == "" {
159+
return nil
160+
}
161+
u, err := url.Parse(origin)
162+
if err != nil {
163+
return xerrors.Errorf("failed to parse Origin header %q: %v", origin, err)
164+
}
165+
for _, o := range origins {
166+
if u.Host == o {
167+
return nil
168+
}
169+
}
170+
return xerrors.New("request origin is not authorized")
40171
}

datatype.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package websocket
22

33
// DataType represents the Opcode of a WebSocket data frame.
4-
//go:generate stringer -type=DataType
4+
//go:generate go run golang.org/x/tools/cmd/stringer -type=DataType
55
type DataType int
66

77
// DataType constants.

dial.go

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,53 @@
11
package websocket
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/base64"
6-
"fmt"
7+
"io"
8+
"io/ioutil"
79
"net/http"
10+
"net/url"
11+
"strings"
12+
13+
"golang.org/x/net/http/httpguts"
14+
"golang.org/x/xerrors"
815
)
916

1017
// DialOption represents a dial option that can be passed to Dial.
18+
// The implementations are printable for easy debugging.
1119
type DialOption interface {
1220
dialOption()
13-
fmt.Stringer
1421
}
1522

23+
type dialHTTPClient http.Client
24+
25+
func (o dialHTTPClient) dialOption() {}
26+
1627
// DialHTTPClient is the http client used for the handshake.
1728
// Its Transport must use HTTP/1.1 and must return writable bodies
1829
// for WebSocket handshakes.
1930
// http.Transport does this correctly.
20-
func DialHTTPClient(h *http.Client) DialOption {
21-
panic("TODO")
31+
func DialHTTPClient(hc *http.Client) DialOption {
32+
return (*dialHTTPClient)(hc)
2233
}
2334

35+
type dialHeader http.Header
36+
37+
func (o dialHeader) dialOption() {}
38+
2439
// DialHeader are the HTTP headers included in the handshake request.
2540
func DialHeader(h http.Header) DialOption {
26-
panic("TODO")
41+
return dialHeader(h)
2742
}
2843

44+
type dialSubprotocols []string
45+
46+
func (o dialSubprotocols) dialOption() {}
47+
2948
// DialSubprotocols accepts a slice of protcols to include in the Sec-WebSocket-Protocol header.
3049
func DialSubprotocols(subprotocols ...string) DialOption {
31-
panic("TODO")
50+
return dialSubprotocols(subprotocols)
3251
}
3352

3453
// We use this key for all client requests as the Sec-WebSocket-Key header is useless.
@@ -37,6 +56,81 @@ func DialSubprotocols(subprotocols ...string) DialOption {
3756
var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16))
3857

3958
// Dial performs a WebSocket handshake on the given url with the given options.
40-
func Dial(ctx context.Context, u string, opts ...DialOption) (*Conn, *http.Response, error) {
41-
panic("TODO")
59+
func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.Response, err error) {
60+
httpClient := http.DefaultClient
61+
var subprotocols []string
62+
header := http.Header{}
63+
for _, o := range opts {
64+
switch o := o.(type) {
65+
case dialSubprotocols:
66+
subprotocols = o
67+
case dialHeader:
68+
header = http.Header(o)
69+
case *dialHTTPClient:
70+
httpClient = (*http.Client)(o)
71+
}
72+
}
73+
74+
parsedURL, err := url.Parse(u)
75+
if err != nil {
76+
return nil, nil, xerrors.Errorf("failed to parse websocket url: %v", err)
77+
}
78+
79+
switch parsedURL.Scheme {
80+
case "ws", "http":
81+
parsedURL.Scheme = "http"
82+
case "wss", "https":
83+
parsedURL.Scheme = "https"
84+
default:
85+
return nil, nil, xerrors.Errorf("unknown scheme in url: %q", parsedURL.Scheme)
86+
}
87+
88+
req, _ := http.NewRequest("GET", u, nil)
89+
req = req.WithContext(ctx)
90+
req.Header = header
91+
req.Header.Set("Connection", "Upgrade")
92+
req.Header.Set("Upgrade", "websocket")
93+
req.Header.Set("Sec-WebSocket-Version", "13")
94+
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
95+
if len(subprotocols) > 0 {
96+
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(subprotocols, ","))
97+
}
98+
99+
resp, err := httpClient.Do(req)
100+
if err != nil {
101+
return nil, nil, err
102+
}
103+
defer func() {
104+
respBody := resp.Body
105+
if err != nil {
106+
// We read a bit of the body for better debugging.
107+
r := io.LimitReader(resp.Body, 1024)
108+
b, _ := ioutil.ReadAll(r)
109+
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
110+
}
111+
respBody.Close()
112+
}()
113+
114+
if resp.StatusCode != http.StatusSwitchingProtocols {
115+
return nil, resp, xerrors.Errorf("websocket: expected status code %v but got %v", http.StatusSwitchingProtocols)
116+
}
117+
118+
if !httpguts.HeaderValuesContainsToken(resp.Header["Connection"], "Upgrade") {
119+
return nil, resp, xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", resp.Header.Get("Connection"))
120+
}
121+
122+
if !httpguts.HeaderValuesContainsToken(resp.Header["Upgrade"], "websocket") {
123+
return nil, resp, xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", resp.Header.Get("Upgrade"))
124+
125+
}
126+
127+
// We do not care about Sec-WebSocket-Accept because it does not matter.
128+
// See the secWebSocketKey global variable.
129+
130+
rwc, ok := resp.Body.(io.ReadWriteCloser)
131+
if !ok {
132+
return nil, resp, xerrors.Errorf("websocket: body is not a read write closer but should be: %T", rwc)
133+
}
134+
135+
return nil, resp, nil
42136
}

doc.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
// Package websocket is a minimal and idiomatic implementation of the WebSocket protocol.
22
//
3+
// See https://tools.ietf.org/html/rfc6455
4+
//
35
// For now the docs are at https://github.com/nhooyr/websocket#websocket. I will move them here later.
46
package websocket

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ require (
77
github.com/kr/pretty v0.1.0 // indirect
88
go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16
99
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3
10+
golang.org/x/net v0.0.0-20190311183353-d8887717615a
1011
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
1112
golang.org/x/tools v0.0.0-20190329215204-73054e8977d1
13+
golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18
1214
mvdan.cc/sh v2.6.4+incompatible
1315
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,7 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb
2020
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
2121
golang.org/x/tools v0.0.0-20190329215204-73054e8977d1 h1:rLRH2E2wN5JjGJSVlBe1ioUkCKgb6eoL9X8bDmtEpsk=
2222
golang.org/x/tools v0.0.0-20190329215204-73054e8977d1/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
23+
golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18 h1:1AGvnywFL1aB5KLRxyLseWJI6aSYPo3oF7HSpXdWQdU=
24+
golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
2325
mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM=
2426
mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8=

opcode.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package websocket
22

33
// opcode represents a WebSocket Opcode.
4-
//go:generate stringer -type=opcode
4+
//go:generate go run golang.org/x/tools/cmd/stringer -type=opcode
55
type opcode int
66

77
// opcode constants.

statuscode.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
)
99

1010
// StatusCode represents a WebSocket status code.
11-
//go:generate stringer -type=StatusCode
11+
//go:generate go run golang.org/x/tools/cmd/stringer -type=StatusCode
1212
type StatusCode int
1313

1414
// These codes were retrieved from:

0 commit comments

Comments
 (0)