Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/agent/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,11 @@ func (cb *ContextBuilder) BuildMessages(
messages = append(messages, history...)

// Add current user message
if strings.TrimSpace(currentMessage) != "" {
if strings.TrimSpace(currentMessage) != "" || len(media) > 0 {
messages = append(messages, providers.Message{
Role: "user",
Content: currentMessage,
Media: media,
})
}

Expand Down
119 changes: 113 additions & 6 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ package agent

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
Expand Down Expand Up @@ -46,11 +48,12 @@ type AgentLoop struct {

// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
DefaultResponse string // Response when LLM returns empty
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // Media URLs attached to the user message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
Expand Down Expand Up @@ -417,6 +420,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
Expand Down Expand Up @@ -509,10 +513,11 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
history,
summary,
opts.UserMessage,
nil,
opts.Media,
opts.Channel,
opts.ChatID,
)
messages = resolveMediaRefs(messages, al.mediaStore)

// 3. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
Expand Down Expand Up @@ -1350,3 +1355,105 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
}
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
}

// maxMediaFileSize is the maximum file size (20 MB) for media resolution.
// Files larger than this are skipped to prevent OOM under concurrent load.
const maxMediaFileSize = 20 * 1024 * 1024
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bake size limitation into config.


// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
// Returns a new slice with resolved URLs; original messages are not mutated.
func resolveMediaRefs(messages []providers.Message, store media.MediaStore) []providers.Message {
if store == nil {
return messages
}

result := make([]providers.Message, len(messages))
copy(result, messages)

for i, m := range result {
if len(m.Media) == 0 {
continue
}

resolved := make([]string, 0, len(m.Media))
for _, ref := range m.Media {
if !strings.HasPrefix(ref, "media://") {
resolved = append(resolved, ref)
continue
}

localPath, meta, err := store.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
"ref": ref,
"error": err.Error(),
})
continue
}

info, err := os.Stat(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to stat media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
if info.Size() > maxMediaFileSize {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxMediaFileSize,
})
continue
}

data, err := os.ReadFile(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to read media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}

mime := meta.ContentType
if mime == "" {
mime = mimeFromExtension(filepath.Ext(localPath))
}
if mime == "" {
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
"path": localPath,
"ext": filepath.Ext(localPath),
})
continue
}

dataURL := "data:" + mime + ";base64," + base64.StdEncoding.EncodeToString(data)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use file handler and encoder for later use, instead of allocation 2X of memory for media files.

resolved = append(resolved, dataURL)
}

result[i].Media = resolved
}

return result
}

// mimeFromExtension returns a MIME type for common image extensions.
// Returns empty string for unrecognized extensions.
func mimeFromExtension(ext string) string {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use github.com/h2non/filetype for file type detection

switch strings.ToLower(ext) {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".gif":
return "image/gif"
case ".webp":
return "image/webp"
case ".bmp":
return "image/bmp"
default:
return ""
}
}
123 changes: 123 additions & 0 deletions pkg/agent/loop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"

"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
Expand Down Expand Up @@ -840,3 +842,124 @@ func TestHandleReasoning(t *testing.T) {
}
})
}

func TestMimeFromExtension(t *testing.T) {
tests := []struct {
ext string
want string
}{
{".jpg", "image/jpeg"},
{".JPEG", "image/jpeg"},
{".png", "image/png"},
{".gif", "image/gif"},
{".webp", "image/webp"},
{".bmp", "image/bmp"},
{".txt", ""},
{".pdf", ""},
{"", ""},
}
for _, tt := range tests {
if got := mimeFromExtension(tt.ext); got != tt.want {
t.Errorf("mimeFromExtension(%q) = %q, want %q", tt.ext, got, tt.want)
}
}
}

func TestResolveMediaRefs_NilStore(t *testing.T) {
msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{"media://abc"}}}
result := resolveMediaRefs(msgs, nil)
if result[0].Media[0] != "media://abc" {
t.Error("nil store should return messages unchanged")
}
}

func TestResolveMediaRefs_NonMediaRef(t *testing.T) {
msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}}}
result := resolveMediaRefs(msgs, media.NewFileMediaStore())
if result[0].Media[0] != "https://example.com/img.png" {
t.Error("non-media:// refs should be passed through unchanged")
}
}

func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
store := media.NewFileMediaStore()

imgPath := filepath.Join(t.TempDir(), "test.png")
if err := os.WriteFile(imgPath, []byte("fake-png-data"), 0o644); err != nil {
t.Fatal(err)
}

ref, err := store.Store(imgPath, media.MediaMeta{ContentType: "image/png"}, "test")
if err != nil {
t.Fatal(err)
}

msgs := []providers.Message{{Role: "user", Content: "describe", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)

if len(result[0].Media) != 1 {
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
t.Errorf("expected data URL, got %s", result[0].Media[0][:40])
}
}

func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
store := media.NewFileMediaStore()

bigPath := filepath.Join(t.TempDir(), "big.jpg")
if err := os.WriteFile(bigPath, make([]byte, maxMediaFileSize+1), 0o644); err != nil {
t.Fatal(err)
}

ref, err := store.Store(bigPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
if err != nil {
t.Fatal(err)
}

msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)

if len(result[0].Media) != 0 {
t.Error("oversized file should be skipped")
}
}

func TestResolveMediaRefs_SkipsUnknownExtension(t *testing.T) {
store := media.NewFileMediaStore()

txtPath := filepath.Join(t.TempDir(), "readme.txt")
if err := os.WriteFile(txtPath, []byte("hello"), 0o644); err != nil {
t.Fatal(err)
}

ref, err := store.Store(txtPath, media.MediaMeta{}, "test")
if err != nil {
t.Fatal(err)
}

msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)

if len(result[0].Media) != 0 {
t.Error("unknown extension with no ContentType should be skipped")
}
}

func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
store := media.NewFileMediaStore()

imgPath := filepath.Join(t.TempDir(), "test.jpg")
if err := os.WriteFile(imgPath, []byte("data"), 0o644); err != nil {
t.Fatal(err)
}

ref, _ := store.Store(imgPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
original := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
resolveMediaRefs(original, store)

if !strings.HasPrefix(original[0].Media[0], "media://") {
t.Error("original message should not be mutated")
}
}
56 changes: 55 additions & 1 deletion pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (p *Provider) Chat(

requestBody := map[string]any{
"model": model,
"messages": stripSystemParts(messages),
"messages": serializeMessages(messages),
}

if len(tools) > 0 {
Expand Down Expand Up @@ -195,6 +195,60 @@ func (p *Provider) Chat(
return parseResponse(body)
}

func serializeMessages(messages []Message) []map[string]interface{} {
result := make([]map[string]interface{}, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
msg := map[string]interface{}{
"role": m.Role,
"content": m.Content,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
result = append(result, msg)
continue
}

parts := make([]map[string]interface{}, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]interface{}{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
parts = append(parts, map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": mediaURL,
},
})
}
msg := map[string]interface{}{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
result = append(result, msg)
}
return result
}

func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Expand Down
Loading