Skip to content

Commit

Permalink
feat (WMS): Improve caching performance of Limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisburr committed Nov 22, 2024
1 parent cb84f65 commit b10bed7
Showing 1 changed file with 121 additions and 13 deletions.
134 changes: 121 additions & 13 deletions src/DIRAC/WorkloadManagementSystem/Client/Limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
Utilities and classes here are used by the Matcher
"""
import threading
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, wait, Future
from functools import partial
from typing import Any

from cachetools import TTLCache

from DIRAC import S_OK, S_ERROR
from DIRAC import gLogger

Expand All @@ -12,10 +21,109 @@
from DIRAC.WorkloadManagementSystem.Client import JobStatus


class TwoLevelCache:
"""A two-level caching system with soft and hard time-to-live (TTL) expiration.
This cache implements a two-tier caching mechanism to allow for background refresh
of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback,
which helps in reducing latency and maintaining data freshness.
Attributes:
soft_cache (TTLCache): A cache with a shorter TTL for quick access.
hard_cache (TTLCache): A cache with a longer TTL as a fallback.
locks (defaultdict): Thread-safe locks for each cache key.
futures (dict): Stores ongoing asynchronous population tasks.
pool (ThreadPoolExecutor): Thread pool for executing cache population tasks.
Args:
soft_ttl (int): Time-to-live in seconds for the soft cache.
hard_ttl (int): Time-to-live in seconds for the hard cache.
max_workers (int): Maximum number of workers in the thread pool.
max_items (int): Maximum number of items in the cache.
Example:
>>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300)
>>> def populate_func():
... return "cached_value"
>>> value = cache.get("key", populate_func)
Note:
The cache uses a ThreadPoolExecutor with a maximum of 10 workers to
handle concurrent cache population requests.
"""

def __init__(self, soft_ttl: int, hard_ttl: int, *, max_workers: int = 10, max_items: int = 1_000_000):
"""Initialize the TwoLevelCache with specified TTLs."""
self.soft_cache = TTLCache(max_items, soft_ttl)
self.hard_cache = TTLCache(max_items, hard_ttl)
self.locks = defaultdict(threading.Lock)
self.futures: dict[str, Future] = {}
self.pool = ThreadPoolExecutor(max_workers=max_workers)

def get(self, key: str, populate_func: Callable[[], Any]):
"""Retrieve a value from the cache, populating it if necessary.
This method first checks the soft cache for the key. If not found,
it checks the hard cache while initiating a background refresh.
If the key is not in either cache, it waits for the populate_func
to complete and stores the result in both caches.
Locks are used to ensure there is never more than one concurrent
population task for a given key.
Args:
key (str): The cache key to retrieve or populate.
populate_func (Callable[[], Any]): A function to call to populate the cache
if the key is not found.
Returns:
Any: The cached value associated with the key.
Note:
This method is thread-safe and handles concurrent requests for the same key.
"""
if result := self.soft_cache.get(key):
return result
with self.locks[key]:
if key not in self.futures:
self.futures[key] = self.pool.submit(self._work, key, populate_func)
if result := self.hard_cache.get(key):
self.soft_cache[key] = result
return result
# It is critical that ``future`` is waited for outside of the lock as
# _work aquires the lock before filling the caches. This also means
# we can gaurentee that the future has not yet been removed from the
# futures dict.
future = self.futures[key]
wait([future])
return self.hard_cache[key]

def _work(self, key: str, populate_func: Callable[[], Any]) -> None:
"""Internal method to execute the populate_func and update caches.
This method is intended to be run in a separate thread. It calls the
populate_func, stores the result in both caches, and cleans up the
associated future.
Args:
key (str): The cache key to populate.
populate_func (Callable[[], Any]): The function to call to get the value.
Note:
This method is not intended to be called directly by users of the class.
"""
result = populate_func()
with self.locks[key]:
self.futures.pop(key)
self.hard_cache[key] = result
self.soft_cache[key] = result


class Limiter:
# static variables shared between all instances of this class
csDictCache = DictCache()
condCache = DictCache()
newCache = TwoLevelCache(10, 300)
delayMem = {}

def __init__(self, jobDB=None, opsHelper=None, pilotRef=None):
Expand Down Expand Up @@ -177,19 +285,7 @@ def __getRunningCondition(self, siteName, gridCE=None):
if attName not in self.jobDB.jobAttributeNames:
self.log.error("Attribute does not exist", f"({attName}). Check the job limits")
continue
cK = f"Running:{siteName}:{attName}"
data = self.condCache.get(cK)
if not data:
result = self.jobDB.getCounters(
"Jobs",
[attName],
{"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]},
)
if not result["OK"]:
return result
data = result["Value"]
data = {k[0][attName]: k[1] for k in data}
self.condCache.add(cK, 10, data)
data = self.newCache.get(f"Running:{siteName}:{attName}", partial(self._countsByJobType, siteName, attName))
for attValue in limitsDict[attName]:
limit = limitsDict[attName][attValue]
running = data.get(attValue, 0)
Expand Down Expand Up @@ -249,3 +345,15 @@ def __getDelayCondition(self, siteName):
negCond[attName] = []
negCond[attName].append(attValue)
return S_OK(negCond)

def _countsByJobType(self, siteName, attName):
result = self.jobDB.getCounters(
"Jobs",
[attName],
{"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]},
)
if not result["OK"]:
return result
data = result["Value"]
data = {k[0][attName]: k[1] for k in data}
return data

0 comments on commit b10bed7

Please sign in to comment.