From ccc05ec80a46b27892e1dc5699978b50c4e0d892 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Wed, 14 Mar 2018 22:37:16 -0700 Subject: [PATCH] maybe we should change the intersection method... --- rust | 2 +- sourmash/minhash.py | 35 +++++++++++++++++++---------------- sourmash/sbt.py | 3 +-- sourmash/sbtmh.py | 3 +-- tests/test__minhash.py | 24 ++++++++++++------------ 5 files changed, 34 insertions(+), 33 deletions(-) diff --git a/rust b/rust index ebe1993b19..ee46b7468a 160000 --- a/rust +++ b/rust @@ -1 +1 @@ -Subproject commit ebe1993b195bda6ebd2f2a1f42d52caba4355b9d +Subproject commit ee46b7468a38e9f5592ac3a4b9718fd3ab6f2345 diff --git a/sourmash/minhash.py b/sourmash/minhash.py index 869bbafa58..2baf1e4323 100644 --- a/sourmash/minhash.py +++ b/sourmash/minhash.py @@ -270,27 +270,30 @@ def downsample_scaled(self, new_num): return a - def intersection(self, other): + def intersection(self, other, in_common=False): + if not isinstance(other, MinHash): + raise TypeError("Must be a MinHash!") + if self.num != other.num: err = 'must have same num: {} != {}'.format(self.num, other.num) raise TypeError(err) - else: - num = self.num - combined_mh = MinHash(num, self.ksize, - is_protein=self.is_protein, - seed=self.seed, - max_hash=self.max_hash, - track_abundance=self.track_abundance) - - combined_mh.merge(self) - combined_mh.merge(other) - - common = set(self.get_mins()) - common.intersection_update(other.get_mins()) - common.intersection_update(combined_mh.get_mins()) + if in_common: + # TODO: copy from buffer to Python land instead, + # this way involves more moving data around. + combined_mh = self.copy_and_clear() + combined_mh.merge(self) + combined_mh.merge(other) + + size = len(combined_mh) + common = set(self.get_mins()) + common.intersection_update(other.get_mins()) + common.intersection_update(combined_mh.get_mins()) + else: + size = self._methodcall(lib.kmerminhash_intersection, other._get_objptr()) + common = set() - return common, max(len(combined_mh), 1) + return common, max(size, 1) def compare(self, other): if self.num != other.num: diff --git a/sourmash/sbt.py b/sourmash/sbt.py index 5e6d1a4e5b..f81a8d079c 100644 --- a/sourmash/sbt.py +++ b/sourmash/sbt.py @@ -534,8 +534,7 @@ def _fill_max_n_below(self): parent = self.parent(i) if parent.pos not in self.missing_nodes: max_n_below = parent.node.metadata.get('max_n_below', 0) - max_n_below = max(len(n.data.minhash.get_mins()), - max_n_below) + max_n_below = max(len(n.data.minhash), max_n_below) parent.node.metadata['max_n_below'] = max_n_below current = parent diff --git a/sourmash/sbtmh.py b/sourmash/sbtmh.py index 631772937d..aca81304f0 100644 --- a/sourmash/sbtmh.py +++ b/sourmash/sbtmh.py @@ -55,8 +55,7 @@ def update(self, parent): for v in self.data.minhash.get_mins(): parent.data.count(v) max_n_below = parent.metadata.get('max_n_below', 0) - max_n_below = max(len(self.data.minhash.get_mins()), - max_n_below) + max_n_below = max(len(self.data.minhash), max_n_below) parent.metadata['max_n_below'] = max_n_below @property diff --git a/tests/test__minhash.py b/tests/test__minhash.py index 7291ef5001..4c3a463c8a 100644 --- a/tests/test__minhash.py +++ b/tests/test__minhash.py @@ -274,38 +274,38 @@ def test_intersection_1(track_abundance): common = set(a.get_mins()) combined_size = 3 - intersection, size = a.intersection(b) + intersection, size = a.intersection(b, in_common=True) assert intersection == common assert combined_size == size - intersection, size = b.intersection(b) + intersection, size = b.intersection(b, in_common=True) assert intersection == common assert combined_size == size - intersection, size = b.intersection(a) + intersection, size = b.intersection(a, in_common=True) assert intersection == common assert combined_size == size - intersection, size = a.intersection(a) + intersection, size = a.intersection(a, in_common=True) assert intersection == common assert combined_size == size # add same sequence again b.add_sequence('TGCCGCCCAGCA') - intersection, size = a.intersection(b) + intersection, size = a.intersection(b, in_common=True) assert intersection == common assert combined_size == size - intersection, size = b.intersection(b) + intersection, size = b.intersection(b, in_common=True) assert intersection == common assert combined_size == size - intersection, size = b.intersection(a) + intersection, size = b.intersection(a, in_common=True) assert intersection == common assert combined_size == size - intersection, size = a.intersection(a) + intersection, size = a.intersection(a, in_common=True) assert intersection == common assert combined_size == size @@ -315,18 +315,18 @@ def test_intersection_1(track_abundance): new_in_common = set(a.get_mins()).intersection(set(b.get_mins())) new_combined_size = 8 - intersection, size = a.intersection(b) + intersection, size = a.intersection(b, in_common=True) assert intersection == new_in_common assert size == new_combined_size - intersection, size = b.intersection(a) + intersection, size = b.intersection(a, in_common=True) assert intersection == new_in_common assert size == new_combined_size - intersection, size = a.intersection(a) + intersection, size = a.intersection(a, in_common=True) assert intersection == set(a.get_mins()) - intersection, size = b.intersection(b) + intersection, size = b.intersection(b, in_common=True) assert intersection == set(b.get_mins())