diff --git a/index/contentprovider.go b/index/contentprovider.go index 7bb52a719..ab05a1d0d 100644 --- a/index/contentprovider.go +++ b/index/contentprovider.go @@ -16,6 +16,7 @@ package index import ( "bytes" + "fmt" "log" "path" "slices" @@ -99,14 +100,21 @@ func (p *contentProvider) findOffset(filename bool, r uint32) uint32 { return r } - sample := p.id.runeOffsets - runeEnds := p.id.fileEndRunes - fileStartByte := p.id.boundaries[p.idx] + var sample runeOffsetMap + var runeEnds []uint32 + var fileStartByte, fileEndByte uint32 if filename { sample = p.id.fileNameRuneOffsets runeEnds = p.id.fileNameEndRunes fileStartByte = p.id.fileNameIndex[p.idx] + fileEndByte = p.id.fileNameIndex[p.idx+1] + } else { + sample = p.id.runeOffsets + runeEnds = p.id.fileEndRunes + fileStartByte = p.id.boundaries[p.idx] + fileEndByte = p.id.boundaries[p.idx+1] } + fileSize := fileEndByte - fileStartByte absR := r if p.idx > 0 { @@ -118,20 +126,37 @@ func (p *contentProvider) findOffset(filename bool, r uint32) uint32 { var data []byte if filename { + if byteOff > uint32(len(p.id.fileNameContent)) { + p.err = fmt.Errorf("corrupt index: filename rune offset %d maps to byte offset %d past filename data size %d", absR, byteOff, len(p.id.fileNameContent)) + return fileSize + } data = p.id.fileNameContent[byteOff:] } else { data, p.err = p.id.readContentSlice(byteOff, 3*runeOffsetFrequency) if p.err != nil { - return 0 + return fileSize } } for left > 0 { + if len(data) == 0 { + p.err = fmt.Errorf("corrupt index: rune offset %d maps past available data", absR) + return fileSize + } _, sz := utf8.DecodeRune(data) byteOff += uint32(sz) data = data[sz:] left-- } + if byteOff < fileStartByte { + p.err = fmt.Errorf("corrupt index: rune offset %d maps to byte offset %d before file start %d", absR, byteOff, fileStartByte) + return fileSize + } + if byteOff > fileEndByte { + p.err = fmt.Errorf("corrupt index: rune offset %d maps to byte offset %d after file end %d", absR, byteOff, fileEndByte) + return fileSize + } + byteOff -= fileStartByte return byteOff } diff --git a/index/matchiter.go b/index/matchiter.go index df75410b5..3a57cb5e1 100644 --- a/index/matchiter.go +++ b/index/matchiter.go @@ -48,7 +48,14 @@ type candidateMatch struct { // Matches content against the substring, and populates byteMatchSz on success func (m *candidateMatch) matchContent(content []byte) bool { + if int(m.byteOffset) > len(content) { + return false + } + if m.caseSensitive { + if int(m.byteOffset)+len(m.substrBytes) > len(content) { + return false + } comp := bytes.Equal(m.substrBytes, content[m.byteOffset:m.byteOffset+uint32(len(m.substrBytes))]) m.byteMatchSz = uint32(len(m.substrBytes)) diff --git a/index/matchiter_test.go b/index/matchiter_test.go index 5de9075e8..225a3ad70 100644 --- a/index/matchiter_test.go +++ b/index/matchiter_test.go @@ -14,6 +14,7 @@ package index import ( "reflect" + "strings" "testing" ) @@ -38,3 +39,49 @@ tool fieldalignment then update this test.`, c.v, c.size, got) } } } + +func TestCandidateMatchContentOutOfBounds(t *testing.T) { + for _, tc := range []struct { + name string + match candidateMatch + }{ + { + name: "offset past content", + match: candidateMatch{ + byteOffset: 4, + substrLowered: []byte("x"), + }, + }, + { + name: "case-sensitive match extends past content", + match: candidateMatch{ + byteOffset: 2, + substrBytes: []byte("cd"), + caseSensitive: true, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + if tc.match.matchContent([]byte("abc")) { + t.Fatal("matchContent returned true for an out-of-bounds match") + } + }) + } +} + +func TestFindOffsetRejectsByteOffsetBeforeFileStart(t *testing.T) { + cp := contentProvider{ + id: &indexData{ + fileNameContent: []byte("previous/current"), + fileNameIndex: []uint32{9, 16}, + fileNameEndRunes: []uint32{7}, + }, + } + + if got, want := cp.findOffset(true, 0), uint32(7); got != want { + t.Fatalf("findOffset returned %d, want file size %d", got, want) + } + if cp.err == nil || !strings.Contains(cp.err.Error(), "before file start") { + t.Fatalf("findOffset error = %v", cp.err) + } +} diff --git a/index/read.go b/index/read.go index 3a9c05b91..472b4139f 100644 --- a/index/read.go +++ b/index/read.go @@ -23,6 +23,7 @@ import ( "os" "slices" "sort" + "unicode/utf8" "github.com/RoaringBitmap/roaring" "github.com/prometheus/client_golang/prometheus" @@ -513,11 +514,83 @@ func (d *indexData) verify() error { "branch masks": len(d.fileBranchMasks), "doc section index": len(d.docSectionsIndex) - 1, "newlines index": len(d.newlinesIndex) - 1, + "file end runes": len(d.fileEndRunes), + "name end runes": len(d.fileNameEndRunes), } { if got != n { return fmt.Errorf("got %s %d, want %d", what, got, n) } } + + if err := d.verifyRuneBoundaryMapping("content", d.boundaries, d.fileEndRunes, d.runeOffsets, false); err != nil { + return err + } + if err := d.verifyRuneBoundaryMapping("filename", d.fileNameIndex, d.fileNameEndRunes, d.fileNameRuneOffsets, true); err != nil { + return err + } + return nil +} + +func (d *indexData) verifyRuneBoundaryMapping(what string, byteEnds, runeEnds []uint32, sample runeOffsetMap, filename bool) error { + var prevRune uint32 + for i, endRune := range runeEnds { + startByte := byteEnds[i] + endByte := byteEnds[i+1] + if endByte < startByte { + return fmt.Errorf("corrupt index: %s %d byte end %d before start %d", what, i, endByte, startByte) + } + if endRune < prevRune { + return fmt.Errorf("corrupt index: %s %d rune end %d before previous end %d", what, i, endRune, prevRune) + } + + if d.metaData.PlainASCII { + wantRuneEnd := prevRune + endByte - startByte + if endRune != wantRuneEnd { + return fmt.Errorf("corrupt index: plain ASCII %s %d ends at rune %d, want %d from byte boundaries", what, i, endRune, wantRuneEnd) + } + prevRune = endRune + continue + } + + byteOff, left := sample.lookup(endRune) + if byteOff > endByte { + return fmt.Errorf("corrupt index: %s %d rune end %d maps to byte offset %d after byte end %d", what, i, endRune, byteOff, endByte) + } + + var data []byte + if filename { + if byteOff > uint32(len(d.fileNameContent)) { + return fmt.Errorf("corrupt index: %s %d rune end %d maps to byte offset %d past filename data size %d", what, i, endRune, byteOff, len(d.fileNameContent)) + } + if endByte > uint32(len(d.fileNameContent)) { + return fmt.Errorf("corrupt index: %s %d byte end %d past filename data size %d", what, i, endByte, len(d.fileNameContent)) + } + data = d.fileNameContent[byteOff:endByte] + } else { + blob, err := d.readSectionBlob(simpleSection{ + off: d.boundariesStart + byteOff, + sz: endByte - byteOff, + }) + if err != nil { + return err + } + data = blob + } + + for ; left > 0; left-- { + if len(data) == 0 { + return fmt.Errorf("corrupt index: %s %d rune end %d does not have enough bytes before byte end %d", what, i, endRune, endByte) + } + _, sz := utf8.DecodeRune(data) + byteOff += uint32(sz) + data = data[sz:] + } + + if byteOff != endByte { + return fmt.Errorf("corrupt index: %s %d rune end %d maps to byte offset %d, want byte end %d", what, i, endRune, byteOff, endByte) + } + prevRune = endRune + } return nil } diff --git a/index/read_test.go b/index/read_test.go index 087b8a65b..700f4a080 100644 --- a/index/read_test.go +++ b/index/read_test.go @@ -79,6 +79,69 @@ func TestReadWrite(t *testing.T) { } } +func TestVerifyRuneBoundaryMapping(t *testing.T) { + id := &indexData{ + file: &memSeeker{[]byte("éx")}, + boundaries: []uint32{0, 2, 3}, + fileEndRunes: []uint32{1, 2}, + fileNameContent: []byte("abé"), + fileNameIndex: []uint32{0, 1, 4}, + fileNameEndRunes: []uint32{1, 3}, + fileBranchMasks: []uint64{0, 0}, + docSectionsIndex: []uint32{0, 0, 0}, + newlinesIndex: []uint32{0, 0, 0}, + } + + if err := id.verify(); err != nil { + t.Fatalf("verify: %v", err) + } + + id.fileEndRunes[0] = 0 + if err := id.verify(); err == nil || !strings.Contains(err.Error(), "content 0 rune end 0 maps to byte offset 0, want byte end 2") { + t.Fatalf("verify content corruption error = %v", err) + } + id.fileEndRunes[0] = 1 + + id.fileNameEndRunes[1] = 2 + if err := id.verify(); err == nil || !strings.Contains(err.Error(), "filename 1 rune end 2 maps to byte offset 2, want byte end 4") { + t.Fatalf("verify filename corruption error = %v", err) + } +} + +func TestVerifyPlainASCIIRuneBoundaryMapping(t *testing.T) { + id := &indexData{ + file: &memSeeker{[]byte("abc")}, + boundaries: []uint32{0, 1, 3}, + fileEndRunes: []uint32{1, 3}, + fileNameContent: []byte("fg"), + fileNameIndex: []uint32{0, 1, 2}, + fileNameEndRunes: []uint32{1, 2}, + fileBranchMasks: []uint64{0, 0}, + docSectionsIndex: []uint32{0, 0, 0}, + newlinesIndex: []uint32{0, 0, 0}, + metaData: zoekt.IndexMetadata{ + PlainASCII: true, + }, + } + + if err := id.verify(); err != nil { + t.Fatalf("verify: %v", err) + } + + id.file = &memSeeker{[]byte("é")} + id.boundaries = []uint32{0, 2} + id.fileEndRunes = []uint32{1} + id.fileNameContent = []byte("f") + id.fileNameIndex = []uint32{0, 1} + id.fileNameEndRunes = []uint32{1} + id.fileBranchMasks = []uint64{0} + id.docSectionsIndex = []uint32{0, 0} + id.newlinesIndex = []uint32{0, 0} + if err := id.verify(); err == nil || !strings.Contains(err.Error(), "plain ASCII content 0 ends at rune 1, want 2") { + t.Fatalf("verify plain ASCII corruption error = %v", err) + } +} + func TestReadWriteNames(t *testing.T) { b, err := NewShardBuilder(nil) if err != nil {