diff --git a/glide.lock b/glide.lock index f33cfe7554..e4b00be511 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: e5972bbdf15ad612d99ce8cd34e19537b9eacb5ff53688f339e0da285eb8ec22 -updated: 2018-11-12T19:38:56.235070564+01:00 +hash: 70e399f3424964c1535cefb66bce0e47af25ea6bb0f32a254e83e91bd774b5f2 +updated: 2018-11-20T09:49:19.83565589-05:00 imports: - name: github.com/beevik/etree version: 4cd0dd976db869f817248477718071a28e978df0 @@ -54,7 +54,7 @@ imports: - diff - pretty - name: github.com/lib/pq - version: 50761b0867bd1d9d069276790bcd4a3bccf2324a + version: 9eb73efc1fcc404148b56765b0d3f61d9a5ef8ee subpackages: - oid - name: github.com/mattn/go-sqlite3 diff --git a/glide.yaml b/glide.yaml index b8f459bea3..e4909ff41f 100644 --- a/glide.yaml +++ b/glide.yaml @@ -114,7 +114,7 @@ import: - package: github.com/mattn/go-sqlite3 version: 3fb7a0e792edd47bf0cf1e919dfc14e2be412e15 - package: github.com/lib/pq - version: 50761b0867bd1d9d069276790bcd4a3bccf2324a + version: 9eb73efc1fcc404148b56765b0d3f61d9a5ef8ee # etcd driver - package: github.com/coreos/etcd diff --git a/storage/sql/crud.go b/storage/sql/crud.go index d7c055ab18..a1406e2000 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { } func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { r, err := getAuthRequest(tx, id) if err != nil { return err @@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) if err != nil { return err } + _, err = tx.Exec(` update auth_request set @@ -163,21 +164,31 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) a.ConnectorID, a.ConnectorData, a.Expiry, r.ID, ) - if err != nil { - return fmt.Errorf("update auth request: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update auth request: %v", err) + } + return nil } func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { - return getAuthRequest(c, id) + req, err := getAuthRequest(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.AuthRequest{}, storage.ErrNotFound + } + + return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err) + } + + return req, nil } func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { err = q.QueryRow(` - select + select id, client_id, response_types, scopes, redirect_uri, nonce, state, force_approval_prompt, logged_in, claims_user_id, claims_username, claims_email, claims_email_verified, @@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { &a.ConnectorID, &a.ConnectorData, &a.Expiry, ) if err != nil { - if err == sql.ErrNoRows { - return a, storage.ErrNotFound - } - return a, fmt.Errorf("select auth request: %v", err) + return a, err } return a, nil } @@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { if c.alreadyExistsCheck(err) { return storage.ErrAlreadyExists } - return fmt.Errorf("insert refresh_token: %v", err) + return fmt.Errorf("insert refresh token: %v", err) } return nil } func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { r, err := getRefresh(tx, id) if err != nil { return err } + if r, err = updater(r); err != nil { return err } + _, err = tx.Exec(` update refresh_token set @@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok r.ConnectorID, r.ConnectorData, r.Token, r.CreatedAt, r.LastUsed, id, ) - if err != nil { - return fmt.Errorf("update refresh token: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update refresh token: %v", err) + } + return nil } func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { - return getRefresh(c, id) + req, err := getRefresh(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.RefreshToken{}, storage.ErrNotFound + } + + return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err) + } + + return req, nil } func getRefresh(q querier, id string) (storage.RefreshToken, error) { @@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { from refresh_token; `) if err != nil { - return nil, fmt.Errorf("query: %v", err) + return nil, fmt.Errorf("select refresh tokens: %v", err) } var tokens []storage.RefreshToken for rows.Next() { r, err := scanRefresh(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan refresh token: %s", err) } + tokens = append(tokens, r) } if err := rows.Err(); err != nil { @@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) { &r.Token, &r.CreatedAt, &r.LastUsed, ) if err != nil { - if err == sql.ErrNoRows { - return r, storage.ErrNotFound - } - return r, fmt.Errorf("scan refresh_token: %v", err) + return r, err } return r, nil } @@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) // TODO(ericchiang): errors may cause a transaction be rolled back by the SQL // server. Test this, and consider adding a COUNT() command beforehand. old, err := getKeys(tx) - if err != nil { - if err != storage.ErrNotFound { - return fmt.Errorf("get keys: %v", err) - } + if err == sql.ErrNoRows { firstUpdate = true old = storage.Keys{} + } else if err != nil { + return err } nk, err := updater(old) @@ -405,12 +422,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) encoder(nk.SigningKeyPub), nk.NextRotation, ) if err != nil { - return fmt.Errorf("insert: %v", err) + return err } } else { _, err = tx.Exec(` update keys - set + set verification_keys = $1, signing_key = $2, signing_key_pub = $3, @@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID, ) if err != nil { - return fmt.Errorf("update: %v", err) + return err } } return nil }) } -func (c *conn) GetKeys() (keys storage.Keys, err error) { - return getKeys(c) +func (c *conn) GetKeys() (storage.Keys, error) { + keys, err := getKeys(c) + if err != nil { + if err == sql.ErrNoRows { + return storage.Keys{}, storage.ErrNotFound + } + + return storage.Keys{}, fmt.Errorf("select keys: %s", err) + } + + return keys, nil } func getKeys(q querier) (keys storage.Keys, err error) { @@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) { decoder(&keys.SigningKeyPub), &keys.NextRotation, ) if err != nil { - if err == sql.ErrNoRows { - return keys, storage.ErrNotFound - } - return keys, fmt.Errorf("query keys: %v", err) + return keys, err } return keys, nil } func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { cli, err := getClient(tx, id) if err != nil { return err } + nc, err := updater(cli) if err != nil { return err @@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage where id = $7; `, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id, ) - if err != nil { - return fmt.Errorf("update client: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update client: %v", err) + } + + return nil } func (c *conn) CreateClient(cli storage.Client) error { @@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) { } func (c *conn) GetClient(id string) (storage.Client, error) { - return getClient(c, id) + client, err := getClient(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.Client{}, storage.ErrNotFound + } + + return storage.Client{}, fmt.Errorf("select client: %v", err) + } + + return client, nil } func (c *conn) ListClients() ([]storage.Client, error) { @@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) { for rows.Next() { cli, err := scanClient(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan client: %s", err) } clients = append(clients, cli) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return clients, nil } @@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) { &cli.Public, &cli.Name, &cli.LogoURL, ) if err != nil { - if err == sql.ErrNoRows { - return cli, storage.ErrNotFound - } - return cli, fmt.Errorf("get client: %v", err) + return cli, err } return cli, nil } @@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error { } func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { p, err := getPassword(tx, email) if err != nil { return err @@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st if err != nil { return err } + _, err = tx.Exec(` update password set @@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st `, np.Hash, np.Username, np.UserID, p.Email, ) - if err != nil { - return fmt.Errorf("update password: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update password: %v", err) + } + return nil } func (c *conn) GetPassword(email string) (storage.Password, error) { - return getPassword(c, email) + pass, err := getPassword(c, email) + if err != nil { + if err == sql.ErrNoRows { + return storage.Password{}, storage.ErrNotFound + } + + return storage.Password{}, fmt.Errorf("get password: %s", err) + } + + return pass, nil } func getPassword(q querier, email string) (p storage.Password, err error) { @@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) { for rows.Next() { p, err := scanPassword(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan password: %s", err) } passwords = append(passwords, p) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return passwords, nil } @@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) { &p.Email, &p.Hash, &p.Username, &p.UserID, ) if err != nil { - if err == sql.ErrNoRows { - return p, storage.ErrNotFound - } - return p, fmt.Errorf("select password: %v", err) + return p, err } return p, nil } @@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { } func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { s, err := getOfflineSessions(tx, userID, connID) if err != nil { return err @@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( if err != nil { return err } + _, err = tx.Exec(` update offline_session set @@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( `, encoder(newSession.Refresh), s.UserID, s.ConnID, ) - if err != nil { - return fmt.Errorf("update offline session: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update offline session: %v", err) + } + + return nil } func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { - return getOfflineSessions(c, userID, connID) + sessions, err := getOfflineSessions(c, userID, connID) + if err != nil { + if err == sql.ErrNoRows { + return storage.OfflineSessions{}, storage.ErrNotFound + } + + return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err) + } + + return sessions, nil } func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { @@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { &o.UserID, &o.ConnID, decoder(&o.Refresh), ) if err != nil { - if err == sql.ErrNoRows { - return o, storage.ErrNotFound - } - return o, fmt.Errorf("select offline session: %v", err) + return o, err } return o, nil } @@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error { } func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { connector, err := getConnector(tx, id) if err != nil { return err @@ -748,9 +797,10 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto if err != nil { return err } + _, err = tx.Exec(` update connector - set + set type = $1, name = $2, resource_version = $3, @@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto `, newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID, ) - if err != nil { - return fmt.Errorf("update connector: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update connector: %v", err) + } + + return nil } func (c *conn) GetConnector(id string) (storage.Connector, error) { - return getConnector(c, id) + connector, err := getConnector(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.Connector{}, storage.ErrNotFound + } + + return storage.Connector{}, fmt.Errorf("get connector: %s", err) + } + + return connector, nil } func getConnector(q querier, id string) (storage.Connector, error) { @@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) { &c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config, ) if err != nil { - if err == sql.ErrNoRows { - return c, storage.ErrNotFound - } - return c, fmt.Errorf("select connector: %v", err) + return c, err } return c, nil } @@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) { for rows.Next() { conn, err := scanConnector(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan connector: %s", err) } connectors = append(connectors, conn) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return connectors, nil } diff --git a/storage/sql/sql.go b/storage/sql/sql.go index dc6be4a1f4..69b03cbd3b 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -2,14 +2,15 @@ package sql import ( + "context" "database/sql" "regexp" "time" + "github.com/lib/pq" "github.com/sirupsen/logrus" // import third party drivers - _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -39,31 +40,66 @@ func matchLiteral(s string) *regexp.Regexp { return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`) } +// Detect a serialization failure, which should trigger retrying the +// transaction according to PostgreSQL docs: +// +// https://www.postgresql.org/docs/current/transaction-iso.html#XACT-SERIALIZABLE +// +// "applications using this level must be prepared to retry transactions due to +// serialization failures" +func isRetryableSerializationFailure(err error) bool { + if pqErr, ok := err.(*pq.Error); ok { + return pqErr.Code.Name() == "serialization_failure" + } + + return false +} + var ( // The "github.com/lib/pq" driver is the default flavor. All others are // translations of this. flavorPostgres = flavor{ - // The default behavior for Postgres transactions is consistent reads, not consistent writes. - // For each transaction opened, ensure it has the correct isolation level. + // The default behavior for Postgres transactions is consistent reads, not + // consistent writes. For each transaction opened, ensure it has the + // correct isolation level. // // See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html // - // NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a - // session level didn't work for some edge cases. Might be something worth exploring. + // Be careful not to wrap sql errors in the callback 'fn', otherwise + // serialization failures will not be detected and retried. executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() - if _, err := tx.Exec(`SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;`); err != nil { - return err + opts := &sql.TxOptions{ + Isolation: sql.LevelSerializable, } - if err := fn(tx); err != nil { - return err + + for { + tx, err := db.BeginTx(ctx, opts) + if err != nil { + return err + } + + if err := fn(tx); err != nil { + if isRetryableSerializationFailure(err) { + continue + } + + return err + } + + err = tx.Commit() + if err != nil { + if isRetryableSerializationFailure(err) { + continue + } + + return err + } + + return nil } - return tx.Commit() }, supportsTimezones: true, diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go index 27eb07a9e1..e4933e2276 100644 --- a/vendor/github.com/lib/pq/array.go +++ b/vendor/github.com/lib/pq/array.go @@ -13,7 +13,7 @@ import ( var typeByteSlice = reflect.TypeOf([]byte{}) var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() -var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // Array returns the optimal driver.Valuer and sql.Scanner for an array or // slice of any dimension. @@ -70,6 +70,9 @@ func (a *BoolArray) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to BoolArray", src) @@ -80,7 +83,7 @@ func (a *BoolArray) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(BoolArray, len(elems)) @@ -141,6 +144,9 @@ func (a *ByteaArray) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) @@ -151,7 +157,7 @@ func (a *ByteaArray) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(ByteaArray, len(elems)) @@ -210,6 +216,9 @@ func (a *Float64Array) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to Float64Array", src) @@ -220,7 +229,7 @@ func (a *Float64Array) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(Float64Array, len(elems)) @@ -269,7 +278,7 @@ func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]b // TODO calculate the assign function for other types // TODO repeat this section on the element type of arrays or slices (multidimensional) { - if reflect.PtrTo(rt).Implements(typeSqlScanner) { + if reflect.PtrTo(rt).Implements(typeSQLScanner) { // dest is always addressable because it is an element of a slice. assign = func(src []byte, dest reflect.Value) (err error) { ss := dest.Addr().Interface().(sql.Scanner) @@ -320,6 +329,11 @@ func (a GenericArray) Scan(src interface{}) error { return a.scanBytes(src, dv) case string: return a.scanBytes([]byte(src), dv) + case nil: + if dv.Kind() == reflect.Slice { + dv.Set(reflect.Zero(dv.Type())) + return nil + } } return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) @@ -386,7 +400,13 @@ func (a GenericArray) Value() (driver.Value, error) { rv := reflect.ValueOf(a.A) - if k := rv.Kind(); k != reflect.Array && k != reflect.Slice { + switch rv.Kind() { + case reflect.Slice: + if rv.IsNil() { + return nil, nil + } + case reflect.Array: + default: return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) } @@ -412,6 +432,9 @@ func (a *Int64Array) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to Int64Array", src) @@ -422,7 +445,7 @@ func (a *Int64Array) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(Int64Array, len(elems)) @@ -470,6 +493,9 @@ func (a *StringArray) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to StringArray", src) @@ -480,7 +506,7 @@ func (a *StringArray) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(StringArray, len(elems)) @@ -561,7 +587,7 @@ func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { } } - var del string = "," + var del = "," var err error var iv interface{} = rv.Interface() @@ -639,6 +665,9 @@ Element: for i < len(src) { switch src[i] { case '{': + if depth == len(dims) { + break Element + } depth++ dims[depth-1] = 0 i++ @@ -680,11 +709,11 @@ Element: } for i < len(src) { - if bytes.HasPrefix(src[i:], del) { + if bytes.HasPrefix(src[i:], del) && depth > 0 { dims[depth-1]++ i += len(del) goto Element - } else if src[i] == '}' { + } else if src[i] == '}' && depth > 0 { dims[depth-1]++ depth-- i++ diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index 8e1aee9f01..43c8df29f1 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -3,15 +3,12 @@ package pq import ( "bufio" "crypto/md5" - "crypto/tls" - "crypto/x509" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" - "io/ioutil" "net" "os" "os/user" @@ -30,18 +27,26 @@ var ( ErrNotSupported = errors.New("pq: Unsupported command") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") + ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") + + errUnexpectedReady = errors.New("unexpected ReadyForQuery") + errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") + errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) -type drv struct{} +// Driver is the Postgres database driver. +type Driver struct{} -func (d *drv) Open(name string) (driver.Conn, error) { +// Open opens a new connection to the database. name is a connection string. +// Most users should only use it through database/sql package from the standard +// library. +func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } func init() { - sql.Register("postgres", &drv{}) + sql.Register("postgres", &Driver{}) } type parameterStatus struct { @@ -77,6 +82,8 @@ func (s transactionStatus) String() string { panic("not reached") } +// Dialer is the dialer interface. It can be used to obtain more control over +// how pq creates network connections. type Dialer interface { Dial(network, address string) (net.Conn, error) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) @@ -97,6 +104,15 @@ type conn struct { namei int scratch [512]byte txnStatus transactionStatus + txnFinish func() + + // Save connection arguments to use during CancelRequest. + dialer Dialer + opts values + + // Cancellation key data for use with CancelRequest messages. + processID int + secretKey int parameterStatus parameterStatus @@ -115,12 +131,15 @@ type conn struct { // Whether to always send []byte parameters over as binary. Enables single // round-trip mode for non-prepared Query calls. binaryParameters bool + + // If true this connection is in the middle of a COPY + inCopy bool } // Handle driver-side settings in parsed connection string. -func (c *conn) handleDriverSettings(o values) (err error) { +func (cn *conn) handleDriverSettings(o values) (err error) { boolSetting := func(key string, val *bool) error { - if value := o.Get(key); value != "" { + if value, ok := o[key]; ok { if value == "yes" { *val = true } else if value == "no" { @@ -132,32 +151,32 @@ func (c *conn) handleDriverSettings(o values) (err error) { return nil } - err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) - if err != nil { - return err - } - err = boolSetting("binary_parameters", &c.binaryParameters) + err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) if err != nil { return err } - return nil + return boolSetting("binary_parameters", &cn.binaryParameters) } -func (c *conn) handlePgpass(o values) { +func (cn *conn) handlePgpass(o values) { // if a password was supplied, do not process .pgpass - _, ok := o["password"] - if ok { + if _, ok := o["password"]; ok { return } filename := os.Getenv("PGPASSFILE") if filename == "" { // XXX this code doesn't work on Windows where the default filename is // XXX %APPDATA%\postgresql\pgpass.conf - user, err := user.Current() - if err != nil { - return + // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 + userHome := os.Getenv("HOME") + if userHome == "" { + user, err := user.Current() + if err != nil { + return + } + userHome = user.HomeDir } - filename = filepath.Join(user.HomeDir, ".pgpass") + filename = filepath.Join(userHome, ".pgpass") } fileinfo, err := os.Stat(filename) if err != nil { @@ -174,11 +193,11 @@ func (c *conn) handlePgpass(o values) { } defer file.Close() scanner := bufio.NewScanner(io.Reader(file)) - hostname := o.Get("host") + hostname := o["host"] ntw, _ := network(o) - port := o.Get("port") - db := o.Get("dbname") - username := o.Get("user") + port := o["port"] + db := o["dbname"] + username := o["user"] // From: https://github.com/tg/pgpass/blob/master/reader.go getFields := func(s string) []string { fs := make([]string, 0, 5) @@ -217,18 +236,22 @@ func (c *conn) handlePgpass(o values) { } } -func (c *conn) writeBuf(b byte) *writeBuf { - c.scratch[0] = b +func (cn *conn) writeBuf(b byte) *writeBuf { + cn.scratch[0] = b return &writeBuf{ - buf: c.scratch[:5], + buf: cn.scratch[:5], pos: 1, } } +// Open opens a new connection to the database. name is a connection string. +// Most users should only use it through database/sql package from the standard +// library. func Open(name string) (_ driver.Conn, err error) { return DialOpen(defaultDialer{}, name) } +// DialOpen opens a new connection to the database using a dialer. func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // Handle any panics during connection initialization. Note that we // specifically do *not* want to use errRecover(), as that would turn any @@ -243,13 +266,13 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // * Very low precedence defaults applied in every situation // * Environment variables // * Explicitly passed connection information - o.Set("host", "localhost") - o.Set("port", "5432") + o["host"] = "localhost" + o["port"] = "5432" // N.B.: Extra float digits should be set to 3, but that breaks // Postgres 8.4 and older, where the max is 2. - o.Set("extra_float_digits", "2") + o["extra_float_digits"] = "2" for k, v := range parseEnviron(os.Environ()) { - o.Set(k, v) + o[k] = v } if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { @@ -264,9 +287,9 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } // Use the "fallback" application name if necessary - if fallback := o.Get("fallback_application_name"); fallback != "" { - if !o.Isset("application_name") { - o.Set("application_name", fallback) + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback } } @@ -277,33 +300,35 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { return nil, errors.New("client_encoding must be absent or 'UTF8'") } - o.Set("client_encoding", "UTF8") + o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle := o.Get("datestyle"); datestyle != "" { + if datestyle, ok := o["datestyle"]; ok { if datestyle != "ISO, MDY" { panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)) } } else { - o.Set("datestyle", "ISO, MDY") + o["datestyle"] = "ISO, MDY" } // If a user is not provided by any other means, the last // resort is to use the current operating system provided user // name. - if o.Get("user") == "" { + if _, ok := o["user"]; !ok { u, err := userCurrent() if err != nil { return nil, err - } else { - o.Set("user", u) } + o["user"] = u } - cn := &conn{} + cn := &conn{ + opts: o, + dialer: d, + } err = cn.handleDriverSettings(o) if err != nil { return nil, err @@ -314,14 +339,28 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { if err != nil { return nil, err } - cn.ssl(o) + + err = cn.ssl(o) + if err != nil { + return nil, err + } + + // cn.startup panics on error. Make sure we don't leak cn.c. + panicking := true + defer func() { + if panicking { + cn.c.Close() + } + }() + cn.buf = bufio.NewReader(cn.c) cn.startup(o) // reset the deadline, in case one was set (see dial) - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { err = cn.c.SetDeadline(time.Time{}) } + panicking = false return cn, err } @@ -333,7 +372,7 @@ func dial(d Dialer, o values) (net.Conn, error) { } // Zero or not specified means wait indefinitely. - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { seconds, err := strconv.ParseInt(timeout, 10, 0) if err != nil { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) @@ -355,31 +394,18 @@ func dial(d Dialer, o values) (net.Conn, error) { } func network(o values) (string, string) { - host := o.Get("host") + host := o["host"] if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) + sockPath := path.Join(host, ".s.PGSQL."+o["port"]) return "unix", sockPath } - return "tcp", net.JoinHostPort(host, o.Get("port")) + return "tcp", net.JoinHostPort(host, o["port"]) } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func (vs values) Isset(k string) bool { - _, ok := vs[k] - return ok -} - // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune @@ -450,7 +476,7 @@ func parseOpts(name string, o values) error { // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. - o.Set(string(keyRunes), "") + o[string(keyRunes)] = "" break } @@ -485,7 +511,7 @@ func parseOpts(name string, o values) error { } } - o.Set(string(keyRunes), string(valRunes)) + o[string(keyRunes)] = string(valRunes) } return nil @@ -504,13 +530,17 @@ func (cn *conn) checkIsInTransaction(intxn bool) { } func (cn *conn) Begin() (_ driver.Tx, err error) { + return cn.begin("") +} + +func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN") + _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, err } @@ -525,7 +555,14 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { return cn, nil } +func (cn *conn) closeTxn() { + if finish := cn.txnFinish; finish != nil { + finish() + } +} + func (cn *conn) Commit() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -561,6 +598,7 @@ func (cn *conn) Commit() (err error) { } func (cn *conn) Rollback() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -598,11 +636,16 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } // done return case 'E': err = parseError(r) - case 'T', 'D', 'I': + case 'I': + res = emptyRows + case 'T', 'D': // ignore any results default: cn.bad = true @@ -635,6 +678,12 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { cn: cn, } } + // Set the result and tag to the last command complete if there wasn't a + // query already run. Although queries usually return from here and cede + // control to Next, a query with zero results does not. + if t == 'C' && res.colNames == nil { + res.result, res.tag = cn.parseComplete(r.string()) + } res.done = true case 'Z': cn.processReadyForQuery(r) @@ -666,9 +715,23 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { } } +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errNoLastInsertID +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errNoRowsAffected +} + // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. -func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { +func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { if len(colTyps) == 0 { return nil, colFmtDataAllText } @@ -680,8 +743,8 @@ func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, c allBinary := true allText := true - for i, o := range colTyps { - switch o { + for i, t := range colTyps { + switch t.OID { // This is the list of types to use binary mode for when receiving them // through a prepared statement. If a type appears in this list, it // must also be implemented in binaryDecode in encode.go. @@ -692,6 +755,8 @@ func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, c case oid.T_int4: fallthrough case oid.T_int2: + fallthrough + case oid.T_uuid: colFmts[i] = formatBinary allText = false @@ -743,32 +808,45 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { defer cn.errRecover(&err) if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { - return cn.prepareCopyIn(q) + s, err := cn.prepareCopyIn(q) + if err == nil { + cn.inCopy = true + } + return s, err } return cn.prepareTo(q, cn.gname()), nil } func (cn *conn) Close() (err error) { - if cn.bad { - return driver.ErrBadConn - } + // Skip cn.bad return here because we always want to close a connection. defer cn.errRecover(&err) + // Ensure that cn.c.Close is always run. Since error handling is done with + // panics and cn.errRecover, the Close must be in a defer. + defer func() { + cerr := cn.c.Close() + if err == nil { + err = cerr + } + }() + // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. - err = cn.sendSimpleMessage('X') - if err != nil { - return err - } - - return cn.c.Close() + return cn.sendSimpleMessage('X') } // Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) { +func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return cn.query(query, args) +} + +func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { if cn.bad { return nil, driver.ErrBadConn } + if cn.inCopy { + return nil, errCopyInProgress + } defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is @@ -786,16 +864,15 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil - } else { - st := cn.prepareTo(query, "") - st.exec(args) - return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, - }, nil } + st := cn.prepareTo(query, "") + st.exec(args) + return &rows{ + cn: cn, + colNames: st.colNames, + colTyps: st.colTyps, + colFmts: st.colFmts, + }, nil } // Implement the optional "Execer" interface for one-shot queries @@ -822,17 +899,16 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err cn.postExecuteWorkaround() res, _, err = cn.readExecuteResponse("Execute") return res, err - } else { - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") - r, err := st.Exec(args) - if err != nil { - panic(err) - } - return r, err } + // Use the unnamed statement to defer planning until bind + // time, or else value-based selectivity estimates cannot be + // used. + st := cn.prepareTo(query, "") + r, err := st.Exec(args) + if err != nil { + panic(err) + } + return r, err } func (cn *conn) send(m *writeBuf) { @@ -842,16 +918,9 @@ func (cn *conn) send(m *writeBuf) { } } -func (cn *conn) sendStartupPacket(m *writeBuf) { - // sanity check - if m.buf[0] != 0 { - panic("oops") - } - +func (cn *conn) sendStartupPacket(m *writeBuf) error { _, err := cn.c.Write((m.wrap())[1:]) - if err != nil { - panic(err) - } + return err } // Send a message of type typ to the server on the other end of cn. The @@ -964,165 +1033,35 @@ func (cn *conn) recv1() (t byte, r *readBuf) { return t, r } -func (cn *conn) ssl(o values) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o.Get("sslmode"); mode { - // "require" is the default. - case "", "require": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - - // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: - // Note: For backwards compatibility with earlier versions of PostgreSQL, if a - // root CA file exists, the behavior of sslmode=require will be the same as - // that of verify-ca, meaning the server certificate is validated against the - // CA. Relying on this behavior is discouraged, and applications that need - // certificate validation should always use verify-ca or verify-full. - if _, err := os.Stat(o.Get("sslrootcert")); err == nil { - verifyCaOnly = true - } else { - o.Set("sslrootcert", "") - } - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o.Get("host") - case "disable": - return - default: - errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) +func (cn *conn) ssl(o values) error { + upgrade, err := ssl(o) + if err != nil { + return err } - cn.setupSSLClientCertificates(&tlsConf, o) - cn.setupSSLCA(&tlsConf, o) + if upgrade == nil { + // Nothing to do + return nil + } w := cn.writeBuf(0) w.int32(80877103) - cn.sendStartupPacket(w) + if err = cn.sendStartupPacket(w); err != nil { + return err + } b := cn.scratch[:1] - _, err := io.ReadFull(cn.c, b) + _, err = io.ReadFull(cn.c, b) if err != nil { - panic(err) + return err } if b[0] != 'S' { - panic(ErrSSLNotSupported) - } - - client := tls.Client(cn.c, &tlsConf) - if verifyCaOnly { - cn.verifyCA(client, &tlsConf) - } - cn.c = client -} - -// verifyCA carries out a TLS handshake to the server and verifies the -// presented certificate against the effective CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) { - err := client.Handshake() - if err != nil { - panic(err) - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - if err != nil { - panic(err) - } -} - -// This function sets up SSL client certificates based on either the "sslkey" -// and "sslcert" settings (possibly set via the environment variables PGSSLKEY -// and PGSSLCERT, respectively), or if they aren't set, from the .postgresql -// directory in the user's home directory. If the file paths are set -// explicitly, the files must exist. The key file must also not be -// world-readable, or this function will panic with -// ErrSSLKeyHasWorldPermissions. -func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) { - var missingOk bool - - sslkey := o.Get("sslkey") - sslcert := o.Get("sslcert") - if sslkey != "" && sslcert != "" { - // If the user has set an sslkey and sslcert, they *must* exist. - missingOk = false - } else { - // Automatically load certificates from ~/.postgresql. - user, err := user.Current() - if err != nil { - // user.Current() might fail when cross-compiling. We have to - // ignore the error and continue without client certificates, since - // we wouldn't know where to load them from. - return - } - - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - missingOk = true + return ErrSSLNotSupported } - // Check that both files exist, and report the error or stop, depending on - // which behaviour we want. Note that we don't do any more extensive - // checks than this (such as checking that the paths aren't directories); - // LoadX509KeyPair() will take care of the rest. - keyfinfo, err := os.Stat(sslkey) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - _, err = os.Stat(sslcert) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - - // If we got this far, the key file must also have the correct permissions - kmode := keyfinfo.Mode() - if kmode != kmode&0600 { - panic(ErrSSLKeyHasWorldPermissions) - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - panic(err) - } - tlsConf.Certificates = []tls.Certificate{cert} -} - -// Sets up RootCAs in the TLS configuration if sslrootcert is set. -func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) { - if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { - tlsConf.RootCAs = x509.NewCertPool() - - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - panic(err) - } - - ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) - if !ok { - errorf("couldn't parse pem in sslrootcert") - } - } + cn.c, err = upgrade(cn.c) + return err } // isDriverSetting returns true iff a setting is purely for configuring the @@ -1171,12 +1110,15 @@ func (cn *conn) startup(o values) { w.string(v) } w.string("") - cn.sendStartupPacket(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } for { t, r := cn.recv() switch t { case 'K': + cn.processBackendKeyData(r) case 'S': cn.processParameterStatus(r) case 'R': @@ -1196,7 +1138,7 @@ func (cn *conn) auth(r *readBuf, o values) { // OK case 3: w := cn.writeBuf('p') - w.string(o.Get("password")) + w.string(o["password"]) cn.send(w) t, r := cn.recv() @@ -1210,7 +1152,7 @@ func (cn *conn) auth(r *readBuf, o values) { case 5: s := string(r.next(4)) w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) + w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) cn.send(w) t, r := cn.recv() @@ -1232,10 +1174,10 @@ const formatText format = 0 const formatBinary format = 1 // One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} +var colFmtDataAllBinary = []byte{0, 1, 0, 1} // No result-column format codes (i.e. all text). -var colFmtDataAllText []byte = []byte{0, 0} +var colFmtDataAllText = []byte{0, 0} type stmt struct { cn *conn @@ -1243,7 +1185,7 @@ type stmt struct { colNames []string colFmts []format colFmtData []byte - colTyps []oid.Oid + colTyps []fieldDesc paramTyps []oid.Oid closed bool } @@ -1404,21 +1346,32 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { type rows struct { cn *conn + finish func() colNames []string - colTyps []oid.Oid + colTyps []fieldDesc colFmts []format done bool rb readBuf + result driver.Result + tag string } func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } // no need to look at cn.bad as Next() will for { err := rs.Next(nil) switch err { case nil: case io.EOF: - return nil + // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row + // description, used with HasNextResultSet). We need to fetch messages until + // we hit a 'Z', which is done by waiting for done to be set. + if rs.done { + return nil + } default: return err } @@ -1429,6 +1382,17 @@ func (rs *rows) Columns() []string { return rs.colNames } +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag +} + func (rs *rows) Next(dest []driver.Value) (err error) { if rs.done { return io.EOF @@ -1446,6 +1410,9 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'E': err = parseError(&rs.rb) case 'C', 'I': + if t == 'C' { + rs.result, rs.tag = conn.parseComplete(rs.rb.string()) + } continue case 'Z': conn.processReadyForQuery(&rs.rb) @@ -1469,21 +1436,33 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) } return + case 'T': + rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) + return io.EOF default: errorf("unexpected message after execute: %q", t) } } } +func (rs *rows) HasNextResultSet() bool { + return !rs.done +} + +func (rs *rows) NextResultSet() error { + return nil +} + // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // used as part of an SQL statement. For example: // // tblname := "my_table" // data := "my_data" -// err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data) +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) // // Any double quotes in name will be escaped. The quoted identifier will be // case sensitive when used in a query. If the input string contains a zero @@ -1564,7 +1543,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { cn.send(b) } -func (c *conn) processParameterStatus(r *readBuf) { +func (cn *conn) processParameterStatus(r *readBuf) { var err error param := r.string() @@ -1575,13 +1554,13 @@ func (c *conn) processParameterStatus(r *readBuf) { var minor int _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) if err == nil { - c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor + cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor } case "TimeZone": - c.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) + cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) if err != nil { - c.parameterStatus.currentLocation = nil + cn.parameterStatus.currentLocation = nil } default: @@ -1589,8 +1568,8 @@ func (c *conn) processParameterStatus(r *readBuf) { } } -func (c *conn) processReadyForQuery(r *readBuf) { - c.txnStatus = transactionStatus(r.byte()) +func (cn *conn) processReadyForQuery(r *readBuf) { + cn.txnStatus = transactionStatus(r.byte()) } func (cn *conn) readReadyForQuery() { @@ -1605,6 +1584,11 @@ func (cn *conn) readReadyForQuery() { } } +func (cn *conn) processBackendKeyData(r *readBuf) { + cn.processID = r.int32() + cn.secretKey = r.int32() +} + func (cn *conn) readParseResponse() { t, r := cn.recv1() switch t { @@ -1620,7 +1604,7 @@ func (cn *conn) readParseResponse() { } } -func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { for { t, r := cn.recv1() switch t { @@ -1646,7 +1630,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ } } -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { +func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) { t, r := cn.recv1() switch t { case 'T': @@ -1720,6 +1704,9 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } return res, commandTag, err case 'E': err = parseError(r) @@ -1728,6 +1715,9 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co cn.bad = true errorf("unexpected %q after error %s", t, err) } + if t == 'I' { + res = emptyRows + } // ignore any results default: cn.bad = true @@ -1736,31 +1726,33 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } } -func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { +func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() // format code not known when describing a statement; always 0 r.next(2) } return } -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { +func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) colFmts = make([]format, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } return diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go new file mode 100644 index 0000000000..81c9ee4758 --- /dev/null +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -0,0 +1,129 @@ +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" +) + +// Implement the "QueryerContext" interface +func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + finish := cn.watchCancel(ctx) + r, err := cn.query(query, list) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "ExecerContext" interface +func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + + return cn.Exec(query, list) +} + +// Implement the "ConnBeginTx" interface +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var mode string + + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) + } + + if opts.ReadOnly { + mode += " READ ONLY" + } else { + mode += " READ WRITE" + } + + tx, err := cn.begin(mode) + if err != nil { + return nil, err + } + cn.txnFinish = cn.watchCancel(ctx) + return tx, nil +} + +func (cn *conn) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + _ = cn.cancel() + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } + } + return nil +} + +func (cn *conn) cancel() error { + c, err := dial(cn.dialer, cn.opts) + if err != nil { + return err + } + defer c.Close() + + { + can := conn{ + c: c, + } + err = can.ssl(cn.opts) + if err != nil { + return err + } + + w := can.writeBuf(0) + w.int32(80877102) // cancel request code + w.int32(cn.processID) + w.int32(cn.secretKey) + + if err := can.sendStartupPacket(w); err != nil { + return err + } + } + + // Read until EOF to ensure that the server received the cancel. + { + _, err := io.Copy(ioutil.Discard, c) + return err + } +} diff --git a/vendor/github.com/lib/pq/connector.go b/vendor/github.com/lib/pq/connector.go new file mode 100644 index 0000000000..9e66eb5df8 --- /dev/null +++ b/vendor/github.com/lib/pq/connector.go @@ -0,0 +1,43 @@ +// +build go1.10 + +package pq + +import ( + "context" + "database/sql/driver" +) + +// Connector represents a fixed configuration for the pq driver with a given +// name. Connector satisfies the database/sql/driver Connector interface and +// can be used to create any number of DB Conn's via the database/sql OpenDB +// function. +// +// See https://golang.org/pkg/database/sql/driver/#Connector. +// See https://golang.org/pkg/database/sql/#OpenDB. +type connector struct { + name string +} + +// Connect returns a connection to the database using the fixed configuration +// of this Connector. Context is not used. +func (c *connector) Connect(_ context.Context) (driver.Conn, error) { + return (&Driver{}).Open(c.name) +} + +// Driver returnst the underlying driver of this Connector. +func (c *connector) Driver() driver.Driver { + return &Driver{} +} + +var _ driver.Connector = &connector{} + +// NewConnector returns a connector for the pq driver in a fixed configuration +// with the given name. The returned connector can be used to create any number +// of equivalent Conn's. The returned connector is intended to be used with +// database/sql.OpenDB. +// +// See https://golang.org/pkg/database/sql/driver/#Connector. +// See https://golang.org/pkg/database/sql/#OpenDB. +func NewConnector(name string) (driver.Connector, error) { + return &connector{name: name}, nil +} diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index 101f111330..345c2398f6 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -13,6 +13,7 @@ var ( errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") errCopyToNotSupported = errors.New("pq: COPY TO is not supported") errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") + errCopyInProgress = errors.New("pq: COPY in progress") ) // CopyIn creates a COPY FROM statement which can be prepared with @@ -96,13 +97,13 @@ awaitCopyInResponse: err = parseError(r) case 'Z': if err == nil { - cn.bad = true + ci.setBad() errorf("unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for copy query: %q", t) } } @@ -121,7 +122,7 @@ awaitCopyInResponse: cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for CopyFail: %q", t) } } @@ -142,7 +143,7 @@ func (ci *copyin) resploop() { var r readBuf t, err := ci.cn.recvMessage(&r) if err != nil { - ci.cn.bad = true + ci.setBad() ci.setError(err) ci.done <- true return @@ -160,7 +161,7 @@ func (ci *copyin) resploop() { err := parseError(&r) ci.setError(err) default: - ci.cn.bad = true + ci.setBad() ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) ci.done <- true return @@ -168,6 +169,19 @@ func (ci *copyin) resploop() { } } +func (ci *copyin) setBad() { + ci.Lock() + ci.cn.bad = true + ci.Unlock() +} + +func (ci *copyin) isBad() bool { + ci.Lock() + b := ci.cn.bad + ci.Unlock() + return b +} + func (ci *copyin) isErrorSet() bool { ci.Lock() isSet := (ci.err != nil) @@ -205,7 +219,7 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { return nil, errCopyInClosed } - if ci.cn.bad { + if ci.isBad() { return nil, driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -243,7 +257,7 @@ func (ci *copyin) Close() (err error) { } ci.closed = true - if ci.cn.bad { + if ci.isBad() { return driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -258,6 +272,7 @@ func (ci *copyin) Close() (err error) { } <-ci.done + ci.cn.inCopy = false if ci.isErrorSet() { err = ci.err diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index 19798dfc92..2a60054e2e 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -11,7 +11,8 @@ using this package directly. For example: ) func main() { - db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full") + connStr := "user=pqgotest dbname=pqgotest sslmode=verify-full" + db, err := sql.Open("postgres", connStr) if err != nil { log.Fatal(err) } @@ -23,7 +24,8 @@ using this package directly. For example: You can also connect to a database using a URL. For example: - db, err := sql.Open("postgres", "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full") + connStr := "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full" + db, err := sql.Open("postgres", connStr) Connection String Parameters @@ -43,21 +45,28 @@ supported: * dbname - The name of the database to connect to * user - The user to sign in as * password - The user's password - * host - The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) + * host - The host to connect to. Values that start with / are for unix + domain sockets. (default is localhost) * port - The port to bind to. (default is 5432) - * sslmode - Whether or not to use SSL (default is require, this is not the default for libpq) + * sslmode - Whether or not to use SSL (default is require, this is not + the default for libpq) * fallback_application_name - An application_name to fall back to if one isn't provided. - * connect_timeout - Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. + * connect_timeout - Maximum wait for connection, in seconds. Zero or + not specified means wait indefinitely. * sslcert - Cert file location. The file must contain PEM encoded data. * sslkey - Key file location. The file must contain PEM encoded data. - * sslrootcert - The location of the root certificate file. The file must contain PEM encoded data. + * sslrootcert - The location of the root certificate file. The file + must contain PEM encoded data. Valid values for sslmode are: * disable - No SSL * require - Always SSL (skip verification) - * verify-ca - Always SSL (verify that the certificate presented by the server was signed by a trusted CA) - * verify-full - Always SSL (verify that the certification presented by the server was signed by a trusted CA and the server host name matches the one in the certificate) + * verify-ca - Always SSL (verify that the certificate presented by the + server was signed by a trusted CA) + * verify-full - Always SSL (verify that the certification presented by + the server was signed by a trusted CA and the server host name + matches the one in the certificate) See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING for more information about connection string parameters. @@ -68,7 +77,7 @@ Use single quotes for values that contain whitespace: A backslash will escape the next character in values: - "user=space\ man password='it\'s valid' + "user=space\ man password='it\'s valid'" Note that the connection parameter client_encoding (which sets the text encoding for the connection) may be set but must be "UTF8", @@ -89,8 +98,10 @@ provided connection parameters. The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html is supported, but on Windows PGPASSFILE must be specified explicitly. + Queries + database/sql does not dictate any specific format for parameter markers in query strings, and pq uses the Postgres-native ordinal markers, as shown above. The same marker can be reused for the same parameter: @@ -114,8 +125,30 @@ For more details on RETURNING, see the Postgres documentation: For additional instructions on querying see the documentation for the database/sql package. + +Data Types + + +Parameters pass through driver.DefaultParameterConverter before they are handled +by this package. When the binary_parameters connection option is enabled, +[]byte values are sent directly to the backend as data in binary format. + +This package returns the following types for values from the PostgreSQL backend: + + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are + returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte + +All other types are returned directly from the backend as []byte values in text format. + + Errors + pq may return errors of type *pq.Error which can be interrogated for error details: if err, ok := err.(*pq.Error); ok { @@ -206,7 +239,7 @@ for more information). Note that the channel name will be truncated to 63 bytes by the PostgreSQL server. You can find a complete, working example of Listener usage at -http://godoc.org/github.com/lib/pq/listen_example. +https://godoc.org/github.com/lib/pq/example/listen. */ package pq diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index 29e8f6ff7c..3b0d365f29 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -76,6 +76,12 @@ func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) inter return int64(int32(binary.BigEndian.Uint32(s))) case oid.T_int2: return int64(int16(binary.BigEndian.Uint16(s))) + case oid.T_uuid: + b, err := decodeUUIDBinary(s) + if err != nil { + panic(err) + } + return b default: errorf("don't know how to decode binary parameter of type %d", uint32(typ)) @@ -361,8 +367,15 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro timeSep := daySep + 3 day := p.mustAtoi(str, daySep+1, timeSep) + minLen := monSep + len("01-01") + 1 + + isBC := strings.HasSuffix(str, " BC") + if isBC { + minLen += 3 + } + var hour, minute, second int - if len(str) > monSep+len("01-01")+1 { + if len(str) > minLen { p.expect(str, ' ', timeSep) minSep := timeSep + 3 p.expect(str, ':', minSep) @@ -418,7 +431,8 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) } var isoYear int - if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" { + + if isBC { isoYear = 1 - year remainderIdx += 3 } else { @@ -471,7 +485,7 @@ func FormatTimestamp(t time.Time) []byte { t = t.AddDate((-t.Year())*2+1, 0, 0) bc = true } - b := []byte(t.Format(time.RFC3339Nano)) + b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) _, offset := t.Zone() offset = offset % 60 diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go index b4bb44cee3..96aae29c65 100644 --- a/vendor/github.com/lib/pq/error.go +++ b/vendor/github.com/lib/pq/error.go @@ -153,6 +153,7 @@ var errorCodeNames = map[ErrorCode]string{ "22004": "null_value_not_allowed", "22002": "null_value_no_indicator_parameter", "22003": "numeric_value_out_of_range", + "2200H": "sequence_generator_limit_exceeded", "22026": "string_data_length_mismatch", "22001": "string_data_right_truncation", "22011": "substring_error", @@ -459,6 +460,11 @@ func errorf(s string, args ...interface{}) { panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } +// TODO(ainar-g) Rename to errorf after removing panics. +func fmterrorf(s string, args ...interface{}) error { + return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)) +} + func errRecoverNoErrBadConn(err *error) { e := recover() if e == nil { @@ -487,7 +493,8 @@ func (c *conn) errRecover(err *error) { *err = v } case *net.OpError: - *err = driver.ErrBadConn + c.bad = true + *err = v case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { *err = driver.ErrBadConn diff --git a/vendor/github.com/lib/pq/notify.go b/vendor/github.com/lib/pq/notify.go index 09f94244b9..850bb9040c 100644 --- a/vendor/github.com/lib/pq/notify.go +++ b/vendor/github.com/lib/pq/notify.go @@ -60,7 +60,7 @@ type ListenerConn struct { replyChan chan message } -// Creates a new ListenerConn. Use NewListener instead. +// NewListenerConn creates a new ListenerConn. Use NewListener instead. func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { return newDialListenerConn(defaultDialer{}, name, notificationChan) } @@ -214,17 +214,17 @@ func (l *ListenerConn) listenerConnMain() { // this ListenerConn is done } -// Send a LISTEN query to the server. See ExecSimpleQuery. +// Listen sends a LISTEN query to the server. See ExecSimpleQuery. func (l *ListenerConn) Listen(channel string) (bool, error) { return l.ExecSimpleQuery("LISTEN " + QuoteIdentifier(channel)) } -// Send an UNLISTEN query to the server. See ExecSimpleQuery. +// Unlisten sends an UNLISTEN query to the server. See ExecSimpleQuery. func (l *ListenerConn) Unlisten(channel string) (bool, error) { return l.ExecSimpleQuery("UNLISTEN " + QuoteIdentifier(channel)) } -// Send `UNLISTEN *` to the server. See ExecSimpleQuery. +// UnlistenAll sends an `UNLISTEN *` query to the server. See ExecSimpleQuery. func (l *ListenerConn) UnlistenAll() (bool, error) { return l.ExecSimpleQuery("UNLISTEN *") } @@ -267,8 +267,8 @@ func (l *ListenerConn) sendSimpleQuery(q string) (err error) { return nil } -// Execute a "simple query" (i.e. one with no bindable parameters) on the -// connection. The possible return values are: +// ExecSimpleQuery executes a "simple query" (i.e. one with no bindable +// parameters) on the connection. The possible return values are: // 1) "executed" is true; the query was executed to completion on the // database server. If the query failed, err will be set to the error // returned by the database, otherwise err will be nil. @@ -333,6 +333,7 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { } } +// Close closes the connection. func (l *ListenerConn) Close() error { l.connectionLock.Lock() if l.err != nil { @@ -346,7 +347,7 @@ func (l *ListenerConn) Close() error { return l.cn.c.Close() } -// Err() returns the reason the connection was closed. It is not safe to call +// Err returns the reason the connection was closed. It is not safe to call // this function until l.Notify has been closed. func (l *ListenerConn) Err() error { return l.err @@ -354,32 +355,43 @@ func (l *ListenerConn) Err() error { var errListenerClosed = errors.New("pq: Listener has been closed") +// ErrChannelAlreadyOpen is returned from Listen when a channel is already +// open. var ErrChannelAlreadyOpen = errors.New("pq: channel is already open") + +// ErrChannelNotOpen is returned from Unlisten when a channel is not open. var ErrChannelNotOpen = errors.New("pq: channel is not open") +// ListenerEventType is an enumeration of listener event types. type ListenerEventType int const ( - // Emitted only when the database connection has been initially - // initialized. err will always be nil. + // ListenerEventConnected is emitted only when the database connection + // has been initially initialized. The err argument of the callback + // will always be nil. ListenerEventConnected ListenerEventType = iota - // Emitted after a database connection has been lost, either because of an - // error or because Close has been called. err will be set to the reason - // the database connection was lost. + // ListenerEventDisconnected is emitted after a database connection has + // been lost, either because of an error or because Close has been + // called. The err argument will be set to the reason the database + // connection was lost. ListenerEventDisconnected - // Emitted after a database connection has been re-established after - // connection loss. err will always be nil. After this event has been - // emitted, a nil pq.Notification is sent on the Listener.Notify channel. + // ListenerEventReconnected is emitted after a database connection has + // been re-established after connection loss. The err argument of the + // callback will always be nil. After this event has been emitted, a + // nil pq.Notification is sent on the Listener.Notify channel. ListenerEventReconnected - // Emitted after a connection to the database was attempted, but failed. - // err will be set to an error describing why the connection attempt did - // not succeed. + // ListenerEventConnectionAttemptFailed is emitted after a connection + // to the database was attempted, but failed. The err argument will be + // set to an error describing why the connection attempt did not + // succeed. ListenerEventConnectionAttemptFailed ) +// EventCallbackType is the event callback type. See also ListenerEventType +// constants' documentation. type EventCallbackType func(event ListenerEventType, err error) // Listener provides an interface for listening to notifications from a @@ -454,9 +466,9 @@ func NewDialListener(d Dialer, return l } -// Returns the notification channel for this listener. This is the same -// channel as Notify, and will not be recreated during the life time of the -// Listener. +// NotificationChannel returns the notification channel for this listener. +// This is the same channel as Notify, and will not be recreated during the +// life time of the Listener. func (l *Listener) NotificationChannel() <-chan *Notification { return l.Notify } @@ -625,7 +637,7 @@ func (l *Listener) disconnectCleanup() error { // after the connection has been established. func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notification) error { doneChan := make(chan error) - go func() { + go func(notificationChan <-chan *Notification) { for channel := range l.channels { // If we got a response, return that error to our caller as it's // going to be more descriptive than cn.Err(). @@ -639,14 +651,14 @@ func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notificatio // close and then return the error message from the connection, as // per ListenerConn's interface. if err != nil { - for _ = range notificationChan { + for range notificationChan { } doneChan <- cn.Err() return } } doneChan <- nil - }() + }(notificationChan) // Ignore notifications while synchronization is going on to avoid // deadlocks. We have to send a nil notification over Notify anyway as @@ -713,6 +725,9 @@ func (l *Listener) Close() error { } l.isClosed = true + // Unblock calls to Listen() + l.reconnectCond.Broadcast() + return nil } @@ -772,7 +787,7 @@ func (l *Listener) listenerConnLoop() { } l.emitEvent(ListenerEventDisconnected, err) - time.Sleep(nextReconnect.Sub(time.Now())) + time.Sleep(time.Until(nextReconnect)) } } diff --git a/vendor/github.com/lib/pq/oid/gen.go b/vendor/github.com/lib/pq/oid/gen.go index cd4aea8086..7c634cdc5c 100644 --- a/vendor/github.com/lib/pq/oid/gen.go +++ b/vendor/github.com/lib/pq/oid/gen.go @@ -10,10 +10,22 @@ import ( "log" "os" "os/exec" + "strings" _ "github.com/lib/pq" ) +// OID represent a postgres Object Identifier Type. +type OID struct { + ID int + Type string +} + +// Name returns an upper case version of the oid type. +func (o OID) Name() string { + return strings.ToUpper(o.Type) +} + func main() { datname := os.Getenv("PGDATABASE") sslmode := os.Getenv("PGSSLMODE") @@ -30,6 +42,25 @@ func main() { if err != nil { log.Fatal(err) } + rows, err := db.Query(` + SELECT typname, oid + FROM pg_type WHERE oid < 10000 + ORDER BY oid; + `) + if err != nil { + log.Fatal(err) + } + oids := make([]*OID, 0) + for rows.Next() { + var oid OID + if err = rows.Scan(&oid.Type, &oid.ID); err != nil { + log.Fatal(err) + } + oids = append(oids, &oid) + } + if err = rows.Err(); err != nil { + log.Fatal(err) + } cmd := exec.Command("gofmt") cmd.Stderr = os.Stderr w, err := cmd.StdinPipe() @@ -45,30 +76,18 @@ func main() { if err != nil { log.Fatal(err) } - fmt.Fprintln(w, "// generated by 'go run gen.go'; do not edit") + fmt.Fprintln(w, "// Code generated by gen.go. DO NOT EDIT.") fmt.Fprintln(w, "\npackage oid") fmt.Fprintln(w, "const (") - rows, err := db.Query(` - SELECT typname, oid - FROM pg_type WHERE oid < 10000 - ORDER BY oid; - `) - if err != nil { - log.Fatal(err) - } - var name string - var oid int - for rows.Next() { - err = rows.Scan(&name, &oid) - if err != nil { - log.Fatal(err) - } - fmt.Fprintf(w, "T_%s Oid = %d\n", name, oid) - } - if err = rows.Err(); err != nil { - log.Fatal(err) + for _, oid := range oids { + fmt.Fprintf(w, "T_%s Oid = %d\n", oid.Type, oid.ID) } fmt.Fprintln(w, ")") + fmt.Fprintln(w, "var TypeName = map[Oid]string{") + for _, oid := range oids { + fmt.Fprintf(w, "T_%s: \"%s\",\n", oid.Type, oid.Name()) + } + fmt.Fprintln(w, "}") w.Close() cmd.Wait() } diff --git a/vendor/github.com/lib/pq/oid/types.go b/vendor/github.com/lib/pq/oid/types.go index 03df05a617..ecc84c2c86 100644 --- a/vendor/github.com/lib/pq/oid/types.go +++ b/vendor/github.com/lib/pq/oid/types.go @@ -1,4 +1,4 @@ -// generated by 'go run gen.go'; do not edit +// Code generated by gen.go. DO NOT EDIT. package oid @@ -18,6 +18,7 @@ const ( T_xid Oid = 28 T_cid Oid = 29 T_oidvector Oid = 30 + T_pg_ddl_command Oid = 32 T_pg_type Oid = 71 T_pg_attribute Oid = 75 T_pg_proc Oid = 81 @@ -28,6 +29,7 @@ const ( T_pg_node_tree Oid = 194 T__json Oid = 199 T_smgr Oid = 210 + T_index_am_handler Oid = 325 T_point Oid = 600 T_lseg Oid = 601 T_path Oid = 602 @@ -133,6 +135,9 @@ const ( T__uuid Oid = 2951 T_txid_snapshot Oid = 2970 T_fdw_handler Oid = 3115 + T_pg_lsn Oid = 3220 + T__pg_lsn Oid = 3221 + T_tsm_handler Oid = 3310 T_anyenum Oid = 3500 T_tsvector Oid = 3614 T_tsquery Oid = 3615 @@ -144,6 +149,8 @@ const ( T__regconfig Oid = 3735 T_regdictionary Oid = 3769 T__regdictionary Oid = 3770 + T_jsonb Oid = 3802 + T__jsonb Oid = 3807 T_anyrange Oid = 3831 T_event_trigger Oid = 3838 T_int4range Oid = 3904 @@ -158,4 +165,179 @@ const ( T__daterange Oid = 3913 T_int8range Oid = 3926 T__int8range Oid = 3927 + T_pg_shseclabel Oid = 4066 + T_regnamespace Oid = 4089 + T__regnamespace Oid = 4090 + T_regrole Oid = 4096 + T__regrole Oid = 4097 ) + +var TypeName = map[Oid]string{ + T_bool: "BOOL", + T_bytea: "BYTEA", + T_char: "CHAR", + T_name: "NAME", + T_int8: "INT8", + T_int2: "INT2", + T_int2vector: "INT2VECTOR", + T_int4: "INT4", + T_regproc: "REGPROC", + T_text: "TEXT", + T_oid: "OID", + T_tid: "TID", + T_xid: "XID", + T_cid: "CID", + T_oidvector: "OIDVECTOR", + T_pg_ddl_command: "PG_DDL_COMMAND", + T_pg_type: "PG_TYPE", + T_pg_attribute: "PG_ATTRIBUTE", + T_pg_proc: "PG_PROC", + T_pg_class: "PG_CLASS", + T_json: "JSON", + T_xml: "XML", + T__xml: "_XML", + T_pg_node_tree: "PG_NODE_TREE", + T__json: "_JSON", + T_smgr: "SMGR", + T_index_am_handler: "INDEX_AM_HANDLER", + T_point: "POINT", + T_lseg: "LSEG", + T_path: "PATH", + T_box: "BOX", + T_polygon: "POLYGON", + T_line: "LINE", + T__line: "_LINE", + T_cidr: "CIDR", + T__cidr: "_CIDR", + T_float4: "FLOAT4", + T_float8: "FLOAT8", + T_abstime: "ABSTIME", + T_reltime: "RELTIME", + T_tinterval: "TINTERVAL", + T_unknown: "UNKNOWN", + T_circle: "CIRCLE", + T__circle: "_CIRCLE", + T_money: "MONEY", + T__money: "_MONEY", + T_macaddr: "MACADDR", + T_inet: "INET", + T__bool: "_BOOL", + T__bytea: "_BYTEA", + T__char: "_CHAR", + T__name: "_NAME", + T__int2: "_INT2", + T__int2vector: "_INT2VECTOR", + T__int4: "_INT4", + T__regproc: "_REGPROC", + T__text: "_TEXT", + T__tid: "_TID", + T__xid: "_XID", + T__cid: "_CID", + T__oidvector: "_OIDVECTOR", + T__bpchar: "_BPCHAR", + T__varchar: "_VARCHAR", + T__int8: "_INT8", + T__point: "_POINT", + T__lseg: "_LSEG", + T__path: "_PATH", + T__box: "_BOX", + T__float4: "_FLOAT4", + T__float8: "_FLOAT8", + T__abstime: "_ABSTIME", + T__reltime: "_RELTIME", + T__tinterval: "_TINTERVAL", + T__polygon: "_POLYGON", + T__oid: "_OID", + T_aclitem: "ACLITEM", + T__aclitem: "_ACLITEM", + T__macaddr: "_MACADDR", + T__inet: "_INET", + T_bpchar: "BPCHAR", + T_varchar: "VARCHAR", + T_date: "DATE", + T_time: "TIME", + T_timestamp: "TIMESTAMP", + T__timestamp: "_TIMESTAMP", + T__date: "_DATE", + T__time: "_TIME", + T_timestamptz: "TIMESTAMPTZ", + T__timestamptz: "_TIMESTAMPTZ", + T_interval: "INTERVAL", + T__interval: "_INTERVAL", + T__numeric: "_NUMERIC", + T_pg_database: "PG_DATABASE", + T__cstring: "_CSTRING", + T_timetz: "TIMETZ", + T__timetz: "_TIMETZ", + T_bit: "BIT", + T__bit: "_BIT", + T_varbit: "VARBIT", + T__varbit: "_VARBIT", + T_numeric: "NUMERIC", + T_refcursor: "REFCURSOR", + T__refcursor: "_REFCURSOR", + T_regprocedure: "REGPROCEDURE", + T_regoper: "REGOPER", + T_regoperator: "REGOPERATOR", + T_regclass: "REGCLASS", + T_regtype: "REGTYPE", + T__regprocedure: "_REGPROCEDURE", + T__regoper: "_REGOPER", + T__regoperator: "_REGOPERATOR", + T__regclass: "_REGCLASS", + T__regtype: "_REGTYPE", + T_record: "RECORD", + T_cstring: "CSTRING", + T_any: "ANY", + T_anyarray: "ANYARRAY", + T_void: "VOID", + T_trigger: "TRIGGER", + T_language_handler: "LANGUAGE_HANDLER", + T_internal: "INTERNAL", + T_opaque: "OPAQUE", + T_anyelement: "ANYELEMENT", + T__record: "_RECORD", + T_anynonarray: "ANYNONARRAY", + T_pg_authid: "PG_AUTHID", + T_pg_auth_members: "PG_AUTH_MEMBERS", + T__txid_snapshot: "_TXID_SNAPSHOT", + T_uuid: "UUID", + T__uuid: "_UUID", + T_txid_snapshot: "TXID_SNAPSHOT", + T_fdw_handler: "FDW_HANDLER", + T_pg_lsn: "PG_LSN", + T__pg_lsn: "_PG_LSN", + T_tsm_handler: "TSM_HANDLER", + T_anyenum: "ANYENUM", + T_tsvector: "TSVECTOR", + T_tsquery: "TSQUERY", + T_gtsvector: "GTSVECTOR", + T__tsvector: "_TSVECTOR", + T__gtsvector: "_GTSVECTOR", + T__tsquery: "_TSQUERY", + T_regconfig: "REGCONFIG", + T__regconfig: "_REGCONFIG", + T_regdictionary: "REGDICTIONARY", + T__regdictionary: "_REGDICTIONARY", + T_jsonb: "JSONB", + T__jsonb: "_JSONB", + T_anyrange: "ANYRANGE", + T_event_trigger: "EVENT_TRIGGER", + T_int4range: "INT4RANGE", + T__int4range: "_INT4RANGE", + T_numrange: "NUMRANGE", + T__numrange: "_NUMRANGE", + T_tsrange: "TSRANGE", + T__tsrange: "_TSRANGE", + T_tstzrange: "TSTZRANGE", + T__tstzrange: "_TSTZRANGE", + T_daterange: "DATERANGE", + T__daterange: "_DATERANGE", + T_int8range: "INT8RANGE", + T__int8range: "_INT8RANGE", + T_pg_shseclabel: "PG_SHSECLABEL", + T_regnamespace: "REGNAMESPACE", + T__regnamespace: "_REGNAMESPACE", + T_regrole: "REGROLE", + T__regrole: "_REGROLE", +} diff --git a/vendor/github.com/lib/pq/rows.go b/vendor/github.com/lib/pq/rows.go new file mode 100644 index 0000000000..c6aa5b9a36 --- /dev/null +++ b/vendor/github.com/lib/pq/rows.go @@ -0,0 +1,93 @@ +package pq + +import ( + "math" + "reflect" + "time" + + "github.com/lib/pq/oid" +) + +const headerSize = 4 + +type fieldDesc struct { + // The object ID of the data type. + OID oid.Oid + // The data type size (see pg_type.typlen). + // Note that negative values denote variable-width types. + Len int + // The type modifier (see pg_attribute.atttypmod). + // The meaning of the modifier is type-specific. + Mod int +} + +func (fd fieldDesc) Type() reflect.Type { + switch fd.OID { + case oid.T_int8: + return reflect.TypeOf(int64(0)) + case oid.T_int4: + return reflect.TypeOf(int32(0)) + case oid.T_int2: + return reflect.TypeOf(int16(0)) + case oid.T_varchar, oid.T_text: + return reflect.TypeOf("") + case oid.T_bool: + return reflect.TypeOf(false) + case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz: + return reflect.TypeOf(time.Time{}) + case oid.T_bytea: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf(new(interface{})).Elem() + } +} + +func (fd fieldDesc) Name() string { + return oid.TypeName[fd.OID] +} + +func (fd fieldDesc) Length() (length int64, ok bool) { + switch fd.OID { + case oid.T_text, oid.T_bytea: + return math.MaxInt64, true + case oid.T_varchar, oid.T_bpchar: + return int64(fd.Mod - headerSize), true + default: + return 0, false + } +} + +func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { + switch fd.OID { + case oid.T_numeric, oid.T__numeric: + mod := fd.Mod - headerSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (rs *rows) ColumnTypeScanType(index int) reflect.Type { + return rs.colTyps[index].Type() +} + +// ColumnTypeDatabaseTypeName return the database system type name. +func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { + return rs.colTyps[index].Name() +} + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.colTyps[index].Length() +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.colTyps[index].PrecisionScale() +} diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go new file mode 100644 index 0000000000..d902084558 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl.go @@ -0,0 +1,175 @@ +package pq + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "os" + "os/user" + "path/filepath" +) + +// ssl generates a function to upgrade a net.Conn based on the "sslmode" and +// related settings. The function is nil when no upgrade should take place. +func ssl(o values) (func(net.Conn) (net.Conn, error), error) { + verifyCaOnly := false + tlsConf := tls.Config{} + switch mode := o["sslmode"]; mode { + // "require" is the default. + case "", "require": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + + // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: + // + // Note: For backwards compatibility with earlier versions of + // PostgreSQL, if a root CA file exists, the behavior of + // sslmode=require will be the same as that of verify-ca, meaning the + // server certificate is validated against the CA. Relying on this + // behavior is discouraged, and applications that need certificate + // validation should always use verify-ca or verify-full. + if sslrootcert, ok := o["sslrootcert"]; ok { + if _, err := os.Stat(sslrootcert); err == nil { + verifyCaOnly = true + } else { + delete(o, "sslrootcert") + } + } + case "verify-ca": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + verifyCaOnly = true + case "verify-full": + tlsConf.ServerName = o["host"] + case "disable": + return nil, nil + default: + return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + } + + err := sslClientCertificates(&tlsConf, o) + if err != nil { + return nil, err + } + err = sslCertificateAuthority(&tlsConf, o) + if err != nil { + return nil, err + } + + // Accept renegotiation requests initiated by the backend. + // + // Renegotiation was deprecated then removed from PostgreSQL 9.5, but + // the default configuration of older versions has it enabled. Redshift + // also initiates renegotiations and cannot be reconfigured. + tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient + + return func(conn net.Conn) (net.Conn, error) { + client := tls.Client(conn, &tlsConf) + if verifyCaOnly { + err := sslVerifyCertificateAuthority(client, &tlsConf) + if err != nil { + return nil, err + } + } + return client, nil + }, nil +} + +// sslClientCertificates adds the certificate specified in the "sslcert" and +// "sslkey" settings, or if they aren't set, from the .postgresql directory +// in the user's home directory. The configured files must exist and have +// the correct permissions. +func sslClientCertificates(tlsConf *tls.Config, o values) error { + // user.Current() might fail when cross-compiling. We have to ignore the + // error and continue without home directory defaults, since we wouldn't + // know from where to load them. + user, _ := user.Current() + + // In libpq, the client certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 + sslcert := o["sslcert"] + if len(sslcert) == 0 && user != nil { + sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 + if len(sslcert) == 0 { + return nil + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 + if _, err := os.Stat(sslcert); os.IsNotExist(err) { + return nil + } else if err != nil { + return err + } + + // In libpq, the ssl key is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 + sslkey := o["sslkey"] + if len(sslkey) == 0 && user != nil { + sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + } + + if len(sslkey) > 0 { + if err := sslKeyPermissions(sslkey); err != nil { + return err + } + } + + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return err + } + + tlsConf.Certificates = []tls.Certificate{cert} + return nil +} + +// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. +func sslCertificateAuthority(tlsConf *tls.Config, o values) error { + // In libpq, the root certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 + if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { + tlsConf.RootCAs = x509.NewCertPool() + + cert, err := ioutil.ReadFile(sslrootcert) + if err != nil { + return err + } + + if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { + return fmterrorf("couldn't parse pem in sslrootcert") + } + } + + return nil +} + +// sslVerifyCertificateAuthority carries out a TLS handshake to the server and +// verifies the presented certificate against the CA, i.e. the one specified in +// sslrootcert or the system CA if sslrootcert was not specified. +func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error { + err := client.Handshake() + if err != nil { + return err + } + certs := client.ConnectionState().PeerCertificates + opts := x509.VerifyOptions{ + DNSName: client.ConnectionState().ServerName, + Intermediates: x509.NewCertPool(), + Roots: tlsConf.RootCAs, + } + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + _, err = certs[0].Verify(opts) + return err +} diff --git a/vendor/github.com/lib/pq/ssl_permissions.go b/vendor/github.com/lib/pq/ssl_permissions.go new file mode 100644 index 0000000000..3b7c3a2a31 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_permissions.go @@ -0,0 +1,20 @@ +// +build !windows + +package pq + +import "os" + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(sslkey string) error { + info, err := os.Stat(sslkey) + if err != nil { + return err + } + if info.Mode().Perm()&0077 != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil +} diff --git a/vendor/github.com/lib/pq/ssl_windows.go b/vendor/github.com/lib/pq/ssl_windows.go new file mode 100644 index 0000000000..5d2c763ceb --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_windows.go @@ -0,0 +1,9 @@ +// +build windows + +package pq + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(string) error { return nil } diff --git a/vendor/github.com/lib/pq/uuid.go b/vendor/github.com/lib/pq/uuid.go new file mode 100644 index 0000000000..9a1b9e0748 --- /dev/null +++ b/vendor/github.com/lib/pq/uuid.go @@ -0,0 +1,23 @@ +package pq + +import ( + "encoding/hex" + "fmt" +) + +// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. +func decodeUUIDBinary(src []byte) ([]byte, error) { + if len(src) != 16 { + return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) + } + + dst := make([]byte, 36) + dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' + hex.Encode(dst[0:], src[0:4]) + hex.Encode(dst[9:], src[4:6]) + hex.Encode(dst[14:], src[6:8]) + hex.Encode(dst[19:], src[8:10]) + hex.Encode(dst[24:], src[10:16]) + + return dst, nil +}