Skip to content

Commit

Permalink
Annotate the remainder of the tracker module
Browse files Browse the repository at this point in the history
  • Loading branch information
mthuurne committed Mar 21, 2023
1 parent e60e881 commit 998eb70
Showing 1 changed file with 95 additions and 53 deletions.
148 changes: 95 additions & 53 deletions model_utils/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@

from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload

from django.core.exceptions import FieldError
from django.db import models
from django.db.models.fields.files import FieldFile, FileDescriptor
from django.db.models.query_utils import DeferredAttribute

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Mapping
from types import TracebackType

class _AugmentedModel(models.Model):
_instance_initialized: bool
_deferred_fields: set[str]


T = TypeVar("T")


class LightStateFieldFile(FieldFile):
"""
FieldFile subclass with the only aim to remove the instance from the state.
Expand All @@ -24,32 +31,34 @@ class LightStateFieldFile(FieldFile):
Django 3.1+ can make the app unusable, as CPU and memory usage gets easily
multiplied by magnitudes.
"""
def __getstate__(self):
def __getstate__(self) -> dict[str, Any]:
"""
We don't need to deepcopy the instance, so nullify if provided.
"""
state = super().__getstate__()
# django-stubs 1.16.0 doesn't annotate __getstate__(), but it does exist
# in Django itself.
state = super().__getstate__() # type: ignore[misc]
if 'instance' in state:
state['instance'] = None
return state


def lightweight_deepcopy(value):
def lightweight_deepcopy(value: T) -> T:
"""
Use our lightweight class to avoid copying the instance on a FieldFile deepcopy.
"""
if isinstance(value, FieldFile):
value = LightStateFieldFile(
value = cast(T, LightStateFieldFile(
instance=value.instance,
field=value.field,
name=value.name,
)
))
return deepcopy(value)


class DescriptorMixin:
field_name: str
tracker_instance: Any = None
tracker_instance: FieldInstanceTracker

def __get__(
self,
Expand All @@ -75,12 +84,20 @@ def _get_field_name(self) -> str:

class DescriptorWrapper:

def __init__(self, field_name, descriptor, tracker_attname):
def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname

def __get__(self, instance, owner):
@overload
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper:
...

@overload
def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field:
...

def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field:
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
Expand All @@ -93,7 +110,7 @@ def __get__(self, instance, owner):
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
return value

def __set__(self, instance, value):
def __set__(self, instance: models.Model, value: models.Field) -> None:
initialized = hasattr(instance, '_instance_initialized')
was_deferred = self.field_name in instance.get_deferred_fields()

Expand All @@ -117,7 +134,7 @@ def __set__(self, instance, value):
instance.__dict__[self.field_name] = value

@staticmethod
def cls_for_descriptor(descriptor):
def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]:
if hasattr(descriptor, '__delete__'):
return FullDescriptorWrapper
else:
Expand All @@ -128,8 +145,8 @@ class FullDescriptorWrapper(DescriptorWrapper):
"""
Wrapper for descriptors with all three descriptor methods.
"""
def __delete__(self, obj):
self.descriptor.__delete__(obj)
def __delete__(self, obj: models.Field) -> None:
self.descriptor.__delete__(obj) # type: ignore[attr-defined]


class FieldsContext:
Expand All @@ -153,7 +170,12 @@ class FieldsContext:
"""

def __init__(self, tracker, *fields, state=None):
def __init__(
self,
tracker: FieldInstanceTracker,
*fields: str,
state: dict[str, int] | None = None
):
"""
:param tracker: FieldInstanceTracker instance to be reset after
context exit
Expand All @@ -171,7 +193,7 @@ def __init__(self, tracker, *fields, state=None):
self.fields = fields
self.state = state

