diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 174f5db627..4ffacacee4 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -3,12 +3,14 @@ package gateway import ( "context" "fmt" - "log" "os" "os/signal" "path/filepath" + "sync" "time" + "github.com/fsnotify/fsnotify" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" @@ -45,6 +47,7 @@ func gatewayCmd(debug bool) error { fmt.Println("🔍 Debug mode enabled") } + configPath := internal.GetConfigPath() cfg, err := internal.LoadConfig() if err != nil { return fmt.Errorf("error loading config: %w", err) @@ -190,31 +193,167 @@ func gatewayCmd(debug bool) error { go agentLoop.Run(ctx) + // Setup config file watcher for hot reload + configWatcher, configReloadChan, watchErr := setupConfigWatcher(configPath, debug) + if watchErr != nil { + logger.Errorf("⚠ Warning: Could not start config file watcher: %v", watchErr) + logger.Warn(" Config changes will require manual restart") + } else { + logger.Info("✓ Config file watcher started (auto-reload on change)") + defer configWatcher.Close() + } + sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt) - <-sigChan - fmt.Println("\nShutting down...") - if cp, ok := provider.(providers.StatefulProvider); ok { - cp.Close() + // Main event loop - wait for signals or config changes + for { + select { + case <-sigChan: + logger.Info("Shutting down...") + if cp, ok := provider.(providers.StatefulProvider); ok { + cp.Close() + } + cancel() + msgBus.Close() + + // Use a fresh context with timeout for graceful shutdown, + // since the original ctx is already canceled. + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second) + defer shutdownCancel() + + channelManager.StopAll(shutdownCtx) + deviceService.Stop() + heartbeatService.Stop() + cronService.Stop() + mediaStore.Stop() + agentLoop.Stop() + logger.Info("✓ Gateway stopped") + + return nil + + case newCfg := <-configReloadChan: + logger.Info("🔄 Config file changed, reloading...") + + newModel := newCfg.Agents.Defaults.ModelName + if newModel == "" { + newModel = newCfg.Agents.Defaults.Model + } + + logger.Infof(" New model is '%s', recreating provider...", newModel) + if cp, ok := provider.(providers.StatefulProvider); ok { + cp.Close() + } + + // Create new provider from updated config + // This will use the correct API key and settings from newCfg.ModelList + newProvider, newModelID, err := providers.CreateProvider(newCfg) + if err != nil { + logger.Errorf(" ⚠ Error creating new provider: %v", err) + logger.Warn(" Continuing with old provider") + continue + } + + provider = newProvider + if newModelID != "" { + newCfg.Agents.Defaults.ModelName = newModelID + } + + // Update agent loop provider and models + agentLoop.SetProvider(provider, newCfg) + + logger.Info(" ✓ Provider and agents updated successfully") + + // Update the config reference for other operations + // Note: Some changes (like channel configs) may require restart to take full effect + cfg = newCfg + logger.Info(" ✓ Configuration reloaded successfully") + } + } +} + +// setupConfigWatcher sets up a file watcher for the config file +// Returns the watcher, a channel for config updates, and any error +func setupConfigWatcher(configPath string, debug bool) (*fsnotify.Watcher, chan *config.Config, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, nil, err } - cancel() - msgBus.Close() - // Use a fresh context with timeout for graceful shutdown, - // since the original ctx is already canceled. - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second) - defer shutdownCancel() + configDir := filepath.Dir(configPath) + if err := watcher.Add(configDir); err != nil { + watcher.Close() + return nil, nil, err + } - channelManager.StopAll(shutdownCtx) - deviceService.Stop() - heartbeatService.Stop() - cronService.Stop() - mediaStore.Stop() - agentLoop.Stop() - fmt.Println("✓ Gateway stopped") + configChan := make(chan *config.Config, 1) + var mu sync.Mutex + var debounceTimer *time.Timer + + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + + // Only process config.json changes + if event.Name != configPath { + continue + } + + // Debounce rapid file changes (some editors write multiple times) + mu.Lock() + if debounceTimer != nil { + debounceTimer.Stop() + } + + debounceTimer = time.AfterFunc(500*time.Millisecond, func() { + mu.Unlock() + + if debug { + logger.DebugSF(" 🔍 Config file event: %v", event) + } + + // Validate and load new config + newCfg, err := config.LoadConfig(configPath) + if err != nil { + logger.Errorf(" ⚠ Error loading new config: %v", err) + logger.Warn(" Using previous valid config") + return + } + + // Validate the new config + if err := newCfg.ValidateModelList(); err != nil { + logger.Errorf(" ⚠ New config validation failed: %v", err) + logger.Warn(" Using previous valid config") + return + } + + logger.Info(" ✓ Config file validated and loaded") + + // Send new config to main loop (non-blocking) + select { + case configChan <- newCfg: + default: + // Channel full, skip this update + logger.Warn(" ⚠ Previous config reload still in progress, skipping") + } + }) + mu.Lock() // Keep lock until timer is set + mu.Unlock() + + case err, ok := <-watcher.Errors: + if !ok { + return + } + logger.Errorf(" ⚠ Config watcher error: %v", err) + } + } + }() - return nil + return watcher, configChan, nil } func setupCronTool( @@ -227,16 +366,13 @@ func setupCronTool( ) *cron.CronService { cronStorePath := filepath.Join(workspace, "cron", "jobs.json") - // Create cron service - cronService := cron.NewCronService(cronStorePath, nil) - // Create and register CronTool if enabled var cronTool *tools.CronTool if cfg.Tools.IsToolEnabled("cron") { var err error cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) if err != nil { - log.Fatalf("Critical error during CronTool initialization: %v", err) + logger.Fatalf("Critical error during CronTool initialization: %v", err) } agentLoop.RegisterTool(cronTool) diff --git a/go.mod b/go.mod index 238bd405cf..ce1ed298e4 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,11 @@ require ( github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 + github.com/fsnotify/fsnotify v1.4.9 github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 + github.com/h2non/filetype v1.1.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mdp/qrterminal/v3 v3.2.1 github.com/modelcontextprotocol/go-sdk v1.3.0 @@ -37,7 +39,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect - github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect diff --git a/go.sum b/go.sum index 060594d06b..0a8915b457 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,7 @@ github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+m github.com/elliotchance/orderedmap/v3 v3.1.0 h1:j4DJ5ObEmMBt/lcwIecKcoRxIQUEnw0L804lXYDt/pg= github.com/elliotchance/orderedmap/v3 v3.1.0/go.mod h1:G+Hc2RwaZvJMcS4JpGCOyViCnGeKf0bTYCGTO4uhjSo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw= github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo= diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 685b346e69..e5e7511725 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -368,6 +368,14 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { al.channelManager = cm } +// SetProvider updates the LLM provider for all agents in the registry +// and updates their model configurations. +func (al *AgentLoop) SetProvider(provider providers.LLMProvider, cfg *config.Config) { + al.cfg = cfg + registry := NewAgentRegistry(cfg, provider) + al.registry = registry +} + // SetMediaStore injects a MediaStore for media lifecycle management. func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 56dc87a536..08c0ea6ff5 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -168,6 +168,10 @@ func DebugC(component string, message string) { logMessage(DEBUG, component, message, nil) } +func DebugSF(message string, ss ...any) { + logMessage(DEBUG, "", fmt.Sprintf(message, ss...), nil) +} + func DebugF(message string, fields map[string]any) { logMessage(DEBUG, "", message, fields) } @@ -188,6 +192,10 @@ func InfoF(message string, fields map[string]any) { logMessage(INFO, "", message, fields) } +func Infof(message string, ss ...any) { + logMessage(INFO, "", fmt.Sprintf(message, ss...), nil) +} + func InfoCF(component string, message string, fields map[string]any) { logMessage(INFO, component, message, fields) } @@ -216,6 +224,10 @@ func ErrorC(component string, message string) { logMessage(ERROR, component, message, nil) } +func Errorf(message string, ss ...any) { + logMessage(ERROR, "", fmt.Sprintf(message, ss...), nil) +} + func ErrorF(message string, fields map[string]any) { logMessage(ERROR, "", message, fields) } @@ -232,6 +244,10 @@ func FatalC(component string, message string) { logMessage(FATAL, component, message, nil) } +func Fatalf(message string, ss ...any) { + logMessage(FATAL, "", fmt.Sprintf(message, ss...), nil) +} + func FatalF(message string, fields map[string]any) { logMessage(FATAL, "", message, fields) }