diff --git a/cmd/pushbits/main.go b/cmd/pushbits/main.go index 1d10352..27458ef 100644 --- a/cmd/pushbits/main.go +++ b/cmd/pushbits/main.go @@ -88,7 +88,7 @@ func main() { log.L.Fatal(err) } - err = runner.Run(engine, c.HTTP.ListenAddress, c.HTTP.Port) + err = runner.Run(engine, c) if err != nil { log.L.Fatal(err) } diff --git a/config.example.yml b/config.example.yml index 8b8c54b..626fc60 100644 --- a/config.example.yml +++ b/config.example.yml @@ -16,6 +16,12 @@ http: # What proxies to trust. trustedproxies: [] + # Filename of the TLS certificate. + certfile: '' + + # Filename of the TLS private key. + keyfile: '' + database: # Currently sqlite3, mysql, and postgres are supported. dialect: 'sqlite3' diff --git a/internal/api/util.go b/internal/api/util.go index 546aa5a..54877e8 100644 --- a/internal/api/util.go +++ b/internal/api/util.go @@ -15,7 +15,7 @@ func SuccessOrAbort(ctx *gin.Context, code int, err error) bool { if err != nil { // If we know the error force error code switch err { - case pberrors.ErrorMessageNotFound: + case pberrors.ErrMessageNotFound: ctx.AbortWithError(http.StatusNotFound, err) default: ctx.AbortWithError(code, err) diff --git a/internal/configuration/configuration.go b/internal/configuration/configuration.go index be38f41..1a6a247 100644 --- a/internal/configuration/configuration.go +++ b/internal/configuration/configuration.go @@ -3,6 +3,8 @@ package configuration import ( "github.com/jinzhu/configor" + "github.com/pushbits/server/internal/log" + "github.com/pushbits/server/internal/pberrors" ) // testMode indicates if the package is run in test mode @@ -53,6 +55,8 @@ type Configuration struct { ListenAddress string `default:""` Port int `default:"8080"` TrustedProxies []string `default:"[]"` + CertFile string `default:""` + KeyFile string `default:""` } Database struct { Dialect string `default:"sqlite3"` @@ -80,6 +84,21 @@ func configFiles() []string { return []string{"config.yml"} } +func validateHTTPConfiguration(c *Configuration) error { + certAndKeyEmpty := (c.HTTP.CertFile == "" && c.HTTP.KeyFile == "") + certAndKeyPopulated := (c.HTTP.CertFile != "" && c.HTTP.KeyFile != "") + + if !certAndKeyEmpty && !certAndKeyPopulated { + return pberrors.ErrConfigTLSFilesInconsistent + } + + return nil +} + +func validateConfiguration(c *Configuration) error { + return validateHTTPConfiguration(c) +} + // Get returns the configuration extracted from env variables or config file. func Get() *Configuration { config := &Configuration{} @@ -93,5 +112,9 @@ func Get() *Configuration { panic(err) } + if err := validateConfiguration(config); err != nil { + log.L.Fatal(err) + } + return config } diff --git a/internal/configuration/configuration_test.go b/internal/configuration/configuration_test.go index 1e2339c..b279005 100644 --- a/internal/configuration/configuration_test.go +++ b/internal/configuration/configuration_test.go @@ -8,6 +8,7 @@ import ( "github.com/jinzhu/configor" "github.com/pushbits/server/internal/log" + "github.com/pushbits/server/internal/pberrors" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -231,3 +232,18 @@ func cleanUp() { log.L.Warnln("Cannot remove config file: ", err) } } + +func TestConfigurationValidation_ConfigTLSFilesInconsistent(t *testing.T) { + assert := assert.New(t) + + c := Configuration{} + c.Admin.MatrixID = "000000" + c.Matrix.Username = "default-username" + c.Matrix.Password = "default-password" + c.HTTP.CertFile = "populated" + c.HTTP.KeyFile = "" + + is := validateConfiguration(&c) + should := pberrors.ErrConfigTLSFilesInconsistent + assert.Equal(is, should, "validateConfiguration() should return ConfigTLSFilesInconsistent") +} diff --git a/internal/dispatcher/notification.go b/internal/dispatcher/notification.go index 3708876..8949a72 100644 --- a/internal/dispatcher/notification.go +++ b/internal/dispatcher/notification.go @@ -90,7 +90,7 @@ func (d *Dispatcher) DeleteNotification(a *model.Application, n *model.DeleteNot deleteMessage, err := d.getMessage(a, n.ID) if err != nil { log.L.Println(err) - return pberrors.ErrorMessageNotFound + return pberrors.ErrMessageNotFound } oldBody, oldFormattedBody, err = bodiesFromMessage(deleteMessage) @@ -199,7 +199,7 @@ func (d *Dispatcher) getMessage(a *model.Application, id string) (*event.Event, start = messages.End } - return nil, pberrors.ErrorMessageNotFound + return nil, pberrors.ErrMessageNotFound } // Replaces the content of a matrix message @@ -273,7 +273,7 @@ func (d *Dispatcher) respondToMessage(a *model.Application, body, formattedBody func bodiesFromMessage(message *event.Event) (body, formattedBody string, err error) { msgContent := message.Content.AsMessage() if msgContent == nil { - return "", "", pberrors.ErrorMessageNotFound + return "", "", pberrors.ErrMessageNotFound } formattedBody = msgContent.Body diff --git a/internal/pberrors/errors.go b/internal/pberrors/errors.go index f3255aa..edaab03 100644 --- a/internal/pberrors/errors.go +++ b/internal/pberrors/errors.go @@ -3,5 +3,8 @@ package pberrors import "errors" -// ErrorMessageNotFound indicates that a message does not exist -var ErrorMessageNotFound = errors.New("message not found") +// ErrMessageNotFound indicates that a message does not exist +var ErrMessageNotFound = errors.New("message not found") + +// ErrConfigTLSFilesInconsistent indicates that either just a certfile or a keyfile was provided +var ErrConfigTLSFilesInconsistent = errors.New("TLS certfile and keyfile must either both be provided or omitted") diff --git a/internal/runner/runner.go b/internal/runner/runner.go index c0cf8c3..05befb3 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -5,14 +5,19 @@ import ( "fmt" "github.com/gin-gonic/gin" + "github.com/pushbits/server/internal/configuration" ) // Run starts the Gin engine. -func Run(engine *gin.Engine, address string, port int) error { - err := engine.Run(fmt.Sprintf("%s:%d", address, port)) - if err != nil { - return err +func Run(engine *gin.Engine, c *configuration.Configuration) error { + var err error + address := fmt.Sprintf("%s:%d", c.HTTP.ListenAddress, c.HTTP.Port) + + if c.HTTP.CertFile != "" && c.HTTP.KeyFile != "" { + err = engine.RunTLS(address, c.HTTP.CertFile, c.HTTP.KeyFile) + } else { + err = engine.Run(address) } - return nil + return err }