diff --git a/cmd/picoclaw/internal/gateway/command.go b/cmd/picoclaw/internal/gateway/command.go index bfa69f0723..4812f1beef 100644 --- a/cmd/picoclaw/internal/gateway/command.go +++ b/cmd/picoclaw/internal/gateway/command.go @@ -5,6 +5,8 @@ import ( "github.com/spf13/cobra" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/gateway" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -12,6 +14,7 @@ import ( func NewGatewayCommand() *cobra.Command { var debug bool var noTruncate bool + var allowEmpty bool cmd := &cobra.Command{ Use: "gateway", @@ -31,12 +34,19 @@ func NewGatewayCommand() *cobra.Command { return nil }, RunE: func(_ *cobra.Command, _ []string) error { - return gatewayCmd(debug) + return gateway.Run(debug, internal.GetConfigPath(), allowEmpty) }, } cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging") cmd.Flags().BoolVarP(&noTruncate, "no-truncate", "T", false, "Disable string truncation in debug logs") + cmd.Flags().BoolVarP( + &allowEmpty, + "allow-empty", + "E", + false, + "Continue starting even when no default model is configured", + ) return cmd } diff --git a/cmd/picoclaw/internal/gateway/command_test.go b/cmd/picoclaw/internal/gateway/command_test.go index 4d591ea672..839a7315a0 100644 --- a/cmd/picoclaw/internal/gateway/command_test.go +++ b/cmd/picoclaw/internal/gateway/command_test.go @@ -28,4 +28,5 @@ func TestNewGatewayCommand(t *testing.T) { assert.True(t, cmd.HasFlags()) assert.NotNil(t, cmd.Flags().Lookup("debug")) + assert.NotNil(t, cmd.Flags().Lookup("allow-empty")) } diff --git a/config/config.example.json b/config/config.example.json index 1c11cd42a9..14e2092598 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -518,6 +518,7 @@ }, "gateway": { "host": "127.0.0.1", - "port": 18790 + "port": 18790, + "hot_reload": false } } diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index df430e4d3a..8121525ab3 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error { if len(m.channels) == 0 { logger.WarnC("channels", "No channels enabled") - return errors.New("no channels enabled") } logger.InfoC("channels", "Starting all channels") diff --git a/pkg/config/config.go b/pkg/config/config.go index 35de48f23b..6694ef3a10 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -625,8 +625,9 @@ func (c *ModelConfig) Validate() error { } type GatewayConfig struct { - Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"` - Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` + Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"` + Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` + HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"` } type ToolDiscoveryConfig struct { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index fc835f78f6..f4f8979e13 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -267,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) { if cfg.Gateway.Port == 0 { t.Error("Gateway port should have default value") } + if cfg.Gateway.HotReload { + t.Error("Gateway hot reload should be disabled by default") + } } // TestDefaultConfig_Providers verifies provider structure diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 2b177d5de5..90a99408e5 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -395,8 +395,9 @@ func DefaultConfig() *Config { }, }, Gateway: GatewayConfig{ - Host: "127.0.0.1", - Port: 18790, + Host: "127.0.0.1", + Port: 18790, + HotReload: false, }, Tools: ToolsConfig{ MediaCleanup: MediaCleanupConfig{ diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/pkg/gateway/gateway.go similarity index 61% rename from cmd/picoclaw/internal/gateway/helpers.go rename to pkg/gateway/gateway.go index 85e93bcf96..6745d1748b 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/pkg/gateway/gateway.go @@ -10,7 +10,6 @@ import ( "syscall" "time" - "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -42,15 +41,13 @@ import ( "github.com/sipeed/picoclaw/pkg/voice" ) -// Timeout constants for service operations const ( serviceShutdownTimeout = 30 * time.Second providerReloadTimeout = 30 * time.Second gracefulShutdownTimeout = 15 * time.Second ) -// gatewayServices holds references to all running services -type gatewayServices struct { +type services struct { CronService *cron.CronService HeartbeatService *heartbeat.HeartbeatService MediaStore media.MediaStore @@ -59,24 +56,41 @@ type gatewayServices struct { HealthServer *health.Server } -func gatewayCmd(debug bool) error { +type startupBlockedProvider struct { + reason string +} + +func (p *startupBlockedProvider) Chat( + _ context.Context, + _ []providers.Message, + _ []providers.ToolDefinition, + _ string, + _ map[string]any, +) (*providers.LLMResponse, error) { + return nil, fmt.Errorf("%s", p.reason) +} + +func (p *startupBlockedProvider) GetDefaultModel() string { + return "" +} + +// Run starts the gateway runtime using the configuration loaded from configPath. +func Run(debug bool, configPath string, allowEmptyStartup bool) error { if debug { logger.SetLevel(logger.DEBUG) fmt.Println("๐Ÿ” Debug mode enabled") } - configPath := internal.GetConfigPath() - cfg, err := internal.LoadConfig() + cfg, err := config.LoadConfig(configPath) if err != nil { return fmt.Errorf("error loading config: %w", err) } - provider, modelID, err := providers.CreateProvider(cfg) + provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup) if err != nil { return fmt.Errorf("error creating provider: %w", err) } - // Use the resolved model ID from provider creation if modelID != "" { cfg.Agents.Defaults.ModelName = modelID } @@ -84,17 +98,13 @@ func gatewayCmd(debug bool) error { msgBus := bus.NewMessageBus() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) - // Print agent startup info fmt.Println("\n๐Ÿ“ฆ Agent Status:") startupInfo := agentLoop.GetStartupInfo() toolsInfo := startupInfo["tools"].(map[string]any) skillsInfo := startupInfo["skills"].(map[string]any) fmt.Printf(" โ€ข Tools: %d loaded\n", toolsInfo["count"]) - fmt.Printf(" โ€ข Skills: %d/%d available\n", - skillsInfo["available"], - skillsInfo["total"]) + fmt.Printf(" โ€ข Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"]) - // Log to file as well logger.InfoCF("agent", "Agent initialized", map[string]any{ "tools_count": toolsInfo["count"], @@ -102,8 +112,7 @@ func gatewayCmd(debug bool) error { "skills_available": skillsInfo["available"], }) - // Setup and start all services - services, err := setupAndStartServices(cfg, agentLoop, msgBus) + runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus) if err != nil { return err } @@ -116,23 +125,25 @@ func gatewayCmd(debug bool) error { go agentLoop.Run(ctx) - // Setup config file watcher for hot reload - configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug) + var configReloadChan <-chan *config.Config + stopWatch := func() {} + if cfg.Gateway.HotReload { + configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug) + logger.Info("Config hot reload enabled") + } defer stopWatch() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - // Main event loop - wait for signals or config changes for { select { case <-sigChan: logger.Info("Shutting down...") - shutdownGateway(services, agentLoop, provider, true) + shutdownGateway(runningServices, agentLoop, provider, true) return nil - case newCfg := <-configReloadChan: - err := handleConfigReload(ctx, agentLoop, newCfg, &provider, services, msgBus) + err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup) if err != nil { logger.Errorf("Config reload failed: %v", err) } @@ -140,18 +151,33 @@ func gatewayCmd(debug bool) error { } } -// setupAndStartServices initializes and starts all services +func createStartupProvider( + cfg *config.Config, + allowEmptyStartup bool, +) (providers.LLMProvider, string, error) { + modelName := cfg.Agents.Defaults.GetModelName() + if modelName == "" && allowEmptyStartup { + reason := "no default model configured; gateway started in limited mode" + fmt.Printf("โš  Warning: %s\n", reason) + logger.WarnCF("gateway", "Gateway started without default model", map[string]any{ + "limited_mode": true, + }) + return &startupBlockedProvider{reason: reason}, "", nil + } + + return providers.CreateProvider(cfg) +} + func setupAndStartServices( cfg *config.Config, agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, -) (*gatewayServices, error) { - services := &gatewayServices{} +) (*services, error) { + runningServices := &services{} - // Setup cron tool and service execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute var err error - services.CronService, err = setupCronTool( + runningServices.CronService, err = setupCronTool( agentLoop, msgBus, cfg.WorkspacePath(), @@ -162,120 +188,105 @@ func setupAndStartServices( if err != nil { return nil, fmt.Errorf("error setting up cron service: %w", err) } - if err = services.CronService.Start(); err != nil { + if err = runningServices.CronService.Start(); err != nil { return nil, fmt.Errorf("error starting cron service: %w", err) } fmt.Println("โœ“ Cron service started") - // Setup heartbeat service - services.HeartbeatService = heartbeat.NewHeartbeatService( + runningServices.HeartbeatService = heartbeat.NewHeartbeatService( cfg.WorkspacePath(), cfg.Heartbeat.Interval, cfg.Heartbeat.Enabled, ) - services.HeartbeatService.SetBus(msgBus) - services.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop)) - if err = services.HeartbeatService.Start(); err != nil { + runningServices.HeartbeatService.SetBus(msgBus) + runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop)) + if err = runningServices.HeartbeatService.Start(); err != nil { return nil, fmt.Errorf("error starting heartbeat service: %w", err) } fmt.Println("โœ“ Heartbeat service started") - // Create media store for file lifecycle management with TTL cleanup - services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{ + runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{ Enabled: cfg.Tools.MediaCleanup.Enabled, MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute, Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute, }) - // Start the media store if it's a FileMediaStore with cleanup - if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok { fms.Start() } - // Create channel manager - services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore) + runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore) if err != nil { - // Stop the media store if it's a FileMediaStore with cleanup - if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok { fms.Stop() } return nil, fmt.Errorf("error creating channel manager: %w", err) } - // Inject channel manager and media store into agent loop - agentLoop.SetChannelManager(services.ChannelManager) - agentLoop.SetMediaStore(services.MediaStore) + agentLoop.SetChannelManager(runningServices.ChannelManager) + agentLoop.SetMediaStore(runningServices.MediaStore) - // Wire up voice transcription if a supported provider is configured. if transcriber := voice.DetectTranscriber(cfg); transcriber != nil { agentLoop.SetTranscriber(transcriber) logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()}) } - enabledChannels := services.ChannelManager.GetEnabledChannels() + enabledChannels := runningServices.ChannelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf("โœ“ Channels enabled: %s\n", enabledChannels) } else { fmt.Println("โš  Warning: No channels enabled") } - // Setup shared HTTP server with health endpoints and webhook handlers addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) - services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - services.ChannelManager.SetupHTTPServer(addr, services.HealthServer) + runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer) - if err = services.ChannelManager.StartAll(context.Background()); err != nil { + if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil { return nil, fmt.Errorf("error starting channels: %w", err) } fmt.Printf("โœ“ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) - // Setup state manager and device service stateManager := state.NewManager(cfg.WorkspacePath()) - services.DeviceService = devices.NewService(devices.Config{ + runningServices.DeviceService = devices.NewService(devices.Config{ Enabled: cfg.Devices.Enabled, MonitorUSB: cfg.Devices.MonitorUSB, }, stateManager) - services.DeviceService.SetBus(msgBus) - if err = services.DeviceService.Start(context.Background()); err != nil { + runningServices.DeviceService.SetBus(msgBus) + if err = runningServices.DeviceService.Start(context.Background()); err != nil { logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()}) } else if cfg.Devices.Enabled { fmt.Println("โœ“ Device event service started") } - return services, nil + return runningServices, nil } -// stopAndCleanupServices stops all services and cleans up resources -func stopAndCleanupServices( - services *gatewayServices, - shutdownTimeout time.Duration, -) { +func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) { shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout) defer shutdownCancel() - if services.ChannelManager != nil { - services.ChannelManager.StopAll(shutdownCtx) + if runningServices.ChannelManager != nil { + runningServices.ChannelManager.StopAll(shutdownCtx) } - if services.DeviceService != nil { - services.DeviceService.Stop() + if runningServices.DeviceService != nil { + runningServices.DeviceService.Stop() } - if services.HeartbeatService != nil { - services.HeartbeatService.Stop() + if runningServices.HeartbeatService != nil { + runningServices.HeartbeatService.Stop() } - if services.CronService != nil { - services.CronService.Stop() + if runningServices.CronService != nil { + runningServices.CronService.Stop() } - if services.MediaStore != nil { - // Stop the media store if it's a FileMediaStore with cleanup - if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + if runningServices.MediaStore != nil { + if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok { fms.Stop() } } } -// shutdownGateway performs a complete gateway shutdown func shutdownGateway( - services *gatewayServices, + runningServices *services, agentLoop *agent.AgentLoop, provider providers.LLMProvider, fullShutdown bool, @@ -284,7 +295,7 @@ func shutdownGateway( cp.Close() } - stopAndCleanupServices(services, gracefulShutdownTimeout) + stopAndCleanupServices(runningServices, gracefulShutdownTimeout) agentLoop.Stop() agentLoop.Close() @@ -292,15 +303,14 @@ func shutdownGateway( logger.Info("โœ“ Gateway stopped") } -// handleConfigReload handles config file reload by stopping all services, -// reloading the provider and config, and restarting services with the new config. func handleConfigReload( ctx context.Context, al *agent.AgentLoop, newCfg *config.Config, providerRef *providers.LLMProvider, - services *gatewayServices, + runningServices *services, msgBus *bus.MessageBus, + allowEmptyStartup bool, ) error { logger.Info("๐Ÿ”„ Config file changed, reloading...") @@ -311,18 +321,14 @@ func handleConfigReload( logger.Infof(" New model is '%s', recreating provider...", newModel) - // Stop all services before reloading logger.Info(" Stopping all services...") - stopAndCleanupServices(services, serviceShutdownTimeout) + stopAndCleanupServices(runningServices, serviceShutdownTimeout) - // Create new provider from updated config first to ensure validity - // This will use the correct API key and settings from newCfg.ModelList - newProvider, newModelID, err := providers.CreateProvider(newCfg) + newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup) if err != nil { logger.Errorf(" โš  Error creating new provider: %v", err) logger.Warn(" Attempting to restart services with old provider and config...") - // Try to restart services with old configuration - if restartErr := restartServices(al, services, msgBus); restartErr != nil { + if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil { logger.Errorf(" โš  Failed to restart services: %v", restartErr) } return fmt.Errorf("error creating new provider: %w", err) @@ -332,31 +338,25 @@ func handleConfigReload( newCfg.Agents.Defaults.ModelName = newModelID } - // Use the atomic reload method on AgentLoop to safely swap provider and config. - // This handles locking internally to prevent races with in-flight LLM calls - // and concurrent reads of registry/config while the swap occurs. reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout) defer reloadCancel() if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil { logger.Errorf(" โš  Error reloading agent loop: %v", err) - // Close the newly created provider since it wasn't adopted if cp, ok := newProvider.(providers.StatefulProvider); ok { cp.Close() } logger.Warn(" Attempting to restart services with old provider and config...") - if restartErr := restartServices(al, services, msgBus); restartErr != nil { + if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil { logger.Errorf(" โš  Failed to restart services: %v", restartErr) } return fmt.Errorf("error reloading agent loop: %w", err) } - // Update local provider reference only after successful atomic reload *providerRef = newProvider - // Restart all services with new config logger.Info(" Restarting all services with new configuration...") - if err := restartServices(al, services, msgBus); err != nil { + if err := restartServices(al, runningServices, msgBus); err != nil { logger.Errorf(" โš  Error restarting services: %v", err) return fmt.Errorf("error restarting services: %w", err) } @@ -365,19 +365,16 @@ func handleConfigReload( return nil } -// restartServices restarts all services after a config reload func restartServices( al *agent.AgentLoop, - services *gatewayServices, + runningServices *services, msgBus *bus.MessageBus, ) error { - // Get current config from agent loop (which has been updated if this is a reload) cfg := al.GetConfig() - // Re-create and start cron service with new config execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute var err error - services.CronService, err = setupCronTool( + runningServices.CronService, err = setupCronTool( al, msgBus, cfg.WorkspacePath(), @@ -388,57 +385,51 @@ func restartServices( if err != nil { return fmt.Errorf("error restarting cron service: %w", err) } - if err = services.CronService.Start(); err != nil { + if err = runningServices.CronService.Start(); err != nil { return fmt.Errorf("error restarting cron service: %w", err) } fmt.Println(" โœ“ Cron service restarted") - // Re-create and start heartbeat service with new config - services.HeartbeatService = heartbeat.NewHeartbeatService( + runningServices.HeartbeatService = heartbeat.NewHeartbeatService( cfg.WorkspacePath(), cfg.Heartbeat.Interval, cfg.Heartbeat.Enabled, ) - services.HeartbeatService.SetBus(msgBus) - services.HeartbeatService.SetHandler(createHeartbeatHandler(al)) - if err = services.HeartbeatService.Start(); err != nil { + runningServices.HeartbeatService.SetBus(msgBus) + runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al)) + if err = runningServices.HeartbeatService.Start(); err != nil { return fmt.Errorf("error restarting heartbeat service: %w", err) } fmt.Println(" โœ“ Heartbeat service restarted") - // Re-create media store with new config - services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{ + runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{ Enabled: cfg.Tools.MediaCleanup.Enabled, MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute, Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute, }) - // Start the media store if it's a FileMediaStore with cleanup - if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok { fms.Start() } - al.SetMediaStore(services.MediaStore) + al.SetMediaStore(runningServices.MediaStore) - // Re-create channel manager with new config - services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore) + runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore) if err != nil { return fmt.Errorf("error recreating channel manager: %w", err) } - al.SetChannelManager(services.ChannelManager) + al.SetChannelManager(runningServices.ChannelManager) - enabledChannels := services.ChannelManager.GetEnabledChannels() + enabledChannels := runningServices.ChannelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf(" โœ“ Channels enabled: %s\n", enabledChannels) } else { fmt.Println(" โš  Warning: No channels enabled") } - // Setup HTTP server with new config addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) - services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - services.ChannelManager.SetupHTTPServer(addr, services.HealthServer) + runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer) - // Use background context for lifecycle to ensure services persist after restartServices returns - if err = services.ChannelManager.StartAll(context.Background()); err != nil { + if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil { return fmt.Errorf("error restarting channels: %w", err) } fmt.Printf( @@ -447,22 +438,20 @@ func restartServices( cfg.Gateway.Port, ) - // Re-create device service with new config stateManager := state.NewManager(cfg.WorkspacePath()) - services.DeviceService = devices.NewService(devices.Config{ + runningServices.DeviceService = devices.NewService(devices.Config{ Enabled: cfg.Devices.Enabled, MonitorUSB: cfg.Devices.MonitorUSB, }, stateManager) - services.DeviceService.SetBus(msgBus) - if err := services.DeviceService.Start(context.Background()); err != nil { + runningServices.DeviceService.SetBus(msgBus) + if err := runningServices.DeviceService.Start(context.Background()); err != nil { logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()}) } else if cfg.Devices.Enabled { fmt.Println(" โœ“ Device event service restarted") } - // Wire up voice transcription with new config transcriber := voice.DetectTranscriber(cfg) - al.SetTranscriber(transcriber) // This will set it to nil if disabled + al.SetTranscriber(transcriber) if transcriber != nil { logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()}) } else { @@ -472,8 +461,6 @@ func restartServices( return nil } -// setupConfigWatcherPolling sets up a simple polling-based config file watcher -// Returns a channel for config updates and a stop function func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) { configChan := make(chan *config.Config, 1) stop := make(chan struct{}) @@ -483,11 +470,10 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf go func() { defer wg.Done() - // Get initial file info lastModTime := getFileModTime(configPath) lastSize := getFileSize(configPath) - ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds + ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() for { @@ -496,20 +482,16 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf currentModTime := getFileModTime(configPath) currentSize := getFileSize(configPath) - // Check if file changed (modification time or size changed) if currentModTime.After(lastModTime) || currentSize != lastSize { if debug { logger.Debugf("๐Ÿ” Config file change detected") } - // Debounce - wait a bit to ensure file write is complete time.Sleep(500 * time.Millisecond) - // Update last known state to prevent repeated reload attempts on failure lastModTime = currentModTime lastSize = currentSize - // Validate and load new config newCfg, err := config.LoadConfig(configPath) if err != nil { logger.Errorf("โš  Error loading new config: %v", err) @@ -517,7 +499,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf continue } - // Validate the new config if err := newCfg.ValidateModelList(); err != nil { logger.Errorf(" โš  New config validation failed: %v", err) logger.Warn(" Using previous valid config") @@ -526,15 +507,12 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf logger.Info("โœ“ Config file validated and loaded") - // Send new config to main loop (non-blocking) select { case configChan <- newCfg: default: - // Channel full, skip this update logger.Warn("โš  Previous config reload still in progress, skipping") } } - case <-stop: return } @@ -549,7 +527,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf return configChan, stopFunc } -// getFileModTime returns the modification time of a file, or zero time if file doesn't exist func getFileModTime(path string) time.Time { info, err := os.Stat(path) if err != nil { @@ -558,7 +535,6 @@ func getFileModTime(path string) time.Time { return info.ModTime() } -// getFileSize returns the size of a file, or 0 if file doesn't exist func getFileSize(path string) int64 { info, err := os.Stat(path) if err != nil { @@ -577,10 +553,8 @@ func setupCronTool( ) (*cron.CronService, error) { cronStorePath := filepath.Join(workspace, "cron", "jobs.json") - // Create cron service cronService := cron.NewCronService(cronStorePath, nil) - // Create and register CronTool if enabled var cronTool *tools.CronTool if cfg.Tools.IsToolEnabled("cron") { var err error @@ -592,7 +566,6 @@ func setupCronTool( agentLoop.RegisterTool(cronTool) } - // Set onJob handler if cronTool != nil { cronService.SetOnJob(func(job *cron.CronJob) (string, error) { result := cronTool.ExecuteJob(context.Background(), job) @@ -605,22 +578,17 @@ func setupCronTool( func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult { return func(prompt, channel, chatID string) *tools.ToolResult { - // Use cli:direct as fallback if no valid channel if channel == "" || chatID == "" { channel, chatID = "cli", "direct" } - // Use ProcessHeartbeat - no session history, each heartbeat is independent - var response string - var err error - response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) + + response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) if err != nil { return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) } if response == "HEARTBEAT_OK" { return tools.SilentResult("Heartbeat OK") } - // For heartbeat, always return silent - the subagent result will be - // sent to user via processSystemMessage when the async task completes return tools.SilentResult(response) } } diff --git a/web/backend/api/events.go b/web/backend/api/events.go deleted file mode 100644 index 5c85b149ad..0000000000 --- a/web/backend/api/events.go +++ /dev/null @@ -1,80 +0,0 @@ -package api - -import ( - "encoding/json" - "sync" -) - -// GatewayEvent represents a state change event for the gateway process. -type GatewayEvent struct { - Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error" - PID int `json:"pid,omitempty"` - BootDefaultModel string `json:"boot_default_model,omitempty"` - ConfigDefaultModel string `json:"config_default_model,omitempty"` - RestartRequired bool `json:"gateway_restart_required,omitempty"` -} - -// EventBroadcaster manages SSE client subscriptions and broadcasts events. -type EventBroadcaster struct { - mu sync.RWMutex - clients map[chan string]struct{} -} - -// NewEventBroadcaster creates a new broadcaster. -func NewEventBroadcaster() *EventBroadcaster { - return &EventBroadcaster{ - clients: make(map[chan string]struct{}), - } -} - -// Subscribe adds a new listener channel and returns it. -// The caller must call Unsubscribe when done. -func (b *EventBroadcaster) Subscribe() chan string { - ch := make(chan string, 8) - b.mu.Lock() - b.clients[ch] = struct{}{} - b.mu.Unlock() - return ch -} - -// Unsubscribe removes a listener channel and closes it. -func (b *EventBroadcaster) Unsubscribe(ch chan string) { - b.mu.Lock() - defer b.mu.Unlock() - - // Check if the channel is still registered before closing - if _, exists := b.clients[ch]; exists { - delete(b.clients, ch) - close(ch) - } -} - -// Broadcast sends a GatewayEvent to all connected SSE clients. -func (b *EventBroadcaster) Broadcast(event GatewayEvent) { - data, err := json.Marshal(event) - if err != nil { - return - } - - b.mu.RLock() - defer b.mu.RUnlock() - - for ch := range b.clients { - // Non-blocking send; drop event if client is slow - select { - case ch <- string(data): - default: - } - } -} - -// Shutdown closes all subscriber channels, notifying all SSE clients to disconnect. -// This should be called when the server is shutting down. -func (b *EventBroadcaster) Shutdown() { - // Close all channels to notify listeners - for ch := range b.clients { - b.Unsubscribe(ch) - } - // Clear the map - b.clients = make(map[chan string]struct{}) -} diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 424b21e96c..16b793427a 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "os" "os/exec" @@ -30,11 +31,9 @@ var gateway = struct { runtimeStatus string startupDeadline time.Time logs *LogBuffer - events *EventBroadcaster }{ runtimeStatus: "stopped", logs: NewLogBuffer(200), - events: NewEventBroadcaster(), } var ( @@ -51,11 +50,19 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, // getGatewayHealth checks the gateway health endpoint and returns the status response // Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid. -func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse, int, error) { - if port == 0 { - port = 18790 +func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) { + port := 18790 + if cfg != nil && cfg.Gateway.Port != 0 { + port = cfg.Gateway.Port } - url := fmt.Sprintf("http://127.0.0.1:%d/health", port) + + probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) + url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health" + + return getGatewayHealthByURL(url, timeout) +} + +func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusResponse, int, error) { resp, err := gatewayHealthGet(url, timeout) if err != nil { return nil, 0, err @@ -73,7 +80,6 @@ func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse, // registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux. func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus) - mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents) mux.HandleFunc("GET /api/gateway/logs", h.handleGatewayLogs) mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs) mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart) @@ -87,7 +93,7 @@ func (h *Handler) TryAutoStartGateway() { // Check if gateway is already running via health endpoint cfg, cfgErr := config.LoadConfig(h.configPath) if cfgErr == nil && cfg != nil { - healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second) + healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second) if err == nil && statusCode == http.StatusOK { // Gateway is already running, attach to the existing process pid := healthResp.Pid @@ -170,6 +176,16 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig return modelCfg } +func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool { + if gatewayStatus != "running" { + return false + } + if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" { + return false + } + return configDefaultModel != bootDefaultModel +} + func isCmdProcessAliveLocked(cmd *exec.Cmd) bool { if cmd == nil || cmd.Process == nil { return false @@ -220,7 +236,7 @@ func attachToGatewayProcessLocked(pid int, cfg *config.Config) error { return nil } -func gatewayStatusOnHealthFailureLocked() string { +func gatewayStatusWithoutHealthLocked() string { if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" { if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) { return gateway.runtimeStatus @@ -233,23 +249,7 @@ func gatewayStatusOnHealthFailureLocked() string { if gateway.runtimeStatus == "error" { return "error" } - return "error" -} - -func currentGatewayStatusLocked(processAlive bool) string { - if !processAlive { - if gateway.runtimeStatus == "restarting" { - if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) { - return "restarting" - } - return "error" - } - if gateway.runtimeStatus == "error" { - return "error" - } - return "stopped" - } - return gatewayStatusOnHealthFailureLocked() + return "stopped" } func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool { @@ -319,15 +319,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int return 0, err } - // Broadcast the attached state - gateway.events.Broadcast(GatewayEvent{ - Status: initialStatus, - PID: pid, - BootDefaultModel: defaultModelName, - ConfigDefaultModel: defaultModelName, - RestartRequired: false, - }) - return pid, nil } @@ -335,7 +326,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int // Locate the picoclaw executable execPath := utils.FindPicoclawBinary() - cmd = exec.Command(execPath, "gateway") + cmd = exec.Command(execPath, "gateway", "-E") cmd.Env = os.Environ() // Forward the launcher's config path via the environment variable that // GetConfigPath() already reads, so the gateway sub-process uses the same @@ -376,15 +367,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int pid = cmd.Process.Pid log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath) - // Broadcast the launch state immediately so clients can reflect it without polling. - gateway.events.Broadcast(GatewayEvent{ - Status: initialStatus, - PID: pid, - BootDefaultModel: defaultModelName, - ConfigDefaultModel: defaultModelName, - RestartRequired: false, - }) - // Capture stdout/stderr in background go scanPipe(stdoutPipe, gateway.logs) go scanPipe(stderrPipe, gateway.logs) @@ -398,26 +380,17 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int } gateway.mu.Lock() - shouldBroadcastStopped := false if gateway.cmd == cmd { gateway.cmd = nil gateway.bootDefaultModel = "" if gateway.runtimeStatus != "restarting" { setGatewayRuntimeStatusLocked("stopped") - shouldBroadcastStopped = true } } gateway.mu.Unlock() - - if shouldBroadcastStopped { - gateway.events.Broadcast(GatewayEvent{ - Status: "stopped", - RestartRequired: false, - }) - } }() - // Start a goroutine to probe health and broadcast "running" once ready + // Start a goroutine to probe health and update the runtime state once ready. go func() { for i := 0; i < 30; i++ { // try for up to 15 seconds time.Sleep(500 * time.Millisecond) @@ -431,7 +404,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int if err != nil { continue } - healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 1*time.Second) + healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second) if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid { // Verify the health endpoint returns the expected pid gateway.mu.Lock() @@ -439,13 +412,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int setGatewayRuntimeStatusLocked("running") } gateway.mu.Unlock() - gateway.events.Broadcast(GatewayEvent{ - Status: "running", - PID: pid, - BootDefaultModel: defaultModelName, - ConfigDefaultModel: defaultModelName, - RestartRequired: false, - }) return } } @@ -461,7 +427,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) { // Prevent duplicate starts by checking health endpoint cfg, cfgErr := config.LoadConfig(h.configPath) if cfgErr == nil && cfg != nil { - healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second) + healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second) if err == nil && statusCode == http.StatusOK { // Gateway is already running, attach to the existing process pid := healthResp.Pid @@ -597,10 +563,6 @@ func (h *Handler) RestartGateway() (int, error) { gateway.mu.Lock() previousCmd := gateway.cmd setGatewayRuntimeStatusLocked("restarting") - gateway.events.Broadcast(GatewayEvent{ - Status: "restarting", - RestartRequired: false, - }) gateway.mu.Unlock() if err = stopGatewayProcessForRestart(previousCmd); err != nil { @@ -704,24 +666,20 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { func (h *Handler) gatewayStatusData() map[string]any { data := map[string]any{} + configDefaultModel := "" cfg, cfgErr := config.LoadConfig(h.configPath) if cfgErr == nil && cfg != nil { - configDefaultModel := strings.TrimSpace(cfg.Agents.Defaults.GetModelName()) + configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName()) if configDefaultModel != "" { data["config_default_model"] = configDefaultModel } } // Probe health endpoint to get pid and status - port := 0 - if cfgErr == nil && cfg != nil { - port = cfg.Gateway.Port - } - - healthResp, statusCode, err := getGatewayHealth(port, 2*time.Second) + healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second) if err != nil { gateway.mu.Lock() - data["gateway_status"] = currentGatewayStatusLocked(true) + data["gateway_status"] = gatewayStatusWithoutHealthLocked() gateway.mu.Unlock() log.Printf("Gateway health check failed: %v", err) } else { @@ -734,45 +692,43 @@ func (h *Handler) gatewayStatusData() map[string]any { data["status_code"] = statusCode } else { gateway.mu.Lock() - // Check if this pid matches our tracked process - if gateway.cmd != nil && gateway.cmd.Process != nil && gateway.cmd.Process.Pid == healthResp.Pid { - setGatewayRuntimeStatusLocked("running") - bootDefaultModel := gateway.bootDefaultModel - if bootDefaultModel != "" { - data["boot_default_model"] = bootDefaultModel - } - data["gateway_status"] = "running" - data["pid"] = healthResp.Pid - } else { - // Health endpoint responded with a different pid - // This could be a manual restart, try to attach to the new process + setGatewayRuntimeStatusLocked("running") + if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid { oldPid := "none" if gateway.cmd != nil && gateway.cmd.Process != nil { oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid) } - log.Printf("Detected new gateway PID (old: %s, new: %d), attempting to attach", oldPid, healthResp.Pid) - + log.Printf( + "Detected gateway PID from health (old: %s, new: %d), attempting to attach", + oldPid, + healthResp.Pid, + ) if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil { - // Failed to find the process, treat as error - setGatewayRuntimeStatusLocked("error") - data["gateway_status"] = "error" - data["pid"] = healthResp.Pid - log.Printf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err) - } else { - // Successfully attached, update response data - bootDefaultModel := gateway.bootDefaultModel - if bootDefaultModel != "" { - data["boot_default_model"] = bootDefaultModel - } - data["gateway_status"] = "running" - data["pid"] = healthResp.Pid + log.Printf( + "Failed to attach to gateway process reported by health (PID: %d): %v", + healthResp.Pid, + err, + ) } } + + bootDefaultModel := gateway.bootDefaultModel + if bootDefaultModel != "" { + data["boot_default_model"] = bootDefaultModel + } + data["gateway_status"] = "running" + data["pid"] = healthResp.Pid gateway.mu.Unlock() } } - data["gateway_restart_required"] = false + bootDefaultModel, _ := data["boot_default_model"].(string) + gatewayStatus, _ := data["gateway_status"].(string) + data["gateway_restart_required"] = gatewayRestartRequired( + configDefaultModel, + bootDefaultModel, + gatewayStatus, + ) ready, reason, readyErr := h.gatewayStartReady() if readyErr != nil { @@ -842,51 +798,6 @@ func gatewayLogsData(r *http.Request) map[string]any { return data } -// handleGatewayEvents serves an SSE stream of gateway state change events. -// -// GET /api/gateway/events -func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) { - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "SSE not supported", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - // Subscribe to gateway events - ch := gateway.events.Subscribe() - defer gateway.events.Unsubscribe(ch) - - // Send initial status so the client doesn't start blank - initial := h.currentGatewayStatus() - fmt.Fprintf(w, "data: %s\n\n", initial) - flusher.Flush() - - for { - select { - case <-r.Context().Done(): - return - case data, ok := <-ch: - if !ok { - return - } - fmt.Fprintf(w, "data: %s\n\n", data) - flusher.Flush() - } - } -} - -// currentGatewayStatus returns the current gateway status as a JSON string. -func (h *Handler) currentGatewayStatus() string { - data := h.gatewayStatusData() - encoded, _ := json.Marshal(data) - return string(encoded) -} - // scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF. func scanPipe(r io.Reader, buf *LogBuffer) { scanner := bufio.NewScanner(r) diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go index 8dde29b76f..592571a286 100644 --- a/web/backend/api/gateway_host.go +++ b/web/backend/api/gateway_host.go @@ -3,6 +3,7 @@ package api import ( "net" "net/http" + "net/url" "strconv" "strings" @@ -46,6 +47,23 @@ func gatewayProbeHost(bindHost string) string { return bindHost } +func (h *Handler) gatewayProxyURL() *url.URL { + cfg, err := config.LoadConfig(h.configPath) + port := 18790 + bindHost := "" + if err == nil && cfg != nil { + if cfg.Gateway.Port != 0 { + port = cfg.Gateway.Port + } + bindHost = h.effectiveGatewayBindHost(cfg) + } + + return &url.URL{ + Scheme: "http", + Host: net.JoinHostPort(gatewayProbeHost(bindHost), strconv.Itoa(port)), + } +} + func requestHostName(r *http.Request) string { reqHost, _, err := net.SplitHostPort(r.Host) if err == nil { diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index 3fffeb8939..ae3434862f 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -2,9 +2,12 @@ package api import ( "crypto/tls" + "errors" + "net/http" "net/http/httptest" "path/filepath" "testing" + "time" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/web/backend/launcherconfig" @@ -59,6 +62,79 @@ func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) { } } +func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "192.168.1.10" + cfg.Gateway.Port = 18791 + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" { + t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791") + } +} + +func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "192.168.1.10" + cfg.Gateway.Port = 18791 + + originalHealthGet := gatewayHealthGet + t.Cleanup(func() { + gatewayHealthGet = originalHealthGet + }) + + var requestedURL string + gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { + requestedURL = url + return nil, errors.New("probe failed") + } + + _, statusCode, err := h.getGatewayHealth(cfg, time.Second) + _ = statusCode + _ = err + + if requestedURL != "http://192.168.1.10:18791/health" { + t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health") + } +} + +func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + h.SetServerOptions(18800, true, true, nil) + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = 18791 + + originalHealthGet := gatewayHealthGet + t.Cleanup(func() { + gatewayHealthGet = originalHealthGet + }) + + var requestedURL string + gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { + requestedURL = url + return nil, errors.New("probe failed") + } + + _, statusCode, err := h.getGatewayHealth(cfg, time.Second) + _ = statusCode + _ = err + + if requestedURL != "http://127.0.0.1:18791/health" { + t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health") + } +} + func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index fb4f7d9438..5c94f0b891 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -3,6 +3,7 @@ package api import ( "encoding/json" "errors" + "io" "net/http" "net/http/httptest" "os" @@ -36,6 +37,15 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd { return cmd } +func mockGatewayHealthResponse(statusCode, pid int) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader( + `{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(pid) + `}`, + )), + } +} + func startIgnoringTermProcess(t *testing.T) *exec.Cmd { t.Helper() @@ -419,6 +429,125 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T) } } +func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("stopped") + gateway.mu.Unlock() + + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "running" { + t.Fatalf("gateway_status = %#v, want %q", got, "running") + } + if got := body["pid"]; got != float64(cmd.Process.Pid) { + t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid) + } + if got := body["gateway_restart_required"]; got != false { + t.Fatalf("gateway_restart_required = %#v, want false", got) + } +} + +func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "test-key" + cfg.ModelList = append(cfg.ModelList, config.ModelConfig{ + ModelName: "second-model", + Model: "openai/gpt-4.1", + APIKey: "second-key", + }) + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + process, err := os.FindProcess(os.Getpid()) + if err != nil { + t.Fatalf("FindProcess() error = %v", err) + } + + gateway.mu.Lock() + gateway.cmd = &exec.Cmd{Process: process} + gateway.bootDefaultModel = cfg.ModelList[0].ModelName + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() + + updatedCfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + updatedCfg.Agents.Defaults.ModelName = "second-model" + if err := config.SaveConfig(configPath, updatedCfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "running" { + t.Fatalf("gateway_status = %#v, want %q", got, "running") + } + if got := body["boot_default_model"]; got != cfg.ModelList[0].ModelName { + t.Fatalf("boot_default_model = %#v, want %q", got, cfg.ModelList[0].ModelName) + } + if got := body["config_default_model"]; got != "second-model" { + t.Fatalf("config_default_model = %#v, want %q", got, "second-model") + } + if got := body["gateway_restart_required"]; got != true { + t.Fatalf("gateway_restart_required = %#v, want true", got) + } +} + func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) { resetGatewayTestState(t) diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index d11f7bc5e9..a880f2f0c0 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/http/httputil" - "net/url" "time" "github.com/sipeed/picoclaw/pkg/config" @@ -22,20 +21,13 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) { // WebSocket proxy: forward /pico/ws to gateway // This allows the frontend to connect via the same port as the web UI, // avoiding the need to expose extra ports for WebSocket communication. - wsProxy := h.createWsProxy() - mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy(wsProxy)) + mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy()) } -// createWsProxy creates a reverse proxy to the gateway WebSocket endpoint. -// The gateway port is read from the configuration. +// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint. +// The gateway bind host and port are resolved from the latest configuration. func (h *Handler) createWsProxy() *httputil.ReverseProxy { - cfg, err := config.LoadConfig(h.configPath) - gatewayPort := 18790 // default - if err == nil && cfg.Gateway.Port != 0 { - gatewayPort = cfg.Gateway.Port - } - gatewayURL, _ := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", gatewayPort)) - wsProxy := httputil.NewSingleHostReverseProxy(gatewayURL) + wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL()) wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway) } @@ -43,12 +35,10 @@ func (h *Handler) createWsProxy() *httputil.ReverseProxy { } // handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections. -// It ensures the Connection and Upgrade headers are properly forwarded. -func (h *Handler) handleWebSocketProxy(proxy *httputil.ReverseProxy) http.HandlerFunc { +// The reverse proxy forwards the incoming upgrade handshake as-is. +func (h *Handler) handleWebSocketProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - // Set headers for WebSocket upgrade - r.Header.Set("Connection", "upgrade") - r.Header.Set("Upgrade", "websocket") + proxy := h.createWsProxy() proxy.ServeHTTP(w, r) } } diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 46149fa09e..075da4ddc9 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -2,9 +2,12 @@ package api import ( "encoding/json" + "io" "net/http" "net/http/httptest" + "net/url" "path/filepath" + "strconv" "testing" "github.com/sipeed/picoclaw/pkg/config" @@ -235,3 +238,77 @@ func TestHandlePicoSetup_Response(t *testing.T) { t.Error("response should have changed=true on first setup") } } + +func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + handler := h.handleWebSocketProxy() + + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/pico/ws" { + t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws") + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "server1") + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/pico/ws" { + t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws") + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "server2") + })) + defer server2.Close() + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL) + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) + rec1 := httptest.NewRecorder() + handler(rec1, req1) + + if rec1.Code != http.StatusOK { + t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK) + } + if body := rec1.Body.String(); body != "server1" { + t.Fatalf("first body = %q, want %q", body, "server1") + } + + cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL) + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) + rec2 := httptest.NewRecorder() + handler(rec2, req2) + + if rec2.Code != http.StatusOK { + t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK) + } + if body := rec2.Body.String(); body != "server2" { + t.Fatalf("second body = %q, want %q", body, "server2") + } +} + +func mustGatewayTestPort(t *testing.T, rawURL string) int { + t.Helper() + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + + port, err := strconv.Atoi(parsed.Port()) + if err != nil { + t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err) + } + + return port +} diff --git a/web/backend/api/router.go b/web/backend/api/router.go index b564387841..028a476f26 100644 --- a/web/backend/api/router.go +++ b/web/backend/api/router.go @@ -71,7 +71,4 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) { h.registerLauncherConfigRoutes(mux) } -// Shutdown gracefully shuts down the handler, closing all SSE connections. -func (h *Handler) Shutdown() { - gateway.events.Shutdown() -} +func (h *Handler) Shutdown() {} diff --git a/web/backend/middleware/middleware.go b/web/backend/middleware/middleware.go index de9e6d8702..e15da577bf 100644 --- a/web/backend/middleware/middleware.go +++ b/web/backend/middleware/middleware.go @@ -4,16 +4,14 @@ import ( "log" "net/http" "runtime/debug" - "strings" "time" ) // JSONContentType sets the Content-Type header to application/json for // API requests handled by the wrapped handler. -// SSE endpoints (text/event-stream) are excluded. func JSONContentType(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.HasPrefix(r.URL.Path, "/api/") && !strings.HasSuffix(r.URL.Path, "/events") { + if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" { w.Header().Set("Content-Type", "application/json") } next.ServeHTTP(w, r) @@ -32,7 +30,6 @@ func (rr *responseRecorder) WriteHeader(code int) { } // Flush delegates to the underlying ResponseWriter if it implements http.Flusher. -// This is required for SSE (Server-Sent Events) to work through the middleware. func (rr *responseRecorder) Flush() { if f, ok := rr.ResponseWriter.(http.Flusher); ok { f.Flush() diff --git a/web/backend/systray.go b/web/backend/systray.go index 58ce4984fe..1ff98c71b8 100644 --- a/web/backend/systray.go +++ b/web/backend/systray.go @@ -94,7 +94,7 @@ func onReady() { func onExit() { fmt.Println(T(Exiting)) - // First, shutdown API handler to close all SSE connections + // First, shutdown API handler if apiHandler != nil { apiHandler.Shutdown() } diff --git a/web/frontend/src/components/app-header.tsx b/web/frontend/src/components/app-header.tsx index fe0c84e69f..4f06880085 100644 --- a/web/frontend/src/components/app-header.tsx +++ b/web/frontend/src/components/app-header.tsx @@ -56,14 +56,20 @@ export function AppHeader() { const isRunning = gwState === "running" const isStarting = gwState === "starting" const isRestarting = gwState === "restarting" + const isStopping = gwState === "stopping" const isStopped = gwState === "stopped" || gwState === "unknown" const showNotConnectedHint = - !isRestarting && canStart && (gwState === "stopped" || gwState === "error") + !isRestarting && + !isStopping && + canStart && + (gwState === "stopped" || gwState === "error") const [showStopDialog, setShowStopDialog] = React.useState(false) const handleGatewayToggle = () => { - if (gwLoading || isRestarting || (!isRunning && !canStart)) return + if (gwLoading || isRestarting || isStopping || (!isRunning && !canStart)) { + return + } if (isRunning) { setShowStopDialog(true) } else { @@ -137,7 +143,7 @@ export function AppHeader() { size="icon-sm" className="bg-amber-500/15 text-amber-700 hover:bg-amber-500/25 hover:text-amber-800 dark:text-amber-300 dark:hover:bg-amber-500/25" onClick={handleGatewayRestart} - disabled={gwLoading || isRestarting || !canStart} + disabled={gwLoading || isRestarting || isStopping || !canStart} aria-label={t("header.gateway.action.restart")} > @@ -168,25 +174,31 @@ export function AppHeader() { ) : ( )} diff --git a/web/frontend/src/hooks/use-gateway-logs.ts b/web/frontend/src/hooks/use-gateway-logs.ts index 15cbca4ae2..1de3611248 100644 --- a/web/frontend/src/hooks/use-gateway-logs.ts +++ b/web/frontend/src/hooks/use-gateway-logs.ts @@ -37,7 +37,9 @@ export function useGatewayLogs() { const fetchLogs = async () => { if ( !mounted || - !["running", "starting", "restarting"].includes(gateway.status) + !["running", "starting", "restarting", "stopping"].includes( + gateway.status, + ) ) { if (mounted) { timeout = setTimeout(fetchLogs, 1000) diff --git a/web/frontend/src/hooks/use-gateway.ts b/web/frontend/src/hooks/use-gateway.ts index 65ec2b7767..b118b43da4 100644 --- a/web/frontend/src/hooks/use-gateway.ts +++ b/web/frontend/src/hooks/use-gateway.ts @@ -1,83 +1,24 @@ import { useAtomValue } from "jotai" import { useCallback, useEffect, useState } from "react" +import { restartGateway, startGateway, stopGateway } from "@/api/gateway" import { - type GatewayStatusResponse, - getGatewayStatus, - restartGateway, - startGateway, - stopGateway, -} from "@/api/gateway" -import { - applyGatewayStatusToStore, + beginGatewayStoppingTransition, + cancelGatewayStoppingTransition, gatewayAtom, + refreshGatewayState, + subscribeGatewayPolling, updateGatewayStore, } from "@/store" -// Global variable to ensure we only have one SSE connection -let sseInitialized = false - export function useGateway() { const gateway = useAtomValue(gatewayAtom) const { status: state, canStart, restartRequired } = gateway const [loading, setLoading] = useState(false) - const applyGatewayStatus = useCallback((data: GatewayStatusResponse) => { - applyGatewayStatusToStore(data) - }, []) - - // Initialize global SSE connection once useEffect(() => { - if (sseInitialized) return - sseInitialized = true - - getGatewayStatus() - .then((data) => applyGatewayStatus(data)) - .catch(() => { - updateGatewayStore({ - status: "unknown", - canStart: true, - restartRequired: false, - }) - }) - - const statusPoll = window.setInterval(() => { - getGatewayStatus() - .then((data) => applyGatewayStatus(data)) - .catch(() => { - // ignore polling errors - }) - }, 5000) - - // Subscribe to SSE for real-time updates globally - const es = new EventSource("/api/gateway/events") - - es.onmessage = (event) => { - try { - const data = JSON.parse(event.data) - if ( - data.gateway_status || - typeof data.gateway_start_allowed === "boolean" - ) { - applyGatewayStatus(data) - } - } catch { - // ignore - } - } - - es.onerror = () => { - // EventSource will auto-reconnect. Preserve the last known gateway - // status so transient SSE disconnects do not suppress chat websocket - // reconnects while polling catches up. - } - - return () => { - window.clearInterval(statusPoll) - es.close() - sseInitialized = false - } - }, [applyGatewayStatus]) + return subscribeGatewayPolling() + }, []) const start = useCallback(async () => { if (!canStart) return @@ -85,33 +26,28 @@ export function useGateway() { setLoading(true) try { await startGateway() - // SSE will push the real state changes, but set optimistic state - updateGatewayStore({ status: "starting" }) + updateGatewayStore({ + status: "starting", + restartRequired: false, + }) } catch (err) { console.error("Failed to start gateway:", err) - try { - const status = await getGatewayStatus() - applyGatewayStatus(status) - } catch { - updateGatewayStore({ status: "unknown" }) - } } finally { + await refreshGatewayState({ force: true }) setLoading(false) } - }, [applyGatewayStatus, canStart]) + }, [canStart]) const stop = useCallback(async () => { setLoading(true) + beginGatewayStoppingTransition() try { await stopGateway() - updateGatewayStore({ - status: "stopped", - canStart: true, - restartRequired: false, - }) } catch (err) { console.error("Failed to stop gateway:", err) + cancelGatewayStoppingTransition() } finally { + await refreshGatewayState({ force: true }) setLoading(false) } }, []) @@ -119,34 +55,20 @@ export function useGateway() { const restart = useCallback(async () => { if (state !== "running") return - const previousState = state - const previousCanStart = canStart - const previousRestartRequired = restartRequired - setLoading(true) - updateGatewayStore({ - status: "restarting", - restartRequired: false, - }) - try { await restartGateway() + updateGatewayStore({ + status: "restarting", + restartRequired: false, + }) } catch (err) { console.error("Failed to restart gateway:", err) - try { - const status = await getGatewayStatus() - applyGatewayStatus(status) - } catch { - updateGatewayStore({ - status: previousState, - canStart: previousCanStart, - restartRequired: previousRestartRequired, - }) - } } finally { + await refreshGatewayState({ force: true }) setLoading(false) } - }, [applyGatewayStatus, canStart, restartRequired, state]) + }, [state]) return { state, loading, canStart, restartRequired, start, stop, restart } } diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index 2fa32ebb5a..327b4c6466 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -63,7 +63,8 @@ }, "status": { "starting": "Starting Gateway...", - "restarting": "Restarting Gateway..." + "restarting": "Restarting Gateway...", + "stopping": "Stopping Gateway..." }, "restartRequired": "Model changes require a gateway restart to take effect." } diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index badf5bb3d0..cd674ddc14 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -63,7 +63,8 @@ }, "status": { "starting": "ๆœๅŠกๅฏๅŠจไธญ...", - "restarting": "ๆœๅŠก้‡ๅฏไธญ..." + "restarting": "ๆœๅŠก้‡ๅฏไธญ...", + "stopping": "ๆœๅŠกๅœๆญขไธญ..." }, "restartRequired": "ๅˆ‡ๆข้ป˜่ฎคๆจกๅž‹ๅŽ้œ€่ฆ้‡ๅฏๆœๅŠกๆ‰่ƒฝ็”Ÿๆ•ˆใ€‚" } diff --git a/web/frontend/src/store/gateway.ts b/web/frontend/src/store/gateway.ts index c5eee84516..1bdec6220c 100644 --- a/web/frontend/src/store/gateway.ts +++ b/web/frontend/src/store/gateway.ts @@ -6,6 +6,7 @@ export type GatewayState = | "running" | "starting" | "restarting" + | "stopping" | "stopped" | "error" | "unknown" @@ -24,9 +25,29 @@ const DEFAULT_GATEWAY_STATE: GatewayStoreState = { restartRequired: false, } +const GATEWAY_POLL_INTERVAL_MS = 2000 +const GATEWAY_TRANSIENT_POLL_INTERVAL_MS = 1000 +const GATEWAY_STOPPING_TIMEOUT_MS = 5000 + +interface RefreshGatewayStateOptions { + force?: boolean +} + // Global atom for gateway state export const gatewayAtom = atom(DEFAULT_GATEWAY_STATE) +let gatewayPollingSubscribers = 0 +let gatewayPollingTimer: ReturnType | null = null +let gatewayPollingRequest: Promise | null = null +let gatewayStoppingTimer: ReturnType | null = null + +function clearGatewayStoppingTimeout() { + if (gatewayStoppingTimer !== null) { + clearTimeout(gatewayStoppingTimer) + gatewayStoppingTimer = null + } +} + function normalizeGatewayStoreState( prev: GatewayStoreState, patch: GatewayStorePatch, @@ -49,10 +70,38 @@ export function updateGatewayStore( | GatewayStorePatch | ((prev: GatewayStoreState) => GatewayStorePatch | GatewayStoreState), ) { - getDefaultStore().set(gatewayAtom, (prev) => { + const store = getDefaultStore() + store.set(gatewayAtom, (prev) => { const nextPatch = typeof patch === "function" ? patch(prev) : patch return normalizeGatewayStoreState(prev, nextPatch) }) + const nextState = store.get(gatewayAtom) + if (nextState?.status !== "stopping") { + clearGatewayStoppingTimeout() + } +} + +export function beginGatewayStoppingTransition() { + clearGatewayStoppingTimeout() + updateGatewayStore({ + status: "stopping", + canStart: false, + restartRequired: false, + }) + gatewayStoppingTimer = setTimeout(() => { + gatewayStoppingTimer = null + updateGatewayStore((prev) => + prev.status === "stopping" ? { status: "running" } : prev, + ) + void refreshGatewayState({ force: true }) + }, GATEWAY_STOPPING_TIMEOUT_MS) +} + +export function cancelGatewayStoppingTransition() { + clearGatewayStoppingTimeout() + updateGatewayStore((prev) => + prev.status === "stopping" ? { status: "running" } : prev, + ) } export function applyGatewayStatusToStore( @@ -64,21 +113,92 @@ export function applyGatewayStatusToStore( >, ) { updateGatewayStore((prev) => ({ - status: data.gateway_status ?? prev.status, - canStart: data.gateway_start_allowed ?? prev.canStart, + status: + prev.status === "stopping" && data.gateway_status === "running" + ? "stopping" + : (data.gateway_status ?? prev.status), + canStart: + prev.status === "stopping" && data.gateway_status === "running" + ? false + : (data.gateway_start_allowed ?? prev.canStart), restartRequired: - data.gateway_restart_required ?? - (data.gateway_status && data.gateway_status !== "running" + prev.status === "stopping" && data.gateway_status === "running" ? false - : prev.restartRequired), + : (data.gateway_restart_required ?? prev.restartRequired), })) } -export async function refreshGatewayState() { +function nextGatewayPollInterval() { + const status = getDefaultStore().get(gatewayAtom).status + if ( + status === "starting" || + status === "restarting" || + status === "stopping" + ) { + return GATEWAY_TRANSIENT_POLL_INTERVAL_MS + } + return GATEWAY_POLL_INTERVAL_MS +} + +function scheduleGatewayPoll(delay = nextGatewayPollInterval()) { + if (gatewayPollingSubscribers === 0) { + return + } + + if (gatewayPollingTimer !== null) { + clearTimeout(gatewayPollingTimer) + } + + gatewayPollingTimer = setTimeout(() => { + gatewayPollingTimer = null + void refreshGatewayState() + }, delay) +} + +export async function refreshGatewayState( + options: RefreshGatewayStateOptions = {}, +) { + if (gatewayPollingRequest) { + await gatewayPollingRequest + if (options.force) { + return refreshGatewayState() + } + return + } + + gatewayPollingRequest = (async () => { + try { + const status = await getGatewayStatus() + applyGatewayStatusToStore(status) + } catch { + // Preserve the last known state when a poll fails. + } finally { + gatewayPollingRequest = null + scheduleGatewayPoll() + } + })() + try { - const status = await getGatewayStatus() - applyGatewayStatusToStore(status) - } catch { - updateGatewayStore(DEFAULT_GATEWAY_STATE) + await gatewayPollingRequest + } finally { + if (gatewayPollingSubscribers === 0 && gatewayPollingTimer !== null) { + clearTimeout(gatewayPollingTimer) + gatewayPollingTimer = null + } + } +} + +export function subscribeGatewayPolling() { + gatewayPollingSubscribers += 1 + if (gatewayPollingSubscribers === 1) { + void refreshGatewayState() + } + + return () => { + gatewayPollingSubscribers = Math.max(0, gatewayPollingSubscribers - 1) + if (gatewayPollingSubscribers === 0 && gatewayPollingTimer !== null) { + clearTimeout(gatewayPollingTimer) + gatewayPollingTimer = null + } } }