Skip to content

Commit

Permalink
Finish prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
xinlifoobar committed Jul 16, 2024
1 parent 1676e93 commit a73167a
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 173 deletions.
27 changes: 13 additions & 14 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,6 @@ rust-version = { workspace = true }
[lints]
workspace = true

[profile.dev]
codegen-units = 1

[profile.release]
codegen-units = 1

[profile.bench]
codegen-units = 1

[profile.test]
codegen-units = 1

[features]
# enable core functions
core_expressions = []
Expand All @@ -58,6 +46,7 @@ default = [
"regex_expressions",
"string_expressions",
"unicode_expressions",
"arrow_udf",
]
# enable encode/decode functions
encoding_expressions = ["base64", "hex"]
Expand All @@ -70,6 +59,15 @@ string_expressions = ["uuid"]
# enable unicode functions
unicode_expressions = ["hashbrown", "unicode-segmentation"]

arrow_udf = [
"global_registry",
"arrow-string",
]

global_registry = [
"arrow-udf",
]

[lib]
name = "datafusion_functions"
path = "src/lib.rs"
Expand All @@ -79,8 +77,8 @@ path = "src/lib.rs"
[dependencies]
arrow = { workspace = true }
arrow-buffer = { workspace = true }
arrow-udf = { version="0.3.0", features = ["global_registry"] }
linkme = { version = "0.3.27"}
arrow-udf = { workspace = true, optional = true, features = ["global_registry"] }
arrow-string = { workspace = true, optional = true }
base64 = { version = "0.22", optional = true }
blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
Expand All @@ -91,6 +89,7 @@ datafusion-expr = { workspace = true }
hashbrown = { workspace = true, optional = true }
hex = { version = "0.4", optional = true }
itertools = { workspace = true }
linkme = { version = "0.3.27" }
log = { workspace = true }
md-5 = { version = "^0.10.0", optional = true }
rand = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ make_stub_package!(unicode, "unicode_expressions");
#[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))]
pub mod planner;

#[cfg(feature = "arrow_udf")]
pub mod udf;

mod utils;
Expand Down
6 changes: 0 additions & 6 deletions datafusion/functions/src/string/starts_with.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use std::sync::Arc;

use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use arrow_udf::function;

use datafusion_common::{cast::as_generic_string_array, internal_err, Result};
use datafusion_expr::ColumnarValue;
Expand All @@ -29,11 +28,6 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};

use crate::utils::make_scalar_function;

#[function("starts_with(string, string) -> bool")]
fn starts_with_udf(left: &str, right: &str) -> bool {
left.starts_with(right)
}

/// Returns true if string starts with prefix.
/// starts_with('alphabet', 'alph') = 't'
pub fn starts_with<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down
190 changes: 105 additions & 85 deletions datafusion/functions/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,102 +15,122 @@
// specific language governing permissions and limitations
// under the License.

use arrow_udf::function;
use std::sync::Arc;

use arrow::{
array::{Array, RecordBatch},
datatypes::{Field, Schema, SchemaRef},
};
use arrow_udf::{function, sig::REGISTRY};
use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
// use arrow_string::predicate::Predicate;

