From 6b0673c99242b9b94c087110af83bd4c823e71d5 Mon Sep 17 00:00:00 2001 From: Vincenzo Eduardo Padulano Date: Fri, 13 Feb 2026 20:48:01 +0100 Subject: [PATCH] [df] Support std::array data members in TTree schemas Fixes https://github.com/root-project/root/issues/14790 Include tests that both check the fixed behaviour and document limitations and currently unsupported schemas or operations on them that are closely related to the one described above. Reading of the same TTree schema with RNTuple is currently not supported within this context. --- roottest/root/dataframe/CMakeLists.txt | 18 + .../dataframe/ClassWithSequenceContainers.hxx | 39 + .../ClassWithSequenceContainersLinkDef.hxx | 6 + .../dataframe_sequence_containers.cxx | 916 ++++++++++++++++++ .../inc/ROOT/RDF/RTreeColumnReader.hxx | 10 +- tree/dataframe/src/RDFUtils.cxx | 11 +- tree/dataframe/src/RTTreeDS.cxx | 8 + tree/dataframe/src/RTreeColumnReader.cxx | 84 +- tree/treeplayer/inc/TTreeReaderArray.h | 12 +- tree/treeplayer/src/TTreeReaderArray.cxx | 100 +- 10 files changed, 1188 insertions(+), 16 deletions(-) create mode 100644 roottest/root/dataframe/ClassWithSequenceContainers.hxx create mode 100644 roottest/root/dataframe/ClassWithSequenceContainersLinkDef.hxx create mode 100644 roottest/root/dataframe/dataframe_sequence_containers.cxx diff --git a/roottest/root/dataframe/CMakeLists.txt b/roottest/root/dataframe/CMakeLists.txt index 7a376e330ddd1..a66a524e31ee0 100644 --- a/roottest/root/dataframe/CMakeLists.txt +++ b/roottest/root/dataframe/CMakeLists.txt @@ -330,3 +330,21 @@ ROOTTEST_ADD_TEST(test_snapshot_copyaddresses MACRO test_snapshot_copyaddresses.C+) ROOT_ADD_GTEST(test_norootextension test_norootextension.cxx LIBRARIES ROOT::ROOTDataFrame) + +ROOTTEST_GENERATE_DICTIONARY( + ClassWithSequenceContainersDict + ${CMAKE_CURRENT_SOURCE_DIR}/ClassWithSequenceContainers.hxx + LINKDEF ${CMAKE_CURRENT_SOURCE_DIR}/ClassWithSequenceContainersLinkDef.hxx + FIXTURES_SETUP ClassWithSequenceContainersDict_setup +) +ROOTTEST_GENERATE_EXECUTABLE( + dataframe_sequence_containers + dataframe_sequence_containers.cxx ClassWithSequenceContainersDict.cxx + LIBRARIES Core RIO Tree ROOTDataFrame GTest::gtest GTest::gtest_main GTest::gmock GTest::gmock_main ROOT::TestSupport + FIXTURES_REQUIRED ClassWithSequenceContainersDict_setup + FIXTURES_SETUP dataframe_sequence_containers_setup +) +ROOTTEST_ADD_TEST(dataframe_sequence_containers + EXEC ${CMAKE_CURRENT_BINARY_DIR}/dataframe_sequence_containers + FIXTURES_REQUIRED dataframe_sequence_containers_setup +) diff --git a/roottest/root/dataframe/ClassWithSequenceContainers.hxx b/roottest/root/dataframe/ClassWithSequenceContainers.hxx new file mode 100644 index 0000000000000..2f4ef62976cf4 --- /dev/null +++ b/roottest/root/dataframe/ClassWithSequenceContainers.hxx @@ -0,0 +1,39 @@ +#ifndef ROOT_DATAFRAME_TEST_ARRRAYCOMBINATIONS +#define ROOT_DATAFRAME_TEST_ARRRAYCOMBINATIONS + +#include +#include + +#include + +struct ClassWithSequenceContainers { + unsigned int fObjIndex{}; + std::array fArrFl{}; + std::array, 3> fArrArrFl{}; + std::array, 3> fArrVecFl{}; + + std::vector fVecFl{}; + std::vector> fVecArrFl{}; //! Not supported for TTree: could not find the real data member + //! '_M_elems[3]' when constructing the branch 'fVecArrFl' + std::vector> fVecVecFl{}; + + // For ROOT I/O + ClassWithSequenceContainers() = default; + + ClassWithSequenceContainers(unsigned int objIndex, std::array a1, std::array, 3> a2, + std::array, 3> a3, std::vector a4, + std::vector> a5, std::vector> a6) + : fObjIndex(objIndex), + fArrFl(std::move(a1)), + fArrArrFl(std::move(a2)), + fArrVecFl(std::move(a3)), + fVecFl(std::move(a4)), + fVecArrFl(std::move(a5)), + fVecVecFl(std::move(a6)) + { + } + + ClassDefNV(ClassWithSequenceContainers, 1) +}; + +#endif diff --git a/roottest/root/dataframe/ClassWithSequenceContainersLinkDef.hxx b/roottest/root/dataframe/ClassWithSequenceContainersLinkDef.hxx new file mode 100644 index 0000000000000..25a4a559190f0 --- /dev/null +++ b/roottest/root/dataframe/ClassWithSequenceContainersLinkDef.hxx @@ -0,0 +1,6 @@ +#ifdef __CLING__ + +#pragma link C++ class ClassWithSequenceContainers+; +#pragma link C++ class std::vector+; + +#endif diff --git a/roottest/root/dataframe/dataframe_sequence_containers.cxx b/roottest/root/dataframe/dataframe_sequence_containers.cxx new file mode 100644 index 0000000000000..117a0e738002c --- /dev/null +++ b/roottest/root/dataframe/dataframe_sequence_containers.cxx @@ -0,0 +1,916 @@ +#include + +#include + +#include + +#include +#include + +#include +#include + +#include +#include + +#include "ClassWithSequenceContainers.hxx" + +using ROOT::RNTupleModel; +using ROOT::RNTupleWriter; + +struct ClassWithSequenceContainersData { + std::array fArrFl{}; + std::array, 3> fArrArrFl{}; + std::array, 3> fArrVecFl{}; + + std::vector fVecFl{}; + std::vector> fVecArrFl{}; + std::vector> fVecVecFl{}; + + ClassWithSequenceContainers fClassWithArrays{}; + std::vector fVecClassWithArrays{}; +}; + +std::vector generateClassWithSequenceContainersData() +{ + // ClassWithSequenceContainers members + std::array topArrFl{1.1, 2.2, 3.3}; + std::array, 3> topArrArrFl{{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}}; + std::array, 3> topArrVecFl{{{11.11}, {12.12, 13.13}, {14.14, 15.15, 16.16}}}; + + std::vector topVecFl{17.17, 18.18, 19.19}; + std::vector> topVecArrFl{{{21.21, 22.22, 23.23}, + {24.24, 25.25, 26.26}, + {27.27, 28.28, 29.29}, + {31.31, 32.32, 33.33}, + {34.34, 35.35, 36.36}, + {37.37, 38.38, 39.39}}}; + std::vector> topVecVecFl{{}, {41.41}, {42.42, 43.43}, {44.44, 45.45, 46.46}}; + + // Class object + ClassWithSequenceContainers classWithArrays(0, topArrFl, topArrArrFl, topArrVecFl, topVecFl, topVecArrFl, + topVecVecFl); + + // std::vector of class objects + std::vector vecClassWithArrays; + vecClassWithArrays.reserve(5); + for (int i = 1; i < 6; i++) { + vecClassWithArrays.emplace_back(i, topArrFl, topArrArrFl, topArrVecFl, topVecFl, topVecArrFl, topVecVecFl); + } + return std::vector{ClassWithSequenceContainersData{topArrFl, topArrArrFl, topArrVecFl, topVecFl, topVecArrFl, + topVecVecFl, classWithArrays, vecClassWithArrays}}; +} + +std::vector generateClassWithSequenceContainersDataPlusOne() +{ + auto data = generateClassWithSequenceContainersData(); + for (auto &entry : data) { + for (auto &val : entry.fArrFl) { + val += 1; + } + for (auto &arr : entry.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : entry.fArrVecFl) { + for (auto &val : vec) { + val += 1; + } + } + for (auto &val : entry.fVecFl) { + val += 1; + } + for (auto &arr : entry.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : entry.fVecVecFl) { + for (auto &val : vec) { + val += 1; + } + } + auto classWithArraysPlusOne = [](ClassWithSequenceContainers &obj) { + obj.fObjIndex += 1; + for (auto &val : obj.fArrFl) { + val += 1; + } + for (auto &arr : obj.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : obj.fArrVecFl) { + for (auto &val : vec) { + val += 1; + } + } + for (auto &val : obj.fVecFl) { + val += 1; + } + for (auto &arr : obj.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : obj.fVecVecFl) { + for (auto &val : vec) { + val += 1; + } + } + }; + classWithArraysPlusOne(entry.fClassWithArrays); + for (auto &obj : entry.fVecClassWithArrays) { + classWithArraysPlusOne(obj); + } + } + return data; +} + +class ClassWithSequenceContainersTest : public ::testing::TestWithParam { +protected: + constexpr static const char *fFileName = "root_dataframe_sequence_containers_ClassWithSequenceContainers.root"; + constexpr static const char *fDatasetName = "root_dataframe_sequence_containers_ClassWithSequenceContainers"; + + void WriteTTree() + { + auto f = std::make_unique(fFileName, "RECREATE"); + auto t = std::make_unique(fDatasetName, fDatasetName); + + auto data = generateClassWithSequenceContainersData(); + + // Branches + t->Branch("topArrFl", &data[0].fArrFl); + // Not supported: std::array with T being a class type as top-level branch + // t->Branch("topArrArrFl", &data.topArrArrFl); + // t->Branch("topArrVecFl", &data.topArrVecFl); + t->Branch("topVecFl", &data[0].fVecFl); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // t->Branch("topVecArrFl", &data.fVecArrFl); + t->Branch("topVecVecFl", &data[0].fVecVecFl); + t->Branch("classWithArrays", &data[0].fClassWithArrays); + t->Branch("vecClassWithArrays", &data[0].fVecClassWithArrays); + + t->Fill(); + + f->Write(); + } + + void WriteRNTuple() + { + auto model = RNTupleModel::Create(); + auto topArrFl = model->MakeField>("topArrFl"); + auto topArrArrFl = model->MakeField, 3>>("topArrArrFl"); + auto topArrVecFl = model->MakeField, 3>>("topArrVecFl"); + + auto topVecFl = model->MakeField>("topVecFl"); + auto topVecArrFl = model->MakeField>>("topVecArrFl"); + auto topVecVecFl = model->MakeField>>("topVecVecFl"); + + auto classWithArrays = model->MakeField("classWithArrays"); + auto vecClassWithArrays = model->MakeField>("vecClassWithArrays"); + auto ntuple = RNTupleWriter::Recreate(std::move(model), fDatasetName, fFileName); + auto data = generateClassWithSequenceContainersData(); + for (const auto &entry : data) { + *topArrFl = entry.fArrFl; + *topArrArrFl = entry.fArrArrFl; + *topArrVecFl = entry.fArrVecFl; + *topVecFl = entry.fVecFl; + *topVecArrFl = entry.fVecArrFl; + *topVecVecFl = entry.fVecVecFl; + *classWithArrays = entry.fClassWithArrays; + *vecClassWithArrays = entry.fVecClassWithArrays; + ntuple->Fill(); + } + } + + ClassWithSequenceContainersTest() + { + if (GetParam()) { + WriteTTree(); + } else { + WriteRNTuple(); + } + } + + ~ClassWithSequenceContainersTest() override { std::remove(fFileName); } +}; + +template +void check_1d_coll(const T &coll1, const U &coll2) +{ + ASSERT_EQ(coll1.size(), coll2.size()); + for (size_t i = 0; i < coll1.size(); ++i) { + EXPECT_FLOAT_EQ(coll1[i], coll2[i]) << " at index " << i; + } +} + +template +void check_2d_coll(const T &coll1, const U &coll2) +{ + ASSERT_EQ(coll1.size(), coll2.size()); + for (size_t i = 0; i < coll1.size(); ++i) { + check_1d_coll(coll1[i], coll2[i]); + } +} + +void check_class_with_arrays(const ClassWithSequenceContainers &obj1, const ClassWithSequenceContainers &obj2) +{ + EXPECT_EQ(obj1.fObjIndex, obj2.fObjIndex); + check_1d_coll(obj1.fArrFl, obj2.fArrFl); + check_2d_coll(obj1.fArrArrFl, obj2.fArrArrFl); + check_2d_coll(obj1.fArrVecFl, obj2.fArrVecFl); + check_1d_coll(obj1.fVecFl, obj2.fVecFl); + // fVecArrFl is not supported: Could not find the real data member '_M_elems[3]' when constructing the branch + // 'fVecArrFl' + // check_2d_coll(obj1.fVecArrFl, obj2.fVecArrFl); + check_2d_coll(obj1.fVecVecFl, obj2.fVecVecFl); +} + +template +void check_coll_class_with_arrays(const T &coll1, const U &coll2) +{ + ASSERT_EQ(coll1.size(), coll2.size()); + for (size_t i = 0; i < coll1.size(); ++i) { + check_class_with_arrays(coll1[i], coll2[i]); + } +} + +TEST_P(ClassWithSequenceContainersTest, ExpectedTypes) +{ + // TODO: This test currently assumes spellings of column types when the + // data format is TTree + if (!GetParam()) + return; // The expected type names are different between the TTree and RNTuple data sources + + ROOT::RDataFrame df{fDatasetName, fFileName}; + + const std::unordered_map expectedColTypes{ + {"topArrFl", "ROOT::VecOps::RVec"}, + // Not supported: std::array with T being a class type as top-level branch + // {"topArrArrFl", "ROOT::VecOps::RVec>"}, + // {"topArrVecFl", "ROOT::VecOps::RVec>"}, + {"topVecFl", "ROOT::VecOps::RVec"}, + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // {"topVecArrFl", "ROOT::VecOps::RVec>"}, + {"topVecVecFl", "ROOT::VecOps::RVec>"}, + {"classWithArrays", "ClassWithSequenceContainers"}, + {"classWithArrays.fObjIndex", "UInt_t"}, + {"classWithArrays.fArrFl[3]", "ROOT::VecOps::RVec"}, + // TODO: array of array is currently not properly handled + // {"classWithArrays.fArrArrFl[3][3]", "ROOT::VecOps::RVec>"}, + {"classWithArrays.fArrVecFl[3]", "ROOT::VecOps::RVec>"}, + {"classWithArrays.fVecFl", "ROOT::VecOps::RVec"}, + {"classWithArrays.fVecVecFl", "ROOT::VecOps::RVec>"}, + {"vecClassWithArrays", "ROOT::VecOps::RVec"}, + {"vecClassWithArrays.fArrFl[3]", "ROOT::VecOps::RVec>"}, + {"vecClassWithArrays.fArrVecFl[3]", "ROOT::VecOps::RVec, 3>>"}, + {"vecClassWithArrays.fVecFl", "ROOT::VecOps::RVec>"}, + {"vecClassWithArrays.fVecVecFl", "ROOT::VecOps::RVec >>"}, + }; + for (const auto &[colName, expectedType] : expectedColTypes) { + EXPECT_EQ(df.GetColumnType(colName), expectedType) << " for column " << colName; + } +} + +TEST_P(ClassWithSequenceContainersTest, TakeExpectedTypes) +{ +#ifndef NDEBUG + // The following warning is only for debugging purposes with the TTree data source. It happens in this test + // because we are reading a class partially, i.e. the branch type is std::vector but + // we are only reading the data member fArrFl with the column name "classWithArrays.fArrFl[3]" as an + // RVec>. + ROOT::TestSupport::CheckDiagsRAII diagRAII; + diagRAII.optionalDiag( + kWarning, "RTreeColumnReader::Get", + "hangs from a non-split branch. A copy is being performed in order to properly read the content.", false); +#endif + + ROOT::RDataFrame df{fDatasetName, fFileName}; + + auto data = generateClassWithSequenceContainersData(); + + const std::string classWithArraysArrFlColName = GetParam() ? "classWithArrays.fArrFl[3]" : "classWithArrays.fArrFl"; + const std::string vecClassWithArraysArrFlColName = + GetParam() ? "vecClassWithArrays.fArrFl[3]" : "vecClassWithArrays.fArrFl"; + + // Take each column individually and check the content + // In this test, use the types as expected by the test "ExpectedTypes" + auto takeTopArrFl = df.Take>("topArrFl"); + // Not supported: std::array with T being a class type as top-level branch + // auto takeTopArrArrFl = df.Take>("topArrArrFl"); + // auto takeTopArrVecFl = df.Take>("topArrVecFl"); + auto takeTopVecFl = df.Take>("topVecFl"); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // auto takeTopVecArrFl = df.Take>>("topVecArrFl"); + auto takeTopVecVecFl = df.Take>>("topVecVecFl"); + auto takeClassWithArrays = df.Take("classWithArrays"); + auto takeClassWithArrays_fObjIndex = df.Take("classWithArrays.fObjIndex"); + auto takeClassWithArrays_fArrFl = df.Take>(classWithArraysArrFlColName); + // TODO: array of array is currently not properly handled + // auto takeClassWithArrays_fArrArrFl = df.Take, + // 3>>>("classWithArrays.fArrArrFl[3][3]"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeClassWithArrays_fArrVecFl = + // df.Take>>("classWithArrays.fArrVecFl[3]"); + auto takeClassWithArrays_fVecFl = df.Take>("classWithArrays.fVecFl"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "ROOT::VecOps::RVec>" for column + // "classWithArrays.fVecVecFl" + ROOT::RDF::RResultPtr>>> takeClassWithArrays_fVecVecFl; + if (GetParam()) { + takeClassWithArrays_fVecVecFl = df.Take>>("classWithArrays.fVecVecFl"); + } + auto takeVecClassWithArrays = df.Take>("vecClassWithArrays"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::array" for column "classWithArrays.fArrFl" + ROOT::RDF::RResultPtr>>> takeVecClassWithArrays_fArrFl; + if (GetParam()) { + takeVecClassWithArrays_fArrFl = + df.Take>>(vecClassWithArraysArrFlColName); + } + // TODO: array of vector as a data member is currently not properly handled + // auto takeVecClassWithArrays_fArrVecFl = + // df.Take,3>>>("vecClassWithArrays.fArrVecFl[3]"); + // TODO: vector of vector throws `std::bad_alloc` currently + // auto takeVecClassWithArrays_fVecFl = df.Take>>("vecClassWithArrays.fVecFl"); + // auto takeVecClassWithArrays_fVecVecFl = + // df.Take>>>("vecClassWithArrays.fVecVecFl"); + + auto nEvents = takeTopArrFl->size(); + EXPECT_EQ(nEvents, 1); + + for (decltype(nEvents) i = 0; i < nEvents; ++i) { + check_1d_coll(takeTopArrFl->at(i), data[i].fArrFl); + check_1d_coll(takeTopVecFl->at(i), data[i].fVecFl); + check_2d_coll(takeTopVecVecFl->at(i), data[i].fVecVecFl); + check_class_with_arrays(takeClassWithArrays->at(i), data[i].fClassWithArrays); + EXPECT_EQ(takeClassWithArrays_fObjIndex->at(i), data[i].fClassWithArrays.fObjIndex); + check_1d_coll(takeClassWithArrays_fArrFl->at(i), data[i].fClassWithArrays.fArrFl); + check_1d_coll(takeClassWithArrays_fVecFl->at(i), data[i].fClassWithArrays.fVecFl); + if (GetParam()) { + check_2d_coll(takeClassWithArrays_fVecVecFl->at(i), data[i].fClassWithArrays.fVecVecFl); + } + check_coll_class_with_arrays(takeVecClassWithArrays->at(i), data[i].fVecClassWithArrays); + if (GetParam()) { + std::vector> expectedVecArrFl(data[i].fVecClassWithArrays.size()); + for (size_t j = 0; j < data[i].fVecClassWithArrays.size(); ++j) { + expectedVecArrFl[j] = data[i].fVecClassWithArrays[j].fArrFl; + } + check_2d_coll(takeVecClassWithArrays_fArrFl->at(i), expectedVecArrFl); + } + } +} + +TEST_P(ClassWithSequenceContainersTest, TakeOriginalTypes) +{ + ROOT::RDataFrame df{fDatasetName, fFileName}; + + const std::string classWithArraysArrFlColName = GetParam() ? "classWithArrays.fArrFl[3]" : "classWithArrays.fArrFl"; + const std::string vecClassWithArraysArrFlColName = + GetParam() ? "vecClassWithArrays.fArrFl[3]" : "vecClassWithArrays.fArrFl"; + + auto data = generateClassWithSequenceContainersData(); + + // Take each column individually and check the content + // In this test, call Take using only original types as written in the EDM + auto takeTopArrFl = df.Take>("topArrFl"); + // Not supported: std::array with T being a class type as top-level branch + // auto takeTopArrArrFl = df.Take>("topArrArrFl"); + // auto takeTopArrVecFl = df.Take>("topArrVecFl"); + auto takeTopVecFl = df.Take>("topVecFl"); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // auto takeTopVecArrFl = df.Take>>("topVecArrFl"); + auto takeTopVecVecFl = df.Take>>("topVecVecFl"); + auto takeClassWithArrays = df.Take("classWithArrays"); + auto takeClassWithArrays_fObjIndex = df.Take("classWithArrays.fObjIndex"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::array" for column "classWithArrays.fArrFl" + ROOT::RDF::RResultPtr>> takeClassWithArrays_fArrFl; + if (GetParam()) { + takeClassWithArrays_fArrFl = df.Take>(classWithArraysArrFlColName); + } + // TODO: array of array is currently not properly handled + // auto takeClassWithArrays_fArrArrFl = df.Take, + // 3>>>("classWithArrays.fArrArrFl[3][3]"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeClassWithArrays_fArrVecFl = + // df.Take>>("classWithArrays.fArrVecFl[3]"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector" for column "classWithArrays.fVecFl" + ROOT::RDF::RResultPtr>> takeClassWithArrays_fVecFl; + if (GetParam()) { + takeClassWithArrays_fVecFl = df.Take>("classWithArrays.fVecFl"); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "classWithArrays.fVecVecFl" + ROOT::RDF::RResultPtr>>> takeClassWithArrays_fVecVecFl; + if (GetParam()) { + takeClassWithArrays_fVecVecFl = df.Take>>("classWithArrays.fVecVecFl"); + } + auto takeVecClassWithArrays = df.Take>("vecClassWithArrays"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "vecClassWithArrays.fArrFl" + ROOT::RDF::RResultPtr>>> takeVecClassWithArrays_fArrFl; + if (GetParam()) { + takeVecClassWithArrays_fArrFl = df.Take>>(vecClassWithArraysArrFlColName); + } + // TODO: array of vector as a data member is currently not properly handled + // auto takeVecClassWithArrays_fArrVecFl = + // df.Take,3>>>("vecClassWithArrays.fArrVecFl[3]"); + // TODO: vector of vector throws `std::bad_alloc` currently + // auto takeVecClassWithArrays_fVecFl = df.Take>>("vecClassWithArrays.fVecFl"); + // auto takeVecClassWithArrays_fVecVecFl = + // df.Take>>>("vecClassWithArrays.fVecVecFl"); + + auto nEvents = takeTopArrFl->size(); + EXPECT_EQ(nEvents, 1); + + for (decltype(nEvents) i = 0; i < nEvents; ++i) { + check_1d_coll(takeTopArrFl->at(i), data[i].fArrFl); + check_1d_coll(takeTopVecFl->at(i), data[i].fVecFl); + check_2d_coll(takeTopVecVecFl->at(i), data[i].fVecVecFl); + check_class_with_arrays(takeClassWithArrays->at(i), data[i].fClassWithArrays); + EXPECT_EQ(takeClassWithArrays_fObjIndex->at(i), data[i].fClassWithArrays.fObjIndex); + if (GetParam()) { + check_1d_coll(takeClassWithArrays_fArrFl->at(i), data[i].fClassWithArrays.fArrFl); + check_1d_coll(takeClassWithArrays_fVecFl->at(i), data[i].fClassWithArrays.fVecFl); + check_2d_coll(takeClassWithArrays_fVecVecFl->at(i), data[i].fClassWithArrays.fVecVecFl); + } + check_coll_class_with_arrays(takeVecClassWithArrays->at(i), data[i].fVecClassWithArrays); + if (GetParam()) { + std::vector> expectedVecArrFl(data[i].fVecClassWithArrays.size()); + for (size_t j = 0; j < data[i].fVecClassWithArrays.size(); ++j) { + expectedVecArrFl[j] = data[i].fVecClassWithArrays[j].fArrFl; + } + check_2d_coll(takeVecClassWithArrays_fArrFl->at(i), expectedVecArrFl); + } + } +} + +TEST_P(ClassWithSequenceContainersTest, TemplatedOps) +{ + ROOT::RDataFrame df{fDatasetName, fFileName}; + ROOT::RDF::RNode node = df; + + const std::string classWithArraysArrFlColName = GetParam() ? "classWithArrays.fArrFl[3]" : "classWithArrays.fArrFl"; + const std::string vecClassWithArraysArrFlColName = + GetParam() ? "vecClassWithArrays.fArrFl[3]" : "vecClassWithArrays.fArrFl"; + + auto data = generateClassWithSequenceContainersDataPlusOne(); + + node = node.Define("topArrFl_plus_1", + [](const std::array &arr) { + std::array result; + for (size_t i = 0; i < arr.size(); ++i) { + result[i] = arr[i] + 1.0f; + } + return result; + }, + {"topArrFl"}); + node = node.Define("topVecFl_plus_1", + [](const std::vector &vec) { + std::vector result(vec.size()); + for (size_t i = 0; i < vec.size(); ++i) { + result[i] = vec[i] + 1.0f; + } + return result; + }, + {"topVecFl"}); + node = node.Define("topVecVecFl_plus_1", + [](const std::vector> &vecvec) { + std::vector> result(vecvec.size()); + for (size_t i = 0; i < vecvec.size(); ++i) { + result[i].resize(vecvec[i].size()); + for (size_t j = 0; j < vecvec[i].size(); ++j) { + result[i][j] = vecvec[i][j] + 1.0f; + } + } + return result; + }, + {"topVecVecFl"}); + node = node.Define("classWithArrays_plus_1", + [](const ClassWithSequenceContainers &obj) { + ClassWithSequenceContainers result = obj; + result.fObjIndex += 1; + for (auto &val : result.fArrFl) { + val += 1; + } + for (auto &arr : result.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : result.fArrVecFl) { + for (auto &val : vec) { + val += 1; + } + } + for (auto &val : result.fVecFl) { + val += 1; + } + for (auto &arr : result.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : result.fVecVecFl) { + for (auto &val : vec) { + val += 1; + } + } + return result; + }, + {"classWithArrays"}); + node = node.Define("vecClassWithArrays_plus_1", + [](const std::vector &vec) { + std::vector result = vec; + for (auto &obj : result) { + obj.fObjIndex += 1; + for (auto &val : obj.fArrFl) { + val += 1; + } + for (auto &arr : obj.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vvec : obj.fArrVecFl) { + for (auto &val : vvec) { + val += 1; + } + } + for (auto &val : obj.fVecFl) { + val += 1; + } + for (auto &arr : obj.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vvec : obj.fVecVecFl) { + for (auto &val : vvec) { + val += 1; + } + } + } + return result; + }, + {"vecClassWithArrays"}); + // Also create modified values for data member columns + node = node.Define("classWithArrays_fObjIndex_plus_1", [](unsigned int objIndex) { return objIndex + 1; }, + {"classWithArrays.fObjIndex"}); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::array" for column "classWithArrays.fArrFl" + if (GetParam()) { + node = node.Define("classWithArrays_fArrFl_plus_1", + [](const std::array &arr) { + std::array result; + for (size_t i = 0; i < arr.size(); ++i) { + result[i] = arr[i] + 1.0f; + } + return result; + }, + {classWithArraysArrFlColName}); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector" for column "classWithArrays.fVecFl" + if (GetParam()) { + node = node.Define("classWithArrays_fVecFl_plus_1", + [](const std::vector &vec) { + std::vector result(vec.size()); + for (size_t i = 0; i < vec.size(); ++i) { + result[i] = vec[i] + 1.0f; + } + return result; + }, + {"classWithArrays.fVecFl"}); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "classWithArrays.fVecVecFl" + if (GetParam()) { + node = node.Define("classWithArrays_fVecVecFl_plus_1", + [](const std::vector> &vecVecFl) { + std::vector> result(vecVecFl.size()); + for (size_t i = 0; i < vecVecFl.size(); ++i) { + result[i].resize(vecVecFl[i].size()); + for (size_t j = 0; j < vecVecFl[i].size(); ++j) { + result[i][j] = vecVecFl[i][j] + 1.0f; + } + } + return result; + }, + {"classWithArrays.fVecVecFl"}); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "vecClassWithArrays.fArrFl" + if (GetParam()) { + node = node.Define("vecClassWithArrays_fArrFl_plus_1", + [](const std::vector> &vecArrFl) { + std::vector> result(vecArrFl.size()); + for (size_t i = 0; i < vecArrFl.size(); ++i) { + for (size_t j = 0; j < vecArrFl[i].size(); ++j) { + result[i][j] = vecArrFl[i][j] + 1.0f; + } + } + return result; + }, + {vecClassWithArraysArrFlColName}); + } + // Take each column individually and check the content + auto takeTopArrFl = node.Take>("topArrFl_plus_1"); + // Not supported: std::array with T being a class type as top-level branch + // auto takeTopArrArrFl = node.Take>("topArrArrFl"); + // auto takeTopArrVecFl = node.Take>("topArrVecFl"); + auto takeTopVecFl = node.Take>("topVecFl_plus_1"); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // auto takeTopVecArrFl = node.Take>>("topVecArrFl"); + auto takeTopVecVecFl = node.Take>>("topVecVecFl_plus_1"); + auto takeClassWithArrays = node.Take("classWithArrays_plus_1"); + auto takeClassWithArrays_fObjIndex = node.Take("classWithArrays_fObjIndex_plus_1"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::array" for column "classWithArrays.fArrFl" + ROOT::RDF::RResultPtr>> takeClassWithArrays_fArrFl; + if (GetParam()) { + takeClassWithArrays_fArrFl = node.Take>("classWithArrays_fArrFl_plus_1"); + } + // TODO: array of array is currently not properly handled + // auto takeClassWithArrays_fArrArrFl = node.Take, + // 3>>>("classWithArrays.fArrArrFl[3][3]"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeClassWithArrays_fArrVecFl = + // node.Take>>("classWithArrays.fArrVecFl[3]"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector" for column "classWithArrays.fVecFl" + ROOT::RDF::RResultPtr>> takeClassWithArrays_fVecFl; + if (GetParam()) { + takeClassWithArrays_fVecFl = node.Take>("classWithArrays_fVecFl_plus_1"); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "classWithArrays.fVecVecFl" + ROOT::RDF::RResultPtr>>> takeClassWithArrays_fVecVecFl; + if (GetParam()) { + takeClassWithArrays_fVecVecFl = node.Take>>("classWithArrays_fVecVecFl_plus_1"); + } + auto takeVecClassWithArrays = node.Take>("vecClassWithArrays_plus_1"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "vecClassWithArrays.fArrFl" + ROOT::RDF::RResultPtr>>> takeVecClassWithArrays_fArrFl; + if (GetParam()) { + takeVecClassWithArrays_fArrFl = node.Take>>("vecClassWithArrays_fArrFl_plus_1"); + } + // TODO: array of vector as a data member is currently not properly handled + // auto takeVecClassWithArrays_fArrVecFl = + // node.Take,3>>>("vecClassWithArrays.fArrVecFl[3]"); + // TODO: vector of vector throws `std::bad_alloc` currently + // auto takeVecClassWithArrays_fVecFl = + // node.Take>>("vecClassWithArrays.fVecFl"); auto + // takeVecClassWithArrays_fVecVecFl = + // node.Take>>>("vecClassWithArrays.fVecVecFl"); + + auto nEvents = takeTopArrFl->size(); + EXPECT_EQ(nEvents, 1); + + for (decltype(nEvents) i = 0; i < nEvents; ++i) { + check_1d_coll(takeTopArrFl->at(i), data[i].fArrFl); + check_1d_coll(takeTopVecFl->at(i), data[i].fVecFl); + check_2d_coll(takeTopVecVecFl->at(i), data[i].fVecVecFl); + check_class_with_arrays(takeClassWithArrays->at(i), data[i].fClassWithArrays); + EXPECT_EQ(takeClassWithArrays_fObjIndex->at(i), data[i].fClassWithArrays.fObjIndex); + if (GetParam()) { + check_1d_coll(takeClassWithArrays_fArrFl->at(i), data[i].fClassWithArrays.fArrFl); + check_1d_coll(takeClassWithArrays_fVecFl->at(i), data[i].fClassWithArrays.fVecFl); + check_2d_coll(takeClassWithArrays_fVecVecFl->at(i), data[i].fClassWithArrays.fVecVecFl); + } + check_coll_class_with_arrays(takeVecClassWithArrays->at(i), data[i].fVecClassWithArrays); + if (GetParam()) { + std::vector> expectedVecArrFl(data[i].fVecClassWithArrays.size()); + for (size_t j = 0; j < data[i].fVecClassWithArrays.size(); ++j) { + expectedVecArrFl[j] = data[i].fVecClassWithArrays[j].fArrFl; + } + check_2d_coll(takeVecClassWithArrays_fArrFl->at(i), expectedVecArrFl); + } + } +} + +TEST_P(ClassWithSequenceContainersTest, JittedOps) +{ + ROOT::RDataFrame df{fDatasetName, fFileName}; + ROOT::RDF::RNode node = df; + + const std::string classWithArraysArrFlColName = GetParam() ? "classWithArrays.fArrFl[3]" : "classWithArrays.fArrFl"; + const std::string vecClassWithArraysArrFlColName = + GetParam() ? "vecClassWithArrays.fArrFl[3]" : "vecClassWithArrays.fArrFl"; + + auto data = generateClassWithSequenceContainersDataPlusOne(); + + // all the next define calls should be written with jitted code in strings instead of C++ lambdas + node = node.Define("topArrFl_plus_1", + R"CODE( + std::array result; + for (size_t i = 0; i < topArrFl.size(); ++i) { + result[i] = topArrFl[i] + 1.0f; + } + return result; + )CODE"); + node = node.Define("topVecFl_plus_1", + R"CODE( + std::vector result(topVecFl.size()); + for (size_t i = 0; i < topVecFl.size(); ++i) { + result[i] = topVecFl[i] + 1.0f; + } + return result; + )CODE"); + node = node.Define("topVecVecFl_plus_1", + R"CODE( + std::vector> result(topVecVecFl.size()); + for (size_t i = 0; i < topVecVecFl.size(); ++i) { + result[i].resize(topVecVecFl[i].size()); + for (size_t j = 0; j < topVecVecFl[i].size(); ++j) { + result[i][j] = topVecVecFl[i][j] + 1.0f; + } + } + return result; + )CODE"); + node = node.Define("classWithArrays_plus_1", + R"CODE( + ClassWithSequenceContainers result = classWithArrays; + result.fObjIndex += 1; + for (auto &val : result.fArrFl) { + val += 1; + } + for (auto &arr : result.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : result.fArrVecFl) { + for (auto &val : vec) { + val += 1; + } + } + for (auto &val : result.fVecFl) { + val += 1; + } + for (auto &arr : result.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : result.fVecVecFl) { + for (auto &val : vec) { + val += 1; + } + } + return result; + )CODE"); + node = node.Define("vecClassWithArrays_plus_1", + R"CODE( + ROOT::VecOps::RVec result = vecClassWithArrays; + for (auto &obj : result) { + obj.fObjIndex += 1; + for (auto &val : obj.fArrFl) { + val += 1; + } + for (auto &arr : obj.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vvec : obj.fArrVecFl) { + for (auto &val : vvec) { + val += 1; + } + } + for (auto &val : obj.fVecFl) { + val += 1; + } + for (auto &arr : obj.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vvec : obj.fVecVecFl) { + for (auto &val : vvec) { + val += 1; + } + } + } + return result; + )CODE"); + // Also create modified values for data member columns + node = node.Define("classWithArrays_fObjIndex_plus_1", "return classWithArrays.fObjIndex + 1;"); + // Using a branch with an invalid C++ name will break the jitted execution + // node = node.Alias("classWithArrays_fArrFl", "classWithArrays.fArrFl[3]"); + // node = node.Define("classWithArrays_fArrFl_plus_1", + // R"CODE( + // std::array result; + // for (size_t i = 0; i < classWithArrays_fArrFl.size(); ++i) { + // result[i] = classWithArrays_fArrFl[i] + 1.0f; + // } + // return result; + // )CODE"); + node = node.Define("classWithArrays_fVecFl_plus_1", + R"CODE( + std::vector result(classWithArrays.fVecFl.size()); + for (size_t i = 0; i < classWithArrays.fVecFl.size(); ++i) { + result[i] = classWithArrays.fVecFl[i] + 1.0f; + } + return result; + )CODE"); + node = node.Define("classWithArrays_fVecVecFl_plus_1", + R"CODE( + std::vector> result(classWithArrays.fVecVecFl.size()); + for (size_t i = 0; i < classWithArrays.fVecVecFl.size(); ++i) { + result[i].resize(classWithArrays.fVecVecFl[i].size()); + for (size_t j = 0; j < classWithArrays.fVecVecFl[i].size(); ++j) { + result[i][j] = classWithArrays.fVecVecFl[i][j] + 1.0f; + } + } + return result; + )CODE"); + // Using a branch with an invalid C++ name will break the jitted execution + // node = node.Alias("vecClassWithArrays_fArrFl", "vecClassWithArrays.fArrFl[3]"); + // node = node.Define("vecClassWithArrays_fArrFl_plus_1", + // R"CODE( + // std::vector> result(vecClassWithArrays_fArrFl.size()); + // for (size_t i = 0; i < vecClassWithArrays_fArrFl.size(); ++i) { + // for (size_t j = 0; j < vecClassWithArrays_fArrFl[i].size(); ++j) { + // result[i][j] = vecClassWithArrays_fArrFl[i][j] + 1.0f; + // } + // } + // return result; + // )CODE"); + + // Take each column individually and check the content + auto takeTopArrFl = node.Take>("topArrFl_plus_1"); + // Not supported: std::array with T being a class type as top-level branch + // auto takeTopArrArrFl = node.Take>("topArrArrFl"); + // auto takeTopArrVecFl = node.Take>("topArrVecFl"); + auto takeTopVecFl = node.Take>("topVecFl_plus_1"); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // auto takeTopVecArrFl = node.Take>>("topVecArrFl"); + auto takeTopVecVecFl = node.Take>>("topVecVecFl_plus_1"); + auto takeClassWithArrays = node.Take("classWithArrays_plus_1"); + auto takeClassWithArrays_fObjIndex = node.Take("classWithArrays_fObjIndex_plus_1"); + // Using a branch with an invalid C++ name will break the jitted execution + // auto takeClassWithArrays_fArrFl = node.Take>("classWithArrays_fArrFl_plus_1"); + // TODO: array of array is currently not properly handled + // auto takeClassWithArrays_fArrArrFl = node.Take, + // 3>>>("classWithArrays.fArrArrFl[3][3]"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeClassWithArrays_fArrVecFl = + // node.Take>>("classWithArrays.fArrVecFl[3]"); + auto takeClassWithArrays_fVecFl = node.Take>("classWithArrays_fVecFl_plus_1"); + auto takeClassWithArrays_fVecVecFl = node.Take>>("classWithArrays_fVecVecFl_plus_1"); + auto takeVecClassWithArrays = + node.Take>("vecClassWithArrays_plus_1"); + // Using a branch with an invalid C++ name will break the jitted execution + // auto takeVecClassWithArrays_fArrFl = + // node.Take>>("vecClassWithArrays_fArrFl_plus_1"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeVecClassWithArrays_fArrVecFl = + // node.Take,3>>>("vecClassWithArrays.fArrVecFl[3]"); + // TODO: vector of vector throws `std::bad_alloc` currently + // auto takeVecClassWithArrays_fVecFl = + // node.Take>>("vecClassWithArrays.fVecFl"); auto + // takeVecClassWithArrays_fVecVecFl = + // node.Take>>>("vecClassWithArrays.fVecVecFl"); + + auto nEvents = takeTopArrFl->size(); + EXPECT_EQ(nEvents, 1); + + for (decltype(nEvents) i = 0; i < nEvents; ++i) { + check_1d_coll(takeTopArrFl->at(i), data[i].fArrFl); + check_1d_coll(takeTopVecFl->at(i), data[i].fVecFl); + check_2d_coll(takeTopVecVecFl->at(i), data[i].fVecVecFl); + check_class_with_arrays(takeClassWithArrays->at(i), data[i].fClassWithArrays); + EXPECT_EQ(takeClassWithArrays_fObjIndex->at(i), data[i].fClassWithArrays.fObjIndex); + // Using a branch with an invalid C++ name will break the jitted execution + // check_1d_coll(takeClassWithArrays_fArrFl->at(i), data[i].fClassWithArrays.fArrFl); + check_1d_coll(takeClassWithArrays_fVecFl->at(i), data[i].fClassWithArrays.fVecFl); + check_2d_coll(takeClassWithArrays_fVecVecFl->at(i), data[i].fClassWithArrays.fVecVecFl); + check_coll_class_with_arrays(takeVecClassWithArrays->at(i), data[i].fVecClassWithArrays); + // Using a branch with an invalid C++ name will break the jitted execution + // std::vector> expectedVecArrFl(data[i].fVecClassWithArrays.size()); + // for (size_t j = 0; j < data[i].fVecClassWithArrays.size(); ++j) { + // expectedVecArrFl[j] = data[i].fVecClassWithArrays[j].fArrFl; + // } + // check_2d_coll(takeVecClassWithArrays_fArrFl->at(i), expectedVecArrFl); + } +} + +INSTANTIATE_TEST_SUITE_P(Run, ClassWithSequenceContainersTest, ::testing::Values(true, false)); + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx b/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx index ddffc47ae9188..b2dc37c8e85a1 100644 --- a/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx @@ -77,7 +77,8 @@ public: enum class ECollectionType { kRVec, kStdArray, - kRVecBool + kRVecBool, + kStdVector }; RTreeUntypedArrayColumnReader(TTreeReader &r, std::string_view colName, std::string_view valueTypeName, @@ -99,6 +100,9 @@ private: /// We return a reference to this RVec to clients, to guarantee a stable address and contiguous memory layout. RVec fRVec{}; + /// When the user explicitly requests std::vector, we use this std::vector as a stable storage. + std::vector fStdVector{}; + Long64_t fLastEntry = -1; /// The size of the collection value type. @@ -108,6 +112,10 @@ private: bool fCopyWarningPrinted = false; void *GetImpl(Long64_t entry) override; + + void *ReadStdArray(Long64_t entry); + void *ReadStdVector(Long64_t entry); + void *ReadRVec(Long64_t entry); }; class R__CLING_PTRCHECK(off) RMaskedColumnReader : public ROOT::Detail::RDF::RColumnReaderBase { diff --git a/tree/dataframe/src/RDFUtils.cxx b/tree/dataframe/src/RDFUtils.cxx index 4ead72d824c29..5eca6fa0dc8d1 100644 --- a/tree/dataframe/src/RDFUtils.cxx +++ b/tree/dataframe/src/RDFUtils.cxx @@ -251,9 +251,14 @@ std::string GetLeafTypeName(TLeaf *leaf, const std::string &colName) // this is a fixed-sized array (we do not differentiate between variable- and fixed-sized arrays) colType = ComposeRVecTypeName(colType); } else if (leaf->GetLeafCount() != nullptr && leaf->GetLenStatic() > 1) { - // we do not know how to deal with this branch - throw std::runtime_error("TTree leaf " + colName + - " has both a leaf count and a static length. This is not supported."); + // This case is encountered when a branch is a collection (e.g. std::vector) of a user-defined class which has + // a data member that is a fixed-size array. Here, 'leaf' is said data member, and the user could read it + // partially as std::vector>. We expose it as ROOT::RVec> for consistency with + // other collection types. + // WARNING: Currently this considers only the possibility of a 1-dim array, as TLeaf does not expose information + // to get all dimension lengths of a multi-dim array in a straightforward way (e.g. with one API call). + auto valueType = colType; + colType = "ROOT::VecOps::RVecGetLenStatic()) + ">>"; } return colType; diff --git a/tree/dataframe/src/RTTreeDS.cxx b/tree/dataframe/src/RTTreeDS.cxx index 1a401ea495810..7d2ecd8aef91f 100644 --- a/tree/dataframe/src/RTTreeDS.cxx +++ b/tree/dataframe/src/RTTreeDS.cxx @@ -60,6 +60,14 @@ GetCollectionInfo(const std::string &typeName) ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kStdArray}; } + // Find TYPE from std::vector + if (auto pos = beginType.find("vector<"); pos != std::string::npos) { + const auto begin = typeName.find_first_of('<', pos) + 1; + const auto end = typeName.find_last_of('>'); + const auto innerTypeName = typeName.substr(begin, end - begin); + return {true, innerTypeName, ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kStdVector}; + } + return {false, "", ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kRVec}; } diff --git a/tree/dataframe/src/RTreeColumnReader.cxx b/tree/dataframe/src/RTreeColumnReader.cxx index c7f1c929b7ecf..0786bb1a7c6af 100644 --- a/tree/dataframe/src/RTreeColumnReader.cxx +++ b/tree/dataframe/src/RTreeColumnReader.cxx @@ -31,12 +31,77 @@ ROOT::Internal::RDF::RTreeUntypedValueColumnReader::RTreeUntypedValueColumnReade ROOT::Internal::RDF::RTreeUntypedValueColumnReader::~RTreeUntypedValueColumnReader() = default; -void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::GetImpl(Long64_t entry) +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadStdArray(Long64_t entry) +{ + if (entry == fLastEntry) + return fRVec.data(); // We return the RVec we already created + + auto &readerArray = *fTreeArray; + // GetSize is called here to trigger actual reading of the branch proxy, which also sets the appropriate read status + const auto readerArraySize = readerArray.GetSize(); + + // The reader could not read an array, signal this back to the node requesting the value + if (R__unlikely(readerArray.GetReadStatus() == ROOT::Internal::TTreeReaderValueBase::EReadStatus::kReadError)) + return nullptr; + + // std::array storage should always be contiguous + assert(readerArray.IsContiguous() && "std::array storage should always be contiguous"); + + if (readerArraySize > 0) { + // trigger loading of the contents of the TTreeReaderArray + // the address of the first element in the reader array is not necessarily equal to + // the address returned by the GetAddress method + RVec rvec(readerArray.At(0), readerArraySize); + swap(fRVec, rvec); + } else { + fRVec.clear(); + } + + fLastEntry = entry; + return fRVec.data(); +} + +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadStdVector(Long64_t entry) +{ + if (entry == fLastEntry) + return &fStdVector; // We return the std::vector we already created + + auto &readerArray = *fTreeArray; + // GetSize is called here to trigger actual reading of the branch proxy, which also sets the appropriate read status + const auto readerArraySize = readerArray.GetSize(); + + // The reader could not read an array, signal this back to the node requesting the value + if (R__unlikely(readerArray.GetReadStatus() == ROOT::Internal::TTreeReaderValueBase::EReadStatus::kReadError)) + return nullptr; + + // There is no zero-copy constructor for std::vector, so we need to copy the data in any case. + if (readerArraySize > 0) { + // Caching the value type size since GetValueSize might be expensive. + if (fValueSize == 0) + fValueSize = readerArray.GetValueSize(); + assert(fValueSize > 0 && "Could not retrieve size of collection value type."); + // Array is not contiguous, make a full copy of it. + fStdVector.clear(); + fStdVector.reserve(readerArraySize * fValueSize); + for (std::size_t i{0}; i < readerArraySize; i++) { + auto val = readerArray.At(i); + std::copy(val, val + fValueSize, std::back_inserter(fStdVector)); + } + } else { + fStdVector.clear(); + } + + fLastEntry = entry; + return &fStdVector; +} + +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadRVec(Long64_t entry) { if (entry == fLastEntry) - return &fRVec; // we already pointed our fRVec to the right address + return &fRVec; // We return the RVec we already created auto &readerArray = *fTreeArray; + // GetSize is called here to trigger actual reading of the branch proxy, which also sets the appropriate read status const auto readerArraySize = readerArray.GetSize(); // The reader could not read an array, signal this back to the node requesting the value @@ -90,11 +155,20 @@ void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::GetImpl(Long64_t entry fRVec.clear(); } } + fLastEntry = entry; + return &fRVec; +} + +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::GetImpl(Long64_t entry) +{ if (fCollectionType == ECollectionType::kStdArray) - return fRVec.data(); - else - return &fRVec; + return ReadStdArray(entry); + + if (fCollectionType == ECollectionType::kStdVector) + return ReadStdVector(entry); + + return ReadRVec(entry); } ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::RTreeUntypedArrayColumnReader(TTreeReader &r, diff --git a/tree/treeplayer/inc/TTreeReaderArray.h b/tree/treeplayer/inc/TTreeReaderArray.h index a799134bad788..e39a5cd79bd33 100644 --- a/tree/treeplayer/inc/TTreeReaderArray.h +++ b/tree/treeplayer/inc/TTreeReaderArray.h @@ -26,6 +26,15 @@ Base class of TTreeReaderArray. */ class TTreeReaderArrayBase : public TTreeReaderValueBase { + struct StreamerElementArrayInfo { + int fArrayNDims{}; // Number of dimensions of the n-dim array + int fArrayCumulativeLength{}; // Total length of the n-dim array as if it were flattened + int fTDataTypeCode{}; // A number representing the value type of the array, queried via TDataType. + std::array fArrayDims{}; // Length of each dimension of the n-dim array. Max 5 dimensions, aligned with + // TStreamerElement::fMaxIndex + }; + StreamerElementArrayInfo FillStreamerElementArrayInfo(TStreamerElement *elem); + public: TTreeReaderArrayBase(TTreeReader *reader, const char *branchname, TDictionary *dict) : TTreeReaderValueBase(reader, branchname, dict) @@ -48,7 +57,8 @@ class TTreeReaderArrayBase : public TTreeReaderValueBase { bool GetBranchAndLeaf(TBranch *&branch, TLeaf *&myLeaf, TDictionary *&branchActualType, bool suppressErrorsForMissingBranch = false); void SetImpl(TBranch *branch, TLeaf *myLeaf); - const char *GetBranchContentDataType(TBranch *branch, TString &contentTypeName, TDictionary *&dict); + const char *GetBranchContentDataType(TBranch *branch, TString &contentTypeName, TDictionary *&dict, + StreamerElementArrayInfo &arrInfo); std::unique_ptr fImpl; // Common interface to collections diff --git a/tree/treeplayer/src/TTreeReaderArray.cxx b/tree/treeplayer/src/TTreeReaderArray.cxx index 4d0f9c85ca54d..ac99b7f3202a6 100644 --- a/tree/treeplayer/src/TTreeReaderArray.cxx +++ b/tree/treeplayer/src/TTreeReaderArray.cxx @@ -107,7 +107,9 @@ bool IsCPContiguous(const TVirtualCollectionProxy &cp) UInt_t GetCPValueSize(const TVirtualCollectionProxy &cp) { - // This works only if the collection proxy value type is a fundamental type + if (auto cl = cp.GetValueClass()) + return cl->Size(); + auto &&eDataType = cp.GetType(); auto *tDataType = TDataType::GetDataType(eDataType); return tDataType ? tDataType->Size() : 0; @@ -645,7 +647,8 @@ void ROOT::Internal::TTreeReaderArrayBase::CreateProxy() if (!myLeaf) { TString branchActualTypeName; - const char *nonCollTypeName = GetBranchContentDataType(branch, branchActualTypeName, branchActualType); + StreamerElementArrayInfo arrInfo; + const char *nonCollTypeName = GetBranchContentDataType(branch, branchActualTypeName, branchActualType, arrInfo); if (nonCollTypeName) { Error("TTreeReaderArrayBase::CreateContentProxy()", "The branch %s contains data of type %s, which should be accessed through a TTreeReaderValue< %s >.", @@ -671,7 +674,54 @@ void ROOT::Internal::TTreeReaderArrayBase::CreateProxy() return; } - auto matchingDataType = [](TDictionary *left, TDictionary *right) -> bool { + auto matchingArrayInfo = [](TDictionary *requestedDict, const StreamerElementArrayInfo &info) { + // Support the case of a data member of a class being a std::array (or generally an n-dim fixed size array) + // In this case the 'left' TDictionary has been requested as a TClass (e.g. std::array) and thus + // will not be a TDataType, but the 'right' TDictionary will be retrieved as a TDataType (e.g. int), because + // of how the fixed-size array data member is streamed. + // The 'right' TDictionary in this case refers to a TBranchElement representing the data member. We already + // gathered the array info by accessing the TStreamerElement of the TBranchElement, now we need to do the same + // for the 'left' TDictionary, which represents what the user requested. + auto cl = dynamic_cast(requestedDict); + if (!cl) + return false; + auto streamerInfo = cl->GetStreamerInfo(); + if (!streamerInfo) + return false; + + auto streamerElements = streamerInfo->GetElements(); + if (!streamerElements) + return false; + + // This part of the logic currently supports the specific use cases where there is only one streamer element + // corresponding to the array data member that we are trying to partially read into a collection. Without + // further use cases surfacing, we do not support more than one streamer element + if (streamerElements->GetEntries() > 1) + return false; + + auto streamerElement = dynamic_cast(streamerElements->At(0)); + if (!streamerElement) + return false; + + int dataTypeCode{}; + if (auto *dataType = gROOT->GetType(streamerElement->GetTypeNameBasic())) + dataTypeCode = dataType->GetType(); + + auto checkArrayDims = [&]() { + for (int i = 0; i < info.fArrayNDims; i++) { + if (info.fArrayDims[i] != streamerElement->GetMaxIndex(i)) + return false; + } + return true; + }; + + return streamerElement->GetArrayDim() == info.fArrayNDims && + streamerElement->GetArrayLength() == info.fArrayCumulativeLength && checkArrayDims() && + dataTypeCode == info.fTDataTypeCode; + }; + + auto matchingDataType = [&matchingArrayInfo](TDictionary *left, TDictionary *right, + const StreamerElementArrayInfo &info) -> bool { if (left == right) return true; if (!left || !right) @@ -687,6 +737,22 @@ void ROOT::Internal::TTreeReaderArrayBase::CreateProxy() if ((left_datatype && right_enum && left_datatype->GetType() == right_enum->GetUnderlyingType()) || (right_datatype && left_enum && right_datatype->GetType() == left_enum->GetUnderlyingType())) return true; + + // Allow reading nested std::array data members of top-level std::vector types. The user has requested + // e.g. TTreeReaderArray> and we allow partial reading of the std::vector as a + // collection of std::array + if (matchingArrayInfo(left, info)) + return true; + + // Allow reading a std::array data member of a top-level class branch when requesting a TTreeReaderArray + // of the same type as the std::array data type (e.g. branch contains std::array and user requests + // TTreeReaderArray). In this case the 'left' dictionary is going to be a TDataType of int. + if (left_datatype) { + auto typeCode = left_datatype->GetType(); + if (typeCode > 0 && typeCode == info.fTDataTypeCode) + return true; + } + if (!left_datatype || !right_datatype) return false; auto l = left_datatype->GetType(); @@ -698,7 +764,7 @@ void ROOT::Internal::TTreeReaderArrayBase::CreateProxy() (l == kFloat16_t && r == kFloat_t) || (l == kFloat_t && r == kFloat16_t)); }; - if (!matchingDataType(fDict, branchActualType)) { + if (!matchingDataType(fDict, branchActualType, arrInfo)) { Error("TTreeReaderArrayBase::CreateContentProxy()", "The branch %s contains data of type %s. It cannot be accessed by a TTreeReaderArray<%s>", fBranchName.Data(), branchActualType->GetName(), fDict->GetName()); @@ -919,6 +985,24 @@ void ROOT::Internal::TTreeReaderArrayBase::SetImpl(TBranch *branch, TLeaf *myLea } } +ROOT::Internal::TTreeReaderArrayBase::StreamerElementArrayInfo +ROOT::Internal::TTreeReaderArrayBase::FillStreamerElementArrayInfo(TStreamerElement *element) +{ + StreamerElementArrayInfo arrInfo{}; + if (!element) + return arrInfo; + + arrInfo.fArrayNDims = element->GetArrayDim(); + arrInfo.fArrayCumulativeLength = element->GetArrayLength(); + if (auto *datatype = gROOT->GetType(element->GetTypeNameBasic())) { + arrInfo.fTDataTypeCode = datatype->GetType(); + } + for (int i = 0; i < arrInfo.fArrayNDims; ++i) + arrInfo.fArrayDims[i] = element->GetMaxIndex(i); + + return arrInfo; +} + //////////////////////////////////////////////////////////////////////////////// /// Access a branch's collection content (not the collection itself) /// through a proxy. @@ -930,7 +1014,8 @@ void ROOT::Internal::TTreeReaderArrayBase::SetImpl(TBranch *branch, TLeaf *myLea /// In all other cases, NULL is returned. const char *ROOT::Internal::TTreeReaderArrayBase::GetBranchContentDataType(TBranch *branch, TString &contentTypeName, - TDictionary *&dict) + TDictionary *&dict, + StreamerElementArrayInfo &arrInfo) { dict = nullptr; contentTypeName = ""; @@ -999,6 +1084,8 @@ const char *ROOT::Internal::TTreeReaderArrayBase::GetBranchContentDataType(TBran contentTypeName = TDataType::GetTypeName(dtData); return nullptr; } + // Fill information about the data member being an n-dim array + arrInfo = FillStreamerElementArrayInfo(brElement->GetInfo()->GetElement(brElement->GetID())); return nullptr; } else if (ExpectedTypeRet == 1) { int brID = brElement->GetID(); @@ -1026,7 +1113,8 @@ const char *ROOT::Internal::TTreeReaderArrayBase::GetBranchContentDataType(TBran if (id >= 0) { TStreamerElement *element = (TStreamerElement *)streamerInfo->GetElements()->At(id); - + // Fill information about the data member being an n-dim array + arrInfo = FillStreamerElementArrayInfo(element); if (element->IsA() == TStreamerSTL::Class()) { TClass *myClass = brElement->GetCurrentClass(); if (!myClass) {