Skip to content

Commit

Permalink
JWKS synchronization (#398)
Browse files Browse the repository at this point in the history
* Test for JWKS synchronization
* Set lock on Keys and expiry
* Use local variables to reduce read locks
* changelog entry
  • Loading branch information
johakoch committed Nov 29, 2021
1 parent 539fab3 commit 5e0b787
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Unreleased changes are available as `avenga/couper:edge` container.
* **Fixed**
* build-date configuration for binary and docker builds ([#396](https://github.com/avenga/couper/pull/396))
* exclude file descriptor limit startup-logs for Windows ([#396](https://github.com/avenga/couper/pull/396), [#383](https://github.com/avenga/couper/pull/383))
* possible race conditions while updating JWKS for the [JWT access control](./docs/REFERENCE.md#jwt-block) ([#398](https://github.com/avenga/couper/pull/398))

---

Expand Down
41 changes: 27 additions & 14 deletions accesscontrol/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/ioutil"
"net/http"
"strings"
"sync"
"time"

"github.com/avenga/couper/config/reader"
Expand All @@ -21,6 +22,7 @@ type JWKS struct {
uri string
transport http.RoundTripper
ttl time.Duration
mtx sync.RWMutex
}

func NewJWKS(uri string, ttl string, transport http.RoundTripper, confContext context.Context) (*JWKS, error) {
Expand Down Expand Up @@ -49,15 +51,23 @@ func NewJWKS(uri string, ttl string, transport http.RoundTripper, confContext co
}

func (j *JWKS) GetKeys(kid string) ([]JWK, error) {
var keys []JWK

if len(j.Keys) == 0 || j.hasExpired() {
if err := j.Load(); err != nil {
var (
keys []JWK
err error
)

j.mtx.RLock()
allKeys := j.Keys
expired := j.hasExpired()
j.mtx.RUnlock()
if len(allKeys) == 0 || expired {
allKeys, err = j.Load()
if err != nil {
return keys, fmt.Errorf("error loading JWKS: %v", err)
}
}

for _, key := range j.Keys {
for _, key := range allKeys {
if key.KeyID == kid {
keys = append(keys, key)
}
Expand All @@ -79,53 +89,56 @@ func (j *JWKS) GetKey(kid string, alg string, use string) (*JWK, error) {
return nil, nil
}

func (j *JWKS) Load() error {
func (j *JWKS) Load() ([]JWK, error) {
var rawJSON []byte

if j.file != "" {
j, err := reader.ReadFromFile("jwks_url", j.file)
if err != nil {
return err
return nil, err
}
rawJSON = j
} else if j.transport != nil {
req, err := http.NewRequest("GET", "", nil)
if err != nil {
return err
return nil, err
}
ctx := context.WithValue(j.context, request.URLAttribute, j.uri)
// TODO which roundtrip name?
ctx = context.WithValue(ctx, request.RoundTripName, "jwks")
req = req.WithContext(ctx)
response, err := j.transport.RoundTrip(req)
if err != nil {
return err
return nil, err
}
if response.StatusCode != 200 {
return fmt.Errorf("status code %d", response.StatusCode)
return nil, fmt.Errorf("status code %d", response.StatusCode)
}

defer response.Body.Close()

body, err := ioutil.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error reading JWKS response for %q: %v", j.uri, err)
return nil, fmt.Errorf("error reading JWKS response for %q: %v", j.uri, err)
}
rawJSON = body
} else {
return fmt.Errorf("jwks: missing both file and request")
return nil, fmt.Errorf("jwks: missing both file and request")
}

var jwks JWKS
err := json.Unmarshal(rawJSON, &jwks)
if err != nil {
return err
return nil, err
}

j.mtx.Lock()
defer j.mtx.Unlock()

j.Keys = jwks.Keys
j.expiry = time.Now().Unix() + int64(j.ttl.Seconds())

return nil
return j.Keys, nil
}

func (jwks *JWKS) hasExpired() bool {
Expand Down
27 changes: 25 additions & 2 deletions accesscontrol/jwks_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package accesscontrol_test

import (
"sync"
"testing"

ac "github.com/avenga/couper/accesscontrol"
Expand Down Expand Up @@ -47,7 +48,7 @@ func Test_JWKS_Load(t *testing.T) {
t.Run(tt.name, func(subT *testing.T) {
jwks, err := ac.NewJWKS("file:"+tt.file, "", nil, nil)
helper.Must(err)
err = jwks.Load()
_, err = jwks.Load()
if err != nil && tt.expParsed {
subT.Error("no jwks parsed")
}
Expand Down Expand Up @@ -91,7 +92,7 @@ func Test_JWKS_GetKey(t *testing.T) {
helper := test.New(subT)
jwks, err := ac.NewJWKS("file:"+tt.file, "", nil, nil)
helper.Must(err)
err = jwks.Load()
_, err = jwks.Load()
helper.Must(err)
jwk, err := jwks.GetKey(tt.kid, tt.alg, tt.use)
if jwk == nil && tt.expFound {
Expand All @@ -103,3 +104,25 @@ func Test_JWKS_GetKey(t *testing.T) {
})
}
}

func Test_JWKS_LoadSynced(t *testing.T) {
helper := test.New(t)

memQuitCh := make(chan struct{})
defer close(memQuitCh)

jwks, err := ac.NewJWKS("file:testdata/jwks.json", "", nil, nil)
helper.Must(err)

wg := sync.WaitGroup{}
wg.Add(10)
for i := 0; i < 10; i++ {
go func(idx int) {
defer wg.Done()

_, e := jwks.GetKeys("kid1")
helper.Must(e)
}(i)
}
wg.Wait()
}

0 comments on commit 5e0b787

Please sign in to comment.