Skip to content

Commit d370ca7

Browse files
committed
fix: resolve multiple bugs from code review #116
Fixes four issues identified in the community code review: - Session persistence broken on Windows: session keys like "telegram:123456" contain ':', which is illegal in Windows filenames. filepath.Base() strips drive-letter prefixes on Windows, causing Save() to silently fail. Added sanitizeFilename() to replace invalid chars in the filename while keeping the original key in the JSON payload. - HTTP client with no timeout: HTTPProvider used Timeout: 0 (infinite wait), which can hang the entire agent if an API endpoint becomes unresponsive. Set a 120s safety timeout. - Slack AllowFrom type mismatch: SlackConfig used plain []string while every other channel uses FlexibleStringSlice, so numeric user IDs in Slack config would fail to parse. - Token estimation wrong for CJK: estimateTokens() divided byte length by 4, but CJK characters are 3 bytes each, causing ~3x overestimation and premature summarization. Switched to utf8.RuneCountInString() / 3 for better cross-language accuracy. Also added unit tests for the session filename sanitization. Ref #116
1 parent 1cff7d4 commit d370ca7

File tree

5 files changed

+113
-10
lines changed

5 files changed

+113
-10
lines changed

pkg/agent/loop.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"sync"
1717
"sync/atomic"
1818
"time"
19+
"unicode/utf8"
1920

2021
"github.com/sipeed/picoclaw/pkg/bus"
2122
"github.com/sipeed/picoclaw/pkg/config"
@@ -768,10 +769,13 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa
768769
}
769770

770771
// estimateTokens estimates the number of tokens in a message list.
772+
// Uses rune count instead of byte length so that CJK and other multi-byte
773+
// characters are not over-counted (a Chinese character is 3 bytes but roughly
774+
// one token).
771775
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
772776
total := 0
773777
for _, m := range messages {
774-
total += len(m.Content) / 4 // Simple heuristic: 4 chars per token
778+
total += utf8.RuneCountInString(m.Content) / 3
775779
}
776780
return total
777781
}

pkg/config/config.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ type DingTalkConfig struct {
130130
}
131131

132132
type SlackConfig struct {
133-
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
134-
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
135-
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
136-
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
133+
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
134+
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
135+
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
136+
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
137137
}
138138

139139
type LINEConfig struct {
@@ -261,7 +261,7 @@ func DefaultConfig() *Config {
261261
Enabled: false,
262262
BotToken: "",
263263
AppToken: "",
264-
AllowFrom: []string{},
264+
AllowFrom: FlexibleStringSlice{},
265265
},
266266
LINE: LINEConfig{
267267
Enabled: false,

pkg/providers/http_provider.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"net/http"
1616
"net/url"
1717
"strings"
18+
"time"
1819

1920
"github.com/sipeed/picoclaw/pkg/auth"
2021
"github.com/sipeed/picoclaw/pkg/config"
@@ -28,7 +29,7 @@ type HTTPProvider struct {
2829

2930
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
3031
client := &http.Client{
31-
Timeout: 0,
32+
Timeout: 120 * time.Second,
3233
}
3334

3435
if proxy != "" {

pkg/session/manager.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,33 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
145145
session.Updated = time.Now()
146146
}
147147

148+
// sanitizeFilename converts a session key into a cross-platform safe filename.
149+
// Characters like ':' are valid on Linux but illegal on Windows; we replace
150+
// them so the same key works everywhere. The original key is preserved inside
151+
// the JSON file, so loadSessions still maps back to the right in-memory key.
152+
func sanitizeFilename(key string) string {
153+
r := strings.NewReplacer(
154+
":", "_",
155+
"<", "_",
156+
">", "_",
157+
"\"", "_",
158+
"|", "_",
159+
"?", "_",
160+
"*", "_",
161+
)
162+
return r.Replace(key)
163+
}
164+
148165
func (sm *SessionManager) Save(key string) error {
149166
if sm.storage == "" {
150167
return nil
151168
}
152169

153-
// Validate key to avoid invalid filenames and path traversal.
154-
if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") {
170+
// Sanitize key into a cross-platform safe filename (e.g. "telegram:123" -> "telegram_123").
171+
filename := sanitizeFilename(key)
172+
173+
// Validate the sanitized filename to avoid path traversal.
174+
if filename == "" || filename == "." || filename == ".." || strings.ContainsAny(filename, "/\\") {
155175
return os.ErrInvalid
156176
}
157177

@@ -182,7 +202,7 @@ func (sm *SessionManager) Save(key string) error {
182202
return err
183203
}
184204

185-
sessionPath := filepath.Join(sm.storage, key+".json")
205+
sessionPath := filepath.Join(sm.storage, filename+".json")
186206
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
187207
if err != nil {
188208
return err

pkg/session/manager_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package session
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"testing"
7+
)
8+
9+
func TestSanitizeFilename(t *testing.T) {
10+
tests := []struct {
11+
input string
12+
expected string
13+
}{
14+
{"simple", "simple"},
15+
{"telegram:123456", "telegram_123456"},
16+
{"discord:987654321", "discord_987654321"},
17+
{"slack:C01234", "slack_C01234"},
18+
{"no-colons-here", "no-colons-here"},
19+
{"multiple:colons:here", "multiple_colons_here"},
20+
{"with<angle>brackets", "with_angle_brackets"},
21+
{"pipe|char", "pipe_char"},
22+
{"question?mark", "question_mark"},
23+
{"star*char", "star_char"},
24+
}
25+
26+
for _, tt := range tests {
27+
t.Run(tt.input, func(t *testing.T) {
28+
got := sanitizeFilename(tt.input)
29+
if got != tt.expected {
30+
t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.expected)
31+
}
32+
})
33+
}
34+
}
35+
36+
func TestSave_WithColonInKey(t *testing.T) {
37+
tmpDir := t.TempDir()
38+
sm := NewSessionManager(tmpDir)
39+
40+
// Create a session with a key containing colon (typical channel session key).
41+
key := "telegram:123456"
42+
sm.GetOrCreate(key)
43+
sm.AddMessage(key, "user", "hello")
44+
45+
// Save should succeed even though the key contains ':'
46+
if err := sm.Save(key); err != nil {
47+
t.Fatalf("Save(%q) failed: %v", key, err)
48+
}
49+
50+
// The file on disk should use sanitized name.
51+
expectedFile := filepath.Join(tmpDir, "telegram_123456.json")
52+
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
53+
t.Fatalf("expected session file %s to exist", expectedFile)
54+
}
55+
56+
// Load into a fresh manager and verify the session round-trips.
57+
sm2 := NewSessionManager(tmpDir)
58+
history := sm2.GetHistory(key)
59+
if len(history) != 1 {
60+
t.Fatalf("expected 1 message after reload, got %d", len(history))
61+
}
62+
if history[0].Content != "hello" {
63+
t.Errorf("expected message content %q, got %q", "hello", history[0].Content)
64+
}
65+
}
66+
67+
func TestSave_RejectsPathTraversal(t *testing.T) {
68+
tmpDir := t.TempDir()
69+
sm := NewSessionManager(tmpDir)
70+
71+
badKeys := []string{"", ".", "..", "foo/bar", "foo\\bar"}
72+
for _, key := range badKeys {
73+
sm.GetOrCreate(key)
74+
if err := sm.Save(key); err == nil {
75+
t.Errorf("Save(%q) should have failed but didn't", key)
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)