-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RUST] python (re)integration v1 (#436)
* Support limit offset as ExecNode * fmt * add test * fix merge errors * address comments * use type alias and remove static lifetime * fmt * [RUST] python integration hello world Hello world for python (re)integration from Rust package. Enables bare minimum: ```python import lance batch = lance.read_batch("/path/to/dataset.lance") lance.write_batch(batch, "/path/to/write/dataset.lance") ``` 1. Separate the Rust lib build and the python lib (e.g., cargo build doesn't seem to work as is) 2. Read batch succeeds but prints errors about tokio worker thread being cancelled 3. Currently each method call has to create its own tokio runtime, bad bad bad. 4. Function signatures / docstrings Our main point of integration with pyarrow is extending the Dataset and Scanner interface. In current python package we: 1. Expose appropriate classes/methods from C++ in the cython code. 2. Define cython class FileSystemDataset(Dataset) 3. Define methods that match the Dataset interface and call our own BuildScanner. For Rust, we have 2 options: 1. Manually define PyTypeInfo for pyarrow.dataset.Dataset and we can define everything in _lib.pyx using pyo3 2. Expose the bare minimum from pyo3 and then code up FileSystemDataset(Dataset) by hand in python Either way, it'll require some mixing of Rust/pyo3 generated python that can be imported and used by a hand-crafted python module on top. Option 2 seems much easier. What do we want to do if people want to use the native pa.dataset.write_dataset and pa.dataset.dataset interfaces? * separate packages * function signature * dataset/scanner classes exposed to python * can trick duckdb but segfaults * let's not fuck up pyarrow * round trip is functional * fmt * fmt * reorg imports * remove prints * add from LanceError for ArrowError * tests for generator * test limit offset * fmt * pass in scanner and dataset * fix merge errors * nearest but segfaults * fuck this shit it's fixed * fmt
- Loading branch information
1 parent
82587cc
commit b23a3f1
Showing
20 changed files
with
812 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
[package] | ||
name = "pylance" | ||
version = "0.1.2" | ||
edition = "2021" | ||
authors = ["Lance Devs <[email protected]>"] | ||
description = "python wrapper for lance-rs" | ||
license = "Apache-2.0" | ||
repository = "https://github.com/eto-ai/lance" | ||
readme = "README.md" | ||
rust-version = "1.65" | ||
keywords = [ | ||
"data-format", | ||
"data-science", | ||
"machine-learning", | ||
"apache-arrow", | ||
"data-analytics" | ||
] | ||
categories = [ | ||
"database-implementations", | ||
"data-structures", | ||
"development-tools", | ||
"science" | ||
] | ||
|
||
|
||
[lib] | ||
name = "lance" | ||
crate-type = ["cdylib"] | ||
|
||
[dependencies] | ||
lance = { path = ".." } | ||
arrow-array = "31.0" | ||
arrow-data = "31.0" | ||
arrow-schema = "31.0" | ||
chrono = "0.4.23" | ||
tokio = { version = "1.23", features = ["rt-multi-thread"] } | ||
futures = "0.3" | ||
pyo3 = { version = "0.17.3", features = ["extension-module", "abi3-py38"] } | ||
arrow = { version = "31.0.0", features = ["pyarrow"] } | ||
|
||
[build-dependencies] | ||
prost-build = "0.11" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../README.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/// Build script using 'built' crate to generate build info. | ||
fn main() { | ||
#[cfg(feature = "build_info")] | ||
{ | ||
println!("cargo:rerun-if-changed=build.rs"); | ||
extern crate built; | ||
use std::env; | ||
use std::path::Path; | ||
|
||
let src = env::var("CARGO_MANIFEST_DIR").unwrap(); | ||
let dst = Path::new(&env::var("OUT_DIR").unwrap()).join("built.rs"); | ||
let mut opts = built::Options::default(); | ||
|
||
opts.set_dependencies(true).set_compiler(true).set_env(true); | ||
|
||
built::write_built_file_with_opts(&opts, Path::new(&src), &dst) | ||
.expect("Failed to acquire build-time information"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[project] | ||
name = "pylance" | ||
dependencies = ["pyarrow>=10", "pandas", "numpy"] | ||
|
||
[tool.maturin] | ||
python-source = "python" | ||
|
||
[build-system] | ||
requires = ["maturin>=0.14,<0.15"] | ||
build-backend = "maturin" | ||
|
||
[project.optional-dependencies] | ||
tests = [ | ||
"pytest", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .dataset import LanceDataset, write_dataset | ||
|
||
|
||
def dataset(uri: str) -> LanceDataset: | ||
return LanceDataset(uri) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
from __future__ import annotations | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pyarrow as pa | ||
import pyarrow.dataset | ||
|
||
from .lance import _Dataset, _Scanner, _write_dataset | ||
|
||
|
||
class LanceDataset: | ||
""" | ||
A dataset in Lance format where the data is stored at the given uri | ||
""" | ||
|
||
def __init__(self, uri: Union[str, Path]): | ||
if isinstance(uri, Path): | ||
uri = str(uri.absolute()) | ||
self._uri = uri | ||
self._ds = _Dataset(uri) | ||
|
||
@property | ||
def uri(self) -> str: | ||
""" | ||
The location of the data | ||
""" | ||
return self._uri | ||
|
||
def scanner(self, columns: Optional[list[str]] = None, | ||
limit: int = 0, offset: Optional[int] = None, | ||
nearest: Optional[dict] = None) -> LanceScanner: | ||
""" | ||
Return a Scanner that can support various pushdowns | ||
Parameters | ||
---------- | ||
columns: list of str, default None | ||
List of column names to be fetched. | ||
All columns if None or unspecified. | ||
limit: int, default 0 | ||
Fetch up to this many rows. All rows if 0 or unspecified. | ||
offset: int, default None | ||
Fetch starting with this row. 0 if None or unspecified. | ||
nearest: dict, default None | ||
Get the rows corresponding to the K most similar vectors | ||
nearest should look like { | ||
"columns": <embedding col name>, | ||
"q": <query vector as pa.Float32Array>, | ||
"k": 10 | ||
} | ||
""" | ||
return (ScannerBuilder(self) | ||
.columns(columns) | ||
.limit(limit) | ||
.offset(offset) | ||
.nearest(**(nearest or {})) | ||
.to_scanner()) | ||
|
||
@property | ||
def schema(self) -> pa.Schema: | ||
""" | ||
The pyarrow Schema for this dataset | ||
""" | ||
return self._ds.schema | ||
|
||
def to_table(self, columns: Optional[list[str]] = None, | ||
limit: int = 0, offset: Optional[int] = None, | ||
nearest: Optional[dict] = None) -> pa.Table: | ||
""" | ||
Read the data into memory and return a pyarrow Table. | ||
Parameters | ||
---------- | ||
columns: list of str, default None | ||
List of column names to be fetched. | ||
All columns if None or unspecified. | ||
limit: int, default 0 | ||
Fetch up to this many rows. All rows if 0 or unspecified. | ||
offset: int, default None | ||
Fetch starting with this row. 0 if None or unspecified. | ||
nearest: dict, default None | ||
Get the rows corresponding to the K most similar vectors | ||
nearest should look like { | ||
"columns": <embedding col name>, | ||
"q": <query vector as pa.Float32Array>, | ||
"k": 10 | ||
} | ||
""" | ||
return self.scanner( | ||
columns=columns, limit=limit, offset=offset, nearest=nearest | ||
).to_table() | ||
|
||
|
||
class ScannerBuilder: | ||
|
||
def __init__(self, ds: LanceDataset): | ||
self.ds = ds | ||
self._limit = 0 | ||
self._offset = None | ||
self._columns = None | ||
self._nearest = None | ||
|
||
def limit(self, n: int = 0) -> ScannerBuilder: | ||
if int(n) < 0: | ||
raise ValueError("Limit must be non-negative") | ||
self._limit = n | ||
return self | ||
|
||
def offset(self, n: Optional[int] = None) -> ScannerBuilder: | ||
if n is not None and int(n) < 0: | ||
raise ValueError("Offset must be non-negative") | ||
self._offset = n | ||
return self | ||
|
||
def columns(self, cols: Optional[list[str]] = None) -> ScannerBuilder: | ||
if cols is not None and len(cols) == 0: | ||
cols = None | ||
self._columns = cols | ||
return self | ||
|
||
def nearest(self, column: Optional[str] = None, | ||
q: Optional[pa.Float32Array] = None, | ||
k: Optional[int] = None) -> ScannerBuilder: | ||
if column is None or q is None: | ||
self._nearest = None | ||
return self | ||
|
||
if self.ds.schema.get_field_index(column) < 0: | ||
raise ValueError(f"Embedding column {column} not in dataset") | ||
if isinstance(q, (np.ndarray, list, tuple)): | ||
q = pa.Float32Array.from_pandas(q) | ||
if k is not None and int(k) <= 0: | ||
raise ValueError(f"Nearest-K must be > 0 but got {k}") | ||
self._nearest = { | ||
"column": column, | ||
"q": q, | ||
"k": k | ||
} | ||
return self | ||
|
||
def to_scanner(self) -> LanceScanner: | ||
scanner = self.ds._ds.scanner(self._columns, self._limit, self._offset, self._nearest) | ||
return LanceScanner(scanner) | ||
|
||
|
||
class LanceScanner: | ||
|
||
def __init__(self, scanner: _Scanner): | ||
self._scanner = scanner | ||
|
||
def to_table(self) -> pa.Table: | ||
""" | ||
Read the data into memory and return a pyarrow Table. | ||
""" | ||
return self.to_reader().read_all() | ||
|
||
def to_reader(self) -> pa.RecordBatchReader: | ||
return self._scanner.to_pyarrow() | ||
|
||
|
||
ReaderLike = Union[pa.Table, pa.dataset.Dataset, pa.dataset.Scanner, | ||
pa.RecordBatchReader, LanceDataset, LanceScanner] | ||
|
||
|
||
def write_dataset(data_obj: ReaderLike, uri: Union[str, Path], | ||
mode: str = "create", | ||
max_rows_per_file: int = 1024*1024, | ||
max_rows_per_group: int = 1024) -> bool: | ||
""" | ||
Write a given data_obj to the given uri | ||
Parameters | ||
---------- | ||
data_obj: Reader-like | ||
The data to be written. Acceptable types are: | ||
- Pyarrow Table, Dataset, Scanner, or RecordBatchReader | ||
- LanceDataset or LanceScanner | ||
uri: str or Path | ||
Where to write the dataset to (directory) | ||
mode: str | ||
create - create a new dataset (raises if uri already exists) | ||
overwrite - create a new snapshot version | ||
append - create a new version that is the concat of the input the | ||
latest version (raises if uri does not exist) | ||
max_rows_per_file: int, default 1024 * 1024 | ||
The max number of rows to write before starting a new file | ||
max_rows_per_group: int, default 1024 | ||
The max number of rows before starting a new group (in the same file) | ||
""" | ||
if isinstance(data_obj, pa.Table): | ||
reader = data_obj.to_reader() | ||
elif isinstance(data_obj, pa.dataset.Dataset): | ||
reader = pa.dataset.Scanner.from_dataset(data_obj).to_reader() | ||
elif isinstance(data_obj, pa.dataset.Scanner): | ||
reader = data_obj.to_reader() | ||
elif isinstance(data_obj, pa.RecordBatchReader): | ||
reader = data_obj | ||
else: | ||
raise TypeError(f"Unknown data_obj type {type(data_obj)}") | ||
# TODO add support for passing in LanceDataset and LanceScanner here | ||
|
||
params = { | ||
"mode": mode, | ||
"max_rows_per_file": max_rows_per_file, | ||
"max_rows_per_group": max_rows_per_group | ||
} | ||
if isinstance(uri, Path): | ||
uri = str(uri.absolute()) | ||
return _write_dataset(reader, str(uri), params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pyarrow as pa | ||
import pyarrow.dataset | ||
import lance | ||
|
||
|
||
def test_table_roundtrip(tmp_path): | ||
uri = tmp_path | ||
|
||
df = pd.DataFrame({'a': range(100), 'b': range(100)}) | ||
tbl = pa.Table.from_pandas(df) | ||
lance.write_dataset(tbl, uri) | ||
|
||
dataset = lance.dataset(uri) | ||
assert dataset.uri == str(uri.absolute()) | ||
assert tbl.schema == dataset.schema | ||
assert tbl == dataset.to_table() | ||
|
||
one_col = dataset.to_table(columns=['a']) | ||
assert one_col == tbl.select(['a']) | ||
|
||
table = dataset.to_table(columns=['a'], limit=20) | ||
assert len(table) == 20 | ||
with_offset = dataset.to_table(columns=['a'], offset=10, limit=10) | ||
assert with_offset == table[10:] | ||
|
||
|
||
def test_input_types(tmp_path): | ||
# test all input types for write_dataset | ||
uri = tmp_path | ||
|
||
df = pd.DataFrame({'a': range(100), 'b': range(100)}) | ||
tbl = pa.Table.from_pandas(df) | ||
_check_roundtrip(tbl, uri / "table.lance", tbl) | ||
|
||
parquet_uri = str(uri / "dataset.parquet") | ||
pa.dataset.write_dataset(tbl, parquet_uri, format="parquet") | ||
ds = pa.dataset.dataset(parquet_uri) | ||
_check_roundtrip(ds, uri / "ds.lance", tbl) | ||
|
||
scanner = pa.dataset.Scanner.from_dataset(ds) | ||
_check_roundtrip(scanner, uri / "scanner.lance", tbl) | ||
|
||
reader = scanner.to_reader() | ||
_check_roundtrip(reader, uri / "reader.lance", tbl) | ||
|
||
# TODO allow Dataset::create to take both async and also RecordBatchReader | ||
# lance_dataset = lance.dataset(uri / "table.lance") | ||
# _check_roundtrip(lance_dataset, uri / "lance_dataset.lance", tbl) | ||
|
||
# lance_scanner = lance_dataset.scanner() | ||
# _check_roundtrip(lance_scanner, uri / "lance_scanner.lance", tbl) | ||
|
||
|
||
def _check_roundtrip(data_obj, uri, expected): | ||
lance.write_dataset(data_obj, uri) | ||
assert expected == lance.dataset(uri).to_table() | ||
|
||
|
||
def test_nearest(tmp_path): | ||
uri = tmp_path | ||
|
||
schema = pa.schema([pa.field("emb", pa.list_(pa.float32(), 32), False)]) | ||
npvals = np.random.rand(32 * 100) | ||
values = pa.array(npvals, type=pa.float32()) | ||
arr = pa.FixedSizeListArray.from_arrays(values, 32) | ||
tbl = pa.Table.from_arrays([arr], schema=schema) | ||
lance.write_dataset(tbl, uri) | ||
|
||
dataset = lance.dataset(uri) | ||
top10 = dataset.to_table(nearest={"column": "emb", "q": arr[0].values, "k": 10}) | ||
scores = l2sq(arr[0].values, npvals.reshape((100, 32))) | ||
indices = np.argsort(scores) | ||
assert tbl.take(indices[:10]).to_pandas().equals(top10.to_pandas()[["emb"]]) | ||
assert np.allclose(scores[indices[:10]], top10.to_pandas().score.values) | ||
|
||
|
||
def l2sq(vec, mat): | ||
return np.sum((mat - vec)**2, axis=1) |
Oops, something went wrong.