def __enter__(self):
def __enter__(self) -> FieldsContext:
"""
Increments tracked fields occurrences count in shared state.
"""
Expand All @@ -180,7 +202,12 @@ def __enter__(self):
self.state[f] += 1
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
"""
Decrements tracked fields occurrences count in shared state.
Expand All @@ -198,29 +225,34 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class FieldInstanceTracker:
def __init__(self, instance, fields, field_map):
self.instance = instance
def __init__(self, instance: models.Model, fields: Iterable[str], field_map: Mapping[str, str]):
self.instance = cast("_AugmentedModel", instance)
self.fields = fields
self.field_map = field_map
self.context = FieldsContext(self, *self.fields)

def __enter__(self):
def __enter__(self) -> FieldsContext:
return self.context.__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
return self.context.__exit__(exc_type, exc_val, exc_tb)

def __call__(self, *fields):
def __call__(self, *fields: str) -> FieldsContext:
return FieldsContext(self, *fields, state=self.context.state)

@property
def deferred_fields(self):
def deferred_fields(self) -> set[str]:
return self.instance.get_deferred_fields()

def get_field_value(self, field):
def get_field_value(self, field: str) -> Any:
return getattr(self.instance, self.field_map[field])

def set_saved_fields(self, fields=None):
def set_saved_fields(self, fields: Iterable[str] | None = None) -> None:
if not self.instance.pk:
self.saved_data = {}
elif fields is None:
Expand All @@ -232,7 +264,7 @@ def set_saved_fields(self, fields=None):
for field, field_value in self.saved_data.items():
self.saved_data[field] = lightweight_deepcopy(field_value)

def current(self, fields=None):
def current(self, fields: Iterable[str] | None = None) -> dict[str, Any]:
"""Returns dict of current values for all tracked fields"""
if fields is None:
deferred_fields = self.deferred_fields
Expand All @@ -246,7 +278,7 @@ def current(self, fields=None):

return {f: self.get_field_value(f) for f in fields}

def has_changed(self, field):
def has_changed(self, field: str) -> bool:
"""Returns ``True`` if field has changed from currently saved value"""
if field in self.fields:
# deferred fields haven't changed
Expand All @@ -256,7 +288,7 @@ def has_changed(self, field):
else:
raise FieldError('field "%s" not tracked' % field)

def previous(self, field):
def previous(self, field: str) -> Any:
"""Returns currently saved value of given field"""

# handle deferred fields that have not yet been loaded from the database
Expand All @@ -276,15 +308,15 @@ def previous(self, field):

return self.saved_data.get(field)

def changed(self):
def changed(self) -> dict[str, Any]:
"""Returns dict of fields that changed since save (with old values)"""
return {
field: self.previous(field)
for field in self.fields
if self.has_changed(field)
}

def init_deferred_fields(self):
def init_deferred_fields(self) -> None:
self.instance._deferred_fields = set()
if hasattr(self.instance, '_deferred') and not self.instance._deferred:
return
Expand All @@ -295,31 +327,36 @@ class DeferredAttributeTracker(DescriptorMixin, DeferredAttribute):
class FileDescriptorTracker(DescriptorMixin, FileDescriptor):
tracker_instance = self

def _get_field_name(self):
def _get_field_name(self) -> str:
return self.field.name

self.instance._deferred_fields = self.instance.get_deferred_fields()
for field in self.instance._deferred_fields:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, FileDescriptor):
field_tracker = FileDescriptorTracker(field_obj.field)
setattr(self.instance.__class__, field, field_tracker)
file_descriptor_tracker = FileDescriptorTracker(field_obj.field)
setattr(self.instance.__class__, field, file_descriptor_tracker)
else:
field_tracker = DeferredAttributeTracker(field)
setattr(self.instance.__class__, field, field_tracker)
deferred_attribute_tracker = DeferredAttributeTracker(field)
setattr(self.instance.__class__, field, deferred_attribute_tracker)


class FieldTracker:

tracker_class = FieldInstanceTracker

def __init__(self, fields=None):
self.fields = fields
def __init__(self, fields: Iterable[str] | None = None):
# finalize_class() will replace None; pretend it is never None.
self.fields = cast("Iterable[str]", fields)

def __call__(self, func=None, fields=None):
def decorator(f):
def __call__(
self,
func: Callable | None = None,
fields: Iterable[str] | None = None
) -> Any:
def decorator(f: Callable) -> Callable:
@wraps(f)
def inner(obj, *args, **kwargs):
def inner(obj: models.Model, *args: object, **kwargs: object) -> object:
tracker = getattr(obj, self.attname)
field_list = tracker.fields if fields is None else fields
with tracker(*field_list):
Expand All @@ -330,25 +367,25 @@ def inner(obj, *args, **kwargs):
return decorator
return decorator(func)

def get_field_map(self, cls):
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
"""Returns dict mapping fields names to model attribute names"""
field_map = {field: field for field in self.fields}
all_fields = {f.name: f.attname for f in cls._meta.fields}
field_map.update(**{k: v for (k, v) in all_fields.items()
if k in field_map})
return field_map

def contribute_to_class(self, cls, name):
def contribute_to_class(self, cls: type[models.Model], name: str) -> None:
self.name = name
self.attname = '_%s' % name
models.signals.class_prepared.connect(self.finalize_class, sender=cls)

def finalize_class(self, sender, **kwargs):
def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
if self.fields is None:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
for field_name in self.fields:
descriptor = getattr(sender, field_name)
descriptor: models.Field = getattr(sender, field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
Expand All @@ -358,24 +395,29 @@ def finalize_class(self, sender, **kwargs):
setattr(sender, self.name, self)
self.patch_save(sender)

def initialize_tracker(self, sender, instance, **kwargs):
def initialize_tracker(
self,
sender: type[models.Model],
instance: models.Model,
**kwargs: object
) -> None:
if not isinstance(instance, self.model_class):
return # Only init instances of given model (including children)
tracker = self.tracker_class(instance, self.fields, self.field_map)
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
instance._instance_initialized = True
cast("_AugmentedModel", instance)._instance_initialized = True

def patch_save(self, model):
def patch_save(self, model: type[models.Model]) -> None:
self._patch(model, 'save_base', 'update_fields')
self._patch(model, 'refresh_from_db', 'fields')

def _patch(self, model, method, fields_kwarg):
def _patch(self, model: type[models.Model], method: str, fields_kwarg: str) -> None:
original = getattr(model, method)

@wraps(original)
def inner(instance, *args, **kwargs):
update_fields = kwargs.get(fields_kwarg)
def inner(instance: models.Model, *args: object, **kwargs: Any) -> object:
update_fields: Iterable[str] | None = kwargs.get(fields_kwarg)
if update_fields is None:
fields = self.fields
else:
Expand All @@ -389,7 +431,7 @@ def inner(instance, *args, **kwargs):

setattr(model, method, inner)

def __get__(self, instance, owner):
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> FieldTracker:
if instance is None:
return self
else:
Expand All @@ -398,7 +440,7 @@ def __get__(self, instance, owner):

class ModelInstanceTracker(FieldInstanceTracker):

def has_changed(self, field):
def has_changed(self, field: str) -> bool:
"""Returns ``True`` if field has changed from currently saved value"""
if not self.instance.pk:
return True
Expand All @@ -407,7 +449,7 @@ def has_changed(self, field):
else:
raise FieldError('field "%s" not tracked' % field)

def changed(self):
def changed(self) -> dict[str, Any]:
"""Returns dict of fields that changed since save (with old values)"""
if not self.instance.pk:
return {}
Expand All @@ -419,5 +461,5 @@ def changed(self):
class ModelTracker(FieldTracker):
tracker_class = ModelInstanceTracker

def get_field_map(self, cls):
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
return {field: field for field in self.fields}

0 comments on commit 998eb70

Please sign in to comment.