diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index 7123db42..68fe18ef 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -105,10 +105,18 @@ def update_item(self, item): current_item.cascade_update() def remove_item(self, id_): + """Removes item and its referrers from the cache. + + Args: + id_ (int): item's database id + + Returns: + list of CacheItem: removed items + """ current_item = self.get(id_) if current_item: - current_item.cascade_remove() - return current_item + return current_item.cascade_remove() + return [] class CacheItem(dict): @@ -146,6 +154,10 @@ def key(self): return None return (self._item_type, self["id"]) + @property + def referrers(self): + return self._referrers + def __getattr__(self, name): """Overridden method to return the dictionary key named after the attribute, or None if it doesn't exist.""" return self.get(name) @@ -193,6 +205,32 @@ def get(self, key, default=None): def copy(self): return type(self)(self._db_cache, self._item_type, **self) + def deepcopy(self): + """Makes a deep copy of the item. + + Returns: + CacheItem: copied item + """ + copy = self.copy() + self._copy_internal_state(copy) + return copy + + def _copy_internal_state(self, other): + """Copies item's internal state to other cache item. + + Args: + other (CacheItem): target item + """ + other._referrers = {key: item.deepcopy() for key, item in self._referrers.items()} + other._weak_referrers = {key: item.deepcopy() for key, item in self._weak_referrers.items()} + other.readd_callbacks = set(self.readd_callbacks) + other.update_callbacks = set(self.update_callbacks) + other.remove_callbacks = set(self.remove_callbacks) + other._to_remove = self._to_remove + other._removed = self._removed + other._corrupted = self._corrupted + other._valid = self._valid + def is_valid(self): if self._valid is not None: return self._valid @@ -221,14 +259,27 @@ def add_weak_referrer(self, referrer): if referrer.key not in self._referrers: self._weak_referrers[referrer.key] = referrer + def readd(self): + """Adds item back to cache without adding its referrers.""" + if not self._removed: + return + self._removed = False + self._to_remove = False + self._call_readd_callbacks() + def cascade_readd(self): if not self._removed: return self._removed = False + self._to_remove = False for referrer in self._referrers.values(): referrer.cascade_readd() for weak_referrer in self._weak_referrers.values(): weak_referrer.call_update_callbacks() + self._call_readd_callbacks() + + def _call_readd_callbacks(self): + """Calls readd callbacks and removes obsolete ones.""" obsolete = set() for callback in self.readd_callbacks: if not callback(self): @@ -236,8 +287,15 @@ def cascade_readd(self): self.readd_callbacks -= obsolete def cascade_remove(self): + """Sets item and its referrers as removed. + + Calls necessary callbacks on weak referrers too. + + Returns: + list of CacheItem: removed items + """ if self._removed: - return + return [] self._removed = True self._to_remove = False self._valid = None @@ -246,10 +304,12 @@ def cascade_remove(self): if not callback(self): obsolete.add(callback) self.remove_callbacks -= obsolete + removed_items = [self] for referrer in self._referrers.values(): - referrer.cascade_remove() + removed_items += referrer.cascade_remove() for weak_referrer in self._weak_referrers.values(): weak_referrer.call_update_callbacks() + return removed_items def cascade_update(self): self.call_update_callbacks() diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index d2aa30eb..f4718a3f 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -28,19 +28,26 @@ def cascade_remove_items(self, cache=None, **kwargs): Args: **kwargs: keyword is table name, argument is list of ids to remove + + Returns: + list of CacheItem: removed items """ cascading_ids = self.cascading_ids(cache=cache, **kwargs) - self.remove_items(**cascading_ids) + return self.remove_items(**cascading_ids) def remove_items(self, **kwargs): """Removes items by id, *not in cascade*. Args: **kwargs: keyword is table name, argument is list of ids to remove + + Returns: + list of CacheItems: removed items """ if not self.committing: return self._make_commit_id() + removed_items = [] for tablename, ids in kwargs.items(): if not ids: continue @@ -52,10 +59,11 @@ def remove_items(self, **kwargs): table_cache = self.cache.get(tablename) if table_cache: for id_ in ids: - table_cache.remove_item(id_) + removed_items += table_cache.remove_item(id_) except DBAPIError as e: msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e + return removed_items # pylint: disable=redefined-builtin def cascading_ids(self, cache=None, **kwargs):