Skip to content

Commit

Permalink
v2 impl
Browse files Browse the repository at this point in the history
  • Loading branch information
xinlifoobar committed Aug 17, 2024
1 parent 042d725 commit 894e797
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 19 deletions.
29 changes: 29 additions & 0 deletions arrow-array/src/array/byte_view_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,21 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
unsafe { self.value_unchecked(i) }
}

/// Returns the inline view data at index `i`
pub unsafe fn prefix_bytes_unchecked(&self, prefix_len: usize, idx: usize) -> &[u8] {
let v = self.views.get_unchecked(idx);
let len = (*v as u32) as usize;

if prefix_len <= 4 || (prefix_len <= 12 && len <= 12) {
Self::inline_value(v, prefix_len)
} else {
let view = ByteView::from(*v);
let data = self.buffers.get_unchecked(view.buffer_index as usize);
let offset = view.offset as usize;
data.get_unchecked(offset..offset + prefix_len)
}
}

/// Returns the element at index `i`
/// # Safety
/// Caller is responsible for ensuring that the index is within the bounds of the array
Expand All @@ -278,6 +293,20 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
T::Native::from_bytes_unchecked(b)
}

/// Returns the bytes at index `i`
pub unsafe fn bytes_unchecked(&self, idx: usize) -> &[u8] {
let v = self.views.get_unchecked(idx);
let len = *v as u32;
if len <= 12 {
Self::inline_value(v, len as usize)
} else {
let view = ByteView::from(*v);
let data = self.buffers.get_unchecked(view.buffer_index as usize);
let offset = view.offset as usize;
data.get_unchecked(offset..offset + len as usize)
}
}

/// Returns the inline value of the view.
///
/// # Safety
Expand Down
145 changes: 126 additions & 19 deletions arrow-string/src/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use arrow_array::{ArrayAccessor, BooleanArray};
use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray};
use arrow_buffer::BooleanBuffer;
use arrow_schema::ArrowError;
use memchr::memchr2;
use memchr::memmem::Finder;
Expand Down Expand Up @@ -111,24 +112,130 @@ impl<'a> Predicate<'a> {
Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| {
(haystack.len() == v.len() && haystack == *v) != negate
}),
Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| {
haystack.eq_ignore_ascii_case(v) != negate
}),
Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
finder.find(haystack.as_bytes()).is_some() != negate
}),
Predicate::StartsWith(v) => BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_kernel) != negate
}),
Predicate::IStartsWithAscii(v) => BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
}),
Predicate::EndsWith(v) => BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_kernel) != negate
}),
Predicate::IEndsWithAscii(v) => BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
}),
Predicate::IEqAscii(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let neddle_bytes = v.as_bytes();
let null_buffer = string_view_array.logical_nulls();
let boolean_buffer =
BooleanBuffer::collect_bool(string_view_array.len(), |i| {
unsafe { string_view_array.bytes_unchecked(i) }
.eq_ignore_ascii_case(neddle_bytes)
!= negate
});

BooleanArray::new(boolean_buffer, null_buffer)
} else {
BooleanArray::from_unary(array, |haystack| {
haystack.eq_ignore_ascii_case(v) != negate
})
}
}
Predicate::Contains(finder) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let null_buffer = string_view_array.logical_nulls();
let boolean_buffer =
BooleanBuffer::collect_bool(string_view_array.len(), |i| {
finder
.find(unsafe { string_view_array.bytes_unchecked(i) })
.is_some()
!= negate
});

BooleanArray::new(boolean_buffer, null_buffer)
} else {
BooleanArray::from_unary(array, |haystack| {
finder.find(haystack.as_bytes()).is_some() != negate
})
}
}
Predicate::StartsWith(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let needle_bytes = v.as_bytes();
let needle_len = needle_bytes.len();
let null_buffer = string_view_array.logical_nulls();
let boolean_buffer =
BooleanBuffer::collect_bool(string_view_array.len(), |i| {
zip(
unsafe { string_view_array.prefix_bytes_unchecked(needle_len, i) },
needle_bytes,
)
.all(equals_kernel)
});

BooleanArray::new(boolean_buffer, null_buffer)
} else {
BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_kernel) != negate
})
}
}
Predicate::IStartsWithAscii(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let needle_bytes = v.as_bytes();
let needle_len = needle_bytes.len();
let null_buffer = string_view_array.logical_nulls();
let boolean_buffer =
BooleanBuffer::collect_bool(string_view_array.len(), |i| {
zip(
unsafe { string_view_array.prefix_bytes_unchecked(needle_len, i) },
needle_bytes,
)
.all(equals_ignore_ascii_case_kernel)
});

BooleanArray::new(boolean_buffer, null_buffer)
} else {
BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
})
}
}
Predicate::EndsWith(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let needle_bytes = v.as_bytes();
let needle_len = needle_bytes.len();
let null_buffer = string_view_array.logical_nulls();
let boolean_buffer =
BooleanBuffer::collect_bool(string_view_array.len(), |i| {
zip(
unsafe { string_view_array.prefix_bytes_unchecked(needle_len, i) }
.iter()
.rev(),
needle_bytes.iter().rev(),
)
.all(equals_kernel)
});

BooleanArray::new(boolean_buffer, null_buffer)
} else {
BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_kernel) != negate
})
}
}
Predicate::IEndsWithAscii(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let needle_bytes = v.as_bytes();
let needle_len = needle_bytes.len();
let null_buffer = string_view_array.logical_nulls();
let boolean_buffer =
BooleanBuffer::collect_bool(string_view_array.len(), |i| {
zip(
unsafe { string_view_array.prefix_bytes_unchecked(needle_len, i) }
.iter()
.rev(),
needle_bytes.iter().rev(),
)
.all(equals_ignore_ascii_case_kernel)
});

BooleanArray::new(boolean_buffer, null_buffer)
} else {
BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
})
}
}
Predicate::Regex(v) => {
BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
}
Expand Down

0 comments on commit 894e797

Please sign in to comment.