diff --git a/src/mp/gen.cpp b/src/mp/gen.cpp index 603f9cc..53e7d0f 100644 --- a/src/mp/gen.cpp +++ b/src/mp/gen.cpp @@ -125,6 +125,109 @@ static bool BoxedType(const ::capnp::Type& type) type.isFloat64() || type.isEnum()); } +struct Field +{ + ::capnp::StructSchema::Field param; + bool param_is_set = false; + ::capnp::StructSchema::Field result; + bool result_is_set = false; + int args = 0; + bool retval = false; + bool optional = false; + bool requested = false; + bool skip = false; + kj::StringPtr exception; +}; + +struct FieldList +{ + std::vector fields; + std::map field_idx; // name -> args index + bool has_result = false; + + void addField(const ::capnp::StructSchema::Field& schema_field, bool param, bool result) + { + auto field_name = schema_field.getProto().getName(); + auto inserted = field_idx.emplace(field_name, fields.size()); + if (inserted.second) { + fields.emplace_back(); + } + auto& field = fields[inserted.first->second]; + if (param) { + field.param = schema_field; + field.param_is_set = true; + } + if (result) { + field.result = schema_field; + field.result_is_set = true; + } + + if (!param && field_name == kj::StringPtr{"result"}) { + field.retval = true; + has_result = true; + } + + if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) { + field.skip = true; + } + GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception); + + int32_t count = 1; + if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) { + if (schema_field.getType().isStruct()) { + GetAnnotationInt32(schema_field.getType().asStruct().getProto(), + COUNT_ANNOTATION_ID, &count); + } else if (schema_field.getType().isInterface()) { + GetAnnotationInt32(schema_field.getType().asInterface().getProto(), + COUNT_ANNOTATION_ID, &count); + } + } + + + if (inserted.second && !field.retval && !field.exception.size()) { + field.args = count; + } + } + + void mergeFields() + { + for (auto& field : field_idx) { + auto has_field = field_idx.find("has" + Cap(field.first)); + if (has_field != field_idx.end()) { + fields[has_field->second].skip = true; + fields[field.second].optional = true; + } + auto want_field = field_idx.find("want" + Cap(field.first)); + if (want_field != field_idx.end() && fields[want_field->second].param_is_set) { + fields[want_field->second].skip = true; + fields[field.second].requested = true; + } + } + } +}; + +std::string AccessorType(kj::StringPtr base_name, const Field& field) +{ + const auto& f = field.param_is_set ? field.param : field.result; + const auto field_name = f.getProto().getName(); + const auto field_type = f.getType(); + + std::ostringstream out; + out << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "; + if (!field.param_is_set) { + out << "FIELD_OUT"; + } else if (field.result_is_set) { + out << "FIELD_IN | FIELD_OUT"; + } else { + out << "FIELD_IN"; + } + if (field.optional) out << " | FIELD_OPTIONAL"; + if (field.requested) out << " | FIELD_REQUESTED"; + if (BoxedType(field_type)) out << " | FIELD_BOXED"; + out << ">"; + return out.str(); +} + // src_file is path to .capnp file to generate stub code from. // // src_prefix can be used to generate outputs in a different directory than the @@ -332,6 +435,13 @@ static void Generate(kj::StringPtr src_prefix, if (node.getProto().isStruct()) { const auto& struc = node.asStruct(); + + FieldList fields; + for (const auto schema_field : struc.getFields()) { + fields.addField(schema_field, true, true); + } + fields.mergeFields(); + std::ostringstream generic_name; generic_name << node_name; dec << "template<"; @@ -352,22 +462,18 @@ static void Generate(kj::StringPtr src_prefix, dec << "struct ProxyStruct<" << message_namespace << "::" << generic_name.str() << ">\n"; dec << "{\n"; dec << " using Struct = " << message_namespace << "::" << generic_name.str() << ";\n"; - for (const auto field : struc.getFields()) { - auto field_name = field.getProto().getName(); + for (const auto& field : fields.fields) { + auto field_name = field.param.getProto().getName(); add_accessor(field_name); - dec << " using " << Cap(field_name) << "Accessor = Accessor<" << base_name - << "_fields::" << Cap(field_name) << ", FIELD_IN | FIELD_OUT"; - if (BoxedType(field.getType())) dec << " | FIELD_BOXED"; - dec << ">;\n"; + dec << " using " << Cap(field_name) << "Accessor = " + << AccessorType(base_name, field) << ";\n"; } dec << " using Accessors = std::tuple<"; size_t i = 0; - for (const auto field : struc.getFields()) { - if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) { - continue; - } + for (const auto& field : fields.fields) { + if (field.skip) continue; if (i) dec << ", "; - dec << Cap(field.getProto().getName()) << "Accessor"; + dec << Cap(field.param.getProto().getName()) << "Accessor"; ++i; } dec << ">;\n"; @@ -381,13 +487,11 @@ static void Generate(kj::StringPtr src_prefix, inl << "public:\n"; inl << " using Struct = " << message_namespace << "::" << node_name << ";\n"; size_t i = 0; - for (const auto field : struc.getFields()) { - if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) { - continue; - } - auto field_name = field.getProto().getName(); + for (const auto& field : fields.fields) { + if (field.skip) continue; + auto field_name = field.param.getProto().getName(); auto member_name = field_name; - GetAnnotationText(field.getProto(), NAME_ANNOTATION_ID, &member_name); + GetAnnotationText(field.param.getProto(), NAME_ANNOTATION_ID, &member_name); inl << " static decltype(auto) get(std::integral_constant) { return " << "&" << proxied_class_type << "::" << member_name << "; }\n"; ++i; @@ -430,85 +534,14 @@ static void Generate(kj::StringPtr src_prefix, const bool is_construct = method_name == kj::StringPtr{"construct"}; const bool is_destroy = method_name == kj::StringPtr{"destroy"}; - struct Field - { - ::capnp::StructSchema::Field param; - bool param_is_set = false; - ::capnp::StructSchema::Field result; - bool result_is_set = false; - int args = 0; - bool retval = false; - bool optional = false; - bool requested = false; - bool skip = false; - kj::StringPtr exception; - }; - - std::vector fields; - std::map field_idx; // name -> args index - bool has_result = false; - - auto add_field = [&](const ::capnp::StructSchema::Field& schema_field, bool param) { - if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) { - return; - } - - auto field_name = schema_field.getProto().getName(); - auto inserted = field_idx.emplace(field_name, fields.size()); - if (inserted.second) { - fields.emplace_back(); - } - auto& field = fields[inserted.first->second]; - if (param) { - field.param = schema_field; - field.param_is_set = true; - } else { - field.result = schema_field; - field.result_is_set = true; - } - - if (!param && field_name == kj::StringPtr{"result"}) { - field.retval = true; - has_result = true; - } - - GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception); - - int32_t count = 1; - if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) { - if (schema_field.getType().isStruct()) { - GetAnnotationInt32(schema_field.getType().asStruct().getProto(), - COUNT_ANNOTATION_ID, &count); - } else if (schema_field.getType().isInterface()) { - GetAnnotationInt32(schema_field.getType().asInterface().getProto(), - COUNT_ANNOTATION_ID, &count); - } - } - - - if (inserted.second && !field.retval && !field.exception.size()) { - field.args = count; - } - }; - + FieldList fields; for (const auto schema_field : method.getParamType().getFields()) { - add_field(schema_field, true); + fields.addField(schema_field, true, false); } for (const auto schema_field : method.getResultType().getFields()) { - add_field(schema_field, false); - } - for (auto& field : field_idx) { - auto has_field = field_idx.find("has" + Cap(field.first)); - if (has_field != field_idx.end()) { - fields[has_field->second].skip = true; - fields[field.second].optional = true; - } - auto want_field = field_idx.find("want" + Cap(field.first)); - if (want_field != field_idx.end() && fields[want_field->second].param_is_set) { - fields[want_field->second].skip = true; - fields[field.second].requested = true; - } + fields.addField(schema_field, false, true); } + fields.mergeFields(); if (!is_construct && !is_destroy && (&method_interface == &interface)) { methods << "template<>\n"; @@ -524,25 +557,11 @@ static void Generate(kj::StringPtr src_prefix, std::ostringstream server_invoke_start; std::ostringstream server_invoke_end; int argc = 0; - for (const auto& field : fields) { + for (const auto& field : fields.fields) { if (field.skip) continue; const auto& f = field.param_is_set ? field.param : field.result; auto field_name = f.getProto().getName(); - auto field_type = f.getType(); - - std::ostringstream field_flags; - if (!field.param_is_set) { - field_flags << "FIELD_OUT"; - } else if (field.result_is_set) { - field_flags << "FIELD_IN | FIELD_OUT"; - } else { - field_flags << "FIELD_IN"; - } - if (field.optional) field_flags << " | FIELD_OPTIONAL"; - if (field.requested) field_flags << " | FIELD_REQUESTED"; - if (BoxedType(field_type)) field_flags << " | FIELD_BOXED"; - add_accessor(field_name); std::ostringstream fwd_args; @@ -569,8 +588,7 @@ static void Generate(kj::StringPtr src_prefix, client_invoke << "MakeClientParam<"; } - client_invoke << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", " - << field_flags.str() << ">>("; + client_invoke << AccessorType(base_name, field) << ">("; if (field.retval) { client_invoke << field_name; @@ -586,8 +604,7 @@ static void Generate(kj::StringPtr src_prefix, } else { server_invoke_start << "MakeServerField<" << field.args; } - server_invoke_start << ", Accessor<" << base_name << "_fields::" << Cap(field_name) << ", " - << field_flags.str() << ">>("; + server_invoke_start << ", " << AccessorType(base_name, field) << ">("; server_invoke_end << ")"; } @@ -603,12 +620,12 @@ static void Generate(kj::StringPtr src_prefix, def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method_ordinal << "::Result ProxyClient<" << message_namespace << "::" << node_name << ">::" << method_name << "(" << super_str << client_args.str() << ") {\n"; - if (has_result) { + if (fields.has_result) { def_client << " typename M" << method_ordinal << "::Result result;\n"; } def_client << " clientInvoke(" << self_str << ", &" << message_namespace << "::" << node_name << "::Client::" << method_name << "Request" << client_invoke.str() << ");\n"; - if (has_result) def_client << " return result;\n"; + if (fields.has_result) def_client << " return result;\n"; def_client << "}\n"; server << " kj::Promise " << method_name << "(" << Cap(method_name) diff --git a/test/mp/test/foo-types.h b/test/mp/test/foo-types.h index 735adb7..bd5565a 100644 --- a/test/mp/test/foo-types.h +++ b/test/mp/test/foo-types.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include diff --git a/test/mp/test/foo.capnp b/test/mp/test/foo.capnp index 5bdee25..99e918a 100644 --- a/test/mp/test/foo.capnp +++ b/test/mp/test/foo.capnp @@ -55,6 +55,8 @@ struct FooStruct $Proxy.wrap("mp::test::FooStruct") { name @0 :Text; setint @1 :List(Int32); vbool @2 :List(Bool); + optionalInt @3 :Int32 $Proxy.name("optional_int"); + hasOptionalInt @4 :Bool; } struct FooCustom $Proxy.wrap("mp::test::FooCustom") { diff --git a/test/mp/test/foo.h b/test/mp/test/foo.h index 317af02..4d52fd0 100644 --- a/test/mp/test/foo.h +++ b/test/mp/test/foo.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -21,6 +22,7 @@ struct FooStruct std::string name; std::set setint; std::vector vbool; + std::optional optional_int; }; enum class FooEnum : uint8_t { ONE = 1, TWO = 2, }; diff --git a/test/mp/test/test.cpp b/test/mp/test/test.cpp index bf41663..f5f175e 100644 --- a/test/mp/test/test.cpp +++ b/test/mp/test/test.cpp @@ -140,6 +140,7 @@ KJ_TEST("Call FooInterface methods") in.vbool.push_back(false); in.vbool.push_back(true); in.vbool.push_back(false); + in.optional_int = 3; FooStruct out = foo->pass(in); KJ_EXPECT(in.name == out.name); KJ_EXPECT(in.setint.size() == out.setint.size()); @@ -150,6 +151,12 @@ KJ_TEST("Call FooInterface methods") for (size_t i = 0; i < in.vbool.size(); ++i) { KJ_EXPECT(in.vbool[i] == out.vbool[i]); } + KJ_EXPECT(in.optional_int == out.optional_int); + + // Additional checks for std::optional member + KJ_EXPECT(foo->pass(in).optional_int == 3); + in.optional_int.reset(); + KJ_EXPECT(!foo->pass(in).optional_int); FooStruct err; try {