Skip to content

Commit

Permalink
[MRG] gather optimizations (#615)
Browse files Browse the repository at this point in the history
* don't recalculate scaled query minhash everytime

* partial progress

* remove_many and add_many tests
  • Loading branch information
luizirber authored and ctb committed Jan 11, 2019
1 parent 314bd8e commit 64b3017
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 19 deletions.
2 changes: 2 additions & 0 deletions sourmash/_minhash.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cdef extern from "kmer_min_hash.hh":

KmerMinHash(unsigned int, unsigned int, bool, uint32_t, HashIntoType)
void add_hash(HashIntoType) except +ValueError
void remove_hash(HashIntoType) except +ValueError
void add_word(string word) except +ValueError
void add_sequence(const char *, bool) except +ValueError
void merge(const KmerMinHash&) except +ValueError
Expand All @@ -42,6 +43,7 @@ cdef extern from "kmer_min_hash.hh":

KmerMinAbundance(unsigned int, unsigned int, bool, uint32_t, HashIntoType)
void add_hash(HashIntoType) except +ValueError
void remove_hash(HashIntoType) except +ValueError
void add_word(string word) except +ValueError
void add_sequence(const char *, bool) except +ValueError
void merge(const KmerMinAbundance&) except +ValueError
Expand Down
5 changes: 5 additions & 0 deletions sourmash/_minhash.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ cdef class MinHash(object):
for hash in hashes:
self.add_hash(hash)

def remove_many(self, hashes):
"Remove many hashes at once."
for hash in hashes:
deref(self._this).remove_hash(hash)

def update(self, other):
"Update this estimator from all the hashes from the other."
self.add_many(other.get_mins())
Expand Down
17 changes: 17 additions & 0 deletions sourmash/kmer_min_hash.hh
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ public:
}
}
}

virtual void remove_hash(const HashIntoType h) {
auto pos = std::lower_bound(std::begin(mins), std::end(mins), h);
if (pos != mins.cend() and *pos == h) {
mins.erase(pos);
}
}

void add_word(const std::string& word) {
const HashIntoType hash = _hash_murmur(word, seed);
add_hash(hash);
Expand Down Expand Up @@ -343,6 +351,15 @@ class KmerMinAbundance: public KmerMinHash {
}
}

virtual void remove_hash(const HashIntoType h) {
auto pos = std::lower_bound(std::begin(mins), std::end(mins), h);
if (pos != mins.cend() and *pos == h) {
mins.erase(pos);
size_t dist = std::distance(begin(mins), pos);
abunds.erase(begin(abunds) + dist);
}
}

virtual void merge(const KmerMinAbundance& other) {
check_compatible(other);

Expand Down
21 changes: 13 additions & 8 deletions sourmash/sbtmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,24 +207,29 @@ class GatherMinHashesFindBestIgnoreMaxHash(object):
def __init__(self, initial_best_match=0.0):
self.best_match = initial_best_match

def search(self, node, sig, threshold, results=None):
mins = sig.minhash.get_mins()

def search(self, node, query, threshold, results=None):
score = 0
if not len(mins):
if not len(query.minhash):
return 0

if isinstance(node, SigLeaf):
max_scaled = max(node.data.minhash.scaled, sig.minhash.scaled)
max_scaled = max(node.data.minhash.scaled, query.minhash.scaled)

mh1 = node.data.minhash
if mh1.scaled != max_scaled:
mh1 = node.data.minhash.downsample_scaled(max_scaled)

mh2 = query.minhash
if mh2.scaled != max_scaled:
mh2 = query.minhash.downsample_scaled(max_scaled)

mh1 = node.data.minhash.downsample_scaled(max_scaled)
mh2 = sig.minhash.downsample_scaled(max_scaled)
matches = mh1.count_common(mh2)
else: # Nodegraph by minhash comparison
mins = query.minhash.get_mins()
get = node.data.get
matches = sum(1 for value in mins if get(value))

score = float(matches) / len(mins)
score = float(matches) / len(query.minhash)

# store results if we have passed in an appropriate dictionary
if results is not None:
Expand Down
23 changes: 12 additions & 11 deletions sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,19 @@ def search_databases(query, databases, threshold, do_containment, best_only,
return results


# define a function to build new query object
def build_new_query(to_remove, old_query, scaled=None):
e = old_query.minhash
e.remove_many(to_remove)
if scaled:
e = e.downsample_scaled(scaled)
return SourmashSignature(e)


GatherResult = namedtuple('GatherResult',
'intersect_bp, f_orig_query, f_match, f_unique_to_query, f_unique_weighted, average_abund, median_abund, std_abund, filename, name, md5, leaf')


def gather_databases(query, databases, threshold_bp, ignore_abundance):
orig_query = query
orig_mins = orig_query.minhash.get_hashes()
Expand Down Expand Up @@ -174,17 +184,8 @@ def find_best(dblist, query, remainder):
return best_similarity, best_leaf, filename


# define a function to build new signature object from set of mins
def build_new_signature(mins, template_sig, scaled=None):
e = template_sig.minhash.copy_and_clear()
e.add_many(mins)
if scaled:
e = e.downsample_scaled(scaled)
return SourmashSignature(e)

# construct a new query that doesn't have the max_hash attribute set.
new_mins = query.minhash.get_hashes()
query = build_new_signature(new_mins, orig_query)
query = build_new_query([], orig_query)

cmp_scaled = 0
remainder = set()
Expand Down Expand Up @@ -263,8 +264,8 @@ def build_new_signature(mins, template_sig, scaled=None):
leaf=best_leaf)

# construct a new query, minus the previous one.
query = build_new_query(found_mins, orig_query, cmp_scaled)
query_mins -= set(found_mins)
query = build_new_signature(query_mins, orig_query, cmp_scaled)

weighted_missed = sum((orig_abunds[k] for k in query_mins)) \
/ sum_abunds
Expand Down
38 changes: 38 additions & 0 deletions tests/test__minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,3 +1030,41 @@ def test_distance_matrix(track_abundance):
D1[i][j] = E.similarity(E2, track_abundance)

assert numpy.array_equal(D1, D2)


def test_remove_many(track_abundance):
a = MinHash(0, 10, track_abundance=track_abundance, max_hash=5000)

a.add_many(list(range(0, 100, 2)))

orig_sig = signature.SourmashSignature(a)
orig_md5 = orig_sig.md5sum()

a.remove_many(list(range(0, 100, 3)))
new_sig = signature.SourmashSignature(a)
new_md5 = new_sig.md5sum()

assert orig_md5 == "f1cc295157374f5c07cfca5f867188a1"
assert new_md5 == "dd93fa319ef57f4a019c59ee1a8c73e2"
assert orig_md5 != new_md5

assert len(a) == 33
assert all(c % 6 != 0 for c in a.get_mins())


def test_add_many(track_abundance):
a = MinHash(0, 10, track_abundance=track_abundance, max_hash=5000)
b = MinHash(0, 10, track_abundance=track_abundance, max_hash=5000)

a.add_many(list(range(0, 100, 2)))
a.add_many(list(range(0, 100, 2)))

assert len(a) == 50
assert all(c % 2 == 0 for c in a.get_mins())

for h in range(0, 100, 2):
b.add_hash(h)
b.add_hash(h)

assert len(b) == 50
assert a == b

0 comments on commit 64b3017

Please sign in to comment.