Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/json"
"errors"
"net/http"

"github.com/sashabaranov/go-openai/jsonschema"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -221,6 +223,31 @@ type ChatCompletionResponseFormatJSONSchema struct {
Strict bool `json:"strict"`
}

func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error {
type rawJSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema json.RawMessage `json:"schema"`
Strict bool `json:"strict"`
}
var raw rawJSONSchema
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
r.Name = raw.Name
r.Description = raw.Description
r.Strict = raw.Strict
if len(raw.Schema) > 0 && string(raw.Schema) != "null" {
var d jsonschema.Definition
err := json.Unmarshal(raw.Schema, &d)
if err != nil {
return err
}
r.Schema = &d
}
return nil
}

// ChatCompletionRequest represents a request structure for chat completion API.
type ChatCompletionRequest struct {
Model string `json:"model"`
Expand Down
139 changes: 139 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -946,3 +946,142 @@ func TestFinishReason(t *testing.T) {
}
}
}

func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) {
type args struct {
data []byte
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"",
args{
data: []byte(`{
"name": "math_response",
"strict": true,
"schema": {
"type": "object",
"properties": {
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"explanation": { "type": "string" },
"output": { "type": "string" }
},
"required": ["explanation","output"],
"additionalProperties": false
}
},
"final_answer": { "type": "string" }
},
"required": ["steps","final_answer"],
"additionalProperties": false
}
}`),
},
false,
},
{
"",
args{
data: []byte(`{
"name": "math_response",
"strict": true,
"schema": null
}`),
},
false,
},
{
"",
args{
data: []byte(`[123,456]`),
},
true,
},
{
"",
args{
data: []byte(`{
"name": "math_response",
"strict": true,
"schema": 123456
}`),
},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var r openai.ChatCompletionResponseFormatJSONSchema
err := r.UnmarshalJSON(tt.args.data)
if (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) {
type args struct {
bs []byte
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"",
args{bs: []byte(`{
"model": "llama3-1b",
"messages": [
{ "role": "system", "content": "You are a helpful math tutor." },
{ "role": "user", "content": "solve 8x + 31 = 2" }
],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "math_response",
"strict": true,
"schema": {
"type": "object",
"properties": {
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"explanation": { "type": "string" },
"output": { "type": "string" }
},
"required": ["explanation","output"],
"additionalProperties": false
}
},
"final_answer": { "type": "string" }
},
"required": ["steps","final_answer"],
"additionalProperties": false
}
}
}
}`)},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var m openai.ChatCompletionRequest
err := json.Unmarshal(tt.args.bs, &m)
if err != nil {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
Loading