diff --git a/zaptest/observer/observer.go b/zaptest/observer/observer.go index 0f6462d3d..ed36ba3be 100644 --- a/zaptest/observer/observer.go +++ b/zaptest/observer/observer.go @@ -80,6 +80,38 @@ func (o *ObservedLogs) AllUntimed() []LoggedEntry { return ret } +// FilterMessage filters entries to those that have the specified message. +func (o *ObservedLogs) FilterMessage(msg string) *ObservedLogs { + return o.filter(func(e LoggedEntry) bool { + return e.Message == msg + }) +} + +// FilterField filters entries to those that have the specified field. +func (o *ObservedLogs) FilterField(field zapcore.Field) *ObservedLogs { + return o.filter(func(e LoggedEntry) bool { + for _, ctxField := range e.Context { + if ctxField == field { + return true + } + } + return false + }) +} + +func (o *ObservedLogs) filter(match func(LoggedEntry) bool) *ObservedLogs { + o.mu.RLock() + defer o.mu.RUnlock() + + var filtered []LoggedEntry + for _, entry := range o.logs { + if match(entry) { + filtered = append(filtered, entry) + } + } + return &ObservedLogs{logs: filtered} +} + func (o *ObservedLogs) add(log LoggedEntry) { o.mu.Lock() o.logs = append(o.logs, log) diff --git a/zaptest/observer/observer_test.go b/zaptest/observer/observer_test.go index cb02d046b..e6fbe599c 100644 --- a/zaptest/observer/observer_test.go +++ b/zaptest/observer/observer_test.go @@ -118,3 +118,66 @@ func TestObserverWith(t *testing.T) { }, }, logs.All(), "expected no field sharing between With siblings") } + +func TestFilters(t *testing.T) { + logs := []LoggedEntry{ + { + Entry: zapcore.Entry{Level: zap.InfoLevel, Message: "log a"}, + Context: []zapcore.Field{zap.String("fStr", "1"), zap.Int("a", 1)}, + }, + { + Entry: zapcore.Entry{Level: zap.InfoLevel, Message: "log a"}, + Context: []zapcore.Field{zap.String("fStr", "2"), zap.Int("b", 2)}, + }, + { + Entry: zapcore.Entry{Level: zap.InfoLevel, Message: "log b"}, + Context: []zapcore.Field{zap.Int("a", 1), zap.Int("b", 2)}, + }, + { + Entry: zapcore.Entry{Level: zap.InfoLevel, Message: "log c"}, + Context: []zapcore.Field{zap.Int("a", 1), zap.Namespace("ns"), zap.Int("a", 2)}, + }, + } + + logger, sink := New(zap.InfoLevel) + for _, log := range logs { + logger.Write(log.Entry, log.Context) + } + + tests := []struct { + msg string + filtered *ObservedLogs + want []LoggedEntry + }{ + { + msg: "filter by message", + filtered: sink.FilterMessage("log a"), + want: logs[0:2], + }, + { + msg: "filter by field", + filtered: sink.FilterField(zap.String("fStr", "1")), + want: logs[0:1], + }, + { + msg: "filter by message and field", + filtered: sink.FilterMessage("log a").FilterField(zap.Int("b", 2)), + want: logs[1:2], + }, + { + msg: "filter by field with duplicate fields", + filtered: sink.FilterField(zap.Int("a", 2)), + want: logs[3:4], + }, + { + msg: "filter doesn't match any messages", + filtered: sink.FilterMessage("no match"), + want: []LoggedEntry{}, + }, + } + + for _, tt := range tests { + got := tt.filtered.AllUntimed() + assert.Equal(t, tt.want, got, tt.msg) + } +}