Skip to content

Commit

Permalink
Merge pull request #88 from iwilltry42/feat/export-collection
Browse files Browse the repository at this point in the history
add: option to only import/export selected collections to/from a DB
  • Loading branch information
philippgille authored Jul 3, 2024
2 parents 4b532fa + a4a5653 commit 935ec30
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 40 deletions.
41 changes: 24 additions & 17 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
safeName := hash2hex(name)
c.persistDirectory = filepath.Join(dbDir, safeName)
c.compress = compress
// Persist name and metadata
metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
metadataPath += ".gob"
if c.compress {
metadataPath += ".gz"
}
pc := struct {
Name string
Metadata map[string]string
}{
Name: name,
Metadata: m,
}
err := persistToFile(metadataPath, pc, compress, "")
if err != nil {
return nil, fmt.Errorf("couldn't persist collection metadata: %w", err)
}
return c, c.persistMetadata()
}

return c, nil
Expand Down Expand Up @@ -545,3 +529,26 @@ func (c *Collection) getDocPath(docID string) string {
}
return docPath
}

// persistMetadata persists the collection metadata to disk
func (c *Collection) persistMetadata() error {
// Persist name and metadata
metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
metadataPath += ".gob"
if c.compress {
metadataPath += ".gz"
}
pc := struct {
Name string
Metadata map[string]string
}{
Name: c.Name,
Metadata: c.metadata,
}
err := persistToFile(metadataPath, pc, c.compress, "")
if err != nil {
return err
}

return nil
}
77 changes: 61 additions & 16 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"sync"
)
Expand Down Expand Up @@ -198,9 +199,12 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
// This works for both the in-memory and persistent DBs.
// Existing collections are overwritten.
//
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
// - collections: Optional. If provided, only the collections with the given names
// are imported. Non-existing collections are ignored.
// If not provided, all collections are imported.
func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections ...string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down Expand Up @@ -244,6 +248,9 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
}

for _, pc := range persistenceDB.Collections {
if len(collections) > 0 && !slices.Contains(collections, pc.Name) {
continue
}
c := &Collection{
Name: pc.Name,

Expand All @@ -253,6 +260,17 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
c.compress = db.compress
err = c.persistMetadata()
if err != nil {
return fmt.Errorf("couldn't persist collection metadata: %w", err)
}
for _, doc := range c.documents {
docPath := c.getDocPath(doc.ID)
err = persistToFile(docPath, doc, c.compress, "")
if err != nil {
return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
}
}
}
db.collections[c.Name] = c
}
Expand All @@ -267,9 +285,12 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
// Existing collections are overwritten.
// If the writer has to be closed, it's the caller's responsibility.
//
// - reader: An implementation of [io.ReadSeeker]
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
// - reader: An implementation of [io.ReadSeeker]
// - encryptionKey: Optional, must be 32 bytes long if provided
// - collections: Optional. If provided, only the collections with the given names
// are imported. Non-existing collections are ignored.
// If not provided, all collections are imported.
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string, collections ...string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
Expand Down Expand Up @@ -299,6 +320,9 @@ func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error
}

for _, pc := range persistenceDB.Collections {
if len(collections) > 0 && !slices.Contains(collections, pc.Name) {
continue
}
c := &Collection{
Name: pc.Name,

Expand All @@ -308,6 +332,17 @@ func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
c.compress = db.compress
err = c.persistMetadata()
if err != nil {
return fmt.Errorf("couldn't persist collection metadata: %w", err)
}
for _, doc := range c.documents {
docPath := c.getDocPath(doc.ID)
err := persistToFile(docPath, doc, c.compress, "")
if err != nil {
return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
}
}
}
db.collections[c.Name] = c
}
Expand Down Expand Up @@ -339,7 +374,10 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error {
// - collections: Optional. If provided, only the collections with the given names
// are exported. Non-existing collections are ignored.
// If not provided, all collections are exported.
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string, collections ...string) error {
if filePath == "" {
filePath = "./chromem-go.gob"
if compress {
Expand Down Expand Up @@ -373,10 +411,12 @@ func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string)
defer db.collectionsLock.RUnlock()

for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
if len(collections) == 0 || slices.Contains(collections, k) {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}
}

Expand All @@ -397,7 +437,10 @@ func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string)
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error {
// - collections: Optional. If provided, only the collections with the given names
// are exported. Non-existing collections are ignored.
// If not provided, all collections are exported.
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string, collections ...string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
Expand All @@ -422,10 +465,12 @@ func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey stri
defer db.collectionsLock.RUnlock()

for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
if len(collections) == 0 || slices.Contains(collections, k) {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}
}

