diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index 974e9c41d5..e1b29f7886 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -18,10 +18,14 @@ func TestUnmarshalConfig(t *testing.T) { rawConfig := []byte(` issuer: http://127.0.0.1:5556/dex storage: - type: sqlite3 + type: postgres config: - file: examples/dex.db - + host: 10.0.0.1 + port: 65432 + maxOpenConns: 5 + maxIdleConns: 3 + connMaxLifetime: 30 + connectionTimeout: 3 web: http: 127.0.0.1:5556 staticClients: @@ -69,9 +73,14 @@ logger: want := Config{ Issuer: "http://127.0.0.1:5556/dex", Storage: Storage{ - Type: "sqlite3", - Config: &sql.SQLite3{ - File: "examples/dex.db", + Type: "postgres", + Config: &sql.Postgres{ + Host: "10.0.0.1", + Port: 65432, + MaxOpenConns: 5, + MaxIdleConns: 3, + ConnMaxLifetime: 30, + ConnectionTimeout: 3, }, }, Web: Web{ diff --git a/storage/sql/config.go b/storage/sql/config.go index 701df92612..07fe6cf958 100644 --- a/storage/sql/config.go +++ b/storage/sql/config.go @@ -7,6 +7,7 @@ import ( "regexp" "strconv" "strings" + "time" "github.com/lib/pq" sqlite3 "github.com/mattn/go-sqlite3" @@ -88,6 +89,13 @@ type Postgres struct { SSL PostgresSSL `json:"ssl" yaml:"ssl"` ConnectionTimeout int // Seconds + + // database/sql tunables, see + // https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime and below + // Note: defaults will be set if these are 0 + MaxOpenConns int // default: 5 + MaxIdleConns int // default: 5 + ConnMaxLifetime int // Seconds, default: not set } // Open creates a new storage implementation backed by Postgres. @@ -177,6 +185,23 @@ func (p *Postgres) open(logger logrus.FieldLogger, dataSourceName string) (*conn return nil, err } + // set database/sql tunables if configured + if p.ConnMaxLifetime != 0 { + db.SetConnMaxLifetime(time.Duration(p.ConnMaxLifetime) * time.Second) + } + + if p.MaxIdleConns == 0 { + db.SetMaxIdleConns(5) + } else { + db.SetMaxIdleConns(p.MaxIdleConns) + } + + if p.MaxOpenConns == 0 { + db.SetMaxOpenConns(5) + } else { + db.SetMaxOpenConns(p.MaxOpenConns) + } + errCheck := func(err error) bool { sqlErr, ok := err.(*pq.Error) if !ok { diff --git a/storage/sql/postgres_test.go b/storage/sql/postgres_test.go new file mode 100644 index 0000000000..2e7bb476c9 --- /dev/null +++ b/storage/sql/postgres_test.go @@ -0,0 +1,48 @@ +// +build go1.11 + +package sql + +import ( + "os" + "testing" +) + +func TestPostgresTunables(t *testing.T) { + host := os.Getenv(testPostgresEnv) + if host == "" { + t.Skipf("test environment variable %q not set, skipping", testPostgresEnv) + } + baseCfg := &Postgres{ + Database: getenv("DEX_POSTGRES_DATABASE", "postgres"), + User: getenv("DEX_POSTGRES_USER", "postgres"), + Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"), + Host: host, + SSL: PostgresSSL{ + Mode: sslDisable, // Postgres container doesn't support SSL. + }} + + t.Run("with nothing set, uses defaults", func(t *testing.T) { + cfg := *baseCfg + c, err := cfg.open(logger, cfg.createDataSourceName()) + if err != nil { + t.Fatalf("error opening connector: %s", err.Error()) + } + defer c.db.Close() + if m := c.db.Stats().MaxOpenConnections; m != 5 { + t.Errorf("expected MaxOpenConnections to have its default (5), got %d", m) + } + }) + + t.Run("with something set, uses that", func(t *testing.T) { + cfg := *baseCfg + cfg.MaxOpenConns = 101 + c, err := cfg.open(logger, cfg.createDataSourceName()) + if err != nil { + t.Fatalf("error opening connector: %s", err.Error()) + } + defer c.db.Close() + if m := c.db.Stats().MaxOpenConnections; m != 101 { + t.Errorf("expected MaxOpenConnections to be set to 101, got %d", m) + } + }) +}