Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.

Commit 2de5127

Browse files
committed
initial tool call stream
1 parent 2b5a33e commit 2de5127

File tree

11 files changed

+261
-136
lines changed

11 files changed

+261
-136
lines changed

internal/llm/agent/agent.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,21 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
380380
case provider.EventContentDelta:
381381
assistantMsg.AppendContent(event.Content)
382382
return a.messages.Update(ctx, *assistantMsg)
383+
case provider.EventToolUseStart:
384+
assistantMsg.AddToolCall(*event.ToolCall)
385+
return a.messages.Update(ctx, *assistantMsg)
386+
// TODO: see how to handle this
387+
// case provider.EventToolUseDelta:
388+
// tm := time.Unix(assistantMsg.UpdatedAt, 0)
389+
// assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
390+
// if time.Since(tm) > 1000*time.Millisecond {
391+
// err := a.messages.Update(ctx, *assistantMsg)
392+
// assistantMsg.UpdatedAt = time.Now().Unix()
393+
// return err
394+
// }
395+
case provider.EventToolUseStop:
396+
assistantMsg.FinishToolCall(event.ToolCall.ID)
397+
return a.messages.Update(ctx, *assistantMsg)
383398
case provider.EventError:
384399
if errors.Is(event.Error, context.Canceled) {
385400
logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
@@ -456,6 +471,13 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
456471
provider.WithReasoningEffort(agentConfig.ReasoningEffort),
457472
),
458473
)
474+
} else if model.Provider == models.ProviderAnthropic && model.CanReason {
475+
opts = append(
476+
opts,
477+
provider.WithAnthropicOptions(
478+
provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
479+
),
480+
)
459481
}
460482
agentProvider, err := provider.NewProvider(
461483
model.Provider,

internal/llm/provider/anthropic.go

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
9393
}
9494

9595
if len(blocks) == 0 {
96-
logging.Warn("There is a message without content, investigate")
97-
// This should never happend but we log this because we might have a bug in our cleanup method
96+
logging.Warn("There is a message without content, investigate, this should not happen")
9897
continue
9998
}
10099
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
@@ -196,8 +195,8 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
196195
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
197196
cfg := config.Get()
198197
if cfg.Debug {
199-
jsonData, _ := json.Marshal(preparedMessages)
200-
logging.Debug("Prepared messages", "messages", string(jsonData))
198+
// jsonData, _ := json.Marshal(preparedMessages)
199+
// logging.Debug("Prepared messages", "messages", string(jsonData))
201200
}
202201
attempts := 0
203202
for {
@@ -243,8 +242,8 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
243242
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
244243
cfg := config.Get()
245244
if cfg.Debug {
246-
jsonData, _ := json.Marshal(preparedMessages)
247-
logging.Debug("Prepared messages", "messages", string(jsonData))
245+
// jsonData, _ := json.Marshal(preparedMessages)
246+
// logging.Debug("Prepared messages", "messages", string(jsonData))
248247
}
249248
attempts := 0
250249
eventChan := make(chan ProviderEvent)
@@ -257,6 +256,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
257256
)
258257
accumulatedMessage := anthropic.Message{}
259258

259+
currentToolCallID := ""
260260
for anthropicStream.Next() {
261261
event := anthropicStream.Current()
262262
err := accumulatedMessage.Accumulate(event)
@@ -267,7 +267,19 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
267267

268268
switch event := event.AsAny().(type) {
269269
case anthropic.ContentBlockStartEvent:
270-
eventChan <- ProviderEvent{Type: EventContentStart}
270+
if event.ContentBlock.Type == "text" {
271+
eventChan <- ProviderEvent{Type: EventContentStart}
272+
} else if event.ContentBlock.Type == "tool_use" {
273+
currentToolCallID = event.ContentBlock.ID
274+
eventChan <- ProviderEvent{
275+
Type: EventToolUseStart,
276+
ToolCall: &message.ToolCall{
277+
ID: event.ContentBlock.ID,
278+
Name: event.ContentBlock.Name,
279+
Finished: false,
280+
},
281+
}
282+
}
271283

272284
case anthropic.ContentBlockDeltaEvent:
273285
if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
@@ -280,11 +292,30 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
280292
Type: EventContentDelta,
281293
Content: event.Delta.Text,
282294
}
295+
} else if event.Delta.Type == "input_json_delta" {
296+
if currentToolCallID != "" {
297+
eventChan <- ProviderEvent{
298+
Type: EventToolUseDelta,
299+
ToolCall: &message.ToolCall{
300+
ID: currentToolCallID,
301+
Finished: false,
302+
Input: event.Delta.JSON.PartialJSON.Raw(),
303+
},
304+
}
305+
}
283306
}
284-
// TODO: check if we can somehow stream tool calls
285-
286307
case anthropic.ContentBlockStopEvent:
287-
eventChan <- ProviderEvent{Type: EventContentStop}
308+
if currentToolCallID != "" {
309+
eventChan <- ProviderEvent{
310+
Type: EventToolUseStop,
311+
ToolCall: &message.ToolCall{
312+
ID: currentToolCallID,
313+
},
314+
}
315+
currentToolCallID = ""
316+
} else {
317+
eventChan <- ProviderEvent{Type: EventContentStop}
318+
}
288319

