diff --git a/paimon-python/pypaimon/tests/reader_base_test.py b/paimon-python/pypaimon/tests/reader_base_test.py index 92a275585ccc..e81fd182b432 100644 --- a/paimon-python/pypaimon/tests/reader_base_test.py +++ b/paimon-python/pypaimon/tests/reader_base_test.py @@ -700,6 +700,112 @@ def _test_value_stats_cols_case(self, manifest_manager, table, value_stats_cols, self.assertEqual(read_entry.file.value_stats.null_counts, null_counts) + def test_primary_key_value_stats(self): + pa_schema = pa.schema([ + ('id', pa.int64()), + ('name', pa.string()), + ('price', pa.float64()), + ('category', pa.string()) + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + primary_keys=['id'], + options={'metadata.stats-mode': 'full', 'bucket': '2'} + ) + self.catalog.create_table('default.test_pk_value_stats', schema, False) + table = self.catalog.get_table('default.test_pk_value_stats') + + test_data = pa.Table.from_pydict({ + 'id': [1, 2, 3, 4, 5], + 'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve'], + 'price': [10.5, 20.3, 30.7, 40.1, 50.9], + 'category': ['A', 'B', 'C', 'D', 'E'] + }, schema=pa_schema) + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + commit = write_builder.new_commit() + commit.commit(commit_messages) + writer.close() + + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan() + latest_snapshot = SnapshotManager(table).get_latest_snapshot() + manifest_files = table_scan.starting_scanner.manifest_list_manager.read_all(latest_snapshot) + manifest_entries = table_scan.starting_scanner.manifest_file_manager.read( + manifest_files[0].file_name, + lambda row: table_scan.starting_scanner._filter_manifest_entry(row), + False + ) + + self.assertGreater(len(manifest_entries), 0, "Should have at least one manifest entry") + file_meta = manifest_entries[0].file + + key_stats = file_meta.key_stats + self.assertIsNotNone(key_stats, "key_stats should not be None") + self.assertGreater(key_stats.min_values.arity, 0, "key_stats should contain key fields") + self.assertEqual(key_stats.min_values.arity, 1, "key_stats should contain exactly 1 key field (id)") + + value_stats = file_meta.value_stats + self.assertIsNotNone(value_stats, "value_stats should not be None") + + if file_meta.value_stats_cols is None: + expected_value_fields = ['name', 'price', 'category'] + self.assertGreaterEqual(value_stats.min_values.arity, len(expected_value_fields), + f"value_stats should contain at least {len(expected_value_fields)} value fields") + else: + self.assertNotIn('id', file_meta.value_stats_cols, + "Key field 'id' should NOT be in value_stats_cols") + + expected_value_fields = ['name', 'price', 'category'] + self.assertTrue(set(expected_value_fields).issubset(set(file_meta.value_stats_cols)), + f"value_stats_cols should contain value fields: {expected_value_fields}, " + f"but got: {file_meta.value_stats_cols}") + + expected_arity = len(file_meta.value_stats_cols) + self.assertEqual(value_stats.min_values.arity, expected_arity, + f"value_stats should contain {expected_arity} fields (matching value_stats_cols), " + f"but got {value_stats.min_values.arity}") + self.assertEqual(value_stats.max_values.arity, expected_arity, + f"value_stats should contain {expected_arity} fields (matching value_stats_cols), " + f"but got {value_stats.max_values.arity}") + self.assertEqual(len(value_stats.null_counts), expected_arity, + f"value_stats null_counts should have {expected_arity} elements, " + f"but got {len(value_stats.null_counts)}") + + self.assertEqual(value_stats.min_values.arity, len(file_meta.value_stats_cols), + f"value_stats.min_values.arity ({value_stats.min_values.arity}) must match " + f"value_stats_cols length ({len(file_meta.value_stats_cols)})") + + for field_name in file_meta.value_stats_cols: + is_system_field = (field_name.startswith('_KEY_') or + field_name in ['_SEQUENCE_NUMBER', '_VALUE_KIND', '_ROW_ID']) + self.assertFalse(is_system_field, + f"value_stats_cols should not contain system field: {field_name}") + + value_stats_fields = table_scan.starting_scanner.manifest_file_manager._get_value_stats_fields( + {'_VALUE_STATS_COLS': file_meta.value_stats_cols}, + table.fields + ) + min_value_stats = GenericRowDeserializer.from_bytes( + value_stats.min_values.data, + value_stats_fields + ).values + max_value_stats = GenericRowDeserializer.from_bytes( + value_stats.max_values.data, + value_stats_fields + ).values + + self.assertEqual(len(min_value_stats), 3, "min_value_stats should have 3 values") + self.assertEqual(len(max_value_stats), 3, "max_value_stats should have 3 values") + + actual_data = read_builder.new_read().to_arrow(table_scan.plan().splits()) + self.assertEqual(actual_data.num_rows, 5, "Should have 5 rows") + actual_ids = sorted(actual_data.column('id').to_pylist()) + self.assertEqual(actual_ids, [1, 2, 3, 4, 5], "All IDs should be present") + def test_split_target_size(self): """Test source.split.target-size configuration effect on split generation.""" from pypaimon.common.options.core_options import CoreOptions diff --git a/paimon-python/pypaimon/write/writer/data_blob_writer.py b/paimon-python/pypaimon/write/writer/data_blob_writer.py index eaf2b9483cd7..8cdd7428dcbc 100644 --- a/paimon-python/pypaimon/write/writer/data_blob_writer.py +++ b/paimon-python/pypaimon/write/writer/data_blob_writer.py @@ -276,14 +276,7 @@ def _create_data_file_meta(self, file_name: str, file_path: str, data: pa.Table, # Column stats (only for normal columns) metadata_stats_enabled = self.options.metadata_stats_enabled() stats_columns = self.normal_columns if metadata_stats_enabled else [] - column_stats = { - field.name: self._get_column_stats(data, field.name) - for field in stats_columns - } - - min_value_stats = [column_stats[field.name]['min_values'] for field in stats_columns] - max_value_stats = [column_stats[field.name]['max_values'] for field in stats_columns] - value_null_counts = [column_stats[field.name]['null_counts'] for field in stats_columns] + value_stats = self._collect_value_stats(data, stats_columns) self.sequence_generator.start = self.sequence_generator.current @@ -293,14 +286,8 @@ def _create_data_file_meta(self, file_name: str, file_path: str, data: pa.Table, row_count=data.num_rows, min_key=GenericRow([], []), max_key=GenericRow([], []), - key_stats=SimpleStats( - GenericRow([], []), - GenericRow([], []), - []), - value_stats=SimpleStats( - GenericRow(min_value_stats, stats_columns), - GenericRow(max_value_stats, stats_columns), - value_null_counts), + key_stats=SimpleStats.empty_stats(), + value_stats=value_stats, min_sequence_number=-1, max_sequence_number=-1, schema_id=self.table.table_schema.id, diff --git a/paimon-python/pypaimon/write/writer/data_writer.py b/paimon-python/pypaimon/write/writer/data_writer.py index 73609ed91282..eba7802d7ada 100644 --- a/paimon-python/pypaimon/write/writer/data_writer.py +++ b/paimon-python/pypaimon/write/writer/data_writer.py @@ -29,6 +29,7 @@ from pypaimon.schema.data_types import PyarrowFieldParser from pypaimon.table.bucket_mode import BucketMode from pypaimon.table.row.generic_row import GenericRow +from pypaimon.table.special_fields import SpecialFields class DataWriter(ABC): @@ -198,36 +199,36 @@ def _write_data_to_file(self, data: pa.Table): field.name: self._get_column_stats(data, field.name) for field in stats_fields } - data_fields = stats_fields if value_stats_enabled else [] - min_value_stats = [column_stats[field.name]['min_values'] for field in data_fields] - max_value_stats = [column_stats[field.name]['max_values'] for field in data_fields] - value_null_counts = [column_stats[field.name]['null_counts'] for field in data_fields] key_fields = self.trimmed_primary_keys_fields - min_key_stats = [column_stats[field.name]['min_values'] for field in key_fields] - max_key_stats = [column_stats[field.name]['max_values'] for field in key_fields] - key_null_counts = [column_stats[field.name]['null_counts'] for field in key_fields] - if not all(count == 0 for count in key_null_counts): + key_field_names = {field.name for field in key_fields} + value_fields = self._filter_value_fields(stats_fields, key_field_names) if value_stats_enabled else [] + + key_stats = self._collect_value_stats(data, key_fields, column_stats) + if not all(count == 0 for count in key_stats.null_counts): raise RuntimeError("Primary key should not be null") + + value_stats = self._collect_value_stats( + data, value_fields, column_stats if value_stats_enabled else None) min_seq = self.sequence_generator.start max_seq = self.sequence_generator.current self.sequence_generator.start = self.sequence_generator.current + if value_stats_enabled: + all_table_value_fields = self._filter_value_fields(self.table.fields, key_field_names) + if len(value_fields) == len(all_table_value_fields): + value_stats_cols = None + else: + value_stats_cols = [field.name for field in value_fields] + else: + value_stats_cols = [] self.committed_files.append(DataFileMeta.create( file_name=file_name, file_size=self.file_io.get_file_size(file_path), row_count=data.num_rows, min_key=GenericRow(min_key, self.trimmed_primary_keys_fields), max_key=GenericRow(max_key, self.trimmed_primary_keys_fields), - key_stats=SimpleStats( - GenericRow(min_key_stats, self.trimmed_primary_keys_fields), - GenericRow(max_key_stats, self.trimmed_primary_keys_fields), - key_null_counts, - ), - value_stats=SimpleStats( - GenericRow(min_value_stats, data_fields), - GenericRow(max_value_stats, data_fields), - value_null_counts, - ), + key_stats=key_stats, + value_stats=value_stats, min_sequence_number=min_seq, max_sequence_number=max_seq, schema_id=self.table.table_schema.id, @@ -236,7 +237,7 @@ def _write_data_to_file(self, data: pa.Table): creation_time=Timestamp.now(), delete_row_count=0, file_source=0, - value_stats_cols=None if value_stats_enabled else [], + value_stats_cols=value_stats_cols, external_path=external_path_str, # Set external path if using external paths first_row_id=None, write_cols=self.write_cols, @@ -244,6 +245,11 @@ def _write_data_to_file(self, data: pa.Table): file_path=file_path, )) + def _filter_value_fields(self, fields: List, key_field_names: set) -> List: + return [field for field in fields if + field.name not in key_field_names and + not (field.name.startswith('_KEY_') or SpecialFields.is_system_field(field.name))] + def _generate_file_path(self, file_name: str) -> str: if self.external_path_provider: external_path = self.external_path_provider.get_next_external_data_path(file_name) @@ -253,6 +259,27 @@ def _generate_file_path(self, file_name: str) -> str: bucket_path = self.path_factory.bucket_path(self.partition, self.bucket) return f"{bucket_path.rstrip('/')}/{file_name}" + def _collect_value_stats(self, data: pa.Table, value_fields: List, + column_stats: Optional[Dict[str, Dict]] = None) -> SimpleStats: + if not value_fields: + return SimpleStats.empty_stats() + + if column_stats is None: + column_stats = { + field.name: self._get_column_stats(data, field.name) + for field in value_fields + } + + min_value_stats = [column_stats[field.name]['min_values'] for field in value_fields] + max_value_stats = [column_stats[field.name]['max_values'] for field in value_fields] + value_null_counts = [column_stats[field.name]['null_counts'] for field in value_fields] + + return SimpleStats( + GenericRow(min_value_stats, value_fields), + GenericRow(max_value_stats, value_fields), + value_null_counts + ) + @staticmethod def _find_optimal_split_point(data: pa.RecordBatch, target_size: int) -> int: total_rows = data.num_rows