diff --git a/errors.go b/errors.go index c4d341c..8220b57 100644 --- a/errors.go +++ b/errors.go @@ -7,7 +7,7 @@ package marshmallow import ( "errors" "fmt" - "github.com/mailru/easyjson/jlexer" + jlexer "github.com/perimeterx/marshmallow/internal/lexer" "reflect" "strings" ) diff --git a/example_test.go b/example_test.go index 8741988..11e9940 100644 --- a/example_test.go +++ b/example_test.go @@ -162,7 +162,7 @@ func ExampleUnmarshalErrorHandling() { // Output: // ModeFailOnFirstError and valid value: v={Foo:bar Boo:[1 2 3]}, result=map[boo:[1 2 3] foo:bar], err= - // ModeFailOnFirstError and invalid value: result=map[], err=*jlexer.LexerError + // ModeFailOnFirstError and invalid value: result=map[], err=*lexer.LexerError // ModeAllowMultipleErrors and valid value: v={Foo:bar Boo:[1 2 3]}, result=map[boo:[1 2 3] foo:bar], err= // ModeAllowMultipleErrors and invalid value: result=map[boo:[1 2 3]], err=*marshmallow.MultipleLexerError // ModeFailOverToOriginalValue and valid value: v={Foo:bar Boo:[1 2 3]}, result=map[boo:[1 2 3] foo:bar], err= diff --git a/go.mod b/go.mod index 6a80ba2..b99c03d 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,5 @@ go 1.17 require ( github.com/go-test/deep v1.0.8 - github.com/mailru/easyjson v0.7.7 github.com/ugorji/go/codec v1.2.7 ) - -require github.com/josharian/intern v1.0.0 // indirect diff --git a/go.sum b/go.sum index 2b41360..85e77d0 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,5 @@ github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= -github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= diff --git a/internal/lexer/lexer.go b/internal/lexer/lexer.go new file mode 100644 index 0000000..5a86eca --- /dev/null +++ b/internal/lexer/lexer.go @@ -0,0 +1,765 @@ +// Copyright 2026 PerimeterX. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +// Package lexer provides a JSON lexer adapted from github.com/mailru/easyjson/jlexer v0.7.7. +package lexer + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "unicode" + "unicode/utf16" + "unicode/utf8" +) + +// LexerError represents a parse error encountered during JSON lexing. +type LexerError struct { + Reason string + Offset int + Data string +} + +func (l *LexerError) Error() string { + return fmt.Sprintf("parse error: %s near offset %d of '%s'", l.Reason, l.Offset, l.Data) +} + +// tokenKind determines type of a token. +type tokenKind byte + +const ( + tokenUndef tokenKind = iota // No token. + tokenDelim // Delimiter: one of '{', '}', '[' or ']'. + tokenString // A string literal, e.g. "abcሴ" + tokenNumber // Number literal, e.g. 1.5e5 + tokenBool // Boolean literal: true or false. + tokenNull // null keyword. +) + +// token describes a single token: type, position in the input and value. +type token struct { + kind tokenKind // Type of a token. + + boolValue bool // Value if a boolean literal token. + byteValueCloned bool // true if byteValue was allocated and does not refer to original json body + byteValue []byte // Raw value of a token. + delimValue byte +} + +// Lexer is a JSON lexer: it iterates over JSON tokens in a byte slice. +type Lexer struct { + Data []byte // Input data given to the lexer. + + start int // Start of the current token. + pos int // Current unscanned position in the input stream. + tok token // Last scanned token, if tok.kind != tokenUndef. + firstElement bool // Whether current element is the first in array or an object. + wantSep byte // A comma or a colon character, which need to occur before a token. + + UseMultipleErrors bool // If we want to use multiple errors. + fatalError error // Fatal error occurred during lexing. It is usually a syntax error. + multipleErrors []*LexerError // Semantic errors occurred during lexing. Marshalling will be continued after finding these errors. +} + +// fetchToken scans the input for the next token. +func (r *Lexer) fetchToken() { + r.tok.kind = tokenUndef + r.start = r.pos + + // Check if r.Data has r.pos element. + // If it doesn't, it means corrupted input data. + if len(r.Data) < r.pos { + r.errParse("Unexpected end of data") + return + } + // Determine the type of a token by skipping whitespace and reading the + // first character. + for _, c := range r.Data[r.pos:] { + switch c { + case ':', ',': + if r.wantSep == c { + r.pos++ + r.start++ + r.wantSep = 0 + } else { + r.errSyntax() + } + case ' ', '\t', '\r', '\n': + r.pos++ + r.start++ + case '"': + if r.wantSep != 0 { + r.errSyntax() + } + r.tok.kind = tokenString + r.fetchString() + return + case '{', '[': + if r.wantSep != 0 { + r.errSyntax() + } + r.firstElement = true + r.tok.kind = tokenDelim + r.tok.delimValue = r.Data[r.pos] + r.pos++ + return + case '}', ']': + if !r.firstElement && r.wantSep != ',' { + r.errSyntax() + } + r.wantSep = 0 + r.tok.kind = tokenDelim + r.tok.delimValue = r.Data[r.pos] + r.pos++ + return + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': + if r.wantSep != 0 { + r.errSyntax() + } + r.tok.kind = tokenNumber + r.fetchNumber() + return + case 'n': + if r.wantSep != 0 { + r.errSyntax() + } + r.tok.kind = tokenNull + r.fetchNull() + return + case 't': + if r.wantSep != 0 { + r.errSyntax() + } + r.tok.kind = tokenBool + r.tok.boolValue = true + r.fetchTrue() + return + case 'f': + if r.wantSep != 0 { + r.errSyntax() + } + r.tok.kind = tokenBool + r.tok.boolValue = false + r.fetchFalse() + return + default: + r.errSyntax() + return + } + } + r.fatalError = io.EOF +} + +// isTokenEnd returns true if the char can follow a non-delimiter token. +func isTokenEnd(c byte) bool { + return c == ' ' || c == '\t' || c == '\r' || c == '\n' || + c == '[' || c == ']' || c == '{' || c == '}' || c == ',' || c == ':' +} + +// fetchNull fetches and checks remaining bytes of null keyword. +func (r *Lexer) fetchNull() { + r.pos += 4 + if r.pos > len(r.Data) || + r.Data[r.pos-3] != 'u' || + r.Data[r.pos-2] != 'l' || + r.Data[r.pos-1] != 'l' || + (r.pos != len(r.Data) && !isTokenEnd(r.Data[r.pos])) { + r.pos -= 4 + r.errSyntax() + } +} + +// fetchTrue fetches and checks remaining bytes of true keyword. +func (r *Lexer) fetchTrue() { + r.pos += 4 + if r.pos > len(r.Data) || + r.Data[r.pos-3] != 'r' || + r.Data[r.pos-2] != 'u' || + r.Data[r.pos-1] != 'e' || + (r.pos != len(r.Data) && !isTokenEnd(r.Data[r.pos])) { + r.pos -= 4 + r.errSyntax() + } +} + +// fetchFalse fetches and checks remaining bytes of false keyword. +func (r *Lexer) fetchFalse() { + r.pos += 5 + if r.pos > len(r.Data) || + r.Data[r.pos-4] != 'a' || + r.Data[r.pos-3] != 'l' || + r.Data[r.pos-2] != 's' || + r.Data[r.pos-1] != 'e' || + (r.pos != len(r.Data) && !isTokenEnd(r.Data[r.pos])) { + r.pos -= 5 + r.errSyntax() + } +} + +// fetchNumber scans a number literal token. +func (r *Lexer) fetchNumber() { + hasE := false + afterE := false + hasDot := false + + r.pos++ + for i, c := range r.Data[r.pos:] { + switch { + case c >= '0' && c <= '9': + afterE = false + case c == '.' && !hasDot: + hasDot = true + case (c == 'e' || c == 'E') && !hasE: + hasE = true + hasDot = true + afterE = true + case (c == '+' || c == '-') && afterE: + afterE = false + default: + r.pos += i + if !isTokenEnd(c) { + r.errSyntax() + } else { + r.tok.byteValue = r.Data[r.start:r.pos] + } + return + } + } + r.pos = len(r.Data) + r.tok.byteValue = r.Data[r.start:] +} + +// findStringLen tries to scan into the string literal for ending quote char to determine required size. +// The size will be exact if no escapes are present and may be inexact if there are escaped chars. +func findStringLen(data []byte) (isValid bool, length int) { + for { + idx := bytes.IndexByte(data, '"') + if idx == -1 { + return false, len(data) + } + if idx == 0 || data[idx-1] != '\\' { + return true, length + idx + } + cnt := 1 + for idx-cnt-1 >= 0 && data[idx-cnt-1] == '\\' { + cnt++ + } + if cnt%2 == 0 { + return true, length + idx + } + length += idx + 1 + data = data[idx+1:] + } +} + +// getUnicode4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getUnicode4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + var val rune + for i := 2; i < 6; i++ { + c := s[i] + var v byte + switch { + case c >= '0' && c <= '9': + v = c - '0' + case c >= 'a' && c <= 'f': + v = c - 'a' + 10 + case c >= 'A' && c <= 'F': + v = c - 'A' + 10 + default: + return -1 + } + val <<= 4 + val |= rune(v) + } + return val +} + +// decodeEscape processes a single escape sequence and returns number of bytes processed. +func decodeEscape(data []byte) (decoded rune, bytesProcessed int, err error) { + if len(data) < 2 { + return 0, 0, errors.New("incorrect escape symbol \\ at the end of token") + } + switch data[1] { + case '"', '/', '\\': + return rune(data[1]), 2, nil + case 'b': + return '\b', 2, nil + case 'f': + return '\f', 2, nil + case 'n': + return '\n', 2, nil + case 'r': + return '\r', 2, nil + case 't': + return '\t', 2, nil + case 'u': + rr := getUnicode4(data) + if rr < 0 { + return 0, 0, errors.New("incorrectly escaped \\uXXXX sequence") + } + read := 6 + if utf16.IsSurrogate(rr) { + rr1 := getUnicode4(data[read:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + read += 6 + rr = dec + } else { + rr = unicode.ReplacementChar + } + } + return rr, read, nil + } + return 0, 0, errors.New("incorrectly escaped bytes") +} + +// fetchString scans a string literal token. +func (r *Lexer) fetchString() { + r.pos++ + data := r.Data[r.pos:] + isValid, length := findStringLen(data) + if !isValid { + r.pos += length + r.errParse("unterminated string literal") + return + } + r.tok.byteValue = data[:length] + r.pos += length + 1 +} + +// unescapeStringToken performs unescaping of string token. +// If no escaping is needed, the original string is returned, otherwise a new one is allocated. +func (r *Lexer) unescapeStringToken() error { + data := r.tok.byteValue + var unescapedData []byte + + for { + i := bytes.IndexByte(data, '\\') + if i == -1 { + break + } + escapedRune, escapedBytes, err := decodeEscape(data[i:]) + if err != nil { + r.errParse(err.Error()) + return err + } + if unescapedData == nil { + unescapedData = make([]byte, 0, len(r.tok.byteValue)) + } + var d [4]byte + s := utf8.EncodeRune(d[:], escapedRune) + unescapedData = append(unescapedData, data[:i]...) + unescapedData = append(unescapedData, d[:s]...) + data = data[i+escapedBytes:] + } + if unescapedData != nil { + r.tok.byteValue = append(unescapedData, data...) + r.tok.byteValueCloned = true + } + return nil +} + +// scanToken scans the next token if no token is currently available in the lexer. +func (r *Lexer) scanToken() { + if r.tok.kind != tokenUndef || r.fatalError != nil { + return + } + r.fetchToken() +} + +// consume resets the current token to allow scanning the next one. +func (r *Lexer) consume() { + r.tok.kind = tokenUndef + r.tok.byteValueCloned = false + r.tok.delimValue = 0 +} + +// Ok returns true if no error (including io.EOF) was encountered during scanning. +func (r *Lexer) Ok() bool { + return r.fatalError == nil +} + +const maxErrorContextLen = 13 + +func (r *Lexer) errParse(what string) { + if r.fatalError != nil { + return + } + var str string + if len(r.Data)-r.pos <= maxErrorContextLen { + str = string(r.Data) + } else { + str = string(r.Data[r.pos:r.pos+maxErrorContextLen-3]) + "..." + } + r.fatalError = &LexerError{ + Reason: what, + Offset: r.pos, + Data: str, + } +} + +func (r *Lexer) errSyntax() { + r.errParse("syntax error") +} + +func (r *Lexer) errInvalidToken(expected string) { + if r.fatalError != nil { + return + } + if r.UseMultipleErrors { + r.pos = r.start + r.consume() + r.SkipRecursive() + switch expected { + case "[": + r.tok.delimValue = ']' + r.tok.kind = tokenDelim + case "{": + r.tok.delimValue = '}' + r.tok.kind = tokenDelim + } + r.addNonfatalError(&LexerError{ + Reason: "expected " + expected, + Offset: r.start, + Data: string(r.Data[r.start:r.pos]), + }) + return + } + var str string + if len(r.tok.byteValue) <= maxErrorContextLen { + str = string(r.tok.byteValue) + } else { + str = string(r.tok.byteValue[:maxErrorContextLen-3]) + "..." + } + r.fatalError = &LexerError{ + Reason: "expected " + expected, + Offset: r.pos, + Data: str, + } +} + +// Delim consumes a token and verifies that it is the given delimiter. +func (r *Lexer) Delim(c byte) { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + if !r.Ok() || r.tok.delimValue != c { + r.consume() // errInvalidToken can change token if UseMultipleErrors is enabled. + r.errInvalidToken(string([]byte{c})) + } else { + r.consume() + } +} + +// IsDelim returns true if there was no scanning error and next token is the given delimiter. +func (r *Lexer) IsDelim(c byte) bool { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + return !r.Ok() || r.tok.delimValue == c +} + +// Null verifies that the next token is null and consumes it. +func (r *Lexer) Null() { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + if !r.Ok() || r.tok.kind != tokenNull { + r.errInvalidToken("null") + } + r.consume() +} + +// IsNull returns true if the next token is a null keyword. +func (r *Lexer) IsNull() bool { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + return r.Ok() && r.tok.kind == tokenNull +} + +// Skip skips a single token. +func (r *Lexer) Skip() { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + r.consume() +} + +// SkipRecursive skips next array or object completely, or just skips a single token if not +// an array/object. +// +// Note: no syntax validation is performed on the skipped data. +func (r *Lexer) SkipRecursive() { + r.scanToken() + var start, end byte + startPos := r.start + + switch r.tok.delimValue { + case '{': + start, end = '{', '}' + case '[': + start, end = '[', ']' + default: + r.consume() + return + } + r.consume() + + level := 1 + inQuotes := false + wasEscape := false + + for i, c := range r.Data[r.pos:] { + switch { + case c == start && !inQuotes: + level++ + case c == end && !inQuotes: + level-- + if level == 0 { + r.pos += i + 1 + if !json.Valid(r.Data[startPos:r.pos]) { + r.pos = len(r.Data) + r.fatalError = &LexerError{ + Reason: "skipped array/object json value is invalid", + Offset: r.pos, + Data: string(r.Data[r.pos:]), + } + } + return + } + case c == '\\' && inQuotes: + wasEscape = !wasEscape + continue + case c == '"' && inQuotes: + inQuotes = wasEscape + case c == '"': + inQuotes = true + } + wasEscape = false + } + r.pos = len(r.Data) + r.fatalError = &LexerError{ + Reason: "EOF reached while skipping array/object or token", + Offset: r.pos, + Data: string(r.Data[r.pos:]), + } +} + +// Raw fetches the next item recursively as a data slice. +func (r *Lexer) Raw() []byte { + r.SkipRecursive() + if !r.Ok() { + return nil + } + return r.Data[r.start:r.pos] +} + +// Consumed reads all remaining bytes from the input, publishing an error if +// there is anything but whitespace remaining. +func (r *Lexer) Consumed() { + if r.pos > len(r.Data) || !r.Ok() { + return + } + for _, c := range r.Data[r.pos:] { + if c != ' ' && c != '\t' && c != '\r' && c != '\n' { + r.AddError(&LexerError{ + Reason: "invalid character '" + string(c) + "' after top-level value", + Offset: r.pos, + Data: string(r.Data[r.pos:]), + }) + return + } + r.pos++ + r.start++ + } +} + +func (r *Lexer) unsafeString(skipUnescape bool) string { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + if !r.Ok() || r.tok.kind != tokenString { + r.errInvalidToken("string") + return "" + } + if !skipUnescape { + if err := r.unescapeStringToken(); err != nil { + r.errInvalidToken("string") + return "" + } + } + ret := string(r.tok.byteValue) + r.consume() + return ret +} + +// UnsafeFieldName returns current member name string token. +func (r *Lexer) UnsafeFieldName(skipUnescape bool) string { + return r.unsafeString(skipUnescape) +} + +// String reads a string literal. +func (r *Lexer) String() string { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + if !r.Ok() || r.tok.kind != tokenString { + r.errInvalidToken("string") + return "" + } + if err := r.unescapeStringToken(); err != nil { + r.errInvalidToken("string") + return "" + } + ret := string(r.tok.byteValue) + r.consume() + return ret +} + +// Bool reads a true or false boolean keyword. +func (r *Lexer) Bool() bool { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + if !r.Ok() || r.tok.kind != tokenBool { + r.errInvalidToken("bool") + return false + } + ret := r.tok.boolValue + r.consume() + return ret +} + +func (r *Lexer) number() string { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + if !r.Ok() || r.tok.kind != tokenNumber { + r.errInvalidToken("number") + return "" + } + ret := string(r.tok.byteValue) + r.consume() + return ret +} + +func (r *Lexer) Float64() float64 { + s := r.number() + if !r.Ok() { + return 0 + } + n, err := strconv.ParseFloat(s, 64) + if err != nil { + r.addNonfatalError(&LexerError{ + Offset: r.start, + Reason: err.Error(), + Data: s, + }) + } + return n +} + +// Interface fetches an interface{} analogous to the 'encoding/json' package. +func (r *Lexer) Interface() interface{} { + if r.tok.kind == tokenUndef && r.Ok() { + r.fetchToken() + } + if !r.Ok() { + return nil + } + switch r.tok.kind { + case tokenString: + return r.String() + case tokenNumber: + return r.Float64() + case tokenBool: + return r.Bool() + case tokenNull: + r.Null() + return nil + } + if r.tok.delimValue == '{' { + r.consume() + ret := map[string]interface{}{} + for !r.IsDelim('}') { + key := r.String() + r.WantColon() + ret[key] = r.Interface() + r.WantComma() + } + r.Delim('}') + if r.Ok() { + return ret + } + return nil + } + if r.tok.delimValue == '[' { + r.consume() + ret := []interface{}{} + for !r.IsDelim(']') { + ret = append(ret, r.Interface()) + r.WantComma() + } + r.Delim(']') + if r.Ok() { + return ret + } + return nil + } + r.errSyntax() + return nil +} + +// WantComma requires a comma to be present before fetching next token. +func (r *Lexer) WantComma() { + r.wantSep = ',' + r.firstElement = false +} + +// WantColon requires a colon to be present before fetching next token. +func (r *Lexer) WantColon() { + r.wantSep = ':' + r.firstElement = false +} + +func (r *Lexer) Error() error { + return r.fatalError +} + +func (r *Lexer) AddError(e error) { + if r.fatalError == nil { + r.fatalError = e + } +} + +func (r *Lexer) AddNonFatalError(e error) { + r.addNonfatalError(&LexerError{ + Offset: r.start, + Data: string(r.Data[r.start:r.pos]), + Reason: e.Error(), + }) +} + +func (r *Lexer) addNonfatalError(err *LexerError) { + if r.UseMultipleErrors { + // We don't want to add errors with the same offset. + if len(r.multipleErrors) != 0 && r.multipleErrors[len(r.multipleErrors)-1].Offset == err.Offset { + return + } + r.multipleErrors = append(r.multipleErrors, err) + return + } + r.fatalError = err +} + +func (r *Lexer) GetNonFatalErrors() []*LexerError { + return r.multipleErrors +} diff --git a/internal/lexer/lexer_test.go b/internal/lexer/lexer_test.go new file mode 100644 index 0000000..a680670 --- /dev/null +++ b/internal/lexer/lexer_test.go @@ -0,0 +1,470 @@ +// Copyright 2026 PerimeterX. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package lexer + +import ( + "reflect" + "testing" +) + +func TestLexerError(t *testing.T) { + e := &LexerError{Reason: "syntax error", Offset: 5, Data: "hello"} + want := "parse error: syntax error near offset 5 of 'hello'" + if got := e.Error(); got != want { + t.Errorf("LexerError.Error() = %q; want %q", got, want) + } +} + +func TestString(t *testing.T) { + for i, tt := range []struct { + toParse string + want string + wantError bool + }{ + {toParse: `"simple string"`, want: "simple string"}, + {toParse: " \r\r\n\t " + `"test"`, want: "test"}, + {toParse: `"\n\t\"\/\\\f\r"`, want: "\n\t\"/\\\f\r"}, + {toParse: `" "`, want: " "}, + {toParse: `" -\t"`, want: " -\t"}, + {toParse: `"��"`, want: "��"}, + {toParse: `"😀"`, want: "😀"}, + {toParse: `"😈"`, want: "😈"}, + {toParse: `"\ud8"`, wantError: true}, + {toParse: `"test"junk`, want: "test"}, + {toParse: `5`, wantError: true}, // not a string + {toParse: `"\x"`, wantError: true}, // invalid escape + {toParse: `"\ud800"`, want: "�"}, // lone surrogate → Unicode replacement char + } { + l := Lexer{Data: []byte(tt.toParse)} + got := l.String() + if got != tt.want { + t.Errorf("[%d, %q] String() = %v; want %v", i, tt.toParse, got, tt.want) + } + if tt.wantError && l.Ok() { + t.Errorf("[%d, %q] String() ok; want error", i, tt.toParse) + } + if !tt.wantError && !l.Ok() { + t.Errorf("[%d, %q] String() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestUnsafeFieldName(t *testing.T) { + for i, tt := range []struct { + toParse string + skipUnescape bool + want string + wantError bool + }{ + {toParse: `"field"`, want: "field"}, + {toParse: `"field\nname"`, want: "field\nname"}, + {toParse: `"field\nname"`, skipUnescape: true, want: `field\nname`}, + {toParse: `"A"`, want: "A"}, + {toParse: `123`, wantError: true}, + {toParse: `true`, wantError: true}, + } { + l := Lexer{Data: []byte(tt.toParse)} + got := l.UnsafeFieldName(tt.skipUnescape) + if !tt.wantError && got != tt.want { + t.Errorf("[%d, %q] UnsafeFieldName() = %q; want %q", i, tt.toParse, got, tt.want) + } + if tt.wantError && l.Ok() { + t.Errorf("[%d, %q] UnsafeFieldName() ok; want error", i, tt.toParse) + } + if !tt.wantError && !l.Ok() { + t.Errorf("[%d, %q] UnsafeFieldName() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestNumber(t *testing.T) { + for i, tt := range []struct { + toParse string + want string + wantError bool + }{ + {toParse: "123", want: "123"}, + {toParse: "-123", want: "-123"}, + {toParse: "\r\n12.35", want: "12.35"}, + {toParse: "12.35e+1", want: "12.35e+1"}, + {toParse: "12.35e-15", want: "12.35e-15"}, + {toParse: "12.35E-15", want: "12.35E-15"}, + {toParse: "12.35E15", want: "12.35E15"}, + {toParse: `"a"`, wantError: true}, + {toParse: "123junk", wantError: true}, + {toParse: "1.2.3", wantError: true}, + {toParse: "1e2e3", wantError: true}, + {toParse: "1e2.3", wantError: true}, + } { + l := Lexer{Data: []byte(tt.toParse)} + got := l.number() + if got != tt.want { + t.Errorf("[%d, %q] number() = %v; want %v", i, tt.toParse, got, tt.want) + } + if tt.wantError && l.Ok() { + t.Errorf("[%d, %q] number() ok; want error", i, tt.toParse) + } + if !tt.wantError && !l.Ok() { + t.Errorf("[%d, %q] number() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestBool(t *testing.T) { + for i, tt := range []struct { + toParse string + want bool + wantError bool + }{ + {toParse: "true", want: true}, + {toParse: "false", want: false}, + {toParse: "1", wantError: true}, + {toParse: "truejunk", wantError: true}, + {toParse: `false"junk"`, wantError: true}, + {toParse: "True", wantError: true}, + {toParse: "False", wantError: true}, + } { + l := Lexer{Data: []byte(tt.toParse)} + got := l.Bool() + if got != tt.want { + t.Errorf("[%d, %q] Bool() = %v; want %v", i, tt.toParse, got, tt.want) + } + if tt.wantError && l.Ok() { + t.Errorf("[%d, %q] Bool() ok; want error", i, tt.toParse) + } + if !tt.wantError && !l.Ok() { + t.Errorf("[%d, %q] Bool() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestSkipRecursive(t *testing.T) { + for i, tt := range []struct { + toParse string + left string + wantError bool + }{ + {toParse: "5, 4", left: ", 4"}, + {toParse: "[5, 6], 4", left: ", 4"}, + {toParse: "[5, [7,8]]: 4", left: ": 4"}, + {toParse: `{"a":1}, 4`, left: ", 4"}, + {toParse: `{"a":1, "b":{"c": 5}, "e":[12,15]}, 4`, left: ", 4"}, + // array start/end chars in a string + {toParse: `[5, "]"], 4`, left: ", 4"}, + {toParse: `[5, "\"]"], 4`, left: ", 4"}, + {toParse: `[5, "["], 4`, left: ", 4"}, + {toParse: `[5, "\"["], 4`, left: ", 4"}, + // object start/end chars in a string + {toParse: `{"a}":1}, 4`, left: ", 4"}, + {toParse: `{"a\"}":1}, 4`, left: ", 4"}, + {toParse: `{"a{":1}, 4`, left: ", 4"}, + {toParse: `{"a\"{":1}, 4`, left: ", 4"}, + // object with double slashes at end of string + {toParse: `{"a":"hey\\"}, 4`, left: ", 4"}, + // invalid JSON inside nested structure + {toParse: `{"a": [ ##invalid json## ]}, 4`, wantError: true}, + {toParse: `{"a": [ [1], [ ##invalid json## ]]}, 4`, wantError: true}, + } { + l := Lexer{Data: []byte(tt.toParse)} + l.SkipRecursive() + got := string(l.Data[l.pos:]) + if got != tt.left { + t.Errorf("[%d, %q] SkipRecursive() left = %v; want %v", i, tt.toParse, got, tt.left) + } + if tt.wantError && l.Ok() { + t.Errorf("[%d, %q] SkipRecursive() ok; want error", i, tt.toParse) + } + if !tt.wantError && !l.Ok() { + t.Errorf("[%d, %q] SkipRecursive() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestInterface(t *testing.T) { + for i, tt := range []struct { + toParse string + want interface{} + wantError bool + }{ + {toParse: "null", want: nil}, + {toParse: "true", want: true}, + {toParse: `"a"`, want: "a"}, + {toParse: "5", want: float64(5)}, + {toParse: `{}`, want: map[string]interface{}{}}, + {toParse: `[]`, want: []interface{}{}}, + {toParse: `{"a": "b"}`, want: map[string]interface{}{"a": "b"}}, + {toParse: `[5]`, want: []interface{}{float64(5)}}, + {toParse: `{"a":5 , "b" : "string"}`, want: map[string]interface{}{"a": float64(5), "b": "string"}}, + {toParse: `["a", 5 , null, true]`, want: []interface{}{"a", float64(5), nil, true}}, + {toParse: `{"a" "b"}`, wantError: true}, + {toParse: `{"a": "b",}`, wantError: true}, + {toParse: `{"a":"b","c" "b"}`, wantError: true}, + {toParse: `{"a": "b","c":"d",}`, wantError: true}, + {toParse: `{,}`, wantError: true}, + {toParse: `[1, 2,]`, wantError: true}, + {toParse: `[1 2]`, wantError: true}, + {toParse: `[,]`, wantError: true}, + } { + l := Lexer{Data: []byte(tt.toParse)} + got := l.Interface() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("[%d, %q] Interface() = %v; want %v", i, tt.toParse, got, tt.want) + } + if tt.wantError && l.Ok() { + t.Errorf("[%d, %q] Interface() ok; want error", i, tt.toParse) + } + if !tt.wantError && !l.Ok() { + t.Errorf("[%d, %q] Interface() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestConsumed(t *testing.T) { + for i, tt := range []struct { + toParse string + wantError bool + }{ + {toParse: "", wantError: false}, + {toParse: " ", wantError: false}, + {toParse: "\r\n", wantError: false}, + {toParse: "\t\t", wantError: false}, + {toParse: "{", wantError: true}, + } { + l := Lexer{Data: []byte(tt.toParse)} + l.Consumed() + if tt.wantError && l.Ok() { + t.Errorf("[%d, %q] Consumed() ok; want error", i, tt.toParse) + } + if !tt.wantError && !l.Ok() { + t.Errorf("[%d, %q] Consumed() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestFetchStringUnterminated(t *testing.T) { + for _, tt := range []struct { + data []byte + }{ + {data: []byte(`"string without trailing quote`)}, + {data: []byte(`"\"`)}, + {data: []byte{'"'}}, + } { + l := Lexer{Data: tt.data} + l.fetchString() + if l.pos > len(l.Data) { + t.Errorf("fetchString(%s): pos=%v must not exceed len(Data)=%v", tt.data, l.pos, len(l.Data)) + } + if l.Error() == nil { + t.Errorf("fetchString(%s): expected parse error, got none", tt.data) + } + } +} + +func TestIsNull(t *testing.T) { + for i, tt := range []struct { + toParse string + wantNull bool + }{ + {toParse: "null", wantNull: true}, + {toParse: " null", wantNull: true}, + {toParse: `"null"`, wantNull: false}, + {toParse: "true", wantNull: false}, + {toParse: "{}", wantNull: false}, + } { + l := Lexer{Data: []byte(tt.toParse)} + got := l.IsNull() + if got != tt.wantNull { + t.Errorf("[%d, %q] IsNull() = %v; want %v", i, tt.toParse, got, tt.wantNull) + } + } +} + +func TestSkip(t *testing.T) { + for i, tt := range []struct { + toParse string + }{ + {toParse: "null"}, + {toParse: "true"}, + {toParse: "false"}, + {toParse: "42"}, + {toParse: `"str"`}, + {toParse: `{}`}, + {toParse: `[]`}, + } { + l := Lexer{Data: []byte(tt.toParse)} + l.Skip() + if !l.Ok() { + t.Errorf("[%d, %q] Skip() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestIsDelim(t *testing.T) { + for i, tt := range []struct { + toParse string + delim byte + want bool + }{ + {toParse: "{}", delim: '{', want: true}, + {toParse: "[]", delim: '[', want: true}, + {toParse: "{}", delim: '[', want: false}, + {toParse: "42", delim: '{', want: false}, + {toParse: `"str"`, delim: '{', want: false}, + } { + l := Lexer{Data: []byte(tt.toParse)} + got := l.IsDelim(tt.delim) + if got != tt.want { + t.Errorf("[%d, %q] IsDelim(%c) = %v; want %v", i, tt.toParse, tt.delim, got, tt.want) + } + } +} + +func TestDelim(t *testing.T) { + t.Run("correct delimiter consumed without error", func(t *testing.T) { + l := Lexer{Data: []byte(`{"key":"val"}`)} + l.Delim('{') + if !l.Ok() { + t.Errorf("Delim('{'): unexpected error: %v", l.Error()) + } + }) + t.Run("wrong delimiter sets error", func(t *testing.T) { + l := Lexer{Data: []byte(`{"key":"val"}`)} + l.Delim('[') + if l.Ok() { + t.Error("Delim('['): expected error on '{' input, got none") + } + }) +} + +func TestRaw(t *testing.T) { + for i, tt := range []struct { + toParse string + want string + }{ + {toParse: `"hello"`, want: `"hello"`}, + {toParse: `42`, want: `42`}, + {toParse: `true`, want: `true`}, + {toParse: `null`, want: `null`}, + {toParse: `{"a":1}`, want: `{"a":1}`}, + {toParse: `[1,2,3]`, want: `[1,2,3]`}, + {toParse: `{"a":[1,{"b":2}]}`, want: `{"a":[1,{"b":2}]}`}, + } { + l := Lexer{Data: []byte(tt.toParse)} + got := l.Raw() + if string(got) != tt.want { + t.Errorf("[%d, %q] Raw() = %q; want %q", i, tt.toParse, got, tt.want) + } + if !l.Ok() { + t.Errorf("[%d, %q] Raw() error: %v", i, tt.toParse, l.Error()) + } + } +} + +func TestNonFatalErrors(t *testing.T) { + // With UseMultipleErrors=true, non-fatal errors are collected without halting. + // Use different offsets so de-duplication does not apply. + l := Lexer{Data: []byte(`{"a":1,"b":2}`), UseMultipleErrors: true} + l.addNonfatalError(&LexerError{Reason: "first", Offset: 0}) + l.addNonfatalError(&LexerError{Reason: "second", Offset: 5}) + + errs := l.GetNonFatalErrors() + if len(errs) != 2 { + t.Fatalf("GetNonFatalErrors() = %d errors; want 2", len(errs)) + } + if errs[0].Reason != "first" || errs[1].Reason != "second" { + t.Errorf("unexpected reasons: %q, %q", errs[0].Reason, errs[1].Reason) + } + // Non-fatal errors must not set fatalError. + if !l.Ok() { + t.Error("Ok() = false; want true when only non-fatal errors present") + } +} + +func TestNonFatalErrorsFallbackWhenDisabled(t *testing.T) { + // With UseMultipleErrors=false, a non-fatal error becomes the fatal error. + l := Lexer{Data: []byte(`{}`), UseMultipleErrors: false} + l.AddNonFatalError(&LexerError{Reason: "test"}) + if l.Ok() { + t.Error("Ok() = true; want false") + } + if l.Error() == nil { + t.Error("Error() = nil; want non-nil") + } +} + +func TestAddError(t *testing.T) { + l := Lexer{Data: []byte(`{}`)} + l.AddError(&LexerError{Reason: "fatal"}) + if l.Ok() { + t.Error("Ok() = true after AddError; want false") + } + // Second AddError must not overwrite the first. + l.AddError(&LexerError{Reason: "second"}) + lexErr, ok := l.Error().(*LexerError) + if !ok || lexErr.Reason != "fatal" { + t.Errorf("first error should be preserved; got %v", l.Error()) + } +} + +func TestDeduplicateNonFatalErrors(t *testing.T) { + // Two errors at the same offset must be stored only once. + l := Lexer{Data: []byte(`{}`), UseMultipleErrors: true} + l.addNonfatalError(&LexerError{Reason: "first", Offset: 5}) + l.addNonfatalError(&LexerError{Reason: "duplicate", Offset: 5}) + + if errs := l.GetNonFatalErrors(); len(errs) != 1 { + t.Errorf("expected 1 deduplicated error; got %d", len(errs)) + } +} + +func TestFullObjectParse(t *testing.T) { + // Simulate the token sequence marshmallow uses when walking a JSON object. + input := []byte(`{"name":"alice","age":30,"active":true,"score":null}`) + l := Lexer{Data: input} + + l.Delim('{') + + key := l.UnsafeFieldName(false) + l.WantColon() + val := l.Interface() + l.WantComma() + if key != "name" || val != "alice" { + t.Errorf(`field "name": got key=%q val=%v`, key, val) + } + + key = l.UnsafeFieldName(false) + l.WantColon() + val = l.Interface() + l.WantComma() + if key != "age" || val != float64(30) { + t.Errorf(`field "age": got key=%q val=%v`, key, val) + } + + key = l.UnsafeFieldName(false) + l.WantColon() + val = l.Interface() + l.WantComma() + if key != "active" || val != true { + t.Errorf(`field "active": got key=%q val=%v`, key, val) + } + + key = l.UnsafeFieldName(false) + l.WantColon() + if !l.IsNull() { + t.Error("IsNull() = false; want true for null value") + } + l.Skip() + l.WantComma() + if key != "score" { + t.Errorf(`field key = %q; want "score"`, key) + } + + l.Delim('}') + l.Consumed() + + if !l.Ok() { + t.Errorf("full object parse error: %v", l.Error()) + } +} diff --git a/unmarshal.go b/unmarshal.go index 160ea30..80869eb 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -6,7 +6,7 @@ package marshmallow import ( "encoding/json" - "github.com/mailru/easyjson/jlexer" + jlexer "github.com/perimeterx/marshmallow/internal/lexer" "reflect" ) diff --git a/unmarshal_test.go b/unmarshal_test.go index c579e34..48da62d 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -9,7 +9,7 @@ import ( "errors" "fmt" "github.com/go-test/deep" - "github.com/mailru/easyjson/jlexer" + jlexer "github.com/perimeterx/marshmallow/internal/lexer" "reflect" "strings" "testing"