From f72c11f9c14f87760fd2db0ab7b7beb3ea7c0910 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Tue, 20 Aug 2024 21:34:41 +0800 Subject: [PATCH 01/12] Implement regexp_ccount --- datafusion/functions/src/regex/mod.rs | 17 +- datafusion/functions/src/regex/regexpcount.rs | 561 ++++++++++++++++++ datafusion/sqllogictest/test_files/regexp.slt | 175 +++++- 3 files changed, 740 insertions(+), 13 deletions(-) create mode 100644 datafusion/functions/src/regex/regexpcount.rs diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 4ac162290ddb..a0b4d0fe6960 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,10 +17,12 @@ //! "regx" DataFusion functions +pub mod regexpcount; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs +make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); make_udf_function!( @@ -32,6 +34,14 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; + pub fn regexp_count(values: Expr, regex: Expr, flags: Option) -> Expr { + let mut args = vec![values, regex]; + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_count().call(args) + } + /// Returns a list of regular expression matches in a string. pub fn regexp_match(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; @@ -67,5 +77,10 @@ pub mod expr_fn { /// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { - vec![regexp_match(), regexp_like(), regexp_replace()] + vec![ + regexp_count(), + regexp_match(), + regexp_like(), + regexp_replace(), + ] } diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs new file mode 100644 index 000000000000..103f934aa903 --- /dev/null +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -0,0 +1,561 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::iter::repeat; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, Datum, GenericStringArray, Int64Array, OffsetSizeTrait, + Scalar, +}; +use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8}; +use arrow::datatypes::Int64Type; +use arrow::error::ArrowError; +use datafusion_common::cast::{as_generic_string_array, as_primitive_array}; +use datafusion_common::{ + arrow_err, exec_err, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use itertools::izip; +use regex::Regex; + +/// regexp_count(string, pattern [, start [, flags ]]) -> int +#[derive(Debug)] +pub struct RegexpCountFunc { + signature: Signature, +} + +impl Default for RegexpCountFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpCountFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8, Utf8, Int64, Utf8]), + Exact(vec![Utf8, Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpCountFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match &arg_types[0] { + Null => Null, + _ => Int64, + }) + } + + fn invoke(&self, args: &[datafusion_expr::ColumnarValue]) -> Result { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.clone().into_array(inferred_length)) + .collect::>>()?; + + let result = regexp_count_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } +} + +fn regexp_count_func(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => regexp_count::(args), + DataType::LargeUtf8 => regexp_count::(args), + other => { + internal_err!("Unsupported data type {other:?} for function regexp_like") + } + } +} + +pub fn regexp_count(args: &[ArrayRef]) -> Result { + let arg_len = args.len(); + + match arg_len { + 2..=4 => { + let values = as_generic_string_array::(&args[0])?; + let regex = as_generic_string_array::(&args[1])?; + let regex_datum: &dyn Datum = if regex.len() != 1 { + regex + } else { + &Scalar::new(regex) + }; + let start_scalar: Scalar<&Int64Array>; + let start_array_datum: Option<&dyn Datum> = if arg_len > 2 { + let start_array = as_primitive_array::(&args[2])?; + if start_array.len() != 1 { + Some(start_array as &dyn Datum) + } else { + start_scalar = Scalar::new(start_array); + Some(&start_scalar as &dyn Datum) + } + } else { + None + }; + + let flags_scalar: Scalar<&GenericStringArray>; + let flags_array_datum: Option<&dyn Datum> = if arg_len > 3 { + let flags_array = as_generic_string_array::(&args[3])?; + if flags_array.len() != 1 { + Some(flags_array as &dyn Datum) + } else { + flags_scalar = Scalar::new(flags_array); + Some(&flags_scalar as &dyn Datum) + } + } else { + None + }; + + Ok(regexp_array_count( + values, + regex_datum, + start_array_datum, + flags_array_datum, + ) + .map(|x| Arc::new(x) as ArrayRef)?) + } + other => { + exec_err!( + "regexp_count was called with {other} arguments. It requires at least 2 and at most 4." + ) + } + } +} + +pub fn regexp_array_count( + array: &GenericStringArray, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, +) -> Result { + let (rhs, is_rhs_scalar) = regex_array.get(); + + if array.data_type() != rhs.data_type() { + return arrow_err!(ArrowError::ComputeError( + "regexp_count() requires pattern to be either Utf8 or LargeUtf8".to_string(), + )); + } + + if !is_rhs_scalar && array.len() != rhs.len() { + return arrow_err!( + ArrowError::ComputeError( + "regexp_count() requires pattern to be either a scalar or the same length as the input array".to_string(), + ) + ); + } + + let (starts, is_starts_scalar) = match start_array { + Some(starts) => starts.get(), + None => (&Int64Array::from(vec![1]) as &dyn Array, true), + }; + + if *starts.data_type() != Int64 { + return arrow_err!(ArrowError::ComputeError( + "regexp_count() requires start to be Int64".to_string(), + )); + } + + if !is_starts_scalar && array.len() != starts.len() { + return arrow_err!( + ArrowError::ComputeError( + "regexp_count() requires start to be either a scalar or the same length as the input array".to_string(), + ) + ); + } + + let (flags, is_flags_scalar) = match flags_array { + Some(flags) => flags.get(), + None => ( + &GenericStringArray::::from(vec![None as Option<&str>]) as &dyn Array, + true, + ), + }; + + if !is_flags_scalar && array.len() != flags.len() { + return arrow_err!( + ArrowError::ComputeError( + "regexp_count() requires flags to be either a scalar or the same length as the input array".to_string(), + ) + ); + } + + let regex_iter: Box>> = if is_rhs_scalar { + let regex: &arrow::array::GenericByteArray< + arrow::datatypes::GenericStringType, + > = rhs.as_string::(); + let regex = regex.is_valid(0).then(|| regex.value(0)); + if regex.is_none() { + return Ok(Int64Array::from( + repeat(0).take(array.len()).collect::>(), + )); + } + + Box::new(repeat(regex)) + } else { + Box::new(rhs.as_string::().iter()) + }; + + let start_iter: Box> = if is_starts_scalar { + let start = starts.as_primitive::(); + let start = start.is_valid(0).then(|| start.value(0)); + Box::new(repeat(start.unwrap_or(1))) + } else { + Box::new( + starts + .as_primitive::() + .iter() + .map(|x| x.unwrap_or(1)), + ) + }; + + let flags_iter: Box>> = if is_flags_scalar { + let flags = flags.as_string::(); + let flags = flags + .is_valid(0) + .then(|| flags.value(0)) + .map(|x| { + if x.contains('g') { + return arrow_err!(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + Ok(x) + }) + .transpose()?; + + Box::new(repeat(flags)) + } else { + Box::new(flags.as_string::().iter()) + }; + + regex_iter_count(array.iter(), regex_iter, start_iter, flags_iter) +} + +fn regex_iter_count<'a>( + array: impl Iterator>, + regex: impl Iterator>, + start: impl Iterator, + flags: impl Iterator>, +) -> Result { + Ok(Int64Array::from( + izip!(array, regex, start, flags) + .map(|(array, regex, start, flags)| { + if array.is_none() || regex.is_none() { + return Ok(0); + } + + let regex = regex.unwrap(); + if regex.is_empty() { + return Ok(0); + } + + if regex.contains('g') { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + + if start <= 0 { + return Err(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + )); + } + + let array = array.unwrap(); + let start = start as usize; + if start > array.len() { + return Ok(0); + } + + let pattern = if let Some(Some(flags)) = + flags.map(|x| if x.is_empty() { None } else { Some(x) }) + { + format!("(?{flags}){regex}") + } else { + regex.to_string() + }; + + let Ok(re) = Regex::new(pattern.as_str()) else { + return Err(ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + ))); + }; + + Ok(re + .find_iter(&array.chars().skip(start - 1).collect::()) + .count() as i64) + }) + .collect::, ArrowError>>()?, + )) +} + +#[cfg(test)] +mod tests { + use crate::regex::regexpcount::regexp_count; + use arrow::array::{ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait}; + use std::sync::Arc; + + #[test] + fn test_regexp_count() { + test_case_sensitive_regexp_count_scalar::(); + test_case_sensitive_regexp_count_scalar::(); + + test_case_sensitive_regexp_count_scalar_start::(); + test_case_sensitive_regexp_count_scalar_start::(); + + test_case_insensitive_regexp_count_scalar_flags::(); + test_case_insensitive_regexp_count_scalar_flags::(); + + test_case_sensitive_regexp_count_array::(); + test_case_sensitive_regexp_count_array::(); + + test_case_sensitive_regexp_count_array_start::(); + test_case_sensitive_regexp_count_array_start::(); + + test_case_insensitive_regexp_count_array_flags::(); + test_case_insensitive_regexp_count_array_flags::(); + + test_case_sensitive_regexp_count_start_scalar_complex::(); + test_case_sensitive_regexp_count_start_scalar_complex::(); + + test_case_sensitive_regexp_count_array_complex::(); + test_case_sensitive_regexp_count_array_complex::(); + } + + fn test_case_sensitive_regexp_count_scalar() { + let values = GenericStringArray::::from(vec![ + "", + "aabca", + "abcabc", + "abcAbcab", + "abcabcabc", + ]); + let regex = GenericStringArray::::from(vec!["abc"; 1]); + + let expected = Int64Array::from(vec![0, 1, 2, 1, 3]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_scalar_start() { + let values = GenericStringArray::::from(vec![ + "", + "aabca", + "abcabc", + "abcAbcab", + "abcabcabc", + ]); + let regex = GenericStringArray::::from(vec!["abc"; 1]); + let start = Int64Array::from(vec![2]); + + let expected = Int64Array::from(vec![0, 1, 1, 0, 2]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_scalar_flags() { + let values = GenericStringArray::::from(vec![ + "", + "aabca", + "abcabc", + "abcAbcab", + "abcabcabc", + ]); + let regex = GenericStringArray::::from(vec!["abc"; 1]); + let start = Int64Array::from(vec![1]); + let flags = GenericStringArray::::from(vec!["i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + Arc::new(flags) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array() { + let values = GenericStringArray::::from(vec![ + "", + "aabca", + "abcabc", + "abcAbcab", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_start() { + let values = GenericStringArray::::from(vec![ + "", + "aAbca", + "abcabc", + "abcAbcab", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_array_flags() { + let values = GenericStringArray::::from(vec![ + "", + "aAbca", + "abcabc", + "abcAbcab", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1]); + let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + Arc::new(flags) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_start_scalar_complex() { + let values = GenericStringArray::::from(vec![ + "", + "aAbca", + "abcabc", + "abcAbcabc", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![5]); + let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 0, 0, 2, 1]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + Arc::new(flags) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_complex() { + let values = GenericStringArray::::from(vec![ + "", + "aAbca", + "abcabc", + "abcAbcab", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + Arc::new(flags) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 22322d79ccfe..aa1b350d3661 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -16,18 +16,18 @@ # under the License. statement ok -CREATE TABLE t (str varchar, pattern varchar, flags varchar) AS VALUES - ('abc', '^(a)', 'i'), - ('ABC', '^(A).*', 'i'), - ('aBc', '(b|d)', 'i'), - ('AbC', '(B|D)', null), - ('aBC', '^(b|c)', null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('Düsseldorf','[\p{Letter}-]+', null), - ('Москва', '[\p{L}-]+', null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', null), - ('إسرائيل', '^\p{Arabic}+$', null); +CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); # # regexp_like tests @@ -395,6 +395,157 @@ SELECT 'foo\nbar\nbaz' LIKE '%bar%'; ---- true + +# regexp_count tests + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + statement ok drop table t; From ee23b977abcfa2eda532c395777ca1f3f65076d2 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Tue, 20 Aug 2024 21:49:26 +0800 Subject: [PATCH 02/12] Update document --- .../source/user-guide/sql/scalar_functions.md | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c7b3409ba7cd..07efc62864e8 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1302,6 +1302,7 @@ Apache DataFusion uses a [PCRE-like] regular expression [syntax] (minus support for several features including look-around and backreferences). The following regular expression functions are supported: +- [regexp_count](#regex_count) - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) @@ -1309,6 +1310,29 @@ The following regular expression functions are supported: [pcre-like]: https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions [syntax]: https://docs.rs/regex/latest/regex/#syntax +### `regexp_count` +Returns the number of matchs that a [regular expression] has in a string. + +``` +regexp_count(str, regexp[, start, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. + Can be a constant, column, or function, and any combination of string operators. +- **regexp**: Regular expression to test against the string expression. + Can be a constant, column, or function. +- **start**: Optional start position to search for the regular expression. + Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the + regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + ### `regexp_like` Returns true if a [regular expression] has at least one match in a string, From d5b63f4261d89314ebbc599239891963489010d6 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Tue, 20 Aug 2024 21:57:28 +0800 Subject: [PATCH 03/12] fix check --- datafusion/functions/src/regex/regexpcount.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 103f934aa903..b07fd455b960 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -305,12 +305,6 @@ fn regex_iter_count<'a>( return Ok(0); } - if regex.contains('g') { - return Err(ArrowError::ComputeError( - "regexp_count() does not support global flag".to_string(), - )); - } - if start <= 0 { return Err(ArrowError::ComputeError( "regexp_count() requires start to be 1 based".to_string(), @@ -326,6 +320,12 @@ fn regex_iter_count<'a>( let pattern = if let Some(Some(flags)) = flags.map(|x| if x.is_empty() { None } else { Some(x) }) { + if flags.contains('g') { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + format!("(?{flags}){regex}") } else { regex.to_string() From 2acd148534f613e0e31a4f3fef011c83450244fa Mon Sep 17 00:00:00 2001 From: Xin Li Date: Tue, 20 Aug 2024 22:04:18 +0800 Subject: [PATCH 04/12] add more tests --- datafusion/functions/src/regex/regexpcount.rs | 2 +- datafusion/sqllogictest/test_files/regexp.slt | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index b07fd455b960..8236fd673550 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -119,7 +119,7 @@ fn regexp_count_func(args: &[ArrayRef]) -> Result { DataType::Utf8 => regexp_count::(args), DataType::LargeUtf8 => regexp_count::(args), other => { - internal_err!("Unsupported data type {other:?} for function regexp_like") + internal_err!("Unsupported data type {other:?} for function regexp_count") } } } diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 43a9974589e2..1d2cc90d9bd0 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -503,6 +503,10 @@ statement error External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based SELECT regexp_count('123123123123', '123', -3); +statement error +External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag +SELECT regexp_count('123123123123', '123', 1, 'g'); + query I SELECT regexp_count(str, '\w') from t; ---- From 27a6fc6e9102da096f68b357c571d76984226bff Mon Sep 17 00:00:00 2001 From: Xin Li Date: Wed, 21 Aug 2024 21:30:20 +0800 Subject: [PATCH 05/12] Update the world to 1.80 --- .github/workflows/rust.yml | 4 ++-- Cargo.toml | 2 +- datafusion-cli/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/proto-common/Cargo.toml | 2 +- datafusion/proto-common/gen/Cargo.toml | 2 +- datafusion/proto/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- datafusion/substrait/Cargo.toml | 2 +- 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 809f3acd8374..2f4b9eae59a2 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -569,9 +569,9 @@ jobs: # # To reproduce: # 1. Install the version of Rust that is failing. Example: - # rustup install 1.76.0 + # rustup install 1.80.0 # 2. Run the command that failed with that version. Example: - # cargo +1.76.0 check -p datafusion + # cargo +1.80.0 check -p datafusion # # To resolve, either: # 1. Change your code to use older Rust features, diff --git a/Cargo.toml b/Cargo.toml index ae344a46a1bd..fb6545c5bc4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.76" +rust-version = "1.80" version = "41.0.0" [workspace.dependencies] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 0a4523a1c04e..57a6c75dc6a2 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.76" +rust-version = "1.80" readme = "README.md" [dependencies] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index adbba3eb31d6..625c1067e495 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.76" +rust-version = "1.80" [lints] workspace = true diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index e5d65827cdec..9b2f15a9a710 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.76" +rust-version = "1.80" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 54ec0e44694b..bb03208b2b70 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.76" +rust-version = "1.80" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 95d9e6700a50..4203bd7a28c0 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.76" +rust-version = "1.80" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 401c51c94563..e69282540cb2 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.76" +rust-version = "1.80" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 9e7ef9632ad3..0647263225c4 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.76" +rust-version = "1.80" [lints] workspace = true From d17e45d037ea38b4d908d2852b61e2b545bb9ddf Mon Sep 17 00:00:00 2001 From: Xin Li Date: Wed, 21 Aug 2024 21:33:46 +0800 Subject: [PATCH 06/12] Fix doc format --- docs/source/user-guide/sql/scalar_functions.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 07efc62864e8..17d172e7a4d3 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1311,6 +1311,7 @@ The following regular expression functions are supported: [syntax]: https://docs.rs/regex/latest/regex/#syntax ### `regexp_count` + Returns the number of matchs that a [regular expression] has in a string. ``` From ee14adf0992146c0341dcf193d57c5dbffd1ced7 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 22 Aug 2024 14:14:20 +0800 Subject: [PATCH 07/12] Add null tests --- datafusion/functions/src/regex/regexpcount.rs | 13 +++++- datafusion/sqllogictest/test_files/regexp.slt | 40 +++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8236fd673550..cc86c66c4792 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -131,6 +131,11 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { 2..=4 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; + + if values.is_empty() || regex.is_empty() { + return Ok(Arc::new(Int64Array::new_null(0))); + } + let regex_datum: &dyn Datum = if regex.len() != 1 { regex } else { @@ -139,7 +144,9 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { let start_scalar: Scalar<&Int64Array>; let start_array_datum: Option<&dyn Datum> = if arg_len > 2 { let start_array = as_primitive_array::(&args[2])?; - if start_array.len() != 1 { + if start_array.is_empty() { + None + } else if start_array.len() != 1 { Some(start_array as &dyn Datum) } else { start_scalar = Scalar::new(start_array); @@ -152,7 +159,9 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { let flags_scalar: Scalar<&GenericStringArray>; let flags_array_datum: Option<&dyn Datum> = if arg_len > 3 { let flags_array = as_generic_string_array::(&args[3])?; - if flags_array.len() != 1 { + if flags_array.is_empty() { + None + } else if flags_array.len() != 1 { Some(flags_array as &dyn Datum) } else { flags_scalar = Scalar::new(flags_array); diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 1d2cc90d9bd0..1c86ea5e6be5 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -614,6 +614,46 @@ SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), 1 1 +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + statement ok drop table t; From 08343ddb160208de67ea35393b8dd35f669af218 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 22 Aug 2024 14:51:55 +0800 Subject: [PATCH 08/12] Add uft8 support and bench --- datafusion/functions/benches/regx.rs | 32 ++++- datafusion/functions/src/regex/mod.rs | 12 +- datafusion/functions/src/regex/regexpcount.rs | 11 +- datafusion/sqllogictest/test_files/regexp.slt | 113 ++++++++++++++++++ 4 files changed, 160 insertions(+), 8 deletions(-) diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 23d57f38efae..dd902400d3a4 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,8 +18,9 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, StringArray}; +use arrow::array::{ArrayRef, Int64Array, StringArray}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::regex::regexpcount::regexp_count; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -59,6 +60,15 @@ fn regex(rng: &mut ThreadRng) -> StringArray { StringArray::from(data) } +fn start(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.gen_range(1..5)); + } + + Int64Array::from(data) +} + fn flags(rng: &mut ThreadRng) -> StringArray { let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); @@ -75,6 +85,26 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("regexp_count_1000", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_count::(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on valid values"), + ) + }) + }); + c.bench_function("regexp_like_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index a0b4d0fe6960..fdd461c6a30b 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -34,8 +34,18 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; - pub fn regexp_count(values: Expr, regex: Expr, flags: Option) -> Expr { + /// Returns the number of consecutive occurrences of a regular expression in a string. + pub fn regexp_count( + values: Expr, + regex: Expr, + start: Option, + flags: Option, + ) -> Expr { let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + if let Some(flags) = flags { args.push(flags); }; diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index cc86c66c4792..246d326fd0c8 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -22,14 +22,14 @@ use arrow::array::{ Array, ArrayRef, AsArray, Datum, GenericStringArray, Int64Array, OffsetSizeTrait, Scalar, }; -use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8}; +use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8, Utf8View}; use arrow::datatypes::Int64Type; use arrow::error::ArrowError; use datafusion_common::cast::{as_generic_string_array, as_primitive_array}; use datafusion_common::{ arrow_err, exec_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::TypeSignature::{Exact, Uniform}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use itertools::izip; use regex::Regex; @@ -51,14 +51,13 @@ impl RegexpCountFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8]), + Uniform(2, vec![Utf8, LargeUtf8, Utf8View]), Exact(vec![Utf8, Utf8, Int64]), Exact(vec![Utf8, Utf8, Int64, Utf8]), - Exact(vec![Utf8, Utf8, Int64, LargeUtf8]), - Exact(vec![LargeUtf8, LargeUtf8]), Exact(vec![LargeUtf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, Utf8]), Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), ], Volatility::Immutable, ), diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 1c86ea5e6be5..95d7bd318718 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -553,6 +553,119 @@ SELECT regexp_count(str, 'ab', 1, 'i') from t; 0 +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + query I SELECT regexp_count(str, pattern) from t; ---- From 218ff7be82c84e43891211f0c45ca3b23e2faca5 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Wed, 28 Aug 2024 22:27:57 +0800 Subject: [PATCH 09/12] Refactoring regexp_count --- Cargo.toml | 2 +- datafusion-cli/Cargo.toml | 2 +- datafusion-examples/examples/regex_count.rs | 33 + datafusion/core/Cargo.toml | 2 +- datafusion/functions/src/regex/regexpcount.rs | 639 +++++++++++------- datafusion/proto-common/Cargo.toml | 2 +- datafusion/proto-common/gen/Cargo.toml | 2 +- datafusion/proto/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- datafusion/substrait/Cargo.toml | 2 +- 10 files changed, 418 insertions(+), 270 deletions(-) create mode 100644 datafusion-examples/examples/regex_count.rs diff --git a/Cargo.toml b/Cargo.toml index fb6545c5bc4c..ae344a46a1bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.80" +rust-version = "1.76" version = "41.0.0" [workspace.dependencies] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 57a6c75dc6a2..0a4523a1c04e 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.80" +rust-version = "1.76" readme = "README.md" [dependencies] diff --git a/datafusion-examples/examples/regex_count.rs b/datafusion-examples/examples/regex_count.rs new file mode 100644 index 000000000000..93ec705ff6cc --- /dev/null +++ b/datafusion-examples/examples/regex_count.rs @@ -0,0 +1,33 @@ +use datafusion::common::Result; +use datafusion::prelude::{CsvReadOptions, SessionContext}; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv( + "examples", + "../../datafusion/physical-expr/tests/data/regex.csv", + CsvReadOptions::new(), + ) + .await?; + + // + // + //regexp_count examples + // + // + // regexp_count format is (regexp_count(text, regex[, flags]) + // + + // use sql and regexp_count function to test col 'values', against patterns in col 'patterns' without flags + let result = ctx + .sql("select regexp_count(values, patterns) from examples") + .await? + .collect() + .await?; + + println!("{:?}", result); + assert_eq!(result.len(), 1); + + Ok(()) +} diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 625c1067e495..adbba3eb31d6 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.80" +rust-version = "1.76" [lints] workspace = true diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 246d326fd0c8..2b7805c40915 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -15,26 +15,24 @@ // specific language governing permissions and limitations // under the License. -use std::iter::repeat; -use std::sync::Arc; - -use arrow::array::{ - Array, ArrayRef, AsArray, Datum, GenericStringArray, Int64Array, OffsetSizeTrait, - Scalar, +use arrow::array::{Array, ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, }; -use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8, Utf8View}; -use arrow::datatypes::Int64Type; use arrow::error::ArrowError; -use datafusion_common::cast::{as_generic_string_array, as_primitive_array}; -use datafusion_common::{ - arrow_err, exec_err, internal_err, DataFusionError, Result, ScalarValue, +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, }; -use datafusion_expr::TypeSignature::{Exact, Uniform}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use itertools::izip; use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::Arc; -/// regexp_count(string, pattern [, start [, flags ]]) -> int #[derive(Debug)] pub struct RegexpCountFunc { signature: Signature, @@ -78,13 +76,8 @@ impl ScalarUDFImpl for RegexpCountFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - - Ok(match &arg_types[0] { - Null => Null, - _ => Int64, - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) } fn invoke(&self, args: &[datafusion_expr::ColumnarValue]) -> Result { @@ -115,250 +108,388 @@ impl ScalarUDFImpl for RegexpCountFunc { fn regexp_count_func(args: &[ArrayRef]) -> Result { match args[0].data_type() { - DataType::Utf8 => regexp_count::(args), - DataType::LargeUtf8 => regexp_count::(args), + Utf8 => regexp_count::(args), + LargeUtf8 => regexp_count::(args), other => { internal_err!("Unsupported data type {other:?} for function regexp_count") } } } +/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. pub fn regexp_count(args: &[ArrayRef]) -> Result { - let arg_len = args.len(); + let args_len = args.len(); + if !(2..=4).contains(&args_len) { + return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); + } - match arg_len { - 2..=4 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; + let values = as_generic_string_array::(&args[0])?; + let regex_array = as_generic_string_array::(&args[1])?; - if values.is_empty() || regex.is_empty() { - return Ok(Arc::new(Int64Array::new_null(0))); - } + let (regex_scalar, is_regex_scalar) = if regex_array.len() == 1 { + (Some(regex_array.value(0)), true) + } else { + (None, false) + }; - let regex_datum: &dyn Datum = if regex.len() != 1 { - regex - } else { - &Scalar::new(regex) - }; - let start_scalar: Scalar<&Int64Array>; - let start_array_datum: Option<&dyn Datum> = if arg_len > 2 { - let start_array = as_primitive_array::(&args[2])?; - if start_array.is_empty() { - None - } else if start_array.len() != 1 { - Some(start_array as &dyn Datum) - } else { - start_scalar = Scalar::new(start_array); - Some(&start_scalar as &dyn Datum) + let (start_array, start_scalar, is_start_scalar) = if args.len() > 2 { + let start = as_int64_array(&args[2])?; + if start.len() == 1 { + (None, Some(start.value(0)), true) + } else { + (Some(start), None, false) + } + } else { + (None, Some(1), true) + }; + + let (flags_array, flags_scalar, is_flags_scalar) = if args.len() > 3 { + let flags = as_generic_string_array::(&args[3])?; + if flags.len() == 1 { + (None, Some(flags.value(0)), true) + } else { + (Some(flags), None, false) + } + } else { + (None, None, true) + }; + + match (is_regex_scalar, is_start_scalar, is_flags_scalar) { + (true, true, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) } - } else { - None + Some(regex) => regex, }; - let flags_scalar: Scalar<&GenericStringArray>; - let flags_array_datum: Option<&dyn Datum> = if arg_len > 3 { - let flags_array = as_generic_string_array::(&args[3])?; - if flags_array.is_empty() { - None - } else if flags_array.len() != 1 { - Some(flags_array as &dyn Datum) - } else { - flags_scalar = Scalar::new(flags_array); - Some(&flags_scalar as &dyn Datum) + let pattern = compile_regex(regex, flags_scalar)?; + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .map(|value| count_matches(value, &pattern, start_scalar)) + .collect::>>()?, + ))) + } + (true, true, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) } - } else { - None + Some(regex) => regex, }; - Ok(regexp_array_count( - values, - regex_datum, - start_array_datum, - flags_array_datum, - ) - .map(|x| Arc::new(x) as ArrayRef)?) - } - other => { - exec_err!( - "regexp_count was called with {other} arguments. It requires at least 2 and at most 4." - ) + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return exec_err!( + "flags_array must be the same length as values array; got {} and {}", + values.len(), + flags_array.len() + ); + } + + let mut regex_cache = HashMap::new(); + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(flags_array.iter()) + .map(|(value, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start_scalar) + }) + .collect::>>()?, + ))) } - } -} + (true, false, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; -pub fn regexp_array_count( - array: &GenericStringArray, - regex_array: &dyn Datum, - start_array: Option<&dyn Datum>, - flags_array: Option<&dyn Datum>, -) -> Result { - let (rhs, is_rhs_scalar) = regex_array.get(); - - if array.data_type() != rhs.data_type() { - return arrow_err!(ArrowError::ComputeError( - "regexp_count() requires pattern to be either Utf8 or LargeUtf8".to_string(), - )); - } + let pattern = compile_regex(regex, flags_scalar)?; - if !is_rhs_scalar && array.len() != rhs.len() { - return arrow_err!( - ArrowError::ComputeError( - "regexp_count() requires pattern to be either a scalar or the same length as the input array".to_string(), - ) - ); - } + let start_array = start_array.unwrap(); - let (starts, is_starts_scalar) = match start_array { - Some(starts) => starts.get(), - None => (&Int64Array::from(vec![1]) as &dyn Array, true), - }; + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(start_array.iter()) + .map(|(value, start)| count_matches(value, &pattern, start)) + .collect::>>()?, + ))) + } + (true, false, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; - if *starts.data_type() != Int64 { - return arrow_err!(ArrowError::ComputeError( - "regexp_count() requires start to be Int64".to_string(), - )); - } + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return exec_err!( + "flags_array must be the same length as values array; got {} and {}", + values.len(), + flags_array.len() + ); + } - if !is_starts_scalar && array.len() != starts.len() { - return arrow_err!( - ArrowError::ComputeError( - "regexp_count() requires start to be either a scalar or the same length as the input array".to_string(), - ) - ); - } + let mut regex_cache = HashMap::new(); + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + start_array.unwrap().iter(), + flags_array.iter() + ) + .map(|(value, start, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start) + }) + .collect::>>()?, + ))) + } + (false, true, true) => { + if values.len() != regex_array.len() { + return exec_err!( + "regex_array must be the same length as values array; got {} and {}", + values.len(), + regex_array.len() + ); + } - let (flags, is_flags_scalar) = match flags_array { - Some(flags) => flags.get(), - None => ( - &GenericStringArray::::from(vec![None as Option<&str>]) as &dyn Array, - true, - ), - }; + let mut regex_cache = HashMap::new(); + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(regex_array.iter()) + .map(|(value, regex)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start_scalar) + }) + .collect::>>()?, + ))) + } + (false, true, false) => { + if values.len() != regex_array.len() { + return exec_err!( + "regex_array must be the same length as values array; got {} and {}", + values.len(), + regex_array.len() + ); + } - if !is_flags_scalar && array.len() != flags.len() { - return arrow_err!( - ArrowError::ComputeError( - "regexp_count() requires flags to be either a scalar or the same length as the input array".to_string(), - ) - ); - } + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return exec_err!( + "flags_array must be the same length as values array; got {} and {}", + values.len(), + flags_array.len() + ); + } - let regex_iter: Box>> = if is_rhs_scalar { - let regex: &arrow::array::GenericByteArray< - arrow::datatypes::GenericStringType, - > = rhs.as_string::(); - let regex = regex.is_valid(0).then(|| regex.value(0)); - if regex.is_none() { - return Ok(Int64Array::from( - repeat(0).take(array.len()).collect::>(), - )); + let mut regex_cache = HashMap::new(); + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), flags_array.iter()) + .map(|(value, regex, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start_scalar) + }) + .collect::>>()?, + ))) } + (false, false, true) => { + if values.len() != regex_array.len() { + return exec_err!( + "regex_array must be the same length as values array; got {} and {}", + values.len(), + regex_array.len() + ); + } - Box::new(repeat(regex)) - } else { - Box::new(rhs.as_string::().iter()) - }; + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return exec_err!( + "start_array must be the same length as values array; got {} and {}", + values.len(), + start_array.len() + ); + } - let start_iter: Box> = if is_starts_scalar { - let start = starts.as_primitive::(); - let start = start.is_valid(0).then(|| start.value(0)); - Box::new(repeat(start.unwrap_or(1))) - } else { - Box::new( - starts - .as_primitive::() - .iter() - .map(|x| x.unwrap_or(1)), - ) - }; + let mut regex_cache = HashMap::new(); + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), start_array.iter()) + .map(|(value, regex, start)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start) + }) + .collect::>>()?, + ))) + } + (false, false, false) => { + if values.len() != regex_array.len() { + return exec_err!( + "regex_array must be the same length as values array; got {} and {}", + values.len(), + regex_array.len() + ); + } - let flags_iter: Box>> = if is_flags_scalar { - let flags = flags.as_string::(); - let flags = flags - .is_valid(0) - .then(|| flags.value(0)) - .map(|x| { - if x.contains('g') { - return arrow_err!(ArrowError::ComputeError( - "regexp_count() does not support global flag".to_string(), - )); - } - Ok(x) - }) - .transpose()?; + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return exec_err!( + "start_array must be the same length as values array; got {} and {}", + values.len(), + start_array.len() + ); + } - Box::new(repeat(flags)) - } else { - Box::new(flags.as_string::().iter()) - }; + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return exec_err!( + "flags_array must be the same length as values array; got {} and {}", + values.len(), + flags_array.len() + ); + } - regex_iter_count(array.iter(), regex_iter, start_iter, flags_iter) + let mut regex_cache = HashMap::new(); + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + regex_array.iter(), + start_array.iter(), + flags_array.iter() + ) + .map(|(value, regex, start, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start) + }) + .collect::>>()?, + ))) + } + } } -fn regex_iter_count<'a>( - array: impl Iterator>, - regex: impl Iterator>, - start: impl Iterator, - flags: impl Iterator>, -) -> Result { - Ok(Int64Array::from( - izip!(array, regex, start, flags) - .map(|(array, regex, start, flags)| { - if array.is_none() || regex.is_none() { - return Ok(0); - } +fn compile_and_cache_regex( + regex: &str, + flags: Option<&str>, + regex_cache: &mut HashMap, +) -> Result { + match regex_cache.entry(regex.to_string()) { + Entry::Vacant(entry) => { + let compiled = compile_regex(regex, flags)?; + entry.insert(compiled.clone()); + Ok(compiled) + } + Entry::Occupied(entry) => Ok(entry.get().to_owned()), + } +} - let regex = regex.unwrap(); - if regex.is_empty() { - return Ok(0); - } +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + ) + .into()); + } + format!("(?{}){}", flags, regex) + } + }; - if start <= 0 { - return Err(ArrowError::ComputeError( - "regexp_count() requires start to be 1 based".to_string(), - )); - } + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + )) + .into() + }) +} - let array = array.unwrap(); - let start = start as usize; - if start > array.len() { - return Ok(0); - } +fn count_matches( + value: Option<&str>, + pattern: &Regex, + start: Option, +) -> Result { + let value = match value { + None | Some("") => return Ok(0), + Some(value) => value, + }; - let pattern = if let Some(Some(flags)) = - flags.map(|x| if x.is_empty() { None } else { Some(x) }) - { - if flags.contains('g') { - return Err(ArrowError::ComputeError( - "regexp_count() does not support global flag".to_string(), - )); - } - - format!("(?{flags}){regex}") - } else { - regex.to_string() - }; - - let Ok(re) = Regex::new(pattern.as_str()) else { - return Err(ArrowError::ComputeError(format!( - "Regular expression did not compile: {}", - pattern - ))); - }; - - Ok(re - .find_iter(&array.chars().skip(start - 1).collect::()) - .count() as i64) - }) - .collect::, ArrowError>>()?, - )) + if let Some(start) = start { + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + ) + .into()); + } + + let find_slice = value.chars().skip(start as usize - 1).collect::(); + let count = pattern.find_iter(find_slice.as_str()).count(); + Ok(count as i64) + } else { + let count = pattern.find_iter(value).count(); + Ok(count as i64) + } } #[cfg(test)] mod tests { - use crate::regex::regexpcount::regexp_count; - use arrow::array::{ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait}; - use std::sync::Arc; + use super::*; + use arrow::array::GenericStringArray; #[test] fn test_regexp_count() { @@ -399,11 +530,7 @@ mod tests { let expected = Int64Array::from(vec![0, 1, 2, 1, 3]); - let re = regexp_count::(&[ - Arc::new(values) as ArrayRef, - Arc::new(regex) as ArrayRef, - ]) - .unwrap(); + let re = regexp_count::(&[Arc::new(values), Arc::new(regex)]).unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -420,12 +547,8 @@ mod tests { let expected = Int64Array::from(vec![0, 1, 1, 0, 2]); - let re = regexp_count::(&[ - Arc::new(values) as ArrayRef, - Arc::new(regex) as ArrayRef, - Arc::new(start) as ArrayRef, - ]) - .unwrap(); + let re = regexp_count::(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -444,10 +567,10 @@ mod tests { let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); let re = regexp_count::(&[ - Arc::new(values) as ArrayRef, - Arc::new(regex) as ArrayRef, - Arc::new(start) as ArrayRef, - Arc::new(flags) as ArrayRef, + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), ]) .unwrap(); assert_eq!(re.as_ref(), &expected); @@ -465,11 +588,7 @@ mod tests { let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); - let re = regexp_count::(&[ - Arc::new(values) as ArrayRef, - Arc::new(regex) as ArrayRef, - ]) - .unwrap(); + let re = regexp_count::(&[Arc::new(values), Arc::new(regex)]).unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -486,12 +605,8 @@ mod tests { let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); - let re = regexp_count::(&[ - Arc::new(values) as ArrayRef, - Arc::new(regex) as ArrayRef, - Arc::new(start) as ArrayRef, - ]) - .unwrap(); + let re = regexp_count::(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -510,10 +625,10 @@ mod tests { let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); let re = regexp_count::(&[ - Arc::new(values) as ArrayRef, - Arc::new(regex) as ArrayRef, - Arc::new(start) as ArrayRef, - Arc::new(flags) as ArrayRef, + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), ]) .unwrap(); assert_eq!(re.as_ref(), &expected); @@ -534,10 +649,10 @@ mod tests { let expected = Int64Array::from(vec![0, 0, 0, 2, 1]); let re = regexp_count::(&[ - Arc::new(values) as ArrayRef, - Arc::new(regex) as ArrayRef, - Arc::new(start) as ArrayRef, - Arc::new(flags) as ArrayRef, + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), ]) .unwrap(); assert_eq!(re.as_ref(), &expected); @@ -558,10 +673,10 @@ mod tests { let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); let re = regexp_count::(&[ - Arc::new(values) as ArrayRef, - Arc::new(regex) as ArrayRef, - Arc::new(start) as ArrayRef, - Arc::new(flags) as ArrayRef, + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), ]) .unwrap(); assert_eq!(re.as_ref(), &expected); diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 9b2f15a9a710..e5d65827cdec 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.80" +rust-version = "1.76" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index bb03208b2b70..54ec0e44694b 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.80" +rust-version = "1.76" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 4203bd7a28c0..95d9e6700a50 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.80" +rust-version = "1.76" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index e69282540cb2..401c51c94563 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.80" +rust-version = "1.76" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 0647263225c4..9e7ef9632ad3 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.80" +rust-version = "1.76" [lints] workspace = true From 07312be60b32421eac49e0af9dbd1ebdbd3b593e Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 29 Aug 2024 12:11:39 +0800 Subject: [PATCH 10/12] Refactoring regexp_count --- datafusion/functions/benches/regx.rs | 30 +- datafusion/functions/src/regex/regexpcount.rs | 442 +++++++++++------- 2 files changed, 288 insertions(+), 184 deletions(-) diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index dd902400d3a4..62fe4f53038d 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -19,8 +19,10 @@ extern crate criterion; use arrow::array::builder::StringBuilder; use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_functions::regex::regexpcount::regexp_count; +use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -85,7 +87,7 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("regexp_count_1000", |b| { + c.bench_function("regexp_count_1000 string", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; @@ -94,13 +96,33 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_count::(&[ + regexp_count_func(&[ Arc::clone(&data), Arc::clone(®ex), Arc::clone(&start), Arc::clone(&flags), ]) - .expect("regexp_count should work on valid values"), + .expect("regexp_count should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_count_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8view"), ) }) }); diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 2b7805c40915..511481d5e892 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array}; +use arrow::datatypes::{DataType, Int64Type}; use arrow::datatypes::{ DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, }; use arrow::error::ArrowError; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarUDFImpl, Signature, TypeSignature::Exact, @@ -33,6 +32,8 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Arc; +use crate::string::common::StringArrayType; + #[derive(Debug)] pub struct RegexpCountFunc { signature: Signature, @@ -106,16 +107,32 @@ impl ScalarUDFImpl for RegexpCountFunc { } } -fn regexp_count_func(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Utf8 => regexp_count::(args), - LargeUtf8 => regexp_count::(args), +pub fn regexp_count_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=4).contains(&args_len) { + return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), other => { - internal_err!("Unsupported data type {other:?} for function regexp_count") + return internal_err!( + "Unsupported data type {other:?} for function regexp_count" + ); } } + + regexp_count( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + ) + .map_err(|e| e.into()) } +/// `arrow-rs` style implementation of `regexp_count` function. /// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern /// within a string array. It supports optional start positions and flags for case insensitivity. /// @@ -130,42 +147,122 @@ fn regexp_count_func(args: &[ArrayRef]) -> Result { /// /// # Errors /// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. -pub fn regexp_count(args: &[ArrayRef]) -> Result { - let args_len = args.len(); - if !(2..=4).contains(&args_len) { - return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); +pub fn regexp_count( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, is_regex_scalar) = regex_array.get(); + let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| { + let (start, is_start_scalar) = start.get(); + (Some(start), is_start_scalar) + }); + let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), is_flags_scalar) + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (Utf8View, Utf8View, None) => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string_view()), + is_flags_scalar, + ), + _ => Err(ArrowError::ComputeError( + "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), } +} - let values = as_generic_string_array::(&args[0])?; - let regex_array = as_generic_string_array::(&args[1])?; - - let (regex_scalar, is_regex_scalar) = if regex_array.len() == 1 { +pub fn regexp_count_inner<'a, S>( + values: S, + regex_array: S, + is_regex_scalar: bool, + start_array: Option<&Int64Array>, + is_start_scalar: bool, + flags_array: Option, + is_flags_scalar: bool, +) -> Result +where + S: StringArrayType<'a>, +{ + let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { (Some(regex_array.value(0)), true) } else { (None, false) }; - let (start_array, start_scalar, is_start_scalar) = if args.len() > 2 { - let start = as_int64_array(&args[2])?; - if start.len() == 1 { - (None, Some(start.value(0)), true) + let (start_array, start_scalar, is_start_scalar) = + if let Some(start_array) = start_array { + if is_start_scalar || start_array.len() == 1 { + (None, Some(start_array.value(0)), true) + } else { + (Some(start_array), None, false) + } } else { - (Some(start), None, false) - } - } else { - (None, Some(1), true) - }; - - let (flags_array, flags_scalar, is_flags_scalar) = if args.len() > 3 { - let flags = as_generic_string_array::(&args[3])?; - if flags.len() == 1 { - (None, Some(flags.value(0)), true) + (None, Some(1), true) + }; + + let (flags_array, flags_scalar, is_flags_scalar) = + if let Some(flags_array) = flags_array { + if is_flags_scalar || flags_array.len() == 1 { + (None, Some(flags_array.value(0)), true) + } else { + (Some(flags_array), None, false) + } } else { - (Some(flags), None, false) - } - } else { - (None, None, true) - }; + (None, None, true) + }; match (is_regex_scalar, is_start_scalar, is_flags_scalar) { (true, true, true) => { @@ -182,7 +279,7 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { values .iter() .map(|value| count_matches(value, &pattern, start_scalar)) - .collect::>>()?, + .collect::, ArrowError>>()?, ))) } (true, true, false) => { @@ -195,11 +292,11 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { let flags_array = flags_array.unwrap(); if values.len() != flags_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "flags_array must be the same length as values array; got {} and {}", values.len(), flags_array.len() - ); + ))); } let mut regex_cache = HashMap::new(); @@ -212,7 +309,7 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { compile_and_cache_regex(regex, flags, &mut regex_cache)?; count_matches(value, &pattern, start_scalar) }) - .collect::>>()?, + .collect::, ArrowError>>()?, ))) } (true, false, true) => { @@ -232,7 +329,7 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { .iter() .zip(start_array.iter()) .map(|(value, start)| count_matches(value, &pattern, start)) - .collect::>>()?, + .collect::, ArrowError>>()?, ))) } (true, false, false) => { @@ -245,11 +342,11 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { let flags_array = flags_array.unwrap(); if values.len() != flags_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "flags_array must be the same length as values array; got {} and {}", values.len(), flags_array.len() - ); + ))); } let mut regex_cache = HashMap::new(); @@ -265,16 +362,16 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { count_matches(value, &pattern, start) }) - .collect::>>()?, + .collect::, ArrowError>>()?, ))) } (false, true, true) => { if values.len() != regex_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "regex_array must be the same length as values array; got {} and {}", values.len(), regex_array.len() - ); + ))); } let mut regex_cache = HashMap::new(); @@ -295,25 +392,25 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { )?; count_matches(value, &pattern, start_scalar) }) - .collect::>>()?, + .collect::, ArrowError>>()?, ))) } (false, true, false) => { if values.len() != regex_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "regex_array must be the same length as values array; got {} and {}", values.len(), regex_array.len() - ); + ))); } let flags_array = flags_array.unwrap(); if values.len() != flags_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "flags_array must be the same length as values array; got {} and {}", values.len(), flags_array.len() - ); + ))); } let mut regex_cache = HashMap::new(); @@ -330,25 +427,25 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { count_matches(value, &pattern, start_scalar) }) - .collect::>>()?, + .collect::, ArrowError>>()?, ))) } (false, false, true) => { if values.len() != regex_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "regex_array must be the same length as values array; got {} and {}", values.len(), regex_array.len() - ); + ))); } let start_array = start_array.unwrap(); if values.len() != start_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "start_array must be the same length as values array; got {} and {}", values.len(), start_array.len() - ); + ))); } let mut regex_cache = HashMap::new(); @@ -367,34 +464,34 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { )?; count_matches(value, &pattern, start) }) - .collect::>>()?, + .collect::, ArrowError>>()?, ))) } (false, false, false) => { if values.len() != regex_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "regex_array must be the same length as values array; got {} and {}", values.len(), regex_array.len() - ); + ))); } let start_array = start_array.unwrap(); if values.len() != start_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "start_array must be the same length as values array; got {} and {}", values.len(), start_array.len() - ); + ))); } let flags_array = flags_array.unwrap(); if values.len() != flags_array.len() { - return exec_err!( + return Err(ArrowError::ComputeError(format!( "flags_array must be the same length as values array; got {} and {}", values.len(), flags_array.len() - ); + ))); } let mut regex_cache = HashMap::new(); @@ -415,7 +512,7 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { compile_and_cache_regex(regex, flags, &mut regex_cache)?; count_matches(value, &pattern, start) }) - .collect::>>()?, + .collect::, ArrowError>>()?, ))) } } @@ -425,7 +522,7 @@ fn compile_and_cache_regex( regex: &str, flags: Option<&str>, regex_cache: &mut HashMap, -) -> Result { +) -> Result { match regex_cache.entry(regex.to_string()) { Entry::Vacant(entry) => { let compiled = compile_regex(regex, flags)?; @@ -436,15 +533,14 @@ fn compile_and_cache_regex( } } -fn compile_regex(regex: &str, flags: Option<&str>) -> Result { +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { let pattern = match flags { None | Some("") => regex.to_string(), Some(flags) => { if flags.contains("g") { return Err(ArrowError::ComputeError( "regexp_count() does not support global flag".to_string(), - ) - .into()); + )); } format!("(?{}){}", flags, regex) } @@ -455,7 +551,6 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result { "Regular expression did not compile: {}", pattern )) - .into() }) } @@ -463,7 +558,7 @@ fn count_matches( value: Option<&str>, pattern: &Regex, start: Option, -) -> Result { +) -> Result { let value = match value { None | Some("") => return Ok(0), Some(value) => value, @@ -473,8 +568,7 @@ fn count_matches( if start < 1 { return Err(ArrowError::ComputeError( "regexp_count() requires start to be 1 based".to_string(), - ) - .into()); + )); } let find_slice = value.chars().skip(start as usize - 1).collect::(); @@ -489,84 +583,87 @@ fn count_matches( #[cfg(test)] mod tests { use super::*; - use arrow::array::GenericStringArray; + use arrow::array::{GenericStringArray, StringViewArray}; #[test] fn test_regexp_count() { - test_case_sensitive_regexp_count_scalar::(); - test_case_sensitive_regexp_count_scalar::(); - - test_case_sensitive_regexp_count_scalar_start::(); - test_case_sensitive_regexp_count_scalar_start::(); - - test_case_insensitive_regexp_count_scalar_flags::(); - test_case_insensitive_regexp_count_scalar_flags::(); - - test_case_sensitive_regexp_count_array::(); - test_case_sensitive_regexp_count_array::(); - - test_case_sensitive_regexp_count_array_start::(); - test_case_sensitive_regexp_count_array_start::(); - - test_case_insensitive_regexp_count_array_flags::(); - test_case_insensitive_regexp_count_array_flags::(); - - test_case_sensitive_regexp_count_start_scalar_complex::(); - test_case_sensitive_regexp_count_start_scalar_complex::(); - - test_case_sensitive_regexp_count_array_complex::(); - test_case_sensitive_regexp_count_array_complex::(); + test_case_sensitive_regexp_count_scalar::>(); + test_case_sensitive_regexp_count_scalar::>(); + test_case_sensitive_regexp_count_scalar::(); + + test_case_sensitive_regexp_count_scalar_start::>(); + test_case_sensitive_regexp_count_scalar_start::>(); + test_case_sensitive_regexp_count_scalar_start::(); + + test_case_insensitive_regexp_count_scalar_flags::>(); + test_case_insensitive_regexp_count_scalar_flags::>(); + test_case_insensitive_regexp_count_scalar_flags::(); + + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::(); + + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::(); + + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::(); + + test_case_sensitive_regexp_count_start_scalar_complex::>( + ); + test_case_sensitive_regexp_count_start_scalar_complex::>( + ); + test_case_sensitive_regexp_count_start_scalar_complex::(); + + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::(); } - fn test_case_sensitive_regexp_count_scalar() { - let values = GenericStringArray::::from(vec![ - "", - "aabca", - "abcabc", - "abcAbcab", - "abcabcabc", - ]); - let regex = GenericStringArray::::from(vec!["abc"; 1]); + fn test_case_sensitive_regexp_count_scalar() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]); + let regex = A::from(vec!["abc"; 1]); + let start = Int64Array::from(vec![2]); - let expected = Int64Array::from(vec![0, 1, 2, 1, 3]); + let expected = Int64Array::from(vec![0, 1, 1, 0, 2]); - let re = regexp_count::(&[Arc::new(values), Arc::new(regex)]).unwrap(); + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); assert_eq!(re.as_ref(), &expected); } - fn test_case_sensitive_regexp_count_scalar_start() { - let values = GenericStringArray::::from(vec![ - "", - "aabca", - "abcabc", - "abcAbcab", - "abcabcabc", - ]); - let regex = GenericStringArray::::from(vec!["abc"; 1]); + fn test_case_sensitive_regexp_count_scalar_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]); + let regex = A::from(vec!["abc"; 1]); let start = Int64Array::from(vec![2]); let expected = Int64Array::from(vec![0, 1, 1, 0, 2]); - let re = regexp_count::(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) .unwrap(); assert_eq!(re.as_ref(), &expected); } - fn test_case_insensitive_regexp_count_scalar_flags() { - let values = GenericStringArray::::from(vec![ - "", - "aabca", - "abcabc", - "abcAbcab", - "abcabcabc", - ]); - let regex = GenericStringArray::::from(vec!["abc"; 1]); + fn test_case_insensitive_regexp_count_scalar_flags() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]); + let regex = A::from(vec!["abc"; 1]); let start = Int64Array::from(vec![1]); - let flags = GenericStringArray::::from(vec!["i"]); + let flags = A::from(vec!["i"]); let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); - let re = regexp_count::(&[ + let re = regexp_count_func(&[ Arc::new(values), Arc::new(regex), Arc::new(start), @@ -576,55 +673,46 @@ mod tests { assert_eq!(re.as_ref(), &expected); } - fn test_case_sensitive_regexp_count_array() { - let values = GenericStringArray::::from(vec![ - "", - "aabca", - "abcabc", - "abcAbcab", - "abcabcAbc", - ]); - let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + fn test_case_sensitive_regexp_count_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); - let re = regexp_count::(&[Arc::new(values), Arc::new(regex)]).unwrap(); + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); assert_eq!(re.as_ref(), &expected); } - fn test_case_sensitive_regexp_count_array_start() { - let values = GenericStringArray::::from(vec![ - "", - "aAbca", - "abcabc", - "abcAbcab", - "abcabcAbc", - ]); - let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + fn test_case_sensitive_regexp_count_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); let start = Int64Array::from(vec![1, 2, 3, 4, 5]); let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); - let re = regexp_count::(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) .unwrap(); assert_eq!(re.as_ref(), &expected); } - fn test_case_insensitive_regexp_count_array_flags() { - let values = GenericStringArray::::from(vec![ - "", - "aAbca", - "abcabc", - "abcAbcab", - "abcabcAbc", - ]); - let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + fn test_case_insensitive_regexp_count_array_flags() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); let start = Int64Array::from(vec![1]); - let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + let flags = A::from(vec!["", "i", "", "", "i"]); let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); - let re = regexp_count::(&[ + let re = regexp_count_func(&[ Arc::new(values), Arc::new(regex), Arc::new(start), @@ -634,21 +722,18 @@ mod tests { assert_eq!(re.as_ref(), &expected); } - fn test_case_sensitive_regexp_count_start_scalar_complex() { - let values = GenericStringArray::::from(vec![ - "", - "aAbca", - "abcabc", - "abcAbcabc", - "abcabcAbc", - ]); - let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + fn test_case_sensitive_regexp_count_start_scalar_complex() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcabc", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); let start = Int64Array::from(vec![5]); - let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + let flags = A::from(vec!["", "i", "", "", "i"]); let expected = Int64Array::from(vec![0, 0, 0, 2, 1]); - let re = regexp_count::(&[ + let re = regexp_count_func(&[ Arc::new(values), Arc::new(regex), Arc::new(start), @@ -658,21 +743,18 @@ mod tests { assert_eq!(re.as_ref(), &expected); } - fn test_case_sensitive_regexp_count_array_complex() { - let values = GenericStringArray::::from(vec![ - "", - "aAbca", - "abcabc", - "abcAbcab", - "abcabcAbc", - ]); - let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + fn test_case_sensitive_regexp_count_array_complex() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); let start = Int64Array::from(vec![1, 2, 3, 4, 5]); - let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + let flags = A::from(vec!["", "i", "", "", "i"]); let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); - let re = regexp_count::(&[ + let re = regexp_count_func(&[ Arc::new(values), Arc::new(regex), Arc::new(start), From 4eb7e6bcb6985cf9c344482d3821282974b2c29c Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 29 Aug 2024 12:15:32 +0800 Subject: [PATCH 11/12] Revert ci change --- .github/workflows/rust.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2f4b9eae59a2..809f3acd8374 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -569,9 +569,9 @@ jobs: # # To reproduce: # 1. Install the version of Rust that is failing. Example: - # rustup install 1.80.0 + # rustup install 1.76.0 # 2. Run the command that failed with that version. Example: - # cargo +1.80.0 check -p datafusion + # cargo +1.76.0 check -p datafusion # # To resolve, either: # 1. Change your code to use older Rust features, From cb135564be724761be9922eb7a34513595dd46c2 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 29 Aug 2024 13:44:22 +0800 Subject: [PATCH 12/12] Fix ci --- datafusion-examples/examples/regex_count.rs | 33 --------------------- datafusion/functions/Cargo.toml | 2 +- 2 files changed, 1 insertion(+), 34 deletions(-) delete mode 100644 datafusion-examples/examples/regex_count.rs diff --git a/datafusion-examples/examples/regex_count.rs b/datafusion-examples/examples/regex_count.rs deleted file mode 100644 index 93ec705ff6cc..000000000000 --- a/datafusion-examples/examples/regex_count.rs +++ /dev/null @@ -1,33 +0,0 @@ -use datafusion::common::Result; -use datafusion::prelude::{CsvReadOptions, SessionContext}; - -#[tokio::main] -async fn main() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_csv( - "examples", - "../../datafusion/physical-expr/tests/data/regex.csv", - CsvReadOptions::new(), - ) - .await?; - - // - // - //regexp_count examples - // - // - // regexp_count format is (regexp_count(text, regex[, flags]) - // - - // use sql and regexp_count function to test col 'values', against patterns in col 'patterns' without flags - let result = ctx - .sql("select regexp_count(values, patterns) from examples") - .await? - .collect() - .await?; - - println!("{:?}", result); - assert_eq!(result.len(), 1); - - Ok(()) -} diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 337379a74670..c793b8e6464e 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -52,7 +52,7 @@ encoding_expressions = ["base64", "hex"] # enable math functions math_expressions = [] # enable regular expressions -regex_expressions = ["regex"] +regex_expressions = ["regex", "string_expressions"] # enable string functions string_expressions = ["uuid"] # enable unicode functions