diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ed20569b..ae007c43c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ Unreleased changes are available as `avenga/couper:edge` container. * **Fixed** * build-date configuration for binary and docker builds ([#396](https://github.com/avenga/couper/pull/396)) * exclude file descriptor limit startup-logs for Windows ([#396](https://github.com/avenga/couper/pull/396), [#383](https://github.com/avenga/couper/pull/383)) + * possible race conditions while updating JWKS for the [JWT access control](./docs/REFERENCE.md#jwt-block) ([#398](https://github.com/avenga/couper/pull/398)) --- diff --git a/accesscontrol/jwks.go b/accesscontrol/jwks.go index 6d5429b09..b35df9151 100644 --- a/accesscontrol/jwks.go +++ b/accesscontrol/jwks.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net/http" "strings" + "sync" "time" "github.com/avenga/couper/config/reader" @@ -21,6 +22,7 @@ type JWKS struct { uri string transport http.RoundTripper ttl time.Duration + mtx sync.RWMutex } func NewJWKS(uri string, ttl string, transport http.RoundTripper, confContext context.Context) (*JWKS, error) { @@ -49,15 +51,23 @@ func NewJWKS(uri string, ttl string, transport http.RoundTripper, confContext co } func (j *JWKS) GetKeys(kid string) ([]JWK, error) { - var keys []JWK - - if len(j.Keys) == 0 || j.hasExpired() { - if err := j.Load(); err != nil { + var ( + keys []JWK + err error + ) + + j.mtx.RLock() + allKeys := j.Keys + expired := j.hasExpired() + j.mtx.RUnlock() + if len(allKeys) == 0 || expired { + allKeys, err = j.Load() + if err != nil { return keys, fmt.Errorf("error loading JWKS: %v", err) } } - for _, key := range j.Keys { + for _, key := range allKeys { if key.KeyID == kid { keys = append(keys, key) } @@ -79,19 +89,19 @@ func (j *JWKS) GetKey(kid string, alg string, use string) (*JWK, error) { return nil, nil } -func (j *JWKS) Load() error { +func (j *JWKS) Load() ([]JWK, error) { var rawJSON []byte if j.file != "" { j, err := reader.ReadFromFile("jwks_url", j.file) if err != nil { - return err + return nil, err } rawJSON = j } else if j.transport != nil { req, err := http.NewRequest("GET", "", nil) if err != nil { - return err + return nil, err } ctx := context.WithValue(j.context, request.URLAttribute, j.uri) // TODO which roundtrip name? @@ -99,33 +109,36 @@ func (j *JWKS) Load() error { req = req.WithContext(ctx) response, err := j.transport.RoundTrip(req) if err != nil { - return err + return nil, err } if response.StatusCode != 200 { - return fmt.Errorf("status code %d", response.StatusCode) + return nil, fmt.Errorf("status code %d", response.StatusCode) } defer response.Body.Close() body, err := ioutil.ReadAll(response.Body) if err != nil { - return fmt.Errorf("error reading JWKS response for %q: %v", j.uri, err) + return nil, fmt.Errorf("error reading JWKS response for %q: %v", j.uri, err) } rawJSON = body } else { - return fmt.Errorf("jwks: missing both file and request") + return nil, fmt.Errorf("jwks: missing both file and request") } var jwks JWKS err := json.Unmarshal(rawJSON, &jwks) if err != nil { - return err + return nil, err } + j.mtx.Lock() + defer j.mtx.Unlock() + j.Keys = jwks.Keys j.expiry = time.Now().Unix() + int64(j.ttl.Seconds()) - return nil + return j.Keys, nil } func (jwks *JWKS) hasExpired() bool { diff --git a/accesscontrol/jwks_test.go b/accesscontrol/jwks_test.go index bdc97ea46..8aa94a728 100644 --- a/accesscontrol/jwks_test.go +++ b/accesscontrol/jwks_test.go @@ -1,6 +1,7 @@ package accesscontrol_test import ( + "sync" "testing" ac "github.com/avenga/couper/accesscontrol" @@ -47,7 +48,7 @@ func Test_JWKS_Load(t *testing.T) { t.Run(tt.name, func(subT *testing.T) { jwks, err := ac.NewJWKS("file:"+tt.file, "", nil, nil) helper.Must(err) - err = jwks.Load() + _, err = jwks.Load() if err != nil && tt.expParsed { subT.Error("no jwks parsed") } @@ -91,7 +92,7 @@ func Test_JWKS_GetKey(t *testing.T) { helper := test.New(subT) jwks, err := ac.NewJWKS("file:"+tt.file, "", nil, nil) helper.Must(err) - err = jwks.Load() + _, err = jwks.Load() helper.Must(err) jwk, err := jwks.GetKey(tt.kid, tt.alg, tt.use) if jwk == nil && tt.expFound { @@ -103,3 +104,25 @@ func Test_JWKS_GetKey(t *testing.T) { }) } } + +func Test_JWKS_LoadSynced(t *testing.T) { + helper := test.New(t) + + memQuitCh := make(chan struct{}) + defer close(memQuitCh) + + jwks, err := ac.NewJWKS("file:testdata/jwks.json", "", nil, nil) + helper.Must(err) + + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func(idx int) { + defer wg.Done() + + _, e := jwks.GetKeys("kid1") + helper.Must(e) + }(i) + } + wg.Wait() +}