From 93b9b90692c5beb371f80bc6d125b3976567e112 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Mon, 16 Dec 2019 01:40:48 +0000 Subject: [PATCH] Indexing in Rust (#773) * working on keeping Index trait close to Index abc add coverage CI job initial index cmd impl add batch_insert method use needletail for fastx parsing fix finch feature lint clippy fix benchmarks comment out code that depends on unreleased crates... another ci check, can we publish it\? add windows, macos, beta and nightly tests tests move ocf to another repo needletail 0.3.2 update for finch 0.3.0 fix warnings initial search test oracle tests for search install sourmash for oracle tests keep nodegraph working on wasm32 use only hash_functions internally niffler instead of ocf test all features for coverage compatibility with sourmash-py SBT note about md5sum and filenames move tests into specific SBTs remove unwrap add note to nodegraph parsing update typed-builder fix nodegraph behavior to match khmer ignore roundtrip_sbt while figuring out heisenbug --- .github/actions-rs/grcov.yml | 7 + .github/workflows/rust.yml | 231 +++++++++++++++++++++ .gitignore | 10 +- .travis.yml | 20 -- Cargo.toml | 79 ++++---- Makefile | 5 +- benches/index.rs | 85 +++----- include/sourmash.h | 51 +++++ ocf/Cargo.toml | 18 -- ocf/src/lib.rs | 275 ------------------------- requirements.txt | 19 +- setup.py | 8 +- sourmash/sbtmh.py | 3 +- src/bin/smrs.rs | 157 ++++++++++----- src/bin/sourmash.yml | 19 ++ src/cmd.rs | 118 ++++++----- src/errors.rs | 13 +- src/ffi/minhash.rs | 96 +++++++-- src/ffi/mod.rs | 5 + src/ffi/nodegraph.rs | 181 +++++++++++++++++ src/ffi/signature.rs | 18 +- src/{ => ffi}/utils.rs | 4 +- src/from.rs | 55 +++-- src/index/bigsi.rs | 72 +++---- src/index/linear.rs | 95 +++++---- src/index/mod.rs | 129 +++++++++--- src/index/sbt/mhbt.rs | 295 +++++++++++++++++++++++++-- src/index/sbt/mhmt.rs | 122 +++++++++-- src/index/sbt/mod.rs | 381 ++++++++++++----------------------- src/index/sbt/ukhs.rs | 24 +-- src/index/storage.rs | 4 +- src/lib.rs | 3 - src/signature.rs | 62 +++--- src/sketch/minhash.rs | 182 ++++++++++------- src/sketch/mod.rs | 3 +- src/sketch/nodegraph.rs | 93 +++++++-- src/sketch/ukhs.rs | 202 +++++++++++-------- src/wasm.rs | 21 +- tests/minhash.rs | 18 +- tests/smrs_cmd.rs | 139 +++++++++++++ tox.ini | 18 +- 41 files changed, 2118 insertions(+), 1222 deletions(-) create mode 100644 .github/actions-rs/grcov.yml create mode 100644 .github/workflows/rust.yml delete mode 100644 ocf/Cargo.toml delete mode 100644 ocf/src/lib.rs create mode 100644 src/ffi/nodegraph.rs rename src/{ => ffi}/utils.rs (97%) create mode 100644 tests/smrs_cmd.rs diff --git a/.github/actions-rs/grcov.yml b/.github/actions-rs/grcov.yml new file mode 100644 index 0000000000..d8822ccdb7 --- /dev/null +++ b/.github/actions-rs/grcov.yml @@ -0,0 +1,7 @@ +branch: true +ignore-not-existing: true +llvm: true +filter: covered +output-type: lcov +output-file: ./lcov.info +prefix-dir: /home/user/build/ diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000000..2972ac0b8a --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,231 @@ +name: Rust checks + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + check: + name: Check + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v1 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - name: Run cargo check + uses: actions-rs/cargo@v1 + with: + command: check + + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + build: [beta, stable, windows, macos] + include: + - build: macos + os: macos-latest + rust: stable + - build: windows + os: windows-latest + rust: stable + - build: beta + os: ubuntu-latest + rust: beta + - build: stable + os: ubuntu-latest + rust: stable + steps: + - uses: actions/checkout@v1 + + - uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.rust }} + override: true + + - name: Set up Python 3.8 + if: matrix.os != 'windows-latest' + uses: actions/setup-python@v1 + with: + python-version: "3.8" + + - name: Install dependencies + if: matrix.os != 'windows-latest' + run: | + python -m pip install --upgrade pip + python -m pip install -e . + + - name: Run tests + uses: actions-rs/cargo@v1 + with: + command: test + args: --no-fail-fast + + test_all_features: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + override: true + + - name: Set up Python 3.8 + uses: actions/setup-python@v1 + with: + python-version: "3.8" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e . + + - name: Run tests + uses: actions-rs/cargo@v1 + with: + command: test + args: --no-fail-fast --all --all-features + + coverage: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true + + - name: Set up Python 3.8 + uses: actions/setup-python@v1 + with: + python-version: "3.8" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e . + + - name: Run tests + uses: actions-rs/cargo@v1 + with: + command: test + args: --no-fail-fast --all --all-features + env: + 'CARGO_INCREMENTAL': '0' + 'RUSTFLAGS': '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Zno-landing-pads' + + - name: Collect coverage and generate report with grcov + uses: actions-rs/grcov@v0.1.4 + id: coverage + + - name: Upload coverage to codecov + uses: codecov/codecov-action@v1.0.3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ${{ steps.coverage.outputs.report }} + + lints: + name: Lints + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v1 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + components: rustfmt, clippy + + - name: Run cargo fmt + uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check + + - name: Run cargo clippy + uses: actions-rs/cargo@v1 + with: + command: clippy + args: -- -D warnings + + wasm-pack: + name: Check if wasm-pack builds a valid package + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@master + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + target: wasm32-unknown-unknown + - uses: actions-rs/cargo@v1 + with: + command: install + args: --force wasm-pack + - name: run wasm-pack + run: wasm-pack build + + wasm32-wasi: + name: Run tests under wasm32-wasi + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@master + - name: Install wasm32-wasi target + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + target: wasm32-wasi + - name: Install wasmtime + run: "curl https://wasmtime.dev/install.sh -sSf | bash" + - name: Add wasmtime to PATH + run: echo "::add-path::$HOME/.wasmtime/bin" + - name: Install cargo-wasi command + uses: actions-rs/cargo@v1 + with: + command: install + args: --force cargo-wasi + - name: Build code with cargo-wasi + uses: actions-rs/cargo@v1 + with: + command: wasi + args: build + - name: Run tests under wasm32-wasi + uses: actions-rs/cargo@v1 + continue-on-error: true ## TODO: remove this when tests work... + with: + command: wasi + args: test + + publish: + name: Publish (dry-run) + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v1 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - name: Make sure we can publish the crate + uses: actions-rs/cargo@v1 + with: + command: publish + args: --dry-run diff --git a/.gitignore b/.gitignore index db06c2c4ae..0aa0e7a6f1 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,12 @@ sourmash/_minhash.cpp .pytest_cache .python-version sourmash/version.py -*.DS_Store \ No newline at end of file +*.DS_Store +.tox +sourmash/_lowlevel*.py +.env +Pipfile +Pipfile.lock +ocf/target/ +target/ +Cargo.lock diff --git a/.travis.yml b/.travis.yml index a971a870f1..937b4b6c0b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -48,26 +48,6 @@ jobs: python: 3.6 - <<: *test python: 3.5 - - <<: *test - name: wasm-pack - language: rust - rust: stable - before_script: skip - install: skip - script: - - rustup target add wasm32-unknown-unknown - - cargo install --force wasm-pack - - wasm-pack build - - <<: *test - name: wasi target - language: rust - rust: stable - before_script: skip - install: skip - script: - - rustup target add wasm32-wasi - - cargo install --force cargo-wasi - - cargo wasi build - &wheel stage: build wheel and send to github releases diff --git a/Cargo.toml b/Cargo.toml index 7c41ba6a2c..f188cc1ad4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,9 @@ keywords = ["minhash", "bioinformatics"] categories = ["science", "algorithms", "data-structures"] license = "BSD-3-Clause" edition = "2018" +default-run = "smrs" +autoexamples = false +autobins = false [lib] name = "sourmash" @@ -18,7 +21,7 @@ bench = false lto=true [features] -from-finch = ["finch", "needletail"] +from-finch = ["finch"] [workspace] @@ -26,53 +29,59 @@ from-finch = ["finch", "needletail"] #cbindgen = "~0.6.7" [dependencies] -byteorder = "^1.2" -cfg-if = "0.1" -clap = { version = "~2.32", features = ["yaml"] } -env_logger = "0.6.0" +byteorder = "1.3.2" +cfg-if = "0.1.10" +clap = { version = "2.33.0", features = ["yaml"] } +env_logger = "0.7.1" exitfailure = "0.5.1" -failure = "0.1.3" -failure_derive = "0.1.3" -finch = { version = "~0.1.6", optional = true } -fixedbitset = "^0.1.9" -human-panic = "1.0.1" -lazy_static = "1.0.0" +failure = "0.1.6" +failure_derive = "0.1.6" +finch = { version = "0.3.0", optional = true } +fixedbitset = "0.2.0" +lazy_static = "1.4.0" lazy-init = "0.3.0" -log = "0.4.0" -md5 = "0.6.0" -murmurhash3 = "~0.0.5" -needletail = { version = "~0.2.1", optional = true } -serde = "1.0" -serde_derive = "~1.0.58" -serde_json = "1.0.2" -ukhs = { git = "https://github.com/luizirber/ukhs", branch = "feature/alternative_backends", features = ["boomphf_mphf"], default-features = false} -bio = { git = "https://github.com/luizirber/rust-bio", branch = "feature/fastx_reader" } +log = "0.4.8" +md5 = "0.7.0" +murmurhash3 = "0.0.5" +serde = "1.0.103" +serde_derive = "1.0.103" +serde_json = "1.0.44" +#ukhs = { git = "https://github.com/luizirber/ukhs", branch = "feature/alternative_backends", features = ["boomphf_mphf"], default-features = false} primal = "0.2.3" -pdatastructs = { git = "https://github.com/luizirber/pdatastructs.rs", branch = "succinct_wasm" } -itertools = "0.8.0" -typed-builder = "0.3.0" -csv = "1.0.7" +#pdatastructs = { git = "https://github.com/luizirber/pdatastructs.rs", branch = "succinct_wasm" } +itertools = "0.8.2" +typed-builder = "0.4.0" +csv = "1.1.1" tempfile = "3.1.0" +[dependencies.needletail] +version = "0.3.2" +default-features = false +#features = ["compression"] + [target.'cfg(all(target_arch = "wasm32", target_vendor="unknown"))'.dependencies.wasm-bindgen] -version = "^0.2" +version = "0.2.55" features = ["serde-serialize"] -[target.'cfg(not(all(target_arch = "wasm32", target_vendor="unknown")))'.dependencies.ocf] -version = "0.1" -path = "ocf" +[target.'cfg(not(all(target_arch = "wasm32", target_vendor="unknown")))'.dependencies.niffler] +version = "1.0" default-features = false -[target.'cfg(not(target_arch = "wasm32"))'.dependencies.mqf] -version = "1.0.0" +[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] +proptest = "0.9.4" [dev-dependencies] -proptest = "^0.8" -criterion = "^0.2" -rand = "^0.5" -tempfile = "3" -assert_matches = "1.2" +criterion = "0.3.0" +rand = "0.7.2" +tempfile = "3.1.0" +assert_matches = "1.3.0" +assert_cmd = "0.12.0" +predicates = "1.0.2" [[bench]] name = "index" harness = false + +[[bin]] +name = "smrs" +path = "src/bin/smrs.rs" diff --git a/Makefile b/Makefile index 482f497302..8bb87d29dd 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ all: clean: $(PYTHON) setup.py clean --all + rm -f sourmash/*.so cd doc && make clean install: all @@ -22,7 +23,7 @@ test: all doc: .PHONY cd doc && make html -include/sourmash.h: src/lib.rs src/ffi/minhash.rs src/ffi/signature.rs src/errors.rs +include/sourmash.h: src/lib.rs src/ffi/minhash.rs src/ffi/signature.rs src/ffi/nodegraph.rs src/errors.rs rustup override set nightly RUST_BACKTRACE=1 cbindgen --clean -c cbindgen.toml -o $@ rustup override set stable @@ -33,7 +34,7 @@ coverage: all $(PYTHON) -m pytest --cov=. --cov-report term-missing benchmark: - asv continuous master $(git rev-parse HEAD) + asv continuous master `git rev-parse HEAD` check: cargo build diff --git a/benches/index.rs b/benches/index.rs index 2ea53a5c72..e2f48a724f 100644 --- a/benches/index.rs +++ b/benches/index.rs @@ -6,9 +6,8 @@ use std::path::PathBuf; use criterion::{Bencher, Criterion, Fun}; use sourmash::index::bigsi::BIGSI; use sourmash::index::linear::LinearIndex; -use sourmash::index::storage::ReadData; +use sourmash::index::Index; use sourmash::index::MHBT; -use sourmash::index::{Dataset, Index}; use sourmash::signature::Signature; fn find_small_bench(c: &mut Criterion) { @@ -17,39 +16,30 @@ fn find_small_bench(c: &mut Criterion) { let sbt = MHBT::from_path(filename).expect("Loading error"); - let leaf: Dataset = (*sbt.datasets().first().unwrap()).clone(); + let leaf: Signature = (*sbt.signatures().first().unwrap()).clone(); let mut linear = LinearIndex::builder().storage(sbt.storage()).build(); - for l in &sbt.datasets() { - linear.insert(l); + for l in sbt.signatures() { + linear.insert(l).unwrap(); } let mut bigsi = BIGSI::new(10000, 10); - for l in &sbt.datasets() { - let data = l.data().unwrap(); - bigsi.insert(data); + for l in sbt.signatures() { + bigsi.insert(l).unwrap(); } - let sbt_find = Fun::new( - "sbt_search", - move |b: &mut Bencher, leaf: &Dataset| b.iter(|| sbt.search(leaf, 0.1, false)), - ); - - let linear_find = Fun::new( - "linear_search", - move |b: &mut Bencher, leaf: &Dataset| { - b.iter(|| linear.search(leaf, 0.1, false)) - }, - ); - - let bigsi_find = Fun::new( - "bigsi_search", - move |b: &mut Bencher, leaf: &Dataset| { - let data = leaf.data().unwrap(); - b.iter(|| bigsi.search(data, 0.1, false)) - }, - ); + let sbt_find = Fun::new("sbt_search", move |b: &mut Bencher, leaf: &Signature| { + b.iter(|| sbt.search(leaf, 0.1, false)) + }); + + let linear_find = Fun::new("linear_search", move |b: &mut Bencher, leaf: &Signature| { + b.iter(|| linear.search(leaf, 0.1, false)) + }); + + let bigsi_find = Fun::new("bigsi_search", move |b: &mut Bencher, leaf: &Signature| { + b.iter(|| bigsi.search(leaf, 0.1, false)) + }); let functions = vec![sbt_find, linear_find, bigsi_find]; c.bench_functions("search_small", functions, leaf); @@ -61,38 +51,29 @@ fn find_subset_bench(c: &mut Criterion) { let sbt = MHBT::from_path(filename).expect("Loading error"); - let leaf: Dataset = (*sbt.datasets().first().unwrap()).clone(); + let leaf: Signature = (*sbt.signatures().first().unwrap()).clone(); let mut linear = LinearIndex::builder().storage(sbt.storage()).build(); - for l in &sbt.datasets() { - linear.insert(l); + for l in sbt.signatures() { + linear.insert(l).unwrap(); } let mut bigsi = BIGSI::new(10000, 10); - for l in &sbt.datasets() { - let data = l.data().unwrap(); - bigsi.insert(data); + for l in sbt.signatures() { + bigsi.insert(l).unwrap(); } - let sbt_find = Fun::new( - "sbt_search", - move |b: &mut Bencher, leaf: &Dataset| b.iter(|| sbt.search(leaf, 0.1, false)), - ); - - let linear_find = Fun::new( - "linear_search", - move |b: &mut Bencher, leaf: &Dataset| { - b.iter(|| linear.search(leaf, 0.1, false)) - }, - ); - - let bigsi_find = Fun::new( - "bigsi_search", - move |b: &mut Bencher, leaf: &Dataset| { - let data = leaf.data().unwrap(); - b.iter(|| bigsi.search(data, 0.1, false)) - }, - ); + let sbt_find = Fun::new("sbt_search", move |b: &mut Bencher, leaf: &Signature| { + b.iter(|| sbt.search(leaf, 0.1, false)) + }); + + let linear_find = Fun::new("linear_search", move |b: &mut Bencher, leaf: &Signature| { + b.iter(|| linear.search(leaf, 0.1, false)) + }); + + let bigsi_find = Fun::new("bigsi_search", move |b: &mut Bencher, leaf: &Signature| { + b.iter(|| bigsi.search(leaf, 0.1, false)) + }); let functions = vec![sbt_find, linear_find, bigsi_find]; c.bench_functions("search_subset", functions, leaf); diff --git a/include/sourmash.h b/include/sourmash.h index ed01e1cb3c..607b37e385 100644 --- a/include/sourmash.h +++ b/include/sourmash.h @@ -8,6 +8,14 @@ #include #include +enum HashFunctions { + HASH_FUNCTIONS_MURMUR64_DNA = 1, + HASH_FUNCTIONS_MURMUR64_PROTEIN = 2, + HASH_FUNCTIONS_MURMUR64_DAYHOFF = 3, + HASH_FUNCTIONS_MURMUR64_HP = 4, +}; +typedef uint32_t HashFunctions; + enum SourmashErrorCode { SOURMASH_ERROR_CODE_NO_ERROR = 0, SOURMASH_ERROR_CODE_PANIC = 1, @@ -32,6 +40,8 @@ typedef uint32_t SourmashErrorCode; typedef struct KmerMinHash KmerMinHash; +typedef struct Nodegraph Nodegraph; + typedef struct Signature Signature; /** @@ -57,6 +67,8 @@ void kmerminhash_add_word(KmerMinHash *ptr, const char *word); double kmerminhash_compare(KmerMinHash *ptr, const KmerMinHash *other); +double kmerminhash_containment_ignore_maxhash(KmerMinHash *ptr, const KmerMinHash *other); + uint64_t kmerminhash_count_common(KmerMinHash *ptr, const KmerMinHash *other); bool kmerminhash_dayhoff(KmerMinHash *ptr); @@ -79,6 +91,12 @@ const uint64_t *kmerminhash_get_mins(KmerMinHash *ptr); uintptr_t kmerminhash_get_mins_size(KmerMinHash *ptr); +HashFunctions kmerminhash_hash_function(KmerMinHash *ptr); + +void kmerminhash_hash_function_set(KmerMinHash *ptr, HashFunctions hash_function); + +bool kmerminhash_hp(KmerMinHash *ptr); + uint64_t kmerminhash_intersection(KmerMinHash *ptr, const KmerMinHash *other); bool kmerminhash_is_protein(KmerMinHash *ptr); @@ -95,6 +113,7 @@ KmerMinHash *kmerminhash_new(uint32_t n, uint32_t k, bool prot, bool dayhoff, + bool hp, uint64_t seed, uint64_t mx, bool track_abundance); @@ -109,6 +128,36 @@ uint64_t kmerminhash_seed(KmerMinHash *ptr); bool kmerminhash_track_abundance(KmerMinHash *ptr); +bool nodegraph_count(Nodegraph *ptr, uint64_t h); + +double nodegraph_expected_collisions(Nodegraph *ptr); + +void nodegraph_free(Nodegraph *ptr); + +Nodegraph *nodegraph_from_buffer(const char *ptr, uintptr_t insize); + +Nodegraph *nodegraph_from_path(const char *filename); + +uintptr_t nodegraph_get(Nodegraph *ptr, uint64_t h); + +uintptr_t nodegraph_ksize(Nodegraph *ptr); + +uintptr_t nodegraph_matches(Nodegraph *ptr, KmerMinHash *mh_ptr); + +Nodegraph *nodegraph_new(void); + +uintptr_t nodegraph_noccupied(Nodegraph *ptr); + +uintptr_t nodegraph_ntables(Nodegraph *ptr); + +void nodegraph_save(Nodegraph *ptr, const char *filename); + +uintptr_t nodegraph_tablesize(Nodegraph *ptr); + +void nodegraph_update(Nodegraph *ptr, Nodegraph *optr); + +Nodegraph *nodegraph_with_tables(uintptr_t ksize, uintptr_t starting_size, uintptr_t n_tables); + bool signature_eq(Signature *ptr, Signature *other); KmerMinHash *signature_first_mh(Signature *ptr); @@ -152,6 +201,8 @@ SourmashStr signatures_save_buffer(Signature **ptr, uintptr_t size); char sourmash_aa_to_dayhoff(char aa); +char sourmash_aa_to_hp(char aa); + /** * Clears the last error. */ diff --git a/ocf/Cargo.toml b/ocf/Cargo.toml deleted file mode 100644 index 1752b25b6d..0000000000 --- a/ocf/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "ocf" -version = "0.1.0" -authors = ["Luiz Irber "] -edition = "2018" - -[features] -default = ["bzip2", "xz2"] -bz2 = ["bzip2"] -lzma = ["xz2"] - -[dependencies] -bzip2 = { version = "0.3.3", optional = true } -cfg-if = "0.1" -failure = "0.1.3" -flate2 = "1.0" -enum_primitive = "0.1.1" -xz2 = { version = "0.1", optional = true } diff --git a/ocf/src/lib.rs b/ocf/src/lib.rs deleted file mode 100644 index f9662244fc..0000000000 --- a/ocf/src/lib.rs +++ /dev/null @@ -1,275 +0,0 @@ -/* -Copyright (c) 2018 Pierre Marijon - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Originally from https://github.com/natir/yacrd/blob/3fc6ef8b5b51256f0c4bc45b8056167acf34fa58/src/file.rs -Changes: - - make bzip2 and lzma support optional -*/ - -/* crates use */ -use cfg_if::cfg_if; -use enum_primitive::{ - enum_from_primitive, enum_from_primitive_impl, enum_from_primitive_impl_ty, FromPrimitive, -}; -use failure::{Error, Fail}; -use flate2; - -/* standard use */ -use std::fs::File; -use std::io; -use std::io::{BufReader, BufWriter}; - -enum_from_primitive! { - #[repr(u64)] - #[derive(Debug, PartialEq)] - pub enum CompressionFormat { - Gzip = 0x1F8B, - Bzip = 0x425A, - Lzma = 0xFD377A585A, - No, - } -} - -#[derive(Debug, Fail)] -pub enum OCFError { - #[fail(display = "Feature disabled, enabled it during compilation")] - FeatureDisabled, -} - -pub fn get_input(input_name: &str) -> Result<(Box, CompressionFormat), Error> { - // choose std::io::stdin or open file - if input_name == "-" { - Ok((Box::new(get_readable(input_name)), CompressionFormat::No)) - } else { - get_readable_file(input_name) - } -} - -pub fn get_readable_file( - input_name: &str, -) -> Result<(Box, CompressionFormat), Error> { - let raw_input = get_readable(input_name); - - // check compression - let compression = get_compression(raw_input); - - // return readable and compression status - match compression { - CompressionFormat::Gzip => Ok(( - Box::new(flate2::read::GzDecoder::new(get_readable(input_name))), - CompressionFormat::Gzip, - )), - CompressionFormat::Bzip => new_bz2_decoder(get_readable(input_name)), - CompressionFormat::Lzma => new_lzma_decoder(get_readable(input_name)), - CompressionFormat::No => Ok((Box::new(get_readable(input_name)), CompressionFormat::No)), - } -} - -pub fn get_readable(input_name: &str) -> Box { - match input_name { - "-" => Box::new(BufReader::new(io::stdin())), - _ => Box::new(BufReader::new( - File::open(input_name) - .unwrap_or_else(|_| panic!("Can't open input file {}", input_name)), - )), - } -} - -fn get_compression(mut in_stream: Box) -> CompressionFormat { - let mut buf = vec![0u8; 5]; - - in_stream - .read_exact(&mut buf) - .expect("Error durring reading first bit of file"); - - let mut five_bit_val: u64 = 0; - for (i, item) in buf.iter().enumerate().take(5) { - five_bit_val |= (u64::from(*item)) << (8 * (4 - i)); - } - if CompressionFormat::from_u64(five_bit_val) == Some(CompressionFormat::Lzma) { - return CompressionFormat::Lzma; - } - - let mut two_bit_val: u64 = 0; - for (i, item) in buf.iter().enumerate().take(2) { - two_bit_val |= (u64::from(*item)) << (8 * (1 - i)); - } - - match CompressionFormat::from_u64(two_bit_val) { - e @ Some(CompressionFormat::Gzip) | e @ Some(CompressionFormat::Bzip) => e.unwrap(), - _ => CompressionFormat::No, - } -} - -cfg_if! { - if #[cfg(feature = "bz2")] { - use bzip2; - - fn new_bz2_encoder(out: Box) -> Result, Error> { - Ok(Box::new(bzip2::write::BzEncoder::new( - out, - bzip2::Compression::Best, - ))) - } - - fn new_bz2_decoder( - inp: Box, - ) -> Result<(Box, CompressionFormat), Error> { - use bzip2; - Ok(( - Box::new(bzip2::read::BzDecoder::new(inp)), - CompressionFormat::Bzip, - )) - } - } else { - fn new_bz2_encoder(_: Box) -> Result, Error> { - Err(OCFError::FeatureDisabled.into()) - } - - fn new_bz2_decoder(_: Box) -> Result<(Box, CompressionFormat), Error> { - Err(OCFError::FeatureDisabled.into()) - } - } -} - -cfg_if! { - if #[cfg(feature = "lzma")] { - use xz2; - - fn new_lzma_encoder(out: Box) -> Result, Error> { - Ok(Box::new(xz2::write::XzEncoder::new(out, 9))) - } - - fn new_lzma_decoder( - inp: Box, - ) -> Result<(Box, CompressionFormat), Error> { - use xz2; - Ok(( - Box::new(xz2::read::XzDecoder::new(inp)), - CompressionFormat::Lzma, - )) - } - } else { - fn new_lzma_encoder(_: Box) -> Result, Error> { - Err(OCFError::FeatureDisabled.into()) - } - - fn new_lzma_decoder(_: Box) -> Result<(Box, CompressionFormat), Error> { - Err(OCFError::FeatureDisabled.into()) - } - } -} - -pub fn get_output( - output_name: &str, - format: CompressionFormat, -) -> Result, Error> { - match format { - CompressionFormat::Gzip => Ok(Box::new(flate2::write::GzEncoder::new( - get_writable(output_name), - flate2::Compression::best(), - ))), - CompressionFormat::Bzip => new_bz2_encoder(get_writable(output_name)), - CompressionFormat::Lzma => new_lzma_encoder(get_writable(output_name)), - CompressionFormat::No => Ok(Box::new(get_writable(output_name))), - } -} - -pub fn choose_compression( - input_compression: CompressionFormat, - compression_set: bool, - compression_value: &str, -) -> CompressionFormat { - if !compression_set { - return input_compression; - } - - match compression_value { - "gzip" => CompressionFormat::Gzip, - "bzip2" => CompressionFormat::Bzip, - "lzma" => CompressionFormat::Lzma, - _ => CompressionFormat::No, - } -} - -fn get_writable(output_name: &str) -> Box { - match output_name { - "-" => Box::new(BufWriter::new(io::stdout())), - _ => Box::new(BufWriter::new( - File::create(output_name) - .unwrap_or_else(|_| panic!("Can't open output file {}", output_name)), - )), - } -} - -#[cfg(test)] -mod test { - - use super::*; - - const GZIP_FILE: &'static [u8] = &[0o037, 0o213, 0o0, 0o0, 0o0]; - const BZIP_FILE: &'static [u8] = &[0o102, 0o132, 0o0, 0o0, 0o0]; - const LZMA_FILE: &'static [u8] = &[0o375, 0o067, 0o172, 0o130, 0o132]; - - #[test] - fn compression_from_file() { - assert_eq!( - get_compression(Box::new(GZIP_FILE)), - CompressionFormat::Gzip - ); - assert_eq!( - get_compression(Box::new(BZIP_FILE)), - CompressionFormat::Bzip - ); - assert_eq!( - get_compression(Box::new(LZMA_FILE)), - CompressionFormat::Lzma - ); - } - - #[test] - fn compression_from_input_or_cli() { - assert_eq!( - choose_compression(CompressionFormat::Gzip, false, "_"), - CompressionFormat::Gzip - ); - assert_eq!( - choose_compression(CompressionFormat::Bzip, false, "_"), - CompressionFormat::Bzip - ); - assert_eq!( - choose_compression(CompressionFormat::Lzma, false, "_"), - CompressionFormat::Lzma - ); - assert_eq!( - choose_compression(CompressionFormat::No, true, "gzip"), - CompressionFormat::Gzip - ); - assert_eq!( - choose_compression(CompressionFormat::No, true, "bzip2"), - CompressionFormat::Bzip - ); - assert_eq!( - choose_compression(CompressionFormat::No, true, "lzma"), - CompressionFormat::Lzma - ); - } -} diff --git a/requirements.txt b/requirements.txt index 7def5908b1..5be6ef7bce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,4 @@ -pytest -screed -numpy -matplotlib -scipy -Cython -khmer>=2.1,<3 -sphinx -alabaster -recommonmark -sphinxcontrib-napoleon -setuptools_scm -setuptools_scm_git_archive -nbsphinx -bam2fasta +-e .[test] +-e .[doc] +-e .[10x] +-e .[storage] diff --git a/setup.py b/setup.py index 6b43497c4a..1f1720821d 100644 --- a/setup.py +++ b/setup.py @@ -70,10 +70,12 @@ 'setuptools_scm', 'setuptools_scm_git_archive'], "use_scm_version": {"write_to": "sourmash/version.py"}, "extras_require": { - 'test' : ['pytest', 'pytest-cov', 'recommonmark'], + 'test' : ['pytest', 'pytest-cov'], 'demo' : ['jupyter', 'jupyter_client', 'ipython'], - 'doc' : ['sphinx'], - '10x': ['bam2fasta==1.0.1'] + 'doc' : ['sphinx', 'recommonmark', 'alabaster', + "sphinxcontrib-napoleon", "nbsphinx"], + '10x': ['bam2fasta==1.0.1'], + 'storage': ["ipfshttpclient", "redis"] }, "include_package_data": True, "package_data": { diff --git a/sourmash/sbtmh.py b/sourmash/sbtmh.py index 44067d8965..2289e0cb94 100644 --- a/sourmash/sbtmh.py +++ b/sourmash/sbtmh.py @@ -57,8 +57,7 @@ def update(self, parent): for v in self.data.minhash.get_mins(): parent.data.count(v) min_n_below = parent.metadata.get('min_n_below', sys.maxsize) - min_n_below = min(len(self.data.minhash.get_mins()), - min_n_below) + min_n_below = min(len(self.data.minhash), min_n_below) if min_n_below == 0: min_n_below = 1 diff --git a/src/bin/smrs.rs b/src/bin/smrs.rs index d0ea0b0847..b58f2c44ee 100644 --- a/src/bin/smrs.rs +++ b/src/bin/smrs.rs @@ -6,25 +6,82 @@ use std::rc::Rc; use clap::{load_yaml, App}; use exitfailure::ExitFailure; use failure::Error; -//use human_panic::setup_panic; -use lazy_init::Lazy; use log::{info, LevelFilter}; -use ocf::{get_output, CompressionFormat}; +use niffler::{get_output, CompressionFormat}; use serde::ser::SerializeStruct; use serde::{Serialize, Serializer}; -use sourmash::cmd::{ - count_unique, draff_compare, draff_index, draff_search, draff_signature, prepare, -}; +/* FIXME bring back after succint-rs changes +use sourmash::cmd::{count_unique, draff_compare, draff_search, draff_signature, prepare}; +*/ +use sourmash::cmd::prepare; + use sourmash::index::linear::LinearIndex; use sourmash::index::sbt::scaffold; use sourmash::index::search::{ search_minhashes, search_minhashes_containment, search_minhashes_find_best, }; -use sourmash::index::{Comparable, Dataset, Index, MHBT}; +use sourmash::index::storage::{FSStorage, Storage}; +use sourmash::index::{Comparable, Index, MHBT}; use sourmash::signature::{Signature, SigsTrait}; use sourmash::sketch::Sketch; +pub fn index( + sig_files: Vec<&str>, + storage: Rc, + outfile: &str, +) -> Result { + let mut index = MHBT::builder().storage(Rc::clone(&storage)).build(); + + for filename in sig_files { + // TODO: check for stdin? can also use get_input()? + + let mut sig = Signature::from_path(filename)?; + + if sig.len() > 1 { + unimplemented!(); + }; + + index.insert(sig.pop().unwrap())?; + } + + // TODO: implement to_writer and use this? + //let mut output = get_output(outfile, CompressionFormat::No)?; + //index.to_writer(&mut output)? + + index.save_file(outfile, Some(storage))?; + + Ok(Indices::MHBT(index)) + + /* + let mut lindex = LinearIndex::::builder() + .storage(Rc::clone(&storage)) + .build(); + + for filename in sig_files { + // TODO: check for stdin? can also use get_input()? + + let mut sig = Signature::from_path(filename)?; + + if sig.len() > 1 { + unimplemented!(); + }; + + lindex.insert(sig.pop().unwrap())?; + } + + let mut index: MHBT = lindex.into(); + + // TODO: implement to_writer and use this? + //let mut output = get_output(outfile, CompressionFormat::No)?; + //index.to_writer(&mut output)? + + index.save_file(outfile, Some(storage))?; + + Ok(Indices::MHBT(index)) + */ +} + struct Query { data: T, } @@ -53,28 +110,19 @@ impl Query { } fn name(&self) -> String { - self.data.name().clone() + self.data.name() } } -impl From> for Dataset { - fn from(other: Query) -> Dataset { - let data = Lazy::new(); - data.get_or_create(|| other.data); - - Dataset::builder() - .data(Rc::new(data)) - .filename("") - .name("") - .metadata("") - .storage(None) - .build() +impl From> for Signature { + fn from(other: Query) -> Signature { + other.data } } fn load_query_signature( query: &str, - ksize: usize, + ksize: Option, moltype: Option<&str>, scaled: Option, ) -> Result, Error> { @@ -93,13 +141,13 @@ struct Database { path: String, } -enum Indices { +pub enum Indices { MHBT(MHBT), - LinearIndex(LinearIndex>), + LinearIndex(LinearIndex), } -impl Index for Database { - type Item = Dataset; +impl Index<'_> for Database { + type Item = Signature; fn find( &self, @@ -116,7 +164,7 @@ impl Index for Database { } } - fn insert(&mut self, node: &Self::Item) -> Result<(), Error> { + fn insert(&mut self, node: Self::Item) -> Result<(), Error> { match &mut self.data { Indices::MHBT(data) => data.insert(node), Indices::LinearIndex(data) => data.insert(node), @@ -134,10 +182,17 @@ impl Index for Database { unimplemented!(); } - fn datasets(&self) -> Vec { + fn signatures(&self) -> Vec { + match &self.data { + Indices::MHBT(data) => data.signatures(), + Indices::LinearIndex(data) => data.signatures(), + } + } + + fn signature_refs(&self) -> Vec<&Self::Item> { match &self.data { - Indices::MHBT(data) => data.datasets(), - Indices::LinearIndex(data) => data.datasets(), + Indices::MHBT(data) => data.signature_refs(), + Indices::LinearIndex(data) => data.signature_refs(), } } } @@ -170,7 +225,7 @@ fn load_sbts_and_sigs( info!("loaded SBT {}", path); n_databases += 1; continue; - } else if let Ok(data) = LinearIndex::>::from_path(path) { + } else if let Ok(data) = LinearIndex::::from_path(path) { // TODO: check compatible dbs.push(Database { data: Indices::LinearIndex(data), @@ -251,7 +306,7 @@ fn search_databases( if similarity >= threshold { results.push(Results { similarity, - match_sig: dataset.clone().into(), + match_sig: dataset.clone(), db: db.path.clone(), }) } @@ -263,7 +318,7 @@ fn search_databases( } fn main() -> Result<(), ExitFailure> { - //setup_panic!(); + //better_panic::install(); env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); @@ -271,6 +326,7 @@ fn main() -> Result<(), ExitFailure> { let m = App::from_yaml(yml).get_matches(); match m.subcommand_name() { + /* FIXME bring back after succint-rs changes Some("draff") => { let cmd = m.subcommand_matches("draff").unwrap(); let inputs = cmd @@ -300,6 +356,14 @@ fn main() -> Result<(), ExitFailure> { draff_compare(inputs)?; } + Some("count_unique") => { + let cmd = m.subcommand_matches("count_unique").unwrap(); + + let index: &str = cmd.value_of("index").unwrap(); + + count_unique(index)?; + } + */ Some("prepare") => { let cmd = m.subcommand_matches("prepare").unwrap(); let index: &str = cmd.value_of("index").unwrap(); @@ -311,29 +375,29 @@ fn main() -> Result<(), ExitFailure> { let inputs = cmd .values_of("inputs") .map(|vals| vals.collect::>()) - .unwrap(); + .expect("Missing inputs"); - let output: &str = cmd.value_of("output").unwrap(); + let output: &str = cmd.value_of("output").expect("Missing output"); + let (output, base) = if output.ends_with(".sbt.json") { + (output.to_owned(), output.trim_end_matches(".sbt.json")) + } else { + (output.to_owned() + ".sbt.json", output) + }; + + let storage: Rc = Rc::new(FSStorage::new(".", &format!(".sbt.{}", base))); - draff_index(inputs, output)?; + index(inputs, storage, &output)?; } Some("scaffold") => { let cmd = m.subcommand_matches("scaffold").unwrap(); let sbt_file = cmd.value_of("current_sbt").unwrap(); let sbt = MHBT::from_path(sbt_file)?; - let mut new_sbt: MHBT = scaffold(sbt.datasets(), sbt.storage()); + let mut new_sbt: MHBT = scaffold(sbt.leaves(), sbt.storage()); new_sbt.save_file("test", None)?; - assert_eq!(new_sbt.datasets().len(), sbt.datasets().len()); - } - Some("count_unique") => { - let cmd = m.subcommand_matches("count_unique").unwrap(); - - let index: &str = cmd.value_of("index").unwrap(); - - count_unique(index)?; + assert_eq!(new_sbt.leaves().len(), sbt.leaves().len()); } Some("search") => { let cmd = m.subcommand_matches("search").unwrap(); @@ -345,10 +409,9 @@ fn main() -> Result<(), ExitFailure> { let query = load_query_signature( cmd.value_of("query").unwrap(), if cmd.is_present("ksize") { - cmd.value_of("ksize").unwrap().parse().unwrap() + Some(cmd.value_of("ksize").unwrap().parse().unwrap()) } else { - // TODO default k - unimplemented!() + None }, Some("dna"), // TODO: select moltype, if cmd.is_present("scaled") { diff --git a/src/bin/sourmash.yml b/src/bin/sourmash.yml index 765358e2a1..ecc072e228 100644 --- a/src/bin/sourmash.yml +++ b/src/bin/sourmash.yml @@ -108,6 +108,25 @@ subcommands: args: - index: help: SBT index + - index: + about: create an index + settings: + - ArgRequiredElseHelp + args: + - ksize: + help: "k-mer size for which to build the SBT." + short: k + long: "ksize" + takes_value: true + required: false + - output: + help: alternative output file + short: o + takes_value: true + required: false + - inputs: + help: signatures + multiple: true # groups: # - protein: diff --git a/src/cmd.rs b/src/cmd.rs index 79538ff82d..c16ea49c3c 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -1,26 +1,29 @@ +use failure::Error; + +use crate::index::MHBT; + +/* FIXME: bring back after boomphf changes use std::path::{Path, PathBuf}; use std::rc::Rc; -use bio::io::fastx; -use failure::Error; use log::info; -use ocf::{get_input, get_output, CompressionFormat}; -use pdatastructs::hyperloglog::HyperLogLog; +use needletail::parse_sequence_path; +use crate::index::{Comparable, Index, MHBT}; use crate::index::linear::LinearIndex; use crate::index::storage::{FSStorage, Storage}; -use crate::index::{Comparable, Dataset, Index, UKHSTree, MHBT}; use crate::signature::{Signature, SigsTrait}; -use crate::sketch::ukhs::{FlatUKHS, UKHSTrait, UniqueUKHS}; use crate::sketch::Sketch; +use crate::index::{UKHSTree}; +use crate::sketch::ukhs::{FlatUKHS, UKHSTrait, UniqueUKHS}; pub fn draff_index(sig_files: Vec<&str>, outfile: &str) -> Result<(), Error> { let storage: Rc = Rc::new( - FSStorage::new(".".into(), ".draff".into()), // TODO: use outfile + FSStorage::new(".", ".draff"), // TODO: use outfile ); //let mut index = UKHSTree::builder().storage(Rc::clone(&storage)).build(); - let mut index = LinearIndex::>::builder() + let mut index = LinearIndex::::builder() .storage(Rc::clone(&storage)) .build(); @@ -40,9 +43,7 @@ pub fn draff_index(sig_files: Vec<&str>, outfile: &str) -> Result<(), Error> { .signatures(vec![Sketch::UKHS(ukhs_sig)]) .build(); - let dataset = sig.into(); - - index.insert(&dataset)?; + index.insert(sig)?; } // TODO: implement to_writer and use this? @@ -85,26 +86,13 @@ pub fn draff_search(index: &str, query: &str) -> Result<(), Error> { .signatures(vec![Sketch::UKHS(ukhs_sig)]) .build(); - let dataset = sig.into(); - - for found in index.search(&dataset, 0.9, false)? { - println!("{:.2}: {:?}", dataset.similarity(found), found); + for found in index.search(&sig, 0.9, false)? { + println!("{:.2}: {:?}", sig.similarity(found), found); } Ok(()) } -pub fn prepare(index_path: &str) -> Result<(), Error> { - let mut index = MHBT::from_path(index_path)?; - - // TODO equivalent to fill_internal in python - //unimplemented!(); - - index.save_file(index_path, None)?; - - Ok(()) -} - pub fn draff_signature(files: Vec<&str>, k: usize, w: usize) -> Result<(), Error> { for filename in files { // TODO: check for stdin? @@ -113,62 +101,68 @@ pub fn draff_signature(files: Vec<&str>, k: usize, w: usize) -> Result<(), Error info!("Build signature for {} with W={}, K={}...", filename, w, k); - let (input, _) = get_input(filename)?; - let reader = fastx::Reader::new(input); - - for record in reader.records() { - let record = record?; - - // if there is anything other than ACGT in sequence, - // it is replaced with A. - // This matches khmer and screed behavior - // - // NOTE: sourmash is different! It uses the force flag to drop - // k-mers that are not ACGT - let seq: Vec = record - .seq() - .iter() - .map(|&x| match x as char { - 'A' | 'C' | 'G' | 'T' => x, - 'a' | 'c' | 'g' | 't' => x.to_ascii_uppercase(), - _ => 'A' as u8, - }) - .collect(); - - ukhs.add_sequence(&seq, false)?; - } + parse_sequence_path( + filename, + |_| {}, + |record| { + // if there is anything other than ACGT in sequence, + // it is replaced with A. + // This matches khmer and screed behavior + // + // NOTE: sourmash is different! It uses the force flag to drop + // k-mers that are not ACGT + let seq: Vec = record + .seq + .iter() + .map(|&x| match x as char { + 'A' | 'C' | 'G' | 'T' => x, + 'a' | 'c' | 'g' | 't' => x.to_ascii_uppercase(), + _ => b'A', + }) + .collect(); + + ukhs.add_sequence(&seq, false) + .expect("Error adding sequence"); + }, + )?; let mut outfile = PathBuf::from(filename); outfile.set_extension("sig"); + /* let mut output = get_output(outfile.to_str().unwrap(), CompressionFormat::No)?; let flat: FlatUKHS = ukhs.into(); flat.to_writer(&mut output)? + */ } info!("Done."); Ok(()) } +*/ +/* FIXME bring back after succint-rs changes pub fn count_unique(index_path: &str) -> Result<(), Error> { let index = MHBT::from_path(index_path)?; info!("Loaded index: {}", index_path); - let mut hll = HyperLogLog::new(16); + let mut hll = pdatastructs::hyperloglog::HyperLogLog::new(16); let mut total_hashes = 0u64; - for (n, dataset) in index.datasets().iter().enumerate() { + for (n, sig) in index.signatures().iter().enumerate() { if n % 1000 == 0 { - info!("Processed {} datasets", n); + info!("Processed {} signatures", n); info!("Unique hashes in {}: {}", index_path, hll.count()); info!("Total hashes in {}: {}", index_path, total_hashes); }; - for hash in dataset.mins() { - hll.add(&hash); - total_hashes += 1; + if let Sketch::MinHash(mh) = &sig.signatures[0] { + for hash in mh.mins() { + hll.add(&hash); + total_hashes += 1; + } } } @@ -177,3 +171,15 @@ pub fn count_unique(index_path: &str) -> Result<(), Error> { Ok(()) } +*/ + +pub fn prepare(index_path: &str) -> Result<(), Error> { + let mut index = MHBT::from_path(index_path)?; + + // TODO equivalent to fill_internal in python + //unimplemented!(); + + index.save_file(index_path, None)?; + + Ok(()) +} diff --git a/src/errors.rs b/src/errors.rs index 23c95de049..a539deb895 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,4 +1,4 @@ -use failure::{Error, Fail}; +use failure::Fail; #[derive(Debug, Fail)] pub enum SourmashError { @@ -21,8 +21,8 @@ pub enum SourmashError { #[fail(display = "different signatures cannot be compared")] MismatchSignatureType, - #[fail(display = "Can only set track_abundance=True if the MinHash is empty")] - NonEmptyMinHash, + #[fail(display = "Can only set {} if the MinHash is empty", message)] + NonEmptyMinHash { message: String }, #[fail(display = "invalid DNA character in input k-mer: {}", message)] InvalidDNA { message: String }, @@ -64,10 +64,11 @@ pub enum SourmashErrorCode { SerdeError = 100_004, } +#[cfg(not(all(target_arch = "wasm32", target_vendor = "unknown")))] impl SourmashErrorCode { - pub fn from_error(error: &Error) -> SourmashErrorCode { + pub fn from_error(error: &failure::Error) -> SourmashErrorCode { for cause in error.iter_chain() { - use crate::utils::Panic; + use crate::ffi::utils::Panic; if cause.downcast_ref::().is_some() { return SourmashErrorCode::Panic; } @@ -82,7 +83,7 @@ impl SourmashErrorCode { SourmashError::MismatchSignatureType => { SourmashErrorCode::MismatchSignatureType } - SourmashError::NonEmptyMinHash => SourmashErrorCode::NonEmptyMinHash, + SourmashError::NonEmptyMinHash { .. } => SourmashErrorCode::NonEmptyMinHash, SourmashError::InvalidDNA { .. } => SourmashErrorCode::InvalidDNA, SourmashError::InvalidProt { .. } => SourmashErrorCode::InvalidProt, SourmashError::InvalidCodonLength { .. } => { diff --git a/src/ffi/minhash.rs b/src/ffi/minhash.rs index 68aa5264db..2f84fab499 100644 --- a/src/ffi/minhash.rs +++ b/src/ffi/minhash.rs @@ -1,12 +1,13 @@ use std::ffi::CStr; -use std::mem; use std::os::raw::c_char; use std::ptr; use std::slice; use crate::errors::SourmashError; use crate::signature::SigsTrait; -use crate::sketch::minhash::{aa_to_dayhoff, translate_codon, KmerMinHash}; +use crate::sketch::minhash::{ + aa_to_dayhoff, aa_to_hp, translate_codon, HashFunctions, KmerMinHash, +}; #[no_mangle] pub unsafe extern "C" fn kmerminhash_new( @@ -14,19 +15,31 @@ pub unsafe extern "C" fn kmerminhash_new( k: u32, prot: bool, dayhoff: bool, + hp: bool, seed: u64, mx: u64, track_abundance: bool, ) -> *mut KmerMinHash { - mem::transmute(Box::new(KmerMinHash::new( + // TODO: at most one of (prot, dayhoff, hp) should be true + + let hash_function = if dayhoff { + HashFunctions::murmur64_dayhoff + } else if hp { + HashFunctions::murmur64_hp + } else if prot { + HashFunctions::murmur64_protein + } else { + HashFunctions::murmur64_DNA + }; + + Box::into_raw(Box::new(KmerMinHash::new( n, k, - prot, - dayhoff, + hash_function, seed, mx, track_abundance, - ))) + ))) as _ } #[no_mangle] @@ -97,8 +110,13 @@ pub unsafe extern "C" fn sourmash_aa_to_dayhoff(aa: c_char) -> c_char { } #[no_mangle] -pub extern "C" fn kmerminhash_remove_hash(ptr: *mut KmerMinHash, h: u64) { - let mh = unsafe { +pub unsafe extern "C" fn sourmash_aa_to_hp(aa: c_char) -> c_char { + aa_to_hp(aa as u8) as c_char +} + +#[no_mangle] +pub unsafe extern "C" fn kmerminhash_remove_hash(ptr: *mut KmerMinHash, h: u64) { + let mh = { assert!(!ptr.is_null()); &mut *ptr }; @@ -107,17 +125,17 @@ pub extern "C" fn kmerminhash_remove_hash(ptr: *mut KmerMinHash, h: u64) { } #[no_mangle] -pub extern "C" fn kmerminhash_remove_many( +pub unsafe extern "C" fn kmerminhash_remove_many( ptr: *mut KmerMinHash, hashes_ptr: *const u64, insize: usize, ) { - let mh = unsafe { + let mh = { assert!(!ptr.is_null()); &mut *ptr }; - let hashes = unsafe { + let hashes = { assert!(!hashes_ptr.is_null()); slice::from_raw_parts(hashes_ptr as *mut u64, insize) }; @@ -236,6 +254,15 @@ pub unsafe extern "C" fn kmerminhash_dayhoff(ptr: *mut KmerMinHash) -> bool { mh.dayhoff() } +#[no_mangle] +pub unsafe extern "C" fn kmerminhash_hp(ptr: *mut KmerMinHash) -> bool { + let mh = { + assert!(!ptr.is_null()); + &mut *ptr + }; + mh.hp() +} + #[no_mangle] pub unsafe extern "C" fn kmerminhash_seed(ptr: *mut KmerMinHash) -> u64 { let mh = { @@ -270,8 +297,8 @@ unsafe fn kmerminhash_enable_abundance(ptr: *mut KmerMinHash) -> Result<()> { &mut *ptr }; - if mh.mins.len() != 0 { - return Err(SourmashError::NonEmptyMinHash.into()); + if mh.mins.is_empty() { + return Err(SourmashError::NonEmptyMinHash { message: "track_abundance=True".into()}.into()); } mh.abunds = Some(vec![]); @@ -306,6 +333,31 @@ pub unsafe extern "C" fn kmerminhash_max_hash(ptr: *mut KmerMinHash) -> u64 { mh.max_hash() } +#[no_mangle] +pub unsafe extern "C" fn kmerminhash_hash_function(ptr: *mut KmerMinHash) -> HashFunctions { + let mh = { + assert!(!ptr.is_null()); + &mut *ptr + }; + mh.hash_function() +} + +ffi_fn! { +unsafe fn kmerminhash_hash_function_set(ptr: *mut KmerMinHash, hash_function: HashFunctions) -> Result<()> { + let mh = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + if !mh.mins.is_empty() { + return Err(SourmashError::NonEmptyMinHash { message: "hash_function".into()}.into()); + } + + mh.hash_function = hash_function; + Ok(()) +} +} + ffi_fn! { unsafe fn kmerminhash_merge(ptr: *mut KmerMinHash, other: *const KmerMinHash) -> Result<()> { let mh = { @@ -366,13 +418,29 @@ unsafe fn kmerminhash_intersection(ptr: *mut KmerMinHash, other: *const KmerMinH &*other }; - if let Ok((_, size)) = mh.intersection(other_mh) { + if let Ok((_, size)) = mh.intersection_size(other_mh) { return Ok(size); } Ok(0) } } +ffi_fn! { +unsafe fn kmerminhash_containment_ignore_maxhash(ptr: *mut KmerMinHash, other: *const KmerMinHash) + -> Result { + let mh = { + assert!(!ptr.is_null()); + &mut *ptr + }; + let other_mh = { + assert!(!other.is_null()); + &*other + }; + + mh.containment_ignore_maxhash(&other_mh) +} +} + ffi_fn! { unsafe fn kmerminhash_compare(ptr: *mut KmerMinHash, other: *const KmerMinHash) -> Result { diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index 9c1d5a982e..ec84d16e5a 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -1,8 +1,13 @@ //! # Foreign Function Interface for calling sourmash from a C API //! //! Primary client for now is the Python version, using CFFI and milksnake. +#![allow(clippy::missing_safety_doc)] + +#[macro_use] +pub mod utils; pub mod minhash; +pub mod nodegraph; pub mod signature; use std::ffi::CStr; diff --git a/src/ffi/nodegraph.rs b/src/ffi/nodegraph.rs new file mode 100644 index 0000000000..734ab26286 --- /dev/null +++ b/src/ffi/nodegraph.rs @@ -0,0 +1,181 @@ +use std::ffi::CStr; +use std::os::raw::c_char; +use std::slice; + +use niffler::get_input; + +use crate::sketch::minhash::KmerMinHash; +use crate::sketch::nodegraph::Nodegraph; + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_new() -> *mut Nodegraph { + Box::into_raw(Box::new(Nodegraph::default())) as _ +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_free(ptr: *mut Nodegraph) { + if ptr.is_null() { + return; + } + Box::from_raw(ptr); +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_with_tables( + ksize: usize, + starting_size: usize, + n_tables: usize, +) -> *mut Nodegraph { + Box::into_raw(Box::new(Nodegraph::with_tables( + starting_size, + n_tables, + ksize, + ))) as _ +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_count(ptr: *mut Nodegraph, h: u64) -> bool { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + ng.count(h) +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_get(ptr: *mut Nodegraph, h: u64) -> usize { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + ng.get(h) +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_expected_collisions(ptr: *mut Nodegraph) -> f64 { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + ng.expected_collisions() +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_ksize(ptr: *mut Nodegraph) -> usize { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + ng.ksize() +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_tablesize(ptr: *mut Nodegraph) -> usize { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + ng.tablesize() +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_ntables(ptr: *mut Nodegraph) -> usize { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + ng.ntables() +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_noccupied(ptr: *mut Nodegraph) -> usize { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + ng.noccupied() +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_matches(ptr: *mut Nodegraph, mh_ptr: *mut KmerMinHash) -> usize { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + let mh = { + assert!(!ptr.is_null()); + &mut *mh_ptr + }; + + ng.matches(mh) +} + +#[no_mangle] +pub unsafe extern "C" fn nodegraph_update(ptr: *mut Nodegraph, optr: *mut Nodegraph) { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + let ong = { + assert!(!optr.is_null()); + &mut *optr + }; + + ng.update(ong); +} + +ffi_fn! { +unsafe fn nodegraph_from_path(filename: *const c_char) -> Result<*mut Nodegraph> { + let c_str = { + assert!(!filename.is_null()); + + CStr::from_ptr(filename) + }; + + let (mut input, _) = get_input(c_str.to_str()?)?; + let ng = Nodegraph::from_reader(&mut input)?; + + Ok(Box::into_raw(Box::new(ng))) +} +} + +ffi_fn! { +unsafe fn nodegraph_from_buffer(ptr: *const c_char, insize: usize) -> Result<*mut Nodegraph> { + let buf = { + assert!(!ptr.is_null()); + slice::from_raw_parts(ptr as *mut u8, insize) + }; + + let ng = Nodegraph::from_reader(&mut &buf[..])?; + + Ok(Box::into_raw(Box::new(ng))) +} +} + +ffi_fn! { +unsafe fn nodegraph_save(ptr: *mut Nodegraph, filename: *const c_char) -> Result<()> { + let ng = { + assert!(!ptr.is_null()); + &mut *ptr + }; + + let c_str = { + assert!(!filename.is_null()); + + CStr::from_ptr(filename) + }; + + ng.save(c_str.to_str()?)?; + + Ok(()) +} +} diff --git a/src/ffi/signature.rs b/src/ffi/signature.rs index 093ff95d0c..56e147ebe3 100644 --- a/src/ffi/signature.rs +++ b/src/ffi/signature.rs @@ -3,13 +3,13 @@ use std::io; use std::os::raw::c_char; use std::slice; -use ocf::get_input; +use niffler::get_input; use serde_json; +use crate::ffi::utils::SourmashStr; use crate::signature::Signature; use crate::sketch::minhash::KmerMinHash; use crate::sketch::Sketch; -use crate::utils::SourmashStr; // Signature methods @@ -243,8 +243,13 @@ unsafe fn signatures_load_path(ptr: *const c_char, // TODO: implement ignore_md5sum + let k = match ksize { + 0 => None, + x => Some(x) + }; + let (mut input, _) = get_input(buf.to_str()?)?; - let filtered_sigs = Signature::load_signatures(&mut input, ksize, moltype, None)?; + let filtered_sigs = Signature::load_signatures(&mut input, k, moltype, None)?; let ptr_sigs: Vec<*mut Signature> = filtered_sigs.into_iter().map(|x| { Box::into_raw(Box::new(x)) as *mut Signature @@ -277,10 +282,15 @@ unsafe fn signatures_load_buffer(ptr: *const c_char, } }; + let k = match ksize { + 0 => None, + x => Some(x) + }; + // TODO: implement ignore_md5sum let mut reader = io::BufReader::new(buf); - let filtered_sigs = Signature::load_signatures(&mut reader, ksize, moltype, None)?; + let filtered_sigs = Signature::load_signatures(&mut reader, k, moltype, None)?; let ptr_sigs: Vec<*mut Signature> = filtered_sigs.into_iter().map(|x| { Box::into_raw(Box::new(x)) as *mut Signature diff --git a/src/utils.rs b/src/ffi/utils.rs similarity index 97% rename from src/utils.rs rename to src/ffi/utils.rs index 999dd21e4c..68c883c77e 100644 --- a/src/utils.rs +++ b/src/ffi/utils.rs @@ -25,7 +25,7 @@ macro_rules! ffi_fn ( $(#[$attr])* pub unsafe extern "C" fn $name($($aname: $aty,)*) -> $rv { - $crate::utils::landingpad(|| $body) + $crate::ffi::utils::landingpad(|| $body) } ); @@ -39,7 +39,7 @@ macro_rules! ffi_fn ( pub unsafe extern "C" fn $name($($aname: $aty,)*) { // this silences panics and stuff - $crate::utils::landingpad(|| { $body; Ok(0 as ::std::os::raw::c_int) }); + $crate::ffi::utils::landingpad(|| { $body; Ok(0 as ::std::os::raw::c_int) }); } } ); diff --git a/src/from.rs b/src/from.rs index e0846a11d3..67821927a3 100644 --- a/src/from.rs +++ b/src/from.rs @@ -1,15 +1,22 @@ -use finch::minhashes::MinHashKmers; +use finch::sketch_schemes::mash::MashSketcher; +use finch::sketch_schemes::SketchScheme; -use crate::signatures::minhash::KmerMinHash; +use crate::sketch::minhash::{HashFunctions, KmerMinHash}; -impl From for KmerMinHash { - fn from(other: MinHashKmers) -> KmerMinHash { - let values = other.into_vec(); +/* + TODO: + - also convert scaled sketches + - sourmash Signature equivalent is the finch Sketch, write conversions for that too +*/ + +impl From for KmerMinHash { + fn from(other: MashSketcher) -> KmerMinHash { + let values = other.to_vec(); let mut new_mh = KmerMinHash::new( values.len() as u32, values.get(0).unwrap().kmer.len() as u32, - false, + HashFunctions::murmur64_DNA, 42, 0, true, @@ -20,7 +27,9 @@ impl From for KmerMinHash { .map(|x| (x.hash as u64, x.count as u64)) .collect(); - new_mh.add_many_with_abund(&hash_with_abunds); + new_mh + .add_many_with_abund(&hash_with_abunds) + .expect("Error adding hashes with abund"); new_mh } @@ -32,27 +41,30 @@ mod test { use std::collections::HashSet; use std::iter::FromIterator; - use crate::signatures::minhash::KmerMinHash; + use crate::signature::SigsTrait; + use crate::sketch::minhash::{HashFunctions, KmerMinHash}; - use finch::minhashes::MinHashKmers; - use needletail::kmer::canonical; + use finch::sketch_schemes::mash::MashSketcher; + use needletail::kmer::CanonicalKmers; + use needletail::Sequence; use super::*; #[test] fn finch_behavior() { - let mut a = KmerMinHash::new(20, 10, false, 42, 0, true); - let mut b = MinHashKmers::new(20, 42); + let mut a = KmerMinHash::new(20, 10, HashFunctions::murmur64_DNA, 42, 0, true); + let mut b = MashSketcher::new(20, 10, 42); let seq = b"TGCCGCCCAGCACCGGGTGACTAGGTTGAGCCATGATTAACCTGCAATGA"; + let rc = seq.reverse_complement(); - a.add_sequence(seq, false); + a.add_sequence(seq, false).unwrap(); - for kmer in seq.windows(10) { - b.push(&canonical(kmer), 0); + for (_, kmer, _) in CanonicalKmers::new(seq, &rc, 10) { + b.push(&kmer, 0); } - let b_hashes = b.into_vec(); + let b_hashes = b.to_vec(); let s1: HashSet<_> = HashSet::from_iter(a.mins.iter().map(|x| *x)); let s2: HashSet<_> = HashSet::from_iter(b_hashes.iter().map(|x| x.hash as u64)); @@ -76,15 +88,16 @@ mod test { #[test] fn from_finch() { - let mut a = KmerMinHash::new(20, 10, false, 42, 0, true); - let mut b = MinHashKmers::new(20, 42); + let mut a = KmerMinHash::new(20, 10, HashFunctions::murmur64_DNA, 42, 0, true); + let mut b = MashSketcher::new(20, 10, 42); let seq = b"TGCCGCCCAGCACCGGGTGACTAGGTTGAGCCATGATTAACCTGCAATGA"; + let rc = seq.reverse_complement(); - a.add_sequence(seq, false); + a.add_sequence(seq, false).unwrap(); - for kmer in seq.windows(10) { - b.push(&canonical(kmer), 0); + for (_, kmer, _) in CanonicalKmers::new(seq, &rc, 10) { + b.push(&kmer, 0); } let c = KmerMinHash::from(b); diff --git a/src/index/bigsi.rs b/src/index/bigsi.rs index 07f67304d4..6fac529a05 100644 --- a/src/index/bigsi.rs +++ b/src/index/bigsi.rs @@ -5,7 +5,7 @@ use failure::{Error, Fail}; use fixedbitset::FixedBitSet; use typed_builder::TypedBuilder; -use crate::index::{Comparable, Index}; +use crate::index::Index; use crate::signature::{Signature, SigsTrait}; use crate::sketch::nodegraph::Nodegraph; use crate::sketch::Sketch; @@ -79,22 +79,9 @@ impl BIGSI { } } -impl Index for BIGSI { +impl<'a> Index<'a> for BIGSI { type Item = Signature; - - fn find( - &self, - _search_fn: F, - _sig: &Self::Item, - _threshold: f64, - ) -> Result, Error> - where - F: Fn(&dyn Comparable, &Self::Item, f64) -> bool, - { - // TODO: is there a better way than making this a runtime check? - //Err(BIGSIError::MethodDisabled.into()) - unimplemented!(); - } + //type SignatureIterator = std::slice::Iter<'a, Self::Item>; fn search( &self, @@ -109,12 +96,10 @@ impl Index for BIGSI { let mut counter: HashMap = HashMap::with_capacity(hashes.size()); for hash in &hashes.mins { - self.query(*hash) - .map(|dataset_idx| { - let idx = counter.entry(dataset_idx).or_insert(0); - *idx += 1; - }) - .count(); + self.query(*hash).for_each(|dataset_idx| { + let idx = counter.entry(dataset_idx).or_insert(0); + *idx += 1; + }); } for (idx, count) in counter { @@ -140,8 +125,8 @@ impl Index for BIGSI { } } - fn insert(&mut self, node: &Self::Item) -> Result<(), Error> { - self.add(node.clone()); + fn insert(&mut self, node: Self::Item) -> Result<(), Error> { + self.add(node); Ok(()) } @@ -153,9 +138,19 @@ impl Index for BIGSI { unimplemented!() } - fn datasets(&self) -> Vec { + fn signatures(&self) -> Vec { + unimplemented!() + } + + fn signature_refs(&self) -> Vec<&Self::Item> { unimplemented!() } + + /* + fn iter_signatures(&'a self) -> Self::SignatureIterator { + self.datasets.iter() + } + */ } #[cfg(test)] @@ -163,14 +158,10 @@ mod test { use std::fs::File; use std::io::BufReader; use std::path::PathBuf; - use std::rc::Rc; - - use lazy_init::Lazy; use super::BIGSI; - use crate::index::storage::ReadData; - use crate::index::Dataset; + use crate::index::SigStore; use crate::index::{Index, MHBT}; use crate::signature::Signature; @@ -182,29 +173,20 @@ mod test { let sbt = MHBT::from_path(filename).expect("Loading error"); let mut bigsi = BIGSI::new(10000, 10); - let datasets = sbt.datasets(); + let datasets = sbt.signatures(); let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); filename.push("tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); let mut reader = BufReader::new(File::open(filename).unwrap()); - let sigs = Signature::load_signatures(&mut reader, 31, Some("DNA".into()), None).unwrap(); + let sigs = + Signature::load_signatures(&mut reader, Some(31), Some("DNA".into()), None).unwrap(); let sig_data = sigs[0].clone(); - let data = Lazy::new(); - data.get_or_create(|| sig_data); - - let leaf = Dataset::builder() - .data(Rc::new(data)) - .filename("") - .name("") - .metadata("") - .storage(None) - .build(); + let leaf: SigStore<_> = sig_data.into(); - for l in &datasets { - let data = l.data().unwrap(); - bigsi.insert(data).expect("insertion error!"); + for l in datasets { + bigsi.insert(l).expect("insertion error!"); } let results_sbt = sbt.search(&leaf, 0.5, false).unwrap(); diff --git a/src/index/linear.rs b/src/index/linear.rs index 6def3bd82b..76ce5b6918 100644 --- a/src/index/linear.rs +++ b/src/index/linear.rs @@ -11,15 +11,18 @@ use serde_derive::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use crate::index::storage::{FSStorage, ReadData, Storage, StorageInfo, ToWriter}; -use crate::index::{Comparable, Dataset, DatasetInfo, Index}; +use crate::index::{Comparable, DatasetInfo, Index, SigStore}; #[derive(TypedBuilder)] -pub struct LinearIndex { +pub struct LinearIndex +where + L: Sync, +{ #[builder(default)] storage: Option>, #[builder(default)] - pub(crate) datasets: Vec, + pub(crate) datasets: Vec>, } #[derive(Serialize, Deserialize)] @@ -29,36 +32,16 @@ struct LinearInfo { leaves: Vec, } -impl Index for LinearIndex +impl<'a, L> Index<'a> for LinearIndex where - L: Clone + Comparable, + L: Sync + Clone + Comparable + 'a, + SigStore: From, { type Item = L; + //type SignatureIterator = std::slice::Iter<'a, Self::Item>; - fn find( - &self, - search_fn: F, - sig: &Self::Item, - threshold: f64, - ) -> Result, Error> - where - F: Fn(&dyn Comparable, &Self::Item, f64) -> bool, - { - Ok(self - .datasets - .iter() - .flat_map(|node| { - if search_fn(node, sig, threshold) { - Some(node) - } else { - None - } - }) - .collect()) - } - - fn insert(&mut self, node: &L) -> Result<(), Error> { - self.datasets.push(node.clone()); + fn insert(&mut self, node: L) -> Result<(), Error> { + self.datasets.push(node.into()); Ok(()) } @@ -77,15 +60,31 @@ where unimplemented!() } - fn datasets(&self) -> Vec { - self.datasets.to_vec() + fn signatures(&self) -> Vec { + self.datasets + .iter() + .map(|x| x.data.get().unwrap().clone()) + .collect() } + + fn signature_refs(&self) -> Vec<&Self::Item> { + self.datasets + .iter() + .map(|x| x.data.get().unwrap()) + .collect() + } + + /* + fn iter_signatures(&'a self) -> Self::SignatureIterator { + self.datasets.iter() + } + */ } -impl LinearIndex> +impl LinearIndex where L: std::marker::Sync + ToWriter, - Dataset: ReadData, + SigStore: ReadData, { pub fn save_file>( &mut self, @@ -127,12 +126,12 @@ where mem::replace(&mut l.storage, Some(Rc::clone(&storage))); let filename = (*l).save(&l.filename).unwrap(); - let new_node = DatasetInfo { - filename: filename, + + DatasetInfo { + filename, name: l.name.clone(), metadata: l.metadata.clone(), - }; - new_node + } }) .collect(), }; @@ -143,7 +142,7 @@ where Ok(()) } - pub fn from_path>(path: P) -> Result>, Error> { + pub fn from_path>(path: P) -> Result, Error> { let file = File::open(&path)?; let mut reader = BufReader::new(file); @@ -153,12 +152,11 @@ where basepath.push(path); basepath.canonicalize()?; - let linear = - LinearIndex::>::from_reader(&mut reader, &basepath.parent().unwrap())?; + let linear = LinearIndex::::from_reader(&mut reader, &basepath.parent().unwrap())?; Ok(linear) } - pub fn from_reader(rdr: &mut R, path: P) -> Result>, Error> + pub fn from_reader(rdr: &mut R, path: P) -> Result, Error> where R: Read, P: AsRef, @@ -177,15 +175,12 @@ where datasets: linear .leaves .into_iter() - .map(|l| { - let new_node = Dataset { - filename: l.filename, - name: l.name, - metadata: l.metadata, - storage: Some(Rc::clone(&storage)), - data: Rc::new(Lazy::new()), - }; - new_node + .map(|l| SigStore { + filename: l.filename, + name: l.name, + metadata: l.metadata, + storage: Some(Rc::clone(&storage)), + data: Rc::new(Lazy::new()), }) .collect(), }) diff --git a/src/index/mod.rs b/src/index/mod.rs index adb07267ef..9f5ecdbe85 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -11,10 +11,10 @@ pub mod storage; pub mod search; +use std::ops::Deref; use std::path::Path; use std::rc::Rc; -use cfg_if::cfg_if; use failure::Error; use lazy_init::Lazy; use serde_derive::{Deserialize, Serialize}; @@ -25,21 +25,28 @@ use crate::index::search::{search_minhashes, search_minhashes_containment}; use crate::index::storage::{ReadData, ReadDataError, Storage}; use crate::signature::Signature; use crate::sketch::nodegraph::Nodegraph; -use crate::sketch::ukhs::{FlatUKHS, UKHSTrait}; use crate::sketch::Sketch; -pub type MHBT = SBT, Dataset>; -pub type UKHSTree = SBT, Dataset>; +/* FIXME: bring back after boomphf changes +use crate::sketch::ukhs::{FlatUKHS, UKHSTrait}; +pub type UKHSTree = SBT, Signature>; +*/ + +pub type MHBT = SBT, Signature>; +/* FIXME: bring back after MQF works on macOS and Windows +use cfg_if::cfg_if; cfg_if! { if #[cfg(not(target_arch = "wasm32"))] { use mqf::MQF; - pub type MHMT = SBT, Dataset>; + pub type MHMT = SBT, Signature>; } } +*/ -pub trait Index { - type Item; +pub trait Index<'a> { + type Item: Comparable; + //type SignatureIterator: Iterator; fn find( &self, @@ -48,7 +55,20 @@ pub trait Index { threshold: f64, ) -> Result, Error> where - F: Fn(&dyn Comparable, &Self::Item, f64) -> bool; + F: Fn(&dyn Comparable, &Self::Item, f64) -> bool, + { + Ok(self + .signature_refs() + .into_iter() + .flat_map(|node| { + if search_fn(&node, sig, threshold) { + Some(node) + } else { + None + } + }) + .collect()) + } fn search( &self, @@ -65,13 +85,27 @@ pub trait Index { //fn gather(&self, sig: &Self::Item, threshold: f64) -> Result, Error>; - fn insert(&mut self, node: &Self::Item) -> Result<(), Error>; + fn insert(&mut self, node: Self::Item) -> Result<(), Error>; + + fn batch_insert(&mut self, nodes: Vec) -> Result<(), Error> { + for node in nodes { + self.insert(node)?; + } + + Ok(()) + } fn save>(&self, path: P) -> Result<(), Error>; fn load>(path: P) -> Result<(), Error>; - fn datasets(&self) -> Vec; + fn signatures(&self) -> Vec; + + fn signature_refs(&self) -> Vec<&Self::Item>; + + /* + fn iter_signatures(&self) -> Self::SignatureIterator; + */ } // TODO: split into two traits, Similarity and Containment? @@ -101,7 +135,7 @@ pub struct DatasetInfo { } #[derive(TypedBuilder, Default, Clone)] -pub struct Dataset +pub struct SigStore where T: std::marker::Sync, { @@ -114,7 +148,7 @@ where pub(crate) data: Rc>, } -impl Dataset +impl SigStore where T: std::marker::Sync + Default, { @@ -123,20 +157,20 @@ where } } -impl std::fmt::Debug for Dataset +impl std::fmt::Debug for SigStore where T: std::marker::Sync, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Dataset [filename: {}, name: {}, metadata: {}]", + "SigStore [filename: {}, name: {}, metadata: {}]", self.filename, self.name, self.metadata ) } } -impl ReadData for Dataset { +impl ReadData for SigStore { fn data(&self) -> Result<&Signature, Error> { if let Some(sig) = self.data.get() { Ok(sig) @@ -160,8 +194,8 @@ impl ReadData for Dataset { } } -impl Dataset { - pub fn count_common(&self, other: &Dataset) -> u64 { +impl SigStore { + pub fn count_common(&self, other: &SigStore) -> u64 { let ng: &Signature = self.data().unwrap(); let ong: &Signature = other.data().unwrap(); @@ -188,21 +222,29 @@ impl Dataset { } } -impl From> for Signature { - fn from(other: Dataset) -> Signature { +impl From> for Signature { + fn from(other: SigStore) -> Signature { other.data.get().unwrap().to_owned() } } -impl From for Dataset { - fn from(other: Signature) -> Dataset { +impl Deref for SigStore { + type Target = Signature; + + fn deref(&self) -> &Signature { + self.data.get().unwrap() + } +} + +impl From for SigStore { + fn from(other: Signature) -> SigStore { let name = other.name(); let filename = other.filename(); let data = Lazy::new(); data.get_or_create(|| other); - Dataset::builder() + SigStore::builder() .name(name) .filename(filename) .data(data) @@ -212,8 +254,8 @@ impl From for Dataset { } } -impl Comparable> for Dataset { - fn similarity(&self, other: &Dataset) -> f64 { +impl Comparable> for SigStore { + fn similarity(&self, other: &SigStore) -> f64 { let ng: &Signature = self.data().unwrap(); let ong: &Signature = other.data().unwrap(); @@ -225,16 +267,18 @@ impl Comparable> for Dataset { } } + /* FIXME: bring back after boomphf changes if let Sketch::UKHS(mh) = &ng.signatures[0] { if let Sketch::UKHS(omh) = &ong.signatures[0] { return 1. - mh.distance(&omh); } } + */ unimplemented!() } - fn containment(&self, other: &Dataset) -> f64 { + fn containment(&self, other: &SigStore) -> f64 { let ng: &Signature = self.data().unwrap(); let ong: &Signature = other.data().unwrap(); @@ -250,3 +294,38 @@ impl Comparable> for Dataset { unimplemented!() } } + +impl Comparable for Signature { + fn similarity(&self, other: &Signature) -> f64 { + // TODO: select the right signatures... + // TODO: better matching here, what if it is not a mh? + if let Sketch::MinHash(mh) = &self.signatures[0] { + if let Sketch::MinHash(omh) = &other.signatures[0] { + return mh.compare(&omh).unwrap(); + } + } + + /* FIXME: bring back after boomphf changes + if let Sketch::UKHS(mh) = &self.signatures[0] { + if let Sketch::UKHS(omh) = &other.signatures[0] { + return 1. - mh.distance(&omh); + } + } + */ + + unimplemented!() + } + + fn containment(&self, other: &Signature) -> f64 { + // TODO: select the right signatures... + // TODO: better matching here, what if it is not a mh? + if let Sketch::MinHash(mh) = &self.signatures[0] { + if let Sketch::MinHash(omh) = &other.signatures[0] { + let common = mh.count_common(&omh).unwrap(); + let size = mh.mins.len(); + return common as f64 / size as f64; + } + } + unimplemented!() + } +} diff --git a/src/index/sbt/mhbt.rs b/src/index/sbt/mhbt.rs index 46f710e521..b26975f6de 100644 --- a/src/index/sbt/mhbt.rs +++ b/src/index/sbt/mhbt.rs @@ -1,10 +1,13 @@ +use std::collections::HashMap; use std::io::Write; +use std::rc::Rc; use failure::Error; +use lazy_init::Lazy; -use crate::index::sbt::{FromFactory, Node, Update, SBT}; +use crate::index::sbt::{Factory, FromFactory, Node, Update, SBT}; use crate::index::storage::{ReadData, ReadDataError, ToWriter}; -use crate::index::{Comparable, Dataset}; +use crate::index::Comparable; use crate::signature::{Signature, SigsTrait}; use crate::sketch::nodegraph::Nodegraph; use crate::sketch::Sketch; @@ -18,9 +21,24 @@ impl ToWriter for Nodegraph { } } -impl FromFactory> for SBT, L> { - fn factory(&self, _name: &str) -> Result, Error> { - unimplemented!() +impl FromFactory> for SBT, L> { + fn factory(&self, name: &str) -> Result, Error> { + match self.factory { + Factory::GraphFactory { args: (k, t, n) } => { + let n = Nodegraph::with_tables(t as usize, n as usize, k as usize); + + let data = Lazy::new(); + data.get_or_create(|| n); + + Ok(Node::builder() + .filename(name) + .name(name) + .metadata(HashMap::default()) + .storage(self.storage()) + .data(Rc::new(data)) + .build()) + } + } } } @@ -30,9 +48,35 @@ impl Update> for Node { } } -impl Update> for Dataset { - fn update(&self, _other: &mut Node) -> Result<(), Error> { - unimplemented!(); +impl Update> for Signature { + fn update(&self, parent: &mut Node) -> Result<(), Error> { + // TODO: avoid copy here + let mut parent_data = parent.data()?.clone(); + + if let Sketch::MinHash(sig) = &self.signatures[0] { + sig.mins.iter().for_each(|h| { + parent_data.count(*h); + }); + + let min_n_below = parent + .metadata + .entry("min_n_below".into()) + .or_insert(u64::max_value()); + + *min_n_below = u64::min(sig.size() as u64, *min_n_below); + if *min_n_below == 0 { + *min_n_below = 1 + } + } else { + //TODO what if it is not a minhash? + unimplemented!() + } + + let data = Lazy::new(); + data.get_or_create(|| parent_data); + parent.data = Rc::new(data); + + Ok(()) } } @@ -50,13 +94,12 @@ impl Comparable> for Node { } } -impl Comparable> for Node { - fn similarity(&self, other: &Dataset) -> f64 { +impl Comparable for Node { + fn similarity(&self, other: &Signature) -> f64 { let ng: &Nodegraph = self.data().unwrap(); - let oth: &Signature = other.data().unwrap(); // TODO: select the right signatures... - if let Sketch::MinHash(sig) = &oth.signatures[0] { + if let Sketch::MinHash(sig) = &other.signatures[0] { if sig.size() == 0 { return 0.0; } @@ -74,12 +117,11 @@ impl Comparable> for Node { } } - fn containment(&self, other: &Dataset) -> f64 { + fn containment(&self, other: &Signature) -> f64 { let ng: &Nodegraph = self.data().unwrap(); - let oth: &Signature = other.data().unwrap(); // TODO: select the right signatures... - if let Sketch::MinHash(sig) = &oth.signatures[0] { + if let Sketch::MinHash(sig) = &other.signatures[0] { if sig.size() == 0 { return 0.0; } @@ -108,3 +150,226 @@ impl ReadData for Node { } } } + +#[cfg(test)] +mod test { + use std::fs::File; + use std::io::{BufReader, Seek, SeekFrom}; + use std::path::PathBuf; + use std::rc::Rc; + use tempfile; + + use assert_matches::assert_matches; + use lazy_init::Lazy; + + use super::Factory; + + use crate::index::linear::LinearIndex; + use crate::index::sbt::scaffold; + use crate::index::search::{search_minhashes, search_minhashes_containment}; + use crate::index::storage::ReadData; + use crate::index::{Index, SigStore, MHBT}; + use crate::signature::Signature; + + #[test] + fn save_sbt() { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/v5.sbt.json"); + + let mut sbt = MHBT::from_path(filename).expect("Loading error"); + + let mut tmpfile = tempfile::NamedTempFile::new().unwrap(); + sbt.save_file(tmpfile.path(), None).unwrap(); + + tmpfile.seek(SeekFrom::Start(0)).unwrap(); + } + + #[test] + fn load_sbt() { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/v5.sbt.json"); + + let sbt = MHBT::from_path(filename).expect("Loading error"); + + assert_eq!(sbt.d, 2); + //assert_eq!(sbt.storage.backend, "FSStorage"); + //assert_eq!(sbt.storage.args["path"], ".sbt.v5"); + //assert_matches!(&sbt.storage, ::FSStorage(args) => { + // assert_eq!(args, &[1, 100000, 4]); + //}); + assert_matches!(&sbt.factory, Factory::GraphFactory { args } => { + assert_eq!(args, &(1, 100000.0, 4)); + }); + + println!("sbt leaves {:?} {:?}", sbt.leaves.len(), sbt.leaves); + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); + + let mut reader = BufReader::new(File::open(filename).unwrap()); + let sigs = + Signature::load_signatures(&mut reader, Some(31), Some("DNA".into()), None).unwrap(); + let sig_data = sigs[0].clone(); + + let data = Lazy::new(); + data.get_or_create(|| sig_data); + + let leaf = SigStore::builder() + .data(Rc::new(data)) + .filename("") + .name("") + .metadata("") + .storage(None) + .build(); + + let results = sbt.find(search_minhashes, &leaf, 0.5).unwrap(); + assert_eq!(results.len(), 1); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let results = sbt.find(search_minhashes, &leaf, 0.1).unwrap(); + assert_eq!(results.len(), 2); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let mut linear = LinearIndex::builder().storage(sbt.storage()).build(); + for l in &sbt.leaves { + linear.insert(l.1.data().unwrap().clone()).unwrap(); + } + + println!( + "linear leaves {:?} {:?}", + linear.datasets.len(), + linear.datasets + ); + + let results = linear.find(search_minhashes, &leaf, 0.5).unwrap(); + assert_eq!(results.len(), 1); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let results = linear.find(search_minhashes, &leaf, 0.1).unwrap(); + assert_eq!(results.len(), 2); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let results = linear + .find(search_minhashes_containment, &leaf, 0.5) + .unwrap(); + assert_eq!(results.len(), 2); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let results = linear + .find(search_minhashes_containment, &leaf, 0.1) + .unwrap(); + assert_eq!(results.len(), 4); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + } + + #[test] + #[ignore] + fn roundtrip_sbt() -> Result<(), Box> { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/v5.sbt.json"); + + let sbt = MHBT::from_path(filename)?; + + assert_eq!(sbt.d, 2); + //assert_eq!(sbt.storage.backend, "FSStorage"); + //assert_eq!(sbt.storage.args["path"], ".sbt.v5"); + //assert_matches!(&sbt.storage, ::FSStorage(args) => { + // assert_eq!(args, &[1, 100000, 4]); + //}); + assert_matches!(&sbt.factory, Factory::GraphFactory { args } => { + assert_eq!(args, &(1, 100000.0, 4)); + }); + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); + + let mut reader = BufReader::new(File::open(filename)?); + let sigs = Signature::load_signatures(&mut reader, Some(31), Some("DNA".into()), None)?; + let sig_data = sigs[0].clone(); + + let leaf: SigStore<_> = sig_data.into(); + + let results = sbt.find(search_minhashes, &leaf, 0.5)?; + assert_eq!(results.len(), 1); + //println!("results: {:?}", results); + //println!("leaf: {:?}", leaf); + + let results = sbt.find(search_minhashes, &leaf, 0.1)?; + assert_eq!(results.len(), 2); + //println!("results: {:?}", results); + //println!("leaf: {:?}", leaf); + + println!("sbt internal {:?} {:?}", sbt.nodes.len(), sbt.nodes); + println!("sbt leaves {:?} {:?}", sbt.leaves.len(), sbt.leaves); + + let mut new_sbt: MHBT = MHBT::builder().storage(None).build(); + let datasets = sbt.signatures(); + for l in datasets { + new_sbt.insert(l)?; + } + + for (i, node) in &sbt.nodes { + assert_eq!(node.data().unwrap(), new_sbt.nodes[i].data().unwrap()); + } + + assert_eq!(new_sbt.signature_refs().len(), 7); + println!("new_sbt internal {:?} {:?}", sbt.nodes.len(), sbt.nodes); + println!("new_sbt leaves {:?} {:?}", sbt.leaves.len(), sbt.leaves); + + let results = new_sbt.find(search_minhashes, &leaf, 0.5)?; + //println!("results: {:?}", results); + //println!("leaf: {:?}", leaf); + assert_eq!(results.len(), 1); + + let results = new_sbt.find(search_minhashes, &leaf, 0.1)?; + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + assert_eq!(results.len(), 2); + + let results = new_sbt.find(search_minhashes_containment, &leaf, 0.5)?; + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + assert_eq!(results.len(), 2); + + let results = new_sbt.find(search_minhashes_containment, &leaf, 0.1)?; + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + assert_eq!(results.len(), 4); + + Ok(()) + } + + #[test] + fn scaffold_sbt() { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/v5.sbt.json"); + + let sbt = MHBT::from_path(filename).expect("Loading error"); + + let new_sbt: MHBT = scaffold(sbt.leaves(), sbt.storage()); + + assert_eq!(new_sbt.signatures().len(), 7); + } + + #[test] + fn load_v4() { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/v4.sbt.json"); + + let _sbt = MHBT::from_path(filename).expect("Loading error"); + } + + #[test] + fn load_v5() { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/v5.sbt.json"); + + let _sbt = MHBT::from_path(filename).expect("Loading error"); + } +} diff --git a/src/index/sbt/mhmt.rs b/src/index/sbt/mhmt.rs index a53834750d..c2baa7496a 100644 --- a/src/index/sbt/mhmt.rs +++ b/src/index/sbt/mhmt.rs @@ -5,7 +5,7 @@ use mqf::MQF; use crate::index::sbt::{FromFactory, Node, Update, SBT}; use crate::index::storage::{ReadData, ReadDataError, ToWriter}; -use crate::index::{Comparable, Dataset}; +use crate::index::Comparable; use crate::signature::{Signature, SigsTrait}; use crate::sketch::Sketch; @@ -34,7 +34,7 @@ impl ReadData for Node { // TODO: using tempfile for now, but ideally want to avoid that let mut tmpfile = tempfile::NamedTempFile::new().unwrap(); - tmpfile.write_all(&mut &raw[..]).unwrap(); + tmpfile.write_all(&raw[..]).unwrap(); MQF::deserialize(tmpfile.path()).unwrap() })) @@ -46,7 +46,7 @@ impl ReadData for Node { } } -impl FromFactory> for SBT, L> { +impl FromFactory> for SBT, L> { fn factory(&self, _name: &str) -> Result, Error> { unimplemented!() } @@ -58,7 +58,7 @@ impl Update> for Node { } } -impl Update> for Dataset { +impl Update> for Signature { fn update(&self, _other: &mut Node) -> Result<(), Error> { unimplemented!(); } @@ -66,27 +66,26 @@ impl Update> for Dataset { impl Comparable> for Node { fn similarity(&self, other: &Node) -> f64 { - let ng: &MQF = self.data().unwrap(); - let ong: &MQF = other.data().unwrap(); + let _ng: &MQF = self.data().unwrap(); + let _ong: &MQF = other.data().unwrap(); unimplemented!(); //ng.similarity(&ong) } fn containment(&self, other: &Node) -> f64 { - let ng: &MQF = self.data().unwrap(); - let ong: &MQF = other.data().unwrap(); + let _ng: &MQF = self.data().unwrap(); + let _ong: &MQF = other.data().unwrap(); unimplemented!(); //ng.containment(&ong) } } -impl Comparable> for Node { - fn similarity(&self, other: &Dataset) -> f64 { +impl Comparable for Node { + fn similarity(&self, other: &Signature) -> f64 { let ng: &MQF = self.data().unwrap(); - let oth: &Signature = other.data().unwrap(); // TODO: select the right signatures... - if let Sketch::MinHash(sig) = &oth.signatures[0] { + if let Sketch::MinHash(sig) = &other.signatures[0] { if sig.size() == 0 { return 0.0; } @@ -109,12 +108,11 @@ impl Comparable> for Node { } } - fn containment(&self, other: &Dataset) -> f64 { + fn containment(&self, other: &Signature) -> f64 { let ng: &MQF = self.data().unwrap(); - let oth: &Signature = other.data().unwrap(); // TODO: select the right signatures... - if let Sketch::MinHash(sig) = &oth.signatures[0] { + if let Sketch::MinHash(sig) = &other.signatures[0] { if sig.size() == 0 { return 0.0; } @@ -133,3 +131,97 @@ impl Comparable> for Node { } } } + +/* FIXME: bring back after MQF works on macOS and Windows +#[cfg(test)] +mod test { + use std::fs::File; + use std::io::{BufReader, Seek, SeekFrom}; + use std::path::PathBuf; + use std::rc::Rc; + use tempfile; + + use assert_matches::assert_matches; + use lazy_init::Lazy; + + use super::{scaffold, Factory}; + + use crate::index::linear::LinearIndex; + use crate::index::search::{search_minhashes, search_minhashes_containment}; + use crate::index::storage::ReadData; + use crate::index::{Index, SigStore, MHBT}; + use crate::signature::Signature; + + #[cfg(not(target_arch = "wasm32"))] + #[test] + fn load_mhmt() { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/v5_mhmt.sbt.json"); + + let mut sbt = crate::index::MHMT::from_path(filename).expect("Loading error"); + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); + + let mut reader = BufReader::new(File::open(filename).unwrap()); + let sigs = Signature::load_signatures(&mut reader, 31, Some("DNA".into()), None).unwrap(); + let sig_data = sigs[0].clone(); + + let data = Lazy::new(); + data.get_or_create(|| sig_data); + + let leaf = SigStore::builder() + .data(Rc::new(data)) + .filename("") + .name("") + .metadata("") + .storage(None) + .build(); + + let results = sbt.find(search_minhashes, &leaf, 0.5).unwrap(); + //assert_eq!(results.len(), 1); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let results = sbt.find(search_minhashes, &leaf, 0.1).unwrap(); + assert_eq!(results.len(), 2); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let mut linear = LinearIndex::builder().storage(sbt.storage()).build(); + for l in &sbt.leaves { + linear.insert(l.1.data().unwrap().clone()).unwrap(); + } + + println!( + "linear leaves {:?} {:?}", + linear.datasets.len(), + linear.datasets + ); + + let results = linear.find(search_minhashes, &leaf, 0.5).unwrap(); + assert_eq!(results.len(), 1); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let results = linear.find(search_minhashes, &leaf, 0.1).unwrap(); + assert_eq!(results.len(), 2); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let results = linear + .find(search_minhashes_containment, &leaf, 0.5) + .unwrap(); + assert_eq!(results.len(), 2); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + + let results = linear + .find(search_minhashes_containment, &leaf, 0.1) + .unwrap(); + assert_eq!(results.len(), 4); + println!("results: {:?}", results); + println!("leaf: {:?}", leaf); + } + */ +} diff --git a/src/index/sbt/mod.rs b/src/index/sbt/mod.rs index 139b6913c1..8262262a98 100644 --- a/src/index/sbt/mod.rs +++ b/src/index/sbt/mod.rs @@ -1,8 +1,13 @@ pub mod mhbt; + +/* FIXME: bring back after boomphf changes pub mod ukhs; +*/ +/* FIXME: bring back after MQF works on macOS and Windows #[cfg(not(target_arch = "wasm32"))] pub mod mhmt; +*/ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; @@ -17,11 +22,12 @@ use std::rc::Rc; use failure::Error; use lazy_init::Lazy; +use log::info; use serde_derive::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use crate::index::storage::{FSStorage, ReadData, Storage, StorageInfo, ToWriter}; -use crate::index::{Comparable, Dataset, DatasetInfo, Index}; +use crate::index::{Comparable, DatasetInfo, Index, SigStore}; use crate::signature::Signature; pub trait Update { @@ -33,7 +39,10 @@ pub trait FromFactory { } #[derive(TypedBuilder)] -pub struct SBT { +pub struct SBT +where + L: Sync, +{ #[builder(default = 2)] d: u32, @@ -47,7 +56,7 @@ pub struct SBT { nodes: HashMap, #[builder(default_code = "HashMap::default()")] - leaves: HashMap, + leaves: HashMap>, } const fn parent(pos: u64, d: u64) -> u64 { @@ -60,7 +69,7 @@ const fn child(parent: u64, pos: u64, d: u64) -> u64 { impl SBT where - L: std::clone::Clone + Default, + L: std::clone::Clone + Default + Sync, N: Default, { #[inline(always)] @@ -86,15 +95,32 @@ where self.storage.clone() } + /* + fn fill_up(&mut self) -> Result<(), Error> { + let mut visited = HashSet::new(); + let mut queue: Vec<_> = self.leaves.keys().collect(); + + while !queue.is_empty() { + let pos = queue.pop().unwrap(); + + if !visited.contains(&pos) { + visited.insert(pos); + } + } + + Ok(()) + } + */ + // combine } -impl SBT, Dataset> +impl SBT, T> where - T: std::marker::Sync + ToWriter, + T: std::marker::Sync + ToWriter + Clone, U: std::marker::Sync + ToWriter, Node: ReadData, - Dataset: ReadData, + SigStore: ReadData, { fn parse_v4(rdr: &mut R) -> Result where @@ -112,7 +138,7 @@ where Ok(SBTInfo::V5(sinfo)) } - pub fn from_reader(rdr: &mut R, path: P) -> Result, Dataset>, Error> + pub fn from_reader(rdr: &mut R, path: P) -> Result, T>, Error> where R: Read, P: AsRef, @@ -165,7 +191,7 @@ where .leaves .into_iter() .map(|(n, l)| { - let new_node = Dataset { + let new_node = SigStore { filename: l.filename, name: l.name, metadata: l.metadata, @@ -192,7 +218,7 @@ where }; Some((*n, new_node)) } - NodeInfoV4::Dataset(_) => None, + NodeInfoV4::Leaf(_) => None, }) .collect(); @@ -201,8 +227,8 @@ where .into_iter() .filter_map(|(n, x)| match x { NodeInfoV4::Node(_) => None, - NodeInfoV4::Dataset(l) => { - let new_node = Dataset { + NodeInfoV4::Leaf(l) => { + let new_node = SigStore { filename: l.filename, name: l.name, metadata: l.metadata, @@ -227,7 +253,7 @@ where }) } - pub fn from_path>(path: P) -> Result, Dataset>, Error> { + pub fn from_path>(path: P) -> Result, T>, Error> { let file = File::open(&path)?; let mut reader = BufReader::new(file); @@ -238,8 +264,7 @@ where // TODO: canonicalize doesn't work on wasm32-wasi //basepath.canonicalize()?; - let sbt = - SBT::, Dataset>::from_reader(&mut reader, &basepath.parent().unwrap())?; + let sbt = SBT::, T>::from_reader(&mut reader, &basepath.parent().unwrap())?; Ok(sbt) } @@ -286,7 +311,7 @@ where let filename = (*l).save(&l.filename).unwrap(); let new_node = NodeInfo { - filename: filename, + filename, name: l.name.clone(), metadata: l.metadata.clone(), }; @@ -303,9 +328,10 @@ where // set storage to new one mem::replace(&mut l.storage, Some(Rc::clone(&storage))); + // TODO: this should be l.md5sum(), not l.filename let filename = (*l).save(&l.filename).unwrap(); let new_node = DatasetInfo { - filename: filename, + filename, name: l.name.clone(), metadata: l.metadata.clone(), }; @@ -319,13 +345,18 @@ where Ok(()) } + + pub fn leaves(&self) -> Vec> { + self.leaves.values().cloned().collect() + } } -impl Index for SBT +impl<'a, N, L> Index<'a> for SBT where N: Comparable + Comparable + Update + Debug + Default, - L: Comparable + Update + Clone + Debug + Default, + L: Comparable + Update + Clone + Debug + Default + Sync, SBT: FromFactory, + SigStore: From + ReadData, { type Item = L; @@ -339,6 +370,7 @@ where while !queue.is_empty() { let pos = queue.pop().unwrap(); + if !visited.contains(&pos) { visited.insert(pos); @@ -349,8 +381,9 @@ where } } } else if let Some(leaf) = self.leaves.get(&pos) { - if search_fn(leaf, sig, threshold) { - matches.push(leaf); + let data = leaf.data().expect("Error reading data"); + if search_fn(data, sig, threshold) { + matches.push(data); } } } @@ -359,11 +392,11 @@ where Ok(matches) } - fn insert(&mut self, dataset: &L) -> Result<(), Error> { + fn insert(&mut self, dataset: L) -> Result<(), Error> { if self.leaves.is_empty() { // in this case the tree is empty, // just add the dataset to the first available leaf - self.leaves.entry(0).or_insert(dataset.clone()); + self.leaves.entry(0).or_insert_with(|| dataset.into()); return Ok(()); } @@ -373,6 +406,7 @@ where // TODO: find position by similarity search let pos = self.leaves.keys().max().unwrap() + 1; let parent_pos = self.parent(pos).unwrap(); + let final_pos; if let Entry::Occupied(pnode) = self.leaves.entry(parent_pos) { // Case 1: parent is a Leaf @@ -384,7 +418,7 @@ where // for each children update the parent node // TODO: write the update method - leaf.update(&mut new_node)?; + leaf.data.get().unwrap().update(&mut new_node)?; dataset.update(&mut new_node)?; // node and parent are children of new internal node @@ -393,7 +427,8 @@ where let c2_pos = c_pos.next().unwrap(); self.leaves.entry(c1_pos).or_insert(leaf); - self.leaves.entry(c2_pos).or_insert(dataset.clone()); + self.leaves.entry(c2_pos).or_insert_with(|| dataset.into()); + final_pos = c2_pos; // add the new internal node to self.nodes[parent_pos) // TODO check if it is really empty? @@ -409,26 +444,31 @@ where // (if there isn't an empty spot, it was already covered by case 1) Entry::Occupied(mut pnode) => { dataset.update(&mut pnode.get_mut())?; - self.leaves.entry(pos).or_insert(dataset.clone()); + self.leaves.entry(pos).or_insert_with(|| dataset.into()); + final_pos = pos; } // Case 3: parent is None/empty // this can happen with d != 2, need to create parent node Entry::Vacant(pnode) => { - self.leaves.entry(c_pos).or_insert(dataset.clone()); dataset.update(&mut new_node)?; + self.leaves.entry(c_pos).or_insert_with(|| dataset.into()); + final_pos = c_pos; pnode.insert(new_node); } } } + let entry = &self.leaves[&final_pos]; + let data = entry.data.get().unwrap(); + let mut parent_pos = parent_pos; while let Some(ppos) = self.parent(parent_pos) { if let Entry::Occupied(mut pnode) = self.nodes.entry(parent_pos) { //TODO: use children for this node to update, instead of dragging // dataset up to the root? It would be more generic, but this // works for minhash, draff signatures and nodegraphs... - dataset.update(&mut pnode.get_mut())?; + data.update(&mut pnode.get_mut())?; } parent_pos = ppos; } @@ -436,17 +476,37 @@ where Ok(()) } + /* + fn batch_insert(&mut self, nodes: Vec) -> Result<(), Error> { + self = scaffold(nodes, self.storage()); + Ok(()) + } + */ + fn save>(&self, _path: P) -> Result<(), Error> { - unimplemented!() + unimplemented!(); } fn load>(_path: P) -> Result<(), Error> { unimplemented!() } - fn datasets(&self) -> Vec { - self.leaves.values().cloned().collect() + fn signatures(&self) -> Vec { + self.leaves + .values() + .map(|x| x.data().unwrap().clone()) + .collect() + } + + fn signature_refs(&self) -> Vec<&Self::Item> { + self.leaves.values().map(|x| x.data().unwrap()).collect() } + + /* + fn iter_signatures(&'a self) -> Self::SignatureIterator { + self.leaves.values() + } + */ } /* @@ -497,7 +557,17 @@ where } } -impl Dataset +impl PartialEq for Node +where + T: Sync + PartialEq, + Node: ReadData, +{ + fn eq(&self, other: &Node) -> bool { + self.data().unwrap() == other.data().unwrap() + } +} + +impl SigStore where T: Sync + ToWriter, { @@ -544,7 +614,7 @@ struct NodeInfo { #[serde(untagged)] enum NodeInfoV4 { Node(NodeInfo), - Dataset(DatasetInfo), + Leaf(DatasetInfo), } #[derive(Serialize, Deserialize)] @@ -603,7 +673,7 @@ type HashIntersection = HashSet>; enum BinaryTree { Empty, Internal(Box>), - Dataset(Box>>), + Leaf(Box>>), } struct TreeNode { @@ -613,20 +683,20 @@ struct TreeNode { } pub fn scaffold( - mut datasets: Vec>, + mut datasets: Vec>, storage: Option>, -) -> SBT, Dataset> +) -> SBT, Signature> where N: std::marker::Sync + std::clone::Clone + std::default::Default, { - let mut leaves: HashMap> = HashMap::with_capacity(datasets.len()); + let mut leaves: HashMap> = HashMap::with_capacity(datasets.len()); let mut next_round = Vec::new(); // generate two bottom levels: // - datasets // - first level of internal nodes - eprintln!("Start processing leaves"); + info!("Start processing leaves"); while !datasets.is_empty() { let next_leaf = datasets.pop().unwrap(); @@ -655,7 +725,7 @@ where .cloned() .collect(); - let simleaf_tree = BinaryTree::Dataset(Box::new(TreeNode { + let simleaf_tree = BinaryTree::Leaf(Box::new(TreeNode { element: similar_leaf, left: BinaryTree::Empty, right: BinaryTree::Empty, @@ -663,7 +733,7 @@ where (simleaf_tree, in_common) }; - let leaf_tree = BinaryTree::Dataset(Box::new(TreeNode { + let leaf_tree = BinaryTree::Leaf(Box::new(TreeNode { element: next_leaf, left: BinaryTree::Empty, right: BinaryTree::Empty, @@ -678,15 +748,15 @@ where next_round.push(tree); if next_round.len() % 100 == 0 { - eprintln!("Processed {} leaves", next_round.len() * 2); + info!("Processed {} leaves", next_round.len() * 2); } } - eprintln!("Finished processing leaves"); + info!("Finished processing leaves"); // while we don't get to the root, generate intermediary levels while next_round.len() != 1 { next_round = BinaryTree::process_internal_level(next_round); - eprintln!("Finished processing round {}", next_round.len()); + info!("Finished processing round {}", next_round.len()); } // Convert from binary tree to nodes/leaves @@ -700,7 +770,7 @@ where visited.insert(pos); match cnode { - BinaryTree::Dataset(leaf) => { + BinaryTree::Leaf(leaf) => { leaves.insert(pos, leaf.element); } BinaryTree::Internal(mut node) => { @@ -761,7 +831,7 @@ impl BinaryTree { BinaryTree::Empty => { std::mem::replace(&mut el1.element, HashIntersection::default()) } - _ => panic!("Should not see a Dataset at this level"), + _ => panic!("Should not see a Leaf at this level"), } } else { HashIntersection::default() @@ -784,217 +854,14 @@ impl BinaryTree { } } -#[cfg(test)] -mod test { - use std::fs::File; - use std::io::{BufReader, Seek, SeekFrom}; - use std::path::PathBuf; - use std::rc::Rc; - use tempfile; - - use assert_matches::assert_matches; - use lazy_init::Lazy; - - use super::{scaffold, Factory}; - - use crate::index::linear::LinearIndex; - use crate::index::search::{search_minhashes, search_minhashes_containment}; - use crate::index::{Dataset, Index, MHBT}; - use crate::signature::Signature; - - #[test] - fn save_sbt() { - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/v5.sbt.json"); - - let mut sbt = MHBT::from_path(filename).expect("Loading error"); - - let mut tmpfile = tempfile::NamedTempFile::new().unwrap(); - sbt.save_file(tmpfile.path(), None).unwrap(); - - tmpfile.seek(SeekFrom::Start(0)).unwrap(); - } - - #[cfg(not(target_arch = "wasm32"))] - #[test] - fn load_mhmt() { - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/v5_mhmt.sbt.json"); - - let mut sbt = crate::index::MHMT::from_path(filename).expect("Loading error"); - - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); - - let mut reader = BufReader::new(File::open(filename).unwrap()); - let sigs = Signature::load_signatures(&mut reader, 31, Some("DNA".into()), None).unwrap(); - let sig_data = sigs[0].clone(); - - let data = Lazy::new(); - data.get_or_create(|| sig_data); - - let leaf = Dataset::builder() - .data(Rc::new(data)) - .filename("") - .name("") - .metadata("") - .storage(None) - .build(); - - let results = sbt.find(search_minhashes, &leaf, 0.5).unwrap(); - //assert_eq!(results.len(), 1); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let results = sbt.find(search_minhashes, &leaf, 0.1).unwrap(); - assert_eq!(results.len(), 2); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let mut linear = LinearIndex::builder().storage(sbt.storage()).build(); - for l in &sbt.leaves { - linear.insert(l.1).unwrap(); - } - - println!( - "linear leaves {:?} {:?}", - linear.datasets.len(), - linear.datasets - ); - - let results = linear.find(search_minhashes, &leaf, 0.5).unwrap(); - assert_eq!(results.len(), 1); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let results = linear.find(search_minhashes, &leaf, 0.1).unwrap(); - assert_eq!(results.len(), 2); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let results = linear - .find(search_minhashes_containment, &leaf, 0.5) - .unwrap(); - assert_eq!(results.len(), 2); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let results = linear - .find(search_minhashes_containment, &leaf, 0.1) - .unwrap(); - assert_eq!(results.len(), 4); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - } - - #[test] - fn load_sbt() { - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/v5.sbt.json"); - - let sbt = MHBT::from_path(filename).expect("Loading error"); - - assert_eq!(sbt.d, 2); - //assert_eq!(sbt.storage.backend, "FSStorage"); - //assert_eq!(sbt.storage.args["path"], ".sbt.v5"); - //assert_matches!(&sbt.storage, ::FSStorage(args) => { - // assert_eq!(args, &[1, 100000, 4]); - //}); - assert_matches!(&sbt.factory, Factory::GraphFactory { args } => { - assert_eq!(args, &(1, 100000.0, 4)); - }); - - println!("sbt leaves {:?} {:?}", sbt.leaves.len(), sbt.leaves); - - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); - - let mut reader = BufReader::new(File::open(filename).unwrap()); - let sigs = Signature::load_signatures(&mut reader, 31, Some("DNA".into()), None).unwrap(); - let sig_data = sigs[0].clone(); - - let data = Lazy::new(); - data.get_or_create(|| sig_data); - - let leaf = Dataset::builder() - .data(Rc::new(data)) - .filename("") - .name("") - .metadata("") - .storage(None) - .build(); - - let results = sbt.find(search_minhashes, &leaf, 0.5).unwrap(); - assert_eq!(results.len(), 1); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let results = sbt.find(search_minhashes, &leaf, 0.1).unwrap(); - assert_eq!(results.len(), 2); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let mut linear = LinearIndex::builder().storage(sbt.storage()).build(); - for l in &sbt.leaves { - linear.insert(l.1).unwrap(); - } - - println!( - "linear leaves {:?} {:?}", - linear.datasets.len(), - linear.datasets - ); - - let results = linear.find(search_minhashes, &leaf, 0.5).unwrap(); - assert_eq!(results.len(), 1); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let results = linear.find(search_minhashes, &leaf, 0.1).unwrap(); - assert_eq!(results.len(), 2); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let results = linear - .find(search_minhashes_containment, &leaf, 0.5) - .unwrap(); - assert_eq!(results.len(), 2); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - - let results = linear - .find(search_minhashes_containment, &leaf, 0.1) - .unwrap(); - assert_eq!(results.len(), 4); - println!("results: {:?}", results); - println!("leaf: {:?}", leaf); - } - - #[test] - fn scaffold_sbt() { - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/v5.sbt.json"); - - let sbt = MHBT::from_path(filename).expect("Loading error"); - - let new_sbt: MHBT = scaffold(sbt.datasets(), sbt.storage()); - - assert_eq!(new_sbt.datasets().len(), 7); - } - - #[test] - fn load_v4() { - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/v4.sbt.json"); - - let _sbt = MHBT::from_path(filename).expect("Loading error"); - } - - #[test] - fn load_v5() { - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/v5.sbt.json"); - - let _sbt = MHBT::from_path(filename).expect("Loading error"); +/* +impl From> for SBT, Signature> +where + U: Sync + Default + Clone, +{ + fn from(other: LinearIndex) -> Self { + let storage = other.storage(); + scaffold(other.datasets, storage) } } +*/ diff --git a/src/index/sbt/ukhs.rs b/src/index/sbt/ukhs.rs index 2fac9c7291..30d3573cc4 100644 --- a/src/index/sbt/ukhs.rs +++ b/src/index/sbt/ukhs.rs @@ -7,12 +7,12 @@ use lazy_init::Lazy; use crate::index::sbt::{FromFactory, Node, Update, SBT}; use crate::index::storage::{ReadData, ReadDataError}; -use crate::index::{Comparable, Dataset}; +use crate::index::Comparable; use crate::signature::Signature; use crate::sketch::ukhs::{FlatUKHS, UKHSTrait}; use crate::sketch::Sketch; -impl FromFactory> for SBT, L> { +impl FromFactory> for SBT, L> { fn factory(&self, name: &str) -> Result, Error> { let data = Lazy::new(); // TODO: don't hardcode this! @@ -34,15 +34,13 @@ impl Update> for Node { } } -impl Update> for Dataset { +impl Update> for Signature { fn update(&self, other: &mut Node) -> Result<(), Error> { - let data = &self.data()?; - - let sigs = if data.signatures.len() > 1 { + let sigs = if self.signatures.len() > 1 { // TODO: select the right signatures... unimplemented!() } else { - &data.signatures[0] + &self.signatures[0] }; if let Sketch::UKHS(sig) = sigs { @@ -73,14 +71,12 @@ impl Comparable> for Node { } } -impl Comparable> for Node { - fn similarity(&self, other: &Dataset) -> f64 { - let odata = other.data().unwrap(); - - if odata.signatures.len() > 1 { +impl Comparable for Node { + fn similarity(&self, other: &Signature) -> f64 { + if other.signatures.len() > 1 { // TODO: select the right signatures... unimplemented!() - } else if let Sketch::UKHS(o_sig) = &odata.signatures[0] { + } else if let Sketch::UKHS(o_sig) = &other.signatures[0] { // This is doing a variation of Weighted Jaccard. // The internal nodes are built with max(l_i, r_i) for each // left and right children, so if we do a WJ similarity directly @@ -108,7 +104,7 @@ impl Comparable> for Node { } } - fn containment(&self, _other: &Dataset) -> f64 { + fn containment(&self, _other: &Signature) -> f64 { unimplemented!(); } } diff --git a/src/index/storage.rs b/src/index/storage.rs index 184a6805ca..b1e83aaa07 100644 --- a/src/index/storage.rs +++ b/src/index/storage.rs @@ -44,7 +44,7 @@ impl From<&StorageArgs> for FSStorage { fullpath.push(path); FSStorage { - fullpath: fullpath, + fullpath, subdir: path.clone(), } } @@ -105,7 +105,7 @@ impl Storage for FSStorage { let file = File::create(&fpath)?; let mut buf_writer = BufWriter::new(file); - buf_writer.write(content)?; + buf_writer.write_all(content)?; Ok(path.into()) } diff --git a/src/lib.rs b/src/lib.rs index 2410941b4c..6add4804eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,9 +18,6 @@ pub mod errors; -#[macro_use] -pub mod utils; - pub mod index; pub mod signature; diff --git a/src/signature.rs b/src/signature.rs index dd83d8045d..07823a43c3 100644 --- a/src/signature.rs +++ b/src/signature.rs @@ -82,6 +82,7 @@ pub struct Signature { #[builder(default)] pub filename: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, #[serde(default = "default_license")] @@ -154,7 +155,7 @@ impl Signature { pub fn load_signatures( buf: &mut R, - ksize: usize, + ksize: Option, moltype: Option<&str>, _scaled: Option, ) -> Result, Error> @@ -181,33 +182,41 @@ impl Signature { .filter(|sig| { match sig { Sketch::MinHash(mh) => { - if ksize == 0 || ksize == mh.ksize() as usize { - match moltype { - Some(x) => { - if (x.to_lowercase() == "dna" && !mh.is_protein()) - || (x.to_lowercase() == "protein" && mh.is_protein()) - { - return true; - } + if let Some(k) = ksize { + if k != mh.ksize() as usize { + return false; + } + }; + + match moltype { + Some(x) => { + if (x.to_lowercase() == "dna" && !mh.is_protein()) + || (x.to_lowercase() == "protein" && mh.is_protein()) + { + return true; } - None => return true, // TODO: match previous behavior - }; + } + None => return true, // TODO: match previous behavior }; } Sketch::UKHS(hs) => { - if ksize == 0 || ksize == hs.ksize() as usize { - match moltype { - Some(x) => { - if x.to_lowercase() == "dna" { - return true; - } else { - // TODO: draff only supports dna for now - unimplemented!() - } + if let Some(k) = ksize { + if k != hs.ksize() as usize { + return false; + } + }; + + match moltype { + Some(x) => { + if x.to_lowercase() == "dna" { + return true; + } else { + // TODO: draff only supports dna for now + unimplemented!() } - None => unimplemented!(), - }; - } + } + None => unimplemented!(), + }; } }; false @@ -231,7 +240,7 @@ impl ToWriter for Signature { where W: io::Write, { - match serde_json::to_writer(writer, &self) { + match serde_json::to_writer(writer, &vec![&self]) { Ok(_) => Ok(()), Err(_) => Err(SourmashError::SerdeError.into()), } @@ -267,6 +276,8 @@ impl PartialEq for Signature { if let Sketch::MinHash(other_mh) = &other.signatures[0] { return metadata && (mh == other_mh); } + } else { + unimplemented!() } metadata } @@ -286,7 +297,8 @@ mod test { filename.push("tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); let mut reader = BufReader::new(File::open(filename).unwrap()); - let sigs = Signature::load_signatures(&mut reader, 31, Some("DNA".into()), None).unwrap(); + let sigs = + Signature::load_signatures(&mut reader, Some(31), Some("DNA".into()), None).unwrap(); let _sig_data = sigs[0].clone(); // TODO: check sig_data } diff --git a/src/sketch/minhash.rs b/src/sketch/minhash.rs index e9e1f3163f..427ccb977d 100644 --- a/src/sketch/minhash.rs +++ b/src/sketch/minhash.rs @@ -17,13 +17,22 @@ use crate::signature::SigsTrait; #[cfg(all(target_arch = "wasm32", target_vendor = "unknown"))] use wasm_bindgen::prelude::*; +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u32)] +pub enum HashFunctions { + murmur64_DNA = 1, + murmur64_protein = 2, + murmur64_dayhoff = 3, + murmur64_hp = 4, +} + #[cfg_attr(all(target_arch = "wasm32", target_vendor = "unknown"), wasm_bindgen)] #[derive(Debug, Clone, PartialEq)] pub struct KmerMinHash { num: u32, ksize: u32, - is_protein: bool, - dayhoff: bool, + pub(crate) hash_function: HashFunctions, seed: u64, max_hash: u64, pub(crate) mins: Vec, @@ -35,8 +44,7 @@ impl Default for KmerMinHash { KmerMinHash { num: 1000, ksize: 21, - is_protein: false, - dayhoff: false, + hash_function: HashFunctions::murmur64_DNA, seed: 42, max_hash: 0, mins: Vec::with_capacity(1000), @@ -69,10 +77,12 @@ impl Serialize for KmerMinHash { partial.serialize_field( "molecule", - match &self.is_protein { + match &self.is_protein() { true => { - if self.dayhoff { + if self.dayhoff() { "dayhoff" + } else if self.hp() { + "hp" } else { "protein" } @@ -105,7 +115,12 @@ impl<'de> Deserialize<'de> for KmerMinHash { let tmpsig = TempSig::deserialize(deserializer)?; let num = if tmpsig.max_hash != 0 { 0 } else { tmpsig.num }; - let molecule = tmpsig.molecule.to_lowercase(); + let hash_function = match tmpsig.molecule.to_lowercase().as_ref() { + "protein" => HashFunctions::murmur64_protein, + "dayhoff" => HashFunctions::murmur64_dayhoff, + "dna" => HashFunctions::murmur64_DNA, + _ => unimplemented!(), // TODO: throw error here + }; Ok(KmerMinHash { num, @@ -114,13 +129,7 @@ impl<'de> Deserialize<'de> for KmerMinHash { max_hash: tmpsig.max_hash, mins: tmpsig.mins, abunds: tmpsig.abundances, - is_protein: match molecule.as_ref() { - "protein" => true, - "dayhoff" => true, - "dna" => false, - _ => unimplemented!(), - }, - dayhoff: molecule == "dayhoff", + hash_function, }) } } @@ -129,8 +138,7 @@ impl KmerMinHash { pub fn new( num: u32, ksize: u32, - is_protein: bool, - dayhoff: bool, + hash_function: HashFunctions, seed: u64, max_hash: u64, track_abundance: bool, @@ -153,8 +161,7 @@ impl KmerMinHash { KmerMinHash { num, ksize, - is_protein, - dayhoff, + hash_function, seed, max_hash, mins, @@ -167,7 +174,7 @@ impl KmerMinHash { } pub fn is_protein(&self) -> bool { - self.is_protein + self.hash_function == HashFunctions::murmur64_protein } pub fn seed(&self) -> u64 { @@ -183,8 +190,7 @@ impl KmerMinHash { md5_ctx.consume(self.ksize().to_string()); self.mins .iter() - .map(|x| md5_ctx.consume(x.to_string())) - .count(); + .for_each(|x| md5_ctx.consume(x.to_string())); format!("{:x}", md5_ctx.compute()) } @@ -393,10 +399,7 @@ impl KmerMinHash { pub fn count_common(&self, other: &KmerMinHash) -> Result { self.check_compatible(other)?; - let iter = Intersection { - left: self.mins.iter().peekable(), - right: other.mins.iter().peekable(), - }; + let iter = Intersection::new(self.mins.iter(), other.mins.iter()); Ok(iter.count() as u64) } @@ -407,8 +410,7 @@ impl KmerMinHash { let mut combined_mh = KmerMinHash::new( self.num, self.ksize, - self.is_protein, - self.dayhoff, + self.hash_function, self.seed, self.max_hash, self.abunds.is_some(), @@ -417,18 +419,12 @@ impl KmerMinHash { combined_mh.merge(&self)?; combined_mh.merge(&other)?; - let it1 = Intersection { - left: self.mins.iter().peekable(), - right: other.mins.iter().peekable(), - }; + let it1 = Intersection::new(self.mins.iter(), other.mins.iter()); // TODO: there is probably a way to avoid this Vec here, // and pass the it1 as left in it2. let i1: Vec = it1.cloned().collect(); - let it2 = Intersection { - left: i1.iter().peekable(), - right: combined_mh.mins.iter().peekable(), - }; + let it2 = Intersection::new(i1.iter(), combined_mh.mins.iter()); let common: Vec = it2.cloned().collect(); Ok((common, combined_mh.mins.len() as u64)) @@ -440,8 +436,7 @@ impl KmerMinHash { let mut combined_mh = KmerMinHash::new( self.num, self.ksize, - self.is_protein, - self.dayhoff, + self.hash_function, self.seed, self.max_hash, self.abunds.is_some(), @@ -450,18 +445,12 @@ impl KmerMinHash { combined_mh.merge(&self)?; combined_mh.merge(&other)?; - let it1 = Intersection { - left: self.mins.iter().peekable(), - right: other.mins.iter().peekable(), - }; + let it1 = Intersection::new(self.mins.iter(), other.mins.iter()); // TODO: there is probably a way to avoid this Vec here, // and pass the it1 as left in it2. let i1: Vec = it1.cloned().collect(); - let it2 = Intersection { - left: i1.iter().peekable(), - right: combined_mh.mins.iter().peekable(), - }; + let it2 = Intersection::new(i1.iter(), combined_mh.mins.iter()); Ok((it2.count() as u64, combined_mh.mins.len() as u64)) } @@ -469,14 +458,28 @@ impl KmerMinHash { pub fn compare(&self, other: &KmerMinHash) -> Result { self.check_compatible(other)?; if let Ok((common, size)) = self.intersection_size(other) { - return Ok(common as f64 / u64::max(1, size) as f64); + Ok(common as f64 / u64::max(1, size) as f64) } else { - return Ok(0.0); + Ok(0.0) } } + pub fn containment_ignore_maxhash(&self, other: &KmerMinHash) -> Result { + let it = Intersection::new(self.mins.iter(), other.mins.iter()); + + Ok(it.count() as f64 / self.size() as f64) + } + pub fn dayhoff(&self) -> bool { - self.dayhoff + self.hash_function == HashFunctions::murmur64_dayhoff + } + + pub fn hp(&self) -> bool { + self.hash_function == HashFunctions::murmur64_hp + } + + pub fn hash_function(&self) -> HashFunctions { + self.hash_function } pub fn mins(&self) -> Vec { @@ -501,10 +504,8 @@ impl SigsTrait for KmerMinHash { if self.ksize != other.ksize { return Err(SourmashError::MismatchKSizes.into()); } - if self.is_protein != other.is_protein { - return Err(SourmashError::MismatchDNAProt.into()); - } - if self.dayhoff != other.dayhoff { + if self.hash_function != other.hash_function { + // TODO: fix this error return Err(SourmashError::MismatchDNAProt.into()); } if self.max_hash != other.max_hash { @@ -522,7 +523,7 @@ impl SigsTrait for KmerMinHash { .map(|&x| (x as char).to_ascii_uppercase() as u8) .collect(); if sequence.len() >= (self.ksize as usize) { - if !self.is_protein { + if !self.is_protein() { // dna for kmer in sequence.windows(self.ksize as usize) { if _checkdna(kmer) { @@ -551,20 +552,17 @@ impl SigsTrait for KmerMinHash { .skip(i) .take(sequence.len() - i) .collect(); - let aa = to_aa(&substr, self.dayhoff)?; + let aa = to_aa(&substr, self.dayhoff(), self.hp())?; - aa.windows(aa_ksize as usize) - .map(|n| self.add_word(n)) - .count(); + aa.windows(aa_ksize as usize).for_each(|n| self.add_word(n)); let rc_substr: Vec = rc.iter().cloned().skip(i).take(rc.len() - i).collect(); - let aa_rc = to_aa(&rc_substr, self.dayhoff)?; + let aa_rc = to_aa(&rc_substr, self.dayhoff(), self.hp())?; aa_rc .windows(aa_ksize as usize) - .map(|n| self.add_word(n)) - .count(); + .for_each(|n| self.add_word(n)); } } } @@ -573,8 +571,17 @@ impl SigsTrait for KmerMinHash { } struct Intersection> { - left: Peekable, - right: Peekable, + iter: Peekable, + other: Peekable, +} + +impl> Intersection { + pub fn new(left: I, right: I) -> Self { + Intersection { + iter: left.peekable(), + other: right.peekable(), + } + } } impl> Iterator for Intersection { @@ -582,21 +589,21 @@ impl> Iterator for Intersection { fn next(&mut self) -> Option { loop { - let res = match (self.left.peek(), self.right.peek()) { + let res = match (self.iter.peek(), self.other.peek()) { (Some(ref left_key), Some(ref right_key)) => left_key.cmp(right_key), _ => return None, }; match res { Ordering::Less => { - self.left.next(); + self.iter.next(); } Ordering::Greater => { - self.right.next(); + self.other.next(); } Ordering::Equal => { - self.right.next(); - return self.left.next(); + self.other.next(); + return self.iter.next(); } } } @@ -685,7 +692,7 @@ lazy_static! { // G ("GGT", b'G'), ("GGC", b'G'), ("GGA", b'G'), ("GGG", b'G'), ("GGN", b'G'), - ].into_iter().cloned().collect() + ].iter().cloned().collect() }; } @@ -730,7 +737,33 @@ lazy_static! { // e (b'F', b'f'), (b'W', b'f'), (b'Y', b'f'), - ].into_iter().cloned().collect() + ].iter().cloned().collect() + }; +} + +// HP Hydrophobic/hydrophilic mapping +// From: Phillips, R., Kondev, J., Theriot, J. (2008). +// Physical Biology of the Cell. New York: Garland Science, Taylor & Francis Group. ISBN: 978-0815341635 + +// +// | Amino acid | HP +// |---------------------------------------|---------| +// | A, F, G, I, L, M, P, V, W, Y | h | +// | N, C, S, T, D, E, R, H, K, Q | p | +lazy_static! { + static ref HPTABLE: HashMap = { + [ + // h + (b'A', b'h'), (b'F', b'h'), (b'G', b'h'), (b'I', b'h'), (b'L', b'h'), + (b'M', b'h'), (b'P', b'h'), (b'V', b'h'), (b'W', b'h'), (b'Y', b'h'), + + // p + (b'N', b'p'), (b'C', b'p'), (b'S', b'p'), (b'T', b'p'), (b'D', b'p'), + (b'E', b'p'), (b'R', b'p'), (b'H', b'p'), (b'K', b'p'), (b'Q', b'p'), + ] + .iter() + .cloned() + .collect() }; } @@ -770,8 +803,15 @@ pub(crate) fn aa_to_dayhoff(aa: u8) -> char { } } +pub(crate) fn aa_to_hp(aa: u8) -> char { + match HPTABLE.get(&aa) { + Some(letter) => *letter as char, + None => 'X', + } +} + #[inline] -fn to_aa(seq: &[u8], dayhoff: bool) -> Result, Error> { +fn to_aa(seq: &[u8], dayhoff: bool, hp: bool) -> Result, Error> { let mut converted: Vec = Vec::with_capacity(seq.len() / 3); for chunk in seq.chunks(3) { @@ -782,6 +822,8 @@ fn to_aa(seq: &[u8], dayhoff: bool) -> Result, Error> { let residue = translate_codon(chunk)?; if dayhoff { converted.push(aa_to_dayhoff(residue) as u8); + } else if hp { + converted.push(aa_to_hp(residue) as u8); } else { converted.push(residue); } diff --git a/src/sketch/mod.rs b/src/sketch/mod.rs index 9da8dca613..f02d0f7e1c 100644 --- a/src/sketch/mod.rs +++ b/src/sketch/mod.rs @@ -1,5 +1,6 @@ pub mod minhash; pub mod nodegraph; + pub mod ukhs; use serde_derive::{Deserialize, Serialize}; @@ -11,5 +12,5 @@ use crate::sketch::ukhs::FlatUKHS; #[serde(untagged)] pub enum Sketch { MinHash(KmerMinHash), - UKHS(FlatUKHS), + UKHS(FlatUKHS), // FIXME } diff --git a/src/sketch/nodegraph.rs b/src/sketch/nodegraph.rs index 9d8b52b452..4ae4725c8e 100644 --- a/src/sketch/nodegraph.rs +++ b/src/sketch/nodegraph.rs @@ -7,9 +7,10 @@ use failure::Error; use fixedbitset::FixedBitSet; use primal; +use crate::sketch::minhash::KmerMinHash; use crate::HashIntoType; -#[derive(Debug, Default, Clone, PartialEq)] +#[derive(Debug, Default, Clone)] pub struct Nodegraph { pub(crate) bs: Vec, ksize: usize, @@ -17,6 +18,15 @@ pub struct Nodegraph { unique_kmers: usize, } +// TODO: only checking for the bitset for now, +// since unique_kmers is not saved in a khmer nodegraph +// and occupied_bins also has issues... +impl PartialEq for Nodegraph { + fn eq(&self, other: &Nodegraph) -> bool { + self.bs == other.bs + } +} + impl Nodegraph { pub fn new(tablesizes: &[usize], ksize: usize) -> Nodegraph { let mut bs = Vec::with_capacity(tablesizes.len()); @@ -33,13 +43,24 @@ impl Nodegraph { } pub fn with_tables(tablesize: usize, n_tables: usize, ksize: usize) -> Nodegraph { - // TODO: cache the Sieve somewhere for repeated calls? - let tablesizes: Vec = primal::Primes::all() - .filter(|p| *p >= tablesize) - .take(n_tables) - .collect(); + let mut tablesizes = Vec::with_capacity(n_tables); + + let mut i = (tablesize - 1) as u64; + if i % 2 == 0 { + i += 1 + } + + while tablesizes.len() != n_tables { + if primal::is_prime(i) { + tablesizes.push(i as usize); + } + if i == 1 { + break; + } + i -= 2; + } - Nodegraph::new(&tablesizes, ksize) + Nodegraph::new(tablesizes.as_slice(), ksize) } pub fn count(&mut self, hash: HashIntoType) -> bool { @@ -85,23 +106,51 @@ impl Nodegraph { let mut new_bins = 0; for (bs, bs_other) in self.bs.iter_mut().zip(&other.bs) { - bs_other - .ones() - .map(|x| { - if !bs.put(x) { - new_bins += 1; - } - }) - .count(); + bs_other.ones().for_each(|x| { + if !bs.put(x) { + new_bins += 1; + } + }); } // TODO: occupied bins seems to be broken in khmer? I don't get the same // values... - //self.occupied_bins += new_bins; + self.occupied_bins += new_bins; + } + + pub fn expected_collisions(&self) -> f64 { + let min_size = self.bs.iter().map(|x| x.len()).min().unwrap(); + let n_ht = self.bs.len(); + let occupancy = self.occupied_bins; + + let fp_one = occupancy / min_size; + f64::powf(fp_one as f64, n_ht as f64) + } + + pub fn tablesize(&self) -> usize { + self.bs.iter().map(|x| x.len()).sum() + } + + pub fn noccupied(&self) -> usize { + self.occupied_bins + } + + pub fn matches(&self, mh: &KmerMinHash) -> usize { + mh.mins.iter().filter(|x| self.get(**x) == 1).count() + } + + pub fn ntables(&self) -> usize { + self.bs.len() + } + + pub fn ksize(&self) -> usize { + self.ksize } // save pub fn save>(&self, path: P) -> Result<(), Error> { - self.save_to_writer(&mut File::open(path)?)?; + // TODO: if it ends with gz, open a compressed file + // might use get_output here? + self.save_to_writer(&mut File::create(path)?)?; Ok(()) } @@ -143,6 +192,9 @@ impl Nodegraph { where R: io::Read, { + // TODO: see https://github.com/brainstorm/bio-index-formats for an + // example of using nom to parse binary data. + // Use it here instead of byteorder let signature = rdr.read_u32::()?; assert_eq!(signature, 0x4f58_4c49); @@ -238,11 +290,14 @@ impl Nodegraph { #[cfg(test)] mod test { use super::*; + use cfg_if::cfg_if; use std::io::{BufReader, BufWriter}; use std::path::PathBuf; + cfg_if! { + if #[cfg(not(target_arch = "wasm32"))] { use proptest::num::u64; - use proptest::{proptest, proptest_helper}; + use proptest::{proptest}; proptest! { #[test] @@ -252,6 +307,8 @@ mod test { assert_eq!(ng.get(hash), 1); } } + } + } #[test] fn count_and_get_nodegraph() { diff --git a/src/sketch/ukhs.rs b/src/sketch/ukhs.rs index 47d2b4e543..03114631e0 100644 --- a/src/sketch/ukhs.rs +++ b/src/sketch/ukhs.rs @@ -1,3 +1,41 @@ +use failure::Error; +use serde_derive::{Deserialize, Serialize}; + +use crate::signature::SigsTrait; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlatUKHS {} + +impl FlatUKHS { + pub fn md5sum(&self) -> String { + unimplemented!() + } +} + +impl SigsTrait for FlatUKHS { + fn size(&self) -> usize { + unimplemented!() + } + + fn to_vec(&self) -> Vec { + unimplemented!() + } + + fn ksize(&self) -> usize { + unimplemented!() + } + + fn check_compatible(&self, _other: &Self) -> Result<(), Error> { + unimplemented!() + } + + fn add_sequence(&mut self, _seq: &[u8], _force: bool) -> Result<(), Error> { + unimplemented!() + } +} + +/* FIXME bring back after succint-rs changes + use std::f64::consts::PI; use std::fs::File; use std::hash::BuildHasherDefault; @@ -5,18 +43,13 @@ use std::io::{BufReader, BufWriter, Read, Write}; use std::mem; use std::path::Path; -use failure::Error; use itertools::Itertools; use pdatastructs::hyperloglog::HyperLogLog; -use serde::de::{Deserialize, Deserializer}; -use serde::ser::{Serialize, SerializeStruct, Serializer}; -use serde_derive::Deserialize; use ukhs; use crate::errors::SourmashError; use crate::index::sbt::NoHashHasher; use crate::index::storage::ToWriter; -use crate::signature::SigsTrait; use crate::sketch::nodegraph::Nodegraph; #[derive(Clone)] @@ -38,8 +71,7 @@ impl UKHS { md5_ctx.consume(self.ukhs.k().to_string()); self.buckets .iter() - .map(|x| md5_ctx.consume(x.to_string())) - .count(); + .for_each(|x| md5_ctx.consume(x.to_string())); format!("{:x}", md5_ctx.compute()) } } @@ -179,24 +211,22 @@ impl UKHSTrait for UKHS { // this is the cosine distance as defined by scipy //1. - d - /* // This is the weighted Jaccard distance // TODO: don't iterate twice... - let mins: u64 = self - .buckets - .iter() - .zip(other.buckets.iter()) - .map(|(a, b)| u64::min(*a, *b)) - .sum(); - let maxs: u64 = self - .buckets - .iter() - .zip(other.buckets.iter()) - .map(|(a, b)| u64::max(*a, *b)) - .sum(); - - 1. - (mins as f64 / maxs as f64) - */ + //let mins: u64 = self + // .buckets + // .iter() + // .zip(other.buckets.iter()) + // .map(|(a, b)| u64::min(*a, *b)) + // .sum(); + //let maxs: u64 = self + // .buckets + // .iter() + // .zip(other.buckets.iter()) + // .map(|(a, b)| u64::max(*a, *b)) + // .sum(); + // + //1. - (mins as f64 / maxs as f64) } fn to_writer(&self, writer: &mut W) -> Result<(), Error> @@ -232,13 +262,13 @@ impl SigsTrait for UKHS { // TODO: is seq.len() > W? let it: Vec<(u64, u64)> = self.ukhs.hash_iter_sequence(seq)?.collect(); - /* This one update every unikmer bucket with w_hash - it.into_iter() - .map(|(_, k_hash)| { - self.buckets[self.ukhs.query_bucket(k_hash).unwrap()] += 1; - }) - .count(); - */ + // This one update every unikmer bucket with w_hash + //it.into_iter() + // .map(|(_, k_hash)| { + // self.buckets[self.ukhs.query_bucket(k_hash).unwrap()] += 1; + // }) + // .count(); + // // Only update the bucket for the minimum unikmer found for (_, group) in &it.into_iter().group_by(|(w, _)| *w) { @@ -313,13 +343,12 @@ impl SigsTrait for UKHS { fn add_sequence(&mut self, seq: &[u8], _force: bool) -> Result<(), Error> { let it: Vec<(u64, u64)> = self.ukhs.hash_iter_sequence(seq)?.collect(); - /* This one update every unikmer bucket with w_hash - it.into_iter() - .map(|(w_hash, k_hash)| { - self.buckets[self.ukhs.query_bucket(k_hash).unwrap()].count(w_hash); - }) - .count(); - */ + // This one update every unikmer bucket with w_hash + //it.into_iter() + // .map(|(w_hash, k_hash)| { + // self.buckets[self.ukhs.query_bucket(k_hash).unwrap()].count(w_hash); + // }) + // .count(); // Only update the bucket for the minimum unikmer found for (w_hash, group) in &it.into_iter().group_by(|(w, _)| *w) { @@ -412,13 +441,12 @@ impl SigsTrait for UKHS { fn add_sequence(&mut self, seq: &[u8], _force: bool) -> Result<(), Error> { let it: Vec<(u64, u64)> = self.ukhs.hash_iter_sequence(seq)?.collect(); - /* This one update every unikmer bucket with w_hash - it.into_iter() - .map(|(w_hash, k_hash)| { - self.buckets[self.ukhs.query_bucket(k_hash).unwrap()].add(&w_hash); - }) - .count(); - */ + // This one update every unikmer bucket with w_hash + //it.into_iter() + // .map(|(w_hash, k_hash)| { + // self.buckets[self.ukhs.query_bucket(k_hash).unwrap()].add(&w_hash); + // }) + // .count(); // Only update the bucket for the minimum unikmer found for (w_hash, group) in &it.into_iter().group_by(|(w, _)| *w) { @@ -516,40 +544,37 @@ where // Removed this for now, because calling .into() in these doesn't // transfer all the important information... -/* -impl From for Dataset { - fn from(other: FlatUKHS) -> Dataset { - let data = Lazy::new(); - data.get_or_create(|| other.into()); - - Dataset::builder() - .data(Rc::new(data)) - .filename("") - .name("") - .metadata("") - .storage(None) - .build() - } -} - -impl From for Signature { - fn from(other: FlatUKHS) -> Signature { - Signature::builder() - .hash_function("nthash") // TODO: spec! - .class("draff_signature") // TODO: spec! - .name(Some("draff_file".into())) // TODO: spec! - .signatures(vec![Sketch::UKHS(other)]) - .build() - } -} -*/ +//impl From for Dataset { +// fn from(other: FlatUKHS) -> Dataset { +// let data = Lazy::new(); +// data.get_or_create(|| other.into()); +// +// Dataset::builder() +// .data(Rc::new(data)) +// .filename("") +// .name("") +// .metadata("") +// .storage(None) +// .build() +// } +//} +// +//impl From for Signature { +// fn from(other: FlatUKHS) -> Signature { +// Signature::builder() +// .hash_function("nthash") // TODO: spec! +// .class("draff_signature") // TODO: spec! +// .name(Some("draff_file".into())) // TODO: spec! +// .signatures(vec![Sketch::UKHS(other)]) +// .build() +// } +//} #[cfg(test)] mod test { use std::path::PathBuf; - use bio::io::fasta::Reader; - use ocf::get_input; + use needletail::parse_sequence_path; use super::{FlatUKHS, MemberUKHS, UKHSTrait}; use crate::signature::SigsTrait; @@ -561,13 +586,14 @@ mod test { let mut ukhs = MemberUKHS::new(9, 21).unwrap(); - let (input, _) = get_input(filename.to_str().unwrap()).unwrap(); - let reader = Reader::new(input); - - for record in reader.records() { - let record = record.unwrap(); - ukhs.add_sequence(record.seq(), false).unwrap(); - } + parse_sequence_path( + filename, + |_| {}, + |record| { + ukhs.add_sequence(&record.seq, false).unwrap(); + }, + ) + .expect("error parsing"); // TODO: find test case... //assert_eq!(ukhs.to_vec(), [1, 2, 3]); @@ -580,13 +606,14 @@ mod test { let mut ukhs = FlatUKHS::new(9, 21).unwrap(); - let (input, _) = get_input(filename.to_str().unwrap()).unwrap(); - let reader = Reader::new(input); - - for record in reader.records() { - let record = record.unwrap(); - ukhs.add_sequence(record.seq(), false).unwrap(); - } + parse_sequence_path( + filename, + |_| {}, + |record| { + ukhs.add_sequence(&record.seq, false).unwrap(); + }, + ) + .expect("error parsing"); let mut buffer = Vec::new(); ukhs.to_writer(&mut buffer).unwrap(); @@ -602,3 +629,4 @@ mod test { } } } +*/ diff --git a/src/wasm.rs b/src/wasm.rs index 7d7000dbf7..93318bad50 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -3,7 +3,7 @@ use wasm_bindgen::prelude::*; use serde_json; use crate::signature::SigsTrait; -use crate::sketch::minhash::KmerMinHash; +use crate::sketch::minhash::{HashFunctions, KmerMinHash}; #[wasm_bindgen] impl KmerMinHash { @@ -13,6 +13,7 @@ impl KmerMinHash { ksize: u32, is_protein: bool, dayhoff: bool, + hp: bool, seed: u32, scaled: u32, track_abundance: bool, @@ -25,11 +26,22 @@ impl KmerMinHash { u64::max_value() / scaled as u64 }; + // TODO: at most one of (prot, dayhoff, hp) should be true + + let hash_function = if dayhoff { + HashFunctions::murmur64_dayhoff + } else if hp { + HashFunctions::murmur64_hp + } else if is_protein { + HashFunctions::murmur64_protein + } else { + HashFunctions::murmur64_DNA + }; + KmerMinHash::new( num, ksize, - is_protein, - dayhoff, + hash_function, seed as u64, max_hash, track_abundance, @@ -38,7 +50,8 @@ impl KmerMinHash { #[wasm_bindgen] pub fn add_sequence_js(&mut self, buf: &str) { - self.add_sequence(buf.as_bytes(), true); + self.add_sequence(buf.as_bytes(), true) + .expect("Error adding sequence"); } #[wasm_bindgen] diff --git a/tests/minhash.rs b/tests/minhash.rs index 86a4b5909d..af0853e69f 100644 --- a/tests/minhash.rs +++ b/tests/minhash.rs @@ -1,9 +1,9 @@ use sourmash::signature::SigsTrait; -use sourmash::sketch::minhash::KmerMinHash; +use sourmash::sketch::minhash::{HashFunctions, KmerMinHash}; #[test] fn throws_error() { - let mut mh = KmerMinHash::new(1, 4, false, false, 42, 0, false); + let mut mh = KmerMinHash::new(1, 4, HashFunctions::murmur64_DNA, 42, 0, false); match mh.add_sequence(b"ATGR", false) { Ok(_) => assert!(false, "R is not a valid DNA character"), @@ -13,8 +13,8 @@ fn throws_error() { #[test] fn merge() { - let mut a = KmerMinHash::new(20, 10, false, false, 42, 0, false); - let mut b = KmerMinHash::new(20, 10, false, false, 42, 0, false); + let mut a = KmerMinHash::new(20, 10, HashFunctions::murmur64_DNA, 42, 0, false); + let mut b = KmerMinHash::new(20, 10, HashFunctions::murmur64_DNA, 42, 0, false); a.add_sequence(b"TGCCGCCCAGCA", false).unwrap(); b.add_sequence(b"TGCCGCCCAGCA", false).unwrap(); @@ -40,8 +40,8 @@ fn merge() { #[test] fn compare() { - let mut a = KmerMinHash::new(20, 10, false, false, 42, 0, false); - let mut b = KmerMinHash::new(20, 10, false, false, 42, 0, false); + let mut a = KmerMinHash::new(20, 10, HashFunctions::murmur64_DNA, 42, 0, false); + let mut b = KmerMinHash::new(20, 10, HashFunctions::murmur64_DNA, 42, 0, false); a.add_sequence(b"TGCCGCCCAGCACCGGGTGACTAGGTTGAGCCATGATTAACCTGCAATGA", false) .unwrap(); @@ -67,12 +67,12 @@ fn compare() { #[test] fn dayhoff() { - let mut a = KmerMinHash::new(10, 6, true, true, 42, 0, false); - let mut b = KmerMinHash::new(10, 6, true, false, 42, 0, false); + let mut a = KmerMinHash::new(10, 6, HashFunctions::murmur64_dayhoff, 42, 0, false); + let mut b = KmerMinHash::new(10, 6, HashFunctions::murmur64_protein, 42, 0, false); a.add_sequence(b"ACTGAC", false).unwrap(); b.add_sequence(b"ACTGAC", false).unwrap(); - assert_eq!(a.size(), 2); + assert_eq!(a.size(), 1); assert_eq!(b.size(), 2); } diff --git a/tests/smrs_cmd.rs b/tests/smrs_cmd.rs new file mode 100644 index 0000000000..c203099158 --- /dev/null +++ b/tests/smrs_cmd.rs @@ -0,0 +1,139 @@ +use std::fs; +use std::process::Command; + +use assert_cmd::prelude::*; +use predicates::prelude::*; +use predicates::str::contains; +use tempfile::TempDir; + +#[test] +fn search() -> Result<(), Box> { + let mut cmd = Command::cargo_bin("smrs")?; + + cmd.arg("search") + .arg("tests/test-data/demo/SRR2060939_1.sig") + .arg("tests/test-data/v5.sbt.json") + .assert() + .success() + .stdout(contains("SRR2060939_1.fastq.gz")) + .stdout(contains("SRR2060939_2.fastq.gz")) + .stdout(contains("SRR2255622_1.fastq.gz")); + + Ok(()) +} + +#[test] +#[ignore] +fn search_only_leaves() -> Result<(), Box> { + let mut cmd = Command::cargo_bin("smrs")?; + + cmd.arg("search") + .arg("tests/test-data/demo/SRR2060939_1.sig") + .arg("tests/test-data/leaves.sbt.json") + .assert() + .success() + .stdout(contains("SRR2060939_1.fastq.gz")) + .stdout(contains("SRR2060939_2.fastq.gz")) + .stdout(contains("SRR2255622_1.fastq.gz")); + + Ok(()) +} + +#[test] +#[ignore] +#[cfg(unix)] +fn compute_index_and_search() -> Result<(), Box> { + let tmp_dir = TempDir::new()?; + fs::copy("tests/test-data/short.fa", tmp_dir.path().join("short.fa"))?; + fs::copy( + "tests/test-data/short2.fa", + tmp_dir.path().join("short2.fa"), + )?; + + assert!(tmp_dir.path().join("short.fa").exists()); + assert!(tmp_dir.path().join("short2.fa").exists()); + + let mut cmd = Command::new("sourmash"); + cmd.arg("compute") + .args(&["short.fa", "short2.fa"]) + .current_dir(&tmp_dir) + .assert() + .success(); + + assert!(tmp_dir.path().join("short.fa.sig").exists()); + assert!(tmp_dir.path().join("short2.fa.sig").exists()); + + let mut cmd = Command::new("sourmash"); + //let mut cmd = Command::cargo_bin("smrs")?; + cmd.arg("index") + .args(&["-k", "31"]) + //.args(&["-o", "zzz.sbt.json"]) + .arg("zzz.sbt.json") + .args(&["short.fa.sig", "short2.fa.sig"]) + .current_dir(&tmp_dir) + .assert() + .success(); + + assert!(tmp_dir.path().join("zzz.sbt.json").exists()); + + let cmds = vec![Command::new("sourmash"), Command::cargo_bin("smrs")?]; + + for mut cmd in cmds { + cmd.arg("search") + .args(&["-k", "31"]) + .arg("short.fa.sig") + .arg("zzz.sbt.json") + .current_dir(&tmp_dir) + .assert() + .success() + .stdout(contains("short.fa")) + .stdout(contains("short2.fa")); + } + + Ok(()) +} + +#[test] +#[cfg(unix)] +fn index_and_search() -> Result<(), Box> { + let tmp_dir = TempDir::new()?; + fs::copy( + "tests/test-data/demo/SRR2060939_1.sig", + tmp_dir.path().join("1.sig"), + )?; + fs::copy( + "tests/test-data/demo/SRR2060939_2.sig", + tmp_dir.path().join("2.sig"), + )?; + + assert!(tmp_dir.path().join("1.sig").exists()); + assert!(tmp_dir.path().join("2.sig").exists()); + + let mut cmd = Command::cargo_bin("smrs")?; + cmd.arg("index") + .args(&["-k", "31"]) + .args(&["-o", "zzz.sbt.json"]) + .args(&["1.sig", "2.sig"]) + .current_dir(&tmp_dir) + .assert() + .success(); + + assert!(tmp_dir.path().join("zzz.sbt.json").exists()); + + let cmds = vec![Command::new("sourmash"), Command::cargo_bin("smrs")?]; + + for mut cmd in cmds { + cmd.arg("search") + .args(&["-k", "31"]) + .arg("1.sig") + .arg("zzz.sbt.json") + .current_dir(&tmp_dir) + .assert() + .success() + .stdout(contains("2 matches:")) + .stdout(contains("SRR2060939_1.fastq.gz")) + .stdout(contains("SRR2060939_2.fastq.gz")); + } + + Ok(()) +} diff --git a/tox.ini b/tox.ini index cca92bcccd..a54618c780 100644 --- a/tox.ini +++ b/tox.ini @@ -4,14 +4,14 @@ envlist=py27,py35,py36,py37 [testenv] passenv = CI TRAVIS TRAVIS_* whitelist_externals= - make + make +extras = + test + doc + 10x + storage deps= - codecov - ipfshttpclient - redis - bam2fasta + codecov commands= - pip install -r requirements.txt - pip install -e .[test] - make coverage - codecov --gcov-glob third-party + make coverage + codecov --gcov-glob third-party