diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index fed3d5ffbe..3562f03ef0 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -3,10 +3,10 @@ package gateway import ( "context" "fmt" - "log" "os" "os/signal" "path/filepath" + "sync" "time" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" @@ -41,12 +41,31 @@ import ( "github.com/sipeed/picoclaw/pkg/voice" ) +// Timeout constants for service operations +const ( + serviceRestartTimeout = 30 * time.Second + serviceShutdownTimeout = 30 * time.Second + providerReloadTimeout = 30 * time.Second + gracefulShutdownTimeout = 15 * time.Second +) + +// gatewayServices holds references to all running services +type gatewayServices struct { + CronService *cron.CronService + HeartbeatService *heartbeat.HeartbeatService + MediaStore media.MediaStore + ChannelManager *channels.Manager + DeviceService *devices.Service + HealthServer *health.Server +} + func gatewayCmd(debug bool) error { if debug { logger.SetLevel(logger.DEBUG) fmt.Println("🔍 Debug mode enabled") } + configPath := internal.GetConfigPath() cfg, err := internal.LoadConfig() if err != nil { return fmt.Errorf("error loading config: %w", err) @@ -83,9 +102,55 @@ func gatewayCmd(debug bool) error { "skills_available": skillsInfo["available"], }) + // Setup and start all services + services, err := setupAndStartServices(cfg, agentLoop, msgBus) + if err != nil { + return err + } + + fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) + fmt.Println("Press Ctrl+C to stop") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go agentLoop.Run(ctx) + + // Setup config file watcher for hot reload + configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug) + defer stopWatch() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + + // Main event loop - wait for signals or config changes + for { + select { + case <-sigChan: + logger.Info("Shutting down...") + shutdownGateway(services, agentLoop, provider, true) + return nil + + case newCfg := <-configReloadChan: + err := handleConfigReload(ctx, agentLoop, newCfg, &provider, services, msgBus) + if err != nil { + logger.Errorf("Config reload failed: %v", err) + } + } + } +} + +// setupAndStartServices initializes and starts all services +func setupAndStartServices( + cfg *config.Config, + agentLoop *agent.AgentLoop, + msgBus *bus.MessageBus, +) (*gatewayServices, error) { + services := &gatewayServices{} + // Setup cron tool and service execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute - cronService := setupCronTool( + services.CronService = setupCronTool( agentLoop, msgBus, cfg.WorkspacePath(), @@ -93,20 +158,26 @@ func gatewayCmd(debug bool) error { execTimeout, cfg, ) + if err := services.CronService.Start(); err != nil { + return nil, fmt.Errorf("error starting cron service: %w", err) + } + fmt.Println("✓ Cron service started") - heartbeatService := heartbeat.NewHeartbeatService( + // Setup heartbeat service + services.HeartbeatService = heartbeat.NewHeartbeatService( cfg.WorkspacePath(), cfg.Heartbeat.Interval, cfg.Heartbeat.Enabled, ) - heartbeatService.SetBus(msgBus) - heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + services.HeartbeatService.SetBus(msgBus) + services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { // Use cli:direct as fallback if no valid channel if channel == "" || chatID == "" { channel, chatID = "cli", "direct" } // Use ProcessHeartbeat - no session history, each heartbeat is independent var response string + var err error response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) if err != nil { return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) @@ -118,24 +189,36 @@ func gatewayCmd(debug bool) error { // sent to user via processSystemMessage when the async task completes return tools.SilentResult(response) }) + if err := services.HeartbeatService.Start(); err != nil { + return nil, fmt.Errorf("error starting heartbeat service: %w", err) + } + fmt.Println("✓ Heartbeat service started") // Create media store for file lifecycle management with TTL cleanup - mediaStore := media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{ + services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{ Enabled: cfg.Tools.MediaCleanup.Enabled, MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute, Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute, }) - mediaStore.Start() + // Start the media store if it's a FileMediaStore with cleanup + if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + fms.Start() + } - channelManager, err := channels.NewManager(cfg, msgBus, mediaStore) + // Create channel manager + var err error + services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore) if err != nil { - mediaStore.Stop() - return fmt.Errorf("error creating channel manager: %w", err) + // Stop the media store if it's a FileMediaStore with cleanup + if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + fms.Stop() + } + return nil, fmt.Errorf("error creating channel manager: %w", err) } // Inject channel manager and media store into agent loop - agentLoop.SetChannelManager(channelManager) - agentLoop.SetMediaStore(mediaStore) + agentLoop.SetChannelManager(services.ChannelManager) + agentLoop.SetMediaStore(services.MediaStore) // Wire up voice transcription if a supported provider is configured. if transcriber := voice.DetectTranscriber(cfg); transcriber != nil { @@ -143,83 +226,386 @@ func gatewayCmd(debug bool) error { logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()}) } - enabledChannels := channelManager.GetEnabledChannels() + enabledChannels := services.ChannelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf("✓ Channels enabled: %s\n", enabledChannels) } else { fmt.Println("⚠ Warning: No channels enabled") } - fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) - fmt.Println("Press Ctrl+C to stop") - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + // Setup shared HTTP server with health endpoints and webhook handlers + addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) + services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + services.ChannelManager.SetupHTTPServer(addr, services.HealthServer) - if err := cronService.Start(); err != nil { - fmt.Printf("Error starting cron service: %v\n", err) + if err := services.ChannelManager.StartAll(context.Background()); err != nil { + return nil, fmt.Errorf("error starting channels: %w", err) } - fmt.Println("✓ Cron service started") - if err := heartbeatService.Start(); err != nil { - fmt.Printf("Error starting heartbeat service: %v\n", err) - } - fmt.Println("✓ Heartbeat service started") + fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) + // Setup state manager and device service stateManager := state.NewManager(cfg.WorkspacePath()) - deviceService := devices.NewService(devices.Config{ + services.DeviceService = devices.NewService(devices.Config{ Enabled: cfg.Devices.Enabled, MonitorUSB: cfg.Devices.MonitorUSB, }, stateManager) - deviceService.SetBus(msgBus) - if err := deviceService.Start(ctx); err != nil { - fmt.Printf("Error starting device service: %v\n", err) + services.DeviceService.SetBus(msgBus) + if err := services.DeviceService.Start(context.Background()); err != nil { + logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()}) } else if cfg.Devices.Enabled { fmt.Println("✓ Device event service started") } - // Setup shared HTTP server with health endpoints and webhook handlers - healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) - channelManager.SetupHTTPServer(addr, healthServer) - - if err := channelManager.StartAll(ctx); err != nil { - fmt.Printf("Error starting channels: %v\n", err) - return err - } - - fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) + return services, nil +} - go agentLoop.Run(ctx) +// stopAndCleanupServices stops all services and cleans up resources +func stopAndCleanupServices( + services *gatewayServices, + shutdownTimeout time.Duration, +) { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer shutdownCancel() - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt) - <-sigChan + if services.ChannelManager != nil { + services.ChannelManager.StopAll(shutdownCtx) + } + if services.DeviceService != nil { + services.DeviceService.Stop() + } + if services.HeartbeatService != nil { + services.HeartbeatService.Stop() + } + if services.CronService != nil { + services.CronService.Stop() + } + if services.MediaStore != nil { + // Stop the media store if it's a FileMediaStore with cleanup + if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + fms.Stop() + } + } +} - fmt.Println("\nShutting down...") - if cp, ok := provider.(providers.StatefulProvider); ok { +// shutdownGateway performs a complete gateway shutdown +func shutdownGateway( + services *gatewayServices, + agentLoop *agent.AgentLoop, + provider providers.LLMProvider, + fullShutdown bool, +) { + if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown { cp.Close() } - cancel() - msgBus.Close() - // Use a fresh context with timeout for graceful shutdown, - // since the original ctx is already canceled. - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second) - defer shutdownCancel() + stopAndCleanupServices(services, gracefulShutdownTimeout) - channelManager.StopAll(shutdownCtx) - deviceService.Stop() - heartbeatService.Stop() - cronService.Stop() - mediaStore.Stop() agentLoop.Stop() agentLoop.Close() - fmt.Println("✓ Gateway stopped") + + logger.Info("✓ Gateway stopped") +} + +// handleConfigReload handles config file reload by stopping all services, +// reloading the provider and config, and restarting services with the new config. +func handleConfigReload( + ctx context.Context, + al *agent.AgentLoop, + newCfg *config.Config, + providerRef *providers.LLMProvider, + services *gatewayServices, + msgBus *bus.MessageBus, +) error { + logger.Info("🔄 Config file changed, reloading...") + + newModel := newCfg.Agents.Defaults.ModelName + if newModel == "" { + newModel = newCfg.Agents.Defaults.Model + } + + logger.Infof(" New model is '%s', recreating provider...", newModel) + + // Stop all services before reloading + logger.Info(" Stopping all services...") + stopAndCleanupServices(services, serviceShutdownTimeout) + + // Create new provider from updated config first to ensure validity + // This will use the correct API key and settings from newCfg.ModelList + newProvider, newModelID, err := providers.CreateProvider(newCfg) + if err != nil { + logger.Errorf(" ⚠ Error creating new provider: %v", err) + logger.Warn(" Attempting to restart services with old provider and config...") + // Try to restart services with old configuration + if restartErr := restartServices(al, services, msgBus); restartErr != nil { + logger.Errorf(" ⚠ Failed to restart services: %v", restartErr) + } + return fmt.Errorf("error creating new provider: %w", err) + } + + if newModelID != "" { + newCfg.Agents.Defaults.ModelName = newModelID + } + + // Use the atomic reload method on AgentLoop to safely swap provider and config. + // This handles locking internally to prevent races with in-flight LLM calls + // and concurrent reads of registry/config while the swap occurs. + reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout) + defer reloadCancel() + + if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil { + logger.Errorf(" ⚠ Error reloading agent loop: %v", err) + // Close the newly created provider since it wasn't adopted + if cp, ok := newProvider.(providers.StatefulProvider); ok { + cp.Close() + } + logger.Warn(" Attempting to restart services with old provider and config...") + if restartErr := restartServices(al, services, msgBus); restartErr != nil { + logger.Errorf(" ⚠ Failed to restart services: %v", restartErr) + } + return fmt.Errorf("error reloading agent loop: %w", err) + } + + // Update local provider reference only after successful atomic reload + *providerRef = newProvider + + // Restart all services with new config + logger.Info(" Restarting all services with new configuration...") + if err := restartServices(al, services, msgBus); err != nil { + logger.Errorf(" ⚠ Error restarting services: %v", err) + return fmt.Errorf("error restarting services: %w", err) + } + + logger.Info(" ✓ Provider, configuration, and services reloaded successfully (thread-safe)") + return nil +} + +// restartServices restarts all services after a config reload +func restartServices( + al *agent.AgentLoop, + services *gatewayServices, + msgBus *bus.MessageBus, +) error { + // Create an independent context with timeout for service restart + // This prevents cancellation from the main loop context during reload + ctx, cancel := context.WithTimeout(context.Background(), serviceRestartTimeout) + defer cancel() + + // Get current config from agent loop (which has been updated if this is a reload) + cfg := al.GetConfig() + + // Re-create and start cron service with new config + execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute + services.CronService = setupCronTool( + al, + msgBus, + cfg.WorkspacePath(), + cfg.Agents.Defaults.RestrictToWorkspace, + execTimeout, + cfg, + ) + if err := services.CronService.Start(); err != nil { + return fmt.Errorf("error restarting cron service: %w", err) + } + fmt.Println(" ✓ Cron service restarted") + + // Re-create and start heartbeat service with new config + services.HeartbeatService = heartbeat.NewHeartbeatService( + cfg.WorkspacePath(), + cfg.Heartbeat.Interval, + cfg.Heartbeat.Enabled, + ) + services.HeartbeatService.SetBus(msgBus) + services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + if channel == "" || chatID == "" { + channel, chatID = "cli", "direct" + } + var response string + var err error + response, err = al.ProcessHeartbeat(context.Background(), prompt, channel, chatID) + if err != nil { + return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) + } + if response == "HEARTBEAT_OK" { + return tools.SilentResult("Heartbeat OK") + } + return tools.SilentResult(response) + }) + if err := services.HeartbeatService.Start(); err != nil { + return fmt.Errorf("error restarting heartbeat service: %w", err) + } + fmt.Println(" ✓ Heartbeat service restarted") + + // Stop the old media store before creating a new one + if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + fms.Stop() + } + + // Re-create media store with new config + services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{ + Enabled: cfg.Tools.MediaCleanup.Enabled, + MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute, + Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute, + }) + // Start the media store if it's a FileMediaStore with cleanup + if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + fms.Start() + } + al.SetMediaStore(services.MediaStore) + + // Re-create channel manager with new config + var err error + services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore) + if err != nil { + // Stop the media store if it's a FileMediaStore with cleanup + if fms, ok := services.MediaStore.(*media.FileMediaStore); ok { + fms.Stop() + } + return fmt.Errorf("error recreating channel manager: %w", err) + } + al.SetChannelManager(services.ChannelManager) + + enabledChannels := services.ChannelManager.GetEnabledChannels() + if len(enabledChannels) > 0 { + fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels) + } else { + fmt.Println(" ⚠ Warning: No channels enabled") + } + + // Setup HTTP server with new config + addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) + services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + services.ChannelManager.SetupHTTPServer(addr, services.HealthServer) + + if err := services.ChannelManager.StartAll(ctx); err != nil { + return fmt.Errorf("error restarting channels: %w", err) + } + fmt.Printf( + " ✓ Channels restarted, health endpoints at http://%s:%d/health and ready\n", + cfg.Gateway.Host, + cfg.Gateway.Port, + ) + + // Re-create device service with new config + stateManager := state.NewManager(cfg.WorkspacePath()) + services.DeviceService = devices.NewService(devices.Config{ + Enabled: cfg.Devices.Enabled, + MonitorUSB: cfg.Devices.MonitorUSB, + }, stateManager) + services.DeviceService.SetBus(msgBus) + if err := services.DeviceService.Start(ctx); err != nil { + logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()}) + } else if cfg.Devices.Enabled { + fmt.Println(" ✓ Device event service restarted") + } + + // Wire up voice transcription with new config + transcriber := voice.DetectTranscriber(cfg) + al.SetTranscriber(transcriber) // This will set it to nil if disabled + if transcriber != nil { + logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()}) + } else { + logger.InfoCF("voice", "Transcription disabled", nil) + } return nil } +// setupConfigWatcherPolling sets up a simple polling-based config file watcher +// Returns a channel for config updates and a stop function +func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) { + configChan := make(chan *config.Config, 1) + stop := make(chan struct{}) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + // Get initial file info + lastModTime := getFileModTime(configPath) + lastSize := getFileSize(configPath) + + ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds + defer ticker.Stop() + + for { + select { + case <-ticker.C: + currentModTime := getFileModTime(configPath) + currentSize := getFileSize(configPath) + + // Check if file changed (modification time or size changed) + if currentModTime.After(lastModTime) || currentSize != lastSize { + if debug { + logger.Debugf("🔍 Config file change detected") + } + + // Debounce - wait a bit to ensure file write is complete + time.Sleep(500 * time.Millisecond) + + // Validate and load new config + newCfg, err := config.LoadConfig(configPath) + if err != nil { + logger.Errorf("⚠ Error loading new config: %v", err) + logger.Warn(" Using previous valid config") + continue + } + + // Validate the new config + if err := newCfg.ValidateModelList(); err != nil { + logger.Errorf(" ⚠ New config validation failed: %v", err) + logger.Warn(" Using previous valid config") + continue + } + + logger.Info("✓ Config file validated and loaded") + + // Update last known state + lastModTime = currentModTime + lastSize = currentSize + + // Send new config to main loop (non-blocking) + select { + case configChan <- newCfg: + default: + // Channel full, skip this update + logger.Warn("⚠ Previous config reload still in progress, skipping") + } + } + + case <-stop: + return + } + } + }() + + stopFunc := func() { + close(stop) + wg.Wait() + } + + return configChan, stopFunc +} + +// getFileModTime returns the modification time of a file, or zero time if file doesn't exist +func getFileModTime(path string) time.Time { + info, err := os.Stat(path) + if err != nil { + return time.Time{} + } + return info.ModTime() +} + +// getFileSize returns the size of a file, or 0 if file doesn't exist +func getFileSize(path string) int64 { + info, err := os.Stat(path) + if err != nil { + return 0 + } + return info.Size() +} + func setupCronTool( agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, @@ -239,7 +625,7 @@ func setupCronTool( var err error cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) if err != nil { - log.Fatalf("Critical error during CronTool initialization: %v", err) + logger.Fatalf("Critical error during CronTool initialization: %v", err) } agentLoop.RegisterTool(cronTool) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 28e549ce03..dfa339dee1 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -48,6 +48,9 @@ type AgentLoop struct { transcriber voice.Transcriber cmdRegistry *commands.Registry mcp mcpRuntime + mu sync.RWMutex + // Track active requests for safe provider cleanup + activeRequests sync.WaitGroup } // processOptions configures how a message is processed @@ -239,6 +242,7 @@ func registerSharedTools( func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) + if err := al.ensureMCPInitialized(ctx); err != nil { return err } @@ -278,7 +282,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { // If so, skip publishing to avoid duplicate messages to the user. // Use default agent's tools to check (message tool is shared). alreadySent := false - defaultAgent := al.registry.GetDefaultAgent() + defaultAgent := al.GetRegistry().GetDefaultAgent() if defaultAgent != nil { if tool, ok := defaultAgent.Tools.Get("message"); ok { if mt, ok := tool.(*tools.MessageTool); ok { @@ -331,12 +335,13 @@ func (al *AgentLoop) Close() { } } - al.registry.Close() + al.GetRegistry().Close() } func (al *AgentLoop) RegisterTool(tool tools.Tool) { - for _, agentID := range al.registry.ListAgentIDs() { - if agent, ok := al.registry.GetAgent(agentID); ok { + registry := al.GetRegistry() + for _, agentID := range registry.ListAgentIDs() { + if agent, ok := registry.GetAgent(agentID); ok { agent.Tools.Register(tool) } } @@ -346,12 +351,123 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { al.channelManager = cm } +// ReloadProviderAndConfig atomically swaps the provider and config with proper synchronization. +// It uses a context to allow timeout control from the caller. +// Returns an error if the reload fails or context is canceled. +func (al *AgentLoop) ReloadProviderAndConfig( + ctx context.Context, + provider providers.LLMProvider, + cfg *config.Config, +) error { + // Validate inputs + if provider == nil { + return fmt.Errorf("provider cannot be nil") + } + if cfg == nil { + return fmt.Errorf("config cannot be nil") + } + + // Create new registry with updated config and provider + // Wrap in defer/recover to handle any panics gracefully + var registry *AgentRegistry + var panicErr error + done := make(chan struct{}, 1) + + go func() { + defer func() { + if r := recover(); r != nil { + panicErr = fmt.Errorf("panic during registry creation: %v", r) + logger.ErrorCF("agent", "Panic during registry creation", + map[string]any{"panic": r}) + } + close(done) + }() + + registry = NewAgentRegistry(cfg, provider) + }() + + // Wait for completion or context cancellation + select { + case <-done: + if registry == nil { + if panicErr != nil { + return fmt.Errorf("registry creation failed: %w", panicErr) + } + return fmt.Errorf("registry creation failed (nil result)") + } + case <-ctx.Done(): + return fmt.Errorf("context canceled during registry creation: %w", ctx.Err()) + } + + // Check context again before proceeding + if err := ctx.Err(); err != nil { + return fmt.Errorf("context canceled after registry creation: %w", err) + } + + // Ensure shared tools are re-registered on the new registry + registerSharedTools(cfg, al.bus, registry, provider) + + // Atomically swap the config and registry under write lock + // This ensures readers see a consistent pair + al.mu.Lock() + oldRegistry := al.registry + + // Store new values + al.cfg = cfg + al.registry = registry + + // Also update fallback chain with new config + al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker()) + + al.mu.Unlock() + + // Close old provider after releasing the lock + // This prevents blocking readers while closing + if oldProvider, ok := extractProvider(oldRegistry); ok { + if stateful, ok := oldProvider.(providers.StatefulProvider); ok { + // Give in-flight requests a moment to complete + // Use a reasonable timeout that balances cleanup vs resource usage + select { + case <-time.After(100 * time.Millisecond): + stateful.Close() + case <-ctx.Done(): + // Context canceled, close immediately but log warning + logger.WarnCF("agent", "Context canceled during provider cleanup, forcing close", + map[string]any{"error": ctx.Err()}) + stateful.Close() + } + } + } + + logger.InfoCF("agent", "Provider and config reloaded successfully", + map[string]any{ + "model": cfg.Agents.Defaults.GetModelName(), + }) + + return nil +} + +// GetRegistry returns the current registry (thread-safe) +func (al *AgentLoop) GetRegistry() *AgentRegistry { + al.mu.RLock() + defer al.mu.RUnlock() + return al.registry +} + +// GetConfig returns the current config (thread-safe) +func (al *AgentLoop) GetConfig() *config.Config { + al.mu.RLock() + defer al.mu.RUnlock() + return al.cfg +} + // SetMediaStore injects a MediaStore for media lifecycle management. func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s // Propagate store to send_file tools in all agents. - al.registry.ForEachTool("send_file", func(t tools.Tool) { + registry := al.GetRegistry() + registry.ForEachTool("send_file", func(t tools.Tool) { if sf, ok := t.(*tools.SendFileTool); ok { sf.SetMediaStore(s) } @@ -540,7 +656,7 @@ func (al *AgentLoop) ProcessHeartbeat( ctx context.Context, content, channel, chatID string, ) (string, error) { - agent := al.registry.GetDefaultAgent() + agent := al.GetRegistry().GetDefaultAgent() if agent == nil { return "", fmt.Errorf("no default agent for heartbeat") } @@ -636,7 +752,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) { - route := al.registry.ResolveRoute(routing.RouteInput{ + registry := al.GetRegistry() + route := registry.ResolveRoute(routing.RouteInput{ Channel: msg.Channel, AccountID: inboundMetadata(msg, metadataKeyAccountID), Peer: extractPeer(msg), @@ -645,9 +762,9 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv TeamID: inboundMetadata(msg, metadataKeyTeamID), }) - agent, ok := al.registry.GetAgent(route.AgentID) + agent, ok := registry.GetAgent(route.AgentID) if !ok { - agent = al.registry.GetDefaultAgent() + agent = registry.GetDefaultAgent() } if agent == nil { return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) @@ -709,7 +826,7 @@ func (al *AgentLoop) processSystemMessage( } // Use default agent for system messages - agent := al.registry.GetDefaultAgent() + agent := al.GetRegistry().GetDefaultAgent() if agent == nil { return "", fmt.Errorf("no default agent for system message") } @@ -765,7 +882,8 @@ func (al *AgentLoop) runAgentLoop( ) // Resolve media:// refs to base64 data URLs (streaming) - maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize() + cfg := al.GetConfig() + maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize() messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) // 2. Save user message to session @@ -943,6 +1061,9 @@ func (al *AgentLoop) runLLMIteration( } callLLM := func() (*providers.LLMResponse, error) { + al.activeRequests.Add(1) + defer al.activeRequests.Done() + if len(activeCandidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( ctx, @@ -1041,6 +1162,7 @@ func (al *AgentLoop) runLLMIteration( map[string]any{ "agent_id": agent.ID, "iteration": iteration, + "model": activeModel, "error": err.Error(), }) return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) @@ -1392,7 +1514,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { func (al *AgentLoop) GetStartupInfo() map[string]any { info := make(map[string]any) - agent := al.registry.GetDefaultAgent() + registry := al.GetRegistry() + agent := registry.GetDefaultAgent() if agent == nil { return info } @@ -1409,8 +1532,8 @@ func (al *AgentLoop) GetStartupInfo() map[string]any { // Agents info info["agents"] = map[string]any{ - "count": len(al.registry.ListAgentIDs()), - "ids": al.registry.ListAgentIDs(), + "count": len(registry.ListAgentIDs()), + "ids": registry.ListAgentIDs(), } return info @@ -1598,17 +1721,22 @@ func (al *AgentLoop) retryLLMCall( var err error for attempt := 0; attempt < maxRetries; attempt++ { - resp, err = agent.Provider.Chat( - ctx, - []providers.Message{{Role: "user", Content: prompt}}, - nil, - agent.Model, - map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": llmTemperature, - "prompt_cache_key": agent.ID, - }, - ) + al.activeRequests.Add(1) + resp, err = func() (*providers.LLMResponse, error) { + defer al.activeRequests.Done() + return agent.Provider.Chat( + ctx, + []providers.Message{{Role: "user", Content: prompt}}, + nil, + agent.Model, + map[string]any{ + "max_tokens": agent.MaxTokens, + "temperature": llmTemperature, + "prompt_cache_key": agent.ID, + }, + ) + }() + if err == nil && resp != nil && resp.Content != "" { return resp, nil } @@ -1741,9 +1869,11 @@ func (al *AgentLoop) handleCommand( } func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime { + registry := al.GetRegistry() + cfg := al.GetConfig() rt := &commands.Runtime{ - Config: al.cfg, - ListAgentIDs: al.registry.ListAgentIDs, + Config: cfg, + ListAgentIDs: registry.ListAgentIDs, ListDefinitions: al.cmdRegistry.Definitions, GetEnabledChannels: func() []string { if al.channelManager == nil { @@ -1763,7 +1893,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt } if agent != nil { rt.GetModelInfo = func() (string, string) { - return agent.Model, al.cfg.Agents.Defaults.Provider + return agent.Model, cfg.Agents.Defaults.Provider } rt.SwitchModel = func(value string) (string, error) { oldModel := agent.Model @@ -1827,3 +1957,16 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { } return &routing.RoutePeer{Kind: parentKind, ID: parentID} } + +// Helper to extract provider from registry for cleanup +func extractProvider(registry *AgentRegistry) (providers.LLMProvider, bool) { + if registry == nil { + return nil, false + } + // Get any agent to access the provider + defaultAgent := registry.GetDefaultAgent() + if defaultAgent == nil { + return nil, false + } + return defaultAgent.Provider, true +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 80adcf86cf..ef75612c39 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -194,6 +194,10 @@ func DebugC(component string, message string) { logMessage(DEBUG, component, message, nil) } +func Debugf(message string, ss ...any) { + logMessage(DEBUG, "", fmt.Sprintf(message, ss...), nil) +} + func DebugF(message string, fields map[string]any) { logMessage(DEBUG, "", message, fields) } @@ -214,6 +218,10 @@ func InfoF(message string, fields map[string]any) { logMessage(INFO, "", message, fields) } +func Infof(message string, ss ...any) { + logMessage(INFO, "", fmt.Sprintf(message, ss...), nil) +} + func InfoCF(component string, message string, fields map[string]any) { logMessage(INFO, component, message, fields) } @@ -242,6 +250,10 @@ func ErrorC(component string, message string) { logMessage(ERROR, component, message, nil) } +func Errorf(message string, ss ...any) { + logMessage(ERROR, "", fmt.Sprintf(message, ss...), nil) +} + func ErrorF(message string, fields map[string]any) { logMessage(ERROR, "", message, fields) } diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go index 6e6f8dfa8d..8170a618ba 100644 --- a/pkg/logger/logger_test.go +++ b/pkg/logger/logger_test.go @@ -123,17 +123,21 @@ func TestLoggerHelperFunctions(t *testing.T) { SetLevel(INFO) Debug("This should not log") + Debugf("this should not log") Info("This should log") Warn("This should log") Error("This should log") InfoC("test", "Component message") InfoF("Fields message", map[string]any{"key": "value"}) + Infof("test from %v", "Infof") WarnC("test", "Warning with component") ErrorF("Error with fields", map[string]any{"error": "test"}) + Errorf("test from %v", "Errorf") SetLevel(DEBUG) DebugC("test", "Debug with component") + Debugf("test from %v", "Debugf") WarnF("Warning with fields", map[string]any{"key": "value"}) }