Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,15 @@ message Proto2MessageWithRequiredFields {
string field6 = 6;
}
}

message Proto2MessageWithRequiredExtension {
extensions 1 to max;
}

extend Proto2MessageWithRequiredExtension {
optional string extension1 = 1;

optional Proto2MessageWithMixedFields extensionRequiredMsg = 2;

repeated Proto2MessageWithMixedFields extensionRepeatedMsg = 3;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package io.github.timortel.kotlin_multiplatform_grpc_plugin.test.model

import io.github.timortel.kmpgrpc.core.message.extensions.buildExtensions
import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields
import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithMixedFields
import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithRequiredExtension
import io.github.timortel.kmpgrpc.test.proto2.Proto2RequiredFields.Proto2MessageWithRequiredFields
import kotlin.test.Test
import kotlin.test.assertFalse
Expand All @@ -25,7 +28,7 @@ class IsInitializedTest {
@Test
fun testUninitializedNestedMessage() {
// field2 is set, but the nested message itself is missing its own required field1
val incompleteNested = Proto2RequiredFields.Proto2MessageWithMixedFields.createPartial(field1 = null)
val incompleteNested = Proto2MessageWithMixedFields.createPartial(field1 = null)
val msg = Proto2MessageWithRequiredFields.createPartial(
field1 = "valid",
field2 = incompleteNested
Expand All @@ -41,7 +44,10 @@ class IsInitializedTest {
field3List = listOf(incomplete)
)

assertFalse(msg.isInitialized, "Message should be uninitialized if any element in a repeated field is uninitialized")
assertFalse(
msg.isInitialized,
"Message should be uninitialized if any element in a repeated field is uninitialized"
)
}

@Test
Expand All @@ -57,7 +63,7 @@ class IsInitializedTest {
@Test
fun testOneOfInitialization() {
// x.field5 is a message type. If that message is incomplete, the parent is incomplete.
val incompleteMixed = Proto2RequiredFields.Proto2MessageWithMixedFields.createPartial(field1 = null)
val incompleteMixed = Proto2MessageWithMixedFields.createPartial(field1 = null)
val msg = Proto2MessageWithRequiredFields(
x = Proto2MessageWithRequiredFields.X.Field5(incompleteMixed)
)
Expand All @@ -70,4 +76,61 @@ class IsInitializedTest {
)
assertTrue(msg2.isInitialized, "Message should be initialized if OneOf contains a valid string")
}

@Test
fun testRequiredMessageExtensionInitialization() {
// 1. Missing both required extensions
val emptyMsg = Proto2MessageWithRequiredExtension.createPartial()
assertFalse(emptyMsg.isInitialized, "Should be uninitialized: missing extension1 and extensionRequiredMsg")

// 2. extension1 is present, but extensionRequiredMsg is missing
val partialExt1 = buildExtensions {
set(Proto2RequiredFields.extension1, "valid")
}
val msgOnlyExt1 = Proto2MessageWithRequiredExtension.createPartial(extensions = partialExt1)
assertFalse(msgOnlyExt1.isInitialized, "Should be uninitialized: missing required message extension")

// 3. Both present, but the required message extension is itself uninitialized
val incompleteNested = Proto2MessageWithMixedFields.createPartial(field1 = null)
val partialExt2 = buildExtensions {
set(Proto2RequiredFields.extension1, "valid")
set(Proto2RequiredFields.extensionRequiredMsg, incompleteNested)
}
val msgIncompleteMsg = Proto2MessageWithRequiredExtension.createPartial(extensions = partialExt2)
assertFalse(
msgIncompleteMsg.isInitialized,
"Should be uninitialized: required message extension is missing field1"
)

// 4. Fully initialized
val completeExt = buildExtensions {
set(Proto2RequiredFields.extension1, "valid")
set(Proto2RequiredFields.extensionRequiredMsg, Proto2MessageWithMixedFields(field1 = "valid"))
}
val validMsg = Proto2MessageWithRequiredExtension(extensions = completeExt)
assertTrue(
validMsg.isInitialized,
"Should be initialized: all required extensions and their fields are present"
)
}

@Test
fun testRepeatedMessageExtensionInitialization() {
val validNested = Proto2MessageWithMixedFields(field1 = "ok")
val incompleteNested = Proto2MessageWithMixedFields.createPartial(field1 = null)

// Base valid extensions so the parent's 'required' constraints are met
val baseExtensions = buildExtensions {
set(Proto2RequiredFields.extension1, "valid")
set(Proto2RequiredFields.extensionRequiredMsg, validNested)
set(Proto2RequiredFields.extensionRepeatedMsgList, listOf(validNested, incompleteNested))
}

val msg = Proto2MessageWithRequiredExtension(extensions = baseExtensions)

assertFalse(
msg.isInitialized,
"Should be uninitialized: one element in the repeated message extension is uninitialized"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.github.timortel.kmpgrpc.plugin.sourcegeneration.generators.protofile.

import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.TypeSpec
import io.github.timortel.kmpgrpc.plugin.sourcegeneration.SourceTarget
import io.github.timortel.kmpgrpc.plugin.sourcegeneration.constants.Const
Expand Down Expand Up @@ -34,7 +35,13 @@ object IsInitializedFieldExtension : MessageWriterExtension {

val subMessages = subMessageFields + subMessageMapFields + oneOfs

if (requiredFields.isEmpty() && subMessages.isEmpty()) {
val consideredExtensionFields = message.extensionsInProject.flatMap { extensions ->
extensions.fields.filter { field ->
field.cardinality.isLegacyRequired || field.type.isMessage
}
}

if (requiredFields.isEmpty() && subMessages.isEmpty() && consideredExtensionFields.isEmpty()) {
add("true")
} else {
val separator = " && "
Expand All @@ -51,6 +58,7 @@ object IsInitializedFieldExtension : MessageWriterExtension {
Const.Message.isInitializedProperty.name
)
}

ProtoFieldCardinality.Repeated -> {
add(
"%N.all { it.%N }",
Expand All @@ -77,7 +85,48 @@ object IsInitializedFieldExtension : MessageWriterExtension {
)
}

val impl = listOf(requiredFieldsBool, subMessageFieldsBool, subMessageOneOfFieldsBool, subMessageMapFieldsBool).joinCodeBlocks(separator)
val requiredExtensionFieldsBool =
consideredExtensionFields.joinToCodeBlock(separator) { field ->
val extensionMember = MemberName(field.file.className, field.codeName)

when (field.cardinality) {
is ProtoFieldCardinality.Singular -> {
if (field.type.isMessage) {
add(
"%N[%M]?.%N == true",
Const.Message.Constructor.MessageExtensions.name,
extensionMember,
Const.Message.isInitializedProperty.name
)
} else {
add(
"%N[%M] != null",
Const.Message.Constructor.MessageExtensions.name,
extensionMember
)
}
}

ProtoFieldCardinality.Repeated -> {
if (field.type.isMessage) {
add(
"%N[%M].all { it.%N }",
Const.Message.Constructor.MessageExtensions.name,
extensionMember,
Const.Message.isInitializedProperty.name
)
}
}
}
}

val impl = listOf(
requiredFieldsBool,
subMessageFieldsBool,
subMessageOneOfFieldsBool,
subMessageMapFieldsBool,
requiredExtensionFieldsBool
).joinCodeBlocks(separator)

add(impl)
}
Expand Down
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ You can construct a message of type `MyMessage` like this:
val msg = Sample.MyMessage(
regularField = "val1",
extensions = buildExtensions {
set[Sample.myExtension] = "val2"
set(Sample.myExtension, "val2")
}
)
```
Expand Down