diff --git a/PWGHF/TableProducer/trackIndexSkimCreator.cxx b/PWGHF/TableProducer/trackIndexSkimCreator.cxx index db69cb40b07..0e93c3e1337 100644 --- a/PWGHF/TableProducer/trackIndexSkimCreator.cxx +++ b/PWGHF/TableProducer/trackIndexSkimCreator.cxx @@ -1693,6 +1693,9 @@ struct HfTrackIndexSkimCreator { const std::vector ptBinsMl{0., 1.e10}; const std::vector cutDirMl{o2::cuts_ml::CutDirection::CutGreater, o2::cuts_ml::CutDirection::CutSmaller, o2::cuts_ml::CutDirection::CutSmaller}; const std::array, kN3ProngDecaysUsedMlForHfFilters> thresholdMlScore3Prongs{config.thresholdMlScoreDplusToPiKPi, config.thresholdMlScoreLcToPiKP, config.thresholdMlScoreDsToPiKK, config.thresholdMlScoreXicToPiKP}; + const std::vector inputFeatures2Prongs = {"ptProng0", "dcaXyProng0", "dcaZProng0", "ptProng1", "dcaXyProng1", "dcaZProng1"}; + const std::vector inputFeatures3Prongs = {"ptProng0", "dcaXyProng0", "dcaZProng0", "ptProng1", "dcaXyProng1", "dcaZProng1", "ptProng2", "dcaXyProng2", "dcaZProng2"}; + const std::vector inputFeatures3ProngsWithPid = {"ptProng0", "dcaXyProng0", "dcaZProng0", "ptProng1", "dcaXyProng1", "dcaZProng1", "ptProng2", "dcaXyProng2", "dcaZProng2", "tpcNSigmaPrProng0", "tpcNSigmaPrProng2", "tpcNSigmaPiProng0", "tpcNSigmaPiProng2", "tpcNSigmaKaProng1"}; // initialise 2-prong ML response hfMlResponse2Prongs.configure(ptBinsMl, config.thresholdMlScoreD0ToKPi, cutDirMl, 3); @@ -1702,6 +1705,7 @@ struct HfTrackIndexSkimCreator { } else { hfMlResponse2Prongs.setModelPathsLocal(onnxFileNames2Prongs); } + hfMlResponse2Prongs.cacheInputFeaturesIndices(inputFeatures2Prongs); hfMlResponse2Prongs.init(); // initialise 3-prong ML responses @@ -1717,6 +1721,11 @@ struct HfTrackIndexSkimCreator { } else { hfMlResponse3Prongs[iDecay3P].setModelPathsLocal(onnxFileNames3Prongs[iDecay3P]); } + if ((doprocess2And3ProngsWithPvRefitWithPidForHfFiltersBdt || doprocess2And3ProngsNoPvRefitWithPidForHfFiltersBdt) && iDecay3P == aod::hf_cand_3prong::DecayType::LcToPKPi) { + hfMlResponse3Prongs[iDecay3P].cacheInputFeaturesIndices(inputFeatures3ProngsWithPid); + } else { + hfMlResponse3Prongs[iDecay3P].cacheInputFeaturesIndices(inputFeatures3Prongs); + } hfMlResponse3Prongs[iDecay3P].init(); } }