Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b429261
refactor(mcp): use the new mcp library
caarlos0 Oct 9, 2025
b6effb8
fix: some cleanups
caarlos0 Oct 9, 2025
c129a1c
fix: keep alive
caarlos0 Oct 9, 2025
c325df9
feat(mcp): support prompts
caarlos0 Oct 9, 2025
726dac8
fix: lint
caarlos0 Oct 9, 2025
42e87b6
fix: lint
caarlos0 Oct 9, 2025
3432ae4
Merge remote-tracking branch 'origin/main' into mcp
caarlos0 Oct 9, 2025
ee7b248
Merge branch 'mcp' into prompts
caarlos0 Oct 9, 2025
c60d0c1
refactor: improvements
caarlos0 Oct 9, 2025
c162a9e
refactor: cleanup
caarlos0 Oct 9, 2025
3362d6b
fix(ui): improvements
caarlos0 Oct 9, 2025
0a16068
refactor: reuse same dialog
caarlos0 Oct 9, 2025
3bf3067
refactor: cleanup
caarlos0 Oct 9, 2025
b68c5bb
refactor: clean
caarlos0 Oct 9, 2025
875f38c
refactor: more cleanup
caarlos0 Oct 9, 2025
ed2f608
chore: smaller diff
caarlos0 Oct 9, 2025
7b2f166
fixup! chore: smaller diff
caarlos0 Oct 9, 2025
64f50e2
fix: improve submit
caarlos0 Oct 9, 2025
7c15017
fixup! fix: improve submit
caarlos0 Oct 9, 2025
d856d29
fix: dont oversubmit
caarlos0 Oct 9, 2025
40bf11d
Merge remote-tracking branch 'origin/main' into mcp
caarlos0 Oct 10, 2025
fd2d966
fix(mcp): improve error handling
caarlos0 Oct 10, 2025
57f9208
Merge branch 'mcp' into prompts
caarlos0 Oct 10, 2025
c48ebad
chore: run with TEA_DEBUG
caarlos0 Oct 10, 2025
c9400eb
chore: taskfile
caarlos0 Oct 10, 2025
94d4570
fix: panic
caarlos0 Oct 10, 2025
fbbdfed
Merge remote-tracking branch 'origin/main' into prompts
caarlos0 Oct 10, 2025
dc58091
chore: revert taskfile change
caarlos0 Oct 10, 2025
82b6f8f
fix(cmds): small improvements in user commands
caarlos0 Oct 10, 2025
ea34b81
fixup! fix(cmds): small improvements in user commands
caarlos0 Oct 10, 2025
412008f
fixup! fixup! fix(cmds): small improvements in user commands
caarlos0 Oct 10, 2025
eccaa13
fixup! fixup! fixup! fix(cmds): small improvements in user commands
caarlos0 Oct 10, 2025
d04f053
Merge remote-tracking branch 'origin/main' into user-cmds
caarlos0 Oct 15, 2025
422762c
Merge remote-tracking branch 'origin/main' into user-cmds
caarlos0 Oct 15, 2025
a479e9f
Merge branch 'main' into prompts
tauraamui Oct 16, 2025
c6bc8a0
fix: use renamed fields for focused element id
tauraamui Oct 16, 2025
9e7e1d8
Merge remote-tracking branch 'origin/user-cmds' into prompts
caarlos0 Oct 17, 2025
34fd531
Merge remote-tracking branch 'origin/main' into prompts
caarlos0 Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions internal/llm/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1101,28 +1101,40 @@ func (a *agent) setupEvents(ctx context.Context) {
slog.Debug("MCPEvents subscription channel closed")
return
}
name := event.Payload.Name
c, ok := mcpClients.Get(name)
if !ok {
slog.Warn("MCP client not found for tools/prompts update", "name", name)
continue
}
switch event.Payload.Type {
case MCPEventToolsListChanged:
name := event.Payload.Name
c, ok := mcpClients.Get(name)
if !ok {
slog.Warn("MCP client not found for tools update", "name", name)
continue
}
cfg := config.Get()
tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
if err != nil {
slog.Error("error listing tools", "error", err)
updateMCPState(name, MCPStateError, err, nil, 0)
updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
_ = c.Close()
continue
}
updateMcpTools(name, tools)
a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
case MCPEventPromptsListChanged:
prompts, err := getPrompts(ctx, c)
if err != nil {
slog.Error("error listing prompts", "error", err)
updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
_ = c.Close()
continue
}
updateMcpPrompts(name, prompts)
default:
continue
}
updateMCPState(name, MCPStateConnected, nil, c, MCPCounts{
Tools: mcpTools.Len(),
Prompts: mcpPrompts.Len(),
})
case <-ctx.Done():
slog.Debug("MCPEvents subscription cancelled")
return
Expand Down
194 changes: 135 additions & 59 deletions internal/llm/agent/mcp-tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,23 @@ func (s MCPState) String() string {
type MCPEventType string

const (
MCPEventStateChanged MCPEventType = "state_changed"
MCPEventToolsListChanged MCPEventType = "tools_list_changed"
MCPEventToolsListChanged MCPEventType = "tools_list_changed"
MCPEventPromptsListChanged MCPEventType = "prompts_list_changed"
)

// MCPEvent represents an event in the MCP system
type MCPEvent struct {
Type MCPEventType
Name string
State MCPState
Error error
ToolCount int
Type MCPEventType
Name string
State MCPState
Error error
Counts MCPCounts
}

// MCPCounts number of available tools, prompts, etc.
type MCPCounts struct {
Tools int
Prompts int
}

// MCPClientInfo holds information about an MCP client's state
Expand All @@ -74,17 +80,19 @@ type MCPClientInfo struct {
State MCPState
Error error
Client *mcp.ClientSession
ToolCount int
Counts MCPCounts
ConnectedAt time.Time
}

var (
mcpToolsOnce sync.Once
mcpTools = csync.NewMap[string, tools.BaseTool]()
mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
mcpClients = csync.NewMap[string, *mcp.ClientSession]()
mcpStates = csync.NewMap[string, MCPClientInfo]()
mcpBroker = pubsub.NewBroker[MCPEvent]()
mcpToolsOnce sync.Once
mcpTools = csync.NewMap[string, tools.BaseTool]()
mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
mcpClients = csync.NewMap[string, *mcp.ClientSession]()
mcpStates = csync.NewMap[string, MCPClientInfo]()
mcpBroker = pubsub.NewBroker[MCPEvent]()
mcpPrompts = csync.NewMap[string, *mcp.Prompt]()
mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]()
)

