Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for extra_body parameter for embeddings API #906

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
14 changes: 14 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ func withBody(body any) requestOption {
}
}

func withExtraBody(extraBody map[string]any) requestOption {
return func(args *requestOptions) {
// Assert that args.body is a map[string]any.
bodyMap, ok := args.body.(map[string]any)
if ok {
// If it's a map[string]any then only add extraBody
// fields to args.body otherwise keep only fields in request struct.
for key, value := range extraBody {
bodyMap[key] = value
}
}
}
}

func withContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
Expand Down
31 changes: 30 additions & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ type EmbeddingRequest struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequest) Convert() EmbeddingRequest {
Expand Down Expand Up @@ -187,6 +190,9 @@ type EmbeddingRequestStrings struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
Expand All @@ -196,6 +202,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
User: r.User,
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
ExtraBody: r.ExtraBody,
}
}

Expand All @@ -219,6 +226,9 @@ type EmbeddingRequestTokens struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
Expand All @@ -228,6 +238,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
User: r.User,
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
ExtraBody: r.ExtraBody,
}
}

Expand All @@ -241,11 +252,29 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()

// Prepare the body with only the provided fields.
// The body map is used to dynamically construct the request payload for the embedding API.
// Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields
// based on their presence, avoiding unnecessary or empty fields in the request.
body := make(map[string]any)
body["input"] = baseReq.Input
body["model"] = baseReq.Model
if baseReq.User != "" {
body["user"] = baseReq.User
}
if baseReq.EncodingFormat != "" {
body["encoding_format"] = baseReq.EncodingFormat
}
if baseReq.Dimensions > 0 { // Assuming 0 means the field is not set
body["dimensions"] = baseReq.Dimensions
}
sashabaranov marked this conversation as resolved.
Show resolved Hide resolved
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
withBody(baseReq),
withBody(body), // Main request body.
withExtraBody(baseReq.ExtraBody), // Merge ExtraBody fields.
)
if err != nil {
return
Expand Down
34 changes: 34 additions & 0 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) {
t.Fatalf("Expected embedding request to contain model field")
}

// test embedding request with strings and extra_body param
embeddingReqWithExtraBody := openai.EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: model,
ExtraBody: map[string]any{
"input_type": "query",
"truncate": "NONE",
},
}
marshaled, err = json.Marshal(embeddingReqWithExtraBody)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}

// test embedding request with strings
embeddingReqStrings := openai.EmbeddingRequestStrings{
Input: []string{
Expand Down Expand Up @@ -124,6 +142,22 @@ func TestEmbeddingEndpoint(t *testing.T) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with strings (ExtraBody in request)
res, err = client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
ExtraBody: map[string]any{
"input_type": "query",
"truncate": "NONE",
},
Dimensions: 1,
},
)
checks.NoError(t, err, "CreateEmbeddings error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with strings (simple embedding request)
res, err = client.CreateEmbeddings(
context.Background(),
Expand Down
Loading