Skip to content

Commit 874b052

Browse files
committed
feat(inpainting): add inpainting endpoint with automatic model selection
- Add inpainting endpoint supporting OpenAI-compatible API - Implement automatic model selection for image generation endpoints - Add comprehensive tests for inpainting functionality - Update Swagger documentation for new endpoint - Wire ImageGenerationFunc to backend Signed-off-by: Greg <[email protected]> ci: re-trigger tests-apple workflow
1 parent 3a23244 commit 874b052

File tree

6 files changed

+429
-3
lines changed

6 files changed

+429
-3
lines changed

core/backend/image.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
4040

4141
return fn, nil
4242
}
43+
44+
// ImageGenerationFunc is a test-friendly indirection to call image generation logic.
45+
// Tests can override this variable to provide a stub implementation.
46+
var ImageGenerationFunc = ImageGeneration
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
package openai
2+
3+
import (
4+
"encoding/base64"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"os"
11+
"path/filepath"
12+
"strconv"
13+
"time"
14+
15+
"github.com/google/uuid"
16+
"github.com/labstack/echo/v4"
17+
"github.com/rs/zerolog/log"
18+
19+
"github.com/mudler/LocalAI/core/backend"
20+
"github.com/mudler/LocalAI/core/config"
21+
"github.com/mudler/LocalAI/core/http/middleware"
22+
"github.com/mudler/LocalAI/core/schema"
23+
model "github.com/mudler/LocalAI/pkg/model"
24+
)
25+
26+
// InpaintingEndpoint handles POST /v1/images/inpainting
27+
//
28+
// Swagger / OpenAPI docstring (swaggo):
29+
// @Summary Image inpainting
30+
// @Description Perform image inpainting. Accepts multipart/form-data with `image` and `mask` files.
31+
// @Tags images
32+
// @Accept multipart/form-data
33+
// @Produce application/json
34+
// @Param model formData string true "Model identifier"
35+
// @Param prompt formData string true "Text prompt guiding the generation"
36+
// @Param steps formData int false "Number of inference steps (default 25)"
37+
// @Param image formData file true "Original image file"
38+
// @Param mask formData file true "Mask image file (white = area to inpaint)"
39+
// @Success 200 {object} schema.OpenAIResponse
40+
// @Failure 400 {object} map[string]string
41+
// @Failure 500 {object} map[string]string
42+
// @Router /v1/images/inpainting [post]
43+
func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
44+
return func(c echo.Context) error {
45+
// Parse basic form values
46+
modelName := c.FormValue("model")
47+
prompt := c.FormValue("prompt")
48+
stepsStr := c.FormValue("steps")
49+
50+
if modelName == "" || prompt == "" {
51+
log.Error().Msg("Inpainting Endpoint - missing model or prompt")
52+
return echo.ErrBadRequest
53+
}
54+
55+
// steps default
56+
steps := 25
57+
if stepsStr != "" {
58+
if v, err := strconv.Atoi(stepsStr); err == nil {
59+
steps = v
60+
}
61+
}
62+
63+
// Get uploaded files
64+
imageFile, err := c.FormFile("image")
65+
if err != nil {
66+
log.Error().Err(err).Msg("Inpainting Endpoint - missing image file")
67+
return echo.NewHTTPError(http.StatusBadRequest, "missing image file")
68+
}
69+
maskFile, err := c.FormFile("mask")
70+
if err != nil {
71+
log.Error().Err(err).Msg("Inpainting Endpoint - missing mask file")
72+
return echo.NewHTTPError(http.StatusBadRequest, "missing mask file")
73+
}
74+
75+
// Read files into memory (small files expected)
76+
imgSrc, err := imageFile.Open()
77+
if err != nil {
78+
return err
79+
}
80+
defer imgSrc.Close()
81+
imgBytes, err := io.ReadAll(imgSrc)
82+
if err != nil {
83+
return err
84+
}
85+
86+
maskSrc, err := maskFile.Open()
87+
if err != nil {
88+
return err
89+
}
90+
defer maskSrc.Close()
91+
maskBytes, err := io.ReadAll(maskSrc)
92+
if err != nil {
93+
return err
94+
}
95+
96+
// Create JSON with base64 fields expected by backend
97+
b64Image := base64.StdEncoding.EncodeToString(imgBytes)
98+
b64Mask := base64.StdEncoding.EncodeToString(maskBytes)
99+
100+
// get model config from context (middleware set it)
101+
cfg, ok := c.Get("MODEL_CONFIG").(*config.ModelConfig)
102+
if !ok || cfg == nil {
103+
log.Error().Msg("Inpainting Endpoint - model config not found in context")
104+
return echo.ErrBadRequest
105+
}
106+
107+
// Use the GeneratedContentDir so the generated PNG is placed where the
108+
// HTTP static handler serves `/generated-images`.
109+
tmpDir := appConfig.GeneratedContentDir
110+
// Ensure the directory exists
111+
if err := os.MkdirAll(tmpDir, 0750); err != nil {
112+
log.Error().Err(err).Msgf("Inpainting Endpoint - failed to create generated content dir: %s", tmpDir)
113+
return echo.NewHTTPError(http.StatusInternalServerError, "failed to prepare storage")
114+
}
115+
id := uuid.New().String()
116+
jsonName := fmt.Sprintf("inpaint_%s.json", id)
117+
jsonPath := filepath.Join(tmpDir, jsonName)
118+
jsonFile := map[string]string{
119+
"image": b64Image,
120+
"mask_image": b64Mask,
121+
}
122+
jf, err := os.CreateTemp(tmpDir, "inpaint_")
123+
if err != nil {
124+
return err
125+
}
126+
// setup cleanup on error; if everything succeeds we set success = true
127+
success := false
128+
var dst string
129+
var origRef string
130+
var maskRef string
131+
defer func() {
132+
if !success {
133+
// Best-effort cleanup; log any failures
134+
if jf != nil {
135+
if cerr := jf.Close(); cerr != nil {
136+
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file in cleanup")
137+
}
138+
if name := jf.Name(); name != "" {
139+
if rerr := os.Remove(name); rerr != nil && !os.IsNotExist(rerr) {
140+
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove temp json file %s in cleanup", name)
141+
}
142+
}
143+
}
144+
if jsonPath != "" {
145+
if rerr := os.Remove(jsonPath); rerr != nil && !os.IsNotExist(rerr) {
146+
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove json file %s in cleanup", jsonPath)
147+
}
148+
}
149+
if dst != "" {
150+
if rerr := os.Remove(dst); rerr != nil && !os.IsNotExist(rerr) {
151+
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove dst file %s in cleanup", dst)
152+
}
153+
}
154+
if origRef != "" {
155+
if rerr := os.Remove(origRef); rerr != nil && !os.IsNotExist(rerr) {
156+
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove orig ref file %s in cleanup", origRef)
157+
}
158+
}
159+
if maskRef != "" {
160+
if rerr := os.Remove(maskRef); rerr != nil && !os.IsNotExist(rerr) {
161+
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove mask ref file %s in cleanup", maskRef)
162+
}
163+
}
164+
}
165+
}()
166+
167+
// write original image and mask to disk as ref images so backends that
168+
// accept reference image files can use them (maintainer request).
169+
origTmp, err := os.CreateTemp(tmpDir, "refimg_")
170+
if err != nil {
171+
return err
172+
}
173+
if _, err := origTmp.Write(imgBytes); err != nil {
174+
_ = origTmp.Close()
175+
_ = os.Remove(origTmp.Name())
176+
return err
177+
}
178+
if cerr := origTmp.Close(); cerr != nil {
179+
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close orig temp file")
180+
}
181+
origRef = origTmp.Name()
182+
183+
maskTmp, err := os.CreateTemp(tmpDir, "refmask_")
184+
if err != nil {
185+
// cleanup origTmp on error
186+
_ = os.Remove(origRef)
187+
return err
188+
}
189+
if _, err := maskTmp.Write(maskBytes); err != nil {
190+
_ = maskTmp.Close()
191+
_ = os.Remove(maskTmp.Name())
192+
_ = os.Remove(origRef)
193+
return err
194+
}
195+
if cerr := maskTmp.Close(); cerr != nil {
196+
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close mask temp file")
197+
}
198+
maskRef = maskTmp.Name()
199+
// write JSON
200+
enc := json.NewEncoder(jf)
201+
if err := enc.Encode(jsonFile); err != nil {
202+
if cerr := jf.Close(); cerr != nil {
203+
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file after encode error")
204+
}
205+
return err
206+
}
207+
if cerr := jf.Close(); cerr != nil {
208+
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file")
209+
}
210+
// rename to desired name
211+
if err := os.Rename(jf.Name(), jsonPath); err != nil {
212+
return err
213+
}
214+
// prepare dst
215+
outTmp, err := os.CreateTemp(tmpDir, "out_")
216+
if err != nil {
217+
return err
218+
}
219+
if cerr := outTmp.Close(); cerr != nil {
220+
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close out temp file")
221+
}
222+
dst = outTmp.Name() + ".png"
223+
if err := os.Rename(outTmp.Name(), dst); err != nil {
224+
return err
225+
}
226+
227+
// Determine width/height default
228+
width := 512
229+
height := 512
230+
231+
// Call backend image generation via indirection so tests can stub it
232+
// Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON
233+
// Also pass ref images (orig + mask) so backends that support ref images can use them.
234+
refImages := []string{origRef, maskRef}
235+
fn, err := backend.ImageGenerationFunc(height, width, 0, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages)
236+
if err != nil {
237+
return err
238+
}
239+
240+
// Execute generation function (blocking)
241+
if err := fn(); err != nil {
242+
return err
243+
}
244+
245+
// On success, build response URL using BaseURL middleware helper and
246+
// the same `generated-images` prefix used by the server static mount.
247+
baseURL := middleware.BaseURL(c)
248+
249+
// Build response using url.JoinPath for correct URL escaping
250+
imgPath, err := url.JoinPath(baseURL, "generated-images", filepath.Base(dst))
251+
if err != nil {
252+
return err
253+
}
254+
255+
created := int(time.Now().Unix())
256+
resp := &schema.OpenAIResponse{
257+
ID: id,
258+
Created: created,
259+
Data: []schema.Item{{
260+
URL: imgPath,
261+
}},
262+
}
263+
264+
// mark success so defer cleanup will not remove output files
265+
success = true
266+
267+
return c.JSON(http.StatusOK, resp)
268+
}
269+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"mime/multipart"
6+
"net/http"
7+
"net/http/httptest"
8+
"os"
9+
"path/filepath"
10+
"testing"
11+
12+
"github.com/labstack/echo/v4"
13+
"github.com/mudler/LocalAI/core/backend"
14+
"github.com/mudler/LocalAI/core/config"
15+
model "github.com/mudler/LocalAI/pkg/model"
16+
"github.com/stretchr/testify/require"
17+
)
18+
19+
func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) {
20+
b := &bytes.Buffer{}
21+
w := multipart.NewWriter(b)
22+
for k, v := range fields {
23+
_ = w.WriteField(k, v)
24+
}
25+
for fname, content := range files {
26+
fw, err := w.CreateFormFile(fname, fname+".png")
27+
require.NoError(t, err)
28+
_, err = fw.Write(content)
29+
require.NoError(t, err)
30+
}
31+
require.NoError(t, w.Close())
32+
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b)
33+
req.Header.Set("Content-Type", w.FormDataContentType())
34+
return req, w.FormDataContentType()
35+
}
36+
37+
func TestInpainting_MissingFiles(t *testing.T) {
38+
e := echo.New()
39+
// handler requires cl, ml, appConfig but this test verifies missing files early
40+
h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig())
41+
42+
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil)
43+
rec := httptest.NewRecorder()
44+
c := e.NewContext(req, rec)
45+
46+
err := h(c)
47+
require.Error(t, err)
48+
}
49+
50+
func TestInpainting_HappyPath(t *testing.T) {
51+
// Setup temp generated content dir
52+
tmpDir, err := os.MkdirTemp("", "gencontent")
53+
require.NoError(t, err)
54+
defer os.RemoveAll(tmpDir)
55+
56+
appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))
57+
58+
// stub the backend.ImageGenerationFunc
59+
orig := backend.ImageGenerationFunc
60+
backend.ImageGenerationFunc = func(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
61+
fn := func() error {
62+
// write a fake png file to dst
63+
return os.WriteFile(dst, []byte("PNGDATA"), 0644)
64+
}
65+
return fn, nil
66+
}
67+
defer func() { backend.ImageGenerationFunc = orig }()
68+
69+
// prepare multipart request with image and mask
70+
fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"}
71+
files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")}
72+
reqBuf, _ := makeMultipartRequest(t, fields, files)
73+
74+
rec := httptest.NewRecorder()
75+
e := echo.New()
76+
c := e.NewContext(reqBuf, rec)
77+
78+
// set a minimal model config in context as handler expects
79+
c.Set("MODEL_CONFIG", &config.ModelConfig{Backend: "diffusers"})
80+
81+
h := InpaintingEndpoint(nil, nil, appConf)
82+
83+
// call handler
84+
err = h(c)
85+
require.NoError(t, err)
86+
require.Equal(t, http.StatusOK, rec.Code)
87+
88+
// verify response body contains generated-images path
89+
body := rec.Body.String()
90+
require.Contains(t, body, "generated-images")
91+
92+
// confirm the file was created in tmpDir
93+
// parse out filename from response (naive search)
94+
// find "generated-images/" and extract until closing quote or brace
95+
idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/"))
96+
require.True(t, idx >= 0)
97+
rest := rec.Body.Bytes()[idx:]
98+
end := bytes.IndexAny(rest, "\",}\n")
99+
if end == -1 {
100+
end = len(rest)
101+
}
102+
fname := string(rest[len("generated-images/"):end])
103+
// ensure file exists
104+
_, err = os.Stat(filepath.Join(tmpDir, fname))
105+
require.NoError(t, err)
106+
}

core/http/endpoints/openai/mcp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
108108
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
109109
}),
110110
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
111-
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
111+
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", config.Name, t.Name, t.Reasoning, t.Arguments)
112112
return true
113113
}),
114114
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
115-
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
115+
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, result: %s, tool arguments: %+v", config.Name, t.Name, t.Result, t.ToolArguments)
116116
}),
117117
)
118118

0 commit comments

Comments
 (0)