diff --git a/cmd/picoclaw/internal/gateway/command.go b/cmd/picoclaw/internal/gateway/command.go
index bfa69f0723..4812f1beef 100644
--- a/cmd/picoclaw/internal/gateway/command.go
+++ b/cmd/picoclaw/internal/gateway/command.go
@@ -5,6 +5,8 @@ import (
"github.com/spf13/cobra"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
+ "github.com/sipeed/picoclaw/pkg/gateway"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -12,6 +14,7 @@ import (
func NewGatewayCommand() *cobra.Command {
var debug bool
var noTruncate bool
+ var allowEmpty bool
cmd := &cobra.Command{
Use: "gateway",
@@ -31,12 +34,19 @@ func NewGatewayCommand() *cobra.Command {
return nil
},
RunE: func(_ *cobra.Command, _ []string) error {
- return gatewayCmd(debug)
+ return gateway.Run(debug, internal.GetConfigPath(), allowEmpty)
},
}
cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging")
cmd.Flags().BoolVarP(&noTruncate, "no-truncate", "T", false, "Disable string truncation in debug logs")
+ cmd.Flags().BoolVarP(
+ &allowEmpty,
+ "allow-empty",
+ "E",
+ false,
+ "Continue starting even when no default model is configured",
+ )
return cmd
}
diff --git a/cmd/picoclaw/internal/gateway/command_test.go b/cmd/picoclaw/internal/gateway/command_test.go
index 4d591ea672..839a7315a0 100644
--- a/cmd/picoclaw/internal/gateway/command_test.go
+++ b/cmd/picoclaw/internal/gateway/command_test.go
@@ -28,4 +28,5 @@ func TestNewGatewayCommand(t *testing.T) {
assert.True(t, cmd.HasFlags())
assert.NotNil(t, cmd.Flags().Lookup("debug"))
+ assert.NotNil(t, cmd.Flags().Lookup("allow-empty"))
}
diff --git a/config/config.example.json b/config/config.example.json
index 1c11cd42a9..14e2092598 100644
--- a/config/config.example.json
+++ b/config/config.example.json
@@ -518,6 +518,7 @@
},
"gateway": {
"host": "127.0.0.1",
- "port": 18790
+ "port": 18790,
+ "hot_reload": false
}
}
diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go
index df430e4d3a..8121525ab3 100644
--- a/pkg/channels/manager.go
+++ b/pkg/channels/manager.go
@@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
if len(m.channels) == 0 {
logger.WarnC("channels", "No channels enabled")
- return errors.New("no channels enabled")
}
logger.InfoC("channels", "Starting all channels")
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 35de48f23b..6694ef3a10 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -625,8 +625,9 @@ func (c *ModelConfig) Validate() error {
}
type GatewayConfig struct {
- Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
- Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
+ Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
+ Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
+ HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"`
}
type ToolDiscoveryConfig struct {
diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go
index fc835f78f6..f4f8979e13 100644
--- a/pkg/config/config_test.go
+++ b/pkg/config/config_test.go
@@ -267,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) {
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
+ if cfg.Gateway.HotReload {
+ t.Error("Gateway hot reload should be disabled by default")
+ }
}
// TestDefaultConfig_Providers verifies provider structure
diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go
index 2b177d5de5..90a99408e5 100644
--- a/pkg/config/defaults.go
+++ b/pkg/config/defaults.go
@@ -395,8 +395,9 @@ func DefaultConfig() *Config {
},
},
Gateway: GatewayConfig{
- Host: "127.0.0.1",
- Port: 18790,
+ Host: "127.0.0.1",
+ Port: 18790,
+ HotReload: false,
},
Tools: ToolsConfig{
MediaCleanup: MediaCleanupConfig{
diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/pkg/gateway/gateway.go
similarity index 61%
rename from cmd/picoclaw/internal/gateway/helpers.go
rename to pkg/gateway/gateway.go
index 85e93bcf96..6745d1748b 100644
--- a/cmd/picoclaw/internal/gateway/helpers.go
+++ b/pkg/gateway/gateway.go
@@ -10,7 +10,6 @@ import (
"syscall"
"time"
- "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -42,15 +41,13 @@ import (
"github.com/sipeed/picoclaw/pkg/voice"
)
-// Timeout constants for service operations
const (
serviceShutdownTimeout = 30 * time.Second
providerReloadTimeout = 30 * time.Second
gracefulShutdownTimeout = 15 * time.Second
)
-// gatewayServices holds references to all running services
-type gatewayServices struct {
+type services struct {
CronService *cron.CronService
HeartbeatService *heartbeat.HeartbeatService
MediaStore media.MediaStore
@@ -59,24 +56,41 @@ type gatewayServices struct {
HealthServer *health.Server
}
-func gatewayCmd(debug bool) error {
+type startupBlockedProvider struct {
+ reason string
+}
+
+func (p *startupBlockedProvider) Chat(
+ _ context.Context,
+ _ []providers.Message,
+ _ []providers.ToolDefinition,
+ _ string,
+ _ map[string]any,
+) (*providers.LLMResponse, error) {
+ return nil, fmt.Errorf("%s", p.reason)
+}
+
+func (p *startupBlockedProvider) GetDefaultModel() string {
+ return ""
+}
+
+// Run starts the gateway runtime using the configuration loaded from configPath.
+func Run(debug bool, configPath string, allowEmptyStartup bool) error {
if debug {
logger.SetLevel(logger.DEBUG)
fmt.Println("๐ Debug mode enabled")
}
- configPath := internal.GetConfigPath()
- cfg, err := internal.LoadConfig()
+ cfg, err := config.LoadConfig(configPath)
if err != nil {
return fmt.Errorf("error loading config: %w", err)
}
- provider, modelID, err := providers.CreateProvider(cfg)
+ provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
if err != nil {
return fmt.Errorf("error creating provider: %w", err)
}
- // Use the resolved model ID from provider creation
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
@@ -84,17 +98,13 @@ func gatewayCmd(debug bool) error {
msgBus := bus.NewMessageBus()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
- // Print agent startup info
fmt.Println("\n๐ฆ Agent Status:")
startupInfo := agentLoop.GetStartupInfo()
toolsInfo := startupInfo["tools"].(map[string]any)
skillsInfo := startupInfo["skills"].(map[string]any)
fmt.Printf(" โข Tools: %d loaded\n", toolsInfo["count"])
- fmt.Printf(" โข Skills: %d/%d available\n",
- skillsInfo["available"],
- skillsInfo["total"])
+ fmt.Printf(" โข Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"])
- // Log to file as well
logger.InfoCF("agent", "Agent initialized",
map[string]any{
"tools_count": toolsInfo["count"],
@@ -102,8 +112,7 @@ func gatewayCmd(debug bool) error {
"skills_available": skillsInfo["available"],
})
- // Setup and start all services
- services, err := setupAndStartServices(cfg, agentLoop, msgBus)
+ runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
if err != nil {
return err
}
@@ -116,23 +125,25 @@ func gatewayCmd(debug bool) error {
go agentLoop.Run(ctx)
- // Setup config file watcher for hot reload
- configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug)
+ var configReloadChan <-chan *config.Config
+ stopWatch := func() {}
+ if cfg.Gateway.HotReload {
+ configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug)
+ logger.Info("Config hot reload enabled")
+ }
defer stopWatch()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
- // Main event loop - wait for signals or config changes
for {
select {
case <-sigChan:
logger.Info("Shutting down...")
- shutdownGateway(services, agentLoop, provider, true)
+ shutdownGateway(runningServices, agentLoop, provider, true)
return nil
-
case newCfg := <-configReloadChan:
- err := handleConfigReload(ctx, agentLoop, newCfg, &provider, services, msgBus)
+ err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup)
if err != nil {
logger.Errorf("Config reload failed: %v", err)
}
@@ -140,18 +151,33 @@ func gatewayCmd(debug bool) error {
}
}
-// setupAndStartServices initializes and starts all services
+func createStartupProvider(
+ cfg *config.Config,
+ allowEmptyStartup bool,
+) (providers.LLMProvider, string, error) {
+ modelName := cfg.Agents.Defaults.GetModelName()
+ if modelName == "" && allowEmptyStartup {
+ reason := "no default model configured; gateway started in limited mode"
+ fmt.Printf("โ Warning: %s\n", reason)
+ logger.WarnCF("gateway", "Gateway started without default model", map[string]any{
+ "limited_mode": true,
+ })
+ return &startupBlockedProvider{reason: reason}, "", nil
+ }
+
+ return providers.CreateProvider(cfg)
+}
+
func setupAndStartServices(
cfg *config.Config,
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
-) (*gatewayServices, error) {
- services := &gatewayServices{}
+) (*services, error) {
+ runningServices := &services{}
- // Setup cron tool and service
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
var err error
- services.CronService, err = setupCronTool(
+ runningServices.CronService, err = setupCronTool(
agentLoop,
msgBus,
cfg.WorkspacePath(),
@@ -162,120 +188,105 @@ func setupAndStartServices(
if err != nil {
return nil, fmt.Errorf("error setting up cron service: %w", err)
}
- if err = services.CronService.Start(); err != nil {
+ if err = runningServices.CronService.Start(); err != nil {
return nil, fmt.Errorf("error starting cron service: %w", err)
}
fmt.Println("โ Cron service started")
- // Setup heartbeat service
- services.HeartbeatService = heartbeat.NewHeartbeatService(
+ runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
- services.HeartbeatService.SetBus(msgBus)
- services.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
- if err = services.HeartbeatService.Start(); err != nil {
+ runningServices.HeartbeatService.SetBus(msgBus)
+ runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
+ if err = runningServices.HeartbeatService.Start(); err != nil {
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
}
fmt.Println("โ Heartbeat service started")
- // Create media store for file lifecycle management with TTL cleanup
- services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
+ runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
- // Start the media store if it's a FileMediaStore with cleanup
- if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
+ if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
- // Create channel manager
- services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
+ runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
- // Stop the media store if it's a FileMediaStore with cleanup
- if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
+ if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
return nil, fmt.Errorf("error creating channel manager: %w", err)
}
- // Inject channel manager and media store into agent loop
- agentLoop.SetChannelManager(services.ChannelManager)
- agentLoop.SetMediaStore(services.MediaStore)
+ agentLoop.SetChannelManager(runningServices.ChannelManager)
+ agentLoop.SetMediaStore(runningServices.MediaStore)
- // Wire up voice transcription if a supported provider is configured.
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
agentLoop.SetTranscriber(transcriber)
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
}
- enabledChannels := services.ChannelManager.GetEnabledChannels()
+ enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf("โ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println("โ Warning: No channels enabled")
}
- // Setup shared HTTP server with health endpoints and webhook handlers
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
- services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
- services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
+ runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
+ runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
- if err = services.ChannelManager.StartAll(context.Background()); err != nil {
+ if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return nil, fmt.Errorf("error starting channels: %w", err)
}
fmt.Printf("โ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
- // Setup state manager and device service
stateManager := state.NewManager(cfg.WorkspacePath())
- services.DeviceService = devices.NewService(devices.Config{
+ runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
- services.DeviceService.SetBus(msgBus)
- if err = services.DeviceService.Start(context.Background()); err != nil {
+ runningServices.DeviceService.SetBus(msgBus)
+ if err = runningServices.DeviceService.Start(context.Background()); err != nil {
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println("โ Device event service started")
}
- return services, nil
+ return runningServices, nil
}
-// stopAndCleanupServices stops all services and cleans up resources
-func stopAndCleanupServices(
- services *gatewayServices,
- shutdownTimeout time.Duration,
-) {
+func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer shutdownCancel()
- if services.ChannelManager != nil {
- services.ChannelManager.StopAll(shutdownCtx)
+ if runningServices.ChannelManager != nil {
+ runningServices.ChannelManager.StopAll(shutdownCtx)
}
- if services.DeviceService != nil {
- services.DeviceService.Stop()
+ if runningServices.DeviceService != nil {
+ runningServices.DeviceService.Stop()
}
- if services.HeartbeatService != nil {
- services.HeartbeatService.Stop()
+ if runningServices.HeartbeatService != nil {
+ runningServices.HeartbeatService.Stop()
}
- if services.CronService != nil {
- services.CronService.Stop()
+ if runningServices.CronService != nil {
+ runningServices.CronService.Stop()
}
- if services.MediaStore != nil {
- // Stop the media store if it's a FileMediaStore with cleanup
- if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
+ if runningServices.MediaStore != nil {
+ if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
}
}
-// shutdownGateway performs a complete gateway shutdown
func shutdownGateway(
- services *gatewayServices,
+ runningServices *services,
agentLoop *agent.AgentLoop,
provider providers.LLMProvider,
fullShutdown bool,
@@ -284,7 +295,7 @@ func shutdownGateway(
cp.Close()
}
- stopAndCleanupServices(services, gracefulShutdownTimeout)
+ stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
agentLoop.Stop()
agentLoop.Close()
@@ -292,15 +303,14 @@ func shutdownGateway(
logger.Info("โ Gateway stopped")
}
-// handleConfigReload handles config file reload by stopping all services,
-// reloading the provider and config, and restarting services with the new config.
func handleConfigReload(
ctx context.Context,
al *agent.AgentLoop,
newCfg *config.Config,
providerRef *providers.LLMProvider,
- services *gatewayServices,
+ runningServices *services,
msgBus *bus.MessageBus,
+ allowEmptyStartup bool,
) error {
logger.Info("๐ Config file changed, reloading...")
@@ -311,18 +321,14 @@ func handleConfigReload(
logger.Infof(" New model is '%s', recreating provider...", newModel)
- // Stop all services before reloading
logger.Info(" Stopping all services...")
- stopAndCleanupServices(services, serviceShutdownTimeout)
+ stopAndCleanupServices(runningServices, serviceShutdownTimeout)
- // Create new provider from updated config first to ensure validity
- // This will use the correct API key and settings from newCfg.ModelList
- newProvider, newModelID, err := providers.CreateProvider(newCfg)
+ newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
if err != nil {
logger.Errorf(" โ Error creating new provider: %v", err)
logger.Warn(" Attempting to restart services with old provider and config...")
- // Try to restart services with old configuration
- if restartErr := restartServices(al, services, msgBus); restartErr != nil {
+ if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
logger.Errorf(" โ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error creating new provider: %w", err)
@@ -332,31 +338,25 @@ func handleConfigReload(
newCfg.Agents.Defaults.ModelName = newModelID
}
- // Use the atomic reload method on AgentLoop to safely swap provider and config.
- // This handles locking internally to prevent races with in-flight LLM calls
- // and concurrent reads of registry/config while the swap occurs.
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
defer reloadCancel()
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
logger.Errorf(" โ Error reloading agent loop: %v", err)
- // Close the newly created provider since it wasn't adopted
if cp, ok := newProvider.(providers.StatefulProvider); ok {
cp.Close()
}
logger.Warn(" Attempting to restart services with old provider and config...")
- if restartErr := restartServices(al, services, msgBus); restartErr != nil {
+ if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
logger.Errorf(" โ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error reloading agent loop: %w", err)
}
- // Update local provider reference only after successful atomic reload
*providerRef = newProvider
- // Restart all services with new config
logger.Info(" Restarting all services with new configuration...")
- if err := restartServices(al, services, msgBus); err != nil {
+ if err := restartServices(al, runningServices, msgBus); err != nil {
logger.Errorf(" โ Error restarting services: %v", err)
return fmt.Errorf("error restarting services: %w", err)
}
@@ -365,19 +365,16 @@ func handleConfigReload(
return nil
}
-// restartServices restarts all services after a config reload
func restartServices(
al *agent.AgentLoop,
- services *gatewayServices,
+ runningServices *services,
msgBus *bus.MessageBus,
) error {
- // Get current config from agent loop (which has been updated if this is a reload)
cfg := al.GetConfig()
- // Re-create and start cron service with new config
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
var err error
- services.CronService, err = setupCronTool(
+ runningServices.CronService, err = setupCronTool(
al,
msgBus,
cfg.WorkspacePath(),
@@ -388,57 +385,51 @@ func restartServices(
if err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
- if err = services.CronService.Start(); err != nil {
+ if err = runningServices.CronService.Start(); err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
fmt.Println(" โ Cron service restarted")
- // Re-create and start heartbeat service with new config
- services.HeartbeatService = heartbeat.NewHeartbeatService(
+ runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
- services.HeartbeatService.SetBus(msgBus)
- services.HeartbeatService.SetHandler(createHeartbeatHandler(al))
- if err = services.HeartbeatService.Start(); err != nil {
+ runningServices.HeartbeatService.SetBus(msgBus)
+ runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al))
+ if err = runningServices.HeartbeatService.Start(); err != nil {
return fmt.Errorf("error restarting heartbeat service: %w", err)
}
fmt.Println(" โ Heartbeat service restarted")
- // Re-create media store with new config
- services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
+ runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
- // Start the media store if it's a FileMediaStore with cleanup
- if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
+ if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
- al.SetMediaStore(services.MediaStore)
+ al.SetMediaStore(runningServices.MediaStore)
- // Re-create channel manager with new config
- services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
+ runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
return fmt.Errorf("error recreating channel manager: %w", err)
}
- al.SetChannelManager(services.ChannelManager)
+ al.SetChannelManager(runningServices.ChannelManager)
- enabledChannels := services.ChannelManager.GetEnabledChannels()
+ enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf(" โ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println(" โ Warning: No channels enabled")
}
- // Setup HTTP server with new config
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
- services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
- services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
+ runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
+ runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
- // Use background context for lifecycle to ensure services persist after restartServices returns
- if err = services.ChannelManager.StartAll(context.Background()); err != nil {
+ if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return fmt.Errorf("error restarting channels: %w", err)
}
fmt.Printf(
@@ -447,22 +438,20 @@ func restartServices(
cfg.Gateway.Port,
)
- // Re-create device service with new config
stateManager := state.NewManager(cfg.WorkspacePath())
- services.DeviceService = devices.NewService(devices.Config{
+ runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
- services.DeviceService.SetBus(msgBus)
- if err := services.DeviceService.Start(context.Background()); err != nil {
+ runningServices.DeviceService.SetBus(msgBus)
+ if err := runningServices.DeviceService.Start(context.Background()); err != nil {
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println(" โ Device event service restarted")
}
- // Wire up voice transcription with new config
transcriber := voice.DetectTranscriber(cfg)
- al.SetTranscriber(transcriber) // This will set it to nil if disabled
+ al.SetTranscriber(transcriber)
if transcriber != nil {
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
} else {
@@ -472,8 +461,6 @@ func restartServices(
return nil
}
-// setupConfigWatcherPolling sets up a simple polling-based config file watcher
-// Returns a channel for config updates and a stop function
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
configChan := make(chan *config.Config, 1)
stop := make(chan struct{})
@@ -483,11 +470,10 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
go func() {
defer wg.Done()
- // Get initial file info
lastModTime := getFileModTime(configPath)
lastSize := getFileSize(configPath)
- ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
+ ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
@@ -496,20 +482,16 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
currentModTime := getFileModTime(configPath)
currentSize := getFileSize(configPath)
- // Check if file changed (modification time or size changed)
if currentModTime.After(lastModTime) || currentSize != lastSize {
if debug {
logger.Debugf("๐ Config file change detected")
}
- // Debounce - wait a bit to ensure file write is complete
time.Sleep(500 * time.Millisecond)
- // Update last known state to prevent repeated reload attempts on failure
lastModTime = currentModTime
lastSize = currentSize
- // Validate and load new config
newCfg, err := config.LoadConfig(configPath)
if err != nil {
logger.Errorf("โ Error loading new config: %v", err)
@@ -517,7 +499,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
continue
}
- // Validate the new config
if err := newCfg.ValidateModelList(); err != nil {
logger.Errorf(" โ New config validation failed: %v", err)
logger.Warn(" Using previous valid config")
@@ -526,15 +507,12 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
logger.Info("โ Config file validated and loaded")
- // Send new config to main loop (non-blocking)
select {
case configChan <- newCfg:
default:
- // Channel full, skip this update
logger.Warn("โ Previous config reload still in progress, skipping")
}
}
-
case <-stop:
return
}
@@ -549,7 +527,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
return configChan, stopFunc
}
-// getFileModTime returns the modification time of a file, or zero time if file doesn't exist
func getFileModTime(path string) time.Time {
info, err := os.Stat(path)
if err != nil {
@@ -558,7 +535,6 @@ func getFileModTime(path string) time.Time {
return info.ModTime()
}
-// getFileSize returns the size of a file, or 0 if file doesn't exist
func getFileSize(path string) int64 {
info, err := os.Stat(path)
if err != nil {
@@ -577,10 +553,8 @@ func setupCronTool(
) (*cron.CronService, error) {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
- // Create cron service
cronService := cron.NewCronService(cronStorePath, nil)
- // Create and register CronTool if enabled
var cronTool *tools.CronTool
if cfg.Tools.IsToolEnabled("cron") {
var err error
@@ -592,7 +566,6 @@ func setupCronTool(
agentLoop.RegisterTool(cronTool)
}
- // Set onJob handler
if cronTool != nil {
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
result := cronTool.ExecuteJob(context.Background(), job)
@@ -605,22 +578,17 @@ func setupCronTool(
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
return func(prompt, channel, chatID string) *tools.ToolResult {
- // Use cli:direct as fallback if no valid channel
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
- // Use ProcessHeartbeat - no session history, each heartbeat is independent
- var response string
- var err error
- response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
+
+ response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
- // For heartbeat, always return silent - the subagent result will be
- // sent to user via processSystemMessage when the async task completes
return tools.SilentResult(response)
}
}
diff --git a/web/backend/api/events.go b/web/backend/api/events.go
deleted file mode 100644
index 5c85b149ad..0000000000
--- a/web/backend/api/events.go
+++ /dev/null
@@ -1,80 +0,0 @@
-package api
-
-import (
- "encoding/json"
- "sync"
-)
-
-// GatewayEvent represents a state change event for the gateway process.
-type GatewayEvent struct {
- Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error"
- PID int `json:"pid,omitempty"`
- BootDefaultModel string `json:"boot_default_model,omitempty"`
- ConfigDefaultModel string `json:"config_default_model,omitempty"`
- RestartRequired bool `json:"gateway_restart_required,omitempty"`
-}
-
-// EventBroadcaster manages SSE client subscriptions and broadcasts events.
-type EventBroadcaster struct {
- mu sync.RWMutex
- clients map[chan string]struct{}
-}
-
-// NewEventBroadcaster creates a new broadcaster.
-func NewEventBroadcaster() *EventBroadcaster {
- return &EventBroadcaster{
- clients: make(map[chan string]struct{}),
- }
-}
-
-// Subscribe adds a new listener channel and returns it.
-// The caller must call Unsubscribe when done.
-func (b *EventBroadcaster) Subscribe() chan string {
- ch := make(chan string, 8)
- b.mu.Lock()
- b.clients[ch] = struct{}{}
- b.mu.Unlock()
- return ch
-}
-
-// Unsubscribe removes a listener channel and closes it.
-func (b *EventBroadcaster) Unsubscribe(ch chan string) {
- b.mu.Lock()
- defer b.mu.Unlock()
-
- // Check if the channel is still registered before closing
- if _, exists := b.clients[ch]; exists {
- delete(b.clients, ch)
- close(ch)
- }
-}
-
-// Broadcast sends a GatewayEvent to all connected SSE clients.
-func (b *EventBroadcaster) Broadcast(event GatewayEvent) {
- data, err := json.Marshal(event)
- if err != nil {
- return
- }
-
- b.mu.RLock()
- defer b.mu.RUnlock()
-
- for ch := range b.clients {
- // Non-blocking send; drop event if client is slow
- select {
- case ch <- string(data):
- default:
- }
- }
-}
-
-// Shutdown closes all subscriber channels, notifying all SSE clients to disconnect.
-// This should be called when the server is shutting down.
-func (b *EventBroadcaster) Shutdown() {
- // Close all channels to notify listeners
- for ch := range b.clients {
- b.Unsubscribe(ch)
- }
- // Clear the map
- b.clients = make(map[chan string]struct{})
-}
diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go
index 424b21e96c..16b793427a 100644
--- a/web/backend/api/gateway.go
+++ b/web/backend/api/gateway.go
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log"
+ "net"
"net/http"
"os"
"os/exec"
@@ -30,11 +31,9 @@ var gateway = struct {
runtimeStatus string
startupDeadline time.Time
logs *LogBuffer
- events *EventBroadcaster
}{
runtimeStatus: "stopped",
logs: NewLogBuffer(200),
- events: NewEventBroadcaster(),
}
var (
@@ -51,11 +50,19 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
// getGatewayHealth checks the gateway health endpoint and returns the status response
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
-func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse, int, error) {
- if port == 0 {
- port = 18790
+func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
+ port := 18790
+ if cfg != nil && cfg.Gateway.Port != 0 {
+ port = cfg.Gateway.Port
}
- url := fmt.Sprintf("http://127.0.0.1:%d/health", port)
+
+ probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
+ url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
+
+ return getGatewayHealthByURL(url, timeout)
+}
+
+func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusResponse, int, error) {
resp, err := gatewayHealthGet(url, timeout)
if err != nil {
return nil, 0, err
@@ -73,7 +80,6 @@ func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse,
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
- mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
mux.HandleFunc("GET /api/gateway/logs", h.handleGatewayLogs)
mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs)
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
@@ -87,7 +93,7 @@ func (h *Handler) TryAutoStartGateway() {
// Check if gateway is already running via health endpoint
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
- healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second)
+ healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err == nil && statusCode == http.StatusOK {
// Gateway is already running, attach to the existing process
pid := healthResp.Pid
@@ -170,6 +176,16 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig
return modelCfg
}
+func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool {
+ if gatewayStatus != "running" {
+ return false
+ }
+ if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" {
+ return false
+ }
+ return configDefaultModel != bootDefaultModel
+}
+
func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
if cmd == nil || cmd.Process == nil {
return false
@@ -220,7 +236,7 @@ func attachToGatewayProcessLocked(pid int, cfg *config.Config) error {
return nil
}
-func gatewayStatusOnHealthFailureLocked() string {
+func gatewayStatusWithoutHealthLocked() string {
if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" {
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
return gateway.runtimeStatus
@@ -233,23 +249,7 @@ func gatewayStatusOnHealthFailureLocked() string {
if gateway.runtimeStatus == "error" {
return "error"
}
- return "error"
-}
-
-func currentGatewayStatusLocked(processAlive bool) string {
- if !processAlive {
- if gateway.runtimeStatus == "restarting" {
- if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
- return "restarting"
- }
- return "error"
- }
- if gateway.runtimeStatus == "error" {
- return "error"
- }
- return "stopped"
- }
- return gatewayStatusOnHealthFailureLocked()
+ return "stopped"
}
func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool {
@@ -319,15 +319,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
return 0, err
}
- // Broadcast the attached state
- gateway.events.Broadcast(GatewayEvent{
- Status: initialStatus,
- PID: pid,
- BootDefaultModel: defaultModelName,
- ConfigDefaultModel: defaultModelName,
- RestartRequired: false,
- })
-
return pid, nil
}
@@ -335,7 +326,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
// Locate the picoclaw executable
execPath := utils.FindPicoclawBinary()
- cmd = exec.Command(execPath, "gateway")
+ cmd = exec.Command(execPath, "gateway", "-E")
cmd.Env = os.Environ()
// Forward the launcher's config path via the environment variable that
// GetConfigPath() already reads, so the gateway sub-process uses the same
@@ -376,15 +367,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
pid = cmd.Process.Pid
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
- // Broadcast the launch state immediately so clients can reflect it without polling.
- gateway.events.Broadcast(GatewayEvent{
- Status: initialStatus,
- PID: pid,
- BootDefaultModel: defaultModelName,
- ConfigDefaultModel: defaultModelName,
- RestartRequired: false,
- })
-
// Capture stdout/stderr in background
go scanPipe(stdoutPipe, gateway.logs)
go scanPipe(stderrPipe, gateway.logs)
@@ -398,26 +380,17 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
}
gateway.mu.Lock()
- shouldBroadcastStopped := false
if gateway.cmd == cmd {
gateway.cmd = nil
gateway.bootDefaultModel = ""
if gateway.runtimeStatus != "restarting" {
setGatewayRuntimeStatusLocked("stopped")
- shouldBroadcastStopped = true
}
}
gateway.mu.Unlock()
-
- if shouldBroadcastStopped {
- gateway.events.Broadcast(GatewayEvent{
- Status: "stopped",
- RestartRequired: false,
- })
- }
}()
- // Start a goroutine to probe health and broadcast "running" once ready
+ // Start a goroutine to probe health and update the runtime state once ready.
go func() {
for i := 0; i < 30; i++ { // try for up to 15 seconds
time.Sleep(500 * time.Millisecond)
@@ -431,7 +404,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
if err != nil {
continue
}
- healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 1*time.Second)
+ healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
// Verify the health endpoint returns the expected pid
gateway.mu.Lock()
@@ -439,13 +412,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
- gateway.events.Broadcast(GatewayEvent{
- Status: "running",
- PID: pid,
- BootDefaultModel: defaultModelName,
- ConfigDefaultModel: defaultModelName,
- RestartRequired: false,
- })
return
}
}
@@ -461,7 +427,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
// Prevent duplicate starts by checking health endpoint
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
- healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second)
+ healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err == nil && statusCode == http.StatusOK {
// Gateway is already running, attach to the existing process
pid := healthResp.Pid
@@ -597,10 +563,6 @@ func (h *Handler) RestartGateway() (int, error) {
gateway.mu.Lock()
previousCmd := gateway.cmd
setGatewayRuntimeStatusLocked("restarting")
- gateway.events.Broadcast(GatewayEvent{
- Status: "restarting",
- RestartRequired: false,
- })
gateway.mu.Unlock()
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
@@ -704,24 +666,20 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
func (h *Handler) gatewayStatusData() map[string]any {
data := map[string]any{}
+ configDefaultModel := ""
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
- configDefaultModel := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
+ configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
if configDefaultModel != "" {
data["config_default_model"] = configDefaultModel
}
}
// Probe health endpoint to get pid and status
- port := 0
- if cfgErr == nil && cfg != nil {
- port = cfg.Gateway.Port
- }
-
- healthResp, statusCode, err := getGatewayHealth(port, 2*time.Second)
+ healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err != nil {
gateway.mu.Lock()
- data["gateway_status"] = currentGatewayStatusLocked(true)
+ data["gateway_status"] = gatewayStatusWithoutHealthLocked()
gateway.mu.Unlock()
log.Printf("Gateway health check failed: %v", err)
} else {
@@ -734,45 +692,43 @@ func (h *Handler) gatewayStatusData() map[string]any {
data["status_code"] = statusCode
} else {
gateway.mu.Lock()
- // Check if this pid matches our tracked process
- if gateway.cmd != nil && gateway.cmd.Process != nil && gateway.cmd.Process.Pid == healthResp.Pid {
- setGatewayRuntimeStatusLocked("running")
- bootDefaultModel := gateway.bootDefaultModel
- if bootDefaultModel != "" {
- data["boot_default_model"] = bootDefaultModel
- }
- data["gateway_status"] = "running"
- data["pid"] = healthResp.Pid
- } else {
- // Health endpoint responded with a different pid
- // This could be a manual restart, try to attach to the new process
+ setGatewayRuntimeStatusLocked("running")
+ if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
oldPid := "none"
if gateway.cmd != nil && gateway.cmd.Process != nil {
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
}
- log.Printf("Detected new gateway PID (old: %s, new: %d), attempting to attach", oldPid, healthResp.Pid)
-
+ log.Printf(
+ "Detected gateway PID from health (old: %s, new: %d), attempting to attach",
+ oldPid,
+ healthResp.Pid,
+ )
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
- // Failed to find the process, treat as error
- setGatewayRuntimeStatusLocked("error")
- data["gateway_status"] = "error"
- data["pid"] = healthResp.Pid
- log.Printf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err)
- } else {
- // Successfully attached, update response data
- bootDefaultModel := gateway.bootDefaultModel
- if bootDefaultModel != "" {
- data["boot_default_model"] = bootDefaultModel
- }
- data["gateway_status"] = "running"
- data["pid"] = healthResp.Pid
+ log.Printf(
+ "Failed to attach to gateway process reported by health (PID: %d): %v",
+ healthResp.Pid,
+ err,
+ )
}
}
+
+ bootDefaultModel := gateway.bootDefaultModel
+ if bootDefaultModel != "" {
+ data["boot_default_model"] = bootDefaultModel
+ }
+ data["gateway_status"] = "running"
+ data["pid"] = healthResp.Pid
gateway.mu.Unlock()
}
}
- data["gateway_restart_required"] = false
+ bootDefaultModel, _ := data["boot_default_model"].(string)
+ gatewayStatus, _ := data["gateway_status"].(string)
+ data["gateway_restart_required"] = gatewayRestartRequired(
+ configDefaultModel,
+ bootDefaultModel,
+ gatewayStatus,
+ )
ready, reason, readyErr := h.gatewayStartReady()
if readyErr != nil {
@@ -842,51 +798,6 @@ func gatewayLogsData(r *http.Request) map[string]any {
return data
}
-// handleGatewayEvents serves an SSE stream of gateway state change events.
-//
-// GET /api/gateway/events
-func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
- flusher, ok := w.(http.Flusher)
- if !ok {
- http.Error(w, "SSE not supported", http.StatusInternalServerError)
- return
- }
-
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("Access-Control-Allow-Origin", "*")
-
- // Subscribe to gateway events
- ch := gateway.events.Subscribe()
- defer gateway.events.Unsubscribe(ch)
-
- // Send initial status so the client doesn't start blank
- initial := h.currentGatewayStatus()
- fmt.Fprintf(w, "data: %s\n\n", initial)
- flusher.Flush()
-
- for {
- select {
- case <-r.Context().Done():
- return
- case data, ok := <-ch:
- if !ok {
- return
- }
- fmt.Fprintf(w, "data: %s\n\n", data)
- flusher.Flush()
- }
- }
-}
-
-// currentGatewayStatus returns the current gateway status as a JSON string.
-func (h *Handler) currentGatewayStatus() string {
- data := h.gatewayStatusData()
- encoded, _ := json.Marshal(data)
- return string(encoded)
-}
-
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
func scanPipe(r io.Reader, buf *LogBuffer) {
scanner := bufio.NewScanner(r)
diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go
index 8dde29b76f..592571a286 100644
--- a/web/backend/api/gateway_host.go
+++ b/web/backend/api/gateway_host.go
@@ -3,6 +3,7 @@ package api
import (
"net"
"net/http"
+ "net/url"
"strconv"
"strings"
@@ -46,6 +47,23 @@ func gatewayProbeHost(bindHost string) string {
return bindHost
}
+func (h *Handler) gatewayProxyURL() *url.URL {
+ cfg, err := config.LoadConfig(h.configPath)
+ port := 18790
+ bindHost := ""
+ if err == nil && cfg != nil {
+ if cfg.Gateway.Port != 0 {
+ port = cfg.Gateway.Port
+ }
+ bindHost = h.effectiveGatewayBindHost(cfg)
+ }
+
+ return &url.URL{
+ Scheme: "http",
+ Host: net.JoinHostPort(gatewayProbeHost(bindHost), strconv.Itoa(port)),
+ }
+}
+
func requestHostName(r *http.Request) string {
reqHost, _, err := net.SplitHostPort(r.Host)
if err == nil {
diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go
index 3fffeb8939..ae3434862f 100644
--- a/web/backend/api/gateway_host_test.go
+++ b/web/backend/api/gateway_host_test.go
@@ -2,9 +2,12 @@ package api
import (
"crypto/tls"
+ "errors"
+ "net/http"
"net/http/httptest"
"path/filepath"
"testing"
+ "time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
@@ -59,6 +62,79 @@ func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
}
}
+func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) {
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+
+ cfg := config.DefaultConfig()
+ cfg.Gateway.Host = "192.168.1.10"
+ cfg.Gateway.Port = 18791
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" {
+ t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791")
+ }
+}
+
+func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) {
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+
+ cfg := config.DefaultConfig()
+ cfg.Gateway.Host = "192.168.1.10"
+ cfg.Gateway.Port = 18791
+
+ originalHealthGet := gatewayHealthGet
+ t.Cleanup(func() {
+ gatewayHealthGet = originalHealthGet
+ })
+
+ var requestedURL string
+ gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
+ requestedURL = url
+ return nil, errors.New("probe failed")
+ }
+
+ _, statusCode, err := h.getGatewayHealth(cfg, time.Second)
+ _ = statusCode
+ _ = err
+
+ if requestedURL != "http://192.168.1.10:18791/health" {
+ t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health")
+ }
+}
+
+func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+ h.SetServerOptions(18800, true, true, nil)
+
+ cfg := config.DefaultConfig()
+ cfg.Gateway.Host = "127.0.0.1"
+ cfg.Gateway.Port = 18791
+
+ originalHealthGet := gatewayHealthGet
+ t.Cleanup(func() {
+ gatewayHealthGet = originalHealthGet
+ })
+
+ var requestedURL string
+ gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
+ requestedURL = url
+ return nil, errors.New("probe failed")
+ }
+
+ _, statusCode, err := h.getGatewayHealth(cfg, time.Second)
+ _ = statusCode
+ _ = err
+
+ if requestedURL != "http://127.0.0.1:18791/health" {
+ t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health")
+ }
+}
+
func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go
index fb4f7d9438..5c94f0b891 100644
--- a/web/backend/api/gateway_test.go
+++ b/web/backend/api/gateway_test.go
@@ -3,6 +3,7 @@ package api
import (
"encoding/json"
"errors"
+ "io"
"net/http"
"net/http/httptest"
"os"
@@ -36,6 +37,15 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd {
return cmd
}
+func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
+ return &http.Response{
+ StatusCode: statusCode,
+ Body: io.NopCloser(strings.NewReader(
+ `{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(pid) + `}`,
+ )),
+ }
+}
+
func startIgnoringTermProcess(t *testing.T) *exec.Cmd {
t.Helper()
@@ -419,6 +429,125 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
}
}
+func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
+ resetGatewayTestState(t)
+
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ cmd := startLongRunningProcess(t)
+ t.Cleanup(func() {
+ if cmd.Process != nil {
+ _ = cmd.Process.Kill()
+ }
+ _ = cmd.Wait()
+ })
+
+ gateway.mu.Lock()
+ setGatewayRuntimeStatusLocked("stopped")
+ gateway.mu.Unlock()
+
+ gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
+ return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+
+ var body map[string]any
+ if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
+ t.Fatalf("unmarshal response: %v", err)
+ }
+
+ if got := body["gateway_status"]; got != "running" {
+ t.Fatalf("gateway_status = %#v, want %q", got, "running")
+ }
+ if got := body["pid"]; got != float64(cmd.Process.Pid) {
+ t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
+ }
+ if got := body["gateway_restart_required"]; got != false {
+ t.Fatalf("gateway_restart_required = %#v, want false", got)
+ }
+}
+
+func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
+ resetGatewayTestState(t)
+
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ cfg := config.DefaultConfig()
+ cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
+ cfg.ModelList[0].APIKey = "test-key"
+ cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
+ ModelName: "second-model",
+ Model: "openai/gpt-4.1",
+ APIKey: "second-key",
+ })
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ process, err := os.FindProcess(os.Getpid())
+ if err != nil {
+ t.Fatalf("FindProcess() error = %v", err)
+ }
+
+ gateway.mu.Lock()
+ gateway.cmd = &exec.Cmd{Process: process}
+ gateway.bootDefaultModel = cfg.ModelList[0].ModelName
+ setGatewayRuntimeStatusLocked("running")
+ gateway.mu.Unlock()
+
+ updatedCfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error = %v", err)
+ }
+ updatedCfg.Agents.Defaults.ModelName = "second-model"
+ if err := config.SaveConfig(configPath, updatedCfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
+ return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+
+ var body map[string]any
+ if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
+ t.Fatalf("unmarshal response: %v", err)
+ }
+
+ if got := body["gateway_status"]; got != "running" {
+ t.Fatalf("gateway_status = %#v, want %q", got, "running")
+ }
+ if got := body["boot_default_model"]; got != cfg.ModelList[0].ModelName {
+ t.Fatalf("boot_default_model = %#v, want %q", got, cfg.ModelList[0].ModelName)
+ }
+ if got := body["config_default_model"]; got != "second-model" {
+ t.Fatalf("config_default_model = %#v, want %q", got, "second-model")
+ }
+ if got := body["gateway_restart_required"]; got != true {
+ t.Fatalf("gateway_restart_required = %#v, want true", got)
+ }
+}
+
func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
resetGatewayTestState(t)
diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go
index d11f7bc5e9..a880f2f0c0 100644
--- a/web/backend/api/pico.go
+++ b/web/backend/api/pico.go
@@ -7,7 +7,6 @@ import (
"fmt"
"net/http"
"net/http/httputil"
- "net/url"
"time"
"github.com/sipeed/picoclaw/pkg/config"
@@ -22,20 +21,13 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
// WebSocket proxy: forward /pico/ws to gateway
// This allows the frontend to connect via the same port as the web UI,
// avoiding the need to expose extra ports for WebSocket communication.
- wsProxy := h.createWsProxy()
- mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy(wsProxy))
+ mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy())
}
-// createWsProxy creates a reverse proxy to the gateway WebSocket endpoint.
-// The gateway port is read from the configuration.
+// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
+// The gateway bind host and port are resolved from the latest configuration.
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
- cfg, err := config.LoadConfig(h.configPath)
- gatewayPort := 18790 // default
- if err == nil && cfg.Gateway.Port != 0 {
- gatewayPort = cfg.Gateway.Port
- }
- gatewayURL, _ := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", gatewayPort))
- wsProxy := httputil.NewSingleHostReverseProxy(gatewayURL)
+ wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
}
@@ -43,12 +35,10 @@ func (h *Handler) createWsProxy() *httputil.ReverseProxy {
}
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
-// It ensures the Connection and Upgrade headers are properly forwarded.
-func (h *Handler) handleWebSocketProxy(proxy *httputil.ReverseProxy) http.HandlerFunc {
+// The reverse proxy forwards the incoming upgrade handshake as-is.
+func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- // Set headers for WebSocket upgrade
- r.Header.Set("Connection", "upgrade")
- r.Header.Set("Upgrade", "websocket")
+ proxy := h.createWsProxy()
proxy.ServeHTTP(w, r)
}
}
diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go
index 46149fa09e..075da4ddc9 100644
--- a/web/backend/api/pico_test.go
+++ b/web/backend/api/pico_test.go
@@ -2,9 +2,12 @@ package api
import (
"encoding/json"
+ "io"
"net/http"
"net/http/httptest"
+ "net/url"
"path/filepath"
+ "strconv"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
@@ -235,3 +238,77 @@ func TestHandlePicoSetup_Response(t *testing.T) {
t.Error("response should have changed=true on first setup")
}
}
+
+func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+ handler := h.handleWebSocketProxy()
+
+ server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/pico/ws" {
+ t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
+ }
+ w.WriteHeader(http.StatusOK)
+ _, _ = io.WriteString(w, "server1")
+ }))
+ defer server1.Close()
+
+ server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/pico/ws" {
+ t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
+ }
+ w.WriteHeader(http.StatusOK)
+ _, _ = io.WriteString(w, "server2")
+ }))
+ defer server2.Close()
+
+ cfg := config.DefaultConfig()
+ cfg.Gateway.Host = "127.0.0.1"
+ cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
+ rec1 := httptest.NewRecorder()
+ handler(rec1, req1)
+
+ if rec1.Code != http.StatusOK {
+ t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
+ }
+ if body := rec1.Body.String(); body != "server1" {
+ t.Fatalf("first body = %q, want %q", body, "server1")
+ }
+
+ cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
+ rec2 := httptest.NewRecorder()
+ handler(rec2, req2)
+
+ if rec2.Code != http.StatusOK {
+ t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
+ }
+ if body := rec2.Body.String(); body != "server2" {
+ t.Fatalf("second body = %q, want %q", body, "server2")
+ }
+}
+
+func mustGatewayTestPort(t *testing.T, rawURL string) int {
+ t.Helper()
+
+ parsed, err := url.Parse(rawURL)
+ if err != nil {
+ t.Fatalf("url.Parse() error = %v", err)
+ }
+
+ port, err := strconv.Atoi(parsed.Port())
+ if err != nil {
+ t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
+ }
+
+ return port
+}
diff --git a/web/backend/api/router.go b/web/backend/api/router.go
index b564387841..028a476f26 100644
--- a/web/backend/api/router.go
+++ b/web/backend/api/router.go
@@ -71,7 +71,4 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
h.registerLauncherConfigRoutes(mux)
}
-// Shutdown gracefully shuts down the handler, closing all SSE connections.
-func (h *Handler) Shutdown() {
- gateway.events.Shutdown()
-}
+func (h *Handler) Shutdown() {}
diff --git a/web/backend/middleware/middleware.go b/web/backend/middleware/middleware.go
index de9e6d8702..e15da577bf 100644
--- a/web/backend/middleware/middleware.go
+++ b/web/backend/middleware/middleware.go
@@ -4,16 +4,14 @@ import (
"log"
"net/http"
"runtime/debug"
- "strings"
"time"
)
// JSONContentType sets the Content-Type header to application/json for
// API requests handled by the wrapped handler.
-// SSE endpoints (text/event-stream) are excluded.
func JSONContentType(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if strings.HasPrefix(r.URL.Path, "/api/") && !strings.HasSuffix(r.URL.Path, "/events") {
+ if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
w.Header().Set("Content-Type", "application/json")
}
next.ServeHTTP(w, r)
@@ -32,7 +30,6 @@ func (rr *responseRecorder) WriteHeader(code int) {
}
// Flush delegates to the underlying ResponseWriter if it implements http.Flusher.
-// This is required for SSE (Server-Sent Events) to work through the middleware.
func (rr *responseRecorder) Flush() {
if f, ok := rr.ResponseWriter.(http.Flusher); ok {
f.Flush()
diff --git a/web/backend/systray.go b/web/backend/systray.go
index 58ce4984fe..1ff98c71b8 100644
--- a/web/backend/systray.go
+++ b/web/backend/systray.go
@@ -94,7 +94,7 @@ func onReady() {
func onExit() {
fmt.Println(T(Exiting))
- // First, shutdown API handler to close all SSE connections
+ // First, shutdown API handler
if apiHandler != nil {
apiHandler.Shutdown()
}
diff --git a/web/frontend/src/components/app-header.tsx b/web/frontend/src/components/app-header.tsx
index fe0c84e69f..4f06880085 100644
--- a/web/frontend/src/components/app-header.tsx
+++ b/web/frontend/src/components/app-header.tsx
@@ -56,14 +56,20 @@ export function AppHeader() {
const isRunning = gwState === "running"
const isStarting = gwState === "starting"
const isRestarting = gwState === "restarting"
+ const isStopping = gwState === "stopping"
const isStopped = gwState === "stopped" || gwState === "unknown"
const showNotConnectedHint =
- !isRestarting && canStart && (gwState === "stopped" || gwState === "error")
+ !isRestarting &&
+ !isStopping &&
+ canStart &&
+ (gwState === "stopped" || gwState === "error")
const [showStopDialog, setShowStopDialog] = React.useState(false)
const handleGatewayToggle = () => {
- if (gwLoading || isRestarting || (!isRunning && !canStart)) return
+ if (gwLoading || isRestarting || isStopping || (!isRunning && !canStart)) {
+ return
+ }
if (isRunning) {
setShowStopDialog(true)
} else {
@@ -137,7 +143,7 @@ export function AppHeader() {
size="icon-sm"
className="bg-amber-500/15 text-amber-700 hover:bg-amber-500/25 hover:text-amber-800 dark:text-amber-300 dark:hover:bg-amber-500/25"
onClick={handleGatewayRestart}
- disabled={gwLoading || isRestarting || !canStart}
+ disabled={gwLoading || isRestarting || isStopping || !canStart}
aria-label={t("header.gateway.action.restart")}
>
@@ -168,25 +174,31 @@ export function AppHeader() {
) : (
)}
diff --git a/web/frontend/src/hooks/use-gateway-logs.ts b/web/frontend/src/hooks/use-gateway-logs.ts
index 15cbca4ae2..1de3611248 100644
--- a/web/frontend/src/hooks/use-gateway-logs.ts
+++ b/web/frontend/src/hooks/use-gateway-logs.ts
@@ -37,7 +37,9 @@ export function useGatewayLogs() {
const fetchLogs = async () => {
if (
!mounted ||
- !["running", "starting", "restarting"].includes(gateway.status)
+ !["running", "starting", "restarting", "stopping"].includes(
+ gateway.status,
+ )
) {
if (mounted) {
timeout = setTimeout(fetchLogs, 1000)
diff --git a/web/frontend/src/hooks/use-gateway.ts b/web/frontend/src/hooks/use-gateway.ts
index 65ec2b7767..b118b43da4 100644
--- a/web/frontend/src/hooks/use-gateway.ts
+++ b/web/frontend/src/hooks/use-gateway.ts
@@ -1,83 +1,24 @@
import { useAtomValue } from "jotai"
import { useCallback, useEffect, useState } from "react"
+import { restartGateway, startGateway, stopGateway } from "@/api/gateway"
import {
- type GatewayStatusResponse,
- getGatewayStatus,
- restartGateway,
- startGateway,
- stopGateway,
-} from "@/api/gateway"
-import {
- applyGatewayStatusToStore,
+ beginGatewayStoppingTransition,
+ cancelGatewayStoppingTransition,
gatewayAtom,
+ refreshGatewayState,
+ subscribeGatewayPolling,
updateGatewayStore,
} from "@/store"
-// Global variable to ensure we only have one SSE connection
-let sseInitialized = false
-
export function useGateway() {
const gateway = useAtomValue(gatewayAtom)
const { status: state, canStart, restartRequired } = gateway
const [loading, setLoading] = useState(false)
- const applyGatewayStatus = useCallback((data: GatewayStatusResponse) => {
- applyGatewayStatusToStore(data)
- }, [])
-
- // Initialize global SSE connection once
useEffect(() => {
- if (sseInitialized) return
- sseInitialized = true
-
- getGatewayStatus()
- .then((data) => applyGatewayStatus(data))
- .catch(() => {
- updateGatewayStore({
- status: "unknown",
- canStart: true,
- restartRequired: false,
- })
- })
-
- const statusPoll = window.setInterval(() => {
- getGatewayStatus()
- .then((data) => applyGatewayStatus(data))
- .catch(() => {
- // ignore polling errors
- })
- }, 5000)
-
- // Subscribe to SSE for real-time updates globally
- const es = new EventSource("/api/gateway/events")
-
- es.onmessage = (event) => {
- try {
- const data = JSON.parse(event.data)
- if (
- data.gateway_status ||
- typeof data.gateway_start_allowed === "boolean"
- ) {
- applyGatewayStatus(data)
- }
- } catch {
- // ignore
- }
- }
-
- es.onerror = () => {
- // EventSource will auto-reconnect. Preserve the last known gateway
- // status so transient SSE disconnects do not suppress chat websocket
- // reconnects while polling catches up.
- }
-
- return () => {
- window.clearInterval(statusPoll)
- es.close()
- sseInitialized = false
- }
- }, [applyGatewayStatus])
+ return subscribeGatewayPolling()
+ }, [])
const start = useCallback(async () => {
if (!canStart) return
@@ -85,33 +26,28 @@ export function useGateway() {
setLoading(true)
try {
await startGateway()
- // SSE will push the real state changes, but set optimistic state
- updateGatewayStore({ status: "starting" })
+ updateGatewayStore({
+ status: "starting",
+ restartRequired: false,
+ })
} catch (err) {
console.error("Failed to start gateway:", err)
- try {
- const status = await getGatewayStatus()
- applyGatewayStatus(status)
- } catch {
- updateGatewayStore({ status: "unknown" })
- }
} finally {
+ await refreshGatewayState({ force: true })
setLoading(false)
}
- }, [applyGatewayStatus, canStart])
+ }, [canStart])
const stop = useCallback(async () => {
setLoading(true)
+ beginGatewayStoppingTransition()
try {
await stopGateway()
- updateGatewayStore({
- status: "stopped",
- canStart: true,
- restartRequired: false,
- })
} catch (err) {
console.error("Failed to stop gateway:", err)
+ cancelGatewayStoppingTransition()
} finally {
+ await refreshGatewayState({ force: true })
setLoading(false)
}
}, [])
@@ -119,34 +55,20 @@ export function useGateway() {
const restart = useCallback(async () => {
if (state !== "running") return
- const previousState = state
- const previousCanStart = canStart
- const previousRestartRequired = restartRequired
-
setLoading(true)
- updateGatewayStore({
- status: "restarting",
- restartRequired: false,
- })
-
try {
await restartGateway()
+ updateGatewayStore({
+ status: "restarting",
+ restartRequired: false,
+ })
} catch (err) {
console.error("Failed to restart gateway:", err)
- try {
- const status = await getGatewayStatus()
- applyGatewayStatus(status)
- } catch {
- updateGatewayStore({
- status: previousState,
- canStart: previousCanStart,
- restartRequired: previousRestartRequired,
- })
- }
} finally {
+ await refreshGatewayState({ force: true })
setLoading(false)
}
- }, [applyGatewayStatus, canStart, restartRequired, state])
+ }, [state])
return { state, loading, canStart, restartRequired, start, stop, restart }
}
diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json
index 2fa32ebb5a..327b4c6466 100644
--- a/web/frontend/src/i18n/locales/en.json
+++ b/web/frontend/src/i18n/locales/en.json
@@ -63,7 +63,8 @@
},
"status": {
"starting": "Starting Gateway...",
- "restarting": "Restarting Gateway..."
+ "restarting": "Restarting Gateway...",
+ "stopping": "Stopping Gateway..."
},
"restartRequired": "Model changes require a gateway restart to take effect."
}
diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json
index badf5bb3d0..cd674ddc14 100644
--- a/web/frontend/src/i18n/locales/zh.json
+++ b/web/frontend/src/i18n/locales/zh.json
@@ -63,7 +63,8 @@
},
"status": {
"starting": "ๆๅกๅฏๅจไธญ...",
- "restarting": "ๆๅก้ๅฏไธญ..."
+ "restarting": "ๆๅก้ๅฏไธญ...",
+ "stopping": "ๆๅกๅๆญขไธญ..."
},
"restartRequired": "ๅๆข้ป่ฎคๆจกๅๅ้่ฆ้ๅฏๆๅกๆ่ฝ็ๆใ"
}
diff --git a/web/frontend/src/store/gateway.ts b/web/frontend/src/store/gateway.ts
index c5eee84516..1bdec6220c 100644
--- a/web/frontend/src/store/gateway.ts
+++ b/web/frontend/src/store/gateway.ts
@@ -6,6 +6,7 @@ export type GatewayState =
| "running"
| "starting"
| "restarting"
+ | "stopping"
| "stopped"
| "error"
| "unknown"
@@ -24,9 +25,29 @@ const DEFAULT_GATEWAY_STATE: GatewayStoreState = {
restartRequired: false,
}
+const GATEWAY_POLL_INTERVAL_MS = 2000
+const GATEWAY_TRANSIENT_POLL_INTERVAL_MS = 1000
+const GATEWAY_STOPPING_TIMEOUT_MS = 5000
+
+interface RefreshGatewayStateOptions {
+ force?: boolean
+}
+
// Global atom for gateway state
export const gatewayAtom = atom(DEFAULT_GATEWAY_STATE)
+let gatewayPollingSubscribers = 0
+let gatewayPollingTimer: ReturnType | null = null
+let gatewayPollingRequest: Promise | null = null
+let gatewayStoppingTimer: ReturnType | null = null
+
+function clearGatewayStoppingTimeout() {
+ if (gatewayStoppingTimer !== null) {
+ clearTimeout(gatewayStoppingTimer)
+ gatewayStoppingTimer = null
+ }
+}
+
function normalizeGatewayStoreState(
prev: GatewayStoreState,
patch: GatewayStorePatch,
@@ -49,10 +70,38 @@ export function updateGatewayStore(
| GatewayStorePatch
| ((prev: GatewayStoreState) => GatewayStorePatch | GatewayStoreState),
) {
- getDefaultStore().set(gatewayAtom, (prev) => {
+ const store = getDefaultStore()
+ store.set(gatewayAtom, (prev) => {
const nextPatch = typeof patch === "function" ? patch(prev) : patch
return normalizeGatewayStoreState(prev, nextPatch)
})
+ const nextState = store.get(gatewayAtom)
+ if (nextState?.status !== "stopping") {
+ clearGatewayStoppingTimeout()
+ }
+}
+
+export function beginGatewayStoppingTransition() {
+ clearGatewayStoppingTimeout()
+ updateGatewayStore({
+ status: "stopping",
+ canStart: false,
+ restartRequired: false,
+ })
+ gatewayStoppingTimer = setTimeout(() => {
+ gatewayStoppingTimer = null
+ updateGatewayStore((prev) =>
+ prev.status === "stopping" ? { status: "running" } : prev,
+ )
+ void refreshGatewayState({ force: true })
+ }, GATEWAY_STOPPING_TIMEOUT_MS)
+}
+
+export function cancelGatewayStoppingTransition() {
+ clearGatewayStoppingTimeout()
+ updateGatewayStore((prev) =>
+ prev.status === "stopping" ? { status: "running" } : prev,
+ )
}
export function applyGatewayStatusToStore(
@@ -64,21 +113,92 @@ export function applyGatewayStatusToStore(
>,
) {
updateGatewayStore((prev) => ({
- status: data.gateway_status ?? prev.status,
- canStart: data.gateway_start_allowed ?? prev.canStart,
+ status:
+ prev.status === "stopping" && data.gateway_status === "running"
+ ? "stopping"
+ : (data.gateway_status ?? prev.status),
+ canStart:
+ prev.status === "stopping" && data.gateway_status === "running"
+ ? false
+ : (data.gateway_start_allowed ?? prev.canStart),
restartRequired:
- data.gateway_restart_required ??
- (data.gateway_status && data.gateway_status !== "running"
+ prev.status === "stopping" && data.gateway_status === "running"
? false
- : prev.restartRequired),
+ : (data.gateway_restart_required ?? prev.restartRequired),
}))
}
-export async function refreshGatewayState() {
+function nextGatewayPollInterval() {
+ const status = getDefaultStore().get(gatewayAtom).status
+ if (
+ status === "starting" ||
+ status === "restarting" ||
+ status === "stopping"
+ ) {
+ return GATEWAY_TRANSIENT_POLL_INTERVAL_MS
+ }
+ return GATEWAY_POLL_INTERVAL_MS
+}
+
+function scheduleGatewayPoll(delay = nextGatewayPollInterval()) {
+ if (gatewayPollingSubscribers === 0) {
+ return
+ }
+
+ if (gatewayPollingTimer !== null) {
+ clearTimeout(gatewayPollingTimer)
+ }
+
+ gatewayPollingTimer = setTimeout(() => {
+ gatewayPollingTimer = null
+ void refreshGatewayState()
+ }, delay)
+}
+
+export async function refreshGatewayState(
+ options: RefreshGatewayStateOptions = {},
+) {
+ if (gatewayPollingRequest) {
+ await gatewayPollingRequest
+ if (options.force) {
+ return refreshGatewayState()
+ }
+ return
+ }
+
+ gatewayPollingRequest = (async () => {
+ try {
+ const status = await getGatewayStatus()
+ applyGatewayStatusToStore(status)
+ } catch {
+ // Preserve the last known state when a poll fails.
+ } finally {
+ gatewayPollingRequest = null
+ scheduleGatewayPoll()
+ }
+ })()
+
try {
- const status = await getGatewayStatus()
- applyGatewayStatusToStore(status)
- } catch {
- updateGatewayStore(DEFAULT_GATEWAY_STATE)
+ await gatewayPollingRequest
+ } finally {
+ if (gatewayPollingSubscribers === 0 && gatewayPollingTimer !== null) {
+ clearTimeout(gatewayPollingTimer)
+ gatewayPollingTimer = null
+ }
+ }
+}
+
+export function subscribeGatewayPolling() {
+ gatewayPollingSubscribers += 1
+ if (gatewayPollingSubscribers === 1) {
+ void refreshGatewayState()
+ }
+
+ return () => {
+ gatewayPollingSubscribers = Math.max(0, gatewayPollingSubscribers - 1)
+ if (gatewayPollingSubscribers === 0 && gatewayPollingTimer !== null) {
+ clearTimeout(gatewayPollingTimer)
+ gatewayPollingTimer = null
+ }
}
}