diff --git a/include/oxli/hashgraph.hh b/include/oxli/hashgraph.hh index f450a42e6b..f0d0d4e2db 100644 --- a/include/oxli/hashgraph.hh +++ b/include/oxli/hashgraph.hh @@ -293,6 +293,7 @@ public: : Hashgraph(ksize, new BitStorage(sizes)) { } ; void update_from(const Nodegraph &other); + double similarity(const Nodegraph &other); }; } diff --git a/include/oxli/storage.hh b/include/oxli/storage.hh index 33fb6f7f73..37a69d9df8 100644 --- a/include/oxli/storage.hh +++ b/include/oxli/storage.hh @@ -226,6 +226,7 @@ public: } void update_from(const BitStorage&); + double similarity(const BitStorage&); }; diff --git a/khmer/_oxli/graphs.pxd b/khmer/_oxli/graphs.pxd index ce4d290e86..16e8627591 100644 --- a/khmer/_oxli/graphs.pxd +++ b/khmer/_oxli/graphs.pxd @@ -204,6 +204,7 @@ cdef extern from "oxli/hashgraph.hh" namespace "oxli" nogil: CpNodegraph(WordLength, vector[uint64_t]) void update_from(const CpNodegraph &) except +oxli_raise_py_error + double similarity(const CpNodegraph &) except +oxli_raise_py_error cdef extern from "oxli/labelhash.hh" namespace "oxli": diff --git a/khmer/_oxli/graphs.pyx b/khmer/_oxli/graphs.pyx index 992ae526ae..bf9c9b4f5e 100644 --- a/khmer/_oxli/graphs.pyx +++ b/khmer/_oxli/graphs.pyx @@ -866,3 +866,6 @@ cdef class Nodegraph(Hashgraph): def update(self, Nodegraph other): deref(self._ng_this).update_from(deref(other._ng_this)) + + def similarity(self, Nodegraph other): + return deref(self._ng_this).similarity(deref(other._ng_this)) diff --git a/src/oxli/hashgraph.cc b/src/oxli/hashgraph.cc index c9cd78d860..cc8aa0a3fe 100644 --- a/src/oxli/hashgraph.cc +++ b/src/oxli/hashgraph.cc @@ -906,6 +906,23 @@ void Nodegraph::update_from(const Nodegraph &otherBASE) } } +double Nodegraph::similarity(const Nodegraph &otherBASE) +{ + if (_ksize != otherBASE._ksize) { + throw oxli_exception("both nodegraphs must have same k size"); + } + BitStorage * myself = dynamic_cast(this->store); + const BitStorage * other; + other = dynamic_cast(otherBASE.store); + + // if dynamic_cast worked, then the pointers will be not null. + if (myself && other) { + return myself->similarity(*other); + } else { + throw oxli_exception("similarity failed with incompatible objects"); + } +} + template void Hashgraph::consume_seqfile_and_tag( std::string const &filename, unsigned int &total_reads, diff --git a/src/oxli/storage.cc b/src/oxli/storage.cc index 923843f451..8e9f559b20 100644 --- a/src/oxli/storage.cc +++ b/src/oxli/storage.cc @@ -95,6 +95,33 @@ void BitStorage::update_from(const BitStorage& other) } } +double BitStorage::similarity(const BitStorage& other) +{ + if (_tablesizes != other._tablesizes) { + throw oxli_exception("both nodegraphs must have same table sizes"); + } + + uint64_t intersection = 0; + uint64_t union_size = 0; + for (unsigned int table_num = 0; table_num < _n_tables; table_num++) { + Byte * me = _counts[table_num]; + Byte * ot = other._counts[table_num]; + uint64_t tablesize = _tablesizes[table_num]; + uint64_t tablebytes = tablesize / 8 + 1; + + for (uint64_t index = 0; index < tablebytes; index++) { + // First, get how many values in common we have + intersection += __builtin_popcountll(me[index] & ot[index]); + union_size += __builtin_popcountll(me[index] | ot[index]); + } + } + + if (union_size == 0) { + union_size = 1; + } + + return double(intersection) / double(union_size); +} void BitStorage::save(std::string outfilename, WordLength ksize) { diff --git a/tests/test_nodegraph.py b/tests/test_nodegraph.py index 607d521bfe..a5d1d42a7b 100755 --- a/tests/test_nodegraph.py +++ b/tests/test_nodegraph.py @@ -216,6 +216,26 @@ def test_update_from_diff_num_tables(): print(str(err)) +def test_similarity_1(): + nodegraph = khmer.Nodegraph(5, 1000, 4) + other_nodegraph = khmer.Nodegraph(5, 1000, 4) + + assert nodegraph.similarity(other_nodegraph) == 0 + + other_nodegraph.count('AAAAA') + + assert nodegraph.similarity(other_nodegraph) == 0 + + nodegraph.count('GCGCG') + + assert nodegraph.similarity(other_nodegraph) == 0 + + nodegraph.count('AAAAA') + other_nodegraph.count('GCGCG') + + assert nodegraph.similarity(other_nodegraph) == 1 + + def test_n_occupied_1(): filename = utils.get_test_data('random-20-a.fa')