From 27929a30b172afeb6a962defe717a5c2353e76de Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 27 Dec 2025 12:12:22 +0530 Subject: [PATCH 1/5] perf: Optimize contains expression with SIMD-based scalar pattern search (#2972) --- native/Cargo.lock | 2 + native/spark-expr/Cargo.toml | 2 + native/spark-expr/src/comet_scalar_funcs.rs | 5 +- .../spark-expr/src/string_funcs/contains.rs | 282 ++++++++++++++++++ native/spark-expr/src/string_funcs/mod.rs | 2 + .../apache/comet/CometExpressionSuite.scala | 19 +- .../CometStringExpressionBenchmark.scala | 1 + 7 files changed, 310 insertions(+), 3 deletions(-) create mode 100644 native/spark-expr/src/string_funcs/contains.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index bf9a7ea2da..7369a97d6b 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1872,6 +1872,7 @@ name = "datafusion-comet-spark-expr" version = "0.13.0" dependencies = [ "arrow", + "arrow-string", "base64", "chrono", "chrono-tz", @@ -1879,6 +1880,7 @@ dependencies = [ "datafusion", "futures", "hex", + "memchr", "num", "rand 0.9.2", "regex", diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index ea89c43204..a0476b2a32 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -28,9 +28,11 @@ edition = { workspace = true } [dependencies] arrow = { workspace = true } +arrow-string = "57.0.0" chrono = { workspace = true } datafusion = { workspace = true } chrono-tz = { workspace = true } +memchr = "2.7" num = { workspace = true } regex = { workspace = true } serde_json = "1.0" diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 8384a4646a..2ff355369e 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc, - SparkStringSpace, + spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateTrunc, + SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), + Arc::new(ScalarUDF::new_from_impl(SparkContains::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), diff --git a/native/spark-expr/src/string_funcs/contains.rs b/native/spark-expr/src/string_funcs/contains.rs new file mode 100644 index 0000000000..c4662ba9d3 --- /dev/null +++ b/native/spark-expr/src/string_funcs/contains.rs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimized `contains` string function for Spark compatibility. +//! +//! This implementation is optimized for the common case where the pattern +//! (second argument) is a scalar value. In this case, we use `memchr::memmem::Finder` +//! which is SIMD-optimized and reuses a single finder instance across all rows. +//! +//! The DataFusion built-in `contains` function uses `make_scalar_function` which +//! expands scalar values to arrays, losing the performance benefit of the optimized +//! scalar path in arrow-rs. + +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::datatypes::DataType; +use arrow_string::like::contains as arrow_contains; +use datafusion::common::{exec_err, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use memchr::memmem::Finder; +use std::any::Any; +use std::sync::Arc; + +/// Spark-optimized contains function. +/// +/// Returns true if the first string argument contains the second string argument. +/// Optimized for the common case where the pattern is a scalar constant. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkContains { + signature: Signature, +} + +impl Default for SparkContains { + fn default() -> Self { + Self::new() + } +} + +impl SparkContains { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkContains { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return exec_err!("contains function requires exactly 2 arguments"); + } + spark_contains(&args.args[0], &args.args[1]) + } +} + +/// Execute the contains function with optimized scalar pattern handling. +fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) -> Result { + match (haystack, needle) { + // Case 1: Both are arrays - use arrow's contains directly + (ColumnarValue::Array(haystack_array), ColumnarValue::Array(needle_array)) => { + let result = arrow_contains(haystack_array, needle_array)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Case 2: Haystack is array, needle is scalar - OPTIMIZED PATH + // This is the common case in SQL like: WHERE col CONTAINS 'pattern' + (ColumnarValue::Array(haystack_array), ColumnarValue::Scalar(needle_scalar)) => { + let result = contains_with_scalar_pattern(haystack_array, needle_scalar)?; + Ok(ColumnarValue::Array(result)) + } + + // Case 3: Haystack is scalar, needle is array - less common + (ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Array(needle_array)) => { + // Convert scalar to array and use arrow's contains + let haystack_array = haystack_scalar.to_array_of_size(needle_array.len())?; + let result = arrow_contains(&haystack_array, needle_array)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Case 4: Both are scalars - compute single result + (ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Scalar(needle_scalar)) => { + let result = contains_scalar_scalar(haystack_scalar, needle_scalar)?; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +/// Optimized contains for array haystack with scalar needle pattern. +/// Uses memchr's SIMD-optimized Finder for efficient repeated searches. +fn contains_with_scalar_pattern( + haystack_array: &ArrayRef, + needle_scalar: &ScalarValue, +) -> Result { + // Handle null needle + if needle_scalar.is_null() { + return Ok(Arc::new(BooleanArray::new_null(haystack_array.len()))); + } + + // Extract the needle string + let needle_str = match needle_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_str(), + _ => { + return exec_err!( + "contains function requires string type for needle, got {:?}", + needle_scalar.data_type() + ) + } + }; + + // Create a reusable Finder for efficient SIMD-optimized searching + let finder = Finder::new(needle_str.as_bytes()); + + match haystack_array.data_type() { + DataType::Utf8 => { + let array = haystack_array.as_string::(); + let result: BooleanArray = array + .iter() + .map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some())) + .collect(); + Ok(Arc::new(result)) + } + DataType::LargeUtf8 => { + let array = haystack_array.as_string::(); + let result: BooleanArray = array + .iter() + .map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some())) + .collect(); + Ok(Arc::new(result)) + } + DataType::Utf8View => { + let array = haystack_array.as_string_view(); + let result: BooleanArray = array + .iter() + .map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some())) + .collect(); + Ok(Arc::new(result)) + } + other => exec_err!( + "contains function requires string type for haystack, got {:?}", + other + ), + } +} + +/// Contains for two scalar values. +fn contains_scalar_scalar( + haystack_scalar: &ScalarValue, + needle_scalar: &ScalarValue, +) -> Result { + // Handle nulls + if haystack_scalar.is_null() || needle_scalar.is_null() { + return Ok(ScalarValue::Boolean(None)); + } + + let haystack_str = match haystack_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_str(), + _ => { + return exec_err!( + "contains function requires string type for haystack, got {:?}", + haystack_scalar.data_type() + ) + } + }; + + let needle_str = match needle_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_str(), + _ => { + return exec_err!( + "contains function requires string type for needle, got {:?}", + needle_scalar.data_type() + ) + } + }; + + Ok(ScalarValue::Boolean(Some( + haystack_str.contains(needle_str), + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_contains_array_scalar() { + let haystack = Arc::new(StringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + Some("testing"), + None, + ])) as ArrayRef; + let needle = ScalarValue::Utf8(Some("world".to_string())); + + let result = contains_with_scalar_pattern(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + assert!(bool_array.value(0)); // "hello world" contains "world" + assert!(!bool_array.value(1)); // "foo bar" does not contain "world" + assert!(!bool_array.value(2)); // "testing" does not contain "world" + assert!(bool_array.is_null(3)); // null input => null output + } + + #[test] + fn test_contains_scalar_scalar() { + let haystack = ScalarValue::Utf8(Some("hello world".to_string())); + let needle = ScalarValue::Utf8(Some("world".to_string())); + + let result = contains_scalar_scalar(&haystack, &needle).unwrap(); + assert_eq!(result, ScalarValue::Boolean(Some(true))); + + let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string())); + let result = contains_scalar_scalar(&haystack, &needle_not_found).unwrap(); + assert_eq!(result, ScalarValue::Boolean(Some(false))); + } + + #[test] + fn test_contains_null_needle() { + let haystack = Arc::new(StringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + ])) as ArrayRef; + let needle = ScalarValue::Utf8(None); + + let result = contains_with_scalar_pattern(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + // Null needle should produce null results + assert!(bool_array.is_null(0)); + assert!(bool_array.is_null(1)); + } + + #[test] + fn test_contains_empty_needle() { + let haystack = Arc::new(StringArray::from(vec![Some("hello world"), Some("")])) as ArrayRef; + let needle = ScalarValue::Utf8(Some("".to_string())); + + let result = contains_with_scalar_pattern(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + // Empty string is contained in any string + assert!(bool_array.value(0)); + assert!(bool_array.value(1)); + } +} diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..abdd0cc89b 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod contains; mod string_space; mod substring; +pub use contains::SparkContains; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0352da7850..93b184ad7f 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1107,7 +1107,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Filter rows that contains 'rose' in 'name' column val queryContains = sql(s"select id from $table where contains (name, 'rose')") - checkAnswer(queryContains, Row(5) :: Nil) + checkSparkAnswerAndOperator(queryContains) + + // Additional test cases for optimized contains implementation + // Test with empty pattern (should match all non-null rows) + val queryEmptyPattern = sql(s"select id from $table where contains (name, '')") + checkSparkAnswerAndOperator(queryEmptyPattern) + + // Test with pattern not found + val queryNotFound = sql(s"select id from $table where contains (name, 'xyz')") + checkSparkAnswerAndOperator(queryNotFound) + + // Test with pattern at start + val queryStart = sql(s"select id from $table where contains (name, 'James')") + checkSparkAnswerAndOperator(queryStart) + + // Test with pattern at end + val queryEnd = sql(s"select id from $table where contains (name, 'Smith')") + checkSparkAnswerAndOperator(queryEnd) } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index 41eabb8513..c96cd83438 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -72,6 +72,7 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { StringExprConfig("initCap", "select initCap(c1) from parquetV1Table"), StringExprConfig("trim", "select trim(c1) from parquetV1Table"), StringExprConfig("concatws", "select concat_ws(' ', c1, c1) from parquetV1Table"), + StringExprConfig("contains", "select contains(c1, '123') from parquetV1Table"), StringExprConfig("length", "select length(c1) from parquetV1Table"), StringExprConfig("repeat", "select repeat(c1, 3) from parquetV1Table"), StringExprConfig("reverse", "select reverse(c1) from parquetV1Table"), From a4492975524ccda0ae1ecc1a87a8f81830ff4272 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 27 Dec 2025 23:47:16 +0530 Subject: [PATCH 2/5] Use arrow re-export instead of direct arrow-string dependency --- native/spark-expr/Cargo.toml | 1 - native/spark-expr/src/string_funcs/contains.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index a0476b2a32..7621c0c974 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -28,7 +28,6 @@ edition = { workspace = true } [dependencies] arrow = { workspace = true } -arrow-string = "57.0.0" chrono = { workspace = true } datafusion = { workspace = true } chrono-tz = { workspace = true } diff --git a/native/spark-expr/src/string_funcs/contains.rs b/native/spark-expr/src/string_funcs/contains.rs index c4662ba9d3..9925319880 100644 --- a/native/spark-expr/src/string_funcs/contains.rs +++ b/native/spark-expr/src/string_funcs/contains.rs @@ -27,7 +27,7 @@ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::datatypes::DataType; -use arrow_string::like::contains as arrow_contains; +use arrow::compute::kernels::comparison::contains as arrow_contains; use datafusion::common::{exec_err, Result, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, From 70163ec5a43e66d216add8c85e107b195e0d36d2 Mon Sep 17 00:00:00 2001 From: Shekhar Prasad Rajak <5774448+Shekharrajak@users.noreply.github.com> Date: Sun, 28 Dec 2025 11:18:22 +0530 Subject: [PATCH 3/5] Update native/Cargo.lock Co-authored-by: Andy Grove --- native/Cargo.lock | 1 - 1 file changed, 1 deletion(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index 7369a97d6b..d430e16785 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1872,7 +1872,6 @@ name = "datafusion-comet-spark-expr" version = "0.13.0" dependencies = [ "arrow", - "arrow-string", "base64", "chrono", "chrono-tz", From e50e0777fb16ff33cbf860b791ebd772da131a85 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Mon, 29 Dec 2025 23:28:24 +0530 Subject: [PATCH 4/5] test fix: org.apache.spark.sql.comet.ParquetEncryptionITCase.SPARK-37117 --- .../org/apache/spark/sql/CometTestBase.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 7dba24bff7..6524f677f0 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -528,6 +528,30 @@ abstract class CometTestBase } } + /** + * Override waitForTasksToFinish to ensure SparkContext is active before checking tasks. This + * fixes the issue where waitForTasksToFinish returns -1 when SparkContext is not active. + */ + override protected def waitForTasksToFinish(): Unit = { + // Ensure SparkContext is active before checking tasks + // The parent implementation uses SparkContext.getActive.map(_.activeTasks).getOrElse(-1) + // If SparkContext is not active, it returns -1 which causes the assertion to fail. + // We ensure we have an active SparkContext before calling the parent method. + if (SparkContext.getActive.isEmpty) { + // Ensure we have a SparkContext from the spark session + if (_spark != null) { + // SparkContext from spark session should already be active + // but if not, getOrCreate will make it active + val _ = _spark.sparkContext + } else { + // Fallback to sparkContext which will get or create one + val _ = sparkContext + } + } + // Now call parent implementation which should find an active SparkContext + super.waitForTasksToFinish() + } + protected def readResourceParquetFile(name: String): DataFrame = { spark.read.parquet(getResourceParquetFilePath(name)) } From 83bcafb5a83bdf91bf97c30ab25b95b48ac2e187 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Tue, 30 Dec 2025 09:47:45 +0530 Subject: [PATCH 5/5] Fix Rust import order and skip INSERT OVERWRITE DIRECTORY in Comet --- native/spark-expr/src/string_funcs/contains.rs | 2 +- .../comet/serde/operator/CometDataWritingCommand.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/native/spark-expr/src/string_funcs/contains.rs b/native/spark-expr/src/string_funcs/contains.rs index 9925319880..e1ae756e53 100644 --- a/native/spark-expr/src/string_funcs/contains.rs +++ b/native/spark-expr/src/string_funcs/contains.rs @@ -26,8 +26,8 @@ //! scalar path in arrow-rs. use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; -use arrow::datatypes::DataType; use arrow::compute::kernels::comparison::contains as arrow_contains; +use arrow::datatypes::DataType; use datafusion::common::{exec_err, Result, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 7fdf055217..9cec88a934 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -50,6 +50,11 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec override def getSupportLevel(op: DataWritingCommandExec): SupportLevel = { op.cmd match { case cmd: InsertIntoHadoopFsRelationCommand => + // Skip INSERT OVERWRITE DIRECTORY operations (catalogTable is None for directory writes) + if (cmd.catalogTable.isEmpty) { + return Unsupported(Some("INSERT OVERWRITE DIRECTORY is not supported")) + } + cmd.fileFormat match { case _: ParquetFileFormat => if (!cmd.outputPath.toString.startsWith("file:")) {