diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 765f5d865a60..97b9d6135ac8 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -255,6 +255,11 @@ harness = false name = "find_in_set" required-features = ["unicode_expressions"] +[[bench]] +harness = false +name = "contains" +required-features = ["string_expressions"] + [[bench]] harness = false name = "starts_with" diff --git a/datafusion/functions/benches/contains.rs b/datafusion/functions/benches/contains.rs new file mode 100644 index 000000000000..052eff38869d --- /dev/null +++ b/datafusion/functions/benches/contains.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +/// Generate a StringArray/StringViewArray with random ASCII strings +fn gen_string_array( + n_rows: usize, + str_len: usize, + is_string_view: bool, +) -> ColumnarValue { + let mut rng = StdRng::seed_from_u64(42); + let strings: Vec> = (0..n_rows) + .map(|_| { + let s: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect(); + Some(s) + }) + .collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +/// Generate a scalar search string +fn gen_scalar_search(search_str: &str, is_string_view: bool) -> ColumnarValue { + if is_string_view { + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(search_str.to_string()))) + } else { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(search_str.to_string()))) + } +} + +/// Generate an array of search strings (same string repeated) +fn gen_array_search( + search_str: &str, + n_rows: usize, + is_string_view: bool, +) -> ColumnarValue { + let strings: Vec> = + (0..n_rows).map(|_| Some(search_str.to_string())).collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let contains = datafusion_functions::string::contains(); + let n_rows = 8192; + let str_len = 128; + let search_str = "xyz"; // A pattern that likely won't be found + + // Benchmark: StringArray with scalar search (the optimized path) + let str_array = gen_string_array(n_rows, str_len, false); + let scalar_search = gen_scalar_search(search_str, false); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("contains_StringArray_scalar_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringArray with array search (for comparison) + let array_search = gen_array_search(search_str, n_rows, false); + c.bench_function("contains_StringArray_array_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), array_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with scalar search (the optimized path) + let str_view_array = gen_string_array(n_rows, str_len, true); + let scalar_search_view = gen_scalar_search(search_str, true); + let arg_fields_view = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function("contains_StringViewArray_scalar_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), scalar_search_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with array search (for comparison) + let array_search_view = gen_array_search(search_str, n_rows, true); + c.bench_function("contains_StringViewArray_array_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), array_search_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark different string lengths with scalar search + for str_len in [8, 32, 128, 512] { + let str_array = gen_string_array(n_rows, str_len, true); + let scalar_search = gen_scalar_search(search_str, true); + let arg_fields = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function( + &format!("contains_StringViewArray_scalar_strlen_{str_len}"), + |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index b85e0ed7966a..b7ec95be444c 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::make_scalar_function; -use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::array::{Array, ArrayRef, Scalar}; use arrow::compute::contains as arrow_contains; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; use datafusion_common::types::logical_string; -use datafusion_common::{DataFusionError, Result, exec_err}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::binary::{binary_to_string_coercion, string_coercion}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -89,7 +88,7 @@ impl ScalarUDFImpl for ContainsFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(contains, vec![])(&args.args) + contains(args.args.as_slice()) } fn documentation(&self) -> Option<&Documentation> { @@ -97,43 +96,71 @@ impl ScalarUDFImpl for ContainsFunc { } } +fn to_array(value: &ColumnarValue) -> Result<(ArrayRef, bool)> { + match value { + ColumnarValue::Array(array) => Ok((Arc::clone(array), false)), + ColumnarValue::Scalar(scalar) => Ok((scalar.to_array()?, true)), + } +} + +/// Helper to call arrow_contains with proper Datum handling. +/// When an argument is marked as scalar, we wrap it in `Scalar` to tell arrow's +/// kernel to use the optimized single-value code path instead of iterating. +fn call_arrow_contains( + haystack: &ArrayRef, + haystack_is_scalar: bool, + needle: &ArrayRef, + needle_is_scalar: bool, +) -> Result { + // Arrow's Datum trait is implemented for ArrayRef, Arc, and Scalar + // We pass ArrayRef directly when not scalar, or wrap in Scalar when it is + let result = match (haystack_is_scalar, needle_is_scalar) { + (false, false) => arrow_contains(haystack, needle)?, + (false, true) => arrow_contains(haystack, &Scalar::new(Arc::clone(needle)))?, + (true, false) => arrow_contains(&Scalar::new(Arc::clone(haystack)), needle)?, + (true, true) => arrow_contains( + &Scalar::new(Arc::clone(haystack)), + &Scalar::new(Arc::clone(needle)), + )?, + }; + + // If both inputs were scalar, return a scalar result + if haystack_is_scalar && needle_is_scalar { + let scalar = datafusion_common::ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } else { + Ok(ColumnarValue::Array(Arc::new(result))) + } +} + /// use `arrow::compute::contains` to do the calculation for contains -fn contains(args: &[ArrayRef]) -> Result { +fn contains(args: &[ColumnarValue]) -> Result { + let (haystack, haystack_is_scalar) = to_array(&args[0])?; + let (needle, needle_is_scalar) = to_array(&args[1])?; + if let Some(coercion_data_type) = - string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| { - binary_to_string_coercion(args[0].data_type(), args[1].data_type()) + string_coercion(haystack.data_type(), needle.data_type()).or_else(|| { + binary_to_string_coercion(haystack.data_type(), needle.data_type()) }) { - let arg0 = if args[0].data_type() == &coercion_data_type { - Arc::clone(&args[0]) + let haystack = if haystack.data_type() == &coercion_data_type { + haystack } else { - arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)? + arrow::compute::kernels::cast::cast(&haystack, &coercion_data_type)? }; - let arg1 = if args[1].data_type() == &coercion_data_type { - Arc::clone(&args[1]) + let needle = if needle.data_type() == &coercion_data_type { + needle } else { - arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)? + arrow::compute::kernels::cast::cast(&needle, &coercion_data_type)? }; match coercion_data_type { - Utf8View => { - let mod_str = arg0.as_string_view(); - let match_str = arg1.as_string_view(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) - } - Utf8 => { - let mod_str = arg0.as_string::(); - let match_str = arg1.as_string::(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) - } - LargeUtf8 => { - let mod_str = arg0.as_string::(); - let match_str = arg1.as_string::(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) - } + Utf8View | Utf8 | LargeUtf8 => call_arrow_contains( + &haystack, + haystack_is_scalar, + &needle, + needle_is_scalar, + ), other => { exec_err!("Unsupported data type {other:?} for function `contains`.") }