diff --git a/docs/content/rest-api.md b/docs/content/rest-api.md index c652c5f4ad..826e03e266 100644 --- a/docs/content/rest-api.md +++ b/docs/content/rest-api.md @@ -1049,6 +1049,7 @@ The request body contains an object that specifies a value for [The input Docume - **metrics** - Return query performance metrics in addition to result. See [Performance Metrics](#performance-metrics) for more detail. - **instrument** - Instrument query evaluation and return a superset of performance metrics in addition to result. See [Performance Metrics](#performance-metrics) for more detail. - **watch** - Set a watch on the data reference if the parameter is present. See [Watches](#watches) for more detail. +- **decision_id** - Client supplied decision ID to use for tracing. If not specified a random ID is generated automatically. #### Status Codes diff --git a/server/server.go b/server/server.go index 004e926270..45a678d4dc 100644 --- a/server/server.go +++ b/server/server.go @@ -1124,6 +1124,7 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) partial := getBoolParam(r.URL, types.ParamPartialV1, true) provenance := getBoolParam(r.URL, types.ParamProvenanceV1, true) + decisionID = getStringParam(r.URL, types.ParamDecisionIDV1, decisionID) m.Timer(metrics.RegoQueryParse).Start() @@ -2083,6 +2084,14 @@ func validateParsedQuery(body ast.Body) ([]string, error) { return unsafeOperators, nil } +func getStringParam(url *url.URL, name string, ifEmpty string) string { + + if p := url.Query().Get(name); p != "" { + return p + } + return ifEmpty +} + func getBoolParam(url *url.URL, name string, ifEmpty bool) bool { p, ok := url.Query()[name] diff --git a/server/server_test.go b/server/server_test.go index 0ab4e5a130..c52ce755ff 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2460,6 +2460,39 @@ func TestDecisionIDs(t *testing.T) { } } +func TestUserProvidedDecisionIDs(t *testing.T) { + f := newFixture(t) + f.server = f.server.WithDiagnosticsBuffer(NewBoundedBuffer(4)) + decisionID := 6 + + enableDiagnostics := ` + package system.diagnostics + + config = {"mode": "on"} + ` + + if err := f.v1("PUT", "/policies/test", enableDiagnostics, 200, "{}"); err != nil { + t.Fatal(err) + } + + if err := f.v1("POST", "/data?decision_id=6", "", 200, `{"decision_id": "6", "result": {}}`); err != nil { + t.Fatal(err) + } + + infos := []*Info{} + + f.server.diagnostics.Iter(func(info *Info) { + if info.DecisionID != fmt.Sprint(decisionID) { + t.Fatalf("Expected decision ID to be %v but got: %v", decisionID, info.DecisionID) + } + infos = append(infos, info) + }) + + if len(infos) != 1 { + t.Fatalf("Expected exactly 1 elements but got: %v", len(infos)) + } +} + func TestDecisionLogging(t *testing.T) { f := newFixture(t) diff --git a/server/types/types.go b/server/types/types.go index b0e96869c2..259b80bc05 100644 --- a/server/types/types.go +++ b/server/types/types.go @@ -438,6 +438,10 @@ const ( // indicates the client wants to include bundle activation in the results // of the health API. ParamBundleActivationV1 = "bundle" + + // ParamDecisionIDV1 defines the name of the HTTP URL parameter that + // the client provided the decision id + ParamDecisionIDV1 = "decision_id" ) // BadRequestErr represents an error condition raised if the caller passes