From edc342d02121cf593d853820818b97aad639243a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20W=C3=BCrbach?= Date: Fri, 3 Feb 2023 16:04:25 +0100 Subject: [PATCH] contrib/net/http: add errCheck function --- contrib/net/http/option.go | 10 ++++++ contrib/net/http/roundtripper.go | 10 ++++-- contrib/net/http/roundtripper_test.go | 47 +++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/contrib/net/http/option.go b/contrib/net/http/option.go index 9fee57a5bc..ae6f0c600f 100644 --- a/contrib/net/http/option.go +++ b/contrib/net/http/option.go @@ -129,6 +129,7 @@ type roundTripperConfig struct { resourceNamer func(req *http.Request) string ignoreRequest func(*http.Request) bool spanOpts []ddtrace.StartSpanOption + errCheck func(err error) bool } func newRoundTripperConfig() *roundTripperConfig { @@ -216,3 +217,12 @@ func RTWithIgnoreRequest(f func(*http.Request) bool) RoundTripperOption { cfg.ignoreRequest = f } } + +// RTWithErrorCheck specifies a function fn which determines whether the passed +// error should be marked as an error. The fn is called whenever an aws operation +// finishes with an error +func RTWithErrorCheck(fn func(err error) bool) RoundTripperOption { + return func(cfg *roundTripperConfig) { + cfg.errCheck = fn + } +} diff --git a/contrib/net/http/roundtripper.go b/contrib/net/http/roundtripper.go index 0d54ee9b0e..a9800c9af0 100644 --- a/contrib/net/http/roundtripper.go +++ b/contrib/net/http/roundtripper.go @@ -52,7 +52,11 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (res *http.Response, err er if rt.cfg.after != nil { rt.cfg.after(res, span) } - span.Finish(tracer.WithError(err)) + if rt.cfg.errCheck == nil || rt.cfg.errCheck(err) { + span.Finish(tracer.WithError(err)) + } else { + span.Finish() + } }() if rt.cfg.before != nil { rt.cfg.before(req, span) @@ -67,7 +71,9 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (res *http.Response, err er res, err = rt.base.RoundTrip(r2) if err != nil { span.SetTag("http.errors", err.Error()) - span.SetTag(ext.Error, err) + if rt.cfg.errCheck == nil || rt.cfg.errCheck(err) { + span.SetTag(ext.Error, err) + } } else { span.SetTag(ext.HTTPCode, strconv.Itoa(res.StatusCode)) // treat 5XX as errors diff --git a/contrib/net/http/roundtripper_test.go b/contrib/net/http/roundtripper_test.go index 5d56411513..114e6512a3 100644 --- a/contrib/net/http/roundtripper_test.go +++ b/contrib/net/http/roundtripper_test.go @@ -183,6 +183,53 @@ func TestRoundTripperNetworkError(t *testing.T) { assert.Equal(t, "net/http", s0.Tag(ext.Component)) } +func TestRoundTripperNetworkErrorWithErrorCheck(t *testing.T) { + failedRequest := func(t *testing.T, mt mocktracer.Tracer, forwardErr bool, opts ...RoundTripperOption) mocktracer.Span { + done := make(chan struct{}) + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := tracer.Extract(tracer.HTTPHeadersCarrier(r.Header)) + assert.NoError(t, err) + <-done + })) + defer s.Close() + defer close(done) + + rt := WrapRoundTripper(http.DefaultTransport, + RTWithErrorCheck(func(err error) bool { + return forwardErr + })) + + client := &http.Client{ + Transport: rt, + Timeout: 1 * time.Millisecond, + } + + client.Get(s.URL + "/hello/world") + + spans := mt.FinishedSpans() + assert.Len(t, spans, 1) + + s0 := spans[0] + return s0 + } + + t.Run("error skipped", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + span := failedRequest(t, mt, false) + assert.Nil(t, span.Tag(ext.Error)) + }) + + t.Run("error forwarded", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + span := failedRequest(t, mt, true) + assert.NotNil(t, span.Tag(ext.Error)) + }) +} + func TestRoundTripperCredentials(t *testing.T) { mt := mocktracer.Start() defer mt.Stop()