Skip to content

Commit

Permalink
[RUST] python (re)integration v1 (#436)
Browse files Browse the repository at this point in the history
* Support limit offset as ExecNode

* fmt

* add test

* fix merge errors

* address comments

* use type alias and remove static lifetime

* fmt

* [RUST] python integration hello world

Hello world for python (re)integration from Rust package.
Enables bare minimum:

```python
import lance
batch = lance.read_batch("/path/to/dataset.lance")
lance.write_batch(batch, "/path/to/write/dataset.lance")
```

1. Separate the Rust lib build and the python lib (e.g., cargo build doesn't seem to work as is)
2. Read batch succeeds but prints errors about tokio worker thread being cancelled
3. Currently each method call has to create its own tokio runtime, bad bad bad.
4. Function signatures / docstrings

Our main point of integration with pyarrow is extending the Dataset and Scanner interface. In current python package we:
1. Expose appropriate classes/methods from C++ in the cython code.
2. Define cython class FileSystemDataset(Dataset)
3. Define methods that match the Dataset interface and call our own BuildScanner.

For Rust, we have 2 options:
1. Manually define PyTypeInfo for pyarrow.dataset.Dataset and we can define everything in _lib.pyx using pyo3
2. Expose the bare minimum from pyo3 and then code up FileSystemDataset(Dataset) by hand in python

Either way, it'll require some mixing of Rust/pyo3 generated python that can be imported and used by a hand-crafted python module on top. Option 2 seems much easier.

What do we want to do if people want to use the native pa.dataset.write_dataset and pa.dataset.dataset interfaces?

* separate packages

* function signature

* dataset/scanner classes exposed to python

* can trick duckdb but segfaults

* let's not fuck up pyarrow

* round trip is functional

* fmt

* fmt

* reorg imports

* remove prints

* add from LanceError for ArrowError

* tests for generator

* test limit offset

* fmt

* pass in scanner and dataset

* fix merge errors

* nearest but segfaults

* fuck this shit it's fixed

* fmt
  • Loading branch information
changhiskhan authored Jan 24, 2023
1 parent 82587cc commit b23a3f1
Show file tree
Hide file tree
Showing 20 changed files with 812 additions and 22 deletions.
1 change: 1 addition & 0 deletions rust/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.env
4 changes: 4 additions & 0 deletions rust/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@
| Dictionary | Yes | No | No |
| RLE | No | No | No |


## Python integration

Done via pyo3 under `pylance` directory (still `import lance` module name though)
42 changes: 42 additions & 0 deletions rust/pylance/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
[package]
name = "pylance"
version = "0.1.2"
edition = "2021"
authors = ["Lance Devs <[email protected]>"]
description = "python wrapper for lance-rs"
license = "Apache-2.0"
repository = "https://github.com/eto-ai/lance"
readme = "README.md"
rust-version = "1.65"
keywords = [
"data-format",
"data-science",
"machine-learning",
"apache-arrow",
"data-analytics"
]
categories = [
"database-implementations",
"data-structures",
"development-tools",
"science"
]


[lib]
name = "lance"
crate-type = ["cdylib"]

[dependencies]
lance = { path = ".." }
arrow-array = "31.0"
arrow-data = "31.0"
arrow-schema = "31.0"
chrono = "0.4.23"
tokio = { version = "1.23", features = ["rt-multi-thread"] }
futures = "0.3"
pyo3 = { version = "0.17.3", features = ["extension-module", "abi3-py38"] }
arrow = { version = "31.0.0", features = ["pyarrow"] }

[build-dependencies]
prost-build = "0.11"
1 change: 1 addition & 0 deletions rust/pylance/README.md
20 changes: 20 additions & 0 deletions rust/pylance/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/// Build script using 'built' crate to generate build info.
fn main() {
#[cfg(feature = "build_info")]
{
println!("cargo:rerun-if-changed=build.rs");
extern crate built;
use std::env;
use std::path::Path;

let src = env::var("CARGO_MANIFEST_DIR").unwrap();
let dst = Path::new(&env::var("OUT_DIR").unwrap()).join("built.rs");
let mut opts = built::Options::default();

opts.set_dependencies(true).set_compiler(true).set_env(true);

built::write_built_file_with_opts(&opts, Path::new(&src), &dst)
.expect("Failed to acquire build-time information");
}
}
15 changes: 15 additions & 0 deletions rust/pylance/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[project]
name = "pylance"
dependencies = ["pyarrow>=10", "pandas", "numpy"]

[tool.maturin]
python-source = "python"

[build-system]
requires = ["maturin>=0.14,<0.15"]
build-backend = "maturin"

[project.optional-dependencies]
tests = [
"pytest",
]
6 changes: 6 additions & 0 deletions rust/pylance/python/lance/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .dataset import LanceDataset, write_dataset


def dataset(uri: str) -> LanceDataset:
return LanceDataset(uri)

211 changes: 211 additions & 0 deletions rust/pylance/python/lance/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset

from .lance import _Dataset, _Scanner, _write_dataset


class LanceDataset:
"""
A dataset in Lance format where the data is stored at the given uri
"""

def __init__(self, uri: Union[str, Path]):
if isinstance(uri, Path):
uri = str(uri.absolute())
self._uri = uri
self._ds = _Dataset(uri)

@property
def uri(self) -> str:
"""
The location of the data
"""
return self._uri

def scanner(self, columns: Optional[list[str]] = None,
limit: int = 0, offset: Optional[int] = None,
nearest: Optional[dict] = None) -> LanceScanner:
"""
Return a Scanner that can support various pushdowns
Parameters
----------
columns: list of str, default None
List of column names to be fetched.
All columns if None or unspecified.
limit: int, default 0
Fetch up to this many rows. All rows if 0 or unspecified.
offset: int, default None
Fetch starting with this row. 0 if None or unspecified.
nearest: dict, default None
Get the rows corresponding to the K most similar vectors
nearest should look like {
"columns": <embedding col name>,
"q": <query vector as pa.Float32Array>,
"k": 10
}
"""
return (ScannerBuilder(self)
.columns(columns)
.limit(limit)
.offset(offset)
.nearest(**(nearest or {}))
.to_scanner())

@property
def schema(self) -> pa.Schema:
"""
The pyarrow Schema for this dataset
"""
return self._ds.schema

def to_table(self, columns: Optional[list[str]] = None,
limit: int = 0, offset: Optional[int] = None,
nearest: Optional[dict] = None) -> pa.Table:
"""
Read the data into memory and return a pyarrow Table.
Parameters
----------
columns: list of str, default None
List of column names to be fetched.
All columns if None or unspecified.
limit: int, default 0
Fetch up to this many rows. All rows if 0 or unspecified.
offset: int, default None
Fetch starting with this row. 0 if None or unspecified.
nearest: dict, default None
Get the rows corresponding to the K most similar vectors
nearest should look like {
"columns": <embedding col name>,
"q": <query vector as pa.Float32Array>,
"k": 10
}
"""
return self.scanner(
columns=columns, limit=limit, offset=offset, nearest=nearest
).to_table()


class ScannerBuilder:

def __init__(self, ds: LanceDataset):
self.ds = ds
self._limit = 0
self._offset = None
self._columns = None
self._nearest = None

def limit(self, n: int = 0) -> ScannerBuilder:
if int(n) < 0:
raise ValueError("Limit must be non-negative")
self._limit = n
return self

def offset(self, n: Optional[int] = None) -> ScannerBuilder:
if n is not None and int(n) < 0:
raise ValueError("Offset must be non-negative")
self._offset = n
return self

def columns(self, cols: Optional[list[str]] = None) -> ScannerBuilder:
if cols is not None and len(cols) == 0:
cols = None
self._columns = cols
return self

def nearest(self, column: Optional[str] = None,
q: Optional[pa.Float32Array] = None,
k: Optional[int] = None) -> ScannerBuilder:
if column is None or q is None:
self._nearest = None
return self

if self.ds.schema.get_field_index(column) < 0:
raise ValueError(f"Embedding column {column} not in dataset")
if isinstance(q, (np.ndarray, list, tuple)):
q = pa.Float32Array.from_pandas(q)
if k is not None and int(k) <= 0:
raise ValueError(f"Nearest-K must be > 0 but got {k}")
self._nearest = {
"column": column,
"q": q,
"k": k
}
return self

def to_scanner(self) -> LanceScanner:
scanner = self.ds._ds.scanner(self._columns, self._limit, self._offset, self._nearest)
return LanceScanner(scanner)


class LanceScanner:

def __init__(self, scanner: _Scanner):
self._scanner = scanner

def to_table(self) -> pa.Table:
"""
Read the data into memory and return a pyarrow Table.
"""
return self.to_reader().read_all()

def to_reader(self) -> pa.RecordBatchReader:
return self._scanner.to_pyarrow()


ReaderLike = Union[pa.Table, pa.dataset.Dataset, pa.dataset.Scanner,
pa.RecordBatchReader, LanceDataset, LanceScanner]


def write_dataset(data_obj: ReaderLike, uri: Union[str, Path],
mode: str = "create",
max_rows_per_file: int = 1024*1024,
max_rows_per_group: int = 1024) -> bool:
"""
Write a given data_obj to the given uri
Parameters
----------
data_obj: Reader-like
The data to be written. Acceptable types are:
- Pyarrow Table, Dataset, Scanner, or RecordBatchReader
- LanceDataset or LanceScanner
uri: str or Path
Where to write the dataset to (directory)
mode: str
create - create a new dataset (raises if uri already exists)
overwrite - create a new snapshot version
append - create a new version that is the concat of the input the
latest version (raises if uri does not exist)
max_rows_per_file: int, default 1024 * 1024
The max number of rows to write before starting a new file
max_rows_per_group: int, default 1024
The max number of rows before starting a new group (in the same file)
"""
if isinstance(data_obj, pa.Table):
reader = data_obj.to_reader()
elif isinstance(data_obj, pa.dataset.Dataset):
reader = pa.dataset.Scanner.from_dataset(data_obj).to_reader()
elif isinstance(data_obj, pa.dataset.Scanner):
reader = data_obj.to_reader()
elif isinstance(data_obj, pa.RecordBatchReader):
reader = data_obj
else:
raise TypeError(f"Unknown data_obj type {type(data_obj)}")
# TODO add support for passing in LanceDataset and LanceScanner here

params = {
"mode": mode,
"max_rows_per_file": max_rows_per_file,
"max_rows_per_group": max_rows_per_group
}
if isinstance(uri, Path):
uri = str(uri.absolute())
return _write_dataset(reader, str(uri), params)
80 changes: 80 additions & 0 deletions rust/pylance/python/tests/test_lance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset
import lance


def test_table_roundtrip(tmp_path):
uri = tmp_path

df = pd.DataFrame({'a': range(100), 'b': range(100)})
tbl = pa.Table.from_pandas(df)
lance.write_dataset(tbl, uri)

dataset = lance.dataset(uri)
assert dataset.uri == str(uri.absolute())
assert tbl.schema == dataset.schema
assert tbl == dataset.to_table()

one_col = dataset.to_table(columns=['a'])
assert one_col == tbl.select(['a'])

table = dataset.to_table(columns=['a'], limit=20)
assert len(table) == 20
with_offset = dataset.to_table(columns=['a'], offset=10, limit=10)
assert with_offset == table[10:]


def test_input_types(tmp_path):
# test all input types for write_dataset
uri = tmp_path

df = pd.DataFrame({'a': range(100), 'b': range(100)})
tbl = pa.Table.from_pandas(df)
_check_roundtrip(tbl, uri / "table.lance", tbl)

parquet_uri = str(uri / "dataset.parquet")
pa.dataset.write_dataset(tbl, parquet_uri, format="parquet")
ds = pa.dataset.dataset(parquet_uri)
_check_roundtrip(ds, uri / "ds.lance", tbl)

scanner = pa.dataset.Scanner.from_dataset(ds)
_check_roundtrip(scanner, uri / "scanner.lance", tbl)

reader = scanner.to_reader()
_check_roundtrip(reader, uri / "reader.lance", tbl)

# TODO allow Dataset::create to take both async and also RecordBatchReader
# lance_dataset = lance.dataset(uri / "table.lance")
# _check_roundtrip(lance_dataset, uri / "lance_dataset.lance", tbl)

# lance_scanner = lance_dataset.scanner()
# _check_roundtrip(lance_scanner, uri / "lance_scanner.lance", tbl)


def _check_roundtrip(data_obj, uri, expected):
lance.write_dataset(data_obj, uri)
assert expected == lance.dataset(uri).to_table()


def test_nearest(tmp_path):
uri = tmp_path

schema = pa.schema([pa.field("emb", pa.list_(pa.float32(), 32), False)])
npvals = np.random.rand(32 * 100)
values = pa.array(npvals, type=pa.float32())
arr = pa.FixedSizeListArray.from_arrays(values, 32)
tbl = pa.Table.from_arrays([arr], schema=schema)
lance.write_dataset(tbl, uri)

dataset = lance.dataset(uri)
top10 = dataset.to_table(nearest={"column": "emb", "q": arr[0].values, "k": 10})
scores = l2sq(arr[0].values, npvals.reshape((100, 32)))
indices = np.argsort(scores)
assert tbl.take(indices[:10]).to_pandas().equals(top10.to_pandas()[["emb"]])
assert np.allclose(scores[indices[:10]], top10.to_pandas().score.values)


def l2sq(vec, mat):
return np.sum((mat - vec)**2, axis=1)
Loading

0 comments on commit b23a3f1

Please sign in to comment.