diff --git a/example/test.lua b/example/test.lua index 6a7bf87..aef72cd 100644 --- a/example/test.lua +++ b/example/test.lua @@ -1,4 +1,4 @@ -require 'person_pb' +person_pb = require 'person_pb' local person = person_pb.Person() person.id = 1000 diff --git a/protobuf/containers.lua b/protobuf/containers.lua index 3f9f878..893dc1f 100644 --- a/protobuf/containers.lua +++ b/protobuf/containers.lua @@ -21,7 +21,7 @@ local table = table local rawset = rawset local error = error -module "protobuf.containers" +local containers = {} local _RCFC_meta = { add = function(self) @@ -47,7 +47,7 @@ local _RCFC_meta = { } _RCFC_meta.__index = _RCFC_meta -function RepeatedCompositeFieldContainer(listener, message_descriptor) +function containers.RepeatedCompositeFieldContainer(listener, message_descriptor) local o = { _listener = listener, _message_descriptor = message_descriptor @@ -71,9 +71,11 @@ local _RSFC_meta = { } _RSFC_meta.__index = _RSFC_meta -function RepeatedScalarFieldContainer(listener, type_checker) +function containers.RepeatedScalarFieldContainer(listener, type_checker) local o = {} o._listener = listener o._type_checker = type_checker return setmetatable(o, _RSFC_meta) end + +return containers diff --git a/protobuf/decoder.lua b/protobuf/decoder.lua index f927830..aabef69 100644 --- a/protobuf/decoder.lua +++ b/protobuf/decoder.lua @@ -27,7 +27,7 @@ local pb = require "protobuf.pb" local encoder = require "protobuf.encoder" local wire_format = require "protobuf.wire_format" -module "protobuf.decoder" +local decoder = {} local _DecodeVarint = pb.varint_decoder local _DecodeSignedVarint = pb.signed_varint_decoder @@ -35,7 +35,7 @@ local _DecodeSignedVarint = pb.signed_varint_decoder local _DecodeVarint32 = pb.varint_decoder local _DecodeSignedVarint32 = pb.signed_varint_decoder -ReadTag = pb.read_tag +decoder.ReadTag = pb.read_tag local function _SimpleDecoder(wire_type, decode_value) return function(field_number, is_repeated, is_packed, key, new_default) @@ -110,7 +110,7 @@ end local function _StructPackDecoder(wire_type, value_size, format) local struct_unpack = pb.struct_unpack - function InnerDecode(buffer, pos) + local function InnerDecode(buffer, pos) local new_pos = pos + value_size local result = struct_unpack(format, buffer, pos) return result, new_pos @@ -122,27 +122,27 @@ local function _Boolean(value) return value ~= 0 end -Int32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) -EnumDecoder = Int32Decoder +decoder.Int32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) +decoder.EnumDecoder = decoder.Int32Decoder -Int64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) +decoder.Int64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) -UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) -UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) +decoder.UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) +decoder.UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) -SInt32Decoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode32) -SInt64Decoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode64) +decoder.SInt32Decoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode32) +decoder.SInt64Decoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode64) -Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('I')) -Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('Q')) -SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('i')) -SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('q')) -FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('f')) -DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('d')) +decoder.Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('I')) +decoder.Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('Q')) +decoder.SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('i')) +decoder.SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('q')) +decoder.FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('f')) +decoder.DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('d')) -BoolDecoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint, _Boolean) +decoder.BoolDecoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint, _Boolean) -function StringDecoder(field_number, is_repeated, is_packed, key, new_default) +function decoder.StringDecoder(field_number, is_repeated, is_packed, key, new_default) local DecodeVarint = _DecodeVarint local sub = string.sub --local unicode = unicode @@ -184,7 +184,7 @@ function StringDecoder(field_number, is_repeated, is_packed, key, new_default) end end -function BytesDecoder(field_number, is_repeated, is_packed, key, new_default) +function decoder.BytesDecoder(field_number, is_repeated, is_packed, key, new_default) local DecodeVarint = _DecodeVarint local sub = string.sub assert(not is_packed) @@ -225,7 +225,7 @@ function BytesDecoder(field_number, is_repeated, is_packed, key, new_default) end end -function MessageDecoder(field_number, is_repeated, is_packed, key, new_default) +function decoder.MessageDecoder(field_number, is_repeated, is_packed, key, new_default) local DecodeVarint = _DecodeVarint local sub = string.sub assert(not is_packed) @@ -275,13 +275,13 @@ function MessageDecoder(field_number, is_repeated, is_packed, key, new_default) end end -function _SkipVarint(buffer, pos, pend) +local function _SkipVarint(buffer, pos, pend) local value value, pos = _DecodeVarint(buffer, pos) return pos end -function _SkipFixed64(buffer, pos, pend) +local function _SkipFixed64(buffer, pos, pend) pos = pos + 8 if pos > pend then error('Truncated message.') @@ -289,7 +289,7 @@ function _SkipFixed64(buffer, pos, pend) return pos end -function _SkipLengthDelimited(buffer, pos, pend) +local function _SkipLengthDelimited(buffer, pos, pend) local size size, pos = _DecodeVarint(buffer, pos) pos = pos + size @@ -299,7 +299,7 @@ function _SkipLengthDelimited(buffer, pos, pend) return pos end -function _SkipFixed32(buffer, pos, pend) +local function _SkipFixed32(buffer, pos, pend) pos = pos + 4 if pos > pend then error('Truncated message.') @@ -307,17 +307,21 @@ function _SkipFixed32(buffer, pos, pend) return pos end -function _RaiseInvalidWireType(buffer, pos, pend) +local function _Unsupported(buffer, pos, pend) + error('Field not supported.') +end + +local function _RaiseInvalidWireType(buffer, pos, pend) error('Tag had invalid wire type.') end -function _FieldSkipper() - WIRETYPE_TO_SKIPPER = { +local function _FieldSkipper() + local WIRETYPE_TO_SKIPPER = { _SkipVarint, _SkipFixed64, _SkipLengthDelimited, - _SkipGroup, - _EndGroup, + _Unsupported, --_SkipGroup, + _Unsupported, --_EndGroup, _SkipFixed32, _RaiseInvalidWireType, _RaiseInvalidWireType, @@ -333,4 +337,6 @@ function _FieldSkipper() end end -SkipField = _FieldSkipper() +decoder.SkipField = _FieldSkipper() + +return decoder \ No newline at end of file diff --git a/protobuf/descriptor.lua b/protobuf/descriptor.lua index c11da5d..38d1dc1 100644 --- a/protobuf/descriptor.lua +++ b/protobuf/descriptor.lua @@ -16,9 +16,9 @@ -------------------------------------------------------------------------------- -- -module "protobuf.descriptor" +local descriptor = {} -FieldDescriptor = { +descriptor.FieldDescriptor = { TYPE_DOUBLE = 1, TYPE_FLOAT = 2, TYPE_INT64 = 3, @@ -58,3 +58,5 @@ FieldDescriptor = { LABEL_REPEATED = 3, MAX_LABEL = 3 } + +return descriptor \ No newline at end of file diff --git a/protobuf/encoder.lua b/protobuf/encoder.lua index f013f4f..cdc5064 100644 --- a/protobuf/encoder.lua +++ b/protobuf/encoder.lua @@ -24,9 +24,9 @@ local assert = assert local pb = require "protobuf.pb" local wire_format = require "protobuf.wire_format" -module "protobuf.encoder" +local encoder = {} -function _VarintSize(value) +local function _VarintSize(value) if value <= 0x7f then return 1 end if value <= 0x3fff then return 2 end if value <= 0x1fffff then return 3 end @@ -34,7 +34,7 @@ function _VarintSize(value) return 5 end -function _SignedVarintSize(value) +local function _SignedVarintSize(value) if value < 0 then return 10 end if value <= 0x7f then return 1 end if value <= 0x3fff then return 2 end @@ -43,11 +43,11 @@ function _SignedVarintSize(value) return 5 end -function _TagSize(field_number) +local function _TagSize(field_number) return _VarintSize(wire_format.PackTag(field_number, 0)) end -function _SimpleSizer(compute_value_size) +local function _SimpleSizer(compute_value_size) return function(field_number, is_repeated, is_packed) local tag_size = _TagSize(field_number) if is_packed then @@ -75,7 +75,7 @@ function _SimpleSizer(compute_value_size) end end -function _ModifiedSizer(compute_value_size, modify_value) +local function _ModifiedSizer(compute_value_size, modify_value) return function (field_number, is_repeated, is_packed) local tag_size = _TagSize(field_number) if is_packed then @@ -103,7 +103,7 @@ function _ModifiedSizer(compute_value_size, modify_value) end end -function _FixedSizer(value_size) +local function _FixedSizer(value_size) return function (field_number, is_repeated, is_packed) local tag_size = _TagSize(field_number) if is_packed then @@ -126,27 +126,27 @@ function _FixedSizer(value_size) end end -Int32Sizer = _SimpleSizer(_SignedVarintSize) -Int64Sizer = Int32Sizer -EnumSizer = Int32Sizer +encoder.Int32Sizer = _SimpleSizer(_SignedVarintSize) +encoder.Int64Sizer = encoder.Int32Sizer +encoder.EnumSizer = encoder.Int32Sizer -UInt32Sizer = _SimpleSizer(_VarintSize) -UInt64Sizer = UInt32Sizer +encoder.UInt32Sizer = _SimpleSizer(_VarintSize) +encoder.UInt64Sizer = encoder.UInt32Sizer -SInt32Sizer = _ModifiedSizer(_SignedVarintSize, wire_format.ZigZagEncode) -SInt64Sizer = SInt32Sizer +encoder.SInt32Sizer = _ModifiedSizer(_SignedVarintSize, wire_format.ZigZagEncode) +encoder.SInt64Sizer = encoder.SInt32Sizer -Fixed32Sizer = _FixedSizer(4) -SFixed32Sizer = Fixed32Sizer -FloatSizer = Fixed32Sizer +encoder.Fixed32Sizer = _FixedSizer(4) +encoder.SFixed32Sizer = encoder.Fixed32Sizer +encoder.FloatSizer = encoder.Fixed32Sizer -Fixed64Sizer = _FixedSizer(8) -SFixed64Sizer = Fixed64Sizer -DoubleSizer = Fixed64Sizer +encoder.Fixed64Sizer = _FixedSizer(8) +encoder.SFixed64Sizer = encoder.Fixed64Sizer +encoder.DoubleSizer = encoder.Fixed64Sizer -BoolSizer = _FixedSizer(1) +encoder.BoolSizer = _FixedSizer(1) -function StringSizer(field_number, is_repeated, is_packed) +function encoder.StringSizer(field_number, is_repeated, is_packed) local tag_size = _TagSize(field_number) local VarintSize = _VarintSize assert(not is_packed) @@ -167,7 +167,7 @@ function StringSizer(field_number, is_repeated, is_packed) end end -function BytesSizer(field_number, is_repeated, is_packed) +function encoder.BytesSizer(field_number, is_repeated, is_packed) local tag_size = _TagSize(field_number) local VarintSize = _VarintSize assert(not is_packed) @@ -188,7 +188,7 @@ function BytesSizer(field_number, is_repeated, is_packed) end end -function MessageSizer(field_number, is_repeated, is_packed) +function encoder.MessageSizer(field_number, is_repeated, is_packed) local tag_size = _TagSize(field_number) local VarintSize = _VarintSize assert(not is_packed) @@ -212,23 +212,23 @@ end local _EncodeVarint = pb.varint_encoder local _EncodeSignedVarint = pb.signed_varint_encoder -function _VarintBytes(value) +local function _VarintBytes(value) local out = {} - local write = function(value) - out[#out + 1 ] = value + local write = function(v) + out[#out + 1 ] = v end _EncodeSignedVarint(write, value) return table.concat(out) end -function TagBytes(field_number, wire_type) +function encoder.TagBytes(field_number, wire_type) return _VarintBytes(wire_format.PackTag(field_number, wire_type)) end -function _SimpleEncoder(wire_type, encode_value, compute_value_size) +local function _SimpleEncoder(wire_type, encode_value, compute_value_size) return function(field_number, is_repeated, is_packed) if is_packed then - local tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local EncodeVarint = _EncodeVarint return function(write, value) write(tag_bytes) @@ -242,7 +242,7 @@ function _SimpleEncoder(wire_type, encode_value, compute_value_size) end end elseif is_repeated then - local tag_bytes = TagBytes(field_number, wire_type) + local tag_bytes = encoder.TagBytes(field_number, wire_type) return function(write, value) for _, element in ipairs(value) do write(tag_bytes) @@ -250,7 +250,7 @@ function _SimpleEncoder(wire_type, encode_value, compute_value_size) end end else - local tag_bytes = TagBytes(field_number, wire_type) + local tag_bytes = encoder.TagBytes(field_number, wire_type) return function(write, value) write(tag_bytes) encode_value(write, value) @@ -259,10 +259,10 @@ function _SimpleEncoder(wire_type, encode_value, compute_value_size) end end -function _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value) +local function _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value) return function (field_number, is_repeated, is_packed) if is_packed then - local tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local EncodeVarint = _EncodeVarint return function (write, value) write(tag_bytes) @@ -276,7 +276,7 @@ function _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_va end end elseif is_repeated then - local tag_bytes = TagBytes(field_number, wire_type) + local tag_bytes = encoder.TagBytes(field_number, wire_type) return function (write, value) for _, element in ipairs(value) do write(tag_bytes) @@ -284,7 +284,7 @@ function _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_va end end else - local tag_bytes = TagBytes(field_number, wire_type) + local tag_bytes = encoder.TagBytes(field_number, wire_type) return function (write, value) write(tag_bytes) encode_value(write, modify_value(value)) @@ -293,11 +293,11 @@ function _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_va end end -function _StructPackEncoder(wire_type, value_size, format) +local function _StructPackEncoder(wire_type, value_size, format) return function(field_number, is_repeated, is_packed) local struct_pack = pb.struct_pack if is_packed then - local tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local EncodeVarint = _EncodeVarint return function (write, value) write(tag_bytes) @@ -307,7 +307,7 @@ function _StructPackEncoder(wire_type, value_size, format) end end elseif is_repeated then - local tag_bytes = TagBytes(field_number, wire_type) + local tag_bytes = encoder.TagBytes(field_number, wire_type) return function (write, value) for _, element in ipairs(value) do write(tag_bytes) @@ -315,7 +315,7 @@ function _StructPackEncoder(wire_type, value_size, format) end end else - local tag_bytes = TagBytes(field_number, wire_type) + local tag_bytes = encoder.TagBytes(field_number, wire_type) return function (write, value) write(tag_bytes) struct_pack(write, format, value) @@ -324,35 +324,35 @@ function _StructPackEncoder(wire_type, value_size, format) end end -Int32Encoder = _SimpleEncoder(wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize) -Int64Encoder = Int32Encoder -EnumEncoder = Int32Encoder +encoder.Int32Encoder = _SimpleEncoder(wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize) +encoder.Int64Encoder = encoder.Int32Encoder +encoder.EnumEncoder = encoder.Int32Encoder -UInt32Encoder = _SimpleEncoder(wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize) -UInt64Encoder = UInt32Encoder +encoder.UInt32Encoder = _SimpleEncoder(wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize) +encoder.UInt64Encoder = encoder.UInt32Encoder -SInt32Encoder = _ModifiedEncoder( +encoder.SInt32Encoder = _ModifiedEncoder( wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize, wire_format.ZigZagEncode32 ) -SInt64Encoder = _ModifiedEncoder( +encoder.SInt64Encoder = _ModifiedEncoder( wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize, wire_format.ZigZagEncode64 ) -Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('I')) -Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('Q')) -SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('i')) -SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('q')) -FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('f')) -DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('d')) +encoder.Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('I')) +encoder.Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('Q')) +encoder.SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('i')) +encoder.SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('q')) +encoder.FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('f')) +encoder.DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('d')) -function BoolEncoder(field_number, is_repeated, is_packed) +function encoder.BoolEncoder(field_number, is_repeated, is_packed) local false_byte = '\0' local true_byte = '\1' if is_packed then - local tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local EncodeVarint = _EncodeVarint return function (write, value) write(tag_bytes) @@ -366,7 +366,7 @@ function BoolEncoder(field_number, is_repeated, is_packed) end end elseif is_repeated then - local tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) + local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) return function(write, value) for _, element in ipairs(value) do write(tag_bytes) @@ -378,7 +378,7 @@ function BoolEncoder(field_number, is_repeated, is_packed) end end else - local tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) + local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) return function (write, value) write(tag_bytes) if value then @@ -389,8 +389,8 @@ function BoolEncoder(field_number, is_repeated, is_packed) end end -function StringEncoder(field_number, is_repeated, is_packed) - local tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) +function encoder.StringEncoder(field_number, is_repeated, is_packed) + local tag = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local EncodeVarint = _EncodeVarint assert(not is_packed) if is_repeated then @@ -412,8 +412,8 @@ function StringEncoder(field_number, is_repeated, is_packed) end end -function BytesEncoder(field_number, is_repeated, is_packed) - local tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) +function encoder.BytesEncoder(field_number, is_repeated, is_packed) + local tag = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local EncodeVarint = _EncodeVarint assert(not is_packed) if is_repeated then @@ -433,8 +433,8 @@ function BytesEncoder(field_number, is_repeated, is_packed) end end -function MessageEncoder(field_number, is_repeated, is_packed) - local tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) +function encoder.MessageEncoder(field_number, is_repeated, is_packed) + local tag = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local EncodeVarint = _EncodeVarint assert(not is_packed) if is_repeated then @@ -453,3 +453,5 @@ function MessageEncoder(field_number, is_repeated, is_packed) end end end + +return encoder \ No newline at end of file diff --git a/protobuf/init.lua b/protobuf/init.lua index f7d6fb3..263a618 100644 --- a/protobuf/init.lua +++ b/protobuf/init.lua @@ -35,11 +35,11 @@ local encoder = require "protobuf.encoder" local decoder = require "protobuf.decoder" local listener_mod = require "protobuf.listener" local containers = require "protobuf.containers" -local descriptor = require "protobuf.descriptor" -local FieldDescriptor = descriptor.FieldDescriptor +local pb_descriptor = require "protobuf.descriptor" +local FieldDescriptor = pb_descriptor.FieldDescriptor local text_format = require "protobuf.text_format" -module "protobuf" +local protobuf = {} local function make_descriptor(name, descriptor, usable_key) local meta = { @@ -56,7 +56,7 @@ local function make_descriptor(name, descriptor, usable_key) return setmetatable({}, meta) end - _M[name] = setmetatable(descriptor, meta); + protobuf[name] = setmetatable(descriptor, meta); end make_descriptor("Descriptor", {}, { @@ -231,27 +231,6 @@ local TYPE_TO_DECODER = { [FieldDescriptor.TYPE_SINT64] = decoder.SInt64Decoder } -local FIELD_TYPE_TO_WIRE_TYPE = { - [FieldDescriptor.TYPE_DOUBLE] = wire_format.WIRETYPE_FIXED64, - [FieldDescriptor.TYPE_FLOAT] = wire_format.WIRETYPE_FIXED32, - [FieldDescriptor.TYPE_INT64] = wire_format.WIRETYPE_VARINT, - [FieldDescriptor.TYPE_UINT64] = wire_format.WIRETYPE_VARINT, - [FieldDescriptor.TYPE_INT32] = wire_format.WIRETYPE_VARINT, - [FieldDescriptor.TYPE_FIXED64] = wire_format.WIRETYPE_FIXED64, - [FieldDescriptor.TYPE_FIXED32] = wire_format.WIRETYPE_FIXED32, - [FieldDescriptor.TYPE_BOOL] = wire_format.WIRETYPE_VARINT, - [FieldDescriptor.TYPE_STRING] = wire_format.WIRETYPE_LENGTH_DELIMITED, - [FieldDescriptor.TYPE_GROUP] = wire_format.WIRETYPE_START_GROUP, - [FieldDescriptor.TYPE_MESSAGE] = wire_format.WIRETYPE_LENGTH_DELIMITED, - [FieldDescriptor.TYPE_BYTES] = wire_format.WIRETYPE_LENGTH_DELIMITED, - [FieldDescriptor.TYPE_UINT32] = wire_format.WIRETYPE_VARINT, - [FieldDescriptor.TYPE_ENUM] = wire_format.WIRETYPE_VARINT, - [FieldDescriptor.TYPE_SFIXED32] = wire_format.WIRETYPE_FIXED32, - [FieldDescriptor.TYPE_SFIXED64] = wire_format.WIRETYPE_FIXED64, - [FieldDescriptor.TYPE_SINT32] = wire_format.WIRETYPE_VARINT, - [FieldDescriptor.TYPE_SINT64] = wire_format.WIRETYPE_VARINT -} - local function IsTypePackable(field_type) return NON_PACKABLE_TYPES[field_type] == nil end @@ -283,7 +262,7 @@ local function _DefaultValueConstructorForField(field) if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE then local message_type = field.message_type return function (message) - result = message_type._concrete_class() + local result = message_type._concrete_class() result._SetListener(message._listener_for_children) return result end @@ -297,18 +276,25 @@ local function _AttachFieldHelpers(message_meta, field_descriptor) local is_repeated = (field_descriptor.label == FieldDescriptor.LABEL_REPEATED) local is_packed = (field_descriptor.has_options and field_descriptor.GetOptions().packed) - rawset(field_descriptor, "_encoder", TYPE_TO_ENCODER[field_descriptor.type](field_descriptor.number, is_repeated, is_packed)) - rawset(field_descriptor, "_sizer", TYPE_TO_SIZER[field_descriptor.type](field_descriptor.number, is_repeated, is_packed)) + rawset(field_descriptor, "_encoder", + TYPE_TO_ENCODER[field_descriptor.type](field_descriptor.number, is_repeated, is_packed)) + rawset(field_descriptor, "_sizer", + TYPE_TO_SIZER[field_descriptor.type](field_descriptor.number, is_repeated, is_packed)) rawset(field_descriptor, "_default_constructor", _DefaultValueConstructorForField(field_descriptor)) - local AddDecoder = function(wiretype, is_packed) + local AddDecoder = function(wiretype, _is_packed) local tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) - message_meta._decoders_by_tag[tag_bytes] = TYPE_TO_DECODER[field_descriptor.type](field_descriptor.number, is_repeated, is_packed, field_descriptor, field_descriptor._default_constructor) + message_meta._decoders_by_tag[tag_bytes] = + TYPE_TO_DECODER[field_descriptor.type](field_descriptor.number, + is_repeated, + _is_packed, + field_descriptor, + field_descriptor._default_constructor) end - AddDecoder(FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False) + AddDecoder(FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], false) if is_repeated and IsTypePackable(field_descriptor.type) then - AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, true) end end @@ -391,7 +377,7 @@ local function _AddPropertiesForNonRepeatedScalarField(field, message) end local function _AddPropertiesForField(field, message_meta) - constant_name = field.name:upper() .. "_FIELD_NUMBER" + local constant_name = field.name:upper() .. "_FIELD_NUMBER" message_meta._member[constant_name] = field.number if field.label == FieldDescriptor.LABEL_REPEATED then @@ -426,7 +412,8 @@ local _ED_meta = { local _extended_message = rawget(self, "_extended_message") if (extension_handle.label == FieldDescriptor.LABEL_REPEATED or extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE) then - error('Cannot assign to extension "'.. extension_handle.full_name .. '" because it is a repeated or composite type.') + error('Cannot assign to extension "'.. extension_handle.full_name .. + '" because it is a repeated or composite type.') end local type_checker = GetTypeChecker(extension_handle.cpp_type, extension_handle.type) type_checker.CheckValue(value) @@ -493,18 +480,18 @@ end local function _AddListFieldsMethod(message_descriptor, message_meta) message_meta._member.ListFields = function (self) local list_field = function(fields) - local f, s, v = pairs(self._fields) - local iter = function(a, i) + local s, v + local iter = function() while true do - local descriptor, value = f(a, i) - if descriptor == nil then + s, v = next(fields, s) + if s == nil then return - elseif _IsPresent(descriptor, value) then - return descriptor, value + elseif _IsPresent(s, v) then + return s, v end end end - return iter, s, v + return iter end return list_field(self._fields) end @@ -518,12 +505,12 @@ local function _AddHasFieldMethod(message_descriptor, message_meta) end end message_meta._member.HasField = function (self, field_name) - field = singular_fields[field_name] + local field = singular_fields[field_name] if field == nil then error('Protocol message has no singular "'.. field_name.. '" field.') end if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE then - value = self._fields[field] + local value = self._fields[field] return value ~= nil and value._is_present_in_parent else return self._fields[field] @@ -573,7 +560,7 @@ local function _AddHasExtensionMethod(message_meta) error(extension_handle.full_name .. ' is repeated.') end if extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE then - value = self._fields[extension_handle] + local value = self._fields[extension_handle] return value ~= nil and value._is_present_in_parent else return self._fields[extension_handle] @@ -749,6 +736,7 @@ local function _AddIsInitializedMethod(message_descriptor, message_meta) for field, value in message_meta._member.ListFields(self) do if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE then + local name if field.is_extension then name = string.format("(%s)", field.full_name) else @@ -756,15 +744,15 @@ local function _AddIsInitializedMethod(message_descriptor, message_meta) end if field.label == FieldDescriptor.LABEL_REPEATED then for i, element in ipairs(value) do - prefix = string.format("%s[%d].", name, i) - sub_errors = element:FindInitializationErrors() + local prefix = string.format("%s[%d].", name, i) + local sub_errors = element:FindInitializationErrors() for _, e in ipairs(sub_errors) do errors[#errors + 1] = prefix .. e end end else - prefix = name .. "." - sub_errors = value:FindInitializationErrors() + local prefix = name .. "." + local sub_errors = value:FindInitializationErrors() for _, e in ipairs(sub_errors) do errors[#errors + 1] = prefix .. e end @@ -784,6 +772,7 @@ local function _AddMergeFromMethod(message_meta) message_meta._member._Modified(self) local fields = self._fields + local field_value for field, value in pairs(msg._fields) do if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE then @@ -860,7 +849,7 @@ local function property_setter(message_meta) end end -function _AddClassAttributesForNestedExtensions(descriptor, message_meta) +local function _AddClassAttributesForNestedExtensions(descriptor, message_meta) local extension_dict = descriptor._extensions_by_name for extension_name, extension_field in pairs(extension_dict) do message_meta._member[extension_name] = extension_field @@ -894,7 +883,7 @@ local function Message(descriptor) if rawget(descriptor, "_concrete_class") == nil then rawset(descriptor, "_concrete_class", ns) - for k, field in ipairs(descriptor.fields) do + for _, field in ipairs(descriptor.fields) do _AttachFieldHelpers(message_meta, field) end end @@ -912,4 +901,6 @@ local function Message(descriptor) return ns end -_M.Message = Message +protobuf.Message = Message + +return protobuf diff --git a/protobuf/listener.lua b/protobuf/listener.lua index b8ff5ca..c9375fa 100644 --- a/protobuf/listener.lua +++ b/protobuf/listener.lua @@ -18,14 +18,14 @@ local setmetatable = setmetatable -module "protobuf.listener" +local listener = {} local _null_listener = { Modified = function() end } -function NullMessageListener() +function listener.NullMessageListener() return _null_listener end @@ -41,10 +41,12 @@ local _listener_meta = { } _listener_meta.__index = _listener_meta -function Listener(parent_message) +function listener.Listener(parent_message) local o = {} o.__mode = "v" o._parent_message = parent_message o.dirty = false return setmetatable(o, _listener_meta) end + +return listener \ No newline at end of file diff --git a/protobuf/text_format.lua b/protobuf/text_format.lua index 51fcde7..4902d6e 100644 --- a/protobuf/text_format.lua +++ b/protobuf/text_format.lua @@ -26,9 +26,9 @@ local tostring = tostring local descriptor = require "protobuf.descriptor" -module "protobuf.text_format" +local text_format = {} -function format(buffer) +function text_format.format(buffer) local len = string.len( buffer ) for i = 1, len, 16 do local text = "" @@ -41,7 +41,7 @@ end local FieldDescriptor = descriptor.FieldDescriptor -msg_format_indent = function(write, msg, indent) +text_format.msg_format_indent = function(write, msg, indent) for field, value in msg:ListFields() do local print_field = function(field_value) local name = field.name @@ -53,7 +53,7 @@ msg_format_indent = function(write, msg, indent) else write(name .. " {\n") end - msg_format_indent(write, field_value, indent + 4) + text_format.msg_format_indent(write, field_value, indent + 4) write(string.rep(" ", indent)) write("}\n") else @@ -70,11 +70,13 @@ msg_format_indent = function(write, msg, indent) end end -function msg_format(msg) +function text_format.msg_format(msg) local out = {} local write = function(value) out[#out + 1] = value end - msg_format_indent(write, msg, 0) + text_format.msg_format_indent(write, msg, 0) return table.concat(out) end + +return text_format \ No newline at end of file diff --git a/protobuf/type_checkers.lua b/protobuf/type_checkers.lua index a00b342..4c26d1a 100644 --- a/protobuf/type_checkers.lua +++ b/protobuf/type_checkers.lua @@ -20,9 +20,9 @@ local type = type local error = error local string = string -module "protobuf.type_checkers" +local type_checkers = {} -function TypeChecker(acceptable_types) +function type_checkers.TypeChecker(acceptable_types) local acceptable_types = acceptable_types return function(proposed_value) local t = type(proposed_value) @@ -33,7 +33,7 @@ function TypeChecker(acceptable_types) end end -function Int32ValueChecker() +function type_checkers.Int32ValueChecker() local _MIN = -2147483648 local _MAX = 2147483647 return function(proposed_value) @@ -47,7 +47,7 @@ function Int32ValueChecker() end end -function Uint32ValueChecker(IntValueChecker) +function type_checkers.Uint32ValueChecker(IntValueChecker) local _MIN = 0 local _MAX = 0xffffffff return function(proposed_value) @@ -61,10 +61,12 @@ function Uint32ValueChecker(IntValueChecker) end end -function UnicodeValueChecker() +function type_checkers.UnicodeValueChecker() return function (proposed_value) if type(proposed_value) ~= 'string' then error(string.format('%s has type %s, but expected one of: string', proposed_value, type(proposed_value))) end end end + +return type_checkers \ No newline at end of file diff --git a/protobuf/wire_format.lua b/protobuf/wire_format.lua index 810ff43..6b04125 100644 --- a/protobuf/wire_format.lua +++ b/protobuf/wire_format.lua @@ -18,15 +18,15 @@ local pb = require "protobuf.pb" -module "protobuf.wire_format" +local wire_format = {} -WIRETYPE_VARINT = 0 -WIRETYPE_FIXED64 = 1 -WIRETYPE_LENGTH_DELIMITED = 2 -WIRETYPE_START_GROUP = 3 -WIRETYPE_END_GROUP = 4 -WIRETYPE_FIXED32 = 5 -_WIRETYPE_MAX = 5 +wire_format.WIRETYPE_VARINT = 0 +wire_format.WIRETYPE_FIXED64 = 1 +wire_format.WIRETYPE_LENGTH_DELIMITED = 2 +wire_format.WIRETYPE_START_GROUP = 3 +wire_format.WIRETYPE_END_GROUP = 4 +wire_format.WIRETYPE_FIXED32 = 5 +wire_format._WIRETYPE_MAX = 5 -- we don't need uint64 local function _VarUInt64ByteSizeNoTag(uint64) @@ -37,94 +37,94 @@ local function _VarUInt64ByteSizeNoTag(uint64) return 5 end -function PackTag(field_number, wire_type) +function wire_format.PackTag(field_number, wire_type) return field_number * 8 + wire_type end -function UnpackTag(tag) +function wire_format.UnpackTag(tag) local wire_type = tag % 8 return (tag - wire_type) / 8, wire_type end -ZigZagEncode32 = pb.zig_zag_encode32 -ZigZagDecode32 = pb.zig_zag_decode32 -ZigZagEncode64 = pb.zig_zag_encode64 -ZigZagDecode64 = pb.zig_zag_decode64 +wire_format.ZigZagEncode32 = pb.zig_zag_encode32 +wire_format.ZigZagDecode32 = pb.zig_zag_decode32 +wire_format.ZigZagEncode64 = pb.zig_zag_encode64 +wire_format.ZigZagDecode64 = pb.zig_zag_decode64 -function Int32ByteSize(field_number, int32) - return Int64ByteSize(field_number, int32) +function wire_format.Int32ByteSize(field_number, int32) + return wire_format.Int64ByteSize(field_number, int32) end -function Int32ByteSizeNoTag(int32) +function wire_format.Int32ByteSizeNoTag(int32) return _VarUInt64ByteSizeNoTag(int32) end -function Int64ByteSize(field_number, int64) - return UInt64ByteSize(field_number, int64) +function wire_format.Int64ByteSize(field_number, int64) + return wire_format.UInt64ByteSize(field_number, int64) end -function UInt32ByteSize(field_number, uint32) - return UInt64ByteSize(field_number, uint32) +function wire_format.UInt32ByteSize(field_number, uint32) + return wire_format.UInt64ByteSize(field_number, uint32) end -function UInt64ByteSize(field_number, uint64) - return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) +function wire_format.UInt64ByteSize(field_number, uint64) + return wire_format.TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) end -function SInt32ByteSize(field_number, int32) - return UInt32ByteSize(field_number, ZigZagEncode(int32)) +function wire_format.SInt32ByteSize(field_number, int32) + return wire_format.UInt32ByteSize(field_number, wire_format.ZigZagEncode(int32)) end -function SInt64ByteSize(field_number, int64) - return UInt64ByteSize(field_number, ZigZagEncode(int64)) +function wire_format.SInt64ByteSize(field_number, int64) + return wire_format.UInt64ByteSize(field_number, wire_format.ZigZagEncode(int64)) end -function Fixed32ByteSize(field_number, fixed32) - return TagByteSize(field_number) + 4 +function wire_format.Fixed32ByteSize(field_number, fixed32) + return wire_format.TagByteSize(field_number) + 4 end -function Fixed64ByteSize(field_number, fixed64) - return TagByteSize(field_number) + 8 +function wire_format.Fixed64ByteSize(field_number, fixed64) + return wire_format.TagByteSize(field_number) + 8 end -function SFixed32ByteSize(field_number, sfixed32) - return TagByteSize(field_number) + 4 +function wire_format.SFixed32ByteSize(field_number, sfixed32) + return wire_format.TagByteSize(field_number) + 4 end -function SFixed64ByteSize(field_number, sfixed64) - return TagByteSize(field_number) + 8 +function wire_format.SFixed64ByteSize(field_number, sfixed64) + return wire_format.TagByteSize(field_number) + 8 end -function FloatByteSize(field_number, flt) - return TagByteSize(field_number) + 4 +function wire_format.FloatByteSize(field_number, flt) + return wire_format.TagByteSize(field_number) + 4 end -function DoubleByteSize(field_number, double) - return TagByteSize(field_number) + 8 +function wire_format.DoubleByteSize(field_number, double) + return wire_format.TagByteSize(field_number) + 8 end -function BoolByteSize(field_number, b) - return TagByteSize(field_number) + 1 +function wire_format.BoolByteSize(field_number, b) + return wire_format.TagByteSize(field_number) + 1 end -function EnumByteSize(field_number, enum) - return UInt32ByteSize(field_number, enum) +function wire_format.EnumByteSize(field_number, enum) + return wire_format.UInt32ByteSize(field_number, enum) end -function StringByteSize(field_number, string) - return BytesByteSize(field_number, string) +function wire_format.StringByteSize(field_number, string) + return wire_format.BytesByteSize(field_number, string) end -function BytesByteSize(field_number, b) - return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(#b) + #b +function wire_format.BytesByteSize(field_number, b) + return wire_format.TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(#b) + #b end -function MessageByteSize(field_number, message) - return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(message.ByteSize()) + message.ByteSize() +function wire_format.MessageByteSize(field_number, message) + return wire_format.TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(message.ByteSize()) + message.ByteSize() end -function MessageSetItemByteSize(field_number, msg) - local total_size = 2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3) +function wire_format.MessageSetItemByteSize(field_number, msg) + local total_size = 2 * wire_format.TagByteSize(1) + wire_format.TagByteSize(2) + wire_format.TagByteSize(3) total_size = total_size + _VarUInt64ByteSizeNoTag(field_number) local message_size = msg.ByteSize() total_size = total_size + _VarUInt64ByteSizeNoTag(message_size) @@ -132,6 +132,8 @@ function MessageSetItemByteSize(field_number, msg) return total_size end -function TagByteSize(field_number) - return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0)) +function wire_format.TagByteSize(field_number) + return _VarUInt64ByteSizeNoTag(wire_format.PackTag(field_number, 0)) end + +return wire_format \ No newline at end of file diff --git a/protoc-plugin/protoc-gen-lua b/protoc-plugin/protoc-gen-lua index b9bdeb3..efa2dd4 100755 --- a/protoc-plugin/protoc-gen-lua +++ b/protoc-plugin/protoc-gen-lua @@ -275,7 +275,7 @@ def code_gen_enum(enum_desc, env): values = [] for i, enum_value in enumerate(enum_desc.value): - values.append(code_gen_enum_item(i, enum_value, env)) + values.append('module.'+code_gen_enum_item(i, enum_value, env)) context('.values = {%s}\n' % ','.join(values)) env.context.append(context.getvalue()) @@ -303,6 +303,8 @@ def code_gen_field(index, field_desc, env): value = field_desc.default_value if field_desc.type == FDP.TYPE_STRING: context('.default_value = \'%s\'\n'%value) + elif field_desc.type == FDP.TYPE_ENUM: + context('.default_value = module.%s\n'%value) else: context('.default_value = %s\n'%value) else: @@ -326,7 +328,7 @@ def code_gen_field(index, field_desc, env): if field_desc.HasField('extendee'): type_name = env.get_ref_name(field_desc.extendee) env.register.append( - '%s.RegisterExtension(%s)\n' % (type_name, obj_name) + 'module.%s.RegisterExtension(module.%s)\n' % (type_name, obj_name) ) context('.type = %d\n' % field_desc.type) @@ -373,7 +375,7 @@ def code_gen_message(message_descriptor, env, containing_type = None): context('.extensions = {%s}\n' % ', '.join(extensions)) if containing_type: - context('.containing_type = %s\n' % containing_type) + context('.containing_type = module.%s\n' % containing_type) env.message.append('module.%s = protobuf.Message(%s)\n' % (full_name, 'module.%s' % obj_name))