Skip to content

Commit

Permalink
first pass at masking interface:
Browse files Browse the repository at this point in the history
- add mask/unmask and apply_mask methods to TreeNeuron, MeshNeuron and Dotprops
- add is_masked property for all neurons
- add `navis.NeuronMask` class
- add __length__ to all neurons
- dotprops: clear `_tree` with temporary attributes
  • Loading branch information
schlegelp committed Oct 24, 2024
1 parent 6a501ae commit 88a2dec
Show file tree
Hide file tree
Showing 8 changed files with 1,261 additions and 316 deletions.
23 changes: 14 additions & 9 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ learn more!
``TreeNeurons``, ``MeshNeurons``, ``VoxelNeurons`` and ``Dotprops`` are neuron
classes. ``NeuronLists`` are containers thereof.

| Class | Description |
|------|------|
| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. |
| [`navis.MeshNeuron`][] | Meshes with vertices and faces. |
| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). |
| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. |
| [`navis.NeuronList`][] | Containers for neurons. |
| Class | Description |
|-------------------------|---------------------------------------------------------|
| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. |
| [`navis.MeshNeuron`][] | Meshes with vertices and faces. |
| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). |
| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. |
| [`navis.NeuronList`][] | Containers for neurons. |

### General Neuron methods

Expand Down Expand Up @@ -89,6 +89,7 @@ to all neurons:
| `Neuron.type` | {{ autosummary("navis.BaseNeuron.type") }} |
| `Neuron.soma` | {{ autosummary("navis.BaseNeuron.soma") }} |
| `Neuron.bbox` | {{ autosummary("navis.BaseNeuron.bbox") }} |
| `Neuron.is_masked` | {{ autosummary("navis.BaseNeuron.is_masked") }} |

!!! note

Expand Down Expand Up @@ -119,6 +120,8 @@ this neuron type. Note that most of them are simply short-hands for the other
| [`TreeNeuron.reroot()`][navis.TreeNeuron.reroot] | {{ autosummary("navis.TreeNeuron.reroot") }} |
| [`TreeNeuron.resample()`][navis.TreeNeuron.resample] | {{ autosummary("navis.TreeNeuron.resample") }} |
| [`TreeNeuron.snap()`][navis.TreeNeuron.snap] | {{ autosummary("navis.TreeNeuron.snap") }} |
| [`TreeNeuron.mask()`][navis.TreeNeuron.mask] | {{ autosummary("navis.TreeNeuron.mask") }} |
| [`TreeNeuron.unmask()`][navis.TreeNeuron.unmask] | {{ autosummary("navis.TreeNeuron.unmask") }} |

In addition, a [`navis.TreeNeuron`][] has a range of different properties:

Expand Down Expand Up @@ -146,7 +149,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties:
| [`TreeNeuron.vertices`][navis.TreeNeuron.vertices] | {{ autosummary("navis.TreeNeuron.vertices") }} |
| [`TreeNeuron.volume`][navis.TreeNeuron.volume] | {{ autosummary("navis.TreeNeuron.volume") }} |


#### Skeleton utility functions

| Function | Description |
Expand All @@ -158,7 +160,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties:
| [`navis.graph.skeleton_adjacency_matrix()`][navis.graph.skeleton_adjacency_matrix] | {{ autosummary("navis.graph.skeleton_adjacency_matrix") }} |



### Mesh neurons

Properties specific to [`navis.MeshNeuron`][]:
Expand All @@ -178,6 +179,8 @@ Methods specific to [`navis.MeshNeuron`][]:
| [`MeshNeuron.skeletonize()`][navis.MeshNeuron.skeletonize] | {{ autosummary("navis.MeshNeuron.skeletonize") }} |
| [`MeshNeuron.snap()`][navis.MeshNeuron.snap] | {{ autosummary("navis.MeshNeuron.snap") }} |
| [`MeshNeuron.validate()`][navis.MeshNeuron.validate] | {{ autosummary("navis.MeshNeuron.validate") }} |
| [`MeshNeuron.mask()`][navis.MeshNeuron.mask] | {{ autosummary("navis.MeshNeuron.mask") }} |
| [`MeshNeuron.unmask()`][navis.MeshNeuron.unmask] | {{ autosummary("navis.MeshNeuron.unmask") }} |


