diff --git a/README.md b/README.md index 80971fd26..acc5d5d4a 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ First, thank you so much for wanting to contribute! It means so much that you ca - Consider opening an issue **BEFORE** creating a Pull request (PR): you won't lose your time on fixing non-existing bugs, or fixing the wrong bug. Also we can help you to produce the best PR! -- Open a PR against the `main` branch if your PR is for mainstrem or version +- Open a PR against the `main` branch if your PR is for mainstream or version specific branch e.g. `v1` if your PR is for specific version. Note that the valid branch for a new feature request PR should be `main` while a PR against a version specific branch are allowed only for bugfixes. diff --git a/errors.go b/errors.go index 228bd3865..ee6a619f2 100644 --- a/errors.go +++ b/errors.go @@ -174,7 +174,7 @@ func defaultErrorHandler(status int, origErr error, c Context) error { c.Logger().Error(origErr) c.Response().WriteHeader(status) - if env != nil && env.(string) == "production" { + if env != nil && env.(string) != "development" { switch strings.ToLower(requestCT) { case "application/json", "text/json", "json", "application/xml", "text/xml", "xml": defaultErrorResponse = &ErrorResponse{ diff --git a/errors_test.go b/errors_test.go index bf55f6002..10d359d03 100644 --- a/errors_test.go +++ b/errors_test.go @@ -66,81 +66,74 @@ func Test_defaultErrorHandler_Logger(t *testing.T) { } func Test_defaultErrorHandler_JSON_development(t *testing.T) { - r := require.New(t) - app := New(Options{}) - app.GET("/", func(c Context) error { - return c.Error(http.StatusUnauthorized, fmt.Errorf("boom")) - }) - - w := httptest.New(app) - res := w.JSON("/").Get() - r.Equal(http.StatusUnauthorized, res.Code) - ct := res.Header().Get("content-type") - r.Equal("application/json", ct) - b := res.Body.String() - r.Contains(b, `"code":401`) - r.Contains(b, `"error":"boom"`) - r.Contains(b, `"trace":"`) + testDefaultErrorHandler(t, "application/json", "development") } func Test_defaultErrorHandler_XML_development(t *testing.T) { - r := require.New(t) - app := New(Options{}) - app.GET("/", func(c Context) error { - return c.Error(http.StatusUnauthorized, fmt.Errorf("boom")) - }) + testDefaultErrorHandler(t, "text/xml", "development") +} - w := httptest.New(app) - res := w.XML("/").Get() - r.Equal(http.StatusUnauthorized, res.Code) - ct := res.Header().Get("content-type") - r.Equal("text/xml", ct) - b := res.Body.String() - r.Contains(b, ``) - r.Contains(b, `boom`) - r.Contains(b, ``) - r.Contains(b, ``) - r.Contains(b, ``) +func Test_defaultErrorHandler_JSON_staging(t *testing.T) { + testDefaultErrorHandler(t, "application/json", "staging") } -func Test_defaultErrorHandler_JSON_production(t *testing.T) { - r := require.New(t) - app := New(Options{}) - app.Env = "production" - app.GET("/", func(c Context) error { - return c.Error(http.StatusUnauthorized, fmt.Errorf("boom")) - }) +func Test_defaultErrorHandler_XML_staging(t *testing.T) { + testDefaultErrorHandler(t, "text/xml", "staging") +} - w := httptest.New(app) - res := w.JSON("/").Get() - r.Equal(http.StatusUnauthorized, res.Code) - ct := res.Header().Get("content-type") - r.Equal("application/json", ct) - b := res.Body.String() - r.Contains(b, `"code":401`) - r.Contains(b, fmt.Sprintf(`"error":"%s"`, http.StatusText(http.StatusUnauthorized))) - r.NotContains(b, `"trace":"`) +func Test_defaultErrorHandler_JSON_production(t *testing.T) { + testDefaultErrorHandler(t, "application/json", "production") } func Test_defaultErrorHandler_XML_production(t *testing.T) { + testDefaultErrorHandler(t, "text/xml", "production") +} + +func testDefaultErrorHandler(t *testing.T, contentType, env string) { r := require.New(t) app := New(Options{}) - app.Env = "production" + app.Env = env app.GET("/", func(c Context) error { return c.Error(http.StatusUnauthorized, fmt.Errorf("boom")) }) w := httptest.New(app) - res := w.XML("/").Get() + var res *httptest.Response + if contentType == "application/json" { + res = w.JSON("/").Get().Response + } else { + res = w.XML("/").Get().Response + } r.Equal(http.StatusUnauthorized, res.Code) ct := res.Header().Get("content-type") - r.Equal("text/xml", ct) + r.Equal(contentType, ct) b := res.Body.String() - r.Contains(b, ``) - r.Contains(b, fmt.Sprintf(`%s`, http.StatusText(http.StatusUnauthorized))) - r.NotContains(b, ``) - r.NotContains(b, ``) - r.Contains(b, ``) + + if env == "development" { + if contentType == "text/xml" { + r.Contains(b, ``) + r.Contains(b, `boom`) + r.Contains(b, ``) + r.Contains(b, ``) + r.Contains(b, ``) + } else { + r.Contains(b, `"code":401`) + r.Contains(b, `"error":"boom"`) + r.Contains(b, `"trace":"`) + } + } else { + if contentType == "text/xml" { + r.Contains(b, ``) + r.Contains(b, fmt.Sprintf(`%s`, http.StatusText(http.StatusUnauthorized))) + r.NotContains(b, ``) + r.NotContains(b, ``) + r.Contains(b, ``) + } else { + r.Contains(b, `"code":401`) + r.Contains(b, fmt.Sprintf(`"error":"%s"`, http.StatusText(http.StatusUnauthorized))) + r.NotContains(b, `"trace":"`) + } + } } func Test_defaultErrorHandler_nil_error(t *testing.T) {