Skip to content
106 changes: 106 additions & 0 deletions paimon-python/pypaimon/tests/reader_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,112 @@ def _test_value_stats_cols_case(self, manifest_manager, table, value_stats_cols,

self.assertEqual(read_entry.file.value_stats.null_counts, null_counts)

def test_primary_key_value_stats(self):
pa_schema = pa.schema([
('id', pa.int64()),
('name', pa.string()),
('price', pa.float64()),
('category', pa.string())
])
schema = Schema.from_pyarrow_schema(
pa_schema,
primary_keys=['id'],
options={'metadata.stats-mode': 'full', 'bucket': '2'}
)
self.catalog.create_table('default.test_pk_value_stats', schema, False)
table = self.catalog.get_table('default.test_pk_value_stats')

test_data = pa.Table.from_pydict({
'id': [1, 2, 3, 4, 5],
'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve'],
'price': [10.5, 20.3, 30.7, 40.1, 50.9],
'category': ['A', 'B', 'C', 'D', 'E']
}, schema=pa_schema)

write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()
writer.write_arrow(test_data)
commit_messages = writer.prepare_commit()
commit = write_builder.new_commit()
commit.commit(commit_messages)
writer.close()

read_builder = table.new_read_builder()
table_scan = read_builder.new_scan()
latest_snapshot = SnapshotManager(table).get_latest_snapshot()
manifest_files = table_scan.starting_scanner.manifest_list_manager.read_all(latest_snapshot)
manifest_entries = table_scan.starting_scanner.manifest_file_manager.read(
manifest_files[0].file_name,
lambda row: table_scan.starting_scanner._filter_manifest_entry(row),
False
)

self.assertGreater(len(manifest_entries), 0, "Should have at least one manifest entry")
file_meta = manifest_entries[0].file

key_stats = file_meta.key_stats
self.assertIsNotNone(key_stats, "key_stats should not be None")
self.assertGreater(key_stats.min_values.arity, 0, "key_stats should contain key fields")
self.assertEqual(key_stats.min_values.arity, 1, "key_stats should contain exactly 1 key field (id)")

value_stats = file_meta.value_stats
self.assertIsNotNone(value_stats, "value_stats should not be None")

if file_meta.value_stats_cols is None:
expected_value_fields = ['name', 'price', 'category']
self.assertGreaterEqual(value_stats.min_values.arity, len(expected_value_fields),
f"value_stats should contain at least {len(expected_value_fields)} value fields")
else:
self.assertNotIn('id', file_meta.value_stats_cols,
"Key field 'id' should NOT be in value_stats_cols")

expected_value_fields = ['name', 'price', 'category']
self.assertTrue(set(expected_value_fields).issubset(set(file_meta.value_stats_cols)),
f"value_stats_cols should contain value fields: {expected_value_fields}, "
f"but got: {file_meta.value_stats_cols}")

expected_arity = len(file_meta.value_stats_cols)
self.assertEqual(value_stats.min_values.arity, expected_arity,
f"value_stats should contain {expected_arity} fields (matching value_stats_cols), "
f"but got {value_stats.min_values.arity}")
self.assertEqual(value_stats.max_values.arity, expected_arity,
f"value_stats should contain {expected_arity} fields (matching value_stats_cols), "
f"but got {value_stats.max_values.arity}")
self.assertEqual(len(value_stats.null_counts), expected_arity,
f"value_stats null_counts should have {expected_arity} elements, "
f"but got {len(value_stats.null_counts)}")

self.assertEqual(value_stats.min_values.arity, len(file_meta.value_stats_cols),
f"value_stats.min_values.arity ({value_stats.min_values.arity}) must match "
f"value_stats_cols length ({len(file_meta.value_stats_cols)})")

for field_name in file_meta.value_stats_cols:
is_system_field = (field_name.startswith('_KEY_') or
field_name in ['_SEQUENCE_NUMBER', '_VALUE_KIND', '_ROW_ID'])
self.assertFalse(is_system_field,
f"value_stats_cols should not contain system field: {field_name}")

value_stats_fields = table_scan.starting_scanner.manifest_file_manager._get_value_stats_fields(
{'_VALUE_STATS_COLS': file_meta.value_stats_cols},
table.fields
)
min_value_stats = GenericRowDeserializer.from_bytes(
value_stats.min_values.data,
value_stats_fields
).values
max_value_stats = GenericRowDeserializer.from_bytes(
value_stats.max_values.data,
value_stats_fields
).values

self.assertEqual(len(min_value_stats), 3, "min_value_stats should have 3 values")
self.assertEqual(len(max_value_stats), 3, "max_value_stats should have 3 values")

actual_data = read_builder.new_read().to_arrow(table_scan.plan().splits())
self.assertEqual(actual_data.num_rows, 5, "Should have 5 rows")
actual_ids = sorted(actual_data.column('id').to_pylist())
self.assertEqual(actual_ids, [1, 2, 3, 4, 5], "All IDs should be present")

