Skip to content

Commit

Permalink
Initial support for regex_replace on StringViewArray (#11556)
Browse files Browse the repository at this point in the history
* initial support for string view regex

* update tests
  • Loading branch information
XiangpengHao authored Jul 22, 2024
1 parent efcf5c6 commit 34d42bc
Showing 1 changed file with 151 additions and 66 deletions.
217 changes: 151 additions & 66 deletions datafusion/functions/src/regex/regexpreplace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

//! Regx expressions
use arrow::array::new_null_array;
use arrow::array::ArrayAccessor;
use arrow::array::ArrayDataBuilder;
use arrow::array::BufferBuilder;
use arrow::array::GenericStringArray;
use arrow::array::StringViewBuilder;
use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion_common::cast::as_string_view_array;
use datafusion_common::exec_err;
use datafusion_common::plan_err;
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -54,6 +57,7 @@ impl RegexpReplaceFunc {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Utf8]),
Exact(vec![Utf8View, Utf8, Utf8]),
Exact(vec![Utf8, Utf8, Utf8, Utf8]),
],
Volatility::Immutable,
Expand All @@ -80,6 +84,7 @@ impl ScalarUDFImpl for RegexpReplaceFunc {
Ok(match &arg_types[0] {
LargeUtf8 | LargeBinary => LargeUtf8,
Utf8 | Binary => Utf8,
Utf8View | BinaryView => Utf8View,
Null => Null,
Dictionary(_, t) => match **t {
LargeUtf8 | LargeBinary => LargeUtf8,
Expand Down Expand Up @@ -118,15 +123,18 @@ impl ScalarUDFImpl for RegexpReplaceFunc {
}
}
}

fn regexp_replace_func(args: &[ColumnarValue]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => specialize_regexp_replace::<i32>(args),
DataType::LargeUtf8 => specialize_regexp_replace::<i64>(args),
DataType::Utf8View => specialize_regexp_replace::<i32>(args),
other => {
internal_err!("Unsupported data type {other:?} for function regexp_replace")
}
}
}

/// replace POSIX capture groups (like \1) with Rust Regex group (like ${1})
/// used by regexp_replace
fn regex_replace_posix_groups(replacement: &str) -> String {
Expand Down Expand Up @@ -280,8 +288,8 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
}
}

