Skip to content

Commit 440e500

Browse files
authored
Revert "DPL Analysis: prevent slice cache from updating when not required by …" (#14252)
1 parent 689970d commit 440e500

File tree

9 files changed

+73
-96
lines changed

9 files changed

+73
-96
lines changed

Framework/Core/include/Framework/ASoA.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,10 +1400,10 @@ namespace o2::framework
14001400

14011401
struct PreslicePolicyBase {
14021402
const std::string binding;
1403-
Entry bindingKey;
1403+
StringPair bindingKey;
14041404

14051405
bool isMissing() const;
1406-
Entry const& getBindingKey() const;
1406+
StringPair const& getBindingKey() const;
14071407
};
14081408

14091409
struct PreslicePolicySorted : public PreslicePolicyBase {
@@ -1428,7 +1428,7 @@ struct PresliceBase : public Policy {
14281428
const std::string binding;
14291429

14301430
PresliceBase(expressions::BindingNode index_)
1431-
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, Entry(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
1431+
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, std::make_pair(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
14321432
{
14331433
}
14341434

@@ -1508,7 +1508,7 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
15081508
{
15091509
if constexpr (OPT) {
15101510
if (container.isMissing()) {
1511-
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.key.c_str());
1511+
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
15121512
}
15131513
}
15141514
uint64_t offset = 0;
@@ -1545,7 +1545,7 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
15451545
{
15461546
if constexpr (OPT) {
15471547
if (container.isMissing()) {
1548-
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.key.c_str());
1548+
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
15491549
}
15501550
}
15511551
auto selection = container.getSliceFor(value);
@@ -1574,7 +1574,7 @@ auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase<C, framework:
15741574
{
15751575
if constexpr (OPT) {
15761576
if (container.isMissing()) {
1577-
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.key.c_str());
1577+
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.second.c_str());
15781578
}
15791579
}
15801580
uint64_t offset = 0;

Framework/Core/include/Framework/AnalysisManagers.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -534,43 +534,39 @@ static void setGroupedCombination(C& comb, TG& grouping, std::tuple<Ts...>& asso
534534
/// Preslice handling
535535
template <typename T>
536536
requires(!is_preslice<T>)
537-
bool registerCache(T&, Cache&, Cache&)
537+
bool registerCache(T&, std::vector<StringPair>&, std::vector<StringPair>&)
538538
{
539539
return false;
540540
}
541541

542542
template <is_preslice T>
543543
requires std::same_as<typename T::policy_t, framework::PreslicePolicySorted>
544-
bool registerCache(T& preslice, Cache& bsks, Cache&)
544+
bool registerCache(T& preslice, std::vector<StringPair>& bsks, std::vector<StringPair>&)
545545
{
546546
if constexpr (T::optional) {
547547
if (preslice.binding == "[MISSING]") {
548548
return true;
549549
}
550550
}
551-
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
551+
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
552552
if (locate == bsks.end()) {
553553
bsks.emplace_back(preslice.getBindingKey());
554-
} else if (locate->enabled == false) {
555-
locate->enabled = true;
556554
}
557555
return true;
558556
}
559557

560558
template <is_preslice T>
561559
requires std::same_as<typename T::policy_t, framework::PreslicePolicyGeneral>
562-
bool registerCache(T& preslice, Cache&, Cache& bsksU)
560+
bool registerCache(T& preslice, std::vector<StringPair>&, std::vector<StringPair>& bsksU)
563561
{
564562
if constexpr (T::optional) {
565563
if (preslice.binding == "[MISSING]") {
566564
return true;
567565
}
568566
}
569-
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
567+
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
570568
if (locate == bsksU.end()) {
571569
bsksU.emplace_back(preslice.getBindingKey());
572-
} else if (locate->enabled == false) {
573-
locate->enabled = true;
574570
}
575571
return true;
576572
}

Framework/Core/include/Framework/AnalysisTask.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,20 @@ concept is_enumeration = is_enumeration_v<std::decay_t<T>>;
6666
namespace {
6767
struct AnalysisDataProcessorBuilder {
6868
template <typename G, typename... Args>
69-
static void addGroupingCandidates(Cache& bk, Cache& bku, bool enabled)
69+
static void addGroupingCandidates(std::vector<StringPair>& bk, std::vector<StringPair>& bku)
7070
{
71-
[&bk, &bku, enabled]<typename... As>(framework::pack<As...>) mutable {
71+
[&bk, &bku]<typename... As>(framework::pack<As...>) mutable {
7272
std::string key;
7373
if constexpr (soa::is_iterator<std::decay_t<G>>) {
7474
key = std::string{"fIndex"} + o2::framework::cutString(soa::getLabelFromType<std::decay_t<G>>());
7575
}
76-
([&bk, &bku, &key, enabled]() mutable {
76+
([&bk, &bku, &key]() mutable {
7777
if constexpr (soa::relatedByIndex<std::decay_t<G>, std::decay_t<As>>()) {
7878
auto binding = soa::getLabelFromTypeForKey<std::decay_t<As>>(key);
7979
if constexpr (o2::soa::is_smallgroups<std::decay_t<As>>) {
80-
framework::updatePairList(bku, binding, key, enabled);
80+
framework::updatePairList(bku, binding, key);
8181
} else {
82-
framework::updatePairList(bk, binding, key, enabled);
82+
framework::updatePairList(bk, binding, key);
8383
}
8484
}
8585
}(),
@@ -147,7 +147,7 @@ struct AnalysisDataProcessorBuilder {
147147
/// helper to parse the process arguments
148148
/// 1. enumeration (must be the only argument)
149149
template <typename R, typename C, is_enumeration A>
150-
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, Cache&, Cache&)
150+
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, std::vector<StringPair>&, std::vector<StringPair>&)
151151
{
152152
std::vector<ConfigParamSpec> inputMetadata;
153153
// FIXME: for the moment we do not support begin, end and step.
@@ -156,17 +156,17 @@ struct AnalysisDataProcessorBuilder {
156156

157157
/// 2. grouping case - 1st argument is an iterator
158158
template <typename R, typename C, soa::is_iterator A, soa::is_table... Args>
159-
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, Cache& bk, Cache& bku)
159+
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>& bk, std::vector<StringPair>& bku)
160160
requires(std::is_lvalue_reference_v<A> && (std::is_lvalue_reference_v<Args> && ...))
161161
{
162-
addGroupingCandidates<A, Args...>(bk, bku, value);
162+
addGroupingCandidates<A, Args...>(bk, bku);
163163
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(A, Args...)>();
164164
addInputsAndExpressions<typename std::decay_t<A>::parent_t, Args...>(hash, name, value, inputs, eInfos);
165165
}
166166

167167
/// 3. generic case
168168
template <typename R, typename C, soa::is_table... Args>
169-
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, Cache&, Cache&)
169+
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>&, std::vector<StringPair>&)
170170
requires(std::is_lvalue_reference_v<Args> && ...)
171171
{
172172
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(Args...)>();
@@ -480,8 +480,8 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
480480
std::vector<InputSpec> inputs;
481481
std::vector<ConfigParamSpec> options;
482482
std::vector<ExpressionInfo> expressionInfos;
483-
Cache bindingsKeys;
484-
Cache bindingsKeysUnsorted;
483+
std::vector<StringPair> bindingsKeys;
484+
std::vector<StringPair> bindingsKeysUnsorted;
485485

486486
/// make sure options and configurables are set before expression infos are created
487487
homogeneous_apply_refs([&options, &hash](auto& element) { return analysis_task_parsers::appendOption(options, element); }, *task.get());

Framework/Core/include/Framework/ArrowTableSlicingCache.h

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,64 +34,51 @@ struct SliceInfoUnsortedPtr {
3434
gsl::span<int64_t const> getSliceFor(int value) const;
3535
};
3636

37-
struct Entry {
38-
std::string binding;
39-
std::string key;
40-
bool enabled;
41-
42-
Entry(std::string b, std::string k, bool e = true)
43-
: binding{b},
44-
key{k},
45-
enabled{e}
46-
{
47-
}
48-
};
49-
50-
using Cache = std::vector<Entry>;
37+
using StringPair = std::pair<std::string, std::string>;
5138

52-
void updatePairList(Cache& list, std::string const& binding, std::string const& key, bool enabled);
39+
void updatePairList(std::vector<StringPair>& list, std::string const& binding, std::string const& key);
5340

5441
struct ArrowTableSlicingCacheDef {
5542
constexpr static ServiceKind service_kind = ServiceKind::Global;
56-
Cache bindingsKeys;
57-
Cache bindingsKeysUnsorted;
43+
std::vector<StringPair> bindingsKeys;
44+
std::vector<StringPair> bindingsKeysUnsorted;
5845

59-
void setCaches(Cache&& bsks);
60-
void setCachesUnsorted(Cache&& bsks);
46+
void setCaches(std::vector<StringPair>&& bsks);
47+
void setCachesUnsorted(std::vector<StringPair>&& bsks);
6148
};
6249

6350
struct ArrowTableSlicingCache {
6451
constexpr static ServiceKind service_kind = ServiceKind::Stream;
6552

66-
Cache bindingsKeys;
53+
std::vector<StringPair> bindingsKeys;
6754
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>> values;
6855
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>> counts;
6956

70-
Cache bindingsKeysUnsorted;
57+
std::vector<StringPair> bindingsKeysUnsorted;
7158
std::vector<std::vector<int>> valuesUnsorted;
7259
std::vector<ListVector> groups;
7360

74-
ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted = {});
61+
ArrowTableSlicingCache(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});
7562

7663
// set caching information externally
77-
void setCaches(Cache&& bsks, Cache&& bsksUnsorted = {});
64+
void setCaches(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});
7865

7966
// update slicing info cache entry (assumes it is already present)
8067
arrow::Status updateCacheEntry(int pos, std::shared_ptr<arrow::Table> const& table);
8168
arrow::Status updateCacheEntryUnsorted(int pos, std::shared_ptr<arrow::Table> const& table);
8269

8370
// helper to locate cache position
84-
std::pair<int, bool> getCachePos(Entry const& bindingKey) const;
85-
int getCachePosSortedFor(Entry const& bindingKey) const;
86-
int getCachePosUnsortedFor(Entry const& bindingKey) const;
71+
std::pair<int, bool> getCachePos(StringPair const& bindingKey) const;
72+
int getCachePosSortedFor(StringPair const& bindingKey) const;
73+
int getCachePosUnsortedFor(StringPair const& bindingKey) const;
8774

8875
// get slice from cache for a given value
89-
SliceInfoPtr getCacheFor(Entry const& bindingKey) const;
90-
SliceInfoUnsortedPtr getCacheUnsortedFor(Entry const& bindingKey) const;
76+
SliceInfoPtr getCacheFor(StringPair const& bindingKey) const;
77+
SliceInfoUnsortedPtr getCacheUnsortedFor(StringPair const& bindingKey) const;
9178
SliceInfoPtr getCacheForPos(int pos) const;
9279
SliceInfoUnsortedPtr getCacheUnsortedForPos(int pos) const;
9380

94-
static void validateOrder(Entry const& bindingKey, std::shared_ptr<arrow::Table> const& input);
81+
static void validateOrder(StringPair const& bindingKey, std::shared_ptr<arrow::Table> const& input);
9582
};
9683
} // namespace o2::framework
9784

Framework/Core/include/Framework/GroupSlicer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ struct GroupSlicer {
5555
{
5656
constexpr auto index = framework::has_type_at_v<std::decay_t<T>>(associated_pack_t{});
5757
auto binding = o2::soa::getLabelFromTypeForKey<std::decay_t<T>>(mIndexColumnName);
58-
auto bk = Entry(binding, mIndexColumnName);
58+
auto bk = std::make_pair(binding, mIndexColumnName);
5959
if constexpr (!o2::soa::is_smallgroups<std::decay_t<T>>) {
6060
if (table.size() == 0) {
6161
return;

Framework/Core/src/ASoA.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ bool PreslicePolicyBase::isMissing() const
197197
return binding == "[MISSING]";
198198
}
199199

200-
Entry const& PreslicePolicyBase::getBindingKey() const
200+
StringPair const& PreslicePolicyBase::getBindingKey() const
201201
{
202202
return bindingKey;
203203
}

Framework/Core/src/ArrowSupport.cxx

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -567,27 +567,26 @@ o2::framework::ServiceSpec ArrowSupport::arrowTableSlicingCacheSpec()
567567
.name = "arrow-slicing-cache",
568568
.uniqueId = CommonServices::simpleServiceId<ArrowTableSlicingCache>(),
569569
.init = [](ServiceRegistryRef services, DeviceState&, fair::mq::ProgOptions&) { return ServiceHandle{TypeIdHelpers::uniqueId<ArrowTableSlicingCache>(),
570-
new ArrowTableSlicingCache(Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeys},
571-
Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
570+
new ArrowTableSlicingCache(std::vector<std::pair<std::string, std::string>>{services.get<ArrowTableSlicingCacheDef>().bindingsKeys}, std::vector{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
572571
ServiceKind::Stream, typeid(ArrowTableSlicingCache).name()}; },
573572
.configure = CommonServices::noConfiguration(),
574573
.preProcessing = [](ProcessingContext& pc, void* service_ptr) {
575574
auto* service = static_cast<ArrowTableSlicingCache*>(service_ptr);
576575
auto& caches = service->bindingsKeys;
577-
for (auto i = 0u; i < caches.size(); ++i) {
578-
if (caches[i].enabled && pc.inputs().getPos(caches[i].binding.c_str()) >= 0) {
579-
auto status = service->updateCacheEntry(i, pc.inputs().get<TableConsumer>(caches[i].binding.c_str())->asArrowTable());
576+
for (auto i = 0; i < caches.size(); ++i) {
577+
if (pc.inputs().getPos(caches[i].first.c_str()) >= 0) {
578+
auto status = service->updateCacheEntry(i, pc.inputs().get<TableConsumer>(caches[i].first.c_str())->asArrowTable());
580579
if (!status.ok()) {
581-
throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].binding.c_str(), caches[i].key.c_str());
580+
throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].first.c_str(), caches[i].second.c_str());
582581
}
583582
}
584583
}
585584
auto& unsortedCaches = service->bindingsKeysUnsorted;
586-
for (auto i = 0u; i < unsortedCaches.size(); ++i) {
587-
if (unsortedCaches[i].enabled && pc.inputs().getPos(unsortedCaches[i].binding.c_str()) >= 0) {
588-
auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get<TableConsumer>(unsortedCaches[i].binding.c_str())->asArrowTable());
585+
for (auto i = 0; i < unsortedCaches.size(); ++i) {
586+
if (pc.inputs().getPos(unsortedCaches[i].first.c_str()) >= 0) {
587+
auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get<TableConsumer>(unsortedCaches[i].first.c_str())->asArrowTable());
589588
if (!status.ok()) {
590-
throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].binding.c_str(), unsortedCaches[i].key.c_str());
589+
throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].first.c_str(), unsortedCaches[i].second.c_str());
591590
}
592591
}
593592
} },

0 commit comments

Comments
 (0)