diff --git a/datafusion/spark/src/function/aggregate/mod.rs b/datafusion/spark/src/function/aggregate/mod.rs index 3db72669d42bd..12f1478b01461 100644 --- a/datafusion/spark/src/function/aggregate/mod.rs +++ b/datafusion/spark/src/function/aggregate/mod.rs @@ -20,6 +20,8 @@ use std::sync::Arc; pub mod avg; pub mod try_sum; +mod sum; +mod sum_decimal; pub mod expr_fn { use datafusion_functions::export_functions; diff --git a/datafusion/spark/src/function/aggregate/sum.rs b/datafusion/spark/src/function/aggregate/sum.rs new file mode 100644 index 0000000000000..06a70134fcf0a --- /dev/null +++ b/datafusion/spark/src/function/aggregate/sum.rs @@ -0,0 +1,585 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + as_primitive_array, cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + BooleanArray, Int64Array, PrimitiveArray, +}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, +}; +use std::{any::Any, sync::Arc}; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature}; +use datafusion_expr::function::AccumulatorArgs; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SumInteger { + signature: Signature, + eval_mode: EvalMode, +} + +impl SumInteger { + pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { + match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { + signature: Signature::user_defined(Immutable), + eval_mode, + }), + _ => Err(DataFusionError::Internal( + "Invalid data type for SumInteger".into(), + )), + } + } +} + +impl AggregateUDFImpl for SumInteger { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { + if self.eval_mode == EvalMode::Try { + Ok(vec![ + Arc::new(Field::new("sum", DataType::Int64, true)), + Arc::new(Field::new("has_all_nulls", DataType::Boolean, false)), + ]) + } else { + Ok(vec![Arc::new(Field::new("sum", DataType::Int64, true))]) + } + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode))) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug)] +struct SumIntegerAccumulator { + sum: Option, + eval_mode: EvalMode, + has_all_nulls: bool, +} + +impl SumIntegerAccumulator { + fn new(eval_mode: EvalMode) -> Self { + if eval_mode == EvalMode::Try { + Self { + // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow) + sum: Some(0), + has_all_nulls: true, + eval_mode, + } + } else { + Self { + sum: None, + has_all_nulls: false, + eval_mode, + } + } + } +} + +impl Accumulator for SumIntegerAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + // accumulator internal to add sum and return null sum (and has_nulls false) if there is an overflow in Try Eval mode + fn update_sum_internal( + int_array: &PrimitiveArray, + eval_mode: EvalMode, + mut sum: i64, + ) -> Result, DataFusionError> + where + T: ArrowPrimitiveType, + { + for i in 0..int_array.len() { + if !int_array.is_null(i) { + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to convert value {:?} to i64", + int_array.value(i) + )) + })?; + match eval_mode { + EvalMode::Legacy => { + sum = v.add_wrapping(sum); + } + EvalMode::Ansi | EvalMode::Try => { + match v.add_checked(sum) { + Ok(v) => sum = v, + Err(_e) => { + return if eval_mode == EvalMode::Ansi { + Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))) + } else { + Ok(None) + }; + } + }; + } + } + } + } + Ok(Some(sum)) + } + + if self.eval_mode == EvalMode::Try && !self.has_all_nulls && self.sum.is_none() { + // we saw an overflow earlier (Try eval mode). Skip processing + return Ok(()); + } + let values = &values[0]; + if values.len() == values.null_count() { + Ok(()) + } else { + // No nulls so there should be a non-null sum / null incase overflow in Try eval + let running_sum = self.sum.unwrap_or(0); + let sum = match values.data_type() { + DataType::Int64 => update_sum_internal( + as_primitive_array::(values), + self.eval_mode, + running_sum, + )?, + DataType::Int32 => update_sum_internal( + as_primitive_array::(values), + self.eval_mode, + running_sum, + )?, + DataType::Int16 => update_sum_internal( + as_primitive_array::(values), + self.eval_mode, + running_sum, + )?, + DataType::Int8 => update_sum_internal( + as_primitive_array::(values), + self.eval_mode, + running_sum, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "unsupported data type: {:?}", + values.data_type() + ))); + } + }; + self.sum = sum; + self.has_all_nulls = false; + Ok(()) + } + } + + fn evaluate(&mut self) -> DFResult { + if self.has_all_nulls { + Ok(ScalarValue::Int64(None)) + } else { + Ok(ScalarValue::Int64(self.sum)) + } + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + if self.eval_mode == EvalMode::Try { + Ok(vec![ + ScalarValue::Int64(self.sum), + ScalarValue::Boolean(Some(self.has_all_nulls)), + ]) + } else { + Ok(vec![ScalarValue::Int64(self.sum)]) + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + let expected_state_len = if self.eval_mode == EvalMode::Try { + 2 + } else { + 1 + }; + if expected_state_len != states.len() { + return Err(DataFusionError::Internal(format!( + "Invalid state while merging batch. Expected {} elements but found {}", + expected_state_len, + states.len() + ))); + } + + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None + } else { + Some(that_sum_array.value(0)) + }; + + // Check for overflow for early termination + if self.eval_mode == EvalMode::Try { + let that_has_all_nulls = states[1].as_boolean().value(0); + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + let this_overflowed = !self.has_all_nulls && self.sum.is_none(); + if that_overflowed || this_overflowed { + self.sum = None; + self.has_all_nulls = false; + return Ok(()); + } + if that_has_all_nulls { + return Ok(()); + } + if self.has_all_nulls { + self.sum = that_sum; + self.has_all_nulls = false; + return Ok(()); + } + } else { + if that_sum.is_none() { + return Ok(()); + } + if self.sum.is_none() { + self.sum = that_sum; + return Ok(()); + } + } + + // safe to unwrap (since we checked nulls above) but handling error just in case state is corrupt + let left = self.sum.ok_or_else(|| { + DataFusionError::Internal( + "Invalid state in merging batch. Current batch's sum is None".to_string(), + ) + })?; + let right = that_sum.ok_or_else(|| { + DataFusionError::Internal( + "Invalid state in merging batch. Incoming sum is None".to_string(), + ) + })?; + + match self.eval_mode { + EvalMode::Legacy => { + self.sum = Some(left.add_wrapping(right)); + } + EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) { + Ok(v) => self.sum = Some(v), + Err(_) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))); + } else { + self.sum = None; + self.has_all_nulls = false; + } + } + }, + } + Ok(()) + } +} + +struct SumIntGroupsAccumulator { + sums: Vec>, + has_all_nulls: Vec, + eval_mode: EvalMode, +} + +impl SumIntGroupsAccumulator { + fn new(eval_mode: EvalMode) -> Self { + Self { + sums: Vec::new(), + eval_mode, + has_all_nulls: Vec::new(), + } + } + + fn resize_helper(&mut self, total_num_groups: usize) { + if self.eval_mode == EvalMode::Try { + self.sums.resize(total_num_groups, Some(0)); + self.has_all_nulls.resize(total_num_groups, true); + } else { + self.sums.resize(total_num_groups, None); + self.has_all_nulls.resize(total_num_groups, false); + } + } +} + +impl GroupsAccumulator for SumIntGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + fn update_groups_sum_internal( + int_array: &PrimitiveArray, + group_indices: &[usize], + sums: &mut [Option], + has_all_nulls: &mut [bool], + eval_mode: EvalMode, + ) -> DFResult<()> + where + T: ArrowPrimitiveType, + T::Native: ArrowNativeType, + { + for (i, &group_index) in group_indices.iter().enumerate() { + if !int_array.is_null(i) { + // there is an overflow in prev group in try eval. Skip processing + if eval_mode == EvalMode::Try + && !has_all_nulls[group_index] + && sums[group_index].is_none() + { + continue; + } + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal("Failed to convert value to i64".to_string()) + })?; + match eval_mode { + EvalMode::Legacy => { + sums[group_index] = + Some(sums[group_index].unwrap_or(0).add_wrapping(v)); + } + EvalMode::Ansi | EvalMode::Try => { + match sums[group_index].unwrap_or(0).add_checked(v) { + Ok(new_sum) => { + sums[group_index] = Some(new_sum); + } + Err(_) => { + if eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from( + arithmetic_overflow_error("integer"), + )); + } else { + sums[group_index] = None; + } + } + }; + } + } + has_all_nulls[group_index] = false + } + } + Ok(()) + } + + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + let values = &values[0]; + self.resize_helper(total_num_groups); + + match values.data_type() { + DataType::Int64 => update_groups_sum_internal( + as_primitive_array::(values), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int32 => update_groups_sum_internal( + as_primitive_array::(values), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int16 => update_groups_sum_internal( + as_primitive_array::(values), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int8 => update_groups_sum_internal( + as_primitive_array::(values), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type for SumIntGroupsAccumulator: {:?}", + values.data_type() + ))) + } + }; + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { + match emit_to { + EmitTo::All => { + let result = Arc::new(Int64Array::from_iter( + self.sums + .iter() + .zip(self.has_all_nulls.iter()) + .map(|(&sum, &is_null)| if is_null { None } else { sum }), + )) as ArrayRef; + + self.sums.clear(); + self.has_all_nulls.clear(); + Ok(result) + } + EmitTo::First(n) => { + let result = Arc::new(Int64Array::from_iter( + self.sums + .drain(..n) + .zip(self.has_all_nulls.drain(..n)) + .map(|(sum, is_null)| if is_null { None } else { sum }), + )) as ArrayRef; + Ok(result) + } + } + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult> { + let sums = emit_to.take_needed(&mut self.sums); + + if self.eval_mode == EvalMode::Try { + let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls); + Ok(vec![ + Arc::new(Int64Array::from(sums)), + Arc::new(BooleanArray::from(has_all_nulls)), + ]) + } else { + Ok(vec![Arc::new(Int64Array::from(sums))]) + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + let expected_state_len = if self.eval_mode == EvalMode::Try { + 2 + } else { + 1 + }; + if expected_state_len != values.len() { + return Err(DataFusionError::Internal(format!( + "Invalid state while merging batch. Expected {} elements but found {}", + expected_state_len, + values.len() + ))); + } + let that_sums = values[0].as_primitive::(); + + self.resize_helper(total_num_groups); + + let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try { + Some(values[1].as_boolean()) + } else { + None + }; + + for (idx, &group_index) in group_indices.iter().enumerate() { + let that_sum = if that_sums.is_null(idx) { + None + } else { + Some(that_sums.value(idx)) + }; + + if self.eval_mode == EvalMode::Try { + let that_has_all_nulls = that_sums_is_all_nulls.unwrap().value(idx); + + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + let this_overflowed = + !self.has_all_nulls[group_index] && self.sums[group_index].is_none(); + + if that_overflowed || this_overflowed { + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; + continue; + } + + if that_has_all_nulls { + continue; + } + + if self.has_all_nulls[group_index] { + self.sums[group_index] = that_sum; + self.has_all_nulls[group_index] = false; + continue; + } + } else { + if that_sum.is_none() { + continue; + } + if self.sums[group_index].is_none() { + self.sums[group_index] = that_sum; + continue; + } + } + + // Both sides have non-null. Update sums now + let left = self.sums[group_index].unwrap(); + let right = that_sum.unwrap(); + + match self.eval_mode { + EvalMode::Legacy => { + self.sums[group_index] = Some(left.add_wrapping(right)); + } + EvalMode::Ansi | EvalMode::Try => { + match left.add_checked(right) { + Ok(v) => self.sums[group_index] = Some(v), + Err(_) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))); + } else { + // overflow. update flag accordingly + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; + } + } + } + } + } + } + Ok(()) + } + + fn size(&self) -> usize { + size_of_val(self) + } +} diff --git a/datafusion/spark/src/function/aggregate/sum_decimal.rs b/datafusion/spark/src/function/aggregate/sum_decimal.rs new file mode 100644 index 0000000000000..2fc1585fd77f9 --- /dev/null +++ b/datafusion/spark/src/function/aggregate/sum_decimal.rs @@ -0,0 +1,621 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::is_valid_decimal_precision; +use crate::{arithmetic_overflow_error, EvalMode}; +use arrow::array::{ + cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, +}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, +}; +use std::{any::Any, sync::Arc}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SumDecimal { + /// Aggregate function signature + signature: Signature, + /// The data type of the SUM result. This will always be a decimal type + /// with the same precision and scale as specified in this struct + result_type: DataType, + /// Decimal precision + precision: u8, + /// Decimal scale + scale: i8, + eval_mode: EvalMode, +} + +impl SumDecimal { + pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { + let (precision, scale) = match data_type { + DataType::Decimal128(p, s) => (p, s), + _ => { + return Err(DataFusionError::Internal( + "Invalid data type for SumDecimal".into(), + )) + } + }; + Ok(Self { + signature: Signature::user_defined(Immutable), + result_type: data_type, + precision, + scale, + eval_mode, + }) + } +} + +impl AggregateUDFImpl for SumDecimal { + fn as_any(&self) -> &dyn Any { + self + } + + fn accumulator(&self, _args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(SumDecimalAccumulator::new( + self.precision, + self.scale, + self.eval_mode, + ))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { + // For decimal sum, we always track is_empty regardless of eval_mode + // This matches Spark's behavior where DecimalType always uses shouldTrackIsEmpty = true + let data_type = self.result_type.clone(); + Ok(vec![ + Arc::new(Field::new("sum", data_type, true)), + Arc::new(Field::new("is_empty", DataType::Boolean, false)), + ]) + } + + fn name(&self) -> &str { + "sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(self.result_type.clone()) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(SumDecimalGroupsAccumulator::new( + self.result_type.clone(), + self.precision, + self.eval_mode, + ))) + } + + fn default_value(&self, _data_type: &DataType) -> DFResult { + ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn is_nullable(&self) -> bool { + // SumDecimal is always nullable because overflows can cause null values + true + } +} + +#[derive(Debug)] +struct SumDecimalAccumulator { + sum: Option, + is_empty: bool, + precision: u8, + scale: i8, + eval_mode: EvalMode, +} + +impl SumDecimalAccumulator { + fn new(precision: u8, scale: i8, eval_mode: EvalMode) -> Self { + // For decimal sum, always track is_empty regardless of eval_mode + // This matches Spark's behavior where DecimalType always uses shouldTrackIsEmpty = true + Self { + sum: Some(0), + is_empty: true, + precision, + scale, + eval_mode, + } + } + + fn update_single(&mut self, values: &Decimal128Array, idx: usize) -> DFResult<()> { + // If already overflowed (sum is None but not empty), stay in overflow state + if !self.is_empty && self.sum.is_none() { + return Ok(()); + } + + let v = unsafe { values.value_unchecked(idx) }; + let running_sum = self.sum.unwrap_or(0); + let (new_sum, is_overflow) = running_sum.overflowing_add(v); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } + self.sum = None; + self.is_empty = false; + return Ok(()); + } + + self.sum = Some(new_sum); + self.is_empty = false; + Ok(()) + } +} + +impl Accumulator for SumDecimalAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + assert_eq!( + values.len(), + 1, + "Expect only one element in 'values' but found {}", + values.len() + ); + + // For decimal sum, always check for overflow regardless of eval_mode (per Spark's expectation) + if !self.is_empty && self.sum.is_none() { + return Ok(()); + } + + let values = &values[0]; + let data = values.as_primitive::(); + + // Update is_empty: it remains true only if it was true AND all values are null + self.is_empty = self.is_empty && values.len() == values.null_count(); + + if self.is_empty { + return Ok(()); + } + + for i in 0..data.len() { + if data.is_null(i) { + continue; + } + self.update_single(data, i)?; + } + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + if self.is_empty { + ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ) + } else { + match self.sum { + Some(sum_value) if is_valid_decimal_precision(sum_value, self.precision) => { + ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale) + } + _ => ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ), + } + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + let sum = match self.sum { + Some(sum_value) => { + ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale)? + } + None => ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + )?, + }; + + // For decimal sum, always return 2 state values regardless of eval_mode + Ok(vec![sum, ScalarValue::from(self.is_empty)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + // For decimal sum, always expect 2 state arrays regardless of eval_mode + assert_eq!( + states.len(), + 2, + "Expect two elements in 'states' but found {}", + states.len() + ); + assert_eq!(states[0].len(), 1); + assert_eq!(states[1].len(), 1); + + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None + } else { + Some(that_sum_array.value(0)) + }; + + let that_is_empty = states[1].as_boolean().value(0); + let that_overflowed = !that_is_empty && that_sum.is_none(); + let this_overflowed = !self.is_empty && self.sum.is_none(); + + if that_overflowed || this_overflowed { + self.sum = None; + self.is_empty = false; + return Ok(()); + } + + if that_is_empty { + return Ok(()); + } + + if self.is_empty { + self.sum = that_sum; + self.is_empty = false; + return Ok(()); + } + + let left = self.sum.unwrap(); + let right = that_sum.unwrap(); + let (new_sum, is_overflow) = left.overflowing_add(right); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } else { + self.sum = None; + self.is_empty = false; + } + } else { + self.sum = Some(new_sum); + } + + Ok(()) + } +} + +struct SumDecimalGroupsAccumulator { + sum: Vec>, + is_empty: Vec, + result_type: DataType, + precision: u8, + eval_mode: EvalMode, +} + +impl SumDecimalGroupsAccumulator { + fn new(result_type: DataType, precision: u8, eval_mode: EvalMode) -> Self { + Self { + sum: Vec::new(), + is_empty: Vec::new(), + result_type, + precision, + eval_mode, + } + } + + fn resize_helper(&mut self, total_num_groups: usize) { + // For decimal sum, always initialize properly regardless of eval_mode + self.sum.resize(total_num_groups, Some(0)); + self.is_empty.resize(total_num_groups, true); + } + + #[inline] + fn update_single(&mut self, group_index: usize, value: i128) -> DFResult<()> { + // For decimal sum, always check for overflow regardless of eval_mode + if !self.is_empty[group_index] && self.sum[group_index].is_none() { + return Ok(()); + } + + let running_sum = self.sum[group_index].unwrap_or(0); + let (new_sum, is_overflow) = running_sum.overflowing_add(value); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } + self.sum[group_index] = None; + } else { + self.sum[group_index] = Some(new_sum); + } + self.is_empty[group_index] = false; + Ok(()) + } +} + +impl GroupsAccumulator for SumDecimalGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + assert_eq!(values.len(), 1); + let values = values[0].as_primitive::(); + let data = values.values(); + + self.resize_helper(total_num_groups); + + let iter = group_indices.iter().zip(data.iter()); + if values.null_count() == 0 { + for (&group_index, &value) in iter { + self.update_single(group_index, value)?; + } + } else { + for (idx, (&group_index, &value)) in iter.enumerate() { + if values.is_null(idx) { + continue; + } + self.update_single(group_index, value)?; + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { + match emit_to { + EmitTo::All => { + let result = + Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map( + |(&sum, &empty)| { + if empty { + None + } else { + match sum { + Some(v) if is_valid_decimal_precision(v, self.precision) => { + Some(v) + } + _ => None, + } + } + }, + )) + .with_data_type(self.result_type.clone()); + + self.sum.clear(); + self.is_empty.clear(); + Ok(Arc::new(result)) + } + EmitTo::First(n) => { + let result = Decimal128Array::from_iter( + self.sum + .drain(..n) + .zip(self.is_empty.drain(..n)) + .map(|(sum, empty)| { + if empty { + None + } else { + match sum { + Some(v) if is_valid_decimal_precision(v, self.precision) => { + Some(v) + } + _ => None, + } + } + }), + ) + .with_data_type(self.result_type.clone()); + + Ok(Arc::new(result)) + } + } + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult> { + let sums = emit_to.take_needed(&mut self.sum); + + let sum_array = Decimal128Array::from_iter(sums.iter().copied()) + .with_data_type(self.result_type.clone()); + + // For decimal sum, always return 2 state arrays regardless of eval_mode + let is_empty = emit_to.take_needed(&mut self.is_empty); + Ok(vec![ + Arc::new(sum_array), + Arc::new(BooleanArray::from(is_empty)), + ]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + self.resize_helper(total_num_groups); + + // For decimal sum, always expect 2 arrays regardless of eval_mode + assert_eq!( + values.len(), + 2, + "Expected two arrays: 'sum' and 'is_empty', but found {}", + values.len() + ); + + let that_sum = values[0].as_primitive::(); + let that_is_empty = values[1].as_boolean(); + + for (idx, &group_index) in group_indices.iter().enumerate() { + let that_sum_val = if that_sum.is_null(idx) { + None + } else { + Some(that_sum.value(idx)) + }; + + let that_is_empty_val = that_is_empty.value(idx); + let that_overflowed = !that_is_empty_val && that_sum_val.is_none(); + let this_overflowed = !self.is_empty[group_index] && self.sum[group_index].is_none(); + + if that_overflowed || this_overflowed { + self.sum[group_index] = None; + self.is_empty[group_index] = false; + continue; + } + + if that_is_empty_val { + continue; + } + + if self.is_empty[group_index] { + self.sum[group_index] = that_sum_val; + self.is_empty[group_index] = false; + continue; + } + + let left = self.sum[group_index].unwrap(); + let right = that_sum_val.unwrap(); + let (new_sum, is_overflow) = left.overflowing_add(right); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } else { + self.sum[group_index] = None; + self.is_empty[group_index] = false; + } + } else { + self.sum[group_index] = Some(new_sum); + } + } + + Ok(()) + } + + fn size(&self) -> usize { + self.sum.capacity() * std::mem::size_of::>() + + self.is_empty.capacity() * std::mem::size_of::() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::builder::{Decimal128Builder, StringBuilder}; + use arrow::array::RecordBatch; + use arrow::datatypes::*; + use datafusion::common::Result; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; + use datafusion::execution::TaskContext; + use datafusion::logical_expr::AggregateUDF; + use datafusion::physical_expr::aggregate::AggregateExprBuilder; + use datafusion::physical_expr::expressions::Column; + use datafusion::physical_expr::PhysicalExpr; + use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + use datafusion::physical_plan::ExecutionPlan; + use futures::StreamExt; + + #[test] + fn invalid_data_type() { + assert!(SumDecimal::try_new(DataType::Int32, EvalMode::Legacy).is_err()); + } + + #[tokio::test] + async fn sum_no_overflow() -> Result<()> { + let num_rows = 8192; + let batch = create_record_batch(num_rows); + let mut batches = Vec::new(); + for _ in 0..10 { + batches.push(batch.clone()); + } + let partitions = &[batches]; + let c0: Arc = Arc::new(Column::new("c0", 0)); + let c1: Arc = Arc::new(Column::new("c1", 1)); + + let data_type = DataType::Decimal128(8, 2); + let schema = Arc::clone(&partitions[0][0].schema()); + let scan: Arc = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(), + ))); + + let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( + data_type.clone(), + EvalMode::Legacy, + )?)); + + let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) + .schema(Arc::clone(&schema)) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), + vec![aggr_expr.into()], + vec![None], // no filter expressions + scan, + Arc::clone(&schema), + )?); + + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch?; + } + + Ok(()) + } + + fn create_record_batch(num_rows: usize) -> RecordBatch { + let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); + let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); + for i in 0..num_rows { + decimal_builder.append_value(i as i128); + string_builder.append_value(format!("this is string #{}", i % 1024)); + } + let decimal_array = Arc::new(decimal_builder.finish()); + let string_array = Arc::new(string_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + // string column + fields.push(Field::new("c0", DataType::Utf8, false)); + columns.push(string_array); + + // decimal column + fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); + columns.push(decimal_array); + + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() + } +}