Skip to content

Commit

Permalink
Capture callstack of canceled contexts (trufflesecurity#979)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcastorina committed Jan 9, 2023
1 parent 09d4422 commit 74831f6
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 6 deletions.
45 changes: 39 additions & 6 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package context

import (
"context"
"fmt"
"runtime/debug"
"time"

"github.com/go-logr/logr"
Expand Down Expand Up @@ -38,13 +40,21 @@ type logCtx struct {
// Embed context.Context to get all methods for free.
context.Context
log logr.Logger
err *error
}

// Logger returns a structured logger.
func (l logCtx) Logger() logr.Logger {
return l.log
}

func (l logCtx) Err() error {
if l.err != nil && *l.err != nil {
return *l.err
}
return l.Context.Err()
}

// Background returns context.Background with a default logger.
func Background() Context {
return logCtx{
Expand All @@ -64,30 +74,33 @@ func TODO() Context {
// WithCancel returns context.WithCancel with the log object propagated.
func WithCancel(parent Context) (Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(parent)
return logCtx{
lCtx := logCtx{
log: parent.Logger(),
Context: ctx,
}, cancel
}
return captureCancelCallstack(lCtx, cancel)
}

// WithDeadline returns context.WithDeadline with the log object propagated and
// the deadline added to the structured log values.
func WithDeadline(parent Context, d time.Time) (Context, context.CancelFunc) {
ctx, cancel := context.WithDeadline(parent, d)
return logCtx{
lCtx := logCtx{
log: parent.Logger().WithValues("deadline", d),
Context: ctx,
}, cancel
}
return captureCancelCallstack(lCtx, cancel)
}

// WithTimeout returns context.WithTimeout with the log object propagated and
// the timeout added to the structured log values.
func WithTimeout(parent Context, timeout time.Duration) (Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(parent, timeout)
return logCtx{
lCtx := logCtx{
log: parent.Logger().WithValues("timeout", timeout),
Context: ctx,
}, cancel
}
return captureCancelCallstack(lCtx, cancel)
}

// WithValue returns context.WithValue with the log object propagated and
Expand Down Expand Up @@ -136,3 +149,23 @@ func AddLogger(parent context.Context) Context {
func SetDefaultLogger(l logr.Logger) {
defaultLogger = l
}

// captureCancelCallstack is a helper function to capture the callstack where
// the cancel function was first called.
func captureCancelCallstack(ctx logCtx, f context.CancelFunc) (Context, context.CancelFunc) {
if ctx.err == nil {
var err error
ctx.err = &err
}
return ctx, func() {
// We must check Err() before calling f() since f() sets the error.
// If there's already an error, do nothing special.
if ctx.Err() != nil {
f()
return
}
f()
// Set the error with the stacktrace if the err pointer is non-nil.
*ctx.err = fmt.Errorf("%w (%s)", ctx.Err(), string(debug.Stack()))
}
}
43 changes: 43 additions & 0 deletions pkg/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,46 @@ func TestDiscardLogger(t *testing.T) {
ctx := Background()
ctx.Logger().Info("this shouldn't panic")
}

func TestErrCallstack(t *testing.T) {
c, cancel := WithCancel(Background())
ctx := c.(logCtx)
cancel()
select {
case <-ctx.Done():
assert.Contains(t, ctx.Err().Error(), "TestErrCallstack")
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}
}

func TestErrCallstackTimeout(t *testing.T) {
ctx, cancel := WithTimeout(Background(), 10*time.Millisecond)
defer cancel()

select {
case <-ctx.Done():
// Deadline exceeded errors will not have a callstack from the cancel
// function.
assert.NotContains(t, ctx.Err().Error(), "TestErrCallstackTimeout")
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}
}

func TestErrCallstackTimeoutCancel(t *testing.T) {
ctx, cancel := WithTimeout(Background(), 10*time.Millisecond)

var err error
select {
case <-ctx.Done():
err = ctx.Err()
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}

// Calling cancel after deadline exceeded should not overwrite the original
// error.
cancel()
assert.Equal(t, err, ctx.Err())
}

0 comments on commit 74831f6

Please sign in to comment.