Skip to content

Commit

Permalink
Add support for configuration of TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
eikendev committed Jul 15, 2023
1 parent 833e666 commit 61d5e04
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 12 deletions.
2 changes: 1 addition & 1 deletion cmd/pushbits/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func main() {
log.L.Fatal(err)
}

err = runner.Run(engine, c.HTTP.ListenAddress, c.HTTP.Port)
err = runner.Run(engine, c)
if err != nil {
log.L.Fatal(err)
}
Expand Down
6 changes: 6 additions & 0 deletions config.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ http:
# What proxies to trust.
trustedproxies: []

# Filename of the TLS certificate.
certfile: ''

# Filename of the TLS private key.
keyfile: ''

database:
# Currently sqlite3, mysql, and postgres are supported.
dialect: 'sqlite3'
Expand Down
2 changes: 1 addition & 1 deletion internal/api/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func SuccessOrAbort(ctx *gin.Context, code int, err error) bool {
if err != nil {
// If we know the error force error code
switch err {
case pberrors.ErrorMessageNotFound:
case pberrors.ErrMessageNotFound:
ctx.AbortWithError(http.StatusNotFound, err)
default:
ctx.AbortWithError(code, err)
Expand Down
23 changes: 23 additions & 0 deletions internal/configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package configuration

import (
"github.com/jinzhu/configor"
"github.com/pushbits/server/internal/log"
"github.com/pushbits/server/internal/pberrors"
)

// testMode indicates if the package is run in test mode
Expand Down Expand Up @@ -53,6 +55,8 @@ type Configuration struct {
ListenAddress string `default:""`
Port int `default:"8080"`
TrustedProxies []string `default:"[]"`
CertFile string `default:""`
KeyFile string `default:""`
}
Database struct {
Dialect string `default:"sqlite3"`
Expand Down Expand Up @@ -80,6 +84,21 @@ func configFiles() []string {
return []string{"config.yml"}
}

func validateHTTPConfiguration(c *Configuration) error {
certAndKeyEmpty := (c.HTTP.CertFile == "" && c.HTTP.KeyFile == "")
certAndKeyPopulated := (c.HTTP.CertFile != "" && c.HTTP.KeyFile != "")

if !certAndKeyEmpty && !certAndKeyPopulated {
return pberrors.ErrConfigTLSFilesInconsistent
}

return nil
}

func validateConfiguration(c *Configuration) error {
return validateHTTPConfiguration(c)
}

// Get returns the configuration extracted from env variables or config file.
func Get() *Configuration {
config := &Configuration{}
Expand All @@ -93,5 +112,9 @@ func Get() *Configuration {
panic(err)
}

if err := validateConfiguration(config); err != nil {
log.L.Fatal(err)
}

return config
}
16 changes: 16 additions & 0 deletions internal/configuration/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/jinzhu/configor"
"github.com/pushbits/server/internal/log"
"github.com/pushbits/server/internal/pberrors"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
Expand Down Expand Up @@ -231,3 +232,18 @@ func cleanUp() {
log.L.Warnln("Cannot remove config file: ", err)
}
}

func TestConfigurationValidation_ConfigTLSFilesInconsistent(t *testing.T) {
assert := assert.New(t)

c := Configuration{}
c.Admin.MatrixID = "000000"
c.Matrix.Username = "default-username"
c.Matrix.Password = "default-password"
c.HTTP.CertFile = "populated"
c.HTTP.KeyFile = ""

is := validateConfiguration(&c)
should := pberrors.ErrConfigTLSFilesInconsistent
assert.Equal(is, should, "validateConfiguration() should return ConfigTLSFilesInconsistent")
}
6 changes: 3 additions & 3 deletions internal/dispatcher/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (d *Dispatcher) DeleteNotification(a *model.Application, n *model.DeleteNot
deleteMessage, err := d.getMessage(a, n.ID)
if err != nil {
log.L.Println(err)
return pberrors.ErrorMessageNotFound
return pberrors.ErrMessageNotFound
}

oldBody, oldFormattedBody, err = bodiesFromMessage(deleteMessage)
Expand Down Expand Up @@ -199,7 +199,7 @@ func (d *Dispatcher) getMessage(a *model.Application, id string) (*event.Event,
start = messages.End
}

return nil, pberrors.ErrorMessageNotFound
return nil, pberrors.ErrMessageNotFound
}

// Replaces the content of a matrix message
Expand Down Expand Up @@ -273,7 +273,7 @@ func (d *Dispatcher) respondToMessage(a *model.Application, body, formattedBody
func bodiesFromMessage(message *event.Event) (body, formattedBody string, err error) {
msgContent := message.Content.AsMessage()
if msgContent == nil {
return "", "", pberrors.ErrorMessageNotFound
return "", "", pberrors.ErrMessageNotFound
}

formattedBody = msgContent.Body
Expand Down
7 changes: 5 additions & 2 deletions internal/pberrors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@ package pberrors

import "errors"

// ErrorMessageNotFound indicates that a message does not exist
var ErrorMessageNotFound = errors.New("message not found")
// ErrMessageNotFound indicates that a message does not exist
var ErrMessageNotFound = errors.New("message not found")

// ErrConfigTLSFilesInconsistent indicates that either just a certfile or a keyfile was provided
var ErrConfigTLSFilesInconsistent = errors.New("TLS certfile and keyfile must either both be provided or omitted")
15 changes: 10 additions & 5 deletions internal/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@ import (
"fmt"

"github.com/gin-gonic/gin"
"github.com/pushbits/server/internal/configuration"
)

// Run starts the Gin engine.
func Run(engine *gin.Engine, address string, port int) error {
err := engine.Run(fmt.Sprintf("%s:%d", address, port))
if err != nil {
return err
func Run(engine *gin.Engine, c *configuration.Configuration) error {
var err error
address := fmt.Sprintf("%s:%d", c.HTTP.ListenAddress, c.HTTP.Port)

if c.HTTP.CertFile != "" && c.HTTP.KeyFile != "" {
err = engine.RunTLS(address, c.HTTP.CertFile, c.HTTP.KeyFile)
} else {
err = engine.Run(address)
}

return nil
return err
}

0 comments on commit 61d5e04

Please sign in to comment.