Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pandas import (
NA,
ArrowDtype,
Categorical,
DataFrame,
Index,
Expand Down Expand Up @@ -1152,4 +1153,54 @@ def time_resample_multiindex(self):
).mean()


class GroupByAggregateArrowDtypes:
param_names = ["dtype", "method"]
params = [
[
"int32[pyarrow]",
"int64[pyarrow]",
"float32[pyarrow]",
"float64[pyarrow]",
"decimal128",
"string[pyarrow]",
],
["sum", "prod", "min", "max", "mean", "std", "var", "count"],
]

# String types only support min, max, count
_string_unsupported = {"sum", "prod", "mean", "std", "var"}

def setup(self, dtype, method):
import pyarrow as pa

from pandas.api.types import is_string_dtype

if dtype == "string[pyarrow]" and method in self._string_unsupported:
raise NotImplementedError("skipped")

size = 100_000
ngroups = 1000

if dtype in ("int32[pyarrow]", "int64[pyarrow]"):
data = np.random.randint(0, 10_000, size)
elif dtype in ("float32[pyarrow]", "float64[pyarrow]"):
data = np.random.randn(size)
elif dtype == "decimal128":
from decimal import Decimal

data = [Decimal(str(round(x, 3))) for x in np.random.randn(size)]
dtype = ArrowDtype(pa.decimal128(10, 3))
elif dtype == "string[pyarrow]":
data = np.random.choice(list(ascii_letters), size)

ser = Series(data, dtype=dtype)
if not is_string_dtype(ser.dtype):
ser.iloc[::10] = NA
self.ser = ser
self.key = np.random.randint(0, ngroups, size)

def time_series_agg(self, dtype, method):
self.ser.groupby(self.key).agg(method)


from .pandas_vb_common import setup # noqa: F401 isort:skip
171 changes: 157 additions & 14 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,6 +2603,127 @@ def _to_masked(self):
arr = self.to_numpy(dtype=dtype.numpy_dtype, na_value=na_value)
return dtype.construct_array_type()(arr, mask)

# pandas groupby 'how' -> PyArrow aggregation function name
_PYARROW_AGG_FUNCS: dict[str, str] = {
"sum": "sum",
"prod": "product",
"min": "min",
"max": "max",
"mean": "mean",
"std": "stddev",
"var": "variance",
"sem": "stddev", # sem = stddev / sqrt(count)
"count": "count",
"any": "any",
"all": "all",
}

# Identity elements for operations (used to fill missing groups)
_PYARROW_AGG_DEFAULTS: dict[str, int | bool] = {
"sum": 0,
"prod": 1,
"count": 0,
"any": False,
"all": True,
}

def _groupby_op_pyarrow(
self,
*,
how: str,
min_count: int,
ngroups: int,
ids: npt.NDArray[np.intp],
**kwargs,
) -> Self | None:
"""
Perform groupby aggregation using PyArrow's native Table.group_by.

Returns None if not supported, caller should fall back to Cython path.
"""
pa_agg_func = self._PYARROW_AGG_FUNCS.get(how)
if pa_agg_func is None:
return None

pa_type = self._pa_array.type
if pa.types.is_temporal(pa_type) and how in ["std", "var", "sem"]:
return None
if how in ["any", "all"] and not pa.types.is_boolean(pa_type):
return None
# PyArrow doesn't support sum/prod/mean/std/var/sem on strings
is_str = pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type)
if is_str and how in ["sum", "prod", "mean", "std", "var", "sem"]:
return None

# Filter out NA group (ids == -1)
mask = ids >= 0
if not mask.all():
ids = ids[mask]
values = pc.filter(self._pa_array, mask)
else:
values = self._pa_array

# Build table and run aggregation (cast ids to int64 for portability)
group_id_arr = pa.array(ids, type=pa.int64())
table = pa.table({"value": values, "group_id": group_id_arr})

if how in ["std", "var", "sem"]:
ddof = kwargs.get("ddof", 1)
aggs: list[tuple[str, str] | tuple[str, str, pc.VarianceOptions]] = [
("value", pa_agg_func, pc.VarianceOptions(ddof=ddof))
]
else:
aggs = [("value", pa_agg_func)]
aggs.append(("value", "count"))

