Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: option to only import/export selected collections to/from a DB #88

Merged
merged 4 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
safeName := hash2hex(name)
c.persistDirectory = filepath.Join(dbDir, safeName)
c.compress = compress
// Persist name and metadata
metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
metadataPath += ".gob"
if c.compress {
metadataPath += ".gz"
}
pc := struct {
Name string
Metadata map[string]string
}{
Name: name,
Metadata: m,
}
err := persistToFile(metadataPath, pc, compress, "")
if err != nil {
return nil, fmt.Errorf("couldn't persist collection metadata: %w", err)
}
return c, c.persistMetadata()
}

return c, nil
Expand Down Expand Up @@ -545,3 +529,26 @@ func (c *Collection) getDocPath(docID string) string {
}
return docPath
}

// persistMetadata persists the collection metadata to disk
func (c *Collection) persistMetadata() error {
// Persist name and metadata
metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
metadataPath += ".gob"
if c.compress {
metadataPath += ".gz"
}
pc := struct {
Name string
Metadata map[string]string
}{
Name: c.Name,
Metadata: c.metadata,
}
err := persistToFile(metadataPath, pc, c.compress, "")
if err != nil {
return err
}

return nil
}
73 changes: 57 additions & 16 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"sync"
)
Expand Down Expand Up @@ -198,9 +199,11 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
// This works for both the in-memory and persistent DBs.
// Existing collections are overwritten.
//
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
// - collections: Optional. If provided, only the collections with the given names
// are imported. If not provided, all collections are imported.
iwilltry42 marked this conversation as resolved.
Show resolved Hide resolved
func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections ...string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down Expand Up @@ -244,6 +247,9 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
}

for _, pc := range persistenceDB.Collections {
if len(collections) > 0 && !slices.Contains(collections, pc.Name) {
continue
}
c := &Collection{
Name: pc.Name,

Expand All @@ -253,6 +259,17 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
c.compress = db.compress
err = c.persistMetadata()
if err != nil {
return fmt.Errorf("couldn't persist collection metadata: %w", err)
}
for _, doc := range c.documents {
docPath := c.getDocPath(doc.ID)
err = persistToFile(docPath, doc, c.compress, "")
if err != nil {
return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
}
}
Comment on lines +263 to +273
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, good point! 👍

If you want you can split this into two PRs, one for the fix/improvement, and one for the new feature. But up to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't have any other change requests apart from the now-resolved ones, I'd leave this in one PR, if you don't mind :)

}
db.collections[c.Name] = c
}
Expand All @@ -267,9 +284,11 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
// Existing collections are overwritten.
// If the writer has to be closed, it's the caller's responsibility.
//
// - reader: An implementation of [io.ReadSeeker]
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
// - reader: An implementation of [io.ReadSeeker]
// - encryptionKey: Optional, must be 32 bytes long if provided
// - collections: Optional. If provided, only the collections with the given names
// are imported. If not provided, all collections are imported.
iwilltry42 marked this conversation as resolved.
Show resolved Hide resolved
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string, collections ...string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
Expand Down Expand Up @@ -299,6 +318,9 @@ func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error
}

for _, pc := range persistenceDB.Collections {
if len(collections) > 0 && !slices.Contains(collections, pc.Name) {
continue
}
c := &Collection{
Name: pc.Name,

Expand All @@ -308,6 +330,17 @@ func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
c.compress = db.compress
err = c.persistMetadata()
if err != nil {
return fmt.Errorf("couldn't persist collection metadata: %w", err)
}
for _, doc := range c.documents {
docPath := c.getDocPath(doc.ID)
err := persistToFile(docPath, doc, c.compress, "")
if err != nil {
return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
}
}
iwilltry42 marked this conversation as resolved.
Show resolved Hide resolved
}
db.collections[c.Name] = c
}
Expand Down Expand Up @@ -339,7 +372,9 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error {
// - collections: Optional. If provided, only the collections with the given names
// are exported. If not provided, all collections are exported.
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string, collections ...string) error {
if filePath == "" {
filePath = "./chromem-go.gob"
if compress {
Expand Down Expand Up @@ -373,10 +408,12 @@ func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string)
defer db.collectionsLock.RUnlock()

for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
if len(collections) == 0 || slices.Contains(collections, k) {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}
}

Expand All @@ -397,7 +434,9 @@ func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string)
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error {
// - collections: Optional. If provided, only the collections with the given names
// are exported. If not provided, all collections are exported.
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string, collections ...string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
Expand All @@ -422,10 +461,12 @@ func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey stri
defer db.collectionsLock.RUnlock()

for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
if len(collections) == 0 || slices.Contains(collections, k) {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}
}

