diff --git a/spark/src/main/scala/org/apache/comet/serde/structs.scala b/spark/src/main/scala/org/apache/comet/serde/structs.scala index 55e031d346..b76c64bac9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/structs.scala +++ b/spark/src/main/scala/org/apache/comet/serde/structs.scala @@ -111,26 +111,6 @@ object CometStructsToJson extends CometExpressionSerde[StructsToJson] { withInfo(expr, "StructsToJson with options is not supported") None } else { - - def isSupportedType(dt: DataType): Boolean = { - dt match { - case StructType(fields) => - fields.forall(f => isSupportedType(f.dataType)) - case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | - DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | - DataTypes.DoubleType | DataTypes.StringType => - true - case DataTypes.DateType | DataTypes.TimestampType => - // TODO implement these types with tests for formatting options and timezone - false - case _: MapType | _: ArrayType => - // Spark supports map and array in StructsToJson but this is not yet - // implemented in Comet - false - case _ => false - } - } - val isSupported = expr.child.dataType match { case s: StructType => s.fields.forall(f => isSupportedType(f.dataType)) @@ -166,6 +146,25 @@ object CometStructsToJson extends CometExpressionSerde[StructsToJson] { } } } + + def isSupportedType(dt: DataType): Boolean = { + dt match { + case StructType(fields) => + fields.forall(f => isSupportedType(f.dataType)) + case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | + DataTypes.DoubleType | DataTypes.StringType => + true + case DataTypes.DateType | DataTypes.TimestampType => + // TODO implement these types with tests for formatting options and timezone + false + case _: MapType | _: ArrayType => + // Spark supports map and array in StructsToJson but this is not yet + // implemented in Comet + false + case _ => false + } + } } object CometJsonToStructs extends CometExpressionSerde[JsonToStructs] { diff --git a/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala b/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala index 00a85930ba..24daebe132 100644 --- a/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala +++ b/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala @@ -229,8 +229,8 @@ object FuzzDataGenerator { Range(0, numRows).map(_ => { r.nextInt(20) match { case 0 if options.allowNull => null - case 1 => Float.NegativeInfinity - case 2 => Float.PositiveInfinity + case 1 if options.generateInfinity => Float.NegativeInfinity + case 2 if options.generateInfinity => Float.PositiveInfinity case 3 => Float.MinValue case 4 => Float.MaxValue case 5 => 0.0f @@ -243,8 +243,8 @@ object FuzzDataGenerator { Range(0, numRows).map(_ => { r.nextInt(20) match { case 0 if options.allowNull => null - case 1 => Double.NegativeInfinity - case 2 => Double.PositiveInfinity + case 1 if options.generateInfinity => Double.NegativeInfinity + case 2 if options.generateInfinity => Double.PositiveInfinity case 3 => Double.MinValue case 4 => Double.MaxValue case 5 => 0.0 @@ -329,4 +329,5 @@ case class DataGenOptions( generateNaN: Boolean = true, baseDate: Long = FuzzDataGenerator.defaultBaseDate, customStrings: Seq[String] = Seq.empty, - maxStringLength: Int = 8) + maxStringLength: Int = 8, + generateInfinity: Boolean = true) diff --git a/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala index 38f5765268..64c330dbdd 100644 --- a/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala @@ -19,24 +19,59 @@ package org.apache.comet +import scala.util.Random + import org.scalactic.source.Position import org.scalatest.Tag +import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase -import org.apache.spark.sql.catalyst.expressions.JsonToStructs +import org.apache.spark.sql.catalyst.expressions.{JsonToStructs, StructsToJson} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions._ + +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus +import org.apache.comet.serde.CometStructsToJson +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} class CometJsonExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { super.test(testName, testTags: _*) { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[JsonToStructs]) -> "true") { + withSQLConf( + CometConf.getExprAllowIncompatConfigKey(classOf[JsonToStructs]) -> "true", + CometConf.getExprAllowIncompatConfigKey(classOf[StructsToJson]) -> "true") { testFun } } } + test("to_json - all supported types") { + assume(!isSpark40Plus) + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false), + DataGenOptions(generateNaN = false, generateInfinity = false)) + } + val table = spark.read.parquet(filename) + val fieldsNames = table.schema.fields + .filter(sf => CometStructsToJson.isSupportedType(sf.dataType)) + .map(sf => col(sf.name)) + .toSeq + val df = table.select(to_json(struct(fieldsNames: _*))) + checkSparkAnswerAndOperator(df) + } + } + test("from_json - basic primitives") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable( diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala index 5b4741ba68..5f1365bd76 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.benchmark -import org.apache.spark.sql.catalyst.expressions.JsonToStructs +import org.apache.spark.sql.catalyst.expressions.{JsonToStructs, StructsToJson} import org.apache.comet.CometConf @@ -106,6 +106,44 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase { FROM $tbl """) + case "to_json - simple primitives" => + spark.sql( + s"""SELECT named_struct("a", CAST(value AS INT), "b", concat("str_", CAST(value AS STRING))) AS json_struct FROM $tbl""") + + case "to_json - all primitive types" => + spark.sql(s""" + SELECT named_struct( + "i32", CAST(value % 1000 AS INT), + "i64", CAST(value * 1000000000L AS LONG), + "f32", CAST(value * 1.5 AS FLOAT), + "f64", CAST(value * 2.5 AS DOUBLE), + "bool", CASE WHEN value % 2 = 0 THEN true ELSE false END, + "str", concat("value_", CAST(value AS STRING)) + ) AS json_struct FROM $tbl + """) + + case "to_json - with nulls" => + spark.sql(s""" + SELECT + CASE + WHEN value % 10 = 0 THEN CAST(NULL AS STRUCT) + WHEN value % 5 = 0 THEN named_struct("a", CAST(NULL AS INT), "b", "test") + WHEN value % 3 = 0 THEN named_struct("a", CAST(123 AS INT), "b", CAST(NULL AS STRING)) + ELSE named_struct("a", CAST(value AS INT), "b", concat("str_", CAST(value AS STRING))) + END AS json_struct + FROM $tbl + """) + + case "to_json - nested struct" => + spark.sql(s""" + SELECT named_struct( + "outer", named_struct( + "inner_a", CAST(value AS INT), + "inner_b", concat("nested_", CAST(value AS STRING)) + ) + ) AS json_struct FROM $tbl + """) + case _ => spark.sql(s""" SELECT @@ -117,8 +155,9 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase { prepareTable(dir, jsonData) val extraConfigs = Map( + CometConf.getExprAllowIncompatConfigKey(classOf[JsonToStructs]) -> "true", CometConf.getExprAllowIncompatConfigKey( - classOf[JsonToStructs]) -> "true") ++ config.extraCometConfigs + classOf[StructsToJson]) -> "true") ++ config.extraCometConfigs runExpressionBenchmark(config.name, values, config.query, extraConfigs) } @@ -127,6 +166,7 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase { // Configuration for all JSON expression benchmarks private val jsonExpressions = List( + // from_json tests JsonExprConfig( "from_json - simple primitives", "a INT, b STRING", @@ -146,7 +186,25 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase { JsonExprConfig( "from_json - field access", "a INT, b STRING", - "SELECT from_json(json_str, 'a INT, b STRING').a FROM parquetV1Table")) + "SELECT from_json(json_str, 'a INT, b STRING').a FROM parquetV1Table"), + + // to_json tests + JsonExprConfig( + "to_json - simple primitives", + "a INT, b STRING", + "SELECT to_json(json_struct) FROM parquetV1Table"), + JsonExprConfig( + "to_json - all primitive types", + "i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE, bool BOOLEAN, str STRING", + "SELECT to_json(json_struct) FROM parquetV1Table"), + JsonExprConfig( + "to_json - with nulls", + "a INT, b STRING", + "SELECT to_json(json_struct) FROM parquetV1Table"), + JsonExprConfig( + "to_json - nested struct", + "outer STRUCT", + "SELECT to_json(json_struct) FROM parquetV1Table")) override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 1024 * 1024