Skip to content

Commit

Permalink
a few additional fixes for masking:
Browse files Browse the repository at this point in the history
- fix NeuronMask doctest
- TreeNeuron.un/mask: make sure to re-classify
- TreeNeuron.unmask: fix re-connecting
  • Loading branch information
schlegelp committed Oct 24, 2024
1 parent 88a2dec commit 7de016e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion navis/core/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class NeuronMask:
>>> # Grab a few skeletons
>>> nl = navis.example_neurons(3)
>>> # Label axon and dendrites
>>> navis.split_axon_dendrite(nl, label_only=True)
>>> _ = navis.split_axon_dendrite(nl, label_only=True)
>>> # Mask by axon
>>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'axon'):
... print("Axon cable length:", nl.cable_length * nl[0].units)
Expand Down
14 changes: 12 additions & 2 deletions navis/core/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,10 +1009,17 @@ def mask(self, mask, copy=True):
self._masked_data["_nodes"] = self.nodes

# N.B. we're directly setting `._nodes`` to avoid overhead from checks
self._nodes = self._nodes.loc[mask]
self._nodes = self._nodes.loc[mask].drop("type", axis=1, errors="ignore")
if copy:
self._nodes = self._nodes.copy()

# See if any parent IDs have ceased to exist
missing_parents = ~self._nodes.parent_id.isin(self._nodes.node_id) & (
self._nodes.parent_id >= 0
)
if any(missing_parents):
self.nodes.loc[missing_parents, "parent_id"] = -1

if hasattr(self, "_connectors"):
self._masked_data["_connectors"] = self.connectors
self._connectors = self._connectors.loc[
Expand Down Expand Up @@ -1092,7 +1099,7 @@ def unmask(self, reset=True):
if r not in pre_parents:
continue
# Skip if this was also a root in the pre-masked data
if pre_parents[r] >= 0:
if pre_parents[r] < 0:
continue
# Skip if the old parent does not exist anymore
if pre_parents[r] not in self.nodes.node_id.values:
Expand All @@ -1110,6 +1117,9 @@ def unmask(self, reset=True):
if any(missing_parents):
self.nodes.loc[missing_parents, "parent_id"] = -1

# Force nodes to be re-classified
self.nodes.drop("type", axis=1, errors="ignore", inplace=True)

# TODO: Make sure that edges have a consistent orientation
# (not sure this is much of a problem)

Expand Down

0 comments on commit 7de016e

Please sign in to comment.