Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 162 additions & 20 deletions cmd/picoclaw/internal/gateway/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"path/filepath"
"sync"
"time"

"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
Expand Down Expand Up @@ -45,6 +45,7 @@
fmt.Println("πŸ” Debug mode enabled")
}

configPath := internal.GetConfigPath()
cfg, err := internal.LoadConfig()
if err != nil {
return fmt.Errorf("error loading config: %w", err)
Expand Down Expand Up @@ -190,31 +191,172 @@

go agentLoop.Run(ctx)

// Setup config file watcher for hot reload
configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug)
defer stopWatch()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
<-sigChan

fmt.Println("\nShutting down...")
if cp, ok := provider.(providers.StatefulProvider); ok {
cp.Close()
// Main event loop - wait for signals or config changes
for {
select {
case <-sigChan:
logger.Info("Shutting down...")
if cp, ok := provider.(providers.StatefulProvider); ok {
cp.Close()
}
cancel()
msgBus.Close()

// Use a fresh context with timeout for graceful shutdown,
// since the original ctx is already canceled.
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
defer shutdownCancel()

channelManager.StopAll(shutdownCtx)
deviceService.Stop()
heartbeatService.Stop()
cronService.Stop()
mediaStore.Stop()
agentLoop.Stop()
logger.Info("βœ“ Gateway stopped")

return nil

case newCfg := <-configReloadChan:
logger.Info("πŸ”„ Config file changed, reloading...")

newModel := newCfg.Agents.Defaults.ModelName
if newModel == "" {
newModel = newCfg.Agents.Defaults.Model
}

logger.Infof(" New model is '%s', recreating provider...", newModel)
if cp, ok := provider.(providers.StatefulProvider); ok {
cp.Close()
}

// Create new provider from updated config
// This will use the correct API key and settings from newCfg.ModelList
newProvider, newModelID, err := providers.CreateProvider(newCfg)
if err != nil {
logger.Errorf(" ⚠ Error creating new provider: %v", err)
logger.Warn(" Continuing with old provider")
continue
}

provider = newProvider
if newModelID != "" {
newCfg.Agents.Defaults.ModelName = newModelID
}

// Update agent loop provider and models
agentLoop.SetProvider(provider, newCfg)

logger.Info(" βœ“ Provider and agents updated successfully")

// Update the config reference for other operations
// Note: Some changes (like channel configs) may require restart to take full effect
cfg = newCfg

Check failure on line 261 in cmd/picoclaw/internal/gateway/helpers.go

View workflow job for this annotation

GitHub Actions / Linter

assigned to cfg, but reassigned without using the value (wastedassign)
logger.Info(" βœ“ Configuration reloaded successfully")
}
}
cancel()
msgBus.Close()
}

// Use a fresh context with timeout for graceful shutdown,
// since the original ctx is already canceled.
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
defer shutdownCancel()
// setupConfigWatcherPolling sets up a simple polling-based config file watcher
// Returns a channel for config updates and a stop function
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
configChan := make(chan *config.Config, 1)
stop := make(chan struct{})
var wg sync.WaitGroup

wg.Add(1)
go func() {
defer wg.Done()

// Get initial file info
lastModTime := getFileModTime(configPath)
lastSize := getFileSize(configPath)

ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
defer ticker.Stop()

for {
select {
case <-ticker.C:
currentModTime := getFileModTime(configPath)
currentSize := getFileSize(configPath)

// Check if file changed (modification time or size changed)
if currentModTime.After(lastModTime) || currentSize != lastSize {
if debug {
logger.DebugSF("πŸ” Config file change detected")
}

// Debounce - wait a bit to ensure file write is complete
time.Sleep(500 * time.Millisecond)

// Validate and load new config
newCfg, err := config.LoadConfig(configPath)
if err != nil {
logger.Errorf("⚠ Error loading new config: %v", err)
logger.Warn(" Using previous valid config")
continue
}

// Validate the new config
if err := newCfg.ValidateModelList(); err != nil {
logger.Errorf(" ⚠ New config validation failed: %v", err)
logger.Warn(" Using previous valid config")
continue
}

logger.Info("βœ“ Config file validated and loaded")

// Update last known state
lastModTime = currentModTime
lastSize = currentSize

// Send new config to main loop (non-blocking)
select {
case configChan <- newCfg:
default:
// Channel full, skip this update
logger.Warn("⚠ Previous config reload still in progress, skipping")
}
}

case <-stop:
return
}
}
}()