type McpTool struct {
Expand Down Expand Up @@ -173,14 +181,14 @@ func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, err
if err == nil {
return sess, nil
}
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)

sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
if err != nil {
return nil, err
}

updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
mcpClients.Set(name, sess)
return sess, nil
}
Expand Down Expand Up @@ -210,6 +218,9 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
}

func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
if c.InitializeResult().Capabilities.Tools == nil {
return nil, nil
}
result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
if err != nil {
return nil, err
Expand Down Expand Up @@ -242,31 +253,23 @@ func GetMCPState(name string) (MCPClientInfo, bool) {
}

// updateMCPState updates the state of an MCP client and publishes an event
func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
info := MCPClientInfo{
Name: name,
State: state,
Error: err,
Client: client,
ToolCount: toolCount,
Name: name,
State: state,
Error: err,
Client: client,
Counts: counts,
}
switch state {
case MCPStateConnected:
info.ConnectedAt = time.Now()
case MCPStateError:
updateMcpTools(name, nil)
updateMcpPrompts(name, nil)
mcpClients.Del(name)
}
mcpStates.Set(name, info)

// Publish state change event
mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
Type: MCPEventStateChanged,
Name: name,
State: state,
Error: err,
ToolCount: toolCount,
})
}

// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
Expand All @@ -289,13 +292,13 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
// Initialize states for all configured MCPs
for name, m := range cfg.MCP {
if m.Disabled {
updateMCPState(name, MCPStateDisabled, nil, nil, 0)
updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
slog.Debug("skipping disabled mcp", "name", name)
continue
}

// Set initial starting state
updateMCPState(name, MCPStateStarting, nil, nil, 0)
updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})

wg.Add(1)
go func(name string, m config.MCPConfig) {
Expand All @@ -311,7 +314,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
default:
err = fmt.Errorf("panic: %v", v)
}
updateMCPState(name, MCPStateError, err, nil, 0)
updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
slog.Error("panic in mcp client initialization", "error", err, "name", name)
}
}()
Expand All @@ -329,14 +332,27 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
if err != nil {
slog.Error("error listing tools", "error", err)
updateMCPState(name, MCPStateError, err, nil, 0)
updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
c.Close()
return
}

