Skip to content

Commit

Permalink
Improve support for undo/redo functionality in Toolbox (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Aug 4, 2023
2 parents bf7b3bf + 82b2f22 commit 4168c30
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
68 changes: 64 additions & 4 deletions spinedb_api/db_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,18 @@ def update_item(self, item):
current_item.cascade_update()

def remove_item(self, id_):
"""Removes item and its referrers from the cache.
Args:
id_ (int): item's database id
Returns:
list of CacheItem: removed items
"""
current_item = self.get(id_)
if current_item:
current_item.cascade_remove()
return current_item
return current_item.cascade_remove()
return []


class CacheItem(dict):
Expand Down Expand Up @@ -146,6 +154,10 @@ def key(self):
return None
return (self._item_type, self["id"])

@property
def referrers(self):
return self._referrers

def __getattr__(self, name):
"""Overridden method to return the dictionary key named after the attribute, or None if it doesn't exist."""
return self.get(name)
Expand Down Expand Up @@ -193,6 +205,32 @@ def get(self, key, default=None):
def copy(self):
return type(self)(self._db_cache, self._item_type, **self)

def deepcopy(self):
"""Makes a deep copy of the item.
Returns:
CacheItem: copied item
"""
copy = self.copy()
self._copy_internal_state(copy)
return copy

def _copy_internal_state(self, other):
"""Copies item's internal state to other cache item.
Args:
other (CacheItem): target item
"""
other._referrers = {key: item.deepcopy() for key, item in self._referrers.items()}
other._weak_referrers = {key: item.deepcopy() for key, item in self._weak_referrers.items()}
other.readd_callbacks = set(self.readd_callbacks)
other.update_callbacks = set(self.update_callbacks)
other.remove_callbacks = set(self.remove_callbacks)
other._to_remove = self._to_remove
other._removed = self._removed
other._corrupted = self._corrupted
other._valid = self._valid

def is_valid(self):
if self._valid is not None:
return self._valid
Expand Down Expand Up @@ -221,23 +259,43 @@ def add_weak_referrer(self, referrer):
if referrer.key not in self._referrers:
self._weak_referrers[referrer.key] = referrer

def readd(self):
"""Adds item back to cache without adding its referrers."""
if not self._removed:
return
self._removed = False
self._to_remove = False
self._call_readd_callbacks()

def cascade_readd(self):
if not self._removed:
return
self._removed = False
self._to_remove = False
for referrer in self._referrers.values():
referrer.cascade_readd()
for weak_referrer in self._weak_referrers.values():
weak_referrer.call_update_callbacks()
self._call_readd_callbacks()

def _call_readd_callbacks(self):
"""Calls readd callbacks and removes obsolete ones."""
obsolete = set()
for callback in self.readd_callbacks:
if not callback(self):
obsolete.add(callback)
self.readd_callbacks -= obsolete

def cascade_remove(self):
"""Sets item and its referrers as removed.
Calls necessary callbacks on weak referrers too.
Returns:
list of CacheItem: removed items
"""
if self._removed:
return
return []
self._removed = True
self._to_remove = False
self._valid = None
Expand All @@ -246,10 +304,12 @@ def cascade_remove(self):
if not callback(self):
obsolete.add(callback)
self.remove_callbacks -= obsolete
removed_items = [self]
for referrer in self._referrers.values():
referrer.cascade_remove()
removed_items += referrer.cascade_remove()
for weak_referrer in self._weak_referrers.values():
weak_referrer.call_update_callbacks()
return removed_items

def cascade_update(self):
self.call_update_callbacks()
Expand Down
12 changes: 10 additions & 2 deletions spinedb_api/db_mapping_remove_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,26 @@ def cascade_remove_items(self, cache=None, **kwargs):
Args:
**kwargs: keyword is table name, argument is list of ids to remove
Returns:
list of CacheItem: removed items
"""
cascading_ids = self.cascading_ids(cache=cache, **kwargs)
self.remove_items(**cascading_ids)
return self.remove_items(**cascading_ids)

def remove_items(self, **kwargs):
"""Removes items by id, *not in cascade*.
Args:
**kwargs: keyword is table name, argument is list of ids to remove
Returns:
list of CacheItems: removed items
"""
if not self.committing:
return
self._make_commit_id()
removed_items = []
for tablename, ids in kwargs.items():
if not ids:
continue
Expand All @@ -52,10 +59,11 @@ def remove_items(self, **kwargs):
table_cache = self.cache.get(tablename)
if table_cache:
for id_ in ids:
table_cache.remove_item(id_)
removed_items += table_cache.remove_item(id_)
except DBAPIError as e:
msg = f"DBAPIError while removing {tablename} items: {e.orig.args}"
raise SpineDBAPIError(msg) from e
return removed_items

# pylint: disable=redefined-builtin
def cascading_ids(self, cache=None, **kwargs):
Expand Down

0 comments on commit 4168c30

Please sign in to comment.