Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
354 changes: 329 additions & 25 deletions pkg/agent/context.go

Large diffs are not rendered by default.

513 changes: 513 additions & 0 deletions pkg/agent/context_cache_test.go

Large diffs are not rendered by default.

20 changes: 12 additions & 8 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,9 @@ func (al *AgentLoop) runLLMIteration(
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
})
},
)
Expand All @@ -540,8 +541,9 @@ func (al *AgentLoop) runLLMIteration(
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
})
}

Expand Down Expand Up @@ -962,8 +964,9 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
nil,
agent.Model,
map[string]any{
"max_tokens": 1024,
"temperature": 0.3,
"max_tokens": 1024,
"temperature": 0.3,
"prompt_cache_key": agent.ID,
},
)
if err == nil {
Expand Down Expand Up @@ -1012,8 +1015,9 @@ func (al *AgentLoop) summarizeBatch(
nil,
agent.Model,
map[string]any{
"max_tokens": 1024,
"temperature": 0.3,
"max_tokens": 1024,
"temperature": 0.3,
"prompt_cache_key": agent.ID,
},
)
if err != nil {
Expand Down
15 changes: 14 additions & 1 deletion pkg/providers/anthropic/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,20 @@ func buildParams(
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
// Prefer structured SystemParts for per-block cache_control.
// This enables LLM-side KV cache reuse: the static block's prefix
// hash stays stable across requests while dynamic parts change freely.
if len(msg.SystemParts) > 0 {
for _, part := range msg.SystemParts {
block := anthropic.TextBlockParam{Text: part.Text}
if part.CacheControl != nil && part.CacheControl.Type == "ephemeral" {
block.CacheControl = anthropic.NewCacheControlEphemeralParam()
}
system = append(system, block)
}
} else {
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
}
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
Expand Down
12 changes: 12 additions & 0 deletions pkg/providers/codex_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ func buildCodexParams(
for _, msg := range messages {
switch msg.Role {
case "system":
// Use the full concatenated system prompt (static + dynamic + summary)
// as instructions. This keeps behavior consistent with Anthropic and
// OpenAI-compat adapters where the complete system context lives in
// one place. Prefix caching is handled by prompt_cache_key below,
// not by splitting content across instructions vs input messages.
instructions = msg.Content
case "user":
if msg.ToolCallID != "" {
Expand Down Expand Up @@ -289,6 +294,13 @@ func buildCodexParams(
params.Instructions = openai.Opt(defaultCodexInstructions)
}

// Prompt caching: pass a stable cache key so OpenAI can bucket requests
// and reuse prefix KV cache across calls with the same key.
// See: https://platform.openai.com/docs/guides/prompt-caching
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
params.PromptCacheKey = openai.Opt(cacheKey)
}

if len(tools) > 0 || enableWebSearch {
params.Tools = translateToolsForCodex(tools, enableWebSearch)
}
Expand Down
36 changes: 35 additions & 1 deletion pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (p *Provider) Chat(

requestBody := map[string]any{
"model": model,
"messages": messages,
"messages": stripSystemParts(messages),
}

if len(tools) > 0 {
Expand Down Expand Up @@ -111,6 +111,14 @@ func (p *Provider) Chat(
}
}

// Prompt caching: pass a stable cache key so OpenAI can bucket requests
// with the same key and reuse prefix KV cache across calls.
// The key is typically the agent ID — stable per agent, shared across requests.
// See: https://platform.openai.com/docs/guides/prompt-caching
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
requestBody["prompt_cache_key"] = cacheKey
}

jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
Expand Down Expand Up @@ -230,6 +238,32 @@ func parseResponse(body []byte) (*LLMResponse, error) {
}, nil
}

// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}

// stripSystemParts converts []Message to []openaiMessage, dropping the
// SystemParts field so it doesn't leak into the JSON payload sent to
// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
func stripSystemParts(messages []Message) []openaiMessage {
out := make([]openaiMessage, len(messages))
for i, m := range messages {
out[i] = openaiMessage{
Role: m.Role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
}
}
return out
}

func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
Expand Down
26 changes: 21 additions & 5 deletions pkg/providers/protocoltypes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,28 @@ type UsageInfo struct {
TotalTokens int `json:"total_tokens"`
}

// CacheControl marks a content block for LLM-side prefix caching.
// Currently only "ephemeral" is supported (used by Anthropic).
type CacheControl struct {
Type string `json:"type"` // "ephemeral"
}

// ContentBlock represents a structured segment of a system message.
// Adapters that understand SystemParts can use these blocks to set
// per-block cache control (e.g. Anthropic's cache_control: ephemeral).
type ContentBlock struct {
Type string `json:"type"` // "text"
Text string `json:"text"`
CacheControl *CacheControl `json:"cache_control,omitempty"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}

type ToolDefinition struct {
Expand Down
2 changes: 2 additions & 0 deletions pkg/providers/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type (
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
ExtraContent = protocoltypes.ExtraContent
GoogleExtra = protocoltypes.GoogleExtra
ContentBlock = protocoltypes.ContentBlock
CacheControl = protocoltypes.CacheControl
)

type LLMProvider interface {
Expand Down
39 changes: 27 additions & 12 deletions pkg/tools/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
"sort"
"sync"
"time"

Expand Down Expand Up @@ -107,13 +108,27 @@ func (r *ToolRegistry) ExecuteWithContext(
return result
}

// sortedToolNames returns tool names in sorted order for deterministic iteration.
// This is critical for KV cache stability: non-deterministic map iteration would
// produce different system prompts and tool definitions on each call, invalidating
// the LLM's prefix cache even when no tools have changed.
func (r *ToolRegistry) sortedToolNames() []string {
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
sort.Strings(names)
return names
}

func (r *ToolRegistry) GetDefinitions() []map[string]any {
r.mu.RLock()
defer r.mu.RUnlock()

definitions := make([]map[string]any, 0, len(r.tools))
for _, tool := range r.tools {
definitions = append(definitions, ToolToSchema(tool))
sorted := r.sortedToolNames()
definitions := make([]map[string]any, 0, len(sorted))
for _, name := range sorted {
definitions = append(definitions, ToolToSchema(r.tools[name]))
}
return definitions
}
Expand All @@ -124,8 +139,10 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()

definitions := make([]providers.ToolDefinition, 0, len(r.tools))
for _, tool := range r.tools {
sorted := r.sortedToolNames()
definitions := make([]providers.ToolDefinition, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
schema := ToolToSchema(tool)

// Safely extract nested values with type checks
Expand Down Expand Up @@ -155,11 +172,7 @@ func (r *ToolRegistry) List() []string {
r.mu.RLock()
defer r.mu.RUnlock()

names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
return names
return r.sortedToolNames()
}

// Count returns the number of registered tools.
Expand All @@ -175,8 +188,10 @@ func (r *ToolRegistry) GetSummaries() []string {
r.mu.RLock()
defer r.mu.RUnlock()

summaries := make([]string, 0, len(r.tools))
for _, tool := range r.tools {
sorted := r.sortedToolNames()
summaries := make([]string, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
}
return summaries
Expand Down