From 59bb3dfca087bede3aabd01edf0badb0bc5a54cd Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 27 Jun 2025 17:03:47 -0700 Subject: [PATCH] Introduce MergedDataMap Differential Revision: [D76529405](https://our.internmc.facebook.com/intern/diff/D76529405/) [ghstack-poisoned] --- .../test/flat_tensor_data_map_test.cpp | 4 +- runtime/executor/merged_data_map.h | 179 ++++++++++++++++++ runtime/executor/targets.bzl | 10 + .../executor/test/merged_data_map_test.cpp | 176 +++++++++++++++++ runtime/executor/test/targets.bzl | 14 ++ 5 files changed, 381 insertions(+), 2 deletions(-) create mode 100644 runtime/executor/merged_data_map.h create mode 100644 runtime/executor/test/merged_data_map_test.cpp diff --git a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp index 5a94b47b954..37e1cd2edac 100644 --- a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp +++ b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp @@ -33,8 +33,8 @@ class FlatTensorDataMapTest : public ::testing::Test { // first. executorch::runtime::runtime_init(); - // Load data map. The eager linear model is defined at: - // //executorch/test/models/linear_model.py + // Load data map. The eager addmul model is defined at: + // //executorch/test/models/export_program.py const char* path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); Result loader = FileDataLoader::from(path); ASSERT_EQ(loader.error(), Error::Ok); diff --git a/runtime/executor/merged_data_map.h b/runtime/executor/merged_data_map.h new file mode 100644 index 00000000000..ed9d29ca7d6 --- /dev/null +++ b/runtime/executor/merged_data_map.h @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { +/** + * A NamedDataMap implementation that wraps other NamedDataMaps. + */ +template +class MergedDataMap final + : public executorch::ET_RUNTIME_NAMESPACE::NamedDataMap { + public: + /** + * Creates a new NamedDataMap that takes in other data maps. + * + * @param[in] data_maps Array of NamedDataMap pointers to merge. + * Note: the data maps must outlive the MergedDataMap instance. + */ + static executorch::runtime::Result load( + const std::array& data_maps) { + std::array valid_data_maps; + size_t num_data_maps = 0; + for (size_t i = 0; i < data_maps.size(); i++) { + if (data_maps[i] != nullptr) { + valid_data_maps[num_data_maps++] = data_maps[i]; + } + } + ET_CHECK_OR_RETURN_ERROR( + num_data_maps > 0, InvalidArgument, "All provided data maps are null"); + + // Check for duplicate keys. + for (size_t i = 0; i < num_data_maps; i++) { + for (size_t j = i + 1; j < num_data_maps; j++) { + for (int k = 0; k < valid_data_maps[i]->get_num_keys().get(); k++) { + const auto key = valid_data_maps[i]->get_key(k).get(); + ET_CHECK_OR_RETURN_ERROR( + valid_data_maps[j]->get_tensor_layout(key).error() == + executorch::runtime::Error::NotFound, + InvalidArgument, + "Duplicate key %s in data maps at index %zu and %zu", + key, + i, + j); + } + } + } + return MergedDataMap(std::move(valid_data_maps), num_data_maps); + } + + /** + * Retrieve the tensor_layout for the specified key. + * + * @param[in] key The name of the tensor to get metadata on. + * + * @return Error::NotFound if the key is not present. + */ + ET_NODISCARD + executorch::runtime::Result< + const executorch::ET_RUNTIME_NAMESPACE::TensorLayout> + get_tensor_layout(executorch::aten::string_view key) const override { + for (size_t i = 0; i < num_data_maps_; i++) { + auto layout = data_maps_[i]->get_tensor_layout(key); + if (layout.ok()) { + return layout.get(); + } + if (layout.error() != executorch::runtime::Error::NotFound) { + return layout.error(); + } + } + return executorch::runtime::Error::NotFound; + } + + /** + * Retrieve read-only data for the specified key. + * + * @param[in] key The name of the tensor to get data on. + * + * @return error if the key is not present or data cannot be loaded. + */ + ET_NODISCARD + executorch::runtime::Result get_data( + executorch::aten::string_view key) const override { + for (size_t i = 0; i < num_data_maps_; i++) { + auto data = data_maps_[i]->get_data(key); + if (data.error() != executorch::runtime::Error::NotFound) { + return data; + } + } + return executorch::runtime::Error::NotFound; + } + + /** + * Loads the data of the specified tensor into the provided buffer. + * + * @param[in] key The name of the tensor to get the data of. + * @param[in] buffer The buffer to load data into. Must point to at least + * `size` bytes of memory. + * @param[in] size The number of bytes to load. + * + * @returns an Error indicating if the load was successful. + */ + ET_NODISCARD executorch::runtime::Error load_data_into( + executorch::aten::string_view key, + void* buffer, + size_t size) const override { + for (size_t i = 0; i < num_data_maps_; i++) { + auto error = data_maps_[i]->load_data_into(key, buffer, size); + if (error != executorch::runtime::Error::NotFound) { + return error; + } + } + return executorch::runtime::Error::NotFound; + } + + /** + * @returns The number of keys in the map. + */ + ET_NODISCARD executorch::runtime::Result get_num_keys() + const override { + uint32_t num_keys = 0; + for (size_t i = 0; i < num_data_maps_; i++) { + num_keys += data_maps_[i]->get_num_keys().get(); + } + return num_keys; + } + + /** + * @returns The key at the specified index, error if index out of bounds. + */ + ET_NODISCARD executorch::runtime::Result get_key( + uint32_t index) const override { + uint32_t total_num_keys = get_num_keys().get(); + ET_CHECK_OR_RETURN_ERROR( + index >= 0 && index < total_num_keys, + InvalidArgument, + "Index %u out of range of size %u", + index, + total_num_keys); + for (size_t i = 0; i < num_data_maps_; i++) { + auto num_keys = data_maps_[i]->get_num_keys().get(); + if (index < num_keys) { + return data_maps_[i]->get_key(index); + } + index -= num_keys; + } + // Shouldn't reach here. + return executorch::runtime::Error::Internal; + } + + MergedDataMap(MergedDataMap&&) noexcept = default; + + ~MergedDataMap() override = default; + + private: + MergedDataMap( + const std::array& data_maps, + size_t num_data_maps) + : data_maps_(data_maps), num_data_maps_(num_data_maps){}; + + // Not copyable or assignable. + MergedDataMap(const MergedDataMap& rhs) = delete; + MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete; + MergedDataMap& operator=(const MergedDataMap& rhs) = delete; + + const std::array data_maps_; + const size_t num_data_maps_; +}; + +} // namespace runtime +} // namespace executorch diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 649b2c13cc1..98165373b73 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -69,6 +69,16 @@ def define_common_targets(): exported_preprocessor_flags = [] if runtime.is_oss else ["-DEXECUTORCH_INTERNAL_FLATBUFFERS=1"], ) + runtime.cxx_library( + name = "merged_data_map" + aten_suffix, + exported_headers = [ + "merged_data_map.h", + ], + exported_deps = [ + "//executorch/runtime/core:named_data_map" + aten_suffix, + ], + ) + runtime.cxx_library( name = "program" + aten_suffix, exported_deps = [ diff --git a/runtime/executor/test/merged_data_map_test.cpp b/runtime/executor/test/merged_data_map_test.cpp new file mode 100644 index 00000000000..6e65ef9a558 --- /dev/null +++ b/runtime/executor/test/merged_data_map_test.cpp @@ -0,0 +1,176 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::extension::FileDataLoader; +using executorch::extension::FlatTensorDataMap; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MergedDataMap; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::TensorLayout; + +class MergedDataMapTest : public ::testing::Test { + protected: + void load_flat_tensor_data_map(const char* path, const char* module_name) { + Result loader = FileDataLoader::from(path); + ASSERT_EQ(loader.error(), Error::Ok); + loaders_.insert( + {module_name, + std::make_unique(std::move(loader.get()))}); + + Result data_map = + FlatTensorDataMap::load(loaders_[module_name].get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + data_maps_.insert( + {module_name, + std::make_unique(std::move(data_map.get()))}); + } + + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Load FlatTensor data maps. + // The eager addmul and linear models are defined at: + // //executorch/test/models/export_program.py + load_flat_tensor_data_map( + std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"), "addmul"); + load_flat_tensor_data_map( + std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear"); + } + + private: + // Must outlive data_maps_, but tests shouldn't need to touch it. + std::unordered_map> loaders_; + + protected: + std::unordered_map> data_maps_; +}; + +// Check that two tensor layouts are equivalent. +void check_tensor_layout(TensorLayout& layout1, TensorLayout& layout2) { + EXPECT_EQ(layout1.scalar_type(), layout2.scalar_type()); + EXPECT_EQ(layout1.nbytes(), layout2.nbytes()); + EXPECT_EQ(layout1.sizes().size(), layout2.sizes().size()); + for (size_t i = 0; i < layout1.sizes().size(); i++) { + EXPECT_EQ(layout1.sizes()[i], layout2.sizes()[i]); + } + EXPECT_EQ(layout1.dim_order().size(), layout2.dim_order().size()); + for (size_t i = 0; i < layout1.dim_order().size(); i++) { + EXPECT_EQ(layout1.dim_order()[i], layout2.dim_order()[i]); + } +} + +// Given that ndm is part of merged, check that all the API calls on ndm produce +// the same results as merged. +void compare_ndm_api_calls( + const NamedDataMap* ndm, + const NamedDataMap* merged) { + size_t num_keys = ndm->get_num_keys().get(); + for (size_t i = 0; i < num_keys; i++) { + auto key = ndm->get_key(i).get(); + + // Compare get_tensor_layout. + auto ndm_meta = ndm->get_tensor_layout(key).get(); + auto merged_meta = merged->get_tensor_layout(key).get(); + check_tensor_layout(ndm_meta, merged_meta); + + // Coompare get_data. + auto ndm_data = ndm->get_data(key); + auto merged_data = merged->get_data(key); + EXPECT_EQ(ndm_data.get().size(), merged_data.get().size()); + for (size_t i = 0; i < ndm_meta.nbytes(); i++) { + EXPECT_EQ( + ((uint8_t*)ndm_data.get().data())[i], + ((uint8_t*)merged_data.get().data())[i]); + } + ndm_data->Free(); + merged_data->Free(); + + // Compare load_data_into. + void* ndm_load_into = malloc(ndm_meta.nbytes()); + ASSERT_EQ( + Error::Ok, ndm->load_data_into(key, ndm_load_into, ndm_meta.nbytes())); + + void* merged_load_into = malloc(merged_meta.nbytes()); + ASSERT_EQ( + Error::Ok, + merged->load_data_into(key, merged_load_into, merged_meta.nbytes())); + + for (size_t i = 0; i < ndm_meta.nbytes(); i++) { + EXPECT_EQ(((uint8_t*)ndm_load_into)[i], ((uint8_t*)merged_load_into)[i]); + } + free(ndm_load_into); + free(merged_load_into); + } +} + +TEST_F(MergedDataMapTest, LoadSingleDataMap) { + const std::array data_map = { + data_maps_["addmul"].get()}; + Result> merged_map = MergedDataMap<1>::load(data_map); + EXPECT_EQ(merged_map.error(), Error::Ok); + + // Load one data map into a merged one with storage for up to 5 data maps. + const std::array data_maps = { + data_maps_["addmul"].get(), nullptr, nullptr, nullptr, nullptr}; + Result> merged_map2 = MergedDataMap<5>::load(data_maps); + EXPECT_EQ(merged_map2.error(), Error::Ok); +} + +TEST_F(MergedDataMapTest, LoadNullDataMap) { + const std::array data_maps = {nullptr, nullptr}; + Result> merged_map = MergedDataMap<2>::load(data_maps); + EXPECT_EQ(merged_map.error(), Error::InvalidArgument); +} + +TEST_F(MergedDataMapTest, LoadMultipleDataMaps) { + // Add pte data map here. + const std::array data_maps = { + data_maps_["addmul"].get(), data_maps_["linear"].get()}; + Result> merged_map = MergedDataMap<2>::load(data_maps); + EXPECT_EQ(merged_map.error(), Error::Ok); +} + +TEST_F(MergedDataMapTest, LoadDuplicateDataMapsFail) { + const std::array data_maps = { + data_maps_["addmul"].get(), data_maps_["addmul"].get()}; + Result> merged_map = MergedDataMap<2>::load(data_maps); + EXPECT_EQ(merged_map.error(), Error::InvalidArgument); +} + +TEST_F(MergedDataMapTest, CheckDataMapContents) { + const std::array data_maps = { + data_maps_["addmul"].get(), data_maps_["linear"].get()}; + Result> merged_map = MergedDataMap<2>::load(data_maps); + EXPECT_EQ(merged_map.error(), Error::Ok); + + // Num keys. + size_t addmul_num_keys = data_maps_["addmul"]->get_num_keys().get(); + size_t linear_num_keys = data_maps_["linear"]->get_num_keys().get(); + EXPECT_EQ( + merged_map->get_num_keys().get(), addmul_num_keys + linear_num_keys); + + // API calls produce equivalent results. + compare_ndm_api_calls(data_maps_["addmul"].get(), &merged_map.get()); + compare_ndm_api_calls(data_maps_["linear"].get(), &merged_map.get()); +} diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index 39ff0668d5d..7b4672e4414 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -125,6 +125,7 @@ def define_common_targets(is_fbcode = False): "ET_MODULE_STATEFUL_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleStateful.pte])", "ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])", "ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])", + "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", } runtime.cxx_test( @@ -142,6 +143,19 @@ def define_common_targets(is_fbcode = False): env = modules_env, ) + runtime.cxx_test( + name = "merged_data_map_test", + srcs = [ + "merged_data_map_test.cpp", + ], + deps = [ + "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/flat_tensor:flat_tensor_data_map", + "//executorch/runtime/executor:merged_data_map", + ], + env = modules_env, + ) + runtime.cxx_test( name = "method_test", srcs = [