Skip to content

Commit

Permalink
Avoid asserting on error message for cancel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rafiss committed Apr 7, 2022
1 parent 8446d16 commit 4b55993
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 37 deletions.
59 changes: 31 additions & 28 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -1859,34 +1860,34 @@ func TestStmtQueryContext(t *testing.T) {
defer db.Close()

tests := []struct {
name string
ctx func() (context.Context, context.CancelFunc)
sql string
err error
name string
ctx func() (context.Context, context.CancelFunc)
sql string
cancelExpected bool
}{
{
name: "context.Background",
ctx: func() (context.Context, context.CancelFunc) {
return context.Background(), nil
},
sql: "SELECT pg_sleep(1);",
err: nil,
sql: "SELECT pg_sleep(1);",
cancelExpected: false,
},
{
name: "context.WithTimeout exceeded",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 1*time.Second)
},
sql: "SELECT pg_sleep(10);",
err: &Error{Message: "canceling statement due to user request"},
sql: "SELECT pg_sleep(10);",
cancelExpected: true,
},
{
name: "context.WithTimeout",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Minute)
},
sql: "SELECT pg_sleep(1);",
err: nil,
sql: "SELECT pg_sleep(1);",
cancelExpected: false,
},
}
for _, tt := range tests {
Expand All @@ -1900,11 +1901,12 @@ func TestStmtQueryContext(t *testing.T) {
t.Fatal(err)
}
_, err = stmt.QueryContext(ctx)
pgErr := (*Error)(nil)
switch {
case (err != nil) != (tt.err != nil):
t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, expected = %v", err, tt.err)
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
t.Errorf("stmt.QueryContext() got = %v, expected = %v", err.Error(), tt.err.Error())
case (err != nil) != tt.cancelExpected:
t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected)
case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode):
t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected)
}
})
}
Expand All @@ -1915,34 +1917,34 @@ func TestStmtExecContext(t *testing.T) {
defer db.Close()

tests := []struct {
name string
ctx func() (context.Context, context.CancelFunc)
sql string
err error
name string
ctx func() (context.Context, context.CancelFunc)
sql string
cancelExpected bool
}{
{
name: "context.Background",
ctx: func() (context.Context, context.CancelFunc) {
return context.Background(), nil
},
sql: "SELECT pg_sleep(1);",
err: nil,
sql: "SELECT pg_sleep(1);",
cancelExpected: false,
},
{
name: "context.WithTimeout exceeded",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 1*time.Second)
},
sql: "SELECT pg_sleep(10);",
err: &Error{Message: "canceling statement due to user request"},
sql: "SELECT pg_sleep(10);",
cancelExpected: true,
},
{
name: "context.WithTimeout",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Minute)
},
sql: "SELECT pg_sleep(1);",
err: nil,
sql: "SELECT pg_sleep(1);",
cancelExpected: false,
},
}
for _, tt := range tests {
Expand All @@ -1956,11 +1958,12 @@ func TestStmtExecContext(t *testing.T) {
t.Fatal(err)
}
_, err = stmt.ExecContext(ctx)
pgErr := (*Error)(nil)
switch {
case (err != nil) != (tt.err != nil):
t.Fatalf("stmt.ExecContext() unexpected nil err got = %v, expected = %v", err, tt.err)
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
t.Errorf("stmt.ExecContext() got = %v, expected = %v", err.Error(), tt.err.Error())
case (err != nil) != tt.cancelExpected:
t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected)
case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode):
t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected)
}
})
}
Expand Down
13 changes: 8 additions & 5 deletions go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"runtime"
"strings"
"testing"
Expand Down Expand Up @@ -75,6 +76,8 @@ func TestMultipleSimpleQuery(t *testing.T) {

const contextRaceIterations = 100

const cancelErrorCode ErrorCode = "57014"

func TestContextCancelExec(t *testing.T) {
db := openTestConn(t)
defer db.Close()
Expand All @@ -87,7 +90,7 @@ func TestContextCancelExec(t *testing.T) {
// Not canceled until after the exec has started.
if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil {
t.Fatal("expected error")
} else if err.Error() != "pq: canceling statement due to user request" {
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
t.Fatalf("unexpected error: %s", err)
}

Expand Down Expand Up @@ -125,7 +128,7 @@ func TestContextCancelQuery(t *testing.T) {
// Not canceled until after the exec has started.
if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil {
t.Fatal("expected error")
} else if err.Error() != "pq: canceling statement due to user request" {
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
t.Fatalf("unexpected error: %s", err)
}

Expand Down Expand Up @@ -215,7 +218,7 @@ func TestContextCancelBegin(t *testing.T) {
// Not canceled until after the exec has started.
if _, err := tx.Exec("select pg_sleep(1)"); err == nil {
t.Fatal("expected error")
} else if err.Error() != "pq: canceling statement due to user request" {
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
t.Fatalf("unexpected error: %s", err)
}

Expand All @@ -240,8 +243,8 @@ func TestContextCancelBegin(t *testing.T) {
cancel()
if err != nil {
t.Fatal(err)
} else if err := tx.Rollback(); err != nil &&
err.Error() != "pq: canceling statement due to user request" &&
} else if err, pgErr := tx.Rollback(), (*Error)(nil); err != nil &&
!(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) &&
err != sql.ErrTxDone && err != driver.ErrBadConn && err != context.Canceled {
t.Fatal(err)
}
Expand Down
10 changes: 6 additions & 4 deletions issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pq

import (
"context"
"errors"
"testing"
"time"
)
Expand Down Expand Up @@ -51,10 +52,9 @@ func TestIssue1046(t *testing.T) {
t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since)
t.Fail()
}
expectedErr := &Error{Message: "canceling statement due to user request"}
if err == nil || err.Error() != expectedErr.Error() {
if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err())
t.Logf("got err: [%T] %+v expected err: [%T] %+v", err, err, expectedErr, expectedErr)
t.Logf("got err: [%T] %+v expected errCode: %v", err, err, cancelErrorCode)
t.Fail()
}
}
Expand All @@ -72,7 +72,9 @@ func TestIssue1062(t *testing.T) {

var v int
err := row.Scan(&v)
if err != nil && err != context.Canceled && err.Error() != "pq: canceling statement due to user request" {
if pgErr := (*Error)(nil); err != nil &&
err != context.Canceled &&
!(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
t.Fatalf("Scan resulted in unexpected error %v for canceled QueryRowContext at attempt %d", err, i+1)
}
}
Expand Down

0 comments on commit 4b55993

Please sign in to comment.