Skip to content

Commit

Permalink
Merge pull request #72 from philippgille/import-from-reader
Browse files Browse the repository at this point in the history
Import from io.Reader
  • Loading branch information
philippgille authored May 5, 2024
2 parents 82f4efe + ff98d0a commit 041e3c0
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 19 deletions.
69 changes: 69 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,21 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
//
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
//
// Deprecated: Use [DB.ImportFromFile] instead.
func (db *DB) Import(filePath string, encryptionKey string) error {
return db.ImportFromFile(filePath, encryptionKey)
}

// ImportFromFile imports the DB from a file at the given path. The file must be
// encoded as gob and can optionally be compressed with flate (as gzip) and encrypted
// with AES-GCM.
// This works for both the in-memory and persistent DBs.
// Existing collections are overwritten.
//
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down Expand Up @@ -246,6 +260,61 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
return nil
}

// ImportFromReader imports the DB from a reader. The stream must be encoded as
// gob and can optionally be compressed with flate (as gzip) and encrypted with
// AES-GCM.
// This works for both the in-memory and persistent DBs.
// Existing collections are overwritten.
// If the writer has to be closed, it's the caller's responsibility.
//
// - reader: An implementation of [io.ReadSeeker]
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// Create persistence structs with exported fields so that they can be decoded
// from gob.
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}

db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()

err := readFromReader(reader, &persistenceDB, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't read stream: %w", err)
}

for _, pc := range persistenceDB.Collections {
c := &Collection{
Name: pc.Name,

metadata: pc.Metadata,
documents: pc.Documents,
}
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
c.compress = db.compress
}
db.collections[c.Name] = c
}

return nil
}

// Export exports the DB to a file at the given path. The file is encoded as gob,
// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
// This works for both the in-memory and persistent DBs.
Expand Down
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func TestDB_ImportExport(t *testing.T) {
new := NewDB()

// Import
err = new.Import(tc.filePath, tc.encryptionKey)
err = new.ImportFromFile(tc.filePath, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down
57 changes: 39 additions & 18 deletions persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,43 @@ func readFromFile(filePath string, obj any, encryptionKey string) error {
}
}

r, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("couldn't open file: %w", err)
}
defer r.Close()

return readFromReader(r, obj, encryptionKey)
}

// readFromReader reads an object from a Reader. The object is deserialized from gob.
// `obj` must be a pointer to an instantiated object. The stream may optionally
// be compressed as gzip and/or encrypted with AES-GCM. The encryption key must
// be 32 bytes long.
// If the reader has to be closed, it's the caller's responsibility.
func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error {
// AES 256 requires a 32 byte key
if encryptionKey != "" {
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// We want to:
// Read file -> decrypt with AES-GCM -> decompress with flate -> decode as gob
// Read from reader -> decrypt with AES-GCM -> decompress with flate -> decode
// as gob.
// To reduce memory usage we chain the readers instead of buffering, so we start
// from the end. For the decryption there's no reader though.

var r io.Reader
// For the chainedReader we don't declare it as ReadSeeker so we can reassign
// the gzip reader to it.
var chainedReader io.Reader

// Decrypt if an encryption key is provided
if encryptionKey != "" {
encrypted, err := os.ReadFile(filePath)
encrypted, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("couldn't read file: %w", err)
return fmt.Errorf("couldn't read from reader: %w", err)
}
block, err := aes.NewCipher([]byte(encryptionKey))
if err != nil {
Expand All @@ -194,28 +219,24 @@ func readFromFile(filePath string, obj any, encryptionKey string) error {
return fmt.Errorf("couldn't decrypt data: %w", err)
}

r = bytes.NewReader(data)
chainedReader = bytes.NewReader(data)
} else {
var err error
r, err = os.Open(filePath)
if err != nil {
return fmt.Errorf("couldn't open file: %w", err)
}
chainedReader = r
}

// Determine if the file is compressed
// Determine if the stream is compressed
magicNumber := make([]byte, 2)
_, err := r.Read(magicNumber)
_, err := chainedReader.Read(magicNumber)
if err != nil {
return fmt.Errorf("couldn't read magic number to determine whether the file is compressed: %w", err)
return fmt.Errorf("couldn't read magic number to determine whether the stream is compressed: %w", err)
}
var compressed bool
if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b {
compressed = true
}

// Reset reader. Both file and bytes.Reader support seeking.
if s, ok := r.(io.Seeker); !ok {
// Reset reader. Both the reader from the param and bytes.Reader support seeking.
if s, ok := chainedReader.(io.Seeker); !ok {
return fmt.Errorf("reader doesn't support seeking")
} else {
_, err := s.Seek(0, 0)
Expand All @@ -225,15 +246,15 @@ func readFromFile(filePath string, obj any, encryptionKey string) error {
}

if compressed {
gzr, err := gzip.NewReader(r)
gzr, err := gzip.NewReader(chainedReader)
if err != nil {
return fmt.Errorf("couldn't create gzip reader: %w", err)
}
defer gzr.Close()
r = gzr
chainedReader = gzr
}

dec := gob.NewDecoder(r)
dec := gob.NewDecoder(chainedReader)
err = dec.Decode(obj)
if err != nil {
return fmt.Errorf("couldn't decode object: %w", err)
Expand Down

0 comments on commit 041e3c0

Please sign in to comment.