Skip to content

Commit

Permalink
Allow using dictionary arrays as filters (#12382)
Browse files Browse the repository at this point in the history
* Allow using dictionaries as filters

* revert, nested

* fmt
  • Loading branch information
adriangb authored Sep 10, 2024
1 parent c575bbf commit 8d2b240
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 3 deletions.
107 changes: 106 additions & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ use arrow::{
},
record_batch::RecordBatch,
};
use arrow_array::{Array, Float32Array, Float64Array, UnionArray};
use arrow_array::{
Array, BooleanArray, DictionaryArray, Float32Array, Float64Array, Int8Array,
UnionArray,
};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{ArrowError, UnionFields, UnionMode};
use datafusion_functions_aggregate::count::count_udaf;
Expand Down Expand Up @@ -2363,3 +2366,105 @@ async fn dense_union_is_null() {
];
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());
}

#[tokio::test]
async fn boolean_dictionary_as_filter() {
let values = vec![Some(true), Some(false), None, Some(true)];
let keys = vec![0, 0, 1, 2, 1, 3, 1];
let values_array = BooleanArray::from(values);
let keys_array = Int8Array::from(keys);
let array =
DictionaryArray::new(keys_array, Arc::new(values_array) as Arc<dyn Array>);
let array = Arc::new(array);

let field = Field::new(
"my_dict",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Boolean)),
true,
);
let schema = Arc::new(Schema::new(vec![field]));

let batch = RecordBatch::try_new(schema, vec![array.clone()]).unwrap();

let ctx = SessionContext::new();

ctx.register_batch("dict_batch", batch).unwrap();

let df = ctx.table("dict_batch").await.unwrap();

// view_all
let expected = [
"+---------+",
"| my_dict |",
"+---------+",
"| true |",
"| true |",
"| false |",
"| |",
"| false |",
"| true |",
"| false |",
"+---------+",
];
assert_batches_eq!(expected, &df.clone().collect().await.unwrap());

let result_df = df.clone().filter(col("my_dict")).unwrap();
let expected = [
"+---------+",
"| my_dict |",
"+---------+",
"| true |",
"| true |",
"| true |",
"+---------+",
];
assert_batches_eq!(expected, &result_df.collect().await.unwrap());

// test nested dictionary
let keys = vec![0, 2]; // 0 -> true, 2 -> false
let keys_array = Int8Array::from(keys);
let nested_array = DictionaryArray::new(keys_array, array);

let field = Field::new(
"my_nested_dict",
DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Boolean),
)),
),
true,
);

let schema = Arc::new(Schema::new(vec![field]));

let batch = RecordBatch::try_new(schema, vec![Arc::new(nested_array)]).unwrap();

ctx.register_batch("nested_dict_batch", batch).unwrap();

let df = ctx.table("nested_dict_batch").await.unwrap();

// view_all
let expected = [
"+----------------+",
"| my_nested_dict |",
"+----------------+",
"| true |",
"| false |",
"+----------------+",
];

assert_batches_eq!(expected, &df.clone().collect().await.unwrap());

let result_df = df.clone().filter(col("my_nested_dict")).unwrap();
let expected = [
"+----------------+",
"| my_nested_dict |",
"+----------------+",
"| true |",
"+----------------+",
];

assert_batches_eq!(expected, &result_df.collect().await.unwrap());
}
14 changes: 12 additions & 2 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2207,6 +2207,17 @@ impl Filter {
Self::try_new_internal(predicate, input, true)
}

fn is_allowed_filter_type(data_type: &DataType) -> bool {
match data_type {
// Interpret NULL as a missing boolean value.
DataType::Boolean | DataType::Null => true,
DataType::Dictionary(_, value_type) => {
Filter::is_allowed_filter_type(value_type.as_ref())
}
_ => false,
}
}

fn try_new_internal(
predicate: Expr,
input: Arc<LogicalPlan>,
Expand All @@ -2217,8 +2228,7 @@ impl Filter {
// construction (such as with correlated subqueries) so we make a best effort here and
// ignore errors resolving the expression against the schema.
if let Ok(predicate_type) = predicate.get_type(input.schema()) {
// Interpret NULL as a missing boolean value.
if predicate_type != DataType::Boolean && predicate_type != DataType::Null {
if !Filter::is_allowed_filter_type(&predicate_type) {
return plan_err!(
"Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}"
);
Expand Down

0 comments on commit 8d2b240

Please sign in to comment.