diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index beb5f9dcf7..17f66b6630 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -167,6 +167,7 @@ jobs: org.apache.comet.CometStringExpressionSuite org.apache.comet.CometBitwiseExpressionSuite org.apache.comet.CometMapExpressionSuite + org.apache.comet.CometCsvExpressionSuite org.apache.comet.CometJsonExpressionSuite org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 9a45fe022d..80e8854ef6 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -131,6 +131,7 @@ jobs: org.apache.comet.CometBitwiseExpressionSuite org.apache.comet.CometMapExpressionSuite org.apache.comet.CometJsonExpressionSuite + org.apache.comet.CometCsvExpressionSuite org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 1a273ad033..bd062ec587 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -324,6 +324,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.StringTrimBoth.enabled` | Enable Comet acceleration for `StringTrimBoth` | true | | `spark.comet.expression.StringTrimLeft.enabled` | Enable Comet acceleration for `StringTrimLeft` | true | | `spark.comet.expression.StringTrimRight.enabled` | Enable Comet acceleration for `StringTrimRight` | true | +| `spark.comet.expression.StructsToCsv.enabled` | Enable Comet acceleration for `StructsToCsv` | true | | `spark.comet.expression.StructsToJson.enabled` | Enable Comet acceleration for `StructsToJson` | true | | `spark.comet.expression.Substring.enabled` | Enable Comet acceleration for `Substring` | true | | `spark.comet.expression.Subtract.enabled` | Enable Comet acceleration for `Subtract` | true | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 93fbb59c11..06decf332a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -71,7 +71,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, SparkSecond, - SumInteger, + SumInteger, ToCsv, }; use iceberg::expr::Bind; @@ -644,6 +644,18 @@ impl PhysicalPlanner { ExprStruct::MonotonicallyIncreasingId(_) => Ok(Arc::new( MonotonicallyIncreasingId::from_partition_id(self.partition), )), + ExprStruct::ToCsv(expr) => { + let csv_struct_expr = + self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; + let options = expr.options.clone().unwrap(); + Ok(Arc::new(ToCsv::new( + csv_struct_expr, + &options.delimiter, + &options.quote, + &options.escape, + &options.null_value, + ))) + } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5f258fd677..f3b27380be 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -86,6 +86,7 @@ message Expr { EmptyExpr spark_partition_id = 63; EmptyExpr monotonically_increasing_id = 64; FromJson from_json = 89; + ToCsv to_csv = 90; } } @@ -275,6 +276,25 @@ message FromJson { string timezone = 3; } +message ToCsv { + Expr child = 1; + CsvWriteOptions options = 2; +} + +message CsvWriteOptions { + string delimiter = 1; + string quote = 2; + string escape = 3; + string null_value = 4; + bool quote_all = 5; + bool ignore_leading_white_space = 6; + bool ignore_trailing_white_space = 7; + string date_format = 8; + string timestamp_format = 9; + string timestamp_ntz_format = 10; + string timezone = 11; +} + enum BinaryOutputStyle { UTF8 = 0; BASIC = 1; diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 94653d8864..fd0a211b29 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -88,6 +88,10 @@ harness = false name = "normalize_nan" harness = false +[[bench]] +name = "to_csv" +harness = false + [[test]] name = "test_udf_registration" path = "tests/spark_expr_reg.rs" diff --git a/native/spark-expr/benches/to_csv.rs b/native/spark-expr/benches/to_csv.rs new file mode 100644 index 0000000000..55cd9af7cb --- /dev/null +++ b/native/spark-expr/benches/to_csv.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + BooleanBuilder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, StringBuilder, + StructArray, StructBuilder, +}; +use arrow::datatypes::{DataType, Field}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::struct_to_csv; +use std::hint::black_box; + +fn create_struct_array(array_size: usize) -> StructArray { + let fields = vec![ + Field::new("f1", DataType::Boolean, true), + Field::new("f2", DataType::Int8, true), + Field::new("f3", DataType::Int16, true), + Field::new("f4", DataType::Int32, true), + Field::new("f5", DataType::Int64, true), + Field::new("f6", DataType::Utf8, true), + ]; + let mut struct_builder = StructBuilder::from_fields(fields, array_size); + for i in 0..array_size { + struct_builder + .field_builder::(0) + .unwrap() + .append_option(if i % 10 == 0 { None } else { Some(i % 2 == 0) }); + + struct_builder + .field_builder::(1) + .unwrap() + .append_option(if i % 10 == 0 { + None + } else { + Some((i % 128) as i8) + }); + + struct_builder + .field_builder::(2) + .unwrap() + .append_option(if i % 10 == 0 { None } else { Some(i as i16) }); + + struct_builder + .field_builder::(3) + .unwrap() + .append_option(if i % 10 == 0 { None } else { Some(i as i32) }); + + struct_builder + .field_builder::(4) + .unwrap() + .append_option(if i % 10 == 0 { None } else { Some(i as i64) }); + + struct_builder + .field_builder::(5) + .unwrap() + .append_option(if i % 10 == 0 { + None + } else { + Some(format!("string_{}", i)) + }); + + struct_builder.append(true); + } + struct_builder.finish() +} + +fn criterion_benchmark(c: &mut Criterion) { + let array_size = 8192; + let struct_array = create_struct_array(array_size); + let default_delimiter = ","; + let default_null_value = ""; + c.bench_function("to_csv", |b| { + b.iter(|| { + black_box(struct_to_csv(&struct_array, default_delimiter, default_null_value).unwrap()) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/native/spark-expr/src/csv_funcs/mod.rs b/native/spark-expr/src/csv_funcs/mod.rs new file mode 100644 index 0000000000..311b509297 --- /dev/null +++ b/native/spark-expr/src/csv_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod to_csv; + +pub use to_csv::{ToCsv, struct_to_csv}; diff --git a/native/spark-expr/src/csv_funcs/to_csv.rs b/native/spark-expr/src/csv_funcs/to_csv.rs new file mode 100644 index 0000000000..68f09c7ebe --- /dev/null +++ b/native/spark-expr/src/csv_funcs/to_csv.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{as_boolean_array, as_largestring_array, as_string_array, as_struct_array, Array, ArrayRef, StringBuilder}; +use arrow::array::{RecordBatch, StructArray}; +use arrow::datatypes::{DataType, Schema}; +use datafusion::common::cast::{as_int16_array, as_int32_array, as_int64_array, as_int8_array}; +use datafusion::common::{exec_err, Result}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::Hash; +use std::sync::Arc; + +/// to_csv spark function +#[derive(Debug, Eq)] +pub struct ToCsv { + expr: Arc, + delimiter: String, + quote: String, + escape: String, + null_value: String, + quote_all: bool, +} + +impl Hash for ToCsv { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + self.delimiter.hash(state); + self.quote.hash(state); + self.escape.hash(state); + self.null_value.hash(state); + self.quote_all.hash(state); + } +} + +impl PartialEq for ToCsv { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) + && self.delimiter.eq(&other.delimiter) + && self.quote.eq(&other.quote) + && self.escape.eq(&other.escape) + && self.null_value.eq(&other.null_value) + && self.quote_all.eq(&other.quote_all) + } +} + +impl ToCsv { + pub fn new( + expr: Arc, + delimiter: &str, + quote: &str, + escape: &str, + null_value: &str, + quote_all: bool + ) -> Self { + Self { + expr, + delimiter: delimiter.to_owned(), + quote: quote.to_owned(), + escape: escape.to_owned(), + null_value: null_value.to_owned(), + quote_all, + } + } +} + +impl Display for ToCsv { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "to_csv({}, delimiter={}, quote={}, escape={}, null_value={}, quote_all={})", + self.expr, self.delimiter, self.quote, self.escape, self.null_value, self.quote_all + ) + } +} + +impl PhysicalExpr for ToCsv { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> Result { + Ok(DataType::Utf8) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.expr.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let input_value = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; + + let struct_array = as_struct_array(&input_value); + + let result = struct_to_csv(struct_array, &self.delimiter, &self.null_value)?; + + Ok(ColumnarValue::Array(result)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.expr] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + &self.delimiter, + &self.quote, + &self.escape, + &self.null_value, + self.quote_all, + ))) + } + + fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } +} + +pub fn struct_to_csv(array: &StructArray, delimiter: &str, null_value: &str, quote_all: bool) -> Result { + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut csv_string = String::with_capacity(array.len() * 16); + + for row_idx in 0..array.len() { + if array.is_null(row_idx) { + builder.append_null(); + } else { + csv_string.clear(); + for (col_idx, column) in array.columns().iter().enumerate() { + if col_idx > 0 { + csv_string.push_str(delimiter); + } + if column.is_null(row_idx) { + csv_string.push_str(null_value); + } else { + convert_to_string(column, &mut csv_string, row_idx)?; + } + } + } + builder.append_value(&csv_string); + } + Ok(Arc::new(builder.finish())) +} + +#[inline] +fn convert_to_string(array: &ArrayRef, csv_string: &mut String, row_idx: usize) -> Result<()> { + match array.data_type() { + DataType::Boolean => { + let array = as_boolean_array(array); + csv_string.push_str(&array.value(row_idx).to_string()) + } + DataType::Int8 => { + let array = as_int8_array(array)?; + csv_string.push_str(&array.value(row_idx).to_string()) + } + DataType::Int16 => { + let array = as_int16_array(array)?; + csv_string.push_str(&array.value(row_idx).to_string()) + } + DataType::Int32 => { + let array = as_int32_array(array)?; + csv_string.push_str(&array.value(row_idx).to_string()) + } + DataType::Int64 => { + let array = as_int64_array(array)?; + csv_string.push_str(&array.value(row_idx).to_string()) + } + DataType::Utf8 => { + let array = as_string_array(array); + csv_string.push_str(&array.value(row_idx).to_string()) + } + DataType::LargeUtf8 => { + let array = as_largestring_array(array); + csv_string.push_str(&array.value(row_idx).to_string()) + } + _ => return exec_err!("to_csv not implemented for type: {:?}", array.data_type()), + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::csv_funcs::to_csv::struct_to_csv; + use arrow::array::{as_string_array, ArrayRef, Int32Array, StringArray, StructArray}; + use arrow::datatypes::{DataType, Field}; + use datafusion::common::Result; + use std::sync::Arc; + + #[test] + fn test_to_csv_basic() -> Result<()> { + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Utf8, true)), + Arc::new(StringArray::from(vec![Some("foo"), None, Some("baz")])) as ArrayRef, + ), + ]); + + let expected = &StringArray::from(vec!["1,foo", "2,", "3,baz"]); + + let result = struct_to_csv(&Arc::new(struct_array), ",", "")?; + let result = as_string_array(&result); + + assert_eq!(result, expected); + + Ok(()) + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index f26fd911d8..d770338eaa 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -56,6 +56,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; +mod csv_funcs; mod math_funcs; mod nondetermenistic_funcs; @@ -69,6 +70,7 @@ pub use comet_scalar_funcs::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, register_all_comet_functions, }; +pub use csv_funcs::*; pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr}; pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e50b1d80e6..47c96d10cf 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -133,7 +133,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[GetArrayStructFields] -> CometGetArrayStructFields, classOf[GetStructField] -> CometGetStructField, classOf[JsonToStructs] -> CometJsonToStructs, - classOf[StructsToJson] -> CometStructsToJson) + classOf[StructsToJson] -> CometStructsToJson, + classOf[StructsToCsv] -> CometStructsToCsv) private val hashExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Md5] -> CometScalarFunction("md5"), diff --git a/spark/src/main/scala/org/apache/comet/serde/structs.scala b/spark/src/main/scala/org/apache/comet/serde/structs.scala index b76c64bac9..f606c5aa24 100644 --- a/spark/src/main/scala/org/apache/comet/serde/structs.scala +++ b/spark/src/main/scala/org/apache/comet/serde/structs.scala @@ -20,9 +20,10 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ +import scala.util.Try -import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetArrayStructFields, GetStructField, JsonToStructs, StructsToJson} -import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetArrayStructFields, GetStructField, JsonToStructs, StructsToCsv, StructsToJson} +import org.apache.spark.sql.types._ import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} @@ -230,3 +231,62 @@ object CometJsonToStructs extends CometExpressionSerde[JsonToStructs] { } } } + +object CometStructsToCsv extends CometExpressionSerde[StructsToCsv] { + + override def getSupportLevel(expr: StructsToCsv): SupportLevel = { + val isSupportedSchema = expr.inputSchema.fields + .forall(sf => QueryPlanSerde.supportedDataType(sf.dataType)) + if (!isSupportedSchema) { + return Unsupported(Some(s"Unsupported data type: ${expr.inputSchema}")) + } + Incompatible() + } + + override def convert( + expr: StructsToCsv, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + for { + childProto <- exprToProtoInternal(expr.child, inputs, binding) + } yield { + val optionsProto = options2Proto(expr.options, expr.timeZoneId) + val toCsv = ExprOuterClass.ToCsv + .newBuilder() + .setChild(childProto) + .setOptions(optionsProto) + .build() + ExprOuterClass.Expr.newBuilder().setToCsv(toCsv).build() + } + } + + private def options2Proto( + options: Map[String, String], + timeZoneId: Option[String]): ExprOuterClass.CsvWriteOptions = { + ExprOuterClass.CsvWriteOptions + .newBuilder() + .setDelimiter(options.getOrElse("delimiter", ",")) + .setQuote(options.getOrElse("quote", "\"")) + .setEscape(options.getOrElse("escape", "\\")) + .setEscape(options.getOrElse("nullValue", "")) + .setTimezone(timeZoneId.getOrElse("UTC")) + .setIgnoreLeadingWhiteSpace(options + .get("ignoreLeadingWhiteSpace") + .flatMap(ignoreLeadingWhiteSpace => Try(ignoreLeadingWhiteSpace.toBoolean).toOption) + .getOrElse(true)) + .setIgnoreTrailingWhiteSpace(options + .get("ignoreTrailingWhiteSpace") + .flatMap(ignoreTrailingWhiteSpace => Try(ignoreTrailingWhiteSpace.toBoolean).toOption) + .getOrElse(true)) + .setQuoteAll(options + .get("quoteAll") + .flatMap(quoteAll => Try(quoteAll.toBoolean).toOption) + .getOrElse(false)) + .setDateFormat(options.getOrElse("dateFormat", "yyyy-MM-dd")) + .setTimestampFormat(options + .getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]")) + .setTimestampNtzFormat(options + .getOrElse("timestampNTZFormat", "yyyy-MM-dd'T'HH:mm:ss[.SSS]")) + .build() + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCsvExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometCsvExpressionSuite.scala new file mode 100644 index 0000000000..421bdc9625 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCsvExpressionSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import scala.util.Random + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.expressions.StructsToCsv +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions._ + +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} + +class CometCsvExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + test("to_csv") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false), + DataGenOptions(allowNull = true, generateNegativeZero = true)) + } + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[StructsToCsv]) -> "true") { + val df = spark.read + .parquet(filename) + .select(to_csv(struct(col("c0"), col("c1"), col("c2")))) + df.explain(true) + df.printSchema() + checkSparkAnswer(df) + } + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 81ac72247f..e2c6151f90 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -22,6 +22,7 @@ package org.apache.spark.sql import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.{Success, Try} @@ -43,7 +44,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal._ import org.apache.spark.sql.test._ -import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.comet._ import org.apache.comet.shims.ShimCometSparkSessionExtensions @@ -128,6 +129,10 @@ abstract class CometTestBase if (withTol.isDefined) { checkAnswerWithTolerance(dfComet, expected, withTol.get) } else { + val df = + spark.createDataFrame(expected.toList.asJava, new StructType().add("value", StringType)) + df.show(false) + df.printSchema() checkAnswer(dfComet, expected) } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCsvExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCsvExpressionBenchmark.scala new file mode 100644 index 0000000000..94288eb9cb --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCsvExpressionBenchmark.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.sql.catalyst.expressions.CsvToStructs + +import org.apache.comet.CometConf + +/** + * Configuration for a CSV expression benchmark. + * + * @param name + * Name for the benchmark + * @param query + * SQL query to benchmark + * @param extraCometConfigs + * Additional Comet configurations for the scan+exec case + */ +case class CsvExprConfig( + name: String, + query: String, + extraCometConfigs: Map[String, String] = Map.empty) + +// spotless:off +/** + * Benchmark to measure performance of Comet CSV expressions. To run this benchmark: + * `SPARK_GENERATE_BENCHMARK_FILES=1 make + * benchmark-org.apache.spark.sql.benchmark.CometCsvExpressionBenchmark` Results will be written + * to "spark/benchmarks/CometCsvExpressionBenchmark-**results.txt". + */ +// spotless:on +object CometCsvExpressionBenchmark extends CometBenchmarkBase { + + /** + * Generic method to run a CSV expression benchmark with the given configuration. + */ + def runCsvExprBenchmark(config: CsvExprConfig, values: Int): Unit = { + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql( + s"SELECT CAST(value AS STRING) AS c1, CAST(value AS INT) AS c2, CAST(value AS LONG) AS c3 FROM $tbl")) + + val extraConfigs = Map( + CometConf.getExprAllowIncompatConfigKey( + classOf[CsvToStructs]) -> "true") ++ config.extraCometConfigs + + runExpressionBenchmark(config.name, values, config.query, extraConfigs) + } + } + } + + // Configuration for all CSV expression benchmarks + private val csvExpressions = List( + CsvExprConfig("to_csv", "SELECT to_csv(struct(c1, c2, c3)) FROM parquetV1Table")) + + override def runCometBenchmark(args: Array[String]): Unit = { + val values = 1024 * 1024 + + csvExpressions.foreach { config => + runBenchmarkWithTable(config.name, values) { value => + runCsvExprBenchmark(config, value) + } + } + } +}