diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 015b5d96b6..1511a107c8 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -245,6 +245,8 @@ message ParquetWriter { optional string job_id = 6; // Task attempt ID for this specific task optional int32 task_attempt_id = 7; +// set of partition columns + repeated string partition_columns = 8; } enum AggregateMode { diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 8349329841..6c1d9ce5b9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFilesExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.comet.{CometConf, ConfigEntry, DataTypeSupport} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -62,7 +63,7 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec } if (cmd.partitionColumns.nonEmpty || cmd.staticPartitions.nonEmpty) { - return Unsupported(Some("Partitioned writes are not supported")) + return Incompatible(Some("Partitioned writes are not supported")) } if (cmd.query.output.exists(attr => DataTypeSupport.isComplexType(attr.dataType))) { @@ -167,6 +168,9 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec other } + val isDynamicOverWriteMode = cmd.partitionColumns.nonEmpty && + SQLConf.get.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // Create FileCommitProtocol for atomic writes val jobId = java.util.UUID.randomUUID().toString val committer = @@ -178,11 +182,7 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec committerClass.getConstructor(classOf[String], classOf[String], classOf[Boolean]) Some( constructor - .newInstance( - jobId, - outputPath, - java.lang.Boolean.FALSE // dynamicPartitionOverwrite = false for now - ) + .newInstance(jobId, outputPath, isDynamicOverWriteMode) .asInstanceOf[org.apache.spark.internal.io.FileCommitProtocol]) } catch { case e: Exception => diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index 3ae7f949ab..f1f5276aef 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -228,4 +228,32 @@ class CometParquetWriterSuite extends CometTestBase { } } } + + test("parquet write with mode overwrite") { + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withTempPath { inputDir => + val inputPath = createTestData(inputDir) + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + val df = spark.read.parquet(inputPath) + + // First write + df.repartition(2).write.parquet(outputPath) + // verifyWrittenFile(outputPath) + // Second write (with overwrite mode and a different record count to make sure we are not reading the same data) + df.limit(500).repartition(2).write.mode("overwrite").parquet(outputPath) + // // Verify the data was written + val resultDf = spark.read.parquet(outputPath) + assert(resultDf.count() == 500, "Expected 1000 rows after overwrite") + } + } + } + } }