### Voxel neurons
Expand Down Expand Up @@ -215,6 +218,8 @@ These are methods and properties specific to [Dotprops][navis.Dotprops]:
| [`Dotprops.alpha`][navis.Dotprops.alpha] | {{ autosummary("navis.Dotprops.alpha") }} |
| [`Dotprops.to_skeleton()`][navis.Dotprops.to_skeleton] | {{ autosummary("navis.Dotprops.to_skeleton") }} |
| [`Dotprops.snap()`][navis.Dotprops.snap] | {{ autosummary("navis.Dotprops.snap") }} |
| [`Dotprops.mask()`][navis.Dotprops.mask] | {{ autosummary("navis.Dotprops.mask") }} |
| [`Dotprops.unmask()`][navis.Dotprops.unmask] | {{ autosummary("navis.Dotprops.unmask") }} |

### Converting between types

Expand Down
15 changes: 13 additions & 2 deletions navis/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,22 @@
from .dotprop import Dotprops
from .voxel import VoxelNeuron
from .neuronlist import NeuronList
from .masking import NeuronMask
from .core_utils import make_dotprops, to_neuron_space, NeuronProcessor

from typing import Union

NeuronObject = Union[NeuronList, TreeNeuron, BaseNeuron, MeshNeuron]

__all__ = ['Volume', 'Neuron', 'BaseNeuron', 'TreeNeuron', 'MeshNeuron',
'Dotprops', 'VoxelNeuron', 'NeuronList', 'make_dotprops']
__all__ = [
"Volume",
"Neuron",
"BaseNeuron",
"TreeNeuron",
"MeshNeuron",
"NeuronMask",
"Dotprops",
"VoxelNeuron",
"NeuronList",
"make_dotprops",
]
96 changes: 95 additions & 1 deletion navis/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@


