diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 765f5d865a60e..835c5ed80b1d6 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -187,7 +187,7 @@ required-features = ["unicode_expressions"] [[bench]] harness = false -name = "ltrim" +name = "trim" required-features = ["string_expressions"] [[bench]] diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/trim.rs similarity index 56% rename from datafusion/functions/benches/ltrim.rs rename to datafusion/functions/benches/trim.rs index 3fce426a917fe..29bbc3f7dcb48 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/trim.rs @@ -48,28 +48,58 @@ impl fmt::Display for StringArrayType { } } -/// returns an array of strings, and `characters` as a ScalarValue -pub fn create_string_array_and_characters( +#[derive(Clone, Copy)] +pub enum TrimType { + Ltrim, + Rtrim, + Btrim, +} + +impl fmt::Display for TrimType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TrimType::Ltrim => f.write_str("ltrim"), + TrimType::Rtrim => f.write_str("rtrim"), + TrimType::Btrim => f.write_str("btrim"), + } + } +} + +/// Returns an array of strings with trim characters positioned according to trim type, +/// and `characters` as a ScalarValue. +/// +/// For ltrim: trim characters are at the start (prefix) +/// For rtrim: trim characters are at the end (suffix) +/// For btrim: trim characters are at both start and end +fn create_string_array_and_characters( size: usize, characters: &str, trimmed: &str, remaining_len: usize, string_array_type: StringArrayType, + trim_type: TrimType, ) -> (ArrayRef, ScalarValue) { let rng = &mut StdRng::seed_from_u64(42); // Create `size` rows: // - 10% rows will be `None` - // - Other 90% will be strings with same `remaining_len` lengths - // We will build the string array on it later. + // - Other 90% will be strings with `remaining_len` content length let string_iter = (0..size).map(|_| { if rng.random::() < 0.1 { None } else { - let mut value = trimmed.as_bytes().to_vec(); - let generated = rng.sample_iter(&Alphanumeric).take(remaining_len); - value.extend(generated); - Some(String::from_utf8(value).unwrap()) + let content: String = rng + .sample_iter(&Alphanumeric) + .take(remaining_len) + .map(char::from) + .collect(); + + let value = match trim_type { + TrimType::Ltrim => format!("{trimmed}{content}"), + TrimType::Rtrim => format!("{content}{trimmed}"), + TrimType::Btrim => format!("{trimmed}{content}{trimmed}"), + }; + Some(value) } }); @@ -90,23 +120,14 @@ pub fn create_string_array_and_characters( } } -/// Create args for the ltrim benchmark -/// Inputs: -/// - size: rows num of the test array -/// - characters: the characters we need to trim -/// - trimmed: the part in the testing string that will be trimmed -/// - remaining_len: the len of the remaining part of testing string after trimming -/// - string_array_type: the method used to store the testing strings -/// -/// Outputs: -/// - testing string array -/// - trimmed characters +/// Create args for the trim benchmark fn create_args( size: usize, characters: &str, trimmed: &str, remaining_len: usize, string_array_type: StringArrayType, + trim_type: TrimType, ) -> Vec { let (string_array, pattern) = create_string_array_and_characters( size, @@ -114,6 +135,7 @@ fn create_args( trimmed, remaining_len, string_array_type, + trim_type, ); vec![ ColumnarValue::Array(string_array), @@ -124,15 +146,23 @@ fn create_args( #[allow(clippy::too_many_arguments)] fn run_with_string_type( group: &mut BenchmarkGroup<'_, M>, - ltrim: &ScalarUDF, + trim_func: &ScalarUDF, + trim_type: TrimType, size: usize, - len: usize, + total_len: usize, characters: &str, trimmed: &str, remaining_len: usize, string_type: StringArrayType, ) { - let args = create_args(size, characters, trimmed, remaining_len, string_type); + let args = create_args( + size, + characters, + trimmed, + remaining_len, + string_type, + trim_type, + ); let arg_fields = args .iter() .enumerate() @@ -142,12 +172,12 @@ fn run_with_string_type( group.bench_function( format!( - "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", + "{trim_type} {string_type} [size={size}, len={total_len}, remaining={remaining_len}]", ), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(ltrim.invoke_with_args(ScalarFunctionArgs { + black_box(trim_func.invoke_with_args(ScalarFunctionArgs { args: args_cloned, arg_fields: arg_fields.clone(), number_rows: size, @@ -160,13 +190,14 @@ fn run_with_string_type( } #[allow(clippy::too_many_arguments)] -fn run_one_group( +fn run_trim_benchmark( c: &mut Criterion, group_name: &str, - ltrim: &ScalarUDF, + trim_func: &ScalarUDF, + trim_type: TrimType, string_types: &[StringArrayType], size: usize, - len: usize, + total_len: usize, characters: &str, trimmed: &str, remaining_len: usize, @@ -178,9 +209,10 @@ fn run_one_group( for string_type in string_types { run_with_string_type( &mut group, - ltrim, + trim_func, + trim_type, size, - len, + total_len, characters, trimmed, remaining_len, @@ -193,6 +225,9 @@ fn run_one_group( fn criterion_benchmark(c: &mut Criterion) { let ltrim = string::ltrim(); + let rtrim = string::rtrim(); + let btrim = string::btrim(); + let characters = ",!()"; let string_types = [ @@ -200,54 +235,69 @@ fn criterion_benchmark(c: &mut Criterion) { StringArrayType::Utf8, StringArrayType::LargeUtf8, ]; - for size in [1024, 4096, 8192] { - // len=12, trimmed_len=4, len_after_ltrim=8 - let len = 12; - let trimmed = characters; - let remaining_len = len - trimmed.len(); - run_one_group( - c, - "INPUT LEN <= 12", - <rim, - &string_types, - size, - len, - characters, - trimmed, - remaining_len, - ); - // len=64, trimmed_len=4, len_after_ltrim=60 - let len = 64; - let trimmed = characters; - let remaining_len = len - trimmed.len(); - run_one_group( - c, - "INPUT LEN > 12, OUTPUT LEN > 12", - <rim, - &string_types, - size, - len, - characters, - trimmed, - remaining_len, - ); + let trim_funcs = [ + (<rim, TrimType::Ltrim), + (&rtrim, TrimType::Rtrim), + (&btrim, TrimType::Btrim), + ]; - // len=64, trimmed_len=56, len_after_ltrim=8 - let len = 64; - let trimmed = characters.repeat(15); - let remaining_len = len - trimmed.len(); - run_one_group( - c, - "INPUT LEN > 12, OUTPUT LEN <= 12", - <rim, - &string_types, - size, - len, - characters, - &trimmed, - remaining_len, - ); + for size in [4096] { + for (trim_func, trim_type) in &trim_funcs { + // Scenario 1: Short strings (len <= 12, inline in StringView) + // trimmed_len=4, remaining_len=8 + let total_len = 12; + let trimmed = characters; + let remaining_len = total_len - trimmed.len(); + run_trim_benchmark( + c, + "short strings (len <= 12)", + trim_func, + *trim_type, + &string_types, + size, + total_len, + characters, + trimmed, + remaining_len, + ); + + // Scenario 2: Long strings, short trim (len > 12, output > 12) + // trimmed_len=4, remaining_len=60 + let total_len = 64; + let trimmed = characters; + let remaining_len = total_len - trimmed.len(); + run_trim_benchmark( + c, + "long strings, short trim", + trim_func, + *trim_type, + &string_types, + size, + total_len, + characters, + trimmed, + remaining_len, + ); + + // Scenario 3: Long strings, long trim (len > 12, output <= 12) + // trimmed_len=56, remaining_len=8 + let total_len = 64; + let trimmed = characters.repeat(14); + let remaining_len = total_len - trimmed.len(); + run_trim_benchmark( + c, + "long strings, long trim", + trim_func, + *trim_type, + &string_types, + size, + total_len, + characters, + &trimmed, + remaining_len, + ); + } } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index ebfada9536fa4..e404d1f5f633f 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -49,90 +49,69 @@ impl Display for TrimType { } } +/// Perform trim operation on input string with given pattern characters. +/// +/// Returns (trimmed_str, start_offset) where start_offset is the byte offset +/// from the beginning of the input string where the trimmed result starts. +#[inline] +fn perform_trim<'a>( + input: &'a str, + pattern: &[char], + trim_type: TrimType, +) -> (&'a str, u32) { + match trim_type { + TrimType::Left => { + let trimmed = input.trim_start_matches(pattern); + let offset = (input.len() - trimmed.len()) as u32; + (trimmed, offset) + } + TrimType::Right => { + let trimmed = input.trim_end_matches(pattern); + (trimmed, 0) + } + TrimType::Both => { + let left_trimmed = input.trim_start_matches(pattern); + let offset = (input.len() - left_trimmed.len()) as u32; + let trimmed = left_trimmed.trim_end_matches(pattern); + (trimmed, offset) + } + } +} + pub(crate) fn general_trim( args: &[ArrayRef], trim_type: TrimType, use_string_view: bool, ) -> Result { - let func = match trim_type { - TrimType::Left => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - let ltrimmed_str = - str::trim_start_matches::<&[char]>(input, pattern.as_ref()); - // `ltrimmed_str` is actually `input`[start_offset..], - // so `start_offset` = len(`input`) - len(`ltrimmed_str`) - let start_offset = input.len() - ltrimmed_str.len(); - - (ltrimmed_str, start_offset as u32) - }, - TrimType::Right => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - let rtrimmed_str = str::trim_end_matches::<&[char]>(input, pattern.as_ref()); - - // `ltrimmed_str` is actually `input`[0..new_len], so `start_offset` is 0 - (rtrimmed_str, 0) - }, - TrimType::Both => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - let ltrimmed_str = - str::trim_start_matches::<&[char]>(input, pattern.as_ref()); - // `btrimmed_str` can be got by rtrim(ltrim(`input`)), - // so its `start_offset` should be same as ltrim situation above - let start_offset = input.len() - ltrimmed_str.len(); - let btrimmed_str = - str::trim_end_matches::<&[char]>(ltrimmed_str, pattern.as_ref()); - - (btrimmed_str, start_offset as u32) - }, - }; - if use_string_view { - string_view_trim(func, args) + string_view_trim(trim_type, args) } else { - string_trim::(func, args) + string_trim::(trim_type, args) } } /// Applies the trim function to the given string view array(s) /// and returns a new string view array with the trimmed values. /// -/// # `trim_func`: The function to apply to each string view. -/// -/// ## Arguments -/// - The original string -/// - the pattern to trim -/// -/// ## Returns -/// - trimmed str (must be a substring of the first argument) -/// - start offset, needed in `string_view_trim` -/// -/// ## Examples -/// -/// For `ltrim`: -/// - `fn(" abc", " ") -> ("abc", 2)` -/// - `fn("abd", " ") -> ("abd", 0)` -/// -/// For `btrim`: -/// - `fn(" abc ", " ") -> ("abc", 2)` -/// - `fn("abd", " ") -> ("abd", 0)` -// removing 'a will cause compiler complaining lifetime of `func` -fn string_view_trim<'a>( - trim_func: fn(&'a str, &'a str) -> (&'a str, u32), - args: &'a [ArrayRef], -) -> Result { +/// Pre-computes the pattern characters once for scalar patterns to avoid +/// repeated allocations per row. +fn string_view_trim(trim_type: TrimType, args: &[ArrayRef]) -> Result { let string_view_array = as_string_view_array(&args[0])?; let mut views_buf = Vec::with_capacity(string_view_array.len()); let mut null_builder = NullBufferBuilder::new(string_view_array.len()); match args.len() { 1 => { - let array_iter = string_view_array.iter(); - let views_iter = string_view_array.views().iter(); - for (src_str_opt, raw_view) in array_iter.zip(views_iter) { - trim_and_append_str( + // Default whitespace trim - pattern is just space + let pattern = [' ']; + for (src_str_opt, raw_view) in string_view_array + .iter() + .zip(string_view_array.views().iter()) + { + trim_and_append_view( src_str_opt, - Some(" "), - trim_func, + &pattern, + trim_type, &mut views_buf, &mut null_builder, raw_view, @@ -143,44 +122,52 @@ fn string_view_trim<'a>( let characters_array = as_string_view_array(&args[1])?; if characters_array.len() == 1 { - // Only one `trim characters` exist + // Scalar pattern - pre-compute pattern chars once if characters_array.is_null(0) { return Ok(new_null_array( - // The schema is expecting utf8 as null &DataType::Utf8View, string_view_array.len(), )); } - let characters = characters_array.value(0); - let array_iter = string_view_array.iter(); - let views_iter = string_view_array.views().iter(); - for (src_str_opt, raw_view) in array_iter.zip(views_iter) { - trim_and_append_str( + let pattern: Vec = characters_array.value(0).chars().collect(); + for (src_str_opt, raw_view) in string_view_array + .iter() + .zip(string_view_array.views().iter()) + { + trim_and_append_view( src_str_opt, - Some(characters), - trim_func, + &pattern, + trim_type, &mut views_buf, &mut null_builder, raw_view, ); } } else { - // A specific `trim characters` for a row in the string view array - let characters_iter = characters_array.iter(); - let array_iter = string_view_array.iter(); - let views_iter = string_view_array.views().iter(); - for ((src_str_opt, raw_view), characters_opt) in - array_iter.zip(views_iter).zip(characters_iter) + // Per-row pattern - must compute pattern chars for each row + for ((src_str_opt, raw_view), characters_opt) in string_view_array + .iter() + .zip(string_view_array.views().iter()) + .zip(characters_array.iter()) { - trim_and_append_str( - src_str_opt, - characters_opt, - trim_func, - &mut views_buf, - &mut null_builder, - raw_view, - ); + if let (Some(src_str), Some(characters)) = + (src_str_opt, characters_opt) + { + let pattern: Vec = characters.chars().collect(); + let (trimmed, offset) = + perform_trim(src_str, &pattern, trim_type); + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw_view, + trimmed, + offset, + ); + } else { + null_builder.append_null(); + views_buf.push(0); + } } } } @@ -211,33 +198,25 @@ fn string_view_trim<'a>( /// Trims the given string and appends the trimmed string to the views buffer /// and the null buffer. /// -/// Calls `trim_func` on the string value in `original_view`, for non_null -/// values and appends the updated view to the views buffer / null_builder. -/// /// Arguments /// - `src_str_opt`: The original string value (represented by the view) -/// - `trim_characters_opt`: The characters to trim from the string -/// - `trim_func`: The function to apply to the string (see [`string_view_trim`] for details) +/// - `pattern`: Pre-computed character pattern to trim +/// - `trim_type`: Type of trim operation (left, right, or both) /// - `views_buf`: The buffer to append the updated views to /// - `null_builder`: The buffer to append the null values to /// - `original_view`: The original view value (that contains src_str_opt) -fn trim_and_append_str<'a>( - src_str_opt: Option<&'a str>, - trim_characters_opt: Option<&'a str>, - trim_func: fn(&'a str, &'a str) -> (&'a str, u32), +#[inline] +fn trim_and_append_view( + src_str_opt: Option<&str>, + pattern: &[char], + trim_type: TrimType, views_buf: &mut Vec, null_builder: &mut NullBufferBuilder, original_view: &u128, ) { - if let (Some(src_str), Some(characters)) = (src_str_opt, trim_characters_opt) { - let (trim_str, start_offset) = trim_func(src_str, characters); - make_and_append_view( - views_buf, - null_builder, - original_view, - trim_str, - start_offset, - ); + if let Some(src_str) = src_str_opt { + let (trimmed, offset) = perform_trim(src_str, pattern, trim_type); + make_and_append_view(views_buf, null_builder, original_view, trimmed, offset); } else { null_builder.append_null(); views_buf.push(0); @@ -247,18 +226,21 @@ fn trim_and_append_str<'a>( /// Applies the trim function to the given string array(s) /// and returns a new string array with the trimmed values. /// -/// See [`string_view_trim`] for details on `func` -fn string_trim<'a, T: OffsetSizeTrait>( - func: fn(&'a str, &'a str) -> (&'a str, u32), - args: &'a [ArrayRef], +/// Pre-computes the pattern characters once for scalar patterns to avoid +/// repeated allocations per row. +fn string_trim( + trim_type: TrimType, + args: &[ArrayRef], ) -> Result { let string_array = as_generic_string_array::(&args[0])?; match args.len() { 1 => { + // Default whitespace trim - pattern is just space + let pattern = [' ']; let result = string_array .iter() - .map(|string| string.map(|string: &str| func(string, " ").0)) + .map(|string| string.map(|s| perform_trim(s, &pattern, trim_type).0)) .collect::>(); Ok(Arc::new(result) as ArrayRef) @@ -267,6 +249,7 @@ fn string_trim<'a, T: OffsetSizeTrait>( let characters_array = as_generic_string_array::(&args[1])?; if characters_array.len() == 1 { + // Scalar pattern - pre-compute pattern chars once if characters_array.is_null(0) { return Ok(new_null_array( string_array.data_type(), @@ -274,19 +257,23 @@ fn string_trim<'a, T: OffsetSizeTrait>( )); } - let characters = characters_array.value(0); + let pattern: Vec = characters_array.value(0).chars().collect(); let result = string_array .iter() - .map(|item| item.map(|string| func(string, characters).0)) + .map(|item| item.map(|s| perform_trim(s, &pattern, trim_type).0)) .collect::>(); return Ok(Arc::new(result) as ArrayRef); } + // Per-row pattern - must compute pattern chars for each row let result = string_array .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters).0), + (Some(s), Some(c)) => { + let pattern: Vec = c.chars().collect(); + Some(perform_trim(s, &pattern, trim_type).0) + } _ => None, }) .collect::>();