#[function("eq(boolean, boolean) -> boolean")]
fn eq(lhs: bool, rhs: bool) -> bool {
#[function("eq(int8, int8) -> boolean")]
#[function("eq(int16, int16) -> boolean")]
#[function("eq(int32, int32) -> boolean")]
#[function("eq(int64, int64) -> boolean")]
#[function("eq(uint8, uint8) -> boolean")]
#[function("eq(uint16, uint16) -> boolean")]
#[function("eq(uint32, uint32) -> boolean")]
#[function("eq(uint64, uint64) -> boolean")]
#[function("eq(string, string) -> boolean")]
#[function("eq(binary, binary) -> boolean")]
#[function("eq(largestring, largestring) -> boolean")]
#[function("eq(largebinary, largebinary) -> boolean")]
#[function("eq(date32, date32) -> boolean")]
// #[function("eq(struct Dictionary, struct Dictionary) -> boolean")]
fn eq<T: Eq>(lhs: T, rhs: T) -> bool {
lhs == rhs
}

#[function("gcd(int, int) -> int", output = "eval_gcd")]
fn gcd(mut a: i32, mut b: i32) -> i32 {
while b != 0 {
(a, b) = (b, a % b);
}
a
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, vec};
// Bad, we could not use the non-public API
// fn like(lhs: &str, rhs: &str) -> bool {
// Predicate::like(rhs).unwrap().matches(lhs);
// }

use arrow::{
array::{BooleanArray, RecordBatch},
datatypes::{Field, Schema},
};
use arrow_udf::sig::REGISTRY;

#[test]
fn test_eq() {
let bool_field = Field::new("", arrow::datatypes::DataType::Boolean, false);
let schema = Schema::new(vec![bool_field.clone()]);
let record_batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(BooleanArray::from(vec![true, false, true]))],
)
.unwrap();
pub fn apply_udf(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
return_field: &Field,
udf_name: &str,
) -> Result<ColumnarValue> {
let (record_batch, schema) = match (lhs, rhs) {
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
let schema = Arc::new(Schema::new(vec![
Field::new("", left.data_type().clone(), left.is_nullable()),
Field::new("", right.data_type().clone(), right.is_nullable()),
]));
let record_batch =
RecordBatch::try_new(schema.clone(), vec![left.clone(), right.clone()])?;
Ok::<(RecordBatch, SchemaRef), DataFusionError>((record_batch, schema))
}
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => {
let schema = Arc::new(Schema::new(vec![
Field::new("", left.data_type().clone(), false),
Field::new("", right.data_type().clone(), right.is_nullable()),
]));
let record_batch = RecordBatch::try_new(
schema.clone(),
vec![left.to_array_of_size(right.len())?, right.clone()],
)?;
Ok((record_batch, schema))
}
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => {
let schema = Arc::new(Schema::new(vec![
Field::new("", left.data_type().clone(), left.is_nullable()),
Field::new("", right.data_type().clone(), false),
]));
let record_batch = RecordBatch::try_new(
schema.clone(),
vec![left.clone(), right.to_array_of_size(left.len())?],
)?;
Ok((record_batch, schema))
}
(ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
let schema = Arc::new(Schema::new(vec![
Field::new("", left.data_type().clone(), false),
Field::new("", right.data_type().clone(), false),
]));
let record_batch = RecordBatch::try_new(
schema.clone(),
vec![left.to_array()?, right.to_array()?],
)?;
Ok((record_batch, schema))
}
}?;

println!("Function signatures:");
REGISTRY.iter().for_each(|sig| {
println!("{:?}", sig.name);
println!("{:?}", sig.arg_types);
println!("{:?}", sig.return_type);
});

let eval_eq_boolean = REGISTRY
.get("eq", &[bool_field.clone(), bool_field.clone()], &bool_field)
.unwrap()
.function
.as_scalar()
.unwrap();

let result = eval_eq_boolean(&record_batch).unwrap();
apply_udf_inner(schema, &record_batch, return_field, udf_name)
}

assert!(result
.column(0)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.value(0));
}
fn apply_udf_inner(
schema: SchemaRef,
record_batch: &RecordBatch,
return_field: &Field,
udf_name: &str,
) -> Result<ColumnarValue> {
println!("schema: {:?}", schema);

#[test]
fn test_gcd() {
let int_field = Field::new("", arrow::datatypes::DataType::Int32, false);
let schema = Schema::new(vec![int_field.clone(), int_field.clone()]);
let record_batch = RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::Int32Array::from(vec![20, 30, 40])),
],
let Some(eval) = REGISTRY
.get(
udf_name,
schema
.all_fields()
.into_iter()
.map(|f| f.to_owned())
.collect::<Vec<_>>()
.as_slice(),
return_field,
)
.unwrap();

println!("Function signatures:");
REGISTRY.iter().for_each(|sig| {
println!("{:?}", sig.name);
println!("{:?}", sig.arg_types);
println!("{:?}", sig.return_type);
});
.and_then(|f| f.function.as_scalar())
else {
return internal_err!("UDF {} not found for schema {}", udf_name, schema);
};

let eval_gcd_int = REGISTRY
.get("gcd", &[int_field.clone(), int_field.clone()], &int_field)
.unwrap()
.function
.as_scalar()
.unwrap();
let result = eval(record_batch)?;

let result = eval_gcd_int(&record_batch).unwrap();
let result_array = result.column_by_name(udf_name).unwrap();

assert_eq!(
result
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int32Array>()
.unwrap()
.value(0),
10
);
}
Ok(ColumnarValue::Array(Arc::clone(result_array)))
}
5 changes: 4 additions & 1 deletion datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ path = "src/lib.rs"
default = [
"regex_expressions",
"encoding_expressions",
"arrow_udf"
]
encoding_expressions = ["base64", "hex"]
regex_expressions = ["regex"]
arrow_udf = [
"datafusion-functions/arrow_udf",
]

[dependencies]
ahash = { workspace = true }
Expand All @@ -51,7 +55,6 @@ arrow-buffer = { workspace = true }
arrow-ord = { workspace = true }
arrow-schema = { workspace = true }
arrow-string = { workspace = true }
arrow-udf = { workspace = true, features = ["global_registry"] }
base64 = { version = "0.22", optional = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
Expand Down
Loading

0 comments on commit a73167a

Please sign in to comment.