diff --git a/vindex/map.go b/vindex/map.go index 1b2b39a..b95b4f1 100644 --- a/vindex/map.go +++ b/vindex/map.go @@ -19,18 +19,12 @@ package vindex import ( - "bufio" "context" "crypto/sha256" "encoding/binary" - "encoding/hex" - "errors" "fmt" "io" "iter" - "os" - "strconv" - "strings" "sync" "time" @@ -72,10 +66,7 @@ type OpenCheckpointFn func(cpRaw []byte) (*log.Checkpoint, error) // Note that only one IndexBuilder should exist for any given walPath at any time. The behaviour is unspecified, // but likely broken, if multiple processes are writing to the same file at any given time. func NewVerifiableIndex(ctx context.Context, inputLog InputLog, inputLogParseFn OpenCheckpointFn, mapFn MapFn, walPath string) (*VerifiableIndex, error) { - wal := &walWriter{ - walPath: walPath, - } - ws, err := wal.init() + wal, ws, err := newWalWriter(walPath) if err != nil { return nil, err } @@ -333,200 +324,3 @@ func (b *VerifiableIndex) buildMap(ctx context.Context, toSize uint64) error { klog.Infof("buildMap: total=%s (wal=%s, vindex=%s)", durationTotal, durationWal, durationVIndex) return nil } - -// walWriter provides the methods needed by the processor of the Input Log when interacting -// with the WAL. init provides the index that this processor should start from, and append -// allows new mapped entries to be added to the WAL. -type walWriter struct { - walPath string - f *os.File -} - -// init verifies that the log is in good shape, and returns the index that is expected next. -// It also opens the log for appending to. -// -// Note that it returns the next expected index to avoid awkwardness with the meaning of 0, -// which could mean 0 was successfully read from a previous run, or that there was no log. -func (l *walWriter) init() (uint64, error) { - ffs := os.O_WRONLY | os.O_APPEND - - idx, err := l.validate() - if err != nil { - if !errors.Is(err, os.ErrNotExist) { - return idx, err - } - ffs |= os.O_CREATE | os.O_EXCL - } else { - // If the file exists, then we expect the next index to be returned - idx++ - } - // Open the file for writing in append-only, creating it if needed - l.f, err = os.OpenFile(l.walPath, ffs, 0o644) - if err != nil { - return 0, fmt.Errorf("failed to open file for writing: %s", err) - } - return idx, err -} - -func (l *walWriter) close() error { - return l.f.Close() -} - -// validate reads the file and determines what the last mapped log index was, and returns it. -// The assumption is that all lines ending with a newline were written correctly. -// If there are any errors in the file then this throws an error. -func (l *walWriter) validate() (uint64, error) { - f, err := os.Open(l.walPath) - if err != nil { - return 0, err - } - defer func() { - _ = f.Close() - }() - fi, err := f.Stat() - if err != nil { - return 0, err - } - - // Handle trivial case of empty file - size := fi.Size() - if size == 0 { - if err := os.Remove(l.walPath); err != nil { - return 0, fmt.Errorf("failed to delete empty file: %s", err) - } - return 0, os.ErrNotExist - } - - // Confirm last character is a newline - // TODO(mhutchinson): support ignoring incomplete lines - lastChar := make([]byte, 1) - if _, err := f.ReadAt(lastChar, size-1); err != nil { - return 0, err - } - if lastChar[0] != '\n' { - return 0, fmt.Errorf("expected final newline but got '%x'", lastChar[0]) - } - - // Read from the end of the file in stripes, terminating when we either: - // a) find another newline; or - // b) we have read from the beginning of the file - var lastLine string - const stripeSize = 1024 - readStripe := make([]byte, stripeSize) - // Set it up so we read all but the last character (we know it's a newline) - currOffset := size - 1 - stripeSize - - for { - if currOffset < 0 { - // If the stripe is bigger than the remaining file contents, adjust the offset - // and scale down what we'll read to avoid reading duplicates. - readStripe = readStripe[:stripeSize+currOffset] - currOffset = 0 - } - if _, err := f.ReadAt(readStripe, currOffset); err != nil { - return 0, err - } - lastLine = string(readStripe) + lastLine - if idx := strings.LastIndexByte(lastLine, '\n'); idx > 0 { - lastLine = lastLine[idx+1:] - break - } - if currOffset == 0 { - // We read from the start of the file so lastLine is full - break - } - currOffset = currOffset - stripeSize - } - - idx, _, err := unmarshalWalEntry(lastLine) - - return idx, err -} - -func (l *walWriter) append(idx uint64, hashes [][32]byte) error { - e, err := marshalWalEntry(idx, hashes) - if err != nil { - return fmt.Errorf("failed to marshal entry: %v", err) - } - _, err = fmt.Fprintf(l.f, "%s\n", e) - return err -} - -func newWalReader(path string) (*walReader, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - return &walReader{ - f: f, - r: bufio.NewReader(f), - }, nil -} - -type walReader struct { - f *os.File - r *bufio.Reader - partial string -} - -// next returns the next index, hashes, and any error. -// TODO(mhutchinson): change this as it's inconvenient with EOF handling, -// which should be common when reader hits the end of the file but more is -// to be written. -func (r *walReader) next() (uint64, [][32]byte, error) { - line, err := r.r.ReadString('\n') - if err != nil { - if err == io.EOF { - r.partial = line - } - return 0, nil, err - } - - // Make sure any partial lines are prepended, and drop the final newline - line = r.partial + line[:len(line)-1] - r.partial = "" - return unmarshalWalEntry(line) -} - -func (r *walReader) close() error { - return r.f.Close() -} - -// unmarshalWalEntry parses a line from the WAL. -// This is the reverse of marshalWalEntry. -func unmarshalWalEntry(e string) (uint64, [][32]byte, error) { - tokens := strings.Split(e, " ") - idx, err := strconv.ParseUint(tokens[0], 10, 64) - if err != nil { - return 0, nil, fmt.Errorf("failed to parse idx from %q", e) - } - - hashes := make([][32]byte, 0, len(tokens)-1) - for i, h := range tokens[1:] { - parsed, err := hex.DecodeString(h) - if err != nil { - return 0, nil, fmt.Errorf("failed to parse hex token %d from %q", i, e) - } - if got, want := len(parsed), 32; got != want { - return 0, nil, fmt.Errorf("expected 32 byte hash but got %d bytes at idx %d", got, i) - } - hashes = append(hashes, [32]byte(parsed)) - } - - return idx, hashes, nil -} - -// unmarshalWalEntry converts an index and the hashes it affects into a line for the WAL. -// This is the reverse of unmarshalWalEntry. -func marshalWalEntry(idx uint64, hashes [][32]byte) (string, error) { - sb := strings.Builder{} - if _, err := sb.WriteString(strconv.FormatUint(idx, 10)); err != nil { - return "", err - } - for _, h := range hashes { - if _, err := sb.WriteString(" " + hex.EncodeToString(h[:])); err != nil { - return "", err - } - } - return sb.String(), nil -} diff --git a/vindex/map_test.go b/vindex/map_test.go index f31330e..8262ee5 100644 --- a/vindex/map_test.go +++ b/vindex/map_test.go @@ -21,12 +21,9 @@ import ( "context" "crypto/sha256" "encoding/hex" - "fmt" - "io" "iter" "os" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/transparency-dev/formats/log" @@ -34,7 +31,6 @@ import ( "github.com/transparency-dev/merkle/rfc6962" "github.com/transparency-dev/merkle/testonly" "golang.org/x/mod/sumdb/note" - "golang.org/x/sync/errgroup" ) const ( @@ -42,241 +38,6 @@ const ( vkey = "logandmap+38581672+Ab/PCr1eCclRPRMBqw/r5An1xO71MCnImLiospEq6b4l" ) -func TestWriteAheadLog_init(t *testing.T) { - testCases := []struct { - desc string - fileContents string - wantIdx uint64 - wantErr bool - }{ - { - desc: "empty file", - fileContents: "", - wantIdx: 0, - wantErr: false, - }, { - desc: "0 file", - fileContents: "0\n", - wantIdx: 1, - wantErr: false, - }, { - desc: "just indexes", - fileContents: "0\n1\n2\n", - wantIdx: 3, - wantErr: false, - }, { - desc: "indexes and hashes", - fileContents: fmt.Sprintf("1 %s %s\n", mustHashEncode("1a"), mustHashEncode("1b")), - wantIdx: 2, - wantErr: false, - }, { - desc: "trailing corruption", - fileContents: "1\n2 fdfxx", - wantErr: true, - }, { - desc: "lots of newlines", - fileContents: "1\n2\n3\n\n", - wantErr: true, - }, { - desc: "no trailing newlines", - fileContents: "1\n2\n3", - wantErr: true, - }, - } - for _, tC := range testCases { - t.Run(tC.desc, func(t *testing.T) { - f, err := os.CreateTemp("", "testWal") - if err != nil { - t.Fatal(err) - } - if _, err := f.WriteString(tC.fileContents); err != nil { - t.Fatal(err) - } - if err := f.Close(); err != nil { - t.Fatal(err) - } - wal := &walWriter{ - walPath: f.Name(), - } - idx, err := wal.init() - if gotErr := err != nil; gotErr != tC.wantErr { - t.Fatalf("wantErr != gotErr (%t != %t) %v", tC.wantErr, gotErr, err) - } - defer func() { - _ = wal.close() - }() - if tC.wantErr { - return - } - if idx != tC.wantIdx { - t.Errorf("want idx %v but got %v", tC.wantIdx, idx) - } - }) - } -} - -func TestWriteAheadLog_roundtrip(t *testing.T) { - f, err := os.CreateTemp("", "testWal") - if err != nil { - t.Fatal(err) - } - if err := f.Close(); err != nil { - t.Fatal(err) - } - if err := os.Remove(f.Name()); err != nil { - t.Fatal(err) - } - - wal := &walWriter{ - walPath: f.Name(), - } - idx, err := wal.init() - if err != nil { - t.Fatal(err) - } - if got, want := idx, uint64(0); got != want { - t.Fatalf("expected index %d, got %d", want, got) - } - - for i := range 33 { - hash := sha256.Sum256([]byte{byte(i)}) - if err := wal.append(uint64(i), [][32]byte{hash}); err != nil { - t.Error(err) - } - } - - if err := wal.close(); err != nil { - t.Error(err) - } - - idx, err = wal.init() - if err != nil { - t.Fatal(err) - } - if got, want := idx, uint64(33); got != want { - t.Fatalf("expected index %d, got %d", want, got) - } - - if err := wal.close(); err != nil { - t.Error(err) - } -} - -func TestWriteAndWriteLog(t *testing.T) { - f, err := os.CreateTemp("", "testWal") - if err != nil { - t.Fatal(err) - } - if err := f.Close(); err != nil { - t.Fatal(err) - } - if err := os.Remove(f.Name()); err != nil { - t.Fatal(err) - } - - wal := &walWriter{ - walPath: f.Name(), - } - idx, err := wal.init() - if err != nil { - t.Fatal(err) - } - if got, want := idx, uint64(0); got != want { - t.Fatalf("expected index %d, got %d", want, got) - } - - reader, err := newWalReader(f.Name()) - if err != nil { - t.Fatal(err) - } - - const count = 2056 - var eg errgroup.Group - eg.Go(func() error { - for i := range count { - hash := sha256.Sum256([]byte{byte(i)}) - if err := wal.append(uint64(i), [][32]byte{hash}); err != nil { - return err - } - } - return nil - }) - eg.Go(func() error { - var expect uint64 - for expect < count { - idx, _, err := reader.next() - if err != nil { - if err != io.EOF { - return err - } - // Wait a small amount of time for more data to become available - time.Sleep(10 * time.Millisecond) - continue - } - if got, want := idx, expect; got != want { - return fmt.Errorf("expected index %d, got %d", want, got) - } - expect++ - } - return nil - }) - if err := eg.Wait(); err != nil { - t.Fatal(err) - } - - if err := wal.close(); err != nil { - t.Error(err) - } - if err := reader.close(); err != nil { - t.Error(err) - } -} - -func TestUnmarshal(t *testing.T) { - testCases := []struct { - desc string - entry string - wantErr bool - wantIdx uint64 - wantHashes int - }{ - { - desc: "just index", - entry: "1", - wantErr: false, - wantIdx: 1, - wantHashes: 0, - }, { - desc: "index and hashes", - entry: fmt.Sprintf("1 %s %s", mustHashEncode("1a"), mustHashEncode("1b")), - wantErr: false, - wantIdx: 1, - wantHashes: 2, - }, { - desc: "corruption at the end", - entry: "1 deadbeef feed01xxx", - wantErr: true, - }, - } - for _, tC := range testCases { - t.Run(tC.desc, func(t *testing.T) { - idx, hashes, err := unmarshalWalEntry(tC.entry) - if gotErr := err != nil; gotErr != tC.wantErr { - t.Fatalf("wantErr != gotErr (%t != %t) %v", tC.wantErr, gotErr, err) - } - if tC.wantErr { - return - } - if idx != tC.wantIdx { - t.Errorf("want idx %v but got %v", tC.wantIdx, idx) - } - if got, want := len(hashes), tC.wantHashes; got != want { - t.Errorf("want %v hashes but got %v: %q", want, got, hashes) - } - }) - } -} - func TestVerifiableIndex(t *testing.T) { ctx := context.Background() s, v, err := fnote.NewEd25519SignerVerifier(skey) diff --git a/vindex/wal.go b/vindex/wal.go new file mode 100644 index 0000000..7ad3dd2 --- /dev/null +++ b/vindex/wal.go @@ -0,0 +1,233 @@ +// Copyright 2025 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vindex + +import ( + "bufio" + "encoding/hex" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" + + "k8s.io/klog/v2" +) + +func newWalWriter(walPath string) (*walWriter, uint64, error) { + w := &walWriter{ + walPath: walPath, + } + idx, err := w.init() + return w, idx, err +} + +// walWriter provides the methods needed by the processor of the Input Log when interacting +// with the WAL. init provides the index that this processor should start from, and append +// allows new mapped entries to be added to the WAL. +type walWriter struct { + walPath string + f *os.File +} + +// init verifies that the log is in good shape, and returns the index that is expected next. +// It also opens the log for appending to. +// +// Note that it returns the next expected index to avoid awkwardness with the meaning of 0, +// which could mean 0 was successfully read from a previous run, or that there was no log. +func (l *walWriter) init() (uint64, error) { + ffs := os.O_WRONLY | os.O_APPEND + + idx, err := validate(l.walPath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return idx, err + } + ffs |= os.O_CREATE | os.O_EXCL + } else { + // If the file exists, then we expect the next index to be returned + idx++ + } + // Open the file for writing in append-only, creating it if needed + l.f, err = os.OpenFile(l.walPath, ffs, 0o644) + if err != nil { + return 0, fmt.Errorf("failed to open file for writing: %s", err) + } + return idx, err +} + +func (l *walWriter) close() error { + return l.f.Close() +} + +// validate reads the file and determines what the last mapped log index was, and returns it. +// The assumption is that all lines ending with a newline were written correctly. +// If there are any errors in the file then this throws an error. +func validate(walPath string) (uint64, error) { + f, err := os.OpenFile(walPath, os.O_RDWR, 0o644) + if err != nil { + return 0, err + } + defer func() { + _ = f.Close() + }() + fi, err := f.Stat() + if err != nil { + return 0, err + } + + // Handle trivial case of empty file + size := fi.Size() + if size == 0 { + if err := os.Remove(walPath); err != nil { + return 0, fmt.Errorf("failed to delete empty file: %s", err) + } + return 0, os.ErrNotExist + } + + // Read from the end of the file in stripes, terminating when we either: + // a) find another newline; or + // b) we have read from the beginning of the file + var buffer string + const stripeSize = 1024 + readStripe := make([]byte, stripeSize) + seekPos := size - stripeSize + droppedTail := false + + for { + if seekPos < 0 { + // If the stripe is bigger than the remaining file contents, adjust the offset + // and scale down what we'll read to avoid reading duplicates. + readStripe = readStripe[:stripeSize+seekPos] + seekPos = 0 + } + if _, err := f.ReadAt(readStripe, seekPos); err != nil { + return 0, err + } + buffer = string(readStripe) + buffer + + for i := strings.LastIndex(buffer, "\n"); i > 0; i = strings.LastIndex(buffer, "\n") { + p := buffer[i+1:] + buffer = buffer[:i] + if !droppedTail { + droppedTail = true + if len(p) > 0 { + truncPos := seekPos + int64(i) + 1 + klog.Warningf("Dropping tail part from WAL: %q", p) + if err := f.Truncate(truncPos); err != nil { + return 0, fmt.Errorf("failed to truncate WAL: %v", err) + } + } + continue + } + idx, _, err := unmarshalWalEntry(p) + return idx, err + } + if seekPos == 0 { + idx, _, err := unmarshalWalEntry(buffer) + return idx, err + } + seekPos = seekPos - stripeSize + } +} + +func (l *walWriter) append(idx uint64, hashes [][32]byte) error { + e, err := marshalWalEntry(idx, hashes) + if err != nil { + return fmt.Errorf("failed to marshal entry: %v", err) + } + _, err = fmt.Fprintf(l.f, "%s\n", e) + return err +} + +func newWalReader(path string) (*walReader, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + return &walReader{ + f: f, + r: bufio.NewReader(f), + }, nil +} + +type walReader struct { + f *os.File + r *bufio.Reader + partial string +} + +// next returns the next index, hashes, and any error. +// TODO(mhutchinson): change this as it's inconvenient with EOF handling, +// which should be common when reader hits the end of the file but more is +// to be written. +func (r *walReader) next() (uint64, [][32]byte, error) { + line, err := r.r.ReadString('\n') + if err != nil { + if err == io.EOF { + r.partial = line + } + return 0, nil, err + } + + // Make sure any partial lines are prepended, and drop the final newline + line = r.partial + line[:len(line)-1] + r.partial = "" + return unmarshalWalEntry(line) +} + +func (r *walReader) close() error { + return r.f.Close() +} + +// unmarshalWalEntry parses a line from the WAL. +// This is the reverse of marshalWalEntry. +func unmarshalWalEntry(e string) (uint64, [][32]byte, error) { + tokens := strings.Split(e, " ") + idx, err := strconv.ParseUint(tokens[0], 10, 64) + if err != nil { + return 0, nil, fmt.Errorf("failed to parse idx from %q", e) + } + + hashes := make([][32]byte, 0, len(tokens)-1) + for i, h := range tokens[1:] { + parsed, err := hex.DecodeString(h) + if err != nil { + return 0, nil, fmt.Errorf("failed to parse hex token %d from %q", i, e) + } + if got, want := len(parsed), 32; got != want { + return 0, nil, fmt.Errorf("expected 32 byte hash but got %d bytes at idx %d", got, i) + } + hashes = append(hashes, [32]byte(parsed)) + } + + return idx, hashes, nil +} + +// unmarshalWalEntry converts an index and the hashes it affects into a line for the WAL. +// This is the reverse of unmarshalWalEntry. +func marshalWalEntry(idx uint64, hashes [][32]byte) (string, error) { + sb := strings.Builder{} + if _, err := sb.WriteString(strconv.FormatUint(idx, 10)); err != nil { + return "", err + } + for _, h := range hashes { + if _, err := sb.WriteString(" " + hex.EncodeToString(h[:])); err != nil { + return "", err + } + } + return sb.String(), nil +} diff --git a/vindex/wal_test.go b/vindex/wal_test.go new file mode 100644 index 0000000..17d60b1 --- /dev/null +++ b/vindex/wal_test.go @@ -0,0 +1,287 @@ +// Copyright 2025 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// vindex contains a prototype of an in-memory verifiable index. +// This version uses the clone tool DB as the log source. +package vindex + +import ( + "bytes" + "crypto/sha256" + "fmt" + "io" + "os" + "testing" + "time" + + "golang.org/x/sync/errgroup" +) + +func TestWriteAheadLog_init(t *testing.T) { + testCases := []struct { + desc string + fileContents string + wantIdx uint64 + wantErr bool + }{ + { + desc: "empty file", + fileContents: "", + wantIdx: 0, + wantErr: false, + }, { + desc: "0 file", + fileContents: "0\n", + wantIdx: 1, + wantErr: false, + }, { + desc: "just indexes", + fileContents: "0\n1\n2\n", + wantIdx: 3, + wantErr: false, + }, { + desc: "indexes and hashes", + fileContents: fmt.Sprintf("1 %s %s\n", mustHashEncode("1a"), mustHashEncode("1b")), + wantIdx: 2, + wantErr: false, + }, { + desc: "trailing corruption", + fileContents: "1\n2 fdfxx", + wantIdx: 2, + wantErr: false, + }, { + desc: "lots of newlines", + fileContents: "1\n2\n3\n\n", + wantErr: true, + }, { + desc: "no trailing newlines", + fileContents: "1\n2\n3", + wantIdx: 3, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + f, err := os.CreateTemp("", "testWal") + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteString(tC.fileContents); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + wal, idx, err := newWalWriter(f.Name()) + defer func() { + _ = wal.close() + }() + if gotErr := err != nil; gotErr != tC.wantErr { + t.Fatalf("wantErr != gotErr (%t != %t) %v", tC.wantErr, gotErr, err) + } + if tC.wantErr { + return + } + if idx != tC.wantIdx { + t.Errorf("want idx %v but got %v", tC.wantIdx, idx) + } + }) + } +} + +func TestWriteAheadLog_truncate(t *testing.T) { + f, err := os.CreateTemp("", "testWal") + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteString("0\n2\n5 xxabcdeadbeef"); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + wal, idx, err := newWalWriter(f.Name()) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = wal.close() + }() + if got, want := idx, uint64(3); got != want { + t.Errorf("expected next index %d, but got %d", want, got) + } + + contents, err := os.ReadFile(f.Name()) + if err != nil { + t.Fatal(err) + } + if got, want := contents, []byte("0\n2\n"); !bytes.Equal(got, want) { + t.Errorf("expected %v but got %v", want, got) + } +} + +func TestWriteAheadLog_roundtrip(t *testing.T) { + f, err := os.CreateTemp("", "testWal") + if err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + if err := os.Remove(f.Name()); err != nil { + t.Fatal(err) + } + + wal, idx, err := newWalWriter(f.Name()) + if err != nil { + t.Fatal(err) + } + if got, want := idx, uint64(0); got != want { + t.Fatalf("expected index %d, got %d", want, got) + } + + for i := range 33 { + hash := sha256.Sum256([]byte{byte(i)}) + if err := wal.append(uint64(i), [][32]byte{hash}); err != nil { + t.Error(err) + } + } + + if err := wal.close(); err != nil { + t.Error(err) + } + + wal, idx, err = newWalWriter(f.Name()) + if err != nil { + t.Fatal(err) + } + if got, want := idx, uint64(33); got != want { + t.Fatalf("expected index %d, got %d", want, got) + } + + if err := wal.close(); err != nil { + t.Error(err) + } +} + +func TestWriteAndWriteLog(t *testing.T) { + f, err := os.CreateTemp("", "testWal") + if err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + if err := os.Remove(f.Name()); err != nil { + t.Fatal(err) + } + + wal, idx, err := newWalWriter(f.Name()) + if err != nil { + t.Fatal(err) + } + if got, want := idx, uint64(0); got != want { + t.Fatalf("expected index %d, got %d", want, got) + } + + reader, err := newWalReader(f.Name()) + if err != nil { + t.Fatal(err) + } + + const count = 2056 + var eg errgroup.Group + eg.Go(func() error { + for i := range count { + hash := sha256.Sum256([]byte{byte(i)}) + if err := wal.append(uint64(i), [][32]byte{hash}); err != nil { + return err + } + } + return nil + }) + eg.Go(func() error { + var expect uint64 + for expect < count { + idx, _, err := reader.next() + if err != nil { + if err != io.EOF { + return err + } + // Wait a small amount of time for more data to become available + time.Sleep(10 * time.Millisecond) + continue + } + if got, want := idx, expect; got != want { + return fmt.Errorf("expected index %d, got %d", want, got) + } + expect++ + } + return nil + }) + if err := eg.Wait(); err != nil { + t.Fatal(err) + } + + if err := wal.close(); err != nil { + t.Error(err) + } + if err := reader.close(); err != nil { + t.Error(err) + } +} + +func TestUnmarshal(t *testing.T) { + testCases := []struct { + desc string + entry string + wantErr bool + wantIdx uint64 + wantHashes int + }{ + { + desc: "just index", + entry: "1", + wantErr: false, + wantIdx: 1, + wantHashes: 0, + }, { + desc: "index and hashes", + entry: fmt.Sprintf("1 %s %s", mustHashEncode("1a"), mustHashEncode("1b")), + wantErr: false, + wantIdx: 1, + wantHashes: 2, + }, { + desc: "corruption at the end", + entry: "1 deadbeef feed01xxx", + wantErr: true, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + idx, hashes, err := unmarshalWalEntry(tC.entry) + if gotErr := err != nil; gotErr != tC.wantErr { + t.Fatalf("wantErr != gotErr (%t != %t) %v", tC.wantErr, gotErr, err) + } + if tC.wantErr { + return + } + if idx != tC.wantIdx { + t.Errorf("want idx %v but got %v", tC.wantIdx, idx) + } + if got, want := len(hashes), tC.wantHashes; got != want { + t.Errorf("want %v hashes but got %v: %q", want, got, hashes) + } + }) + } +}