diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 765f5d865a60..74964267f7a1 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -106,6 +106,11 @@ harness = false name = "concat" required-features = ["string_expressions"] +[[bench]] +harness = false +name = "concat_ws" +required-features = ["string_expressions"] + [[bench]] harness = false name = "to_timestamp" diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index f7ef97892090..b01f3b257012 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -23,10 +23,12 @@ use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::concat; +use rand::Rng; +use rand::distr::Alphanumeric; use std::hint::black_box; use std::sync::Arc; -fn create_args(size: usize, str_len: usize) -> Vec { +fn create_array_args(size: usize, str_len: usize) -> Vec { let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); let scalar = ScalarValue::Utf8(Some(", ".to_string())); vec![ @@ -36,9 +38,27 @@ fn create_args(size: usize, str_len: usize) -> Vec { ] } +fn generate_random_string(str_len: usize) -> String { + rand::rng() + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect() +} + +fn create_scalar_args(count: usize, str_len: usize) -> Vec { + std::iter::repeat_with(|| { + let s = generate_random_string(str_len); + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + }) + .take(count) + .collect() +} + fn criterion_benchmark(c: &mut Criterion) { + // Benchmark for array concat for size in [1024, 4096, 8192] { - let args = create_args(size, 32); + let args = create_array_args(size, 32); let arg_fields = args .iter() .enumerate() @@ -67,6 +87,31 @@ fn criterion_benchmark(c: &mut Criterion) { }); group.finish(); } + + // Benchmark for scalar concat + let scalar_args = create_scalar_args(10, 100); + let scalar_arg_fields = scalar_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let mut group = c.benchmark_group("concat function"); + group.bench_function(BenchmarkId::new("concat", "scalar"), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + concat() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/concat_ws.rs b/datafusion/functions/benches/concat_ws.rs new file mode 100644 index 000000000000..97d6d96411d7 --- /dev/null +++ b/datafusion/functions/benches/concat_ws.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::string::concat_ws; +use rand::Rng; +use rand::distr::Alphanumeric; +use std::hint::black_box; +use std::sync::Arc; + +fn create_array_args(size: usize, str_len: usize) -> Vec { + let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + let scalar = ScalarValue::Utf8(Some(", ".to_string())); + vec![ + ColumnarValue::Scalar(scalar), + ColumnarValue::Array(Arc::clone(&array) as ArrayRef), + ColumnarValue::Array(array), + ] +} + +fn generate_random_string(str_len: usize) -> String { + rand::rng() + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect() +} + +fn create_scalar_args(count: usize, str_len: usize) -> Vec { + let mut args = Vec::with_capacity(count + 1); + + args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ",".to_string(), + )))); + + for _ in 0..count { + let s = generate_random_string(str_len); + args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))); + } + args +} + +fn criterion_benchmark(c: &mut Criterion) { + // Benchmark for array concat_ws + for size in [1024, 4096, 8192] { + let args = create_array_args(size, 32); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + let mut group = c.benchmark_group("concat_ws function"); + group.bench_function(BenchmarkId::new("concat_ws", size), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box( + concat_ws() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + group.finish(); + } + + // Benchmark for scalar concat_ws + let scalar_args = create_scalar_args(10, 100); + let scalar_arg_fields = scalar_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let mut group = c.benchmark_group("concat_ws function"); + group.bench_function(BenchmarkId::new("concat_ws", "scalar"), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + concat_ws() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 42d455a05760..811ea0d8d258 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -130,14 +130,14 @@ impl ScalarUDFImpl for ConcatFunc { // Scalar if array_len.is_none() { - let mut result = String::new(); - for arg in args { + let mut values = Vec::with_capacity(args.len()); + for arg in &args { let ColumnarValue::Scalar(scalar) = arg else { return internal_err!("concat expected scalar value, got {arg:?}"); }; match scalar.try_as_str() { - Some(Some(v)) => result.push_str(v), + Some(Some(v)) => values.push(v), Some(None) => {} // null literal None => plan_err!( "Concat function does not support scalar type {}", @@ -145,6 +145,7 @@ impl ScalarUDFImpl for ConcatFunc { )?, } } + let result = values.concat(); return match return_datatype { DataType::Utf8View => { @@ -206,7 +207,11 @@ impl ScalarUDFImpl for ConcatFunc { DataType::Utf8View => { let string_array = as_string_view_array(array)?; - data_size += string_array.len(); + data_size += string_array + .data_buffers() + .iter() + .map(|buf| buf.len()) + .sum::(); let column = if array.is_nullable() { ColumnarValueRef::NullableStringViewArray(string_array) } else { diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 8fe095c5ce2b..e292ecb2b8ee 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -136,43 +136,22 @@ impl ScalarUDFImpl for ConcatWsFunc { None => return internal_err!("Expected string literal, got {scalar:?}"), }; - let mut result = String::new(); - // iterator over Option - let iter = &mut args[1..].iter().map(|arg| { + let mut values = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { let ColumnarValue::Scalar(scalar) = arg else { // loop above checks for all args being scalar unreachable!() }; - scalar.try_as_str() - }); - - // append first non null arg - for scalar in iter.by_ref() { - match scalar { - Some(Some(s)) => { - result.push_str(s); - break; - } - Some(None) => {} // null literal string - None => { - return internal_err!("Expected string literal, got {scalar:?}"); - } - } - } - // handle subsequent non null args - for scalar in iter.by_ref() { - match scalar { - Some(Some(s)) => { - result.push_str(sep); - result.push_str(s); - } + match scalar.try_as_str() { + Some(Some(v)) => values.push(v), Some(None) => {} // null literal string None => { return internal_err!("Expected string literal, got {scalar:?}"); } } } + let result = values.join(sep); return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); }