diff --git a/image.go b/image.go index 72077ce41..84b9daf02 100644 --- a/image.go +++ b/image.go @@ -132,7 +132,32 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons return } +// WrapReader wraps an io.Reader with filename and Content-type. +func WrapReader(rdr io.Reader, filename string, contentType string) io.Reader { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Name() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} + // ImageEditRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageEditRequest struct { Image io.Reader `json:"image,omitempty"` Mask io.Reader `json:"mask,omitempty"` @@ -150,7 +175,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required + // image, filename verification can be postponed err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return @@ -158,7 +183,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // mask, it is optional if request.Mask != nil { - // mask, filename is not required + // filename verification can be postponed err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return @@ -206,6 +231,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) } // ImageVariRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageVariRequest struct { Image io.Reader `json:"image,omitempty"` Model string `json:"model,omitempty"` @@ -221,7 +247,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required + // image, filename verification can be postponed err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return diff --git a/internal/form_builder.go b/internal/form_builder.go index 1c2513dd9..5b382df20 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -39,9 +39,18 @@ func escapeQuotes(s string) string { } // CreateFormFileReader creates a form field with a file reader. -// The filename in parameters can be an empty string. -// The filename in Content-Disposition is required, But it can be an empty string. +// The filename in Content-Disposition is required. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + if filename == "" { + if f, ok := r.(interface{ Name() string }); ok { + filename = f.Name() + } + } + var contentType string + if f, ok := r.(interface{ ContentType() string }); ok { + contentType = f.ContentType() + } + h := make(textproto.MIMEHeader) h.Set( "Content-Disposition", @@ -51,6 +60,10 @@ func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader escapeQuotes(filepath.Base(filename)), ), ) + // content type is optional, but it can be set + if contentType != "" { + h.Set("Content-Type", contentType) + } fieldWriter, err := fb.writer.CreatePart(h) if err != nil { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 76922c1ba..f4958ad5e 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,6 +1,8 @@ package openai //nolint:testpackage // testing private field import ( + "io" + "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -53,6 +55,18 @@ func (*failingReader) Read([]byte) (int, error) { return 0, errMockFailingReaderError } +type readerWithNameAndContentType struct { + io.Reader +} + +func (*readerWithNameAndContentType) Name() string { + return "" +} + +func (*readerWithNameAndContentType) ContentType() string { + return "image/png" +} + func TestFormBuilderWithReader(t *testing.T) { file, err := os.CreateTemp(t.TempDir(), "") if err != nil { @@ -71,4 +85,8 @@ func TestFormBuilderWithReader(t *testing.T) { successReader := &bytes.Buffer{} err = builder.CreateFormFileReader("file", successReader, "") checks.NoError(t, err, "formbuilder should not return error") + + rnc := &readerWithNameAndContentType{Reader: &bytes.Buffer{}} + err = builder.CreateFormFileReader("file", rnc, "") + checks.NoError(t, err, "formbuilder should not return error") }