Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,34 @@ type ModelConfig struct {
ConcurrencyLimit int `yaml:"concurrencyLimit"`
}

func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
}

// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}

if err := unmarshal(&defaults); err != nil {
return err
}

*m = ModelConfig(defaults)
return nil
}

func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
Expand Down Expand Up @@ -111,26 +139,23 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
return Config{}, err
}

var config Config
// default configuration values
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
}
err = yaml.Unmarshal(data, &config)
if err != nil {
return Config{}, err
}

if config.HealthCheckTimeout == 0 {
// this high default timeout helps avoid failing health checks
// for configurations that wait for docker or have slower startup
config.HealthCheckTimeout = 120
} else if config.HealthCheckTimeout < 15 {
if config.HealthCheckTimeout < 15 {
// set a minimum of 15 seconds
config.HealthCheckTimeout = 15
}

// set default port ranges
if config.StartPort == 0 {
// default to 5800
config.StartPort = 5800
} else if config.StartPort < 1 {
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}

Expand Down Expand Up @@ -180,6 +205,11 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]

// enforce ${PORT} used in both cmd and proxy
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
}

// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range config.Macros {
macroSlug := fmt.Sprintf("${%s}", macroName)
Expand All @@ -191,17 +221,11 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {

// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
if modelConfig.Proxy == "" {
modelConfig.Proxy = "http://localhost:${PORT}"
}

nextPortStr := strconv.Itoa(nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
nextPort++
} else if modelConfig.Proxy == "" {
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
}

// make sure there are no unknown macros that have not been replaced
Expand Down
42 changes: 42 additions & 0 deletions proxy/config_posix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package proxy

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -40,3 +41,44 @@ func TestConfig_SanitizeCommand(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, args)
}

// Test the default values are automatically set for global, model and group configurations
// after loading the configuration
func TestConfig_DefaultValues(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
`

config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)

// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
assert.True(t, exists, "default group should exist")
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
assert.Equal(t, true, defaultGroup.Swap)
assert.Equal(t, true, defaultGroup.Exclusive)
assert.Equal(t, false, defaultGroup.Persistent)
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
}

model1, exists := config.Models["model1"]
assert.True(t, exists, "model1 should exist")
if assert.NotNil(t, model1, "model1 should not be nil") {
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
assert.Equal(t, "", model1.CmdStop)
assert.Equal(t, "http://localhost:5800", model1.Proxy)
assert.Equal(t, "/health", model1.CheckEndpoint)
assert.Equal(t, []string{}, model1.Aliases)
assert.Equal(t, []string{}, model1.Env)
assert.Equal(t, 0, model1.UnloadAfter)
assert.Equal(t, false, model1.Unlisted)
assert.Equal(t, "", model1.UseModelName)
assert.Equal(t, 0, model1.ConcurrencyLimit)
}
}
7 changes: 5 additions & 2 deletions proxy/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ groups:
}

expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
"svr-path": "path/to/server",
Expand All @@ -93,20 +94,22 @@ groups:
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: nil,
Env: []string{},
CheckEndpoint: "/",
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: nil,
Env: []string{},
CheckEndpoint: "/",
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
},
},
HealthCheckTimeout: 15,
Expand Down
40 changes: 40 additions & 0 deletions proxy/config_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package proxy

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -39,3 +40,42 @@ func TestConfig_SanitizeCommand(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, args)
}

func TestConfig_DefaultValuesWindows(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
`

config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)

// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
assert.True(t, exists, "default group should exist")
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
assert.Equal(t, true, defaultGroup.Swap)
assert.Equal(t, true, defaultGroup.Exclusive)
assert.Equal(t, false, defaultGroup.Persistent)
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
}

model1, exists := config.Models["model1"]
assert.True(t, exists, "model1 should exist")
if assert.NotNil(t, model1, "model1 should not be nil") {
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
assert.Equal(t, "", model1.CmdStop)
assert.Equal(t, "http://localhost:5800", model1.Proxy)
assert.Equal(t, "/health", model1.CheckEndpoint)
assert.Equal(t, []string{}, model1.Aliases)
assert.Equal(t, []string{}, model1.Env)
assert.Equal(t, 0, model1.UnloadAfter)
assert.Equal(t, false, model1.Unlisted)
assert.Equal(t, "", model1.UseModelName)
assert.Equal(t, 0, model1.ConcurrencyLimit)
}
}
11 changes: 0 additions & 11 deletions proxy/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net/http"
"net/url"
"os/exec"
"runtime"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -232,11 +231,6 @@ func (p *Process) start() error {

// a "none" means don't check for health ... I could have picked a better word :facepalm:
if checkEndpoint != "none" {
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}

proxyTo := p.config.Proxy
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
Expand Down Expand Up @@ -523,11 +517,6 @@ func (p *Process) cmdStopUpstreamProcess() error {
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
}

// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" {
p.config.CmdStop = "taskkill /f /t /pid ${PID}"
}

if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
Expand Down
Loading