diff --git a/PWGHF/Core/HfMlResponseDstarToD0Pi.h b/PWGHF/Core/HfMlResponseDstarToD0Pi.h index 4d9e2bff488..af0351b2a73 100644 --- a/PWGHF/Core/HfMlResponseDstarToD0Pi.h +++ b/PWGHF/Core/HfMlResponseDstarToD0Pi.h @@ -113,6 +113,8 @@ enum class InputFeaturesDstarToD0Pi : uint8_t { ptSoftPi, impactParameter0, impactParameter1, + impactParameterXY0, + impactParameterXY1, impactParameterZ0, impactParameterZ1, impParamSoftPi, @@ -219,6 +221,28 @@ class HfMlResponseDstarToD0Pi : public HfMlResponse return inputFeatures; } + /// Method to get the input features used for D0 in HF triggers + /// \param candidate is the D* candidate + /// \return inputFeatures vector + template + std::vector getInputFeaturesTrigger(T1 const& candidate) + { + std::vector inputFeatures; + + for (const auto& idx : MlResponse::mCachedIndices) { + switch (idx) { + CHECK_AND_FILL_VEC_DSTAR(ptProng0); + CHECK_AND_FILL_VEC_DSTAR_GETTER(impactParameterXY0, impactParameter0); + CHECK_AND_FILL_VEC_DSTAR(impactParameterZ0); + CHECK_AND_FILL_VEC_DSTAR(ptProng1); + CHECK_AND_FILL_VEC_DSTAR_GETTER(impactParameterXY1, impactParameter1); + CHECK_AND_FILL_VEC_DSTAR(impactParameterZ1); + } + } + + return inputFeatures; + } + protected: /// Method to fill the map of available input features void setAvailableInputFeatures() @@ -238,6 +262,8 @@ class HfMlResponseDstarToD0Pi : public HfMlResponse FILL_MAP_DSTAR(ptSoftPi), FILL_MAP_DSTAR(impactParameter0), FILL_MAP_DSTAR(impactParameter1), + FILL_MAP_DSTAR(impactParameterXY0), + FILL_MAP_DSTAR(impactParameterXY1), FILL_MAP_DSTAR(impactParameterZ0), FILL_MAP_DSTAR(impactParameterZ1), FILL_MAP_DSTAR(impParamSoftPi), diff --git a/PWGHF/TableProducer/candidateSelectorDstarToD0Pi.cxx b/PWGHF/TableProducer/candidateSelectorDstarToD0Pi.cxx index 3df427ad2bb..dbc06282ce0 100644 --- a/PWGHF/TableProducer/candidateSelectorDstarToD0Pi.cxx +++ b/PWGHF/TableProducer/candidateSelectorDstarToD0Pi.cxx @@ -16,6 +16,7 @@ /// \author Fabrizio Grosa , CERN #include "PWGHF/Core/HfHelper.h" +#include "PWGHF/Core/HfMlResponseD0ToKPi.h" #include "PWGHF/Core/HfMlResponseDstarToD0Pi.h" #include "PWGHF/Core/SelectorCuts.h" #include "PWGHF/DataModel/CandidateReconstructionTables.h" @@ -90,18 +91,27 @@ struct HfCandidateSelectorDstarToD0Pi { // QA switch Configurable activateQA{"activateQA", false, "Flag to enable QA histogram"}; - // ML inference + // ML inference D* Configurable applyMl{"applyMl", false, "Flag to apply ML selections"}; Configurable> binsPtMl{"binsPtMl", std::vector{hf_cuts_ml::vecBinsPt}, "pT bin limits for ML application"}; Configurable> cutDirMl{"cutDirMl", std::vector{hf_cuts_ml::vecCutDir}, "Whether to reject score values greater or smaller than the threshold"}; Configurable> cutsMl{"cutsMl", {hf_cuts_ml::Cuts[0], hf_cuts_ml::NBinsPt, hf_cuts_ml::NCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsCutScore}, "ML selections per pT bin"}; Configurable nClassesMl{"nClassesMl", static_cast(hf_cuts_ml::NCutScores), "Number of classes in ML model"}; Configurable> namesInputFeatures{"namesInputFeatures", std::vector{"feature1", "feature2"}, "Names of ML model input features"}; + // ML inference D0 + Configurable applyMlD0Daug{"applyMlD0Daug", false, "Flag to apply ML selections on D0 daughter"}; + Configurable> binsPtMlD0Daug{"binsPtMlD0Daug", std::vector{hf_cuts_ml::vecBinsPt}, "pT bin limits for ML application on D0 daughter"}; + Configurable> cutDirMlD0Daug{"cutDirMlD0Daug", std::vector{hf_cuts_ml::vecCutDir}, "Whether to reject score values greater or smaller than the threshold on D0 daughter"}; + Configurable> cutsMlD0Daug{"cutsMlD0Daug", {hf_cuts_ml::Cuts[0], hf_cuts_ml::NBinsPt, hf_cuts_ml::NCutScores, hf_cuts_ml::labelsPt, hf_cuts_ml::labelsCutScore}, "ML selections per pT bin on D0 daughter"}; + Configurable nClassesMlD0Daug{"nClassesMlD0Daug", static_cast(hf_cuts_ml::NCutScores), "Number of classes in ML model on D0 daughter"}; + Configurable> namesInputFeaturesD0Daug{"namesInputFeaturesD0Daug", std::vector{"feature1", "feature2"}, "Names of ML model input features on D0 daughter"}; // CCDB configuration Configurable ccdbUrl{"ccdbUrl", "http://alice-ccdb.cern.ch", "url of the ccdb repository"}; Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{""}, "Paths of models on CCDB"}; + Configurable> modelPathsCCDBD0Daug{"modelPathsCCDBD0Daug", std::vector{""}, "Paths of models on CCDB for D0 daughter"}; Configurable> onnxFileNames{"onnxFileNames", std::vector{"Model.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"}; + Configurable> onnxFileNamesD0Daug{"onnxFileNamesD0Daug", std::vector{"Model.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path) for D0 daughter"}; Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; @@ -110,7 +120,9 @@ struct HfCandidateSelectorDstarToD0Pi { HfHelper hfHelper; o2::analysis::HfMlResponseDstarToD0Pi hfMlResponse; + o2::analysis::HfMlResponseDstarToD0Pi hfMlResponseD0Daughter; std::vector outputMlDstarToD0Pi = {}; + std::vector outputMlD0ToKPi = {}; o2::ccdb::CcdbApi ccdbApi; TrackSelectorPi selectorPion; @@ -172,6 +184,18 @@ struct HfCandidateSelectorDstarToD0Pi { hfMlResponse.cacheInputFeaturesIndices(namesInputFeatures); hfMlResponse.init(); } + + if (applyMlD0Daug) { + hfMlResponseD0Daughter.configure(binsPtMlD0Daug, cutsMlD0Daug, cutDirMlD0Daug, nClassesMlD0Daug); + if (loadModelsFromCCDB) { + ccdbApi.init(ccdbUrl); + hfMlResponseD0Daughter.setModelPathsCCDB(onnxFileNamesD0Daug, ccdbApi, modelPathsCCDBD0Daug, timestampCCDB); + } else { + hfMlResponseD0Daughter.setModelPathsLocal(onnxFileNamesD0Daug); + } + hfMlResponseD0Daughter.cacheInputFeaturesIndices(namesInputFeaturesD0Daug); + hfMlResponseD0Daughter.init(); + } } /// Conjugate-independent topological cuts on D0 @@ -234,6 +258,15 @@ struct HfCandidateSelectorDstarToD0Pi { if (candidate.decayLengthXYD0() > cutsD0->get(binPt, "max decay length XY")) { return false; } + + if (applyMlD0Daug) { + outputMlD0ToKPi.clear(); + std::vector inputFeaturesD0 = hfMlResponseD0Daughter.getInputFeaturesTrigger(candidate); + bool isSelectedMlD0 = hfMlResponseD0Daughter.isSelectedMl(inputFeaturesD0, candpT, outputMlD0ToKPi); + if (!isSelectedMlD0) { + return false; + } + } return true; }