Skip to content

Commit

Permalink
Merge pull request #112 from krakendio/extra_cache
Browse files Browse the repository at this point in the history
Add a global cache layer common for all the endpoints
  • Loading branch information
taik0 authored Apr 18, 2023
2 parents 2e12718 + 75d4a38 commit e0c32b6
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 4 deletions.
2 changes: 1 addition & 1 deletion jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func SecretProvider(cfg SecretProviderConfig, te auth0.RequestTokenExtractor) (*
client := NewJWKClientWithCache(
opts,
te,
NewMemoryKeyCacher(cacheDuration, auth0.MaxCacheSizeNoCheck, opts.KeyIdentifyStrategy),
NewGlobalMemoryKeyCacher(cacheDuration, auth0.MaxCacheSizeNoCheck, opts.KeyIdentifyStrategy),
)

// request an unexistent key in order to cache all the actual ones
Expand Down
80 changes: 80 additions & 0 deletions jwk_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package jose

import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/auth0-community/go-auth0"
"github.com/luraproject/lura/v2/config"
"github.com/luraproject/lura/v2/logging"
)

func TestJWKClient_globalCache(t *testing.T) {
jwk := []byte(`{ "keys": [{
"kty": "RSA",
"e": "AQAB",
"use": "sig",
"kid": "8-2-2PBmlHKMo5tizxp-uw9pFrQQamfa1M1ZYMrAFZI",
"alg": "RS256",
"n": "n6p2fLU7PLwMvJ-xeukn-f5wrAdyZ0ZaFa6kanQzVBofacLs2l4FVe6_bcjw4VGWM2Ct3WgelZQUYVkFbqePODpMnV0lV8U4hxbIpMEJOJqY3tK48_PBIdEkl02DN8LaucK1Y7GpOlUZFrWAOM68TyWJTjkyc-yx0ibu2MFaGQoXacV7239Yei_x68iGBpQa2f9SYv8U5nJINdI1CuyccQp991qeskJATgn-UVqQfOfHDsUA2qud2yNOf5QKkvqqPEH_IXuTtPcf_yzVuco9rhhUW8q5bC4R0BxjCv9w4b-Q_UKjKEXQK5UlAuiWqWgmQbQO9Ne94EDFpjlkCtil2Q"
}]}`)

var count uint64
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
atomic.AddUint64(&count, 1)
w.Write(jwk)
}))

defer backend.Close()
opts := JWKClientOptions{
JWKClientOptions: auth0.JWKClientOptions{
URI: backend.URL,
},
}
te := auth0.FromMultiple(
auth0.RequestTokenExtractorFunc(auth0.FromHeader),
)
cfg := config.ExtraConfig{
ValidatorNamespace: map[string]interface{}{
"shared_cache_duration": 3,
},
}
if err := SetGlobalCacher(logging.NoOp, cfg); err != nil {
t.Error(err)
return
}
for i := 0; i < 10; i++ {
client := NewJWKClientWithCache(
opts,
te,
NewGlobalMemoryKeyCacher(1*time.Second, auth0.MaxCacheSizeNoCheck, opts.KeyIdentifyStrategy),
)
if _, err := client.GetKey("8-2-2PBmlHKMo5tizxp-uw9pFrQQamfa1M1ZYMrAFZI"); err != nil {
t.Error(err)
return
}
}
if count != 1 {
t.Errorf("invalid count %d", count)
return
}
<-time.After(4 * time.Second)
for i := 0; i < 10; i++ {
client := NewJWKClientWithCache(
opts,
te,
NewGlobalMemoryKeyCacher(1*time.Second, auth0.MaxCacheSizeNoCheck, opts.KeyIdentifyStrategy),
)
if _, err := client.GetKey("8-2-2PBmlHKMo5tizxp-uw9pFrQQamfa1M1ZYMrAFZI"); err != nil {
t.Error(err)
return
}
}
if count != 2 {
t.Errorf("invalid count %d", count)
}
}
114 changes: 111 additions & 3 deletions key_cacher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,77 @@ package jose

import (
b64 "encoding/base64"
"encoding/json"
"errors"
"fmt"
"sync"
"time"

"github.com/luraproject/lura/v2/config"
"github.com/luraproject/lura/v2/logging"
"gopkg.in/square/go-jose.v2"
)

var (
ErrNoKeyFound = errors.New("no Keys have been found")
ErrKeyExpired = errors.New("key exists but is expired")
ErrNoKeyFound = errors.New("no Keys have been found")
ErrKeyExpired = errors.New("key exists but is expired")
defaultGlobalCacheMaxAge uint32 = 900
defaultStrategy = "kid"

// Configuring with MaxKeyAgeNoCheck will skip key expiry check
MaxKeyAgeNoCheck = time.Duration(-1)
MaxKeyAgeNoCheck = time.Duration(-1)
globalKeyCacher = map[string]GlobalCacher{}
globalKeyCacherOnce = new(sync.Once)
)

