diff --git a/PWGHF/DataModel/CandidateSelectionTables.h b/PWGHF/DataModel/CandidateSelectionTables.h index 7f5786ee1f6..125b0fb7286 100644 --- a/PWGHF/DataModel/CandidateSelectionTables.h +++ b/PWGHF/DataModel/CandidateSelectionTables.h @@ -231,10 +231,12 @@ DECLARE_SOA_TABLE(HfSelJpsi, "AOD", "HFSELJPSI", //! namespace hf_sel_candidate_lc_to_k0s_p { DECLARE_SOA_COLUMN(IsSelLcToK0sP, isSelLcToK0sP, int); +DECLARE_SOA_COLUMN(MlProbLcToK0sP, mlProbLcToK0sP, std::vector); //! } // namespace hf_sel_candidate_lc_to_k0s_p - DECLARE_SOA_TABLE(HfSelLcToK0sP, "AOD", "HFSELLCK0SP", //! hf_sel_candidate_lc_to_k0s_p::IsSelLcToK0sP); +DECLARE_SOA_TABLE(HfMlLcToK0sP, "AOD", "HFMLLcK0sP", //! + hf_sel_candidate_lc_to_k0s_p::MlProbLcToK0sP); namespace hf_sel_candidate_b0 { diff --git a/PWGHF/TableProducer/candidateSelectorLcToK0sP.cxx b/PWGHF/TableProducer/candidateSelectorLcToK0sP.cxx index a4a026a005b..d91ab2f718e 100644 --- a/PWGHF/TableProducer/candidateSelectorLcToK0sP.cxx +++ b/PWGHF/TableProducer/candidateSelectorLcToK0sP.cxx @@ -54,6 +54,7 @@ using namespace o2::framework; struct HfCandidateSelectorLcToK0sP { Produces hfSelLcToK0sPCandidate; + Produces hfMlLcToK0sPCandidate; Configurable ptCandMin{"ptCandMin", 0., "Lower bound of candidate pT"}; Configurable ptCandMax{"ptCandMax", 50., "Upper bound of candidate pT"}; @@ -95,6 +96,7 @@ struct HfCandidateSelectorLcToK0sP { TrackSelectorPr selectorProtonHighP; o2::analysis::HfMlResponseLcToK0sP hfMlResponse; + std::vector outputMl = {}; o2::ccdb::CcdbApi ccdbApi; @@ -239,12 +241,11 @@ struct HfCandidateSelectorLcToK0sP { } template - bool selectionMl(const T& hfCandCascade, const U& bach) + bool selectionMl(const T& hfCandCascade, const U& bach, std::vector& outputMl) { auto ptCand = hfCandCascade.pt(); std::vector inputFeatures = hfMlResponse.getInputFeatures(hfCandCascade, bach); - std::vector outputMl = {}; bool isSelectedMl = hfMlResponse.isSelectedMl(inputFeatures, ptCand, outputMl); @@ -265,26 +266,37 @@ struct HfCandidateSelectorLcToK0sP { const auto& bach = candidate.prong0_as(); // bachelor track statusLc = 0; + outputMl.clear(); // implement filter bit 4 cut - should be done before this task at the track selection level // need to add special cuts (additional cuts on decay length and d0 norm) if (!selectionTopol(candidate)) { hfSelLcToK0sPCandidate(statusLc); + if (applyMl) { + hfMlLcToK0sPCandidate(outputMl); + } continue; } if (!selectionStandardPID(bach)) { hfSelLcToK0sPCandidate(statusLc); + if (applyMl) { + hfMlLcToK0sPCandidate(outputMl); + } continue; } - if (applyMl && !selectionMl(candidate, bach)) { - hfSelLcToK0sPCandidate(statusLc); - continue; + if (applyMl) { + bool isSelectedMlLcToK0sP = selectionMl(candidate, bach, outputMl); + hfMlLcToK0sPCandidate(outputMl); + + if (!isSelectedMlLcToK0sP) { + hfSelLcToK0sPCandidate(statusLc); + continue; + } } statusLc = 1; - hfSelLcToK0sPCandidate(statusLc); } } @@ -299,24 +311,35 @@ struct HfCandidateSelectorLcToK0sP { const auto& bach = candidate.prong0_as(); // bachelor track statusLc = 0; + outputMl.clear(); if (!selectionTopol(candidate)) { hfSelLcToK0sPCandidate(statusLc); + if (applyMl) { + hfMlLcToK0sPCandidate(outputMl); + } continue; } if (!selectionBayesPID(bach)) { hfSelLcToK0sPCandidate(statusLc); + if (applyMl) { + hfMlLcToK0sPCandidate(outputMl); + } continue; } - if (applyMl && !selectionMl(candidate, bach)) { - hfSelLcToK0sPCandidate(statusLc); - continue; + if (applyMl) { + bool isSelectedMlLcToK0sP = selectionMl(candidate, bach, outputMl); + hfMlLcToK0sPCandidate(outputMl); + + if (!isSelectedMlLcToK0sP) { + hfSelLcToK0sPCandidate(statusLc); + continue; + } } statusLc = 1; - hfSelLcToK0sPCandidate(statusLc); } } diff --git a/PWGHF/TableProducer/treeCreatorLcToK0sP.cxx b/PWGHF/TableProducer/treeCreatorLcToK0sP.cxx index 0012108a45c..1e2b2484c13 100644 --- a/PWGHF/TableProducer/treeCreatorLcToK0sP.cxx +++ b/PWGHF/TableProducer/treeCreatorLcToK0sP.cxx @@ -33,8 +33,10 @@ #include #include +#include #include #include +#include using namespace o2; using namespace o2::framework; @@ -68,8 +70,8 @@ DECLARE_SOA_COLUMN(DecayLength, decayLength, float); DECLARE_SOA_COLUMN(DecayLengthXY, decayLengthXY, float); DECLARE_SOA_COLUMN(DecayLengthNormalised, decayLengthNormalised, float); DECLARE_SOA_COLUMN(DecayLengthXYNormalised, decayLengthXYNormalised, float); -DECLARE_SOA_COLUMN(CPA, cpa, float); -DECLARE_SOA_COLUMN(CPAXY, cpaXY, float); +DECLARE_SOA_COLUMN(Cpa, cpa, float); +DECLARE_SOA_COLUMN(CpaXY, cpaXY, float); DECLARE_SOA_COLUMN(Ct, ct, float); DECLARE_SOA_COLUMN(PtV0Pos, ptV0Pos, float); DECLARE_SOA_COLUMN(PtV0Neg, ptV0Neg, float); @@ -84,6 +86,9 @@ DECLARE_SOA_COLUMN(V0CtLambda, v0CtLambda, float); DECLARE_SOA_COLUMN(FlagMc, flagMc, int8_t); DECLARE_SOA_COLUMN(OriginMcRec, originMcRec, int8_t); DECLARE_SOA_COLUMN(OriginMcGen, originMcGen, int8_t); +DECLARE_SOA_COLUMN(MlScoreFirstClass, mlScoreFirstClass, float); +DECLARE_SOA_COLUMN(MlScoreSecondClass, mlScoreSecondClass, float); +DECLARE_SOA_COLUMN(MlScoreThirdClass, mlScoreThirdClass, float); // Events DECLARE_SOA_COLUMN(IsEventReject, isEventReject, int); DECLARE_SOA_COLUMN(RunNumber, runNumber, int); @@ -118,15 +123,18 @@ DECLARE_SOA_TABLE(HfCandCascLites, "AOD", "HFCANDCASCLITE", full::NSigmaTOFPr0, full::M, full::Pt, - full::CPA, - full::CPAXY, + full::Cpa, + full::CpaXY, full::Ct, full::Eta, full::Phi, full::Y, full::E, full::FlagMc, - full::OriginMcRec); + full::OriginMcRec, + full::MlScoreFirstClass, + full::MlScoreSecondClass, + full::MlScoreThirdClass); DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL", collision::BCId, @@ -188,15 +196,18 @@ DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL", full::M, full::Pt, full::P, - full::CPA, - full::CPAXY, + full::Cpa, + full::CpaXY, full::Ct, full::Eta, full::Phi, full::Y, full::E, full::FlagMc, - full::OriginMcRec); + full::OriginMcRec, + full::MlScoreFirstClass, + full::MlScoreSecondClass, + full::MlScoreThirdClass); DECLARE_SOA_TABLE(HfCandCascFullEs, "AOD", "HFCANDCASCFULLE", collision::BCId, @@ -228,23 +239,56 @@ struct HfTreeCreatorLcToK0sP { Configurable ptMaxForDownSample{"ptMaxForDownSample", 24., "Maximum pt for the application of the downsampling factor"}; Configurable fillOnlySignal{"fillOnlySignal", false, "Flag to fill derived tables with signal for ML trainings"}; Configurable fillOnlyBackground{"fillOnlyBackground", false, "Flag to fill derived tables with background for ML trainings"}; + Configurable applyMl{"applyMl", false, "Whether ML was used in candidateSelectorLc"}; + + constexpr static float UndefValueFloat = -999.f; HfHelper hfHelper; - Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1; using TracksWPid = soa::Join; using SelectedCandidatesMc = soa::Filtered>; - - Partition recSig = nabs(aod::hf_cand_casc::flagMcMatchRec) != int8_t(0); - Partition recBkg = nabs(aod::hf_cand_casc::flagMcMatchRec) == int8_t(0); + Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1; void init(InitContext const&) { } + /// \brief function to get ML score values for the current candidate and assign them to input parameters + /// \param candidate candidate instance + /// \param candidateMlScore instance of handler of vectors with ML scores associated with the current candidate + /// \param mlScoreFirstClass ML score for belonging to the first class + /// \param mlScoreSecondClass ML score for belonging to the second class + /// \param mlScoreThirdClass ML score for belonging to the third class + void assignMlScores(aod::HfMlLcToK0sP::iterator const& candidateMlScore, float& mlScoreFirstClass, float& mlScoreSecondClass, float& mlScoreThirdClass) + { + std::vector mlScores; + std::copy(candidateMlScore.mlProbLcToK0sP().begin(), candidateMlScore.mlProbLcToK0sP().end(), std::back_inserter(mlScores)); + + constexpr int IndexFirstClass{0}; + constexpr int IndexSecondClass{1}; + constexpr int IndexThirdClass{2}; + if (mlScores.size() == 0) { + return; // when candidateSelectorLcK0sP rejects a candidate by "usual", non-ML cut, the ml score vector remains empty + } + mlScoreFirstClass = mlScores.at(IndexFirstClass); + mlScoreSecondClass = mlScores.at(IndexSecondClass); + if (mlScores.size() > IndexThirdClass) { + mlScoreThirdClass = mlScores.at(IndexThirdClass); + } + } + template - void fillCandidate(const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec) + void fillCandidate(const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec, aod::HfMlLcToK0sP::iterator const& candidateMlScore) { + + float mlScoreFirstClass{UndefValueFloat}; + float mlScoreSecondClass{UndefValueFloat}; + float mlScoreThirdClass{UndefValueFloat}; + + if (applyMl) { + assignMlScores(candidateMlScore, mlScoreFirstClass, mlScoreSecondClass, mlScoreThirdClass); + } + if (fillCandidateLiteTable) { rowCandidateLite( candidate.chi2PCA(), @@ -283,7 +327,10 @@ struct HfTreeCreatorLcToK0sP { hfHelper.yLc(candidate), hfHelper.eLc(candidate), flagMc, - originMcRec); + originMcRec, + mlScoreFirstClass, + mlScoreSecondClass, + mlScoreThirdClass); } else { rowCandidateFull( bach.collision().bcId(), @@ -353,7 +400,10 @@ struct HfTreeCreatorLcToK0sP { hfHelper.yLc(candidate), hfHelper.eLc(candidate), flagMc, - originMcRec); + originMcRec, + mlScoreFirstClass, + mlScoreSecondClass, + mlScoreThirdClass); } } template @@ -370,52 +420,41 @@ struct HfTreeCreatorLcToK0sP { void processMc(aod::Collisions const& collisions, aod::McCollisions const&, SelectedCandidatesMc const& candidates, + aod::HfMlLcToK0sP const& candidateMlScores, soa::Join const& particles, TracksWPid const&) { + if (applyMl && candidateMlScores.size() == 0) { + LOG(fatal) << "ML enabled but table with the ML scores is empty! Please check your configurables."; + return; + } + // Filling event properties rowCandidateFullEvents.reserve(collisions.size()); for (const auto& collision : collisions) { fillEvent(collision); } - if (fillOnlySignal) { - if (fillCandidateLiteTable) { - rowCandidateLite.reserve(recSig.size()); - } else { - rowCandidateFull.reserve(recSig.size()); - } - for (const auto& candidate : recSig) { - auto bach = candidate.prong0_as(); // bachelor - fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec()); - } - } else if (fillOnlyBackground) { - if (fillCandidateLiteTable) { - rowCandidateLite.reserve(recBkg.size()); - } else { - rowCandidateFull.reserve(recBkg.size()); - } - for (const auto& candidate : recBkg) { - if (downSampleBkgFactor < 1.) { - float pseudoRndm = candidate.ptProng0() * 1000. - static_cast(candidate.ptProng0() * 1000); - if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) { - continue; - } - } - auto bach = candidate.prong0_as(); // bachelor - fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec()); - } + if (fillCandidateLiteTable) { + rowCandidateLite.reserve(candidates.size()); } else { - // Filling candidate properties - if (fillCandidateLiteTable) { - rowCandidateLite.reserve(candidates.size()); + rowCandidateFull.reserve(candidates.size()); + } + + int iCand{0}; + for (const auto& candidate : candidates) { + auto candidateMlScore = candidateMlScores.rawIteratorAt(iCand); + ++iCand; + auto bach = candidate.prong0_as(); // bachelor + const int flag = candidate.flagMcMatchRec(); + + if (fillOnlySignal && flag != 0) { + fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore); + } else if (fillOnlyBackground && flag == 0) { + fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore); } else { - rowCandidateFull.reserve(candidates.size()); - } - for (const auto& candidate : candidates) { - auto bach = candidate.prong0_as(); // bachelor - fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec()); + fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore); } } @@ -439,9 +478,15 @@ struct HfTreeCreatorLcToK0sP { void processData(aod::Collisions const& collisions, soa::Join const& candidates, + aod::HfMlLcToK0sP const& candidateMlScores, TracksWPid const&) { + if (applyMl && candidateMlScores.size() == 0) { + LOG(fatal) << "ML enabled but table with the ML scores is empty! Please check your configurables."; + return; + } + // Filling event properties rowCandidateFullEvents.reserve(collisions.size()); for (const auto& collision : collisions) { @@ -454,11 +499,15 @@ struct HfTreeCreatorLcToK0sP { } else { rowCandidateFull.reserve(candidates.size()); } + + int iCand{0}; for (const auto& candidate : candidates) { + auto candidateMlScore = candidateMlScores.rawIteratorAt(iCand); + ++iCand; auto bach = candidate.prong0_as(); // bachelor double pseudoRndm = bach.pt() * 1000. - static_cast(bach.pt() * 1000); if (candidate.isSelLcToK0sP() >= 1 && pseudoRndm < downSampleBkgFactor) { - fillCandidate(candidate, bach, 0, 0); + fillCandidate(candidate, bach, 0, 0, candidateMlScore); } } }