Skip to content
14 changes: 14 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ func withBody(body any) requestOption {
}
}

func withExtraBody(extraBody map[string]any) requestOption {
return func(args *requestOptions) {
// Assert that args.body is a map[string]any.
bodyMap, ok := args.body.(map[string]any)
if ok {
// If it's a map[string]any then only add extraBody
// fields to args.body otherwise keep only fields in request struct.
for key, value := range extraBody {
bodyMap[key] = value
}
}
}
}

func withContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
Expand Down
32 changes: 31 additions & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"math"
"net/http"
Expand Down Expand Up @@ -160,6 +161,9 @@ type EmbeddingRequest struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequest) Convert() EmbeddingRequest {
Expand Down Expand Up @@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
Expand All @@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
User: r.User,
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
ExtraBody: r.ExtraBody,
}
}

Expand All @@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
// in the request body that may not be explicitly defined in this struct.
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
Expand All @@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
User: r.User,
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
ExtraBody: r.ExtraBody,
}
}

Expand All @@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()

// The body map is used to dynamically construct the request payload for the embedding API.
// Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields
// based on their presence, avoiding unnecessary or empty fields in the request.
extraBody := baseReq.ExtraBody
baseReq.ExtraBody = nil

// Serialize baseReq to JSON
jsonData, err := json.Marshal(baseReq)
if err != nil {
return
}

// Deserialize JSON to map[string]any
var body map[string]any
_ = json.Unmarshal(jsonData, &body)
Copy link

Copilot AI Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid ignoring the error returned by json.Unmarshal. Capture and handle the error (e.g., if err := json.Unmarshal(...); err != nil { return res, err }) to prevent silent failures.

Suggested change
_ = json.Unmarshal(jsonData, &body)
if err := json.Unmarshal(jsonData, &body); err != nil {
return res, err
}

Copilot uses AI. Check for mistakes.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AyushSawant18588 just some tech debt :)


req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
withBody(baseReq),
withBody(body), // Main request body.
withExtraBody(extraBody), // Merge ExtraBody fields.
)
if err != nil {
return
Expand Down
46 changes: 45 additions & 1 deletion embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) {
t.Fatalf("Expected embedding request to contain model field")
}

// test embedding request with strings and extra_body param
embeddingReqWithExtraBody := openai.EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: model,
ExtraBody: map[string]any{
"input_type": "query",
"truncate": "NONE",
},
}
marshaled, err = json.Marshal(embeddingReqWithExtraBody)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}

// test embedding request with strings
embeddingReqStrings := openai.EmbeddingRequestStrings{
Input: []string{
Expand Down Expand Up @@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with strings (simple embedding request)
// test create embeddings with strings (ExtraBody in request)
res, err = client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
ExtraBody: map[string]any{
"input_type": "query",
"truncate": "NONE",
},
Dimensions: 1,
},
)
checks.NoError(t, err, "CreateEmbeddings error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with strings (ExtraBody in request and )
_, err = client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Input: make(chan int), // Channels are not serializable
Model: "example_model",
},
)
checks.HasError(t, err, "CreateEmbeddings error")

// test failed (Serialize JSON error)
res, err = client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Expand Down
Loading