diff --git a/PWGDQ/Core/CMakeLists.txt b/PWGDQ/Core/CMakeLists.txt index d19a66a68e6..41ceb661bf9 100644 --- a/PWGDQ/Core/CMakeLists.txt +++ b/PWGDQ/Core/CMakeLists.txt @@ -21,7 +21,7 @@ o2physics_add_library(PWGDQCore AnalysisCompositeCut.cxx MCProng.cxx MCSignal.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle) + PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle O2Physics::MLCore) o2physics_target_root_dictionary(PWGDQCore HEADERS AnalysisCut.h diff --git a/PWGDQ/Core/CutsLibrary.cxx b/PWGDQ/Core/CutsLibrary.cxx index 5cf99d6f22f..14d8b6aa256 100644 --- a/PWGDQ/Core/CutsLibrary.cxx +++ b/PWGDQ/Core/CutsLibrary.cxx @@ -12,14 +12,19 @@ // Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no // #include "PWGDQ/Core/CutsLibrary.h" -#include -#include -#include -#include -#include + #include "AnalysisCompositeCut.h" #include "VarManager.h" +#include + +#include + +#include +#include +#include +#include + using std::cout; using std::endl; @@ -7100,3 +7105,230 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons return retCut; } + +//________________________________________________________________________________________________ +o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json) +{ + LOG(info) << "========================================== interpreting JSON for ML analysis cuts"; + if (!json) { + LOG(fatal) << "JSON config string is null!"; + return {}; + } + LOG(info) << "JSON string: " << json; + + rapidjson::Document document; + + // Check that the json is parsed correctly + rapidjson::ParseResult ok = document.Parse(json); + if (!ok) { + LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")"; + return {}; + } + + for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) { + const auto& obj = it->value; + + // Classification type + if (!obj.HasMember("type")) { + LOG(fatal) << "Missing type (Binary/MultiClass)"; + return {}; + } + TString typeStr = obj["type"].GetString(); + if (typeStr != "Binary" && typeStr != "MultiClass") { + LOG(fatal) << "Unsupported classification type: " << typeStr; + return {}; + } + + // Input features + if (!obj.HasMember("inputFeatures") || !obj["inputFeatures"].IsArray()) { + LOG(fatal) << "Missing inputFeatures member or array"; + return {}; + } + std::vector namesInputFeatures; + for (const auto& feature : obj["inputFeatures"].GetArray()) { + namesInputFeatures.emplace_back(feature.GetString()); + LOG(debug) << "Input features: " << feature.GetString(); + } + + // Model files + if (!obj.HasMember("modelFiles") || !obj["modelFiles"].IsArray()) { + LOG(fatal) << "Missing modelFiles member or array"; + return {}; + } + std::vector onnxFileNames; + for (const auto& model : obj["modelFiles"].GetArray()) { + onnxFileNames.emplace_back(model.GetString()); + LOG(debug) << "Model Files: " << model.GetString() << " "; + } + + // Centrality estimation type + if (!obj.HasMember("cent") || !obj["cent"].IsString()) { + LOG(fatal) << "Missing cent member"; + return {}; + } + std::string cent = obj["cent"].GetString(); + LOG(debug) << "Centrality type: " << cent; + if (cent != "kCentFT0C" && cent != "kCentFT0A" && cent != "kCentFT0M") { + LOG(fatal) << "Unsupported centrality type: " << cent; + return {}; + } + + // Cut storage + std::vector> centBins; + std::vector> ptBins; + std::vector> cutsMl; + std::vector cutDirs; + std::vector labelsFlatBin; + bool cutDirsFilled = false; + + for (auto centMember = obj.MemberBegin(); centMember != obj.MemberEnd(); ++centMember) { + TString centKey = centMember->name.GetString(); + if (!centKey.Contains("AddCentCut")) + continue; + + const auto& centCut = centMember->value; + + // Centrality info + if (!centCut.HasMember("centMin") || !centCut.HasMember("centMax")) { + LOG(fatal) << "Missing centMin/centMax in " << centKey; + return {}; + } + double centMin = centCut["centMin"].GetDouble(); + double centMax = centCut["centMax"].GetDouble(); + + for (auto ptMember = centCut.MemberBegin(); ptMember != centCut.MemberEnd(); ++ptMember) { + TString ptKey = ptMember->name.GetString(); + if (!ptKey.Contains("AddPtCut")) + continue; + + const auto& ptCut = ptMember->value; + + // Pt info + if (!ptCut.HasMember("pTMin") || !ptCut.HasMember("pTMax")) { + LOG(fatal) << "Missing pTMin/pTMax in " << ptKey; + return {}; + } + + double ptMin = ptCut["pTMin"].GetDouble(); + double ptMax = ptCut["pTMax"].GetDouble(); + + std::vector binCuts; + bool exclude = false; + + for (auto mlMember = ptCut.MemberBegin(); mlMember != ptCut.MemberEnd(); ++mlMember) { + TString mlKey = mlMember->name.GetString(); + if (!mlKey.Contains("AddMLCut")) + continue; + + const auto& mlcut = mlMember->value; + + if (!mlcut.HasMember("cut")) { + LOG(fatal) << "Missing cut (score) in " << mlKey; + return {}; + } + + double cutVal = mlcut["cut"].GetDouble(); + exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false; + + binCuts.push_back(cutVal); + + if (!cutDirsFilled) { + cutDirs.push_back(exclude ? 0 : 1); + } + } + + if (!cutDirsFilled) { + cutDirsFilled = true; + } + + centBins.emplace_back(centMin, centMax); + ptBins.emplace_back(ptMin, ptMax); + cutsMl.push_back(binCuts); + labelsFlatBin.push_back(Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax)); + LOG(info) << "Added cut for " << Form("%s_cent%.0f_%.0f_pt%.1f_%.1f", cent.c_str(), centMin, centMax, ptMin, ptMax) << " with cuts: ["; + for (size_t i = 0; i < binCuts.size(); ++i) { + std::cout << binCuts[i]; + if (i != binCuts.size() - 1) + std::cout << ", "; + } + std::cout << "] and direction: " << (exclude ? "CutGreater" : "CutSmaller") << std::endl; + } + } + + if (cutDirs.size() != cutsMl[0].size()) { + LOG(fatal) << "Mismatch the cut size and direction size: cutsMl[0].size() = " << cutsMl[0].size() + << ", cutsMl[0].size() = " << cutDirs.size(); + return {}; + } + + std::vector labelsClass; + for (size_t j = 0; j < cutsMl[0].size(); ++j) { + labelsClass.push_back(Form("score class %d", static_cast(j))); + } + + size_t nFlatBins = cutsMl.size(); + std::vector binsMl(nFlatBins + 1); + std::iota(binsMl.begin(), binsMl.end(), 0); + + // Binary + if (typeStr == "Binary") { + dqmlcuts::BinaryBdtScoreConfig binaryCfg; + binaryCfg.inputFeatures = namesInputFeatures; + binaryCfg.onnxFiles = onnxFileNames; + binaryCfg.binsCent = centBins; + binaryCfg.binsPt = ptBins; + binaryCfg.binsMl = binsMl; + binaryCfg.cutDirs = cutDirs; + binaryCfg.centType = cent; + binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass); + + return binaryCfg; + + // MultiClass + } else if (typeStr == "MultiClass") { + dqmlcuts::MultiClassBdtScoreConfig multiCfg; + multiCfg.inputFeatures = namesInputFeatures; + multiCfg.onnxFiles = onnxFileNames; + multiCfg.binsCent = centBins; + multiCfg.binsPt = ptBins; + multiCfg.binsMl = binsMl; + multiCfg.cutDirs = cutDirs; + multiCfg.centType = cent; + multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsFlatBin, labelsClass); + + return multiCfg; + } + } + + return {}; +} + +o2::framework::LabeledArray o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector>& cuts, + const std::vector& labelsflatBin, + const std::vector& labelsClass) +{ + const size_t nRows = cuts.size(); + const size_t nCols = cuts.empty() ? 0 : cuts[0].size(); + std::vector flat; + + for (const auto& row : cuts) { + flat.insert(flat.end(), row.begin(), row.end()); + } + + o2::framework::Array2D arr(flat.data(), nRows, nCols); + return o2::framework::LabeledArray(arr, labelsflatBin, labelsClass); +} + +int o2::aod::dqmlcuts::getMlBinIndex(double cent, double pt, + const std::vector>& binsCent, + const std::vector>& binsPt) +{ + LOG(debug) << "Searching for Ml bin index for cent: " << cent << ", pt: " << pt; + for (size_t i = 0; i < binsCent.size(); ++i) { + if (cent >= binsCent[i].first && cent < binsCent[i].second && pt >= binsPt[i].first && pt < binsPt[i].second) { + LOG(debug) << " - Found at index: " << i; + return static_cast(i); + } + } + return -1; // not found +} diff --git a/PWGDQ/Core/CutsLibrary.h b/PWGDQ/Core/CutsLibrary.h index c6ad4caded2..b38577f3c2b 100644 --- a/PWGDQ/Core/CutsLibrary.h +++ b/PWGDQ/Core/CutsLibrary.h @@ -15,12 +15,13 @@ #ifndef PWGDQ_CORE_CUTSLIBRARY_H_ #define PWGDQ_CORE_CUTSLIBRARY_H_ -#include -#include -#include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/AnalysisCompositeCut.h" +#include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/VarManager.h" +#include +#include + // /////////////////////////////////////////////// // These are the Cuts used in the CEFP Task // // to select tracks in the event selection // @@ -119,6 +120,41 @@ bool ValidateJSONAnalysisCompositeCut(T cut); template AnalysisCompositeCut* ParseJSONAnalysisCompositeCut(T key, const char* cutName); } // namespace dqcuts +namespace dqmlcuts +{ +struct BinaryBdtScoreConfig { + std::vector inputFeatures; + std::vector onnxFiles; + std::vector> binsCent; // bins for centrality + std::vector> binsPt; // bins for pT + std::vector binsMl; // bins for flattened binning + std::string centType; + o2::framework::LabeledArray cutsMl; // BDT score cuts for each bin + std::vector cutDirs; // direction of the cuts on the BDT score +}; + +struct MultiClassBdtScoreConfig { + std::vector inputFeatures; + std::vector onnxFiles; + std::vector> binsCent; + std::vector> binsPt; + std::vector binsMl; + std::string centType; + o2::framework::LabeledArray cutsMl; + std::vector cutDirs; +}; + +using BdtScoreConfig = std::variant; + +BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json); + +o2::framework::LabeledArray makeLabeledCutsMl(const std::vector>& cuts, + const std::vector& labelsPt, + const std::vector& labelsClass); +int getMlBinIndex(double cent, double pt, + const std::vector>& binsCent, + const std::vector>& binsPt); +} // namespace dqmlcuts } // namespace o2::aod AnalysisCompositeCut* o2::aod::dqcuts::GetCompositeCut(const char* cutName); diff --git a/PWGDQ/Core/DQMlResponse.h b/PWGDQ/Core/DQMlResponse.h new file mode 100644 index 00000000000..64bfe233cc7 --- /dev/null +++ b/PWGDQ/Core/DQMlResponse.h @@ -0,0 +1,239 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. +// +// Contact: jseo@cern.ch +// +// Class to compute the ML response for DQ-analysis selections +// + +#ifndef PWGDQ_CORE_DQMLRESPONSE_H_ +#define PWGDQ_CORE_DQMLRESPONSE_H_ + +#include "Tools/ML/MlResponse.h" + +#include +#include +#include +#include + +namespace o2::analysis +{ + +enum class InputFeatures : uint8_t { // refer to DielectronsAll, TODO: add more features if needed + kMass = 0, + kPt, + kEta, + kPhi, + kPt1, + kITSChi2NCl1, + kTPCNClsCR1, + kTPCNClsFound1, + kTPCChi2NCl1, + kDcaXY1, + kDcaZ1, + kTPCNSigmaEl1, + kTPCNSigmaPi1, + kTPCNSigmaPr1, + kTOFNSigmaEl1, + kTOFNSigmaPi1, + kTOFNSigmaPr1, + kPt2, + kITSChi2NCl2, + kTPCNClsCR2, + kTPCNClsFound2, + kTPCChi2NCl2, + kDcaXY2, + kDcaZ2, + kTPCNSigmaEl2, + kTPCNSigmaPi2, + kTPCNSigmaPr2, + kTOFNSigmaEl2, + kTOFNSigmaPi2, + kTOFNSigmaPr2 +}; + +static const std::map gFeatureNameMap = { + {InputFeatures::kMass, "kMass"}, + {InputFeatures::kPt, "kPt"}, + {InputFeatures::kEta, "kEta"}, + {InputFeatures::kPhi, "kPhi"}, + {InputFeatures::kPt1, "kPt1"}, + {InputFeatures::kITSChi2NCl1, "kITSChi2NCl1"}, + {InputFeatures::kTPCNClsCR1, "kTPCNClsCR1"}, + {InputFeatures::kTPCNClsFound1, "kTPCNClsFound1"}, + {InputFeatures::kTPCChi2NCl1, "kTPCChi2NCl1"}, + {InputFeatures::kDcaXY1, "kDcaXY1"}, + {InputFeatures::kDcaZ1, "kDcaZ1"}, + {InputFeatures::kTPCNSigmaEl1, "kTPCNSigmaEl1"}, + {InputFeatures::kTPCNSigmaPi1, "kTPCNSigmaPi1"}, + {InputFeatures::kTPCNSigmaPr1, "kTPCNSigmaPr1"}, + {InputFeatures::kTOFNSigmaEl1, "kTOFNSigmaEl1"}, + {InputFeatures::kTOFNSigmaPi1, "kTOFNSigmaPi1"}, + {InputFeatures::kTOFNSigmaPr1, "kTOFNSigmaPr1"}, + {InputFeatures::kPt2, "kPt2"}, + {InputFeatures::kITSChi2NCl2, "kITSChi2NCl2"}, + {InputFeatures::kTPCNClsCR2, "kTPCNClsCR2"}, + {InputFeatures::kTPCNClsFound2, "kTPCNClsFound2"}, + {InputFeatures::kTPCChi2NCl2, "kTPCChi2NCl2"}, + {InputFeatures::kDcaXY2, "kDcaXY2"}, + {InputFeatures::kDcaZ2, "kDcaZ2"}, + {InputFeatures::kTPCNSigmaEl2, "kTPCNSigmaEl2"}, + {InputFeatures::kTPCNSigmaPi2, "kTPCNSigmaPi2"}, + {InputFeatures::kTPCNSigmaPr2, "kTPCNSigmaPr2"}, + {InputFeatures::kTOFNSigmaEl2, "kTOFNSigmaEl2"}, + {InputFeatures::kTOFNSigmaPi2, "kTOFNSigmaPi2"}, + {InputFeatures::kTOFNSigmaPr2, "kTOFNSigmaPr2"}}; + +template +class DQMlResponse : public MlResponse +{ + public: + DQMlResponse() = default; + virtual ~DQMlResponse() = default; + + void setBinsCent(const std::vector>& bins) { binsCent = bins; } + void setBinsPt(const std::vector>& bins) { binsPt = bins; } + void setCentType(std::string& type) { centType = type; } + + const std::vector>& getBinsCent() const { return binsCent; } + const std::vector>& getBinsPt() const { return binsPt; } + const std::string& getCentType() const { return centType; } + + /// Method to get the input features vector needed for ML inference + /// \return inputFeatures vector + template + std::vector getInputFeatures(const T1& t1, + const T2& t2, + const TValues& fg) const + { + std::vector dqInputFeatures; + dqInputFeatures.reserve(MlResponse::mCachedIndices.size()); + + for (auto idx : MlResponse::mCachedIndices) { + auto enumIdx = static_cast(idx); + auto mapIdx = gFeatureNameMap.find(enumIdx); + if (mapIdx == gFeatureNameMap.end()) { + LOG(fatal) << "Unknown InputFeatures index: " << static_cast(enumIdx); + } + + const auto& name = mapIdx->second; + if (name == "kMass") { + dqInputFeatures.push_back(fg[VarManager::fgVarNamesMap["kMass"]]); + } else if (name == "kPt") { + dqInputFeatures.push_back(fg[VarManager::fgVarNamesMap["kPt"]]); + } else if (name == "kEta") { + dqInputFeatures.push_back(fg[VarManager::fgVarNamesMap["kEta"]]); + } else if (name == "kPhi") { + dqInputFeatures.push_back(fg[VarManager::fgVarNamesMap["kPhi"]]); + } else if (name == "kPt1") { + dqInputFeatures.push_back(t1.pt()); + } else if (name == "kITSChi2NCl1") { + dqInputFeatures.push_back(t1.itsChi2NCl()); + } else if (name == "kTPCNClsCR1") { + dqInputFeatures.push_back(t1.tpcNClsCrossedRows()); + } else if (name == "kTPCNClsFound1") { + dqInputFeatures.push_back(t1.tpcNClsFound()); + } else if (name == "kTPCChi2NCl1") { + dqInputFeatures.push_back(t1.tpcChi2NCl()); + } else if (name == "kDcaXY1") { + dqInputFeatures.push_back(t1.dcaXY()); + } else if (name == "kDcaZ1") { + dqInputFeatures.push_back(t1.dcaZ()); + } else if (name == "kTPCNSigmaEl1") { + dqInputFeatures.push_back(t1.tpcNSigmaEl()); + } else if (name == "kTPCNSigmaPi1") { + dqInputFeatures.push_back(t1.tpcNSigmaPi()); + } else if (name == "kTPCNSigmaPr1") { + dqInputFeatures.push_back(t1.tpcNSigmaPr()); + } else if (name == "kTOFNSigmaEl1") { + dqInputFeatures.push_back(t1.tofNSigmaEl()); + } else if (name == "kTOFNSigmaPi1") { + dqInputFeatures.push_back(t1.tofNSigmaPi()); + } else if (name == "kTOFNSigmaPr1") { + dqInputFeatures.push_back(t1.tofNSigmaPr()); + } else if (name == "kPt2") { + dqInputFeatures.push_back(t2.pt()); + } else if (name == "kITSChi2NCl2") { + dqInputFeatures.push_back(t2.itsChi2NCl()); + } else if (name == "kTPCNClsCR2") { + dqInputFeatures.push_back(t2.tpcNClsCrossedRows()); + } else if (name == "kTPCNClsFound2") { + dqInputFeatures.push_back(t2.tpcNClsFound()); + } else if (name == "kTPCChi2NCl2") { + dqInputFeatures.push_back(t2.tpcChi2NCl()); + } else if (name == "kDcaXY2") { + dqInputFeatures.push_back(t2.dcaXY()); + } else if (name == "kDcaZ2") { + dqInputFeatures.push_back(t2.dcaZ()); + } else if (name == "kTPCNSigmaEl2") { + dqInputFeatures.push_back(t2.tpcNSigmaEl()); + } else if (name == "kTPCNSigmaPi2") { + dqInputFeatures.push_back(t2.tpcNSigmaPi()); + } else if (name == "kTPCNSigmaPr2") { + dqInputFeatures.push_back(t2.tpcNSigmaPr()); + } else if (name == "kTOFNSigmaEl2") { + dqInputFeatures.push_back(t2.tofNSigmaEl()); + } else if (name == "kTOFNSigmaPi2") { + dqInputFeatures.push_back(t2.tofNSigmaPi()); + } else if (name == "kTOFNSigmaPr2") { + dqInputFeatures.push_back(t2.tofNSigmaPr()); + } else { + LOG(fatal) << "Missing accessor for feature: " << name; + } + } + LOG(debug) << "Total features collected: " << dqInputFeatures.size(); + return dqInputFeatures; + } + + protected: + std::vector> binsCent; + std::vector> binsPt; + std::string centType; + + void setAvailableInputFeatures() + { + MlResponse::mAvailableInputFeatures = { + {"kMass", static_cast(InputFeatures::kMass)}, + {"kPt", static_cast(InputFeatures::kPt)}, + {"kEta", static_cast(InputFeatures::kEta)}, + {"kPhi", static_cast(InputFeatures::kPhi)}, + {"kPt1", static_cast(InputFeatures::kPt1)}, + {"kITSChi2NCl1", static_cast(InputFeatures::kITSChi2NCl1)}, + {"kTPCNClsCR1", static_cast(InputFeatures::kTPCNClsCR1)}, + {"kTPCNClsFound1", static_cast(InputFeatures::kTPCNClsFound1)}, + {"kTPCChi2NCl1", static_cast(InputFeatures::kTPCChi2NCl1)}, + {"kDcaXY1", static_cast(InputFeatures::kDcaXY1)}, + {"kDcaZ1", static_cast(InputFeatures::kDcaZ1)}, + {"kTPCNSigmaEl1", static_cast(InputFeatures::kTPCNSigmaEl1)}, + {"kTPCNSigmaPi1", static_cast(InputFeatures::kTPCNSigmaPi1)}, + {"kTPCNSigmaPr1", static_cast(InputFeatures::kTPCNSigmaPr1)}, + {"kTOFNSigmaEl1", static_cast(InputFeatures::kTOFNSigmaEl1)}, + {"kTOFNSigmaPi1", static_cast(InputFeatures::kTOFNSigmaPi1)}, + {"kTOFNSigmaPr1", static_cast(InputFeatures::kTOFNSigmaPr1)}, + {"kPt2", static_cast(InputFeatures::kPt2)}, + {"kITSChi2NCl2", static_cast(InputFeatures::kITSChi2NCl2)}, + {"kTPCNClsCR2", static_cast(InputFeatures::kTPCNClsCR2)}, + {"kTPCNClsFound2", static_cast(InputFeatures::kTPCNClsFound2)}, + {"kTPCChi2NCl2", static_cast(InputFeatures::kTPCChi2NCl2)}, + {"kDcaXY2", static_cast(InputFeatures::kDcaXY2)}, + {"kDcaZ2", static_cast(InputFeatures::kDcaZ2)}, + {"kTPCNSigmaEl2", static_cast(InputFeatures::kTPCNSigmaEl2)}, + {"kTPCNSigmaPi2", static_cast(InputFeatures::kTPCNSigmaPi2)}, + {"kTPCNSigmaPr2", static_cast(InputFeatures::kTPCNSigmaPr2)}, + {"kTOFNSigmaEl2", static_cast(InputFeatures::kTOFNSigmaEl2)}, + {"kTOFNSigmaPi2", static_cast(InputFeatures::kTOFNSigmaPi2)}, + {"kTOFNSigmaPr2", static_cast(InputFeatures::kTOFNSigmaPr2)}}; + } +}; + +} // namespace o2::analysis + +#endif // PWGDQ_CORE_DQMLRESPONSE_H_ diff --git a/PWGDQ/Core/VarManager.cxx b/PWGDQ/Core/VarManager.cxx index 912dc06740a..4bb8c419ad7 100644 --- a/PWGDQ/Core/VarManager.cxx +++ b/PWGDQ/Core/VarManager.cxx @@ -1131,6 +1131,12 @@ void VarManager::SetDefaultVarNames() fgVariableUnits[kS13] = "GeV^{2}/c^{4}"; fgVariableNames[kS23] = "m_{23}^{2}"; fgVariableUnits[kS23] = "GeV^{2}/c^{4}"; + fgVariableNames[kBdtBackground] = "kBdtBackground"; + fgVariableUnits[kBdtBackground] = " "; + fgVariableNames[kBdtPrompt] = "kBdtPrompt"; + fgVariableUnits[kBdtPrompt] = " "; + fgVariableNames[kBdtNonprompt] = "kBdtNonprompt"; + fgVariableUnits[kBdtNonprompt] = " "; // Set the variables short names map. This is needed for dynamic configuration via JSON files fgVarNamesMap["kNothing"] = kNothing; @@ -1770,4 +1776,7 @@ void VarManager::SetDefaultVarNames() fgVarNamesMap["kV24ME"] = kV24ME; fgVarNamesMap["kWV22ME"] = kWV22ME; fgVarNamesMap["kWV24ME"] = kWV24ME; + fgVarNamesMap["kBdtBackground"] = kBdtBackground; + fgVarNamesMap["kBdtPrompt"] = kBdtPrompt; + fgVarNamesMap["kBdtNonprompt"] = kBdtNonprompt; } diff --git a/PWGDQ/Core/VarManager.h b/PWGDQ/Core/VarManager.h index ec2ec55b83b..d49e71677da 100644 --- a/PWGDQ/Core/VarManager.h +++ b/PWGDQ/Core/VarManager.h @@ -855,6 +855,11 @@ class VarManager : public TObject // deltaMass_jpsi = kPairMass - kPairMassDau +3.096900 kDeltaMass_jpsi, + // BDT score + kBdtBackground, + kBdtPrompt, + kBdtNonprompt, + kNVars }; // end of Variables enumeration @@ -1127,6 +1132,8 @@ class VarManager : public TObject static void FillDileptonTrackTrackVertexing(C const& collision, T1 const& lepton1, T1 const& lepton2, T1 const& track1, T1 const& track2, float* values); template static void FillZDC(const T& zdc, float* values = nullptr); + template + static void FillBdtScore(const T& bdtScore, float* values = nullptr); static void SetCalibrationObject(CalibObjects calib, TObject* obj) { @@ -5524,4 +5531,24 @@ float VarManager::calculatePhiV(T1 const& t1, T2 const& t2) return pairPhiV; } +/// Fill BDT score values. +/// Supports binary (1 output) and multiclass (3 outputs) models. +template +void VarManager::FillBdtScore(T1 const& bdtScore, float* values) +{ + if (!values) { + values = fgValues; + } + + if (bdtScore.size() == 1) { + values[kBdtBackground] = bdtScore[0]; + } else if (bdtScore.size() == 3) { + values[kBdtBackground] = bdtScore[0]; + values[kBdtPrompt] = bdtScore[1]; + values[kBdtNonprompt] = bdtScore[2]; + } else { + LOG(warning) << "Unexpected number of BDT outputs: " << bdtScore.size(); + } +} + #endif // PWGDQ_CORE_VARMANAGER_H_ diff --git a/PWGDQ/Macros/bdtCut.json b/PWGDQ/Macros/bdtCut.json new file mode 100644 index 00000000000..2e015fa8618 --- /dev/null +++ b/PWGDQ/Macros/bdtCut.json @@ -0,0 +1,89 @@ +{ + "TestCut": { + "type": "Binary", + "title": "MyBDTModel", + "inputFeatures": [ + "kMass", + "kPt", + "kEta", + "kPhi", + "kPt1", + "kITSChi2NCl1", + "kTPCNClsCR1", + "kTPCNClsFound1", + "kTPCChi2NCl1", + "kDcaXY1", + "kDcaZ1", + "kTPCNSigmaEl1", + "kTPCNSigmaPi1", + "kTPCNSigmaPr1", + "kTOFNSigmaEl1", + "kTOFNSigmaPi1", + "kTOFNSigmaPr1", + "kPt2", + "kITSChi2NCl2", + "kTPCNClsCR2", + "kTPCNClsFound2", + "kTPCChi2NCl2", + "kDcaXY2", + "kDcaZ2", + "kTPCNSigmaEl2", + "kTPCNSigmaPi2", + "kTPCNSigmaPr2", + "kTOFNSigmaEl2", + "kTOFNSigmaPi2", + "kTOFNSigmaPr2" + ], + "modelFiles": [ + "cent_10_30_pt0_2_onnx.onnx", + "cent_10_30_pt2_20_onnx.onnx", + "cent_30_50_pt0_2_onnx.onnx", + "cent_30_50_pt2_20_onnx.onnx" + ], + "cent": "kCentFT0C", + "AddCentCut-Cent1030": { + "centMin": 10, + "centMax": 30, + "AddPtCut-pTBin1": { + "pTMin": 0, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + }, + "AddPtCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + } + }, + "AddCentCut-Cent3050": { + "centMin": 30, + "centMax": 50, + "AddPtCut-pTBin1": { + "pTMin": 0, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + }, + "AddPtCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + } + } + } +} diff --git a/PWGDQ/Macros/bdtCutMulti.json b/PWGDQ/Macros/bdtCutMulti.json new file mode 100644 index 00000000000..ba687e36200 --- /dev/null +++ b/PWGDQ/Macros/bdtCutMulti.json @@ -0,0 +1,129 @@ +{ + "TestCut": { + "type": "MultiClass", + "title": "MyBDTModel", + "inputFeatures": [ + "kMass", + "kPt", + "kEta", + "kPhi", + "kPt1", + "kITSChi2NCl1", + "kTPCNClsCR1", + "kTPCNClsFound1", + "kTPCChi2NCl1", + "kDcaXY1", + "kDcaZ1", + "kTPCNSigmaEl1", + "kTPCNSigmaPi1", + "kTPCNSigmaPr1", + "kTOFNSigmaEl1", + "kTOFNSigmaPi1", + "kTOFNSigmaPr1", + "kPt2", + "kITSChi2NCl2", + "kTPCNClsCR2", + "kTPCNClsFound2", + "kTPCChi2NCl2", + "kDcaXY2", + "kDcaZ2", + "kTPCNSigmaEl2", + "kTPCNSigmaPi2", + "kTPCNSigmaPr2", + "kTOFNSigmaEl2", + "kTOFNSigmaPi2", + "kTOFNSigmaPr2" + ], + "modelFiles": [ + "cent_10_30_pt_1_2_onnx.onnx", + "cent_10_30_pt_2_20_onnx.onnx", + "cent_30_50_pt_1_2_onnx.onnx", + "cent_30_50_pt_2_20_onnx.onnx" + ], + "cent": "kCentFT0C", + "AddCentCut-Cent1030": { + "centMin": 10, + "centMax": 30, + "AddPtCut-pTBin1": { + "pTMin": 1, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } + }, + "AddPtCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } + } + }, + "AddCentCut-Cent3050": { + "centMin": 30, + "centMax": 50, + "AddPtCut-pTBin1": { + "pTMin": 1, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } + }, + "AddPtCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.1, + "exclude": true + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } + } + } + } +} diff --git a/PWGDQ/Tasks/CMakeLists.txt b/PWGDQ/Tasks/CMakeLists.txt index 5095140a2b8..c3bb38bf955 100644 --- a/PWGDQ/Tasks/CMakeLists.txt +++ b/PWGDQ/Tasks/CMakeLists.txt @@ -11,12 +11,12 @@ o2physics_add_dpl_workflow(table-reader SOURCES tableReader.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2Physics::PWGDQCore + PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2Physics::PWGDQCore O2Physics::MLCore COMPONENT_NAME Analysis) o2physics_add_dpl_workflow(table-reader-with-assoc SOURCES tableReader_withAssoc.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2::DetectorsBase O2Physics::AnalysisCore O2Physics::AnalysisCCDB O2Physics::PWGDQCore + PUBLIC_LINK_LIBRARIES O2::Framework O2::DetectorsBase O2Physics::AnalysisCore O2Physics::AnalysisCCDB O2Physics::PWGDQCore O2Physics::MLCore COMPONENT_NAME Analysis) o2physics_add_dpl_workflow(efficiency @@ -127,4 +127,4 @@ o2physics_add_dpl_workflow(model-converter-event-extended o2physics_add_dpl_workflow(tag-and-probe SOURCES TagAndProbe.cxx PUBLIC_LINK_LIBRARIES O2::Framework O2::DetectorsBase O2Physics::AnalysisCore O2Physics::AnalysisCCDB O2Physics::PWGDQCore - COMPONENT_NAME Analysis) \ No newline at end of file + COMPONENT_NAME Analysis) diff --git a/PWGDQ/Tasks/tableReader.cxx b/PWGDQ/Tasks/tableReader.cxx index 888be5a6d67..102eb26100f 100644 --- a/PWGDQ/Tasks/tableReader.cxx +++ b/PWGDQ/Tasks/tableReader.cxx @@ -14,6 +14,7 @@ #include "PWGDQ/Core/AnalysisCompositeCut.h" #include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/CutsLibrary.h" +#include "PWGDQ/Core/DQMlResponse.h" #include "PWGDQ/Core/HistogramManager.h" #include "PWGDQ/Core/HistogramsLibrary.h" #include "PWGDQ/Core/MixingHandler.h" @@ -1054,6 +1055,12 @@ struct AnalysisSameEventPairing { Configurable fCenterMassEnergy{"energy", 13600, "Center of mass energy in GeV"}; Configurable fConfigCumulants{"cfgCumulants", false, "If true, fill Cumulants with Weights different than 0"}; Configurable fConfigAddJSONHistograms{"cfgAddJSONHistograms", "", "Histograms in JSON format"}; + // ML inference + Configurable applyBDT{"applyBDT", false, "Flag to apply ML selections"}; + Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; + Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/j/jseo/ML/PbPbPsi/default/"}, "Paths of models on CCDB"}; + Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; + Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; // Configurables to create output tree (flat tables or minitree) struct : ConfigurableGroup { @@ -1071,6 +1078,10 @@ struct AnalysisSameEventPairing { HistogramManager* fHistMan; + o2::analysis::DQMlResponse fDQMlResponse; + std::vector fOutputMlPsi2ee = {}; // TODO: check this is needed or not + o2::ccdb::CcdbApi ccdbApi; + // NOTE: The track filter produced by the barrel track selection contain a number of electron cut decisions and one last cut for hadrons used in the // dilepton - hadron task downstream. So the bit mask is required to select pairs just based on the electron cuts // TODO: provide as Configurable the list and names of the cuts which should be used in pairing @@ -1121,6 +1132,54 @@ struct AnalysisSameEventPairing { } } + if (applyBDT) { + // BDT cuts via JSON + std::vector binsMl; + o2::framework::LabeledArray cutsMl; + std::vector cutDirMl; + int nClassesMl = 1; + std::vector namesInputFeatures; + std::vector onnxFileNames; + + auto config = o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(fConfigBdtCutsJSON.value.c_str()); + + if (std::holds_alternative(config)) { + auto& cfg = std::get(config); + binsMl = cfg.binsMl; + nClassesMl = 1; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + fDQMlResponse.setBinsCent(cfg.binsCent); + fDQMlResponse.setBinsPt(cfg.binsPt); + fDQMlResponse.setCentType(cfg.centType); + LOG(info) << "Using BDT cuts for binary classification"; + } else { + auto& cfg = std::get(config); + binsMl = cfg.binsMl; + nClassesMl = 3; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + fDQMlResponse.setBinsCent(cfg.binsCent); + fDQMlResponse.setBinsPt(cfg.binsPt); + fDQMlResponse.setCentType(cfg.centType); + LOG(info) << "Using BDT cuts for multiclass classification"; + } + + fDQMlResponse.configure(binsMl, cutsMl, cutDirMl, nClassesMl); + if (loadModelsFromCCDB) { + ccdbApi.init(ccdburl); + fDQMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB); + } else { + fDQMlResponse.setModelPathsLocal(onnxFileNames); + } + fDQMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + fDQMlResponse.init(); + } + if (context.mOptions.get("processDecayToEESkimmed") || context.mOptions.get("processDecayToEESkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithCov") || context.mOptions.get("processDecayToEESkimmedWithCovNoTwoProngFitter") || context.mOptions.get("processDecayToEEVertexingSkimmed") || context.mOptions.get("processVnDecayToEESkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithColl") || context.mOptions.get("processDecayToEESkimmedWithCollNoTwoProngFitter") || context.mOptions.get("processDecayToPiPiSkimmed") || context.mOptions.get("processAllSkimmed")) { TString cutNames = fConfigTrackCuts.value; if (!cutNames.IsNull()) { // if track cuts @@ -1317,6 +1376,8 @@ struct AnalysisSameEventPairing { dileptonMiniTree.reserve(1); } + bool isSelectedBDT = false; + if (fConfigMultDimuons.value) { uint32_t mult_dimuons = 0; @@ -1395,6 +1456,43 @@ struct AnalysisSameEventPairing { } } if constexpr ((TPairType == pairTypeEE) && (TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { + if (applyBDT) { + std::vector dqInputFeatures = fDQMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + + if (dqInputFeatures.empty()) { + LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; + return; + } + + int modelIndex = -1; + const auto& binsCent = fDQMlResponse.getBinsCent(); + const auto& binsPt = fDQMlResponse.getBinsPt(); + const std::string& centType = fDQMlResponse.getCentType(); + + if ("kCentFT0C" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0C], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else if ("kCentFT0A" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0A], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else if ("kCentFT0M" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0M], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else { + LOG(fatal) << "Unknown centrality estimation type: " << centType; + return; + } + + if (modelIndex < 0) { + LOG(debug) << "Ml index is negative! This means that the centrality/pt is not in the range of the model bins."; + continue; + } + + LOG(debug) << "Model index: " << modelIndex << ", pT: " << VarManager::fgValues[VarManager::kPt] << ", centrality (kCentFT0C): " << VarManager::fgValues[VarManager::kCentFT0C]; + isSelectedBDT = fDQMlResponse.isSelectedMl(dqInputFeatures, modelIndex, fOutputMlPsi2ee); + VarManager::FillBdtScore(fOutputMlPsi2ee); // TODO: check if this is needed or not + } + + if (applyBDT && !isSelectedBDT) + continue; + if (fConfigFlatTables.value) { dielectronAllList(VarManager::fgValues[VarManager::kMass], VarManager::fgValues[VarManager::kPt], VarManager::fgValues[VarManager::kEta], VarManager::fgValues[VarManager::kPhi], t1.sign() + t2.sign(), dileptonFilterMap, dileptonMcDecision, t1.pt(), t1.eta(), t1.phi(), t1.itsClusterMap(), t1.itsChi2NCl(), t1.tpcNClsCrossedRows(), t1.tpcNClsFound(), t1.tpcChi2NCl(), t1.dcaXY(), t1.dcaZ(), t1.tpcSignal(), t1.tpcNSigmaEl(), t1.tpcNSigmaPi(), t1.tpcNSigmaPr(), t1.beta(), t1.tofNSigmaEl(), t1.tofNSigmaPi(), t1.tofNSigmaPr(), @@ -1468,6 +1566,9 @@ struct AnalysisSameEventPairing { } } + if (applyBDT && !isSelectedBDT) + continue; + int iCut = 0; for (int icut = 0; icut < ncuts; icut++) { if (twoTrackFilter & (static_cast(1) << icut)) { diff --git a/PWGDQ/Tasks/tableReader_withAssoc.cxx b/PWGDQ/Tasks/tableReader_withAssoc.cxx index 139476d3b71..df3e4e2587d 100644 --- a/PWGDQ/Tasks/tableReader_withAssoc.cxx +++ b/PWGDQ/Tasks/tableReader_withAssoc.cxx @@ -15,6 +15,7 @@ #include "PWGDQ/Core/AnalysisCompositeCut.h" #include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/CutsLibrary.h" +#include "PWGDQ/Core/DQMlResponse.h" #include "PWGDQ/Core/HistogramManager.h" #include "PWGDQ/Core/HistogramsLibrary.h" #include "PWGDQ/Core/MixingHandler.h" @@ -1246,6 +1247,14 @@ struct AnalysisSameEventPairing { Configurable propTrack{"cfgPropTrack", true, "Propgate tracks to associated collision to recalculate DCA and momentum vector"}; Configurable useRemoteCollisionInfo{"cfgUseRemoteCollisionInfo", false, "Use remote collision information from CCDB"}; } fConfigOptions; + struct : ConfigurableGroup { + Configurable applyBDT{"applyBDT", false, "Flag to apply ML selections"}; + Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; + + Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/j/jseo/ML/PbPbPsi/default/"}, "Paths of models on CCDB"}; + Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; + Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; + } fConfigML; Service fCCDB; o2::ccdb::CcdbApi fCCDBApi; @@ -1254,6 +1263,9 @@ struct AnalysisSameEventPairing { HistogramManager* fHistMan; + o2::analysis::DQMlResponse fDQMlResponse; + std::vector fOutputMlPsi2ee = {}; // TODO: check this is needed or not + // keep histogram class names in maps, so we don't have to buld their names in the pair loops std::map> fTrackHistNames; std::map> fMuonHistNames; @@ -1321,6 +1333,54 @@ struct AnalysisSameEventPairing { objArrayMuonCuts = muonCutsStr.Tokenize(","); } + if (fConfigML.applyBDT) { + // BDT cuts via JSON + std::vector binsMl; + o2::framework::LabeledArray cutsMl; + std::vector cutDirMl; + int nClassesMl = 1; + std::vector namesInputFeatures; + std::vector onnxFileNames; + + auto config = o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(fConfigML.fConfigBdtCutsJSON.value.c_str()); + + if (std::holds_alternative(config)) { + auto& cfg = std::get(config); + binsMl = cfg.binsMl; + nClassesMl = 1; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + fDQMlResponse.setBinsCent(cfg.binsCent); + fDQMlResponse.setBinsPt(cfg.binsPt); + fDQMlResponse.setCentType(cfg.centType); + LOG(info) << "Using BDT cuts for binary classification"; + } else { + auto& cfg = std::get(config); + binsMl = cfg.binsMl; + nClassesMl = 3; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + fDQMlResponse.setBinsCent(cfg.binsCent); + fDQMlResponse.setBinsPt(cfg.binsPt); + fDQMlResponse.setCentType(cfg.centType); + LOG(info) << "Using BDT cuts for multiclass classification"; + } + + fDQMlResponse.configure(binsMl, cutsMl, cutDirMl, nClassesMl); + if (fConfigML.loadModelsFromCCDB) { + fCCDBApi.init(fConfigCCDB.url); + fDQMlResponse.setModelPathsCCDB(onnxFileNames, fCCDBApi, fConfigML.modelPathsCCDB, fConfigML.timestampCCDB); + } else { + fDQMlResponse.setModelPathsLocal(onnxFileNames); + } + fDQMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + fDQMlResponse.init(); + } + // get the barrel track selection cuts string tempCuts; getTaskOptionValue(context, "analysis-track-selection", "cfgTrackCuts", tempCuts, false); @@ -1627,6 +1687,7 @@ struct AnalysisSameEventPairing { constexpr bool eventHasQvector = ((TEventFillMap & VarManager::ObjTypes::ReducedEventQvector) > 0); constexpr bool eventHasQvectorCentr = ((TEventFillMap & VarManager::ObjTypes::CollisionQvect) > 0); constexpr bool trackHasCov = ((TTrackFillMap & VarManager::ObjTypes::TrackCov) > 0 || (TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelCov) > 0); + bool isSelectedBDT = false; for (auto& event : events) { if (!event.isEventSelected_bit(0)) { @@ -1707,6 +1768,43 @@ struct AnalysisSameEventPairing { if constexpr (trackHasCov && TTwoProngFitter) { dielectronsExtraList(t1.globalIndex(), t2.globalIndex(), VarManager::fgValues[VarManager::kVertexingTauzProjected], VarManager::fgValues[VarManager::kVertexingLzProjected], VarManager::fgValues[VarManager::kVertexingLxyProjected]); if constexpr ((TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { + if (fConfigML.applyBDT) { + std::vector dqInputFeatures = fDQMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + + if (dqInputFeatures.empty()) { + LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; + return; + } + + int modelIndex = -1; + const auto& binsCent = fDQMlResponse.getBinsCent(); + const auto& binsPt = fDQMlResponse.getBinsPt(); + const std::string& centType = fDQMlResponse.getCentType(); + + if ("kCentFT0C" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0C], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else if ("kCentFT0A" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0A], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else if ("kCentFT0M" == centType) { + modelIndex = o2::aod::dqmlcuts::getMlBinIndex(VarManager::fgValues[VarManager::kCentFT0M], VarManager::fgValues[VarManager::kPt], binsCent, binsPt); + } else { + LOG(fatal) << "Unknown centrality estimation type: " << centType; + return; + } + + if (modelIndex < 0) { + LOG(info) << "Ml index is negative! This means that the centrality/pt is not in the range of the model bins."; + continue; + } + + LOG(debug) << "Model index: " << modelIndex << ", pT: " << VarManager::fgValues[VarManager::kPt] << ", centrality (kCentFT0C): " << VarManager::fgValues[VarManager::kCentFT0C]; + isSelectedBDT = fDQMlResponse.isSelectedMl(dqInputFeatures, modelIndex, fOutputMlPsi2ee); + VarManager::FillBdtScore(fOutputMlPsi2ee); // TODO: check if this is needed or not + } + + if (fConfigML.applyBDT && !isSelectedBDT) + continue; + if (fConfigOptions.flatTables.value) { dielectronAllList(VarManager::fgValues[VarManager::kMass], VarManager::fgValues[VarManager::kPt], VarManager::fgValues[VarManager::kEta], VarManager::fgValues[VarManager::kPhi], t1.sign() + t2.sign(), twoTrackFilter, dileptonMcDecision, t1.pt(), t1.eta(), t1.phi(), t1.itsClusterMap(), t1.itsChi2NCl(), t1.tpcNClsCrossedRows(), t1.tpcNClsFound(), t1.tpcChi2NCl(), t1.dcaXY(), t1.dcaZ(), t1.tpcSignal(), t1.tpcNSigmaEl(), t1.tpcNSigmaPi(), t1.tpcNSigmaPr(), t1.beta(), t1.tofNSigmaEl(), t1.tofNSigmaPi(), t1.tofNSigmaPr(), @@ -1831,6 +1929,10 @@ struct AnalysisSameEventPairing { bool isLeg1Ambi = false; bool isLeg2Ambi = false; bool isAmbiExtra = false; + + if (fConfigML.applyBDT && !isSelectedBDT) + continue; + for (int icut = 0; icut < ncuts; icut++) { if (twoTrackFilter & (static_cast(1) << icut)) { isAmbiInBunch = (twoTrackFilter & (static_cast(1) << 28)) || (twoTrackFilter & (static_cast(1) << 29));