From b5b2f19dc2ed19165615719d49b9ad8919537476 Mon Sep 17 00:00:00 2001 From: Eugene Gavrilov Date: Thu, 6 Dec 2018 01:25:44 +0500 Subject: [PATCH] [#23] Fix check same origin function --- README.md | 27 ++++++++------ api/stream/stream.go | 77 ++++++++++++++++++++++++++------------- api/stream/stream_test.go | 42 +++++++++++++++++---- config/config.go | 3 ++ config/config_test.go | 10 +++++ router/router.go | 2 +- 6 files changed, 117 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 1595ce4f..644b729d 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # Gotify Server -[![Build Status][badge-travis]][travis] -[![codecov][badge-codecov]][codecov] -[![Go Report Card][badge-go-report]][go-report] +[![Build Status][badge-travis]][travis] +[![codecov][badge-codecov]][codecov] +[![Go Report Card][badge-go-report]][go-report] [![Swagger Valid][badge-swagger]][swagger] [![FOSSA Status][fossa-badge]][fossa] -[![Api Docs][badge-api-docs]][api-docs] +[![Api Docs][badge-api-docs]][api-docs] [![latest release version][badge-release]][release] @@ -46,7 +46,7 @@ Google Play and the Google Play logo are trademarks of Google LLC. The docker image is available on docker hub at [gotify/server][docker-normal]. ``` bash -$ docker run -p 80:80 -v /etc/gotify/data:/app/data gotify/server +$ docker run -p 80:80 -v /etc/gotify/data:/app/data gotify/server ``` Also there is a specific docker image for arm-7 processors (raspberry pi), named [gotify/server-arm7][docker-arm7]. ``` bash @@ -60,7 +60,7 @@ Visit the [releases page](https://github.com/gotify/server/releases) and downloa ## Configuration ### File -When strings contain reserved characters then they need to be escaped. +When strings contain reserved characters then they need to be escaped. [List of reserved characters and how to escape them](https://stackoverflow.com/a/22235064/4244993). ``` yml @@ -82,6 +82,10 @@ server: responseheaders: # response headers are added to every response (default: none) Access-Control-Allow-Origin: "*" Access-Control-Allow-Methods: "GET,POST" + stream: + allowedorigins: # allowed origins for websocket connections (same origin is always allowed) + - ".+.example.com" + - "otherdomain.com" database: # for database see (configure database section) dialect: sqlite3 connection: data/gotify.db @@ -94,8 +98,8 @@ uploadedimagesdir: data/images # the directory for storing uploaded images ### Environment -Escaped characters in list or map environment settings (`GOTIFY_SERVER_RESPONSEHEADERS` and -`GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS`) need to be escaped as well. +Escaped characters in list or map environment settings (`GOTIFY_SERVER_RESPONSEHEADERS` and +`GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS`) need to be escaped as well. [List of reserved characters and how to escape them](https://stackoverflow.com/a/22235064/4244993). ``` bash @@ -111,6 +115,7 @@ GOTIFY_SERVER_SSL_LETSENCRYPT_CACHE=certs # lists are a little weird but do-able (: GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS=- mydomain.tld\n- myotherdomain.tld GOTIFY_SERVER_RESPONSEHEADERS="Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"" +GOTIFY_SERVER_STREAM_ALLOWEDORIGINS="- \".+.example.com\"\n- \"otherdomain.com\"" GOTIFY_DATABASE_DIALECT=sqlite3 GOTIFY_DATABASE_CONNECTION=gotify.db GOTIFY_DEFAULTUSER_NAME=admin @@ -126,7 +131,7 @@ GOTIFY_UPLOADEDIMAGESDIR=images | mysql | `gotify:secret@/gotifydb?charset=utf8&parseTime=True&loc=Local ` | | postgres | `host=localhost port=3306 user=gotify dbname=gotify password=secret` | -When using postgres without SSL then `sslmode=disable` must be added to the connection string. +When using postgres without SSL then `sslmode=disable` must be added to the connection string. See [#90](https://github.com/gotify/server/issues/90). ## Push Message Examples @@ -141,7 +146,7 @@ $ http -f POST "https://push.example.de/message?token=" title="my titl ``` [More examples can be found here](ADD_MESSAGE_EXAMPLES.md) -Also you can use [gotify/cli](https://github.com/gotify/cli) to push messages. +Also you can use [gotify/cli](https://github.com/gotify/cli) to push messages. The CLI stores url and token in a config file. ```bash @@ -200,7 +205,7 @@ $ go test ./... ``` ## Versioning -We use [SemVer](http://semver.org/) for versioning. For the versions available, see the +We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/gotify/server/tags). ## License diff --git a/api/stream/stream.go b/api/stream/stream.go index ed742c9f..07a0784f 100644 --- a/api/stream/stream.go +++ b/api/stream/stream.go @@ -1,6 +1,7 @@ package stream import ( + "regexp" "sync" "time" @@ -14,46 +15,25 @@ import ( "github.com/gotify/server/model" ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - if mode.IsDev() { - return true - } - return checkSameOrigin(r) - }, -} - -func checkSameOrigin(r *http.Request) bool { - origin := r.Header["Origin"] - if len(origin) == 0 { - return true - } - u, err := url.Parse(origin[0]) - if err != nil { - return false - } - return u.Host == r.Host -} - // The API provides a handler for a WebSocket stream API. type API struct { clients map[uint][]*client lock sync.RWMutex pingPeriod time.Duration pongTimeout time.Duration + upgrader *websocket.Upgrader } // New creates a new instance of API. // pingPeriod: is the interval, in which is server sends the a ping to the client. // pongTimeout: is the duration after the connection will be terminated, when the client does not respond with the // pong command. -func New(pingPeriod, pongTimeout time.Duration) *API { +func New(pingPeriod, pongTimeout time.Duration, allowedWebSocketOrigins []string) *API { return &API{ clients: make(map[uint][]*client), pingPeriod: pingPeriod, pongTimeout: pingPeriod + pongTimeout, + upgrader: newUpgrader(allowedWebSocketOrigins), } } @@ -147,7 +127,7 @@ func (a *API) register(client *client) { // schema: // $ref: "#/definitions/Error" func (a *API) Handle(ctx *gin.Context) { - conn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil) + conn, err := a.upgrader.Upgrade(ctx.Writer, ctx.Request, nil) if err != nil { return } @@ -172,3 +152,50 @@ func (a *API) Close() { delete(a.clients, k) } } + +func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + + if u.Hostname() == r.Host { + return true + } + + for _, allowedOrigin := range allowedOrigins { + if allowedOrigin.Match([]byte(u.Hostname())) { + return true + } + } + + return false +} + +func newUpgrader(allowedWebSocketOrigins []string) *websocket.Upgrader { + compiledAllowedOrigins := compileAllowedWebSocketOrigins(allowedWebSocketOrigins) + return &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + if mode.IsDev() { + return true + } + return isAllowedOrigin(r, compiledAllowedOrigins) + }, + } +} + +func compileAllowedWebSocketOrigins(allowedOrigins []string) []*regexp.Regexp { + var compiledAllowedOrigins []*regexp.Regexp + for _, origin := range allowedOrigins { + compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin)) + } + + return compiledAllowedOrigins +} diff --git a/api/stream/stream_test.go b/api/stream/stream_test.go index cf88b08b..a6248d26 100644 --- a/api/stream/stream_test.go +++ b/api/stream/stream_test.go @@ -405,14 +405,37 @@ func Test_sameOrigin_returnsTrue(t *testing.T) { mode.Set(mode.Prod) req := httptest.NewRequest("GET", "http://example.com/stream", nil) req.Header.Set("Origin", "http://example.com") - actual := checkSameOrigin(req) + actual := isAllowedOrigin(req, nil) assert.True(t, actual) } +func Test_isAllowedOrigin_withoutAllowedOrigins_failsWhenNotSameOrigin(t *testing.T) { + mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + req.Header.Set("Origin", "http://gorify.example.com") + actual := isAllowedOrigin(req, nil) + assert.False(t, actual) +} + +func Test_isAllowedOriginMatching(t *testing.T) { + mode.Set(mode.Prod) + compiledAllowedOrigins := compileAllowedWebSocketOrigins([]string{"go.{4}\\.example\\.com", "go\\.example\\.com"}) + + req := httptest.NewRequest("GET", "http://example.me/stream", nil) + req.Header.Set("Origin", "http://gorify.example.com") + assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins)) + + req.Header.Set("Origin", "http://go.example.com") + assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins)) + + req.Header.Set("Origin", "http://hello.example.com") + assert.False(t, isAllowedOrigin(req, compiledAllowedOrigins)) +} + func Test_emptyOrigin_returnsTrue(t *testing.T) { mode.Set(mode.Prod) req := httptest.NewRequest("GET", "http://example.com/stream", nil) - actual := checkSameOrigin(req) + actual := isAllowedOrigin(req, nil) assert.True(t, actual) } @@ -420,7 +443,7 @@ func Test_otherOrigin_returnsFalse(t *testing.T) { mode.Set(mode.Prod) req := httptest.NewRequest("GET", "http://example.com/stream", nil) req.Header.Set("Origin", "http://otherexample.de") - actual := checkSameOrigin(req) + actual := isAllowedOrigin(req, nil) assert.False(t, actual) } @@ -428,10 +451,15 @@ func Test_invalidOrigin_returnsFalse(t *testing.T) { mode.Set(mode.Prod) req := httptest.NewRequest("GET", "http://example.com/stream", nil) req.Header.Set("Origin", "http\\://otherexample.de") - actual := checkSameOrigin(req) + actual := isAllowedOrigin(req, nil) assert.False(t, actual) } +func Test_compileAllowedWebSocketOrigins(t *testing.T) { + assert.Equal(t, 0, len(compileAllowedWebSocketOrigins([]string{}))) + assert.Equal(t, 3, len(compileAllowedWebSocketOrigins([]string{"^.*$", "", "abc"}))) +} + func clients(api *API, user uint) []*client { api.lock.RLock() defer api.lock.RUnlock() @@ -439,7 +467,7 @@ func clients(api *API, user uint) []*client { return api.clients[user] } -func testClient(t *testing.T, url string) *testingClient { +func testClient(t *testing.T, url string) *testingClient { client := createClient(t, url) startReading(client) return client @@ -507,11 +535,11 @@ func (c *testingClient) expectNoMessage() { } func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) { - r := gin.New() r.Use(handlerFunc) // all 4 seconds a ping, and the client has 1 second to respond - api := New(4*time.Second, 1*time.Second) + api := New(4*time.Second, 1*time.Second, []string{}) + r.GET("/", api.Handle) server := httptest.NewServer(r) return server, api diff --git a/config/config.go b/config/config.go index c8f3e09a..c80288c0 100644 --- a/config/config.go +++ b/config/config.go @@ -25,6 +25,9 @@ type Configuration struct { } } ResponseHeaders map[string]string + Stream struct { + AllowedOrigins []string + } } Database struct { Dialect string `default:"sqlite3"` diff --git a/config/config_test.go b/config/config_test.go index 666cb1c9..c4c64ddf 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -15,15 +15,20 @@ func TestConfigEnv(t *testing.T) { os.Setenv("GOTIFY_SERVER_RESPONSEHEADERS", "Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"", ) + os.Setenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"") + conf := Get() assert.Equal(t, 80, conf.Server.Port, "should use defaults") assert.Equal(t, "jmattheis", conf.DefaultUser.Name, "should not use default but env var") assert.Equal(t, []string{"push.example.tld", "push.other.tld"}, conf.Server.SSL.LetsEncrypt.Hosts) assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"]) assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"]) + assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins) + os.Unsetenv("GOTIFY_DEFAULTUSER_NAME") os.Unsetenv("GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS") os.Unsetenv("GOTIFY_SERVER_RESPONSEHEADERS") + os.Unsetenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS") } func TestAddSlash(t *testing.T) { @@ -75,6 +80,10 @@ server: responseheaders: Access-Control-Allow-Origin: "*" Access-Control-Allow-Methods: "GET,POST" + stream: + allowedorigins: + - ".+.example.com" + - "otherdomain.com" database: dialect: mysql connection: user name @@ -94,6 +103,7 @@ defaultuser: assert.Equal(t, "user name", conf.Database.Connection) assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"]) assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"]) + assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins) assert.Nil(t, os.Remove("config.yml")) } diff --git a/router/router.go b/router/router.go index ccf49fe7..dc2adf5c 100644 --- a/router/router.go +++ b/router/router.go @@ -23,7 +23,7 @@ import ( // Create creates the gin engine with all routes. func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Configuration) (*gin.Engine, func()) { - streamHandler := stream.New(200*time.Second, 15*time.Second) + streamHandler := stream.New(200*time.Second, 15*time.Second, conf.Server.Stream.AllowedOrigins) authentication := auth.Auth{DB: db} messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db} clientHandler := api.ClientAPI{