prompts, err := getPrompts(ctx, c)
if err != nil {
slog.Error("error listing prompts", "error", err)
updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
c.Close()
return
}

updateMcpTools(name, tools)
updateMcpPrompts(name, prompts)
mcpClients.Set(name, c)
updateMCPState(name, MCPStateConnected, nil, c, len(tools))
counts := MCPCounts{
Tools: len(tools),
Prompts: len(prompts),
}
updateMCPState(name, MCPStateConnected, nil, c, counts)
}(name, m)
}
wg.Wait()
Expand All @@ -363,7 +379,7 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso

transport, err := createMCPTransport(mcpCtx, m, resolver)
if err != nil {
updateMCPState(name, MCPStateError, err, nil, 0)
updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
slog.Error("error creating mcp client", "error", err, "name", name)
cancel()
cancelTimer.Stop()
Expand All @@ -383,14 +399,20 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso
Name: name,
})
},
PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
Type: MCPEventPromptsListChanged,
Name: name,
})
},
KeepAlive: time.Minute * 10,
},
)

session, err := client.Connect(mcpCtx, transport, nil)
if err != nil {
err = maybeStdioErr(err, transport)
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
slog.Error("error starting mcp client", "error", err, "name", name)
cancel()
cancelTimer.Stop()
Expand All @@ -402,27 +424,6 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso
return session, nil
}

// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
// to parse, and the cli will then close it, causing the EOF error.
// so, if we got an EOF err, and the transport is STDIO, we try to exec it
// again with a timeout and collect the output so we can add details to the
// error.
// this happens particularly when starting things with npx, e.g. if node can't
// be found or some other error like that.
func maybeStdioErr(err error, transport mcp.Transport) error {
if !errors.Is(err, io.EOF) {
return err
}
ct, ok := transport.(*mcp.CommandTransport)
if !ok {
return err
}
if err2 := stdioMCPCheck(ct.Command); err2 != nil {
err = errors.Join(err, err2)
}
return err
}

func maybeTimeoutErr(err error, timeout time.Duration) error {
if errors.Is(err, context.Canceled) {
return fmt.Errorf("timed out after %s", timeout)
Expand Down Expand Up @@ -491,6 +492,81 @@ func mcpTimeout(m config.MCPConfig) time.Duration {
return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
}

func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) {
if c.InitializeResult().Capabilities.Prompts == nil {
return nil, nil
}
result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{})
if err != nil {
return nil, err
}
return result.Prompts, nil
}

// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps.
func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) {
if len(prompts) == 0 {
mcpClient2Prompts.Del(mcpName)
} else {
mcpClient2Prompts.Set(mcpName, prompts)
}
for clientName, prompts := range mcpClient2Prompts.Seq2() {
for _, p := range prompts {
key := clientName + ":" + p.Name
mcpPrompts.Set(key, p)
}
}
}

// GetMCPPrompts returns all available MCP prompts.
func GetMCPPrompts() map[string]*mcp.Prompt {
return maps.Collect(mcpPrompts.Seq2())
}

// GetMCPPrompt returns a specific MCP prompt by name.
func GetMCPPrompt(name string) (*mcp.Prompt, bool) {
return mcpPrompts.Get(name)
}

// GetMCPPromptsByClient returns all prompts for a specific MCP client.
func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) {
return mcpClient2Prompts.Get(clientName)
}

// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments.
func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) {
c, err := getOrRenewClient(ctx, clientName)
if err != nil {
return nil, err
}

return c.GetPrompt(ctx, &mcp.GetPromptParams{
Name: promptName,
Arguments: args,
})
}

// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
// to parse, and the cli will then close it, causing the EOF error.
// so, if we got an EOF err, and the transport is STDIO, we try to exec it
// again with a timeout and collect the output so we can add details to the
// error.
// this happens particularly when starting things with npx, e.g. if node can't
// be found or some other error like that.
func maybeStdioErr(err error, transport mcp.Transport) error {
if !errors.Is(err, io.EOF) {
return err
}
ct, ok := transport.(*mcp.CommandTransport)
if !ok {
return err
}
if err2 := stdioMCPCheck(ct.Command); err2 != nil {
err = errors.Join(err, err2)
}
return err
}

func stdioMCPCheck(old *exec.Cmd) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
Expand Down
Loading
Loading