Expand Down
123 changes: 119 additions & 4 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ func TestDB_ImportExport(t *testing.T) {
t.Fatal("expected no error, got", err)
}

new := NewDB()
newDB := NewDB()

// Import
err = new.ImportFromFile(tc.filePath, tc.encryptionKey)
err = newDB.ImportFromFile(tc.filePath, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand All @@ -156,13 +156,128 @@ func TestDB_ImportExport(t *testing.T) {
// We have to reset the embed function, but otherwise the DB objects
// should be deep equal.
c.embed = nil
if !reflect.DeepEqual(orig, new) {
t.Fatalf("expected DB %+v, got %+v", orig, new)
if !reflect.DeepEqual(orig, newDB) {
t.Fatalf("expected DB %+v, got %+v", orig, newDB)
iwilltry42 marked this conversation as resolved.
Show resolved Hide resolved
}
})
}
}

func TestDB_ImportExportSpecificCollections(t *testing.T) {
r := rand.New(rand.NewSource(rand.Int63()))
randString := randomString(r, 10)
path := filepath.Join(os.TempDir(), randString)
filePath := path + ".gob"
defer os.RemoveAll(path)

// Values in the collection
name := "test"
name2 := "test2"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}

// Create DB, can just be in-memory
orig := NewDB()

// Create collections
c, err := orig.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}

c2, err := orig.CreateCollection(name2, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Add documents
doc := Document{
ID: name,
Metadata: metadata,
Embedding: vectors,
Content: "test",
}

doc2 := Document{
ID: name2,
Metadata: metadata,
Embedding: vectors,
Content: "test2",
}

err = c.AddDocument(context.Background(), doc)
if err != nil {
t.Fatal("expected no error, got", err)
}

err = c2.AddDocument(context.Background(), doc2)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Export
iwilltry42 marked this conversation as resolved.
Show resolved Hide resolved
err = orig.ExportToFile(filePath, false, "", name2)
if err != nil {
t.Fatal("expected no error, got", err)
}

dir := filepath.Join(path, randomString(r, 10))
defer os.RemoveAll(dir)

newPDB, err := NewPersistentDB(dir, false)
iwilltry42 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatal("expected no error, got", err)
}

err = newPDB.ImportFromFile(filePath, "")
if err != nil {
t.Fatal("expected no error, got", err)
}

if len(newPDB.collections) != 1 {
t.Fatalf("expected 1 collection, got %d", len(newPDB.collections))
}

// Make sure that the imported documents are actually persisted on disk
for _, col := range newPDB.collections {
for _, d := range col.documents {
_, err = os.Stat(col.getDocPath(d.ID))
if err != nil {
t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err)
}
}
}

// Now export both collections and import them into the same persistent DB (overwriting the one we just imported)
filePath2 := path + "2.gob"
iwilltry42 marked this conversation as resolved.
Show resolved Hide resolved
err = orig.ExportToFile(filePath2, false, "")
if err != nil {
t.Fatal("expected no error, got", err)
}

err = newPDB.ImportFromFile(filePath2, "")
if err != nil {
t.Fatal("expected no error, got", err)
}

if len(newPDB.collections) != 2 {
t.Fatalf("expected 2 collections, got %d", len(newPDB.collections))
}

// Make sure that the imported documents are actually persisted on disk
for _, col := range newPDB.collections {
for _, d := range col.documents {
_, err = os.Stat(col.getDocPath(d.ID))
if err != nil {
t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err)
}
}
}
}

func TestDB_CreateCollection(t *testing.T) {
// Values in the collection
name := "test"
Expand Down