Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pkg/agent/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,10 @@ func (cb *ContextBuilder) BuildMessages(
messages = append(messages, history...)

// Add current user message
if strings.TrimSpace(currentMessage) != "" || len(media) > 0 {
if strings.TrimSpace(currentMessage) != "" {
messages = append(messages, providers.Message{
Role: "user",
Content: currentMessage,
Media: media,
})
}

Expand Down
119 changes: 6 additions & 113 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ package agent

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
Expand Down Expand Up @@ -49,12 +47,11 @@ type AgentLoop struct {

// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // Media URLs attached to the user message
DefaultResponse string // Response when LLM returns empty
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
Expand Down Expand Up @@ -499,7 +496,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
Expand Down Expand Up @@ -606,11 +602,10 @@ func (al *AgentLoop) runAgentLoop(
history,
summary,
opts.UserMessage,
opts.Media,
nil,
opts.Channel,
opts.ChatID,
)
messages = resolveMediaRefs(messages, al.mediaStore)

// 3. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
Expand Down Expand Up @@ -1481,105 +1476,3 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
}
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
}

// maxMediaFileSize is the maximum file size (20 MB) for media resolution.
// Files larger than this are skipped to prevent OOM under concurrent load.
const maxMediaFileSize = 20 * 1024 * 1024

// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
// Returns a new slice with resolved URLs; original messages are not mutated.
func resolveMediaRefs(messages []providers.Message, store media.MediaStore) []providers.Message {
if store == nil {
return messages
}

result := make([]providers.Message, len(messages))
copy(result, messages)

for i, m := range result {
if len(m.Media) == 0 {
continue
}

resolved := make([]string, 0, len(m.Media))
for _, ref := range m.Media {
if !strings.HasPrefix(ref, "media://") {
resolved = append(resolved, ref)
continue
}

localPath, meta, err := store.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
"ref": ref,
"error": err.Error(),
})
continue
}

info, err := os.Stat(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to stat media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
if info.Size() > maxMediaFileSize {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxMediaFileSize,
})
continue
}

data, err := os.ReadFile(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to read media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}

mime := meta.ContentType
if mime == "" {
mime = mimeFromExtension(filepath.Ext(localPath))
}
if mime == "" {
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
"path": localPath,
"ext": filepath.Ext(localPath),
})
continue
}

dataURL := "data:" + mime + ";base64," + base64.StdEncoding.EncodeToString(data)
resolved = append(resolved, dataURL)
}

result[i].Media = resolved
}

return result
}

// mimeFromExtension returns a MIME type for common image extensions.
// Returns empty string for unrecognized extensions.
func mimeFromExtension(ext string) string {
switch strings.ToLower(ext) {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".gif":
return "image/gif"
case ".webp":
return "image/webp"
case ".bmp":
return "image/bmp"
default:
return ""
}
}
123 changes: 0 additions & 123 deletions pkg/agent/loop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@ import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"

"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
Expand Down Expand Up @@ -810,124 +808,3 @@ func TestHandleReasoning(t *testing.T) {
}
})
}

func TestMimeFromExtension(t *testing.T) {
tests := []struct {
ext string
want string
}{
{".jpg", "image/jpeg"},
{".JPEG", "image/jpeg"},
{".png", "image/png"},
{".gif", "image/gif"},
{".webp", "image/webp"},
{".bmp", "image/bmp"},
{".txt", ""},
{".pdf", ""},
{"", ""},
}
for _, tt := range tests {
if got := mimeFromExtension(tt.ext); got != tt.want {
t.Errorf("mimeFromExtension(%q) = %q, want %q", tt.ext, got, tt.want)
}
}
}

func TestResolveMediaRefs_NilStore(t *testing.T) {
msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{"media://abc"}}}
result := resolveMediaRefs(msgs, nil)
if result[0].Media[0] != "media://abc" {
t.Error("nil store should return messages unchanged")
}
}

func TestResolveMediaRefs_NonMediaRef(t *testing.T) {
msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}}}
result := resolveMediaRefs(msgs, media.NewFileMediaStore())
if result[0].Media[0] != "https://example.com/img.png" {
t.Error("non-media:// refs should be passed through unchanged")
}
}

func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
store := media.NewFileMediaStore()

imgPath := filepath.Join(t.TempDir(), "test.png")
if err := os.WriteFile(imgPath, []byte("fake-png-data"), 0o644); err != nil {
t.Fatal(err)
}

ref, err := store.Store(imgPath, media.MediaMeta{ContentType: "image/png"}, "test")
if err != nil {
t.Fatal(err)
}

msgs := []providers.Message{{Role: "user", Content: "describe", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)

if len(result[0].Media) != 1 {
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
t.Errorf("expected data URL, got %s", result[0].Media[0][:40])
}
}

func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
store := media.NewFileMediaStore()

bigPath := filepath.Join(t.TempDir(), "big.jpg")
if err := os.WriteFile(bigPath, make([]byte, maxMediaFileSize+1), 0o644); err != nil {
t.Fatal(err)
}

ref, err := store.Store(bigPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
if err != nil {
t.Fatal(err)
}

msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)

if len(result[0].Media) != 0 {
t.Error("oversized file should be skipped")
}
}

func TestResolveMediaRefs_SkipsUnknownExtension(t *testing.T) {
store := media.NewFileMediaStore()

txtPath := filepath.Join(t.TempDir(), "readme.txt")
if err := os.WriteFile(txtPath, []byte("hello"), 0o644); err != nil {
t.Fatal(err)
}

ref, err := store.Store(txtPath, media.MediaMeta{}, "test")
if err != nil {
t.Fatal(err)
}

msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)

if len(result[0].Media) != 0 {
t.Error("unknown extension with no ContentType should be skipped")
}
}

func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
store := media.NewFileMediaStore()

imgPath := filepath.Join(t.TempDir(), "test.jpg")
if err := os.WriteFile(imgPath, []byte("data"), 0o644); err != nil {
t.Fatal(err)
}

ref, _ := store.Store(imgPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
original := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
resolveMediaRefs(original, store)

if !strings.HasPrefix(original[0].Media[0], "media://") {
t.Error("original message should not be mutated")
}
}
56 changes: 1 addition & 55 deletions pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (p *Provider) Chat(

requestBody := map[string]any{
"model": model,
"messages": serializeMessages(messages),
"messages": stripSystemParts(messages),
}

if len(tools) > 0 {
Expand Down Expand Up @@ -195,60 +195,6 @@ func (p *Provider) Chat(
return parseResponse(body)
}

func serializeMessages(messages []Message) []map[string]interface{} {
result := make([]map[string]interface{}, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
msg := map[string]interface{}{
"role": m.Role,
"content": m.Content,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
result = append(result, msg)
continue
}

parts := make([]map[string]interface{}, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]interface{}{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
parts = append(parts, map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": mediaURL,
},
})
}
msg := map[string]interface{}{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
result = append(result, msg)
}
return result
}

func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Expand Down
Loading