def test_split_target_size(self):
"""Test source.split.target-size configuration effect on split generation."""
from pypaimon.common.options.core_options import CoreOptions
Expand Down
19 changes: 3 additions & 16 deletions paimon-python/pypaimon/write/writer/data_blob_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,7 @@ def _create_data_file_meta(self, file_name: str, file_path: str, data: pa.Table,
# Column stats (only for normal columns)
metadata_stats_enabled = self.options.metadata_stats_enabled()
stats_columns = self.normal_columns if metadata_stats_enabled else []
column_stats = {
field.name: self._get_column_stats(data, field.name)
for field in stats_columns
}

min_value_stats = [column_stats[field.name]['min_values'] for field in stats_columns]
max_value_stats = [column_stats[field.name]['max_values'] for field in stats_columns]
value_null_counts = [column_stats[field.name]['null_counts'] for field in stats_columns]
value_stats = self._collect_value_stats(data, stats_columns)

self.sequence_generator.start = self.sequence_generator.current

Expand All @@ -293,14 +286,8 @@ def _create_data_file_meta(self, file_name: str, file_path: str, data: pa.Table,
row_count=data.num_rows,
min_key=GenericRow([], []),
max_key=GenericRow([], []),
key_stats=SimpleStats(
GenericRow([], []),
GenericRow([], []),
[]),
value_stats=SimpleStats(
GenericRow(min_value_stats, stats_columns),
GenericRow(max_value_stats, stats_columns),
value_null_counts),
key_stats=SimpleStats.empty_stats(),
value_stats=value_stats,
min_sequence_number=-1,
max_sequence_number=-1,
schema_id=self.table.table_schema.id,
Expand Down
65 changes: 46 additions & 19 deletions paimon-python/pypaimon/write/writer/data_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pypaimon.schema.data_types import PyarrowFieldParser
from pypaimon.table.bucket_mode import BucketMode
from pypaimon.table.row.generic_row import GenericRow
from pypaimon.table.special_fields import SpecialFields


class DataWriter(ABC):
Expand Down Expand Up @@ -198,36 +199,36 @@ def _write_data_to_file(self, data: pa.Table):
field.name: self._get_column_stats(data, field.name)
for field in stats_fields
}
data_fields = stats_fields if value_stats_enabled else []
min_value_stats = [column_stats[field.name]['min_values'] for field in data_fields]
max_value_stats = [column_stats[field.name]['max_values'] for field in data_fields]
value_null_counts = [column_stats[field.name]['null_counts'] for field in data_fields]
key_fields = self.trimmed_primary_keys_fields
min_key_stats = [column_stats[field.name]['min_values'] for field in key_fields]
max_key_stats = [column_stats[field.name]['max_values'] for field in key_fields]
key_null_counts = [column_stats[field.name]['null_counts'] for field in key_fields]
if not all(count == 0 for count in key_null_counts):
key_field_names = {field.name for field in key_fields}
value_fields = self._filter_value_fields(stats_fields, key_field_names) if value_stats_enabled else []

key_stats = self._collect_value_stats(data, key_fields, column_stats)
if not all(count == 0 for count in key_stats.null_counts):
raise RuntimeError("Primary key should not be null")

value_stats = self._collect_value_stats(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will collect twice for key fields.

data, value_fields, column_stats if value_stats_enabled else None)

min_seq = self.sequence_generator.start
max_seq = self.sequence_generator.current
self.sequence_generator.start = self.sequence_generator.current
if value_stats_enabled:
all_table_value_fields = self._filter_value_fields(self.table.fields, key_field_names)
if len(value_fields) == len(all_table_value_fields):
value_stats_cols = None
else:
value_stats_cols = [field.name for field in value_fields]
else:
value_stats_cols = []
self.committed_files.append(DataFileMeta.create(
file_name=file_name,
file_size=self.file_io.get_file_size(file_path),
row_count=data.num_rows,
min_key=GenericRow(min_key, self.trimmed_primary_keys_fields),
max_key=GenericRow(max_key, self.trimmed_primary_keys_fields),
key_stats=SimpleStats(
GenericRow(min_key_stats, self.trimmed_primary_keys_fields),
GenericRow(max_key_stats, self.trimmed_primary_keys_fields),
key_null_counts,
),
value_stats=SimpleStats(
GenericRow(min_value_stats, data_fields),
GenericRow(max_value_stats, data_fields),
value_null_counts,
),
key_stats=key_stats,
value_stats=value_stats,
min_sequence_number=min_seq,
max_sequence_number=max_seq,
schema_id=self.table.table_schema.id,
Expand All @@ -236,14 +237,19 @@ def _write_data_to_file(self, data: pa.Table):
creation_time=Timestamp.now(),
delete_row_count=0,
file_source=0,
value_stats_cols=None if value_stats_enabled else [],
value_stats_cols=value_stats_cols,
external_path=external_path_str, # Set external path if using external paths
first_row_id=None,
write_cols=self.write_cols,
# None means all columns in the table have been written
file_path=file_path,
))

def _filter_value_fields(self, fields: List, key_field_names: set) -> List:
return [field for field in fields if
field.name not in key_field_names and
not (field.name.startswith('_KEY_') or SpecialFields.is_system_field(field.name))]

def _generate_file_path(self, file_name: str) -> str:
if self.external_path_provider:
external_path = self.external_path_provider.get_next_external_data_path(file_name)
Expand All @@ -253,6 +259,27 @@ def _generate_file_path(self, file_name: str) -> str:
bucket_path = self.path_factory.bucket_path(self.partition, self.bucket)
return f"{bucket_path.rstrip('/')}/{file_name}"

def _collect_value_stats(self, data: pa.Table, value_fields: List,
column_stats: Optional[Dict[str, Dict]] = None) -> SimpleStats:
if not value_fields:
return SimpleStats.empty_stats()

if column_stats is None:
column_stats = {
field.name: self._get_column_stats(data, field.name)
for field in value_fields
}

min_value_stats = [column_stats[field.name]['min_values'] for field in value_fields]
max_value_stats = [column_stats[field.name]['max_values'] for field in value_fields]
value_null_counts = [column_stats[field.name]['null_counts'] for field in value_fields]

return SimpleStats(
GenericRow(min_value_stats, value_fields),
GenericRow(max_value_stats, value_fields),
value_null_counts
)

@staticmethod
def _find_optimal_split_point(data: pa.RecordBatch, target_size: int) -> int:
total_rows = data.num_rows
Expand Down