diff --git a/README.md b/README.md index a17cb24..a080035 100644 --- a/README.md +++ b/README.md @@ -141,3 +141,24 @@ func main() { item := cache.Get("key from file") } ``` + +To restrict the cache's capacity based on criteria beyond the number +of items it can hold, the `ttlcache.WithMaxCost` option allows for +implementing custom strategies. The following example demonstrates +how to limit the maximum memory usage of a cache to 5KiB: +```go +import ( + "github.com/jellydator/ttlcache" + "github.com/DmitriyVTitov/size" +) + +func main() { + cache := ttlcache.New[string, string]( + ttlcache.WithMaxCost[string, string](5120, func(item *ttlcache.Item[string, string]) uint64 { + return uint64(size.Of(item)) + }), + ) + + cache.Set("first", "value1", ttlcache.DefaultTTL) +} +``` diff --git a/cache.go b/cache.go index fbfaebe..24ca09a 100644 --- a/cache.go +++ b/cache.go @@ -15,6 +15,7 @@ const ( EvictionReasonDeleted EvictionReason = iota + 1 EvictionReasonCapacityReached EvictionReasonExpired + EvictionReasonMaxCostExceeded ) // EvictionReason is used to specify why a certain item was @@ -36,6 +37,7 @@ type Cache[K comparable, V any] struct { timerCh chan time.Duration } + cost uint64 metricsMu sync.RWMutex metrics Metrics @@ -137,9 +139,20 @@ func (c *Cache[K, V]) set(key K, value V, ttl time.Duration) *Item[K, V] { if elem != nil { // update/overwrite an existing item item := elem.Value.(*Item[K, V]) + oldItemCost := item.cost + item.update(value, ttl) + c.updateExpirations(false, elem) + if c.options.maxCost != 0 { + c.cost = c.cost - oldItemCost + item.cost + + for c.cost > c.options.maxCost { + c.evict(EvictionReasonMaxCostExceeded, c.items.lru.Back()) + } + } + return item } @@ -153,11 +166,19 @@ func (c *Cache[K, V]) set(key K, value V, ttl time.Duration) *Item[K, V] { } // create a new item - item := NewItem(key, value, ttl, c.options.enableVersionTracking) + item := newItemWithOpts(key, value, ttl, c.options.itemOpts...) elem = c.items.lru.PushFront(item) c.items.values[key] = elem c.updateExpirations(true, elem) + if c.options.maxCost != 0 { + c.cost += item.cost + + for c.cost > c.options.maxCost { + c.evict(EvictionReasonMaxCostExceeded, c.items.lru.Back()) + } + } + c.metricsMu.Lock() c.metrics.Insertions++ c.metricsMu.Unlock() @@ -258,6 +279,11 @@ func (c *Cache[K, V]) evict(reason EvictionReason, elems ...*list.Element) { for i := range elems { item := elems[i].Value.(*Item[K, V]) delete(c.items.values, item.key) + + if c.options.maxCost != 0 { + c.cost -= item.cost + } + c.items.lru.Remove(elems[i]) c.items.expQueue.remove(elems[i]) diff --git a/cache_test.go b/cache_test.go index 68bd340..ba58c3f 100644 --- a/cache_test.go +++ b/cache_test.go @@ -122,7 +122,7 @@ func Test_Cache_updateExpirations(t *testing.T) { t.Run(cn, func(t *testing.T) { t.Parallel() - cache := prepCache(time.Hour) + cache := prepCache(0, time.Hour) if c.TimerChValue > 0 { cache.items.timerCh <- c.TimerChValue @@ -172,6 +172,7 @@ func Test_Cache_set(t *testing.T) { cc := map[string]struct { Capacity uint64 + MaxCost uint64 Key string TTL time.Duration Metrics Metrics @@ -244,6 +245,33 @@ func Test_Cache_set(t *testing.T) { }, ExpectFns: true, }, + "Set with existing key and eviction caused by exhausted cost": { + MaxCost: 30, + Key: existingKey, + TTL: DefaultTTL, + Metrics: Metrics{ + Insertions: 0, + Evictions: 1, + }, + }, + "Set with existing key and no eviction": { + MaxCost: 50, + Key: existingKey, + TTL: DefaultTTL, + Metrics: Metrics{ + Insertions: 0, + Evictions: 0, + }, + }, + "Set with new key and eviction caused by exhausted cost": { + MaxCost: 40, + Key: newKey, + TTL: DefaultTTL, + Metrics: Metrics{ + Insertions: 1, + Evictions: 1, + }, + }, } for cn, c := range cc { @@ -260,7 +288,7 @@ func Test_Cache_set(t *testing.T) { // calculated based on how addToCache sets ttl existingKeyTTL := time.Hour + time.Minute - cache := prepCache(time.Hour, evictedKey, existingKey, "test3") + cache := prepCache(c.MaxCost, time.Hour, evictedKey, existingKey, "test3") cache.options.capacity = c.Capacity cache.options.ttl = time.Minute * 20 cache.events.insertion.fns[1] = func(item *Item[string, string]) { @@ -269,16 +297,18 @@ func Test_Cache_set(t *testing.T) { } cache.events.insertion.fns[2] = cache.events.insertion.fns[1] cache.events.eviction.fns[1] = func(r EvictionReason, item *Item[string, string]) { - assert.Equal(t, EvictionReasonCapacityReached, r) + if c.MaxCost != 0 { + assert.Equal(t, EvictionReasonMaxCostExceeded, r) + } else { + assert.Equal(t, EvictionReasonCapacityReached, r) + } + assert.Equal(t, evictedKey, item.key) evictionFnsCalls++ } cache.events.eviction.fns[2] = cache.events.eviction.fns[1] - total := 3 - if c.Key == newKey && (c.Capacity == 0 || c.Capacity >= 4) { - total++ - } + total := 3 - int(c.Metrics.Evictions) + int(c.Metrics.Insertions) item := cache.set(c.Key, "value123", c.TTL) @@ -390,7 +420,7 @@ func Test_Cache_get(t *testing.T) { t.Run(cn, func(t *testing.T) { t.Parallel() - cache := prepCache(time.Hour, existingKey, "test2", "test3") + cache := prepCache(0, time.Hour, existingKey, "test2", "test3") addToCache(cache, time.Nanosecond, expiredKey) time.Sleep(time.Millisecond) // force expiration @@ -441,7 +471,7 @@ func Test_Cache_evict(t *testing.T) { key4FnsCalls int ) - cache := prepCache(time.Hour, "1", "2", "3", "4") + cache := prepCache(0, time.Hour, "1", "2", "3", "4") cache.events.eviction.fns[1] = func(r EvictionReason, item *Item[string, string]) { assert.Equal(t, EvictionReasonDeleted, r) switch item.key { @@ -486,7 +516,7 @@ func Test_Cache_evict(t *testing.T) { } func Test_Cache_Set(t *testing.T) { - cache := prepCache(time.Hour, "test1", "test2", "test3") + cache := prepCache(0, time.Hour, "test1", "test2", "test3") item := cache.Set("hello", "value123", time.Minute) require.NotNil(t, item) assert.Same(t, item, cache.items.values["hello"].Value) @@ -599,7 +629,7 @@ func Test_Cache_Get(t *testing.T) { t.Run(cn, func(t *testing.T) { t.Parallel() - cache := prepCache(time.Minute, foundKey, "test2", "test3") + cache := prepCache(0, time.Minute, foundKey, "test2", "test3") oldExpiresAt := cache.items.values[foundKey].Value.(*Item[string, string]).expiresAt cache.options = c.DefaultOptions @@ -632,7 +662,7 @@ func Test_Cache_Get(t *testing.T) { func Test_Cache_Delete(t *testing.T) { var fnsCalls int - cache := prepCache(time.Hour, "1", "2", "3", "4") + cache := prepCache(0, time.Hour, "1", "2", "3", "4") cache.events.eviction.fns[1] = func(r EvictionReason, item *Item[string, string]) { assert.Equal(t, EvictionReasonDeleted, r) fnsCalls++ @@ -652,7 +682,7 @@ func Test_Cache_Delete(t *testing.T) { } func Test_Cache_Has(t *testing.T) { - cache := prepCache(time.Hour, "1") + cache := prepCache(0, time.Hour, "1") addToCache(cache, time.Nanosecond, "2") assert.True(t, cache.Has("1")) @@ -661,7 +691,7 @@ func Test_Cache_Has(t *testing.T) { } func Test_Cache_GetOrSet(t *testing.T) { - cache := prepCache(time.Hour) + cache := prepCache(0, time.Hour) item, retrieved := cache.GetOrSet("test", "1", WithTTL[string, string](time.Minute)) require.NotNil(t, item) assert.Same(t, item, cache.items.values["test"].Value) @@ -685,7 +715,7 @@ func Test_Cache_GetOrSet(t *testing.T) { } func Test_Cache_GetAndDelete(t *testing.T) { - cache := prepCache(time.Hour, "test1", "test2", "test3") + cache := prepCache(0, time.Hour, "test1", "test2", "test3") listItem := cache.items.lru.Front() require.NotNil(t, listItem) assert.Same(t, listItem, cache.items.values["test3"]) @@ -721,7 +751,7 @@ func Test_Cache_DeleteAll(t *testing.T) { key4FnsCalls int ) - cache := prepCache(time.Hour, "1", "2", "3", "4") + cache := prepCache(0, time.Hour, "1", "2", "3", "4") cache.events.eviction.fns[1] = func(r EvictionReason, item *Item[string, string]) { assert.Equal(t, EvictionReasonDeleted, r) switch item.key { @@ -751,7 +781,7 @@ func Test_Cache_DeleteExpired(t *testing.T) { key2FnsCalls int ) - cache := prepCache(time.Hour) + cache := prepCache(0, time.Hour) cache.events.eviction.fns[1] = func(r EvictionReason, item *Item[string, string]) { assert.Equal(t, EvictionReasonExpired, r) switch item.key { @@ -792,7 +822,7 @@ func Test_Cache_DeleteExpired(t *testing.T) { } func Test_Cache_Touch(t *testing.T) { - cache := prepCache(time.Hour, "1", "2") + cache := prepCache(0, time.Hour, "1", "2") oldExpiresAt := cache.items.values["1"].Value.(*Item[string, string]).expiresAt cache.Touch("1") @@ -803,7 +833,7 @@ func Test_Cache_Touch(t *testing.T) { } func Test_Cache_Len(t *testing.T) { - cache := prepCache(time.Hour) + cache := prepCache(0, time.Hour) assert.Equal(t, 0, cache.Len()) addToCache(cache, time.Hour, "1") @@ -820,13 +850,13 @@ func Test_Cache_Len(t *testing.T) { } func Test_Cache_Keys(t *testing.T) { - cache := prepCache(time.Hour, "1", "2", "3") + cache := prepCache(0, time.Hour, "1", "2", "3") addToCache(cache, time.Nanosecond, "4") assert.ElementsMatch(t, []string{"1", "2", "3"}, cache.Keys()) } func Test_Cache_Items(t *testing.T) { - cache := prepCache(time.Hour, "1", "2", "3") + cache := prepCache(0, time.Hour, "1", "2", "3") addToCache(cache, time.Nanosecond, "4") items := cache.Items() require.Len(t, items, 3) @@ -840,7 +870,7 @@ func Test_Cache_Items(t *testing.T) { } func Test_Cache_Range(t *testing.T) { - c := prepCache(DefaultTTL, "1", "2", "3", "4", "5") + c := prepCache(0, DefaultTTL, "1", "2", "3", "4", "5") addToCache(c, time.Nanosecond, "6") var results []string @@ -860,7 +890,7 @@ func Test_Cache_Range(t *testing.T) { } func Test_Cache_RangeBackwards(t *testing.T) { - c := prepCache(DefaultTTL) + c := prepCache(0, DefaultTTL) addToCache(c, time.Nanosecond, "1") addToCache(c, time.Hour, "2", "3", "4", "5") @@ -890,7 +920,7 @@ func Test_Cache_Metrics(t *testing.T) { } func Test_Cache_Start(t *testing.T) { - cache := prepCache(0) + cache := prepCache(0, 0) cache.stopCh = make(chan struct{}) addToCache(cache, time.Nanosecond, "1") @@ -938,7 +968,7 @@ func Test_Cache_Stop(t *testing.T) { func Test_Cache_OnInsertion(t *testing.T) { checkCh := make(chan struct{}) resCh := make(chan struct{}) - cache := prepCache(time.Hour) + cache := prepCache(0, time.Hour) del1 := cache.OnInsertion(func(_ context.Context, _ *Item[string, string]) { checkCh <- struct{}{} }) @@ -1022,7 +1052,7 @@ func Test_Cache_OnInsertion(t *testing.T) { func Test_Cache_OnEviction(t *testing.T) { checkCh := make(chan struct{}) resCh := make(chan struct{}) - cache := prepCache(time.Hour) + cache := prepCache(0, time.Hour) del1 := cache.OnEviction(func(_ context.Context, _ EvictionReason, _ *Item[string, string]) { checkCh <- struct{}{} }) @@ -1181,7 +1211,7 @@ func Test_SuppressedLoader_Load(t *testing.T) { item1, item2 *Item[string, string] ) - cache := prepCache(time.Hour) + cache := prepCache(0, time.Hour) // nil result wg.Add(2) @@ -1228,9 +1258,20 @@ func Test_SuppressedLoader_Load(t *testing.T) { assert.Equal(t, 1, loadCalls) } -func prepCache(ttl time.Duration, keys ...string) *Cache[string, string] { +func prepCache(maxCost uint64, ttl time.Duration, keys ...string) *Cache[string, string] { c := &Cache[string, string]{} c.options.ttl = ttl + c.options.itemOpts = append(c.options.itemOpts, + withVersionTracking[string, string](false)) + + if maxCost != 0 { + c.options.maxCost = maxCost + c.options.itemOpts = append(c.options.itemOpts, + withCostFunc[string, string](func(item *Item[string, string]) uint64 { + return uint64(len(item.value)) + })) + } + c.items.values = make(map[string]*list.Element) c.items.lru = list.New() c.items.expQueue = newExpirationQueue[string, string]() @@ -1245,14 +1286,19 @@ func prepCache(ttl time.Duration, keys ...string) *Cache[string, string] { func addToCache(c *Cache[string, string], ttl time.Duration, keys ...string) { for i, key := range keys { - item := NewItem( + value := fmt.Sprint("value of", key) + item := newItemWithOpts( key, - fmt.Sprint("value of", key), + value, ttl+time.Duration(i)*time.Minute, - false, + c.options.itemOpts..., ) elem := c.items.lru.PushFront(item) c.items.values[key] = elem c.items.expQueue.push(elem) + + if c.options.maxCost != 0 { + c.cost += item.cost + } } } diff --git a/item.go b/item.go index 92eb73b..a72add5 100644 --- a/item.go +++ b/item.go @@ -30,28 +30,38 @@ type Item[K comparable, V any] struct { // well, so locking this mutex would be redundant. // In other words, this mutex is only useful when these fields // are being read from the outside (e.g. in event functions). - mu sync.RWMutex - key K - value V - ttl time.Duration - expiresAt time.Time - queueIndex int - version int64 + mu sync.RWMutex + key K + value V + ttl time.Duration + expiresAt time.Time + queueIndex int + version int64 + calculateCost CostFunc[K, V] + cost uint64 } // NewItem creates a new cache item. func NewItem[K comparable, V any](key K, value V, ttl time.Duration, enableVersionTracking bool) *Item[K, V] { + return newItemWithOpts(key, value, ttl, withVersionTracking[K, V](enableVersionTracking)) +} + +// newItemWithOpts creates a new cache item. +func newItemWithOpts[K comparable, V any](key K, value V, ttl time.Duration, opts ...itemOption[K, V]) *Item[K, V] { item := &Item[K, V]{ - key: key, - value: value, - ttl: ttl, + key: key, + value: value, + ttl: ttl, + version: -1, + calculateCost: func(item *Item[K, V]) uint64 { return 0 }, } - if !enableVersionTracking { - item.version = -1 + for _, opt := range opts { + opt.apply(item) } item.touch() + item.cost = item.calculateCost(item) return item } @@ -62,6 +72,7 @@ func (item *Item[K, V]) update(value V, ttl time.Duration) { defer item.mu.Unlock() item.value = value + item.cost = item.calculateCost(item) // update version if enabled if item.version > -1 { diff --git a/item_test.go b/item_test.go index 7e7699e..02becb3 100644 --- a/item_test.go +++ b/item_test.go @@ -9,6 +9,8 @@ import ( ) func Test_NewItem(t *testing.T) { + t.Parallel() + item := NewItem("key", 123, time.Hour, false) require.NotNil(t, item) assert.Equal(t, "key", item.key) @@ -18,33 +20,166 @@ func Test_NewItem(t *testing.T) { assert.WithinDuration(t, time.Now().Add(time.Hour), item.expiresAt, time.Minute) } -func Test_Item_update(t *testing.T) { - item := Item[string, string]{ - expiresAt: time.Now().Add(-time.Hour), - value: "hello", - version: 0, +func Test_newItemWithOpts(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + uc string + opts []itemOption[string, int] + assert func(t *testing.T, item *Item[string, int]) + }{ + { + uc: "item without any options", + assert: func(t *testing.T, item *Item[string, int]) { + assert.Equal(t, int64(-1), item.version) + assert.Equal(t, uint64(0), item.cost) + require.NotNil(t, item.calculateCost) + assert.Equal(t, uint64(0), item.calculateCost(item)) + }, + }, + { + uc: "item with version tracking disabled", + opts: []itemOption[string, int]{ + withVersionTracking[string, int](false), + }, + assert: func(t *testing.T, item *Item[string, int]) { + assert.Equal(t, int64(-1), item.version) + assert.Equal(t, uint64(0), item.cost) + require.NotNil(t, item.calculateCost) + assert.Equal(t, uint64(0), item.calculateCost(item)) + }, + }, + { + uc: "item with version tracking explicitly enabled", + opts: []itemOption[string, int]{ + withVersionTracking[string, int](true), + }, + assert: func(t *testing.T, item *Item[string, int]) { + assert.Equal(t, int64(0), item.version) + assert.Equal(t, uint64(0), item.cost) + require.NotNil(t, item.calculateCost) + assert.Equal(t, uint64(0), item.calculateCost(item)) + }, + }, + { + uc: "item with cost calculation", + opts: []itemOption[string, int]{ + withCostFunc[string, int](func(item *Item[string, int]) uint64 { return 5 }), + }, + assert: func(t *testing.T, item *Item[string, int]) { + assert.Equal(t, int64(-1), item.version) + assert.Equal(t, uint64(5), item.cost) + require.NotNil(t, item.calculateCost) + assert.Equal(t, uint64(5), item.calculateCost(item)) + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + item := newItemWithOpts("key", 123, time.Hour, tc.opts...) + require.NotNil(t, item) + assert.Equal(t, "key", item.key) + assert.Equal(t, 123, item.value) + assert.Equal(t, time.Hour, item.ttl) + assert.WithinDuration(t, time.Now().Add(time.Hour), item.expiresAt, time.Minute) + tc.assert(t, item) + }) } +} - item.update("test", time.Hour) - assert.Equal(t, "test", item.value) - assert.Equal(t, time.Hour, item.ttl) - assert.Equal(t, int64(1), item.version) - assert.WithinDuration(t, time.Now().Add(time.Hour), item.expiresAt, time.Minute) +func Test_Item_update(t *testing.T) { + t.Parallel() - item.update("previous ttl", PreviousOrDefaultTTL) - assert.Equal(t, "previous ttl", item.value) - assert.Equal(t, time.Hour, item.ttl) - assert.Equal(t, int64(2), item.version) - assert.WithinDuration(t, time.Now().Add(time.Hour), item.expiresAt, time.Minute) + initialTTL := -1 * time.Hour + newValue := "world" + + for _, tc := range []struct { + uc string + opts []itemOption[string, string] + ttl time.Duration + assert func(t *testing.T, item *Item[string, string]) + }{ + { + uc: "with expiration in an hour", + ttl: time.Hour, + assert: func(t *testing.T, item *Item[string, string]) { + t.Helper() + + assert.Equal(t, uint64(0), item.cost) + assert.Equal(t, time.Hour, item.ttl) + assert.Equal(t, int64(-1), item.version) + assert.WithinDuration(t, time.Now().Add(time.Hour), item.expiresAt, time.Minute) + }, + }, + { + uc: "with previous or default ttl", + ttl: PreviousOrDefaultTTL, + assert: func(t *testing.T, item *Item[string, string]) { + t.Helper() + + assert.Equal(t, uint64(0), item.cost) + assert.Equal(t, initialTTL, item.ttl) + assert.Equal(t, int64(-1), item.version) + }, + }, + { + uc: "with no ttl", + ttl: NoTTL, + assert: func(t *testing.T, item *Item[string, string]) { + t.Helper() + + assert.Equal(t, uint64(0), item.cost) + assert.Equal(t, NoTTL, item.ttl) + assert.Equal(t, int64(-1), item.version) + assert.Zero(t, item.expiresAt) + }, + }, + { + uc: "with version tracking explicitly disabled", + opts: []itemOption[string, string]{ + withVersionTracking[string, string](false), + }, + ttl: time.Hour, + assert: func(t *testing.T, item *Item[string, string]) { + t.Helper() + + assert.Equal(t, uint64(0), item.cost) + assert.Equal(t, time.Hour, item.ttl) + assert.Equal(t, int64(-1), item.version) + assert.WithinDuration(t, time.Now().Add(time.Hour), item.expiresAt, time.Minute) + }, + }, + { + uc: "with version calculation and version tracking", + opts: []itemOption[string, string]{ + withVersionTracking[string, string](true), + withCostFunc[string, string](func(item *Item[string, string]) uint64 { return uint64(len(item.value)) }), + }, + ttl: time.Hour, + assert: func(t *testing.T, item *Item[string, string]) { + t.Helper() + + assert.Equal(t, uint64(len(newValue)), item.cost) + assert.Equal(t, time.Hour, item.ttl) + assert.Equal(t, int64(1), item.version) + assert.WithinDuration(t, time.Now().Add(time.Hour), item.expiresAt, time.Minute) + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + item := newItemWithOpts[string, string]("test", "hello", initialTTL, tc.opts...) + + item.update(newValue, tc.ttl) + + assert.Equal(t, newValue, item.value) + tc.assert(t, item) + }) + } - item.update("hi", NoTTL) - assert.Equal(t, "hi", item.value) - assert.Equal(t, NoTTL, item.ttl) - assert.Equal(t, int64(3), item.version) - assert.Zero(t, item.expiresAt) } func Test_Item_touch(t *testing.T) { + t.Parallel() + var item Item[string, string] item.touch() assert.Equal(t, int64(0), item.version) @@ -57,6 +192,8 @@ func Test_Item_touch(t *testing.T) { } func Test_Item_IsExpired(t *testing.T) { + t.Parallel() + // no ttl item := Item[string, string]{ expiresAt: time.Now().Add(-time.Hour), @@ -74,6 +211,8 @@ func Test_Item_IsExpired(t *testing.T) { } func Test_Item_Key(t *testing.T) { + t.Parallel() + item := Item[string, string]{ key: "test", } @@ -82,6 +221,8 @@ func Test_Item_Key(t *testing.T) { } func Test_Item_Value(t *testing.T) { + t.Parallel() + item := Item[string, string]{ value: "test", } @@ -90,6 +231,8 @@ func Test_Item_Value(t *testing.T) { } func Test_Item_TTL(t *testing.T) { + t.Parallel() + item := Item[string, string]{ ttl: time.Hour, } @@ -98,6 +241,8 @@ func Test_Item_TTL(t *testing.T) { } func Test_Item_ExpiresAt(t *testing.T) { + t.Parallel() + now := time.Now() item := Item[string, string]{ expiresAt: now, @@ -107,6 +252,8 @@ func Test_Item_ExpiresAt(t *testing.T) { } func Test_Item_Version(t *testing.T) { + t.Parallel() + item := Item[string, string]{version: 5} assert.Equal(t, int64(5), item.Version()) } diff --git a/options.go b/options.go index 8a6088c..6607a09 100644 --- a/options.go +++ b/options.go @@ -15,13 +15,18 @@ func (fn optionFunc[K, V]) apply(opts *options[K, V]) { fn(opts) } +// CostFunc is used to calculate the cost of the key and the item to be +// inserted into the cache. +type CostFunc[K comparable, V any] func(item *Item[K, V]) uint64 + // options holds all available cache configuration options. type options[K comparable, V any] struct { - capacity uint64 - ttl time.Duration - loader Loader[K, V] - disableTouchOnHit bool - enableVersionTracking bool + capacity uint64 + maxCost uint64 + ttl time.Duration + loader Loader[K, V] + disableTouchOnHit bool + itemOpts []itemOption[K, V] } // applyOptions applies the provided option values to the option struct. @@ -52,7 +57,7 @@ func WithTTL[K comparable, V any](ttl time.Duration) Option[K, V] { // It has no effect when used with Get(). func WithVersion[K comparable, V any](enable bool) Option[K, V] { return optionFunc[K, V](func(opts *options[K, V]) { - opts.enableVersionTracking = enable + opts.itemOpts = append(opts.itemOpts, withVersionTracking[K, V](enable)) }) } @@ -75,3 +80,48 @@ func WithDisableTouchOnHit[K comparable, V any]() Option[K, V] { opts.disableTouchOnHit = true }) } + +// WithMaxCost sets the maximum cost the cache is allowed to use (e.g. the used memory). +// The actual cost calculation for each inserted item happens by making use of the +// callback CostFunc. +func WithMaxCost[K comparable, V any](s uint64, callback CostFunc[K, V]) Option[K, V] { + return optionFunc[K, V](func(opts *options[K, V]) { + opts.maxCost = s + opts.itemOpts = append(opts.itemOpts, withCostFunc[K, V](callback)) + }) +} + +// itemOption represents an option to be applied to an Item on creation +type itemOption[K comparable, V any] interface { + apply(item *Item[K, V]) +} + +// itemOptionFunc wraps a function and implements the itemOption interface. +type itemOptionFunc[K comparable, V any] func(*Item[K, V]) + +// apply calls the wrapped function. +func (fn itemOptionFunc[K, V]) apply(item *Item[K, V]) { + fn(item) +} + +// withVersionTracking deactivates ot activates item version tracking. +// If version tracking is disabled, the version is always -1. +// It has no effect when used with Get(). +func withVersionTracking[K comparable, V any](enable bool) itemOption[K, V] { + return itemOptionFunc[K, V](func(item *Item[K, V]) { + if enable { + item.version = 0 + } else { + item.version = -1 + } + }) +} + +// withCostFunc configures the cost calculation function for an item +func withCostFunc[K comparable, V any](costFunc CostFunc[K, V]) itemOption[K, V] { + return itemOptionFunc[K, V](func(item *Item[K, V]) { + if costFunc != nil { + item.calculateCost = costFunc + } + }) +} diff --git a/options_test.go b/options_test.go index 8cf0fb2..478dc40 100644 --- a/options_test.go +++ b/options_test.go @@ -5,9 +5,12 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_optionFunc_apply(t *testing.T) { + t.Parallel() + var called bool optionFunc[string, string](func(_ *options[string, string]) { @@ -17,6 +20,8 @@ func Test_optionFunc_apply(t *testing.T) { } func Test_applyOptions(t *testing.T) { + t.Parallel() + var opts options[string, string] applyOptions(&opts, @@ -29,6 +34,8 @@ func Test_applyOptions(t *testing.T) { } func Test_WithCapacity(t *testing.T) { + t.Parallel() + var opts options[string, string] WithCapacity[string, string](12).apply(&opts) @@ -36,6 +43,8 @@ func Test_WithCapacity(t *testing.T) { } func Test_WithTTL(t *testing.T) { + t.Parallel() + var opts options[string, string] WithTTL[string, string](time.Hour).apply(&opts) @@ -43,16 +52,26 @@ func Test_WithTTL(t *testing.T) { } func Test_WithVersion(t *testing.T) { + t.Parallel() + var opts options[string, string] + var item Item[string, string] WithVersion[string, string](true).apply(&opts) - assert.Equal(t, true, opts.enableVersionTracking) + assert.Len(t, opts.itemOpts, 1) + opts.itemOpts[0].apply(&item) + assert.Equal(t, int64(0), item.version) + opts.itemOpts = []itemOption[string, string]{} WithVersion[string, string](false).apply(&opts) - assert.Equal(t, false, opts.enableVersionTracking) + assert.Len(t, opts.itemOpts, 1) + opts.itemOpts[0].apply(&item) + assert.Equal(t, int64(-1), item.version) } func Test_WithLoader(t *testing.T) { + t.Parallel() + var opts options[string, string] l := LoaderFunc[string, string](func(_ *Cache[string, string], _ string) *Item[string, string] { @@ -63,8 +82,54 @@ func Test_WithLoader(t *testing.T) { } func Test_WithDisableTouchOnHit(t *testing.T) { + t.Parallel() + var opts options[string, string] WithDisableTouchOnHit[string, string]().apply(&opts) assert.True(t, opts.disableTouchOnHit) } + +func Test_WithMaxCost(t *testing.T) { + t.Parallel() + + var opts options[string, string] + var item Item[string, string] + + WithMaxCost[string, string](1024, func(item *Item[string, string]) uint64 { return 1 }).apply(&opts) + + assert.Equal(t, uint64(1024), opts.maxCost) + assert.Len(t, opts.itemOpts, 1) + opts.itemOpts[0].apply(&item) + assert.Equal(t, uint64(0), item.cost) + assert.NotNil(t, item.calculateCost) + assert.Equal(t, uint64(1), item.calculateCost(&item)) +} + +func Test_withVersionTracking(t *testing.T) { + t.Parallel() + + var item Item[string, string] + + opt := withVersionTracking[string, string](false) + opt.apply(&item) + assert.Equal(t, int64(-1), item.version) + + opt = withVersionTracking[string, string](true) + opt.apply(&item) + assert.Equal(t, int64(0), item.version) +} + +func Test_withCostFunc(t *testing.T) { + t.Parallel() + + var item Item[string, string] + + opt := withCostFunc[string, string](func(item *Item[string, string]) uint64 { + return 10 + }) + opt.apply(&item) + assert.Equal(t, uint64(0), item.cost) + require.NotNil(t, item.calculateCost) + assert.Equal(t, uint64(10), item.calculateCost(&item)) +}