diff --git a/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-required-fields.proto b/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-required-fields.proto index 02099df0..abe72bbe 100644 --- a/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-required-fields.proto +++ b/kmp-grpc-internal-test/src/commonMain/proto/proto2/proto2-required-fields.proto @@ -20,3 +20,15 @@ message Proto2MessageWithRequiredFields { string field6 = 6; } } + +message Proto2MessageWithRequiredExtension { + extensions 1 to max; +} + +extend Proto2MessageWithRequiredExtension { + optional string extension1 = 1; + + optional Proto2MessageWithMixedFields extensionRequiredMsg = 2; + + repeated Proto2MessageWithMixedFields extensionRepeatedMsg = 3; +} diff --git a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/IsInitializedTest.kt b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/IsInitializedTest.kt index 7975686c..f220a28a 100644 --- a/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/IsInitializedTest.kt +++ b/kmp-grpc-internal-test/src/commonTest/kotlin/io/github/timortel/kotlin_multiplatform_grpc_plugin/test/model/IsInitializedTest.kt @@ -1,6 +1,9 @@ package io.github.timortel.kotlin_multiplatform_grpc_plugin.test.model +import io.github.timortel.kmpgrpc.core.message.extensions.buildExtensions import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields +import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithMixedFields +import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithRequiredExtension import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithRequiredFields import kotlin.test.Test import kotlin.test.assertFalse @@ -25,7 +28,7 @@ class IsInitializedTest { @Test fun testUninitializedNestedMessage() { // field2 is set, but the nested message itself is missing its own required field1 - val incompleteNested = Proto2RequiredFields.Proto2MessageWithMixedFields.createPartial(field1 = null) + val incompleteNested = Proto2MessageWithMixedFields.createPartial(field1 = null) val msg = Proto2MessageWithRequiredFields.createPartial( field1 = "valid", field2 = incompleteNested @@ -41,7 +44,10 @@ class IsInitializedTest { field3List = listOf(incomplete) ) - assertFalse(msg.isInitialized, "Message should be uninitialized if any element in a repeated field is uninitialized") + assertFalse( + msg.isInitialized, + "Message should be uninitialized if any element in a repeated field is uninitialized" + ) } @Test @@ -57,7 +63,7 @@ class IsInitializedTest { @Test fun testOneOfInitialization() { // x.field5 is a message type. If that message is incomplete, the parent is incomplete. - val incompleteMixed = Proto2RequiredFields.Proto2MessageWithMixedFields.createPartial(field1 = null) + val incompleteMixed = Proto2MessageWithMixedFields.createPartial(field1 = null) val msg = Proto2MessageWithRequiredFields( x = Proto2MessageWithRequiredFields.X.Field5(incompleteMixed) ) @@ -70,4 +76,61 @@ class IsInitializedTest { ) assertTrue(msg2.isInitialized, "Message should be initialized if OneOf contains a valid string") } + + @Test + fun testRequiredMessageExtensionInitialization() { + // 1. Missing both required extensions + val emptyMsg = Proto2MessageWithRequiredExtension.createPartial() + assertFalse(emptyMsg.isInitialized, "Should be uninitialized: missing extension1 and extensionRequiredMsg") + + // 2. extension1 is present, but extensionRequiredMsg is missing + val partialExt1 = buildExtensions { + set(Proto2RequiredFields.extension1, "valid") + } + val msgOnlyExt1 = Proto2MessageWithRequiredExtension.createPartial(extensions = partialExt1) + assertFalse(msgOnlyExt1.isInitialized, "Should be uninitialized: missing required message extension") + + // 3. Both present, but the required message extension is itself uninitialized + val incompleteNested = Proto2MessageWithMixedFields.createPartial(field1 = null) + val partialExt2 = buildExtensions { + set(Proto2RequiredFields.extension1, "valid") + set(Proto2RequiredFields.extensionRequiredMsg, incompleteNested) + } + val msgIncompleteMsg = Proto2MessageWithRequiredExtension.createPartial(extensions = partialExt2) + assertFalse( + msgIncompleteMsg.isInitialized, + "Should be uninitialized: required message extension is missing field1" + ) + + // 4. Fully initialized + val completeExt = buildExtensions { + set(Proto2RequiredFields.extension1, "valid") + set(Proto2RequiredFields.extensionRequiredMsg, Proto2MessageWithMixedFields(field1 = "valid")) + } + val validMsg = Proto2MessageWithRequiredExtension(extensions = completeExt) + assertTrue( + validMsg.isInitialized, + "Should be initialized: all required extensions and their fields are present" + ) + } + + @Test + fun testRepeatedMessageExtensionInitialization() { + val validNested = Proto2MessageWithMixedFields(field1 = "ok") + val incompleteNested = Proto2MessageWithMixedFields.createPartial(field1 = null) + + // Base valid extensions so the parent's 'required' constraints are met + val baseExtensions = buildExtensions { + set(Proto2RequiredFields.extension1, "valid") + set(Proto2RequiredFields.extensionRequiredMsg, validNested) + set(Proto2RequiredFields.extensionRepeatedMsgList, listOf(validNested, incompleteNested)) + } + + val msg = Proto2MessageWithRequiredExtension(extensions = baseExtensions) + + assertFalse( + msg.isInitialized, + "Should be uninitialized: one element in the repeated message extension is uninitialized" + ) + } } diff --git a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/IsInitializedFieldExtension.kt b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/IsInitializedFieldExtension.kt index 251cfe28..d213bf1b 100644 --- a/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/IsInitializedFieldExtension.kt +++ b/kmp-grpc-plugin/src/main/java/io/github/timortel/kmpgrpc/plugin/sourcegeneration/generators/protofile/message/extensions/IsInitializedFieldExtension.kt @@ -2,6 +2,7 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile. import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.KModifier +import com.squareup.kotlinpoet.MemberName import com.squareup.kotlinpoet.TypeSpec import io.github.timortel.kmpgrpc.plugin.sourcegeneration.SourceTarget import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.Const @@ -34,7 +35,13 @@ object IsInitializedFieldExtension : MessageWriterExtension { val subMessages = subMessageFields + subMessageMapFields + oneOfs - if (requiredFields.isEmpty() && subMessages.isEmpty()) { + val consideredExtensionFields = message.extensionsInProject.flatMap { extensions -> + extensions.fields.filter { field -> + field.cardinality.isLegacyRequired || field.type.isMessage + } + } + + if (requiredFields.isEmpty() && subMessages.isEmpty() && consideredExtensionFields.isEmpty()) { add("true") } else { val separator = " && " @@ -51,6 +58,7 @@ object IsInitializedFieldExtension : MessageWriterExtension { Const.Message.isInitializedProperty.name ) } + ProtoFieldCardinality.Repeated -> { add( "%N.all { it.%N }", @@ -77,7 +85,48 @@ object IsInitializedFieldExtension : MessageWriterExtension { ) } - val impl = listOf(requiredFieldsBool, subMessageFieldsBool, subMessageOneOfFieldsBool, subMessageMapFieldsBool).joinCodeBlocks(separator) + val requiredExtensionFieldsBool = + consideredExtensionFields.joinToCodeBlock(separator) { field -> + val extensionMember = MemberName(field.file.className, field.codeName) + + when (field.cardinality) { + is ProtoFieldCardinality.Singular -> { + if (field.type.isMessage) { + add( + "%N[%M]?.%N == true", + Const.Message.Constructor.MessageExtensions.name, + extensionMember, + Const.Message.isInitializedProperty.name + ) + } else { + add( + "%N[%M] != null", + Const.Message.Constructor.MessageExtensions.name, + extensionMember + ) + } + } + + ProtoFieldCardinality.Repeated -> { + if (field.type.isMessage) { + add( + "%N[%M].all { it.%N }", + Const.Message.Constructor.MessageExtensions.name, + extensionMember, + Const.Message.isInitializedProperty.name + ) + } + } + } + } + + val impl = listOf( + requiredFieldsBool, + subMessageFieldsBool, + subMessageOneOfFieldsBool, + subMessageMapFieldsBool, + requiredExtensionFieldsBool + ).joinCodeBlocks(separator) add(impl) } diff --git a/readme.md b/readme.md index 94e5e0ae..fef6a1ee 100644 --- a/readme.md +++ b/readme.md @@ -367,7 +367,7 @@ You can construct a message of type `MyMessage` like this: val msg = Sample.MyMessage( regularField = "val1", extensions = buildExtensions { - set[Sample.myExtension] = "val2" + set(Sample.myExtension, "val2") } ) ```