diff --git a/zaptest/logger.go b/zaptest/logger.go index 4fa726f67..73ad5b71b 100644 --- a/zaptest/logger.go +++ b/zaptest/logger.go @@ -65,17 +65,39 @@ func NewLogger(t TestingT, opts ...LoggerOption) *zap.Logger { o.applyLoggerOption(&cfg) } + writer := newTestingWriter(t) return zap.New( zapcore.NewCore( zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), - testingWriter{t}, + writer, cfg.Level, ), + + // Send zap errors to the same writer and mark the test as failed if + // that happens. + zap.ErrorOutput(writer.WithMarkFailed(true)), ) } // testingWriter is a WriteSyncer that writes to the given testing.TB. -type testingWriter struct{ t TestingT } +type testingWriter struct { + t TestingT + + // If true, the test will be marked as failed if this testingWriter is + // ever used. + markFailed bool +} + +func newTestingWriter(t TestingT) testingWriter { + return testingWriter{t: t} +} + +// WithMarkFailed returns a copy of this testingWriter with markFailed set to +// the provided value. +func (w testingWriter) WithMarkFailed(v bool) testingWriter { + w.markFailed = v + return w +} func (w testingWriter) Write(p []byte) (n int, err error) { n = len(p) @@ -85,6 +107,10 @@ func (w testingWriter) Write(p []byte) (n int, err error) { // Note: t.Log is safe for concurrent use. w.t.Logf("%s", p) + if w.markFailed { + w.t.Fail() + } + return n, nil } diff --git a/zaptest/logger_test.go b/zaptest/logger_test.go index 10fd33740..b69aa28c8 100644 --- a/zaptest/logger_test.go +++ b/zaptest/logger_test.go @@ -25,16 +25,19 @@ import ( "fmt" "io" "strings" - "sync" "testing" "go.uber.org/zap" + "go.uber.org/zap/internal/ztest" + "go.uber.org/zap/zapcore" "github.com/stretchr/testify/assert" ) func TestTestLogger(t *testing.T) { ts := newTestLogSpy(t) + defer ts.AssertPassed() + log := NewLogger(ts) log.Info("received work order") @@ -57,6 +60,8 @@ func TestTestLogger(t *testing.T) { func TestTestLoggerSupportsLevels(t *testing.T) { ts := newTestLogSpy(t) + defer ts.AssertPassed() + log := NewLogger(ts, Level(zap.WarnLevel)) log.Info("received work order") @@ -77,18 +82,43 @@ func TestTestLoggerSupportsLevels(t *testing.T) { func TestTestingWriter(t *testing.T) { ts := newTestLogSpy(t) - w := testingWriter{ts} + w := newTestingWriter(ts) n, err := io.WriteString(w, "hello\n\n") assert.NoError(t, err, "WriteString must not fail") assert.Equal(t, 7, n) } +func TestTestLoggerErrorOutput(t *testing.T) { + // This test verifies that the test logger logs internal messages to the + // testing.T and marks the test as failed. + + ts := newTestLogSpy(t) + defer ts.AssertFailed() + + log := NewLogger(ts) + + // Replace with a core that fails. + log = log.WithOptions(zap.WrapCore(func(zapcore.Core) zapcore.Core { + return zapcore.NewCore( + zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()), + zapcore.Lock(zapcore.AddSync(ztest.FailWriter{})), + zapcore.DebugLevel, + ) + })) + + log.Info("foo") // this fails + + if assert.Len(t, ts.Messages, 1, "expected a log message") { + assert.Regexp(t, `write error: failed`, ts.Messages[0]) + } +} + // testLogSpy is a testing.TB that captures logged messages. type testLogSpy struct { testing.TB - mu sync.Mutex + failed bool Messages []string } @@ -96,6 +126,19 @@ func newTestLogSpy(t testing.TB) *testLogSpy { return &testLogSpy{TB: t} } +func (t *testLogSpy) Fail() { + t.failed = true +} + +func (t *testLogSpy) Failed() bool { + return t.failed +} + +func (t *testLogSpy) FailNow() { + t.Fail() + t.TB.FailNow() +} + func (t *testLogSpy) Logf(format string, args ...interface{}) { // Log messages are in the format, // @@ -105,15 +148,22 @@ func (t *testLogSpy) Logf(format string, args ...interface{}) { // for the timestamp from these tests. m := fmt.Sprintf(format, args...) m = m[strings.IndexByte(m, '\t')+1:] - - // t.Log should be thread-safe. - t.mu.Lock() t.Messages = append(t.Messages, m) - t.mu.Unlock() - t.TB.Log(m) } func (t *testLogSpy) AssertMessages(msgs ...string) { - assert.Equal(t, msgs, t.Messages) + assert.Equal(t.TB, msgs, t.Messages, "logged messages did not match") +} + +func (t *testLogSpy) AssertPassed() { + t.assertFailed(false, "expected test to pass") +} + +func (t *testLogSpy) AssertFailed() { + t.assertFailed(true, "expected test to fail") +} + +func (t *testLogSpy) assertFailed(v bool, msg string) { + assert.Equal(t.TB, v, t.failed, msg) } diff --git a/zaptest/testingt.go b/zaptest/testingt.go index b01d2b853..792463be3 100644 --- a/zaptest/testingt.go +++ b/zaptest/testingt.go @@ -29,6 +29,15 @@ type TestingT interface { // Logs the given message and marks the test as failed. Errorf(string, ...interface{}) + // Marks the test as failed. + Fail() + + // Returns true if the test has been marked as failed. + Failed() bool + + // Returns the name of the test. + Name() string + // Marks the test as failed and stops execution of that test. FailNow() }