Skip to content

Commit

Permalink
feat: allow return blob in write_fragments
Browse files Browse the repository at this point in the history
  • Loading branch information
fecet committed Dec 13, 2024
1 parent 99ae761 commit 71c464c
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
20 changes: 19 additions & 1 deletion python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from .dependencies import _check_for_pandas
from .dependencies import pandas as pd
from .lance import _Fragment, _write_fragments
from .lance import _Fragment, _write_fragments, _write_fragments_with_blobs
from .lance import _FragmentMetadata as _FragmentMetadata
from .progress import FragmentWriteProgress, NoopFragmentWriteProgress
from .udf import BatchUDF, normalize_transform
Expand Down Expand Up @@ -498,6 +498,7 @@ def write_fragments(
data_storage_version: Optional[str] = None,
use_legacy_format: Optional[bool] = None,
storage_options: Optional[Dict[str, str]] = None,
with_blobs: bool = False,
) -> List[FragmentMetadata]:
"""
Write data into one or more fragments.
Expand Down Expand Up @@ -582,6 +583,23 @@ def write_fragments(
else:
data_storage_version = "stable"

if with_blobs:
default_fragments, blob_fragments = _write_fragments_with_blobs(
dataset_uri,
reader,
mode=mode,
max_rows_per_file=max_rows_per_file,
max_rows_per_group=max_rows_per_group,
max_bytes_per_file=max_bytes_per_file,
progress=progress,
data_storage_version=data_storage_version,
storage_options=storage_options,
)

return [FragmentMetadata.from_metadata(frag) for frag in default_fragments], [
FragmentMetadata.from_metadata(frag) for frag in blob_fragments
]

fragments = _write_fragments(
dataset_uri,
reader,
Expand Down
61 changes: 61 additions & 0 deletions python/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,67 @@ pub fn write_fragments(
.collect()
}

#[pyfunction(name = "_write_fragments_with_blobs")]
#[pyo3(signature = (dest, reader, **kwargs))]
pub fn write_fragments_with_blobs(
dest: &Bound<PyAny>,
reader: &Bound<PyAny>,
kwargs: Option<&PyDict>,
) -> PyResult<(Vec<FragmentMetadata>, Vec<FragmentMetadata>)> {
let batches = convert_reader(reader)?;

let params = kwargs
.and_then(|params| get_write_params(params).transpose())
.transpose()?
.unwrap_or_default();

let dest = if dest.is_instance_of::<Dataset>() {
let dataset: Dataset = dest.extract()?;
WriteDestination::Dataset(dataset.ds.clone())
} else {
WriteDestination::Uri(dest.extract()?)
};

let written = RT
.block_on(
Some(reader.py()),
InsertBuilder::new(dest)
.with_params(&params)
.execute_uncommitted_stream(batches),
)?
.map_err(|err| PyIOError::new_err(err.to_string()))?;

let get_fragments = |operation: Operation| match operation {
Operation::Overwrite { fragments, .. } => Ok(fragments),
Operation::Append { fragments, .. } => Ok(fragments),
_ => Err(Error::Internal {
message: "Unexpected operation".into(),
location: location!(),
}),
};

let default_fragments =
get_fragments(written.operation).map_err(|err| PyRuntimeError::new_err(err.to_string()))?;

let blob_fragments = if let Some(blob_op) = written.blobs_op {
get_fragments(blob_op).map_err(|err| PyRuntimeError::new_err(err.to_string()))?
} else {
Vec::new()
};

let default_meta: Vec<_> = default_fragments
.into_iter()
.map(FragmentMetadata::new)
.collect();

let blob_meta: Vec<_> = blob_fragments
.into_iter()
.map(FragmentMetadata::new)
.collect();

Ok((default_meta, blob_meta))
}

fn convert_reader(reader: &Bound<PyAny>) -> PyResult<Box<dyn RecordBatchReader + Send + 'static>> {
if reader.is_instance_of::<Scanner>() {
let scanner: Scanner = reader.extract()?;
Expand Down
3 changes: 2 additions & 1 deletion python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub(crate) mod tracing;
pub(crate) mod utils;

pub use crate::arrow::{bfloat16_array, BFloat16};
use crate::fragment::write_fragments;
use crate::fragment::{write_fragments, write_fragments_with_blobs};
pub use crate::tracing::{trace_to_chrome, TraceGuard};
use crate::utils::Hnsw;
use crate::utils::KMeans;
Expand Down Expand Up @@ -139,6 +139,7 @@ fn lance(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(bfloat16_array))?;
m.add_wrapped(wrap_pyfunction!(write_dataset))?;
m.add_wrapped(wrap_pyfunction!(write_fragments))?;
m.add_wrapped(wrap_pyfunction!(write_fragments_with_blobs))?;
m.add_wrapped(wrap_pyfunction!(schema_to_json))?;
m.add_wrapped(wrap_pyfunction!(json_to_schema))?;
m.add_wrapped(wrap_pyfunction!(infer_tfrecord_schema))?;
Expand Down

0 comments on commit 71c464c

Please sign in to comment.