channelManager.StopAll(shutdownCtx)
deviceService.Stop()
heartbeatService.Stop()
cronService.Stop()
mediaStore.Stop()
agentLoop.Stop()
fmt.Println("βœ“ Gateway stopped")
stopFunc := func() {
close(stop)
wg.Wait()
}

return configChan, stopFunc
}

return nil
// getFileModTime returns the modification time of a file, or zero time if file doesn't exist
func getFileModTime(path string) time.Time {
info, err := os.Stat(path)
if err != nil {
return time.Time{}
}
return info.ModTime()
}

// getFileSize returns the size of a file, or 0 if file doesn't exist
func getFileSize(path string) int64 {
info, err := os.Stat(path)
if err != nil {
return 0
}
return info.Size()
}

func setupCronTool(
Expand All @@ -236,7 +378,7 @@
var err error
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
if err != nil {
log.Fatalf("Critical error during CronTool initialization: %v", err)
logger.Fatalf("Critical error during CronTool initialization: %v", err)
}

agentLoop.RegisterTool(cronTool)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/gdamore/tcell/v2 v2.13.8
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/h2non/filetype v1.1.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mdp/qrterminal/v3 v3.2.1
github.com/modelcontextprotocol/go-sdk v1.3.0
Expand All @@ -37,7 +38,6 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
github.com/gdamore/encoding v1.0.1 // indirect
github.com/h2non/filetype v1.1.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
Expand Down
8 changes: 8 additions & 0 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
al.channelManager = cm
}

// SetProvider updates the LLM provider for all agents in the registry
// and updates their model configurations.
func (al *AgentLoop) SetProvider(provider providers.LLMProvider, cfg *config.Config) {
al.cfg = cfg
registry := NewAgentRegistry(cfg, provider)
al.registry = registry
}

// SetMediaStore injects a MediaStore for media lifecycle management.
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
al.mediaStore = s
Expand Down
16 changes: 16 additions & 0 deletions pkg/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ func DebugC(component string, message string) {
logMessage(DEBUG, component, message, nil)
}

func DebugSF(message string, ss ...any) {
logMessage(DEBUG, "", fmt.Sprintf(message, ss...), nil)
}

func DebugF(message string, fields map[string]any) {
logMessage(DEBUG, "", message, fields)
}
Expand All @@ -188,6 +192,10 @@ func InfoF(message string, fields map[string]any) {
logMessage(INFO, "", message, fields)
}

func Infof(message string, ss ...any) {
logMessage(INFO, "", fmt.Sprintf(message, ss...), nil)
}

func InfoCF(component string, message string, fields map[string]any) {
logMessage(INFO, component, message, fields)
}
Expand Down Expand Up @@ -216,6 +224,10 @@ func ErrorC(component string, message string) {
logMessage(ERROR, component, message, nil)
}

func Errorf(message string, ss ...any) {
logMessage(ERROR, "", fmt.Sprintf(message, ss...), nil)
}

func ErrorF(message string, fields map[string]any) {
logMessage(ERROR, "", message, fields)
}
Expand All @@ -232,6 +244,10 @@ func FatalC(component string, message string) {
logMessage(FATAL, component, message, nil)
}

func Fatalf(message string, ss ...any) {
logMessage(FATAL, "", fmt.Sprintf(message, ss...), nil)
}

func FatalF(message string, fields map[string]any) {
logMessage(FATAL, "", message, fields)
}
Expand Down
Loading