diff --git a/roottest/root/dataframe/CMakeLists.txt b/roottest/root/dataframe/CMakeLists.txt index 1322a8551da94..80265cd8aabd5 100644 --- a/roottest/root/dataframe/CMakeLists.txt +++ b/roottest/root/dataframe/CMakeLists.txt @@ -330,3 +330,26 @@ ROOTTEST_ADD_TEST(test_snapshot_copyaddresses MACRO test_snapshot_copyaddresses.C+) ROOT_ADD_GTEST(test_norootextension test_norootextension.cxx LIBRARIES ROOT::ROOTDataFrame) + +# In 6.38 on Windows the following test would not be able to compile because it can't find the ROOT/TestSupport.hxx +# header even though the dependency is specified in ROOTTEST_GENERATE_EXECUTABLE. Given this test properly compiles in +# later versions of ROOT on Windows, we disable it for the 6.38 release. +if(NOT MSVC) + ROOTTEST_GENERATE_DICTIONARY( + ClassWithSequenceContainersDict + ${CMAKE_CURRENT_SOURCE_DIR}/ClassWithSequenceContainers.hxx + LINKDEF ${CMAKE_CURRENT_SOURCE_DIR}/ClassWithSequenceContainersLinkDef.hxx + FIXTURES_SETUP ClassWithSequenceContainersDict_setup + ) + ROOTTEST_GENERATE_EXECUTABLE( + dataframe_sequence_containers + dataframe_sequence_containers.cxx ClassWithSequenceContainersDict.cxx + LIBRARIES Core RIO Tree ROOTDataFrame GTest::gtest GTest::gtest_main GTest::gmock GTest::gmock_main ROOT::TestSupport + FIXTURES_REQUIRED ClassWithSequenceContainersDict_setup + FIXTURES_SETUP dataframe_sequence_containers_setup + ) + ROOTTEST_ADD_TEST(dataframe_sequence_containers + EXEC ${CMAKE_CURRENT_BINARY_DIR}/dataframe_sequence_containers + FIXTURES_REQUIRED dataframe_sequence_containers_setup + ) +endif() diff --git a/roottest/root/dataframe/ClassWithSequenceContainers.hxx b/roottest/root/dataframe/ClassWithSequenceContainers.hxx new file mode 100644 index 0000000000000..2f4ef62976cf4 --- /dev/null +++ b/roottest/root/dataframe/ClassWithSequenceContainers.hxx @@ -0,0 +1,39 @@ +#ifndef ROOT_DATAFRAME_TEST_ARRRAYCOMBINATIONS +#define ROOT_DATAFRAME_TEST_ARRRAYCOMBINATIONS + +#include +#include + +#include + +struct ClassWithSequenceContainers { + unsigned int fObjIndex{}; + std::array fArrFl{}; + std::array, 3> fArrArrFl{}; + std::array, 3> fArrVecFl{}; + + std::vector fVecFl{}; + std::vector> fVecArrFl{}; //! Not supported for TTree: could not find the real data member + //! '_M_elems[3]' when constructing the branch 'fVecArrFl' + std::vector> fVecVecFl{}; + + // For ROOT I/O + ClassWithSequenceContainers() = default; + + ClassWithSequenceContainers(unsigned int objIndex, std::array a1, std::array, 3> a2, + std::array, 3> a3, std::vector a4, + std::vector> a5, std::vector> a6) + : fObjIndex(objIndex), + fArrFl(std::move(a1)), + fArrArrFl(std::move(a2)), + fArrVecFl(std::move(a3)), + fVecFl(std::move(a4)), + fVecArrFl(std::move(a5)), + fVecVecFl(std::move(a6)) + { + } + + ClassDefNV(ClassWithSequenceContainers, 1) +}; + +#endif diff --git a/roottest/root/dataframe/ClassWithSequenceContainersLinkDef.hxx b/roottest/root/dataframe/ClassWithSequenceContainersLinkDef.hxx new file mode 100644 index 0000000000000..25a4a559190f0 --- /dev/null +++ b/roottest/root/dataframe/ClassWithSequenceContainersLinkDef.hxx @@ -0,0 +1,6 @@ +#ifdef __CLING__ + +#pragma link C++ class ClassWithSequenceContainers+; +#pragma link C++ class std::vector+; + +#endif diff --git a/roottest/root/dataframe/dataframe_sequence_containers.cxx b/roottest/root/dataframe/dataframe_sequence_containers.cxx new file mode 100644 index 0000000000000..117a0e738002c --- /dev/null +++ b/roottest/root/dataframe/dataframe_sequence_containers.cxx @@ -0,0 +1,916 @@ +#include + +#include + +#include + +#include +#include + +#include +#include + +#include +#include + +#include "ClassWithSequenceContainers.hxx" + +using ROOT::RNTupleModel; +using ROOT::RNTupleWriter; + +struct ClassWithSequenceContainersData { + std::array fArrFl{}; + std::array, 3> fArrArrFl{}; + std::array, 3> fArrVecFl{}; + + std::vector fVecFl{}; + std::vector> fVecArrFl{}; + std::vector> fVecVecFl{}; + + ClassWithSequenceContainers fClassWithArrays{}; + std::vector fVecClassWithArrays{}; +}; + +std::vector generateClassWithSequenceContainersData() +{ + // ClassWithSequenceContainers members + std::array topArrFl{1.1, 2.2, 3.3}; + std::array, 3> topArrArrFl{{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}}; + std::array, 3> topArrVecFl{{{11.11}, {12.12, 13.13}, {14.14, 15.15, 16.16}}}; + + std::vector topVecFl{17.17, 18.18, 19.19}; + std::vector> topVecArrFl{{{21.21, 22.22, 23.23}, + {24.24, 25.25, 26.26}, + {27.27, 28.28, 29.29}, + {31.31, 32.32, 33.33}, + {34.34, 35.35, 36.36}, + {37.37, 38.38, 39.39}}}; + std::vector> topVecVecFl{{}, {41.41}, {42.42, 43.43}, {44.44, 45.45, 46.46}}; + + // Class object + ClassWithSequenceContainers classWithArrays(0, topArrFl, topArrArrFl, topArrVecFl, topVecFl, topVecArrFl, + topVecVecFl); + + // std::vector of class objects + std::vector vecClassWithArrays; + vecClassWithArrays.reserve(5); + for (int i = 1; i < 6; i++) { + vecClassWithArrays.emplace_back(i, topArrFl, topArrArrFl, topArrVecFl, topVecFl, topVecArrFl, topVecVecFl); + } + return std::vector{ClassWithSequenceContainersData{topArrFl, topArrArrFl, topArrVecFl, topVecFl, topVecArrFl, + topVecVecFl, classWithArrays, vecClassWithArrays}}; +} + +std::vector generateClassWithSequenceContainersDataPlusOne() +{ + auto data = generateClassWithSequenceContainersData(); + for (auto &entry : data) { + for (auto &val : entry.fArrFl) { + val += 1; + } + for (auto &arr : entry.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : entry.fArrVecFl) { + for (auto &val : vec) { + val += 1; + } + } + for (auto &val : entry.fVecFl) { + val += 1; + } + for (auto &arr : entry.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : entry.fVecVecFl) { + for (auto &val : vec) { + val += 1; + } + } + auto classWithArraysPlusOne = [](ClassWithSequenceContainers &obj) { + obj.fObjIndex += 1; + for (auto &val : obj.fArrFl) { + val += 1; + } + for (auto &arr : obj.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : obj.fArrVecFl) { + for (auto &val : vec) { + val += 1; + } + } + for (auto &val : obj.fVecFl) { + val += 1; + } + for (auto &arr : obj.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : obj.fVecVecFl) { + for (auto &val : vec) { + val += 1; + } + } + }; + classWithArraysPlusOne(entry.fClassWithArrays); + for (auto &obj : entry.fVecClassWithArrays) { + classWithArraysPlusOne(obj); + } + } + return data; +} + +class ClassWithSequenceContainersTest : public ::testing::TestWithParam { +protected: + constexpr static const char *fFileName = "root_dataframe_sequence_containers_ClassWithSequenceContainers.root"; + constexpr static const char *fDatasetName = "root_dataframe_sequence_containers_ClassWithSequenceContainers"; + + void WriteTTree() + { + auto f = std::make_unique(fFileName, "RECREATE"); + auto t = std::make_unique(fDatasetName, fDatasetName); + + auto data = generateClassWithSequenceContainersData(); + + // Branches + t->Branch("topArrFl", &data[0].fArrFl); + // Not supported: std::array with T being a class type as top-level branch + // t->Branch("topArrArrFl", &data.topArrArrFl); + // t->Branch("topArrVecFl", &data.topArrVecFl); + t->Branch("topVecFl", &data[0].fVecFl); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // t->Branch("topVecArrFl", &data.fVecArrFl); + t->Branch("topVecVecFl", &data[0].fVecVecFl); + t->Branch("classWithArrays", &data[0].fClassWithArrays); + t->Branch("vecClassWithArrays", &data[0].fVecClassWithArrays); + + t->Fill(); + + f->Write(); + } + + void WriteRNTuple() + { + auto model = RNTupleModel::Create(); + auto topArrFl = model->MakeField>("topArrFl"); + auto topArrArrFl = model->MakeField, 3>>("topArrArrFl"); + auto topArrVecFl = model->MakeField, 3>>("topArrVecFl"); + + auto topVecFl = model->MakeField>("topVecFl"); + auto topVecArrFl = model->MakeField>>("topVecArrFl"); + auto topVecVecFl = model->MakeField>>("topVecVecFl"); + + auto classWithArrays = model->MakeField("classWithArrays"); + auto vecClassWithArrays = model->MakeField>("vecClassWithArrays"); + auto ntuple = RNTupleWriter::Recreate(std::move(model), fDatasetName, fFileName); + auto data = generateClassWithSequenceContainersData(); + for (const auto &entry : data) { + *topArrFl = entry.fArrFl; + *topArrArrFl = entry.fArrArrFl; + *topArrVecFl = entry.fArrVecFl; + *topVecFl = entry.fVecFl; + *topVecArrFl = entry.fVecArrFl; + *topVecVecFl = entry.fVecVecFl; + *classWithArrays = entry.fClassWithArrays; + *vecClassWithArrays = entry.fVecClassWithArrays; + ntuple->Fill(); + } + } + + ClassWithSequenceContainersTest() + { + if (GetParam()) { + WriteTTree(); + } else { + WriteRNTuple(); + } + } + + ~ClassWithSequenceContainersTest() override { std::remove(fFileName); } +}; + +template +void check_1d_coll(const T &coll1, const U &coll2) +{ + ASSERT_EQ(coll1.size(), coll2.size()); + for (size_t i = 0; i < coll1.size(); ++i) { + EXPECT_FLOAT_EQ(coll1[i], coll2[i]) << " at index " << i; + } +} + +template +void check_2d_coll(const T &coll1, const U &coll2) +{ + ASSERT_EQ(coll1.size(), coll2.size()); + for (size_t i = 0; i < coll1.size(); ++i) { + check_1d_coll(coll1[i], coll2[i]); + } +} + +void check_class_with_arrays(const ClassWithSequenceContainers &obj1, const ClassWithSequenceContainers &obj2) +{ + EXPECT_EQ(obj1.fObjIndex, obj2.fObjIndex); + check_1d_coll(obj1.fArrFl, obj2.fArrFl); + check_2d_coll(obj1.fArrArrFl, obj2.fArrArrFl); + check_2d_coll(obj1.fArrVecFl, obj2.fArrVecFl); + check_1d_coll(obj1.fVecFl, obj2.fVecFl); + // fVecArrFl is not supported: Could not find the real data member '_M_elems[3]' when constructing the branch + // 'fVecArrFl' + // check_2d_coll(obj1.fVecArrFl, obj2.fVecArrFl); + check_2d_coll(obj1.fVecVecFl, obj2.fVecVecFl); +} + +template +void check_coll_class_with_arrays(const T &coll1, const U &coll2) +{ + ASSERT_EQ(coll1.size(), coll2.size()); + for (size_t i = 0; i < coll1.size(); ++i) { + check_class_with_arrays(coll1[i], coll2[i]); + } +} + +TEST_P(ClassWithSequenceContainersTest, ExpectedTypes) +{ + // TODO: This test currently assumes spellings of column types when the + // data format is TTree + if (!GetParam()) + return; // The expected type names are different between the TTree and RNTuple data sources + + ROOT::RDataFrame df{fDatasetName, fFileName}; + + const std::unordered_map expectedColTypes{ + {"topArrFl", "ROOT::VecOps::RVec"}, + // Not supported: std::array with T being a class type as top-level branch + // {"topArrArrFl", "ROOT::VecOps::RVec>"}, + // {"topArrVecFl", "ROOT::VecOps::RVec>"}, + {"topVecFl", "ROOT::VecOps::RVec"}, + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // {"topVecArrFl", "ROOT::VecOps::RVec>"}, + {"topVecVecFl", "ROOT::VecOps::RVec>"}, + {"classWithArrays", "ClassWithSequenceContainers"}, + {"classWithArrays.fObjIndex", "UInt_t"}, + {"classWithArrays.fArrFl[3]", "ROOT::VecOps::RVec"}, + // TODO: array of array is currently not properly handled + // {"classWithArrays.fArrArrFl[3][3]", "ROOT::VecOps::RVec>"}, + {"classWithArrays.fArrVecFl[3]", "ROOT::VecOps::RVec>"}, + {"classWithArrays.fVecFl", "ROOT::VecOps::RVec"}, + {"classWithArrays.fVecVecFl", "ROOT::VecOps::RVec>"}, + {"vecClassWithArrays", "ROOT::VecOps::RVec"}, + {"vecClassWithArrays.fArrFl[3]", "ROOT::VecOps::RVec>"}, + {"vecClassWithArrays.fArrVecFl[3]", "ROOT::VecOps::RVec, 3>>"}, + {"vecClassWithArrays.fVecFl", "ROOT::VecOps::RVec>"}, + {"vecClassWithArrays.fVecVecFl", "ROOT::VecOps::RVec >>"}, + }; + for (const auto &[colName, expectedType] : expectedColTypes) { + EXPECT_EQ(df.GetColumnType(colName), expectedType) << " for column " << colName; + } +} + +TEST_P(ClassWithSequenceContainersTest, TakeExpectedTypes) +{ +#ifndef NDEBUG + // The following warning is only for debugging purposes with the TTree data source. It happens in this test + // because we are reading a class partially, i.e. the branch type is std::vector but + // we are only reading the data member fArrFl with the column name "classWithArrays.fArrFl[3]" as an + // RVec>. + ROOT::TestSupport::CheckDiagsRAII diagRAII; + diagRAII.optionalDiag( + kWarning, "RTreeColumnReader::Get", + "hangs from a non-split branch. A copy is being performed in order to properly read the content.", false); +#endif + + ROOT::RDataFrame df{fDatasetName, fFileName}; + + auto data = generateClassWithSequenceContainersData(); + + const std::string classWithArraysArrFlColName = GetParam() ? "classWithArrays.fArrFl[3]" : "classWithArrays.fArrFl"; + const std::string vecClassWithArraysArrFlColName = + GetParam() ? "vecClassWithArrays.fArrFl[3]" : "vecClassWithArrays.fArrFl"; + + // Take each column individually and check the content + // In this test, use the types as expected by the test "ExpectedTypes" + auto takeTopArrFl = df.Take>("topArrFl"); + // Not supported: std::array with T being a class type as top-level branch + // auto takeTopArrArrFl = df.Take>("topArrArrFl"); + // auto takeTopArrVecFl = df.Take>("topArrVecFl"); + auto takeTopVecFl = df.Take>("topVecFl"); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // auto takeTopVecArrFl = df.Take>>("topVecArrFl"); + auto takeTopVecVecFl = df.Take>>("topVecVecFl"); + auto takeClassWithArrays = df.Take("classWithArrays"); + auto takeClassWithArrays_fObjIndex = df.Take("classWithArrays.fObjIndex"); + auto takeClassWithArrays_fArrFl = df.Take>(classWithArraysArrFlColName); + // TODO: array of array is currently not properly handled + // auto takeClassWithArrays_fArrArrFl = df.Take, + // 3>>>("classWithArrays.fArrArrFl[3][3]"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeClassWithArrays_fArrVecFl = + // df.Take>>("classWithArrays.fArrVecFl[3]"); + auto takeClassWithArrays_fVecFl = df.Take>("classWithArrays.fVecFl"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "ROOT::VecOps::RVec>" for column + // "classWithArrays.fVecVecFl" + ROOT::RDF::RResultPtr>>> takeClassWithArrays_fVecVecFl; + if (GetParam()) { + takeClassWithArrays_fVecVecFl = df.Take>>("classWithArrays.fVecVecFl"); + } + auto takeVecClassWithArrays = df.Take>("vecClassWithArrays"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::array" for column "classWithArrays.fArrFl" + ROOT::RDF::RResultPtr>>> takeVecClassWithArrays_fArrFl; + if (GetParam()) { + takeVecClassWithArrays_fArrFl = + df.Take>>(vecClassWithArraysArrFlColName); + } + // TODO: array of vector as a data member is currently not properly handled + // auto takeVecClassWithArrays_fArrVecFl = + // df.Take,3>>>("vecClassWithArrays.fArrVecFl[3]"); + // TODO: vector of vector throws `std::bad_alloc` currently + // auto takeVecClassWithArrays_fVecFl = df.Take>>("vecClassWithArrays.fVecFl"); + // auto takeVecClassWithArrays_fVecVecFl = + // df.Take>>>("vecClassWithArrays.fVecVecFl"); + + auto nEvents = takeTopArrFl->size(); + EXPECT_EQ(nEvents, 1); + + for (decltype(nEvents) i = 0; i < nEvents; ++i) { + check_1d_coll(takeTopArrFl->at(i), data[i].fArrFl); + check_1d_coll(takeTopVecFl->at(i), data[i].fVecFl); + check_2d_coll(takeTopVecVecFl->at(i), data[i].fVecVecFl); + check_class_with_arrays(takeClassWithArrays->at(i), data[i].fClassWithArrays); + EXPECT_EQ(takeClassWithArrays_fObjIndex->at(i), data[i].fClassWithArrays.fObjIndex); + check_1d_coll(takeClassWithArrays_fArrFl->at(i), data[i].fClassWithArrays.fArrFl); + check_1d_coll(takeClassWithArrays_fVecFl->at(i), data[i].fClassWithArrays.fVecFl); + if (GetParam()) { + check_2d_coll(takeClassWithArrays_fVecVecFl->at(i), data[i].fClassWithArrays.fVecVecFl); + } + check_coll_class_with_arrays(takeVecClassWithArrays->at(i), data[i].fVecClassWithArrays); + if (GetParam()) { + std::vector> expectedVecArrFl(data[i].fVecClassWithArrays.size()); + for (size_t j = 0; j < data[i].fVecClassWithArrays.size(); ++j) { + expectedVecArrFl[j] = data[i].fVecClassWithArrays[j].fArrFl; + } + check_2d_coll(takeVecClassWithArrays_fArrFl->at(i), expectedVecArrFl); + } + } +} + +TEST_P(ClassWithSequenceContainersTest, TakeOriginalTypes) +{ + ROOT::RDataFrame df{fDatasetName, fFileName}; + + const std::string classWithArraysArrFlColName = GetParam() ? "classWithArrays.fArrFl[3]" : "classWithArrays.fArrFl"; + const std::string vecClassWithArraysArrFlColName = + GetParam() ? "vecClassWithArrays.fArrFl[3]" : "vecClassWithArrays.fArrFl"; + + auto data = generateClassWithSequenceContainersData(); + + // Take each column individually and check the content + // In this test, call Take using only original types as written in the EDM + auto takeTopArrFl = df.Take>("topArrFl"); + // Not supported: std::array with T being a class type as top-level branch + // auto takeTopArrArrFl = df.Take>("topArrArrFl"); + // auto takeTopArrVecFl = df.Take>("topArrVecFl"); + auto takeTopVecFl = df.Take>("topVecFl"); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // auto takeTopVecArrFl = df.Take>>("topVecArrFl"); + auto takeTopVecVecFl = df.Take>>("topVecVecFl"); + auto takeClassWithArrays = df.Take("classWithArrays"); + auto takeClassWithArrays_fObjIndex = df.Take("classWithArrays.fObjIndex"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::array" for column "classWithArrays.fArrFl" + ROOT::RDF::RResultPtr>> takeClassWithArrays_fArrFl; + if (GetParam()) { + takeClassWithArrays_fArrFl = df.Take>(classWithArraysArrFlColName); + } + // TODO: array of array is currently not properly handled + // auto takeClassWithArrays_fArrArrFl = df.Take, + // 3>>>("classWithArrays.fArrArrFl[3][3]"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeClassWithArrays_fArrVecFl = + // df.Take>>("classWithArrays.fArrVecFl[3]"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector" for column "classWithArrays.fVecFl" + ROOT::RDF::RResultPtr>> takeClassWithArrays_fVecFl; + if (GetParam()) { + takeClassWithArrays_fVecFl = df.Take>("classWithArrays.fVecFl"); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "classWithArrays.fVecVecFl" + ROOT::RDF::RResultPtr>>> takeClassWithArrays_fVecVecFl; + if (GetParam()) { + takeClassWithArrays_fVecVecFl = df.Take>>("classWithArrays.fVecVecFl"); + } + auto takeVecClassWithArrays = df.Take>("vecClassWithArrays"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "vecClassWithArrays.fArrFl" + ROOT::RDF::RResultPtr>>> takeVecClassWithArrays_fArrFl; + if (GetParam()) { + takeVecClassWithArrays_fArrFl = df.Take>>(vecClassWithArraysArrFlColName); + } + // TODO: array of vector as a data member is currently not properly handled + // auto takeVecClassWithArrays_fArrVecFl = + // df.Take,3>>>("vecClassWithArrays.fArrVecFl[3]"); + // TODO: vector of vector throws `std::bad_alloc` currently + // auto takeVecClassWithArrays_fVecFl = df.Take>>("vecClassWithArrays.fVecFl"); + // auto takeVecClassWithArrays_fVecVecFl = + // df.Take>>>("vecClassWithArrays.fVecVecFl"); + + auto nEvents = takeTopArrFl->size(); + EXPECT_EQ(nEvents, 1); + + for (decltype(nEvents) i = 0; i < nEvents; ++i) { + check_1d_coll(takeTopArrFl->at(i), data[i].fArrFl); + check_1d_coll(takeTopVecFl->at(i), data[i].fVecFl); + check_2d_coll(takeTopVecVecFl->at(i), data[i].fVecVecFl); + check_class_with_arrays(takeClassWithArrays->at(i), data[i].fClassWithArrays); + EXPECT_EQ(takeClassWithArrays_fObjIndex->at(i), data[i].fClassWithArrays.fObjIndex); + if (GetParam()) { + check_1d_coll(takeClassWithArrays_fArrFl->at(i), data[i].fClassWithArrays.fArrFl); + check_1d_coll(takeClassWithArrays_fVecFl->at(i), data[i].fClassWithArrays.fVecFl); + check_2d_coll(takeClassWithArrays_fVecVecFl->at(i), data[i].fClassWithArrays.fVecVecFl); + } + check_coll_class_with_arrays(takeVecClassWithArrays->at(i), data[i].fVecClassWithArrays); + if (GetParam()) { + std::vector> expectedVecArrFl(data[i].fVecClassWithArrays.size()); + for (size_t j = 0; j < data[i].fVecClassWithArrays.size(); ++j) { + expectedVecArrFl[j] = data[i].fVecClassWithArrays[j].fArrFl; + } + check_2d_coll(takeVecClassWithArrays_fArrFl->at(i), expectedVecArrFl); + } + } +} + +TEST_P(ClassWithSequenceContainersTest, TemplatedOps) +{ + ROOT::RDataFrame df{fDatasetName, fFileName}; + ROOT::RDF::RNode node = df; + + const std::string classWithArraysArrFlColName = GetParam() ? "classWithArrays.fArrFl[3]" : "classWithArrays.fArrFl"; + const std::string vecClassWithArraysArrFlColName = + GetParam() ? "vecClassWithArrays.fArrFl[3]" : "vecClassWithArrays.fArrFl"; + + auto data = generateClassWithSequenceContainersDataPlusOne(); + + node = node.Define("topArrFl_plus_1", + [](const std::array &arr) { + std::array result; + for (size_t i = 0; i < arr.size(); ++i) { + result[i] = arr[i] + 1.0f; + } + return result; + }, + {"topArrFl"}); + node = node.Define("topVecFl_plus_1", + [](const std::vector &vec) { + std::vector result(vec.size()); + for (size_t i = 0; i < vec.size(); ++i) { + result[i] = vec[i] + 1.0f; + } + return result; + }, + {"topVecFl"}); + node = node.Define("topVecVecFl_plus_1", + [](const std::vector> &vecvec) { + std::vector> result(vecvec.size()); + for (size_t i = 0; i < vecvec.size(); ++i) { + result[i].resize(vecvec[i].size()); + for (size_t j = 0; j < vecvec[i].size(); ++j) { + result[i][j] = vecvec[i][j] + 1.0f; + } + } + return result; + }, + {"topVecVecFl"}); + node = node.Define("classWithArrays_plus_1", + [](const ClassWithSequenceContainers &obj) { + ClassWithSequenceContainers result = obj; + result.fObjIndex += 1; + for (auto &val : result.fArrFl) { + val += 1; + } + for (auto &arr : result.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : result.fArrVecFl) { + for (auto &val : vec) { + val += 1; + } + } + for (auto &val : result.fVecFl) { + val += 1; + } + for (auto &arr : result.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : result.fVecVecFl) { + for (auto &val : vec) { + val += 1; + } + } + return result; + }, + {"classWithArrays"}); + node = node.Define("vecClassWithArrays_plus_1", + [](const std::vector &vec) { + std::vector result = vec; + for (auto &obj : result) { + obj.fObjIndex += 1; + for (auto &val : obj.fArrFl) { + val += 1; + } + for (auto &arr : obj.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vvec : obj.fArrVecFl) { + for (auto &val : vvec) { + val += 1; + } + } + for (auto &val : obj.fVecFl) { + val += 1; + } + for (auto &arr : obj.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vvec : obj.fVecVecFl) { + for (auto &val : vvec) { + val += 1; + } + } + } + return result; + }, + {"vecClassWithArrays"}); + // Also create modified values for data member columns + node = node.Define("classWithArrays_fObjIndex_plus_1", [](unsigned int objIndex) { return objIndex + 1; }, + {"classWithArrays.fObjIndex"}); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::array" for column "classWithArrays.fArrFl" + if (GetParam()) { + node = node.Define("classWithArrays_fArrFl_plus_1", + [](const std::array &arr) { + std::array result; + for (size_t i = 0; i < arr.size(); ++i) { + result[i] = arr[i] + 1.0f; + } + return result; + }, + {classWithArraysArrFlColName}); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector" for column "classWithArrays.fVecFl" + if (GetParam()) { + node = node.Define("classWithArrays_fVecFl_plus_1", + [](const std::vector &vec) { + std::vector result(vec.size()); + for (size_t i = 0; i < vec.size(); ++i) { + result[i] = vec[i] + 1.0f; + } + return result; + }, + {"classWithArrays.fVecFl"}); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "classWithArrays.fVecVecFl" + if (GetParam()) { + node = node.Define("classWithArrays_fVecVecFl_plus_1", + [](const std::vector> &vecVecFl) { + std::vector> result(vecVecFl.size()); + for (size_t i = 0; i < vecVecFl.size(); ++i) { + result[i].resize(vecVecFl[i].size()); + for (size_t j = 0; j < vecVecFl[i].size(); ++j) { + result[i][j] = vecVecFl[i][j] + 1.0f; + } + } + return result; + }, + {"classWithArrays.fVecVecFl"}); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "vecClassWithArrays.fArrFl" + if (GetParam()) { + node = node.Define("vecClassWithArrays_fArrFl_plus_1", + [](const std::vector> &vecArrFl) { + std::vector> result(vecArrFl.size()); + for (size_t i = 0; i < vecArrFl.size(); ++i) { + for (size_t j = 0; j < vecArrFl[i].size(); ++j) { + result[i][j] = vecArrFl[i][j] + 1.0f; + } + } + return result; + }, + {vecClassWithArraysArrFlColName}); + } + // Take each column individually and check the content + auto takeTopArrFl = node.Take>("topArrFl_plus_1"); + // Not supported: std::array with T being a class type as top-level branch + // auto takeTopArrArrFl = node.Take>("topArrArrFl"); + // auto takeTopArrVecFl = node.Take>("topArrVecFl"); + auto takeTopVecFl = node.Take>("topVecFl_plus_1"); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // auto takeTopVecArrFl = node.Take>>("topVecArrFl"); + auto takeTopVecVecFl = node.Take>>("topVecVecFl_plus_1"); + auto takeClassWithArrays = node.Take("classWithArrays_plus_1"); + auto takeClassWithArrays_fObjIndex = node.Take("classWithArrays_fObjIndex_plus_1"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::array" for column "classWithArrays.fArrFl" + ROOT::RDF::RResultPtr>> takeClassWithArrays_fArrFl; + if (GetParam()) { + takeClassWithArrays_fArrFl = node.Take>("classWithArrays_fArrFl_plus_1"); + } + // TODO: array of array is currently not properly handled + // auto takeClassWithArrays_fArrArrFl = node.Take, + // 3>>>("classWithArrays.fArrArrFl[3][3]"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeClassWithArrays_fArrVecFl = + // node.Take>>("classWithArrays.fArrVecFl[3]"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector" for column "classWithArrays.fVecFl" + ROOT::RDF::RResultPtr>> takeClassWithArrays_fVecFl; + if (GetParam()) { + takeClassWithArrays_fVecFl = node.Take>("classWithArrays_fVecFl_plus_1"); + } + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "classWithArrays.fVecVecFl" + ROOT::RDF::RResultPtr>>> takeClassWithArrays_fVecVecFl; + if (GetParam()) { + takeClassWithArrays_fVecVecFl = node.Take>>("classWithArrays_fVecVecFl_plus_1"); + } + auto takeVecClassWithArrays = node.Take>("vecClassWithArrays_plus_1"); + // RNTuple currently fails with the following operation with error + // RNTupleDS: Could not create field with type "std::vector>" for column + // "vecClassWithArrays.fArrFl" + ROOT::RDF::RResultPtr>>> takeVecClassWithArrays_fArrFl; + if (GetParam()) { + takeVecClassWithArrays_fArrFl = node.Take>>("vecClassWithArrays_fArrFl_plus_1"); + } + // TODO: array of vector as a data member is currently not properly handled + // auto takeVecClassWithArrays_fArrVecFl = + // node.Take,3>>>("vecClassWithArrays.fArrVecFl[3]"); + // TODO: vector of vector throws `std::bad_alloc` currently + // auto takeVecClassWithArrays_fVecFl = + // node.Take>>("vecClassWithArrays.fVecFl"); auto + // takeVecClassWithArrays_fVecVecFl = + // node.Take>>>("vecClassWithArrays.fVecVecFl"); + + auto nEvents = takeTopArrFl->size(); + EXPECT_EQ(nEvents, 1); + + for (decltype(nEvents) i = 0; i < nEvents; ++i) { + check_1d_coll(takeTopArrFl->at(i), data[i].fArrFl); + check_1d_coll(takeTopVecFl->at(i), data[i].fVecFl); + check_2d_coll(takeTopVecVecFl->at(i), data[i].fVecVecFl); + check_class_with_arrays(takeClassWithArrays->at(i), data[i].fClassWithArrays); + EXPECT_EQ(takeClassWithArrays_fObjIndex->at(i), data[i].fClassWithArrays.fObjIndex); + if (GetParam()) { + check_1d_coll(takeClassWithArrays_fArrFl->at(i), data[i].fClassWithArrays.fArrFl); + check_1d_coll(takeClassWithArrays_fVecFl->at(i), data[i].fClassWithArrays.fVecFl); + check_2d_coll(takeClassWithArrays_fVecVecFl->at(i), data[i].fClassWithArrays.fVecVecFl); + } + check_coll_class_with_arrays(takeVecClassWithArrays->at(i), data[i].fVecClassWithArrays); + if (GetParam()) { + std::vector> expectedVecArrFl(data[i].fVecClassWithArrays.size()); + for (size_t j = 0; j < data[i].fVecClassWithArrays.size(); ++j) { + expectedVecArrFl[j] = data[i].fVecClassWithArrays[j].fArrFl; + } + check_2d_coll(takeVecClassWithArrays_fArrFl->at(i), expectedVecArrFl); + } + } +} + +TEST_P(ClassWithSequenceContainersTest, JittedOps) +{ + ROOT::RDataFrame df{fDatasetName, fFileName}; + ROOT::RDF::RNode node = df; + + const std::string classWithArraysArrFlColName = GetParam() ? "classWithArrays.fArrFl[3]" : "classWithArrays.fArrFl"; + const std::string vecClassWithArraysArrFlColName = + GetParam() ? "vecClassWithArrays.fArrFl[3]" : "vecClassWithArrays.fArrFl"; + + auto data = generateClassWithSequenceContainersDataPlusOne(); + + // all the next define calls should be written with jitted code in strings instead of C++ lambdas + node = node.Define("topArrFl_plus_1", + R"CODE( + std::array result; + for (size_t i = 0; i < topArrFl.size(); ++i) { + result[i] = topArrFl[i] + 1.0f; + } + return result; + )CODE"); + node = node.Define("topVecFl_plus_1", + R"CODE( + std::vector result(topVecFl.size()); + for (size_t i = 0; i < topVecFl.size(); ++i) { + result[i] = topVecFl[i] + 1.0f; + } + return result; + )CODE"); + node = node.Define("topVecVecFl_plus_1", + R"CODE( + std::vector> result(topVecVecFl.size()); + for (size_t i = 0; i < topVecVecFl.size(); ++i) { + result[i].resize(topVecVecFl[i].size()); + for (size_t j = 0; j < topVecVecFl[i].size(); ++j) { + result[i][j] = topVecVecFl[i][j] + 1.0f; + } + } + return result; + )CODE"); + node = node.Define("classWithArrays_plus_1", + R"CODE( + ClassWithSequenceContainers result = classWithArrays; + result.fObjIndex += 1; + for (auto &val : result.fArrFl) { + val += 1; + } + for (auto &arr : result.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : result.fArrVecFl) { + for (auto &val : vec) { + val += 1; + } + } + for (auto &val : result.fVecFl) { + val += 1; + } + for (auto &arr : result.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vec : result.fVecVecFl) { + for (auto &val : vec) { + val += 1; + } + } + return result; + )CODE"); + node = node.Define("vecClassWithArrays_plus_1", + R"CODE( + ROOT::VecOps::RVec result = vecClassWithArrays; + for (auto &obj : result) { + obj.fObjIndex += 1; + for (auto &val : obj.fArrFl) { + val += 1; + } + for (auto &arr : obj.fArrArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vvec : obj.fArrVecFl) { + for (auto &val : vvec) { + val += 1; + } + } + for (auto &val : obj.fVecFl) { + val += 1; + } + for (auto &arr : obj.fVecArrFl) { + for (auto &val : arr) { + val += 1; + } + } + for (auto &vvec : obj.fVecVecFl) { + for (auto &val : vvec) { + val += 1; + } + } + } + return result; + )CODE"); + // Also create modified values for data member columns + node = node.Define("classWithArrays_fObjIndex_plus_1", "return classWithArrays.fObjIndex + 1;"); + // Using a branch with an invalid C++ name will break the jitted execution + // node = node.Alias("classWithArrays_fArrFl", "classWithArrays.fArrFl[3]"); + // node = node.Define("classWithArrays_fArrFl_plus_1", + // R"CODE( + // std::array result; + // for (size_t i = 0; i < classWithArrays_fArrFl.size(); ++i) { + // result[i] = classWithArrays_fArrFl[i] + 1.0f; + // } + // return result; + // )CODE"); + node = node.Define("classWithArrays_fVecFl_plus_1", + R"CODE( + std::vector result(classWithArrays.fVecFl.size()); + for (size_t i = 0; i < classWithArrays.fVecFl.size(); ++i) { + result[i] = classWithArrays.fVecFl[i] + 1.0f; + } + return result; + )CODE"); + node = node.Define("classWithArrays_fVecVecFl_plus_1", + R"CODE( + std::vector> result(classWithArrays.fVecVecFl.size()); + for (size_t i = 0; i < classWithArrays.fVecVecFl.size(); ++i) { + result[i].resize(classWithArrays.fVecVecFl[i].size()); + for (size_t j = 0; j < classWithArrays.fVecVecFl[i].size(); ++j) { + result[i][j] = classWithArrays.fVecVecFl[i][j] + 1.0f; + } + } + return result; + )CODE"); + // Using a branch with an invalid C++ name will break the jitted execution + // node = node.Alias("vecClassWithArrays_fArrFl", "vecClassWithArrays.fArrFl[3]"); + // node = node.Define("vecClassWithArrays_fArrFl_plus_1", + // R"CODE( + // std::vector> result(vecClassWithArrays_fArrFl.size()); + // for (size_t i = 0; i < vecClassWithArrays_fArrFl.size(); ++i) { + // for (size_t j = 0; j < vecClassWithArrays_fArrFl[i].size(); ++j) { + // result[i][j] = vecClassWithArrays_fArrFl[i][j] + 1.0f; + // } + // } + // return result; + // )CODE"); + + // Take each column individually and check the content + auto takeTopArrFl = node.Take>("topArrFl_plus_1"); + // Not supported: std::array with T being a class type as top-level branch + // auto takeTopArrArrFl = node.Take>("topArrArrFl"); + // auto takeTopArrVecFl = node.Take>("topArrVecFl"); + auto takeTopVecFl = node.Take>("topVecFl_plus_1"); + // Not supported: Could not find the real data member '_M_elems[3]' when constructing the branch 'topVecArrFl' + // auto takeTopVecArrFl = node.Take>>("topVecArrFl"); + auto takeTopVecVecFl = node.Take>>("topVecVecFl_plus_1"); + auto takeClassWithArrays = node.Take("classWithArrays_plus_1"); + auto takeClassWithArrays_fObjIndex = node.Take("classWithArrays_fObjIndex_plus_1"); + // Using a branch with an invalid C++ name will break the jitted execution + // auto takeClassWithArrays_fArrFl = node.Take>("classWithArrays_fArrFl_plus_1"); + // TODO: array of array is currently not properly handled + // auto takeClassWithArrays_fArrArrFl = node.Take, + // 3>>>("classWithArrays.fArrArrFl[3][3]"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeClassWithArrays_fArrVecFl = + // node.Take>>("classWithArrays.fArrVecFl[3]"); + auto takeClassWithArrays_fVecFl = node.Take>("classWithArrays_fVecFl_plus_1"); + auto takeClassWithArrays_fVecVecFl = node.Take>>("classWithArrays_fVecVecFl_plus_1"); + auto takeVecClassWithArrays = + node.Take>("vecClassWithArrays_plus_1"); + // Using a branch with an invalid C++ name will break the jitted execution + // auto takeVecClassWithArrays_fArrFl = + // node.Take>>("vecClassWithArrays_fArrFl_plus_1"); + // TODO: array of vector as a data member is currently not properly handled + // auto takeVecClassWithArrays_fArrVecFl = + // node.Take,3>>>("vecClassWithArrays.fArrVecFl[3]"); + // TODO: vector of vector throws `std::bad_alloc` currently + // auto takeVecClassWithArrays_fVecFl = + // node.Take>>("vecClassWithArrays.fVecFl"); auto + // takeVecClassWithArrays_fVecVecFl = + // node.Take>>>("vecClassWithArrays.fVecVecFl"); + + auto nEvents = takeTopArrFl->size(); + EXPECT_EQ(nEvents, 1); + + for (decltype(nEvents) i = 0; i < nEvents; ++i) { + check_1d_coll(takeTopArrFl->at(i), data[i].fArrFl); + check_1d_coll(takeTopVecFl->at(i), data[i].fVecFl); + check_2d_coll(takeTopVecVecFl->at(i), data[i].fVecVecFl); + check_class_with_arrays(takeClassWithArrays->at(i), data[i].fClassWithArrays); + EXPECT_EQ(takeClassWithArrays_fObjIndex->at(i), data[i].fClassWithArrays.fObjIndex); + // Using a branch with an invalid C++ name will break the jitted execution + // check_1d_coll(takeClassWithArrays_fArrFl->at(i), data[i].fClassWithArrays.fArrFl); + check_1d_coll(takeClassWithArrays_fVecFl->at(i), data[i].fClassWithArrays.fVecFl); + check_2d_coll(takeClassWithArrays_fVecVecFl->at(i), data[i].fClassWithArrays.fVecVecFl); + check_coll_class_with_arrays(takeVecClassWithArrays->at(i), data[i].fVecClassWithArrays); + // Using a branch with an invalid C++ name will break the jitted execution + // std::vector> expectedVecArrFl(data[i].fVecClassWithArrays.size()); + // for (size_t j = 0; j < data[i].fVecClassWithArrays.size(); ++j) { + // expectedVecArrFl[j] = data[i].fVecClassWithArrays[j].fArrFl; + // } + // check_2d_coll(takeVecClassWithArrays_fArrFl->at(i), expectedVecArrFl); + } +} + +INSTANTIATE_TEST_SUITE_P(Run, ClassWithSequenceContainersTest, ::testing::Values(true, false)); + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx b/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx index ddffc47ae9188..b2dc37c8e85a1 100644 --- a/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx @@ -77,7 +77,8 @@ public: enum class ECollectionType { kRVec, kStdArray, - kRVecBool + kRVecBool, + kStdVector }; RTreeUntypedArrayColumnReader(TTreeReader &r, std::string_view colName, std::string_view valueTypeName, @@ -99,6 +100,9 @@ private: /// We return a reference to this RVec to clients, to guarantee a stable address and contiguous memory layout. RVec fRVec{}; + /// When the user explicitly requests std::vector, we use this std::vector as a stable storage. + std::vector fStdVector{}; + Long64_t fLastEntry = -1; /// The size of the collection value type. @@ -108,6 +112,10 @@ private: bool fCopyWarningPrinted = false; void *GetImpl(Long64_t entry) override; + + void *ReadStdArray(Long64_t entry); + void *ReadStdVector(Long64_t entry); + void *ReadRVec(Long64_t entry); }; class R__CLING_PTRCHECK(off) RMaskedColumnReader : public ROOT::Detail::RDF::RColumnReaderBase { diff --git a/tree/dataframe/src/RDFUtils.cxx b/tree/dataframe/src/RDFUtils.cxx index 2639e281f8ec0..e62c996323d66 100644 --- a/tree/dataframe/src/RDFUtils.cxx +++ b/tree/dataframe/src/RDFUtils.cxx @@ -238,9 +238,14 @@ std::string GetLeafTypeName(TLeaf *leaf, const std::string &colName) // this is a fixed-sized array (we do not differentiate between variable- and fixed-sized arrays) colType = ComposeRVecTypeName(colType); } else if (leaf->GetLeafCount() != nullptr && leaf->GetLenStatic() > 1) { - // we do not know how to deal with this branch - throw std::runtime_error("TTree leaf " + colName + - " has both a leaf count and a static length. This is not supported."); + // This case is encountered when a branch is a collection (e.g. std::vector) of a user-defined class which has + // a data member that is a fixed-size array. Here, 'leaf' is said data member, and the user could read it + // partially as std::vector>. We expose it as ROOT::RVec> for consistency with + // other collection types. + // WARNING: Currently this considers only the possibility of a 1-dim array, as TLeaf does not expose information + // to get all dimension lengths of a multi-dim array in a straightforward way (e.g. with one API call). + auto valueType = colType; + colType = "ROOT::VecOps::RVecGetLenStatic()) + ">>"; } return colType; diff --git a/tree/dataframe/src/RTTreeDS.cxx b/tree/dataframe/src/RTTreeDS.cxx index 824ff39c565a4..28af66c224d16 100644 --- a/tree/dataframe/src/RTTreeDS.cxx +++ b/tree/dataframe/src/RTTreeDS.cxx @@ -60,6 +60,14 @@ GetCollectionInfo(const std::string &typeName) ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kStdArray}; } + // Find TYPE from std::vector + if (auto pos = beginType.find("vector<"); pos != std::string::npos) { + const auto begin = typeName.find_first_of('<', pos) + 1; + const auto end = typeName.find_last_of('>'); + const auto innerTypeName = typeName.substr(begin, end - begin); + return {true, innerTypeName, ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kStdVector}; + } + return {false, "", ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ECollectionType::kRVec}; } diff --git a/tree/dataframe/src/RTreeColumnReader.cxx b/tree/dataframe/src/RTreeColumnReader.cxx index c7f1c929b7ecf..0786bb1a7c6af 100644 --- a/tree/dataframe/src/RTreeColumnReader.cxx +++ b/tree/dataframe/src/RTreeColumnReader.cxx @@ -31,12 +31,77 @@ ROOT::Internal::RDF::RTreeUntypedValueColumnReader::RTreeUntypedValueColumnReade ROOT::Internal::RDF::RTreeUntypedValueColumnReader::~RTreeUntypedValueColumnReader() = default; -void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::GetImpl(Long64_t entry) +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadStdArray(Long64_t entry) +{ + if (entry == fLastEntry) + return fRVec.data(); // We return the RVec we already created + + auto &readerArray = *fTreeArray; + // GetSize is called here to trigger actual reading of the branch proxy, which also sets the appropriate read status + const auto readerArraySize = readerArray.GetSize(); + + // The reader could not read an array, signal this back to the node requesting the value + if (R__unlikely(readerArray.GetReadStatus() == ROOT::Internal::TTreeReaderValueBase::EReadStatus::kReadError)) + return nullptr; + + // std::array storage should always be contiguous + assert(readerArray.IsContiguous() && "std::array storage should always be contiguous"); + + if (readerArraySize > 0) { + // trigger loading of the contents of the TTreeReaderArray + // the address of the first element in the reader array is not necessarily equal to + // the address returned by the GetAddress method + RVec rvec(readerArray.At(0), readerArraySize); + swap(fRVec, rvec); + } else { + fRVec.clear(); + } + + fLastEntry = entry; + return fRVec.data(); +} + +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadStdVector(Long64_t entry) +{ + if (entry == fLastEntry) + return &fStdVector; // We return the std::vector we already created + + auto &readerArray = *fTreeArray; + // GetSize is called here to trigger actual reading of the branch proxy, which also sets the appropriate read status + const auto readerArraySize = readerArray.GetSize(); + + // The reader could not read an array, signal this back to the node requesting the value + if (R__unlikely(readerArray.GetReadStatus() == ROOT::Internal::TTreeReaderValueBase::EReadStatus::kReadError)) + return nullptr; + + // There is no zero-copy constructor for std::vector, so we need to copy the data in any case. + if (readerArraySize > 0) { + // Caching the value type size since GetValueSize might be expensive. + if (fValueSize == 0) + fValueSize = readerArray.GetValueSize(); + assert(fValueSize > 0 && "Could not retrieve size of collection value type."); + // Array is not contiguous, make a full copy of it. + fStdVector.clear(); + fStdVector.reserve(readerArraySize * fValueSize); + for (std::size_t i{0}; i < readerArraySize; i++) { + auto val = readerArray.At(i); + std::copy(val, val + fValueSize, std::back_inserter(fStdVector)); + } + } else { + fStdVector.clear(); + } + + fLastEntry = entry; + return &fStdVector; +} + +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadRVec(Long64_t entry) { if (entry == fLastEntry) - return &fRVec; // we already pointed our fRVec to the right address + return &fRVec; // We return the RVec we already created auto &readerArray = *fTreeArray; + // GetSize is called here to trigger actual reading of the branch proxy, which also sets the appropriate read status const auto readerArraySize = readerArray.GetSize(); // The reader could not read an array, signal this back to the node requesting the value @@ -90,11 +155,20 @@ void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::GetImpl(Long64_t entry fRVec.clear(); } } + fLastEntry = entry; + return &fRVec; +} + +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::GetImpl(Long64_t entry) +{ if (fCollectionType == ECollectionType::kStdArray) - return fRVec.data(); - else - return &fRVec; + return ReadStdArray(entry); + + if (fCollectionType == ECollectionType::kStdVector) + return ReadStdVector(entry); + + return ReadRVec(entry); } ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::RTreeUntypedArrayColumnReader(TTreeReader &r, diff --git a/tree/treeplayer/inc/TTreeReaderArray.h b/tree/treeplayer/inc/TTreeReaderArray.h index a799134bad788..e39a5cd79bd33 100644 --- a/tree/treeplayer/inc/TTreeReaderArray.h +++ b/tree/treeplayer/inc/TTreeReaderArray.h @@ -26,6 +26,15 @@ Base class of TTreeReaderArray. */ class TTreeReaderArrayBase : public TTreeReaderValueBase { + struct StreamerElementArrayInfo { + int fArrayNDims{}; // Number of dimensions of the n-dim array + int fArrayCumulativeLength{}; // Total length of the n-dim array as if it were flattened + int fTDataTypeCode{}; // A number representing the value type of the array, queried via TDataType. + std::array fArrayDims{}; // Length of each dimension of the n-dim array. Max 5 dimensions, aligned with + // TStreamerElement::fMaxIndex + }; + StreamerElementArrayInfo FillStreamerElementArrayInfo(TStreamerElement *elem); + public: TTreeReaderArrayBase(TTreeReader *reader, const char *branchname, TDictionary *dict) : TTreeReaderValueBase(reader, branchname, dict) @@ -48,7 +57,8 @@ class TTreeReaderArrayBase : public TTreeReaderValueBase { bool GetBranchAndLeaf(TBranch *&branch, TLeaf *&myLeaf, TDictionary *&branchActualType, bool suppressErrorsForMissingBranch = false); void SetImpl(TBranch *branch, TLeaf *myLeaf); - const char *GetBranchContentDataType(TBranch *branch, TString &contentTypeName, TDictionary *&dict); + const char *GetBranchContentDataType(TBranch *branch, TString &contentTypeName, TDictionary *&dict, + StreamerElementArrayInfo &arrInfo); std::unique_ptr fImpl; // Common interface to collections diff --git a/tree/treeplayer/src/TTreeReaderArray.cxx b/tree/treeplayer/src/TTreeReaderArray.cxx index 4d0f9c85ca54d..ac99b7f3202a6 100644 --- a/tree/treeplayer/src/TTreeReaderArray.cxx +++ b/tree/treeplayer/src/TTreeReaderArray.cxx @@ -107,7 +107,9 @@ bool IsCPContiguous(const TVirtualCollectionProxy &cp) UInt_t GetCPValueSize(const TVirtualCollectionProxy &cp) { - // This works only if the collection proxy value type is a fundamental type + if (auto cl = cp.GetValueClass()) + return cl->Size(); + auto &&eDataType = cp.GetType(); auto *tDataType = TDataType::GetDataType(eDataType); return tDataType ? tDataType->Size() : 0; @@ -645,7 +647,8 @@ void ROOT::Internal::TTreeReaderArrayBase::CreateProxy() if (!myLeaf) { TString branchActualTypeName; - const char *nonCollTypeName = GetBranchContentDataType(branch, branchActualTypeName, branchActualType); + StreamerElementArrayInfo arrInfo; + const char *nonCollTypeName = GetBranchContentDataType(branch, branchActualTypeName, branchActualType, arrInfo); if (nonCollTypeName) { Error("TTreeReaderArrayBase::CreateContentProxy()", "The branch %s contains data of type %s, which should be accessed through a TTreeReaderValue< %s >.", @@ -671,7 +674,54 @@ void ROOT::Internal::TTreeReaderArrayBase::CreateProxy() return; } - auto matchingDataType = [](TDictionary *left, TDictionary *right) -> bool { + auto matchingArrayInfo = [](TDictionary *requestedDict, const StreamerElementArrayInfo &info) { + // Support the case of a data member of a class being a std::array (or generally an n-dim fixed size array) + // In this case the 'left' TDictionary has been requested as a TClass (e.g. std::array) and thus + // will not be a TDataType, but the 'right' TDictionary will be retrieved as a TDataType (e.g. int), because + // of how the fixed-size array data member is streamed. + // The 'right' TDictionary in this case refers to a TBranchElement representing the data member. We already + // gathered the array info by accessing the TStreamerElement of the TBranchElement, now we need to do the same + // for the 'left' TDictionary, which represents what the user requested. + auto cl = dynamic_cast(requestedDict); + if (!cl) + return false; + auto streamerInfo = cl->GetStreamerInfo(); + if (!streamerInfo) + return false; + + auto streamerElements = streamerInfo->GetElements(); + if (!streamerElements) + return false; + + // This part of the logic currently supports the specific use cases where there is only one streamer element + // corresponding to the array data member that we are trying to partially read into a collection. Without + // further use cases surfacing, we do not support more than one streamer element + if (streamerElements->GetEntries() > 1) + return false; + + auto streamerElement = dynamic_cast(streamerElements->At(0)); + if (!streamerElement) + return false; + + int dataTypeCode{}; + if (auto *dataType = gROOT->GetType(streamerElement->GetTypeNameBasic())) + dataTypeCode = dataType->GetType(); + + auto checkArrayDims = [&]() { + for (int i = 0; i < info.fArrayNDims; i++) { + if (info.fArrayDims[i] != streamerElement->GetMaxIndex(i)) + return false; + } + return true; + }; + + return streamerElement->GetArrayDim() == info.fArrayNDims && + streamerElement->GetArrayLength() == info.fArrayCumulativeLength && checkArrayDims() && + dataTypeCode == info.fTDataTypeCode; + }; + + auto matchingDataType = [&matchingArrayInfo](TDictionary *left, TDictionary *right, + const StreamerElementArrayInfo &info) -> bool { if (left == right) return true; if (!left || !right) @@ -687,6 +737,22 @@ void ROOT::Internal::TTreeReaderArrayBase::CreateProxy() if ((left_datatype && right_enum && left_datatype->GetType() == right_enum->GetUnderlyingType()) || (right_datatype && left_enum && right_datatype->GetType() == left_enum->GetUnderlyingType())) return true; + + // Allow reading nested std::array data members of top-level std::vector types. The user has requested + // e.g. TTreeReaderArray> and we allow partial reading of the std::vector as a + // collection of std::array + if (matchingArrayInfo(left, info)) + return true; + + // Allow reading a std::array data member of a top-level class branch when requesting a TTreeReaderArray + // of the same type as the std::array data type (e.g. branch contains std::array and user requests + // TTreeReaderArray). In this case the 'left' dictionary is going to be a TDataType of int. + if (left_datatype) { + auto typeCode = left_datatype->GetType(); + if (typeCode > 0 && typeCode == info.fTDataTypeCode) + return true; + } + if (!left_datatype || !right_datatype) return false; auto l = left_datatype->GetType(); @@ -698,7 +764,7 @@ void ROOT::Internal::TTreeReaderArrayBase::CreateProxy() (l == kFloat16_t && r == kFloat_t) || (l == kFloat_t && r == kFloat16_t)); }; - if (!matchingDataType(fDict, branchActualType)) { + if (!matchingDataType(fDict, branchActualType, arrInfo)) { Error("TTreeReaderArrayBase::CreateContentProxy()", "The branch %s contains data of type %s. It cannot be accessed by a TTreeReaderArray<%s>", fBranchName.Data(), branchActualType->GetName(), fDict->GetName()); @@ -919,6 +985,24 @@ void ROOT::Internal::TTreeReaderArrayBase::SetImpl(TBranch *branch, TLeaf *myLea } } +ROOT::Internal::TTreeReaderArrayBase::StreamerElementArrayInfo +ROOT::Internal::TTreeReaderArrayBase::FillStreamerElementArrayInfo(TStreamerElement *element) +{ + StreamerElementArrayInfo arrInfo{}; + if (!element) + return arrInfo; + + arrInfo.fArrayNDims = element->GetArrayDim(); + arrInfo.fArrayCumulativeLength = element->GetArrayLength(); + if (auto *datatype = gROOT->GetType(element->GetTypeNameBasic())) { + arrInfo.fTDataTypeCode = datatype->GetType(); + } + for (int i = 0; i < arrInfo.fArrayNDims; ++i) + arrInfo.fArrayDims[i] = element->GetMaxIndex(i); + + return arrInfo; +} + //////////////////////////////////////////////////////////////////////////////// /// Access a branch's collection content (not the collection itself) /// through a proxy. @@ -930,7 +1014,8 @@ void ROOT::Internal::TTreeReaderArrayBase::SetImpl(TBranch *branch, TLeaf *myLea /// In all other cases, NULL is returned. const char *ROOT::Internal::TTreeReaderArrayBase::GetBranchContentDataType(TBranch *branch, TString &contentTypeName, - TDictionary *&dict) + TDictionary *&dict, + StreamerElementArrayInfo &arrInfo) { dict = nullptr; contentTypeName = ""; @@ -999,6 +1084,8 @@ const char *ROOT::Internal::TTreeReaderArrayBase::GetBranchContentDataType(TBran contentTypeName = TDataType::GetTypeName(dtData); return nullptr; } + // Fill information about the data member being an n-dim array + arrInfo = FillStreamerElementArrayInfo(brElement->GetInfo()->GetElement(brElement->GetID())); return nullptr; } else if (ExpectedTypeRet == 1) { int brID = brElement->GetID(); @@ -1026,7 +1113,8 @@ const char *ROOT::Internal::TTreeReaderArrayBase::GetBranchContentDataType(TBran if (id >= 0) { TStreamerElement *element = (TStreamerElement *)streamerInfo->GetElements()->At(id); - + // Fill information about the data member being an n-dim array + arrInfo = FillStreamerElementArrayInfo(element); if (element->IsA() == TStreamerSTL::Class()) { TClass *myClass = brElement->GetCurrentClass(); if (!myClass) {