Skip to content

Commit d5b079b

Browse files
committed
upstream handler support for model names with forward slash
The upstream handler would break on model IDs that contained a forward slash. Model IDs like "aaa/bbb" called at upstream/aaa/bbb would result in an error. This commit adds support for model IDs with a forward slash by iteratively searching the path for a match. Updates: #229
1 parent be6d42f commit d5b079b

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

proxy/proxymanager.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ func (pm *ProxyManager) setupGinEngine() {
227227
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
228228
c.Redirect(http.StatusFound, "/ui/models")
229229
})
230-
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
230+
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
231231

232232
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
233233
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
@@ -393,24 +393,52 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
393393
}
394394

395395
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
396-
requestedModel := c.Param("model_id")
396+
upstreamPath := c.Param("upstreamPath")
397397

398-
if requestedModel == "" {
398+
// split the upstream path by / and search for the model name
399+
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
400+
if len(parts) == 0 {
399401
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
400402
return
401403
}
402404

403-
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
405+
modelFound := false
406+
searchModelName := ""
407+
var modelName, remainingPath string
408+
for i, part := range parts {
409+
if parts[i] == "" {
410+
continue
411+
}
412+
413+
if searchModelName == "" {
414+
searchModelName = part
415+
} else {
416+
searchModelName = searchModelName + "/" + parts[i]
417+
}
418+
419+
if real, ok := pm.config.RealModelName(searchModelName); ok {
420+
modelName = real
421+
remainingPath = "/" + strings.Join(parts[i+1:], "/")
422+
modelFound = true
423+
break
424+
}
425+
}
426+
427+
if !modelFound {
428+
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
429+
return
430+
}
431+
432+
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
404433
if err != nil {
405434
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
406435
return
407436
}
408437

409438
// rewrite the path
410-
c.Request.URL.Path = c.Param("upstreamPath")
439+
c.Request.URL.Path = remainingPath
411440
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
412441
}
413-
414442
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
415443
bodyBytes, err := io.ReadAll(c.Request.Body)
416444
if err != nil {

0 commit comments

Comments
 (0)