Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions proxy/processgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}

func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
pg.Lock()

process, exists := pg.processes[modelID]
if !exists {
pg.Unlock()
return fmt.Errorf("process not found for %s", modelID)
}

if pg.lastUsedProcess == modelID {
pg.lastUsedProcess = ""
}
pg.Unlock()

switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
return nil
}

func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
Expand Down
1 change: 0 additions & 1 deletion proxy/proxymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ func (pm *ProxyManager) setupGinEngine() {
c.Redirect(http.StatusFound, "/ui/models")
})
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)

pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) {
Expand Down
25 changes: 25 additions & 0 deletions proxy/proxymanager_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package proxy
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"strings"

"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
Expand All @@ -23,6 +25,7 @@ func addApiHandlers(pm *ProxyManager) {
apiGroup := pm.ginEngine.Group("/api")
{
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics)
}
Expand Down Expand Up @@ -202,3 +205,25 @@ func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
}
c.Data(http.StatusOK, "application/json", jsonData)
}

func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
return
}

processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
return
}

if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
return
} else {
c.String(http.StatusOK, "OK")
}
}
58 changes: 56 additions & 2 deletions proxy/proxymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,65 @@ func TestProxyManager_Unload(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")

// give it a bit of time to stop
<-time.After(time.Millisecond * 250)
select {
case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
// good
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for model1 to stop")
}
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
}

func TestProxyManager_UnloadSingleModel(t *testing.T) {
const testGroupId = "testGroup"
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]GroupConfig{
testGroupId: {
Swap: false,
Members: []string{"model1", "model2"},
},
},
LogLevel: "error",
})

proxy := New(config)
defer proxy.StopProcesses(StopImmediately)

// start both model
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
}

assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())

req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
if !assert.Equal(t, w.Body.String(), "OK") {
t.FailNow()
}

select {
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
// good
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for model1 to stop")
}

assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
}

// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
Expand Down
18 changes: 17 additions & 1 deletion ui/src/contexts/APIProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ interface APIProviderType {
models: Model[];
listModels: () => Promise<Model[]>;
unloadAllModels: () => Promise<void>;
unloadSingleModel: (model: string) => Promise<void>;
loadModel: (model: string) => Promise<void>;
enableAPIEvents: (enabled: boolean) => void;
proxyLogs: string;
Expand Down Expand Up @@ -177,7 +178,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider

const unloadAllModels = useCallback(async () => {
try {
const response = await fetch(`/api/models/unload/`, {
const response = await fetch(`/api/models/unload`, {
method: "POST",
});
if (!response.ok) {
Expand All @@ -189,6 +190,20 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
}
}, []);

const unloadSingleModel = useCallback(async (model: string) => {
try {
const response = await fetch(`/api/models/unload/${model}`, {
method: "POST",
});
if (!response.ok) {
throw new Error(`Failed to unload model: ${response.status}`);
}
} catch (error) {
console.error("Failed to unload model", model, error);
throw error;
}
}, []);

const loadModel = useCallback(async (model: string) => {
try {
const response = await fetch(`/upstream/${model}/`, {
Expand All @@ -208,6 +223,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
models,
listModels,
unloadAllModels,
unloadSingleModel,
loadModel,
enableAPIEvents,
proxyLogs,
Expand Down
26 changes: 16 additions & 10 deletions ui/src/pages/Models.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer";
import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import { useTheme } from "../contexts/ThemeProvider";
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri";
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine } from "react-icons/ri";

export default function ModelsPage() {
const { isNarrow } = useTheme();
Expand Down Expand Up @@ -37,7 +37,7 @@ export default function ModelsPage() {
}

function ModelsPanel() {
const { models, loadModel, unloadAllModels } = useAPI();
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
const [isUnloading, setIsUnloading] = useState(false);
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
Expand Down Expand Up @@ -90,7 +90,7 @@ function ModelsPanel() {
onClick={handleUnloadAllModels}
disabled={isUnloading}
>
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
</div>
Expand Down Expand Up @@ -119,13 +119,19 @@ function ModelsPanel() {
)}
</td>
<td className="w-12">
<button
className="btn btn--sm"
disabled={model.state !== "stopped"}
onClick={() => loadModel(model.id)}
>
Load
</button>
{model.state === "stopped" ? (
<button className="btn btn--sm" onClick={() => loadModel(model.id)}>
Load
</button>
) : (
<button
className="btn btn--sm"
onClick={() => unloadSingleModel(model.id)}
disabled={model.state !== "ready"}
>
Unload
</button>
)}
</td>
<td className="w-20">
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>
Expand Down
1 change: 1 addition & 0 deletions ui/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export default defineConfig({
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
"/logs": "http://localhost:8080",
"/upstream": "http://localhost:8080",
"/unload": "http://localhost:8080",
},
},
});
Loading