Skip to content

Commit

Permalink
Merge pull request #10 from philippgille/add-embedding-providers
Browse files Browse the repository at this point in the history
Add embedding providers
  • Loading branch information
philippgille authored Feb 10, 2024
2 parents b3f5f9b + deb598a commit 88957cb
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 14 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ Initially, only a minimal subset of all of Chroma's interface is implemented or
- [X] Zero dependencies on third party libraries
- [X] Concurrent processing (when adding and querying documents)
- Embedding creators:
- [X] [OpenAI text-embedding-3-small](https://platform.openai.com/docs/guides/embeddings/embedding-models) (default)
- [X] [OpenAI](https://platform.openai.com/docs/guides/embeddings/embedding-models) (default)
- [X] [Mistral](https://docs.mistral.ai/platform/endpoints/#embedding-models)
- [X] [Jina](https://jina.ai/embeddings)
- [X] [mixedbread.ai](https://www.mixedbread.ai/)
- [X] [LocalAI](https://github.com/mudler/LocalAI)
- [X] Bring your own
- [ ] [Mistral (API)](https://docs.mistral.ai/api/#operation/createEmbedding)
- [ ] [ollama](https://ollama.ai/)
- [ ] [LocalAI](https://github.com/mudler/LocalAI)
- [ ] [ollama](https://ollama.ai/) (As of 2024-02-10 their OpenAI compatible API doesn't support embeddings yet, but they have a custom API which does)
- Similarity search:
- [X] Exact nearest neighbor search using cosine similarity
- [ ] Approximate nearest neighbor search with index
Expand Down
18 changes: 18 additions & 0 deletions embed_jina.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package chromem

const baseURLJina = "https://api.jina.ai/v1"

type EmbeddingModelJina string

const (
EmbeddingModelJina2BaseEN EmbeddingModelJina = "jina-embeddings-v2-base-en"
EmbeddingModelJina2BaseDE EmbeddingModelJina = "jina-embeddings-v2-base-de"
EmbeddingModelJina2BaseCode EmbeddingModelJina = "jina-embeddings-v2-base-code"
EmbeddingModelJina2BaseZH EmbeddingModelJina = "jina-embeddings-v2-base-zh"
)

// NewEmbeddingFuncJina returns a function that creates embeddings for a document
// using the Jina API.
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model))
}
16 changes: 16 additions & 0 deletions embed_localai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package chromem

const baseURLLocalAI = "http://localhost:8080/v1"

// NewEmbeddingFuncLocalAI returns a function that creates embeddings for a document
// using the LocalAI API.
// You can start a LocalAI instance like this:
//
// docker run -it -p 127.0.0.1:8080:8080 localai/localai:v2.7.0-ffmpeg-core bert-cpp
//
// And then call this constructor with model "bert-cpp-minilm-v6".
// But other embedding models are supported as well. See the LocalAI documentation
// for details.
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model)
}
15 changes: 15 additions & 0 deletions embed_mistral.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package chromem

const (
baseURLMistral = "https://api.mistral.ai/v1"
// Currently there's only one. Let's turn this into a pseudo-enum as soon as there are more.
embeddingModelMistral = "mistral-embed"
)

// NewEmbeddingFuncMistral returns a function that creates embeddings for a document
// using the Mistral API.
func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
// The Mistral API docs don't mention the `encoding_format` as optional,
// but it seems to be, just like OpenAI. So we reuse the OpenAI function.
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral)
}
22 changes: 22 additions & 0 deletions embed_mixedbread.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package chromem

const baseURLMixedbread = "https://api.mixedbread.ai"

type EmbeddingModelMixedbread string

const (
EmbeddingModelMixedbreadUAELargeV1 EmbeddingModelMixedbread = "UAE-Large-V1"
EmbeddingModelMixedbreadBGELargeENV15 EmbeddingModelMixedbread = "bge-large-en-v1.5"
EmbeddingModelMixedbreadGTELarge EmbeddingModelMixedbread = "gte-large"
EmbeddingModelMixedbreadE5LargeV2 EmbeddingModelMixedbread = "e5-large-v2"
EmbeddingModelMixedbreadMultilingualE5Large EmbeddingModelMixedbread = "multilingual-e5-large"
EmbeddingModelMixedbreadMultilingualE5Base EmbeddingModelMixedbread = "multilingual-e5-base"
EmbeddingModelMixedbreadAllMiniLML6V2 EmbeddingModelMixedbread = "all-MiniLM-L6-v2"
EmbeddingModelMixedbreadGTELargeZh EmbeddingModelMixedbread = "gte-large-zh"
)

// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a document
// using the mixedbread.ai API.
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model))
}
13 changes: 3 additions & 10 deletions embedding.go → embed_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ import (
"os"
)

const (
BaseURLOpenAI = "https://api.openai.com/v1"
)
const BaseURLOpenAI = "https://api.openai.com/v1"

type EmbeddingModelOpenAI string

Expand Down Expand Up @@ -60,7 +58,7 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
"input": document,
"model": string(model),
"model": model,
})
if err != nil {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
Expand All @@ -84,12 +82,7 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc {

// Check the response status.
if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
fmt.Println("========", string(body))
return nil, errors.New("error response from the OpenAI API: " + resp.Status)
return nil, errors.New("error response from the embedding API: " + resp.Status)
}

// Read and decode the response body.
Expand Down

0 comments on commit 88957cb

Please sign in to comment.