def Neuron(
x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], **metadata
x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"],
**metadata, # noqa: F821
):
"""Constructor for Neuron objects. Depending on the input, either a
`TreeNeuron` or a `MeshNeuron` is returned.
Expand Down Expand Up @@ -195,6 +196,9 @@ class BaseNeuron(UnitObject):
#: Core data table(s) used to calculate hash
CORE_DATA = []

#: Property used to calculate length of neuron
_LENGTH_DATA = None

def __init__(self, **kwargs):
# Set a random ID -> may be replaced later
self.id = uuid.uuid4()
Expand Down Expand Up @@ -303,6 +307,14 @@ def __isub__(self, other):
"""Subtraction with assignment (-=)."""
return self.__sub__(other, copy=False)

def __len__(self):
if self._LENGTH_DATA is None:
return None
# Deal with potential empty neurons
if not hasattr(self, self._LENGTH_DATA):
return 0
return len(getattr(self, self._LENGTH_DATA))

def _repr_html_(self):
frame = self.summary().to_frame()
frame.columns = [""]
Expand Down Expand Up @@ -654,6 +666,7 @@ def copy(self, deepcopy=False) -> "BaseNeuron":

def summary(self, add_props=None) -> pd.Series:
"""Get a summary of this neuron."""

# Do not remove the list -> otherwise we might change the original!
props = list(self.SUMMARY_PROPS)

Expand Down Expand Up @@ -721,6 +734,87 @@ def plot3d(self, **kwargs):

return plot3d(core.NeuronList(self, make_copy=False), **kwargs)

@property
def is_masked(self):
"""Test if neuron is masked.
See Also
--------
[`navis.BaseNeuron.mask`][]
Mask neuron.
[`navis.BaseNeuron.unmask`][]
Remove mask from neuron.
[`navis.NeuronMask`][]
Context manager for masking neurons.
"""
return hasattr(self, "_masked_data")

def mask(self, mask):
"""Mask neuron."""
raise NotImplementedError(
f"Masking not implemented for neuron of type {type(self)}."
)

def unmask(self):
"""Unmask neuron.
Returns the neuron to its original state before masking.
Returns
-------
self
See Also
--------
[`Neuron.is_masked`][navis.BaseNeuron.is_masked]
Check if neuron. is masked.
[`Neuron.mask`][navis.BaseNeuron.unmask]
Mask neuron.
[`navis.NeuronMask`][]
Context manager for masking neurons.
"""
if not self.is_masked:
raise ValueError("Neuron is not masked.")

for k, v in self._masked_data.items():
if hasattr(self, k):
setattr(self, k, v)

delattr(self, "_mask")
delattr(self, "_masked_data")
self._clear_temp_attr()

return self

def apply_mask(self, inplace=False):
"""Apply mask to neuron.
This will effectively make the mask permanent.
Parameters
----------
inplace : bool
If True will apply mask in-place. If False
will return a copy and the original neuron
will remain masked.
Returns
-------
Neuron
Neuron with mask applied.
"""
if not self.is_masked:
raise ValueError("Neuron is not masked.")

n = self if inplace else self.copy()

delattr(n, "_mask")
delattr(n, "_masked_data")

return n

def map_units(
self,
units: Union[pint.Unit, str],
Expand Down
139 changes: 135 additions & 4 deletions navis/core/dotprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ class Dotprops(BaseNeuron):
EQ_ATTRIBUTES = ['name', 'n_points', 'k']

#: Temporary attributes that need clearing when neuron data changes
TEMP_ATTR = ['_memory_usage']
TEMP_ATTR = ['_memory_usage', "_tree"]

#: Core data table(s) used to calculate hash
_CORE_DATA = ['points', 'vect']

#: Property used to calculate length of neuron
_LENGTH_DATA = 'points'

def __init__(self,
points: np.ndarray,
k: int,
Expand Down Expand Up @@ -230,9 +233,6 @@ def __getstate__(self):

return state

def __len__(self):
return len(self.points)

@property
def alpha(self):
"""Alpha value for tangent vectors (optional)."""
Expand Down Expand Up @@ -539,6 +539,137 @@ def drop_fluff(self, epsilon, keep_size: int = None, n_largest: int = None, inpl
if not inplace:
return x

def mask(self, mask, copy=True):
"""Mask neuron with given mask.
This is always done in-place!
Parameters
----------
mask : np.ndarray
Mask to apply. Can be:
- 1D array with boolean values
- callable that accepts a neuron and returns a mask
- string with property name
Returns
-------
self
The masked neuron.
See Also
--------
[`Dotprops.unmask`][navis.Dotprops.unmask]
Remove mask from neuron.
[`Dotprops.is_masked`][navis.Dotprops.is_masked]
Check if neuron is masked.
[`navis.NeuronMask`][]
Context manager for masking neurons.
"""
if self.is_masked:
raise ValueError(
"Neuron already masked. Layering multiple masks is currently not supported, please unmask first."
)

if callable(mask):
mask = mask(self)
elif isinstance(mask, str):
mask = getattr(self, mask)

mask = np.asarray(mask)

if mask.dtype != bool:
raise ValueError("Mask must be boolean array.")
elif mask.shape[0] != len(self):
raise ValueError("Mask must have same length as points.")

self._mask = mask
self._masked_data = {}
self._masked_data['_points'] = self.points

# Drop soma if masked out
if self.soma is not None:
if isinstance(self.soma, (list, np.ndarray)):
soma_left = self.soma[mask[self.soma]]
self._masked_data['_soma'] = self.soma

if any(soma_left):
self.soma = soma_left
else:
self.soma = None
elif not mask[self.soma]:
self._masked_data['_soma'] = self.soma
self.soma = None

# N.B. we're directly setting `._nodes`` to avoid overhead from checks
for att in ("_points", "_vect", "_alpha"):
if hasattr(self, att):
self._masked_data[att] = getattr(self, att)
setattr(self, att, getattr(self, att)[mask])

if copy:
setattr(self, att, getattr(self, att).copy())

if hasattr(self, "_connectors") and "point_ix" in self._connectors.columns:
self._masked_data['connectors'] = self.connectors
self._connectors = self._connectors.loc[
self.connectors.point_ix.isin(np.arange(len(mask))[mask])
]
if copy:
self._connectors = self._connectors.copy()

self._clear_temp_attr()

return self

def unmask(self, reset=True):
"""Unmask neuron.
Returns the neuron to its original state before masking.
Parameters
----------
reset : bool
Whether to reset the neuron to its original state before masking.
If False, edits made to the neuron after masking will be kept.
Returns
-------
self
See Also
--------
[`Dotprops.is_masked`][navis.Dotprops.is_masked]
Check if neuron is masked.
[`Dotprops.mask`][navis.Dotprops.mask]
Mask neuron.
[`navis.NeuronMask`][]
Context manager for masking neurons.
"""
if not self.is_masked:
raise ValueError("Neuron is not masked.")

if reset:
# Unmask and reset to original state
super().unmask()
return self

mask = self._mask
for k, v in self._masked_data.items():
# Combine with current data
if hasattr(self, k):
v = np.concatenate((v[~mask], getattr(self, k)), axis=0)
setattr(self, k, v)

del self._mask
del self._masked_data

self._clear_temp_attr()

return self

def recalculate_tangents(self, k: int, inplace=False):
"""Recalculate tangent vectors and alpha with a new `k`.
Expand Down
Loading

0 comments on commit 88a2dec

Please sign in to comment.