diff --git a/cmd/picoclaw/internal/onboard/command.go b/cmd/picoclaw/internal/onboard/command.go index ec10129594..843ace7e7e 100644 --- a/cmd/picoclaw/internal/onboard/command.go +++ b/cmd/picoclaw/internal/onboard/command.go @@ -1,15 +1,9 @@ package onboard import ( - "embed" - "github.com/spf13/cobra" ) -//go:generate cp -r ../../../../workspace . -//go:embed workspace -var embeddedFiles embed.FS - func NewOnboardCommand() *cobra.Command { cmd := &cobra.Command{ Use: "onboard", diff --git a/cmd/picoclaw/internal/onboard/helpers.go b/cmd/picoclaw/internal/onboard/helpers.go index 4db8bdc8ba..0c5324d628 100644 --- a/cmd/picoclaw/internal/onboard/helpers.go +++ b/cmd/picoclaw/internal/onboard/helpers.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + picoclaw "github.com/sipeed/picoclaw" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/config" ) @@ -47,20 +48,20 @@ func onboard() { } func createWorkspaceTemplates(workspace string) { - err := copyEmbeddedToTarget(workspace) + err := copyEmbeddedToTarget(picoclaw.EmbeddedWorkspace, workspace) if err != nil { fmt.Printf("Error copying workspace templates: %v\n", err) } } -func copyEmbeddedToTarget(targetDir string) error { +func copyEmbeddedToTarget(source fs.FS, targetDir string) error { // Ensure target directory exists if err := os.MkdirAll(targetDir, 0o755); err != nil { return fmt.Errorf("Failed to create target directory: %w", err) } // Walk through all files in embed.FS - err := fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error { + err := fs.WalkDir(source, "workspace", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -71,7 +72,7 @@ func copyEmbeddedToTarget(targetDir string) error { } // Read embedded file - data, err := embeddedFiles.ReadFile(path) + data, err := fs.ReadFile(source, path) if err != nil { return fmt.Errorf("Failed to read embedded file %s: %w", path, err) } diff --git a/cmd/picoclaw/internal/onboard/helpers_test.go b/cmd/picoclaw/internal/onboard/helpers_test.go index f3e0c92e08..cf087c3993 100644 --- a/cmd/picoclaw/internal/onboard/helpers_test.go +++ b/cmd/picoclaw/internal/onboard/helpers_test.go @@ -4,12 +4,14 @@ import ( "os" "path/filepath" "testing" + + picoclaw "github.com/sipeed/picoclaw" ) func TestCopyEmbeddedToTargetUsesAgentsMarkdown(t *testing.T) { targetDir := t.TempDir() - if err := copyEmbeddedToTarget(targetDir); err != nil { + if err := copyEmbeddedToTarget(picoclaw.EmbeddedWorkspace, targetDir); err != nil { t.Fatalf("copyEmbeddedToTarget() error = %v", err) } diff --git a/embedded_workspace.go b/embedded_workspace.go new file mode 100644 index 0000000000..b70f5ecf0e --- /dev/null +++ b/embedded_workspace.go @@ -0,0 +1,8 @@ +package picoclaw + +import "embed" + +// EmbeddedWorkspace bundles the default workspace templates used by `picoclaw onboard`. +// +//go:embed workspace +var EmbeddedWorkspace embed.FS diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f20a56b9c4..4a05a32194 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -68,7 +68,6 @@ type processOptions struct { const ( defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." - sessionKeyAgentPrefix = "agent:" metadataKeyAccountID = "account_id" metadataKeyGuildID = "guild_id" metadataKeyTeamID = "team_id" @@ -774,7 +773,7 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv } func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string { - if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) { + if msgSessionKey != "" { return msgSessionKey } return route.SessionKey diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index a6604e87fd..0161273bf4 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -439,6 +439,59 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { } } +func TestProcessDirectWithChannel_PreservesExplicitSessionKey(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "ok"} + al := NewAgentLoop(cfg, msgBus, provider) + + const sessionKey = "custom-session" + response, err := al.ProcessDirectWithChannel( + context.Background(), + "hello", + sessionKey, + "cli", + "direct", + ) + if err != nil { + t.Fatalf("ProcessDirectWithChannel() error = %v", err) + } + if response != "ok" { + t.Fatalf("response = %q, want %q", response, "ok") + } + + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + history := defaultAgent.Sessions.GetHistory(sessionKey) + if len(history) != 2 { + t.Fatalf("history len = %d, want 2", len(history)) + } + if history[0].Role != "user" || history[0].Content != "hello" { + t.Fatalf("unexpected user message: %+v", history[0]) + } + if history[1].Role != "assistant" || history[1].Content != "ok" { + t.Fatalf("unexpected assistant message: %+v", history[1]) + } +} + func TestProcessMessage_CommandOutcomes(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { diff --git a/pkg/config/config.go b/pkg/config/config.go index 1903412248..dccec4f940 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -958,7 +958,7 @@ func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) { } // Multiple configs - use round-robin for load balancing - idx := rrCounter.Add(1) % uint64(len(matches)) + idx := (rrCounter.Add(1) - 1) % uint64(len(matches)) return &matches[idx], nil } diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go index da6e506f84..5bbb9d0f4b 100644 --- a/pkg/config/model_config_test.go +++ b/pkg/config/model_config_test.go @@ -54,6 +54,8 @@ func TestGetModelConfig_EmptyList(t *testing.T) { } func TestGetModelConfig_RoundRobin(t *testing.T) { + rrCounter.Store(0) + cfg := &Config{ ModelList: []ModelConfig{ {ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"}, @@ -62,20 +64,22 @@ func TestGetModelConfig_RoundRobin(t *testing.T) { }, } - // Test round-robin distribution - results := make(map[string]int) - for range 30 { + want := []string{ + "openai/gpt-4o-1", + "openai/gpt-4o-2", + "openai/gpt-4o-3", + "openai/gpt-4o-1", + "openai/gpt-4o-2", + "openai/gpt-4o-3", + } + + for i, wantModel := range want { result, err := cfg.GetModelConfig("lb-model") if err != nil { t.Fatalf("GetModelConfig() error = %v", err) } - results[result.Model]++ - } - - // Each model should appear roughly 10 times (30 calls / 3 models) - for model, count := range results { - if count < 5 || count > 15 { - t.Errorf("Model %s appeared %d times, expected ~10", model, count) + if result.Model != wantModel { + t.Fatalf("call %d selected %q, want %q", i+1, result.Model, wantModel) } } } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index f97bf3acd5..bef24bd18b 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -4,12 +4,15 @@ import ( "bufio" "bytes" "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" "log" "net/http" "net/url" + "os" "strings" "time" @@ -54,22 +57,48 @@ func WithRequestTimeout(timeout time.Duration) Option { } } -func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { - client := &http.Client{ - Timeout: defaultRequestTimeout, +// buildCertPool returns the system cert pool supplemented with CA bundles from +// well-known Termux paths. On Android/Termux, Go's x509.SystemCertPool returns +// an empty pool because it does not probe Termux-specific locations, causing +// TLS handshakes to fail with "certificate signed by unknown authority". +// InsecureSkipVerify is never set. +func buildCertPool() *x509.CertPool { + pool, err := x509.SystemCertPool() + if err != nil || pool == nil { + pool = x509.NewCertPool() + } + for _, p := range []string{ + "/data/data/com.termux/files/usr/etc/tls/cert.pem", + "/data/data/com.termux/files/usr/etc/ssl/certs/ca-bundle.crt", + "/data/data/com.termux/files/usr/etc/ssl/certs/ca-certificates.crt", + } { + if pem, e := os.ReadFile(p); e == nil { + pool.AppendCertsFromPEM(pem) + } } + return pool +} + +func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { + // Clone preserves all http.DefaultTransport defaults (connection pooling, + // dial/TLS handshake timeouts, HTTP/2, env proxy). We only patch RootCAs. + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{RootCAs: buildCertPool()} if proxy != "" { parsed, err := url.Parse(proxy) if err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(parsed), - } + transport.Proxy = http.ProxyURL(parsed) } else { log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err) } } + client := &http.Client{ + Timeout: defaultRequestTimeout, + Transport: transport, + } + p := &Provider{ apiKey: apiKey, apiBase: strings.TrimRight(apiBase, "/"), diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 41f278a1b1..f4c1d50ee3 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -841,3 +841,31 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) { t.Fatal("system_parts should not appear in serialized output") } } + +// TestNewProvider_TLSTransportNeverInsecure ensures every provider — whether +// configured with a proxy or not — has an explicit TLS transport that never +// sets InsecureSkipVerify. This is the security guard for issue #1375. +func TestNewProvider_TLSTransportNeverInsecure(t *testing.T) { + tests := []struct { + name string + proxy string + }{ + {name: "no proxy", proxy: ""}, + {name: "with proxy", proxy: "http://127.0.0.1:8080"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewProvider("key", "https://example.com", tt.proxy) + tr, ok := p.httpClient.Transport.(*http.Transport) + if !ok || tr == nil { + t.Fatalf("Transport = %T, want *http.Transport", p.httpClient.Transport) + } + if tr.TLSClientConfig == nil { + t.Fatal("TLSClientConfig is nil") + } + if tr.TLSClientConfig.InsecureSkipVerify { + t.Fatal("InsecureSkipVerify must never be true") + } + }) + } +}