Skip to content

Commit 592abbf

Browse files
authored
DPL Analysis: implemented grouping for Filtered argument (#3003)
1 parent 6727943 commit 592abbf

File tree

3 files changed

+53
-7
lines changed

3 files changed

+53
-7
lines changed

Framework/Core/include/Framework/ASoA.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,8 @@ class Table
844844
}
845845
std::shared_ptr<arrow::Table> mTable;
846846
/// This is a cached lookup of the column index in a given
847-
std::tuple<std::pair<C*, arrow::Column*>...> mColumnIndex;
847+
std::tuple<std::pair<C*, arrow::Column*>...>
848+
mColumnIndex;
848849
/// Cached begin iterator for this table.
849850
unfiltered_iterator mBegin;
850851
/// Cached end iterator for this table.
@@ -1196,6 +1197,11 @@ class Filtered : public T
11961197
return mSelection->GetNumSlots();
11971198
}
11981199

1200+
int64_t tableSize() const
1201+
{
1202+
return table_t::asArrowTable()->num_rows();
1203+
}
1204+
11991205
framework::expressions::Selection getSelection() const
12001206
{
12011207
return mSelection;

Framework/Core/include/Framework/AnalysisTask.h

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,50 @@ struct AnalysisDataProcessorBuilder {
398398
++const_cast<std::decay_t<Grouping>&>(groupingElement);
399399
++oi;
400400
}
401-
}
402-
if constexpr (is_specialization<std::decay_t<AssociatedType>, o2::soa::Filtered>::value) {
403-
// FIXME: we need to implement the case for the grouped filtered case.
404-
static_assert(always_static_assert_v<std::decay_t<AssociatedType>>, "Grouping Filtered is not yet supported");
401+
} else if constexpr (is_specialization<std::decay_t<AssociatedType>, o2::soa::Filtered>::value) {
402+
auto& fullFiltered = std::get<0>(associatedTables);
403+
auto selectionBuffer = std::shared_ptr<arrow::Buffer>(&(fullFiltered.getSelection()->GetBuffer()));
404+
auto selectionArray = fullFiltered.getSelection()->ToArray();
405+
offsets.push_back(fullFiltered.tableSize());
406+
uint64_t selectionIndex = 0;
407+
uint64_t sliceStart = 0;
408+
uint64_t sliceStop = 0;
409+
410+
auto findSliceBounds = [&](int64_t l, int64_t h) {
411+
size_t s = 0;
412+
for (auto i = selectionIndex; i < selectionArray->length(); ++i) {
413+
auto value = selectionArray->data()->template GetValues<uint64_t>(i);
414+
if (*value == l) {
415+
sliceStart = i;
416+
s = i;
417+
break;
418+
}
419+
}
420+
for (auto i = s + 1; i < selectionArray->length(); ++i) {
421+
auto value = selectionArray->data()->template GetValues<uint64_t>(i);
422+
if (*value == h) {
423+
sliceStop = i;
424+
selectionIndex = i;
425+
break;
426+
}
427+
}
428+
};
429+
430+
for (auto& groupedDatum : groupsCollection) {
431+
auto groupedElementsTable = arrow::util::get<std::shared_ptr<arrow::Table>>(groupedDatum.value);
432+
// for each grouping element we need to slice the selection vector
433+
findSliceBounds(offsets[oi], offsets[oi + 1]);
434+
auto slicedBuffer = arrow::SliceBuffer(selectionBuffer, sliceStart, sliceStop - sliceStart + 1);
435+
expressions::Selection slicedSelection;
436+
if (!gandiva::SelectionVector::MakeInt64(sliceStop - sliceStart + 1, slicedBuffer, &slicedSelection).ok()) {
437+
throw std::runtime_error("Cannot create sliced selection");
438+
}
439+
std::decay_t<AssociatedType> typedTable{{groupedElementsTable}, slicedSelection, offsets[oi]};
440+
typedTable.bindExternalIndices(&groupingTable);
441+
task.process(groupingElement, typedTable);
442+
++const_cast<std::decay_t<Grouping>&>(groupingElement);
443+
++oi;
444+
}
405445
} else if constexpr (is_specialization<std::decay_t<AssociatedType>, o2::soa::Join>::value || is_specialization<std::decay_t<AssociatedType>, o2::soa::Concat>::value) {
406446
for (auto& groupedDatum : groupsCollection) {
407447
auto groupedElementsTable = arrow::util::get<std::shared_ptr<arrow::Table>>(groupedDatum.value);

Framework/Core/test/test_AnalysisTask.cxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ struct GTask {
123123
//};
124124

125125
struct ITask {
126-
void process(o2::soa::Filtered<o2::soa::Join<o2::aod::Foos, o2::aod::Bars, o2::aod::XYZ>> const& foobars)
126+
void process(o2::aod::Collision const&, o2::soa::Filtered<o2::soa::Join<o2::aod::Foos, o2::aod::Bars, o2::aod::XYZ>> const& foobars)
127127
{
128128
for (auto foobar : foobars) {
129129
foobar.x();
@@ -170,5 +170,5 @@ BOOST_AUTO_TEST_CASE(AdaptorCompilation)
170170
// BOOST_CHECK_EQUAL(task8.inputs.size(), 3);
171171

172172
auto task9 = adaptAnalysisTask<ITask>("test9");
173-
BOOST_CHECK_EQUAL(task9.inputs.size(), 3);
173+
BOOST_CHECK_EQUAL(task9.inputs.size(), 4);
174174
}

0 commit comments

Comments
 (0)