From f72c11f9c14f87760fd2db0ab7b7beb3ea7c0910 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Tue, 20 Aug 2024 21:34:41 +0800 Subject: [PATCH] Implement regexp_ccount --- datafusion/functions/src/regex/mod.rs | 17 +- datafusion/functions/src/regex/regexpcount.rs | 561 ++++++++++++++++++ datafusion/sqllogictest/test_files/regexp.slt | 175 +++++- 3 files changed, 740 insertions(+), 13 deletions(-) create mode 100644 datafusion/functions/src/regex/regexpcount.rs diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 4ac162290ddb..a0b4d0fe6960 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,10 +17,12 @@ //! "regx" DataFusion functions +pub mod regexpcount; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs +make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); make_udf_function!( @@ -32,6 +34,14 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; + pub fn regexp_count(values: Expr, regex: Expr, flags: Option) -> Expr { + let mut args = vec![values, regex]; + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_count().call(args) + } + /// Returns a list of regular expression matches in a string. pub fn regexp_match(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; @@ -67,5 +77,10 @@ pub mod expr_fn { /// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { - vec![regexp_match(), regexp_like(), regexp_replace()] + vec![ + regexp_count(), + regexp_match(), + regexp_like(), + regexp_replace(), + ] } diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs new file mode 100644 index 000000000000..103f934aa903 --- /dev/null +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -0,0 +1,561 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::iter::repeat; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, Datum, GenericStringArray, Int64Array, OffsetSizeTrait, + Scalar, +}; +use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8}; +use arrow::datatypes::Int64Type; +use arrow::error::ArrowError; +use datafusion_common::cast::{as_generic_string_array, as_primitive_array}; +use datafusion_common::{ + arrow_err, exec_err, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use itertools::izip; +use regex::Regex; + +/// regexp_count(string, pattern [, start [, flags ]]) -> int +#[derive(Debug)] +pub struct RegexpCountFunc { + signature: Signature, +} + +impl Default for RegexpCountFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpCountFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8, Utf8, Int64, Utf8]), + Exact(vec![Utf8, Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpCountFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match &arg_types[0] { + Null => Null, + _ => Int64, + }) + } + + fn invoke(&self, args: &[datafusion_expr::ColumnarValue]) -> Result { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.clone().into_array(inferred_length)) + .collect::>>()?; + + let result = regexp_count_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } +} + +fn regexp_count_func(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => regexp_count::(args), + DataType::LargeUtf8 => regexp_count::(args), + other => { + internal_err!("Unsupported data type {other:?} for function regexp_like") + } + } +} + +pub fn regexp_count(args: &[ArrayRef]) -> Result { + let arg_len = args.len(); + + match arg_len { + 2..=4 => { + let values = as_generic_string_array::(&args[0])?; + let regex = as_generic_string_array::(&args[1])?; + let regex_datum: &dyn Datum = if regex.len() != 1 { + regex + } else { + &Scalar::new(regex) + }; + let start_scalar: Scalar<&Int64Array>; + let start_array_datum: Option<&dyn Datum> = if arg_len > 2 { + let start_array = as_primitive_array::(&args[2])?; + if start_array.len() != 1 { + Some(start_array as &dyn Datum) + } else { + start_scalar = Scalar::new(start_array); + Some(&start_scalar as &dyn Datum) + } + } else { + None + }; + + let flags_scalar: Scalar<&GenericStringArray>; + let flags_array_datum: Option<&dyn Datum> = if arg_len > 3 { + let flags_array = as_generic_string_array::(&args[3])?; + if flags_array.len() != 1 { + Some(flags_array as &dyn Datum) + } else { + flags_scalar = Scalar::new(flags_array); + Some(&flags_scalar as &dyn Datum) + } + } else { + None + }; + + Ok(regexp_array_count( + values, + regex_datum, + start_array_datum, + flags_array_datum, + ) + .map(|x| Arc::new(x) as ArrayRef)?) + } + other => { + exec_err!( + "regexp_count was called with {other} arguments. It requires at least 2 and at most 4." + ) + } + } +} + +pub fn regexp_array_count( + array: &GenericStringArray, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, +) -> Result { + let (rhs, is_rhs_scalar) = regex_array.get(); + + if array.data_type() != rhs.data_type() { + return arrow_err!(ArrowError::ComputeError( + "regexp_count() requires pattern to be either Utf8 or LargeUtf8".to_string(), + )); + } + + if !is_rhs_scalar && array.len() != rhs.len() { + return arrow_err!( + ArrowError::ComputeError( + "regexp_count() requires pattern to be either a scalar or the same length as the input array".to_string(), + ) + ); + } + + let (starts, is_starts_scalar) = match start_array { + Some(starts) => starts.get(), + None => (&Int64Array::from(vec![1]) as &dyn Array, true), + }; + + if *starts.data_type() != Int64 { + return arrow_err!(ArrowError::ComputeError( + "regexp_count() requires start to be Int64".to_string(), + )); + } + + if !is_starts_scalar && array.len() != starts.len() { + return arrow_err!( + ArrowError::ComputeError( + "regexp_count() requires start to be either a scalar or the same length as the input array".to_string(), + ) + ); + } + + let (flags, is_flags_scalar) = match flags_array { + Some(flags) => flags.get(), + None => ( + &GenericStringArray::::from(vec![None as Option<&str>]) as &dyn Array, + true, + ), + }; + + if !is_flags_scalar && array.len() != flags.len() { + return arrow_err!( + ArrowError::ComputeError( + "regexp_count() requires flags to be either a scalar or the same length as the input array".to_string(), + ) + ); + } + + let regex_iter: Box>> = if is_rhs_scalar { + let regex: &arrow::array::GenericByteArray< + arrow::datatypes::GenericStringType, + > = rhs.as_string::(); + let regex = regex.is_valid(0).then(|| regex.value(0)); + if regex.is_none() { + return Ok(Int64Array::from( + repeat(0).take(array.len()).collect::>(), + )); + } + + Box::new(repeat(regex)) + } else { + Box::new(rhs.as_string::().iter()) + }; + + let start_iter: Box> = if is_starts_scalar { + let start = starts.as_primitive::(); + let start = start.is_valid(0).then(|| start.value(0)); + Box::new(repeat(start.unwrap_or(1))) + } else { + Box::new( + starts + .as_primitive::() + .iter() + .map(|x| x.unwrap_or(1)), + ) + }; + + let flags_iter: Box>> = if is_flags_scalar { + let flags = flags.as_string::(); + let flags = flags + .is_valid(0) + .then(|| flags.value(0)) + .map(|x| { + if x.contains('g') { + return arrow_err!(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + Ok(x) + }) + .transpose()?; + + Box::new(repeat(flags)) + } else { + Box::new(flags.as_string::().iter()) + }; + + regex_iter_count(array.iter(), regex_iter, start_iter, flags_iter) +} + +fn regex_iter_count<'a>( + array: impl Iterator>, + regex: impl Iterator>, + start: impl Iterator, + flags: impl Iterator>, +) -> Result { + Ok(Int64Array::from( + izip!(array, regex, start, flags) + .map(|(array, regex, start, flags)| { + if array.is_none() || regex.is_none() { + return Ok(0); + } + + let regex = regex.unwrap(); + if regex.is_empty() { + return Ok(0); + } + + if regex.contains('g') { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + + if start <= 0 { + return Err(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + )); + } + + let array = array.unwrap(); + let start = start as usize; + if start > array.len() { + return Ok(0); + } + + let pattern = if let Some(Some(flags)) = + flags.map(|x| if x.is_empty() { None } else { Some(x) }) + { + format!("(?{flags}){regex}") + } else { + regex.to_string() + }; + + let Ok(re) = Regex::new(pattern.as_str()) else { + return Err(ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + ))); + }; + + Ok(re + .find_iter(&array.chars().skip(start - 1).collect::()) + .count() as i64) + }) + .collect::, ArrowError>>()?, + )) +} + +#[cfg(test)] +mod tests { + use crate::regex::regexpcount::regexp_count; + use arrow::array::{ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait}; + use std::sync::Arc; + + #[test] + fn test_regexp_count() { + test_case_sensitive_regexp_count_scalar::(); + test_case_sensitive_regexp_count_scalar::(); + + test_case_sensitive_regexp_count_scalar_start::(); + test_case_sensitive_regexp_count_scalar_start::(); + + test_case_insensitive_regexp_count_scalar_flags::(); + test_case_insensitive_regexp_count_scalar_flags::(); + + test_case_sensitive_regexp_count_array::(); + test_case_sensitive_regexp_count_array::(); + + test_case_sensitive_regexp_count_array_start::(); + test_case_sensitive_regexp_count_array_start::(); + + test_case_insensitive_regexp_count_array_flags::(); + test_case_insensitive_regexp_count_array_flags::(); + + test_case_sensitive_regexp_count_start_scalar_complex::(); + test_case_sensitive_regexp_count_start_scalar_complex::(); + + test_case_sensitive_regexp_count_array_complex::(); + test_case_sensitive_regexp_count_array_complex::(); + } + + fn test_case_sensitive_regexp_count_scalar() { + let values = GenericStringArray::::from(vec![ + "", + "aabca", + "abcabc", + "abcAbcab", + "abcabcabc", + ]); + let regex = GenericStringArray::::from(vec!["abc"; 1]); + + let expected = Int64Array::from(vec![0, 1, 2, 1, 3]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_scalar_start() { + let values = GenericStringArray::::from(vec![ + "", + "aabca", + "abcabc", + "abcAbcab", + "abcabcabc", + ]); + let regex = GenericStringArray::::from(vec!["abc"; 1]); + let start = Int64Array::from(vec![2]); + + let expected = Int64Array::from(vec![0, 1, 1, 0, 2]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_scalar_flags() { + let values = GenericStringArray::::from(vec![ + "", + "aabca", + "abcabc", + "abcAbcab", + "abcabcabc", + ]); + let regex = GenericStringArray::::from(vec!["abc"; 1]); + let start = Int64Array::from(vec![1]); + let flags = GenericStringArray::::from(vec!["i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + Arc::new(flags) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array() { + let values = GenericStringArray::::from(vec![ + "", + "aabca", + "abcabc", + "abcAbcab", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_start() { + let values = GenericStringArray::::from(vec![ + "", + "aAbca", + "abcabc", + "abcAbcab", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_array_flags() { + let values = GenericStringArray::::from(vec![ + "", + "aAbca", + "abcabc", + "abcAbcab", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1]); + let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + Arc::new(flags) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_start_scalar_complex() { + let values = GenericStringArray::::from(vec![ + "", + "aAbca", + "abcabc", + "abcAbcabc", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![5]); + let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 0, 0, 2, 1]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + Arc::new(flags) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_complex() { + let values = GenericStringArray::::from(vec![ + "", + "aAbca", + "abcabc", + "abcAbcab", + "abcabcAbc", + ]); + let regex = GenericStringArray::::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + let flags = GenericStringArray::::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); + + let re = regexp_count::(&[ + Arc::new(values) as ArrayRef, + Arc::new(regex) as ArrayRef, + Arc::new(start) as ArrayRef, + Arc::new(flags) as ArrayRef, + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 22322d79ccfe..aa1b350d3661 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -16,18 +16,18 @@ # under the License. statement ok -CREATE TABLE t (str varchar, pattern varchar, flags varchar) AS VALUES - ('abc', '^(a)', 'i'), - ('ABC', '^(A).*', 'i'), - ('aBc', '(b|d)', 'i'), - ('AbC', '(B|D)', null), - ('aBC', '^(b|c)', null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('Düsseldorf','[\p{Letter}-]+', null), - ('Москва', '[\p{L}-]+', null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', null), - ('إسرائيل', '^\p{Arabic}+$', null); +CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); # # regexp_like tests @@ -395,6 +395,157 @@ SELECT 'foo\nbar\nbaz' LIKE '%bar%'; ---- true + +# regexp_count tests + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + statement ok drop table t;