Skip to content

Commit

Permalink
Add a cancel cause to job cancellation (#1728)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcastorina authored Aug 30, 2023
1 parent c77c117 commit 522b2fa
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
8 changes: 4 additions & 4 deletions pkg/sources/job_progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ func (r *JobProgressRef) Done() <-chan struct{} {
// CancelRun requests that the job this is referencing is cancelled and stops
// running. This method will have no effect if the job does not allow
// cancellation.
func (r *JobProgressRef) CancelRun() {
func (r *JobProgressRef) CancelRun(cause error) {
if r.jobProgress == nil || r.jobProgress.jobCancel == nil {
return
}
r.jobProgress.jobCancel()
r.jobProgress.jobCancel(cause)
}

// Fatal is a wrapper around error to differentiate non-fatal errors from fatal
Expand Down Expand Up @@ -108,7 +108,7 @@ type JobProgress struct {
ctx context.Context
cancel context.CancelFunc
// Requests to cancel the job.
jobCancel context.CancelFunc
jobCancel context.CancelCauseFunc
// Metrics.
metrics JobProgressMetrics
metricsLock sync.Mutex
Expand Down Expand Up @@ -150,7 +150,7 @@ func WithHooks(hooks ...JobProgressHook) func(*JobProgress) {
}

// WithCancel allows cancelling the job by the JobProgressRef.
func WithCancel(cancel context.CancelFunc) func(*JobProgress) {
func WithCancel(cancel context.CancelCauseFunc) func(*JobProgress) {
return func(jp *JobProgress) { jp.jobCancel = cancel }
}

Expand Down
10 changes: 8 additions & 2 deletions pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (s *SourceManager) asyncRun(ctx context.Context, handle handle) (JobProgres
return JobProgressRef{SourceID: int64(handle), SourceName: sourceName}, err
}
// Create a JobProgress object for tracking progress.
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancelCause(ctx)
progress := NewJobProgress(jobID, int64(handle), sourceName, WithHooks(s.hooks...), WithCancel(cancel))
s.pool.Go(func() error {
atomic.AddInt32(&s.currentRunningCount, 1)
Expand All @@ -180,7 +180,7 @@ func (s *SourceManager) asyncRun(ctx context.Context, handle handle) (JobProgres
"source_manager_worker_id", common.RandomID(5),
)
defer common.Recover(ctx)
defer cancel()
defer cancel(nil)
return s.run(ctx, handle, jobID, progress)
})
return progress.Ref(), nil
Expand Down Expand Up @@ -247,6 +247,12 @@ func (s *SourceManager) run(ctx context.Context, handle handle, jobID int64, rep
report.Start(time.Now())
defer func() { report.End(time.Now()) }()

defer func() {
if err := context.Cause(ctx); err != nil {
report.ReportError(Fatal{err})
}
}()

// Initialize the source.
sourceInfo, ok := s.getSourceInfo(handle)
if !ok {
Expand Down
4 changes: 3 additions & 1 deletion pkg/sources/source_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,12 @@ func TestSourceManagerCancelRun(t *testing.T) {
ref, err := mgr.ScheduleRun(context.Background(), handle)
assert.NoError(t, err)

ref.CancelRun()
cancelErr := fmt.Errorf("abort! abort!")
ref.CancelRun(cancelErr)
<-ref.Done()
assert.Error(t, ref.Snapshot().FatalError())
assert.True(t, errors.Is(ref.Snapshot().FatalError(), returnedErr))
assert.True(t, errors.Is(ref.Snapshot().FatalErrors(), cancelErr))
}

func TestSourceManagerAvailableCapacity(t *testing.T) {
Expand Down

0 comments on commit 522b2fa

Please sign in to comment.