From b9c8f093f48698b33f5734b635af58fae90f8710 Mon Sep 17 00:00:00 2001 From: Zhuo Wang Date: Fri, 26 Dec 2025 09:56:11 +0800 Subject: [PATCH 1/2] feat: impl metrics config --- src/iceberg/metrics_config.cc | 228 +++++++++++++++++++ src/iceberg/metrics_config.h | 96 ++++++++ src/iceberg/sort_order.cc | 14 ++ src/iceberg/sort_order.h | 4 + src/iceberg/test/metrics_config_test.cc | 280 ++++++++++++++++++++++-- src/iceberg/util/type_util.cc | 25 +++ src/iceberg/util/type_util.h | 14 ++ 7 files changed, 645 insertions(+), 16 deletions(-) diff --git a/src/iceberg/metrics_config.cc b/src/iceberg/metrics_config.cc index f78cadf2b..2d56b804e 100644 --- a/src/iceberg/metrics_config.cc +++ b/src/iceberg/metrics_config.cc @@ -19,15 +19,235 @@ #include "iceberg/metrics_config.h" +#include #include #include #include "iceberg/result.h" #include "iceberg/schema.h" +#include "iceberg/sort_order.h" +#include "iceberg/table.h" #include "iceberg/table_properties.h" +#include "iceberg/util/checked_cast.h" +#include "iceberg/util/type_util.h" namespace iceberg { +namespace { + +constexpr std::string_view kNoneName = "none"; +constexpr std::string_view kCountsName = "counts"; +constexpr std::string_view kFullName = "full"; +constexpr std::string_view kTruncatePrefix = "truncate("; +constexpr int32_t kDefaultTruncateLength = 16; +const std::shared_ptr kDefaultMetricsMode = + std::make_shared(kDefaultTruncateLength); + +std::shared_ptr SortedColumnDefaultMode( + std::shared_ptr default_mode) { + if (default_mode->kind() == MetricsMode::Kind::kNone || + default_mode->kind() == MetricsMode::Kind::kCounts) { + return kDefaultMetricsMode; + } else { + return std::move(default_mode); + } +} + +int32_t MaxInferredColumns(const TableProperties& properties) { + int32_t max_inferred_columns = + properties.Get(TableProperties::kMetricsMaxInferredColumnDefaults); + if (max_inferred_columns < 0) { + // fallback to default + return TableProperties::kMetricsMaxInferredColumnDefaults.value(); + } + return max_inferred_columns; +} + +Result> ParseMode(const std::string& mode, + std::shared_ptr fallback) { + if (auto metrics_mode = MetricsMode::FromString(mode); metrics_mode.has_value()) { + return std::move(metrics_mode.value()); + } + return std::move(fallback); +} + +} // namespace + +const std::shared_ptr& MetricsMode::None() { + static const std::shared_ptr none = std::make_shared(); + return none; +} + +const std::shared_ptr& MetricsMode::Counts() { + static const std::shared_ptr counts = + std::make_shared(); + return counts; +} + +const std::shared_ptr& MetricsMode::Full() { + static const std::shared_ptr full = std::make_shared(); + return full; +} + +const std::shared_ptr& MetricsMode::Truncate() { + return kDefaultMetricsMode; +} + +Result> MetricsMode::FromString(const std::string& mode) { + if (StringUtils::EqualsIgnoreCase(mode, kNoneName)) { + return MetricsMode::None(); + } else if (StringUtils::EqualsIgnoreCase(mode, kCountsName)) { + return MetricsMode::Counts(); + } else if (StringUtils::EqualsIgnoreCase(mode, kFullName)) { + return MetricsMode::Full(); + } + + if (mode.starts_with(kTruncatePrefix) && mode.ends_with(")")) { + int32_t length; + auto [ptr, ec] = std::from_chars(mode.data() + 9 /* "truncate(" length */, + mode.data() + mode.size() - 1, length); + if (ec != std::errc{}) { + return InvalidArgument("Invalid truncate mode: {}", mode); + } + if (length == kDefaultTruncateLength) { + return kDefaultMetricsMode; + } + return TruncateMetricsMode::Make(length); + } + return InvalidArgument("Invalid metrics mode: {}", mode); +} + +std::string NoneMetricsMode::ToString() const { return std::string(kNoneName); } +std::string CountsMetricsMode::ToString() const { return std::string(kCountsName); } +std::string FullMetricsMode::ToString() const { return std::string(kFullName); } +std::string TruncateMetricsMode::ToString() const { + return std::format("truncate({})", length_); +} + +Result> TruncateMetricsMode::Make(int32_t length) { + ICEBERG_PRECHECK(length > 0, "Truncate length should be positive."); + return std::make_shared(length); +} + +MetricsConfig::MetricsConfig( + std::unordered_map> column_modes, + std::shared_ptr default_mode) + : column_modes_(std::move(column_modes)), default_mode_(std::move(default_mode)) {} + +const std::shared_ptr& MetricsConfig::Default() { + static const auto default_config = std::make_shared( + std::unordered_map>{}, + kDefaultMetricsMode); + return default_config; +} + +Result> MetricsConfig::Make(std::shared_ptr table) { + ICEBERG_PRECHECK(table != nullptr, "table cannot be null"); + ICEBERG_ASSIGN_OR_RAISE(auto schema, table->schema()); + + auto sort_order = table->sort_order(); + return MakeInternal( + table->properties(), *schema, + sort_order.has_value() ? *sort_order.value() : *SortOrder::Unsorted()); +} + +Result> MetricsConfig::MakeInternal( + const TableProperties& props, const Schema& schema, const SortOrder& order) { + std::unordered_map> column_modes; + + std::shared_ptr default_mode = kDefaultMetricsMode; + if (props.configs().contains(TableProperties::kDefaultWriteMetricsMode.key())) { + std::string configured_metrics_mode = + props.Get(TableProperties::kDefaultWriteMetricsMode); + ICEBERG_ASSIGN_OR_RAISE(default_mode, + ParseMode(configured_metrics_mode, kDefaultMetricsMode)); + } else { + int32_t max_inferred_columns = MaxInferredColumns(props); + GetProjectedIdsVisitor visitor(true); + ICEBERG_RETURN_UNEXPECTED( + visitor.Visit(internal::checked_cast(schema))); + int32_t projected_columns = visitor.Finish().size(); + if (max_inferred_columns < projected_columns) { + ICEBERG_ASSIGN_OR_RAISE(auto limit_field_ids, + LimitFieldIds(schema, max_inferred_columns)); + for (auto id : limit_field_ids) { + ICEBERG_ASSIGN_OR_RAISE(auto column_name, schema.FindColumnNameById(id)); + ICEBERG_CHECK(column_name.has_value(), "Field id {} not found in schema", id); + column_modes[std::string(column_name.value())] = kDefaultMetricsMode; + } + // All other columns don't use metrics + default_mode = MetricsMode::None(); + } + } + + // First set sorted column with sorted column default (can be overridden by user) + auto sorted_col_default_mode = SortedColumnDefaultMode(default_mode); + auto sorted_columns = SortOrder::OrderPreservingSortedColumns(schema, order); + for (const auto& sc : sorted_columns) { + column_modes[std::string(sc)] = sorted_col_default_mode; + } + + // Handle user overrides of defaults + for (const auto& prop : props.configs()) { + if (prop.first.starts_with(TableProperties::kMetricModeColumnConfPrefix)) { + std::string column_alias = + prop.first.substr(TableProperties::kMetricModeColumnConfPrefix.size()); + ICEBERG_ASSIGN_OR_RAISE(auto mode, ParseMode(prop.second, default_mode)); + column_modes[std::move(column_alias)] = mode; + } + } + + return std::make_shared(std::move(column_modes), + std::move(default_mode)); +} + +Result> MetricsConfig::LimitFieldIds(const Schema& schema, + int32_t limit) { + class Visitor { + public: + explicit Visitor(int32_t limit) : limit_(limit) {} + + Status Visit(const std::shared_ptr& type) { + if (type->is_nested()) { + return Visit(internal::checked_cast(*type)); + } + return {}; + } + + Status Visit(const NestedType& type) { + for (auto& field : type.fields()) { + if (!ShouldContinue()) { + break; + } + if (field.type()->is_primitive()) { + ids_.insert(field.field_id()); + } + } + + for (auto& field : type.fields()) { + if (ShouldContinue()) { + ICEBERG_RETURN_UNEXPECTED(Visit(field.type())); + } + } + return {}; + } + + std::unordered_set Finish() { return ids_; } + + private: + bool ShouldContinue() { return ids_.size() < limit_; } + + private: + std::unordered_set ids_; + int32_t limit_; + }; + + Visitor visitor(limit); + ICEBERG_RETURN_UNEXPECTED( + visitor.Visit(internal::checked_cast(schema))); + return visitor.Finish(); +} + Status MetricsConfig::VerifyReferencedColumns( const std::unordered_map& updates, const Schema& schema) { for (const auto& [key, value] : updates) { @@ -47,4 +267,12 @@ Status MetricsConfig::VerifyReferencedColumns( return {}; } +std::shared_ptr MetricsConfig::ColumnMode( + const std::string& column_name) const { + if (auto it = column_modes_.find(column_name); it != column_modes_.end()) { + return it->second; + } + return default_mode_; +} + } // namespace iceberg diff --git a/src/iceberg/metrics_config.h b/src/iceberg/metrics_config.h index c42539d64..f0ec9cdac 100644 --- a/src/iceberg/metrics_config.h +++ b/src/iceberg/metrics_config.h @@ -22,24 +22,120 @@ /// \file iceberg/metrics_config.h /// \brief Metrics configuration for Iceberg tables +#include #include #include +#include #include "iceberg/iceberg_export.h" #include "iceberg/result.h" #include "iceberg/type_fwd.h" +#include "iceberg/util/formattable.h" namespace iceberg { +class ICEBERG_EXPORT MetricsMode : public util::Formattable { + public: + enum class Kind : uint8_t { + kNone, + kCounts, + kTruncate, + kFull, + }; + + static Result> FromString(const std::string& mode); + + static const std::shared_ptr& None(); + static const std::shared_ptr& Counts(); + static const std::shared_ptr& Truncate(); + static const std::shared_ptr& Full(); + + /// \brief Return the kind of this metrics mode. + virtual Kind kind() const = 0; + + std::string ToString() const override = 0; +}; + +class ICEBERG_EXPORT NoneMetricsMode : public MetricsMode { + public: + constexpr Kind kind() const override { return Kind::kNone; } + + std::string ToString() const override; +}; + +class ICEBERG_EXPORT CountsMetricsMode : public MetricsMode { + public: + constexpr Kind kind() const override { return Kind::kCounts; } + + std::string ToString() const override; +}; + +class ICEBERG_EXPORT TruncateMetricsMode : public MetricsMode { + public: + explicit TruncateMetricsMode(int32_t length) : length_(length) {} + + constexpr Kind kind() const override { return Kind::kTruncate; } + + std::string ToString() const override; + + static Result> Make(int32_t length); + + private: + const int32_t length_; +}; + +class ICEBERG_EXPORT FullMetricsMode : public MetricsMode { + public: + constexpr Kind kind() const override { return Kind::kFull; } + + std::string ToString() const override; +}; + /// \brief Configuration utilities for table metrics class ICEBERG_EXPORT MetricsConfig { public: + MetricsConfig( + std::unordered_map> column_modes, + std::shared_ptr default_mode); + + /// \brief Get the default metrics config. + static const std::shared_ptr& Default(); + + /// \brief Creates a metrics config from a table. + static Result> Make(std::shared_ptr
table); + + /// \brief Get `limit` num of primitive field ids from schema + static Result> LimitFieldIds(const Schema& schema, + int32_t limit); + /// \brief Verify that all referenced columns are valid /// \param updates The updates to verify /// \param schema The schema to verify against /// \return OK if all referenced columns are valid static Status VerifyReferencedColumns( const std::unordered_map& updates, const Schema& schema); + + /// \brief Get the metrics mode for a specific column + /// \param column_name The full name of the column + /// \return The metrics mode for the column + std::shared_ptr ColumnMode(const std::string& column_name) const; + + private: + /// \brief Generate a MetricsConfig for all columns based on overrides, schema, and sort + /// order. + /// + /// \param props will be read for metrics overrides (write.metadata.metrics.column.*) + /// and default(write.metadata.metrics.default) + /// \param schema table schema + /// \param order sort order columns, will be promoted to truncate(16) + /// \return metrics configuration + static Result> MakeInternal(const TableProperties& props, + const Schema& schema, + const SortOrder& order); + + private: + std::unordered_map> column_modes_; + std::shared_ptr default_mode_; }; } // namespace iceberg diff --git a/src/iceberg/sort_order.cc b/src/iceberg/sort_order.cc index fca138a6c..b317efb90 100644 --- a/src/iceberg/sort_order.cc +++ b/src/iceberg/sort_order.cc @@ -132,4 +132,18 @@ Result> SortOrder::Make(int32_t sort_id, return std::unique_ptr(new SortOrder(sort_id, std::move(fields))); } +std::unordered_set SortOrder::OrderPreservingSortedColumns( + const Schema& schema, const SortOrder& order) { + return order.fields() | std::views::filter([&schema](const SortField& field) { + return field.transform()->PreservesOrder(); + }) | + std::views::transform([&schema](const SortField& field) { + return schema.FindColumnNameById(field.source_id()) + .value_or(std::nullopt) + .value_or(""); + }) | + std::views::filter([](std::string_view name) { return !name.empty(); }) | + std::ranges::to>(); +} + } // namespace iceberg diff --git a/src/iceberg/sort_order.h b/src/iceberg/sort_order.h index 1e7285d32..7c9b799fb 100644 --- a/src/iceberg/sort_order.h +++ b/src/iceberg/sort_order.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "iceberg/iceberg_export.h" @@ -91,6 +92,9 @@ class ICEBERG_EXPORT SortOrder : public util::Formattable { static Result> Make(int32_t sort_id, std::vector fields); + static std::unordered_set OrderPreservingSortedColumns( + const Schema& schema, const SortOrder& order); + private: /// \brief Constructs a SortOrder instance. /// \param order_id The sort order id. diff --git a/src/iceberg/test/metrics_config_test.cc b/src/iceberg/test/metrics_config_test.cc index e6a8e0f51..bcd9421ce 100644 --- a/src/iceberg/test/metrics_config_test.cc +++ b/src/iceberg/test/metrics_config_test.cc @@ -28,30 +28,278 @@ #include "iceberg/result.h" #include "iceberg/schema.h" #include "iceberg/schema_field.h" +#include "iceberg/sort_order.h" +#include "iceberg/table.h" +#include "iceberg/table_metadata.h" #include "iceberg/table_properties.h" #include "iceberg/test/matchers.h" +#include "iceberg/test/mock_catalog.h" +#include "iceberg/test/mock_io.h" +#include "iceberg/transform.h" namespace iceberg { -class MetricsConfigTest : public ::testing::Test { - protected: - void SetUp() override { - SchemaField field1(1, "col1", std::make_shared(), false); - SchemaField field2(2, "col2", std::make_shared(), true); - SchemaField field3(3, "col3", std::make_shared(), false); - schema_ = - std::make_unique(std::vector{field1, field2, field3}, 100); +TEST(MetricsConfigTest, MetricsMode) { + EXPECT_EQ(MetricsMode::Kind::kNone, MetricsMode::None()->kind()); + EXPECT_EQ(MetricsMode::Kind::kCounts, MetricsMode::Counts()->kind()); + EXPECT_EQ(MetricsMode::Kind::kFull, MetricsMode::Full()->kind()); + EXPECT_EQ(MetricsMode::Kind::kTruncate, MetricsMode::Truncate()->kind()); + + EXPECT_EQ("none", MetricsMode::None()->ToString()); + EXPECT_EQ("counts", MetricsMode::Counts()->ToString()); + EXPECT_EQ("full", MetricsMode::Full()->ToString()); + EXPECT_EQ("truncate(16)", MetricsMode::Truncate()->ToString()); + + EXPECT_EQ(MetricsMode::Kind::kNone, MetricsMode::FromString("none").value()->kind()); + EXPECT_EQ(MetricsMode::Kind::kCounts, + MetricsMode::FromString("counts").value()->kind()); + EXPECT_EQ(MetricsMode::Kind::kFull, MetricsMode::FromString("full").value()->kind()); + EXPECT_EQ(MetricsMode::Kind::kTruncate, + MetricsMode::FromString("truncate(32)").value()->kind()); + + EXPECT_EQ("none", MetricsMode::FromString("none").value()->ToString()); + EXPECT_EQ("counts", MetricsMode::FromString("counts").value()->ToString()); + EXPECT_EQ("full", MetricsMode::FromString("full").value()->ToString()); + EXPECT_EQ("truncate(32)", MetricsMode::FromString("truncate(32)").value()->ToString()); + + auto result = MetricsMode::FromString("truncate(abc)"); + EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); + EXPECT_THAT(result, HasErrorMessage("Invalid truncate mode")); + + result = MetricsMode::FromString("truncate(-1)"); + EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); + EXPECT_THAT(result, HasErrorMessage("Truncate length should be positive")); + + result = MetricsMode::FromString("invalid"); + EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); + EXPECT_THAT(result, HasErrorMessage("Invalid metrics mode")); +} + +TEST(MetricsConfigTest, ForTable) { + { + // table is nullptr + auto result = MetricsConfig::Make(nullptr); + EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); + EXPECT_THAT(result, HasErrorMessage("table cannot be null")); + } + + auto io = std::make_shared(); + auto catalog = std::make_shared(); + auto schema = std::make_shared( + std::vector{SchemaField::MakeRequired(1, "id", int64()), + SchemaField::MakeOptional(2, "name", string()), + SchemaField::MakeOptional(3, "addr", string())}, + 1); + TableIdentifier ident{.ns = Namespace{.levels = {"db"}}, .name = "t"}; + + { + // Default + auto metadata = std::make_shared( + TableMetadata{.format_version = 2, .schemas = {schema}, .current_schema_id = 1}); + ICEBERG_UNWRAP_OR_FAIL( + auto table, Table::Make(ident, metadata, "s3://bucket/meta.json", io, catalog)); + + ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(table)); + auto mode = config->ColumnMode("id"); + EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind()); + EXPECT_EQ("truncate(16)", mode->ToString()); + + mode = config->ColumnMode("name"); + EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind()); + EXPECT_EQ("truncate(16)", mode->ToString()); + + mode = config->ColumnMode("addr"); + EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind()); + EXPECT_EQ("truncate(16)", mode->ToString()); } - std::unique_ptr schema_; -}; + { + // Custom metrics mode by set default metrics mode properties + auto metadata = std::make_shared( + TableMetadata{.format_version = 2, + .schemas = {schema}, + .current_schema_id = 1, + .properties = TableProperties::FromMap( + {{TableProperties::kDefaultWriteMetricsMode.key(), "full"}})}); + ICEBERG_UNWRAP_OR_FAIL( + auto table, Table::Make(ident, metadata, "s3://bucket/meta.json", io, catalog)); + + ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(table)); + auto mode = config->ColumnMode("id"); + EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind()); + EXPECT_EQ("full", mode->ToString()); + + mode = config->ColumnMode("name"); + EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind()); + EXPECT_EQ("full", mode->ToString()); + + mode = config->ColumnMode("addr"); + EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind()); + EXPECT_EQ("full", mode->ToString()); + } + + { + // Custom metrics mode by set column's metrics mode + ICEBERG_UNWRAP_OR_FAIL( + std::shared_ptr sort_order, + SortOrder::Make(*schema, 1, + std::vector( + {SortField(1, Transform::Identity(), + SortDirection::kAscending, NullOrder::kLast)}))); + + auto metadata = std::make_shared(TableMetadata{ + .format_version = 2, + .schemas = {schema}, + .current_schema_id = 1, + + .properties = TableProperties::FromMap( + {{TableProperties::kDefaultWriteMetricsMode.key(), "none"}, + {TableProperties::kMetricsMaxInferredColumnDefaults.key(), "2"}, + {std::string(TableProperties::kMetricModeColumnConfPrefix) + "name", + "full"}}), + .sort_orders = {sort_order}, + .default_sort_order_id = 1, + }); + + ICEBERG_UNWRAP_OR_FAIL( + auto table, Table::Make(ident, metadata, "s3://bucket/meta.json", io, catalog)); + + ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(table)); + auto mode = config->ColumnMode("id"); + EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind()); + EXPECT_EQ("truncate(16)", mode->ToString()); + + mode = config->ColumnMode("name"); + EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind()); + EXPECT_EQ("full", mode->ToString()); + + mode = config->ColumnMode("addr"); + EXPECT_EQ(MetricsMode::Kind::kNone, mode->kind()); + EXPECT_EQ("none", mode->ToString()); + } +} + +TEST(MetricsConfigTest, LimitFieldIds) { + { + // Nested struct type + // Create nested struct type for level1_struct_a + auto level2_struct_a_type = std::make_shared(std::vector{ + SchemaField(31, "level3_primitive_s", std::make_shared(), true)}); + + auto level1_struct_a_type = std::make_shared(std::vector{ + SchemaField(21, "level2_primitive_i", std::make_shared(), false), + SchemaField(22, "level2_struct_a", level2_struct_a_type, false), + SchemaField(23, "level2_primitive_b", std::make_shared(), true)}); + + // Create nested struct type for level1_struct_b + auto level2_struct_b_type = std::make_shared(std::vector{ + SchemaField(32, "level3_primitive_s", std::make_shared(), true)}); + + auto level1_struct_b_type = std::make_shared(std::vector{ + SchemaField(24, "level2_primitive_i", std::make_shared(), false), + SchemaField(25, "level2_struct_b", level2_struct_b_type, false)}); + + // Create the main schema + Schema schema( + std::vector{ + SchemaField(11, "level1_struct_a", level1_struct_a_type, false), + SchemaField(12, "level1_struct_b", level1_struct_b_type, false), + SchemaField(13, "level1_primitive_i", std::make_shared(), false)}, + 100); + + auto result1 = MetricsConfig::LimitFieldIds(schema, 1); + EXPECT_EQ(result1, (std::unordered_set{13})) + << "Should only include top level primitive field"; + + auto result2 = MetricsConfig::LimitFieldIds(schema, 2); + EXPECT_EQ(result2, (std::unordered_set{13, 21})) + << "Should include level 2 primitive field before nested struct"; + + auto result3 = MetricsConfig::LimitFieldIds(schema, 3); + EXPECT_EQ(result3, (std::unordered_set{13, 21, 23})) + << "Should include all of level 2 primitive fields of struct a before nested " + "struct"; + + auto result4 = MetricsConfig::LimitFieldIds(schema, 4); + EXPECT_EQ(result4, (std::unordered_set{13, 21, 23, 31})) + << "Should include all eligible fields in struct a"; + + auto result5 = MetricsConfig::LimitFieldIds(schema, 5); + EXPECT_EQ(result5, (std::unordered_set{13, 21, 23, 31, 24})) + << "Should include first primitive field in struct b"; + + auto result6 = MetricsConfig::LimitFieldIds(schema, 6); + EXPECT_EQ(result6, (std::unordered_set{13, 21, 23, 31, 24, 32})) + << "Should include all primitive fields"; + + auto result7 = MetricsConfig::LimitFieldIds(schema, 7); + EXPECT_EQ(result7, (std::unordered_set{13, 21, 23, 31, 24, 32})) + << "Should return all primitive fields when limit is higher"; + } + + { + // Nested map + auto map_type = std::make_shared( + SchemaField(2, "key", std::make_shared(), false), + SchemaField(3, "value", std::make_shared(), false)); + + Schema schema( + std::vector{ + SchemaField(1, "map", map_type, false), + SchemaField(4, "top", std::make_shared(), false)}, + 100); + + auto result1 = MetricsConfig::LimitFieldIds(schema, 1); + EXPECT_EQ(result1, (std::unordered_set{4})); + + auto result2 = MetricsConfig::LimitFieldIds(schema, 2); + EXPECT_EQ(result2, (std::unordered_set{4, 2})); + + auto result3 = MetricsConfig::LimitFieldIds(schema, 3); + EXPECT_EQ(result3, (std::unordered_set{4, 2, 3})); + + auto result4 = MetricsConfig::LimitFieldIds(schema, 4); + EXPECT_EQ(result4, (std::unordered_set{4, 2, 3})); + } + + { + // Nested list of maps + auto map_type = std::make_shared( + SchemaField(3, "key", std::make_shared(), false), + SchemaField(4, "value", std::make_shared(), false)); + auto list_type = std::make_shared(2, map_type, false); + + Schema schema( + std::vector{ + SchemaField(1, "array_of_maps", list_type, false), + SchemaField(5, "top", std::make_shared(), false)}, + 100); + + auto result1 = MetricsConfig::LimitFieldIds(schema, 1); + EXPECT_EQ(result1, (std::unordered_set{5})); + + auto result2 = MetricsConfig::LimitFieldIds(schema, 2); + EXPECT_EQ(result2, (std::unordered_set{5, 3})); + + auto result3 = MetricsConfig::LimitFieldIds(schema, 3); + EXPECT_EQ(result3, (std::unordered_set{5, 3, 4})); + + auto result4 = MetricsConfig::LimitFieldIds(schema, 4); + EXPECT_EQ(result4, (std::unordered_set{5, 3, 4})); + } +} + +TEST(MetricsConfigTest, ValidateColumnReferences) { + SchemaField field1(1, "col1", std::make_shared(), false); + SchemaField field2(2, "col2", std::make_shared(), true); + SchemaField field3(3, "col3", std::make_shared(), false); + Schema schema(std::vector{field1, field2, field3}, 100); -TEST_F(MetricsConfigTest, ValidateColumnReferences) { { // Empty updates should be valid std::unordered_map updates; - auto result = MetricsConfig::VerifyReferencedColumns(updates, *schema_); + auto result = MetricsConfig::VerifyReferencedColumns(updates, schema); EXPECT_THAT(result, IsOk()) << "Validation should pass for empty updates"; } @@ -61,7 +309,7 @@ TEST_F(MetricsConfigTest, ValidateColumnReferences) { updates["write.format.default"] = "parquet"; updates["write.target-file-size-bytes"] = "524288000"; - auto result = MetricsConfig::VerifyReferencedColumns(updates, *schema_); + auto result = MetricsConfig::VerifyReferencedColumns(updates, schema); EXPECT_THAT(result, IsOk()) << "Validation should pass when no column references exist"; } @@ -74,7 +322,7 @@ TEST_F(MetricsConfigTest, ValidateColumnReferences) { updates[std::string(TableProperties::kMetricModeColumnConfPrefix) + "col2"] = "full"; updates["some.other.property"] = "value"; - auto result = MetricsConfig::VerifyReferencedColumns(updates, *schema_); + auto result = MetricsConfig::VerifyReferencedColumns(updates, schema); EXPECT_THAT(result, IsOk()) << "Validation should pass for valid column references"; } @@ -84,7 +332,7 @@ TEST_F(MetricsConfigTest, ValidateColumnReferences) { updates[std::string(TableProperties::kMetricModeColumnConfPrefix) + "nonexistent"] = "counts"; - auto result = MetricsConfig::VerifyReferencedColumns(updates, *schema_); + auto result = MetricsConfig::VerifyReferencedColumns(updates, schema); EXPECT_THAT(result, IsError(ErrorKind::kValidationFailed)) << "Validation should fail for invalid column references"; } @@ -97,7 +345,7 @@ TEST_F(MetricsConfigTest, ValidateColumnReferences) { updates[std::string(TableProperties::kMetricModeColumnConfPrefix) + "nonexistent"] = "full"; - auto result = MetricsConfig::VerifyReferencedColumns(updates, *schema_); + auto result = MetricsConfig::VerifyReferencedColumns(updates, schema); EXPECT_THAT(result, IsError(ErrorKind::kValidationFailed)) << "Validation should fail when any column reference is invalid"; } diff --git a/src/iceberg/util/type_util.cc b/src/iceberg/util/type_util.cc index a6cfd645a..f3372bd1d 100644 --- a/src/iceberg/util/type_util.cc +++ b/src/iceberg/util/type_util.cc @@ -271,6 +271,31 @@ Result> PruneColumnVisitor::Visit( MakeField(value_field, std::move(value_type))); } +GetProjectedIdsVisitor::GetProjectedIdsVisitor(bool include_struct_ids) + : include_struct_ids_(include_struct_ids) {} + +Status GetProjectedIdsVisitor::Visit(const std::shared_ptr& type) { + if (type->is_nested()) { + return Visit(internal::checked_cast(*type)); + } + return {}; +} + +Status GetProjectedIdsVisitor::Visit(const NestedType& type) { + for (auto& field : type.fields()) { + ICEBERG_RETURN_UNEXPECTED(Visit(field.type())); + } + for (auto& field : type.fields()) { + if ((include_struct_ids_ && field.type()->type_id() == TypeId::kStruct) || + field.type()->is_primitive()) { + ids_.insert(field.field_id()); + } + } + return {}; +} + +std::unordered_set GetProjectedIdsVisitor::Finish() const { return ids_; } + std::unordered_map IndexParents(const StructType& root_struct) { std::unordered_map id_to_parent; std::stack parent_id_stack; diff --git a/src/iceberg/util/type_util.h b/src/iceberg/util/type_util.h index 959bdb9f9..4151406b7 100644 --- a/src/iceberg/util/type_util.h +++ b/src/iceberg/util/type_util.h @@ -122,6 +122,20 @@ class PruneColumnVisitor { const bool select_full_types_; }; +/// \brief Visitor for get field IDs which could be used for projection. +class GetProjectedIdsVisitor { + public: + explicit GetProjectedIdsVisitor(bool include_struct_ids = false); + + Status Visit(const std::shared_ptr& type); + Status Visit(const NestedType& type); + std::unordered_set Finish() const; + + private: + const bool include_struct_ids_; + std::unordered_set ids_; +}; + /// \brief Index parent field IDs for all fields in a struct hierarchy. /// \param root_struct The root struct type to analyze /// \return A map from field ID to its parent struct field ID From 57dbec1f2e1e0c256817dc213a4390a77c511525 Mon Sep 17 00:00:00 2001 From: Zhuo Wang Date: Fri, 9 Jan 2026 14:37:29 +0800 Subject: [PATCH 2/2] fix comments --- src/iceberg/metrics_config.cc | 82 +++++++++---------------- src/iceberg/metrics_config.h | 59 ++++-------------- src/iceberg/test/metrics_config_test.cc | 73 ++++++++-------------- src/iceberg/util/string_util.h | 7 +++ src/iceberg/util/type_util.cc | 14 ++--- src/iceberg/util/type_util.h | 5 +- src/iceberg/util/visit_type.h | 10 +++ 7 files changed, 92 insertions(+), 158 deletions(-) diff --git a/src/iceberg/metrics_config.cc b/src/iceberg/metrics_config.cc index 2d56b804e..d783e1568 100644 --- a/src/iceberg/metrics_config.cc +++ b/src/iceberg/metrics_config.cc @@ -30,6 +30,7 @@ #include "iceberg/table_properties.h" #include "iceberg/util/checked_cast.h" #include "iceberg/util/type_util.h" +#include "iceberg/util/visit_type.h" namespace iceberg { @@ -41,12 +42,12 @@ constexpr std::string_view kFullName = "full"; constexpr std::string_view kTruncatePrefix = "truncate("; constexpr int32_t kDefaultTruncateLength = 16; const std::shared_ptr kDefaultMetricsMode = - std::make_shared(kDefaultTruncateLength); + std::make_shared(MetricsMode::Kind::kTruncate, kDefaultTruncateLength); std::shared_ptr SortedColumnDefaultMode( std::shared_ptr default_mode) { - if (default_mode->kind() == MetricsMode::Kind::kNone || - default_mode->kind() == MetricsMode::Kind::kCounts) { + if (default_mode->kind == MetricsMode::Kind::kNone || + default_mode->kind == MetricsMode::Kind::kCounts) { return kDefaultMetricsMode; } else { return std::move(default_mode); @@ -74,26 +75,21 @@ Result> ParseMode(const std::string& mode, } // namespace const std::shared_ptr& MetricsMode::None() { - static const std::shared_ptr none = std::make_shared(); + static const auto none = std::make_shared(Kind::kNone); return none; } const std::shared_ptr& MetricsMode::Counts() { - static const std::shared_ptr counts = - std::make_shared(); + static const auto counts = std::make_shared(Kind::kCounts); return counts; } const std::shared_ptr& MetricsMode::Full() { - static const std::shared_ptr full = std::make_shared(); + static const auto full = std::make_shared(Kind::kFull); return full; } -const std::shared_ptr& MetricsMode::Truncate() { - return kDefaultMetricsMode; -} - -Result> MetricsMode::FromString(const std::string& mode) { +Result> MetricsMode::FromString(std::string_view mode) { if (StringUtils::EqualsIgnoreCase(mode, kNoneName)) { return MetricsMode::None(); } else if (StringUtils::EqualsIgnoreCase(mode, kCountsName)) { @@ -102,7 +98,7 @@ Result> MetricsMode::FromString(const std::string& return MetricsMode::Full(); } - if (mode.starts_with(kTruncatePrefix) && mode.ends_with(")")) { + if (StringUtils::StartsWithIgnoreCase(mode, kTruncatePrefix) && mode.ends_with(")")) { int32_t length; auto [ptr, ec] = std::from_chars(mode.data() + 9 /* "truncate(" length */, mode.data() + mode.size() - 1, length); @@ -112,42 +108,29 @@ Result> MetricsMode::FromString(const std::string& if (length == kDefaultTruncateLength) { return kDefaultMetricsMode; } - return TruncateMetricsMode::Make(length); + ICEBERG_PRECHECK(length > 0, "Truncate length should be positive."); + return std::make_shared(Kind::kTruncate, length); } return InvalidArgument("Invalid metrics mode: {}", mode); } -std::string NoneMetricsMode::ToString() const { return std::string(kNoneName); } -std::string CountsMetricsMode::ToString() const { return std::string(kCountsName); } -std::string FullMetricsMode::ToString() const { return std::string(kFullName); } -std::string TruncateMetricsMode::ToString() const { - return std::format("truncate({})", length_); -} - -Result> TruncateMetricsMode::Make(int32_t length) { - ICEBERG_PRECHECK(length > 0, "Truncate length should be positive."); - return std::make_shared(length); -} - MetricsConfig::MetricsConfig( std::unordered_map> column_modes, std::shared_ptr default_mode) : column_modes_(std::move(column_modes)), default_mode_(std::move(default_mode)) {} const std::shared_ptr& MetricsConfig::Default() { - static const auto default_config = std::make_shared( - std::unordered_map>{}, - kDefaultMetricsMode); + static const std::shared_ptr default_config( + new MetricsConfig({}, kDefaultMetricsMode)); return default_config; } -Result> MetricsConfig::Make(std::shared_ptr
table) { - ICEBERG_PRECHECK(table != nullptr, "table cannot be null"); - ICEBERG_ASSIGN_OR_RAISE(auto schema, table->schema()); +Result> MetricsConfig::Make(const Table& table) { + ICEBERG_ASSIGN_OR_RAISE(auto schema, table.schema()); - auto sort_order = table->sort_order(); + auto sort_order = table.sort_order(); return MakeInternal( - table->properties(), *schema, + table.properties(), *schema, sort_order.has_value() ? *sort_order.value() : *SortOrder::Unsorted()); } @@ -197,8 +180,8 @@ Result> MetricsConfig::MakeInternal( } } - return std::make_shared(std::move(column_modes), - std::move(default_mode)); + return std::shared_ptr( + new MetricsConfig(std::move(column_modes), std::move(default_mode))); } Result> MetricsConfig::LimitFieldIds(const Schema& schema, @@ -207,18 +190,14 @@ Result> MetricsConfig::LimitFieldIds(const Schema& s public: explicit Visitor(int32_t limit) : limit_(limit) {} - Status Visit(const std::shared_ptr& type) { - if (type->is_nested()) { - return Visit(internal::checked_cast(*type)); - } - return {}; - } + Status Visit(const Type& type) { return VisitNestedType(type, this); } - Status Visit(const NestedType& type) { + Status VisitNested(const NestedType& type) { for (auto& field : type.fields()) { if (!ShouldContinue()) { break; } + // TODO(zhuo.wang) or is_variant if (field.type()->is_primitive()) { ids_.insert(field.field_id()); } @@ -226,12 +205,14 @@ Result> MetricsConfig::LimitFieldIds(const Schema& s for (auto& field : type.fields()) { if (ShouldContinue()) { - ICEBERG_RETURN_UNEXPECTED(Visit(field.type())); + ICEBERG_RETURN_UNEXPECTED(Visit(*field.type())); } } return {}; } + Status VisitNonNested(const Type& type) { return {}; } + std::unordered_set Finish() { return ids_; } private: @@ -243,8 +224,7 @@ Result> MetricsConfig::LimitFieldIds(const Schema& s }; Visitor visitor(limit); - ICEBERG_RETURN_UNEXPECTED( - visitor.Visit(internal::checked_cast(schema))); + ICEBERG_RETURN_UNEXPECTED(visitor.Visit(internal::checked_cast(schema))); return visitor.Finish(); } @@ -257,12 +237,10 @@ Status MetricsConfig::VerifyReferencedColumns( auto field_name = std::string_view(key).substr(TableProperties::kMetricModeColumnConfPrefix.size()); ICEBERG_ASSIGN_OR_RAISE(auto field, schema.FindFieldByName(field_name)); - if (!field.has_value()) { - return ValidationFailed( - "Invalid metrics config, could not find column {} from table prop {} in " - "schema {}", - field_name, key, schema.ToString()); - } + ICEBERG_CHECK(field.has_value(), + "Invalid metrics config, could not find column {} from table prop {} " + "in schema {}", + field_name, key, schema.ToString()); } return {}; } diff --git a/src/iceberg/metrics_config.h b/src/iceberg/metrics_config.h index f0ec9cdac..ab4942a05 100644 --- a/src/iceberg/metrics_config.h +++ b/src/iceberg/metrics_config.h @@ -24,17 +24,18 @@ #include #include +#include #include #include +#include #include "iceberg/iceberg_export.h" #include "iceberg/result.h" #include "iceberg/type_fwd.h" -#include "iceberg/util/formattable.h" namespace iceberg { -class ICEBERG_EXPORT MetricsMode : public util::Formattable { +struct ICEBERG_EXPORT MetricsMode { public: enum class Kind : uint8_t { kNone, @@ -43,66 +44,24 @@ class ICEBERG_EXPORT MetricsMode : public util::Formattable { kFull, }; - static Result> FromString(const std::string& mode); + static Result> FromString(std::string_view mode); static const std::shared_ptr& None(); static const std::shared_ptr& Counts(); - static const std::shared_ptr& Truncate(); static const std::shared_ptr& Full(); - /// \brief Return the kind of this metrics mode. - virtual Kind kind() const = 0; - - std::string ToString() const override = 0; -}; - -class ICEBERG_EXPORT NoneMetricsMode : public MetricsMode { - public: - constexpr Kind kind() const override { return Kind::kNone; } - - std::string ToString() const override; -}; - -class ICEBERG_EXPORT CountsMetricsMode : public MetricsMode { - public: - constexpr Kind kind() const override { return Kind::kCounts; } - - std::string ToString() const override; -}; - -class ICEBERG_EXPORT TruncateMetricsMode : public MetricsMode { - public: - explicit TruncateMetricsMode(int32_t length) : length_(length) {} - - constexpr Kind kind() const override { return Kind::kTruncate; } - - std::string ToString() const override; - - static Result> Make(int32_t length); - - private: - const int32_t length_; -}; - -class ICEBERG_EXPORT FullMetricsMode : public MetricsMode { - public: - constexpr Kind kind() const override { return Kind::kFull; } - - std::string ToString() const override; + Kind kind; + std::variant length; }; /// \brief Configuration utilities for table metrics class ICEBERG_EXPORT MetricsConfig { public: - MetricsConfig( - std::unordered_map> column_modes, - std::shared_ptr default_mode); - /// \brief Get the default metrics config. static const std::shared_ptr& Default(); /// \brief Creates a metrics config from a table. - static Result> Make(std::shared_ptr
table); + static Result> Make(const Table& table); /// \brief Get `limit` num of primitive field ids from schema static Result> LimitFieldIds(const Schema& schema, @@ -121,6 +80,10 @@ class ICEBERG_EXPORT MetricsConfig { std::shared_ptr ColumnMode(const std::string& column_name) const; private: + MetricsConfig( + std::unordered_map> column_modes, + std::shared_ptr default_mode); + /// \brief Generate a MetricsConfig for all columns based on overrides, schema, and sort /// order. /// diff --git a/src/iceberg/test/metrics_config_test.cc b/src/iceberg/test/metrics_config_test.cc index bcd9421ce..5dd80cc22 100644 --- a/src/iceberg/test/metrics_config_test.cc +++ b/src/iceberg/test/metrics_config_test.cc @@ -40,27 +40,15 @@ namespace iceberg { TEST(MetricsConfigTest, MetricsMode) { - EXPECT_EQ(MetricsMode::Kind::kNone, MetricsMode::None()->kind()); - EXPECT_EQ(MetricsMode::Kind::kCounts, MetricsMode::Counts()->kind()); - EXPECT_EQ(MetricsMode::Kind::kFull, MetricsMode::Full()->kind()); - EXPECT_EQ(MetricsMode::Kind::kTruncate, MetricsMode::Truncate()->kind()); - - EXPECT_EQ("none", MetricsMode::None()->ToString()); - EXPECT_EQ("counts", MetricsMode::Counts()->ToString()); - EXPECT_EQ("full", MetricsMode::Full()->ToString()); - EXPECT_EQ("truncate(16)", MetricsMode::Truncate()->ToString()); - - EXPECT_EQ(MetricsMode::Kind::kNone, MetricsMode::FromString("none").value()->kind()); - EXPECT_EQ(MetricsMode::Kind::kCounts, - MetricsMode::FromString("counts").value()->kind()); - EXPECT_EQ(MetricsMode::Kind::kFull, MetricsMode::FromString("full").value()->kind()); - EXPECT_EQ(MetricsMode::Kind::kTruncate, - MetricsMode::FromString("truncate(32)").value()->kind()); + EXPECT_EQ(MetricsMode::Kind::kNone, MetricsMode::None()->kind); + EXPECT_EQ(MetricsMode::Kind::kCounts, MetricsMode::Counts()->kind); + EXPECT_EQ(MetricsMode::Kind::kFull, MetricsMode::Full()->kind); - EXPECT_EQ("none", MetricsMode::FromString("none").value()->ToString()); - EXPECT_EQ("counts", MetricsMode::FromString("counts").value()->ToString()); - EXPECT_EQ("full", MetricsMode::FromString("full").value()->ToString()); - EXPECT_EQ("truncate(32)", MetricsMode::FromString("truncate(32)").value()->ToString()); + EXPECT_EQ(MetricsMode::Kind::kNone, MetricsMode::FromString("none").value()->kind); + EXPECT_EQ(MetricsMode::Kind::kCounts, MetricsMode::FromString("counts").value()->kind); + EXPECT_EQ(MetricsMode::Kind::kFull, MetricsMode::FromString("full").value()->kind); + EXPECT_EQ(MetricsMode::Kind::kTruncate, + MetricsMode::FromString("truncate(32)").value()->kind); auto result = MetricsMode::FromString("truncate(abc)"); EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); @@ -76,13 +64,6 @@ TEST(MetricsConfigTest, MetricsMode) { } TEST(MetricsConfigTest, ForTable) { - { - // table is nullptr - auto result = MetricsConfig::Make(nullptr); - EXPECT_THAT(result, IsError(ErrorKind::kInvalidArgument)); - EXPECT_THAT(result, HasErrorMessage("table cannot be null")); - } - auto io = std::make_shared(); auto catalog = std::make_shared(); auto schema = std::make_shared( @@ -99,18 +80,17 @@ TEST(MetricsConfigTest, ForTable) { ICEBERG_UNWRAP_OR_FAIL( auto table, Table::Make(ident, metadata, "s3://bucket/meta.json", io, catalog)); - ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(table)); + ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(*table)); auto mode = config->ColumnMode("id"); - EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind()); - EXPECT_EQ("truncate(16)", mode->ToString()); + EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind); + EXPECT_EQ(16, std::get(mode->length)); mode = config->ColumnMode("name"); - EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind()); - EXPECT_EQ("truncate(16)", mode->ToString()); - + EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind); + EXPECT_EQ(16, std::get(mode->length)); mode = config->ColumnMode("addr"); - EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind()); - EXPECT_EQ("truncate(16)", mode->ToString()); + EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind); + EXPECT_EQ(16, std::get(mode->length)); } { @@ -124,18 +104,15 @@ TEST(MetricsConfigTest, ForTable) { ICEBERG_UNWRAP_OR_FAIL( auto table, Table::Make(ident, metadata, "s3://bucket/meta.json", io, catalog)); - ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(table)); + ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(*table)); auto mode = config->ColumnMode("id"); - EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind()); - EXPECT_EQ("full", mode->ToString()); + EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind); mode = config->ColumnMode("name"); - EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind()); - EXPECT_EQ("full", mode->ToString()); + EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind); mode = config->ColumnMode("addr"); - EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind()); - EXPECT_EQ("full", mode->ToString()); + EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind); } { @@ -164,18 +141,16 @@ TEST(MetricsConfigTest, ForTable) { ICEBERG_UNWRAP_OR_FAIL( auto table, Table::Make(ident, metadata, "s3://bucket/meta.json", io, catalog)); - ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(table)); + ICEBERG_UNWRAP_OR_FAIL(auto config, MetricsConfig::Make(*table)); auto mode = config->ColumnMode("id"); - EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind()); - EXPECT_EQ("truncate(16)", mode->ToString()); + EXPECT_EQ(MetricsMode::Kind::kTruncate, mode->kind); + EXPECT_EQ(16, std::get(mode->length)); mode = config->ColumnMode("name"); - EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind()); - EXPECT_EQ("full", mode->ToString()); + EXPECT_EQ(MetricsMode::Kind::kFull, mode->kind); mode = config->ColumnMode("addr"); - EXPECT_EQ(MetricsMode::Kind::kNone, mode->kind()); - EXPECT_EQ("none", mode->ToString()); + EXPECT_EQ(MetricsMode::Kind::kNone, mode->kind); } } diff --git a/src/iceberg/util/string_util.h b/src/iceberg/util/string_util.h index 8aa209c94..0c9e89bc7 100644 --- a/src/iceberg/util/string_util.h +++ b/src/iceberg/util/string_util.h @@ -44,6 +44,13 @@ class ICEBERG_EXPORT StringUtils { lhs, rhs, [](char lc, char rc) { return std::tolower(lc) == std::tolower(rc); }); } + static bool StartsWithIgnoreCase(std::string_view str, std::string_view prefix) { + if (str.size() < prefix.size()) { + return false; + } + return EqualsIgnoreCase(str.substr(0, prefix.size()), prefix); + } + /// \brief Count the number of code points in a UTF-8 string. static size_t CodePointCount(std::string_view str) { size_t count = 0; diff --git a/src/iceberg/util/type_util.cc b/src/iceberg/util/type_util.cc index f3372bd1d..0c598d416 100644 --- a/src/iceberg/util/type_util.cc +++ b/src/iceberg/util/type_util.cc @@ -274,18 +274,16 @@ Result> PruneColumnVisitor::Visit( GetProjectedIdsVisitor::GetProjectedIdsVisitor(bool include_struct_ids) : include_struct_ids_(include_struct_ids) {} -Status GetProjectedIdsVisitor::Visit(const std::shared_ptr& type) { - if (type->is_nested()) { - return Visit(internal::checked_cast(*type)); - } - return {}; +Status GetProjectedIdsVisitor::Visit(const Type& type) { + return VisitNestedType(type, this); } -Status GetProjectedIdsVisitor::Visit(const NestedType& type) { +Status GetProjectedIdsVisitor::VisitNested(const NestedType& type) { for (auto& field : type.fields()) { - ICEBERG_RETURN_UNEXPECTED(Visit(field.type())); + ICEBERG_RETURN_UNEXPECTED(Visit(*field.type())); } for (auto& field : type.fields()) { + // TODO(zhuo.wang) or is_variant if ((include_struct_ids_ && field.type()->type_id() == TypeId::kStruct) || field.type()->is_primitive()) { ids_.insert(field.field_id()); @@ -294,6 +292,8 @@ Status GetProjectedIdsVisitor::Visit(const NestedType& type) { return {}; } +Status GetProjectedIdsVisitor::VisitNonNested(const Type& type) { return {}; } + std::unordered_set GetProjectedIdsVisitor::Finish() const { return ids_; } std::unordered_map IndexParents(const StructType& root_struct) { diff --git a/src/iceberg/util/type_util.h b/src/iceberg/util/type_util.h index 4151406b7..be16c3ee8 100644 --- a/src/iceberg/util/type_util.h +++ b/src/iceberg/util/type_util.h @@ -127,8 +127,9 @@ class GetProjectedIdsVisitor { public: explicit GetProjectedIdsVisitor(bool include_struct_ids = false); - Status Visit(const std::shared_ptr& type); - Status Visit(const NestedType& type); + Status Visit(const Type& type); + Status VisitNested(const NestedType& type); + Status VisitNonNested(const Type& type); std::unordered_set Finish() const; private: diff --git a/src/iceberg/util/visit_type.h b/src/iceberg/util/visit_type.h index bf52d2e9a..c2231408e 100644 --- a/src/iceberg/util/visit_type.h +++ b/src/iceberg/util/visit_type.h @@ -158,8 +158,18 @@ inline auto VisitTypeCategory(const Type& type, VISITOR* visitor, ARGS&&... args switch (type.type_id()) { ICEBERG_TYPE_SWITCH_WITH_PRIMITIVE_DEFAULT(SCHEMA_VISIT_ACTION) } +} #undef SCHEMA_VISIT_ACTION + +template +inline auto VisitNestedType(const Type& type, VISITOR* visitor, ARGS&&... args) { + if (type.is_nested()) { + return visitor->VisitNested(internal::checked_cast(type), + std::forward(args)...); + } else { + return visitor->VisitNonNested(type, std::forward(args)...); + } } } // namespace iceberg