Skip to content

Commit 850ff50

Browse files
committed
Add Filters to model Config #174
1 parent c6d7fd2 commit 850ff50

File tree

7 files changed

+126
-0
lines changed

7 files changed

+126
-0
lines changed

misc/simple-responder/simple-responder.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,12 @@ func main() {
4242
time.Sleep(wait)
4343
}
4444

45+
bodyBytes, _ := io.ReadAll(c.Request.Body)
46+
4547
c.JSON(http.StatusOK, gin.H{
4648
"responseMessage": *responseMessage,
4749
"h_content_length": c.Request.Header.Get("Content-Length"),
50+
"request_body": string(bodyBytes),
4851
})
4952
})
5053

proxy/config.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"regexp"
88
"runtime"
9+
"slices"
910
"sort"
1011
"strconv"
1112
"strings"
@@ -29,6 +30,9 @@ type ModelConfig struct {
2930

3031
// Limit concurrency of HTTP requests to process
3132
ConcurrencyLimit int `yaml:"concurrencyLimit"`
33+
34+
// Model filters see issue #174
35+
Filters ModelFilters `yaml:"filters"`
3236
}
3337

3438
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
@@ -63,6 +67,46 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
6367
return SanitizeCommand(m.Cmd)
6468
}
6569

70+
// ModelFilters see issue #174
71+
type ModelFilters struct {
72+
StripParams string `yaml:"strip_params"`
73+
}
74+
75+
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
76+
type rawModelFilters ModelFilters
77+
defaults := rawModelFilters{
78+
StripParams: "",
79+
}
80+
81+
if err := unmarshal(&defaults); err != nil {
82+
return err
83+
}
84+
85+
*m = ModelFilters(defaults)
86+
return nil
87+
}
88+
89+
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
90+
if f.StripParams == "" {
91+
return nil, nil
92+
}
93+
94+
params := strings.Split(f.StripParams, ",")
95+
cleaned := make([]string, 0, len(params))
96+
97+
for _, param := range params {
98+
trimmed := strings.TrimSpace(param)
99+
if trimmed == "model" || trimmed == "" {
100+
continue
101+
}
102+
cleaned = append(cleaned, strings.TrimSpace(param))
103+
}
104+
105+
// sort cleaned
106+
slices.Sort(cleaned)
107+
return cleaned, nil
108+
}
109+
66110
type GroupConfig struct {
67111
Swap bool `yaml:"swap"`
68112
Exclusive bool `yaml:"exclusive"`
@@ -212,6 +256,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
212256
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
213257
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
214258
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
259+
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
215260
}
216261

217262
// enforce ${PORT} used in both cmd and proxy

proxy/config_posix_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ models:
8383
assert.Equal(t, "", model1.UseModelName)
8484
assert.Equal(t, 0, model1.ConcurrencyLimit)
8585
}
86+
87+
// default empty filter exists
88+
assert.Equal(t, "", model1.Filters.StripParams)
8689
}
8790

8891
func TestConfig_LoadPosix(t *testing.T) {

proxy/config_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,28 @@ models:
300300
})
301301
}
302302
}
303+
304+
func TestConfig_ModelFilters(t *testing.T) {
305+
content := `
306+
macros:
307+
default_strip: "temperature, top_p"
308+
models:
309+
model1:
310+
cmd: path/to/cmd --port ${PORT}
311+
filters:
312+
strip_params: "model, top_k, ${default_strip}, , ,"
313+
`
314+
config, err := LoadConfigFromReader(strings.NewReader(content))
315+
assert.NoError(t, err)
316+
modelConfig, ok := config.Models["model1"]
317+
if !assert.True(t, ok) {
318+
t.FailNow()
319+
}
320+
321+
// make sure `model` and enmpty strings are not in the list
322+
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
323+
sanitized, err := modelConfig.Filters.SanitizedStripParams()
324+
if assert.NoError(t, err) {
325+
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
326+
}
327+
}

proxy/config_windows_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ models:
8080
assert.Equal(t, "", model1.UseModelName)
8181
assert.Equal(t, 0, model1.ConcurrencyLimit)
8282
}
83+
84+
// default empty filter exists
85+
assert.Equal(t, "", model1.Filters.StripParams)
8386
}
8487

8588
func TestConfig_LoadWindows(t *testing.T) {

proxy/proxymanager.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
365365
}
366366
}
367367

368+
// issue #174 strip parameters from the JSON body
369+
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
370+
if err != nil { // just log it and continue
371+
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
372+
} else {
373+
for _, param := range stripParams {
374+
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
375+
if err != nil {
376+
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
377+
}
378+
}
379+
}
380+
368381
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
369382

370383
// dechunk it as we already have all the body bytes see issue #11

proxy/proxymanager_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,3 +623,37 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
623623
assert.Equal(t, "81", response["h_content_length"])
624624
assert.Equal(t, "model1", response["responseMessage"])
625625
}
626+
627+
func TestProxyManager_FiltersStripParams(t *testing.T) {
628+
modelConfig := getTestSimpleResponderConfig("model1")
629+
modelConfig.Filters = ModelFilters{
630+
StripParams: "temperature, model, stream",
631+
}
632+
633+
config := AddDefaultGroupToConfig(Config{
634+
HealthCheckTimeout: 15,
635+
LogLevel: "error",
636+
Models: map[string]ModelConfig{
637+
"model1": modelConfig,
638+
},
639+
})
640+
641+
proxy := New(config)
642+
defer proxy.StopProcesses(StopWaitForInflightRequest)
643+
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
644+
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
645+
w := httptest.NewRecorder()
646+
647+
proxy.ServeHTTP(w, req)
648+
assert.Equal(t, http.StatusOK, w.Code)
649+
var response map[string]string
650+
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
651+
652+
// `temperature` and `stream` are gone but model remains
653+
assert.Equal(t, `{"model":"model1", "x_param":"123", "y_param":"abc"}`, response["request_body"])
654+
655+
// assert.Nil(t, response["temperature"])
656+
// assert.Equal(t, "123", response["x_param"])
657+
// assert.Equal(t, "abc", response["y_param"])
658+
// t.Logf("%v", response)
659+
}

0 commit comments

Comments
 (0)