diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 3e832691f96b..711727a9d756 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -254,3 +254,8 @@ required-features = ["unicode_expressions"] harness = false name = "find_in_set" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "contains" +required-features = ["string_expressions"] diff --git a/datafusion/functions/benches/contains.rs b/datafusion/functions/benches/contains.rs new file mode 100644 index 000000000000..052eff38869d --- /dev/null +++ b/datafusion/functions/benches/contains.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +/// Generate a StringArray/StringViewArray with random ASCII strings +fn gen_string_array( + n_rows: usize, + str_len: usize, + is_string_view: bool, +) -> ColumnarValue { + let mut rng = StdRng::seed_from_u64(42); + let strings: Vec> = (0..n_rows) + .map(|_| { + let s: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect(); + Some(s) + }) + .collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +/// Generate a scalar search string +fn gen_scalar_search(search_str: &str, is_string_view: bool) -> ColumnarValue { + if is_string_view { + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(search_str.to_string()))) + } else { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(search_str.to_string()))) + } +} + +/// Generate an array of search strings (same string repeated) +fn gen_array_search( + search_str: &str, + n_rows: usize, + is_string_view: bool, +) -> ColumnarValue { + let strings: Vec> = + (0..n_rows).map(|_| Some(search_str.to_string())).collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let contains = datafusion_functions::string::contains(); + let n_rows = 8192; + let str_len = 128; + let search_str = "xyz"; // A pattern that likely won't be found + + // Benchmark: StringArray with scalar search (the optimized path) + let str_array = gen_string_array(n_rows, str_len, false); + let scalar_search = gen_scalar_search(search_str, false); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("contains_StringArray_scalar_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringArray with array search (for comparison) + let array_search = gen_array_search(search_str, n_rows, false); + c.bench_function("contains_StringArray_array_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), array_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with scalar search (the optimized path) + let str_view_array = gen_string_array(n_rows, str_len, true); + let scalar_search_view = gen_scalar_search(search_str, true); + let arg_fields_view = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function("contains_StringViewArray_scalar_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), scalar_search_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with array search (for comparison) + let array_search_view = gen_array_search(search_str, n_rows, true); + c.bench_function("contains_StringViewArray_array_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), array_search_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark different string lengths with scalar search + for str_len in [8, 32, 128, 512] { + let str_array = gen_string_array(n_rows, str_len, true); + let scalar_search = gen_scalar_search(search_str, true); + let arg_fields = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function( + &format!("contains_StringViewArray_scalar_strlen_{str_len}"), + |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index b85e0ed7966a..b44f1858fdfd 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::make_scalar_function; -use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::array::{ArrayRef, Scalar}; use arrow::compute::contains as arrow_contains; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; +use arrow::datatypes::DataType::Boolean; use datafusion_common::types::logical_string; -use datafusion_common::{DataFusionError, Result, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::binary::{binary_to_string_coercion, string_coercion}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -89,61 +88,81 @@ impl ScalarUDFImpl for ContainsFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(contains, vec![])(&args.args) - } - - fn documentation(&self) -> Option<&Documentation> { - self.doc() - } -} + let [str_arg, search_arg] = args.args.as_slice() else { + return exec_err!( + "contains was called with {} arguments, expected 2", + args.args.len() + ); + }; -/// use `arrow::compute::contains` to do the calculation for contains -fn contains(args: &[ArrayRef]) -> Result { - if let Some(coercion_data_type) = - string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| { - binary_to_string_coercion(args[0].data_type(), args[1].data_type()) - }) - { - let arg0 = if args[0].data_type() == &coercion_data_type { - Arc::clone(&args[0]) - } else { - arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)? + // Determine the common type for coercion + let coercion_type = string_coercion( + &str_arg.data_type(), + &search_arg.data_type(), + ) + .or_else(|| { + binary_to_string_coercion(&str_arg.data_type(), &search_arg.data_type()) + }); + + let Some(coercion_type) = coercion_type else { + return exec_err!( + "Unsupported data types {:?}, {:?} for function `contains`.", + str_arg.data_type(), + search_arg.data_type() + ); }; - let arg1 = if args[1].data_type() == &coercion_data_type { - Arc::clone(&args[1]) - } else { - arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)? + + // Helper to cast an array if needed + let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result { + if arr.data_type() == target { + Ok(Arc::clone(arr)) + } else { + Ok(arrow::compute::kernels::cast::cast(arr, target)?) + } }; - match coercion_data_type { - Utf8View => { - let mod_str = arg0.as_string_view(); - let match_str = arg1.as_string_view(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) + match (str_arg, search_arg) { + // Both scalars - just compute directly + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(search_scalar)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let search_arr = search_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let search_arr = maybe_cast(&search_arr, &coercion_type)?; + let result = arrow_contains(&str_arr, &search_arr)?; + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) } - Utf8 => { - let mod_str = arg0.as_string::(); - let match_str = arg1.as_string::(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) + // String is array, search is scalar - use Scalar wrapper for optimization + (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(search_scalar)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let search_arr = search_scalar.to_array_of_size(1)?; + let search_arr = maybe_cast(&search_arr, &coercion_type)?; + let search_scalar = Scalar::new(search_arr); + let result = arrow_contains(&str_arr, &search_scalar)?; + Ok(ColumnarValue::Array(Arc::new(result))) } - LargeUtf8 => { - let mod_str = arg0.as_string::(); - let match_str = arg1.as_string::(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) + // String is scalar, search is array - use Scalar wrapper for string + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(search_arr)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let str_scalar = Scalar::new(str_arr); + let search_arr = maybe_cast(search_arr, &coercion_type)?; + let result = arrow_contains(&str_scalar, &search_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) } - other => { - exec_err!("Unsupported data type {other:?} for function `contains`.") + // Both arrays - pass directly + (ColumnarValue::Array(str_arr), ColumnarValue::Array(search_arr)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let search_arr = maybe_cast(search_arr, &coercion_type)?; + let result = arrow_contains(&str_arr, &search_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) } } - } else { - exec_err!( - "Unsupported data type {}, {:?} for function `contains`.", - args[0].data_type(), - args[1].data_type() - ) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() } }