diff --git a/stream/stream.go b/stream/stream.go index 699f30f..9771997 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -3,6 +3,7 @@ package stream import ( "bytes" "crypto/sha512" + "fmt" "hash" "io" "math" @@ -150,6 +151,12 @@ func NewEncoderFromFile(file *os.File) *Encoder { return e } +// SetFilename saves provided source filename into stream metadata +func (e *Encoder) SetFilename(filename string) { + e.sd.StreamName = filename + e.sd.SuggestedFileName = sanitizeFilename(filename) +} + // WithIVs sets preset cryptographic material for encoding func (e *Encoder) WithIVs(key []byte, ivs [][]byte) *Encoder { e.sd.Key = key @@ -188,6 +195,7 @@ func (e *Encoder) Next() (Blob, error) { } // Stream creates the whole stream in one call +// TODO: Can be refactored to use Encode method func (e *Encoder) Stream() (Stream, error) { s := make(Stream, 1, 1+int(math.Ceil(float64(e.srcSizeHint)/maxBlobDataSize))) // len starts at 1 and cap is +1 to leave room for sd blob @@ -214,6 +222,37 @@ func (e *Encoder) Stream() (Stream, error) { return s, nil } +// Encode splits the source into blobs and feeds them into handler function +func (e *Encoder) Encode(handler func(string, []byte) error) ([]string, error) { + manifest := []string{} + + for { + blob, err := e.Next() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + + err = handler(blob.HashHex(), blob) + if err != nil { + return nil, fmt.Errorf("cannot process blob: %w", err) + } + manifest = append(manifest, blob.HashHex()) + } + + sdb := e.SDBlob().ToBlob() + h := sdb.HashHex() + err := handler(h, sdb) + if err != nil { + return nil, fmt.Errorf("cannot handle SD blob: %w", err) + } + manifest = append([]string{h}, manifest...) + + return manifest, nil +} + // SDBlob returns the sd blob so far func (e *Encoder) SDBlob() *SDBlob { e.sd.updateStreamHash() diff --git a/stream/stream_test.go b/stream/stream_test.go index 697f75d..4ff7127 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -7,11 +7,15 @@ import ( "crypto/sha512" "encoding/hex" "io" + "io/ioutil" "os" + "path" "path/filepath" "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" + "gotest.tools/assert" ) var testdataBlobHashes = []string{ @@ -57,6 +61,9 @@ func TestStreamToFile(t *testing.T) { enc := NewEncoderFromSD(bytes.NewBuffer(data), sdBlob) newStream, err := enc.Stream() + if err != nil { + t.Fatal(err) + } if len(newStream) != len(testdataBlobHashes) { t.Fatalf("stream length mismatch. got %d blobs, expected %d", len(newStream), len(testdataBlobHashes)) @@ -140,6 +147,78 @@ func TestMakeStream(t *testing.T) { } } +func TestEncode(t *testing.T) { + blobsToRead := 3 + totalBlobs := blobsToRead + 3 + + data := make([]byte, ((totalBlobs-1)*maxBlobDataSize)+1000) // last blob is partial + _, err := rand.Read(data) + if err != nil { + t.Fatal(err) + } + + buf := bytes.NewBuffer(data) + + enc := NewEncoder(buf) + + stream := make(Stream, blobsToRead+1) // +1 for sd blob + for i := 1; i < blobsToRead+1; i++ { // start at 1 to skip sd blob + stream[i], err = enc.Next() + if err != nil { + t.Fatal(err) + } + } + + sdBlob := enc.SDBlob() + + if len(sdBlob.BlobInfos) != blobsToRead { + t.Errorf("expected %d blobs in partial sdblob, got %d", blobsToRead, len(sdBlob.BlobInfos)) + } + if enc.SourceLen() != maxBlobDataSize*blobsToRead { + t.Errorf("expected length of %d , got %d", maxBlobDataSize*blobsToRead, enc.SourceLen()) + } + + // now finish the stream, reusing key and IVs + buf = bytes.NewBuffer(data) // rewind to the beginning of the data + + enc = NewEncoderFromSD(buf, sdBlob) + + outPath := t.TempDir() + handler := func(h string, b []byte) error { + return os.WriteFile(path.Join(outPath, h), b, os.ModePerm) + } + writtenManifest, err := enc.Encode(handler) + if err != nil { + t.Fatal(err) + } + + if len(writtenManifest) != totalBlobs+1 { // +1 for the terminating blob at the end + t.Errorf("expected %d blobs in stream, got %d", totalBlobs+1, len(writtenManifest)) + } + if enc.SourceLen() != len(data) { + t.Errorf("expected length of %d , got %d", len(data), enc.SourceLen()) + } + + sdb, err := ioutil.ReadFile(path.Join(outPath, writtenManifest[0])) + if err != nil { + t.Fatal(err) + } + osdb := enc.SDBlob().ToBlob() + + if !bytes.Equal(osdb, sdb) { + t.Errorf("written sd blob does not match original sd blob") + } + for i := 1; i < len(stream); i++ { // start at 1 to skip sd blob + b, err := ioutil.ReadFile(path.Join(outPath, writtenManifest[i])) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(stream[i], b) { + t.Errorf("blob %d of reconstructed stream does not match original stream", i) + } + } +} + func TestEmptyStream(t *testing.T) { enc := NewEncoder(bytes.NewBuffer(nil)) _, err := enc.Next() @@ -204,15 +283,24 @@ func TestNew(t *testing.T) { } func TestNewEncoderFromFile(t *testing.T) { - f, err := os.Open(filepath.Join("testdata", `new "encoder" from file.whatever`)) - if err != nil { - t.Error(err) - return - } + sketchyFile := filepath.Join(t.TempDir(), `new "encoder" from file.whatever...`) + file, err := os.OpenFile(sketchyFile, os.O_RDONLY|os.O_CREATE, 0644) + require.NoError(t, err) + file.Close() + file, err = os.Open(sketchyFile) + require.NoError(t, err) - e := NewEncoderFromFile(f) + e := NewEncoderFromFile(file) if e.sd.SuggestedFileName != "new encoder from file.whatever" { t.Error("wrong or missing suggested_file_name in sd blob") } } + +func TestSetFilename(t *testing.T) { + enc := NewEncoder(bytes.NewBuffer(nil)) + enc.SetFilename(`filename "sketchy" string`) + + assert.Equal(t, "filename sketchy string", enc.sd.SuggestedFileName) + assert.Equal(t, `filename "sketchy" string`, enc.sd.StreamName) +} diff --git "a/stream/testdata/new \"encoder\" from file.whatever" "b/stream/testdata/new \"encoder\" from file.whatever" deleted file mode 100644 index e69de29..0000000