fn _regexp_replace_early_abort<T: OffsetSizeTrait>(
input_array: &GenericStringArray<T>,
fn _regexp_replace_early_abort<T: ArrayAccessor>(
input_array: T,
sz: usize,
) -> Result<ArrayRef> {
// Mimicking the existing behavior of regexp_replace, if any of the scalar arguments
Expand All @@ -290,13 +298,14 @@ fn _regexp_replace_early_abort<T: OffsetSizeTrait>(
// Also acts like an early abort mechanism when the input array is empty.
Ok(new_null_array(input_array.data_type(), sz))
}

/// Get the first argument from the given string array.
///
/// Note: If the array is empty or the first argument is null,
/// then calls the given early abort function.
macro_rules! fetch_string_arg {
($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident, $ARRAY_SIZE:expr) => {{
let array = as_generic_string_array::<T>($ARG)?;
let array = as_generic_string_array::<$T>($ARG)?;
if array.len() == 0 || array.is_null(0) {
return $EARLY_ABORT(array, $ARRAY_SIZE);
} else {
Expand All @@ -313,25 +322,24 @@ macro_rules! fetch_string_arg {
fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;
let array_size = string_array.len();
let array_size = args[0].len();
let pattern = fetch_string_arg!(
&args[1],
"pattern",
T,
i32,
_regexp_replace_early_abort,
array_size
);
let replacement = fetch_string_arg!(
&args[2],
"replacement",
T,
i32,
_regexp_replace_early_abort,
array_size
);
let flags = match args.len() {
3 => None,
4 => Some(fetch_string_arg!(&args[3], "flags", T, _regexp_replace_early_abort, array_size)),
4 => Some(fetch_string_arg!(&args[3], "flags", i32, _regexp_replace_early_abort, array_size)),
other => {
return exec_err!(
"regexp_replace was called with {other} arguments. It requires at least 3 and at most 4."
Expand All @@ -358,32 +366,61 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
// with rust ones.
let replacement = regex_replace_posix_groups(replacement);

// We are going to create the underlying string buffer from its parts
// to be able to re-use the existing null buffer for sparse arrays.
let mut vals = BufferBuilder::<u8>::new({
let offsets = string_array.value_offsets();
(offsets[string_array.len()] - offsets[0])
.to_usize()
.expect("Failed to convert usize")
});
let mut new_offsets = BufferBuilder::<T>::new(string_array.len() + 1);
new_offsets.append(T::zero());

string_array.iter().for_each(|val| {
if let Some(val) = val {
let result = re.replacen(val, limit, replacement.as_str());
vals.append_slice(result.as_bytes());
let string_array_type = args[0].data_type();
match string_array_type {
DataType::Utf8 | DataType::LargeUtf8 => {
let string_array = as_generic_string_array::<T>(&args[0])?;

// We are going to create the underlying string buffer from its parts
// to be able to re-use the existing null buffer for sparse arrays.
let mut vals = BufferBuilder::<u8>::new({
let offsets = string_array.value_offsets();
(offsets[string_array.len()] - offsets[0])
.to_usize()
.unwrap()
});
let mut new_offsets = BufferBuilder::<T>::new(string_array.len() + 1);
new_offsets.append(T::zero());

string_array.iter().for_each(|val| {
if let Some(val) = val {
let result = re.replacen(val, limit, replacement.as_str());
vals.append_slice(result.as_bytes());
}
new_offsets.append(T::from_usize(vals.len()).unwrap());
});

let data = ArrayDataBuilder::new(GenericStringArray::<T>::DATA_TYPE)
.len(string_array.len())
.nulls(string_array.nulls().cloned())
.buffers(vec![new_offsets.finish(), vals.finish()])
.build()?;
let result_array = GenericStringArray::<T>::from(data);
Ok(Arc::new(result_array) as ArrayRef)
}
new_offsets.append(T::from_usize(vals.len()).unwrap());
});

let data = ArrayDataBuilder::new(GenericStringArray::<T>::DATA_TYPE)
.len(string_array.len())
.nulls(string_array.nulls().cloned())
.buffers(vec![new_offsets.finish(), vals.finish()])
.build()?;
let result_array = GenericStringArray::<T>::from(data);
Ok(Arc::new(result_array) as ArrayRef)
DataType::Utf8View => {
let string_view_array = as_string_view_array(&args[0])?;

let mut builder = StringViewBuilder::with_capacity(string_view_array.len())
.with_block_size(1024 * 1024 * 2);

for val in string_view_array.iter() {
if let Some(val) = val {
let result = re.replacen(val, limit, replacement.as_str());
builder.append_value(result);
} else {
builder.append_null();
}
}

let result = builder.finish();
Ok(Arc::new(result) as ArrayRef)
}
_ => unreachable!(
"Invalid data type for regexp_replace: {}",
string_array_type
),
}
}

/// Determine which implementation of the regexp_replace to use based
Expand Down Expand Up @@ -469,43 +506,91 @@ mod tests {

use super::*;

#[test]
fn test_static_pattern_regexp_replace() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec!["b"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let expected = StringArray::from(vec!["afooc"; 5]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
macro_rules! static_pattern_regexp_replace {
($name:ident, $T:ty, $O:ty) => {
#[test]
fn $name() {
let values = vec!["abc", "acd", "abcd1234567890123", "123456789012abc"];
let patterns = vec!["b"; 4];
let replacement = vec!["foo"; 4];
let expected =
vec!["afooc", "acd", "afoocd1234567890123", "123456789012afooc"];

let values = <$T>::from(values);
let patterns = StringArray::from(patterns);
let replacements = StringArray::from(replacement);
let expected = <$T>::from(expected);

let re = _regexp_replace_static_pattern_replace::<$O>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}
};
}

#[test]
fn test_static_pattern_regexp_replace_with_flags() {
let values = StringArray::from(vec!["abc", "ABC", "aBc", "AbC", "aBC"]);
let patterns = StringArray::from(vec!["b"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let flags = StringArray::from(vec!["i"; 5]);
let expected =
StringArray::from(vec!["afooc", "AfooC", "afooc", "AfooC", "afooC"]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Arc::new(flags),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
static_pattern_regexp_replace!(string_array, StringArray, i32);
static_pattern_regexp_replace!(string_view_array, StringViewArray, i32);
static_pattern_regexp_replace!(large_string_array, LargeStringArray, i64);

macro_rules! static_pattern_regexp_replace_with_flags {
($name:ident, $T:ty, $O: ty) => {
#[test]
fn $name() {
let values = vec![
"abc",
"aBc",
"acd",
"abcd1234567890123",
"aBcd1234567890123",
"123456789012abc",
"123456789012aBc",
];
let expected = vec![
"afooc",
"afooc",
"acd",
"afoocd1234567890123",
"afoocd1234567890123",
"123456789012afooc",
"123456789012afooc",
];

let values = <$T>::from(values);
let patterns = StringArray::from(vec!["b"; 7]);
let replacements = StringArray::from(vec!["foo"; 7]);
let flags = StringArray::from(vec!["i"; 5]);
let expected = <$T>::from(expected);

let re = _regexp_replace_static_pattern_replace::<$O>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Arc::new(flags),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}
};
}

static_pattern_regexp_replace_with_flags!(string_array_with_flags, StringArray, i32);
static_pattern_regexp_replace_with_flags!(
string_view_array_with_flags,
StringViewArray,
i32
);
static_pattern_regexp_replace_with_flags!(
large_string_array_with_flags,
LargeStringArray,
i64
);

#[test]
fn test_static_pattern_regexp_replace_early_abort() {
let values = StringArray::from(vec!["abc"; 5]);
Expand Down

0 comments on commit 34d42bc

Please sign in to comment.