Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions proxy/processgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}

func (pg *ProcessGroup) StopProcess(modelID string) error {
pg.Lock()
defer pg.Unlock()

process, exists := pg.processes[modelID]
if !exists {
return fmt.Errorf("process not found for %s", modelID)
}

process.StopImmediately()
return nil
}

func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
Expand Down
24 changes: 24 additions & 0 deletions proxy/proxymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ func (pm *ProxyManager) setupGinEngine() {
})
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)

pm.ginEngine.GET("/unload/*model", pm.unloadSingleModelHandler)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) {
Expand Down Expand Up @@ -628,6 +629,29 @@ func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
c.String(http.StatusOK, "OK")
}

func (pm *ProxyManager) unloadSingleModelHandler(c *gin.Context) {
requestedModel := strings.TrimPrefix(c.Param("model"), "/")

realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
c.String(http.StatusNotFound, "Model not found")
return
}

processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
c.String(http.StatusInternalServerError, "process group not found for model %s", requestedModel)
return
}

if err := processGroup.StopProcess(realModelName); err != nil {
c.String(http.StatusInternalServerError, "error stopping process: %s", err.Error())
return
} else {
c.String(http.StatusOK, "OK")
}
}

func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
Expand Down
54 changes: 52 additions & 2 deletions proxy/proxymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,61 @@ func TestProxyManager_Unload(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")

// give it a bit of time to stop
<-time.After(time.Millisecond * 250)
select {
case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
// good
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for model1 to stop")
}
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
}

func TestProxyManager_UnloadSingleModel(t *testing.T) {
const testGroupId = "testGroup"
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]GroupConfig{
testGroupId: {
Swap: false,
Members: []string{"model1", "model2"},
},
},
LogLevel: "error",
})

proxy := New(config)

// start both model
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
}

assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateReady)
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)

req := httptest.NewRequest("GET", "/unload/model1", nil)
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")

select {
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
// good
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for model1 to stop")
}

assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
}

// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
Expand Down
16 changes: 16 additions & 0 deletions ui/src/contexts/APIProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ interface APIProviderType {
models: Model[];
listModels: () => Promise<Model[]>;
unloadAllModels: () => Promise<void>;
unloadSingleModel: (model: string) => Promise<void>;
loadModel: (model: string) => Promise<void>;
enableAPIEvents: (enabled: boolean) => void;
proxyLogs: string;
Expand Down Expand Up @@ -189,6 +190,20 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
}
}, []);

const unloadSingleModel = useCallback(async (model: string) => {
try {
const response = await fetch(`/unload/${model}`, {
method: "GET",
});
if (!response.ok) {
throw new Error(`Failed to unload model: ${response.status}`);
}
} catch (error) {
console.error("Failed to unload model", model, error);
throw error;
}
}, []);

const loadModel = useCallback(async (model: string) => {
try {
const response = await fetch(`/upstream/${model}/`, {
Expand All @@ -208,6 +223,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
models,
listModels,
unloadAllModels,
unloadSingleModel,
loadModel,
enableAPIEvents,
proxyLogs,
Expand Down
22 changes: 14 additions & 8 deletions ui/src/pages/Models.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export default function ModelsPage() {
}

function ModelsPanel() {
const { models, loadModel, unloadAllModels } = useAPI();
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
const [isUnloading, setIsUnloading] = useState(false);
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
Expand Down Expand Up @@ -119,13 +119,19 @@ function ModelsPanel() {
)}
</td>
<td className="w-12">
<button
className="btn btn--sm"
disabled={model.state !== "stopped"}
onClick={() => loadModel(model.id)}
>
Load
</button>
{model.state === "stopped" ? (
<button className="btn btn--sm" onClick={() => loadModel(model.id)}>
Load
</button>
) : (
<button
className="btn btn--sm"
onClick={() => unloadSingleModel(model.id)}
disabled={model.state !== "ready"}
>
Unload
</button>
)}
</td>
<td className="w-20">
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>
Expand Down
1 change: 1 addition & 0 deletions ui/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export default defineConfig({
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
"/logs": "http://localhost:8080",
"/upstream": "http://localhost:8080",
"/unload": "http://localhost:8080",
},
},
});
Loading