Skip to content

Commit

Permalink
maybe we should change the intersection method...
Browse files Browse the repository at this point in the history
  • Loading branch information
luizirber committed Mar 15, 2018
1 parent 4c7855d commit ccc05ec
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 33 deletions.
2 changes: 1 addition & 1 deletion rust
Submodule rust updated 4 files
+1 −5 appveyor.yml
+16 −0 src/ffi.rs
+7 −0 src/lib.rs
+10 −8 target/sourmash.h
35 changes: 19 additions & 16 deletions sourmash/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,27 +270,30 @@ def downsample_scaled(self, new_num):

return a

def intersection(self, other):
def intersection(self, other, in_common=False):
if not isinstance(other, MinHash):
raise TypeError("Must be a MinHash!")

if self.num != other.num:
err = 'must have same num: {} != {}'.format(self.num, other.num)
raise TypeError(err)
else:
num = self.num

combined_mh = MinHash(num, self.ksize,
is_protein=self.is_protein,
seed=self.seed,
max_hash=self.max_hash,
track_abundance=self.track_abundance)

combined_mh.merge(self)
combined_mh.merge(other)

common = set(self.get_mins())
common.intersection_update(other.get_mins())
common.intersection_update(combined_mh.get_mins())
if in_common:
# TODO: copy from buffer to Python land instead,
# this way involves more moving data around.
combined_mh = self.copy_and_clear()
combined_mh.merge(self)
combined_mh.merge(other)

size = len(combined_mh)
common = set(self.get_mins())
common.intersection_update(other.get_mins())
common.intersection_update(combined_mh.get_mins())
else:
size = self._methodcall(lib.kmerminhash_intersection, other._get_objptr())
common = set()

return common, max(len(combined_mh), 1)
return common, max(size, 1)

def compare(self, other):
if self.num != other.num:
Expand Down
3 changes: 1 addition & 2 deletions sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,7 @@ def _fill_max_n_below(self):
parent = self.parent(i)
if parent.pos not in self.missing_nodes:
max_n_below = parent.node.metadata.get('max_n_below', 0)
max_n_below = max(len(n.data.minhash.get_mins()),
max_n_below)
max_n_below = max(len(n.data.minhash), max_n_below)
parent.node.metadata['max_n_below'] = max_n_below

current = parent
Expand Down
3 changes: 1 addition & 2 deletions sourmash/sbtmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def update(self, parent):
for v in self.data.minhash.get_mins():
parent.data.count(v)
max_n_below = parent.metadata.get('max_n_below', 0)
max_n_below = max(len(self.data.minhash.get_mins()),
max_n_below)
max_n_below = max(len(self.data.minhash), max_n_below)
parent.metadata['max_n_below'] = max_n_below

@property
Expand Down
24 changes: 12 additions & 12 deletions tests/test__minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,38 +274,38 @@ def test_intersection_1(track_abundance):
common = set(a.get_mins())
combined_size = 3

intersection, size = a.intersection(b)
intersection, size = a.intersection(b, in_common=True)
assert intersection == common
assert combined_size == size

intersection, size = b.intersection(b)
intersection, size = b.intersection(b, in_common=True)
assert intersection == common
assert combined_size == size

intersection, size = b.intersection(a)
intersection, size = b.intersection(a, in_common=True)
assert intersection == common
assert combined_size == size

intersection, size = a.intersection(a)
intersection, size = a.intersection(a, in_common=True)
assert intersection == common
assert combined_size == size

# add same sequence again
b.add_sequence('TGCCGCCCAGCA')

intersection, size = a.intersection(b)
intersection, size = a.intersection(b, in_common=True)
assert intersection == common
assert combined_size == size

intersection, size = b.intersection(b)
intersection, size = b.intersection(b, in_common=True)
assert intersection == common
assert combined_size == size

intersection, size = b.intersection(a)
intersection, size = b.intersection(a, in_common=True)
assert intersection == common
assert combined_size == size

intersection, size = a.intersection(a)
intersection, size = a.intersection(a, in_common=True)
assert intersection == common
assert combined_size == size

Expand All @@ -315,18 +315,18 @@ def test_intersection_1(track_abundance):
new_in_common = set(a.get_mins()).intersection(set(b.get_mins()))
new_combined_size = 8

intersection, size = a.intersection(b)
intersection, size = a.intersection(b, in_common=True)
assert intersection == new_in_common
assert size == new_combined_size

intersection, size = b.intersection(a)
intersection, size = b.intersection(a, in_common=True)
assert intersection == new_in_common
assert size == new_combined_size

intersection, size = a.intersection(a)
intersection, size = a.intersection(a, in_common=True)
assert intersection == set(a.get_mins())

intersection, size = b.intersection(b)
intersection, size = b.intersection(b, in_common=True)
assert intersection == set(b.get_mins())


Expand Down

0 comments on commit ccc05ec

Please sign in to comment.