diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index b2b222db1..224d2850c 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -1101,28 +1101,40 @@ func (a *agent) setupEvents(ctx context.Context) { slog.Debug("MCPEvents subscription channel closed") return } + name := event.Payload.Name + c, ok := mcpClients.Get(name) + if !ok { + slog.Warn("MCP client not found for tools/prompts update", "name", name) + continue + } switch event.Payload.Type { case MCPEventToolsListChanged: - name := event.Payload.Name - c, ok := mcpClients.Get(name) - if !ok { - slog.Warn("MCP client not found for tools update", "name", name) - continue - } cfg := config.Get() tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir()) if err != nil { slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) + updateMCPState(name, MCPStateError, err, nil, MCPCounts{}) _ = c.Close() continue } updateMcpTools(name, tools) a.mcpTools.Reset(maps.Collect(mcpTools.Seq2())) - updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len()) + case MCPEventPromptsListChanged: + prompts, err := getPrompts(ctx, c) + if err != nil { + slog.Error("error listing prompts", "error", err) + updateMCPState(name, MCPStateError, err, nil, MCPCounts{}) + _ = c.Close() + continue + } + updateMcpPrompts(name, prompts) default: continue } + updateMCPState(name, MCPStateConnected, nil, c, MCPCounts{ + Tools: mcpTools.Len(), + Prompts: mcpPrompts.Len(), + }) case <-ctx.Done(): slog.Debug("MCPEvents subscription cancelled") return diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 6838c54ab..a1077623a 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -55,17 +55,23 @@ func (s MCPState) String() string { type MCPEventType string const ( - MCPEventStateChanged MCPEventType = "state_changed" - MCPEventToolsListChanged MCPEventType = "tools_list_changed" + MCPEventToolsListChanged MCPEventType = "tools_list_changed" + MCPEventPromptsListChanged MCPEventType = "prompts_list_changed" ) // MCPEvent represents an event in the MCP system type MCPEvent struct { - Type MCPEventType - Name string - State MCPState - Error error - ToolCount int + Type MCPEventType + Name string + State MCPState + Error error + Counts MCPCounts +} + +// MCPCounts number of available tools, prompts, etc. +type MCPCounts struct { + Tools int + Prompts int } // MCPClientInfo holds information about an MCP client's state @@ -74,17 +80,19 @@ type MCPClientInfo struct { State MCPState Error error Client *mcp.ClientSession - ToolCount int + Counts MCPCounts ConnectedAt time.Time } var ( - mcpToolsOnce sync.Once - mcpTools = csync.NewMap[string, tools.BaseTool]() - mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]() - mcpClients = csync.NewMap[string, *mcp.ClientSession]() - mcpStates = csync.NewMap[string, MCPClientInfo]() - mcpBroker = pubsub.NewBroker[MCPEvent]() + mcpToolsOnce sync.Once + mcpTools = csync.NewMap[string, tools.BaseTool]() + mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]() + mcpClients = csync.NewMap[string, *mcp.ClientSession]() + mcpStates = csync.NewMap[string, MCPClientInfo]() + mcpBroker = pubsub.NewBroker[MCPEvent]() + mcpPrompts = csync.NewMap[string, *mcp.Prompt]() + mcpClient2Prompts = csync.NewMap[string, []*mcp.Prompt]() ) type McpTool struct { @@ -173,14 +181,14 @@ func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, err if err == nil { return sess, nil } - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount) + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts) sess, err = createMCPSession(ctx, name, m, cfg.Resolver()) if err != nil { return nil, err } - updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount) + updateMCPState(name, MCPStateConnected, nil, sess, state.Counts) mcpClients.Set(name, sess) return sess, nil } @@ -210,6 +218,9 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes } func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) { + if c.InitializeResult().Capabilities.Tools == nil { + return nil, nil + } result, err := c.ListTools(ctx, &mcp.ListToolsParams{}) if err != nil { return nil, err @@ -242,31 +253,23 @@ func GetMCPState(name string) (MCPClientInfo, bool) { } // updateMCPState updates the state of an MCP client and publishes an event -func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) { +func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) { info := MCPClientInfo{ - Name: name, - State: state, - Error: err, - Client: client, - ToolCount: toolCount, + Name: name, + State: state, + Error: err, + Client: client, + Counts: counts, } switch state { case MCPStateConnected: info.ConnectedAt = time.Now() case MCPStateError: updateMcpTools(name, nil) + updateMcpPrompts(name, nil) mcpClients.Del(name) } mcpStates.Set(name, info) - - // Publish state change event - mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ - Type: MCPEventStateChanged, - Name: name, - State: state, - Error: err, - ToolCount: toolCount, - }) } // CloseMCPClients closes all MCP clients. This should be called during application shutdown. @@ -289,13 +292,13 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con // Initialize states for all configured MCPs for name, m := range cfg.MCP { if m.Disabled { - updateMCPState(name, MCPStateDisabled, nil, nil, 0) + updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{}) slog.Debug("skipping disabled mcp", "name", name) continue } // Set initial starting state - updateMCPState(name, MCPStateStarting, nil, nil, 0) + updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{}) wg.Add(1) go func(name string, m config.MCPConfig) { @@ -311,7 +314,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con default: err = fmt.Errorf("panic: %v", v) } - updateMCPState(name, MCPStateError, err, nil, 0) + updateMCPState(name, MCPStateError, err, nil, MCPCounts{}) slog.Error("panic in mcp client initialization", "error", err, "name", name) } }() @@ -329,14 +332,27 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir()) if err != nil { slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) + updateMCPState(name, MCPStateError, err, nil, MCPCounts{}) + c.Close() + return + } + + prompts, err := getPrompts(ctx, c) + if err != nil { + slog.Error("error listing prompts", "error", err) + updateMCPState(name, MCPStateError, err, nil, MCPCounts{}) c.Close() return } updateMcpTools(name, tools) + updateMcpPrompts(name, prompts) mcpClients.Set(name, c) - updateMCPState(name, MCPStateConnected, nil, c, len(tools)) + counts := MCPCounts{ + Tools: len(tools), + Prompts: len(prompts), + } + updateMCPState(name, MCPStateConnected, nil, c, counts) }(name, m) } wg.Wait() @@ -363,7 +379,7 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso transport, err := createMCPTransport(mcpCtx, m, resolver) if err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) + updateMCPState(name, MCPStateError, err, nil, MCPCounts{}) slog.Error("error creating mcp client", "error", err, "name", name) cancel() cancelTimer.Stop() @@ -383,6 +399,12 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso Name: name, }) }, + PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) { + mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ + Type: MCPEventPromptsListChanged, + Name: name, + }) + }, KeepAlive: time.Minute * 10, }, ) @@ -390,7 +412,7 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso session, err := client.Connect(mcpCtx, transport, nil) if err != nil { err = maybeStdioErr(err, transport) - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{}) slog.Error("error starting mcp client", "error", err, "name", name) cancel() cancelTimer.Stop() @@ -402,27 +424,6 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso return session, nil } -// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail -// to parse, and the cli will then close it, causing the EOF error. -// so, if we got an EOF err, and the transport is STDIO, we try to exec it -// again with a timeout and collect the output so we can add details to the -// error. -// this happens particularly when starting things with npx, e.g. if node can't -// be found or some other error like that. -func maybeStdioErr(err error, transport mcp.Transport) error { - if !errors.Is(err, io.EOF) { - return err - } - ct, ok := transport.(*mcp.CommandTransport) - if !ok { - return err - } - if err2 := stdioMCPCheck(ct.Command); err2 != nil { - err = errors.Join(err, err2) - } - return err -} - func maybeTimeoutErr(err error, timeout time.Duration) error { if errors.Is(err, context.Canceled) { return fmt.Errorf("timed out after %s", timeout) @@ -491,6 +492,81 @@ func mcpTimeout(m config.MCPConfig) time.Duration { return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second } +func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*mcp.Prompt, error) { + if c.InitializeResult().Capabilities.Prompts == nil { + return nil, nil + } + result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{}) + if err != nil { + return nil, err + } + return result.Prompts, nil +} + +// updateMcpPrompts updates the global mcpPrompts and mcpClient2Prompts maps. +func updateMcpPrompts(mcpName string, prompts []*mcp.Prompt) { + if len(prompts) == 0 { + mcpClient2Prompts.Del(mcpName) + } else { + mcpClient2Prompts.Set(mcpName, prompts) + } + for clientName, prompts := range mcpClient2Prompts.Seq2() { + for _, p := range prompts { + key := clientName + ":" + p.Name + mcpPrompts.Set(key, p) + } + } +} + +// GetMCPPrompts returns all available MCP prompts. +func GetMCPPrompts() map[string]*mcp.Prompt { + return maps.Collect(mcpPrompts.Seq2()) +} + +// GetMCPPrompt returns a specific MCP prompt by name. +func GetMCPPrompt(name string) (*mcp.Prompt, bool) { + return mcpPrompts.Get(name) +} + +// GetMCPPromptsByClient returns all prompts for a specific MCP client. +func GetMCPPromptsByClient(clientName string) ([]*mcp.Prompt, bool) { + return mcpClient2Prompts.Get(clientName) +} + +// GetMCPPromptContent retrieves the content of an MCP prompt with the given arguments. +func GetMCPPromptContent(ctx context.Context, clientName, promptName string, args map[string]string) (*mcp.GetPromptResult, error) { + c, err := getOrRenewClient(ctx, clientName) + if err != nil { + return nil, err + } + + return c.GetPrompt(ctx, &mcp.GetPromptParams{ + Name: promptName, + Arguments: args, + }) +} + +// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail +// to parse, and the cli will then close it, causing the EOF error. +// so, if we got an EOF err, and the transport is STDIO, we try to exec it +// again with a timeout and collect the output so we can add details to the +// error. +// this happens particularly when starting things with npx, e.g. if node can't +// be found or some other error like that. +func maybeStdioErr(err error, transport mcp.Transport) error { + if !errors.Is(err, io.EOF) { + return err + } + ct, ok := transport.(*mcp.CommandTransport) + if !ok { + return err + } + if err2 := stdioMCPCheck(ct.Command); err2 != nil { + err = errors.Join(err, err2) + } + return err +} + func stdioMCPCheck(old *exec.Cmd) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() diff --git a/internal/tui/components/dialogs/commands/arguments.go b/internal/tui/components/dialogs/commands/arguments.go index 72677bc93..26ca00b74 100644 --- a/internal/tui/components/dialogs/commands/arguments.go +++ b/internal/tui/components/dialogs/commands/arguments.go @@ -1,8 +1,7 @@ package commands import ( - "fmt" - "strings" + "cmp" "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" @@ -20,9 +19,10 @@ const ( // ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog. type ShowArgumentsDialogMsg struct { - CommandID string - Content string - ArgNames []string + CommandID string + Description string + ArgNames []string + OnSubmit func(args map[string]string) tea.Cmd } // CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed. @@ -39,26 +39,39 @@ type CommandArgumentsDialog interface { } type commandArgumentsDialogCmp struct { - width int - wWidth int // Width of the terminal window - wHeight int // Height of the terminal window - - inputs []textinput.Model - focusIndex int - keys ArgumentsDialogKeyMap - commandID string - content string - argNames []string - help help.Model + wWidth, wHeight int + width, height int + + inputs []textinput.Model + focused int + keys ArgumentsDialogKeyMap + arguments []Argument + help help.Model + + id string + title string + name string + description string + + onSubmit func(args map[string]string) tea.Cmd +} + +type Argument struct { + Name, Title, Description string + Required bool } -func NewCommandArgumentsDialog(commandID, content string, argNames []string) CommandArgumentsDialog { +func NewCommandArgumentsDialog( + id, title, name, description string, + arguments []Argument, + onSubmit func(args map[string]string) tea.Cmd, +) CommandArgumentsDialog { t := styles.CurrentTheme() - inputs := make([]textinput.Model, len(argNames)) + inputs := make([]textinput.Model, len(arguments)) - for i, name := range argNames { + for i, arg := range arguments { ti := textinput.New() - ti.Placeholder = fmt.Sprintf("Enter value for %s...", name) + ti.Placeholder = cmp.Or(arg.Description, "Enter value for "+arg.Title) ti.SetWidth(40) ti.SetVirtualCursor(false) ti.Prompt = "" @@ -75,14 +88,16 @@ func NewCommandArgumentsDialog(commandID, content string, argNames []string) Com } return &commandArgumentsDialogCmp{ - inputs: inputs, - keys: DefaultArgumentsDialogKeyMap(), - commandID: commandID, - content: content, - argNames: argNames, - focusIndex: 0, - width: 60, - help: help.New(), + inputs: inputs, + keys: DefaultArgumentsDialogKeyMap(), + id: id, + name: name, + title: title, + description: description, + arguments: arguments, + width: 60, + help: help.New(), + onSubmit: onSubmit, } } @@ -97,47 +112,51 @@ func (c *commandArgumentsDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.WindowSizeMsg: c.wWidth = msg.Width c.wHeight = msg.Height + c.width = min(90, c.wWidth) + c.height = min(15, c.wHeight) + for i := range c.inputs { + c.inputs[i].SetWidth(c.width - (paddingHorizontal * 2)) + } case tea.KeyPressMsg: switch { + case key.Matches(msg, c.keys.Close): + return c, util.CmdHandler(dialogs.CloseDialogMsg{}) case key.Matches(msg, c.keys.Confirm): - if c.focusIndex == len(c.inputs)-1 { - content := c.content - for i, name := range c.argNames { + if c.focused == len(c.inputs)-1 { + args := make(map[string]string) + for i, arg := range c.arguments { value := c.inputs[i].Value() - placeholder := "$" + name - content = strings.ReplaceAll(content, placeholder, value) + args[arg.Name] = value } return c, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), - util.CmdHandler(CommandRunCustomMsg{ - Content: content, - }), + c.onSubmit(args), ) } // Otherwise, move to the next input - c.inputs[c.focusIndex].Blur() - c.focusIndex++ - c.inputs[c.focusIndex].Focus() + c.inputs[c.focused].Blur() + c.focused++ + c.inputs[c.focused].Focus() case key.Matches(msg, c.keys.Next): // Move to the next input - c.inputs[c.focusIndex].Blur() - c.focusIndex = (c.focusIndex + 1) % len(c.inputs) - c.inputs[c.focusIndex].Focus() + c.inputs[c.focused].Blur() + c.focused = (c.focused + 1) % len(c.inputs) + c.inputs[c.focused].Focus() case key.Matches(msg, c.keys.Previous): // Move to the previous input - c.inputs[c.focusIndex].Blur() - c.focusIndex = (c.focusIndex - 1 + len(c.inputs)) % len(c.inputs) - c.inputs[c.focusIndex].Focus() + c.inputs[c.focused].Blur() + c.focused = (c.focused - 1 + len(c.inputs)) % len(c.inputs) + c.inputs[c.focused].Focus() case key.Matches(msg, c.keys.Close): return c, util.CmdHandler(dialogs.CloseDialogMsg{}) default: var cmd tea.Cmd - c.inputs[c.focusIndex], cmd = c.inputs[c.focusIndex].Update(msg) + c.inputs[c.focused], cmd = c.inputs[c.focused].Update(msg) return c, cmd } case tea.PasteMsg: var cmd tea.Cmd - c.inputs[c.focusIndex], cmd = c.inputs[c.focusIndex].Update(msg) + c.inputs[c.focused], cmd = c.inputs[c.focused].Update(msg) return c, cmd } return c, nil @@ -152,26 +171,28 @@ func (c *commandArgumentsDialogCmp) View() string { Foreground(t.Primary). Bold(true). Padding(0, 1). - Render("Command Arguments") + Render(cmp.Or(c.title, c.name)) - explanation := t.S().Text. + promptName := t.S().Text. Padding(0, 1). - Render("This command requires arguments.") + Render(c.description) - // Create input fields for each argument inputFields := make([]string, len(c.inputs)) for i, input := range c.inputs { - // Highlight the label of the focused input - labelStyle := baseStyle. - Padding(1, 1, 0, 1) + labelStyle := baseStyle.Padding(1, 1, 0, 1) - if i == c.focusIndex { + if i == c.focused { labelStyle = labelStyle.Foreground(t.FgBase).Bold(true) } else { labelStyle = labelStyle.Foreground(t.FgMuted) } - label := labelStyle.Render(c.argNames[i] + ":") + arg := c.arguments[i] + argName := cmp.Or(arg.Title, arg.Name) + if arg.Required { + argName += "*" + } + label := labelStyle.Render(argName + ":") field := t.S().Text. Padding(0, 1). @@ -180,18 +201,14 @@ func (c *commandArgumentsDialogCmp) View() string { inputFields[i] = lipgloss.JoinVertical(lipgloss.Left, label, field) } - // Join all elements vertically - elements := []string{title, explanation} + elements := []string{title, promptName} elements = append(elements, inputFields...) c.help.ShowAll = false helpText := baseStyle.Padding(0, 1).Render(c.help.View(c.keys)) elements = append(elements, "", helpText) - content := lipgloss.JoinVertical( - lipgloss.Left, - elements..., - ) + content := lipgloss.JoinVertical(lipgloss.Left, elements...) return baseStyle.Padding(1, 1, 0, 1). Border(lipgloss.RoundedBorder()). @@ -201,26 +218,33 @@ func (c *commandArgumentsDialogCmp) View() string { } func (c *commandArgumentsDialogCmp) Cursor() *tea.Cursor { - cursor := c.inputs[c.focusIndex].Cursor() + if len(c.inputs) == 0 { + return nil + } + cursor := c.inputs[c.focused].Cursor() if cursor != nil { cursor = c.moveCursor(cursor) } return cursor } +const ( + headerHeight = 3 + itemHeight = 3 + paddingHorizontal = 3 +) + func (c *commandArgumentsDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor { row, col := c.Position() - offset := row + 3 + (1+c.focusIndex)*3 + offset := row + headerHeight + (1+c.focused)*itemHeight cursor.Y += offset - cursor.X = cursor.X + col + 3 + cursor.X = cursor.X + col + paddingHorizontal return cursor } func (c *commandArgumentsDialogCmp) Position() (int, int) { - row := c.wHeight / 2 - row -= c.wHeight / 2 - col := c.wWidth / 2 - col -= c.width / 2 + row := (c.wHeight / 2) - (c.height / 2) + col := (c.wWidth / 2) - (c.width / 2) return row, col } diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index 664158fc3..747184914 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -2,6 +2,8 @@ package commands import ( "os" + "slices" + "strings" "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" @@ -10,7 +12,10 @@ import ( "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/llm/prompt" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" @@ -25,9 +30,14 @@ const ( defaultWidth int = 70 ) +type commandType uint + +func (c commandType) String() string { return []string{"System", "User", "MCP"}[c] } + const ( - SystemCommands int = iota + SystemCommands commandType = iota UserCommands + MCPPrompts ) type listModel = list.FilterableList[list.CompletionItem[Command]] @@ -54,9 +64,10 @@ type commandDialogCmp struct { commandList listModel keyMap CommandsDialogKeyMap help help.Model - commandType int // SystemCommands or UserCommands - userCommands []Command // User-defined commands - sessionID string // Current session ID + selected commandType // Selected SystemCommands, UserCommands, or MCPPrompts + userCommands []Command // User-defined commands + mcpPrompts *csync.Slice[Command] // MCP prompts + sessionID string // Current session ID } type ( @@ -102,8 +113,9 @@ func NewCommandDialog(sessionID string) CommandsDialog { width: defaultWidth, keyMap: DefaultCommandsDialogKeyMap(), help: help, - commandType: SystemCommands, + selected: SystemCommands, sessionID: sessionID, + mcpPrompts: csync.NewSlice[Command](), } } @@ -113,7 +125,8 @@ func (c *commandDialogCmp) Init() tea.Cmd { return util.ReportError(err) } c.userCommands = commands - return c.SetCommandType(c.commandType) + c.mcpPrompts.SetSlice(loadMCPPrompts()) + return c.setCommandType(c.selected) } func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -122,9 +135,19 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { c.wWidth = msg.Width c.wHeight = msg.Height return c, tea.Batch( - c.SetCommandType(c.commandType), + c.setCommandType(c.selected), c.commandList.SetSize(c.listWidth(), c.listHeight()), ) + case pubsub.Event[agent.MCPEvent]: + // Reload MCP prompts when MCP state changes + if msg.Type == pubsub.UpdatedEvent { + c.mcpPrompts.SetSlice(loadMCPPrompts()) + // If we're currently viewing MCP prompts, refresh the list + if c.selected == MCPPrompts { + return c, c.setCommandType(MCPPrompts) + } + return c, nil + } case tea.KeyPressMsg: switch { case key.Matches(msg, c.keyMap.Select): @@ -138,15 +161,10 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { command.Handler(command), ) case key.Matches(msg, c.keyMap.Tab): - if len(c.userCommands) == 0 { + if len(c.userCommands) == 0 && c.mcpPrompts.Len() == 0 { return c, nil } - // Toggle command type between System and User commands - if c.commandType == SystemCommands { - return c, c.SetCommandType(UserCommands) - } else { - return c, c.SetCommandType(SystemCommands) - } + return c, c.setCommandType(c.next()) case key.Matches(msg, c.keyMap.Close): return c, util.CmdHandler(dialogs.CloseDialogMsg{}) default: @@ -158,13 +176,35 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return c, nil } +func (c *commandDialogCmp) next() commandType { + switch c.selected { + case SystemCommands: + if len(c.userCommands) > 0 { + return UserCommands + } + if c.mcpPrompts.Len() > 0 { + return MCPPrompts + } + fallthrough + case UserCommands: + if c.mcpPrompts.Len() > 0 { + return MCPPrompts + } + fallthrough + case MCPPrompts: + return SystemCommands + default: + return SystemCommands + } +} + func (c *commandDialogCmp) View() string { t := styles.CurrentTheme() listView := c.commandList radio := c.commandTypeRadio() header := t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-lipgloss.Width(radio)-5) + " " + radio) - if len(c.userCommands) == 0 { + if len(c.userCommands) == 0 && c.mcpPrompts.Len() == 0 { header = t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-4)) } content := lipgloss.JoinVertical( @@ -190,27 +230,41 @@ func (c *commandDialogCmp) Cursor() *tea.Cursor { func (c *commandDialogCmp) commandTypeRadio() string { t := styles.CurrentTheme() - choices := []string{"System", "User"} - iconSelected := "◉" - iconUnselected := "○" - if c.commandType == SystemCommands { - return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1]) + + fn := func(i commandType) string { + if i == c.selected { + return "◉ " + i.String() + } + return "○ " + i.String() + } + + parts := []string{ + fn(SystemCommands), + } + if len(c.userCommands) > 0 { + parts = append(parts, fn(UserCommands)) + } + if c.mcpPrompts.Len() > 0 { + parts = append(parts, fn(MCPPrompts)) } - return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1]) + return t.S().Base.Foreground(t.FgHalfMuted).Render(strings.Join(parts, " ")) } func (c *commandDialogCmp) listWidth() int { return defaultWidth - 2 // 4 for padding } -func (c *commandDialogCmp) SetCommandType(commandType int) tea.Cmd { - c.commandType = commandType +func (c *commandDialogCmp) setCommandType(commandType commandType) tea.Cmd { + c.selected = commandType var commands []Command - if c.commandType == SystemCommands { + switch c.selected { + case SystemCommands: commands = c.defaultCommands() - } else { + case UserCommands: commands = c.userCommands + case MCPPrompts: + commands = slices.Collect(c.mcpPrompts.Seq()) } commandItems := []list.CompletionItem[Command]{} diff --git a/internal/tui/components/dialogs/commands/loader.go b/internal/tui/components/dialogs/commands/loader.go index 74d9c7e4b..31bbee94e 100644 --- a/internal/tui/components/dialogs/commands/loader.go +++ b/internal/tui/components/dialogs/commands/loader.go @@ -1,22 +1,28 @@ package commands import ( + "context" "fmt" "io/fs" + "log/slog" "os" "path/filepath" "regexp" "strings" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/util" ) const ( - UserCommandPrefix = "user:" - ProjectCommandPrefix = "project:" + userCommandPrefix = "user:" + projectCommandPrefix = "project:" ) var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`) @@ -50,7 +56,7 @@ func buildCommandSources(cfg *config.Config) []commandSource { if dir := getXDGCommandsDir(); dir != "" { sources = append(sources, commandSource{ path: dir, - prefix: UserCommandPrefix, + prefix: userCommandPrefix, }) } @@ -58,14 +64,14 @@ func buildCommandSources(cfg *config.Config) []commandSource { if home := home.Dir(); home != "" { sources = append(sources, commandSource{ path: filepath.Join(home, ".crush", "commands"), - prefix: UserCommandPrefix, + prefix: userCommandPrefix, }) } // Project directory sources = append(sources, commandSource{ path: filepath.Join(cfg.Options.DataDirectory, "commands"), - prefix: ProjectCommandPrefix, + prefix: projectCommandPrefix, }) return sources @@ -127,12 +133,13 @@ func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, erro } id := buildCommandID(path, baseDir, prefix) + desc := fmt.Sprintf("Custom command from %s", filepath.Base(path)) return Command{ ID: id, Title: id, - Description: fmt.Sprintf("Custom command from %s", filepath.Base(path)), - Handler: createCommandHandler(id, string(content)), + Description: desc, + Handler: createCommandHandler(id, desc, string(content)), }, nil } @@ -149,21 +156,35 @@ func buildCommandID(path, baseDir, prefix string) string { return prefix + strings.Join(parts, ":") } -func createCommandHandler(id string, content string) func(Command) tea.Cmd { +func createCommandHandler(id, desc, content string) func(Command) tea.Cmd { return func(cmd Command) tea.Cmd { args := extractArgNames(content) - if len(args) > 0 { - return util.CmdHandler(ShowArgumentsDialogMsg{ - CommandID: id, - Content: content, - ArgNames: args, + if len(args) == 0 { + return util.CmdHandler(CommandRunCustomMsg{ + Content: content, }) } + return util.CmdHandler(ShowArgumentsDialogMsg{ + CommandID: id, + Description: desc, + ArgNames: args, + OnSubmit: func(args map[string]string) tea.Cmd { + return execUserPrompt(content, args) + }, + }) + } +} - return util.CmdHandler(CommandRunCustomMsg{ +func execUserPrompt(content string, args map[string]string) tea.Cmd { + return func() tea.Msg { + for name, value := range args { + placeholder := "$" + name + content = strings.ReplaceAll(content, placeholder, value) + } + return CommandRunCustomMsg{ Content: content, - }) + } } } @@ -201,3 +222,67 @@ func isMarkdownFile(name string) bool { type CommandRunCustomMsg struct { Content string } + +func loadMCPPrompts() []Command { + prompts := agent.GetMCPPrompts() + commands := make([]Command, 0, len(prompts)) + + for key, prompt := range prompts { + clientName, promptName, ok := strings.Cut(key, ":") + if !ok { + slog.Warn("prompt not found", "key", key) + continue + } + commands = append(commands, Command{ + ID: key, + Title: clientName + ":" + promptName, + Description: prompt.Description, + Handler: createMCPPromptHandler(clientName, promptName, prompt), + }) + } + + return commands +} + +func createMCPPromptHandler(clientName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd { + return func(cmd Command) tea.Cmd { + if len(prompt.Arguments) == 0 { + return execMCPPrompt(clientName, promptName, nil) + } + return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{ + Prompt: prompt, + OnSubmit: func(args map[string]string) tea.Cmd { + return execMCPPrompt(clientName, promptName, args) + }, + }) + } +} + +func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd { + return func() tea.Msg { + ctx := context.Background() + result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, args) + if err != nil { + return util.ReportError(err) + } + + var content strings.Builder + for _, msg := range result.Messages { + if msg.Role == "user" { + if textContent, ok := msg.Content.(*mcp.TextContent); ok { + content.WriteString(textContent.Text) + content.WriteString("\n") + } + } + } + + return chat.SendMsg{ + Text: content.String(), + } + } +} + +type ShowMCPPromptArgumentsDialogMsg struct { + Prompt *mcp.Prompt + OnSubmit func(arg map[string]string) tea.Cmd +} diff --git a/internal/tui/components/mcp/mcp.go b/internal/tui/components/mcp/mcp.go index fd3bd0127..91afa66c1 100644 --- a/internal/tui/components/mcp/mcp.go +++ b/internal/tui/components/mcp/mcp.go @@ -2,6 +2,7 @@ package mcp import ( "fmt" + "strings" "github.com/charmbracelet/lipgloss/v2" @@ -56,7 +57,7 @@ func RenderMCPList(opts RenderOptions) []string { // Determine icon and color based on state icon := t.ItemOfflineIcon description := "" - extraContent := "" + extraContent := []string{} if state, exists := mcpStates[l.Name]; exists { switch state.State { @@ -67,8 +68,11 @@ func RenderMCPList(opts RenderOptions) []string { description = t.S().Subtle.Render("starting...") case agent.MCPStateConnected: icon = t.ItemOnlineIcon - if state.ToolCount > 0 { - extraContent = t.S().Subtle.Render(fmt.Sprintf("%d tools", state.ToolCount)) + if count := state.Counts.Tools; count > 0 { + extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", count))) + } + if count := state.Counts.Prompts; count > 0 { + extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", count))) } case agent.MCPStateError: icon = t.ItemErrorIcon @@ -88,7 +92,7 @@ func RenderMCPList(opts RenderOptions) []string { Icon: icon.String(), Title: l.Name, Description: description, - ExtraContent: extraContent, + ExtraContent: strings.Join(extraContent, " "), }, opts.MaxWidth, ), diff --git a/internal/tui/exp/list/filterable.go b/internal/tui/exp/list/filterable.go index e639786db..7e0075318 100644 --- a/internal/tui/exp/list/filterable.go +++ b/internal/tui/exp/list/filterable.go @@ -276,7 +276,7 @@ func (f *filterableList[T]) Filter(query string) tea.Cmd { func (f *filterableList[T]) SetItems(items []T) tea.Cmd { f.items = items - return f.list.SetItems(items) + return f.Filter(f.query) } func (f *filterableList[T]) Cursor() *tea.Cursor { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 26d23f46e..8252950c9 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -34,6 +34,8 @@ import ( "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" "github.com/charmbracelet/lipgloss/v2" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) var lastMouseEvent time.Time @@ -138,15 +140,44 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.dialog = u.(dialogs.DialogCmp) return a, tea.Batch(completionCmd, dialogCmd) case commands.ShowArgumentsDialogMsg: + var args []commands.Argument + for _, arg := range msg.ArgNames { + args = append(args, commands.Argument{ + Name: arg, + Title: cases.Title(language.English).String(arg), + Required: true, + }) + } return a, util.CmdHandler( dialogs.OpenDialogMsg{ Model: commands.NewCommandArgumentsDialog( msg.CommandID, - msg.Content, - msg.ArgNames, + msg.CommandID, + msg.CommandID, + msg.Description, + args, + msg.OnSubmit, ), }, ) + case commands.ShowMCPPromptArgumentsDialogMsg: + args := make([]commands.Argument, 0, len(msg.Prompt.Arguments)) + for _, arg := range msg.Prompt.Arguments { + args = append(args, commands.Argument(*arg)) + } + dialog := commands.NewCommandArgumentsDialog( + msg.Prompt.Name, + msg.Prompt.Title, + msg.Prompt.Name, + msg.Prompt.Description, + args, + msg.OnSubmit, + ) + return a, util.CmdHandler( + dialogs.OpenDialogMsg{ + Model: dialog, + }, + ) // Page change messages case page.PageChangeMsg: return a, a.moveToPage(msg.ID)