result_table = table.group_by("group_id").aggregate(aggs)
result_group_ids = result_table.column("group_id")
result_values = result_table.column(f"value_{pa_agg_func}")
result_counts = result_table.column("value_count")

if how == "sem":
result_values = pc.divide(result_values, pc.sqrt(result_counts))

output_type = result_values.type
default_value = pa.scalar(self._PYARROW_AGG_DEFAULTS.get(how), type=output_type)

# Replace nulls from all-null groups with identity element
if result_values.null_count > 0 and how in ["sum", "prod"] and min_count == 0:
result_values = pc.if_else(
pc.is_null(result_values), default_value, result_values
)

# Null out groups below min_count
if min_count > 0:
below_min_count = pc.less(result_counts, pa.scalar(min_count))
result_values = pc.if_else(below_min_count, None, result_values)

# Scatter results into output array ordered by group id.
# NumPy scatter is O(n) vs O(n log n) for join+sort or pc.scatter workaround.
result_group_ids_np = result_group_ids.to_numpy(zero_copy_only=False).astype(
np.int64, copy=False
)
result_values_np = result_values.to_numpy(zero_copy_only=False)

default_py = default_value.as_py()
if default_py is not None and min_count == 0:
# Fill missing groups with identity element
output_np = np.full(ngroups, default_py, dtype=result_values_np.dtype)
output_np[result_group_ids_np] = result_values_np
pa_result = pa.array(output_np, type=output_type)
else:
# Fill missing groups with null
output_np = np.empty(ngroups, dtype=result_values_np.dtype)
null_mask = np.ones(ngroups, dtype=bool)
output_np[result_group_ids_np] = result_values_np
null_mask[result_group_ids_np] = False
if result_values.null_count > 0:
result_nulls = pc.is_null(result_values).to_numpy()
null_mask[result_group_ids_np[result_nulls]] = True
pa_result = pa.array(output_np, type=output_type, mask=null_mask)

return self._from_pyarrow_array(pa_result)

def _groupby_op(
self,
*,
Expand All @@ -2628,41 +2749,63 @@ def _groupby_op(
raise TypeError(
f"dtype '{self.dtype}' does not support operation '{how}'"
)
return super()._groupby_op(
# Fall through to Arrow-native path below

pa_type = self._pa_array.type

# Try PyArrow-native path for decimal and string types where it's faster.
# For integer/float/boolean, the fallback path via _to_masked() is faster.
if (
pa.types.is_decimal(pa_type)
or pa.types.is_string(pa_type)
or pa.types.is_large_string(pa_type)
):
result = self._groupby_op_pyarrow(
how=how,
has_dropped_na=has_dropped_na,
min_count=min_count,
ngroups=ngroups,
ids=ids,
**kwargs,
)
if result is not None:
return result
# For string types, fall back to parent implementation (Python path)
# since _to_masked() doesn't support strings
if pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
return super()._groupby_op(
how=how,
has_dropped_na=has_dropped_na,
min_count=min_count,
ngroups=ngroups,
ids=ids,
**kwargs,
)

# maybe convert to a compatible dtype optimized for groupby
values: ExtensionArray
pa_type = self._pa_array.type
# Fall back to converting to masked/datetime array and using Cython
fallback_values: ExtensionArray
if pa.types.is_timestamp(pa_type):
values = self._to_datetimearray()
fallback_values = self._to_datetimearray()
elif pa.types.is_duration(pa_type):
values = self._to_timedeltaarray()
fallback_values = self._to_timedeltaarray()
else:
values = self._to_masked()
fallback_values = self._to_masked()

result = values._groupby_op(
fallback_result = fallback_values._groupby_op(
how=how,
has_dropped_na=has_dropped_na,
min_count=min_count,
ngroups=ngroups,
ids=ids,
**kwargs,
)
if isinstance(result, np.ndarray):
return result
elif isinstance(result, BaseMaskedArray):
pa_result = result.__arrow_array__()
if isinstance(fallback_result, np.ndarray):
return fallback_result
elif isinstance(fallback_result, BaseMaskedArray):
pa_result = fallback_result.__arrow_array__()
return self._from_pyarrow_array(pa_result)
else:
# DatetimeArray, TimedeltaArray
pa_result = pa.array(result)
pa_result = pa.array(fallback_result)
return self._from_pyarrow_array(pa_result)

def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
Expand Down
Loading
Loading