From eb2a74d23ca94630c95425f13aa6820a35d6077c Mon Sep 17 00:00:00 2001 From: Jinjoo Seo Date: Thu, 17 Jul 2025 16:44:18 +0200 Subject: [PATCH 1/7] Implement ML-based response class (DQMlResponse) for dielectron DQ-analysis selections. Supports both binary and multiclass BDT evaluation using ONNX Example JSON files included for reference --- PWGDQ/Core/CMakeLists.txt | 2 +- PWGDQ/Core/CutsLibrary.cxx | 152 ++++++++++++++++++ PWGDQ/Core/CutsLibrary.h | 26 ++++ PWGDQ/Core/DQMlResponse.h | 213 ++++++++++++++++++++++++++ PWGDQ/Core/VarManager.cxx | 9 ++ PWGDQ/Core/VarManager.h | 25 +++ PWGDQ/Macros/bdtCut.json | 56 +++++++ PWGDQ/Macros/bdtCutMulti.json | 74 +++++++++ PWGDQ/Tasks/CMakeLists.txt | 6 +- PWGDQ/Tasks/tableReader.cxx | 82 +++++++++- PWGDQ/Tasks/tableReader_withAssoc.cxx | 83 +++++++++- 11 files changed, 722 insertions(+), 6 deletions(-) create mode 100644 PWGDQ/Core/DQMlResponse.h create mode 100644 PWGDQ/Macros/bdtCut.json create mode 100644 PWGDQ/Macros/bdtCutMulti.json diff --git a/PWGDQ/Core/CMakeLists.txt b/PWGDQ/Core/CMakeLists.txt index d19a66a68e6..41ceb661bf9 100644 --- a/PWGDQ/Core/CMakeLists.txt +++ b/PWGDQ/Core/CMakeLists.txt @@ -21,7 +21,7 @@ o2physics_add_library(PWGDQCore AnalysisCompositeCut.cxx MCProng.cxx MCSignal.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle) + PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle O2Physics::MLCore) o2physics_target_root_dictionary(PWGDQCore HEADERS AnalysisCut.h diff --git a/PWGDQ/Core/CutsLibrary.cxx b/PWGDQ/Core/CutsLibrary.cxx index 5cf99d6f22f..9bc9fc58a3e 100644 --- a/PWGDQ/Core/CutsLibrary.cxx +++ b/PWGDQ/Core/CutsLibrary.cxx @@ -17,6 +17,7 @@ #include #include #include +#include #include "AnalysisCompositeCut.h" #include "VarManager.h" @@ -7100,3 +7101,154 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons return retCut; } + +//________________________________________________________________________________________________ +o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json) +{ + LOG(info) << "========================================== interpreting JSON for analysis cuts"; + LOG(info) << "JSON string: " << json; + + rapidjson::Document document; + rapidjson::ParseResult ok = document.Parse(json); + if (!ok) { + LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")"; + return {}; // empty variant + } + + for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) { + const auto& obj = it->value; + + if (!obj.HasMember("type")) { + LOG(fatal) << "Missing type (Binary/MultiClass)"; + return {}; + } + + TString typeStr = obj["type"].GetString(); + // int nClasses = (typeStr == "MultiClass") ? 3 : 1; + + std::vector namesInputFeatures; + if (obj.HasMember("inputFeatures") && obj["inputFeatures"].IsArray()) { + for (auto& feature : obj["inputFeatures"].GetArray()) { + namesInputFeatures.emplace_back(feature.GetString()); + } + } + + std::vector onnxFileNames; + if (obj.HasMember("modelFiles") && obj["modelFiles"].IsArray()) { + for (const auto& model : obj["modelFiles"].GetArray()) { + onnxFileNames.emplace_back(model.GetString()); + } + } + + // Cut storage + std::vector> ptBins; + std::vector> cutsMl; + std::vector cutDirs; + bool cutDirsFilled = false; + + for (auto member = obj.MemberBegin(); member != obj.MemberEnd(); ++member) { + TString key = member->name.GetString(); + if (!key.Contains("AddCut")) + continue; + + const auto& cut = member->value; + + if (!cut.HasMember("pTMin") || !cut.HasMember("pTMax")) { + LOG(fatal) << "Missing pTMin/pTMax in ML cut"; + return {}; + } + + double pTMin = cut["pTMin"].GetDouble(); + double pTMax = cut["pTMax"].GetDouble(); + ptBins.emplace_back(pTMin, pTMax); + + std::vector binCuts; + bool exclude = false; + + for (auto& sub : cut.GetObject()) { + TString subKey = sub.name.GetString(); + if (!subKey.Contains("AddMLCut")) + continue; + + const auto& mlcut = sub.value; + // const char* var = mlcut["var"].GetString(); + double cutVal = mlcut.HasMember("cut") ? mlcut["cut"].GetDouble() : 0.5; + exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false; + + binCuts.push_back(cutVal); + + if (!cutDirsFilled) { + cutDirs.push_back(exclude ? 1 : 0); + cutDirsFilled = true; + } + } + + cutsMl.push_back(binCuts); + } + + // bin edges + std::vector binsPt; + if (!ptBins.empty()) { + std::set binEdges; + for (auto& b : ptBins) + binEdges.insert(b.first); + binEdges.insert(ptBins.back().second); + binsPt = std::vector(binEdges.begin(), binEdges.end()); + } else { + LOG(fatal) << "No pT bins found in ML cuts"; + return {}; + } + + std::vector labelsPt, labelsClass; + for (size_t i = 0; i < cutsMl.size(); ++i) { + labelsPt.push_back(Form("pT%.1f", binsPt[i])); + } + for (size_t j = 0; j < cutsMl[0].size(); ++j) { + labelsClass.push_back(Form("cls%d", static_cast(j))); + } + + // Binary + if (typeStr == "Binary") { + dqmlcuts::BinaryBdtScoreConfig binaryCfg; + binaryCfg.inputFeatures = namesInputFeatures; + binaryCfg.onnxFiles = onnxFileNames; + binaryCfg.binsPt = binsPt; + binaryCfg.cutDirs = cutDirs; + binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass); + + return binaryCfg; + + // MultiClass + } else if (typeStr == "MultiClass") { + dqmlcuts::MultiClassBdtScoreConfig multiCfg; + multiCfg.inputFeatures = namesInputFeatures; + multiCfg.onnxFiles = onnxFileNames; + multiCfg.binsPt = binsPt; + multiCfg.cutDirs = cutDirs; + multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass); + + return multiCfg; + } + + LOG(fatal) << "Unsupported classification type: " << typeStr; + return {}; + } + + return {}; +} + +o2::framework::LabeledArray o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector>& cuts, + const std::vector& labelsPt, + const std::vector& labelsClass) +{ + const size_t nRows = cuts.size(); + const size_t nCols = cuts.empty() ? 0 : cuts[0].size(); + std::vector flat; + + for (const auto& row : cuts) { + flat.insert(flat.end(), row.begin(), row.end()); + } + + o2::framework::Array2D arr(flat.data(), nRows, nCols); + return o2::framework::LabeledArray(arr, labelsPt, labelsClass); +} diff --git a/PWGDQ/Core/CutsLibrary.h b/PWGDQ/Core/CutsLibrary.h index c6ad4caded2..9418197b75e 100644 --- a/PWGDQ/Core/CutsLibrary.h +++ b/PWGDQ/Core/CutsLibrary.h @@ -119,6 +119,32 @@ bool ValidateJSONAnalysisCompositeCut(T cut); template AnalysisCompositeCut* ParseJSONAnalysisCompositeCut(T key, const char* cutName); } // namespace dqcuts +namespace dqmlcuts +{ +struct BinaryBdtScoreConfig { + std::vector inputFeatures; + std::vector onnxFiles; + std::vector binsPt; + o2::framework::LabeledArray cutsMl; + std::vector cutDirs; +}; + +struct MultiClassBdtScoreConfig { + std::vector inputFeatures; + std::vector onnxFiles; + std::vector binsPt; + o2::framework::LabeledArray cutsMl; + std::vector cutDirs; +}; + +using BdtScoreConfig = std::variant; + +BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json); + +o2::framework::LabeledArray makeLabeledCutsMl(const std::vector>& cuts, + const std::vector& labelsPt, + const std::vector& labelsClass); +} // namespace dqmlcuts } // namespace o2::aod AnalysisCompositeCut* o2::aod::dqcuts::GetCompositeCut(const char* cutName); diff --git a/PWGDQ/Core/DQMlResponse.h b/PWGDQ/Core/DQMlResponse.h new file mode 100644 index 00000000000..ab150800cf7 --- /dev/null +++ b/PWGDQ/Core/DQMlResponse.h @@ -0,0 +1,213 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. +// +// Contact: jseo@cern.ch +// +// Class to compute the ML response for DQ-analysis selections +// + +#ifndef PWGDQ_CORE_DQMLRESPONSE_H_ +#define PWGDQ_CORE_DQMLRESPONSE_H_ + +#include "Tools/ML/MlResponse.h" + +#include +#include +#include +#include + +// Fill the map of available input features +// the key is the feature's name (std::string) +// the value is the corresponding value in EnumInputFeatures +#define FILL_MAP(FEATURE) \ + { \ + #FEATURE, static_cast(InputFeatures::FEATURE) \ + } + +namespace o2::analysis +{ + +enum class InputFeatures : uint8_t { // refer to DielectronsAll + fMass = 0, + fPt, + fEta, + fPhi, + fPt1, + fITSChi2NCl1, + fTPCNClsCR1, + fTPCNClsFound1, + fTPCChi2NCl1, + fDcaXY1, + fDcaZ1, + fTPCNSigmaEl1, + fTPCNSigmaPi1, + fTPCNSigmaPr1, + fTOFNSigmaEl1, + fTOFNSigmaPi1, + fTOFNSigmaPr1, + fPt2, + fITSChi2NCl2, + fTPCNClsCR2, + fTPCNClsFound2, + fTPCChi2NCl2, + fDcaXY2, + fDcaZ2, + fTPCNSigmaEl2, + fTPCNSigmaPi2, + fTPCNSigmaPr2, + fTOFNSigmaEl2, + fTOFNSigmaPi2, + fTOFNSigmaPr2, +}; + +static const std::map gFeatureNameMap = { + {InputFeatures::fMass, "fMass"}, + {InputFeatures::fPt, "fPt"}, + {InputFeatures::fEta, "fEta"}, + {InputFeatures::fPhi, "fPhi"}, + {InputFeatures::fPt1, "fPt1"}, + {InputFeatures::fITSChi2NCl1, "fITSChi2NCl1"}, + {InputFeatures::fTPCNClsCR1, "fTPCNClsCR1"}, + {InputFeatures::fTPCNClsFound1, "fTPCNClsFound1"}, + {InputFeatures::fTPCChi2NCl1, "fTPCChi2NCl1"}, + {InputFeatures::fDcaXY1, "fDcaXY1"}, + {InputFeatures::fDcaZ1, "fDcaZ1"}, + {InputFeatures::fTPCNSigmaEl1, "fTPCNSigmaEl1"}, + {InputFeatures::fTPCNSigmaPi1, "fTPCNSigmaPi1"}, + {InputFeatures::fTPCNSigmaPr1, "fTPCNSigmaPr1"}, + {InputFeatures::fTOFNSigmaEl1, "fTOFNSigmaEl1"}, + {InputFeatures::fTOFNSigmaPi1, "fTOFNSigmaPi1"}, + {InputFeatures::fTOFNSigmaPr1, "fTOFNSigmaPr1"}, + {InputFeatures::fPt2, "fPt2"}, + {InputFeatures::fITSChi2NCl2, "fITSChi2NCl2"}, + {InputFeatures::fTPCNClsCR2, "fTPCNClsCR2"}, + {InputFeatures::fTPCNClsFound2, "fTPCNClsFound2"}, + {InputFeatures::fTPCChi2NCl2, "fTPCChi2NCl2"}, + {InputFeatures::fDcaXY2, "fDcaXY2"}, + {InputFeatures::fDcaZ2, "fDcaZ2"}, + {InputFeatures::fTPCNSigmaEl2, "fTPCNSigmaEl2"}, + {InputFeatures::fTPCNSigmaPi2, "fTPCNSigmaPi2"}, + {InputFeatures::fTPCNSigmaPr2, "fTPCNSigmaPr2"}, + {InputFeatures::fTOFNSigmaEl2, "fTOFNSigmaEl2"}, + {InputFeatures::fTOFNSigmaPi2, "fTOFNSigmaPi2"}, + {InputFeatures::fTOFNSigmaPr2, "fTOFNSigmaPr2"}}; + +template +class DQMlResponse : public MlResponse +{ + public: + /// Default constructor + DQMlResponse() = default; + /// Default destructor + virtual ~DQMlResponse() = default; + + /// Method to get the input features vector needed for ML inference + /// \return inputFeatures vector + template + std::vector getInputFeatures(const T1& t1, + const T2& t2, + const TValues& fg) const + { + using Accessor = std::function; + static const std::unordered_map featureMap{ + {"fMass", [](auto const&, auto const&, auto const& v) { return v[VarManager::kMass]; }}, + {"fPt", [](auto const&, auto const&, auto const& v) { return v[VarManager::kPt]; }}, + {"fEta", [](auto const&, auto const&, auto const& v) { return v[VarManager::kEta]; }}, + {"fPhi", [](auto const&, auto const&, auto const& v) { return v[VarManager::kPhi]; }}, + + {"fPt1", [](auto const& t1, auto const&, auto const&) { return t1.pt(); }}, + {"fITSChi2NCl1", [](auto const& t1, auto const&, auto const&) { return t1.itsChi2NCl(); }}, + {"fTPCNClsCR1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNClsCrossedRows(); }}, + {"fTPCNClsFound1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNClsFound(); }}, + {"fTPCChi2NCl1", [](auto const& t1, auto const&, auto const&) { return t1.tpcChi2NCl(); }}, + {"fDcaXY1", [](auto const& t1, auto const&, auto const&) { return t1.dcaXY(); }}, + {"fDcaZ1", [](auto const& t1, auto const&, auto const&) { return t1.dcaZ(); }}, + {"fTPCNSigmaEl1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaEl(); }}, + {"fTPCNSigmaPi1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaPi(); }}, + {"fTPCNSigmaPr1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaPr(); }}, + {"fTOFNSigmaEl1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaEl(); }}, + {"fTOFNSigmaPi1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaPi(); }}, + {"fTOFNSigmaPr1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaPr(); }}, + + {"fPt2", [](auto const&, auto const& t2, auto const&) { return t2.pt(); }}, + {"fITSChi2NCl2", [](auto const&, auto const& t2, auto const&) { return t2.itsChi2NCl(); }}, + {"fTPCNClsCR2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNClsCrossedRows(); }}, + {"fTPCNClsFound2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNClsFound(); }}, + {"fTPCChi2NCl2", [](auto const&, auto const& t2, auto const&) { return t2.tpcChi2NCl(); }}, + {"fDcaXY2", [](auto const&, auto const& t2, auto const&) { return t2.dcaXY(); }}, + {"fDcaZ2", [](auto const&, auto const& t2, auto const&) { return t2.dcaZ(); }}, + {"fTPCNSigmaEl2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaEl(); }}, + {"fTPCNSigmaPi2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaPi(); }}, + {"fTPCNSigmaPr2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaPr(); }}, + {"fTOFNSigmaEl2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaEl(); }}, + {"fTOFNSigmaPi2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaPi(); }}, + {"fTOFNSigmaPr2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaPr(); }}}; + + std::vector dqInputFeatures; + dqInputFeatures.reserve(MlResponse::mCachedIndices.size()); + + for (auto idx : MlResponse::mCachedIndices) { + auto enumIdx = static_cast(idx); + const auto& name = gFeatureNameMap.at(enumIdx); + + auto acc = featureMap.find(name); + if (acc == featureMap.end()) { + LOG(error) << "Missing accessor for " << name; + continue; + } else { + dqInputFeatures.push_back(acc->second(t1, t2, fg)); + } + } + return dqInputFeatures; + } + + protected: + void setAvailableInputFeatures() + { + MlResponse::mAvailableInputFeatures = { + FILL_MAP(fMass), + FILL_MAP(fPt), + FILL_MAP(fEta), + FILL_MAP(fPhi), + FILL_MAP(fPt1), + FILL_MAP(fITSChi2NCl1), + FILL_MAP(fTPCNClsCR1), + FILL_MAP(fTPCNClsFound1), + FILL_MAP(fTPCChi2NCl1), + FILL_MAP(fDcaXY1), + FILL_MAP(fDcaZ1), + FILL_MAP(fTPCNSigmaEl1), + FILL_MAP(fTPCNSigmaPi1), + FILL_MAP(fTPCNSigmaPr1), + FILL_MAP(fTOFNSigmaEl1), + FILL_MAP(fTOFNSigmaPi1), + FILL_MAP(fTOFNSigmaPr1), + FILL_MAP(fPt2), + FILL_MAP(fITSChi2NCl2), + FILL_MAP(fTPCNClsCR2), + FILL_MAP(fTPCNClsFound2), + FILL_MAP(fTPCChi2NCl2), + FILL_MAP(fDcaXY2), + FILL_MAP(fDcaZ2), + FILL_MAP(fTPCNSigmaEl2), + FILL_MAP(fTPCNSigmaPi2), + FILL_MAP(fTPCNSigmaPr2), + FILL_MAP(fTOFNSigmaEl2), + FILL_MAP(fTOFNSigmaPi2), + FILL_MAP(fTOFNSigmaPr2)}; + } +}; + +} // namespace o2::analysis + +#undef FILL_MAP + +#endif // PWGDQ_CORE_DQMLRESPONSE_H_ diff --git a/PWGDQ/Core/VarManager.cxx b/PWGDQ/Core/VarManager.cxx index 912dc06740a..4bb8c419ad7 100644 --- a/PWGDQ/Core/VarManager.cxx +++ b/PWGDQ/Core/VarManager.cxx @@ -1131,6 +1131,12 @@ void VarManager::SetDefaultVarNames() fgVariableUnits[kS13] = "GeV^{2}/c^{4}"; fgVariableNames[kS23] = "m_{23}^{2}"; fgVariableUnits[kS23] = "GeV^{2}/c^{4}"; + fgVariableNames[kBdtBackground] = "kBdtBackground"; + fgVariableUnits[kBdtBackground] = " "; + fgVariableNames[kBdtPrompt] = "kBdtPrompt"; + fgVariableUnits[kBdtPrompt] = " "; + fgVariableNames[kBdtNonprompt] = "kBdtNonprompt"; + fgVariableUnits[kBdtNonprompt] = " "; // Set the variables short names map. This is needed for dynamic configuration via JSON files fgVarNamesMap["kNothing"] = kNothing; @@ -1770,4 +1776,7 @@ void VarManager::SetDefaultVarNames() fgVarNamesMap["kV24ME"] = kV24ME; fgVarNamesMap["kWV22ME"] = kWV22ME; fgVarNamesMap["kWV24ME"] = kWV24ME; + fgVarNamesMap["kBdtBackground"] = kBdtBackground; + fgVarNamesMap["kBdtPrompt"] = kBdtPrompt; + fgVarNamesMap["kBdtNonprompt"] = kBdtNonprompt; } diff --git a/PWGDQ/Core/VarManager.h b/PWGDQ/Core/VarManager.h index ec2ec55b83b..2acf33cf8e8 100644 --- a/PWGDQ/Core/VarManager.h +++ b/PWGDQ/Core/VarManager.h @@ -855,6 +855,11 @@ class VarManager : public TObject // deltaMass_jpsi = kPairMass - kPairMassDau +3.096900 kDeltaMass_jpsi, + // BDT score + kBdtBackground, + kBdtPrompt, + kBdtNonprompt, + kNVars }; // end of Variables enumeration @@ -1127,6 +1132,8 @@ class VarManager : public TObject static void FillDileptonTrackTrackVertexing(C const& collision, T1 const& lepton1, T1 const& lepton2, T1 const& track1, T1 const& track2, float* values); template static void FillZDC(const T& zdc, float* values = nullptr); + template + static void FillBdtScore(const T& bdtScore, float* values = nullptr); static void SetCalibrationObject(CalibObjects calib, TObject* obj) { @@ -5524,4 +5531,22 @@ float VarManager::calculatePhiV(T1 const& t1, T2 const& t2) return pairPhiV; } +template +void VarManager::FillBdtScore(T1 const& bdtScore, float* values) +{ + if (!values) { + values = fgValues; + } + + if (bdtScore.size() == 1) { + values[kBdtBackground] = bdtScore[0]; + } else if (bdtScore.size() == 3) { + values[kBdtBackground] = bdtScore[0]; + values[kBdtPrompt] = bdtScore[1]; + values[kBdtNonprompt] = bdtScore[2]; + } else { + LOG(warning) << "Unexpected number of BDT outputs: " << bdtScore.size(); + } +} + #endif // PWGDQ_CORE_VARMANAGER_H_ diff --git a/PWGDQ/Macros/bdtCut.json b/PWGDQ/Macros/bdtCut.json new file mode 100644 index 00000000000..5ff0ada64b4 --- /dev/null +++ b/PWGDQ/Macros/bdtCut.json @@ -0,0 +1,56 @@ +{ + "TestCut": { + "type": "Binary", + "title": "MyBDTModel", + "inputFeatures": [ + "fPt1", + "fITSChi2NCl1", + "fTPCNClsCR1", + "fTPCNClsFound1", + "fTPCChi2NCl1", + "fDcaXY1", + "fDcaZ1", + "fTPCNSigmaEl1", + "fTPCNSigmaPi1", + "fTPCNSigmaPr1", + "fTOFNSigmaEl1", + "fTOFNSigmaPi1", + "fTOFNSigmaPr1", + "fPt2", + "fITSChi2NCl2", + "fTPCNClsCR2", + "fTPCNClsFound2", + "fTPCChi2NCl2", + "fDcaXY2", + "fDcaZ2", + "fTPCNSigmaEl2", + "fTPCNSigmaPi2", + "fTPCNSigmaPr2", + "fTOFNSigmaEl2", + "fTOFNSigmaPi2", + "fTOFNSigmaPr2" + ], + "modelFiles": [ + "cent_30_50_pt0_2_onnx.onnx", + "cent_30_50_pt2_20_onnx.onnx" + ], + "AddMLCut-pTBin1": { + "pTMin": 0, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + }, + "AddMLCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + } + } +} diff --git a/PWGDQ/Macros/bdtCutMulti.json b/PWGDQ/Macros/bdtCutMulti.json new file mode 100644 index 00000000000..3042e767c11 --- /dev/null +++ b/PWGDQ/Macros/bdtCutMulti.json @@ -0,0 +1,74 @@ +{ + "TestCut": { + "type": "Binary", + "title": "MyBDTModel", + "inputFeatures": [ + "fITSChi2NCl1", + "fTPCNClsCR1", + "fTPCNClsFound1", + "fTPCChi2NCl1", + "fDcaXY1", + "fDcaZ1", + "fTPCNSigmaEl1", + "fTPCNSigmaPi1", + "fTPCNSigmaPr1", + "fTOFNSigmaEl1", + "fTOFNSigmaPi1", + "fTOFNSigmaPr1", + "fITSChi2NCl2", + "fTPCNClsCR2", + "fTPCNClsFound2", + "fTPCChi2NCl2", + "fDcaXY2", + "fDcaZ2", + "fTPCNSigmaEl2", + "fTPCNSigmaPi2", + "fTPCNSigmaPr2", + "fTOFNSigmaEl2", + "fTOFNSigmaPi2", + "fTOFNSigmaPr2" + ], + "modelFiles": [ + "cent_30_50_pt0_2_onnx.onnx", + "cent_30_50_pt2_20_onnx.onnx" + ], + "AddMLCut-pTBin1": { + "pTMin": 0, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.5, + "exclude": false + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } + }, + "AddMLCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.5, + "exclude": false + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } + } + } +} diff --git a/PWGDQ/Tasks/CMakeLists.txt b/PWGDQ/Tasks/CMakeLists.txt index 5095140a2b8..c3bb38bf955 100644 --- a/PWGDQ/Tasks/CMakeLists.txt +++ b/PWGDQ/Tasks/CMakeLists.txt @@ -11,12 +11,12 @@ o2physics_add_dpl_workflow(table-reader SOURCES tableReader.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2Physics::PWGDQCore + PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2Physics::PWGDQCore O2Physics::MLCore COMPONENT_NAME Analysis) o2physics_add_dpl_workflow(table-reader-with-assoc SOURCES tableReader_withAssoc.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2::DetectorsBase O2Physics::AnalysisCore O2Physics::AnalysisCCDB O2Physics::PWGDQCore + PUBLIC_LINK_LIBRARIES O2::Framework O2::DetectorsBase O2Physics::AnalysisCore O2Physics::AnalysisCCDB O2Physics::PWGDQCore O2Physics::MLCore COMPONENT_NAME Analysis) o2physics_add_dpl_workflow(efficiency @@ -127,4 +127,4 @@ o2physics_add_dpl_workflow(model-converter-event-extended o2physics_add_dpl_workflow(tag-and-probe SOURCES TagAndProbe.cxx PUBLIC_LINK_LIBRARIES O2::Framework O2::DetectorsBase O2Physics::AnalysisCore O2Physics::AnalysisCCDB O2Physics::PWGDQCore - COMPONENT_NAME Analysis) \ No newline at end of file + COMPONENT_NAME Analysis) diff --git a/PWGDQ/Tasks/tableReader.cxx b/PWGDQ/Tasks/tableReader.cxx index 888be5a6d67..03c0fb8ee1a 100644 --- a/PWGDQ/Tasks/tableReader.cxx +++ b/PWGDQ/Tasks/tableReader.cxx @@ -20,6 +20,7 @@ #include "PWGDQ/Core/MixingLibrary.h" #include "PWGDQ/Core/VarManager.h" #include "PWGDQ/DataModel/ReducedInfoTables.h" +#include "PWGDQ/Core/DQMlResponse.h" #include "Common/CCDB/EventSelectionParams.h" @@ -1054,6 +1055,12 @@ struct AnalysisSameEventPairing { Configurable fCenterMassEnergy{"energy", 13600, "Center of mass energy in GeV"}; Configurable fConfigCumulants{"cfgCumulants", false, "If true, fill Cumulants with Weights different than 0"}; Configurable fConfigAddJSONHistograms{"cfgAddJSONHistograms", "", "Histograms in JSON format"}; + // ML inference + Configurable applyBDT{"applyBDT", false, "Flag to apply ML selections"}; + Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; + Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/j/jseo/ML/PbPbPsi/default/"}, "Paths of models on CCDB"}; + Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; + Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; // Configurables to create output tree (flat tables or minitree) struct : ConfigurableGroup { @@ -1071,6 +1078,10 @@ struct AnalysisSameEventPairing { HistogramManager* fHistMan; + o2::analysis::DQMlResponse dqMlResponse; + std::vector outputMlPsi2ee = {}; // TODO: check this is needed or not + o2::ccdb::CcdbApi ccdbApi; + // NOTE: The track filter produced by the barrel track selection contain a number of electron cut decisions and one last cut for hadrons used in the // dilepton - hadron task downstream. So the bit mask is required to select pairs just based on the electron cuts // TODO: provide as Configurable the list and names of the cuts which should be used in pairing @@ -1121,7 +1132,47 @@ struct AnalysisSameEventPairing { } } - if (context.mOptions.get("processDecayToEESkimmed") || context.mOptions.get("processDecayToEESkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithCov") || context.mOptions.get("processDecayToEESkimmedWithCovNoTwoProngFitter") || context.mOptions.get("processDecayToEEVertexingSkimmed") || context.mOptions.get("processVnDecayToEESkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithColl") || context.mOptions.get("processDecayToEESkimmedWithCollNoTwoProngFitter") || context.mOptions.get("processDecayToPiPiSkimmed") || context.mOptions.get("processAllSkimmed")) { + if (applyBDT) { + // BDT cuts via JSON + std::vector binsPtMl; + o2::framework::LabeledArray cutsMl; + std::vector cutDirMl; + int nClassesMl = 1; // 1 for binary BDT, 3 for multiclass BDT + std::vector namesInputFeatures; + std::vector onnxFileNames; + + auto config = o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(fConfigBdtCutsJSON.value.c_str()); + + if (std::holds_alternative(config)) { + auto& cfg = std::get(config); + binsPtMl = cfg.binsPt; + nClassesMl = 1; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + } else { + auto& cfg = std::get(config); + binsPtMl = cfg.binsPt; + nClassesMl = 3; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + } + + dqMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl); + if (loadModelsFromCCDB) { + ccdbApi.init(ccdburl); + dqMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB); + } else { + dqMlResponse.setModelPathsLocal(onnxFileNames); + } + dqMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + dqMlResponse.init(); + } + + if (context.mOptions.get("processDecayToEESkimmed") || context.mOptions.get("processDecayToEESkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithCov") || context.mOptions.get("processDecayToEESkimmedWithCovNoTwoProngFitter") || context.mOptions.get("processDecayToEEVertexingSkimmed") || context.mOptions.get("processVnDecayToEESkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithColl") || context.mOptions.get("processDecayToEESkimmedWithCollNoTwoProngFitter") || context.mOptions.get("processDecayToPiPiSkimmed") || context.mOptions.get("processAllSkimmed") || context.mOptions.get("processDecayToEESkimmedBDT")) { TString cutNames = fConfigTrackCuts.value; if (!cutNames.IsNull()) { // if track cuts std::unique_ptr objArray(cutNames.Tokenize(",")); @@ -1317,6 +1368,8 @@ struct AnalysisSameEventPairing { dileptonMiniTree.reserve(1); } + bool isSelectedBDT = false; + if (fConfigMultDimuons.value) { uint32_t mult_dimuons = 0; @@ -1395,6 +1448,22 @@ struct AnalysisSameEventPairing { } } if constexpr ((TPairType == pairTypeEE) && (TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { + if (applyBDT) { + std::vector dqInputFeatures = dqMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + + if (dqInputFeatures.empty()) { + LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; + return; + } + + // isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt]); + isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt], outputMlPsi2ee); + VarManager::FillBdtScore(outputMlPsi2ee); // TODO: check if this is needed or not + } + + if (applyBDT && !isSelectedBDT) + continue; + if (fConfigFlatTables.value) { dielectronAllList(VarManager::fgValues[VarManager::kMass], VarManager::fgValues[VarManager::kPt], VarManager::fgValues[VarManager::kEta], VarManager::fgValues[VarManager::kPhi], t1.sign() + t2.sign(), dileptonFilterMap, dileptonMcDecision, t1.pt(), t1.eta(), t1.phi(), t1.itsClusterMap(), t1.itsChi2NCl(), t1.tpcNClsCrossedRows(), t1.tpcNClsFound(), t1.tpcChi2NCl(), t1.dcaXY(), t1.dcaZ(), t1.tpcSignal(), t1.tpcNSigmaEl(), t1.tpcNSigmaPi(), t1.tpcNSigmaPr(), t1.beta(), t1.tofNSigmaEl(), t1.tofNSigmaPi(), t1.tofNSigmaPr(), @@ -1468,6 +1537,9 @@ struct AnalysisSameEventPairing { } } + if (applyBDT && !isSelectedBDT) + continue; + int iCut = 0; for (int icut = 0; icut < ncuts; icut++) { if (twoTrackFilter & (static_cast(1) << icut)) { @@ -1606,6 +1678,13 @@ struct AnalysisSameEventPairing { VarManager::FillEvent(event, VarManager::fgValues); runSameEventPairing(event, tracks, tracks); } + void processDecayToEESkimmedBDT(soa::Filtered::iterator const& event, soa::Filtered const& tracks) + { + // Reset the fValues array + VarManager::ResetValues(0, VarManager::kNVars); + VarManager::FillEvent(event, VarManager::fgValues); + runSameEventPairing(event, tracks, tracks); + } void processDecayToMuMuSkimmed(soa::Filtered::iterator const& event, soa::Filtered const& muons) { // Reset the fValues array @@ -1710,6 +1789,7 @@ struct AnalysisSameEventPairing { PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEEPrefilterSkimmedNoTwoProngFitter, "Run electron-electron pairing, with skimmed tracks and prefilter from AnalysisPrefilterSelection but no two prong fitter", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedWithColl, "Run electron-electron pairing, with skimmed tracks and with collision information", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedWithCollNoTwoProngFitter, "Run electron-electron pairing, with skimmed tracks and with collision information but no two prong fitter", false); + PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedBDT, "Run electron-electron pairing, with skimmed tracks and BDT selection", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuSkimmed, "Run muon-muon pairing, with skimmed muons", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuSkimmedWithMult, "Run muon-muon pairing, with skimmed muons and multiplicity", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuVertexingSkimmed, "Run muon-muon pairing and vertexing, with skimmed muons", false); diff --git a/PWGDQ/Tasks/tableReader_withAssoc.cxx b/PWGDQ/Tasks/tableReader_withAssoc.cxx index 139476d3b71..fd3d1eb2ec7 100644 --- a/PWGDQ/Tasks/tableReader_withAssoc.cxx +++ b/PWGDQ/Tasks/tableReader_withAssoc.cxx @@ -21,6 +21,7 @@ #include "PWGDQ/Core/MixingLibrary.h" #include "PWGDQ/Core/VarManager.h" #include "PWGDQ/DataModel/ReducedInfoTables.h" +#include "PWGDQ/Core/DQMlResponse.h" #include "Common/CCDB/EventSelectionParams.h" #include "Common/Core/TableHelper.h" @@ -1246,6 +1247,14 @@ struct AnalysisSameEventPairing { Configurable propTrack{"cfgPropTrack", true, "Propgate tracks to associated collision to recalculate DCA and momentum vector"}; Configurable useRemoteCollisionInfo{"cfgUseRemoteCollisionInfo", false, "Use remote collision information from CCDB"}; } fConfigOptions; + struct : ConfigurableGroup { + Configurable applyBDT{"applyBDT", false, "Flag to apply ML selections"}; + Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; + + Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/j/jseo/ML/PbPbPsi/default/"}, "Paths of models on CCDB"}; + Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; + Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; + } fConfigML; Service fCCDB; o2::ccdb::CcdbApi fCCDBApi; @@ -1254,6 +1263,9 @@ struct AnalysisSameEventPairing { HistogramManager* fHistMan; + o2::analysis::DQMlResponse dqMlResponse; + std::vector outputMlPsi2ee = {}; // TODO: check this is needed or not + // keep histogram class names in maps, so we don't have to buld their names in the pair loops std::map> fTrackHistNames; std::map> fMuonHistNames; @@ -1281,7 +1293,7 @@ struct AnalysisSameEventPairing { void init(o2::framework::InitContext& context) { LOG(info) << "Starting initialization of AnalysisSameEventPairing (idstoreh)"; - fEnableBarrelHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processBarrelOnlySkimmed") || context.mOptions.get("processBarrelOnlyWithCollSkimmed") || context.mOptions.get("processBarrelOnlySkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedNoCovWithMultExtra") || context.mOptions.get("processBarrelOnlyWithQvectorCentrSkimmedNoCov"); + fEnableBarrelHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processBarrelOnlySkimmed") || context.mOptions.get("processBarrelOnlyWithCollSkimmed") || context.mOptions.get("processBarrelOnlySkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedNoCovWithMultExtra") || context.mOptions.get("processBarrelOnlyWithQvectorCentrSkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedBDT"); fEnableBarrelMixingHistos = context.mOptions.get("processMixingAllSkimmed") || context.mOptions.get("processMixingBarrelSkimmed"); fEnableMuonHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processMuonOnlySkimmed") || context.mOptions.get("processMuonOnlySkimmedMultExtra") || context.mOptions.get("processMixingMuonSkimmed"); fEnableMuonMixingHistos = context.mOptions.get("processMixingAllSkimmed") || context.mOptions.get("processMixingMuonSkimmed"); @@ -1321,6 +1333,46 @@ struct AnalysisSameEventPairing { objArrayMuonCuts = muonCutsStr.Tokenize(","); } + if (fConfigML.applyBDT) { + // BDT cuts via JSON + std::vector binsPtMl; + o2::framework::LabeledArray cutsMl; + std::vector cutDirMl; + int nClassesMl = 1; // 1 for binary BDT, 3 for multiclass BDT + std::vector namesInputFeatures; + std::vector onnxFileNames; + + auto config = o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(fConfigML.fConfigBdtCutsJSON.value.c_str()); + + if (std::holds_alternative(config)) { + auto& cfg = std::get(config); + binsPtMl = cfg.binsPt; + nClassesMl = 1; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + } else { + auto& cfg = std::get(config); + binsPtMl = cfg.binsPt; + nClassesMl = 3; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + } + + dqMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl); + if (fConfigML.loadModelsFromCCDB) { + fCCDBApi.init(fConfigCCDB.url); + dqMlResponse.setModelPathsCCDB(onnxFileNames, fCCDBApi, fConfigML.modelPathsCCDB, fConfigML.timestampCCDB); + } else { + dqMlResponse.setModelPathsLocal(onnxFileNames); + } + dqMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + dqMlResponse.init(); + } + // get the barrel track selection cuts string tempCuts; getTaskOptionValue(context, "analysis-track-selection", "cfgTrackCuts", tempCuts, false); @@ -1627,6 +1679,7 @@ struct AnalysisSameEventPairing { constexpr bool eventHasQvector = ((TEventFillMap & VarManager::ObjTypes::ReducedEventQvector) > 0); constexpr bool eventHasQvectorCentr = ((TEventFillMap & VarManager::ObjTypes::CollisionQvect) > 0); constexpr bool trackHasCov = ((TTrackFillMap & VarManager::ObjTypes::TrackCov) > 0 || (TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelCov) > 0); + bool isSelectedBDT = false; for (auto& event : events) { if (!event.isEventSelected_bit(0)) { @@ -1707,6 +1760,22 @@ struct AnalysisSameEventPairing { if constexpr (trackHasCov && TTwoProngFitter) { dielectronsExtraList(t1.globalIndex(), t2.globalIndex(), VarManager::fgValues[VarManager::kVertexingTauzProjected], VarManager::fgValues[VarManager::kVertexingLzProjected], VarManager::fgValues[VarManager::kVertexingLxyProjected]); if constexpr ((TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { + if (fConfigML.applyBDT) { + std::vector dqInputFeatures = dqMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + + if (dqInputFeatures.empty()) { + LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; + return; + } + + // isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt]); + isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt], outputMlPsi2ee); + VarManager::FillBdtScore(outputMlPsi2ee); // TODO: check if this is needed or not + } + + if (fConfigML.applyBDT && !isSelectedBDT) + continue; + if (fConfigOptions.flatTables.value) { dielectronAllList(VarManager::fgValues[VarManager::kMass], VarManager::fgValues[VarManager::kPt], VarManager::fgValues[VarManager::kEta], VarManager::fgValues[VarManager::kPhi], t1.sign() + t2.sign(), twoTrackFilter, dileptonMcDecision, t1.pt(), t1.eta(), t1.phi(), t1.itsClusterMap(), t1.itsChi2NCl(), t1.tpcNClsCrossedRows(), t1.tpcNClsFound(), t1.tpcChi2NCl(), t1.dcaXY(), t1.dcaZ(), t1.tpcSignal(), t1.tpcNSigmaEl(), t1.tpcNSigmaPi(), t1.tpcNSigmaPr(), t1.beta(), t1.tofNSigmaEl(), t1.tofNSigmaPi(), t1.tofNSigmaPr(), @@ -1831,6 +1900,10 @@ struct AnalysisSameEventPairing { bool isLeg1Ambi = false; bool isLeg2Ambi = false; bool isAmbiExtra = false; + + if (fConfigML.applyBDT && !isSelectedBDT) + continue; + for (int icut = 0; icut < ncuts; icut++) { if (twoTrackFilter & (static_cast(1) << icut)) { isAmbiInBunch = (twoTrackFilter & (static_cast(1) << 28)) || (twoTrackFilter & (static_cast(1) << 29)); @@ -2152,6 +2225,13 @@ struct AnalysisSameEventPairing { runSameEventPairing(events, trackAssocsPerCollision, barrelAssocs, barrelTracks); } + void processBarrelOnlySkimmedBDT(MyEventsVtxCovSelected const& events, + soa::Join const& barrelAssocs, + MyBarrelTracksWithCovWithAmbiguities const& barrelTracks) + { + runSameEventPairing(events, trackAssocsPerCollision, barrelAssocs, barrelTracks); + } + void processMuonOnlySkimmed(MyEventsVtxCovSelected const& events, soa::Join const& muonAssocs, MyMuonTracksWithCovWithAmbiguities const& muons) { @@ -2195,6 +2275,7 @@ struct AnalysisSameEventPairing { PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedNoCov, "Run barrel only pairing (no covariances), with skimmed tracks and with collision information", false); PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedNoCovWithMultExtra, "Run barrel only pairing (no covariances), with skimmed tracks, with collision information, with MultsExtra", false); PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlyWithQvectorCentrSkimmedNoCov, "Run barrel only pairing (no covariances), with skimmed tracks, with Qvector from central framework", false); + PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedBDT, "Run electron-electron pairing, with skimmed tracks and BDT selection", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMuonOnlySkimmed, "Run muon only pairing, with skimmed tracks", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMuonOnlySkimmedMultExtra, "Run muon only pairing, with skimmed tracks", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMixingAllSkimmed, "Run all types of mixed pairing, with skimmed tracks/muons", false); From 9d6591d5a48c49eab19e9f8ad4196ce9c61fcf20 Mon Sep 17 00:00:00 2001 From: Jinjoo Seo Date: Thu, 17 Jul 2025 17:31:12 +0200 Subject: [PATCH 2/7] linter error resolve --- PWGDQ/Core/CutsLibrary.cxx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/PWGDQ/Core/CutsLibrary.cxx b/PWGDQ/Core/CutsLibrary.cxx index 9bc9fc58a3e..08c9e32bceb 100644 --- a/PWGDQ/Core/CutsLibrary.cxx +++ b/PWGDQ/Core/CutsLibrary.cxx @@ -7128,7 +7128,7 @@ o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFro std::vector namesInputFeatures; if (obj.HasMember("inputFeatures") && obj["inputFeatures"].IsArray()) { - for (auto& feature : obj["inputFeatures"].GetArray()) { + for (const auto& feature : obj["inputFeatures"].GetArray()) { namesInputFeatures.emplace_back(feature.GetString()); } } @@ -7165,7 +7165,7 @@ o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFro std::vector binCuts; bool exclude = false; - for (auto& sub : cut.GetObject()) { + for (const auto& sub : cut.GetObject()) { TString subKey = sub.name.GetString(); if (!subKey.Contains("AddMLCut")) continue; @@ -7190,7 +7190,7 @@ o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFro std::vector binsPt; if (!ptBins.empty()) { std::set binEdges; - for (auto& b : ptBins) + for (const auto& b : ptBins) binEdges.insert(b.first); binEdges.insert(ptBins.back().second); binsPt = std::vector(binEdges.begin(), binEdges.end()); From 2549ae29f833d28262b6b6f8e31be4fbb2f8a3b5 Mon Sep 17 00:00:00 2001 From: Jseo Date: Mon, 21 Jul 2025 13:05:21 +0200 Subject: [PATCH 3/7] clang resolve --- PWGDQ/Core/CutsLibrary.cxx | 14 +++++++++----- PWGDQ/Core/DQMlResponse.h | 7 +++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/PWGDQ/Core/CutsLibrary.cxx b/PWGDQ/Core/CutsLibrary.cxx index 08c9e32bceb..96994e89621 100644 --- a/PWGDQ/Core/CutsLibrary.cxx +++ b/PWGDQ/Core/CutsLibrary.cxx @@ -12,14 +12,18 @@ // Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no // #include "PWGDQ/Core/CutsLibrary.h" -#include + +#include "AnalysisCompositeCut.h" +#include "VarManager.h" + #include -#include -#include + +#include + #include #include -#include "AnalysisCompositeCut.h" -#include "VarManager.h" +#include +#include using std::cout; using std::endl; diff --git a/PWGDQ/Core/DQMlResponse.h b/PWGDQ/Core/DQMlResponse.h index ab150800cf7..361e0c1fad7 100644 --- a/PWGDQ/Core/DQMlResponse.h +++ b/PWGDQ/Core/DQMlResponse.h @@ -27,10 +27,9 @@ // Fill the map of available input features // the key is the feature's name (std::string) // the value is the corresponding value in EnumInputFeatures -#define FILL_MAP(FEATURE) \ - { \ - #FEATURE, static_cast(InputFeatures::FEATURE) \ - } +#define FILL_MAP(FEATURE) \ + { \ + #FEATURE, static_cast(InputFeatures::FEATURE)} namespace o2::analysis { From 03a38449efaf006419b0ec3ad28a3fd3da1f0209 Mon Sep 17 00:00:00 2001 From: Jseo Date: Mon, 21 Jul 2025 13:19:56 +0200 Subject: [PATCH 4/7] new PR --- PWGDQ/Core/DQMlResponse.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/PWGDQ/Core/DQMlResponse.h b/PWGDQ/Core/DQMlResponse.h index 361e0c1fad7..ab150800cf7 100644 --- a/PWGDQ/Core/DQMlResponse.h +++ b/PWGDQ/Core/DQMlResponse.h @@ -27,9 +27,10 @@ // Fill the map of available input features // the key is the feature's name (std::string) // the value is the corresponding value in EnumInputFeatures -#define FILL_MAP(FEATURE) \ - { \ - #FEATURE, static_cast(InputFeatures::FEATURE)} +#define FILL_MAP(FEATURE) \ + { \ + #FEATURE, static_cast(InputFeatures::FEATURE) \ + } namespace o2::analysis { From 5707bee04e788fa1767b4e6e7be8dfd757481119 Mon Sep 17 00:00:00 2001 From: Jseo Date: Mon, 4 Aug 2025 15:14:52 +0200 Subject: [PATCH 5/7] Address comments --- PWGDQ/Core/CutsLibrary.cxx | 196 +++++++++++----- PWGDQ/Core/CutsLibrary.h | 17 +- PWGDQ/Core/DQMlResponse.h | 314 ++++++++++++++------------ PWGDQ/Macros/bdtCut.json | 113 +++++---- PWGDQ/Macros/bdtCutMulti.json | 173 +++++++++----- PWGDQ/Tasks/tableReader.cxx | 53 +++-- PWGDQ/Tasks/tableReader_withAssoc.cxx | 54 +++-- 7 files changed, 581 insertions(+), 339 deletions(-) diff --git a/PWGDQ/Core/CutsLibrary.cxx b/PWGDQ/Core/CutsLibrary.cxx index 96994e89621..6c26bfcac01 100644 --- a/PWGDQ/Core/CutsLibrary.cxx +++ b/PWGDQ/Core/CutsLibrary.cxx @@ -7109,140 +7109,202 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons //________________________________________________________________________________________________ o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json) { - LOG(info) << "========================================== interpreting JSON for analysis cuts"; + LOG(info) << "========================================== interpreting JSON for ML analysis cuts"; + if (!json) { + LOG(fatal) << "JSON config string is null!"; + return {}; + } LOG(info) << "JSON string: " << json; rapidjson::Document document; + + // Check that the json is parsed correctly rapidjson::ParseResult ok = document.Parse(json); if (!ok) { LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")"; - return {}; // empty variant + return {}; } for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) { const auto& obj = it->value; + // Classification type if (!obj.HasMember("type")) { LOG(fatal) << "Missing type (Binary/MultiClass)"; return {}; } - TString typeStr = obj["type"].GetString(); - // int nClasses = (typeStr == "MultiClass") ? 3 : 1; + if (typeStr != "Binary" && typeStr != "MultiClass") { + LOG(fatal) << "Unsupported classification type: " << typeStr; + return {}; + } + // Input features + if (!obj.HasMember("inputFeatures") || !obj["inputFeatures"].IsArray()) { + LOG(fatal) << "Missing inputFeatures member or array"; + return {}; + } std::vector namesInputFeatures; - if (obj.HasMember("inputFeatures") && obj["inputFeatures"].IsArray()) { - for (const auto& feature : obj["inputFeatures"].GetArray()) { - namesInputFeatures.emplace_back(feature.GetString()); - } + for (const auto& feature : obj["inputFeatures"].GetArray()) { + namesInputFeatures.emplace_back(feature.GetString()); + LOG(debug) << "Input features: " << feature.GetString(); } + // Model files + if (!obj.HasMember("modelFiles") || !obj["modelFiles"].IsArray()) { + LOG(fatal) << "Missing modelFiles member or array"; + return {}; + } std::vector onnxFileNames; - if (obj.HasMember("modelFiles") && obj["modelFiles"].IsArray()) { - for (const auto& model : obj["modelFiles"].GetArray()) { - onnxFileNames.emplace_back(model.GetString()); - } + for (const auto& model : obj["modelFiles"].GetArray()) { + onnxFileNames.emplace_back(model.GetString()); + LOG(debug) << "Model Files: " << model.GetString() << " "; + } + + // Centrality estimation type + if (!obj.HasMember("cent") || !obj["cent"].IsString()) { + LOG(fatal) << "Missing cent member"; + return {}; + } + std::string cent = obj["cent"].GetString(); + LOG(debug) << "Centrality type: " << cent; + if (cent != "kCentFT0C" && cent != "kCentFT0A" && cent != "kCentFT0M") { + LOG(fatal) << "Unsupported centrality type: " << cent; + return {}; } // Cut storage + std::vector> centBins; std::vector> ptBins; std::vector> cutsMl; std::vector cutDirs; + std::vector labelsFlatBin; bool cutDirsFilled = false; - for (auto member = obj.MemberBegin(); member != obj.MemberEnd(); ++member) { - TString key = member->name.GetString(); - if (!key.Contains("AddCut")) + for (auto centMember = obj.MemberBegin(); centMember != obj.MemberEnd(); ++centMember) { + TString centKey = centMember->name.GetString(); + if (!centKey.Contains("AddCentCut")) continue; - const auto& cut = member->value; + const auto& centCut = centMember->value; - if (!cut.HasMember("pTMin") || !cut.HasMember("pTMax")) { - LOG(fatal) << "Missing pTMin/pTMax in ML cut"; + // Centrality info + if (!centCut.HasMember("centMin") || !centCut.HasMember("centMax")) { + LOG(fatal) << "Missing centMin/centMax in " << centKey; return {}; } + double centMin = centCut["centMin"].GetDouble(); + double centMax = centCut["centMax"].GetDouble(); - double pTMin = cut["pTMin"].GetDouble(); - double pTMax = cut["pTMax"].GetDouble(); - ptBins.emplace_back(pTMin, pTMax); + for (auto ptMember = centCut.MemberBegin(); ptMember != centCut.MemberEnd(); ++ptMember) { + TString ptKey = ptMember->name.GetString(); + if (!ptKey.Contains("AddPtCut")) + continue; - std::vector binCuts; - bool exclude = false; + const auto& ptCut = ptMember->value; - for (const auto& sub : cut.GetObject()) { - TString subKey = sub.name.GetString(); - if (!subKey.Contains("AddMLCut")) - continue; + // Pt info + if (!ptCut.HasMember("pTMin") || !ptCut.HasMember("pTMax")) { + LOG(fatal) << "Missing pTMin/pTMax in " << ptKey; + return {}; + } + + double ptMin = ptCut["pTMin"].GetDouble(); + double ptMax = ptCut["pTMax"].GetDouble(); + + std::vector binCuts; + bool exclude = false; + + for (auto mlMember = ptCut.MemberBegin(); mlMember != ptCut.MemberEnd(); ++mlMember) { + TString mlKey = mlMember->name.GetString(); + if (!mlKey.Contains("AddMLCut")) + continue; - const auto& mlcut = sub.value; - // const char* var = mlcut["var"].GetString(); - double cutVal = mlcut.HasMember("cut") ? mlcut["cut"].GetDouble() : 0.5; - exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false; + const auto& mlcut = mlMember->value; - binCuts.push_back(cutVal); + if (!mlcut.HasMember("cut")) { + LOG(fatal) << "Missing cut (score) in " << mlKey; + return {}; + } + + double cutVal = mlcut["cut"].GetDouble(); + exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false; + + binCuts.push_back(cutVal); + + if (!cutDirsFilled) { + cutDirs.push_back(exclude ? 0 : 1); + } + } if (!cutDirsFilled) { - cutDirs.push_back(exclude ? 1 : 0); cutDirsFilled = true; } - } - cutsMl.push_back(binCuts); + centBins.emplace_back(centMin, centMax); + ptBins.emplace_back(ptMin, ptMax); + cutsMl.push_back(binCuts); + labelsFlatBin.push_back(Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax)); + LOG(info) << "Added cut for " << Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax) << " with cuts: ["; + for (size_t i = 0; i < binCuts.size(); ++i) { + std::cout << binCuts[i]; + if (i != binCuts.size() - 1) + std::cout << ", "; + } + std::cout << "] and direction: " << (exclude ? "CutGreater" : "CutSmaller") << std::endl; + } } - // bin edges - std::vector binsPt; - if (!ptBins.empty()) { - std::set binEdges; - for (const auto& b : ptBins) - binEdges.insert(b.first); - binEdges.insert(ptBins.back().second); - binsPt = std::vector(binEdges.begin(), binEdges.end()); - } else { - LOG(fatal) << "No pT bins found in ML cuts"; + if (cutDirs.size() != cutsMl[0].size()) { + LOG(fatal) << "Mismatch the cut size and direction size: cutsMl[0].size() = " << cutsMl[0].size() + << ", cutsMl[0].size() = " << cutDirs.size(); return {}; } - std::vector labelsPt, labelsClass; - for (size_t i = 0; i < cutsMl.size(); ++i) { - labelsPt.push_back(Form("pT%.1f", binsPt[i])); - } + std::vector labelsClass; for (size_t j = 0; j < cutsMl[0].size(); ++j) { - labelsClass.push_back(Form("cls%d", static_cast(j))); + labelsClass.push_back(Form("score class %d", static_cast(j))); } + size_t nFlatBins = cutsMl.size(); + std::vector binsMl(nFlatBins + 1); + std::iota(binsMl.begin(), binsMl.end(), 0); + // Binary if (typeStr == "Binary") { dqmlcuts::BinaryBdtScoreConfig binaryCfg; binaryCfg.inputFeatures = namesInputFeatures; binaryCfg.onnxFiles = onnxFileNames; - binaryCfg.binsPt = binsPt; + binaryCfg.binsCent = centBins; + binaryCfg.binsPt = ptBins; + binaryCfg.binsMl = binsMl; binaryCfg.cutDirs = cutDirs; - binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass); + binaryCfg.centType = cent; + binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass); return binaryCfg; - // MultiClass + // MultiClass } else if (typeStr == "MultiClass") { dqmlcuts::MultiClassBdtScoreConfig multiCfg; multiCfg.inputFeatures = namesInputFeatures; multiCfg.onnxFiles = onnxFileNames; - multiCfg.binsPt = binsPt; + multiCfg.binsCent = centBins; + multiCfg.binsPt = ptBins; + multiCfg.binsMl = binsMl; multiCfg.cutDirs = cutDirs; - multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass); + multiCfg.centType = cent; + multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass); return multiCfg; } - - LOG(fatal) << "Unsupported classification type: " << typeStr; - return {}; } return {}; } o2::framework::LabeledArray o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector>& cuts, - const std::vector& labelsPt, + const std::vector& labelsflatBin, const std::vector& labelsClass) { const size_t nRows = cuts.size(); @@ -7254,5 +7316,19 @@ o2::framework::LabeledArray o2::aod::dqmlcuts::makeLabeledCutsMl(const s } o2::framework::Array2D arr(flat.data(), nRows, nCols); - return o2::framework::LabeledArray(arr, labelsPt, labelsClass); + return o2::framework::LabeledArray(arr, labelsflatBin, labelsClass); } + +int o2::aod::dqmlcuts::getMlBinIndex(double cent, double pt, + const std::vector>& binsCent, + const std::vector>& binsPt) +{ + LOG(debug) << "Searching for Ml bin index for cent: " << cent << ", pt: " << pt; //here + for (size_t i = 0; i < binsCent.size(); ++i) { + if (cent >= binsCent[i].first && cent < binsCent[i].second && pt >= binsPt[i].first && pt < binsPt[i].second) { + LOG(debug) << " - Found at index: " << i; //here + return static_cast(i); + } + } + return -1; // not found +} \ No newline at end of file diff --git a/PWGDQ/Core/CutsLibrary.h b/PWGDQ/Core/CutsLibrary.h index 9418197b75e..f6874543d8e 100644 --- a/PWGDQ/Core/CutsLibrary.h +++ b/PWGDQ/Core/CutsLibrary.h @@ -124,15 +124,21 @@ namespace dqmlcuts struct BinaryBdtScoreConfig { std::vector inputFeatures; std::vector onnxFiles; - std::vector binsPt; - o2::framework::LabeledArray cutsMl; - std::vector cutDirs; + std::vector> binsCent; // bins for centrality + std::vector> binsPt; // bins for pT + std::vector binsMl; // bins for flattened binning + std::string centType; + o2::framework::LabeledArray cutsMl; // BDT score cuts for each bin + std::vector cutDirs; // direction of the cuts on the BDT score }; struct MultiClassBdtScoreConfig { std::vector inputFeatures; std::vector onnxFiles; - std::vector binsPt; + std::vector> binsCent; + std::vector> binsPt; + std::vector binsMl; + std::string centType; o2::framework::LabeledArray cutsMl; std::vector cutDirs; }; @@ -144,6 +150,9 @@ BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json); o2::framework::LabeledArray makeLabeledCutsMl(const std::vector>& cuts, const std::vector& labelsPt, const std::vector& labelsClass); +int getMlBinIndex(double cent, double pt, + const std::vector>& binsCent, + const std::vector>& binsPt); } // namespace dqmlcuts } // namespace o2::aod diff --git a/PWGDQ/Core/DQMlResponse.h b/PWGDQ/Core/DQMlResponse.h index ab150800cf7..64bfe233cc7 100644 --- a/PWGDQ/Core/DQMlResponse.h +++ b/PWGDQ/Core/DQMlResponse.h @@ -24,91 +24,89 @@ #include #include -// Fill the map of available input features -// the key is the feature's name (std::string) -// the value is the corresponding value in EnumInputFeatures -#define FILL_MAP(FEATURE) \ - { \ - #FEATURE, static_cast(InputFeatures::FEATURE) \ - } - namespace o2::analysis { -enum class InputFeatures : uint8_t { // refer to DielectronsAll - fMass = 0, - fPt, - fEta, - fPhi, - fPt1, - fITSChi2NCl1, - fTPCNClsCR1, - fTPCNClsFound1, - fTPCChi2NCl1, - fDcaXY1, - fDcaZ1, - fTPCNSigmaEl1, - fTPCNSigmaPi1, - fTPCNSigmaPr1, - fTOFNSigmaEl1, - fTOFNSigmaPi1, - fTOFNSigmaPr1, - fPt2, - fITSChi2NCl2, - fTPCNClsCR2, - fTPCNClsFound2, - fTPCChi2NCl2, - fDcaXY2, - fDcaZ2, - fTPCNSigmaEl2, - fTPCNSigmaPi2, - fTPCNSigmaPr2, - fTOFNSigmaEl2, - fTOFNSigmaPi2, - fTOFNSigmaPr2, +enum class InputFeatures : uint8_t { // refer to DielectronsAll, TODO: add more features if needed + kMass = 0, + kPt, + kEta, + kPhi, + kPt1, + kITSChi2NCl1, + kTPCNClsCR1, + kTPCNClsFound1, + kTPCChi2NCl1, + kDcaXY1, + kDcaZ1, + kTPCNSigmaEl1, + kTPCNSigmaPi1, + kTPCNSigmaPr1, + kTOFNSigmaEl1, + kTOFNSigmaPi1, + kTOFNSigmaPr1, + kPt2, + kITSChi2NCl2, + kTPCNClsCR2, + kTPCNClsFound2, + kTPCChi2NCl2, + kDcaXY2, + kDcaZ2, + kTPCNSigmaEl2, + kTPCNSigmaPi2, + kTPCNSigmaPr2, + kTOFNSigmaEl2, + kTOFNSigmaPi2, + kTOFNSigmaPr2 }; static const std::map gFeatureNameMap = { - {InputFeatures::fMass, "fMass"}, - {InputFeatures::fPt, "fPt"}, - {InputFeatures::fEta, "fEta"}, - {InputFeatures::fPhi, "fPhi"}, - {InputFeatures::fPt1, "fPt1"}, - {InputFeatures::fITSChi2NCl1, "fITSChi2NCl1"}, - {InputFeatures::fTPCNClsCR1, "fTPCNClsCR1"}, - {InputFeatures::fTPCNClsFound1, "fTPCNClsFound1"}, - {InputFeatures::fTPCChi2NCl1, "fTPCChi2NCl1"}, - {InputFeatures::fDcaXY1, "fDcaXY1"}, - {InputFeatures::fDcaZ1, "fDcaZ1"}, - {InputFeatures::fTPCNSigmaEl1, "fTPCNSigmaEl1"}, - {InputFeatures::fTPCNSigmaPi1, "fTPCNSigmaPi1"}, - {InputFeatures::fTPCNSigmaPr1, "fTPCNSigmaPr1"}, - {InputFeatures::fTOFNSigmaEl1, "fTOFNSigmaEl1"}, - {InputFeatures::fTOFNSigmaPi1, "fTOFNSigmaPi1"}, - {InputFeatures::fTOFNSigmaPr1, "fTOFNSigmaPr1"}, - {InputFeatures::fPt2, "fPt2"}, - {InputFeatures::fITSChi2NCl2, "fITSChi2NCl2"}, - {InputFeatures::fTPCNClsCR2, "fTPCNClsCR2"}, - {InputFeatures::fTPCNClsFound2, "fTPCNClsFound2"}, - {InputFeatures::fTPCChi2NCl2, "fTPCChi2NCl2"}, - {InputFeatures::fDcaXY2, "fDcaXY2"}, - {InputFeatures::fDcaZ2, "fDcaZ2"}, - {InputFeatures::fTPCNSigmaEl2, "fTPCNSigmaEl2"}, - {InputFeatures::fTPCNSigmaPi2, "fTPCNSigmaPi2"}, - {InputFeatures::fTPCNSigmaPr2, "fTPCNSigmaPr2"}, - {InputFeatures::fTOFNSigmaEl2, "fTOFNSigmaEl2"}, - {InputFeatures::fTOFNSigmaPi2, "fTOFNSigmaPi2"}, - {InputFeatures::fTOFNSigmaPr2, "fTOFNSigmaPr2"}}; + {InputFeatures::kMass, "kMass"}, + {InputFeatures::kPt, "kPt"}, + {InputFeatures::kEta, "kEta"}, + {InputFeatures::kPhi, "kPhi"}, + {InputFeatures::kPt1, "kPt1"}, + {InputFeatures::kITSChi2NCl1, "kITSChi2NCl1"}, + {InputFeatures::kTPCNClsCR1, "kTPCNClsCR1"}, + {InputFeatures::kTPCNClsFound1, "kTPCNClsFound1"}, + {InputFeatures::kTPCChi2NCl1, "kTPCChi2NCl1"}, + {InputFeatures::kDcaXY1, "kDcaXY1"}, + {InputFeatures::kDcaZ1, "kDcaZ1"}, + {InputFeatures::kTPCNSigmaEl1, "kTPCNSigmaEl1"}, + {InputFeatures::kTPCNSigmaPi1, "kTPCNSigmaPi1"}, + {InputFeatures::kTPCNSigmaPr1, "kTPCNSigmaPr1"}, + {InputFeatures::kTOFNSigmaEl1, "kTOFNSigmaEl1"}, + {InputFeatures::kTOFNSigmaPi1, "kTOFNSigmaPi1"}, + {InputFeatures::kTOFNSigmaPr1, "kTOFNSigmaPr1"}, + {InputFeatures::kPt2, "kPt2"}, + {InputFeatures::kITSChi2NCl2, "kITSChi2NCl2"}, + {InputFeatures::kTPCNClsCR2, "kTPCNClsCR2"}, + {InputFeatures::kTPCNClsFound2, "kTPCNClsFound2"}, + {InputFeatures::kTPCChi2NCl2, "kTPCChi2NCl2"}, + {InputFeatures::kDcaXY2, "kDcaXY2"}, + {InputFeatures::kDcaZ2, "kDcaZ2"}, + {InputFeatures::kTPCNSigmaEl2, "kTPCNSigmaEl2"}, + {InputFeatures::kTPCNSigmaPi2, "kTPCNSigmaPi2"}, + {InputFeatures::kTPCNSigmaPr2, "kTPCNSigmaPr2"}, + {InputFeatures::kTOFNSigmaEl2, "kTOFNSigmaEl2"}, + {InputFeatures::kTOFNSigmaPi2, "kTOFNSigmaPi2"}, + {InputFeatures::kTOFNSigmaPr2, "kTOFNSigmaPr2"}}; template class DQMlResponse : public MlResponse { public: - /// Default constructor DQMlResponse() = default; - /// Default destructor virtual ~DQMlResponse() = default; + void setBinsCent(const std::vector>& bins) { binsCent = bins; } + void setBinsPt(const std::vector>& bins) { binsPt = bins; } + void setCentType(std::string& type) { centType = type; } + + const std::vector>& getBinsCent() const { return binsCent; } + const std::vector>& getBinsPt() const { return binsPt; } + const std::string& getCentType() const { return centType; } + /// Method to get the input features vector needed for ML inference /// \return inputFeatures vector template @@ -116,98 +114,126 @@ class DQMlResponse : public MlResponse const T2& t2, const TValues& fg) const { - using Accessor = std::function; - static const std::unordered_map featureMap{ - {"fMass", [](auto const&, auto const&, auto const& v) { return v[VarManager::kMass]; }}, - {"fPt", [](auto const&, auto const&, auto const& v) { return v[VarManager::kPt]; }}, - {"fEta", [](auto const&, auto const&, auto const& v) { return v[VarManager::kEta]; }}, - {"fPhi", [](auto const&, auto const&, auto const& v) { return v[VarManager::kPhi]; }}, - - {"fPt1", [](auto const& t1, auto const&, auto const&) { return t1.pt(); }}, - {"fITSChi2NCl1", [](auto const& t1, auto const&, auto const&) { return t1.itsChi2NCl(); }}, - {"fTPCNClsCR1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNClsCrossedRows(); }}, - {"fTPCNClsFound1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNClsFound(); }}, - {"fTPCChi2NCl1", [](auto const& t1, auto const&, auto const&) { return t1.tpcChi2NCl(); }}, - {"fDcaXY1", [](auto const& t1, auto const&, auto const&) { return t1.dcaXY(); }}, - {"fDcaZ1", [](auto const& t1, auto const&, auto const&) { return t1.dcaZ(); }}, - {"fTPCNSigmaEl1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaEl(); }}, - {"fTPCNSigmaPi1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaPi(); }}, - {"fTPCNSigmaPr1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaPr(); }}, - {"fTOFNSigmaEl1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaEl(); }}, - {"fTOFNSigmaPi1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaPi(); }}, - {"fTOFNSigmaPr1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaPr(); }}, - - {"fPt2", [](auto const&, auto const& t2, auto const&) { return t2.pt(); }}, - {"fITSChi2NCl2", [](auto const&, auto const& t2, auto const&) { return t2.itsChi2NCl(); }}, - {"fTPCNClsCR2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNClsCrossedRows(); }}, - {"fTPCNClsFound2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNClsFound(); }}, - {"fTPCChi2NCl2", [](auto const&, auto const& t2, auto const&) { return t2.tpcChi2NCl(); }}, - {"fDcaXY2", [](auto const&, auto const& t2, auto const&) { return t2.dcaXY(); }}, - {"fDcaZ2", [](auto const&, auto const& t2, auto const&) { return t2.dcaZ(); }}, - {"fTPCNSigmaEl2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaEl(); }}, - {"fTPCNSigmaPi2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaPi(); }}, - {"fTPCNSigmaPr2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaPr(); }}, - {"fTOFNSigmaEl2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaEl(); }}, - {"fTOFNSigmaPi2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaPi(); }}, - {"fTOFNSigmaPr2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaPr(); }}}; - std::vector dqInputFeatures; dqInputFeatures.reserve(MlResponse::mCachedIndices.size()); for (auto idx : MlResponse::mCachedIndices) { auto enumIdx = static_cast(idx); - const auto& name = gFeatureNameMap.at(enumIdx); + auto mapIdx = gFeatureNameMap.find(enumIdx); + if (mapIdx == gFeatureNameMap.end()) { + LOG(fatal) << "Unknown InputFeatures index: " << static_cast(enumIdx); + } - auto acc = featureMap.find(name); - if (acc == featureMap.end()) { - LOG(error) << "Missing accessor for " << name; - continue; + const auto& name = mapIdx->second; + if (name == "kMass") { + dqInputFeatures.push_back(fg[VarManager::fgVarNamesMap["kMass"]]); + } else if (name == "kPt") { + dqInputFeatures.push_back(fg[VarManager::fgVarNamesMap["kPt"]]); + } else if (name == "kEta") { + dqInputFeatures.push_back(fg[VarManager::fgVarNamesMap["kEta"]]); + } else if (name == "kPhi") { + dqInputFeatures.push_back(fg[VarManager::fgVarNamesMap["kPhi"]]); + } else if (name == "kPt1") { + dqInputFeatures.push_back(t1.pt()); + } else if (name == "kITSChi2NCl1") { + dqInputFeatures.push_back(t1.itsChi2NCl()); + } else if (name == "kTPCNClsCR1") { + dqInputFeatures.push_back(t1.tpcNClsCrossedRows()); + } else if (name == "kTPCNClsFound1") { + dqInputFeatures.push_back(t1.tpcNClsFound()); + } else if (name == "kTPCChi2NCl1") { + dqInputFeatures.push_back(t1.tpcChi2NCl()); + } else if (name == "kDcaXY1") { + dqInputFeatures.push_back(t1.dcaXY()); + } else if (name == "kDcaZ1") { + dqInputFeatures.push_back(t1.dcaZ()); + } else if (name == "kTPCNSigmaEl1") { + dqInputFeatures.push_back(t1.tpcNSigmaEl()); + } else if (name == "kTPCNSigmaPi1") { + dqInputFeatures.push_back(t1.tpcNSigmaPi()); + } else if (name == "kTPCNSigmaPr1") { + dqInputFeatures.push_back(t1.tpcNSigmaPr()); + } else if (name == "kTOFNSigmaEl1") { + dqInputFeatures.push_back(t1.tofNSigmaEl()); + } else if (name == "kTOFNSigmaPi1") { + dqInputFeatures.push_back(t1.tofNSigmaPi()); + } else if (name == "kTOFNSigmaPr1") { + dqInputFeatures.push_back(t1.tofNSigmaPr()); + } else if (name == "kPt2") { + dqInputFeatures.push_back(t2.pt()); + } else if (name == "kITSChi2NCl2") { + dqInputFeatures.push_back(t2.itsChi2NCl()); + } else if (name == "kTPCNClsCR2") { + dqInputFeatures.push_back(t2.tpcNClsCrossedRows()); + } else if (name == "kTPCNClsFound2") { + dqInputFeatures.push_back(t2.tpcNClsFound()); + } else if (name == "kTPCChi2NCl2") { + dqInputFeatures.push_back(t2.tpcChi2NCl()); + } else if (name == "kDcaXY2") { + dqInputFeatures.push_back(t2.dcaXY()); + } else if (name == "kDcaZ2") { + dqInputFeatures.push_back(t2.dcaZ()); + } else if (name == "kTPCNSigmaEl2") { + dqInputFeatures.push_back(t2.tpcNSigmaEl()); + } else if (name == "kTPCNSigmaPi2") { + dqInputFeatures.push_back(t2.tpcNSigmaPi()); + } else if (name == "kTPCNSigmaPr2") { + dqInputFeatures.push_back(t2.tpcNSigmaPr()); + } else if (name == "kTOFNSigmaEl2") { + dqInputFeatures.push_back(t2.tofNSigmaEl()); + } else if (name == "kTOFNSigmaPi2") { + dqInputFeatures.push_back(t2.tofNSigmaPi()); + } else if (name == "kTOFNSigmaPr2") { + dqInputFeatures.push_back(t2.tofNSigmaPr()); } else { - dqInputFeatures.push_back(acc->second(t1, t2, fg)); + LOG(fatal) << "Missing accessor for feature: " << name; } } + LOG(debug) << "Total features collected: " << dqInputFeatures.size(); return dqInputFeatures; } protected: + std::vector> binsCent; + std::vector> binsPt; + std::string centType; + void setAvailableInputFeatures() { MlResponse::mAvailableInputFeatures = { - FILL_MAP(fMass), - FILL_MAP(fPt), - FILL_MAP(fEta), - FILL_MAP(fPhi), - FILL_MAP(fPt1), - FILL_MAP(fITSChi2NCl1), - FILL_MAP(fTPCNClsCR1), - FILL_MAP(fTPCNClsFound1), - FILL_MAP(fTPCChi2NCl1), - FILL_MAP(fDcaXY1), - FILL_MAP(fDcaZ1), - FILL_MAP(fTPCNSigmaEl1), - FILL_MAP(fTPCNSigmaPi1), - FILL_MAP(fTPCNSigmaPr1), - FILL_MAP(fTOFNSigmaEl1), - FILL_MAP(fTOFNSigmaPi1), - FILL_MAP(fTOFNSigmaPr1), - FILL_MAP(fPt2), - FILL_MAP(fITSChi2NCl2), - FILL_MAP(fTPCNClsCR2), - FILL_MAP(fTPCNClsFound2), - FILL_MAP(fTPCChi2NCl2), - FILL_MAP(fDcaXY2), - FILL_MAP(fDcaZ2), - FILL_MAP(fTPCNSigmaEl2), - FILL_MAP(fTPCNSigmaPi2), - FILL_MAP(fTPCNSigmaPr2), - FILL_MAP(fTOFNSigmaEl2), - FILL_MAP(fTOFNSigmaPi2), - FILL_MAP(fTOFNSigmaPr2)}; + {"kMass", static_cast(InputFeatures::kMass)}, + {"kPt", static_cast(InputFeatures::kPt)}, + {"kEta", static_cast(InputFeatures::kEta)}, + {"kPhi", static_cast(InputFeatures::kPhi)}, + {"kPt1", static_cast(InputFeatures::kPt1)}, + {"kITSChi2NCl1", static_cast(InputFeatures::kITSChi2NCl1)}, + {"kTPCNClsCR1", static_cast(InputFeatures::kTPCNClsCR1)}, + {"kTPCNClsFound1", static_cast(InputFeatures::kTPCNClsFound1)}, + {"kTPCChi2NCl1", static_cast(InputFeatures::kTPCChi2NCl1)}, + {"kDcaXY1", static_cast(InputFeatures::kDcaXY1)}, + {"kDcaZ1", static_cast(InputFeatures::kDcaZ1)}, + {"kTPCNSigmaEl1", static_cast(InputFeatures::kTPCNSigmaEl1)}, + {"kTPCNSigmaPi1", static_cast(InputFeatures::kTPCNSigmaPi1)}, + {"kTPCNSigmaPr1", static_cast(InputFeatures::kTPCNSigmaPr1)}, + {"kTOFNSigmaEl1", static_cast(InputFeatures::kTOFNSigmaEl1)}, + {"kTOFNSigmaPi1", static_cast(InputFeatures::kTOFNSigmaPi1)}, + {"kTOFNSigmaPr1", static_cast(InputFeatures::kTOFNSigmaPr1)}, + {"kPt2", static_cast(InputFeatures::kPt2)}, + {"kITSChi2NCl2", static_cast(InputFeatures::kITSChi2NCl2)}, + {"kTPCNClsCR2", static_cast(InputFeatures::kTPCNClsCR2)}, + {"kTPCNClsFound2", static_cast(InputFeatures::kTPCNClsFound2)}, + {"kTPCChi2NCl2", static_cast(InputFeatures::kTPCChi2NCl2)}, + {"kDcaXY2", static_cast(InputFeatures::kDcaXY2)}, + {"kDcaZ2", static_cast(InputFeatures::kDcaZ2)}, + {"kTPCNSigmaEl2", static_cast(InputFeatures::kTPCNSigmaEl2)}, + {"kTPCNSigmaPi2", static_cast(InputFeatures::kTPCNSigmaPi2)}, + {"kTPCNSigmaPr2", static_cast(InputFeatures::kTPCNSigmaPr2)}, + {"kTOFNSigmaEl2", static_cast(InputFeatures::kTOFNSigmaEl2)}, + {"kTOFNSigmaPi2", static_cast(InputFeatures::kTOFNSigmaPi2)}, + {"kTOFNSigmaPr2", static_cast(InputFeatures::kTOFNSigmaPr2)}}; } }; } // namespace o2::analysis -#undef FILL_MAP - #endif // PWGDQ_CORE_DQMLRESPONSE_H_ diff --git a/PWGDQ/Macros/bdtCut.json b/PWGDQ/Macros/bdtCut.json index 5ff0ada64b4..2e015fa8618 100644 --- a/PWGDQ/Macros/bdtCut.json +++ b/PWGDQ/Macros/bdtCut.json @@ -3,53 +3,86 @@ "type": "Binary", "title": "MyBDTModel", "inputFeatures": [ - "fPt1", - "fITSChi2NCl1", - "fTPCNClsCR1", - "fTPCNClsFound1", - "fTPCChi2NCl1", - "fDcaXY1", - "fDcaZ1", - "fTPCNSigmaEl1", - "fTPCNSigmaPi1", - "fTPCNSigmaPr1", - "fTOFNSigmaEl1", - "fTOFNSigmaPi1", - "fTOFNSigmaPr1", - "fPt2", - "fITSChi2NCl2", - "fTPCNClsCR2", - "fTPCNClsFound2", - "fTPCChi2NCl2", - "fDcaXY2", - "fDcaZ2", - "fTPCNSigmaEl2", - "fTPCNSigmaPi2", - "fTPCNSigmaPr2", - "fTOFNSigmaEl2", - "fTOFNSigmaPi2", - "fTOFNSigmaPr2" + "kMass", + "kPt", + "kEta", + "kPhi", + "kPt1", + "kITSChi2NCl1", + "kTPCNClsCR1", + "kTPCNClsFound1", + "kTPCChi2NCl1", + "kDcaXY1", + "kDcaZ1", + "kTPCNSigmaEl1", + "kTPCNSigmaPi1", + "kTPCNSigmaPr1", + "kTOFNSigmaEl1", + "kTOFNSigmaPi1", + "kTOFNSigmaPr1", + "kPt2", + "kITSChi2NCl2", + "kTPCNClsCR2", + "kTPCNClsFound2", + "kTPCChi2NCl2", + "kDcaXY2", + "kDcaZ2", + "kTPCNSigmaEl2", + "kTPCNSigmaPi2", + "kTPCNSigmaPr2", + "kTOFNSigmaEl2", + "kTOFNSigmaPi2", + "kTOFNSigmaPr2" ], "modelFiles": [ + "cent_10_30_pt0_2_onnx.onnx", + "cent_10_30_pt2_20_onnx.onnx", "cent_30_50_pt0_2_onnx.onnx", "cent_30_50_pt2_20_onnx.onnx" ], - "AddMLCut-pTBin1": { - "pTMin": 0, - "pTMax": 2, - "AddMLCut-background": { - "var": "kBdtBackground", - "cut": 0.5, - "exclude": false + "cent": "kCentFT0C", + "AddCentCut-Cent1030": { + "centMin": 10, + "centMax": 30, + "AddPtCut-pTBin1": { + "pTMin": 0, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + }, + "AddPtCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } } }, - "AddMLCut-pTBin2": { - "pTMin": 2, - "pTMax": 20, - "AddMLCut-background": { - "var": "kBdtBackground", - "cut": 0.5, - "exclude": false + "AddCentCut-Cent3050": { + "centMin": 30, + "centMax": 50, + "AddPtCut-pTBin1": { + "pTMin": 0, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + }, + "AddPtCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } } } } diff --git a/PWGDQ/Macros/bdtCutMulti.json b/PWGDQ/Macros/bdtCutMulti.json index 3042e767c11..ba687e36200 100644 --- a/PWGDQ/Macros/bdtCutMulti.json +++ b/PWGDQ/Macros/bdtCutMulti.json @@ -1,73 +1,128 @@ { "TestCut": { - "type": "Binary", + "type": "MultiClass", "title": "MyBDTModel", "inputFeatures": [ - "fITSChi2NCl1", - "fTPCNClsCR1", - "fTPCNClsFound1", - "fTPCChi2NCl1", - "fDcaXY1", - "fDcaZ1", - "fTPCNSigmaEl1", - "fTPCNSigmaPi1", - "fTPCNSigmaPr1", - "fTOFNSigmaEl1", - "fTOFNSigmaPi1", - "fTOFNSigmaPr1", - "fITSChi2NCl2", - "fTPCNClsCR2", - "fTPCNClsFound2", - "fTPCChi2NCl2", - "fDcaXY2", - "fDcaZ2", - "fTPCNSigmaEl2", - "fTPCNSigmaPi2", - "fTPCNSigmaPr2", - "fTOFNSigmaEl2", - "fTOFNSigmaPi2", - "fTOFNSigmaPr2" + "kMass", + "kPt", + "kEta", + "kPhi", + "kPt1", + "kITSChi2NCl1", + "kTPCNClsCR1", + "kTPCNClsFound1", + "kTPCChi2NCl1", + "kDcaXY1", + "kDcaZ1", + "kTPCNSigmaEl1", + "kTPCNSigmaPi1", + "kTPCNSigmaPr1", + "kTOFNSigmaEl1", + "kTOFNSigmaPi1", + "kTOFNSigmaPr1", + "kPt2", + "kITSChi2NCl2", + "kTPCNClsCR2", + "kTPCNClsFound2", + "kTPCChi2NCl2", + "kDcaXY2", + "kDcaZ2", + "kTPCNSigmaEl2", + "kTPCNSigmaPi2", + "kTPCNSigmaPr2", + "kTOFNSigmaEl2", + "kTOFNSigmaPi2", + "kTOFNSigmaPr2" ], "modelFiles": [ - "cent_30_50_pt0_2_onnx.onnx", - "cent_30_50_pt2_20_onnx.onnx" + "cent_10_30_pt_1_2_onnx.onnx", + "cent_10_30_pt_2_20_onnx.onnx", + "cent_30_50_pt_1_2_onnx.onnx", + "cent_30_50_pt_2_20_onnx.onnx" ], - "AddMLCut-pTBin1": { - "pTMin": 0, - "pTMax": 2, - "AddMLCut-background": { - "var": "kBdtBackground", - "cut": 0.5, - "exclude": true + "cent": "kCentFT0C", + "AddCentCut-Cent1030": { + "centMin": 10, + "centMax": 30, + "AddPtCut-pTBin1": { + "pTMin": 1, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } }, - "AddMLCut-prompt": { - "var": "kBdtPrompt", - "cut": 0.5, - "exclude": false - }, - "AddMLCut-nonprompt": { - "var": "kBdtNonprompt", - "cut": 0.5, - "exclude": false + "AddPtCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } } }, - "AddMLCut-pTBin2": { - "pTMin": 2, - "pTMax": 20, - "AddMLCut-background": { - "var": "kBdtBackground", - "cut": 0.5, - "exclude": true - }, - "AddMLCut-prompt": { - "var": "kBdtPrompt", - "cut": 0.5, - "exclude": false + "AddCentCut-Cent3050": { + "centMin": 30, + "centMax": 50, + "AddPtCut-pTBin1": { + "pTMin": 1, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } }, - "AddMLCut-nonprompt": { - "var": "kBdtNonprompt", - "cut": 0.5, - "exclude": false + "AddPtCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } } } } diff --git a/PWGDQ/Tasks/tableReader.cxx b/PWGDQ/Tasks/tableReader.cxx index 03c0fb8ee1a..2c96d5c0960 100644 --- a/PWGDQ/Tasks/tableReader.cxx +++ b/PWGDQ/Tasks/tableReader.cxx @@ -1134,10 +1134,10 @@ struct AnalysisSameEventPairing { if (applyBDT) { // BDT cuts via JSON - std::vector binsPtMl; + std::vector binsMl; o2::framework::LabeledArray cutsMl; std::vector cutDirMl; - int nClassesMl = 1; // 1 for binary BDT, 3 for multiclass BDT + int nClassesMl = 1; std::vector namesInputFeatures; std::vector onnxFileNames; @@ -1145,23 +1145,31 @@ struct AnalysisSameEventPairing { if (std::holds_alternative(config)) { auto& cfg = std::get(config); - binsPtMl = cfg.binsPt; + binsMl = cfg.binsMl; nClassesMl = 1; cutsMl = cfg.cutsMl; cutDirMl = cfg.cutDirs; namesInputFeatures = cfg.inputFeatures; onnxFileNames = cfg.onnxFiles; + dqMlResponse.setBinsCent(cfg.binsCent); + dqMlResponse.setBinsPt(cfg.binsPt); + dqMlResponse.setCentType(cfg.centType); + LOG(info) << "Using BDT cuts for binary classification"; } else { auto& cfg = std::get(config); - binsPtMl = cfg.binsPt; + binsMl = cfg.binsMl; nClassesMl = 3; cutsMl = cfg.cutsMl; cutDirMl = cfg.cutDirs; namesInputFeatures = cfg.inputFeatures; onnxFileNames = cfg.onnxFiles; + dqMlResponse.setBinsCent(cfg.binsCent); + dqMlResponse.setBinsPt(cfg.binsPt); + dqMlResponse.setCentType(cfg.centType); + LOG(info) << "Using BDT cuts for multiclass classification"; } - dqMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl); + dqMlResponse.configure(binsMl, cutsMl, cutDirMl, nClassesMl); if (loadModelsFromCCDB) { ccdbApi.init(ccdburl); dqMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB); @@ -1172,7 +1180,7 @@ struct AnalysisSameEventPairing { dqMlResponse.init(); } - if (context.mOptions.get("processDecayToEESkimmed") || context.mOptions.get("processDecayToEESkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithCov") || context.mOptions.get("processDecayToEESkimmedWithCovNoTwoProngFitter") || context.mOptions.get("processDecayToEEVertexingSkimmed") || context.mOptions.get("processVnDecayToEESkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithColl") || context.mOptions.get("processDecayToEESkimmedWithCollNoTwoProngFitter") || context.mOptions.get("processDecayToPiPiSkimmed") || context.mOptions.get("processAllSkimmed") || context.mOptions.get("processDecayToEESkimmedBDT")) { + if (context.mOptions.get("processDecayToEESkimmed") || context.mOptions.get("processDecayToEESkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithCov") || context.mOptions.get("processDecayToEESkimmedWithCovNoTwoProngFitter") || context.mOptions.get("processDecayToEEVertexingSkimmed") || context.mOptions.get("processVnDecayToEESkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithColl") || context.mOptions.get("processDecayToEESkimmedWithCollNoTwoProngFitter") || context.mOptions.get("processDecayToPiPiSkimmed") || context.mOptions.get("processAllSkimmed")) { TString cutNames = fConfigTrackCuts.value; if (!cutNames.IsNull()) { // if track cuts std::unique_ptr objArray(cutNames.Tokenize(",")); @@ -1456,8 +1464,29 @@ struct AnalysisSameEventPairing { return; } - // isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt]); - isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt], outputMlPsi2ee); + int modelIndex = -1; + const auto& binsCent = dqMlResponse.getBinsCent(); + const auto& binsPt = dqMlResponse.getBinsPt(); + const std::string& centType = dqMlResponse.getCentType(); + + if ("kCentFT0C" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0C], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else if ("kCentFT0A" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0A], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else if ("kCentFT0M" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0M], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else { + LOG(fatal) << "Unknown centrality estimation type: " << centType; + return; + } + + if (modelIndex < 0) { + LOG(debug) << "Ml index is negative! This means that the centrality/pt is not in the range of the model bins."; + continue; + } + + LOG(debug) << "Model index: " << modelIndex << ", pT: " << VarManager::fgValues[VarManager::kPt] << ", centrality (kCentFT0C): " << VarManager::fgValues[VarManager::kCentFT0C]; + isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, modelIndex, outputMlPsi2ee); VarManager::FillBdtScore(outputMlPsi2ee); // TODO: check if this is needed or not } @@ -1678,13 +1707,6 @@ struct AnalysisSameEventPairing { VarManager::FillEvent(event, VarManager::fgValues); runSameEventPairing(event, tracks, tracks); } - void processDecayToEESkimmedBDT(soa::Filtered::iterator const& event, soa::Filtered const& tracks) - { - // Reset the fValues array - VarManager::ResetValues(0, VarManager::kNVars); - VarManager::FillEvent(event, VarManager::fgValues); - runSameEventPairing(event, tracks, tracks); - } void processDecayToMuMuSkimmed(soa::Filtered::iterator const& event, soa::Filtered const& muons) { // Reset the fValues array @@ -1789,7 +1811,6 @@ struct AnalysisSameEventPairing { PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEEPrefilterSkimmedNoTwoProngFitter, "Run electron-electron pairing, with skimmed tracks and prefilter from AnalysisPrefilterSelection but no two prong fitter", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedWithColl, "Run electron-electron pairing, with skimmed tracks and with collision information", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedWithCollNoTwoProngFitter, "Run electron-electron pairing, with skimmed tracks and with collision information but no two prong fitter", false); - PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedBDT, "Run electron-electron pairing, with skimmed tracks and BDT selection", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuSkimmed, "Run muon-muon pairing, with skimmed muons", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuSkimmedWithMult, "Run muon-muon pairing, with skimmed muons and multiplicity", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuVertexingSkimmed, "Run muon-muon pairing and vertexing, with skimmed muons", false); diff --git a/PWGDQ/Tasks/tableReader_withAssoc.cxx b/PWGDQ/Tasks/tableReader_withAssoc.cxx index fd3d1eb2ec7..2d320f3619a 100644 --- a/PWGDQ/Tasks/tableReader_withAssoc.cxx +++ b/PWGDQ/Tasks/tableReader_withAssoc.cxx @@ -1293,7 +1293,7 @@ struct AnalysisSameEventPairing { void init(o2::framework::InitContext& context) { LOG(info) << "Starting initialization of AnalysisSameEventPairing (idstoreh)"; - fEnableBarrelHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processBarrelOnlySkimmed") || context.mOptions.get("processBarrelOnlyWithCollSkimmed") || context.mOptions.get("processBarrelOnlySkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedNoCovWithMultExtra") || context.mOptions.get("processBarrelOnlyWithQvectorCentrSkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedBDT"); + fEnableBarrelHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processBarrelOnlySkimmed") || context.mOptions.get("processBarrelOnlyWithCollSkimmed") || context.mOptions.get("processBarrelOnlySkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedNoCovWithMultExtra") || context.mOptions.get("processBarrelOnlyWithQvectorCentrSkimmedNoCov"); fEnableBarrelMixingHistos = context.mOptions.get("processMixingAllSkimmed") || context.mOptions.get("processMixingBarrelSkimmed"); fEnableMuonHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processMuonOnlySkimmed") || context.mOptions.get("processMuonOnlySkimmedMultExtra") || context.mOptions.get("processMixingMuonSkimmed"); fEnableMuonMixingHistos = context.mOptions.get("processMixingAllSkimmed") || context.mOptions.get("processMixingMuonSkimmed"); @@ -1335,10 +1335,10 @@ struct AnalysisSameEventPairing { if (fConfigML.applyBDT) { // BDT cuts via JSON - std::vector binsPtMl; + std::vector binsMl; o2::framework::LabeledArray cutsMl; std::vector cutDirMl; - int nClassesMl = 1; // 1 for binary BDT, 3 for multiclass BDT + int nClassesMl = 1; std::vector namesInputFeatures; std::vector onnxFileNames; @@ -1346,23 +1346,31 @@ struct AnalysisSameEventPairing { if (std::holds_alternative(config)) { auto& cfg = std::get(config); - binsPtMl = cfg.binsPt; + binsMl = cfg.binsMl; nClassesMl = 1; cutsMl = cfg.cutsMl; cutDirMl = cfg.cutDirs; namesInputFeatures = cfg.inputFeatures; onnxFileNames = cfg.onnxFiles; + dqMlResponse.setBinsCent(cfg.binsCent); + dqMlResponse.setBinsPt(cfg.binsPt); + dqMlResponse.setCentType(cfg.centType); + LOG(info) << "Using BDT cuts for binary classification"; } else { auto& cfg = std::get(config); - binsPtMl = cfg.binsPt; + binsMl = cfg.binsMl; nClassesMl = 3; cutsMl = cfg.cutsMl; cutDirMl = cfg.cutDirs; namesInputFeatures = cfg.inputFeatures; onnxFileNames = cfg.onnxFiles; + dqMlResponse.setBinsCent(cfg.binsCent); + dqMlResponse.setBinsPt(cfg.binsPt); + dqMlResponse.setCentType(cfg.centType); + LOG(info) << "Using BDT cuts for multiclass classification"; } - dqMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl); + dqMlResponse.configure(binsMl, cutsMl, cutDirMl, nClassesMl); if (fConfigML.loadModelsFromCCDB) { fCCDBApi.init(fConfigCCDB.url); dqMlResponse.setModelPathsCCDB(onnxFileNames, fCCDBApi, fConfigML.modelPathsCCDB, fConfigML.timestampCCDB); @@ -1762,14 +1770,36 @@ struct AnalysisSameEventPairing { if constexpr ((TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { if (fConfigML.applyBDT) { std::vector dqInputFeatures = dqMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + LOG(debug) << "Input features size: " << dqInputFeatures.size(); if (dqInputFeatures.empty()) { LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; return; } - // isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt]); - isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt], outputMlPsi2ee); + int modelIndex = -1; + const auto& binsCent = dqMlResponse.getBinsCent(); + const auto& binsPt = dqMlResponse.getBinsPt(); + const std::string& centType = dqMlResponse.getCentType(); + + if ("kCentFT0C" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0C], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else if ("kCentFT0A" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0A], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else if ("kCentFT0M" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0M], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else { + LOG(fatal) << "Unknown centrality estimation type: " << centType; + return; + } + + if (modelIndex < 0) { + LOG(info) << "Ml index is negative! This means that the centrality/pt is not in the range of the model bins."; + continue; + } + + LOG(debug) << "Model index: " << modelIndex << ", pT: " << VarManager::fgValues[VarManager::kPt] << ", centrality (kCentFT0C): " << VarManager::fgValues[VarManager::kCentFT0C]; + isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, modelIndex, outputMlPsi2ee); VarManager::FillBdtScore(outputMlPsi2ee); // TODO: check if this is needed or not } @@ -2225,13 +2255,6 @@ struct AnalysisSameEventPairing { runSameEventPairing(events, trackAssocsPerCollision, barrelAssocs, barrelTracks); } - void processBarrelOnlySkimmedBDT(MyEventsVtxCovSelected const& events, - soa::Join const& barrelAssocs, - MyBarrelTracksWithCovWithAmbiguities const& barrelTracks) - { - runSameEventPairing(events, trackAssocsPerCollision, barrelAssocs, barrelTracks); - } - void processMuonOnlySkimmed(MyEventsVtxCovSelected const& events, soa::Join const& muonAssocs, MyMuonTracksWithCovWithAmbiguities const& muons) { @@ -2275,7 +2298,6 @@ struct AnalysisSameEventPairing { PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedNoCov, "Run barrel only pairing (no covariances), with skimmed tracks and with collision information", false); PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedNoCovWithMultExtra, "Run barrel only pairing (no covariances), with skimmed tracks, with collision information, with MultsExtra", false); PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlyWithQvectorCentrSkimmedNoCov, "Run barrel only pairing (no covariances), with skimmed tracks, with Qvector from central framework", false); - PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedBDT, "Run electron-electron pairing, with skimmed tracks and BDT selection", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMuonOnlySkimmed, "Run muon only pairing, with skimmed tracks", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMuonOnlySkimmedMultExtra, "Run muon only pairing, with skimmed tracks", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMixingAllSkimmed, "Run all types of mixed pairing, with skimmed tracks/muons", false); From 6516a476e6fc8772af97f180303e662cd0e0bda9 Mon Sep 17 00:00:00 2001 From: Jseo Date: Mon, 4 Aug 2025 15:44:31 +0200 Subject: [PATCH 6/7] clang format --- PWGDQ/Core/CutsLibrary.cxx | 8 ++++---- PWGDQ/Core/CutsLibrary.h | 7 ++++--- PWGDQ/Tasks/tableReader.cxx | 2 +- PWGDQ/Tasks/tableReader_withAssoc.cxx | 3 +-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/PWGDQ/Core/CutsLibrary.cxx b/PWGDQ/Core/CutsLibrary.cxx index 6c26bfcac01..14d8b6aa256 100644 --- a/PWGDQ/Core/CutsLibrary.cxx +++ b/PWGDQ/Core/CutsLibrary.cxx @@ -7284,7 +7284,7 @@ o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFro return binaryCfg; - // MultiClass + // MultiClass } else if (typeStr == "MultiClass") { dqmlcuts::MultiClassBdtScoreConfig multiCfg; multiCfg.inputFeatures = namesInputFeatures; @@ -7323,12 +7323,12 @@ int o2::aod::dqmlcuts::getMlBinIndex(double cent, double pt, const std::vector>& binsCent, const std::vector>& binsPt) { - LOG(debug) << "Searching for Ml bin index for cent: " << cent << ", pt: " << pt; //here + LOG(debug) << "Searching for Ml bin index for cent: " << cent << ", pt: " << pt; for (size_t i = 0; i < binsCent.size(); ++i) { if (cent >= binsCent[i].first && cent < binsCent[i].second && pt >= binsPt[i].first && pt < binsPt[i].second) { - LOG(debug) << " - Found at index: " << i; //here + LOG(debug) << " - Found at index: " << i; return static_cast(i); } } return -1; // not found -} \ No newline at end of file +} diff --git a/PWGDQ/Core/CutsLibrary.h b/PWGDQ/Core/CutsLibrary.h index f6874543d8e..b38577f3c2b 100644 --- a/PWGDQ/Core/CutsLibrary.h +++ b/PWGDQ/Core/CutsLibrary.h @@ -15,12 +15,13 @@ #ifndef PWGDQ_CORE_CUTSLIBRARY_H_ #define PWGDQ_CORE_CUTSLIBRARY_H_ -#include -#include -#include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/AnalysisCompositeCut.h" +#include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/VarManager.h" +#include +#include + // /////////////////////////////////////////////// // These are the Cuts used in the CEFP Task // // to select tracks in the event selection // diff --git a/PWGDQ/Tasks/tableReader.cxx b/PWGDQ/Tasks/tableReader.cxx index 2c96d5c0960..d96a40359e7 100644 --- a/PWGDQ/Tasks/tableReader.cxx +++ b/PWGDQ/Tasks/tableReader.cxx @@ -14,13 +14,13 @@ #include "PWGDQ/Core/AnalysisCompositeCut.h" #include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/CutsLibrary.h" +#include "PWGDQ/Core/DQMlResponse.h" #include "PWGDQ/Core/HistogramManager.h" #include "PWGDQ/Core/HistogramsLibrary.h" #include "PWGDQ/Core/MixingHandler.h" #include "PWGDQ/Core/MixingLibrary.h" #include "PWGDQ/Core/VarManager.h" #include "PWGDQ/DataModel/ReducedInfoTables.h" -#include "PWGDQ/Core/DQMlResponse.h" #include "Common/CCDB/EventSelectionParams.h" diff --git a/PWGDQ/Tasks/tableReader_withAssoc.cxx b/PWGDQ/Tasks/tableReader_withAssoc.cxx index 2d320f3619a..31aee1be49a 100644 --- a/PWGDQ/Tasks/tableReader_withAssoc.cxx +++ b/PWGDQ/Tasks/tableReader_withAssoc.cxx @@ -15,13 +15,13 @@ #include "PWGDQ/Core/AnalysisCompositeCut.h" #include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/CutsLibrary.h" +#include "PWGDQ/Core/DQMlResponse.h" #include "PWGDQ/Core/HistogramManager.h" #include "PWGDQ/Core/HistogramsLibrary.h" #include "PWGDQ/Core/MixingHandler.h" #include "PWGDQ/Core/MixingLibrary.h" #include "PWGDQ/Core/VarManager.h" #include "PWGDQ/DataModel/ReducedInfoTables.h" -#include "PWGDQ/Core/DQMlResponse.h" #include "Common/CCDB/EventSelectionParams.h" #include "Common/Core/TableHelper.h" @@ -1770,7 +1770,6 @@ struct AnalysisSameEventPairing { if constexpr ((TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { if (fConfigML.applyBDT) { std::vector dqInputFeatures = dqMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); - LOG(debug) << "Input features size: " << dqInputFeatures.size(); if (dqInputFeatures.empty()) { LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; From f90b7717ccb19f898c5791cc89507f57c125a6bf Mon Sep 17 00:00:00 2001 From: Jseo Date: Tue, 5 Aug 2025 16:08:58 +0200 Subject: [PATCH 7/7] apply comments --- PWGDQ/Core/VarManager.h | 2 ++ PWGDQ/Tasks/tableReader.cxx | 40 +++++++++++++-------------- PWGDQ/Tasks/tableReader_withAssoc.cxx | 40 +++++++++++++-------------- 3 files changed, 42 insertions(+), 40 deletions(-) diff --git a/PWGDQ/Core/VarManager.h b/PWGDQ/Core/VarManager.h index 2acf33cf8e8..d49e71677da 100644 --- a/PWGDQ/Core/VarManager.h +++ b/PWGDQ/Core/VarManager.h @@ -5531,6 +5531,8 @@ float VarManager::calculatePhiV(T1 const& t1, T2 const& t2) return pairPhiV; } +/// Fill BDT score values. +/// Supports binary (1 output) and multiclass (3 outputs) models. template void VarManager::FillBdtScore(T1 const& bdtScore, float* values) { diff --git a/PWGDQ/Tasks/tableReader.cxx b/PWGDQ/Tasks/tableReader.cxx index d96a40359e7..102eb26100f 100644 --- a/PWGDQ/Tasks/tableReader.cxx +++ b/PWGDQ/Tasks/tableReader.cxx @@ -1057,7 +1057,7 @@ struct AnalysisSameEventPairing { Configurable fConfigAddJSONHistograms{"cfgAddJSONHistograms", "", "Histograms in JSON format"}; // ML inference Configurable applyBDT{"applyBDT", false, "Flag to apply ML selections"}; - Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; + Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/j/jseo/ML/PbPbPsi/default/"}, "Paths of models on CCDB"}; Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; @@ -1078,8 +1078,8 @@ struct AnalysisSameEventPairing { HistogramManager* fHistMan; - o2::analysis::DQMlResponse dqMlResponse; - std::vector outputMlPsi2ee = {}; // TODO: check this is needed or not + o2::analysis::DQMlResponse fDQMlResponse; + std::vector fOutputMlPsi2ee = {}; // TODO: check this is needed or not o2::ccdb::CcdbApi ccdbApi; // NOTE: The track filter produced by the barrel track selection contain a number of electron cut decisions and one last cut for hadrons used in the @@ -1151,9 +1151,9 @@ struct AnalysisSameEventPairing { cutDirMl = cfg.cutDirs; namesInputFeatures = cfg.inputFeatures; onnxFileNames = cfg.onnxFiles; - dqMlResponse.setBinsCent(cfg.binsCent); - dqMlResponse.setBinsPt(cfg.binsPt); - dqMlResponse.setCentType(cfg.centType); + fDQMlResponse.setBinsCent(cfg.binsCent); + fDQMlResponse.setBinsPt(cfg.binsPt); + fDQMlResponse.setCentType(cfg.centType); LOG(info) << "Using BDT cuts for binary classification"; } else { auto& cfg = std::get(config); @@ -1163,21 +1163,21 @@ struct AnalysisSameEventPairing { cutDirMl = cfg.cutDirs; namesInputFeatures = cfg.inputFeatures; onnxFileNames = cfg.onnxFiles; - dqMlResponse.setBinsCent(cfg.binsCent); - dqMlResponse.setBinsPt(cfg.binsPt); - dqMlResponse.setCentType(cfg.centType); + fDQMlResponse.setBinsCent(cfg.binsCent); + fDQMlResponse.setBinsPt(cfg.binsPt); + fDQMlResponse.setCentType(cfg.centType); LOG(info) << "Using BDT cuts for multiclass classification"; } - dqMlResponse.configure(binsMl, cutsMl, cutDirMl, nClassesMl); + fDQMlResponse.configure(binsMl, cutsMl, cutDirMl, nClassesMl); if (loadModelsFromCCDB) { ccdbApi.init(ccdburl); - dqMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB); + fDQMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB); } else { - dqMlResponse.setModelPathsLocal(onnxFileNames); + fDQMlResponse.setModelPathsLocal(onnxFileNames); } - dqMlResponse.cacheInputFeaturesIndices(namesInputFeatures); - dqMlResponse.init(); + fDQMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + fDQMlResponse.init(); } if (context.mOptions.get("processDecayToEESkimmed") || context.mOptions.get("processDecayToEESkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithCov") || context.mOptions.get("processDecayToEESkimmedWithCovNoTwoProngFitter") || context.mOptions.get("processDecayToEEVertexingSkimmed") || context.mOptions.get("processVnDecayToEESkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithColl") || context.mOptions.get("processDecayToEESkimmedWithCollNoTwoProngFitter") || context.mOptions.get("processDecayToPiPiSkimmed") || context.mOptions.get("processAllSkimmed")) { @@ -1457,7 +1457,7 @@ struct AnalysisSameEventPairing { } if constexpr ((TPairType == pairTypeEE) && (TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { if (applyBDT) { - std::vector dqInputFeatures = dqMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + std::vector dqInputFeatures = fDQMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); if (dqInputFeatures.empty()) { LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; @@ -1465,9 +1465,9 @@ struct AnalysisSameEventPairing { } int modelIndex = -1; - const auto& binsCent = dqMlResponse.getBinsCent(); - const auto& binsPt = dqMlResponse.getBinsPt(); - const std::string& centType = dqMlResponse.getCentType(); + const auto& binsCent = fDQMlResponse.getBinsCent(); + const auto& binsPt = fDQMlResponse.getBinsPt(); + const std::string& centType = fDQMlResponse.getCentType(); if ("kCentFT0C" == centType) { modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0C], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); @@ -1486,8 +1486,8 @@ struct AnalysisSameEventPairing { } LOG(debug) << "Model index: " << modelIndex << ", pT: " << VarManager::fgValues[VarManager::kPt] << ", centrality (kCentFT0C): " << VarManager::fgValues[VarManager::kCentFT0C]; - isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, modelIndex, outputMlPsi2ee); - VarManager::FillBdtScore(outputMlPsi2ee); // TODO: check if this is needed or not + isSelectedBDT = fDQMlResponse.isSelectedMl(dqInputFeatures, modelIndex, fOutputMlPsi2ee); + VarManager::FillBdtScore(fOutputMlPsi2ee); // TODO: check if this is needed or not } if (applyBDT && !isSelectedBDT) diff --git a/PWGDQ/Tasks/tableReader_withAssoc.cxx b/PWGDQ/Tasks/tableReader_withAssoc.cxx index 31aee1be49a..df3e4e2587d 100644 --- a/PWGDQ/Tasks/tableReader_withAssoc.cxx +++ b/PWGDQ/Tasks/tableReader_withAssoc.cxx @@ -1249,7 +1249,7 @@ struct AnalysisSameEventPairing { } fConfigOptions; struct : ConfigurableGroup { Configurable applyBDT{"applyBDT", false, "Flag to apply ML selections"}; - Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; + Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/j/jseo/ML/PbPbPsi/default/"}, "Paths of models on CCDB"}; Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; @@ -1263,8 +1263,8 @@ struct AnalysisSameEventPairing { HistogramManager* fHistMan; - o2::analysis::DQMlResponse dqMlResponse; - std::vector outputMlPsi2ee = {}; // TODO: check this is needed or not + o2::analysis::DQMlResponse fDQMlResponse; + std::vector fOutputMlPsi2ee = {}; // TODO: check this is needed or not // keep histogram class names in maps, so we don't have to buld their names in the pair loops std::map> fTrackHistNames; @@ -1352,9 +1352,9 @@ struct AnalysisSameEventPairing { cutDirMl = cfg.cutDirs; namesInputFeatures = cfg.inputFeatures; onnxFileNames = cfg.onnxFiles; - dqMlResponse.setBinsCent(cfg.binsCent); - dqMlResponse.setBinsPt(cfg.binsPt); - dqMlResponse.setCentType(cfg.centType); + fDQMlResponse.setBinsCent(cfg.binsCent); + fDQMlResponse.setBinsPt(cfg.binsPt); + fDQMlResponse.setCentType(cfg.centType); LOG(info) << "Using BDT cuts for binary classification"; } else { auto& cfg = std::get(config); @@ -1364,21 +1364,21 @@ struct AnalysisSameEventPairing { cutDirMl = cfg.cutDirs; namesInputFeatures = cfg.inputFeatures; onnxFileNames = cfg.onnxFiles; - dqMlResponse.setBinsCent(cfg.binsCent); - dqMlResponse.setBinsPt(cfg.binsPt); - dqMlResponse.setCentType(cfg.centType); + fDQMlResponse.setBinsCent(cfg.binsCent); + fDQMlResponse.setBinsPt(cfg.binsPt); + fDQMlResponse.setCentType(cfg.centType); LOG(info) << "Using BDT cuts for multiclass classification"; } - dqMlResponse.configure(binsMl, cutsMl, cutDirMl, nClassesMl); + fDQMlResponse.configure(binsMl, cutsMl, cutDirMl, nClassesMl); if (fConfigML.loadModelsFromCCDB) { fCCDBApi.init(fConfigCCDB.url); - dqMlResponse.setModelPathsCCDB(onnxFileNames, fCCDBApi, fConfigML.modelPathsCCDB, fConfigML.timestampCCDB); + fDQMlResponse.setModelPathsCCDB(onnxFileNames, fCCDBApi, fConfigML.modelPathsCCDB, fConfigML.timestampCCDB); } else { - dqMlResponse.setModelPathsLocal(onnxFileNames); + fDQMlResponse.setModelPathsLocal(onnxFileNames); } - dqMlResponse.cacheInputFeaturesIndices(namesInputFeatures); - dqMlResponse.init(); + fDQMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + fDQMlResponse.init(); } // get the barrel track selection cuts @@ -1769,7 +1769,7 @@ struct AnalysisSameEventPairing { dielectronsExtraList(t1.globalIndex(), t2.globalIndex(), VarManager::fgValues[VarManager::kVertexingTauzProjected], VarManager::fgValues[VarManager::kVertexingLzProjected], VarManager::fgValues[VarManager::kVertexingLxyProjected]); if constexpr ((TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { if (fConfigML.applyBDT) { - std::vector dqInputFeatures = dqMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + std::vector dqInputFeatures = fDQMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); if (dqInputFeatures.empty()) { LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; @@ -1777,9 +1777,9 @@ struct AnalysisSameEventPairing { } int modelIndex = -1; - const auto& binsCent = dqMlResponse.getBinsCent(); - const auto& binsPt = dqMlResponse.getBinsPt(); - const std::string& centType = dqMlResponse.getCentType(); + const auto& binsCent = fDQMlResponse.getBinsCent(); + const auto& binsPt = fDQMlResponse.getBinsPt(); + const std::string& centType = fDQMlResponse.getCentType(); if ("kCentFT0C" == centType) { modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0C], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); @@ -1798,8 +1798,8 @@ struct AnalysisSameEventPairing { } LOG(debug) << "Model index: " << modelIndex << ", pT: " << VarManager::fgValues[VarManager::kPt] << ", centrality (kCentFT0C): " << VarManager::fgValues[VarManager::kCentFT0C]; - isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, modelIndex, outputMlPsi2ee); - VarManager::FillBdtScore(outputMlPsi2ee); // TODO: check if this is needed or not + isSelectedBDT = fDQMlResponse.isSelectedMl(dqInputFeatures, modelIndex, fOutputMlPsi2ee); + VarManager::FillBdtScore(fOutputMlPsi2ee); // TODO: check if this is needed or not } if (fConfigML.applyBDT && !isSelectedBDT)