Skip to content

Commit

Permalink
stage progress
Browse files Browse the repository at this point in the history
  • Loading branch information
xinlifoobar committed Jul 15, 2024
1 parent 983664a commit 1676e93
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 25 deletions.
4 changes: 0 additions & 4 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ default = [
"unicode_expressions",
"compression",
"parquet",
"arrow_udf",
]
encoding_expressions = ["datafusion-functions/encoding_expressions"]
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
Expand All @@ -76,9 +75,6 @@ unicode_expressions = [
"datafusion-sql/unicode_expressions",
"datafusion-functions/unicode_expressions",
]
arrow_udf = [
"datafusion-functions/arrow_udf",
]

[dependencies]
ahash = { workspace = true }
Expand Down
18 changes: 14 additions & 4 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ rust-version = { workspace = true }
[lints]
workspace = true

[profile.dev]
codegen-units = 1

[profile.release]
codegen-units = 1

[profile.bench]
codegen-units = 1

[profile.test]
codegen-units = 1

[features]
# enable core functions
core_expressions = []
Expand All @@ -46,7 +58,6 @@ default = [
"regex_expressions",
"string_expressions",
"unicode_expressions",
"arrow_udf",
]
# enable encode/decode functions
encoding_expressions = ["base64", "hex"]
Expand All @@ -58,8 +69,6 @@ regex_expressions = ["regex"]
string_expressions = ["uuid"]
# enable unicode functions
unicode_expressions = ["hashbrown", "unicode-segmentation"]
# enable arrow_udf
arrow_udf = ["arrow-udf",]

[lib]
name = "datafusion_functions"
Expand All @@ -70,7 +79,8 @@ path = "src/lib.rs"
[dependencies]
arrow = { workspace = true }
arrow-buffer = { workspace = true }
arrow-udf = { workspace = true, optional = true, features = ["global_registry"] }
arrow-udf = { version="0.3.0", features = ["global_registry"] }
linkme = { version = "0.3.27"}
base64 = { version = "0.22", optional = true }
blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
Expand Down
1 change: 0 additions & 1 deletion datafusion/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ make_stub_package!(unicode, "unicode_expressions");
#[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))]
pub mod planner;

#[cfg(feature = "arrow_udf")]
pub mod udf;

mod utils;
Expand Down
103 changes: 96 additions & 7 deletions datafusion/functions/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,100 @@

use arrow_udf::function;

#[function("eq(bool, bool) -> bool", output="eval_eq_boolean")]
#[function("eq(string, string) -> bool", output="eval_eq_string")]
#[function("eq(binary, binary) -> bool", output="eval_eq_binary")]
#[function("eq(largestring, largestring) -> bool", output="eval_eq_largestring")]
#[function("eq(largebinary, largebinary) -> bool", output="eval_eq_largebinary")]
fn eq<T: std::cmp::Eq>(_lhs: T, _rhs: T) -> bool {
_lhs == _rhs
#[function("eq(boolean, boolean) -> boolean")]
fn eq(lhs: bool, rhs: bool) -> bool {
lhs == rhs
}

#[function("gcd(int, int) -> int", output = "eval_gcd")]
fn gcd(mut a: i32, mut b: i32) -> i32 {
while b != 0 {
(a, b) = (b, a % b);
}
a
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, vec};

use arrow::{
array::{BooleanArray, RecordBatch},
datatypes::{Field, Schema},
};
use arrow_udf::sig::REGISTRY;

#[test]
fn test_eq() {
let bool_field = Field::new("", arrow::datatypes::DataType::Boolean, false);
let schema = Schema::new(vec![bool_field.clone()]);
let record_batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(BooleanArray::from(vec![true, false, true]))],
)
.unwrap();

println!("Function signatures:");
REGISTRY.iter().for_each(|sig| {
println!("{:?}", sig.name);
println!("{:?}", sig.arg_types);
println!("{:?}", sig.return_type);
});

let eval_eq_boolean = REGISTRY
.get("eq", &[bool_field.clone(), bool_field.clone()], &bool_field)
.unwrap()
.function
.as_scalar()
.unwrap();

let result = eval_eq_boolean(&record_batch).unwrap();

assert!(result
.column(0)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.value(0));
}

#[test]
fn test_gcd() {
let int_field = Field::new("", arrow::datatypes::DataType::Int32, false);
let schema = Schema::new(vec![int_field.clone(), int_field.clone()]);
let record_batch = RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::Int32Array::from(vec![20, 30, 40])),
],
)
.unwrap();

println!("Function signatures:");
REGISTRY.iter().for_each(|sig| {
println!("{:?}", sig.name);
println!("{:?}", sig.arg_types);
println!("{:?}", sig.return_type);
});

let eval_gcd_int = REGISTRY
.get("gcd", &[int_field.clone(), int_field.clone()], &int_field)
.unwrap()
.function
.as_scalar()
.unwrap();

let result = eval_gcd_int(&record_batch).unwrap();

assert_eq!(
result
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int32Array>()
.unwrap()
.value(0),
10
);
}
}
6 changes: 1 addition & 5 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ default = [
]
encoding_expressions = ["base64", "hex"]
regex_expressions = ["regex"]
arrow_udf = [
"arrow-udf",
"datafusion-functions/arrow_udf",
]

[dependencies]
ahash = { workspace = true }
Expand All @@ -55,7 +51,7 @@ arrow-buffer = { workspace = true }
arrow-ord = { workspace = true }
arrow-schema = { workspace = true }
arrow-string = { workspace = true }
arrow-udf = { workspace = true, optional = true, features = ["global_registry"] }
arrow-udf = { workspace = true, features = ["global_registry"] }
base64 = { version = "0.22", optional = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
Expand Down
12 changes: 8 additions & 4 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl PhysicalExpr for BinaryExpr {
println!("schema: {:?}", schema);

let record_batch = RecordBatch::try_new(
schema.clone(),
Arc::clone(&schema),
vec![
lhs.clone().into_array(batch.num_rows())?,
rhs.clone().into_array(batch.num_rows())?,
Expand All @@ -315,15 +315,19 @@ impl PhysicalExpr for BinaryExpr {
let Some(eval_eq_string) = REGISTRY
.get(
"eq",
&schema
schema
.all_fields()
.into_iter()
.map(|f| f.to_owned())
.collect::<Vec<Field>>()
.as_slice(),
&Field::new("bool", DataType::Boolean, false),
)
.and_then(|f| f.function.as_scalar())
.and_then(|f| {
println!("Function found");

return f.function.as_scalar();
})
else {
return internal_err!("Failed to get eq function");
};
Expand All @@ -336,7 +340,7 @@ impl PhysicalExpr for BinaryExpr {
return internal_err!("Failed to get result array");
};

return Ok(ColumnarValue::Array(result_array.clone()));
return Ok(ColumnarValue::Array(Arc::clone(result_array)));
}
Operator::NotEq => return apply_cmp(&lhs, &rhs, neq),
Operator::Lt => return apply_cmp(&lhs, &rhs, lt),
Expand Down

0 comments on commit 1676e93

Please sign in to comment.