diff --git a/src/lib/bindings.cpp b/src/lib/bindings.cpp index 6464425..0949217 100644 --- a/src/lib/bindings.cpp +++ b/src/lib/bindings.cpp @@ -1050,16 +1050,95 @@ void bind_keys(py::module &m) .def_readwrite("secretKey", &KeyPair::secretKey) .def("good", &KeyPair::good,kp_good_docs); py::class_, std::shared_ptr>>(m, "EvalKey") - .def(py::init<>()) + .def(py::init<>()) .def("GetKeyTag", &EvalKeyImpl::GetKeyTag) .def("SetKeyTag", &EvalKeyImpl::SetKeyTag); py::class_>, std::shared_ptr>>>(m, "EvalKeyMap") .def(py::init<>()); } +// PlaintextImpl is an abstract class, so we should use a helper (trampoline) class +class PlaintextImpl_helper : public PlaintextImpl +{ +public: + using PlaintextImpl::PlaintextImpl; // inherited constructors + + // the PlaintextImpl virtual functions' overrides + bool Encode() override { + PYBIND11_OVERRIDE_PURE( + bool, // return type + PlaintextImpl, // parent class + Encode // function name + // no arguments + ); + } + bool Decode() override { + PYBIND11_OVERRIDE_PURE( + bool, // return type + PlaintextImpl, // parent class + Decode // function name + // no arguments + ); + } + bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override { + PYBIND11_OVERRIDE( + bool, // return type + PlaintextImpl, // parent class + Decode, // function name + depth, scalingFactor, scalTech, executionMode // arguments + ); + } + size_t GetLength() const override { + PYBIND11_OVERRIDE_PURE( + size_t, // return type + PlaintextImpl, // parent class + GetLength // function name + // no arguments + ); + } + void SetLength(size_t newSize) override { + PYBIND11_OVERRIDE( + void, // return type + PlaintextImpl, // parent class + SetLength, // function name + newSize // arguments + ); + } + double GetLogError() const override { + PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogError); + } + double GetLogPrecision() const override { + PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogPrecision); + } + const std::string& GetStringValue() const override { + PYBIND11_OVERRIDE(const std::string&, PlaintextImpl, GetStringValue); + } + const std::vector& GetCoefPackedValue() const override { + PYBIND11_OVERRIDE(const std::vector&, PlaintextImpl, GetCoefPackedValue); + } + const std::vector& GetPackedValue() const override { + PYBIND11_OVERRIDE(const std::vector&, PlaintextImpl, GetPackedValue); + } + const std::vector>& GetCKKSPackedValue() const override { + PYBIND11_OVERRIDE(const std::vector>&, PlaintextImpl, GetCKKSPackedValue); + } + std::vector GetRealPackedValue() const override { + PYBIND11_OVERRIDE(std::vector, PlaintextImpl, GetRealPackedValue); + } + void SetStringValue(const std::string& str) override { + PYBIND11_OVERRIDE(void, PlaintextImpl, SetStringValue, str); + } + void SetIntVectorValue(const std::vector& vec) override { + PYBIND11_OVERRIDE(void, PlaintextImpl, SetIntVectorValue, vec); + } + std::string GetFormattedValues(int64_t precision) const override { + PYBIND11_OVERRIDE(std::string, PlaintextImpl, GetFormattedValues, precision); + } +}; + void bind_encodings(py::module &m) { - py::class_>(m, "Plaintext") + py::class_, PlaintextImpl_helper>(m, "Plaintext") .def("GetScalingFactor", &PlaintextImpl::GetScalingFactor, ptx_GetScalingFactor_docs) .def("SetScalingFactor", &PlaintextImpl::SetScalingFactor, @@ -1069,8 +1148,6 @@ void bind_encodings(py::module &m) ptx_GetSchemeID_docs) .def("GetLength", &PlaintextImpl::GetLength, ptx_GetLength_docs) - .def("GetSchemeID", &PlaintextImpl::GetSchemeID, - ptx_GetSchemeID_docs) .def("SetLength", &PlaintextImpl::SetLength, ptx_SetLength_docs, py::arg("newSize")) @@ -1080,7 +1157,9 @@ void bind_encodings(py::module &m) ptx_GetLogPrecision_docs) .def("Encode", &PlaintextImpl::Encode, ptx_Encode_docs) - .def("Decode", &PlaintextImpl::Decode, + .def("Decode", py::overload_cast<>(&PlaintextImpl::Decode), + ptx_Decode_docs) + .def("Decode", py::overload_cast(&PlaintextImpl::Decode), ptx_Decode_docs) .def("LowBound", &PlaintextImpl::LowBound, ptx_LowBound_docs)