From 03fae4f0de76604dd2b9a9f5e16573a2d395181c Mon Sep 17 00:00:00 2001 From: Ada Zhang Date: Tue, 26 May 2026 16:45:55 -0700 Subject: [PATCH] Support non-canonical extensions in message copy/freeze/encode operations. PiperOrigin-RevId: 921759833 --- upb/message/BUILD | 3 + upb/message/copy.c | 28 +++- upb/message/copy_test.cc | 51 +++++++ upb/message/internal/accessors.h | 13 ++ upb/message/internal/extension.c | 21 ++- upb/message/internal/extension.h | 8 + upb/message/internal/map_sorter.h | 2 +- upb/message/internal/message.h | 43 ++++-- upb/message/map_sorter.c | 12 +- upb/message/message.c | 61 ++++++-- upb/message/message.h | 56 ++++++- upb/message/promote.c | 166 ++++++++++++++------- upb/message/promote.h | 15 ++ upb/message/promote_test.cc | 207 +++++++++++++++++++++++++- upb/message/test.cc | 83 +++++++++++ upb/mini_table/internal/extension.h | 2 + upb/wire/BUILD | 4 + upb/wire/decode_test.cc | 176 ++++++++++++++++++++++ upb/wire/encode_test.cc | 222 ++++++++++++++++++++++++++++ upb/wire/internal/encoder.c | 49 ++++-- 20 files changed, 1115 insertions(+), 107 deletions(-) diff --git a/upb/message/BUILD b/upb/message/BUILD index 0cca4318a7602..c57e744f7c03d 100644 --- a/upb/message/BUILD +++ b/upb/message/BUILD @@ -186,6 +186,7 @@ cc_library( features = UPB_DEFAULT_FEATURES, visibility = ["//visibility:public"], deps = [ + ":copy", ":internal", ":message", "//upb/base", @@ -193,6 +194,7 @@ cc_library( "//upb/mini_table", "//upb/port", "//upb/wire", + "//upb/wire:encoder", "//upb/wire:eps_copy_input_stream", "//upb/wire:reader", ], @@ -381,6 +383,7 @@ cc_test( "//upb/mini_descriptor", "//upb/mini_descriptor:internal", "//upb/mini_table", + "//upb/port", "//upb/test:test_messages_proto2_upb_proto", "//upb/test:test_messages_proto3_upb_proto", "//upb/test:test_proto_upb_minitable", diff --git a/upb/message/copy.c b/upb/message/copy.c index 8c7f2397e9f95..c214f6125766c 100644 --- a/upb/message/copy.c +++ b/upb/message/copy.c @@ -245,12 +245,19 @@ upb_Message* _upb_Message_Copy(upb_Message* dst, const upb_Message* src, for (size_t i = 0; i < in->size; i++) { upb_TaggedAuxPtr tagged_ptr = in->aux_data[i]; - if (upb_TaggedAuxPtr_IsExtension(tagged_ptr)) { - // Clone extension - const upb_Extension* msg_ext = upb_TaggedAuxPtr_Extension(tagged_ptr); + if (upb_TaggedAuxPtr_IsExtension(tagged_ptr) || + upb_TaggedAuxPtr_IsNonCanonicalExtension(tagged_ptr)) { + // Clone a canonical or non-canonical extension + bool canonical = upb_TaggedAuxPtr_IsExtension(tagged_ptr); + const upb_Extension* msg_ext = + canonical ? upb_TaggedAuxPtr_Extension(tagged_ptr) + : upb_TaggedAuxPtr_NonCanonicalExtension(tagged_ptr); const upb_MiniTableField* field = &msg_ext->ext->UPB_PRIVATE(field); - upb_Extension* dst_ext = UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( - dst, msg_ext->ext, arena); + upb_Extension* dst_ext = + canonical ? UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( + dst, msg_ext->ext, arena) + : UPB_PRIVATE(_upb_Message_CreateNonCanonicalExtension)( + dst, msg_ext->ext, arena); if (!dst_ext) return NULL; if (upb_MiniTableField_IsScalar(field)) { if (!upb_Clone_ExtensionValue(msg_ext->ext, msg_ext, dst_ext, arena)) { @@ -268,7 +275,7 @@ upb_Message* _upb_Message_Copy(upb_Message* dst, const upb_Message* src, dst_ext->data.array_val = cloned_array; } } else if (upb_TaggedAuxPtr_IsUnknown(tagged_ptr)) { - // Clone unknown + // Clone a raw unknown field upb_StringView* unknown = upb_TaggedAuxPtr_UnknownData(tagged_ptr); // Make a copy into destination arena. if (!UPB_PRIVATE(_upb_Message_AddUnknown)( @@ -333,6 +340,15 @@ bool upb_Message_ShallowCopy(upb_Message* dst, const upb_Message* src, dst_in->aux_data[i] = upb_TaggedAuxPtr_MakeUnknownDataAliased(dst_sv); break; } + case kUpb_TaggedAuxType_NonCanonicalExtension: { + const upb_Extension* msg_ext = aux.extension; + upb_Extension* dst_ext = upb_Arena_Malloc(arena, sizeof(upb_Extension)); + if (!dst_ext) return false; + *dst_ext = *msg_ext; + dst_in->aux_data[i] = + upb_TaggedAuxPtr_MakeNonCanonicalExtension(dst_ext); + break; + } } } diff --git a/upb/message/copy_test.cc b/upb/message/copy_test.cc index 397d30e1ccc71..dff8b67e999e5 100644 --- a/upb/message/copy_test.cc +++ b/upb/message/copy_test.cc @@ -27,6 +27,8 @@ #include "upb/base/upcast.h" #include "upb/mem/arena.h" #include "upb/message/accessors.h" +#include "upb/message/internal/accessors.h" +#include "upb/message/internal/extension.h" #include "upb/message/internal/message.h" #include "upb/message/map.h" #include "upb/message/message.h" @@ -533,4 +535,53 @@ TEST(GeneratedCode, ShallowCloneMessage) { upb_Arena_Free(arena); } +TEST(GeneratedCode, DeepCloneMessageNonCanonicalExtensions) { + upb_Arena* source_arena = upb_Arena_New(); + upb_test_ModelWithExtensions* msg = + upb_test_ModelWithExtensions_new(source_arena); + upb_test_ModelExtension1* ext1 = upb_test_ModelExtension1_new(source_arena); + upb_test_ModelExtension1_set_str(ext1, + upb_StringView_FromString("LifecycleValue")); + + // Attach as non-canonical extension + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + UPB_UPCAST(msg), upb_test_ModelExtension1_model_ext_ext, &ext1, + source_arena); + + // Deep clone msg to clone + upb_Arena* arena = upb_Arena_New(); + upb_test_ModelWithExtensions* clone = + (upb_test_ModelWithExtensions*)upb_Message_DeepClone( + UPB_UPCAST(msg), &upb_0test__ModelWithExtensions_msg_init, arena); + ASSERT_NE(clone, nullptr); + + // Mutate original + upb_test_ModelExtension1_set_str(ext1, upb_StringView_FromString("Mutated")); + upb_Arena_Free(source_arena); + + // Check if clone has the non-canonical extension and it's unmodified + upb_MessageUnknown data; + uintptr_t iter = kUpb_Message_UnknownBegin; + bool has_non_canonical = false; + const upb_Extension* ext_found = nullptr; + while (upb_Message_NextUnknown2(UPB_UPCAST(clone), &data, &iter)) { + if (data.type == kUpb_MessageUnknownType_NonCanonicalExtension) { + has_non_canonical = true; + ext_found = (const upb_Extension*)data.value.extension; + } + } + EXPECT_TRUE(has_non_canonical); + ASSERT_NE(ext_found, nullptr); + + const upb_test_ModelExtension1* cloned_ext = + (const upb_test_ModelExtension1*)ext_found->data.msg_val; + EXPECT_TRUE( + upb_StringView_IsEqual(upb_test_ModelExtension1_str(cloned_ext), + upb_StringView_FromString("LifecycleValue"))); + + upb_Arena_Free(arena); +} + } // namespace + +#include "upb/port/undef.inc" diff --git a/upb/message/internal/accessors.h b/upb/message/internal/accessors.h index 69cc639862e6a..7a90001fb9f2d 100644 --- a/upb/message/internal/accessors.h +++ b/upb/message/internal/accessors.h @@ -331,6 +331,19 @@ UPB_API_INLINE bool upb_Message_SetExtension(struct upb_Message* msg, return true; } +UPB_API_INLINE bool UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + struct upb_Message* msg, const upb_MiniTableExtension* e, const void* val, + upb_Arena* a) { + UPB_ASSERT(!upb_Message_IsFrozen(msg)); + UPB_ASSERT(a); + upb_Extension* ext = + UPB_PRIVATE(_upb_Message_CreateNonCanonicalExtension)(msg, e, a); + if (!ext) return false; + UPB_PRIVATE(_upb_MiniTableField_DataCopy) + (&e->UPB_PRIVATE(field), &ext->data, val); + return true; +} + // Sets the value of the given field in the given msg. The return value is true // if the operation completed successfully, or false if memory allocation // failed. diff --git a/upb/message/internal/extension.c b/upb/message/internal/extension.c index ca8a5865fd625..8b04234235050 100644 --- a/upb/message/internal/extension.c +++ b/upb/message/internal/extension.c @@ -37,8 +37,9 @@ const upb_Extension* UPB_PRIVATE(_upb_Message_Getext)( return NULL; } -upb_Extension* UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( - struct upb_Message* msg, const upb_MiniTableExtension* e, upb_Arena* a) { +UPB_INLINE upb_Extension* _upb_Message_GetOrCreateExtensionInternal( + struct upb_Message* msg, const upb_MiniTableExtension* e, upb_Arena* a, + bool canonical) { UPB_ASSERT(!upb_Message_IsFrozen(msg)); upb_Extension* ext = (upb_Extension*)UPB_PRIVATE(_upb_Message_Getext)(msg, e); if (ext) return ext; @@ -49,6 +50,20 @@ upb_Extension* UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( if (!ext) return NULL; memset(ext, 0, sizeof(upb_Extension)); ext->ext = e; - in->aux_data[in->size++] = upb_TaggedAuxPtr_MakeExtension(ext); + in->aux_data[in->size++] = + canonical ? upb_TaggedAuxPtr_MakeExtension(ext) + : upb_TaggedAuxPtr_MakeNonCanonicalExtension(ext); return ext; } + +upb_Extension* UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( + struct upb_Message* msg, const upb_MiniTableExtension* e, upb_Arena* a) { + return _upb_Message_GetOrCreateExtensionInternal(msg, e, a, + /*canonical=*/true); +} + +upb_Extension* UPB_PRIVATE(_upb_Message_CreateNonCanonicalExtension)( + struct upb_Message* msg, const upb_MiniTableExtension* e, upb_Arena* a) { + return _upb_Message_GetOrCreateExtensionInternal(msg, e, a, + /*canonical=*/false); +} diff --git a/upb/message/internal/extension.h b/upb/message/internal/extension.h index d0dc11cf8cd90..0c551aad233ba 100644 --- a/upb/message/internal/extension.h +++ b/upb/message/internal/extension.h @@ -45,6 +45,14 @@ UPB_NODISCARD upb_Extension* UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( struct upb_Message* msg, const upb_MiniTableExtension* ext, upb_Arena* arena); +// Adds the given non-canonical extension data to the given message. +// |ext| is copied into the message instance. +// This logically replaces any previously-added extension with this number. +UPB_NODISCARD upb_Extension* UPB_PRIVATE( + _upb_Message_CreateNonCanonicalExtension)(struct upb_Message* msg, + const upb_MiniTableExtension* ext, + upb_Arena* arena); + // Returns an extension for a message with a given mini table, // or NULL if no extension exists with this mini table. const upb_Extension* UPB_PRIVATE(_upb_Message_Getext)( diff --git a/upb/message/internal/map_sorter.h b/upb/message/internal/map_sorter.h index bfc8f5775b5c3..abf3e7df177d5 100644 --- a/upb/message/internal/map_sorter.h +++ b/upb/message/internal/map_sorter.h @@ -90,7 +90,7 @@ bool _upb_mapsorter_pushmap(_upb_mapsorter* s, upb_FieldType key_type, const struct upb_Map* map, _upb_sortedmap* sorted); bool _upb_mapsorter_pushexts(_upb_mapsorter* s, const upb_Message_Internal* in, - _upb_sortedmap* sorted); + _upb_sortedmap* sorted, bool include_noncanonical); #ifdef __cplusplus } /* extern "C" */ diff --git a/upb/message/internal/message.h b/upb/message/internal/message.h index 59050be0b594c..88f5ea8925513 100644 --- a/upb/message/internal/message.h +++ b/upb/message/internal/message.h @@ -46,7 +46,8 @@ typedef struct upb_TaggedAuxPtr { // Two lowest bits form a tag: // 00 - non-aliased unknown data // 10 - aliased unknown data - // 01 - extension + // 01 - canonical extension + // 11 - unknown as a non-canonical extension // // The main semantic difference between aliased and non-aliased unknown data // is that non-aliased unknown data can be assumed to have the following @@ -64,11 +65,19 @@ typedef struct upb_TaggedAuxPtr { // For aliased unknown data, this layout is _not_ guaranteed, since the // pointer to the StringView can be anywhere in the allocation, and the // StringView may point to non-data memory. + // + // For a non-canonical extension, its schema is known but not + // the one expected by the message, so it should be treated like an unknown + // field, but is stored as an extension for efficiency. uintptr_t ptr; } upb_TaggedAuxPtr; UPB_INLINE bool upb_TaggedAuxPtr_IsExtension(upb_TaggedAuxPtr ptr) { - return ptr.ptr & 1; + return (ptr.ptr & 3) == 1; +} + +UPB_INLINE bool upb_TaggedAuxPtr_IsNonCanonicalExtension(upb_TaggedAuxPtr ptr) { + return (ptr.ptr & 3) == 3; } UPB_INLINE bool upb_TaggedAuxPtr_IsUnknown(upb_TaggedAuxPtr ptr) { @@ -84,6 +93,12 @@ UPB_INLINE upb_Extension* upb_TaggedAuxPtr_Extension(upb_TaggedAuxPtr ptr) { return (upb_Extension*)(ptr.ptr & ~3ULL); } +UPB_INLINE upb_Extension* upb_TaggedAuxPtr_NonCanonicalExtension( + upb_TaggedAuxPtr ptr) { + UPB_ASSERT(upb_TaggedAuxPtr_IsNonCanonicalExtension(ptr)); + return (upb_Extension*)(ptr.ptr & ~3ULL); +} + UPB_INLINE upb_StringView* upb_TaggedAuxPtr_UnknownData(upb_TaggedAuxPtr ptr) { UPB_ASSERT(!upb_TaggedAuxPtr_IsExtension(ptr)); return (upb_StringView*)(ptr.ptr & ~3ULL); @@ -91,6 +106,7 @@ UPB_INLINE upb_StringView* upb_TaggedAuxPtr_UnknownData(upb_TaggedAuxPtr ptr) { typedef enum { kUpb_TaggedAuxType_Extension, + kUpb_TaggedAuxType_NonCanonicalExtension, kUpb_TaggedAuxType_Unknown, kUpb_TaggedAuxType_AliasedUnknown } upb_TaggedAuxType; @@ -108,10 +124,13 @@ UPB_INLINE upb_TaggedAuxType upb_TaggedAux_Get(upb_TaggedAuxPtr ptr, } else if (upb_TaggedAuxPtr_IsUnknownAliased(ptr)) { data->unknown_data = *upb_TaggedAuxPtr_UnknownData(ptr); return kUpb_TaggedAuxType_AliasedUnknown; - } else { - UPB_ASSERT(upb_TaggedAuxPtr_IsUnknown(ptr)); + } else if (upb_TaggedAuxPtr_IsUnknown(ptr)) { data->unknown_data = *upb_TaggedAuxPtr_UnknownData(ptr); return kUpb_TaggedAuxType_Unknown; + } else { + UPB_ASSERT(upb_TaggedAuxPtr_IsNonCanonicalExtension(ptr)); + data->extension = upb_TaggedAuxPtr_NonCanonicalExtension(ptr); + return kUpb_TaggedAuxType_NonCanonicalExtension; } } @@ -128,6 +147,13 @@ upb_TaggedAuxPtr_MakeExtension(const upb_Extension* e) { return ptr; } +UPB_INLINE upb_TaggedAuxPtr +upb_TaggedAuxPtr_MakeNonCanonicalExtension(const upb_Extension* e) { + upb_TaggedAuxPtr ptr; + ptr.ptr = (uintptr_t)e | 3; + return ptr; +} + // This tag means that the original allocation for this field starts with the // string view and ends with the end of the content referenced by the string // view. @@ -234,7 +260,8 @@ UPB_NODISCARD UPB_INLINE struct upb_Message* _upb_Message_New( return msg; } -// Discards the unknown fields for this message only. +// Discards the unknown (including non-canonical extensions) for this message +// only. void _upb_Message_DiscardUnknown_shallow(struct upb_Message* msg); UPB_NODISCARD UPB_NOINLINE bool UPB_PRIVATE(_upb_Message_AddUnknownSlowPath)( @@ -342,12 +369,6 @@ UPB_INLINE bool upb_Message_NextUnknown(const struct upb_Message* msg, return false; } -UPB_INLINE bool upb_Message_HasUnknown(const struct upb_Message* msg) { - upb_StringView data; - uintptr_t iter = kUpb_Message_UnknownBegin; - return upb_Message_NextUnknown(msg, &data, &iter); -} - UPB_INLINE bool upb_Message_NextExtension(const struct upb_Message* msg, const upb_MiniTableExtension** out_e, upb_MessageValue* out_v, diff --git a/upb/message/map_sorter.c b/upb/message/map_sorter.c index 968b583af651e..fc5d691407ab5 100644 --- a/upb/message/map_sorter.c +++ b/upb/message/map_sorter.c @@ -164,10 +164,15 @@ static int _upb_mapsorter_cmpext(const void* _a, const void* _b) { } bool _upb_mapsorter_pushexts(_upb_mapsorter* s, const upb_Message_Internal* in, - _upb_sortedmap* sorted) { + _upb_sortedmap* sorted, + bool include_noncanonical) { size_t count = 0; for (size_t i = 0; i < in->size; i++) { - count += upb_TaggedAuxPtr_IsExtension(in->aux_data[i]); + bool is_any_extension = + upb_TaggedAuxPtr_IsExtension(in->aux_data[i]) || + (include_noncanonical && + upb_TaggedAuxPtr_IsNonCanonicalExtension(in->aux_data[i])); + count += is_any_extension; } if (!_upb_mapsorter_resize(s, sorted, count)) return false; if (count == 0) return true; @@ -177,6 +182,9 @@ bool _upb_mapsorter_pushexts(_upb_mapsorter* s, const upb_Message_Internal* in, upb_TaggedAuxPtr tagged_ptr = in->aux_data[i]; if (upb_TaggedAuxPtr_IsExtension(tagged_ptr)) { *entry++ = upb_TaggedAuxPtr_Extension(tagged_ptr); + } else if (include_noncanonical && + upb_TaggedAuxPtr_IsNonCanonicalExtension(tagged_ptr)) { + *entry++ = upb_TaggedAuxPtr_NonCanonicalExtension(tagged_ptr); } } qsort(&s->entries[sorted->start], count, sizeof(*s->entries), diff --git a/upb/message/message.c b/upb/message/message.c index e013c04fc8b49..5f1190f3ec801 100644 --- a/upb/message/message.c +++ b/upb/message/message.c @@ -164,6 +164,7 @@ void _upb_Message_DiscardUnknown_shallow(upb_Message* msg) { uint32_t size = 0; for (uint32_t i = 0; i < in->size; i++) { upb_TaggedAuxPtr tagged_ptr = in->aux_data[i]; + // Only keep canonical extensions. if (upb_TaggedAuxPtr_IsExtension(tagged_ptr)) { in->aux_data[size++] = tagged_ptr; } @@ -175,33 +176,61 @@ upb_Message_DeleteUnknownStatus upb_Message_DeleteUnknown(upb_Message* msg, upb_StringView* data, uintptr_t* iter, upb_Arena* arena) { + upb_MessageUnknown unknown; + unknown.type = kUpb_MessageUnknownType_Bytes; + unknown.value.bytes = *data; + + upb_Message_DeleteUnknownStatus res = + upb_Message_DeleteUnknown2(msg, &unknown, iter, arena); + *data = unknown.value.bytes; + return res; +} + +upb_Message_DeleteUnknownStatus upb_Message_DeleteUnknown2( + upb_Message* msg, upb_MessageUnknown* data, uintptr_t* iter, + upb_Arena* arena) { UPB_ASSERT(!upb_Message_IsFrozen(msg)); UPB_ASSERT(*iter != kUpb_Message_UnknownBegin); upb_Message_Internal* in = UPB_PRIVATE(_upb_Message_GetInternal)(msg); UPB_ASSERT(in); UPB_ASSERT(*iter <= in->size); upb_TaggedAuxPtr unknown_ptr = in->aux_data[*iter - 1]; + + if (data->type == kUpb_MessageUnknownType_NonCanonicalExtension) { + UPB_ASSERT(upb_TaggedAuxPtr_IsNonCanonicalExtension(unknown_ptr)); + // When the unknown is a non-canonical extension, we just remove it from the + // aux data array. + in->aux_data[*iter - 1] = upb_TaggedAuxPtr_Null(); + return upb_Message_NextUnknown2(msg, data, iter) + ? kUpb_DeleteUnknown_IterUpdated + : kUpb_DeleteUnknown_DeletedLast; + } + UPB_ASSERT(upb_TaggedAuxPtr_IsUnknown(unknown_ptr)); upb_StringView* unknown = upb_TaggedAuxPtr_UnknownData(unknown_ptr); - if (unknown->data == data->data && unknown->size == data->size) { + UPB_ASSERT(data->type == kUpb_MessageUnknownType_Bytes); + upb_StringView* data_bytes = &data->value.bytes; + if (unknown->data == data_bytes->data && unknown->size == data_bytes->size) { // Remove whole field in->aux_data[*iter - 1] = upb_TaggedAuxPtr_Null(); - } else if (unknown->data == data->data) { + } else if (unknown->data == data_bytes->data) { // Strip prefix - unknown->data += data->size; - unknown->size -= data->size; - *data = *unknown; + unknown->data += data_bytes->size; + unknown->size -= data_bytes->size; + *data_bytes = *unknown; return kUpb_DeleteUnknown_IterUpdated; - } else if (unknown->data + unknown->size == data->data + data->size) { + } else if (unknown->data + unknown->size == + data_bytes->data + data_bytes->size) { // Truncate existing field - unknown->size -= data->size; + unknown->size -= data_bytes->size; if (!upb_TaggedAuxPtr_IsUnknownAliased(unknown_ptr)) { in->aux_data[*iter - 1] = upb_TaggedAuxPtr_MakeUnknownDataAliased(unknown); } } else { - UPB_ASSERT(unknown->data < data->data && - unknown->data + unknown->size > data->data + data->size); + UPB_ASSERT(unknown->data < data_bytes->data && + unknown->data + unknown->size > + data_bytes->data + data_bytes->size); // Split in the middle upb_StringView* prefix = unknown; upb_StringView* suffix = upb_Arena_Malloc(arena, sizeof(upb_StringView)); @@ -222,11 +251,11 @@ upb_Message_DeleteUnknownStatus upb_Message_DeleteUnknown(upb_Message* msg, in->aux_data[*iter - 1] = upb_TaggedAuxPtr_MakeUnknownDataAliased(prefix); } in->size++; - suffix->data = data->data + data->size; + suffix->data = data_bytes->data + data_bytes->size; suffix->size = (prefix->data + prefix->size) - suffix->data; - prefix->size = data->data - prefix->data; + prefix->size = data_bytes->data - prefix->data; } - return upb_Message_NextUnknown(msg, data, iter) + return upb_Message_NextUnknown2(msg, data, iter) ? kUpb_DeleteUnknown_IterUpdated : kUpb_DeleteUnknown_DeletedLast; } @@ -286,10 +315,14 @@ void upb_Message_Freeze(upb_Message* msg, const upb_MiniTable* m) { uint32_t size = in ? in->size : 0; for (size_t i = 0; i < size; i++) { upb_TaggedAuxPtr tagged_ptr = in->aux_data[i]; - if (!upb_TaggedAuxPtr_IsExtension(tagged_ptr)) { + const upb_Extension* ext = NULL; + if (upb_TaggedAuxPtr_IsExtension(tagged_ptr)) { + ext = upb_TaggedAuxPtr_Extension(tagged_ptr); + } else if (upb_TaggedAuxPtr_IsNonCanonicalExtension(tagged_ptr)) { + ext = upb_TaggedAuxPtr_NonCanonicalExtension(tagged_ptr); + } else { continue; } - const upb_Extension* ext = upb_TaggedAuxPtr_Extension(tagged_ptr); const upb_MiniTableExtension* e = ext->ext; const upb_MiniTableField* f = &e->UPB_PRIVATE(field); const upb_MiniTable* m2 = upb_MiniTableExtension_GetSubMessage(e); diff --git a/upb/message/message.h b/upb/message/message.h index bcceebd46dfc7..3c0f08cc3f5d4 100644 --- a/upb/message/message.h +++ b/upb/message/message.h @@ -18,7 +18,6 @@ #include "upb/base/string_view.h" #include "upb/mem/arena.h" #include "upb/message/array.h" -#include "upb/message/internal/extension.h" #include "upb/message/internal/message.h" #include "upb/message/internal/types.h" #include "upb/mini_table/extension.h" @@ -51,10 +50,57 @@ UPB_NODISCARD UPB_API upb_Message* upb_Message_New(const upb_MiniTable* m, #define kUpb_Message_UnknownBegin 0 #define kUpb_Message_ExtensionBegin 0 +typedef enum { + kUpb_MessageUnknownType_Bytes, + kUpb_MessageUnknownType_NonCanonicalExtension, +} upb_MessageUnknownType; + +typedef struct upb_MessageUnknown { + uint8_t type; + union { + upb_StringView bytes; + const void* extension; + } value; +} upb_MessageUnknown; + +// TODO - Rename to upb_Message_NextUnknownField, and make this +// API private. +// Iterates over unknown fields. UPB_INLINE bool upb_Message_NextUnknown(const upb_Message* msg, upb_StringView* data, uintptr_t* iter); -UPB_INLINE bool upb_Message_HasUnknown(const upb_Message* msg); +// Support iteration over unknown, including non-canonical extensions. +UPB_INLINE bool upb_Message_NextUnknown2(const upb_Message* msg, + upb_MessageUnknown* data, + uintptr_t* iter) { + const upb_Message_Internal* in = UPB_PRIVATE(_upb_Message_GetInternal)(msg); + size_t i = *iter; + if (in) { + while (i < in->size) { + upb_TaggedAuxPtr tagged_ptr = in->aux_data[i++]; + if (upb_TaggedAuxPtr_IsUnknown(tagged_ptr)) { + data->type = kUpb_MessageUnknownType_Bytes; + data->value.bytes = *upb_TaggedAuxPtr_UnknownData(tagged_ptr); + *iter = i; + return true; + } else if (upb_TaggedAuxPtr_IsNonCanonicalExtension(tagged_ptr)) { + data->type = kUpb_MessageUnknownType_NonCanonicalExtension; + data->value.extension = + upb_TaggedAuxPtr_NonCanonicalExtension(tagged_ptr); + *iter = i; + return true; + } + } + } + *iter = i; + return false; +} + +UPB_INLINE bool upb_Message_HasUnknown(const upb_Message* msg) { + upb_MessageUnknown data; + uintptr_t iter = kUpb_Message_UnknownBegin; + return upb_Message_NextUnknown2(msg, &data, &iter); +} // Removes a segment of unknown data from the message, advancing to the next // segment. Returns false if the removed segment was at the end of the last @@ -87,9 +133,15 @@ typedef enum upb_Message_DeleteUnknownStatus { kUpb_DeleteUnknown_IterUpdated, kUpb_DeleteUnknown_AllocFail, } upb_Message_DeleteUnknownStatus; + UPB_NODISCARD upb_Message_DeleteUnknownStatus upb_Message_DeleteUnknown( upb_Message* msg, upb_StringView* data, uintptr_t* iter, upb_Arena* arena); +// Support deletion of unknown, including non-canonical extensions. +UPB_NODISCARD upb_Message_DeleteUnknownStatus +upb_Message_DeleteUnknown2(upb_Message* msg, upb_MessageUnknown* data, + uintptr_t* iter, upb_Arena* arena); + // Returns the number of extensions present in this message. size_t upb_Message_ExtensionCount(const upb_Message* msg); diff --git a/upb/message/promote.c b/upb/message/promote.c index 4e1706e4be801..dc0b8704b22db 100644 --- a/upb/message/promote.c +++ b/upb/message/promote.c @@ -16,7 +16,6 @@ #include "upb/mem/arena.h" #include "upb/message/accessors.h" #include "upb/message/array.h" -#include "upb/message/internal/array.h" #include "upb/message/internal/extension.h" #include "upb/message/internal/message.h" #include "upb/message/map.h" @@ -25,6 +24,7 @@ #include "upb/mini_table/field.h" #include "upb/mini_table/message.h" #include "upb/wire/decode.h" +#include "upb/wire/encode.h" #include "upb/wire/eps_copy_input_stream.h" #include "upb/wire/reader.h" @@ -83,42 +83,74 @@ upb_GetExtension_Status upb_Message_GetOrPromoteExtension( upb_Message* extension_msg = NULL; int depth_limit = 100; uintptr_t iter = kUpb_Message_UnknownBegin; - upb_StringView data; - while (upb_Message_NextUnknown(msg, &data, &iter)) { - const char* ptr = data.data; - upb_EpsCopyInputStream stream; - upb_EpsCopyInputStream_Init(&stream, &ptr, data.size); - while (!upb_EpsCopyInputStream_IsDone(&stream, &ptr)) { - uint32_t tag; - const char* unknown_begin = ptr; - ptr = upb_WireReader_ReadTag(ptr, &tag, &stream); - if (!ptr) return kUpb_GetExtension_ParseError; - if (field_number == upb_WireReader_GetFieldNumber(tag)) { - upb_StringView data; + upb_MessageUnknown data; + while (upb_Message_NextUnknown2(msg, &data, &iter)) { + if (data.type == kUpb_MessageUnknownType_NonCanonicalExtension) { + const upb_Extension* ext = (const upb_Extension*)data.value.extension; + if (upb_MiniTableExtension_Number(ext->ext) == field_number) { + // Encode and then decode to promote. + char* buf; + size_t size; + upb_EncodeStatus enc_status = upb_Encode( + ext->data.msg_val, upb_MiniTableExtension_GetSubMessage(ext->ext), + 0, arena, &buf, &size); + if (enc_status != kUpb_EncodeStatus_Ok) { + return enc_status == kUpb_EncodeStatus_OutOfMemory + ? kUpb_GetExtension_OutOfMemory + : kUpb_GetExtension_ParseError; + } found_count++; - upb_EpsCopyCapture capture; - upb_EpsCopyCapture_Start(&capture, &stream, unknown_begin); - ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, &stream); - if (!ptr || !upb_EpsCopyCapture_End(&capture, &stream, ptr, &data)) { + if (!extension_msg) { + extension_msg = _upb_Message_New(extension_table, arena); + if (!extension_msg) return kUpb_GetExtension_OutOfMemory; + } + upb_DecodeStatus status = + upb_Decode(buf, size, extension_msg, extension_table, NULL, + decode_options, arena); + if (status == kUpb_DecodeStatus_OutOfMemory) { + return kUpb_GetExtension_OutOfMemory; + } else if (status != kUpb_DecodeStatus_Ok) { return kUpb_GetExtension_ParseError; } - upb_UnknownToMessageRet parse_result = - upb_MiniTable_ParseUnknownMessage( - data.data, data.size, extension_table, - /* base_message= */ extension_msg, decode_options, arena); - switch (parse_result.status) { - case kUpb_UnknownToMessage_OutOfMemory: - return kUpb_GetExtension_OutOfMemory; - case kUpb_UnknownToMessage_ParseError: + } + } else { + UPB_ASSERT(data.type == kUpb_MessageUnknownType_Bytes); + upb_StringView unknown_bytes = data.value.bytes; + const char* ptr = unknown_bytes.data; + upb_EpsCopyInputStream stream; + upb_EpsCopyInputStream_Init(&stream, &ptr, unknown_bytes.size); + while (!upb_EpsCopyInputStream_IsDone(&stream, &ptr)) { + uint32_t tag; + const char* unknown_begin = ptr; + ptr = upb_WireReader_ReadTag(ptr, &tag, &stream); + if (!ptr) return kUpb_GetExtension_ParseError; + if (field_number == upb_WireReader_GetFieldNumber(tag)) { + upb_StringView data; + found_count++; + upb_EpsCopyCapture capture; + upb_EpsCopyCapture_Start(&capture, &stream, unknown_begin); + ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, &stream); + if (!ptr || !upb_EpsCopyCapture_End(&capture, &stream, ptr, &data)) { return kUpb_GetExtension_ParseError; - case kUpb_UnknownToMessage_NotFound: - return kUpb_GetExtension_NotPresent; - case kUpb_UnknownToMessage_Ok: - extension_msg = parse_result.message; + } + upb_UnknownToMessageRet parse_result = + upb_MiniTable_ParseUnknownMessage( + data.data, data.size, extension_table, + /* base_message= */ extension_msg, decode_options, arena); + switch (parse_result.status) { + case kUpb_UnknownToMessage_OutOfMemory: + return kUpb_GetExtension_OutOfMemory; + case kUpb_UnknownToMessage_ParseError: + return kUpb_GetExtension_ParseError; + case kUpb_UnknownToMessage_NotFound: + return kUpb_GetExtension_NotPresent; + case kUpb_UnknownToMessage_Ok: + extension_msg = parse_result.message; + } + } else { + ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, &stream); + if (!ptr) return kUpb_GetExtension_ParseError; } - } else { - ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, &stream); - if (!ptr) return kUpb_GetExtension_ParseError; } } } @@ -134,10 +166,9 @@ upb_GetExtension_Status upb_Message_GetOrPromoteExtension( ext->data.msg_val = extension_msg; while (found_count > 0) { - upb_FindUnknownRet found = upb_Message_FindUnknown(msg, field_number, 0); + upb_FindUnknownRet2 found = upb_Message_FindUnknown2(msg, field_number, 0); UPB_ASSERT(found.status == kUpb_FindUnknown_Ok); - upb_StringView view = {.data = found.ptr, .size = found.len}; - if (upb_Message_DeleteUnknown(msg, &view, &found.iter, arena) == + if (upb_Message_DeleteUnknown2(msg, &found.unknown, &found.iter, arena) == kUpb_DeleteUnknown_AllocFail) { return kUpb_GetExtension_OutOfMemory; } @@ -147,48 +178,77 @@ upb_GetExtension_Status upb_Message_GetOrPromoteExtension( return kUpb_GetExtension_Ok; } -static upb_FindUnknownRet upb_FindUnknownRet_ParseError(void) { - return (upb_FindUnknownRet){.status = kUpb_FindUnknown_ParseError}; -} - upb_FindUnknownRet upb_Message_FindUnknown(const upb_Message* msg, uint32_t field_number, int depth_limit) { - depth_limit = depth_limit ? depth_limit : 100; + upb_FindUnknownRet2 ret2 = + upb_Message_FindUnknown2(msg, field_number, depth_limit); + if (ret2.status != kUpb_FindUnknown_Ok) { + return (upb_FindUnknownRet){.status = ret2.status}; + } + UPB_ASSERT(ret2.unknown.type == kUpb_MessageUnknownType_Bytes); upb_FindUnknownRet ret; + ret.status = ret2.status; + ret.ptr = ret2.unknown.value.bytes.data; + ret.len = ret2.unknown.value.bytes.size; + ret.iter = ret2.iter; + return ret; +} + +static upb_FindUnknownRet2 upb_FindUnknownRet2_ParseError(void) { + return (upb_FindUnknownRet2){.status = kUpb_FindUnknown_ParseError}; +} + +upb_FindUnknownRet2 upb_Message_FindUnknown2(const upb_Message* msg, + uint32_t field_number, + int depth_limit) { + depth_limit = depth_limit ? depth_limit : 100; + upb_FindUnknownRet2 ret; ret.iter = kUpb_Message_UnknownBegin; - upb_StringView data; - while (upb_Message_NextUnknown(msg, &data, &ret.iter)) { + upb_MessageUnknown data; + while (upb_Message_NextUnknown2(msg, &data, &ret.iter)) { + if (data.type == kUpb_MessageUnknownType_NonCanonicalExtension) { + const upb_Extension* ext = (const upb_Extension*)data.value.extension; + if (upb_MiniTableExtension_Number(ext->ext) == field_number) { + ret.status = kUpb_FindUnknown_Ok; + ret.unknown = data; + return ret; + } + continue; + } + + UPB_ASSERT(data.type == kUpb_MessageUnknownType_Bytes); + upb_StringView bytes = data.value.bytes; upb_EpsCopyInputStream stream; - const char* ptr = data.data; - upb_EpsCopyInputStream_Init(&stream, &ptr, data.size); + const char* ptr = bytes.data; + upb_EpsCopyInputStream_Init(&stream, &ptr, bytes.size); while (!upb_EpsCopyInputStream_IsDone(&stream, &ptr)) { uint32_t tag; const char* unknown_begin = ptr; ptr = upb_WireReader_ReadTag(ptr, &tag, &stream); - if (!ptr) return upb_FindUnknownRet_ParseError(); + if (!ptr) return upb_FindUnknownRet2_ParseError(); if (field_number == upb_WireReader_GetFieldNumber(tag)) { - upb_StringView data; + upb_StringView cap_data; ret.status = kUpb_FindUnknown_Ok; upb_EpsCopyCapture capture; upb_EpsCopyCapture_Start(&capture, &stream, unknown_begin); ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, &stream); - if (!ptr || !upb_EpsCopyCapture_End(&capture, &stream, ptr, &data)) { - return upb_FindUnknownRet_ParseError(); + if (!ptr || + !upb_EpsCopyCapture_End(&capture, &stream, ptr, &cap_data)) { + return upb_FindUnknownRet2_ParseError(); } - ret.ptr = data.data; - ret.len = data.size; + ret.unknown.type = kUpb_MessageUnknownType_Bytes; + ret.unknown.value.bytes = cap_data; return ret; } ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, &stream); - if (!ptr) return upb_FindUnknownRet_ParseError(); + if (!ptr) return upb_FindUnknownRet2_ParseError(); } } ret.status = kUpb_FindUnknown_NotPresent; - ret.ptr = NULL; - ret.len = 0; + memset(&ret.unknown, 0, sizeof(ret.unknown)); ret.iter = kUpb_Message_UnknownBegin; return ret; } diff --git a/upb/message/promote.h b/upb/message/promote.h index b2c5073de6dcc..c2eef0a4e815e 100644 --- a/upb/message/promote.h +++ b/upb/message/promote.h @@ -62,12 +62,27 @@ typedef struct { uintptr_t iter; } upb_FindUnknownRet; +// TODO: Migrate callers to use upb_Message_FindUnknown2. +// // Finds first occurrence of unknown data by tag id in message. // A depth_limit of zero means to just use the upb default depth limit. upb_FindUnknownRet upb_Message_FindUnknown(const upb_Message* msg, uint32_t field_number, int depth_limit); +typedef struct { + upb_FindUnknown_Status status; + upb_MessageUnknown unknown; + uintptr_t iter; +} upb_FindUnknownRet2; + +// Finds first occurrence of unknown data (including non-canonical extensions) +// by tag id in message. A depth_limit of zero means to just use the upb default +// depth limit. +upb_FindUnknownRet2 upb_Message_FindUnknown2(const upb_Message* msg, + uint32_t field_number, + int depth_limit); + typedef enum { kUpb_UnknownToMessage_Ok, kUpb_UnknownToMessage_ParseError, diff --git a/upb/message/promote_test.cc b/upb/message/promote_test.cc index 21906c12382e2..39874b55c371f 100644 --- a/upb/message/promote_test.cc +++ b/upb/message/promote_test.cc @@ -29,7 +29,7 @@ #include "upb/mem/arena.hpp" #include "upb/message/accessors.h" #include "upb/message/array.h" -#include "upb/message/copy.h" +#include "upb/message/internal/accessors.h" #include "upb/message/internal/message.h" #include "upb/message/map.h" #include "upb/message/message.h" @@ -44,7 +44,9 @@ #include "upb/test/test.upb.h" #include "upb/test/test.upb_minitable.h" #include "upb/wire/decode.h" -#include "upb/wire/encode.h" + +// Must be last. +#include "upb/port/def.inc" namespace { @@ -505,4 +507,205 @@ TEST(GeneratedCode, PromoteUnknownToMapOld) { upb_Arena_Free(arena); } +TEST(GeneratedCode, PromoteNonCanonicalExtension) { + upb::Arena arena; + + // 1. Build custom different mini-table + upb::MtDataEncoder e; + e.StartMessage(0); + e.PutField(kUpb_FieldType_String, 25, 0); + + upb_Status status; + upb_Status_Clear(&status); + upb_MiniTable* custom_sub_table = upb_MiniTable_Build( + e.data().data(), e.data().size(), arena.ptr(), &status); + ASSERT_TRUE(status.ok); + + upb_MiniTableExtension custom_ext = *upb_test_ModelExtension1_model_ext_ext; + upb_MiniTableExtension_SetSubMessage(&custom_ext, custom_sub_table); + + // 2. Create base message msg to hold our non-canonical extension + upb_test_ModelWithExtensions* msg = + upb_test_ModelWithExtensions_new(arena.ptr()); + + // 3. Create submsg parsed under custom_sub_table ("World") + upb_Message* extension1 = _upb_Message_New(custom_sub_table, arena.ptr()); + upb_MessageValue val_str; + val_str.str_val = upb_StringView_FromString("World"); + const upb_MiniTableField* custom_f = + upb_MiniTable_GetFieldByIndex(custom_sub_table, 0); + upb_Message_SetString(extension1, custom_f, val_str.str_val, arena.ptr()); + + // 4. Attach custom parsed submessage "World" to msg as a non-canonical + // extension under the different custom mini-table layout. + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + UPB_UPCAST(msg), &custom_ext, &extension1, arena.ptr()); + + // 5. Promote the extension using standard compiled mini-table ModelExtension1 + upb_MessageValue val; + upb_GetExtension_Status promote_status = upb_Message_GetOrPromoteExtension( + UPB_UPCAST(msg), upb_test_ModelExtension1_model_ext_ext, + kUpb_DecodeOption_AliasString, arena.ptr(), &val); + + ASSERT_EQ(kUpb_GetExtension_Ok, promote_status); + + // 6. Verify that the engine correctly converted the shape and promoted the + // value! + upb_test_ModelExtension1* ext_msg = (upb_test_ModelExtension1*)val.msg_val; + upb_StringView field = upb_test_ModelExtension1_str(ext_msg); + EXPECT_EQ(absl::string_view(field.data, field.size), "World"); + + EXPECT_EQ(1, upb_Message_ExtensionCount(UPB_UPCAST(msg))); + + // 7. Verify that upb_Message_NextExtension works and iterates over the + // promoted extension! + uintptr_t ext_iter = kUpb_Message_ExtensionBegin; + const upb_MiniTableExtension* ext_out = nullptr; + upb_MessageValue val_out; + EXPECT_TRUE(upb_Message_NextExtension(UPB_UPCAST(msg), &ext_out, &val_out, + &ext_iter)); + EXPECT_EQ( + upb_MiniTableExtension_Number(ext_out), + upb_MiniTableExtension_Number(upb_test_ModelExtension1_model_ext_ext)); + + upb_test_ModelExtension1* ext_msg_iter = + (upb_test_ModelExtension1*)val_out.msg_val; + upb_StringView field_iter = upb_test_ModelExtension1_str(ext_msg_iter); + EXPECT_EQ(absl::string_view(field_iter.data, field_iter.size), "World"); + + EXPECT_FALSE(upb_Message_NextExtension(UPB_UPCAST(msg), &ext_out, &val_out, + &ext_iter)); + + // 8. Verify that the promoted non-canonical extension is indeed no longer + // present in unknowns + upb_FindUnknownRet found = upb_Message_FindUnknown(UPB_UPCAST(msg), 1547, 0); + EXPECT_EQ(kUpb_FindUnknown_NotPresent, found.status); +} + +TEST(GeneratedCode, PromoteNonCanonicalExtensionWithSameMinitable) { + upb::Arena arena; + upb_test_ModelWithExtensions* msg = + upb_test_ModelWithExtensions_new(arena.ptr()); + upb_test_ModelExtension1* extension1 = + upb_test_ModelExtension1_new(arena.ptr()); + upb_test_ModelExtension1_set_str(extension1, + upb_StringView_FromString("World")); + + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + UPB_UPCAST(msg), upb_test_ModelExtension1_model_ext_ext, + (upb_Message**)&extension1, arena.ptr()); + + upb_MessageValue val; + upb_GetExtension_Status promote_status = upb_Message_GetOrPromoteExtension( + UPB_UPCAST(msg), upb_test_ModelExtension1_model_ext_ext, + kUpb_DecodeOption_AliasString, arena.ptr(), &val); + + EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); + upb_test_ModelExtension1* ext_msg = (upb_test_ModelExtension1*)val.msg_val; + upb_StringView field = upb_test_ModelExtension1_str(ext_msg); + EXPECT_EQ(absl::string_view(field.data, field.size), "World"); + EXPECT_EQ(1, upb_Message_ExtensionCount(UPB_UPCAST(msg))); + uintptr_t ext_iter = kUpb_Message_ExtensionBegin; + const upb_MiniTableExtension* ext_out = nullptr; + upb_MessageValue val_out; + EXPECT_TRUE(upb_Message_NextExtension(UPB_UPCAST(msg), &ext_out, &val_out, + &ext_iter)); + EXPECT_FALSE(upb_Message_NextExtension(UPB_UPCAST(msg), &ext_out, &val_out, + &ext_iter)); + upb_FindUnknownRet found = upb_Message_FindUnknown(UPB_UPCAST(msg), 1547, 0); + EXPECT_EQ(kUpb_FindUnknown_NotPresent, found.status); +} + +TEST(GeneratedCode, PromoteNonCanonicalExtensionWithDifferentMinitable) { + upb::Arena arena; + + // 1. Build custom different mini-table for the non-canonical extension ("ext" + // layout) It has an int32 field at tag 1. + upb_Status status; + upb_Status_Clear(&status); + upb::MtDataEncoder e_ext; + e_ext.StartMessage(0); + e_ext.PutField(kUpb_FieldType_Int32, 1, 0); + + upb_MiniTable* custom_sub_table_ext = upb_MiniTable_Build( + e_ext.data().data(), e_ext.data().size(), arena.ptr(), &status); + ASSERT_TRUE(status.ok); + + // 2. Build target mini-table for the base field ("base" layout, matching + // field tag 1 int32) + upb::MtDataEncoder e_base; + e_base.StartMessage(0); + e_base.PutField(kUpb_FieldType_Int32, 1, 0); + + upb_MiniTable* custom_sub_table_base = upb_MiniTable_Build( + e_base.data().data(), e_base.data().size(), arena.ptr(), &status); + ASSERT_TRUE(status.ok); + + // 3. Create target extension descriptor pointing to custom_sub_table_base + upb_MiniTableExtension target_ext = *upb_test_ModelExtension1_model_ext_ext; + upb_MiniTableExtension_SetSubMessage(&target_ext, custom_sub_table_base); + + // 4. Create a custom extension descriptor matching field number 1547 and + // pointing to custom_sub_table_ext + upb_MiniTableExtension custom_ext = *upb_test_ModelExtension1_model_ext_ext; + upb_MiniTableExtension_SetSubMessage(&custom_ext, custom_sub_table_ext); + + // 5. Create base msg + upb_test_ModelWithExtensions* msg = + upb_test_ModelWithExtensions_new(arena.ptr()); + + // 6. Populate the submsg parsed under custom_sub_table_ext with value 42 at + // tag 1 + upb_Message* extension1 = _upb_Message_New(custom_sub_table_ext, arena.ptr()); + const upb_MiniTableField* custom_f = + upb_MiniTable_GetFieldByIndex(custom_sub_table_ext, 0); + upb_Message_SetInt32(extension1, custom_f, 42, arena.ptr()); + + // 7. Attach it as a non-canonical extension to msg using field 1547 + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + UPB_UPCAST(msg), &custom_ext, &extension1, arena.ptr()); + + // 8. Run extension promotion using targeting target_ext layout + upb_MessageValue val; + upb_GetExtension_Status promote_status = upb_Message_GetOrPromoteExtension( + UPB_UPCAST(msg), &target_ext, kUpb_DecodeOption_AliasString, arena.ptr(), + &val); + + EXPECT_EQ(promote_status, kUpb_GetExtension_Ok); + + // 9. Retrieve and verify that it successfully converted and promoted the + // actual value + upb_Message* promoted_message = (upb_Message*)val.msg_val; + ASSERT_NE(promoted_message, nullptr); + + const upb_MiniTableField* base_f = + upb_MiniTable_GetFieldByIndex(custom_sub_table_base, 0); + int32_t promoted_value = upb_Message_GetInt32(promoted_message, base_f, 0); + EXPECT_EQ(promoted_value, 42); + + EXPECT_EQ(1, upb_Message_ExtensionCount(UPB_UPCAST(msg))); + + // 10. Verify that upb_Message_NextExtension successfully works and returns + // true! + uintptr_t ext_iter = kUpb_Message_ExtensionBegin; + const upb_MiniTableExtension* ext_out = nullptr; + upb_MessageValue val_out; + EXPECT_TRUE(upb_Message_NextExtension(UPB_UPCAST(msg), &ext_out, &val_out, + &ext_iter)); + EXPECT_EQ(upb_MiniTableExtension_Number(ext_out), 1547); + + upb_Message* ext_msg_iter = (upb_Message*)val_out.msg_val; + int32_t promoted_value_iter = upb_Message_GetInt32(ext_msg_iter, base_f, 0); + EXPECT_EQ(promoted_value_iter, 42); + + EXPECT_FALSE(upb_Message_NextExtension(UPB_UPCAST(msg), &ext_out, &val_out, + &ext_iter)); + + // 11. Verify that the promoted non-canonical extension is indeed no longer + // present in unknowns + upb_FindUnknownRet found = upb_Message_FindUnknown(UPB_UPCAST(msg), 1547, 0); + EXPECT_EQ(kUpb_FindUnknown_NotPresent, found.status); +} } // namespace + +#include "upb/port/undef.inc" diff --git a/upb/message/test.cc b/upb/message/test.cc index 5bb5abf3d40c7..324cc523411fe 100644 --- a/upb/message/test.cc +++ b/upb/message/test.cc @@ -34,6 +34,8 @@ #include "upb/message/accessors.h" #include "upb/message/array.h" #include "upb/message/compare.h" +#include "upb/message/internal/accessors.h" +#include "upb/message/internal/message.h" #include "upb/message/map.h" #include "upb/message/message.h" #include "upb/message/test.upb.h" @@ -879,6 +881,87 @@ TEST(MessageTest, Freeze) { } } +TEST(MessageTest, FreezeNonCanonicalExtensions) { + upb::Arena arena; + upb_test_TestExtensions* msg = upb_test_TestExtensions_new(arena.ptr()); + + // Create sub-message + protobuf_test_messages_proto3_TestAllTypesProto3* ext_submsg = + protobuf_test_messages_proto3_TestAllTypesProto3_new(arena.ptr()); + protobuf_test_messages_proto3_TestAllTypesProto3_set_optional_int32( + ext_submsg, 456); + + // Attach as non-canonical extension + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + UPB_UPCAST(msg), upb_test_optional_msg_ext_ext, &ext_submsg, arena.ptr()); + + EXPECT_FALSE(upb_Message_IsFrozen(UPB_UPCAST(msg))); + EXPECT_FALSE(upb_Message_IsFrozen(UPB_UPCAST(ext_submsg))); + + // Freeze the parent message + upb_Message_Freeze(UPB_UPCAST(msg), &upb_0test__TestExtensions_msg_init); + + EXPECT_TRUE(upb_Message_IsFrozen(UPB_UPCAST(msg))); + // The non-canonical extension sub-message must be recursively frozen too! + EXPECT_TRUE(upb_Message_IsFrozen(UPB_UPCAST(ext_submsg))); +} + +TEST(MessageTest, DiscardUnknownsNonCanonicalExtensions) { + upb::Arena arena; + upb_test_TestExtensions* msg = upb_test_TestExtensions_new(arena.ptr()); + + // Create sub-message + protobuf_test_messages_proto3_TestAllTypesProto3* ext_submsg = + protobuf_test_messages_proto3_TestAllTypesProto3_new(arena.ptr()); + + // Attach as non-canonical extension + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + UPB_UPCAST(msg), upb_test_optional_msg_ext_ext, &ext_submsg, arena.ptr()); + + // Add some standard raw unknown bytes + char raw_unknown[] = "\x08\x96\x01"; // tag 1 = 150 + UPB_PRIVATE(_upb_Message_AddUnknown)(UPB_UPCAST(msg), raw_unknown, + sizeof(raw_unknown) - 1, arena.ptr(), + kUpb_AddUnknown_Copy); + + // Verify both are present initially + { + upb_MessageUnknown data; + uintptr_t iter = kUpb_Message_UnknownBegin; + bool has_non_canonical = false; + bool has_bytes = false; + while (upb_Message_NextUnknown2(UPB_UPCAST(msg), &data, &iter)) { + if (data.type == kUpb_MessageUnknownType_NonCanonicalExtension) { + has_non_canonical = true; + } else if (data.type == kUpb_MessageUnknownType_Bytes) { + has_bytes = true; + } + } + EXPECT_TRUE(has_non_canonical); + EXPECT_TRUE(has_bytes); + } + + // Discard unknown fields on the message + _upb_Message_DiscardUnknown_shallow(UPB_UPCAST(msg)); + + // Verify both non-canonical extension and raw unknown bytes are discarded! + { + upb_MessageUnknown data; + uintptr_t iter = kUpb_Message_UnknownBegin; + bool has_non_canonical = false; + bool has_bytes = false; + while (upb_Message_NextUnknown2(UPB_UPCAST(msg), &data, &iter)) { + if (data.type == kUpb_MessageUnknownType_NonCanonicalExtension) { + has_non_canonical = true; + } else if (data.type == kUpb_MessageUnknownType_Bytes) { + has_bytes = true; + } + } + EXPECT_FALSE(has_non_canonical); + EXPECT_FALSE(has_bytes); + } +} + /* Tests some somewhat tricky math used in size calculations while encoding */ TEST(MessageTest, SkippedVarintSize) { for (uint32_t clz = 0; clz <= 64; clz++) { diff --git a/upb/mini_table/internal/extension.h b/upb/mini_table/internal/extension.h index c0cb8c484e58a..1f6bc39709d0d 100644 --- a/upb/mini_table/internal/extension.h +++ b/upb/mini_table/internal/extension.h @@ -23,6 +23,8 @@ struct upb_MiniTableExtension { struct upb_MiniTableField UPB_PRIVATE(field); union upb_MiniTableSub UPB_PRIVATE(sub); // NULL unless submsg or proto2 enum + // A known extendee schema for a canonical extension. Otherwise, the + // `extendee` info should be ignored/NULL for a non-canonical one. const struct upb_MiniTable* UPB_PRIVATE(extendee); }; diff --git a/upb/wire/BUILD b/upb/wire/BUILD index b52de7aa6a5df..48f50ae5ca4d1 100644 --- a/upb/wire/BUILD +++ b/upb/wire/BUILD @@ -132,10 +132,14 @@ cc_test( "//upb/base", "//upb/mem", "//upb/message", + "//upb/message:internal", "//upb/message:message_cc", "//upb/mini_descriptor", + "//upb/mini_descriptor:internal", "//upb/mini_table", "//upb/port", + "//upb/test:test_proto_upb_minitable", + "//upb/test:test_upb_proto", "//upb/wire/decode_fast:combinations", "//upb/wire/test_util:field_types", "//upb/wire/test_util:make_mini_table", diff --git a/upb/wire/decode_test.cc b/upb/wire/decode_test.cc index fae6fe7e62e2e..9d784fe6d6b47 100644 --- a/upb/wire/decode_test.cc +++ b/upb/wire/decode_test.cc @@ -21,17 +21,28 @@ #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "upb/base/descriptor_constants.h" +#include "upb/base/status.h" #include "upb/base/string_view.h" +#include "upb/base/upcast.h" #include "upb/mem/arena.h" #include "upb/mem/arena.hpp" #include "upb/message/accessors.h" #include "upb/message/accessors.hpp" #include "upb/message/array.h" +#include "upb/message/internal/accessors.h" +#include "upb/message/internal/message.h" #include "upb/message/message.h" +#include "upb/mini_descriptor/decode.h" +#include "upb/mini_descriptor/internal/encode.hpp" #include "upb/mini_descriptor/link.h" +#include "upb/mini_table/extension.h" #include "upb/mini_table/field.h" #include "upb/mini_table/message.h" +#include "upb/test/test.upb.h" +#include "upb/test/test.upb_minitable.h" #include "upb/wire/decode_fast/combinations.h" +#include "upb/wire/encode.h" #include "upb/wire/test_util/field_types.h" #include "upb/wire/test_util/make_mini_table.h" #include "upb/wire/test_util/wire_message.h" @@ -471,6 +482,171 @@ TEST(DecodeTest, MaxDepthPayloadParsesSuccessfully) { } } +TEST(DecodeTest, DecodeNonCanonicalExtensionAsUnknown) { + upb::Arena arena; + + // 1. Build custom different mini-table for the submessage layout ("extension + // A"). + upb::MtDataEncoder e; + e.StartMessage(0); + e.PutField(kUpb_FieldType_String, 25, 0); + + upb_Status status; + upb_Status_Clear(&status); + upb_MiniTable* custom_sub_table = upb_MiniTable_Build( + e.data().data(), e.data().size(), arena.ptr(), &status); + ASSERT_TRUE(status.ok); + + upb_MiniTableExtension custom_ext = *upb_test_ModelExtension1_model_ext_ext; + upb_MiniTableExtension_SetSubMessage(&custom_ext, custom_sub_table); + + // 2. Create base msg which starts empty + upb_test_ModelWithExtensions* msg = + upb_test_ModelWithExtensions_new(arena.ptr()); + + // 3. Create parsed submessage ("World") under custom_sub_table + upb_Message* extension1 = _upb_Message_New(custom_sub_table, arena.ptr()); + upb_MessageValue val_str; + val_str.str_val = upb_StringView_FromString("World"); + const upb_MiniTableField* custom_f = + upb_MiniTable_GetFieldByIndex(custom_sub_table, 0); + upb_Message_SetString(extension1, custom_f, val_str.str_val, arena.ptr()); + + // 4. msg has a non-canonical extension A + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + UPB_UPCAST(msg), &custom_ext, &extension1, arena.ptr()); + + // Verify extension count is 0 before encoding/decoding. + EXPECT_EQ((int)upb_Message_ExtensionCount(UPB_UPCAST(msg)), 0); + + // 5. Obtain encoded non-canonical extension A by serializing msg + char* buf; + size_t size; + upb_EncodeStatus enc_status = + upb_Encode(UPB_UPCAST(msg), &upb_0test__ModelWithExtensions_msg_init, 0, + arena.ptr(), &buf, &size); + ASSERT_EQ(enc_status, kUpb_EncodeStatus_Ok); + ASSERT_GT(size, 0u); + + // 6. Decode with extreg = nullptr (so the encoded extension A is decoded as + // unknown bytes) + upb_DecodeStatus dec_status = upb_Decode( + buf, size, UPB_UPCAST(msg), &upb_0test__ModelWithExtensions_msg_init, + /*extreg=*/nullptr, 0, arena.ptr()); + ASSERT_EQ(dec_status, kUpb_DecodeStatus_Ok); + + // 7. Verify that we end up with exactly one non-canonical extension A + one + // unknown bytes block representing A + int non_canonical_count = 0; + int unknown_bytes_count = 0; + uintptr_t iter = kUpb_Message_UnknownBegin; + upb_MessageUnknown data; + while (upb_Message_NextUnknown2(UPB_UPCAST(msg), &data, &iter)) { + if (data.type == kUpb_MessageUnknownType_NonCanonicalExtension) { + non_canonical_count++; + } else if (data.type == kUpb_MessageUnknownType_Bytes) { + unknown_bytes_count++; + } + } + EXPECT_EQ(non_canonical_count, 1); + EXPECT_EQ(unknown_bytes_count, 1); + + // Verify extension APIs: there are zero canonical extensions. + EXPECT_EQ((int)upb_Message_ExtensionCount(UPB_UPCAST(msg)), 0); + uintptr_t ext_iter = kUpb_Message_ExtensionBegin; + const upb_MiniTableExtension* ext_out = nullptr; + upb_MessageValue val_out; + EXPECT_FALSE(upb_Message_NextExtension(UPB_UPCAST(msg), &ext_out, &val_out, + &ext_iter)); +} + +TEST(DecodeTest, DecodeExtensionAsUnknownWithPreexistingUnknown) { + upb::Arena arena; + + // 1. Build custom different mini-table for the submessage layout ("extension + // A"). + upb::MtDataEncoder e; + e.StartMessage(0); + e.PutField(kUpb_FieldType_String, 25, 0); + + upb_Status status; + upb_Status_Clear(&status); + upb_MiniTable* custom_sub_table = upb_MiniTable_Build( + e.data().data(), e.data().size(), arena.ptr(), &status); + ASSERT_TRUE(status.ok); + + upb_MiniTableExtension custom_ext = *upb_test_ModelExtension1_model_ext_ext; + upb_MiniTableExtension_SetSubMessage(&custom_ext, custom_sub_table); + + // 2. Create a temporary message to serialize the extension + upb_test_ModelWithExtensions* tmp_msg = + upb_test_ModelWithExtensions_new(arena.ptr()); + + // 3. Create parsed submessage ("World") under custom_sub_table + upb_Message* extension1 = _upb_Message_New(custom_sub_table, arena.ptr()); + upb_MessageValue val_str; + val_str.str_val = upb_StringView_FromString("World"); + const upb_MiniTableField* custom_f = + upb_MiniTable_GetFieldByIndex(custom_sub_table, 0); + upb_Message_SetString(extension1, custom_f, val_str.str_val, arena.ptr()); + + // 4. Attach to tmp_msg as a non-canonical extension so we can serialize it to + // get the bytes + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + UPB_UPCAST(tmp_msg), &custom_ext, &extension1, arena.ptr()); + + // 5. Obtain encoded extension A by serializing tmp_msg + char* buf; + size_t size; + upb_EncodeStatus enc_status = + upb_Encode(UPB_UPCAST(tmp_msg), &upb_0test__ModelWithExtensions_msg_init, + 0, arena.ptr(), &buf, &size); + ASSERT_EQ(enc_status, kUpb_EncodeStatus_Ok); + ASSERT_GT(size, 0u); + + // 6. Create destination message and put the serialized bytes as an unknown + // field on msg + upb_test_ModelWithExtensions* msg = + upb_test_ModelWithExtensions_new(arena.ptr()); + bool add_ok = UPB_PRIVATE(_upb_Message_AddUnknown)( + UPB_UPCAST(msg), buf, size, arena.ptr(), kUpb_AddUnknown_Alias); + ASSERT_TRUE(add_ok); + + // Verify extension count is 0 before decoding. + EXPECT_EQ((int)upb_Message_ExtensionCount(UPB_UPCAST(msg)), 0); + + // 7. Decode with extreg = nullptr (so the encoded extension A is decoded as + // unknown bytes) + upb_DecodeStatus dec_status = upb_Decode( + buf, size, UPB_UPCAST(msg), &upb_0test__ModelWithExtensions_msg_init, + /*extreg=*/nullptr, 0, arena.ptr()); + ASSERT_EQ(dec_status, kUpb_DecodeStatus_Ok); + + // 8. Verify that we end up with exactly two unknown bytes blocks representing + // A + int non_canonical_count = 0; + int unknown_bytes_count = 0; + uintptr_t iter = kUpb_Message_UnknownBegin; + upb_MessageUnknown data; + while (upb_Message_NextUnknown2(UPB_UPCAST(msg), &data, &iter)) { + if (data.type == kUpb_MessageUnknownType_NonCanonicalExtension) { + non_canonical_count++; + } else if (data.type == kUpb_MessageUnknownType_Bytes) { + unknown_bytes_count++; + } + } + EXPECT_EQ(non_canonical_count, 0); + EXPECT_EQ(unknown_bytes_count, 2); + + // Verify extension APIs: there are zero canonical extensions. + EXPECT_EQ((int)upb_Message_ExtensionCount(UPB_UPCAST(msg)), 0); + uintptr_t ext_iter = kUpb_Message_ExtensionBegin; + const upb_MiniTableExtension* ext_out = nullptr; + upb_MessageValue val_out; + EXPECT_FALSE(upb_Message_NextExtension(UPB_UPCAST(msg), &ext_out, &val_out, + &ext_iter)); +} + TEST(DecodeTest, DecodeGroupFieldFromDelimitedWireFormatAsUnknown) { upb::Arena mt_arena; upb::Arena msg_arena; diff --git a/upb/wire/encode_test.cc b/upb/wire/encode_test.cc index 6da6d5d25abc9..c70ab08267164 100644 --- a/upb/wire/encode_test.cc +++ b/upb/wire/encode_test.cc @@ -6,6 +6,8 @@ #include #include "upb/mem/arena.h" #include "upb/message/array.h" +#include "upb/message/internal/accessors.h" +#include "upb/message/internal/extension.h" #include "upb/message/internal/map_sorter.h" #include "upb/message/message.h" #include "upb/mini_table/extension.h" @@ -203,6 +205,226 @@ TEST(EncodeTest, EncodeExtensionMaxDepthExceeded) { upb_Arena_Free(arena); } +TEST(EncodeTest, EncodeNonCanonicalExtensionSuccess) { + upb_Arena* arena = upb_Arena_New(); + + upb_wire_test_TestExtensions* msg = upb_wire_test_TestExtensions_new(arena); + + // Attach scalar extension as non-canonical + int32_t val = 42; + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + (upb_Message*)msg, upb_wire_test_ext_i32_ext, &val, arena); + + // Encode the message. + char* buf; + size_t size; + upb_EncodeStatus status = + upb_Encode((upb_Message*)msg, &upb_0wire_0test__TestExtensions_msg_init, + 0, arena, &buf, &size); + EXPECT_EQ(status, kUpb_EncodeStatus_Ok); + EXPECT_GT(size, 0u); + + // Verify that the encoded bytes can be decoded back using the registry! + upb_ExtensionRegistry* ext_reg = upb_ExtensionRegistry_New(arena); + const upb_MiniTableExtension* ext_array[1] = {upb_wire_test_ext_i32_ext}; + upb_ExtensionRegistry_AddArray(ext_reg, ext_array, 1); + + upb_wire_test_TestExtensions* decoded_msg = + upb_wire_test_TestExtensions_parse_ex(buf, size, ext_reg, 0, arena); + EXPECT_NE(decoded_msg, nullptr); + EXPECT_TRUE(upb_wire_test_has_ext_i32(decoded_msg)); + EXPECT_EQ(upb_wire_test_ext_i32(decoded_msg), 42); + + upb_Arena_Free(arena); +} + +TEST(EncodeTest, SkipUnknownNonCanonicalExtensionSuccess) { + upb_Arena* arena = upb_Arena_New(); + + upb_wire_test_TestExtensions* msg = upb_wire_test_TestExtensions_new(arena); + + // 1. Add a canonical extension (ext_i32, tag 100) to msg + upb_Extension* canonical_ext = UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( + (upb_Message*)msg, upb_wire_test_ext_i32_ext, arena); + canonical_ext->data.int32_val = 1000; + + // 2. Attach a non-canonical extension (ext_recursive, tag 101) to msg + upb_wire_test_TestRecursive* sub_msg = upb_wire_test_TestRecursive_new(arena); + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + (upb_Message*)msg, upb_wire_test_ext_recursive_ext, &sub_msg, arena); + + // 3. Also add some standard raw unknown bytes (tag 150) + char raw_unknown[] = "\x08\x96\x01"; // tag 1 = 150 + UPB_PRIVATE(_upb_Message_AddUnknown)((upb_Message*)msg, raw_unknown, + sizeof(raw_unknown) - 1, arena, + kUpb_AddUnknown_Copy); + + // Encode the message WITH kUpb_EncodeOption_SkipUnknown option! + char* buf; + size_t size; + upb_EncodeStatus status = + upb_Encode((upb_Message*)msg, &upb_0wire_0test__TestExtensions_msg_init, + kUpb_EncodeOption_SkipUnknown, arena, &buf, &size); + EXPECT_EQ(status, kUpb_EncodeStatus_Ok); + + // Parse back the serialized bytes. + // It MUST contain the canonical extension, + // but the non-canonical extension and standard raw unknown bytes MUST be + // successfully skipped. + upb_ExtensionRegistry* ext_reg = upb_ExtensionRegistry_New(arena); + const upb_MiniTableExtension* ext_array[2] = { + upb_wire_test_ext_i32_ext, upb_wire_test_ext_recursive_ext}; + upb_ExtensionRegistry_AddArray(ext_reg, ext_array, 2); + + upb_wire_test_TestExtensions* decoded_msg = + upb_wire_test_TestExtensions_parse_ex(buf, size, ext_reg, 0, arena); + EXPECT_NE(decoded_msg, nullptr); + + // Verify canonical extension was NOT skipped and is present + EXPECT_TRUE(upb_wire_test_has_ext_i32(decoded_msg)); + EXPECT_EQ(upb_wire_test_ext_i32(decoded_msg), 1000); + + // Verify non-canonical extension WAS skipped + EXPECT_FALSE(upb_Message_HasExtension((const upb_Message*)decoded_msg, + upb_wire_test_ext_recursive_ext)); + + // Verify raw unknown bytes WERE skipped and are discarded + upb_MessageUnknown data; + uintptr_t iter = kUpb_Message_UnknownBegin; + EXPECT_FALSE( + upb_Message_NextUnknown2((const upb_Message*)decoded_msg, &data, &iter)); + + upb_Arena_Free(arena); +} + +TEST(EncodeTest, EncodeNonCanonicalExtensionDeterministicSuccess) { + upb_Arena* arena = upb_Arena_New(); + + upb_wire_test_TestExtensions* msg = upb_wire_test_TestExtensions_new(arena); + + // 1. Attach scalar extension as non-canonical (tag 100) + int32_t val = 42; + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + (upb_Message*)msg, upb_wire_test_ext_i32_ext, &val, arena); + + // 2. Attach recursive extension as non-canonical (tag 101) + upb_wire_test_TestRecursive* sub_msg = upb_wire_test_TestRecursive_new(arena); + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + (upb_Message*)msg, upb_wire_test_ext_recursive_ext, &sub_msg, arena); + + // Encode the message with deterministic option! + char* buf; + size_t size; + upb_EncodeStatus status = + upb_Encode((upb_Message*)msg, &upb_0wire_0test__TestExtensions_msg_init, + kUpb_EncodeOption_Deterministic, arena, &buf, &size); + EXPECT_EQ(status, kUpb_EncodeStatus_Ok); + EXPECT_GT(size, 0u); + + // Verify that the encoded bytes can be decoded back using the registry! + upb_ExtensionRegistry* ext_reg = upb_ExtensionRegistry_New(arena); + const upb_MiniTableExtension* ext_array[2] = { + upb_wire_test_ext_i32_ext, upb_wire_test_ext_recursive_ext}; + upb_ExtensionRegistry_AddArray(ext_reg, ext_array, 2); + + upb_wire_test_TestExtensions* decoded_msg = + upb_wire_test_TestExtensions_parse_ex(buf, size, ext_reg, 0, arena); + EXPECT_NE(decoded_msg, nullptr); + + EXPECT_TRUE(upb_wire_test_has_ext_i32(decoded_msg)); + EXPECT_EQ(upb_wire_test_ext_i32(decoded_msg), 42); + EXPECT_TRUE(upb_wire_test_has_ext_recursive(decoded_msg)); + + // Explicitly verify that the extensions are successfully serialized and + // resolved from the encoded message payload. + EXPECT_EQ((int)upb_Message_ExtensionCount((const upb_Message*)decoded_msg), + 2); + + // Verify that if we decode without a registry, the non-canonical extensions + // remain as raw unknown bytes inside the decoded message. + upb_wire_test_TestExtensions* decoded_as_unknown = + upb_wire_test_TestExtensions_parse_ex(buf, size, nullptr, 0, arena); + EXPECT_NE(decoded_as_unknown, nullptr); + EXPECT_EQ( + (int)upb_Message_ExtensionCount((const upb_Message*)decoded_as_unknown), + 0); + EXPECT_TRUE(upb_Message_HasUnknown((const upb_Message*)decoded_as_unknown)); + + int unknown_bytes_count = 0; + upb_MessageUnknown udata; + uintptr_t uiter = kUpb_Message_UnknownBegin; + while (upb_Message_NextUnknown2((const upb_Message*)decoded_as_unknown, + &udata, &uiter)) { + if (udata.type == kUpb_MessageUnknownType_Bytes) { + unknown_bytes_count++; + } + } + EXPECT_GT(unknown_bytes_count, 0); + + upb_Arena_Free(arena); +} + +TEST(EncodeTest, SkipUnknownNonCanonicalExtensionDeterministicSuccess) { + upb_Arena* arena = upb_Arena_New(); + + upb_wire_test_TestExtensions* msg = upb_wire_test_TestExtensions_new(arena); + + // 1. Add a canonical extension (ext_i32, tag 100) to msg + upb_Extension* canonical_ext = UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( + (upb_Message*)msg, upb_wire_test_ext_i32_ext, arena); + canonical_ext->data.int32_val = 1000; + + // 2. Attach a non-canonical extension (ext_recursive, tag 101) to msg + upb_wire_test_TestRecursive* sub_msg = upb_wire_test_TestRecursive_new(arena); + UPB_PRIVATE(_upb_Message_SetNonCanonicalExtension)( + (upb_Message*)msg, upb_wire_test_ext_recursive_ext, &sub_msg, arena); + + // 3. Also add some standard raw unknown bytes (tag 150) + char raw_unknown[] = "\x08\x96\x01"; // tag 1 = 150 + UPB_PRIVATE(_upb_Message_AddUnknown)((upb_Message*)msg, raw_unknown, + sizeof(raw_unknown) - 1, arena, + kUpb_AddUnknown_Copy); + + // Encode the message WITH kUpb_EncodeOption_SkipUnknown | + // kUpb_EncodeOption_Deterministic! + char* buf; + size_t size; + upb_EncodeStatus status = upb_Encode( + (upb_Message*)msg, &upb_0wire_0test__TestExtensions_msg_init, + kUpb_EncodeOption_SkipUnknown | kUpb_EncodeOption_Deterministic, arena, + &buf, &size); + EXPECT_EQ(status, kUpb_EncodeStatus_Ok); + + // Parse back the serialized bytes. + // It MUST contain the canonical extension, + // but the non-canonical extension and standard raw unknown bytes MUST be + // successfully skipped. + upb_ExtensionRegistry* ext_reg = upb_ExtensionRegistry_New(arena); + const upb_MiniTableExtension* ext_array[2] = { + upb_wire_test_ext_i32_ext, upb_wire_test_ext_recursive_ext}; + upb_ExtensionRegistry_AddArray(ext_reg, ext_array, 2); + + upb_wire_test_TestExtensions* decoded_msg = + upb_wire_test_TestExtensions_parse_ex(buf, size, ext_reg, 0, arena); + EXPECT_NE(decoded_msg, nullptr); + + // Verify canonical extension was NOT skipped and is present + EXPECT_TRUE(upb_wire_test_has_ext_i32(decoded_msg)); + EXPECT_EQ(upb_wire_test_ext_i32(decoded_msg), 1000); + + // Verify non-canonical extension WAS skipped + EXPECT_FALSE(upb_Message_HasExtension((const upb_Message*)decoded_msg, + upb_wire_test_ext_recursive_ext)); + + // Verify raw unknown bytes WERE skipped and are discarded + upb_MessageUnknown data; + uintptr_t iter = kUpb_Message_UnknownBegin; + EXPECT_FALSE( + upb_Message_NextUnknown2((const upb_Message*)decoded_msg, &data, &iter)); + + upb_Arena_Free(arena); +} + } // namespace } // namespace upb diff --git a/upb/wire/internal/encoder.c b/upb/wire/internal/encoder.c index 3b67547f520de..3dd2e4e6d35c8 100644 --- a/upb/wire/internal/encoder.c +++ b/upb/wire/internal/encoder.c @@ -724,21 +724,26 @@ static char* encode_exts(char* ptr, upb_encstate* e, const upb_MiniTable* m, upb_Message_Internal* in = UPB_PRIVATE(_upb_Message_GetInternal)(msg); if (!in) return ptr; + bool skip_unknown = (e->options & kUpb_EncodeOption_SkipUnknown) != 0; /* Encode all extensions together. Unlike C++, we do not attempt to keep * these in field number order relative to normal fields or even to each * other. */ - uintptr_t iter = kUpb_Message_ExtensionBegin; - const upb_MiniTableExtension* ext; - upb_MessageValue ext_val; - if (!UPB_PRIVATE(_upb_Message_NextExtensionReverse)(msg, &ext, &ext_val, - &iter)) { + size_t count = 0; + for (size_t i = 0; i < in->size; i++) { + bool is_any_extension = + upb_TaggedAuxPtr_IsExtension(in->aux_data[i]) || + (!skip_unknown && + upb_TaggedAuxPtr_IsNonCanonicalExtension(in->aux_data[i])); + count += is_any_extension; + } + if (count == 0) { // Message has no extensions. return ptr; } if (e->options & kUpb_EncodeOption_Deterministic) { _upb_sortedmap sorted; - if (!_upb_mapsorter_pushexts(&e->sorter, in, &sorted)) { + if (!_upb_mapsorter_pushexts(&e->sorter, in, &sorted, !skip_unknown)) { // TODO: b/378744096 - handle alloc failure } const upb_Extension* ext; @@ -749,11 +754,25 @@ static char* encode_exts(char* ptr, upb_encstate* e, const upb_MiniTable* m, } _upb_mapsorter_popmap(&e->sorter, &sorted); } else { - do { - ptr = encode_ext(ptr, e, ext, ext_val, - m->UPB_PRIVATE(ext) == kUpb_ExtMode_IsMessageSet); - } while (UPB_PRIVATE(_upb_Message_NextExtensionReverse)(msg, &ext, &ext_val, - &iter)); + size_t i = in->size; + while (i > 0) { + i--; + upb_TaggedAuxPtr tagged_ptr = in->aux_data[i]; + if (upb_TaggedAuxPtr_IsExtension(tagged_ptr)) { + const upb_Extension* ext = upb_TaggedAuxPtr_Extension(tagged_ptr); + ptr = encode_ext(ptr, e, ext->ext, ext->data, + UPB_PRIVATE(_upb_MiniTable_ExtModeBase)(m) == + kUpb_ExtMode_IsMessageSet); + } else if (!skip_unknown && + upb_TaggedAuxPtr_IsNonCanonicalExtension(tagged_ptr)) { + // Encode non-canonical extensions if not skipping unknown fields. + const upb_Extension* ext = + upb_TaggedAuxPtr_NonCanonicalExtension(tagged_ptr); + ptr = encode_ext(ptr, e, ext->ext, ext->data, + UPB_PRIVATE(_upb_MiniTable_ExtModeBase)(m) == + kUpb_ExtMode_IsMessageSet); + } + } } return ptr; } @@ -770,12 +789,16 @@ char* encode_message(char* ptr, upb_encstate* e, const upb_Message* msg, } } - if ((e->options & kUpb_EncodeOption_SkipUnknown) == 0) { + bool skip_unknown = (e->options & kUpb_EncodeOption_SkipUnknown) != 0; + if (!skip_unknown) { size_t unknown_size = 0; uintptr_t iter = kUpb_Message_UnknownBegin; upb_StringView unknown; // Need to write in reverse order, but iteration is in-order; scan to - // reserve capacity up front, then write in-order + // reserve capacity up front, then write in-order. + // + // Encode unknown fields only. Non-canonical extension encoding is handled + // in encode_exts below. while (upb_Message_NextUnknown(msg, &unknown, &iter)) { unknown_size += unknown.size; }