diff --git a/src/apps/CMakeLists.txt b/src/apps/CMakeLists.txt index ab7df28..7893e54 100644 --- a/src/apps/CMakeLists.txt +++ b/src/apps/CMakeLists.txt @@ -1 +1,2 @@ -add_subdirectory(cost_matrix_based_matching) \ No newline at end of file +add_subdirectory(cost_matrix_based_matching) +add_subdirectory(feature_based_matching) \ No newline at end of file diff --git a/src/apps/feature_based_matching/CMakeLists.txt b/src/apps/feature_based_matching/CMakeLists.txt new file mode 100644 index 0000000..eb999fd --- /dev/null +++ b/src/apps/feature_based_matching/CMakeLists.txt @@ -0,0 +1,15 @@ +find_package(OpenCV REQUIRED) + +add_executable(feature_based_localizer feature_based_localizer.cpp) +target_link_libraries(feature_based_localizer + cxx_flags + glog::glog + path_element + online_database + online_localizer + successor_manager + config_parser + lsh_cv_hashing + cnn_feature + ${OpenCV_LIBS} +) \ No newline at end of file diff --git a/src/apps/feature_based_matching/feature_based_localizer.cpp b/src/apps/feature_based_matching/feature_based_localizer.cpp new file mode 100644 index 0000000..2e67ee8 --- /dev/null +++ b/src/apps/feature_based_matching/feature_based_localizer.cpp @@ -0,0 +1,81 @@ +// Created by O.Vysotska in 2023 + +#include "online_localizer/online_localizer.h" +#include "database/idatabase.h" +#include "database/list_dir.h" +#include "database/online_database.h" +#include "features/cnn_feature.h" +#include "features/ifeature.h" +#include "online_localizer/path_element.h" +#include "relocalizers/lsh_cv_hashing.h" +#include "tools/config_parser/config_parser.h" + +#include + +#include +#include +#include + +namespace loc = localization; + +std::vector> +loadFeatures(const std::string &path2folder) { + LOG(INFO) << "Loading the features to hash with LSH."; + std::vector featureNames = + loc::database::listProtoDir(path2folder, ".Feature"); + std::vector> features; + + for (size_t i = 0; i < featureNames.size(); ++i) { + features.emplace_back( + std::make_unique(featureNames[i])); + fprintf(stderr, "."); + } + fprintf(stderr, "\n"); + LOG(INFO) << "Features were loaded and binarized"; + return features; +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = 1; + LOG(INFO) << "===== Online place recognition with LSH ====\n"; + + if (argc < 2) { + printf("[ERROR] Not enough input parameters.\n"); + printf("Proper usage: ./cost_matrix_based_matching_lsh config_file.yaml\n"); + exit(0); + } + + std::string config_file = argv[1]; + ConfigParser parser; + parser.parseYaml(config_file); + parser.print(); + + const auto database = std::make_unique( + /*queryFeaturesDir=*/parser.path2qu, + /*refFeaturesDir=*/parser.path2ref, + /*type=*/loc::features::FeatureType::Cnn_Feature, + /*bufferSize=*/parser.bufferSize); + + auto relocalizer = std::make_unique( + /*onlineDatabase=*/database.get(), + /*tableNum=*/1, + /*keySize=*/12, + /*multiProbeLevel=*/2); + relocalizer->train(loadFeatures(parser.path2ref)); + + auto successorManager = + std::make_unique( + database.get(), relocalizer.get(), parser.fanOut); + loc::online_localizer::OnlineLocalizer localizer{ + successorManager.get(), parser.expansionRate, parser.nonMatchCost}; + const loc::online_localizer::Matches imageMatches = + localizer.findMatchesTill(parser.querySize); + loc::online_localizer::storeMatchesAsProto(imageMatches, + parser.matchingResult); + loc::database::storeCostsAsProto(database->getEstimatedCosts(), + parser.costsOutputName); + + LOG(INFO) << "Done."; + return 0; +} diff --git a/src/localization/database/online_database.cpp b/src/localization/database/online_database.cpp index 094c966..75eed65 100644 --- a/src/localization/database/online_database.cpp +++ b/src/localization/database/online_database.cpp @@ -27,6 +27,7 @@ #include "database/list_dir.h" #include "features/feature_buffer.h" #include "features/ifeature.h" +#include "localization_protos.pb.h" #include @@ -102,4 +103,31 @@ const features::iFeature &OnlineDatabase::getQueryFeature(int quId) { return addFeatureIfNeeded(*queryBuffer_, quFeaturesNames_, featureType_, quId); } + +void storeCostsAsProto( + const localization::database::OnlineDatabase::MatchingCosts &costs, + const std::string &protoFilename) { + if (costs.empty()) { + LOG(WARNING) << "Matching costs are empty. Nothing to store."; + return; + } + image_sequence_localizer::MatchingCosts costsProto; + for (const auto &[queryId, refValueMap] : costs) { + for (const auto &[refId, value] : refValueMap) { + image_sequence_localizer::MatchingCosts::Element *element = + costsProto.add_elements(); + element->set_query_id(queryId); + element->set_ref_id(refId); + element->set_value(value); + } + } + std::fstream out(protoFilename, + std::ios::out | std::ios::trunc | std::ios::binary); + if (!costsProto.SerializeToOstream(&out)) { + LOG(ERROR) << "Couldn't open the file" << protoFilename; + return; + } + out.close(); + LOG(INFO) << "Wrote matching costs to: " << protoFilename; +} } // namespace localization::database diff --git a/src/localization/database/online_database.h b/src/localization/database/online_database.h index bf16b62..8ea27de 100644 --- a/src/localization/database/online_database.h +++ b/src/localization/database/online_database.h @@ -38,12 +38,15 @@ #include namespace localization::database { + /** * @brief Database for loading and matching features. Caches the computed * matching costs. */ class OnlineDatabase : public iDatabase { public: + using MatchingCosts = + std::unordered_map>; OnlineDatabase(const std::string &queryFeaturesDir, const std::string &refFeaturesDir, features::FeatureType type, int bufferSize, const std::string &costMatrixFile = ""); @@ -54,6 +57,7 @@ class OnlineDatabase : public iDatabase { double computeMatchingCost(int quId, int refId); const features::iFeature &getQueryFeature(int quId); + const MatchingCosts &getEstimatedCosts() const { return costs_; } protected: std::vector quFeaturesNames_; @@ -64,10 +68,14 @@ class OnlineDatabase : public iDatabase { private: std::unique_ptr refBuffer_{}; std::unique_ptr queryBuffer_{}; - std::unordered_map> costs_; + MatchingCosts costs_; std::optional precomputedCosts_ = {}; }; + +void storeCostsAsProto( + const localization::database::OnlineDatabase::MatchingCosts &matches, + const std::string &protoFilename); } // namespace localization::database #endif // SRC_DATABASE_ONLINE_DATABASE_H_ diff --git a/src/localization/online_localizer/online_localizer.cpp b/src/localization/online_localizer/online_localizer.cpp index 36d9b4c..aed15dd 100644 --- a/src/localization/online_localizer/online_localizer.cpp +++ b/src/localization/online_localizer/online_localizer.cpp @@ -92,21 +92,21 @@ Matches OnlineLocalizer::findMatchesTill(int queryId) { } void OnlineLocalizer::writeOutExpanded(const std::string &filename) const { - image_sequence_localizer::Patch patch; + image_sequence_localizer::MatchingCosts matchedCosts; for (const auto &node : expandedRecently_) { - image_sequence_localizer::Patch::Element *element = patch.add_elements(); - element->set_row(node.quId); - element->set_col(node.refId); - element->set_similarity_value(node.idvCost); + image_sequence_localizer::MatchingCosts::Element *element = matchedCosts.add_elements(); + element->set_query_id(node.quId); + element->set_ref_id(node.refId); + element->set_value(node.idvCost); } std::fstream out(filename, std::ios::out | std::ios::trunc | std::ios::binary); - if (!patch.SerializeToOstream(&out)) { + if (!matchedCosts.SerializeToOstream(&out)) { LOG(ERROR) << "Couldn't open the file" << filename; return; } out.close(); - LOG(INFO) << "Wrote patch " << filename; + LOG(INFO) << "Wrote matched costs to: " << filename; } // frontier picking up routine diff --git a/src/localization/tools/config_parser/config_parser.cpp b/src/localization/tools/config_parser/config_parser.cpp index fba2395..3f2a591 100644 --- a/src/localization/tools/config_parser/config_parser.cpp +++ b/src/localization/tools/config_parser/config_parser.cpp @@ -108,9 +108,9 @@ bool ConfigParser::parse(const std::string &iniFile) { ss >> costMatrix; continue; } - if (header == "costOutputName") { + if (header == "costsOutputName") { ss >> header; // reads "=" - ss >> costOutputName; + ss >> costsOutputName; continue; } if (header == "simPlaces") { @@ -139,7 +139,7 @@ void ConfigParser::print() const { printf("== Buffer size: %d\n", bufferSize); printf("== CostMatrix: %s\n", costMatrix.c_str()); - printf("== costOutputName: %s\n", costOutputName.c_str()); + printf("== costsOutputName: %s\n", costsOutputName.c_str()); printf("== matchingResult: %s\n", matchingResult.c_str()); printf("== simPlaces: %s\n", simPlaces.c_str()); } @@ -187,8 +187,8 @@ bool ConfigParser::parseYaml(const std::string &yamlFile) { if (config["costMatrix"]) { costMatrix = config["costMatrix"].as(); } - if (config["costOutputName"]) { - costOutputName = config["costOutputName"].as(); + if (config["costsOutputName"]) { + costsOutputName = config["costsOutputName"].as(); } if (config["simPlaces"]) { simPlaces = config["simPlaces"].as(); diff --git a/src/localization/tools/config_parser/config_parser.h b/src/localization/tools/config_parser/config_parser.h index b089ea6..da146a8 100644 --- a/src/localization/tools/config_parser/config_parser.h +++ b/src/localization/tools/config_parser/config_parser.h @@ -30,28 +30,28 @@ * @brief Class for storing the configuration parameters. */ class ConfigParser { - public: - ConfigParser() {} - bool parse(const std::string &iniFile); - bool parseYaml(const std::string &yamlFile); - void print() const; +public: + ConfigParser() {} + bool parse(const std::string &iniFile); + bool parseYaml(const std::string &yamlFile); + void print() const; - std::string path2qu = ""; - std::string path2ref = ""; - std::string path2quImg = ""; - std::string path2refImg = ""; - std::string imgExt = ""; - std::string costMatrix = ""; - std::string costOutputName = ""; - std::string simPlaces = ""; - std::string hashTable = ""; - std::string matchingResult = "matches.MatchingResult.pb"; + std::string path2qu = ""; + std::string path2ref = ""; + std::string path2quImg = ""; + std::string path2refImg = ""; + std::string imgExt = ""; + std::string costMatrix = ""; + std::string costsOutputName = "costs.MatchingCosts.pb"; + std::string simPlaces = ""; + std::string hashTable = ""; + std::string matchingResult = "matches.MatchingResult.pb"; - int querySize = -1; - int fanOut = -1; - int bufferSize = -1; - double nonMatchCost = -1.0; - double expansionRate = -1.0; + int querySize = -1; + int fanOut = -1; + int bufferSize = -1; + double nonMatchCost = -1.0; + double expansionRate = -1.0; }; /*! \var std::string ConfigParser::path2qu @@ -73,7 +73,7 @@ class ConfigParser { /*! \var std::string ConfigParser::costMatrix \brief stores path to precomputed cost/similarity matrix. */ -/*! \var std::string ConfigParser::costOutputName +/*! \var std::string ConfigParser::costsOutputName \brief stores the name of the produced result for the cost_matrix_based matching. */ @@ -111,4 +111,4 @@ class ConfigParser { typically be selected from 0.5 - 0.7. */ -#endif // SRC_TOOLS_CONFIG_PARSER_CONFIG_PARSER_H_ +#endif // SRC_TOOLS_CONFIG_PARSER_CONFIG_PARSER_H_ diff --git a/src/localization_protos.proto b/src/localization_protos.proto index 0bc3152..5365930 100644 --- a/src/localization_protos.proto +++ b/src/localization_protos.proto @@ -2,12 +2,25 @@ syntax = "proto2"; package image_sequence_localizer; +// Assumes that all values of the matrix is known. +// values.size() == rows*cols. No checks for this though. message CostMatrix { optional int32 rows = 20; optional int32 cols = 21; repeated double values = 1; } +// Is used to store matching costs with associated query and reference id. +// Does not require all elements of the cost matrix to be present +message MatchingCosts{ + message Element { + optional int32 query_id = 1; + optional int32 ref_id = 2; + optional double value = 3; + } + repeated Element elements = 1; +} + message MatchingResult { message Match { optional int32 query_id = 1; @@ -25,12 +38,3 @@ message Feature { optional int32 size = 2; optional string type = 3; } - -message Patch { - message Element { - optional int32 row = 1; - optional int32 col = 2; - optional int32 similarity_value = 3; - } - repeated Element elements = 1; -} \ No newline at end of file diff --git a/src/python/protos_io.py b/src/python/protos_io.py index 9643194..69de81c 100644 --- a/src/python/protos_io.py +++ b/src/python/protos_io.py @@ -18,7 +18,6 @@ def write_feature(filename, proto): def write_cost_matrix(cost_matrix, cost_matrix_file): - cost_matrix_proto = loc_protos.CostMatrix() cost_matrix_proto.rows = cost_matrix.shape[0] cost_matrix_proto.cols = cost_matrix.shape[1] @@ -51,14 +50,14 @@ def read_matching_result(filename): def read_expanded_mask(expanded_patches_dir): - patch_files = list(expanded_patches_dir.glob("*.Patch.pb")) - patch_files.sort() + matching_costs_files = list(expanded_patches_dir.glob("*.MatchingCosts.pb")) + matching_costs_files.sort() mask = [] - for patch_file in patch_files: - f = open(patch_file, "rb") - patch_proto = loc_protos.Patch() - patch_proto.ParseFromString(f.read()) + for matching_costs_file in matching_costs_files: + f = open(matching_costs_file, "rb") + matching_costs_proto = loc_protos.MatchingCosts() + matching_costs_proto.ParseFromString(f.read()) f.close() - mask.extend(patch_proto.elements) + mask.extend(matching_costs_proto.elements) return mask diff --git a/src/python/visualize_localization_result.py b/src/python/visualize_localization_result.py index 48c9e0f..f4ad582 100644 --- a/src/python/visualize_localization_result.py +++ b/src/python/visualize_localization_result.py @@ -6,7 +6,6 @@ def create_combined_image(matching_result, cost_matrix, expanded_mask=None): - rgb_costs = np.zeros((cost_matrix.shape[0], cost_matrix.shape[1], 3)) rgb_costs[:, :, 0] = cost_matrix rgb_costs[:, :, 1] = cost_matrix @@ -48,7 +47,7 @@ def main(): "--expanded_patches_dir", required=False, type=Path, - help="Path to directory with expanded nodes files of type .Patch.pb", + help="Path to directory with expanded nodes files of type .MatchingCosts.pb", ) parser.add_argument( "--image_name", diff --git a/src/viewer/src/App.tsx b/src/viewer/src/App.tsx index 596a834..638aa63 100644 --- a/src/viewer/src/App.tsx +++ b/src/viewer/src/App.tsx @@ -10,6 +10,7 @@ import { ElementProvider } from "./context/ElementContext"; function App() { const [costMatrixProtoFile, setCostMatrixProtoFile] = useState(); + const [matchingCostsProtoFile, setMatchingCostsProtoFile] = useState(); const [matchingResultProtoFile, setMatchingResultProtoFile] = useState(); const [queryImageFiles, setQueryImageFiles] = useState(); @@ -31,6 +32,7 @@ function App() { @@ -42,6 +44,7 @@ function App() { )} diff --git a/src/viewer/src/components/CostMatrixComponent.tsx b/src/viewer/src/components/CostMatrixComponent.tsx index 0801219..e4e31d1 100644 --- a/src/viewer/src/components/CostMatrixComponent.tsx +++ b/src/viewer/src/components/CostMatrixComponent.tsx @@ -29,6 +29,7 @@ function getMatchingResultInZoomBlock( type CostMatrixProps = { costMatrixProtoFile: File; matchingResultProtoFile?: File; + matchingCostsProtoFile?: File; }; function CostMatrixComponent(props: CostMatrixProps): React.ReactElement { @@ -36,6 +37,9 @@ function CostMatrixComponent(props: CostMatrixProps): React.ReactElement { const [image, setImage] = useState(); const [matchingResult, setMatchingResult] = useState(); + // This should be some useful data structure to propagate the expanded costs. + // For now just checking that this coud be loaded. + const [matchingCostsProto, setMatchingCostsProto] = useState(); const [zoomParams, setZoomParams] = useState(); const [zoomedCostMatrix, setZoomedCostMatrix] = useState(); const [matchingResultVisible, setMatchingResultVisible] = @@ -80,6 +84,24 @@ function CostMatrixComponent(props: CostMatrixProps): React.ReactElement { }); }, [props.matchingResultProtoFile]); + + // Read matching costs -> for expanded costs proto file + useEffect(() => { + if (props.matchingCostsProtoFile == null) { + return; + } + readProtoFromFile( + props.matchingCostsProtoFile, + ProtoMessageType.MatchingCosts + ) + .then((matchingCostsProto) => { + setMatchingCostsProto(matchingCostsProto); + }) + .catch((e) => { + console.log("Couldn't read file", props.matchingCostsProtoFile); + }); + }, [props.matchingCostsProtoFile]); + useEffect(() => { if (zoomParams == null || costMatrix == null) { return; diff --git a/src/viewer/src/components/DataLoader.tsx b/src/viewer/src/components/DataLoader.tsx index e69e1ee..b21f72a 100644 --- a/src/viewer/src/components/DataLoader.tsx +++ b/src/viewer/src/components/DataLoader.tsx @@ -1,10 +1,23 @@ type DataLoaderProps = { setMatchingResultProtoFile: (file: File) => void; + setMatchingCostsProtoFile: (file: File) => void; setCostMatrixProtoFile: (file: File) => void; setQueryImageFiles: (files: File[]) => void; setReferenceImageFiles: (file: File[]) => void; }; +function findAndSetFileByType(files: File[], type: string, setter: (file: File)=> void){ + const protoFile = files.find((file) => { + return file.webkitRelativePath.split("/")[1].endsWith(type); + }); + if (protoFile == null) { + console.warn("File of type", type, "was not found"); + } else { + console.log("Found file", protoFile) + setter(protoFile); + } +} + function DataLoader(props: DataLoaderProps): React.ReactElement { async function onChange(event: React.ChangeEvent) { const fileList = event.target.files; @@ -13,26 +26,9 @@ function DataLoader(props: DataLoaderProps): React.ReactElement { } const files = Array.from(fileList); - const costMatrixProtoFile = files.find((file) => { - return file.webkitRelativePath.split("/")[1].endsWith(".CostMatrix.pb"); - }); - if (costMatrixProtoFile == null) { - console.warn("CostMatrix proto file was not found"); - } else { - props.setCostMatrixProtoFile(costMatrixProtoFile); - } - - const matchingResultProtoFile = files.find((file) => { - return file.webkitRelativePath - .split("/")[1] - .endsWith(".MatchingResult.pb"); - }); - - if (matchingResultProtoFile == null) { - console.warn("MatchingResult proto file was not found"); - } else { - props.setMatchingResultProtoFile(matchingResultProtoFile); - } + findAndSetFileByType(files, ".CostMatrix.pb", props.setCostMatrixProtoFile); + findAndSetFileByType(files, ".MatchingResult.pb", props.setMatchingResultProtoFile); + findAndSetFileByType(files, ".MatchingCosts.pb", props.setMatchingCostsProtoFile); const queryImageFiles = files.filter((file) => { return file.webkitRelativePath.split("/")[1] === "query_images"; diff --git a/src/viewer/src/components/ImageCarousel.tsx b/src/viewer/src/components/ImageCarousel.tsx index 66b9230..1fdfa01 100644 --- a/src/viewer/src/components/ImageCarousel.tsx +++ b/src/viewer/src/components/ImageCarousel.tsx @@ -127,7 +127,7 @@ function ImageCarousel(props: ImageCarouselProps) { )} - {images && currentImageId !== undefined && + {images && currentImageId !== null && images.length > currentImageId && images[currentImageId] !== undefined && ( diff --git a/src/viewer/src/resources/readers.ts b/src/viewer/src/resources/readers.ts index 7651931..722b30d 100644 --- a/src/viewer/src/resources/readers.ts +++ b/src/viewer/src/resources/readers.ts @@ -23,6 +23,7 @@ function readImageAsync(file: Blob) { enum ProtoMessageType { CostMatrix = "CostMatrix", MatchingResult = "MatchingResult", + MatchingCosts = "MatchingCosts", } function readProtoFromBuffer(buffer: Uint8Array, protoMessageType: string) { @@ -34,11 +35,11 @@ function readProtoFromBuffer(buffer: Uint8Array, protoMessageType: string) { // Get the message type from the root object const message = root.lookupType(protoMessageType); const decodedMessage = message.decode(buffer); - console.log("Message", decodedMessage); resolve(decodedMessage); }) .catch((error: any) => { console.log("ERROR, proto couldn't be loaded", error); + console.log("For type", protoMessageType); reject(); }); });