Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 84 additions & 5 deletions src/lib/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,16 +1050,95 @@ void bind_keys(py::module &m)
.def_readwrite("secretKey", &KeyPair<DCRTPoly>::secretKey)
.def("good", &KeyPair<DCRTPoly>::good,kp_good_docs);
py::class_<EvalKeyImpl<DCRTPoly>, std::shared_ptr<EvalKeyImpl<DCRTPoly>>>(m, "EvalKey")
.def(py::init<>())
.def(py::init<>())
.def("GetKeyTag", &EvalKeyImpl<DCRTPoly>::GetKeyTag)
.def("SetKeyTag", &EvalKeyImpl<DCRTPoly>::SetKeyTag);
py::class_<std::map<usint, EvalKey<DCRTPoly>>, std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>>>(m, "EvalKeyMap")
.def(py::init<>());
}

// PlaintextImpl is an abstract class, so we should use a helper (trampoline) class
class PlaintextImpl_helper : public PlaintextImpl
{
public:
using PlaintextImpl::PlaintextImpl; // inherited constructors

// the PlaintextImpl virtual functions' overrides
bool Encode() override {
PYBIND11_OVERRIDE_PURE(
bool, // return type
PlaintextImpl, // parent class
Encode // function name
// no arguments
);
}
bool Decode() override {
PYBIND11_OVERRIDE_PURE(
bool, // return type
PlaintextImpl, // parent class
Decode // function name
// no arguments
);
}
bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override {
PYBIND11_OVERRIDE(
bool, // return type
PlaintextImpl, // parent class
Decode, // function name
depth, scalingFactor, scalTech, executionMode // arguments
);
}
size_t GetLength() const override {
PYBIND11_OVERRIDE_PURE(
size_t, // return type
PlaintextImpl, // parent class
GetLength // function name
// no arguments
);
}
void SetLength(size_t newSize) override {
PYBIND11_OVERRIDE(
void, // return type
PlaintextImpl, // parent class
SetLength, // function name
newSize // arguments
);
}
double GetLogError() const override {
PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogError);
}
double GetLogPrecision() const override {
PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogPrecision);
}
const std::string& GetStringValue() const override {
PYBIND11_OVERRIDE(const std::string&, PlaintextImpl, GetStringValue);
}
const std::vector<int64_t>& GetCoefPackedValue() const override {
PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetCoefPackedValue);
}
const std::vector<int64_t>& GetPackedValue() const override {
PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetPackedValue);
}
const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
PYBIND11_OVERRIDE(const std::vector<std::complex<double>>&, PlaintextImpl, GetCKKSPackedValue);
}
std::vector<double> GetRealPackedValue() const override {
PYBIND11_OVERRIDE(std::vector<double>, PlaintextImpl, GetRealPackedValue);
}
void SetStringValue(const std::string& str) override {
PYBIND11_OVERRIDE(void, PlaintextImpl, SetStringValue, str);
}
void SetIntVectorValue(const std::vector<int64_t>& vec) override {
PYBIND11_OVERRIDE(void, PlaintextImpl, SetIntVectorValue, vec);
}
std::string GetFormattedValues(int64_t precision) const override {
PYBIND11_OVERRIDE(std::string, PlaintextImpl, GetFormattedValues, precision);
}
};

void bind_encodings(py::module &m)
{
py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>>(m, "Plaintext")
py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>, PlaintextImpl_helper>(m, "Plaintext")
.def("GetScalingFactor", &PlaintextImpl::GetScalingFactor,
ptx_GetScalingFactor_docs)
.def("SetScalingFactor", &PlaintextImpl::SetScalingFactor,
Expand All @@ -1069,8 +1148,6 @@ void bind_encodings(py::module &m)
ptx_GetSchemeID_docs)
.def("GetLength", &PlaintextImpl::GetLength,
ptx_GetLength_docs)
.def("GetSchemeID", &PlaintextImpl::GetSchemeID,
ptx_GetSchemeID_docs)
.def("SetLength", &PlaintextImpl::SetLength,
ptx_SetLength_docs,
py::arg("newSize"))
Expand All @@ -1080,7 +1157,9 @@ void bind_encodings(py::module &m)
ptx_GetLogPrecision_docs)
.def("Encode", &PlaintextImpl::Encode,
ptx_Encode_docs)
.def("Decode", &PlaintextImpl::Decode,
.def("Decode", py::overload_cast<>(&PlaintextImpl::Decode),
ptx_Decode_docs)
.def("Decode", py::overload_cast<size_t, double, ScalingTechnique, ExecutionMode>(&PlaintextImpl::Decode),
ptx_Decode_docs)
.def("LowBound", &PlaintextImpl::LowBound,
ptx_LowBound_docs)
Expand Down
Loading