diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index e3f1824811..dab4762fb8 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -242,6 +242,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers contextBuilder.SetMCPManager(mcpManager) } + // Config-defined API endpoints (safe HTTP allowlist) + if len(cfg.Tools.APIs) > 0 { + apiCallTool := tools.NewAPICallTool(cfg.Tools.APIs) + toolsRegistry.Register(apiCallTool) + subagentTools.Register(apiCallTool) + } + return &AgentLoop{ bus: msgBus, provider: provider, diff --git a/pkg/config/config.go b/pkg/config/config.go index 48d3413efa..48f10f5b3f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -167,6 +167,28 @@ type ExecToolsConfig struct { Enabled bool `json:"enabled" label:"Enabled" env:"CLAWDROID_TOOLS_EXEC_ENABLED"` } +// HTTPParam describes a parameter accepted by a user-defined API endpoint. +type HTTPParam struct { + Name string `json:"name"` + In string `json:"in"` // "query", "body", or "path" + Description string `json:"description"` // shown to the LLM + Required bool `json:"required"` +} + +// APIEndpointConfig defines a single pre-approved HTTP API endpoint. +// The LLM can only call endpoints registered here; it cannot reach +// arbitrary URLs. Headers (e.g. auth tokens) are stored in config +// and are never exposed to the LLM. +type APIEndpointConfig struct { + Name string `json:"name"` + Description string `json:"description"` + URL string `json:"url"` + Method string `json:"method"` // GET, POST, PUT, DELETE, PATCH + Headers map[string]string `json:"headers,omitempty"` // fixed request headers (invisible to LLM) + Params []HTTPParam `json:"params,omitempty"` // parameter schema exposed to LLM + Timeout int `json:"timeout,omitempty"` // seconds; 0 → 30 s default +} + type MCPServerConfig struct { // Stdio transport Command string `json:"command,omitempty"` @@ -363,6 +385,7 @@ type ToolsConfig struct { Android AndroidToolsConfig `json:"android" label:"Android"` Memory MemoryToolsConfig `json:"memory" label:"Memory"` MCP map[string]MCPServerConfig `json:"mcp,omitempty" label:"MCP Servers"` + APIs []APIEndpointConfig `json:"apis,omitempty" label:"Custom APIs"` } func DefaultConfig() *Config { diff --git a/pkg/tools/api_call.go b/pkg/tools/api_call.go new file mode 100644 index 0000000000..eefbf8e6c1 --- /dev/null +++ b/pkg/tools/api_call.go @@ -0,0 +1,289 @@ +package tools + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/KarakuriAgent/clawdroid/pkg/config" +) + +// APICallTool exposes a curated set of HTTP endpoints to the LLM. +// Each endpoint is defined in the config (tools.apis), so the LLM can +// only reach pre-approved URLs. Fixed headers such as auth tokens are +// stored in the config and are invisible to the LLM. +type APICallTool struct { + endpoints []config.APIEndpointConfig +} + +// NewAPICallTool creates an APICallTool from a slice of configured endpoints. +func NewAPICallTool(endpoints []config.APIEndpointConfig) *APICallTool { + return &APICallTool{endpoints: endpoints} +} + +// IsActive implements ActivatableTool: hide this tool when no endpoints +// are configured so it does not clutter the LLM's tool list. +func (t *APICallTool) IsActive() bool { + return len(t.endpoints) > 0 +} + +func (t *APICallTool) Name() string { + return "api_call" +} + +func (t *APICallTool) Description() string { + if len(t.endpoints) == 0 { + return "Call a pre-configured API endpoint." + } + + var sb strings.Builder + sb.WriteString("Call one of the pre-configured API endpoints listed below.\n") + sb.WriteString("Fixed authentication headers are applied automatically – do not pass them as params.\n\n") + sb.WriteString("Available endpoints:\n") + for _, ep := range t.endpoints { + method := ep.Method + if method == "" { + method = "GET" + } + sb.WriteString(fmt.Sprintf("- %s [%s %s]: %s\n", ep.Name, strings.ToUpper(method), ep.URL, ep.Description)) + } + return sb.String() +} + +func (t *APICallTool) Parameters() map[string]interface{} { + names := make([]interface{}, 0, len(t.endpoints)) + for _, ep := range t.endpoints { + names = append(names, ep.Name) + } + + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "api_name": map[string]interface{}{ + "type": "string", + "description": "Name of the API endpoint to call (must be one of the configured endpoints)", + "enum": names, + }, + "params": map[string]interface{}{ + "type": "object", + "description": "Parameters to pass to the endpoint (query string, path, or request body – depends on endpoint definition)", + "additionalProperties": true, + }, + }, + "required": []string{"api_name"}, + } +} + +func (t *APICallTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + apiName, ok := args["api_name"].(string) + if !ok || apiName == "" { + return ErrorResult("api_name is required") + } + + // Find the endpoint in the allowlist. + ep := t.findEndpoint(apiName) + if ep == nil { + return ErrorResult(fmt.Sprintf("unknown api_name %q – must be one of the configured endpoints", apiName)) + } + + // Collect caller-supplied params (may be nil / absent). + var params map[string]interface{} + if p, ok := args["params"].(map[string]interface{}); ok { + params = p + } else { + params = map[string]interface{}{} + } + + // Validate required params. + for _, pd := range ep.Params { + if pd.Required { + if _, exists := params[pd.Name]; !exists { + return ErrorResult(fmt.Sprintf("missing required parameter %q for endpoint %q", pd.Name, apiName)) + } + } + } + + // Build the request URL (expand {name} path placeholders). + rawURL, err := t.buildURL(ep, params) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to build URL: %v", err)) + } + + // Separate remaining params by "in" location. + queryParams, bodyParams := t.splitParams(ep, params) + + // Append query parameters. + if len(queryParams) > 0 { + qv := url.Values{} + for k, v := range queryParams { + qv.Set(k, fmt.Sprintf("%v", v)) + } + separator := "?" + if strings.Contains(rawURL, "?") { + separator = "&" + } + rawURL += separator + qv.Encode() + } + + // Build request body. + var bodyReader io.Reader + method := strings.ToUpper(ep.Method) + if method == "" { + method = "GET" + } + if len(bodyParams) > 0 { + bodyJSON, err := json.Marshal(bodyParams) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to marshal request body: %v", err)) + } + bodyReader = bytes.NewReader(bodyJSON) + } + + req, err := http.NewRequestWithContext(ctx, method, rawURL, bodyReader) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) + } + + // Apply fixed headers from config (these are invisible to the LLM). + for k, v := range ep.Headers { + req.Header.Set(k, v) + } + if len(bodyParams) > 0 && req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/json") + } + + // Set timeout. + timeout := time.Duration(ep.Timeout) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + client := &http.Client{Timeout: timeout} + + resp, err := client.Do(req) + if err != nil { + return ErrorResult(fmt.Sprintf("request to %q failed: %v", apiName, err)) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to read response from %q: %v", apiName, err)) + } + + // Try to pretty-print JSON responses. + var prettyBody string + var jsonData interface{} + if json.Unmarshal(respBody, &jsonData) == nil { + formatted, _ := json.MarshalIndent(jsonData, "", " ") + prettyBody = string(formatted) + } else { + prettyBody = string(respBody) + } + + summary := fmt.Sprintf("API %q responded with status %d (%d bytes)", apiName, resp.StatusCode, len(respBody)) + result := map[string]interface{}{ + "api_name": apiName, + "status_code": resp.StatusCode, + "body": prettyBody, + } + resultJSON, _ := json.MarshalIndent(result, "", " ") + + return &ToolResult{ + ForLLM: fmt.Sprintf("%s\n%s", summary, string(resultJSON)), + ForUser: string(resultJSON), + IsError: resp.StatusCode >= 400, + } +} + +// findEndpoint returns the endpoint config whose name matches, or nil. +func (t *APICallTool) findEndpoint(name string) *config.APIEndpointConfig { + for i := range t.endpoints { + if t.endpoints[i].Name == name { + return &t.endpoints[i] + } + } + return nil +} + +// buildURL expands {param} placeholders in the URL template with path params, +// returning the URL without query-string parameters. +func (t *APICallTool) buildURL(ep *config.APIEndpointConfig, params map[string]interface{}) (string, error) { + rawURL := ep.URL + for _, pd := range ep.Params { + if pd.In != "path" { + continue + } + val, exists := params[pd.Name] + if !exists { + if pd.Required { + return "", fmt.Errorf("missing required path parameter %q", pd.Name) + } + continue + } + placeholder := "{" + pd.Name + "}" + rawURL = strings.ReplaceAll(rawURL, placeholder, url.PathEscape(fmt.Sprintf("%v", val))) + } + + // Validate the resulting URL is still http/https and has a host. + parsed, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("invalid URL after expansion: %w", err) + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", fmt.Errorf("only http/https endpoints are allowed") + } + if parsed.Host == "" { + return "", fmt.Errorf("missing host in URL") + } + + return rawURL, nil +} + +// splitParams separates caller params into query-string and body buckets, +// skipping path params that have already been interpolated. +func (t *APICallTool) splitParams(ep *config.APIEndpointConfig, params map[string]interface{}) ( + queryParams map[string]interface{}, + bodyParams map[string]interface{}, +) { + queryParams = map[string]interface{}{} + bodyParams = map[string]interface{}{} + + // Build a lookup for the declared "in" location. + inMap := map[string]string{} + for _, pd := range ep.Params { + inMap[pd.Name] = pd.In + } + + method := strings.ToUpper(ep.Method) + if method == "" { + method = "GET" + } + supportsBody := method == "POST" || method == "PUT" || method == "PATCH" + + for k, v := range params { + location := inMap[k] // empty string if not declared + + switch location { + case "path": + // Already expanded – skip. + case "body": + bodyParams[k] = v + case "query": + queryParams[k] = v + default: + // Undeclared param: use body for POST/PUT/PATCH, else query. + if supportsBody { + bodyParams[k] = v + } else { + queryParams[k] = v + } + } + } + return queryParams, bodyParams +} diff --git a/pkg/tools/api_call_test.go b/pkg/tools/api_call_test.go new file mode 100644 index 0000000000..b9c40b0a97 --- /dev/null +++ b/pkg/tools/api_call_test.go @@ -0,0 +1,353 @@ +package tools + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/KarakuriAgent/clawdroid/pkg/config" +) + +// helpers + +func newTestEndpoint(name, method, rawURL string, headers map[string]string, params []config.HTTPParam) config.APIEndpointConfig { + return config.APIEndpointConfig{ + Name: name, + Description: "test endpoint", + URL: rawURL, + Method: method, + Headers: headers, + Params: params, + Timeout: 5, + } +} + +// TestAPICallTool_IsActive checks that the tool is inactive when no endpoints +// are registered. +func TestAPICallTool_IsActive(t *testing.T) { + t.Run("no endpoints", func(t *testing.T) { + tool := NewAPICallTool(nil) + if tool.IsActive() { + t.Error("expected IsActive() == false when no endpoints are configured") + } + }) + + t.Run("with endpoints", func(t *testing.T) { + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("ep1", "GET", "http://example.com", nil, nil), + }) + if !tool.IsActive() { + t.Error("expected IsActive() == true when endpoints are configured") + } + }) +} + +// TestAPICallTool_Name checks the tool name. +func TestAPICallTool_Name(t *testing.T) { + tool := NewAPICallTool(nil) + if tool.Name() != "api_call" { + t.Errorf("expected name 'api_call', got %q", tool.Name()) + } +} + +// TestAPICallTool_Parameters_Enum verifies that the enum in the parameters +// schema matches the configured endpoint names. +func TestAPICallTool_Parameters_Enum(t *testing.T) { + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("ep1", "GET", "http://example.com", nil, nil), + newTestEndpoint("ep2", "POST", "http://example.com/post", nil, nil), + }) + + params := tool.Parameters() + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Fatal("parameters should have a 'properties' map") + } + apiNameProp, ok := props["api_name"].(map[string]interface{}) + if !ok { + t.Fatal("parameters should have an 'api_name' property") + } + enum, ok := apiNameProp["enum"].([]interface{}) + if !ok { + t.Fatal("api_name property should have an 'enum' field") + } + if len(enum) != 2 { + t.Fatalf("expected 2 enum values, got %d", len(enum)) + } +} + +// TestAPICallTool_Execute_MissingAPIName checks that an error is returned when +// api_name is absent. +func TestAPICallTool_Execute_MissingAPIName(t *testing.T) { + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("ep1", "GET", "http://example.com", nil, nil), + }) + result := tool.Execute(context.Background(), map[string]interface{}{}) + if !result.IsError { + t.Error("expected error when api_name is missing") + } + if !strings.Contains(result.ForLLM, "api_name") { + t.Errorf("expected error to mention 'api_name', got: %s", result.ForLLM) + } +} + +// TestAPICallTool_Execute_UnknownAPIName checks that an error is returned for +// an unrecognised api_name. +func TestAPICallTool_Execute_UnknownAPIName(t *testing.T) { + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("ep1", "GET", "http://example.com", nil, nil), + }) + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "nonexistent", + }) + if !result.IsError { + t.Error("expected error for unknown api_name") + } + if !strings.Contains(result.ForLLM, "unknown api_name") { + t.Errorf("expected 'unknown api_name' in error, got: %s", result.ForLLM) + } +} + +// TestAPICallTool_Execute_GET performs a successful GET request. +func TestAPICallTool_Execute_GET(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "wrong method", http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"status":"ok"}`) + })) + defer server.Close() + + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("test_get", "GET", server.URL+"/api", nil, nil), + }) + + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "test_get", + }) + + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "200") { + t.Errorf("expected status 200 in ForLLM, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForUser, "ok") { + t.Errorf("expected response body in ForUser, got: %s", result.ForUser) + } +} + +// TestAPICallTool_Execute_POST_BodyParams sends body params and checks that +// they arrive in the request body. +func TestAPICallTool_Execute_POST_BodyParams(t *testing.T) { + var receivedBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "wrong method", http.StatusMethodNotAllowed) + return + } + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"received":true}`) + })) + defer server.Close() + + params := []config.HTTPParam{ + {Name: "message", In: "body", Description: "msg", Required: true}, + } + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("test_post", "POST", server.URL+"/api", nil, params), + }) + + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "test_post", + "params": map[string]interface{}{ + "message": "hello", + }, + }) + + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + + var body map[string]interface{} + if err := json.Unmarshal(receivedBody, &body); err != nil { + t.Fatalf("server received non-JSON body: %s", receivedBody) + } + if body["message"] != "hello" { + t.Errorf("expected body param 'message'='hello', got: %v", body) + } +} + +// TestAPICallTool_Execute_QueryParams verifies that query params are appended +// to the URL for GET requests. +func TestAPICallTool_Execute_QueryParams(t *testing.T) { + var receivedQuery url.Values + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedQuery = r.URL.Query() + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{}`) + })) + defer server.Close() + + params := []config.HTTPParam{ + {Name: "q", In: "query", Description: "search query", Required: true}, + } + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("test_query", "GET", server.URL+"/search", nil, params), + }) + + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "test_query", + "params": map[string]interface{}{ + "q": "clawdroid", + }, + }) + + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + if receivedQuery.Get("q") != "clawdroid" { + t.Errorf("expected query param q=clawdroid, got: %v", receivedQuery) + } +} + +// TestAPICallTool_Execute_PathParams verifies that path params are interpolated +// into the URL template. +func TestAPICallTool_Execute_PathParams(t *testing.T) { + var receivedPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{}`) + })) + defer server.Close() + + params := []config.HTTPParam{ + {Name: "id", In: "path", Description: "item id", Required: true}, + } + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("test_path", "GET", server.URL+"/items/{id}", nil, params), + }) + + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "test_path", + "params": map[string]interface{}{ + "id": "42", + }, + }) + + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + if receivedPath != "/items/42" { + t.Errorf("expected path /items/42, got: %s", receivedPath) + } +} + +// TestAPICallTool_Execute_FixedHeaders verifies that fixed headers from config +// are sent with the request and cannot be overridden by LLM params. +func TestAPICallTool_Execute_FixedHeaders(t *testing.T) { + var receivedAuthHeader string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{}`) + })) + defer server.Close() + + headers := map[string]string{ + "Authorization": "Bearer secret-token", + } + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("test_auth", "GET", server.URL+"/secure", headers, nil), + }) + + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "test_auth", + }) + + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + if receivedAuthHeader != "Bearer secret-token" { + t.Errorf("expected auth header to be sent, got: %q", receivedAuthHeader) + } +} + +// TestAPICallTool_Execute_MissingRequiredParam checks that a missing required +// param produces an error. +func TestAPICallTool_Execute_MissingRequiredParam(t *testing.T) { + params := []config.HTTPParam{ + {Name: "required_field", In: "query", Description: "must be set", Required: true}, + } + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("test_required", "GET", "http://example.com/api", nil, params), + }) + + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "test_required", + // params omitted entirely + }) + + if !result.IsError { + t.Error("expected error for missing required param") + } + if !strings.Contains(result.ForLLM, "required_field") { + t.Errorf("expected error to mention 'required_field', got: %s", result.ForLLM) + } +} + +// TestAPICallTool_Execute_HTTPError checks that 4xx responses set IsError=true. +func TestAPICallTool_Execute_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) + })) + defer server.Close() + + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("test_404", "GET", server.URL+"/missing", nil, nil), + }) + + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "test_404", + }) + + if !result.IsError { + t.Error("expected IsError=true for 404 response") + } + if !strings.Contains(result.ForLLM, "404") { + t.Errorf("expected 404 in ForLLM, got: %s", result.ForLLM) + } +} + +// TestAPICallTool_Execute_NoParams verifies a simple GET with no params works. +func TestAPICallTool_Execute_NoParams(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"hello":"world"}`) + })) + defer server.Close() + + tool := NewAPICallTool([]config.APIEndpointConfig{ + newTestEndpoint("simple", "GET", server.URL, nil, nil), + }) + + result := tool.Execute(context.Background(), map[string]interface{}{ + "api_name": "simple", + }) + + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + if !strings.Contains(result.ForUser, "world") { + t.Errorf("expected 'world' in ForUser, got: %s", result.ForUser) + } +}