289320
case anthropic.MessageStopEvent:
290321
content := ""
@@ -378,10 +409,11 @@ func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
378409
switch variant := block.AsAny().(type) {
379410
case anthropic.ToolUseBlock:
380411
toolCall := message.ToolCall{
381-
ID: variant.ID,
382-
Name: variant.Name,
383-
Input: string(variant.Input),
384-
Type: string(variant.Type),
412+
ID: variant.ID,
413+
Name: variant.Name,
414+
Input: string(variant.Input),
415+
Type: string(variant.Type),
416+
Finished: true,
385417
}
386418
toolCalls = append(toolCalls, toolCall)
387419
}

internal/llm/provider/openai.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,11 @@ func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.Too
344344
if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
345345
for _, call := range completion.Choices[0].Message.ToolCalls {
346346
toolCall := message.ToolCall{
347-
ID: call.ID,
348-
Name: call.Function.Name,
349-
Input: call.Function.Arguments,
350-
Type: "function",
347+
ID: call.ID,
348+
Name: call.Function.Name,
349+
Input: call.Function.Arguments,
350+
Type: "function",
351+
Finished: true,
351352
}
352353
toolCalls = append(toolCalls, toolCall)
353354
}

internal/llm/provider/provider.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ const maxRetries = 8
1515

1616
const (
1717
EventContentStart EventType = "content_start"
18+
EventToolUseStart EventType = "tool_use_start"
19+
EventToolUseDelta EventType = "tool_use_delta"
20+
EventToolUseStop EventType = "tool_use_stop"
1821
EventContentDelta EventType = "content_delta"
1922
EventThinkingDelta EventType = "thinking_delta"
2023
EventContentStop EventType = "content_stop"
@@ -43,8 +46,8 @@ type ProviderEvent struct {
4346
Content string
4447
Thinking string
4548
Response *ProviderResponse
46-
47-
Error error
49+
ToolCall *message.ToolCall
50+
Error error
4851
}
4952
type Provider interface {
5053
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)

internal/message/content.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,40 @@ func (m *Message) AppendReasoningContent(delta string) {
233233
}
234234
}
235235

236+
func (m *Message) FinishToolCall(toolCallID string) {
237+
for i, part := range m.Parts {
238+
if c, ok := part.(ToolCall); ok {
239+
if c.ID == toolCallID {
240+
m.Parts[i] = ToolCall{
241+
ID: c.ID,
242+
Name: c.Name,
243+
Input: c.Input,
244+
Type: c.Type,
245+
Finished: true,
246+
}
247+
return
248+
}
249+
}
250+
}
251+
}
252+
253+
func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
254+
for i, part := range m.Parts {
255+
if c, ok := part.(ToolCall); ok {
256+
if c.ID == toolCallID {
257+
m.Parts[i] = ToolCall{
258+
ID: c.ID,
259+
Name: c.Name,
260+
Input: c.Input + inputDelta,
261+
Type: c.Type,
262+
Finished: c.Finished,
263+
}
264+
return
265+
}
266+
}
267+
}
268+
}
269+
236270
func (m *Message) AddToolCall(tc ToolCall) {
237271
for i, part := range m.Parts {
238272
if c, ok := part.(ToolCall); ok {
@@ -246,6 +280,15 @@ func (m *Message) AddToolCall(tc ToolCall) {
246280
}
247281

248282
func (m *Message) SetToolCalls(tc []ToolCall) {
283+
// remove any existing tool call part it could have multiple
284+
parts := make([]ContentPart, 0)
285+
for _, part := range m.Parts {
286+
if _, ok := part.(ToolCall); ok {
287+
continue
288+
}
289+
parts = append(parts, part)
290+
}
291+
m.Parts = parts
249292
for _, toolCall := range tc {
250293
m.Parts = append(m.Parts, toolCall)
251294
}

internal/message/message.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"database/sql"
66
"encoding/json"
77
"fmt"
8+
"time"
89

910
"github.com/google/uuid"
1011
"github.com/kujtimiihoxha/opencode/internal/db"
@@ -116,6 +117,7 @@ func (s *service) Update(ctx context.Context, message Message) error {
116117
if err != nil {
117118
return err
118119
}
120+
message.UpdatedAt = time.Now().Unix()
119121
s.Publish(pubsub.UpdatedEvent, message)
120122
return nil
121123
}

internal/pubsub/broker.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,6 @@ import (
77

88
const bufferSize = 1024
99

10-
type Logger interface {
11-
Debug(msg string, args ...any)
12-
Info(msg string, args ...any)
13-
Warn(msg string, args ...any)
14-
Error(msg string, args ...any)
15-
}
16-
1710
// Broker allows clients to publish events and subscribe to events
1811
type Broker[T any] struct {
1912
subs map[chan Event[T]]struct{} // subscriptions

0 commit comments

Comments
 (0)