Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions cmd/picoclaw/internal/onboard/command.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
package onboard

import (
"embed"

"github.com/spf13/cobra"
)

//go:generate cp -r ../../../../workspace .
//go:embed workspace
var embeddedFiles embed.FS

func NewOnboardCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "onboard",
Expand Down
9 changes: 5 additions & 4 deletions cmd/picoclaw/internal/onboard/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"path/filepath"

picoclaw "github.com/sipeed/picoclaw"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/config"
)
Expand Down Expand Up @@ -47,20 +48,20 @@ func onboard() {
}

func createWorkspaceTemplates(workspace string) {
err := copyEmbeddedToTarget(workspace)
err := copyEmbeddedToTarget(picoclaw.EmbeddedWorkspace, workspace)
if err != nil {
fmt.Printf("Error copying workspace templates: %v\n", err)
}
}

func copyEmbeddedToTarget(targetDir string) error {
func copyEmbeddedToTarget(source fs.FS, targetDir string) error {
// Ensure target directory exists
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("Failed to create target directory: %w", err)
}

// Walk through all files in embed.FS
err := fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error {
err := fs.WalkDir(source, "workspace", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
Expand All @@ -71,7 +72,7 @@ func copyEmbeddedToTarget(targetDir string) error {
}

// Read embedded file
data, err := embeddedFiles.ReadFile(path)
data, err := fs.ReadFile(source, path)
if err != nil {
return fmt.Errorf("Failed to read embedded file %s: %w", path, err)
}
Expand Down
4 changes: 3 additions & 1 deletion cmd/picoclaw/internal/onboard/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"os"
"path/filepath"
"testing"

picoclaw "github.com/sipeed/picoclaw"
)

func TestCopyEmbeddedToTargetUsesAgentsMarkdown(t *testing.T) {
targetDir := t.TempDir()

if err := copyEmbeddedToTarget(targetDir); err != nil {
if err := copyEmbeddedToTarget(picoclaw.EmbeddedWorkspace, targetDir); err != nil {
t.Fatalf("copyEmbeddedToTarget() error = %v", err)
}

Expand Down
8 changes: 8 additions & 0 deletions embedded_workspace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package picoclaw

import "embed"

// EmbeddedWorkspace bundles the default workspace templates used by `picoclaw onboard`.
//
//go:embed workspace
var EmbeddedWorkspace embed.FS
3 changes: 1 addition & 2 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ type processOptions struct {

const (
defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
sessionKeyAgentPrefix = "agent:"
metadataKeyAccountID = "account_id"
metadataKeyGuildID = "guild_id"
metadataKeyTeamID = "team_id"
Expand Down Expand Up @@ -774,7 +773,7 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv
}

func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string {
if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) {
if msgSessionKey != "" {
return msgSessionKey
}
return route.SessionKey
Expand Down
53 changes: 53 additions & 0 deletions pkg/agent/loop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,59 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
}
}

func TestProcessDirectWithChannel_PreservesExplicitSessionKey(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}

msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "ok"}
al := NewAgentLoop(cfg, msgBus, provider)

const sessionKey = "custom-session"
response, err := al.ProcessDirectWithChannel(
context.Background(),
"hello",
sessionKey,
"cli",
"direct",
)
if err != nil {
t.Fatalf("ProcessDirectWithChannel() error = %v", err)
}
if response != "ok" {
t.Fatalf("response = %q, want %q", response, "ok")
}

defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
history := defaultAgent.Sessions.GetHistory(sessionKey)
if len(history) != 2 {
t.Fatalf("history len = %d, want 2", len(history))
}
if history[0].Role != "user" || history[0].Content != "hello" {
t.Fatalf("unexpected user message: %+v", history[0])
}
if history[1].Role != "assistant" || history[1].Content != "ok" {
t.Fatalf("unexpected assistant message: %+v", history[1])
}
}

func TestProcessMessage_CommandOutcomes(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) {
}

// Multiple configs - use round-robin for load balancing
idx := rrCounter.Add(1) % uint64(len(matches))
idx := (rrCounter.Add(1) - 1) % uint64(len(matches))
return &matches[idx], nil
}

Expand Down
24 changes: 14 additions & 10 deletions pkg/config/model_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ func TestGetModelConfig_EmptyList(t *testing.T) {
}

func TestGetModelConfig_RoundRobin(t *testing.T) {
rrCounter.Store(0)

cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
Expand All @@ -62,20 +64,22 @@ func TestGetModelConfig_RoundRobin(t *testing.T) {
},
}

// Test round-robin distribution
results := make(map[string]int)
for range 30 {
want := []string{
"openai/gpt-4o-1",
"openai/gpt-4o-2",
"openai/gpt-4o-3",
"openai/gpt-4o-1",
"openai/gpt-4o-2",
"openai/gpt-4o-3",
}

for i, wantModel := range want {
result, err := cfg.GetModelConfig("lb-model")
if err != nil {
t.Fatalf("GetModelConfig() error = %v", err)
}
results[result.Model]++
}

// Each model should appear roughly 10 times (30 calls / 3 models)
for model, count := range results {
if count < 5 || count > 15 {
t.Errorf("Model %s appeared %d times, expected ~10", model, count)
if result.Model != wantModel {
t.Fatalf("call %d selected %q, want %q", i+1, result.Model, wantModel)
}
}
}
Expand Down
41 changes: 35 additions & 6 deletions pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"strings"
"time"

Expand Down Expand Up @@ -54,22 +57,48 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
}

func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
client := &http.Client{
Timeout: defaultRequestTimeout,
// buildCertPool returns the system cert pool supplemented with CA bundles from
// well-known Termux paths. On Android/Termux, Go's x509.SystemCertPool returns
// an empty pool because it does not probe Termux-specific locations, causing
// TLS handshakes to fail with "certificate signed by unknown authority".
// InsecureSkipVerify is never set.
func buildCertPool() *x509.CertPool {
pool, err := x509.SystemCertPool()
if err != nil || pool == nil {
pool = x509.NewCertPool()
}
for _, p := range []string{
"/data/data/com.termux/files/usr/etc/tls/cert.pem",
"/data/data/com.termux/files/usr/etc/ssl/certs/ca-bundle.crt",
"/data/data/com.termux/files/usr/etc/ssl/certs/ca-certificates.crt",
} {
if pem, e := os.ReadFile(p); e == nil {
pool.AppendCertsFromPEM(pem)
}
}
return pool
}

func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
// Clone preserves all http.DefaultTransport defaults (connection pooling,
// dial/TLS handshake timeouts, HTTP/2, env proxy). We only patch RootCAs.
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{RootCAs: buildCertPool()}

if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
transport.Proxy = http.ProxyURL(parsed)
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}

client := &http.Client{
Timeout: defaultRequestTimeout,
Transport: transport,
}

p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
Expand Down
28 changes: 28 additions & 0 deletions pkg/providers/openai_compat/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -841,3 +841,31 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
t.Fatal("system_parts should not appear in serialized output")
}
}

// TestNewProvider_TLSTransportNeverInsecure ensures every provider — whether
// configured with a proxy or not — has an explicit TLS transport that never
// sets InsecureSkipVerify. This is the security guard for issue #1375.
func TestNewProvider_TLSTransportNeverInsecure(t *testing.T) {
tests := []struct {
name string
proxy string
}{
{name: "no proxy", proxy: ""},
{name: "with proxy", proxy: "http://127.0.0.1:8080"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewProvider("key", "https://example.com", tt.proxy)
tr, ok := p.httpClient.Transport.(*http.Transport)
if !ok || tr == nil {
t.Fatalf("Transport = %T, want *http.Transport", p.httpClient.Transport)
}
if tr.TLSClientConfig == nil {
t.Fatal("TLSClientConfig is nil")
}
if tr.TLSClientConfig.InsecureSkipVerify {
t.Fatal("InsecureSkipVerify must never be true")
}
})
}
}