diff --git a/MC/config/common/TPCloopers/generatorWGAN_compton.onnx b/MC/config/common/TPCloopers/generatorWGAN_compton.onnx deleted file mode 100644 index 9a5233361..000000000 Binary files a/MC/config/common/TPCloopers/generatorWGAN_compton.onnx and /dev/null differ diff --git a/MC/config/common/TPCloopers/generatorWGAN_pair.onnx b/MC/config/common/TPCloopers/generatorWGAN_pair.onnx deleted file mode 100644 index 0ec23eb5e..000000000 Binary files a/MC/config/common/TPCloopers/generatorWGAN_pair.onnx and /dev/null differ diff --git a/MC/config/common/external/generator/TPCLoopers.C b/MC/config/common/external/generator/TPCLoopers.C index 76bd2417f..fb53c7a50 100644 --- a/MC/config/common/external/generator/TPCLoopers.C +++ b/MC/config/common/external/generator/TPCLoopers.C @@ -3,6 +3,8 @@ #include #include #include +#include "CCDB/CCDBTimeStampUtils.h" +#include "CCDB/CcdbApi.h" // Static Ort::Env instance for multiple onnx model loading static Ort::Env global_env(ORT_LOGGING_LEVEL_WARNING, "GlobalEnv"); @@ -377,6 +379,11 @@ class GenTPCLoopers : public Generator } // namespace eventgen } // namespace o2 +// ONNX model files can be local, on AliEn or in the ALICE CCDB. +// For local and alien files it is mandatory to provide the filenames, for the CCDB instead the +// path to the object in the CCDB is sufficient. The model files will be downloaded locally. +// Example of CCDB path: "ccdb://Users/n/name/test" +// Example of alien path: "alien:///alice/cern.ch/user/n/name/test/test.onnx" FairGenerator * Generator_TPCLoopers(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx", std::string poisson = "poisson.csv", std::string gauss = "gauss.csv", std::string scaler_pair = "scaler_pair.json", @@ -390,6 +397,55 @@ FairGenerator * gauss = gSystem->ExpandPathName(gauss.c_str()); scaler_pair = gSystem->ExpandPathName(scaler_pair.c_str()); scaler_compton = gSystem->ExpandPathName(scaler_compton.c_str()); + const std::array models = {model_pairs, model_compton}; + const std::array local_names = {"WGANpair.onnx", "WGANcompton.onnx"}; + const std::array isAlien = {models[0].starts_with("alien://"), models[1].starts_with("alien://")}; + const std::array isCCDB = {models[0].starts_with("ccdb://"), models[1].starts_with("ccdb://")}; + if (std::any_of(isAlien.begin(), isAlien.end(), [](bool v) { return v; })) + { + if (!gGrid) { + TGrid::Connect("alien://"); + if (!gGrid) { + LOG(fatal) << "AliEn connection failed, check token."; + exit(1); + } + } + for (size_t i = 0; i < models.size(); ++i) + { + if (isAlien[i] && !TFile::Cp(models[i].c_str(), local_names[i].c_str())) + { + LOG(fatal) << "Error: Model file " << models[i] << " does not exist!"; + exit(1); + } + } + } + if (std::any_of(isCCDB.begin(), isCCDB.end(), [](bool v) { return v; })) + { + o2::ccdb::CcdbApi ccdb_api; + ccdb_api.init("http://alice-ccdb.cern.ch"); + for (size_t i = 0; i < models.size(); ++i) + { + if (isCCDB[i]) + { + auto model_path = models[i].substr(7); // Remove "ccdb://" + // Treat filename if provided in the CCDB path + auto extension = model_path.find(".onnx"); + if (extension != std::string::npos) + { + auto last_slash = model_path.find_last_of('/'); + model_path = model_path.substr(0, last_slash); + } + std::map filter; + if(!ccdb_api.retrieveBlob(model_path, "./" , filter, o2::ccdb::getCurrentTimestamp(), false, local_names[i].c_str())) + { + LOG(fatal) << "Error: issues in retrieving " << model_path << " from CCDB!"; + exit(1); + } + } + } + } + model_pairs = isAlien[0] || isCCDB[0] ? local_names[0] : model_pairs; + model_compton = isAlien[1] || isCCDB[1] ? local_names[1] : model_compton; auto generator = new o2::eventgen::GenTPCLoopers(model_pairs, model_compton, poisson, gauss, scaler_pair, scaler_compton); generator->SetNLoopers(nloopers_pairs, nloopers_compton); generator->SetMultiplier(mult); diff --git a/MC/config/common/ini/GeneratorTPCloopers.ini b/MC/config/common/ini/GeneratorTPCloopers.ini index 838c52fc9..276a086c8 100644 --- a/MC/config/common/ini/GeneratorTPCloopers.ini +++ b/MC/config/common/ini/GeneratorTPCloopers.ini @@ -1,4 +1,4 @@ # Example of tpc loopers generator with a poisson distribution of pairs and gaussian distribution of compton electrons [GeneratorExternal] fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C -funcName = Generator_TPCLoopers("${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/generatorWGAN_pair.onnx", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/generatorWGAN_compton.onnx", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/poisson_params.csv", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/gaussian_params.csv", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerComptonParams.json") \ No newline at end of file +funcName = Generator_TPCLoopers("ccdb://Users/m/mgiacalo/WGAN_ExtGenPair", "ccdb://Users/m/mgiacalo/WGAN_ExtGenCompton", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/poisson_params.csv", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/gaussian_params.csv", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerComptonParams.json") \ No newline at end of file diff --git a/MC/config/common/ini/GeneratorTPCloopers_fixNPairs.ini b/MC/config/common/ini/GeneratorTPCloopers_fixNPairs.ini index 0ed296e7c..4223ecf40 100644 --- a/MC/config/common/ini/GeneratorTPCloopers_fixNPairs.ini +++ b/MC/config/common/ini/GeneratorTPCloopers_fixNPairs.ini @@ -1,5 +1,6 @@ -# Example of tpc loopers generator with a fixed number of pairs (10) +# Example of tpc loopers generator with a fixed number of pairs and compton electrons (10) +# Multiplier values are ignored in this case, but kept to 1 for consistency #---> GeneratorTPCloopers [GeneratorExternal] fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C -funcName = Generator_TPCLoopers("${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/generatorWGAN_pair.onnx", "", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json",10) +funcName = Generator_TPCLoopers("ccdb://Users/m/mgiacalo/WGAN_ExtGenPair", "ccdb://Users/m/mgiacalo/WGAN_ExtGenCompton", "", "", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerComptonParams.json",{1.,1.}, 10,10)