From 64b301775d5a6508a62e1bed1bf68b6ce3a7dc4b Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Fri, 11 Jan 2019 14:46:33 +0000 Subject: [PATCH] [MRG] gather optimizations (#615) * don't recalculate scaled query minhash everytime * partial progress * remove_many and add_many tests --- sourmash/_minhash.pxd | 2 ++ sourmash/_minhash.pyx | 5 +++++ sourmash/kmer_min_hash.hh | 17 +++++++++++++++++ sourmash/sbtmh.py | 21 +++++++++++++-------- sourmash/search.py | 23 ++++++++++++----------- tests/test__minhash.py | 38 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 87 insertions(+), 19 deletions(-) diff --git a/sourmash/_minhash.pxd b/sourmash/_minhash.pxd index 59ec09ba21..1d6fe3a2dd 100644 --- a/sourmash/_minhash.pxd +++ b/sourmash/_minhash.pxd @@ -30,6 +30,7 @@ cdef extern from "kmer_min_hash.hh": KmerMinHash(unsigned int, unsigned int, bool, uint32_t, HashIntoType) void add_hash(HashIntoType) except +ValueError + void remove_hash(HashIntoType) except +ValueError void add_word(string word) except +ValueError void add_sequence(const char *, bool) except +ValueError void merge(const KmerMinHash&) except +ValueError @@ -42,6 +43,7 @@ cdef extern from "kmer_min_hash.hh": KmerMinAbundance(unsigned int, unsigned int, bool, uint32_t, HashIntoType) void add_hash(HashIntoType) except +ValueError + void remove_hash(HashIntoType) except +ValueError void add_word(string word) except +ValueError void add_sequence(const char *, bool) except +ValueError void merge(const KmerMinAbundance&) except +ValueError diff --git a/sourmash/_minhash.pyx b/sourmash/_minhash.pyx index a58b38c160..5652218667 100644 --- a/sourmash/_minhash.pyx +++ b/sourmash/_minhash.pyx @@ -189,6 +189,11 @@ cdef class MinHash(object): for hash in hashes: self.add_hash(hash) + def remove_many(self, hashes): + "Remove many hashes at once." + for hash in hashes: + deref(self._this).remove_hash(hash) + def update(self, other): "Update this estimator from all the hashes from the other." self.add_many(other.get_mins()) diff --git a/sourmash/kmer_min_hash.hh b/sourmash/kmer_min_hash.hh index 044f8471f9..068e3fde6f 100644 --- a/sourmash/kmer_min_hash.hh +++ b/sourmash/kmer_min_hash.hh @@ -114,6 +114,14 @@ public: } } } + + virtual void remove_hash(const HashIntoType h) { + auto pos = std::lower_bound(std::begin(mins), std::end(mins), h); + if (pos != mins.cend() and *pos == h) { + mins.erase(pos); + } + } + void add_word(const std::string& word) { const HashIntoType hash = _hash_murmur(word, seed); add_hash(hash); @@ -343,6 +351,15 @@ class KmerMinAbundance: public KmerMinHash { } } + virtual void remove_hash(const HashIntoType h) { + auto pos = std::lower_bound(std::begin(mins), std::end(mins), h); + if (pos != mins.cend() and *pos == h) { + mins.erase(pos); + size_t dist = std::distance(begin(mins), pos); + abunds.erase(begin(abunds) + dist); + } + } + virtual void merge(const KmerMinAbundance& other) { check_compatible(other); diff --git a/sourmash/sbtmh.py b/sourmash/sbtmh.py index ad9b8a8cfe..066b2a952d 100644 --- a/sourmash/sbtmh.py +++ b/sourmash/sbtmh.py @@ -207,24 +207,29 @@ class GatherMinHashesFindBestIgnoreMaxHash(object): def __init__(self, initial_best_match=0.0): self.best_match = initial_best_match - def search(self, node, sig, threshold, results=None): - mins = sig.minhash.get_mins() - + def search(self, node, query, threshold, results=None): score = 0 - if not len(mins): + if not len(query.minhash): return 0 if isinstance(node, SigLeaf): - max_scaled = max(node.data.minhash.scaled, sig.minhash.scaled) + max_scaled = max(node.data.minhash.scaled, query.minhash.scaled) + + mh1 = node.data.minhash + if mh1.scaled != max_scaled: + mh1 = node.data.minhash.downsample_scaled(max_scaled) + + mh2 = query.minhash + if mh2.scaled != max_scaled: + mh2 = query.minhash.downsample_scaled(max_scaled) - mh1 = node.data.minhash.downsample_scaled(max_scaled) - mh2 = sig.minhash.downsample_scaled(max_scaled) matches = mh1.count_common(mh2) else: # Nodegraph by minhash comparison + mins = query.minhash.get_mins() get = node.data.get matches = sum(1 for value in mins if get(value)) - score = float(matches) / len(mins) + score = float(matches) / len(query.minhash) # store results if we have passed in an appropriate dictionary if results is not None: diff --git a/sourmash/search.py b/sourmash/search.py index ffa385fa30..2106405812 100644 --- a/sourmash/search.py +++ b/sourmash/search.py @@ -107,9 +107,19 @@ def search_databases(query, databases, threshold, do_containment, best_only, return results +# define a function to build new query object +def build_new_query(to_remove, old_query, scaled=None): + e = old_query.minhash + e.remove_many(to_remove) + if scaled: + e = e.downsample_scaled(scaled) + return SourmashSignature(e) + + GatherResult = namedtuple('GatherResult', 'intersect_bp, f_orig_query, f_match, f_unique_to_query, f_unique_weighted, average_abund, median_abund, std_abund, filename, name, md5, leaf') + def gather_databases(query, databases, threshold_bp, ignore_abundance): orig_query = query orig_mins = orig_query.minhash.get_hashes() @@ -174,17 +184,8 @@ def find_best(dblist, query, remainder): return best_similarity, best_leaf, filename - # define a function to build new signature object from set of mins - def build_new_signature(mins, template_sig, scaled=None): - e = template_sig.minhash.copy_and_clear() - e.add_many(mins) - if scaled: - e = e.downsample_scaled(scaled) - return SourmashSignature(e) - # construct a new query that doesn't have the max_hash attribute set. - new_mins = query.minhash.get_hashes() - query = build_new_signature(new_mins, orig_query) + query = build_new_query([], orig_query) cmp_scaled = 0 remainder = set() @@ -263,8 +264,8 @@ def build_new_signature(mins, template_sig, scaled=None): leaf=best_leaf) # construct a new query, minus the previous one. + query = build_new_query(found_mins, orig_query, cmp_scaled) query_mins -= set(found_mins) - query = build_new_signature(query_mins, orig_query, cmp_scaled) weighted_missed = sum((orig_abunds[k] for k in query_mins)) \ / sum_abunds diff --git a/tests/test__minhash.py b/tests/test__minhash.py index 2c4159945a..f960ac9999 100644 --- a/tests/test__minhash.py +++ b/tests/test__minhash.py @@ -1030,3 +1030,41 @@ def test_distance_matrix(track_abundance): D1[i][j] = E.similarity(E2, track_abundance) assert numpy.array_equal(D1, D2) + + +def test_remove_many(track_abundance): + a = MinHash(0, 10, track_abundance=track_abundance, max_hash=5000) + + a.add_many(list(range(0, 100, 2))) + + orig_sig = signature.SourmashSignature(a) + orig_md5 = orig_sig.md5sum() + + a.remove_many(list(range(0, 100, 3))) + new_sig = signature.SourmashSignature(a) + new_md5 = new_sig.md5sum() + + assert orig_md5 == "f1cc295157374f5c07cfca5f867188a1" + assert new_md5 == "dd93fa319ef57f4a019c59ee1a8c73e2" + assert orig_md5 != new_md5 + + assert len(a) == 33 + assert all(c % 6 != 0 for c in a.get_mins()) + + +def test_add_many(track_abundance): + a = MinHash(0, 10, track_abundance=track_abundance, max_hash=5000) + b = MinHash(0, 10, track_abundance=track_abundance, max_hash=5000) + + a.add_many(list(range(0, 100, 2))) + a.add_many(list(range(0, 100, 2))) + + assert len(a) == 50 + assert all(c % 2 == 0 for c in a.get_mins()) + + for h in range(0, 100, 2): + b.add_hash(h) + b.add_hash(h) + + assert len(b) == 50 + assert a == b