diff --git a/pkg/download/protocol.go b/pkg/download/protocol.go index b9d5945ee..a6942a9c7 100644 --- a/pkg/download/protocol.go +++ b/pkg/download/protocol.go @@ -298,7 +298,7 @@ func union(list1, list2 []string) []string { func copyContextWithCustomTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { ctxCopy, cancel := context.WithTimeout(context.Background(), timeout) - requestid.SetInContext(ctxCopy, requestid.FromContext(ctx)) - log.SetEntryInContext(ctxCopy, log.EntryFromContext(ctx)) + ctxCopy = requestid.SetInContext(ctxCopy, requestid.FromContext(ctx)) + ctxCopy = log.SetEntryInContext(ctxCopy, log.EntryFromContext(ctx)) return ctxCopy, cancel } diff --git a/pkg/download/protocol_test.go b/pkg/download/protocol_test.go index 40134a497..d2f910684 100644 --- a/pkg/download/protocol_test.go +++ b/pkg/download/protocol_test.go @@ -17,6 +17,7 @@ import ( "github.com/gomods/athens/pkg/download/mode" "github.com/gomods/athens/pkg/errors" "github.com/gomods/athens/pkg/index/nop" + "github.com/gomods/athens/pkg/log" "github.com/gomods/athens/pkg/module" "github.com/gomods/athens/pkg/stash" "github.com/gomods/athens/pkg/storage" @@ -495,3 +496,38 @@ func (ml *mockLister) List(ctx context.Context, mod string) (*storage.RevInfo, [ ml.called = true return nil, ml.list, ml.err } + +type testEntry struct { + msg string +} + +var _ log.Entry = &testEntry{} + +func (e *testEntry) Debugf(format string, args ...any) { + e.msg = format +} +func (*testEntry) Infof(format string, args ...any) {} +func (*testEntry) Warnf(format string, args ...any) {} +func (*testEntry) Errorf(format string, args ...any) {} +func (*testEntry) WithFields(fields map[string]any) log.Entry { return nil } +func (*testEntry) SystemErr(err error) {} + +func Test_copyContextWithCustomTimeout(t *testing.T) { + testEntry := &testEntry{} + + // create a context with a logger entry + logctx := log.SetEntryInContext(context.Background(), testEntry) + + // check the log work as expected + log.EntryFromContext(logctx).Debugf("first test") + require.Equal(t, "first test", testEntry.msg) + + // use copyContextWithCustomTimeout to create a new context with a custom timeout, + // and the returned context should have the same logger entry + newCtx, cancel := copyContextWithCustomTimeout(logctx, 10*time.Second) + defer cancel() + + // check the log work as expected + log.EntryFromContext(newCtx).Debugf("second test") + require.Equal(t, "second test", testEntry.msg) +}