diff --git a/PWGDQ/Core/CMakeLists.txt b/PWGDQ/Core/CMakeLists.txt index d19a66a68e6..41ceb661bf9 100644 --- a/PWGDQ/Core/CMakeLists.txt +++ b/PWGDQ/Core/CMakeLists.txt @@ -21,7 +21,7 @@ o2physics_add_library(PWGDQCore AnalysisCompositeCut.cxx MCProng.cxx MCSignal.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle) + PUBLIC_LINK_LIBRARIES O2::Framework O2::DCAFitter O2::GlobalTracking O2Physics::AnalysisCore KFParticle::KFParticle O2Physics::MLCore) o2physics_target_root_dictionary(PWGDQCore HEADERS AnalysisCut.h diff --git a/PWGDQ/Core/CutsLibrary.cxx b/PWGDQ/Core/CutsLibrary.cxx index 8eacab59397..ce9a3de312f 100644 --- a/PWGDQ/Core/CutsLibrary.cxx +++ b/PWGDQ/Core/CutsLibrary.cxx @@ -12,14 +12,19 @@ // Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no // #include "PWGDQ/Core/CutsLibrary.h" -#include -#include -#include -#include -#include + #include "AnalysisCompositeCut.h" #include "VarManager.h" +#include + +#include + +#include +#include +#include +#include + using std::cout; using std::endl; @@ -7099,3 +7104,154 @@ AnalysisCompositeCut* o2::aod::dqcuts::ParseJSONAnalysisCompositeCut(T cut, cons return retCut; } + +//________________________________________________________________________________________________ +o2::aod::dqmlcuts::BdtScoreConfig o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(const char* json) +{ + LOG(info) << "========================================== interpreting JSON for analysis cuts"; + LOG(info) << "JSON string: " << json; + + rapidjson::Document document; + rapidjson::ParseResult ok = document.Parse(json); + if (!ok) { + LOG(fatal) << "JSON parse error: " << rapidjson::GetParseErrorFunc(ok.Code()) << " (" << ok.Offset() << ")"; + return {}; // empty variant + } + + for (auto it = document.MemberBegin(); it != document.MemberEnd(); ++it) { + const auto& obj = it->value; + + if (!obj.HasMember("type")) { + LOG(fatal) << "Missing type (Binary/MultiClass)"; + return {}; + } + + TString typeStr = obj["type"].GetString(); + // int nClasses = (typeStr == "MultiClass") ? 3 : 1; + + std::vector namesInputFeatures; + if (obj.HasMember("inputFeatures") && obj["inputFeatures"].IsArray()) { + for (auto& feature : obj["inputFeatures"].GetArray()) { + namesInputFeatures.emplace_back(feature.GetString()); + } + } + + std::vector onnxFileNames; + if (obj.HasMember("modelFiles") && obj["modelFiles"].IsArray()) { + for (const auto& model : obj["modelFiles"].GetArray()) { + onnxFileNames.emplace_back(model.GetString()); + } + } + + // Cut storage + std::vector> ptBins; + std::vector> cutsMl; + std::vector cutDirs; + bool cutDirsFilled = false; + + for (auto member = obj.MemberBegin(); member != obj.MemberEnd(); ++member) { + TString key = member->name.GetString(); + if (!key.Contains("AddCut")) + continue; + + const auto& cut = member->value; + + if (!cut.HasMember("pTMin") || !cut.HasMember("pTMax")) { + LOG(fatal) << "Missing pTMin/pTMax in ML cut"; + return {}; + } + + double pTMin = cut["pTMin"].GetDouble(); + double pTMax = cut["pTMax"].GetDouble(); + ptBins.emplace_back(pTMin, pTMax); + + std::vector binCuts; + bool exclude = false; + + for (auto& sub : cut.GetObject()) { + TString subKey = sub.name.GetString(); + if (!subKey.Contains("AddMLCut")) + continue; + + const auto& mlcut = sub.value; + // const char* var = mlcut["var"].GetString(); + double cutVal = mlcut.HasMember("cut") ? mlcut["cut"].GetDouble() : 0.5; + exclude = mlcut.HasMember("exclude") ? mlcut["exclude"].GetBool() : false; + + binCuts.push_back(cutVal); + + if (!cutDirsFilled) { + cutDirs.push_back(exclude ? 1 : 0); + cutDirsFilled = true; + } + } + + cutsMl.push_back(binCuts); + } + + // bin edges + std::vector binsPt; + if (!ptBins.empty()) { + std::set binEdges; + for (auto& b : ptBins) + binEdges.insert(b.first); + binEdges.insert(ptBins.back().second); + binsPt = std::vector(binEdges.begin(), binEdges.end()); + } else { + LOG(fatal) << "No pT bins found in ML cuts"; + return {}; + } + + std::vector labelsPt, labelsClass; + for (size_t i = 0; i < cutsMl.size(); ++i) { + labelsPt.push_back(Form("pT%.1f", binsPt[i])); + } + for (size_t j = 0; j < cutsMl[0].size(); ++j) { + labelsClass.push_back(Form("cls%d", static_cast(j))); + } + + // Binary + if (typeStr == "Binary") { + dqmlcuts::BinaryBdtScoreConfig binaryCfg; + binaryCfg.inputFeatures = namesInputFeatures; + binaryCfg.onnxFiles = onnxFileNames; + binaryCfg.binsPt = binsPt; + binaryCfg.cutDirs = cutDirs; + binaryCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass); + + return binaryCfg; + + // MultiClass + } else if (typeStr == "MultiClass") { + dqmlcuts::MultiClassBdtScoreConfig multiCfg; + multiCfg.inputFeatures = namesInputFeatures; + multiCfg.onnxFiles = onnxFileNames; + multiCfg.binsPt = binsPt; + multiCfg.cutDirs = cutDirs; + multiCfg.cutsMl = makeLabeledCutsMl(cutsMl, labelsPt, labelsClass); + + return multiCfg; + } + + LOG(fatal) << "Unsupported classification type: " << typeStr; + return {}; + } + + return {}; +} + +o2::framework::LabeledArray o2::aod::dqmlcuts::makeLabeledCutsMl(const std::vector>& cuts, + const std::vector& labelsPt, + const std::vector& labelsClass) +{ + const size_t nRows = cuts.size(); + const size_t nCols = cuts.empty() ? 0 : cuts[0].size(); + std::vector flat; + + for (const auto& row : cuts) { + flat.insert(flat.end(), row.begin(), row.end()); + } + + o2::framework::Array2D arr(flat.data(), nRows, nCols); + return o2::framework::LabeledArray(arr, labelsPt, labelsClass); +} diff --git a/PWGDQ/Core/CutsLibrary.h b/PWGDQ/Core/CutsLibrary.h index c6ad4caded2..9418197b75e 100644 --- a/PWGDQ/Core/CutsLibrary.h +++ b/PWGDQ/Core/CutsLibrary.h @@ -119,6 +119,32 @@ bool ValidateJSONAnalysisCompositeCut(T cut); template AnalysisCompositeCut* ParseJSONAnalysisCompositeCut(T key, const char* cutName); } // namespace dqcuts +namespace dqmlcuts +{ +struct BinaryBdtScoreConfig { + std::vector inputFeatures; + std::vector onnxFiles; + std::vector binsPt; + o2::framework::LabeledArray cutsMl; + std::vector cutDirs; +}; + +struct MultiClassBdtScoreConfig { + std::vector inputFeatures; + std::vector onnxFiles; + std::vector binsPt; + o2::framework::LabeledArray cutsMl; + std::vector cutDirs; +}; + +using BdtScoreConfig = std::variant; + +BdtScoreConfig GetBdtScoreCutsAndConfigFromJSON(const char* json); + +o2::framework::LabeledArray makeLabeledCutsMl(const std::vector>& cuts, + const std::vector& labelsPt, + const std::vector& labelsClass); +} // namespace dqmlcuts } // namespace o2::aod AnalysisCompositeCut* o2::aod::dqcuts::GetCompositeCut(const char* cutName); diff --git a/PWGDQ/Core/DQMlResponse.h b/PWGDQ/Core/DQMlResponse.h new file mode 100644 index 00000000000..ab150800cf7 --- /dev/null +++ b/PWGDQ/Core/DQMlResponse.h @@ -0,0 +1,213 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. +// +// Contact: jseo@cern.ch +// +// Class to compute the ML response for DQ-analysis selections +// + +#ifndef PWGDQ_CORE_DQMLRESPONSE_H_ +#define PWGDQ_CORE_DQMLRESPONSE_H_ + +#include "Tools/ML/MlResponse.h" + +#include +#include +#include +#include + +// Fill the map of available input features +// the key is the feature's name (std::string) +// the value is the corresponding value in EnumInputFeatures +#define FILL_MAP(FEATURE) \ + { \ + #FEATURE, static_cast(InputFeatures::FEATURE) \ + } + +namespace o2::analysis +{ + +enum class InputFeatures : uint8_t { // refer to DielectronsAll + fMass = 0, + fPt, + fEta, + fPhi, + fPt1, + fITSChi2NCl1, + fTPCNClsCR1, + fTPCNClsFound1, + fTPCChi2NCl1, + fDcaXY1, + fDcaZ1, + fTPCNSigmaEl1, + fTPCNSigmaPi1, + fTPCNSigmaPr1, + fTOFNSigmaEl1, + fTOFNSigmaPi1, + fTOFNSigmaPr1, + fPt2, + fITSChi2NCl2, + fTPCNClsCR2, + fTPCNClsFound2, + fTPCChi2NCl2, + fDcaXY2, + fDcaZ2, + fTPCNSigmaEl2, + fTPCNSigmaPi2, + fTPCNSigmaPr2, + fTOFNSigmaEl2, + fTOFNSigmaPi2, + fTOFNSigmaPr2, +}; + +static const std::map gFeatureNameMap = { + {InputFeatures::fMass, "fMass"}, + {InputFeatures::fPt, "fPt"}, + {InputFeatures::fEta, "fEta"}, + {InputFeatures::fPhi, "fPhi"}, + {InputFeatures::fPt1, "fPt1"}, + {InputFeatures::fITSChi2NCl1, "fITSChi2NCl1"}, + {InputFeatures::fTPCNClsCR1, "fTPCNClsCR1"}, + {InputFeatures::fTPCNClsFound1, "fTPCNClsFound1"}, + {InputFeatures::fTPCChi2NCl1, "fTPCChi2NCl1"}, + {InputFeatures::fDcaXY1, "fDcaXY1"}, + {InputFeatures::fDcaZ1, "fDcaZ1"}, + {InputFeatures::fTPCNSigmaEl1, "fTPCNSigmaEl1"}, + {InputFeatures::fTPCNSigmaPi1, "fTPCNSigmaPi1"}, + {InputFeatures::fTPCNSigmaPr1, "fTPCNSigmaPr1"}, + {InputFeatures::fTOFNSigmaEl1, "fTOFNSigmaEl1"}, + {InputFeatures::fTOFNSigmaPi1, "fTOFNSigmaPi1"}, + {InputFeatures::fTOFNSigmaPr1, "fTOFNSigmaPr1"}, + {InputFeatures::fPt2, "fPt2"}, + {InputFeatures::fITSChi2NCl2, "fITSChi2NCl2"}, + {InputFeatures::fTPCNClsCR2, "fTPCNClsCR2"}, + {InputFeatures::fTPCNClsFound2, "fTPCNClsFound2"}, + {InputFeatures::fTPCChi2NCl2, "fTPCChi2NCl2"}, + {InputFeatures::fDcaXY2, "fDcaXY2"}, + {InputFeatures::fDcaZ2, "fDcaZ2"}, + {InputFeatures::fTPCNSigmaEl2, "fTPCNSigmaEl2"}, + {InputFeatures::fTPCNSigmaPi2, "fTPCNSigmaPi2"}, + {InputFeatures::fTPCNSigmaPr2, "fTPCNSigmaPr2"}, + {InputFeatures::fTOFNSigmaEl2, "fTOFNSigmaEl2"}, + {InputFeatures::fTOFNSigmaPi2, "fTOFNSigmaPi2"}, + {InputFeatures::fTOFNSigmaPr2, "fTOFNSigmaPr2"}}; + +template +class DQMlResponse : public MlResponse +{ + public: + /// Default constructor + DQMlResponse() = default; + /// Default destructor + virtual ~DQMlResponse() = default; + + /// Method to get the input features vector needed for ML inference + /// \return inputFeatures vector + template + std::vector getInputFeatures(const T1& t1, + const T2& t2, + const TValues& fg) const + { + using Accessor = std::function; + static const std::unordered_map featureMap{ + {"fMass", [](auto const&, auto const&, auto const& v) { return v[VarManager::kMass]; }}, + {"fPt", [](auto const&, auto const&, auto const& v) { return v[VarManager::kPt]; }}, + {"fEta", [](auto const&, auto const&, auto const& v) { return v[VarManager::kEta]; }}, + {"fPhi", [](auto const&, auto const&, auto const& v) { return v[VarManager::kPhi]; }}, + + {"fPt1", [](auto const& t1, auto const&, auto const&) { return t1.pt(); }}, + {"fITSChi2NCl1", [](auto const& t1, auto const&, auto const&) { return t1.itsChi2NCl(); }}, + {"fTPCNClsCR1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNClsCrossedRows(); }}, + {"fTPCNClsFound1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNClsFound(); }}, + {"fTPCChi2NCl1", [](auto const& t1, auto const&, auto const&) { return t1.tpcChi2NCl(); }}, + {"fDcaXY1", [](auto const& t1, auto const&, auto const&) { return t1.dcaXY(); }}, + {"fDcaZ1", [](auto const& t1, auto const&, auto const&) { return t1.dcaZ(); }}, + {"fTPCNSigmaEl1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaEl(); }}, + {"fTPCNSigmaPi1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaPi(); }}, + {"fTPCNSigmaPr1", [](auto const& t1, auto const&, auto const&) { return t1.tpcNSigmaPr(); }}, + {"fTOFNSigmaEl1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaEl(); }}, + {"fTOFNSigmaPi1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaPi(); }}, + {"fTOFNSigmaPr1", [](auto const& t1, auto const&, auto const&) { return t1.tofNSigmaPr(); }}, + + {"fPt2", [](auto const&, auto const& t2, auto const&) { return t2.pt(); }}, + {"fITSChi2NCl2", [](auto const&, auto const& t2, auto const&) { return t2.itsChi2NCl(); }}, + {"fTPCNClsCR2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNClsCrossedRows(); }}, + {"fTPCNClsFound2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNClsFound(); }}, + {"fTPCChi2NCl2", [](auto const&, auto const& t2, auto const&) { return t2.tpcChi2NCl(); }}, + {"fDcaXY2", [](auto const&, auto const& t2, auto const&) { return t2.dcaXY(); }}, + {"fDcaZ2", [](auto const&, auto const& t2, auto const&) { return t2.dcaZ(); }}, + {"fTPCNSigmaEl2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaEl(); }}, + {"fTPCNSigmaPi2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaPi(); }}, + {"fTPCNSigmaPr2", [](auto const&, auto const& t2, auto const&) { return t2.tpcNSigmaPr(); }}, + {"fTOFNSigmaEl2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaEl(); }}, + {"fTOFNSigmaPi2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaPi(); }}, + {"fTOFNSigmaPr2", [](auto const&, auto const& t2, auto const&) { return t2.tofNSigmaPr(); }}}; + + std::vector dqInputFeatures; + dqInputFeatures.reserve(MlResponse::mCachedIndices.size()); + + for (auto idx : MlResponse::mCachedIndices) { + auto enumIdx = static_cast(idx); + const auto& name = gFeatureNameMap.at(enumIdx); + + auto acc = featureMap.find(name); + if (acc == featureMap.end()) { + LOG(error) << "Missing accessor for " << name; + continue; + } else { + dqInputFeatures.push_back(acc->second(t1, t2, fg)); + } + } + return dqInputFeatures; + } + + protected: + void setAvailableInputFeatures() + { + MlResponse::mAvailableInputFeatures = { + FILL_MAP(fMass), + FILL_MAP(fPt), + FILL_MAP(fEta), + FILL_MAP(fPhi), + FILL_MAP(fPt1), + FILL_MAP(fITSChi2NCl1), + FILL_MAP(fTPCNClsCR1), + FILL_MAP(fTPCNClsFound1), + FILL_MAP(fTPCChi2NCl1), + FILL_MAP(fDcaXY1), + FILL_MAP(fDcaZ1), + FILL_MAP(fTPCNSigmaEl1), + FILL_MAP(fTPCNSigmaPi1), + FILL_MAP(fTPCNSigmaPr1), + FILL_MAP(fTOFNSigmaEl1), + FILL_MAP(fTOFNSigmaPi1), + FILL_MAP(fTOFNSigmaPr1), + FILL_MAP(fPt2), + FILL_MAP(fITSChi2NCl2), + FILL_MAP(fTPCNClsCR2), + FILL_MAP(fTPCNClsFound2), + FILL_MAP(fTPCChi2NCl2), + FILL_MAP(fDcaXY2), + FILL_MAP(fDcaZ2), + FILL_MAP(fTPCNSigmaEl2), + FILL_MAP(fTPCNSigmaPi2), + FILL_MAP(fTPCNSigmaPr2), + FILL_MAP(fTOFNSigmaEl2), + FILL_MAP(fTOFNSigmaPi2), + FILL_MAP(fTOFNSigmaPr2)}; + } +}; + +} // namespace o2::analysis + +#undef FILL_MAP + +#endif // PWGDQ_CORE_DQMLRESPONSE_H_ diff --git a/PWGDQ/Core/VarManager.cxx b/PWGDQ/Core/VarManager.cxx index 912dc06740a..4bb8c419ad7 100644 --- a/PWGDQ/Core/VarManager.cxx +++ b/PWGDQ/Core/VarManager.cxx @@ -1131,6 +1131,12 @@ void VarManager::SetDefaultVarNames() fgVariableUnits[kS13] = "GeV^{2}/c^{4}"; fgVariableNames[kS23] = "m_{23}^{2}"; fgVariableUnits[kS23] = "GeV^{2}/c^{4}"; + fgVariableNames[kBdtBackground] = "kBdtBackground"; + fgVariableUnits[kBdtBackground] = " "; + fgVariableNames[kBdtPrompt] = "kBdtPrompt"; + fgVariableUnits[kBdtPrompt] = " "; + fgVariableNames[kBdtNonprompt] = "kBdtNonprompt"; + fgVariableUnits[kBdtNonprompt] = " "; // Set the variables short names map. This is needed for dynamic configuration via JSON files fgVarNamesMap["kNothing"] = kNothing; @@ -1770,4 +1776,7 @@ void VarManager::SetDefaultVarNames() fgVarNamesMap["kV24ME"] = kV24ME; fgVarNamesMap["kWV22ME"] = kWV22ME; fgVarNamesMap["kWV24ME"] = kWV24ME; + fgVarNamesMap["kBdtBackground"] = kBdtBackground; + fgVarNamesMap["kBdtPrompt"] = kBdtPrompt; + fgVarNamesMap["kBdtNonprompt"] = kBdtNonprompt; } diff --git a/PWGDQ/Core/VarManager.h b/PWGDQ/Core/VarManager.h index dfcba891bfc..f748f16ce35 100644 --- a/PWGDQ/Core/VarManager.h +++ b/PWGDQ/Core/VarManager.h @@ -855,6 +855,11 @@ class VarManager : public TObject // deltaMass_jpsi = kPairMass - kPairMassDau +3.096900 kDeltaMass_jpsi, + // BDT score + kBdtBackground, + kBdtPrompt, + kBdtNonprompt, + kNVars }; // end of Variables enumeration @@ -1127,6 +1132,8 @@ class VarManager : public TObject static void FillDileptonTrackTrackVertexing(C const& collision, T1 const& lepton1, T1 const& lepton2, T1 const& track1, T1 const& track2, float* values); template static void FillZDC(const T& zdc, float* values = nullptr); + template + static void FillBdtScore(const T& bdtScore, float* values = nullptr); static void SetCalibrationObject(CalibObjects calib, TObject* obj) { @@ -5533,4 +5540,22 @@ float VarManager::calculatePhiV(T1 const& t1, T2 const& t2) return pairPhiV; } +template +void VarManager::FillBdtScore(T1 const& bdtScore, float* values) +{ + if (!values) { + values = fgValues; + } + + if (bdtScore.size() == 1) { + values[kBdtBackground] = bdtScore[0]; + } else if (bdtScore.size() == 3) { + values[kBdtBackground] = bdtScore[0]; + values[kBdtPrompt] = bdtScore[1]; + values[kBdtNonprompt] = bdtScore[2]; + } else { + LOG(warning) << "Unexpected number of BDT outputs: " << bdtScore.size(); + } +} + #endif // PWGDQ_CORE_VARMANAGER_H_ diff --git a/PWGDQ/Macros/bdtCut.json b/PWGDQ/Macros/bdtCut.json new file mode 100644 index 00000000000..5ff0ada64b4 --- /dev/null +++ b/PWGDQ/Macros/bdtCut.json @@ -0,0 +1,56 @@ +{ + "TestCut": { + "type": "Binary", + "title": "MyBDTModel", + "inputFeatures": [ + "fPt1", + "fITSChi2NCl1", + "fTPCNClsCR1", + "fTPCNClsFound1", + "fTPCChi2NCl1", + "fDcaXY1", + "fDcaZ1", + "fTPCNSigmaEl1", + "fTPCNSigmaPi1", + "fTPCNSigmaPr1", + "fTOFNSigmaEl1", + "fTOFNSigmaPi1", + "fTOFNSigmaPr1", + "fPt2", + "fITSChi2NCl2", + "fTPCNClsCR2", + "fTPCNClsFound2", + "fTPCChi2NCl2", + "fDcaXY2", + "fDcaZ2", + "fTPCNSigmaEl2", + "fTPCNSigmaPi2", + "fTPCNSigmaPr2", + "fTOFNSigmaEl2", + "fTOFNSigmaPi2", + "fTOFNSigmaPr2" + ], + "modelFiles": [ + "cent_30_50_pt0_2_onnx.onnx", + "cent_30_50_pt2_20_onnx.onnx" + ], + "AddMLCut-pTBin1": { + "pTMin": 0, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + }, + "AddMLCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": false + } + } + } +} diff --git a/PWGDQ/Macros/bdtCutMulti.json b/PWGDQ/Macros/bdtCutMulti.json new file mode 100644 index 00000000000..3042e767c11 --- /dev/null +++ b/PWGDQ/Macros/bdtCutMulti.json @@ -0,0 +1,74 @@ +{ + "TestCut": { + "type": "Binary", + "title": "MyBDTModel", + "inputFeatures": [ + "fITSChi2NCl1", + "fTPCNClsCR1", + "fTPCNClsFound1", + "fTPCChi2NCl1", + "fDcaXY1", + "fDcaZ1", + "fTPCNSigmaEl1", + "fTPCNSigmaPi1", + "fTPCNSigmaPr1", + "fTOFNSigmaEl1", + "fTOFNSigmaPi1", + "fTOFNSigmaPr1", + "fITSChi2NCl2", + "fTPCNClsCR2", + "fTPCNClsFound2", + "fTPCChi2NCl2", + "fDcaXY2", + "fDcaZ2", + "fTPCNSigmaEl2", + "fTPCNSigmaPi2", + "fTPCNSigmaPr2", + "fTOFNSigmaEl2", + "fTOFNSigmaPi2", + "fTOFNSigmaPr2" + ], + "modelFiles": [ + "cent_30_50_pt0_2_onnx.onnx", + "cent_30_50_pt2_20_onnx.onnx" + ], + "AddMLCut-pTBin1": { + "pTMin": 0, + "pTMax": 2, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.5, + "exclude": false + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } + }, + "AddMLCut-pTBin2": { + "pTMin": 2, + "pTMax": 20, + "AddMLCut-background": { + "var": "kBdtBackground", + "cut": 0.5, + "exclude": true + }, + "AddMLCut-prompt": { + "var": "kBdtPrompt", + "cut": 0.5, + "exclude": false + }, + "AddMLCut-nonprompt": { + "var": "kBdtNonprompt", + "cut": 0.5, + "exclude": false + } + } + } +} diff --git a/PWGDQ/Tasks/CMakeLists.txt b/PWGDQ/Tasks/CMakeLists.txt index 5095140a2b8..3492d815bb2 100644 --- a/PWGDQ/Tasks/CMakeLists.txt +++ b/PWGDQ/Tasks/CMakeLists.txt @@ -11,7 +11,7 @@ o2physics_add_dpl_workflow(table-reader SOURCES tableReader.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2Physics::PWGDQCore + PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2Physics::PWGDQCore O2Physics::MLCore COMPONENT_NAME Analysis) o2physics_add_dpl_workflow(table-reader-with-assoc diff --git a/PWGDQ/Tasks/tableReader.cxx b/PWGDQ/Tasks/tableReader.cxx index 22a8d82e2d8..6625bcc5037 100644 --- a/PWGDQ/Tasks/tableReader.cxx +++ b/PWGDQ/Tasks/tableReader.cxx @@ -11,39 +11,44 @@ // // Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no // -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "CCDB/BasicCCDBManager.h" -#include "DataFormatsParameters/GRPObject.h" -#include "Framework/runDataProcessing.h" -#include "Framework/AnalysisTask.h" -#include "Framework/AnalysisDataModel.h" -#include "Framework/ASoAHelpers.h" -#include "PWGDQ/DataModel/ReducedInfoTables.h" -#include "PWGDQ/Core/VarManager.h" -#include "PWGDQ/Core/HistogramManager.h" -#include "PWGDQ/Core/MixingHandler.h" -#include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/AnalysisCompositeCut.h" -#include "PWGDQ/Core/HistogramsLibrary.h" +#include "PWGDQ/Core/AnalysisCut.h" #include "PWGDQ/Core/CutsLibrary.h" +#include "PWGDQ/Core/DQMlResponse.h" +#include "PWGDQ/Core/HistogramManager.h" +#include "PWGDQ/Core/HistogramsLibrary.h" +#include "PWGDQ/Core/MixingHandler.h" #include "PWGDQ/Core/MixingLibrary.h" +#include "PWGDQ/Core/VarManager.h" +#include "PWGDQ/DataModel/ReducedInfoTables.h" + +#include "Common/CCDB/EventSelectionParams.h" + +#include "CCDB/BasicCCDBManager.h" #include "DataFormatsParameters/GRPMagField.h" -#include "Field/MagneticField.h" -#include "TGeoGlobalMagField.h" -#include "DetectorsBase/Propagator.h" +#include "DataFormatsParameters/GRPObject.h" #include "DetectorsBase/GeometryManager.h" +#include "DetectorsBase/Propagator.h" +#include "Field/MagneticField.h" +#include "Framework/ASoAHelpers.h" +#include "Framework/AnalysisDataModel.h" +#include "Framework/AnalysisTask.h" +#include "Framework/runDataProcessing.h" #include "ITSMFTBase/DPLAlpideParam.h" -#include "Common/CCDB/EventSelectionParams.h" + +#include "TGeoGlobalMagField.h" +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include using std::cout; using std::endl; @@ -1050,6 +1055,12 @@ struct AnalysisSameEventPairing { Configurable fCenterMassEnergy{"energy", 13600, "Center of mass energy in GeV"}; Configurable fConfigCumulants{"cfgCumulants", false, "If true, fill Cumulants with Weights different than 0"}; Configurable fConfigAddJSONHistograms{"cfgAddJSONHistograms", "", "Histograms in JSON format"}; + // ML inference + Configurable applyBDT{"applyBDT", false, "Flag to apply ML selections"}; + Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; + Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/j/jseo/ML/PbPbPsi/default/"}, "Paths of models on CCDB"}; + Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; + Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; // Configurables to create output tree (flat tables or minitree) struct : ConfigurableGroup { @@ -1067,6 +1078,10 @@ struct AnalysisSameEventPairing { HistogramManager* fHistMan; + o2::analysis::DQMlResponse dqMlResponse; + std::vector outputMlPsi2ee = {}; // TODO: check this is needed or not + o2::ccdb::CcdbApi ccdbApi; + // NOTE: The track filter produced by the barrel track selection contain a number of electron cut decisions and one last cut for hadrons used in the // dilepton - hadron task downstream. So the bit mask is required to select pairs just based on the electron cuts // TODO: provide as Configurable the list and names of the cuts which should be used in pairing @@ -1117,7 +1132,47 @@ struct AnalysisSameEventPairing { } } - if (context.mOptions.get("processDecayToEESkimmed") || context.mOptions.get("processDecayToEESkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithCov") || context.mOptions.get("processDecayToEESkimmedWithCovNoTwoProngFitter") || context.mOptions.get("processDecayToEEVertexingSkimmed") || context.mOptions.get("processVnDecayToEESkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithColl") || context.mOptions.get("processDecayToEESkimmedWithCollNoTwoProngFitter") || context.mOptions.get("processDecayToPiPiSkimmed") || context.mOptions.get("processAllSkimmed")) { + if (applyBDT) { + // BDT cuts via JSON + std::vector binsPtMl; + o2::framework::LabeledArray cutsMl; + std::vector cutDirMl; + int nClassesMl = 1; // 1 for binary BDT, 3 for multiclass BDT + std::vector namesInputFeatures; + std::vector onnxFileNames; + + auto config = o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(fConfigBdtCutsJSON.value.c_str()); + + if (std::holds_alternative(config)) { + auto& cfg = std::get(config); + binsPtMl = cfg.binsPt; + nClassesMl = 1; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + } else { + auto& cfg = std::get(config); + binsPtMl = cfg.binsPt; + nClassesMl = 3; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + } + + dqMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl); + if (loadModelsFromCCDB) { + ccdbApi.init(ccdburl); + dqMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB); + } else { + dqMlResponse.setModelPathsLocal(onnxFileNames); + } + dqMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + dqMlResponse.init(); + } + + if (context.mOptions.get("processDecayToEESkimmed") || context.mOptions.get("processDecayToEESkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithCov") || context.mOptions.get("processDecayToEESkimmedWithCovNoTwoProngFitter") || context.mOptions.get("processDecayToEEVertexingSkimmed") || context.mOptions.get("processVnDecayToEESkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmed") || context.mOptions.get("processDecayToEEPrefilterSkimmedNoTwoProngFitter") || context.mOptions.get("processDecayToEESkimmedWithColl") || context.mOptions.get("processDecayToEESkimmedWithCollNoTwoProngFitter") || context.mOptions.get("processDecayToPiPiSkimmed") || context.mOptions.get("processAllSkimmed") || context.mOptions.get("processDecayToEESkimmedBDT")) { TString cutNames = fConfigTrackCuts.value; if (!cutNames.IsNull()) { // if track cuts std::unique_ptr objArray(cutNames.Tokenize(",")); @@ -1313,6 +1368,8 @@ struct AnalysisSameEventPairing { dileptonMiniTree.reserve(1); } + bool isSelectedBDT = false; + if (fConfigMultDimuons.value) { uint32_t mult_dimuons = 0; @@ -1391,6 +1448,22 @@ struct AnalysisSameEventPairing { } } if constexpr ((TPairType == pairTypeEE) && (TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { + if (applyBDT) { + std::vector dqInputFeatures = dqMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + + if (dqInputFeatures.empty()) { + LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; + return; + } + + // isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt]); + isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt], outputMlPsi2ee); + VarManager::FillBdtScore(outputMlPsi2ee); // TODO: check if this is needed or not + } + + if (applyBDT && !isSelectedBDT) + continue; + if (fConfigFlatTables.value) { dielectronAllList(VarManager::fgValues[VarManager::kMass], VarManager::fgValues[VarManager::kPt], VarManager::fgValues[VarManager::kEta], VarManager::fgValues[VarManager::kPhi], t1.sign() + t2.sign(), dileptonFilterMap, dileptonMcDecision, t1.pt(), t1.eta(), t1.phi(), t1.itsClusterMap(), t1.itsChi2NCl(), t1.tpcNClsCrossedRows(), t1.tpcNClsFound(), t1.tpcChi2NCl(), t1.dcaXY(), t1.dcaZ(), t1.tpcSignal(), t1.tpcNSigmaEl(), t1.tpcNSigmaPi(), t1.tpcNSigmaPr(), t1.beta(), t1.tofNSigmaEl(), t1.tofNSigmaPi(), t1.tofNSigmaPr(), @@ -1464,6 +1537,9 @@ struct AnalysisSameEventPairing { } } + if (applyBDT && !isSelectedBDT) + continue; + int iCut = 0; for (int icut = 0; icut < ncuts; icut++) { if (twoTrackFilter & (static_cast(1) << icut)) { @@ -1603,6 +1679,13 @@ struct AnalysisSameEventPairing { VarManager::FillEvent(event, VarManager::fgValues); runSameEventPairing(event, tracks, tracks); } + void processDecayToEESkimmedBDT(soa::Filtered::iterator const& event, soa::Filtered const& tracks) + { + // Reset the fValues array + VarManager::ResetValues(0, VarManager::kNVars); + VarManager::FillEvent(event, VarManager::fgValues); + runSameEventPairing(event, tracks, tracks); + } void processDecayToMuMuSkimmed(soa::Filtered::iterator const& event, soa::Filtered const& muons) { // Reset the fValues array @@ -1707,6 +1790,7 @@ struct AnalysisSameEventPairing { PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEEPrefilterSkimmedNoTwoProngFitter, "Run electron-electron pairing, with skimmed tracks and prefilter from AnalysisPrefilterSelection but no two prong fitter", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedWithColl, "Run electron-electron pairing, with skimmed tracks and with collision information", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedWithCollNoTwoProngFitter, "Run electron-electron pairing, with skimmed tracks and with collision information but no two prong fitter", false); + PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToEESkimmedBDT, "Run electron-electron pairing, with skimmed tracks and BDT selection", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuSkimmed, "Run muon-muon pairing, with skimmed muons", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuSkimmedWithMult, "Run muon-muon pairing, with skimmed muons and multiplicity", false); PROCESS_SWITCH(AnalysisSameEventPairing, processDecayToMuMuVertexingSkimmed, "Run muon-muon pairing and vertexing, with skimmed muons", false); diff --git a/PWGDQ/Tasks/tableReader_withAssoc.cxx b/PWGDQ/Tasks/tableReader_withAssoc.cxx index d6b29b461d2..0aa383c46e3 100644 --- a/PWGDQ/Tasks/tableReader_withAssoc.cxx +++ b/PWGDQ/Tasks/tableReader_withAssoc.cxx @@ -12,51 +12,55 @@ // Contact: iarsene@cern.ch, i.c.arsene@fys.uio.no // Configurable workflow for running several DQ or other PWG analyses +#include "PWGDQ/Core/AnalysisCompositeCut.h" +#include "PWGDQ/Core/AnalysisCut.h" +#include "PWGDQ/Core/CutsLibrary.h" +#include "PWGDQ/Core/DQMlResponse.h" +#include "PWGDQ/Core/HistogramManager.h" +#include "PWGDQ/Core/HistogramsLibrary.h" +#include "PWGDQ/Core/MixingHandler.h" +#include "PWGDQ/Core/MixingLibrary.h" +#include "PWGDQ/Core/VarManager.h" +#include "PWGDQ/DataModel/ReducedInfoTables.h" + +#include "Common/CCDB/EventSelectionParams.h" +#include "Common/Core/TableHelper.h" + +#include "CCDB/BasicCCDBManager.h" +#include "DataFormatsParameters/GRPMagField.h" +#include "DataFormatsParameters/GRPObject.h" +#include "DetectorsBase/GeometryManager.h" +#include "DetectorsBase/Propagator.h" +#include "Field/MagneticField.h" +#include "Framework/ASoAHelpers.h" +#include "Framework/AnalysisDataModel.h" +#include "Framework/AnalysisHelpers.h" +#include "Framework/AnalysisTask.h" +#include "Framework/Configurable.h" +#include "Framework/OutputObjHeader.h" +#include "Framework/runDataProcessing.h" +#include "ITSMFTBase/DPLAlpideParam.h" + +#include "TGeoGlobalMagField.h" +#include +#include +#include +#include +#include +#include + +#include #include #include #include #include +#include +#include #include -#include -#include -#include #include -#include #include -#include #include -#include -#include -#include -#include -#include -#include -#include "CCDB/BasicCCDBManager.h" -#include "DataFormatsParameters/GRPObject.h" -#include "Framework/AnalysisHelpers.h" -#include "Framework/Configurable.h" -#include "Framework/OutputObjHeader.h" -#include "Framework/runDataProcessing.h" -#include "Framework/AnalysisTask.h" -#include "Framework/AnalysisDataModel.h" -#include "Framework/ASoAHelpers.h" -#include "PWGDQ/DataModel/ReducedInfoTables.h" -#include "PWGDQ/Core/VarManager.h" -#include "PWGDQ/Core/HistogramManager.h" -#include "PWGDQ/Core/MixingHandler.h" -#include "PWGDQ/Core/AnalysisCut.h" -#include "PWGDQ/Core/AnalysisCompositeCut.h" -#include "PWGDQ/Core/HistogramsLibrary.h" -#include "PWGDQ/Core/CutsLibrary.h" -#include "PWGDQ/Core/MixingLibrary.h" -#include "DataFormatsParameters/GRPMagField.h" -#include "Field/MagneticField.h" -#include "TGeoGlobalMagField.h" -#include "DetectorsBase/Propagator.h" -#include "DetectorsBase/GeometryManager.h" -#include "Common/Core/TableHelper.h" -#include "ITSMFTBase/DPLAlpideParam.h" -#include "Common/CCDB/EventSelectionParams.h" +#include using std::cout; using std::endl; @@ -1236,6 +1240,14 @@ struct AnalysisSameEventPairing { Configurable propTrack{"cfgPropTrack", true, "Propgate tracks to associated collision to recalculate DCA and momentum vector"}; Configurable useRemoteCollisionInfo{"cfgUseRemoteCollisionInfo", false, "Use remote collision information from CCDB"}; } fConfigOptions; + struct : ConfigurableGroup { + Configurable applyBDT{"applyBDT", false, "Flag to apply ML selections"}; + Configurable fConfigBdtCutsJSON{"fConfigBdtCutsJSON", "", "Additional list of BDT cuts in JSON format"}; + + Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/j/jseo/ML/PbPbPsi/default/"}, "Paths of models on CCDB"}; + Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; + Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; + } fConfigML; Service fCCDB; o2::ccdb::CcdbApi fCCDBApi; @@ -1244,6 +1256,9 @@ struct AnalysisSameEventPairing { HistogramManager* fHistMan; + o2::analysis::DQMlResponse dqMlResponse; + std::vector outputMlPsi2ee = {}; // TODO: check this is needed or not + // keep histogram class names in maps, so we don't have to buld their names in the pair loops std::map> fTrackHistNames; std::map> fMuonHistNames; @@ -1271,7 +1286,7 @@ struct AnalysisSameEventPairing { void init(o2::framework::InitContext& context) { LOG(info) << "Starting initialization of AnalysisSameEventPairing (idstoreh)"; - fEnableBarrelHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processBarrelOnlySkimmed") || context.mOptions.get("processBarrelOnlyWithCollSkimmed") || context.mOptions.get("processBarrelOnlySkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedNoCovWithMultExtra") || context.mOptions.get("processBarrelOnlyWithQvectorCentrSkimmedNoCov"); + fEnableBarrelHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processBarrelOnlySkimmed") || context.mOptions.get("processBarrelOnlyWithCollSkimmed") || context.mOptions.get("processBarrelOnlySkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedNoCovWithMultExtra") || context.mOptions.get("processBarrelOnlyWithQvectorCentrSkimmedNoCov") || context.mOptions.get("processBarrelOnlySkimmedBDT"); fEnableBarrelMixingHistos = context.mOptions.get("processMixingAllSkimmed") || context.mOptions.get("processMixingBarrelSkimmed"); fEnableMuonHistos = context.mOptions.get("processAllSkimmed") || context.mOptions.get("processMuonOnlySkimmed") || context.mOptions.get("processMuonOnlySkimmedMultExtra") || context.mOptions.get("processMixingMuonSkimmed"); fEnableMuonMixingHistos = context.mOptions.get("processMixingAllSkimmed") || context.mOptions.get("processMixingMuonSkimmed"); @@ -1311,6 +1326,46 @@ struct AnalysisSameEventPairing { objArrayMuonCuts = muonCutsStr.Tokenize(","); } + if (fConfigML.applyBDT) { + // BDT cuts via JSON + std::vector binsPtMl; + o2::framework::LabeledArray cutsMl; + std::vector cutDirMl; + int nClassesMl = 1; // 1 for binary BDT, 3 for multiclass BDT + std::vector namesInputFeatures; + std::vector onnxFileNames; + + auto config = o2::aod::dqmlcuts::GetBdtScoreCutsAndConfigFromJSON(fConfigML.fConfigBdtCutsJSON.value.c_str()); + + if (std::holds_alternative(config)) { + auto& cfg = std::get(config); + binsPtMl = cfg.binsPt; + nClassesMl = 1; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + } else { + auto& cfg = std::get(config); + binsPtMl = cfg.binsPt; + nClassesMl = 3; + cutsMl = cfg.cutsMl; + cutDirMl = cfg.cutDirs; + namesInputFeatures = cfg.inputFeatures; + onnxFileNames = cfg.onnxFiles; + } + + dqMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl); + if (fConfigML.loadModelsFromCCDB) { + fCCDBApi.init(fConfigCCDB.url); + dqMlResponse.setModelPathsCCDB(onnxFileNames, fCCDBApi, fConfigML.modelPathsCCDB, fConfigML.timestampCCDB); + } else { + dqMlResponse.setModelPathsLocal(onnxFileNames); + } + dqMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + dqMlResponse.init(); + } + // get the barrel track selection cuts string tempCuts; getTaskOptionValue(context, "analysis-track-selection", "cfgTrackCuts", tempCuts, false); @@ -1617,6 +1672,7 @@ struct AnalysisSameEventPairing { constexpr bool eventHasQvector = ((TEventFillMap & VarManager::ObjTypes::ReducedEventQvector) > 0); constexpr bool eventHasQvectorCentr = ((TEventFillMap & VarManager::ObjTypes::CollisionQvect) > 0); constexpr bool trackHasCov = ((TTrackFillMap & VarManager::ObjTypes::TrackCov) > 0 || (TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelCov) > 0); + bool isSelectedBDT = false; for (auto& event : events) { if (!event.isEventSelected_bit(0)) { @@ -1697,6 +1753,22 @@ struct AnalysisSameEventPairing { if constexpr (trackHasCov && TTwoProngFitter) { dielectronsExtraList(t1.globalIndex(), t2.globalIndex(), VarManager::fgValues[VarManager::kVertexingTauzProjected], VarManager::fgValues[VarManager::kVertexingLzProjected], VarManager::fgValues[VarManager::kVertexingLxyProjected]); if constexpr ((TTrackFillMap & VarManager::ObjTypes::ReducedTrackBarrelPID) > 0) { + if (fConfigML.applyBDT) { + std::vector dqInputFeatures = dqMlResponse.getInputFeatures(t1, t2, VarManager::fgValues); + + if (dqInputFeatures.empty()) { + LOG(fatal) << "Input features for ML selection are empty! Please check your configuration."; + return; + } + + // isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt]); + isSelectedBDT = dqMlResponse.isSelectedMl(dqInputFeatures, VarManager::fgValues[VarManager::kPt], outputMlPsi2ee); + VarManager::FillBdtScore(outputMlPsi2ee); // TODO: check if this is needed or not + } + + if (fConfigML.applyBDT && !isSelectedBDT) + continue; + if (fConfigOptions.flatTables.value) { dielectronAllList(VarManager::fgValues[VarManager::kMass], VarManager::fgValues[VarManager::kPt], VarManager::fgValues[VarManager::kEta], VarManager::fgValues[VarManager::kPhi], t1.sign() + t2.sign(), twoTrackFilter, dileptonMcDecision, t1.pt(), t1.eta(), t1.phi(), t1.itsClusterMap(), t1.itsChi2NCl(), t1.tpcNClsCrossedRows(), t1.tpcNClsFound(), t1.tpcChi2NCl(), t1.dcaXY(), t1.dcaZ(), t1.tpcSignal(), t1.tpcNSigmaEl(), t1.tpcNSigmaPi(), t1.tpcNSigmaPr(), t1.beta(), t1.tofNSigmaEl(), t1.tofNSigmaPi(), t1.tofNSigmaPr(), @@ -1821,6 +1893,10 @@ struct AnalysisSameEventPairing { bool isLeg1Ambi = false; bool isLeg2Ambi = false; bool isAmbiExtra = false; + + if (fConfigML.applyBDT && !isSelectedBDT) + continue; + for (int icut = 0; icut < ncuts; icut++) { if (twoTrackFilter & (static_cast(1) << icut)) { isAmbiInBunch = (twoTrackFilter & (static_cast(1) << 28)) || (twoTrackFilter & (static_cast(1) << 29)); @@ -2142,6 +2218,13 @@ struct AnalysisSameEventPairing { runSameEventPairing(events, trackAssocsPerCollision, barrelAssocs, barrelTracks); } + void processBarrelOnlySkimmedBDT(MyEventsVtxCovSelected const& events, + soa::Join const& barrelAssocs, + MyBarrelTracksWithCovWithAmbiguities const& barrelTracks) + { + runSameEventPairing(events, trackAssocsPerCollision, barrelAssocs, barrelTracks); + } + void processMuonOnlySkimmed(MyEventsVtxCovSelected const& events, soa::Join const& muonAssocs, MyMuonTracksWithCovWithAmbiguities const& muons) { @@ -2185,6 +2268,7 @@ struct AnalysisSameEventPairing { PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedNoCov, "Run barrel only pairing (no covariances), with skimmed tracks and with collision information", false); PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedNoCovWithMultExtra, "Run barrel only pairing (no covariances), with skimmed tracks, with collision information, with MultsExtra", false); PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlyWithQvectorCentrSkimmedNoCov, "Run barrel only pairing (no covariances), with skimmed tracks, with Qvector from central framework", false); + PROCESS_SWITCH(AnalysisSameEventPairing, processBarrelOnlySkimmedBDT, "Run electron-electron pairing, with skimmed tracks and BDT selection", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMuonOnlySkimmed, "Run muon only pairing, with skimmed tracks", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMuonOnlySkimmedMultExtra, "Run muon only pairing, with skimmed tracks", false); PROCESS_SWITCH(AnalysisSameEventPairing, processMixingAllSkimmed, "Run all types of mixed pairing, with skimmed tracks/muons", false);