diff --git a/fbpcf/engine/communication/SocketPartyCommunicationAgentFactory.h b/fbpcf/engine/communication/SocketPartyCommunicationAgentFactory.h index 04b69859..319825c4 100644 --- a/fbpcf/engine/communication/SocketPartyCommunicationAgentFactory.h +++ b/fbpcf/engine/communication/SocketPartyCommunicationAgentFactory.h @@ -51,8 +51,7 @@ class SocketPartyCommunicationAgentFactory final if (id == myId_) { throw std::runtime_error("No need to talk to myself!"); } else { - auto serverId = id < myId_ ? id : myId_; - auto iter = partyInfos_.find(serverId); + auto iter = partyInfos_.find(id); if (iter == partyInfos_.end()) { throw std::runtime_error("Don't know how to connect to this party!"); } diff --git a/fbpcf/engine/communication/test/AgentFactoryCreationHelper.h b/fbpcf/engine/communication/test/AgentFactoryCreationHelper.h index 4559f01c..350c4fa4 100644 --- a/fbpcf/engine/communication/test/AgentFactoryCreationHelper.h +++ b/fbpcf/engine/communication/test/AgentFactoryCreationHelper.h @@ -7,8 +7,11 @@ #pragma once +#include +#include #include #include "fbpcf/engine/communication/InMemoryPartyCommunicationAgentFactory.h" +#include "folly/Random.h" namespace fbpcf::engine::communication { @@ -18,7 +21,7 @@ getInMemoryAgentFactory(int numberOfParty) { int, std::shared_ptr>>( numberOfParty); - std::vector> rst; + std::vector> result; for (int i = 0; i < numberOfParty; i++) { for (int j = i + 1; j < numberOfParty; j++) { @@ -30,10 +33,46 @@ getInMemoryAgentFactory(int numberOfParty) { } } for (int i = 0; i < numberOfParty; i++) { - rst.push_back(std::make_unique( + result.push_back(std::make_unique( i, std::move(maps[i]))); } - return rst; + return result; +} + +inline std::vector> +getSocketAgentFactory(int numberOfParty) { + auto maps = std::vector< + std::map>( + numberOfParty); + for (int i = 0; i < numberOfParty; i++) { + int port = 5000 + folly::Random::rand32() % 1000; + for (int j = i; j < numberOfParty; j++) { + SocketPartyCommunicationAgentFactory::PartyInfo partyInfo = { + "localhost", port}; + if (i == j) { + continue; + } + maps[j].emplace(i, partyInfo); + maps[i].emplace(j, partyInfo); + } + } + + std::vector> result; + for (int i = 0; i < numberOfParty; i++) { + auto m = maps[i]; + std::cerr << "entry " << i << "\n"; + for (auto const& x : m) { + std::cerr << x.first << ":[" << x.second.address << "," << x.second.portNo + << "]"; + } + std::cerr << "\n"; + } + + for (int i = 0; i < numberOfParty; i++) { + result.push_back( + std::make_unique(i, maps[i])); + } + return result; } } // namespace fbpcf::engine::communication diff --git a/fbpcf/engine/communication/test/PartyCommunicationAgentTest.cpp b/fbpcf/engine/communication/test/PartyCommunicationAgentTest.cpp index ada92c59..e8589832 100644 --- a/fbpcf/engine/communication/test/PartyCommunicationAgentTest.cpp +++ b/fbpcf/engine/communication/test/PartyCommunicationAgentTest.cpp @@ -73,14 +73,16 @@ TEST(SocketPartyCommunicationAgentTest, testSendAndReceiveWithTls) { * stress runs, we get errors when trying to bind to the * same port multiple times. */ - std::map partyInfo = { - {0, {"127.0.0.1", intDistro(defEngine)}}, - {1, {"127.0.0.1", intDistro(defEngine)}}}; + auto port = intDistro(defEngine); + std::map partyInfo0 = { + {1, {"127.0.0.1", port}}}; + std::map partyInfo1 = { + {0, {"127.0.0.1", port}}}; auto factory0 = std::make_unique( - 0, partyInfo, true, createdDir); + 0, partyInfo0, true, createdDir); auto factory1 = std::make_unique( - 1, partyInfo, true, createdDir); + 1, partyInfo1, true, createdDir); int size = 1048576; // 1024 ^ 2 auto thread0 = std::thread(testAgentFactory, 0, size, std::move(factory0)); @@ -97,14 +99,16 @@ TEST(SocketPartyCommunicationAgentTest, testSendAndReceiveWithoutTls) { std::default_random_engine defEngine(rd()); std::uniform_int_distribution intDistro(10000, 25000); - std::map partyInfo = { - {0, {"127.0.0.1", intDistro(defEngine)}}, - {1, {"127.0.0.1", intDistro(defEngine)}}}; + auto port = intDistro(defEngine); + std::map partyInfo0 = { + {1, {"127.0.0.1", port}}}; + std::map partyInfo1 = { + {0, {"127.0.0.1", port}}}; auto factory0 = - std::make_unique(0, partyInfo); + std::make_unique(0, partyInfo0); auto factory1 = - std::make_unique(1, partyInfo); + std::make_unique(1, partyInfo1); int size = 1048576; // 1024 ^ 2 auto thread0 = std::thread(testAgentFactory, 0, size, std::move(factory0)); diff --git a/fbpcf/engine/util/test/benchmarks/BenchmarkHelper.h b/fbpcf/engine/util/test/benchmarks/BenchmarkHelper.h index bf6c676e..a8ae562d 100644 --- a/fbpcf/engine/util/test/benchmarks/BenchmarkHelper.h +++ b/fbpcf/engine/util/test/benchmarks/BenchmarkHelper.h @@ -44,18 +44,22 @@ getSocketAgents() { auto retries = 5; while (retries--) { try { + auto port = intDistro(e); std::map< int, communication::SocketPartyCommunicationAgentFactory::PartyInfo> - partyInfo = { - {0, {"127.0.0.1", intDistro(e)}}, - {1, {"127.0.0.1", intDistro(e)}}}; + partyInfo0 = {{1, {"127.0.0.1", port}}}; + std::map< + int, + communication::SocketPartyCommunicationAgentFactory::PartyInfo> + partyInfo1 = {{0, {"127.0.0.1", port}}}; + auto factory0 = std::make_unique( - 0, partyInfo); + 0, partyInfo0); auto factory1 = std::make_unique( - 1, partyInfo); + 1, partyInfo1); auto task = [](std::unique_ptr diff --git a/fbpcf/frontend/test/IntE2ETest.cpp b/fbpcf/frontend/test/IntE2ETest.cpp new file mode 100644 index 00000000..c003df09 --- /dev/null +++ b/fbpcf/frontend/test/IntE2ETest.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "fbpcf/engine/communication/SocketPartyCommunicationAgentFactory.h" +#include "fbpcf/engine/communication/test/AgentFactoryCreationHelper.h" +#include "fbpcf/frontend/Int.h" + +namespace fbpcf::frontend { + +template +static int runGame() { + Integer>, schedulerId> int1(PARTY == 0 ? -45 : 10, 0); + Integer>, schedulerId> int2(PARTY == 1 ? 15 : -90, 1); + + auto sum = int1 + int2; + + return sum.openToParty(PARTY + 1 % 2).getValue(); +} + +class IntE2ETest : public ::testing::Test { + protected: + void SetUp() override { + auto factories = fbpcf::engine::communication::getSocketAgentFactory(2); + fbpcf::setupRealBackend<5, 10>(*factories[0], *factories[1]); + } + + void TearDown() override {} +}; + +TEST_F(IntE2ETest, TestCorrectness) { + auto futureAlice = std::async(runGame<0, 5>); + auto futureBob = std::async(runGame<1, 10>); + + int ans1 = futureAlice.get(); + int ans2 = futureBob.get(); + + EXPECT_EQ(ans1, 30); + EXPECT_EQ(ans2, 30); +} + +} // namespace fbpcf::frontend