Expand Down
130 changes: 123 additions & 7 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ func TestDB_ImportExport(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
// Create DB, can just be in-memory
orig := NewDB()
origDB := NewDB()

// Create collection
c, err := orig.CreateCollection(name, metadata, embeddingFunc)
c, err := origDB.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand All @@ -139,15 +139,15 @@ func TestDB_ImportExport(t *testing.T) {
}

// Export
err = orig.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
err = origDB.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}

new := NewDB()
newDB := NewDB()

// Import
err = new.ImportFromFile(tc.filePath, tc.encryptionKey)
err = newDB.ImportFromFile(tc.filePath, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand All @@ -156,13 +156,129 @@ func TestDB_ImportExport(t *testing.T) {
// We have to reset the embed function, but otherwise the DB objects
// should be deep equal.
c.embed = nil
if !reflect.DeepEqual(orig, new) {
t.Fatalf("expected DB %+v, got %+v", orig, new)
if !reflect.DeepEqual(origDB, newDB) {
t.Fatalf("expected DB %+v, got %+v", origDB, newDB)
}
})
}
}

func TestDB_ImportExportSpecificCollections(t *testing.T) {
r := rand.New(rand.NewSource(rand.Int63()))
randString := randomString(r, 10)
path := filepath.Join(os.TempDir(), randString)
filePath := path + ".gob"
defer os.RemoveAll(path)

// Values in the collection
name := "test"
name2 := "test2"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}

// Create DB, can just be in-memory
origDB := NewDB()

// Create collections
c, err := origDB.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}

c2, err := origDB.CreateCollection(name2, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Add documents
doc := Document{
ID: name,
Metadata: metadata,
Embedding: vectors,
Content: "test",
}

doc2 := Document{
ID: name2,
Metadata: metadata,
Embedding: vectors,
Content: "test2",
}

err = c.AddDocument(context.Background(), doc)
if err != nil {
t.Fatal("expected no error, got", err)
}

err = c2.AddDocument(context.Background(), doc2)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Export only one of the two collections
err = origDB.ExportToFile(filePath, false, "", name2)
if err != nil {
t.Fatal("expected no error, got", err)
}

dir := filepath.Join(path, randomString(r, 10))
defer os.RemoveAll(dir)

// Instead of importing to an in-memory DB we use a persistent one to cover the behavior of immediate persistent files being created for the imported data
newPDB, err := NewPersistentDB(dir, false)
if err != nil {
t.Fatal("expected no error, got", err)
}

err = newPDB.ImportFromFile(filePath, "")
if err != nil {
t.Fatal("expected no error, got", err)
}

if len(newPDB.collections) != 1 {
t.Fatalf("expected 1 collection, got %d", len(newPDB.collections))
}

// Make sure that the imported documents are actually persisted on disk
for _, col := range newPDB.collections {
for _, d := range col.documents {
_, err = os.Stat(col.getDocPath(d.ID))
if err != nil {
t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err)
}
}
}

// Now export both collections and import them into the same persistent DB (overwriting the one we just imported)
filePath2 := filepath.Join(path, "2.gob")
err = origDB.ExportToFile(filePath2, false, "")
if err != nil {
t.Fatal("expected no error, got", err)
}

err = newPDB.ImportFromFile(filePath2, "")
if err != nil {
t.Fatal("expected no error, got", err)
}

if len(newPDB.collections) != 2 {
t.Fatalf("expected 2 collections, got %d", len(newPDB.collections))
}

// Make sure that the imported documents are actually persisted on disk
for _, col := range newPDB.collections {
for _, d := range col.documents {
_, err = os.Stat(col.getDocPath(d.ID))
if err != nil {
t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err)
}
}
}
}

func TestDB_CreateCollection(t *testing.T) {
// Values in the collection
name := "test"
Expand Down

0 comments on commit 935ec30

Please sign in to comment.