diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts index 42426cc9..ada8d837 100644 --- a/buildSrc/build.gradle.kts +++ b/buildSrc/build.gradle.kts @@ -1,4 +1,3 @@ - plugins { `kotlin-dsl` } diff --git a/kmp-grpc-core/src/androidJvmCommon/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt b/kmp-grpc-core/src/androidJvmCommon/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt index 10a461fb..ed4d1fed 100644 --- a/kmp-grpc-core/src/androidJvmCommon/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt +++ b/kmp-grpc-core/src/androidJvmCommon/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt @@ -11,6 +11,8 @@ actual interface Message { actual val fullName: String + actual val isInitialized: Boolean + actual fun serialize(): ByteArray { val buffer = Buffer() serialize(CodedOutputStreamImpl(buffer)) diff --git a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/UninitializedMessageException.kt b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/UninitializedMessageException.kt new file mode 100644 index 00000000..3a825fc3 --- /dev/null +++ b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/UninitializedMessageException.kt @@ -0,0 +1,21 @@ +package io.github.timortel.kmpgrpc.core + +import io.github.timortel.kmpgrpc.core.message.Message + +/** + * Thrown when a Protocol Buffers message is missing one or more required fields. + * + * In `proto2`, fields marked as `required` must be populated before a message + * can be fully initialized or serialized. This exception typically occurs during + * a DSL `build()` operation or when parsing a message that violates these + * presence constraints. + * + * @property msg The incomplete [Message] instance that triggered this exception. + * Note that accessing fields on this instance is safe, but it is considered + * semantically invalid according to the schema. + */ +class UninitializedMessageException( + val msg: Message, +) : RuntimeException( + "Message ${msg::class.simpleName} is missing required fields." +) diff --git a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/CodedInputStream.kt b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/CodedInputStream.kt index d6f0005d..3ff3187a 100644 --- a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/CodedInputStream.kt +++ b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/CodedInputStream.kt @@ -64,6 +64,15 @@ abstract class CodedInputStream { return recursiveRead { deserializer.deserialize(this, extensionRegistry) } } + fun readGroup(deserializer: MessageDeserializer, extensionRegistry: ExtensionRegistry, fieldNumber: Int): M { + checkRecursionLimit() + recursionDepth++ + val message = deserializer.deserialize(this, extensionRegistry) + checkLastTagWas(wireFormatMakeTag(fieldNumber, WireFormat.END_GROUP)) + recursionDepth-- + return message + } + abstract fun readBytes(): ByteArray abstract fun readUInt32(): UInt @@ -89,6 +98,7 @@ abstract class CodedInputStream { extensionRegistry: ExtensionRegistry ): UnknownFieldOrExtension? { val number = wireFormatGetTagFieldNumber(tag) + if (wireFormatGetTagWireType(tag) == WireFormat.END_GROUP.value) return null val extension = extensionRegistry.getExtensionForFieldNumber(number) @Suppress("UNCHECKED_CAST") diff --git a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/CodedOutputStream.kt b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/CodedOutputStream.kt index d6e5ac89..66c6713d 100644 --- a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/CodedOutputStream.kt +++ b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/CodedOutputStream.kt @@ -77,8 +77,12 @@ interface CodedOutputStream { fun writeMessage(fieldNumber: Int, value: Message) + fun writeGroup(fieldNumber: Int, value: Message) + fun writeMessageArray(fieldNumber: Int, values: List) + fun writeGroupArray(fieldNumber: Int, values: List) + fun writeSFixed32(fieldNumber: Int, value: Int) fun writeSFixed32Array(fieldNumber: Int, values: List, tag: UInt) diff --git a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/internal/CodedOutputStreamImpl.kt b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/internal/CodedOutputStreamImpl.kt index b68f331f..e11690ef 100644 --- a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/internal/CodedOutputStreamImpl.kt +++ b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/io/internal/CodedOutputStreamImpl.kt @@ -217,6 +217,16 @@ internal class CodedOutputStreamImpl(private val sink: Sink) : CodedOutputStream values.forEach { writeMessage(fieldNumber, it) } } + override fun writeGroup(fieldNumber: Int, value: Message) { + writeTag(fieldNumber, WireFormat.START_GROUP) + value.serialize(this) + writeTag(fieldNumber, WireFormat.END_GROUP) + } + + override fun writeGroupArray(fieldNumber: Int, values: List) { + values.forEach { writeGroup(fieldNumber, it) } + } + override fun writeMap( fieldNumber: Int, map: Map, diff --git a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt index 385f59d0..41729222 100644 --- a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt +++ b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt @@ -17,6 +17,11 @@ expect interface Message { */ val fullName: String + /** + * If all required fields for this message have been set. + */ + val isInitialized: Boolean + /** * Serializes this message and returns it as a [ByteArray]. * diff --git a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/message/util.kt b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/message/util.kt index 74141bbf..920061a2 100644 --- a/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/message/util.kt +++ b/kmp-grpc-core/src/commonMain/kotlin/io/github/timortel/kmpgrpc/core/message/util.kt @@ -9,7 +9,7 @@ fun mergeUnknownFieldOrExtension( fieldOrExtension: UnknownFieldOrExtension?, unknownFields: MutableList, extensionBuilder: MessageExtensionsBuilder -) { +): Boolean { when (fieldOrExtension) { is UnknownFieldOrExtension.UnknownField -> unknownFields.add(fieldOrExtension.field) is UnknownFieldOrExtension.RepeatedExtension -> { @@ -18,6 +18,8 @@ fun mergeUnknownFieldOrExtension( is UnknownFieldOrExtension.ScalarExtension -> { extensionBuilder[fieldOrExtension.extension] = fieldOrExtension.value } - null -> {} + null -> return false } + + return true } diff --git a/kmp-grpc-core/src/jsTargetCommon/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt b/kmp-grpc-core/src/jsTargetCommon/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt index ff8ccc0c..d9f8d184 100644 --- a/kmp-grpc-core/src/jsTargetCommon/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt +++ b/kmp-grpc-core/src/jsTargetCommon/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt @@ -11,6 +11,8 @@ actual interface Message { actual val requiredSize: Int + actual val isInitialized: Boolean + actual fun serialize(): ByteArray { val buffer = Buffer() serialize(CodedOutputStreamImpl(buffer)) diff --git a/kmp-grpc-core/src/nativeMain/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt b/kmp-grpc-core/src/nativeMain/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt index ff8ccc0c..d9f8d184 100644 --- a/kmp-grpc-core/src/nativeMain/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt +++ b/kmp-grpc-core/src/nativeMain/kotlin/io/github/timortel/kmpgrpc/core/message/Message.kt @@ -11,6 +11,8 @@ actual interface Message { actual val requiredSize: Int + actual val isInitialized: Boolean + actual fun serialize(): ByteArray { val buffer = Buffer() serialize(CodedOutputStreamImpl(buffer)) diff --git a/kmp-grpc-internal-test/build.gradle.kts b/kmp-grpc-internal-test/build.gradle.kts index 494bb502..4dc58cd9 100644 --- a/kmp-grpc-internal-test/build.gradle.kts +++ b/kmp-grpc-internal-test/build.gradle.kts @@ -114,11 +114,16 @@ kmpGrpc { includeWellKnownTypes = true - protoSourceFolders = project.files("src/commonMain/proto/general", "src/commonMain/proto/unknownfield", "src/commonMain/proto/editions") + protoSourceFolders = project.files( + "src/commonMain/proto/general", + "src/commonMain/proto/unknownfield", + "src/commonMain/proto/editions", + "src/commonMain/proto/proto2" + ) } buildConfig { - packageName("iio.github.timortel.kmpgrpc.internal.test") + packageName("io.github.timortel.kmpgrpc.internal.test") useKotlinOutput { internalVisibility = true diff --git a/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-group-test.proto b/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-group-test.proto new file mode 100644 index 00000000..74d3d8fe --- /dev/null +++ b/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-group-test.proto @@ -0,0 +1,43 @@ +syntax = "proto2"; + +package io.github.timortel.kmpgrpc.test.proto2; + +option java_outer_classname = "Proto2GroupTest"; + +message A { + optional group B = 1 { + optional string field1 = 1; + optional int32 field2 = 2; + + optional group C = 3 { + optional string field1 = 1; + + optional string field2 = 2; + } + } + + repeated group D = 2 { + optional string field1 = 1; + optional int32 field2 = 2; + } +} + +message E { + optional string field1 = 1; + optional group G = 2 { + optional string field1 = 1; + extensions 2 to max; + } + + extensions 3 to max; +} + +extend E { + optional group F = 4 { + optional string field1 = 1; + } +} + +extend E.G { + optional string field2 = 2; +} diff --git a/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-required-fields.proto b/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-required-fields.proto new file mode 100644 index 00000000..02099df0 --- /dev/null +++ b/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-required-fields.proto @@ -0,0 +1,22 @@ +syntax = "proto2"; + +package io.github.timortel.kmpgrpc.test.proto2; + +option java_outer_classname = "Proto2RequiredFields"; + +message Proto2MessageWithMixedFields { + required string field1 = 1; + optional int32 field2 = 2; +} + +message Proto2MessageWithRequiredFields { + required string field1 = 1; + required Proto2MessageWithMixedFields field2 = 2; + repeated Proto2MessageWithRequiredFields field3 = 3; + map field4 = 4; + + oneof x { + Proto2MessageWithMixedFields field5 = 5; + string field6 = 6; + } +} diff --git a/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-test-service.proto b/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-test-service.proto new file mode 100644 index 00000000..c7db32b3 --- /dev/null +++ b/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-test-service.proto @@ -0,0 +1,11 @@ +syntax = "proto2"; + +package io.github.timortel.kmpgrpc.test.proto2; + +import "proto2-group-test.proto"; + +option java_multiple_files = true; + +service Proto2TestService { + rpc sendMessageWithNestedGroups (A) returns (A); +} diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/defaults.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/defaults.kt index f4bdc570..71fe0d36 100644 --- a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/defaults.kt +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/defaults.kt @@ -3,6 +3,7 @@ package io.github.timortel.kotlin_multiplatform_grpc_plugin.test import ExtensionsTest import io.github.timortel.kmpgrpc.core.message.extensions.buildExtensions import io.github.timortel.kmpgrpc.test.* +import io.github.timortel.kmpgrpc.test.proto2.Proto2GroupTest fun createScalarMessage() = scalarTypes { field1 = "Test" @@ -102,11 +103,13 @@ fun createMessageWithAllExtensions() = ExtensionsTest.MessageWithEveryExtension( set(ExtensionsTest.field33, listOf(0uL, 134uL, 353111345134uL)) set(ExtensionsTest.field34, listOf(-14, 0, 1241522)) set(ExtensionsTest.field35, listOf(-154L, 0L, 4514124121L)) - set(ExtensionsTest.field36, listOf( - byteArrayOf(0, -127, 127), - byteArrayOf(-123, 1, 2), - byteArrayOf(3, 3, -6) - )) + set( + ExtensionsTest.field36, listOf( + byteArrayOf(0, -127, 127), + byteArrayOf(-123, 1, 2), + byteArrayOf(3, 3, -6) + ) + ) } ) @@ -188,3 +191,32 @@ fun createEditionsNonPackedTypesMessage(): EditionsNonPackedTypesMessage = Editi field12List = field12, field13List = field13, ) + +fun createProto2NestedGroupMessage() = Proto2GroupTest.A( + b = Proto2GroupTest.A.B( + field1 = "text1", + field2 = 2, + c = Proto2GroupTest.A.B.C(field1 = "text3", field2 = "text4") + ), + dList = listOf( + Proto2GroupTest.A.D(field1 = "text5", field2 = 6), + Proto2GroupTest.A.D(field1 = "text7", field2 = 8) + ) +) + +fun createProto2GroupMessageWithExtensions() = Proto2GroupTest.E( + field1 = "text1", + g = Proto2GroupTest.E.G( + field1 = "text2", + extensions = buildExtensions { + set(Proto2GroupTest.field2, "text3") + } + ), + extensions = buildExtensions { + set( + Proto2GroupTest.f, Proto2GroupTest.F( + field1 = "text4" + ) + ) + } +) diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/integration/EditionsRpcTest.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/integration/EditionsRpcTest.kt index 87c274c3..9d317d13 100644 --- a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/integration/EditionsRpcTest.kt +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/integration/EditionsRpcTest.kt @@ -1,6 +1,8 @@ package io.github.timortel.kotlin_multiplatform_grpc_plugin.test.integration import io.github.timortel.kmpgrpc.core.Channel +import io.github.timortel.kmpgrpc.core.Code +import io.github.timortel.kmpgrpc.core.StatusException import io.github.timortel.kmpgrpc.test.EditionsTestServiceStub import io.github.timortel.kmpgrpc.test.editions.EditionsLegacyField import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createEditionsNonPackedTypesMessage @@ -9,6 +11,7 @@ import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createMessageWit import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith abstract class EditionsRpcTest : ServerTest { @@ -39,9 +42,11 @@ abstract class EditionsRpcTest : ServerTest { @Test fun testLegacyRequiredFieldNoData() = runTest { val message = EditionsLegacyField() - val response = stub.sendLegacyRequiredField(message) + val exception = assertFailsWith { + stub.sendLegacyRequiredField(message) + } - assertEquals(message, response) + assertEquals(Code.INTERNAL, exception.status.code) } @Test diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/integration/RpcTest.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/integration/RpcTest.kt index 69ffa9bc..9f803ec4 100644 --- a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/integration/RpcTest.kt +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/integration/RpcTest.kt @@ -5,7 +5,9 @@ import io.github.timortel.kmpgrpc.core.Code import io.github.timortel.kmpgrpc.core.StatusException import io.github.timortel.kmpgrpc.core.message.UnknownField import io.github.timortel.kmpgrpc.test.* +import io.github.timortel.kmpgrpc.test.proto2.Proto2TestServiceStub import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createMessageWithAllTypes +import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createProto2NestedGroupMessage import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createScalarMessage import kotlinx.coroutines.* import kotlinx.coroutines.flow.Flow @@ -141,6 +143,15 @@ abstract class RpcTest : ServerTest { assertEquals(baseMessage, returnedMessage) } + @Test + fun testSendNestedGroupMessage() = runTest { + val message = createProto2NestedGroupMessage() + + val stub = Proto2TestServiceStub(channel) + val returnedMessage = stub.sendMessageWithNestedGroups(message) + assertEquals(message, returnedMessage) + } + @Test fun testFailedRpcThrowsKmStatusException() = runTest { val message = simpleMessage { } diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/EqTest.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/EqTest.kt index c53568ae..f6f23f6d 100644 --- a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/EqTest.kt +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/EqTest.kt @@ -5,9 +5,12 @@ import io.github.timortel.kmpgrpc.core.message.UnknownField import io.github.timortel.kmpgrpc.core.message.extensions.buildExtensions import io.github.timortel.kmpgrpc.test.Unknownfield import io.github.timortel.kmpgrpc.test.emptyMessage +import io.github.timortel.kmpgrpc.test.proto2.Proto2GroupTest import io.github.timortel.kmpgrpc.test.simpleMessage import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createMessageWithAllExtensions import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createMessageWithAllTypes +import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createProto2GroupMessageWithExtensions +import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createProto2NestedGroupMessage import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createScalarMessage import kotlin.test.Test import kotlin.test.assertEquals @@ -98,4 +101,36 @@ class EqTest { assertNotEquals(msg1, msg2) } + + @Test + fun messageWithNestedGroupsEqual() { + val msg1 = createProto2NestedGroupMessage() + val msg2 = createProto2NestedGroupMessage() + + assertEquals(msg1, msg2) + } + + @Test + fun messageWithNestedGroupsDiffer() { + val msg1 = createProto2NestedGroupMessage() + val msg2 = Proto2GroupTest.A() + + assertNotEquals(msg1, msg2) + } + + @Test + fun messageWithGroupMessageExtensionsEqual() { + val msg1 = createProto2GroupMessageWithExtensions() + val msg2 = createProto2GroupMessageWithExtensions() + + assertEquals(msg1, msg2) + } + + @Test + fun messageWithGroupMessageExtensionsDiffer() { + val msg1 = createProto2GroupMessageWithExtensions() + val msg2 = msg1.copy(extensions = buildExtensions { }) + + assertNotEquals(msg1, msg2) + } } diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/IsInitializedTest.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/IsInitializedTest.kt new file mode 100644 index 00000000..7975686c --- /dev/null +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/IsInitializedTest.kt @@ -0,0 +1,73 @@ +package io.github.timortel.kotlin_multiplatform_grpc_plugin.test.model + +import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields +import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithRequiredFields +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class IsInitializedTest { + + @Test + fun testDefaultIsInitialized() { + // invoke() uses default values that satisfy required fields + val msg = Proto2MessageWithRequiredFields() + assertTrue(msg.isInitialized, "Default message should be initialized") + } + + @Test + fun testMissingLocalRequiredField() { + // field1 is required. Passing null via createPartial should make it uninitialized. + val msg = Proto2MessageWithRequiredFields.createPartial(field1 = null) + assertFalse(msg.isInitialized, "Message should be uninitialized if local required field is missing") + } + + @Test + fun testUninitializedNestedMessage() { + // field2 is set, but the nested message itself is missing its own required field1 + val incompleteNested = Proto2RequiredFields.Proto2MessageWithMixedFields.createPartial(field1 = null) + val msg = Proto2MessageWithRequiredFields.createPartial( + field1 = "valid", + field2 = incompleteNested + ) + + assertFalse(msg.isInitialized, "Message should be uninitialized if a nested message is uninitialized") + } + + @Test + fun testUninitializedMessageInList() { + val incomplete = Proto2MessageWithRequiredFields.createPartial(field1 = null) + val msg = Proto2MessageWithRequiredFields( + field3List = listOf(incomplete) + ) + + assertFalse(msg.isInitialized, "Message should be uninitialized if any element in a repeated field is uninitialized") + } + + @Test + fun testUninitializedMessageInMap() { + val incomplete = Proto2MessageWithRequiredFields.createPartial(field1 = null) + val msg = Proto2MessageWithRequiredFields( + field4Map = mapOf("key" to incomplete) + ) + + assertFalse(msg.isInitialized, "Message should be uninitialized if any value in a map is uninitialized") + } + + @Test + fun testOneOfInitialization() { + // x.field5 is a message type. If that message is incomplete, the parent is incomplete. + val incompleteMixed = Proto2RequiredFields.Proto2MessageWithMixedFields.createPartial(field1 = null) + val msg = Proto2MessageWithRequiredFields( + x = Proto2MessageWithRequiredFields.X.Field5(incompleteMixed) + ) + + assertFalse(msg.isInitialized, "Message should be uninitialized if a message inside a OneOf is uninitialized") + + // x.field6 is a string (primitive-like), so it's always considered initialized if the case is set + val msg2 = Proto2MessageWithRequiredFields( + x = Proto2MessageWithRequiredFields.X.Field6("hello") + ) + assertTrue(msg2.isInitialized, "Message should be initialized if OneOf contains a valid string") + } +} diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/UninitializedBuilderTest.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/UninitializedBuilderTest.kt new file mode 100644 index 00000000..c35c44f4 --- /dev/null +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/UninitializedBuilderTest.kt @@ -0,0 +1,88 @@ +package io.github.timortel.kotlin_multiplatform_grpc_plugin.test.model + +import io.github.timortel.kmpgrpc.core.UninitializedMessageException +import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields +import io.github.timortel.kmpgrpc.test.proto2.proto2MessageWithMixedFields +import io.github.timortel.kmpgrpc.test.proto2.proto2MessageWithRequiredFields +import kotlin.test.Test +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull + +class UninitializedBuilderTest { + + @Test + fun testSuccessfulBuild() { + // Should not throw because all required fields are set + val msg = proto2MessageWithRequiredFields { + field1 = "top level" + field2 = proto2MessageWithMixedFields { + field1 = "nested required" + } + } + assertNotNull(msg) + } + + @Test + fun testMissingTopLevelRequiredField() { + assertFailsWith("Should throw if top-level required field1 is missing") { + proto2MessageWithRequiredFields { + // field1 is missing + field2 = proto2MessageWithMixedFields { field1 = "valid" } + } + } + } + + @Test + fun testMissingNestedRequiredField() { + assertFailsWith("Should throw if a required field inside field2 is missing") { + proto2MessageWithRequiredFields { + field1 = "valid" + field2 = proto2MessageWithMixedFields { + // field1 is required in MixedFields but missing here + field2 = 123 + } + } + } + } + + @Test + fun testUninitializedInList() { + assertFailsWith("Should throw if an element in the list is uninitialized") { + proto2MessageWithRequiredFields { + field1 = "valid" + field2 = proto2MessageWithMixedFields { field1 = "valid" } + + // Add an incomplete message to the list + field3List.add(Proto2RequiredFields.Proto2MessageWithRequiredFields.createPartial(field1 = null)) + } + } + } + + @Test + fun testUninitializedInMap() { + assertFailsWith("Should throw if a map value is uninitialized") { + proto2MessageWithRequiredFields { + field1 = "valid" + field2 = proto2MessageWithMixedFields { field1 = "valid" } + + // Add an incomplete message to the map + field4Map["key"] = Proto2RequiredFields.Proto2MessageWithRequiredFields.createPartial(field1 = null) + } + } + } + + @Test + fun testUninitializedInOneOf() { + assertFailsWith("Should throw if the chosen OneOf case is uninitialized") { + proto2MessageWithRequiredFields { + field1 = "valid" + field2 = proto2MessageWithMixedFields { field1 = "valid" } + + // x is set to a Field5 which contains an uninitialized message + x = Proto2RequiredFields.Proto2MessageWithRequiredFields.X.Field5( + Proto2RequiredFields.Proto2MessageWithMixedFields.createPartial(field1 = null) + ) + } + } + } +} diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/serialization/RequiredFieldSerializationTests.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/serialization/RequiredFieldSerializationTests.kt new file mode 100644 index 00000000..51efa493 --- /dev/null +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/serialization/RequiredFieldSerializationTests.kt @@ -0,0 +1,85 @@ +package io.github.timortel.kotlin_multiplatform_grpc_plugin.test.serialization + +import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithMixedFields +import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithRequiredFields +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class RequiredFieldSerializationTests { + + @Test + fun testPartialMessageRoundTripRemainsUninitialized() { + // 1. Create a partial message missing a required field (field1) + val original = Proto2MessageWithRequiredFields.createPartial( + field1 = null, + field2 = Proto2MessageWithMixedFields("valid") + ) + assertFalse(original.isInitialized, "Original should be uninitialized") + + // 2. Serialize to bytes + val bytes = original.serialize() + + // 3. Deserialize back + val deserialized = Proto2MessageWithRequiredFields.deserialize(bytes) + + // 4. Verify state is preserved + assertFalse(deserialized.isInitialized, "Deserialized message should still be uninitialized") + assertEquals(original.field2, deserialized.field2, "Other data should remain intact") + } + + @Test + fun testNestedUninitializedMessageRoundTrip() { + // 1. Create a message where the parent is "complete" but the child is "partial" + val partialChild = Proto2MessageWithMixedFields.createPartial(field1 = null) + val original = Proto2MessageWithRequiredFields.createPartial( + field1 = "parent-valid", + field2 = partialChild + ) + assertFalse(original.isInitialized, "Parent should be uninitialized because child is uninitialized") + + // 2. Round trip + val bytes = original.serialize() + val deserialized = Proto2MessageWithRequiredFields.deserialize(bytes) + + // 3. Verify + assertFalse(deserialized.isInitialized, "Deserialized parent should still be uninitialized") + assertFalse(deserialized.field2.isInitialized, "Deserialized child should still be uninitialized") + } + + @Test + fun testFullyInitializedRoundTrip() { + // 1. Create a fully valid message + val original = Proto2MessageWithRequiredFields( + field1 = "valid", + field2 = Proto2MessageWithMixedFields(field1 = "nested-valid") + ) + assertTrue(original.isInitialized) + + // 2. Round trip + val bytes = original.serialize() + val deserialized = Proto2MessageWithRequiredFields.deserialize(bytes) + + // 3. Verify + assertTrue(deserialized.isInitialized, "Deserialized message should be fully initialized") + assertEquals("valid", deserialized.field1) + } + + @Test + fun testEmptyRepeatedAndMapRoundTrip() { + // In proto2, empty repeated/map fields are initialized by default + // as long as the local required fields are present. + val original = Proto2MessageWithRequiredFields.createPartial( + field1 = "valid", + field2 = Proto2MessageWithMixedFields("valid"), + field3List = emptyList(), + field4Map = emptyMap() + ) + + assertTrue(original.isInitialized) + + val deserialized = Proto2MessageWithRequiredFields.deserialize(original.serialize()) + assertTrue(deserialized.isInitialized, "Message with empty collections should stay initialized") + } +} diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/serialization/SelfMessageSerializationTest.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/serialization/SelfMessageSerializationTest.kt index ee41f7a5..fa3a97c4 100644 --- a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/serialization/SelfMessageSerializationTest.kt +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/serialization/SelfMessageSerializationTest.kt @@ -16,12 +16,15 @@ import io.github.timortel.kmpgrpc.test.Unknownfield import io.github.timortel.kmpgrpc.test.longMessage import io.github.timortel.kmpgrpc.test.messageWithSubMessage import io.github.timortel.kmpgrpc.test.oneOfMessage +import io.github.timortel.kmpgrpc.test.proto2.Proto2GroupTest import io.github.timortel.kmpgrpc.test.repeatedLongMessage import io.github.timortel.kmpgrpc.test.simpleMessage import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createComplexRepeated import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createMessageWithAllExtensions import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createMessageWithAllTypes import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createNonPackedTypesMessage +import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createProto2GroupMessageWithExtensions +import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createProto2NestedGroupMessage import io.github.timortel.kotlin_multiplatform_grpc_plugin.test.createScalarMessage import kotlin.test.Test import kotlin.test.assertEquals @@ -190,4 +193,22 @@ class SelfMessageSerializationTest { assertEquals(msg, reconstructed) } + + @Test + fun testGroupSerialization() { + val msg = createProto2NestedGroupMessage() + + val reconstructed = Proto2GroupTest.A.deserialize(msg.serialize()) + + assertEquals(msg, reconstructed) + } + + @Test + fun testGroupExtensionSerialization() { + val msg = createProto2GroupMessageWithExtensions() + + val reconstructed = Proto2GroupTest.E.deserialize(msg.serialize()) + + assertEquals(msg, reconstructed) + } } diff --git a/kmp-grpc-internal-test/src/nativeJvmTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/ClientCredentialsRpcTest.kt b/kmp-grpc-internal-test/src/nativeJvmTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/ClientCredentialsRpcTest.kt index 6e8cf616..8de0941c 100644 --- a/kmp-grpc-internal-test/src/nativeJvmTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/ClientCredentialsRpcTest.kt +++ b/kmp-grpc-internal-test/src/nativeJvmTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/ClientCredentialsRpcTest.kt @@ -1,8 +1,8 @@ package io.github.timortel.kotlin_multiplatform_grpc_plugin.test -import iio.github.timortel.kmpgrpc.internal.test.CA_CERTIFICATE -import iio.github.timortel.kmpgrpc.internal.test.CLIENT_CERTIFICATE -import iio.github.timortel.kmpgrpc.internal.test.CLIENT_KEY +import io.github.timortel.kmpgrpc.internal.test.CA_CERTIFICATE +import io.github.timortel.kmpgrpc.internal.test.CLIENT_CERTIFICATE +import io.github.timortel.kmpgrpc.internal.test.CLIENT_KEY import io.github.timortel.kmpgrpc.core.Certificate import io.github.timortel.kmpgrpc.core.Channel import io.github.timortel.kmpgrpc.core.PrivateKey diff --git a/kmp-grpc-internal-test/src/nativeJvmTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/CustomCertificatesRpcTest.kt b/kmp-grpc-internal-test/src/nativeJvmTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/CustomCertificatesRpcTest.kt index dccd27dd..a9be594c 100644 --- a/kmp-grpc-internal-test/src/nativeJvmTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/CustomCertificatesRpcTest.kt +++ b/kmp-grpc-internal-test/src/nativeJvmTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/CustomCertificatesRpcTest.kt @@ -1,7 +1,7 @@ package io.github.timortel.kotlin_multiplatform_grpc_plugin.test -import iio.github.timortel.kmpgrpc.internal.test.CA_CERTIFICATE -import iio.github.timortel.kmpgrpc.internal.test.STANDALONE_LEAF_CERTIFICATE +import io.github.timortel.kmpgrpc.internal.test.CA_CERTIFICATE +import io.github.timortel.kmpgrpc.internal.test.STANDALONE_LEAF_CERTIFICATE import io.github.timortel.kmpgrpc.core.Certificate import io.github.timortel.kmpgrpc.core.Channel import io.github.timortel.kmpgrpc.core.StatusException diff --git a/kmp-grpc-internal-test/test-server/build.gradle.kts b/kmp-grpc-internal-test/test-server/build.gradle.kts index 85ef6ff1..9c20c5c9 100644 --- a/kmp-grpc-internal-test/test-server/build.gradle.kts +++ b/kmp-grpc-internal-test/test-server/build.gradle.kts @@ -45,7 +45,7 @@ dependencies { sourceSets { main { proto { - srcDirs("../src/commonMain/proto/general") + srcDirs("../src/commonMain/proto/general", "../src/commonMain/proto/proto2") } kotlin.srcDir(layout.buildDirectory.dir("generated/source/proto/main/grpc")) kotlin.srcDir(layout.buildDirectory.dir("generated/source/proto/main/grpckt")) diff --git a/kmp-grpc-internal-test/test-server/src/main/kotlin/io/github/timortel/kmpgrpc/testserver/TestServer.kt b/kmp-grpc-internal-test/test-server/src/main/kotlin/io/github/timortel/kmpgrpc/testserver/TestServer.kt index 33eecdac..edbbffef 100644 --- a/kmp-grpc-internal-test/test-server/src/main/kotlin/io/github/timortel/kmpgrpc/testserver/TestServer.kt +++ b/kmp-grpc-internal-test/test-server/src/main/kotlin/io/github/timortel/kmpgrpc/testserver/TestServer.kt @@ -1,6 +1,8 @@ package io.github.timortel.kmpgrpc.testserver import io.github.timortel.kmpgrpc.test.* +import io.github.timortel.kmpgrpc.test.proto2.Proto2GroupTest +import io.github.timortel.kmpgrpc.test.proto2.Proto2TestServiceGrpcKt import io.grpc.* import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder @@ -212,6 +214,11 @@ object TestServer { } } ) + .addService(object : Proto2TestServiceGrpcKt.Proto2TestServiceCoroutineImplBase() { + override suspend fun sendMessageWithNestedGroups(request: Proto2GroupTest.A): Proto2GroupTest.A { + return request + } + }) .intercept( object : ServerInterceptor { override fun interceptCall( diff --git a/kmp-grpc-plugin/build.gradle.kts b/kmp-grpc-plugin/build.gradle.kts index 3f85cd7c..387b40e3 100644 --- a/kmp-grpc-plugin/build.gradle.kts +++ b/kmp-grpc-plugin/build.gradle.kts @@ -1,3 +1,5 @@ +import io.github.timortel.kmpgrpc.plugin.DownloadWellKnownTypesTask + plugins { kotlin("jvm") version libs.versions.kotlin.get() id("java-gradle-plugin") @@ -47,6 +49,10 @@ kotlin { main { kotlin.srcDir(layout.projectDirectory.dir("../kmp-grpc-shared/src/commonMain")) } + + test { + resources.srcDir(layout.buildDirectory.dir("wkt")) + } } jvmToolchain(17) @@ -111,3 +117,11 @@ tasks.withType { exclude("**/Protobuf3BaseVisitor.java") exclude("**/Protobuf3BaseListener.java") } + +val downloadWellKnownTypesTask = tasks.register("downloadWellKnownTypes", DownloadWellKnownTypesTask::class.java) { + outputDir.set(layout.buildDirectory.dir("wkt")) +} + +tasks.named("processTestResources") { + dependsOn(downloadWellKnownTypesTask) +} diff --git a/kmp-grpc-plugin/buildSrc/build.gradle.kts b/kmp-grpc-plugin/buildSrc/build.gradle.kts new file mode 100644 index 00000000..270b3e29 --- /dev/null +++ b/kmp-grpc-plugin/buildSrc/build.gradle.kts @@ -0,0 +1,8 @@ +plugins { + `kotlin-dsl` +} + +repositories { + gradlePluginPortal() + google() +} diff --git a/kmp-grpc-plugin/buildSrc/src/main/java/io/github/timortel/kmpgrpc/plugin/DownloadWellKnownTypesTask.kt b/kmp-grpc-plugin/buildSrc/src/main/java/io/github/timortel/kmpgrpc/plugin/DownloadWellKnownTypesTask.kt new file mode 120000 index 00000000..6ad2929d --- /dev/null +++ b/kmp-grpc-plugin/buildSrc/src/main/java/io/github/timortel/kmpgrpc/plugin/DownloadWellKnownTypesTask.kt @@ -0,0 +1 @@ +../../../../../../../../../src/main/java/io/github/timortel/kmpgrpc/plugin/DownloadWellKnownTypesTask.kt \ No newline at end of file diff --git a/kmp-grpc-plugin/settings.gradle.kts b/kmp-grpc-plugin/settings.gradle.kts index fa8bc749..b5a0fabf 100644 --- a/kmp-grpc-plugin/settings.gradle.kts +++ b/kmp-grpc-plugin/settings.gradle.kts @@ -4,4 +4,4 @@ dependencyResolutionManagement { from(files("../gradle/libs.versions.toml")) } } -} \ No newline at end of file +} diff --git a/kmp-grpc-plugin/src/main/antlr/io/github/timortel/kmpgrpc/anltr/Protobuf2.g4 b/kmp-grpc-plugin/src/main/antlr/io/github/timortel/kmpgrpc/anltr/Protobuf2.g4 new file mode 100644 index 00000000..abbcc6f3 --- /dev/null +++ b/kmp-grpc-plugin/src/main/antlr/io/github/timortel/kmpgrpc/anltr/Protobuf2.g4 @@ -0,0 +1,720 @@ +/** + * A Protocol Buffers 2 grammar + * + * Original source: https://developers.google.com/protocol-buffers/docs/reference/proto2-spec + * + * follow by the style of Protobuf3.g4 written by the author @anatawa12 + * + * @author Boyce-Lee + * + * Direct copy from https://github.com/antlr/grammars-v4/blob/b3ba447883223e376ef003b1d1a32c80b2aad0f1/protobuf/protobuf2/Protobuf2.g4 + * Changes from the source above: + * - Added package header + * - Adapted rpc definition to expose clientStream and serverStream attributes + * - Added field options to group declarations + * - Added support for options on extension declarations. + * @author Tim Ortel + */ + +// $antlr-format alignTrailingComments true, columnLimit 150, minEmptyLines 1, maxEmptyLinesToKeep 1, reflowComments false, useTab false +// $antlr-format allowShortRulesOnASingleLine false, allowShortBlocksOnASingleLine true, alignSemicolons hanging, alignColons hanging + +grammar Protobuf2; + +@header { package io.github.timortel.kmpgrpc.anltr; } + +proto + : syntax? (importStatement | packageStatement | optionStatement | topLevelDef | emptyStatement_)* EOF + ; + +// Syntax + +syntax + : SYNTAX EQ (PROTO2_LIT_SINGLE | PROTO2_LIT_DOUBLE) SEMI + ; + +// Import Statement + +importStatement + : IMPORT (WEAK | PUBLIC)? strLit SEMI + ; + +// Package + +packageStatement + : PACKAGE fullIdent SEMI + ; + +// Option + +optionStatement + : OPTION optionName EQ constant SEMI + ; + +optionName + : fullIdent + | ( ident | LP fullIdent RP) ( DOT fullIdent)? + ; + +// Normal Field + +fieldLabel + : REQUIRED + | OPTIONAL + | REPEATED + ; + +field + : fieldLabel type_ fieldName EQ fieldNumber (LB fieldOptions RB)? SEMI + ; + +fieldOptions + : fieldOption (COMMA fieldOption)* + ; + +fieldOption + : optionName EQ constant + ; + +fieldNumber + : intLit + ; + +// Group field + +group + : fieldLabel GROUP groupName EQ fieldNumber (LB fieldOptions RB)? messageBody + ; + +// Oneof and oneof field + +oneof + : ONEOF oneofName LC (optionStatement | oneofField | emptyStatement_)* RC + ; + +oneofField + : type_ fieldName EQ fieldNumber (LB fieldOptions RB)? SEMI + ; + +// Map field + +mapField + : MAP LT keyType COMMA type_ GT mapName EQ fieldNumber (LB fieldOptions RB)? SEMI + ; + +keyType + : INT32 + | INT64 + | UINT32 + | UINT64 + | SINT32 + | SINT64 + | FIXED32 + | FIXED64 + | SFIXED32 + | SFIXED64 + | BOOL + | STRING + ; + +// field types + +type_ + : DOUBLE + | FLOAT + | INT32 + | INT64 + | UINT32 + | UINT64 + | SINT32 + | SINT64 + | FIXED32 + | FIXED64 + | SFIXED32 + | SFIXED64 + | BOOL + | STRING + | BYTES + | messageType + | enumType + ; + +// Extensions + +extensions + : EXTENSIONS ranges (LB fieldOptions RB)? SEMI + ; + +// Reserved + +reserved + : RESERVED (ranges | reservedFieldNames) SEMI + ; + +ranges + : range_ (COMMA range_)* + ; + +range_ + : intLit (TO ( intLit | MAX))? + ; + +reservedFieldNames + : strLit (COMMA strLit)* + ; + +// Top Level definitions + +topLevelDef + : messageDef + | enumDef + | serviceDef + | extendDef + ; + +// enum + +enumDef + : ENUM enumName enumBody + ; + +enumBody + : LC enumElement* RC + ; + +enumElement + : optionStatement + | enumField + | reserved + | emptyStatement_ + ; + +enumField + : ident EQ MINUS? intLit enumValueOptions? SEMI + ; + +enumValueOptions + : LB enumValueOption (COMMA enumValueOption)* RB + ; + +enumValueOption + : optionName EQ constant + ; + +// message + +messageDef + : MESSAGE messageName messageBody + ; + +messageBody + : LC messageElement* RC + ; + +messageElement + : field + | enumDef + | messageDef + | extendDef + | optionStatement + | oneof + | mapField + | extensions + | group + | reserved + | emptyStatement_ + ; + +// extend + +extendDef + : EXTEND messageType LC extendElement* RC + ; + +extendElement + : field + | group + | emptyStatement_ + ; + +// service + +serviceDef + : SERVICE serviceName LC serviceElement* RC + ; + +serviceElement + : optionStatement + | rpc + | stream + | emptyStatement_ + ; + +rpc + : RPC rpcName LP clientStream=STREAM? messageType RP RETURNS LP serverStream=STREAM? messageType RP ( + LC ( optionStatement | emptyStatement_)* RC + | SEMI + ) + ; + +stream + : STREAM streamName LP messageType COMMA messageType RP ( + LC ( optionStatement | emptyStatement_)* RC + | SEMI + ) + ; + +// lexical + +constant + : fullIdent + | (MINUS | PLUS)? intLit + | ( MINUS | PLUS)? floatLit + | strLit + | boolLit + | blockLit + ; + +// not specified in specification but used in tests +blockLit + : LC (ident COLON constant (COMMA)?)* RC + ; + +emptyStatement_ + : SEMI + ; + +// Lexical elements + +ident + : IDENTIFIER + | keywords + ; + +fullIdent + : ident (DOT ident)* + ; + +messageName + : ident + ; + +enumName + : ident + ; + +fieldName + : ident + ; + +groupName + : ident + ; + +oneofName + : ident + ; + +mapName + : ident + ; + +serviceName + : ident + ; + +rpcName + : ident + ; + +streamName + : ident + ; + +messageType + : DOT? (ident DOT)* messageName + ; + +enumType + : DOT? (ident DOT)* enumName + ; + +intLit + : INT_LIT + ; + +strLit + : STR_LIT+ + | PROTO2_LIT_SINGLE + | PROTO2_LIT_DOUBLE + ; + +boolLit + : BOOL_LIT + ; + +floatLit + : FLOAT_LIT + ; + +// keywords +SYNTAX + : 'syntax' + ; + +IMPORT + : 'import' + ; + +WEAK + : 'weak' + ; + +PUBLIC + : 'public' + ; + +PACKAGE + : 'package' + ; + +OPTION + : 'option' + ; + +REPEATED + : 'repeated' + ; + +OPTIONAL + : 'optional' + ; + +REQUIRED + : 'required' + ; + +GROUP + : 'group' + ; + +ONEOF + : 'oneof' + ; + +MAP + : 'map' + ; + +INT32 + : 'int32' + ; + +INT64 + : 'int64' + ; + +UINT32 + : 'uint32' + ; + +UINT64 + : 'uint64' + ; + +SINT32 + : 'sint32' + ; + +SINT64 + : 'sint64' + ; + +FIXED32 + : 'fixed32' + ; + +FIXED64 + : 'fixed64' + ; + +SFIXED32 + : 'sfixed32' + ; + +SFIXED64 + : 'sfixed64' + ; + +BOOL + : 'bool' + ; + +STRING + : 'string' + ; + +DOUBLE + : 'double' + ; + +FLOAT + : 'float' + ; + +BYTES + : 'bytes' + ; + +RESERVED + : 'reserved' + ; + +EXTENSIONS + : 'extensions' + ; + +TO + : 'to' + ; + +MAX + : 'max' + ; + +ENUM + : 'enum' + ; + +EXTEND + : 'extend' + ; + +MESSAGE + : 'message' + ; + +SERVICE + : 'service' + ; + +RPC + : 'rpc' + ; + +STREAM + : 'stream' + ; + +RETURNS + : 'returns' + ; + +PROTO2_LIT_SINGLE + : '"proto2"' + ; + +PROTO2_LIT_DOUBLE + : '\'proto2\'' + ; + +// symbols + +SEMI + : ';' + ; + +EQ + : '=' + ; + +LP + : '(' + ; + +RP + : ')' + ; + +LB + : '[' + ; + +RB + : ']' + ; + +LC + : '{' + ; + +RC + : '}' + ; + +LT + : '<' + ; + +GT + : '>' + ; + +DOT + : '.' + ; + +COMMA + : ',' + ; + +COLON + : ':' + ; + +PLUS + : '+' + ; + +MINUS + : '-' + ; + +STR_LIT + : '\'' CHAR_VALUE*? '\'' + | '"' CHAR_VALUE*? '"' + ; + +fragment CHAR_VALUE + : HEX_ESCAPE + | OCT_ESCAPE + | CHAR_ESCAPE + | ~[\u0000\n\\] + ; + +fragment HEX_ESCAPE + : '\\' ('x' | 'X') HEX_DIGIT HEX_DIGIT + ; + +fragment OCT_ESCAPE + : '\\' OCTAL_DIGIT OCTAL_DIGIT OCTAL_DIGIT + ; + +fragment CHAR_ESCAPE + : '\\' ('a' | 'b' | 'f' | 'n' | 'r' | 't' | 'v' | '\\' | '\'' | '"') + ; + +BOOL_LIT + : 'true' + | 'false' + ; + +FLOAT_LIT + : DECIMALS DOT DECIMALS? EXPONENT? + | DECIMALS EXPONENT + | DOT DECIMALS EXPONENT? + | 'inf' + | 'nan' + ; + +fragment EXPONENT + : ('e' | 'E') (PLUS | MINUS)? DECIMALS + ; + +fragment DECIMALS + : DECIMAL_DIGIT+ + ; + +INT_LIT + : DECIMAL_LIT + | OCTAL_LIT + | HEX_LIT + ; + +fragment DECIMAL_LIT + : [1-9] DECIMAL_DIGIT* + ; + +fragment OCTAL_LIT + : '0' OCTAL_DIGIT* + ; + +fragment HEX_LIT + : '0' ('x' | 'X') HEX_DIGIT+ + ; + +IDENTIFIER + : LETTER (LETTER | DECIMAL_DIGIT)* + ; + +fragment LETTER + : [A-Za-z_] + ; + +fragment DECIMAL_DIGIT + : [0-9] + ; + +fragment OCTAL_DIGIT + : [0-7] + ; + +fragment HEX_DIGIT + : [0-9A-Fa-f] + ; + +// comments +WS + : [ \t\r\n\u000C]+ -> skip + ; + +LINE_COMMENT + : '//' ~[\r\n]* -> skip + ; + +COMMENT + : '/*' .*? '*/' -> skip + ; + +keywords + : SYNTAX + | IMPORT + | WEAK + | PUBLIC + | PACKAGE + | OPTION + | REPEATED + | OPTIONAL + | REQUIRED + | GROUP + | ONEOF + | MAP + | INT32 + | INT64 + | UINT32 + | UINT64 + | SINT32 + | SINT64 + | FIXED32 + | FIXED64 + | SFIXED32 + | SFIXED64 + | BOOL + | STRING + | DOUBLE + | FLOAT + | BYTES + | RESERVED + | EXTENSIONS + | TO + | MAX + | ENUM + | MESSAGE + | EXTEND + | SERVICE + | RPC + | STREAM + | STREAM + | RETURNS + | BOOL_LIT + ; diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/DownloadWellKnownTypesTask.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/DownloadWellKnownTypesTask.kt index d8f60267..a7836a23 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/DownloadWellKnownTypesTask.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/DownloadWellKnownTypesTask.kt @@ -5,6 +5,7 @@ import org.gradle.api.file.DirectoryProperty import org.gradle.api.model.ObjectFactory import org.gradle.api.tasks.OutputDirectory import org.gradle.api.tasks.TaskAction +import java.net.URI import java.net.URL import javax.inject.Inject @@ -19,6 +20,7 @@ abstract class DownloadWellKnownTypesTask @Inject constructor(objectFactory: Obj private val wellKnownTypes = listOf( "any.proto", "api.proto", + "descriptor.proto", "duration.proto", "empty.proto", "field_mask.proto", @@ -39,7 +41,7 @@ abstract class DownloadWellKnownTypesTask @Inject constructor(objectFactory: Obj wellKnownTypes.forEach { protoFile -> val protoFileUrl = "$WELL_KNOW_BASE_URL/$protoFile" - val url = URL(protoFileUrl) + val url = URI(protoFileUrl).toURL() val outputFile = outputDir.file("$WELL_KNOWN_TYPES_RELATIVE_PATH/$protoFile").get().asFile logger.info("Downloading $protoFileUrl into ${outputFile.path}") diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/CompilationException.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/CompilationException.kt index 89c0ed89..92897954 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/CompilationException.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/CompilationException.kt @@ -26,7 +26,7 @@ sealed class CompilationException(val msg: String, val filePath: String, val ctx class EnumNoFields(message: String, file: ProtoFile, ctx: ParserRuleContext) : CompilationException(message, file, ctx) class IllegalClosedEnumImport(message: String, file: ProtoFile, ctx: ParserRuleContext) : CompilationException(message, file, ctx) - // Name Resolving + // Name Resolution class ResolvedToPackage(message: String, file: ProtoFile, ctx: ParserRuleContext) : CompilationException(message, file, ctx) class ConflictingResolution(message: String, file: ProtoFile, ctx: ParserRuleContext) : CompilationException(message, file, ctx) class UnresolvedReference(message: String, file: ProtoFile, ctx: ParserRuleContext) : CompilationException(message, file, ctx) diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/ProtoSourceGenerator.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/ProtoSourceGenerator.kt index 3fb80022..cadafe13 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/ProtoSourceGenerator.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/ProtoSourceGenerator.kt @@ -1,6 +1,8 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration import com.squareup.kotlinpoet.FileSpec +import io.github.timortel.kmpgrpc.anltr.Protobuf2Lexer +import io.github.timortel.kmpgrpc.anltr.Protobuf2Parser import io.github.timortel.kmpgrpc.anltr.Protobuf3Lexer import io.github.timortel.kmpgrpc.anltr.Protobuf3Parser import io.github.timortel.kmpgrpc.anltr.ProtobufEditionsLexer @@ -16,16 +18,19 @@ import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.Visibility import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.file.ProtoFile import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.structure.ProtoFolder import io.github.timortel.kmpgrpc.plugin.sourcegeneration.parsing.ProtobufModelBuilderVisitor +import org.antlr.v4.runtime.TokenStream +import org.antlr.v4.runtime.CharStream import org.antlr.v4.runtime.CharStreams import org.antlr.v4.runtime.CommonTokenStream +import org.antlr.v4.runtime.Lexer import org.slf4j.Logger import java.io.File object ProtoSourceGenerator { // regexes that allow for arbitrarily many whitespaces and comments in the proto file, but expect the syntax/edition statement first. - private val proto3Regex = - "^\\s*(?:(?://[^\\n]*|/\\*[\\s\\S]*?\\*/)\\s*|\\s*\\n)*syntax\\s*=\\s*[\"']proto3[\"']\\s*;[\\s\\S]*$".toRegex() + private val protoSyntaxRegex = + "^\\s*(?:(?://[^\\n]*|/\\*[\\s\\S]*?\\*/)\\s*|\\s*\\n)*syntax\\s*=\\s*[\"'](proto2|proto3)[\"']\\s*;[\\s\\S]*$".toRegex() private val protoEditionsRegex = "^\\s*(?:(?://[^\\n]*|/\\*[\\s\\S]*?\\*/)\\s*|\\s*\\n)*edition\\s*=\\s*[\"'](\\d+)[\"']\\s*;[\\s\\S]*$".toRegex() @@ -161,50 +166,65 @@ object ProtoSourceGenerator { val fileText = file.inputStream().use { inputStream -> inputStream.reader().use { reader -> reader.readText() } } - return when { - fileText.matches(proto3Regex) -> { - val visitor = ProtobufModelBuilderVisitor( - filePath = file.path, - fileNameWithoutExtension = file.nameWithoutExtension, - fileName = file.name, - protoLanguageVersion = ProtoLanguageVersion.PROTO3 - ) - - val proto3Lexer = Protobuf3Lexer(CharStreams.fromStream(file.inputStream())) - val proto3Parser = Protobuf3Parser(CommonTokenStream(proto3Lexer)) - val proto3File = proto3Parser.proto() - - visitor.visitProto(proto3File) + val languageVersion = when { + fileText.matches(protoSyntaxRegex) -> { + when (val versionGroup = protoSyntaxRegex.matchEntire(fileText)!!.groups[1]!!.value) { + "proto2" -> ProtoLanguageVersion.PROTO2 + "proto3" -> ProtoLanguageVersion.PROTO3 + else -> { + logger.warn("File $file uses unsupported proto language version $versionGroup. Only ${ProtoLanguageVersion.entries} are supported.") + return null + } + } } fileText.matches(protoEditionsRegex) -> { - val versionGroup = protoEditionsRegex.matchEntire(fileText)!!.groups[1]!!.value - - val visitor = ProtobufModelBuilderVisitor( - filePath = file.path, - fileNameWithoutExtension = file.nameWithoutExtension, - fileName = file.name, - protoLanguageVersion = when (versionGroup) { - "2023" -> ProtoLanguageVersion.EDITION2023 - "2024" -> ProtoLanguageVersion.EDITION2024 - else -> { - logger.warn("File $file uses unsupported proto editions $versionGroup. Only ${ProtoLanguageVersion.entries} are supported.") - return null - } + when (val versionGroup = protoEditionsRegex.matchEntire(fileText)!!.groups[1]!!.value) { + "2023" -> ProtoLanguageVersion.EDITION2023 + "2024" -> ProtoLanguageVersion.EDITION2024 + else -> { + logger.warn("File $file uses unsupported proto editions $versionGroup. Only ${ProtoLanguageVersion.entries} are supported.") + return null } - ) - - val protoEditionsLexer = ProtobufEditionsLexer(CharStreams.fromStream(file.inputStream())) - val protoEditionsParser = ProtobufEditionsParser(CommonTokenStream(protoEditionsLexer)) - val protoEditionsFile = protoEditionsParser.proto() - - visitor.visitProto(protoEditionsFile) + } } else -> { - logger.warn("File $file does not seem to conform to any known proto syntax. Only proto3 and proto editions files are currently supported. Ignoring file.") - null + logger.debug("Ignoring file {}, as it is not recognized as a valid proto file", file) + return null } } + + data class Toolchain( + val lexerFactory: (CharStream) -> Lexer, + val visitProto: (ProtobufModelBuilderVisitor, TokenStream) -> ProtoFile + ) + + val toolchain = when (languageVersion) { + ProtoLanguageVersion.PROTO2 -> Toolchain( + lexerFactory = ::Protobuf2Lexer, + visitProto = { visitor, tokenStream -> visitor.visitProto(Protobuf2Parser(tokenStream).proto()) } + ) + + ProtoLanguageVersion.PROTO3 -> Toolchain( + lexerFactory = ::Protobuf3Lexer, + visitProto = { visitor, tokenStream -> visitor.visitProto(Protobuf3Parser(tokenStream).proto()) } + ) + + ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> Toolchain( + lexerFactory = ::ProtobufEditionsLexer, + visitProto = { visitor, tokenStream -> visitor.visitProto(ProtobufEditionsParser(tokenStream).proto()) } + ) + } + + val visitor = ProtobufModelBuilderVisitor( + filePath = file.path, + fileNameWithoutExtension = file.nameWithoutExtension, + fileName = file.name, + protoLanguageVersion = languageVersion + ) + + val lexer = toolchain.lexerFactory(CharStreams.fromStream(file.inputStream())) + return toolchain.visitProto(visitor, CommonTokenStream(lexer)) } } diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/constants/Const.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/constants/Const.kt index c6fc6340..11d4b24c 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/constants/Const.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/constants/Const.kt @@ -1,5 +1,6 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants +import com.squareup.kotlinpoet.BOOLEAN import com.squareup.kotlinpoet.LIST import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import com.squareup.kotlinpoet.STRING @@ -38,17 +39,10 @@ object Const { } object Message { - val reservedAttributeNames = setOf( - "fullName", - "requiredSize", - Companion.WrapperDeserializationFunction.TAG_LOCAL_VARIABLE, - Companion.WrapperDeserializationFunction.ENUM_NUMBER_VALUE_LOCAL_VARIABLE, - Companion.WrapperDeserializationFunction.ENUM_VALUE_LOCAL_VARIABLE, - Constructor.UnknownFields.name - ) - val fullNameProperty = Property.of("fullName", STRING) + val isInitializedProperty = Property.of("isInitialized", BOOLEAN) + object Constructor { val UnknownFields = Property.of("unknownFields", LIST.parameterizedBy(unknownField)) val MessageExtensions = Property.of("extensions", kmMessageExtensions) @@ -63,6 +57,7 @@ object Const { val reservedAttributeNames = setOf("requiredSize") const val REQUIRED_SIZE_PROPERTY_NAME = "requiredSize" + val isInitializedProperty = Property.of("isInitialized", BOOLEAN) const val SERIALIZE_FUNCTION_NAME = "serialize" const val SERIALIZE_FUNCTION_STREAM_PARAM_NAME = "stream" @@ -104,6 +99,16 @@ object Const { const val EXTENSION_BUILDER_LOCAL_VARIABLE = "extensionBuilder" } } + + val reservedAttributeNames = setOf( + fullNameProperty.name, + "requiredSize", + isInitializedProperty.name, + Companion.WrapperDeserializationFunction.TAG_LOCAL_VARIABLE, + Companion.WrapperDeserializationFunction.ENUM_NUMBER_VALUE_LOCAL_VARIABLE, + Companion.WrapperDeserializationFunction.ENUM_VALUE_LOCAL_VARIABLE, + Constructor.UnknownFields.name + ) } object DSL { diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/constants/library_fields.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/constants/library_fields.kt index 9939dab5..51823959 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/constants/library_fields.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/constants/library_fields.kt @@ -70,3 +70,6 @@ val fieldTypeBytes = fieldType.nestedClass("Bytes") // util val mergeUnknownFieldOrExtension = MemberName(PACKAGE_MESSAGE, "mergeUnknownFieldOrExtension") val readMapEntry = MemberName(PACKAGE_IO, "readMapEntry") + +// exceptions +val uninitializedMessageException = ClassName(PACKAGE_BASE, "UninitializedMessageException") diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/MessageConstructorCallWriter.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/MessageConstructorCallWriter.kt new file mode 100644 index 00000000..1d329f31 --- /dev/null +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/MessageConstructorCallWriter.kt @@ -0,0 +1,91 @@ +package io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators + +import com.squareup.kotlinpoet.CodeBlock +import com.squareup.kotlinpoet.MemberName.Companion.member +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.Const +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoMessage +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.ProtoOneOf +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoMapField +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoMessageField +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.joinToCodeBlock + +object MessageConstructorCallWriter { + + enum class ConstructorType { + DIRECT, + BUILD, + BUILD_PARTIAL + } + + fun getConstructorCallCode( + message: ProtoMessage, + type: ConstructorType, + getFieldParameter: (ProtoMessageField) -> CodeBlock, + getMapFieldParameter: (ProtoMapField) -> CodeBlock, + getOneOfFieldParameter: (ProtoOneOf) -> CodeBlock, + getUnknownFieldsParameter: () -> CodeBlock?, + getExtensionParameter: () -> CodeBlock, + ): CodeBlock { + return CodeBlock.builder() + .apply { + val companion = message.className.nestedClass("Companion") + + when (type) { + ConstructorType.DIRECT -> add("%T(", message.className) + ConstructorType.BUILD -> add("%M(", companion.member("invoke")) + ConstructorType.BUILD_PARTIAL -> add("%M(", companion.member("createPartial")) + } + + add("\n") + indent() + + val separator = ",\n" + + val fields = message.fields.joinToCodeBlock(separator) { field -> + add("%N = ", field.attributeName) + add(getFieldParameter(field)) + } + + val mapFields = message.mapFields.joinToCodeBlock(separator) { field -> + add("%N = ", field.attributeName) + add(getMapFieldParameter(field)) + } + + val oneOfFields = message.oneOfs.joinToCodeBlock(separator) { oneOf -> + add("%N = ", oneOf.attributeName) + add(getOneOfFieldParameter(oneOf)) + } + + val extensionBlock = CodeBlock.builder() + .add("%N = ", Const.Message.Constructor.MessageExtensions.name) + .add(getExtensionParameter()) + .build() + + val unknownFields = getUnknownFieldsParameter()?.let { + listOf( + CodeBlock.builder() + .add("%N = ", Const.Message.Constructor.UnknownFields.name) + .add(it) + .build() + ) + }.orEmpty() + + val blocks = listOf(fields, mapFields, oneOfFields) + unknownFields + + if (message.isExtendable) listOf(extensionBlock) else emptyList() + + add( + blocks + .filter { it.isNotEmpty() } + .joinToCodeBlock(separator) { add(it) } + ) + + if (fields.isNotEmpty() || mapFields.isNotEmpty() || oneOfFields.isNotEmpty() || message.isExtendable) { + add("\n") + } + + unindent() + add(")") + } + .build() + } +} diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/dsl/ActualProtoDslWriter.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/dsl/ActualProtoDslWriter.kt index f3f20ef9..81cd18f0 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/dsl/ActualProtoDslWriter.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/dsl/ActualProtoDslWriter.kt @@ -3,54 +3,47 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.dsl import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.FunSpec import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.Const -import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.joinToCodeBlock +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.uninitializedMessageException +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.MessageConstructorCallWriter import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoMessage +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoMessageField object ActualProtoDslWriter : ProtoDslWriter(true) { override fun modifyBuildFunction(builder: FunSpec.Builder, message: ProtoMessage) { builder.apply { - addCode("return %T(", message.className) + addCode("val msg = ") - val separator = ",\n" - - val fields = message.fields.joinToCodeBlock(separator) { field -> - add("%N = %N ?: ", field.attributeName, field.attributeName) - add(field.defaultValue()) - } - - val mapFields = message.mapFields.joinToCodeBlock(separator) { field -> - add("%N = %N ?: emptyMap()", field.attributeName, field.attributeName) - } - - val oneOfFields = message.oneOfs.joinToCodeBlock(separator) { oneOf -> - add( - "%N = %N", - oneOf.attributeName, - oneOf.attributeName + addCode( + MessageConstructorCallWriter.getConstructorCallCode( + message = message, + type = MessageConstructorCallWriter.ConstructorType.BUILD_PARTIAL, + getFieldParameter = { field -> + if (field.isConstructorParameterNullable(ProtoMessageField.ConstructorParameterType.CREATE_PARTIAL)) { + CodeBlock.of("%N", field.attributeName) + } else { + CodeBlock.builder() + .add("%N ?: ", field.attributeName) + .add(field.defaultValue()) + .build() + } + }, + getMapFieldParameter = { field -> + CodeBlock.of("%N ?: emptyMap()", field.attributeName) + }, + getOneOfFieldParameter = { oneOf -> + CodeBlock.of("%N", oneOf.attributeName) + }, + getUnknownFieldsParameter = { null }, + getExtensionParameter = { CodeBlock.of("%N.build()", Const.DSL.MessageExtensions.name) } ) - } - - val extensionBlock = CodeBlock.of( - "%N = %N.build()", - Const.Message.Constructor.MessageExtensions.name, - Const.DSL.MessageExtensions.name ) - val blocks = listOf(fields, mapFields, oneOfFields) + - if (message.isExtendable) listOf(extensionBlock) else emptyList() - - addCode( - blocks - .filter { it.isNotEmpty() } - .joinToCodeBlock(separator) { add(it) } - ) + addCode("\n") - if (fields.isNotEmpty() || mapFields.isNotEmpty() || oneOfFields.isNotEmpty() || message.isExtendable) { - addCode("\n") - } + addStatement("if (!msg.isInitialized) throw %T(msg)", uninitializedMessageException) - addCode(")") + addStatement("return msg") } } -} \ No newline at end of file +} diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/ProtoMessageWriter.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/ProtoMessageWriter.kt index a6fc2c26..735cfe0b 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/ProtoMessageWriter.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/ProtoMessageWriter.kt @@ -13,6 +13,7 @@ import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.kmMessageWit import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.enumeration.ProtoEnumerationWriter import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.field.ProtoFieldWriter import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.message.extensions.FieldPropertyConstructorExtension +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.message.extensions.IsInitializedFieldExtension import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.message.extensions.MessageWriterExtension import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.message.extensions.UnknownFieldsExtension import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.message.extensions.functions.CopyFunctionExtension @@ -55,7 +56,8 @@ abstract class ProtoMessageWriter(private val isActual: Boolean) { UnknownFieldsExtension, ExtensionsPropertyExtension, ExtensionDefinitionExtension, - DefaultExtensionRegistryExtension + DefaultExtensionRegistryExtension, + IsInitializedFieldExtension ) /** @@ -77,6 +79,7 @@ abstract class ProtoMessageWriter(private val isActual: Boolean) { primaryConstructor( FunSpec .constructorBuilder() + .addModifiers(KModifier.PRIVATE) .apply { if (isActual) { addModifiers(KModifier.ACTUAL) diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/FieldPropertyConstructorExtension.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/FieldPropertyConstructorExtension.kt index 7ce72b21..33ee23c5 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/FieldPropertyConstructorExtension.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/FieldPropertyConstructorExtension.kt @@ -3,19 +3,136 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile. import com.squareup.kotlinpoet.* import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import io.github.timortel.kmpgrpc.plugin.sourcegeneration.SourceTarget +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.Const +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.kmMessageExtensions +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.MessageConstructorCallWriter import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoMessage import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoFieldCardinality +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoMessageField +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.type.ProtoType object FieldPropertyConstructorExtension : MessageWriterExtension { override fun applyToConstructor(builder: FunSpec.Builder, message: ProtoMessage, sourceTarget: SourceTarget) { + addConstructorParameters( + builder = builder, + message = message, + sourceTarget = sourceTarget, + type = ProtoMessageField.ConstructorParameterType.CONSTRUCTOR + ) + } + + override fun applyToCompanionObject(builder: TypeSpec.Builder, message: ProtoMessage, sourceTarget: SourceTarget) { + addCompanionObjectBuildFunction( + name = "invoke", + type = ProtoMessageField.ConstructorParameterType.CREATE, + builder = builder, + message = message, + sourceTarget = sourceTarget, + modifiers = listOf(KModifier.OPERATOR) + ) + + addCompanionObjectBuildFunction( + name = "createPartial", + type = ProtoMessageField.ConstructorParameterType.CREATE_PARTIAL, + builder = builder, + message = message, + sourceTarget = sourceTarget, + modifiers = emptyList() + ) + } + + private fun addCompanionObjectBuildFunction( + name: String, + type: ProtoMessageField.ConstructorParameterType, + builder: TypeSpec.Builder, + message: ProtoMessage, + sourceTarget: SourceTarget, + modifiers: List + ) { val isActual = sourceTarget is SourceTarget.Actual + builder.addFunction( + FunSpec.builder(name) + .addModifiers(modifiers) + .returns(message.className) + .apply { + addConstructorParameters( + builder = this, + message = message, + sourceTarget = sourceTarget, + type = type + ) + + if (message.isExtendable) { + addParameter( + Const.Message.Constructor.MessageExtensions + .parametrizedBy(message.className) + .toParamSpecBuilder() + .apply { + if (!isActual) defaultValue(CodeBlock.of("%T()", kmMessageExtensions)) + } + .build() + ) + } + + addParameter(Const.Message.Constructor.UnknownFields + .toParamSpecBuilder() + .apply { + if (!isActual) defaultValue("emptyList()") + } + .build() + ) + + if (isActual) { + addModifiers(KModifier.ACTUAL) + + addCode("return ") + addCode( + MessageConstructorCallWriter.getConstructorCallCode( + message = message, + type = MessageConstructorCallWriter.ConstructorType.DIRECT, + getFieldParameter = { field -> CodeBlock.of("%N", field.attributeName) }, + getMapFieldParameter = { field -> CodeBlock.of("%N", field.attributeName) }, + getOneOfFieldParameter = { field -> CodeBlock.of("%N", field.attributeName) }, + getUnknownFieldsParameter = { + CodeBlock.of( + "%N", + Const.Message.Constructor.UnknownFields.name + ) + }, + getExtensionParameter = { + CodeBlock.of( + "%N", + Const.Message.Constructor.MessageExtensions.name + ) + }, + ) + ) + } + } + .build() + ) + } + + private fun addConstructorParameters( + builder: FunSpec.Builder, + message: ProtoMessage, + sourceTarget: SourceTarget, + type: ProtoMessageField.ConstructorParameterType + ) { + val isActual = sourceTarget is SourceTarget.Actual + + val addDefaultValues = when (type) { + ProtoMessageField.ConstructorParameterType.CONSTRUCTOR -> false + ProtoMessageField.ConstructorParameterType.CREATE, ProtoMessageField.ConstructorParameterType.CREATE_PARTIAL -> true + } + //one of attributes do not get a parameter, as they get the one of parameter message.fields.forEach { field -> when (field.cardinality) { is ProtoFieldCardinality.Singular -> { - val isParamNullable = field.needsIsSetProperty + val isParamNullable = field.isConstructorParameterNullable(type) val type = if (isParamNullable) field.type.resolve().copy(nullable = true) else field.type.resolve() @@ -24,13 +141,15 @@ object FieldPropertyConstructorExtension : MessageWriterExtension { ParameterSpec .builder(field.attributeName, type) .apply { - if (!isActual) { + if (!isActual && addDefaultValues) { defaultValue( // If the field needs a isSet property, then the constructor must pass null by default if (isParamNullable) { CodeBlock.of("null") } else { - field.type.defaultValue() + field.type.defaultValue( + messageDefaultValue = ProtoType.MessageDefaultValue.EMPTY + ) } ) } @@ -44,7 +163,7 @@ object FieldPropertyConstructorExtension : MessageWriterExtension { ParameterSpec .builder(field.attributeName, LIST.parameterizedBy(field.type.resolve())) .apply { - if (!isActual) defaultValue("emptyList()") + if (!isActual && addDefaultValues) defaultValue("emptyList()") } .build() ) @@ -63,7 +182,7 @@ object FieldPropertyConstructorExtension : MessageWriterExtension { ) ) .apply { - if (!isActual) defaultValue("emptyMap()") + if (!isActual && addDefaultValues) defaultValue("emptyMap()") } .build() ) @@ -77,7 +196,7 @@ object FieldPropertyConstructorExtension : MessageWriterExtension { oneOf.sealedClassName ) .apply { - if (!isActual) defaultValue("%T", oneOf.sealedClassNameNotSet) + if (!isActual && addDefaultValues) defaultValue("%T", oneOf.sealedClassNameNotSet) } .build() ) diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/IsInitializedFieldExtension.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/IsInitializedFieldExtension.kt new file mode 100644 index 00000000..6ac021fc --- /dev/null +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/IsInitializedFieldExtension.kt @@ -0,0 +1,92 @@ +package io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.message.extensions + +import com.squareup.kotlinpoet.CodeBlock +import com.squareup.kotlinpoet.KModifier +import com.squareup.kotlinpoet.TypeSpec +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.SourceTarget +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.Const +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoMessage +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoFieldCardinality +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.isLegacyRequired +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.joinCodeBlocks +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.joinToCodeBlock + +object IsInitializedFieldExtension : MessageWriterExtension { + + override fun applyToClass(builder: TypeSpec.Builder, message: ProtoMessage, sourceTarget: SourceTarget) { + val isActual = sourceTarget is SourceTarget.Actual + + builder.addProperty( + Const.Message.isInitializedProperty.toPropertySpecBuilder(KModifier.OVERRIDE) + .apply { + if (isActual) { + addModifiers(KModifier.ACTUAL) + + initializer( + CodeBlock.builder().apply { + val requiredFields = message.fields.filter { field -> + field.cardinality.isLegacyRequired + } + + val subMessageFields = message.fields.filter { it.type.isMessage } + val subMessageMapFields = message.mapFields.filter { it.valuesType.isMessage } + val oneOfs = message.oneOfs.filter { oneOf -> oneOf.fields.any { it.type.isMessage } } + + val subMessages = subMessageFields + subMessageMapFields + oneOfs + + if (requiredFields.isEmpty() && subMessages.isEmpty()) { + add("true") + } else { + val separator = " && " + val requiredFieldsBool = requiredFields.joinToCodeBlock(separator) { + add("%N", it.isSetProperty.name) + } + + val subMessageFieldsBool = subMessageFields.joinToCodeBlock(separator) { + when (it.cardinality) { + is ProtoFieldCardinality.Singular -> { + add( + "(%1N == null || %1N.%2N)", + it.attributeName, + Const.Message.isInitializedProperty.name + ) + } + ProtoFieldCardinality.Repeated -> { + add( + "%N.all { it.%N }", + it.attributeName, + Const.Message.isInitializedProperty.name + ) + } + } + } + + val subMessageOneOfFieldsBool = oneOfs.joinToCodeBlock(separator) { + add( + "%N.%N", + it.attributeName, + Const.Message.OneOf.isInitializedProperty.name + ) + } + + val subMessageMapFieldsBool = subMessageMapFields.joinToCodeBlock(separator) { + add( + "%N.values.all { it.%N }", + it.attributeName, + Const.Message.isInitializedProperty.name + ) + } + + val impl = listOf(requiredFieldsBool, subMessageFieldsBool, subMessageOneOfFieldsBool, subMessageMapFieldsBool).joinCodeBlocks(separator) + + add(impl) + } + } + .build() + ) + } + } + .build() + ) + } +} diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/serialization/DeserializationFunctionExtension.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/serialization/DeserializationFunctionExtension.kt index 4ad66bcb..aa96d509 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/serialization/DeserializationFunctionExtension.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/serialization/DeserializationFunctionExtension.kt @@ -1,10 +1,10 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.message.extensions.serialization import com.squareup.kotlinpoet.* -import com.squareup.kotlinpoet.MemberName.Companion.member import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import io.github.timortel.kmpgrpc.plugin.sourcegeneration.SourceTarget import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.* +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.MessageConstructorCallWriter import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoEnum import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoMessage import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoFieldCardinality @@ -12,8 +12,6 @@ import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.mess import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoRegularField import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.file.ProtoFile import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.type.ProtoType -import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.joinCodeBlocks -import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.joinToCodeBlock import io.github.timortel.kmpgrpc.shared.internal.io.DataType import io.github.timortel.kmpgrpc.shared.internal.io.wireFormatForType import io.github.timortel.kmpgrpc.shared.internal.io.wireFormatMakeTag @@ -102,7 +100,7 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { // Unknown field or extension addStatement( - "else -> %M(%N.%N(%N, %N), %N, %N)", + "else -> if·(!%M(%N.%N(%N, %N), %N, %N))·break", mergeUnknownFieldOrExtension, wrapperParamName, "readUnknownFieldOrExtension", @@ -184,7 +182,7 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { ) add("if·(%N·!=·null)·%N·$assignMode·", enumVar, variableName) - constructType { add("%N", enumVar)} + constructType { add("%N", enumVar) } add("\n") addStatement( @@ -201,7 +199,7 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { when (val type = type) { is ProtoType.DefType -> when (val decl = type.resolveDeclaration()) { is ProtoEnum -> buildReadScalarFieldOpenEnumTypeCode(type) - is ProtoMessage -> buildReadScalarFieldMessageTypeCode(type, decl) + is ProtoMessage -> buildReadScalarFieldMessageTypeCode(type, decl, fieldNumber) } is ProtoType.NonDeclType -> { @@ -248,17 +246,9 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { field.type is ProtoType.DefType && field.type.isMessage -> { val message = field.type.resolveDeclaration() as ProtoMessage - addCode( - "%N·+=·%N.%N(%T.Companion, ", - field.attributeName, - wrapperParamName, - "readMessage", - field.type.resolve() - ) - - addCode(buildExtensionRegistryCodeForMessage(message)) - - addCode(")\n") + addCode("%N·+=·", field.attributeName) + addCode(buildReadScalarFieldMessageTypeCode(field.type, message, field.number)) + addCode("\n") } isPacked -> { @@ -363,9 +353,9 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { addCode(", ") addCode(getDefaultEntry(mapField.valuesType)) addCode(", ") - addCode(buildReadMapFieldDataCode(mapField.keyType)) + addCode(buildReadMapFieldDataCode(mapField.keyType, 1)) addCode(", ") - addCode(buildReadMapFieldDataCode(mapField.valuesType)) + addCode(buildReadMapFieldDataCode(mapField.valuesType, 2)) addCode(")\n") } } @@ -375,41 +365,29 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { message: ProtoMessage ) { builder.apply { - addCode("return %T(", message.className) - - val separator = ",\n" + addCode("return ") - val fieldsBlock = (message.fields + message.mapFields + message.oneOfs) - .joinToCodeBlock(separator = separator) { field -> - add( - "%N·=·%N", - field.attributeName, - field.attributeName - ) - } - - val unknownFieldsBlock = - CodeBlock.of( - "%N·=·%N", - Const.Message.Constructor.UnknownFields.name, - Const.Message.Companion.WrapperDeserializationFunction.UNKNOWN_FIELDS_LOCAL_VARIABLE - ) - - val extensionsBlock = - CodeBlock.of( - "%N·=·%N.build()", - Const.Message.Constructor.MessageExtensions.name, - Const.Message.Companion.WrapperDeserializationFunction.EXTENSION_BUILDER_LOCAL_VARIABLE + addCode( + MessageConstructorCallWriter.getConstructorCallCode( + message = message, + type = MessageConstructorCallWriter.ConstructorType.BUILD_PARTIAL, + getFieldParameter = { CodeBlock.of("%N", it.attributeName) }, + getMapFieldParameter = { CodeBlock.of("%N", it.attributeName) }, + getOneOfFieldParameter = { CodeBlock.of("%N", it.attributeName) }, + getUnknownFieldsParameter = { + CodeBlock.of( + "%N", + Const.Message.Companion.WrapperDeserializationFunction.UNKNOWN_FIELDS_LOCAL_VARIABLE + ) + }, + getExtensionParameter = { + CodeBlock.of( + "%N.build()", + Const.Message.Companion.WrapperDeserializationFunction.EXTENSION_BUILDER_LOCAL_VARIABLE + ) + } ) - - val codeBlocks = listOf( - fieldsBlock, - unknownFieldsBlock - ) + if (message.isExtendable) listOf(extensionsBlock) else emptyList() - - addCode(codeBlocks.joinCodeBlocks(separator)) - - addCode(")\n") + ) } } @@ -499,15 +477,24 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { ) } - private fun buildReadScalarFieldMessageTypeCode(type: ProtoType.DefType, message: ProtoMessage): CodeBlock { + private fun buildReadScalarFieldMessageTypeCode( + type: ProtoType.DefType, + message: ProtoMessage, + fieldNumber: Int + ): CodeBlock { return CodeBlock.builder() .add( "%N.%N(%T.Companion, ", wrapperParamName, - "readMessage", + getReadScalarFunctionName(type), type.resolve() ) .add(buildExtensionRegistryCodeForMessage(message)) + .apply { + if (message.type == ProtoMessage.Type.GROUP) { + add(", %L", fieldNumber) + } + } .add(")") .build() } @@ -521,7 +508,7 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { ) } - private fun buildReadMapFieldDataCode(type: ProtoType): CodeBlock { + private fun buildReadMapFieldDataCode(type: ProtoType, fieldNumber: Int): CodeBlock { return when (type) { is ProtoType.NonDeclType -> { CodeBlock.of( @@ -534,13 +521,9 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { when (val decl = type.resolveDeclaration()) { is ProtoMessage -> { CodeBlock.builder() - .add( - "{·%N(%T.Companion, ", - "readMessage", - type.resolve() - ) - .add(buildExtensionRegistryCodeForMessage(decl)) - .add(")}") + .add("{·") + .add(buildReadScalarFieldMessageTypeCode(type, decl, fieldNumber)) + .add("}") .build() } @@ -575,9 +558,12 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { ProtoType.StringType -> "readString" ProtoType.BytesType -> "readBytes" is ProtoType.DefType -> { - when (protoType.declType) { - ProtoType.DefType.DeclarationType.MESSAGE -> "readMessage" - ProtoType.DefType.DeclarationType.ENUM -> "readEnum" + when (val decl = protoType.resolveDeclaration()) { + is ProtoEnum -> "readEnum" + is ProtoMessage -> when (decl.type) { + ProtoMessage.Type.DEFAULT -> "readMessage" + ProtoMessage.Type.GROUP -> "readGroup" + } } } } @@ -585,10 +571,11 @@ class DeserializationFunctionExtension : BaseSerializationExtension() { private fun buildExtensionRegistryCodeForMessage(message: ProtoMessage): CodeBlock { return if (message.isExtendable) { + // bug in KotlinPoet: Member declaration resolves incorrectly, so we use %T.%N CodeBlock.of( - "%M", - message.className.nestedClass("Companion") - .member(Const.Message.Companion.defaultExtensionRegistryProperty.name) + "%T.%N", + message.className.nestedClass("Companion"), + Const.Message.Companion.defaultExtensionRegistryProperty.name ) } else { CodeBlock.of("%T.empty()", kmExtensionRegistry) diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/serialization/SerializationFunctionExtension.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/serialization/SerializationFunctionExtension.kt index 6ec9fd8b..0d987fe5 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/serialization/SerializationFunctionExtension.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/serialization/SerializationFunctionExtension.kt @@ -6,6 +6,7 @@ import com.squareup.kotlinpoet.KModifier import com.squareup.kotlinpoet.TypeSpec import io.github.timortel.kmpgrpc.plugin.sourcegeneration.SourceTarget import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.* +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoEnum import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoMessage import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.* import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.type.ProtoType @@ -193,7 +194,8 @@ class SerializationFunctionExtension : BaseSerializationExtension() { when (type.declType) { ProtoType.DefType.DeclarationType.MESSAGE -> { addCode( - "{·fieldNumber,·msg·-> writeMessage(fieldNumber, msg)·}" + "{·fieldNumber,·msg·-> %N(fieldNumber, msg)·}", + getWriteScalarFunctionName(type) ) } @@ -238,9 +240,12 @@ class SerializationFunctionExtension : BaseSerializationExtension() { ProtoType.StringType -> "writeStringArray" ProtoType.BytesType -> "writeBytesArray" is ProtoType.DefType -> { - when (protoType.declType) { - ProtoType.DefType.DeclarationType.MESSAGE -> "writeMessageArray" - ProtoType.DefType.DeclarationType.ENUM -> "writeEnumArray" + when (val decl = protoType.resolveDeclaration()) { + is ProtoMessage -> when (decl.type) { + ProtoMessage.Type.DEFAULT -> "writeMessageArray" + ProtoMessage.Type.GROUP -> "writeGroupArray" + } + is ProtoEnum -> "writeEnumArray" } } } @@ -266,17 +271,19 @@ class SerializationFunctionExtension : BaseSerializationExtension() { } is ProtoType.DefType -> { - when (type.declType) { - ProtoType.DefType.DeclarationType.MESSAGE -> { + when (type.resolveDeclaration()) { + is ProtoMessage -> { + val functionName = getWriteScalarFunctionName(type) + CodeBlock.of( - "%N.writeMessage(%L, %N)", + "%N.%N(%L, %N)", streamParam, + functionName, field.number, field.attributeName ) } - - ProtoType.DefType.DeclarationType.ENUM -> { + is ProtoEnum -> { CodeBlock.of( "%N.writeEnum(%L, %N.%N)\n", streamParam, @@ -321,9 +328,12 @@ class SerializationFunctionExtension : BaseSerializationExtension() { ProtoType.StringType -> "writeString" ProtoType.BytesType -> "writeBytes" is ProtoType.DefType -> { - when (protoType.declType) { - ProtoType.DefType.DeclarationType.MESSAGE -> "writeMessage" - ProtoType.DefType.DeclarationType.ENUM -> "writeEnum" + when (val decl = protoType.resolveDeclaration()) { + is ProtoEnum -> "writeEnum" + is ProtoMessage -> when (decl.type) { + ProtoMessage.Type.DEFAULT -> "writeMessage" + ProtoMessage.Type.GROUP -> "writeGroup" + } } } } diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/oneof/ActualProtoOneOfWriter.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/oneof/ActualProtoOneOfWriter.kt index 129c52ca..e8cc343e 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/oneof/ActualProtoOneOfWriter.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/oneof/ActualProtoOneOfWriter.kt @@ -24,6 +24,7 @@ abstract class ActualProtoOneOfWriter : ProtoOneOfWriter(true) { addSerializeFunction(builder, listOf(KModifier.ABSTRACT)) {} builder.addProperty(Const.Message.OneOf.REQUIRED_SIZE_PROPERTY_NAME, INT, KModifier.ABSTRACT) + builder.addProperty(Const.Message.OneOf.isInitializedProperty.toPropertySpec(KModifier.ABSTRACT)) } override fun modifyChildClass(builder: TypeSpec.Builder, oneOf: ProtoOneOf, childClassType: ChildClassType) { @@ -65,6 +66,22 @@ abstract class ActualProtoOneOfWriter : ProtoOneOfWriter(true) { ) .build() ) + + builder.addProperty( + Const.Message.OneOf.isInitializedProperty.toPropertySpecBuilder(KModifier.OVERRIDE) + .apply { + when (childClassType) { + is ChildClassType.Normal -> if (childClassType.field.type.isMessage) { + initializer("%N.%N", childClassType.field.attributeName, Const.Message.isInitializedProperty.name) + } else { + initializer("true") + } + ChildClassType.NotSet -> initializer("true") + ChildClassType.Unknown -> initializer("true") + } + } + .build() + ) } private fun addSerializeFunction( @@ -81,4 +98,4 @@ abstract class ActualProtoOneOfWriter : ProtoOneOfWriter(true) { .build() ) } -} \ No newline at end of file +} diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/DeclarationResolver.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/DeclarationResolver.kt index 0b1d605e..1cb23df7 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/DeclarationResolver.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/DeclarationResolver.kt @@ -4,12 +4,14 @@ import io.github.timortel.kmpgrpc.plugin.sourcegeneration.CompilationException import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoDeclaration import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoEnum import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoMessage +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.file.ProtoFile +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.file.ProtoImport import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.structure.ProtoPackage import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.type.ProtoType import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.toFilePositionString /** - * Implementation for resoling in both messages and enums + * Implementation for resolving declarations in both messages and enums */ interface DeclarationResolver : BaseDeclarationResolver { @@ -27,13 +29,27 @@ interface DeclarationResolver : BaseDeclarationResolver { // Search scope val identifier = type.declaration - val allowedFiles = listOf(type.file) + type.file.importedFiles + val fileToImport = type.file.imports.associateBy { import -> + type.file.project.rootFolder.resolveImport(import.path) + ?: throw CompilationException.UnresolvedImport( + "Unable to resolve import ${import.identifier}", + type.file, + import.ctx + ) + } // Only allow candidates from the file itself or from imported files val allowedCandidates = candidates.filter { candidate -> when (candidate) { - is Candidate.Message -> candidate.message.file in allowedFiles - is Candidate.Enum -> candidate.enum.file in allowedFiles + is Candidate.Message, is Candidate.Enum -> { + when { + candidate.file == type.file -> true + else -> { + val import = fileToImport[candidate.file] + import != null && import.type == ProtoImport.Type.DEFAULT + } + } + } is Candidate.Package -> true // Packages are always allowed } } @@ -47,11 +63,11 @@ interface DeclarationResolver : BaseDeclarationResolver { validateCandidates(type, matchingCandidates) - return when { + when { matchingCandidates.isNotEmpty() -> { val newType = type.copy(declaration = remainingIdentifier) - // Go deeper into the three. No turning back. There must be exactly one element in the list. + // Go deeper into the tree. No turning back. There must be exactly one element in the list. when (val candidate = matchingCandidates.first()) { is Candidate.Message -> candidate.message.resolveDeclaration(newType, false) is Candidate.Enum -> candidate.enum.resolveDeclaration(newType) @@ -128,17 +144,27 @@ interface DeclarationResolver : BaseDeclarationResolver { fun getLocation(): String - data class Message(val message: ProtoMessage) : Candidate { + sealed interface FileBasedCandidate : Candidate { + val file: ProtoFile + } + + data class Message(val message: ProtoMessage) : FileBasedCandidate { override val name: String get() = message.name + override val file: ProtoFile + get() = message.file + override fun getLocation(): String = message.ctx.toFilePositionString(message.file.path) } - data class Enum(val enum: ProtoEnum) : Candidate { + data class Enum(val enum: ProtoEnum) : FileBasedCandidate { override val name: String get() = enum.name + override val file: ProtoFile + get() = enum.file + override fun getLocation(): String = enum.ctx.toFilePositionString(enum.file.path) } diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/ProtoLanguageVersion.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/ProtoLanguageVersion.kt index a6f688f9..1c529532 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/ProtoLanguageVersion.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/ProtoLanguageVersion.kt @@ -1,6 +1,7 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.model enum class ProtoLanguageVersion { + PROTO2, PROTO3, EDITION2023, EDITION2024 diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/ProtoOptionsHolder.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/ProtoOptionsHolder.kt index a734dad6..48ddb5a5 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/ProtoOptionsHolder.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/ProtoOptionsHolder.kt @@ -10,6 +10,10 @@ import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.toFilePositionStr interface ProtoOptionsHolder : ProtoNode { + companion object { + private val ignoredPackages = listOf("google.protobuf") + } + val options: List val file: ProtoFile @@ -18,6 +22,8 @@ interface ProtoOptionsHolder : ProtoNode { val optionTarget: OptionTarget override fun validate() { + if (file.`package` in ignoredPackages) return + options.forEach { option -> val isIgnored = option.name in Options.ignoredOptions val relatedOption = Options.options.firstOrNull { it.name == option.name } diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoBaseDeclaration.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoBaseDeclaration.kt index d6c0d762..3166456f 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoBaseDeclaration.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoBaseDeclaration.kt @@ -47,7 +47,7 @@ interface ProtoBaseDeclaration : ProtoOptionsHolder, ProtoVisibilityHolder { */ val isNested: Boolean get() = when (file.languageVersion) { - ProtoLanguageVersion.PROTO3, ProtoLanguageVersion.EDITION2023 -> !Options.Basic.javaMultipleFiles.get(file) + ProtoLanguageVersion.PROTO2, ProtoLanguageVersion.PROTO3, ProtoLanguageVersion.EDITION2023 -> !Options.Basic.javaMultipleFiles.get(file) ProtoLanguageVersion.EDITION2024 -> when (Options.Feature.nestInFileClass.get(this)) { ProtoNestInFileClass.YES -> true ProtoNestInFileClass.NO -> false diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoEnum.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoEnum.kt index ed4779e8..445246df 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoEnum.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoEnum.kt @@ -39,16 +39,6 @@ data class ProtoEnum( is ProtoDeclParent.Message -> p.message.file } - val defaultField: ProtoEnumField - get() = - fields - .firstOrNull { it.number == 0 } - ?: throw CompilationException.EnumIllegalFirstField( - "Enumeration does not have field with value 0.", - file, - ctx - ) - override val heldFields: List = fields @@ -79,11 +69,20 @@ data class ProtoEnum( } return when (userLanguage) { + ProtoLanguageVersion.PROTO2 -> when (file.languageVersion) { + ProtoLanguageVersion.PROTO2 -> false + ProtoLanguageVersion.PROTO3 -> true + ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> getFeatureIsOpen() + } + // Closed enum imports are illegal. We still return closed here and throw an exception in ProtoType#validate ProtoLanguageVersion.PROTO3 -> when (file.languageVersion) { + ProtoLanguageVersion.PROTO2 -> false ProtoLanguageVersion.PROTO3 -> true ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> getFeatureIsOpen() } + ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> when (file.languageVersion) { + ProtoLanguageVersion.PROTO2 -> false ProtoLanguageVersion.PROTO3 -> true ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> getFeatureIsOpen() } @@ -100,11 +99,16 @@ data class ProtoEnum( ctx = ctx ) - if (fields.first().number != 0) throw CompilationException.EnumIllegalFirstField( - message = "The first value defined in an enumeration must have value 0", - file = file, - ctx = ctx - ) + when (file.languageVersion) { + ProtoLanguageVersion.PROTO2 -> {} + ProtoLanguageVersion.PROTO3, ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> { + if (fields.first().number != 0) throw CompilationException.EnumIllegalFirstField( + message = "The first value defined in an enumeration must have value 0", + file = file, + ctx = ctx + ) + } + } val allowAlias = Options.Basic.allowAlias.get(this) fields diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoMessage.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoMessage.kt index 0b87c75e..356274bb 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoMessage.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/ProtoMessage.kt @@ -31,6 +31,7 @@ data class ProtoMessage( override val extensionDefinitions: List, val extensionRange: ProtoExtensionRanges, override val symbolVisibility: ProtoSymbolVisibility?, + val type: Type, override val ctx: ParserRuleContext ) : ProtoDeclaration, FileBasedDeclarationResolver, ProtoFieldHolder, ProtoChildPropertyNameResolver, ProtoExtensionDefinitionHolder, ProtoExtensionDefinitionFinder { @@ -174,4 +175,9 @@ data class ProtoMessage( throw CompilationException.FieldNumberConflict(message, file, ctx) } } + + enum class Type { + DEFAULT, + GROUP + } } diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/message/field/ProtoMessageField.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/message/field/ProtoMessageField.kt index 08d24f00..40c6f176 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/message/field/ProtoMessageField.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/declaration/message/field/ProtoMessageField.kt @@ -63,9 +63,10 @@ class ProtoMessageField( val cardinality: ProtoFieldCardinality get() = when (file.languageVersion) { - ProtoLanguageVersion.PROTO3 -> when (fieldCardinality) { + ProtoLanguageVersion.PROTO3, ProtoLanguageVersion.PROTO2 -> when (fieldCardinality) { FieldCardinality.SINGULAR -> ProtoFieldCardinality.Singular(ProtoFieldPresence.IMPLICIT) FieldCardinality.SINGULAR_OPTIONAL -> ProtoFieldCardinality.Singular(ProtoFieldPresence.EXPLICIT) + FieldCardinality.SINGULAR_REQUIRED -> ProtoFieldCardinality.Singular(ProtoFieldPresence.LEGACY_REQUIRED) FieldCardinality.REPEATED -> ProtoFieldCardinality.Repeated } @@ -75,7 +76,7 @@ class ProtoMessageField( ) FieldCardinality.REPEATED -> ProtoFieldCardinality.Repeated - FieldCardinality.SINGULAR_OPTIONAL -> throw IllegalArgumentException("FieldCardinality.SINGULAR_OPTIONAL is illegal for edition versions.") + FieldCardinality.SINGULAR_OPTIONAL, FieldCardinality.SINGULAR_REQUIRED -> throw IllegalArgumentException("field cardinality $fieldCardinality is illegal for edition versions.") } } @@ -100,7 +101,15 @@ class ProtoMessageField( * If cardinality is either explicit or legacy, or if the type is a message and it is not repeated */ val needsIsSetProperty: Boolean - get() = cardinality.isExplicit || (type is ProtoType.DefType && type.isMessage && cardinality != ProtoFieldCardinality.Repeated) + get() { + val isSingularMessage = type is ProtoType.DefType && type.isMessage && cardinality != ProtoFieldCardinality.Repeated + + return when (file.languageVersion) { + ProtoLanguageVersion.PROTO2 -> cardinality is ProtoFieldCardinality.Singular + ProtoLanguageVersion.PROTO3 -> cardinality.isExplicit || isSingularMessage + ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> !cardinality.isImplicit && cardinality != ProtoFieldCardinality.Repeated + } + } val isSetProperty: ExtraProperty get() = ExtraProperty( @@ -118,7 +127,7 @@ class ProtoMessageField( */ override val isPacked: Boolean get() = cardinality == ProtoFieldCardinality.Repeated && type.isPackable && when (file.languageVersion) { - ProtoLanguageVersion.PROTO3 -> Options.Basic.packed.get(this) + ProtoLanguageVersion.PROTO3, ProtoLanguageVersion.PROTO2 -> Options.Basic.packed.get(this) ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> when (Options.Feature.repeatedFieldEncoding.get(this)) { ProtoRepeatedFieldEncoding.PACKED -> true @@ -170,6 +179,22 @@ class ProtoMessageField( } } + fun isConstructorParameterNullable(type: ConstructorParameterType): Boolean { + return when (type) { + ConstructorParameterType.CONSTRUCTOR, ConstructorParameterType.CREATE_PARTIAL -> needsIsSetProperty + ConstructorParameterType.CREATE -> when (fieldCardinality) { + FieldCardinality.SINGULAR, FieldCardinality.SINGULAR_OPTIONAL, FieldCardinality.REPEATED -> needsIsSetProperty + FieldCardinality.SINGULAR_REQUIRED -> false + } + } + } + + enum class ConstructorParameterType { + CONSTRUCTOR, + CREATE, + CREATE_PARTIAL + } + data class ExtraProperty( override val desiredAttributeName: String, override val resolvingParent: ProtoChildPropertyNameResolver, @@ -191,6 +216,7 @@ class ProtoMessageField( enum class FieldCardinality { SINGULAR, SINGULAR_OPTIONAL, + SINGULAR_REQUIRED, REPEATED } } diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/file/ProtoImport.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/file/ProtoImport.kt index 6567b153..e6afb8a0 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/file/ProtoImport.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/file/ProtoImport.kt @@ -20,7 +20,7 @@ data class ProtoImport(val identifier: String, val type: Type, val ctx: ParserRu } Type.OPTION -> { when (file.languageVersion) { - ProtoLanguageVersion.PROTO3, ProtoLanguageVersion.EDITION2023 -> throw CompilationException.UnsupportedLanguageFeatureUsed( + ProtoLanguageVersion.PROTO2, ProtoLanguageVersion.PROTO3, ProtoLanguageVersion.EDITION2023 -> throw CompilationException.UnsupportedLanguageFeatureUsed( message = "Option imports are not available in language version ${file.languageVersion}", file = file, ctx = ctx diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/FeatureProtoOption.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/FeatureProtoOption.kt index 1adba00c..7c319959 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/FeatureProtoOption.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/FeatureProtoOption.kt @@ -31,6 +31,7 @@ class FeatureProtoOption( name = name, parse = parse, languageConfigurationMap = mapOf( + ProtoLanguageVersion.PROTO2 to LangConfig.Unavailable(), ProtoLanguageVersion.PROTO3 to LangConfig.Unavailable(), ProtoLanguageVersion.EDITION2023 to edition2023Config, ProtoLanguageVersion.EDITION2024 to edition2024Config, diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/Options.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/Options.kt index cc98784c..bfe85875 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/Options.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/Options.kt @@ -12,6 +12,7 @@ object Options { name = "java_multiple_files", parse = String::toBooleanStrictOrNull, targets = listOf(OptionTargetMatcher.FILE), + proto2Config = LangConfig.Available(defaultValue = false), proto3Config = LangConfig.Available(defaultValue = false), edition2023Config = LangConfig.Available(defaultValue = false), edition2024Config = LangConfig.Available(defaultValue = false) @@ -21,6 +22,7 @@ object Options { name = "java_package", parse = { it }, targets = listOf(OptionTargetMatcher.FILE), + proto2Config = LangConfig.Available(defaultValue = null), proto3Config = LangConfig.Available(defaultValue = null), editionConfig = LangConfig.Available(defaultValue = null) ) @@ -29,6 +31,7 @@ object Options { name = "java_outer_classname", parse = { it }, targets = listOf(OptionTargetMatcher.FILE), + proto2Config = LangConfig.Available(defaultValue = null), proto3Config = LangConfig.Available(defaultValue = null), editionConfig = LangConfig.Available(defaultValue = null) ) @@ -37,6 +40,7 @@ object Options { name = "allow_alias", parse = String::toBooleanStrictOrNull, targets = listOf(OptionTargetMatcher.ENUM(restrictToTopLevel = false)), + proto2Config = LangConfig.Available(defaultValue = false), proto3Config = LangConfig.Available(defaultValue = false), editionConfig = LangConfig.Available(defaultValue = false) ) @@ -45,6 +49,7 @@ object Options { name = "deprecated", parse = String::toBooleanStrictOrNull, targets = listOf(OptionTargetMatcher.FIELD(), OptionTargetMatcher.ENUM_ENTRY), + proto2Config = LangConfig.Available(defaultValue = false), proto3Config = LangConfig.Available(defaultValue = false), editionConfig = LangConfig.Available(defaultValue = false), failOnInvalidTargetUsage = false @@ -54,9 +59,19 @@ object Options { name = "packed", parse = String::toBooleanStrictOrNull, targets = listOf(OptionTargetMatcher.FIELD(restriction = OptionTargetMatcher.FIELD.Restriction.OnlyOnRepeated(forcePackable = true))), + proto2Config = LangConfig.Available(defaultValue = true), proto3Config = LangConfig.Available(defaultValue = true), editionConfig = LangConfig.Unavailable() ) + + val default = SimpleProtoOption( + name = "default", + parse = { it }, + targets = listOf(OptionTargetMatcher.FIELD()), + proto2Config = LangConfig.Available(defaultValue = null), + proto3Config = LangConfig.Unavailable(), + editionConfig = LangConfig.Available(defaultValue = null) + ) } object Feature { @@ -80,6 +95,10 @@ object Options { name = "default_symbol_visibility", parse = { value -> ProtoDefaultSymbolVisibility.entries.firstOrNull { it.name == value } }, languageConfigurationMap = mapOf( + ProtoLanguageVersion.PROTO2 to LangConfig.Available( + defaultValue = ProtoDefaultSymbolVisibility.EXPORT_ALL, + isLocked = true + ), ProtoLanguageVersion.PROTO3 to LangConfig.Available( defaultValue = ProtoDefaultSymbolVisibility.EXPORT_ALL, isLocked = true @@ -123,6 +142,7 @@ object Options { Basic.allowAlias, Basic.deprecated, Basic.packed, + Basic.default, Feature.fieldPresence, Feature.repeatedFieldEncoding, Feature.defaultSymbolVisibility, @@ -140,7 +160,6 @@ object Options { "cc_enable_arenas" ) - sealed interface LangConfig { class Unavailable : LangConfig diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/SimpleProtoOption.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/SimpleProtoOption.kt index f96a31b9..c924ac8e 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/SimpleProtoOption.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/option/SimpleProtoOption.kt @@ -19,12 +19,14 @@ class SimpleProtoOption( name: String, parse: (String) -> T?, targets: List, + proto2Config: LangConfig, proto3Config: LangConfig, editionConfig: LangConfig, failOnInvalidTargetUsage: Boolean = true ) : this( name = name, parse = parse, + proto2Config = proto2Config, proto3Config = proto3Config, edition2023Config = editionConfig, edition2024Config = editionConfig, @@ -36,6 +38,7 @@ class SimpleProtoOption( name: String, parse: (String) -> T?, targets: List, + proto2Config: LangConfig, proto3Config: LangConfig, edition2023Config: LangConfig, edition2024Config: LangConfig, @@ -44,6 +47,7 @@ class SimpleProtoOption( name = name, parse = parse, languageConfigurationMap = mapOf( + ProtoLanguageVersion.PROTO2 to proto2Config, ProtoLanguageVersion.PROTO3 to proto3Config, ProtoLanguageVersion.EDITION2023 to edition2023Config, ProtoLanguageVersion.EDITION2024 to edition2024Config diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/type/ProtoType.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/type/ProtoType.kt index 9b0b7715..1ab4cd4d 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/type/ProtoType.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/model/type/ProtoType.kt @@ -6,6 +6,7 @@ import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.* import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.ProtoExtensionDefinition import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.ProtoLanguageVersion import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.ProtoNode +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.ProtoOptionsHolder import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.ProtoProject import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoDeclaration import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.ProtoEnum @@ -14,6 +15,7 @@ import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.mess import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoMessageField import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoOneOfField import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.file.ProtoFile +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.option.Options import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.service.ProtoRpc import io.github.timortel.kmpgrpc.shared.internal.io.DataType import org.antlr.v4.runtime.ParserRuleContext @@ -196,9 +198,12 @@ sealed interface ProtoType : ProtoNode { override val isEnum: Boolean get() = declType == DeclarationType.ENUM override val wireType: DataType - get() = when (declType) { - DeclarationType.MESSAGE -> DataType.MESSAGE - DeclarationType.ENUM -> DataType.ENUM + get() = when (val decl = resolveDeclaration()) { + is ProtoEnum -> DataType.ENUM + is ProtoMessage -> when (decl.type) { + ProtoMessage.Type.DEFAULT -> DataType.MESSAGE + ProtoMessage.Type.GROUP -> DataType.GROUP + } } override val fieldType: TypeName @@ -210,8 +215,28 @@ sealed interface ProtoType : ProtoNode { override fun defaultValue(messageDefaultValue: MessageDefaultValue): CodeBlock { return when (val decl = resolveDeclaration()) { is ProtoEnum -> { - val defaultField = decl.defaultField - CodeBlock.of("%T.%N", decl.className, defaultField.name) + val optionsHolder: ProtoOptionsHolder = when (val p = parent) { + is Parent.MessageField -> p.field + is Parent.MapField -> p.field + is Parent.OneOfField -> p.field + is Parent.Rpc -> throw IllegalStateException("Enum cannot have rpc as parent") + is Parent.ExtensionDefinition -> throw IllegalStateException("Cannot get default value with extension definition parent") + } + + val defaultValueAsString = when (file.languageVersion) { + ProtoLanguageVersion.PROTO3 -> null + ProtoLanguageVersion.PROTO2, ProtoLanguageVersion.EDITION2023, ProtoLanguageVersion.EDITION2024 -> { + Options.Basic.default.get(optionsHolder) + } + } + + val defaultValue = if (defaultValueAsString == null) decl.fields.first() + else { + decl.fields.firstOrNull { it.name == defaultValueAsString } + ?: throw CompilationException.UnresolvedReference("Could not find enum entry with name $defaultValueAsString", file, ctx) + } + + CodeBlock.of("%T.%N", decl.className, defaultValue.name) } is ProtoMessage -> { diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/parsing/ProtobufModelBuilderVisitor.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/parsing/ProtobufModelBuilderVisitor.kt index e4a2b355..455d634b 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/parsing/ProtobufModelBuilderVisitor.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/parsing/ProtobufModelBuilderVisitor.kt @@ -1,5 +1,7 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.parsing +import io.github.timortel.kmpgrpc.anltr.Protobuf2Parser +import io.github.timortel.kmpgrpc.anltr.Protobuf2Visitor import io.github.timortel.kmpgrpc.anltr.Protobuf3Parser import io.github.timortel.kmpgrpc.anltr.Protobuf3Visitor import io.github.timortel.kmpgrpc.anltr.ProtobufEditionsParser @@ -27,6 +29,7 @@ import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.mess import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoOneOfField import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.service.ProtoRpc import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.service.ProtoService +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.decapitalize import org.antlr.v4.runtime.ParserRuleContext import org.antlr.v4.runtime.tree.ErrorNode import org.antlr.v4.runtime.tree.ParseTree @@ -40,7 +43,7 @@ class ProtobufModelBuilderVisitor( private val fileName: String, private val fileNameWithoutExtension: String, private val protoLanguageVersion: ProtoLanguageVersion -) : Protobuf3Visitor, ProtobufEditionsVisitor { +) : Protobuf3Visitor, ProtobufEditionsVisitor, Protobuf2Visitor { private fun visitProto( ctx: ParserRuleContext, @@ -118,6 +121,29 @@ class ProtobufModelBuilderVisitor( ) } + override fun visitProto(ctx: Protobuf2Parser.ProtoContext): ProtoFile { + val imports = ctx.importStatement().map { visitImportStatement(it) } + val options = ctx.optionStatement().map { visitOptionStatement(it) } + + val messages = ctx.topLevelDef().mapNotNull { it.messageDef() }.map { visitMessageDef(it) } + val topLevelEnums = ctx.topLevelDef().mapNotNull { it.enumDef() }.map { visitEnumDef(it) } + val services = ctx.topLevelDef().mapNotNull { it.serviceDef() }.map { visitServiceDef(it) } + val extensionDefinitionsData = ctx.topLevelDef().mapNotNull { it.extendDef() }.map { visitExtendDef(it) } + + val packages = ctx.packageStatement().mapNotNull { it.fullIdent()?.text } + + return visitProto( + ctx = ctx, + imports = imports, + options = options, + messages = messages + extensionDefinitionsData.flatMap { it.groupMessages }, + topLevelEnums = topLevelEnums, + services = services, + packages = packages, + extensionDefinitions = extensionDefinitionsData.map { it.extensionDefinition } + ) + } + private fun visitImportStatement(ctx: ParserRuleContext, identifier: String, type: ProtoImport.Type): ProtoImport { return ProtoImport(identifier = identifier, type = type, ctx = ctx) } @@ -141,6 +167,15 @@ class ProtobufModelBuilderVisitor( return visitImportStatement(ctx, ctx.strLit().text, type) } + override fun visitImportStatement(ctx: Protobuf2Parser.ImportStatementContext): ProtoImport { + val type = when { + ctx.PUBLIC() != null -> ProtoImport.Type.PUBLIC + else -> ProtoImport.Type.DEFAULT + } + + return visitImportStatement(ctx, ctx.strLit().text, type) + } + private fun visitOption(ctx: ParserRuleContext, name: String, constant: String): ProtoOption { val value = if (constant.startsWith("\"") && constant.endsWith("\"")) { constant.substring(1, constant.length - 1) @@ -157,6 +192,10 @@ class ProtobufModelBuilderVisitor( return visitOption(ctx, ctx.optionName().text, ctx.constant().text) } + override fun visitOptionStatement(ctx: Protobuf2Parser.OptionStatementContext): ProtoOption { + return visitOption(ctx, ctx.optionName().text, ctx.constant().text) + } + override fun visitSymbolVisibility(ctx: ProtobufEditionsParser.SymbolVisibilityContext?): ProtoSymbolVisibility? { return when { ctx?.LOCAL() != null -> ProtoSymbolVisibility.LOCAL @@ -165,7 +204,9 @@ class ProtobufModelBuilderVisitor( } } + // ------------------------- // Message parsing + // ------------------------- override fun visitMessageDef(ctx: ProtobufEditionsParser.MessageDefContext): ProtoMessage { val name = ctx.messageName().text @@ -199,6 +240,7 @@ class ProtobufModelBuilderVisitor( extensionDefinitions = extensionDefinitions, extensionRange = extensionRange, symbolVisibility = symbolVisibility, + type = ProtoMessage.Type.DEFAULT, ctx = ctx ) } @@ -233,15 +275,106 @@ class ProtobufModelBuilderVisitor( extensionDefinitions = extensionDefinitions, extensionRange = ProtoExtensionRanges(), symbolVisibility = null, + type = ProtoMessage.Type.DEFAULT, + ctx = ctx + ) + } + + private fun visitProto2MessageFromElements( + name: String, + elements: List, + type: ProtoMessage.Type, + ctx: ParserRuleContext + ): ProtoMessage { + val nestedMessages = elements.mapNotNull { it.messageDef() }.map { visitMessageDef(it) } + val nestedEnums = elements.mapNotNull { it.enumDef() }.map { visitEnumDef(it) } + + val directFields = elements.mapNotNull { it.field() }.map { visitField(it) } + val mapFields = elements.mapNotNull { it.mapField() }.map { visitMapField(it) } + val oneOfs = elements.mapNotNull { it.oneof() }.map { visitOneof(it) } + + val parsedGroups = elements.mapNotNull { it.group() }.map { visitGroup(it) } + val groupFields = parsedGroups.map { it.field } + val groupMessages = parsedGroups.map { it.message } + + val reservation = elements.mapNotNull { it.reserved() }.map { visitReserved(it) }.fold() + val options = elements.mapNotNull { it.optionStatement() }.map { visitOptionStatement(it) } + + val extensionDefinitionsData = elements.mapNotNull { it.extendDef() }.map { visitExtendDef(it) } + val extensionRange = elements.mapNotNull { it.extensions() }.map { visitExtensions(it) }.fold() + + return ProtoMessage( + name = name, + messages = nestedMessages + groupMessages + extensionDefinitionsData.flatMap { it.groupMessages }, + enums = nestedEnums, + fields = directFields + groupFields, + oneOfs = oneOfs, + mapFields = mapFields, + reservation = reservation, + options = options, + extensionDefinitions = extensionDefinitionsData.map { it.extensionDefinition }, + extensionRange = extensionRange, + symbolVisibility = null, + type = type, + ctx = ctx + ) + } + + override fun visitMessageDef(ctx: Protobuf2Parser.MessageDefContext): ProtoMessage { + val name = ctx.messageName().text + val elements = ctx.messageBody().messageElement() + + return visitProto2MessageFromElements( + name = name, + elements = elements, + type = ProtoMessage.Type.DEFAULT, + ctx = ctx + ) + } + + private fun visitGroupFieldCardinality(label: Protobuf2Parser.FieldLabelContext?): ProtoMessageField.FieldCardinality { + return when { + label?.REPEATED() != null -> ProtoMessageField.FieldCardinality.REPEATED + label?.OPTIONAL() != null -> ProtoMessageField.FieldCardinality.SINGULAR_OPTIONAL + label?.REQUIRED() != null -> ProtoMessageField.FieldCardinality.SINGULAR_REQUIRED + else -> throw IllegalStateException("field cardinality must be one of: repeated, optional, required") + } + } + + override fun visitGroup(ctx: Protobuf2Parser.GroupContext): ParsedGroup { + val label = ctx.fieldLabel() + val fieldCardinality = visitGroupFieldCardinality(label) + + val groupName = ctx.groupName().text + val number = visitIntLit(ctx.fieldNumber().intLit()) + + val options = visitFieldOptions(ctx.fieldOptions()) + + val field = ProtoMessageField( + type = ProtoType.DefType(groupName, ctx), + name = groupName.decapitalize(), + number = number, + options = options, + fieldCardinality = fieldCardinality, ctx = ctx ) + + val groupElements = ctx.messageBody().messageElement() + val groupMessage = visitProto2MessageFromElements( + name = groupName, + elements = groupElements, + type = ProtoMessage.Type.GROUP, + ctx = ctx + ) + + return ParsedGroup(field = field, message = groupMessage) } override fun visitReserved(ctx: ProtobufEditionsParser.ReservedContext): ProtoReservation { return when { ctx.ranges() != null -> ProtoReservation(ranges = visitRanges(ctx.ranges())) ctx.reservedFieldNames() != null -> visitReservedFieldNames(ctx.reservedFieldNames()) - else -> throw ParseException("Could not read reserved field", ctx) + else -> throw ParseException("Could not read reserved field", ctx, filePath) } } @@ -249,7 +382,15 @@ class ProtobufModelBuilderVisitor( return when { ctx.ranges() != null -> ProtoReservation(ranges = visitRanges(ctx.ranges())) ctx.reservedFieldNames() != null -> visitReservedFieldNames(ctx.reservedFieldNames()) - else -> throw ParseException("Could not read reserved field", ctx) + else -> throw ParseException("Could not read reserved field", ctx, filePath) + } + } + + override fun visitReserved(ctx: Protobuf2Parser.ReservedContext): ProtoReservation { + return when { + ctx.ranges() != null -> ProtoReservation(ranges = visitRanges(ctx.ranges())) + ctx.reservedFieldNames() != null -> visitReservedFieldNames(ctx.reservedFieldNames()) + else -> throw ParseException("Could not read reserved field", ctx, filePath) } } @@ -261,6 +402,10 @@ class ProtobufModelBuilderVisitor( return ctx.range_().map { visitRange_(it) } } + override fun visitRanges(ctx: Protobuf2Parser.RangesContext): List { + return ctx.range_().map { visitRange_(it) } + } + private fun visitRange(start: Int, end: Int?, isMax: Boolean, ctx: ParserRuleContext): ProtoRange { val end = when { isMax -> Const.FIELD_NUMBER_MAX_VALUE @@ -272,11 +417,15 @@ class ProtobufModelBuilderVisitor( } override fun visitRange_(ctx: ProtobufEditionsParser.Range_Context): ProtoRange { - return visitRange(ctx.intLit(0).parseInt(), ctx.intLit(1)?.parseInt(), ctx.MAX() != null, ctx) + return visitRange(visitIntLit(ctx.intLit(0)), ctx.intLit(1)?.let(::visitIntLit), ctx.MAX() != null, ctx) } override fun visitRange_(ctx: Protobuf3Parser.Range_Context): ProtoRange { - return visitRange(ctx.intLit(0).parseInt(), ctx.intLit(1)?.parseInt(), ctx.MAX() != null, ctx) + return visitRange(visitIntLit(ctx.intLit(0)), ctx.intLit(1)?.let(::visitIntLit), ctx.MAX() != null, ctx) + } + + override fun visitRange_(ctx: Protobuf2Parser.Range_Context): ProtoRange { + return visitRange(visitIntLit(ctx.intLit(0)), ctx.intLit(1)?.let(::visitIntLit), ctx.MAX() != null, ctx) } private fun visitReservedFieldNames(names: List): ProtoReservation { @@ -294,13 +443,21 @@ class ProtobufModelBuilderVisitor( return visitReservedFieldNames(ctx.strLit().map { it.text }) } + override fun visitReservedFieldNames(ctx: Protobuf2Parser.ReservedFieldNamesContext): ProtoReservation { + return visitReservedFieldNames(ctx.strLit().map { it.text }) + } + override fun visitMessageBody(ctx: ProtobufEditionsParser.MessageBodyContext): Any = Unit override fun visitMessageBody(ctx: Protobuf3Parser.MessageBodyContext): Any = Unit + override fun visitMessageBody(ctx: Protobuf2Parser.MessageBodyContext): Any = Unit override fun visitMessageElement(ctx: ProtobufEditionsParser.MessageElementContext): Any = Unit override fun visitMessageElement(ctx: Protobuf3Parser.MessageElementContext?): Any = Unit + override fun visitMessageElement(ctx: Protobuf2Parser.MessageElementContext?): Any = Unit + // ------------------------- // Enum Parsing + // ------------------------- override fun visitEnumDef(ctx: ProtobufEditionsParser.EnumDefContext): ProtoEnum { val name = ctx.enumName().text @@ -340,7 +497,27 @@ class ProtobufModelBuilderVisitor( ) } + override fun visitEnumDef(ctx: Protobuf2Parser.EnumDefContext): ProtoEnum { + val name = ctx.enumName().text + val elements = ctx.enumBody().enumElement() + + val options = elements.mapNotNull { it.optionStatement() }.map { visitOptionStatement(it) } + val fields = elements.mapNotNull { it.enumField() }.map { visitEnumField(it) } + val reservation = elements.mapNotNull { it.reserved() }.map { visitReserved(it) }.fold() + + return ProtoEnum( + name = name, + fields = fields, + options = options, + reservation = reservation, + symbolVisibility = null, + ctx = ctx + ) + } + + // ------------------------- // Field parsing + // ------------------------- override fun visitField(ctx: ProtobufEditionsParser.FieldContext): ProtoMessageField { val label = ctx.fieldLabel() @@ -352,7 +529,7 @@ class ProtobufModelBuilderVisitor( val type = visitType_(ctx.type_()) val name = ctx.fieldName().text - val number = ctx.fieldNumber().parseInt() + val number = visitIntLit(ctx.fieldNumber().intLit()) val options = visitFieldOptions(ctx.fieldOptions()) @@ -377,7 +554,33 @@ class ProtobufModelBuilderVisitor( val type = visitType_(ctx.type_()) val name = ctx.fieldName().text - val number = ctx.fieldNumber().parseInt() + val number = visitIntLit(ctx.fieldNumber().intLit()) + + val options = visitFieldOptions(ctx.fieldOptions()) + + return ProtoMessageField( + type = type, + name = name, + number = number, + options = options, + fieldCardinality = fieldCardinality, + ctx = ctx + ) + } + + override fun visitField(ctx: Protobuf2Parser.FieldContext): ProtoMessageField { + val label = ctx.fieldLabel() + + val fieldCardinality = when { + label?.REPEATED() != null -> ProtoMessageField.FieldCardinality.REPEATED + label?.OPTIONAL() != null -> ProtoMessageField.FieldCardinality.SINGULAR_OPTIONAL + label?.REQUIRED() != null -> ProtoMessageField.FieldCardinality.SINGULAR_REQUIRED + else -> throw IllegalStateException("field cardinality must be one of: repeated, optional, required") + } + + val type = visitType_(ctx.type_()) + val name = ctx.fieldName().text + val number = visitIntLit(ctx.fieldNumber().intLit()) val options = visitFieldOptions(ctx.fieldOptions()) @@ -396,7 +599,7 @@ class ProtobufModelBuilderVisitor( val valuesType = visitType_(ctx.type_()) val name = ctx.mapName().text - val number = ctx.fieldNumber().parseInt() + val number = visitIntLit(ctx.fieldNumber().intLit()) val options = visitFieldOptions(ctx.fieldOptions()) @@ -415,7 +618,26 @@ class ProtobufModelBuilderVisitor( val valuesType = visitType_(ctx.type_()) val name = ctx.mapName().text - val number = ctx.fieldNumber().parseInt() + val number = visitIntLit(ctx.fieldNumber().intLit()) + + val options = visitFieldOptions(ctx.fieldOptions()) + + return ProtoMapField( + name = name, + number = number, + options = options, + keyType = keyType, + valuesType = valuesType, + ctx = ctx + ) + } + + override fun visitMapField(ctx: Protobuf2Parser.MapFieldContext): ProtoMapField { + val keyType = visitKeyType(ctx.keyType()) + val valuesType = visitType_(ctx.type_()) + + val name = ctx.mapName().text + val number = visitIntLit(ctx.fieldNumber().intLit()) val options = visitFieldOptions(ctx.fieldOptions()) @@ -438,6 +660,10 @@ class ProtobufModelBuilderVisitor( return ctx?.fieldOption().orEmpty().map { visitFieldOption(it) } } + override fun visitFieldOptions(ctx: Protobuf2Parser.FieldOptionsContext?): List { + return ctx?.fieldOption().orEmpty().map { visitFieldOption(it) } + } + private fun visitFieldOption(ctx: ParserRuleContext, name: String, value: String): ProtoOption { return ProtoOption(name, value, ctx) } @@ -450,6 +676,10 @@ class ProtobufModelBuilderVisitor( return visitFieldOption(ctx, ctx.optionName().text, ctx.constant().text) } + override fun visitFieldOption(ctx: Protobuf2Parser.FieldOptionContext): ProtoOption { + return visitFieldOption(ctx, ctx.optionName().text, ctx.constant().text) + } + private fun visitEnumField( ctx: ParserRuleContext, name: String, @@ -467,7 +697,7 @@ class ProtobufModelBuilderVisitor( override fun visitEnumField(ctx: ProtobufEditionsParser.EnumFieldContext): ProtoEnumField { val name = ctx.ident().text - val number = ctx.intLit().parseInt() + val number = visitIntLit(ctx.intLit()) val options = visitEnumValueOptions(ctx.enumValueOptions()) return visitEnumField(ctx, name, number, options, ctx.MINUS() != null) @@ -475,7 +705,15 @@ class ProtobufModelBuilderVisitor( override fun visitEnumField(ctx: Protobuf3Parser.EnumFieldContext): ProtoEnumField { val name = ctx.ident().text - val number = ctx.intLit().parseInt() + val number = visitIntLit(ctx.intLit()) + val options = visitEnumValueOptions(ctx.enumValueOptions()) + + return visitEnumField(ctx, name, number, options, ctx.MINUS() != null) + } + + override fun visitEnumField(ctx: Protobuf2Parser.EnumFieldContext): ProtoEnumField { + val name = ctx.ident().text + val number = visitIntLit(ctx.intLit()) val options = visitEnumValueOptions(ctx.enumValueOptions()) return visitEnumField(ctx, name, number, options, ctx.MINUS() != null) @@ -490,6 +728,10 @@ class ProtobufModelBuilderVisitor( return ctx?.enumValueOption().orEmpty().map { visitEnumValueOption(it) } } + override fun visitEnumValueOptions(ctx: Protobuf2Parser.EnumValueOptionsContext?): List { + return ctx?.enumValueOption().orEmpty().map { visitEnumValueOption(it) } + } + override fun visitEnumValueOption(ctx: ProtobufEditionsParser.EnumValueOptionContext): ProtoOption { return ProtoOption(name = ctx.optionName().text, value = ctx.constant().text, ctx) } @@ -498,7 +740,13 @@ class ProtobufModelBuilderVisitor( return ProtoOption(name = ctx.optionName().text, value = ctx.constant().text, ctx) } + override fun visitEnumValueOption(ctx: Protobuf2Parser.EnumValueOptionContext): ProtoOption { + return ProtoOption(name = ctx.optionName().text, value = ctx.constant().text, ctx) + } + + // ------------------------- // One-Of Parsing + // ------------------------- override fun visitOneof(ctx: ProtobufEditionsParser.OneofContext): ProtoOneOf { val name = ctx.oneofName().text @@ -518,10 +766,19 @@ class ProtobufModelBuilderVisitor( return ProtoOneOf(name = name, fields = fields, options = options) } + override fun visitOneof(ctx: Protobuf2Parser.OneofContext): ProtoOneOf { + val name = ctx.oneofName().text + + val options = ctx.optionStatement().map { visitOptionStatement(it) } + val fields = ctx.oneofField().map { visitOneofField(it) } + + return ProtoOneOf(name = name, fields = fields, options = options) + } + override fun visitOneofField(ctx: ProtobufEditionsParser.OneofFieldContext): ProtoOneOfField { val type = visitType_(ctx.type_()) val name = ctx.fieldName().text - val number = ctx.fieldNumber().parseInt() + val number = visitIntLit(ctx.fieldNumber().intLit()) val options = visitFieldOptions(ctx.fieldOptions()) @@ -537,7 +794,23 @@ class ProtobufModelBuilderVisitor( override fun visitOneofField(ctx: Protobuf3Parser.OneofFieldContext): ProtoOneOfField { val type = visitType_(ctx.type_()) val name = ctx.fieldName().text - val number = ctx.fieldNumber().parseInt() + val number = visitIntLit(ctx.fieldNumber().intLit()) + + val options = visitFieldOptions(ctx.fieldOptions()) + + return ProtoOneOfField( + type = type, + name = name, + number = number, + options = options, + ctx = ctx + ) + } + + override fun visitOneofField(ctx: Protobuf2Parser.OneofFieldContext): ProtoOneOfField { + val type = visitType_(ctx.type_()) + val name = ctx.fieldName().text + val number = visitIntLit(ctx.fieldNumber().intLit()) val options = visitFieldOptions(ctx.fieldOptions()) @@ -550,7 +823,9 @@ class ProtobufModelBuilderVisitor( ) } + // ------------------------- // Service parsing + // ------------------------- override fun visitServiceDef(ctx: ProtobufEditionsParser.ServiceDefContext): ProtoService { val name = ctx.serviceName().text @@ -578,6 +853,19 @@ class ProtobufModelBuilderVisitor( ) } + override fun visitServiceDef(ctx: Protobuf2Parser.ServiceDefContext): ProtoService { + val name = ctx.serviceName().text + val options = ctx.serviceElement().mapNotNull { it.optionStatement() }.map { visitOptionStatement(it) } + val rpcs = ctx.serviceElement().mapNotNull { it.rpc() }.map { visitRpc(it) } + + return ProtoService( + name = name, + options = options, + rpcs = rpcs, + ctx = ctx + ) + } + override fun visitRpc(ctx: ProtobufEditionsParser.RpcContext): ProtoRpc { val name = ctx.rpcName().text val clientType = visitMessageType(ctx.messageType(0)) @@ -614,10 +902,31 @@ class ProtobufModelBuilderVisitor( ) } + override fun visitRpc(ctx: Protobuf2Parser.RpcContext): ProtoRpc { + val name = ctx.rpcName().text + val clientType = visitMessageType(ctx.messageType(0)) + val serverType = visitMessageType(ctx.messageType(1)) + val isClientStream = ctx.clientStream != null + val isServerStream = ctx.serverStream != null + val options = ctx.optionStatement().map { visitOptionStatement(it) } + + return ProtoRpc( + name = name, + sendType = clientType, + returnType = serverType, + isSendingStream = isClientStream, + isReceivingStream = isServerStream, + options = options + ) + } + override fun visitServiceElement(ctx: ProtobufEditionsParser.ServiceElementContext?): Any = Unit override fun visitServiceElement(ctx: Protobuf3Parser.ServiceElementContext?): Any = Unit + override fun visitServiceElement(ctx: Protobuf2Parser.ServiceElementContext): Any = Unit + // ------------------------- // Extensions + // ------------------------- override fun visitExtendDef(ctx: ProtobufEditionsParser.ExtendDefContext): ProtoExtensionDefinition { val messageDef = ctx.messageType().text @@ -633,11 +942,37 @@ class ProtobufModelBuilderVisitor( return ProtoExtensionDefinition(ProtoType.DefType(messageDef, ctx.messageType()), fields, ctx) } + override fun visitExtendDef(ctx: Protobuf2Parser.ExtendDefContext): Proto2ExtendDefinitionData { + val messageDef = ctx.messageType().text + + val elements = ctx.extendElement() + val fieldsFromFields = elements.mapNotNull { it.field() }.map { visitField(it) } + // Groups inside extend: best-effort mapping to a field (group message is not modeled in ProtoExtensionDefinition). + val groups = elements.mapNotNull { it.group() }.map { grp -> visitGroup(grp) } + + val extensionDefinition = ProtoExtensionDefinition( + messageType = ProtoType.DefType(messageDef, ctx.messageType()), + fields = fieldsFromFields + groups.map { it.field }, + ctx = ctx + ) + + return Proto2ExtendDefinitionData( + extensionDefinition = extensionDefinition, + groupMessages = groups.map { it.message } + ) + } + override fun visitExtensions(ctx: ProtobufEditionsParser.ExtensionsContext): ProtoExtensionRanges { return ProtoExtensionRanges(ranges = visitRanges(ctx.ranges())) } + override fun visitExtensions(ctx: Protobuf2Parser.ExtensionsContext): ProtoExtensionRanges { + return ProtoExtensionRanges(ranges = visitRanges(ctx.ranges())) + } + + // ------------------------- // Type parsing + // ------------------------- override fun visitType_(ctx: ProtobufEditionsParser.Type_Context): ProtoType { return when { @@ -657,7 +992,7 @@ class ProtobufModelBuilderVisitor( ctx.BOOL() != null -> ProtoType.BoolType ctx.STRING() != null -> ProtoType.StringType ctx.BYTES() != null -> ProtoType.BytesType - else -> throw ParseException("Unknown type found.", ctx) + else -> throw ParseException("Unknown type found.", ctx, filePath) } } @@ -679,7 +1014,29 @@ class ProtobufModelBuilderVisitor( ctx.BOOL() != null -> ProtoType.BoolType ctx.STRING() != null -> ProtoType.StringType ctx.BYTES() != null -> ProtoType.BytesType - else -> throw ParseException("Unknown type found.", ctx) + else -> throw ParseException("Unknown type found.", ctx, filePath) + } + } + + override fun visitType_(ctx: Protobuf2Parser.Type_Context): ProtoType { + return when { + ctx.messageType() != null || ctx.enumType() != null -> ProtoType.DefType(ctx.text, ctx) + ctx.DOUBLE() != null -> ProtoType.DoubleType + ctx.FLOAT() != null -> ProtoType.FloatType + ctx.INT32() != null -> ProtoType.Int32Type + ctx.INT64() != null -> ProtoType.Int64Type + ctx.UINT32() != null -> ProtoType.UInt32Type + ctx.UINT64() != null -> ProtoType.UInt64Type + ctx.SINT32() != null -> ProtoType.SInt32Type + ctx.SINT64() != null -> ProtoType.SInt64Type + ctx.FIXED32() != null -> ProtoType.Fixed32Type + ctx.FIXED64() != null -> ProtoType.Fixed64Type + ctx.SFIXED32() != null -> ProtoType.SFixed32Type + ctx.SFIXED64() != null -> ProtoType.SFixed64Type + ctx.BOOL() != null -> ProtoType.BoolType + ctx.STRING() != null -> ProtoType.StringType + ctx.BYTES() != null -> ProtoType.BytesType + else -> throw ParseException("Unknown type found.", ctx, filePath) } } @@ -697,7 +1054,7 @@ class ProtobufModelBuilderVisitor( ctx.SFIXED64() != null -> ProtoType.SFixed64Type ctx.BOOL() != null -> ProtoType.BoolType ctx.STRING() != null -> ProtoType.StringType - else -> throw ParseException("Unknown type found.", ctx) + else -> throw ParseException("Unknown type found.", ctx, filePath) } } @@ -715,7 +1072,25 @@ class ProtobufModelBuilderVisitor( ctx.SFIXED64() != null -> ProtoType.SFixed64Type ctx.BOOL() != null -> ProtoType.BoolType ctx.STRING() != null -> ProtoType.StringType - else -> throw ParseException("Unknown type found.", ctx) + else -> throw ParseException("Unknown type found.", ctx, filePath) + } + } + + override fun visitKeyType(ctx: Protobuf2Parser.KeyTypeContext): ProtoType.MapKeyType { + return when { + ctx.INT32() != null -> ProtoType.Int32Type + ctx.INT64() != null -> ProtoType.Int64Type + ctx.UINT32() != null -> ProtoType.UInt32Type + ctx.UINT64() != null -> ProtoType.UInt64Type + ctx.SINT32() != null -> ProtoType.SInt32Type + ctx.SINT64() != null -> ProtoType.SInt64Type + ctx.FIXED32() != null -> ProtoType.Fixed32Type + ctx.FIXED64() != null -> ProtoType.Fixed64Type + ctx.SFIXED32() != null -> ProtoType.SFixed32Type + ctx.SFIXED64() != null -> ProtoType.SFixed64Type + ctx.BOOL() != null -> ProtoType.BoolType + ctx.STRING() != null -> ProtoType.StringType + else -> throw ParseException("Unknown type found.", ctx, filePath) } } @@ -727,6 +1102,14 @@ class ProtobufModelBuilderVisitor( return ProtoType.DefType(ctx.text, ctx) } + override fun visitMessageType(ctx: Protobuf2Parser.MessageTypeContext): ProtoType.DefType { + return ProtoType.DefType(ctx.text, ctx) + } + + // ------------------------- + // Remaining visitor methods (stubs) + // ------------------------- + override fun visit(tree: ParseTree): Any = Unit override fun visitChildren(node: RuleNode?): Any = Unit @@ -737,83 +1120,133 @@ class ProtobufModelBuilderVisitor( override fun visitEdition(ctx: ProtobufEditionsParser.EditionContext?): Any = Unit override fun visitSyntax(ctx: Protobuf3Parser.SyntaxContext?): Any = Unit + override fun visitSyntax(ctx: Protobuf2Parser.SyntaxContext?): Any = Unit override fun visitPackageStatement(ctx: ProtobufEditionsParser.PackageStatementContext?): Any = Unit override fun visitPackageStatement(ctx: Protobuf3Parser.PackageStatementContext?): Any = Unit + override fun visitPackageStatement(ctx: Protobuf2Parser.PackageStatementContext?): Any = Unit override fun visitOptionName(ctx: ProtobufEditionsParser.OptionNameContext?): Any = Unit override fun visitOptionName(ctx: Protobuf3Parser.OptionNameContext?): Any = Unit + override fun visitOptionName(ctx: Protobuf2Parser.OptionNameContext?): Any = Unit override fun visitFieldLabel(ctx: ProtobufEditionsParser.FieldLabelContext?): Any = Unit override fun visitFieldLabel(ctx: Protobuf3Parser.FieldLabelContext?): Any = Unit + override fun visitFieldLabel(ctx: Protobuf2Parser.FieldLabelContext?): Any = Unit override fun visitFieldNumber(ctx: ProtobufEditionsParser.FieldNumberContext?): Any = Unit override fun visitFieldNumber(ctx: Protobuf3Parser.FieldNumberContext?): Any = Unit + override fun visitFieldNumber(ctx: Protobuf2Parser.FieldNumberContext?): Any = Unit override fun visitTopLevelDef(ctx: ProtobufEditionsParser.TopLevelDefContext?): Any = Unit override fun visitTopLevelDef(ctx: Protobuf3Parser.TopLevelDefContext?): Any = Unit + override fun visitTopLevelDef(ctx: Protobuf2Parser.TopLevelDefContext?): Any = Unit override fun visitEnumBody(ctx: ProtobufEditionsParser.EnumBodyContext?): Any = Unit override fun visitEnumBody(ctx: Protobuf3Parser.EnumBodyContext?): Any = Unit + override fun visitEnumBody(ctx: Protobuf2Parser.EnumBodyContext?): Any = Unit override fun visitEnumElement(ctx: ProtobufEditionsParser.EnumElementContext?): Any = Unit override fun visitEnumElement(ctx: Protobuf3Parser.EnumElementContext?): Any = Unit + override fun visitEnumElement(ctx: Protobuf2Parser.EnumElementContext?): Any = Unit override fun visitConstant(ctx: ProtobufEditionsParser.ConstantContext?): Any = Unit override fun visitConstant(ctx: Protobuf3Parser.ConstantContext?): Any = Unit + override fun visitConstant(ctx: Protobuf2Parser.ConstantContext?): Any = Unit override fun visitBlockLit(ctx: ProtobufEditionsParser.BlockLitContext?): Any = Unit override fun visitBlockLit(ctx: Protobuf3Parser.BlockLitContext?): Any = Unit + override fun visitBlockLit(ctx: Protobuf2Parser.BlockLitContext?): Any = Unit override fun visitEmptyStatement_(ctx: ProtobufEditionsParser.EmptyStatement_Context?): Any = Unit override fun visitEmptyStatement_(ctx: Protobuf3Parser.EmptyStatement_Context?): Any = Unit + override fun visitEmptyStatement_(ctx: Protobuf2Parser.EmptyStatement_Context?): Any = Unit override fun visitIdent(ctx: ProtobufEditionsParser.IdentContext?): Any = Unit override fun visitIdent(ctx: Protobuf3Parser.IdentContext?): Any = Unit + override fun visitIdent(ctx: Protobuf2Parser.IdentContext?): Any = Unit override fun visitFullIdent(ctx: ProtobufEditionsParser.FullIdentContext?): Any = Unit override fun visitFullIdent(ctx: Protobuf3Parser.FullIdentContext?): Any = Unit + override fun visitFullIdent(ctx: Protobuf2Parser.FullIdentContext?): Any = Unit override fun visitMessageName(ctx: ProtobufEditionsParser.MessageNameContext?): Any = Unit override fun visitMessageName(ctx: Protobuf3Parser.MessageNameContext?): Any = Unit + override fun visitMessageName(ctx: Protobuf2Parser.MessageNameContext?): Any = Unit override fun visitEnumName(ctx: ProtobufEditionsParser.EnumNameContext?): Any = Unit override fun visitEnumName(ctx: Protobuf3Parser.EnumNameContext?): Any = Unit + override fun visitEnumName(ctx: Protobuf2Parser.EnumNameContext?): Any = Unit override fun visitFieldName(ctx: ProtobufEditionsParser.FieldNameContext?): Any = Unit override fun visitFieldName(ctx: Protobuf3Parser.FieldNameContext?): Any = Unit + override fun visitFieldName(ctx: Protobuf2Parser.FieldNameContext?): Any = Unit override fun visitOneofName(ctx: ProtobufEditionsParser.OneofNameContext?): Any = Unit override fun visitOneofName(ctx: Protobuf3Parser.OneofNameContext?): Any = Unit + override fun visitOneofName(ctx: Protobuf2Parser.OneofNameContext?): Any = Unit override fun visitMapName(ctx: ProtobufEditionsParser.MapNameContext?): Any = Unit override fun visitMapName(ctx: Protobuf3Parser.MapNameContext?): Any = Unit + override fun visitMapName(ctx: Protobuf2Parser.MapNameContext?): Any = Unit override fun visitServiceName(ctx: ProtobufEditionsParser.ServiceNameContext?): Any = Unit override fun visitServiceName(ctx: Protobuf3Parser.ServiceNameContext?): Any = Unit + override fun visitServiceName(ctx: Protobuf2Parser.ServiceNameContext?): Any = Unit override fun visitRpcName(ctx: ProtobufEditionsParser.RpcNameContext?): Any = Unit override fun visitRpcName(ctx: Protobuf3Parser.RpcNameContext?): Any = Unit + override fun visitRpcName(ctx: Protobuf2Parser.RpcNameContext?): Any = Unit override fun visitEnumType(ctx: ProtobufEditionsParser.EnumTypeContext?): Any = Unit override fun visitEnumType(ctx: Protobuf3Parser.EnumTypeContext?): Any = Unit + override fun visitEnumType(ctx: Protobuf2Parser.EnumTypeContext?): Any = Unit + + override fun visitIntLit(ctx: ProtobufEditionsParser.IntLitContext): Int = visitIntLit(ctx.text, ctx) + override fun visitIntLit(ctx: Protobuf3Parser.IntLitContext): Int = visitIntLit(ctx.text, ctx) + override fun visitIntLit(ctx: Protobuf2Parser.IntLitContext): Int = visitIntLit(ctx.text, ctx) - override fun visitIntLit(ctx: ProtobufEditionsParser.IntLitContext?): Any = Unit - override fun visitIntLit(ctx: Protobuf3Parser.IntLitContext?): Any = Unit + private fun visitIntLit(text: String, ctx: ParserRuleContext): Int { + return when { + text.startsWith("0x") || text.startsWith("0X") -> { + text.substring(2).toIntOrNull(16) + } + text.startsWith('0') -> { + text.toIntOrNull(8) + } + else -> text.toIntOrNull() + } ?: throw ParseException("Could not parse integer", ctx, filePath) + } override fun visitStrLit(ctx: ProtobufEditionsParser.StrLitContext?): Any = Unit override fun visitStrLit(ctx: Protobuf3Parser.StrLitContext?): Any = Unit + override fun visitStrLit(ctx: Protobuf2Parser.StrLitContext?): Any = Unit override fun visitBoolLit(ctx: ProtobufEditionsParser.BoolLitContext?): Any = Unit override fun visitBoolLit(ctx: Protobuf3Parser.BoolLitContext?): Any = Unit + override fun visitBoolLit(ctx: Protobuf2Parser.BoolLitContext?): Any = Unit override fun visitFloatLit(ctx: ProtobufEditionsParser.FloatLitContext?): Any = Unit override fun visitFloatLit(ctx: Protobuf3Parser.FloatLitContext?): Any = Unit + override fun visitFloatLit(ctx: Protobuf2Parser.FloatLitContext?): Any = Unit override fun visitKeywords(ctx: ProtobufEditionsParser.KeywordsContext?): Any = Unit override fun visitKeywords(ctx: Protobuf3Parser.KeywordsContext?): Any = Unit - - private fun ParserRuleContext.parseInt(): Int { - return text.toIntOrNull() ?: throw ParseException("Could not parse integer", this) - } + override fun visitKeywords(ctx: Protobuf2Parser.KeywordsContext?): Any = Unit + + // Protobuf2-only nodes (stubs / safety) + override fun visitExtendElement(ctx: Protobuf2Parser.ExtendElementContext?): Any = Unit + override fun visitStream(ctx: Protobuf2Parser.StreamContext?): Any = Unit + override fun visitGroupName(ctx: Protobuf2Parser.GroupNameContext?): Any = Unit + override fun visitStreamName(ctx: Protobuf2Parser.StreamNameContext?): Any = Unit + + data class ParsedGroup( + val field: ProtoMessageField, + val message: ProtoMessage + ) + + data class Proto2ExtendDefinitionData( + val extensionDefinition: ProtoExtensionDefinition, + val groupMessages: List + ) } diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/parsing/ProtobufParserException.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/parsing/ProtobufParserException.kt index 32b941ca..09e14e9d 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/parsing/ProtobufParserException.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/parsing/ProtobufParserException.kt @@ -1,5 +1,9 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.parsing +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.util.toFilePositionString import org.antlr.v4.runtime.ParserRuleContext -class ProtobufParserException(override val message: String?, val ctx: ParserRuleContext) : Exception() +class ProtobufParserException(val msg: String?, val ctx: ParserRuleContext, val filePath: String) : Exception() { + override val message: String + get() ="${ctx.toFilePositionString(filePath)}: $msg" +} diff --git a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/WellKnownTypesFolder.kt b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/WellKnownTypesFolder.kt new file mode 100644 index 00000000..08098084 --- /dev/null +++ b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/WellKnownTypesFolder.kt @@ -0,0 +1,23 @@ +package io.github.timortel.kotlin_multiplatform_grpc_plugin + +val wellKnownTypesFolder = FakeInputDirectory( + name = "google", + path = "google", + files = listOf( + FakeInputDirectory( + name = "protobuf", + path = "protobuf", + files = listOf( + FakeInputFile( + name = "descriptor.proto", + content = Thread.currentThread().contextClassLoader.getResourceAsStream("google/protobuf/descriptor.proto") + .use { inputStream -> + inputStream!!.bufferedReader().use { bufferedReader -> + bufferedReader.readText() + } + } + ) + ) + ) + ) +) diff --git a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/modeltree/DefaultEnumValueTest.kt b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/modeltree/DefaultEnumValueTest.kt new file mode 100644 index 00000000..f7811432 --- /dev/null +++ b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/modeltree/DefaultEnumValueTest.kt @@ -0,0 +1,80 @@ +package io.github.timortel.kotlin_multiplatform_grpc_plugin.modeltree + +import com.google.testing.junit.testparameterinjector.junit5.TestParameter +import com.google.testing.junit.testparameterinjector.junit5.TestParameterInjectorTest +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.model.declaration.message.field.ProtoMessageField +import io.github.timortel.kotlin_multiplatform_grpc_plugin.validation.BaseValidationTest +import org.junit.jupiter.api.Assertions + +class DefaultEnumValueTest : BaseModelTreeTest() { + + @TestParameterInjectorTest + fun `test USING proto langauge version WHEN proto enum field is used without default value THEN the first entry is the default value`( + @TestParameter version: BaseValidationTest.ProtoVersion + ) { + val fieldPrefix = when (version) { + BaseValidationTest.ProtoVersion.PROTO2 -> "required" + else -> "" + } + + assertDefaultEnumValue( + proto = """ + ${if (version == BaseValidationTest.ProtoVersion.EDITION2024) "option features.(pb.java).nest_in_file_class = YES;" else ""} + + enum A { + A = 0; + B = 1; + } + + message C { + $fieldPrefix A a = 1; + } + """, + version = version, + expectedDefaultValue = "A.A" + ) + } + + @TestParameterInjectorTest + fun `test USING langauge version WHEN proto enum field is used with default value THEN the correct default value is chosen`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) version: BaseValidationTest.ProtoVersion + ) { + val fieldPrefix = when (version) { + BaseValidationTest.ProtoVersion.PROTO2 -> "required" + else -> "" + } + + assertDefaultEnumValue( + proto = """ + ${if (version == BaseValidationTest.ProtoVersion.EDITION2024) "option features.(pb.java).nest_in_file_class = YES;" else ""} + enum A { + A = 0; + B = 1; + } + + message C { + $fieldPrefix A a = 1 [default = B]; + } + """, + version = version, + expectedDefaultValue = "A.B" + ) + } + + private fun assertDefaultEnumValue( + proto: String, + version: BaseValidationTest.ProtoVersion, + expectedDefaultValue: String + ) { + val project = buildProject(proto.trimIndent(), version) + + val field = project + .findMessage("C") + .findField("a") + .assertIsInstance() + + val defaultValueCode = field.defaultValue().toString() + + Assertions.assertEquals("TestFile.$expectedDefaultValue", defaultValueCode) + } +} diff --git a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/BaseValidationTest.kt b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/BaseValidationTest.kt index 40086100..72fbf8df 100644 --- a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/BaseValidationTest.kt +++ b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/BaseValidationTest.kt @@ -32,8 +32,14 @@ abstract class BaseValidationTest { } enum class ProtoVersion(val header: String) { + PROTO2("syntax = \"proto2\";"), PROTO3("syntax = \"proto3\";"), EDITION2023("edition = \"2023\";"), - EDITION2024("edition = \"2024\";") + EDITION2024("edition = \"2024\";"); + + val fieldPrefix: String get() = when (this) { + PROTO2 -> "optional" + else -> "" + } } } diff --git a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/EnumImportValidationTest.kt b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/EnumImportValidationTest.kt index 050651ce..d2227ca3 100644 --- a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/EnumImportValidationTest.kt +++ b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/EnumImportValidationTest.kt @@ -9,7 +9,7 @@ import org.junit.jupiter.api.assertThrows class EnumImportValidationTest : BaseValidationTest() { @Test - fun `test WHEN proto3 imports closed enum THEN error is thrown`() { + fun `test WHEN proto3 imports editions closed enum THEN error is thrown`() { assertThrows { runGenerator( listOf( @@ -44,6 +44,41 @@ class EnumImportValidationTest : BaseValidationTest() { } } + @Test + fun `test WHEN proto3 imports proto2 enum THEN error is thrown`() { + assertThrows { + runGenerator( + listOf( + FakeInputDirectory( + name = "dir", + files = listOf( + createProtoFile( + fileHeader = ProtoVersion.PROTO2.header, + content = """ + enum A { + DEFAULT = 0; + } + """.trimIndent(), + name = "file1.proto" + ), + createProtoFile( + fileHeader = ProtoVersion.PROTO3.header, + content = """ + import "file1.proto"; + + message B { + A a = 1; + } + """.trimIndent(), + name = "file2.proto" + ) + ) + ) + ) + ) + } + } + @Test fun `test WHEN proto3 imports open enum THEN no error is thrown`() { runGenerator( diff --git a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/EnumValidationTests.kt b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/EnumValidationTests.kt index 0a481cb6..00e4eedd 100644 --- a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/EnumValidationTests.kt +++ b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/EnumValidationTests.kt @@ -1,16 +1,19 @@ package io.github.timortel.kotlin_multiplatform_grpc_plugin.validation -import io.github.timortel.kotlin_multiplatform_grpc_plugin.matchWarning +import com.google.testing.junit.testparameterinjector.junit5.TestParameter +import com.google.testing.junit.testparameterinjector.junit5.TestParameterInjectorTest import io.github.timortel.kmpgrpc.plugin.sourcegeneration.CompilationException import io.github.timortel.kmpgrpc.plugin.sourcegeneration.Warnings +import io.github.timortel.kotlin_multiplatform_grpc_plugin.matchWarning import io.mockk.verify -import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows class EnumValidationTests : BaseValidationTest() { - @Test - fun `test WHEN enum has two fields with the same name THEN a compilation exception is thrown`() { + @TestParameterInjectorTest + fun `test WHEN enum has two fields with the same name THEN a compilation exception is thrown`( + @TestParameter protoVersion: ProtoVersion + ) { assertThrows { runGenerator( """ @@ -19,13 +22,16 @@ class EnumValidationTests : BaseValidationTest() { b = 1; b = 2; } - """.trimIndent() + """.trimIndent(), + protoVersion ) } } - @Test - fun `test WHEN enum uses a directly reserved number THEN a compilation exception is thrown`() { + @TestParameterInjectorTest + fun `test WHEN enum uses a directly reserved number THEN a compilation exception is thrown`( + @TestParameter protoVersion: ProtoVersion + ) { assertThrows { runGenerator( """ @@ -34,13 +40,16 @@ class EnumValidationTests : BaseValidationTest() { a = 0; b = 1; } - """.trimIndent() + """.trimIndent(), + protoVersion ) } } - @Test - fun `test WHEN enum uses a reserved number in range THEN a compilation exception is thrown`() { + @TestParameterInjectorTest + fun `test WHEN enum uses a reserved number in range THEN a compilation exception is thrown`( + @TestParameter protoVersion: ProtoVersion + ) { assertThrows { runGenerator( """ @@ -49,13 +58,16 @@ class EnumValidationTests : BaseValidationTest() { a = 0; b = 14; } - """.trimIndent() + """.trimIndent(), + protoVersion ) } } - @Test - fun `test WHEN enum uses a reserved field name THEN a compilation exception is thrown`() { + @TestParameterInjectorTest + fun `test WHEN enum uses a reserved field name THEN a compilation exception is thrown`( + @TestParameter protoVersion: ProtoVersion + ) { assertThrows { runGenerator( """ @@ -64,39 +76,48 @@ class EnumValidationTests : BaseValidationTest() { a = 0; b = 1; } - """.trimIndent() + """.trimIndent(), + protoVersion ) } } - @Test - fun `test WHEN enum does not have default field THEN a compilation exception is thrown`() { + @TestParameterInjectorTest + fun `test WHEN enum does not have default field THEN a compilation exception is thrown`( + @TestParameter(value = ["PROTO3", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { assertThrows { runGenerator( """ enum TestEnum { field = 1; } - """.trimIndent() + """.trimIndent(), + protoVersion ) } } - @Test - fun `test WHEN enum has no field THEN compilation exception is thrown`() { + @TestParameterInjectorTest + fun `test WHEN enum has no field THEN compilation exception is thrown`( + @TestParameter protoVersion: ProtoVersion + ) { assertThrows { runGenerator( """ enum TestEnum { } - """.trimIndent() + """.trimIndent(), + protoVersion ) } } - @Test - fun `test WHEN enum has enum aliases but without the option THEN a warning is printed`() { + @TestParameterInjectorTest + fun `test WHEN enum has enum aliases but without the option THEN a warning is printed`( + @TestParameter protoVersion: ProtoVersion + ) { runGenerator( """ enum TestEnum { @@ -104,14 +125,17 @@ class EnumValidationTests : BaseValidationTest() { field2 = 1; field3 = 1; } - """.trimIndent() + """.trimIndent(), + protoVersion ) verify(atLeast = 1) { logger.warn(matchWarning(Warnings.enumAliasWithoutOption)) } } - @Test - fun `test WHEN enum has enum aliases and the option THEN no warning is printed`() { + @TestParameterInjectorTest + fun `test WHEN enum has enum aliases and the option THEN no warning is printed`( + @TestParameter protoVersion: ProtoVersion + ) { runGenerator( """ enum TestEnum { @@ -120,9 +144,10 @@ class EnumValidationTests : BaseValidationTest() { field2 = 1; field3 = 1; } - """.trimIndent() + """.trimIndent(), + protoVersion ) verify(atLeast = 0) { logger.warn(matchWarning(Warnings.enumAliasWithoutOption)) } } -} \ No newline at end of file +} diff --git a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/ExtensionDefinitionValidationTests.kt b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/ExtensionDefinitionValidationTests.kt index eae90e8e..68effe7e 100644 --- a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/ExtensionDefinitionValidationTests.kt +++ b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/ExtensionDefinitionValidationTests.kt @@ -1,15 +1,20 @@ package io.github.timortel.kotlin_multiplatform_grpc_plugin.validation +import com.google.testing.junit.testparameterinjector.junit5.TestParameter +import com.google.testing.junit.testparameterinjector.junit5.TestParameterInjectorTest import io.github.timortel.kmpgrpc.plugin.sourcegeneration.CompilationException import io.github.timortel.kotlin_multiplatform_grpc_plugin.FakeInputDirectory import io.github.timortel.kotlin_multiplatform_grpc_plugin.createProtoFile -import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows class ExtensionDefinitionValidationTests : BaseValidationTest() { - @Test - fun `test GIVEN an extension references an enum WHEN generating the code THEN an error is thrown`() { + @TestParameterInjectorTest + fun `test GIVEN an extension references an enum WHEN generating the code THEN an error is thrown`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { + val declarationPrefix = getDeclarationPrefix(protoVersion) + assertThrows { runGenerator( """ @@ -18,16 +23,20 @@ class ExtensionDefinitionValidationTests : BaseValidationTest() { } extend A { - string a = 1; + $declarationPrefix string a = 1; } """.trimIndent(), - protoVersion = ProtoVersion.EDITION2023, + protoVersion = protoVersion ) } } - @Test - fun `test GIVEN duplicated extensions in the same extension WHEN generating the code THEN an error is thrown`() { + @TestParameterInjectorTest + fun `test GIVEN duplicated extensions in the same extension WHEN generating the code THEN an error is thrown`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { + val declarationPrefix = getDeclarationPrefix(protoVersion) + assertThrows { runGenerator( """ @@ -36,17 +45,21 @@ class ExtensionDefinitionValidationTests : BaseValidationTest() { } extend A { - string a = 1; - string a = 2; + $declarationPrefix string a = 1; + $declarationPrefix string a = 2; } """.trimIndent(), - protoVersion = ProtoVersion.EDITION2023, + protoVersion = protoVersion, ) } } - @Test - fun `test GIVEN duplicated extensions in different extensions WHEN generating the code THEN an error is thrown`() { + @TestParameterInjectorTest + fun `test GIVEN duplicated extensions in different extensions WHEN generating the code THEN an error is thrown`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { + val declarationPrefix = getDeclarationPrefix(protoVersion) + assertThrows { runGenerator( """ @@ -59,20 +72,24 @@ class ExtensionDefinitionValidationTests : BaseValidationTest() { } extend A { - string a = 1; + $declarationPrefix string a = 1; } extend B { - string a = 2; + $declarationPrefix string a = 2; } """.trimIndent(), - protoVersion = ProtoVersion.EDITION2023, + protoVersion = protoVersion, ) } } - @Test - fun `test GIVEN a message that is not extendable and a defined extension for the message WHEN generating the code THEN an error is thrown`() { + @TestParameterInjectorTest + fun `test GIVEN a message that is not extendable and a defined extension for the message WHEN generating the code THEN an error is thrown`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { + val declarationPrefix = getDeclarationPrefix(protoVersion) + assertThrows { runGenerator( """ @@ -81,16 +98,20 @@ class ExtensionDefinitionValidationTests : BaseValidationTest() { } extend A { - string a = 1; + $declarationPrefix string a = 1; } """.trimIndent(), - protoVersion = ProtoVersion.EDITION2023, + protoVersion = protoVersion, ) } } - @Test - fun `test GIVEN reused field numbers in extension in the extension definition WHEN generating the code THEN an error is thrown`() { + @TestParameterInjectorTest + fun `test GIVEN reused field numbers in extension in the extension definition WHEN generating the code THEN an error is thrown`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { + val declarationPrefix = getDeclarationPrefix(protoVersion) + assertThrows { runGenerator( """ @@ -99,40 +120,44 @@ class ExtensionDefinitionValidationTests : BaseValidationTest() { } extend A { - string a = 1; - string b = 1; + $declarationPrefix string a = 1; + $declarationPrefix string b = 1; } """.trimIndent(), - protoVersion = ProtoVersion.EDITION2023, + protoVersion = protoVersion, ) } } - @Test - fun `test GIVEN reused field numbers in extension in the extension definitions across multiple files WHEN generating the code THEN an error is thrown`() { + @TestParameterInjectorTest + fun `test GIVEN reused field numbers in extension in the extension definitions across multiple files WHEN generating the code THEN an error is thrown`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { + val declarationPrefix = getDeclarationPrefix(protoVersion) + val folder = FakeInputDirectory( name = "dir", files = listOf( createProtoFile( - fileHeader = ProtoVersion.EDITION2023.header, + fileHeader = protoVersion.header, """ message A { extensions 1 to 5; } extend A { - string a = 1; + $declarationPrefix string a = 1; } """.trimIndent(), name = "file1" ), createProtoFile( - fileHeader = ProtoVersion.EDITION2023.header, + fileHeader = protoVersion.header, """ import "file1"; extend A { - string b = 1; + $declarationPrefix string b = 1; } """.trimIndent(), name = "file2" @@ -145,8 +170,12 @@ class ExtensionDefinitionValidationTests : BaseValidationTest() { } } - @Test - fun `test WHEN message has a extension definition with a minimum field number smaller than 1 THEN a compilation exception is thrown`() { + @TestParameterInjectorTest + fun `test WHEN message has a extension definition with a minimum field number smaller than 1 THEN a compilation exception is thrown`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { + val declarationPrefix = getDeclarationPrefix(protoVersion) + assertThrows { runGenerator( """ @@ -155,16 +184,20 @@ class ExtensionDefinitionValidationTests : BaseValidationTest() { } extend A { - string a = 1; + $declarationPrefix string a = 1; } """.trimIndent(), - protoVersion = ProtoVersion.EDITION2023 + protoVersion = protoVersion ) } } - @Test - fun `test WHEN message has field with field number greater than max field number THEN a compilation exception is thrown`() { + @TestParameterInjectorTest + fun `test WHEN message has field with field number greater than max field number THEN a compilation exception is thrown`( + @TestParameter(value = ["PROTO2", "EDITION2023", "EDITION2024"]) protoVersion: ProtoVersion + ) { + val declarationPrefix = getDeclarationPrefix(protoVersion) + assertThrows { runGenerator( """ @@ -173,11 +206,16 @@ class ExtensionDefinitionValidationTests : BaseValidationTest() { } extend A { - string a = 1; + $declarationPrefix string a = 1; } """.trimIndent(), - protoVersion = ProtoVersion.EDITION2023 + protoVersion = protoVersion ) } } + + private fun getDeclarationPrefix(protoVersion: ProtoVersion): String = when (protoVersion) { + ProtoVersion.PROTO2 -> "optional" + else -> "" + } } diff --git a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/OptionImportTest.kt b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/OptionImportTest.kt new file mode 100644 index 00000000..5f591c94 --- /dev/null +++ b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/OptionImportTest.kt @@ -0,0 +1,126 @@ +package io.github.timortel.kotlin_multiplatform_grpc_plugin.validation + +import io.github.timortel.kmpgrpc.plugin.sourcegeneration.CompilationException +import io.github.timortel.kotlin_multiplatform_grpc_plugin.FakeInputDirectory +import io.github.timortel.kotlin_multiplatform_grpc_plugin.createProtoFile +import io.github.timortel.kotlin_multiplatform_grpc_plugin.wellKnownTypesFolder +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows + +class OptionImportTest : BaseValidationTest() { + + @Test + fun `test GIVEN file with custom option WHEN importing it normally THEN all declarations are available`() { + runGenerator( + listOf( + FakeInputDirectory( + name = "dir", + files = listOf( + createProtoFile( + fileHeader = ProtoVersion.PROTO2.header, + content = """ + import "google/protobuf/descriptor.proto"; + extend google.protobuf.MessageOptions { + optional string custom_option = 51234; + } + + message B {} + """.trimIndent(), + name = "file1.proto" + ), + createProtoFile( + fileHeader = ProtoVersion.PROTO3.header, + content = """ + import "file1.proto"; + + message A { + option (custom_option) = "some value"; + B b = 1; + } + """.trimIndent(), + name = "file2.proto" + ), + wellKnownTypesFolder + ) + ) + ) + ) + } + + @Test + fun `test GIVEN file with custom option WHEN importing it as an option import and still using the declared message THEN an exception is thrown`() { + assertThrows { + runGenerator( + listOf( + FakeInputDirectory( + name = "dir", + files = listOf( + createProtoFile( + fileHeader = ProtoVersion.PROTO2.header, + content = """ + import "google/protobuf/descriptor.proto"; + extend google.protobuf.MessageOptions { + optional string custom_option = 51234; + } + + message B {} + """.trimIndent(), + name = "file1.proto" + ), + createProtoFile( + fileHeader = ProtoVersion.EDITION2024.header, + content = """ + import option "file1.proto"; + + message A { + option (custom_option) = "some value"; + B b = 1; + } + """.trimIndent(), + name = "file2.proto" + ), + wellKnownTypesFolder + ) + ) + ) + ) + } + } + + @Test + fun `test GIVEN file with custom option WHEN importing it as an option import and not using the declared message THEN no exception is thrown`() { + runGenerator( + listOf( + FakeInputDirectory( + name = "dir", + files = listOf( + createProtoFile( + fileHeader = ProtoVersion.PROTO2.header, + content = """ + import "google/protobuf/descriptor.proto"; + extend google.protobuf.MessageOptions { + optional string custom_option = 51234; + } + + message B {} + """.trimIndent(), + name = "file1.proto" + ), + createProtoFile( + fileHeader = ProtoVersion.EDITION2024.header, + content = """ + import option "file1.proto"; + + message A { + option (custom_option) = "some value"; + } + """.trimIndent(), + name = "file2.proto" + ), + wellKnownTypesFolder + ) + ) + ) + ) + } +} diff --git a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/options/OptionHolderValidationTests.kt b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/options/OptionHolderValidationTests.kt index cc76daaa..de5d63b1 100644 --- a/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/options/OptionHolderValidationTests.kt +++ b/kmp-grpc-plugin/src/test/java/io/github/timortel/kotlin_multiplatform_grpc_plugin/validation/options/OptionHolderValidationTests.kt @@ -87,7 +87,7 @@ class OptionHolderValidationTests : BaseValidationTest() { runGenerator( """ message TestMessage { - string field1 = 1 [foo="bar"]; + ${protoVersion.fieldPrefix} string field1 = 1 [foo="bar"]; } """.trimIndent(), protoVersion diff --git a/readme.md b/readme.md index 70b0d23f..b13b8b84 100644 --- a/readme.md +++ b/readme.md @@ -32,7 +32,7 @@ This projects implements client-side gRPC for Android, JVM, Native (including iO ### Supported protobuf versions | | Support status | |---------------|----------------| -| Proto2 | ⏳ Planned | +| Proto2 | ✅ Supported | | Proto3 | ✅ Supported | | Editions 2023 | ✅ Supported | | Editions 2024 | ✅ Supported | @@ -61,14 +61,15 @@ Please note that not all features may be available even if the protobuf version ### Supported proto options and features: ### Legacy options -| Proto Option | Proto3 | Edition 2023 | -|------------------------|--------|--------------| -| `java_package` | ✅ | ✅ | -| `java_outer_classname` | ✅ | ✅ | -| `java_multiple_files` | ✅ | ✅ | -| `deprecated` | ✅ | ✅ | -| `packed` | ✅ | ✅ | -| `optimize_for` | ❌ | ❌ | +| Proto Option | Proto2 | Proto3 | Edition 2023 | +|------------------------|--------|--------|--------------| +| `java_package` | ✅ | ✅ | ✅ | +| `java_outer_classname` | ✅ | ✅ | ✅ | +| `java_multiple_files` | ✅ | ✅ | ✅ | +| `deprecated` | ✅ | ✅ | ✅ | +| `packed` | ✅ | ✅ | ✅ | +| `default` enum-option | ✅ | ✅ | ✅ | +| `optimize_for` | ❌ | ❌ | ❌ | ### Features | Feature | Edition 2023 | Edition 2024 | @@ -95,18 +96,19 @@ Please note that not all features may be available even if the protobuf version ### Well-known types: For reference, see [the official documentation](https://protobuf.dev/reference/protobuf/google.protobuf/). Well-known types support must be enabled in your gradle config (see [Setup](#setup)). -| Protobuf Type | Supported | -|----------------------|---------------| -| `any.proto` | ✅ Supported | -| `api.proto` | ✅ Supported | -| `duration.proto` | ✅ Supported | -| `empty.proto` | ✅ Supported | -| `field_mask.proto` | ✅ Supported | +| Protobuf Type | Supported | +|------------------------|---------------| +| `any.proto` | ✅ Supported | +| `api.proto` | ✅ Supported | +| `descriptor.proto` | ✅ Supported | +| `duration.proto` | ✅ Supported | +| `empty.proto` | ✅ Supported | +| `field_mask.proto` | ✅ Supported | | `source_context.proto` | ✅ Supported | -| `struct.proto` | ✅ Supported | -| `timestamp.proto` | ✅ Supported | -| `type.proto` | ✅ Supported | -| `wrappers.proto` | ✅ Supported | +| `struct.proto` | ✅ Supported | +| `timestamp.proto` | ✅ Supported | +| `type.proto` | ✅ Supported | +| `wrappers.proto` | ✅ Supported | ### Additional Features - ✅ Generates DSL syntax to create messages @@ -456,7 +458,7 @@ The plugin generates kotlin code for all provided proto files. No `protoc` is ne by gRPC for JVM and by [tonic](https://github.com/hyperium/tonic) for all native targets. For JavaScript, the requests are handled by [ktor](https://github.com/ktorio/ktor). ## License -Copyright 2025 Tim Ortel +Copyright 2026 Tim Ortel Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at