Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ object CometConf extends ShimCometConf {
.intConf
.createWithDefault(1)

val COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.columnar.shuffle.complexTypes.enabled")
.category(CATEGORY_SHUFFLE)
.doc(
"Whether to enable Comet columnar shuffle for complex types (struct, array, map). " +
"When disabled (default), queries with complex types will fall back to Spark shuffle " +
"for better performance. Enable this only if you need columnar shuffle features for " +
"complex types and accept potential performance tradeoffs.")
.booleanConf
.createWithDefault(false)

val COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.columnar.shuffle.async.enabled")
.category(CATEGORY_SHUFFLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom
import com.google.common.base.Objects

import org.apache.comet.CometConf
import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE}
import org.apache.comet.CometConf.{COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED, COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE}
import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo}
import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported}
import org.apache.comet.serde.operator.CometSink
Expand Down Expand Up @@ -403,23 +403,39 @@ object CometShuffleExchangeExec
*
* Comet columnar shuffle used native code to convert Spark unsafe rows to Arrow batches, see
* shuffle/row.rs
*
* Returns None if supported, or Some(reason) if not supported.
*/
def supportedSerializableDataType(dt: DataType): Boolean = dt match {
def supportedSerializableDataType(dt: DataType): Option[String] = dt match {
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
_: TimestampNTZType | _: DecimalType | _: DateType =>
true
None
case StructType(fields) =>
fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) &&
// Java Arrow stream reader cannot work on duplicate field name
fields.map(f => f.name).distinct.length == fields.length &&
fields.nonEmpty
if (!COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf)) {
Some(s"${COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key} is not enabled")
Comment on lines +415 to +416
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

} else if (fields.isEmpty) {
Some("struct type with no fields is not supported")
} else if (fields.map(f => f.name).distinct.length != fields.length) {
// Java Arrow stream reader cannot work on duplicate field name
Some("struct type with duplicate field names is not supported")
} else {
fields.flatMap(f => supportedSerializableDataType(f.dataType)).headOption
}
case ArrayType(elementType, _) =>
supportedSerializableDataType(elementType)
if (!COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf)) {
Some(s"${COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key} is not enabled")
} else {
supportedSerializableDataType(elementType)
}
case MapType(keyType, valueType, _) =>
supportedSerializableDataType(keyType) && supportedSerializableDataType(valueType)
if (!COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf)) {
Some(s"${COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key} is not enabled")
} else {
supportedSerializableDataType(keyType).orElse(supportedSerializableDataType(valueType))
}
case _ =>
false
Some(s"unsupported data type: $dt")
}

if (!isCometShuffleEnabledWithInfo(s)) {
Expand All @@ -444,9 +460,13 @@ object CometShuffleExchangeExec
val inputs = s.child.output

for (input <- inputs) {
if (!supportedSerializableDataType(input.dataType)) {
withInfo(s, s"unsupported shuffle data type ${input.dataType} for input $input")
return false
supportedSerializableDataType(input.dataType) match {
case Some(reason) =>
withInfo(
s,
s"unsupported data type ${input.dataType} for column ${input.name}: $reason")
return false
case None => // supported
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,42 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
checkSparkAnswer(df)
}

test("Fallback to Spark for complex types when config is disabled (default)") {
// https://github.com/apache/datafusion-comet/issues/2904
// By default, complex types should fall back to Spark shuffle for better performance
withSQLConf(CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "false") {
// Test struct type
withParquetTable(Seq((1, (0, "1")), (2, (3, "3"))), "tbl") {
val df = sql("SELECT * FROM tbl").repartition(10, $"_1", $"_2")
// Should have 0 Comet shuffle exchanges since complex types are disabled
checkCometExchange(df, 0, false)
checkSparkAnswer(df)
}

// Test array type
withParquetTable((0 until 10).map(i => (Seq(i, i + 1), i + 1)), "tbl2") {
val df = sql("SELECT * FROM tbl2").repartition(10, $"_1", $"_2")
checkCometExchange(df, 0, false)
checkSparkAnswer(df)
}

// Test map type
withParquetTable((0 until 10).map(i => (Map(i -> i.toString), i + 1)), "tbl3") {
val df = sql("SELECT * FROM tbl3").repartition(10, $"_1", $"_2")
checkCometExchange(df, 0, false)
checkSparkAnswer(df)
}
}
}

test("columnar shuffle on nested struct including nulls") {
// https://github.com/apache/datafusion-comet/issues/1538
assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withParquetTable(
(0 until 50).map(i =>
(i, Seq((i + 1, i.toString), null, (i + 3, (i + 3).toString)), i + 1)),
Expand All @@ -137,7 +167,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
test("columnar shuffle on struct including nulls") {
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
val data: Seq[(Int, (Int, String))] =
Seq((1, (0, "1")), (2, (3, "3")), (3, null))
withParquetTable(data, "tbl") {
Expand All @@ -158,6 +190,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> execEnabled,
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withParquetTable((0 until 50).map(i => (Map(Seq(i, i + 1) -> i), i + 1)), "tbl") {
Expand Down Expand Up @@ -230,6 +263,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> execEnabled,
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withParquetTable(
Expand Down Expand Up @@ -336,7 +370,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
def columnarShuffleOnMapTest[K: TypeTag](num: Int, keys: Seq[K]): Unit = {
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withParquetTable(genTuples(num, keys), "tbl") {
repartitionAndSort(numPartitions)
}
Expand Down Expand Up @@ -451,7 +487,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar

Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withParquetTable(
(0 until 50).map(i =>
(
Expand Down Expand Up @@ -483,7 +521,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
Seq("false", "true").foreach { _ =>
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withParquetTable(
(0 until 50).map(i => (Seq(Seq(i + 1), Seq(i + 2), Seq(i + 3)), i + 1)),
"tbl") {
Expand All @@ -503,7 +543,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
test("columnar shuffle on nested struct") {
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
withParquetTable(
(0 until 50).map(i =>
((i, 2.toString, (i + 1).toLong, (3.toString, i + 1, (i + 2).toLong)), i + 1)),
Expand Down Expand Up @@ -871,29 +913,31 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
}

test("columnar shuffle on null struct fields") {
withTempDir { dir =>
val testData = "{}\n"
val path = Paths.get(dir.toString, "test.json")
Files.write(path, testData.getBytes)

// Define the nested struct schema
val readSchema = StructType(
Array(
StructField(
"metaData",
StructType(
Array(StructField(
"format",
StructType(Array(StructField("provider", StringType, nullable = true))),
nullable = true))),
nullable = true)))

// Read JSON with custom schema and repartition, this will repartition rows that contain
// null struct fields.
val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2)
assert(df.count() == 1)
val row = df.collect()(0)
assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null)
withSQLConf(CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true") {
withTempDir { dir =>
val testData = "{}\n"
val path = Paths.get(dir.toString, "test.json")
Files.write(path, testData.getBytes)

// Define the nested struct schema
val readSchema = StructType(
Array(
StructField(
"metaData",
StructType(
Array(StructField(
"format",
StructType(Array(StructField("provider", StringType, nullable = true))),
nullable = true))),
nullable = true)))

// Read JSON with custom schema and repartition, this will repartition rows that contain
// null struct fields.
val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2)
assert(df.count() == 1)
val row = df.collect()(0)
assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null)
}
}
}

Expand Down
Loading