Skip to content

Commit

Permalink
Add Cookie header support for WebSockets
Browse files Browse the repository at this point in the history
Closes #1226
  • Loading branch information
lcd1232 authored and Ivan Mirić committed Jun 8, 2021
1 parent 86315c6 commit 3dd5a6b
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 6 deletions.
75 changes: 75 additions & 0 deletions js/modules/k6/ws/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ import (
"io/ioutil"
"net"
"net/http"
"net/http/cookiejar"
netURL "net/url"
"reflect"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -109,9 +112,20 @@ func (*WS) Connect(ctx context.Context, url string, args ...goja.Value) (*WSHTTP

// Leave header to nil by default so we can pass it directly to the Dialer
var header http.Header
var jar http.CookieJar

tags := state.CloneTags()

u, err := netURL.Parse(url)
if err != nil {
return nil, err
}

cookiesMap := make(map[string]http.Cookie)
for _, cookie := range state.CookieJar.Cookies(u) {
cookiesMap[cookie.Name] = *cookie
}

// Parse the optional second argument (params)
if !goja.IsUndefined(paramsV) && !goja.IsNull(paramsV) {
params := paramsV.ToObject(rt)
Expand All @@ -130,6 +144,43 @@ func (*WS) Connect(ctx context.Context, url string, args ...goja.Value) (*WSHTTP
for _, key := range headersObj.Keys() {
header.Set(key, headersObj.Get(key).String())
}
case "cookies":
cookiesV := params.Get(k)
if goja.IsUndefined(cookiesV) || goja.IsNull(cookiesV) {
continue
}
cookies := cookiesV.ToObject(rt)
if cookies == nil {
continue
}
for _, key := range cookies.Keys() {
replace := false
value := ""
cookieV := cookies.Get(key)
if goja.IsUndefined(cookieV) || goja.IsNull(cookieV) {
continue
}
switch cookieV.ExportType() {
case reflect.TypeOf(map[string]interface{}{}):
cookie := cookieV.ToObject(rt)
for _, attr := range cookie.Keys() {
switch strings.ToLower(attr) {
case "replace":
replace = cookie.Get(attr).ToBoolean()
case "value":
value = cookie.Get(attr).String()
}
}
default:
value = cookieV.String()
}
if _, ok := cookiesMap[key]; !ok || replace {
cookiesMap[key] = http.Cookie{
Name: key,
Value: value,
}
}
}
case "tags":
tagsV := params.Get(k)
if goja.IsUndefined(tagsV) || goja.IsNull(tagsV) {
Expand All @@ -147,6 +198,29 @@ func (*WS) Connect(ctx context.Context, url string, args ...goja.Value) (*WSHTTP

}

cookies := make([]*http.Cookie, 0, len(cookiesMap))
for _, cookie := range cookiesMap {
v := cookie
cookies = append(cookies, &v)
}
if len(cookies) > 0 {
jar, err = cookiejar.New(nil)
if err != nil {
return nil, err
}

// SetCookies looking for http or https scheme. Otherwise it does nothing. This is mini hack to bypass it.
oldScheme := u.Scheme
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
}
jar.SetCookies(u, cookies)
u.Scheme = oldScheme
}

if state.Options.SystemTags.Has(stats.TagURL) {
tags["url"] = url
}
Expand All @@ -165,6 +239,7 @@ func (*WS) Connect(ctx context.Context, url string, args ...goja.Value) (*WSHTTP
NetDialContext: state.Dialer.DialContext,
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: tlsConfig,
Jar: jar,
}

start := time.Now()
Expand Down
20 changes: 19 additions & 1 deletion js/modules/k6/ws/ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func TestSession(t *testing.T) {
socket.send("test")
})
socket.on("message", function (data){
if (!data=="test") {
if (!(data=="test")) {
throw new Error ("echo'd data doesn't match our message!");
}
socket.close()
Expand Down Expand Up @@ -315,6 +315,24 @@ func TestSession(t *testing.T) {
})
assertSessionMetricsEmitted(t, stats.GetBufferedSamples(samples), "", sr("WSBIN_URL/ws-echo"), 101, "")

t.Run("client_cookie", func(t *testing.T) {
_, err := common.RunString(rt, sr(`
var params = {
cookies: { "Session": "123" },
};
var res = ws.connect("WSBIN_URL/ws-echo-cookie", params, function(socket){
socket.on("message", function (data){
if (!(data == "Session=123")) {
throw new Error ("echo'd data doesn't match our message!");
}
socket.close();
});
});
`))
assert.NoError(t, err)
assertSessionMetricsEmitted(t, stats.GetBufferedSamples(samples), "", sr("WSBIN_URL/ws-echo-cookie"), 101, "")
})

serverCloseTests := []struct {
name string
endpoint string
Expand Down
20 changes: 15 additions & 5 deletions lib/testutils/httpmultibin/httpmultibin.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ type jsonBody struct {
Compression string `json:"compression"`
}

func getWebsocketHandler(echo bool, closePrematurely bool) http.Handler {
func getWebsocketHandler(echo bool, closePrematurely bool, echoCookies bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conn, err := (&websocket.Upgrader{}).Upgrade(w, req, w.Header())
if err != nil {
Expand All @@ -129,6 +129,15 @@ func getWebsocketHandler(echo bool, closePrematurely bool) http.Handler {
return
}
}
if echoCookies {
var cookies []string
for _, cookie := range req.Cookies() {
cookies = append(cookies, cookie.String())
}
if err = conn.WriteMessage(websocket.TextMessage, []byte(strings.Join(cookies, "; "))); err != nil {
return
}
}
// closePrematurely=true mimics an invalid WS server that doesn't
// send a close control frame before closing the connection.
if !closePrematurely {
Expand Down Expand Up @@ -257,10 +266,11 @@ func NewHTTPMultiBin(t testing.TB) *HTTPMultiBin {
// Create a http.ServeMux and set the httpbin handler as the default
mux := http.NewServeMux()
mux.Handle("/brotli", getEncodedHandler(t, "br"))
mux.Handle("/ws-echo", getWebsocketHandler(true, false))
mux.Handle("/ws-echo-invalid", getWebsocketHandler(true, true))
mux.Handle("/ws-close", getWebsocketHandler(false, false))
mux.Handle("/ws-close-invalid", getWebsocketHandler(false, true))
mux.Handle("/ws-echo", getWebsocketHandler(true, false, false))
mux.Handle("/ws-echo-cookie", getWebsocketHandler(false, false, true))
mux.Handle("/ws-echo-invalid", getWebsocketHandler(true, true, false))
mux.Handle("/ws-close", getWebsocketHandler(false, false, false))
mux.Handle("/ws-close-invalid", getWebsocketHandler(false, true, false))
mux.Handle("/zstd", getEncodedHandler(t, "zstd"))
mux.Handle("/zstd-br", getZstdBrHandler(t))
mux.Handle("/", httpbin.New().Handler())
Expand Down

0 comments on commit 3dd5a6b

Please sign in to comment.