type GlobalCacher struct {
kc KeyCacher
mu *sync.RWMutex
}

func SetGlobalCacher(l logging.Logger, cfg config.ExtraConfig) error {
scfg, err := configGetter(l, cfg)
if err != nil {
if err != ErrNoValidatorCfg {
l.Error("[SERVICE: JOSE]", err.Error())
}
return err
}
duration := time.Duration(scfg.CacheDuration) * time.Second
globalKeyCacherOnce.Do(func() {
globalKeyCacher = map[string]GlobalCacher{
"kid": {kc: NewMemoryKeyCacher(duration, -1, "kid"), mu: new(sync.RWMutex)},
"x5t": {kc: NewMemoryKeyCacher(duration, -1, "x5t"), mu: new(sync.RWMutex)},
"kid_x5t": {kc: NewMemoryKeyCacher(duration, -1, "kid_x5t"), mu: new(sync.RWMutex)},
}
})
return nil
}

type serviceConfig struct {
CacheDuration uint32 `json:"shared_cache_duration"`
}

func configGetter(l logging.Logger, cfg config.ExtraConfig) (serviceConfig, error) {
scfg := serviceConfig{}
e, ok := cfg[ValidatorNamespace].(map[string]interface{})
if !ok {
return scfg, fmt.Errorf("no config")
}
tmp, err := json.Marshal(e)
if err != nil {
return scfg, err
}
if err := json.Unmarshal(tmp, &scfg); err != nil {
return scfg, err
}
if scfg.CacheDuration == 0 {
scfg.CacheDuration = defaultGlobalCacheMaxAge
l.Info("[SERVICE: JOSE] Empty shared_cache_duration, using default (15m)")
}
return scfg, nil
}

// KeyIDGetter extracts a key id from a JSONWebKey
type KeyIDGetter interface {
Get(*jose.JSONWebKey) string
Expand Down Expand Up @@ -74,6 +131,37 @@ type keyCacherEntry struct {
jose.JSONWebKey
}

type GMemoryKeyCacher struct {
*MemoryKeyCacher
Global GlobalCacher
}

func (gkc *GMemoryKeyCacher) Add(keyID string, downloadedKeys []jose.JSONWebKey) (*jose.JSONWebKey, error) {
if gkc.Global.kc != nil {
gkc.Global.mu.Lock()
gkc.Global.kc.Add(keyID, downloadedKeys)
gkc.Global.mu.Unlock()
}

return gkc.MemoryKeyCacher.Add(keyID, downloadedKeys)
}

// Get obtains a key from the cache, and checks if the key is expired
func (gkc *GMemoryKeyCacher) Get(keyID string) (*jose.JSONWebKey, error) {
k, err := gkc.MemoryKeyCacher.Get(keyID)
if err == nil || gkc.Global.kc == nil {
return k, err
}

gkc.Global.mu.RLock()
v, err := gkc.Global.kc.Get(keyID)
gkc.Global.mu.RUnlock()
if err == nil {
gkc.MemoryKeyCacher.Add(keyID, []jose.JSONWebKey{*v})
}
return v, err
}

// NewMemoryKeyCacher creates a new Keycacher interface with option
// to set max age of cached keys and max size of the cache.
func NewMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) KeyCacher {
Expand All @@ -85,6 +173,26 @@ func NewMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifySt
}
}

func NewGlobalMemoryKeyCacher(maxKeyAge time.Duration, maxCacheSize int, keyIdentifyStrategy string) KeyCacher {
kc := &GMemoryKeyCacher{
MemoryKeyCacher: &MemoryKeyCacher{
entries: map[string]keyCacherEntry{},
maxKeyAge: maxKeyAge,
maxCacheSize: maxCacheSize,
keyIDGetter: KeyIDGetterFactory(keyIdentifyStrategy),
},
Global: GlobalCacher{},
}
if keyIdentifyStrategy == "" {
keyIdentifyStrategy = defaultStrategy
}
if len(globalKeyCacher) > 0 {
g := globalKeyCacher[keyIdentifyStrategy]
kc.Global = g
}
return kc
}

// Get obtains a key from the cache, and checks if the key is expired
func (mkc *MemoryKeyCacher) Get(keyID string) (*jose.JSONWebKey, error) {
searchKey, ok := mkc.entries[keyID]
Expand Down

0 comments on commit e0c32b6

Please sign in to comment.