From 08343ddb160208de67ea35393b8dd35f669af218 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 22 Aug 2024 14:51:55 +0800 Subject: [PATCH] Add uft8 support and bench --- datafusion/functions/benches/regx.rs | 32 ++++- datafusion/functions/src/regex/mod.rs | 12 +- datafusion/functions/src/regex/regexpcount.rs | 11 +- datafusion/sqllogictest/test_files/regexp.slt | 113 ++++++++++++++++++ 4 files changed, 160 insertions(+), 8 deletions(-) diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 23d57f38efae..dd902400d3a4 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,8 +18,9 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, StringArray}; +use arrow::array::{ArrayRef, Int64Array, StringArray}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::regex::regexpcount::regexp_count; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -59,6 +60,15 @@ fn regex(rng: &mut ThreadRng) -> StringArray { StringArray::from(data) } +fn start(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.gen_range(1..5)); + } + + Int64Array::from(data) +} + fn flags(rng: &mut ThreadRng) -> StringArray { let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); @@ -75,6 +85,26 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("regexp_count_1000", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_count::(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on valid values"), + ) + }) + }); + c.bench_function("regexp_like_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index a0b4d0fe6960..fdd461c6a30b 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -34,8 +34,18 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; - pub fn regexp_count(values: Expr, regex: Expr, flags: Option) -> Expr { + /// Returns the number of consecutive occurrences of a regular expression in a string. + pub fn regexp_count( + values: Expr, + regex: Expr, + start: Option, + flags: Option, + ) -> Expr { let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + if let Some(flags) = flags { args.push(flags); }; diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index cc86c66c4792..246d326fd0c8 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -22,14 +22,14 @@ use arrow::array::{ Array, ArrayRef, AsArray, Datum, GenericStringArray, Int64Array, OffsetSizeTrait, Scalar, }; -use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8}; +use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8, Utf8View}; use arrow::datatypes::Int64Type; use arrow::error::ArrowError; use datafusion_common::cast::{as_generic_string_array, as_primitive_array}; use datafusion_common::{ arrow_err, exec_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::TypeSignature::{Exact, Uniform}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use itertools::izip; use regex::Regex; @@ -51,14 +51,13 @@ impl RegexpCountFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8]), + Uniform(2, vec![Utf8, LargeUtf8, Utf8View]), Exact(vec![Utf8, Utf8, Int64]), Exact(vec![Utf8, Utf8, Int64, Utf8]), - Exact(vec![Utf8, Utf8, Int64, LargeUtf8]), - Exact(vec![LargeUtf8, LargeUtf8]), Exact(vec![LargeUtf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, Utf8]), Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), ], Volatility::Immutable, ), diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 1c86ea5e6be5..95d7bd318718 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -553,6 +553,119 @@ SELECT regexp_count(str, 'ab', 1, 'i') from t; 0 +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + query I SELECT regexp_count(str, pattern) from t; ----