diff --git a/pkg/config/config.go b/pkg/config/config.go index 16559a2df..a3485f313 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -57,6 +57,7 @@ type Config struct { Tools ToolsConfig `json:"tools"` Heartbeat HeartbeatConfig `json:"heartbeat"` Devices DevicesConfig `json:"devices"` + Security SecurityConfig `json:"security,omitempty"` } // MarshalJSON implements custom JSON marshaling for Config @@ -316,6 +317,57 @@ type DevicesConfig struct { MonitorUSB bool `json:"monitor_usb" env:"PICOCLAW_DEVICES_MONITOR_USB"` } +// SecurityConfig holds all security-related configuration. +type SecurityConfig struct { + SSRF SSRFConfig `json:"ssrf"` + AuditLogging AuditLoggingConfig `json:"audit_logging"` + RateLimiting RateLimitingConfig `json:"rate_limiting"` + CredentialEncryption CredentialEncryptionConfig `json:"credential_encryption"` + PromptInjection PromptInjectionConfig `json:"prompt_injection"` +} + +// SSRFConfig configures Server-Side Request Forgery protection. +type SSRFConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_SECURITY_SSRF_ENABLED"` + BlockPrivateIPs bool `json:"block_private_ips" env:"PICOCLAW_SECURITY_SSRF_BLOCK_PRIVATE_IPS"` + BlockMetadataEndpoints bool `json:"block_metadata_endpoints" env:"PICOCLAW_SECURITY_SSRF_BLOCK_METADATA_ENDPOINTS"` + BlockLocalhost bool `json:"block_localhost" env:"PICOCLAW_SECURITY_SSRF_BLOCK_LOCALHOST"` + AllowedHosts []string `json:"allowed_hosts"` + DNSRebindingProtection bool `json:"dns_rebinding_protection" env:"PICOCLAW_SECURITY_SSRF_DNS_REBINDING_PROTECTION"` +} + +// AuditLoggingConfig configures audit logging for security events. +type AuditLoggingConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_SECURITY_AUDIT_ENABLED"` + LogToolExecutions bool `json:"log_tool_executions" env:"PICOCLAW_SECURITY_AUDIT_LOG_TOOL_EXECUTIONS"` + LogAuthEvents bool `json:"log_auth_events" env:"PICOCLAW_SECURITY_AUDIT_LOG_AUTH_EVENTS"` + LogConfigChanges bool `json:"log_config_changes" env:"PICOCLAW_SECURITY_AUDIT_LOG_CONFIG_CHANGES"` + RetentionDays int `json:"retention_days" env:"PICOCLAW_SECURITY_AUDIT_RETENTION_DAYS"` +} + +// RateLimitingConfig configures rate limiting for API and tool usage. +type RateLimitingConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_SECURITY_RATELIMIT_ENABLED"` + RequestsPerMinute int `json:"requests_per_minute" env:"PICOCLAW_SECURITY_RATELIMIT_REQUESTS_PER_MINUTE"` + ToolExecutionsPerMinute int `json:"tool_executions_per_minute" env:"PICOCLAW_SECURITY_RATELIMIT_TOOL_EXECUTIONS_PER_MINUTE"` + PerUserLimit bool `json:"per_user_limit" env:"PICOCLAW_SECURITY_RATELIMIT_PER_USER_LIMIT"` +} + +// CredentialEncryptionConfig configures how credentials are encrypted at rest. +type CredentialEncryptionConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_SECURITY_CRED_ENCRYPTION_ENABLED"` + UseKeychain bool `json:"use_keychain" env:"PICOCLAW_SECURITY_CRED_ENCRYPTION_USE_KEYCHAIN"` + Algorithm string `json:"algorithm" env:"PICOCLAW_SECURITY_CRED_ENCRYPTION_ALGORITHM"` +} + +// PromptInjectionConfig configures prompt injection defense mechanisms. +type PromptInjectionConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_SECURITY_PROMPT_INJECTION_ENABLED"` + SanitizeUserInput bool `json:"sanitize_user_input" env:"PICOCLAW_SECURITY_PROMPT_INJECTION_SANITIZE_USER_INPUT"` + DetectInjectionPatterns bool `json:"detect_injection_patterns" env:"PICOCLAW_SECURITY_PROMPT_INJECTION_DETECT_PATTERNS"` + CustomBlockPatterns []string `json:"custom_block_patterns"` +} + type ProvidersConfig struct { Anthropic ProviderConfig `json:"anthropic"` OpenAI OpenAIProviderConfig `json:"openai"` @@ -371,12 +423,11 @@ func (p ProvidersConfig) MarshalJSON() ([]byte, error) { } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` - Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` - RequestTimeout int `json:"request_timeout,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_REQUEST_TIMEOUT"` - AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` - ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` + ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc` } type OpenAIProviderConfig struct { @@ -407,7 +458,6 @@ type ModelConfig struct { // Optional optimizations RPM int `json:"rpm,omitempty"` // Requests per minute limit MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") - RequestTimeout int `json:"request_timeout,omitempty"` } // Validate checks if the ModelConfig has all required fields. diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index cf799140d..27d3360e8 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -21,7 +21,7 @@ func DefaultConfig() *Config { }, Bindings: []AgentBinding{}, Session: SessionConfig{ - DMScope: "per-channel-peer", + DMScope: "main", }, Channels: ChannelsConfig{ WhatsApp: WhatsAppConfig{ @@ -277,7 +277,6 @@ func DefaultConfig() *Config { }, Tools: ToolsConfig{ Web: WebToolsConfig{ - Proxy: "", Brave: BraveConfig{ Enabled: false, APIKey: "", @@ -321,5 +320,39 @@ func DefaultConfig() *Config { Enabled: false, MonitorUSB: true, }, + Security: SecurityConfig{ + SSRF: SSRFConfig{ + Enabled: true, + BlockPrivateIPs: true, + BlockMetadataEndpoints: true, + BlockLocalhost: true, + AllowedHosts: []string{}, + DNSRebindingProtection: true, + }, + AuditLogging: AuditLoggingConfig{ + Enabled: true, + LogToolExecutions: true, + LogAuthEvents: true, + LogConfigChanges: true, + RetentionDays: 30, + }, + RateLimiting: RateLimitingConfig{ + Enabled: false, // Off by default for single-user use + RequestsPerMinute: 60, + ToolExecutionsPerMinute: 30, + PerUserLimit: true, + }, + CredentialEncryption: CredentialEncryptionConfig{ + Enabled: true, + UseKeychain: true, + Algorithm: "chacha20-poly1305", + }, + PromptInjection: PromptInjectionConfig{ + Enabled: true, + SanitizeUserInput: true, + DetectInjectionPatterns: true, + CustomBlockPatterns: []string{}, + }, + }, } } diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 56dc87a53..35888a809 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -9,6 +9,8 @@ import ( "strings" "sync" "time" + + "github.com/sipeed/picoclaw/pkg/redaction" ) type LogLevel int @@ -34,6 +36,9 @@ var ( logger *Logger once sync.Once mu sync.RWMutex + + // redactionEnabled controls whether log messages are redacted for privacy + redactionEnabled = true ) type Logger struct { @@ -101,6 +106,14 @@ func logMessage(level LogLevel, component string, message string, fields map[str return } + // Apply redaction to message and fields for privacy + if redactionEnabled { + message = redaction.Redact(message) + if fields != nil { + fields = redaction.RedactFields(fields) + } + } + entry := LogEntry{ Level: logLevelNames[level], Timestamp: time.Now().UTC().Format(time.RFC3339), @@ -239,3 +252,22 @@ func FatalF(message string, fields map[string]any) { func FatalCF(component string, message string, fields map[string]any) { logMessage(FATAL, component, message, fields) } + +// SetRedactionEnabled enables or disables log redaction for privacy. +func SetRedactionEnabled(enabled bool) { + mu.Lock() + defer mu.Unlock() + redactionEnabled = enabled +} + +// IsRedactionEnabled returns whether log redaction is enabled. +func IsRedactionEnabled() bool { + mu.RLock() + defer mu.RUnlock() + return redactionEnabled +} + +// ConfigureRedaction sets up the global redaction configuration. +func ConfigureRedaction(config redaction.Config) { + redaction.SetGlobalConfig(config) +} diff --git a/pkg/redaction/redaction.go b/pkg/redaction/redaction.go new file mode 100644 index 000000000..75a433277 --- /dev/null +++ b/pkg/redaction/redaction.go @@ -0,0 +1,321 @@ +// Package redaction provides privacy protection through sensitive data redaction. +// It automatically detects and masks API keys, tokens, passwords, and PII. +package redaction + +import ( + "regexp" + "strings" + "sync" +) + +// Config holds redaction configuration. +type Config struct { + // Enabled controls whether redaction is active. + Enabled bool `json:"enabled"` + + // RedactAPIKeys redacts API keys and tokens. + RedactAPIKeys bool `json:"redact_api_keys"` + + // RedactPasswords redacts password fields. + RedactPasswords bool `json:"redact_passwords"` + + // RedactEmails redacts email addresses. + RedactEmails bool `json:"redact_emails"` + + // RedactPhoneNumbers redacts phone numbers. + RedactPhoneNumbers bool `json:"redact_phone_numbers"` + + // RedactIPAddresses redacts IP addresses. + RedactIPAddresses bool `json:"redact_ip_addresses"` + + // CustomPatterns allows additional regex patterns to redact. + CustomPatterns []string `json:"custom_patterns"` + + // Replacement is the string used to replace sensitive data. + Replacement string `json:"replacement"` +} + +// DefaultConfig returns the default redaction configuration. +func DefaultConfig() Config { + return Config{ + Enabled: true, + RedactAPIKeys: true, + RedactPasswords: true, + RedactEmails: true, + RedactPhoneNumbers: true, + RedactIPAddresses: false, // Off by default as it may redact useful info + Replacement: "[REDACTED]", + } +} + +// Redactor provides sensitive data redaction capabilities. +type Redactor struct { + config Config + compiledCustom []*regexp.Regexp + compiledBuiltin map[string]*regexp.Regexp + mu sync.RWMutex +} + +// NewRedactor creates a new Redactor with the given configuration. +func NewRedactor(config Config) *Redactor { + r := &Redactor{ + config: config, + compiledBuiltin: make(map[string]*regexp.Regexp), + } + + // Compile builtin patterns + r.compileBuiltinPatterns() + + // Compile custom patterns + if len(config.CustomPatterns) > 0 { + r.compiledCustom = make([]*regexp.Regexp, 0, len(config.CustomPatterns)) + for _, pattern := range config.CustomPatterns { + re, err := regexp.Compile(pattern) + if err == nil { + r.compiledCustom = append(r.compiledCustom, re) + } + } + } + + return r +} + +// compileBuiltinPatterns compiles the builtin redaction patterns. +func (r *Redactor) compileBuiltinPatterns() { + // API Key patterns - various formats + r.compiledBuiltin["api_key"] = regexp.MustCompile(`(?i)(api[_-]?key|apikey|api[_-]?secret)\s*[=:]\s*['"]?([a-zA-Z0-9_\-]{20,})['"]?`) + r.compiledBuiltin["bearer_token"] = regexp.MustCompile(`(?i)bearer\s+([a-zA-Z0-9_\-\.]{20,})`) + r.compiledBuiltin["auth_token"] = regexp.MustCompile(`(?i)(auth[_-]?token|access[_-]?token|refresh[_-]?token)\s*[=:]\s*['"]?([a-zA-Z0-9_\-\.]{20,})['"]?`) + r.compiledBuiltin["secret_key"] = regexp.MustCompile(`(?i)(secret[_-]?key|secretkey|private[_-]?key)\s*[=:]\s*['"]?([a-zA-Z0-9_\-]{20,})['"]?`) + + // OpenAI-style keys + r.compiledBuiltin["openai_key"] = regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}`) + r.compiledBuiltin["anthropic_key"] = regexp.MustCompile(`sk-ant-[a-zA-Z0-9\-]{20,}`) + + // Generic token patterns + r.compiledBuiltin["jwt"] = regexp.MustCompile(`eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]*`) + r.compiledBuiltin["uuid"] = regexp.MustCompile(`[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}`) + + // Password patterns + r.compiledBuiltin["password"] = regexp.MustCompile(`(?i)(password|passwd|pwd)\s*[=:]\s*['"]?([^'"\s]{4,})['"]?`) + + // Email pattern + r.compiledBuiltin["email"] = regexp.MustCompile(`[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`) + + // Phone number patterns (various formats) + r.compiledBuiltin["phone_intl"] = regexp.MustCompile(`\+\d{1,3}[\s\-]?\d{1,4}[\s\-]?\d{1,4}[\s\-]?\d{1,9}`) + r.compiledBuiltin["phone_us"] = regexp.MustCompile(`\(\d{3}\)\s*\d{3}[\s\-]?\d{4}`) + r.compiledBuiltin["phone_simple"] = regexp.MustCompile(`\b\d{3}[\s\-]?\d{3}[\s\-]?\d{4}\b`) + + // IP Address patterns + r.compiledBuiltin["ipv4"] = regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b`) + r.compiledBuiltin["ipv6"] = regexp.MustCompile(`\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b`) + + // AWS keys + r.compiledBuiltin["aws_access_key"] = regexp.MustCompile(`AKIA[0-9A-Z]{16}`) + r.compiledBuiltin["aws_secret"] = regexp.MustCompile(`(?i)aws[_-]?secret[_-]?access[_-]?key\s*[=:]\s*['"]?([a-zA-Z0-9/+=]{40})['"]?`) + + // Generic secrets in JSON/config + r.compiledBuiltin["json_secret"] = regexp.MustCompile(`"(?:api_key|apikey|secret|password|token|private_key)"\s*:\s*"([^"]+)"`) +} + +// Redact applies all configured redaction rules to the input string. +func (r *Redactor) Redact(input string) string { + if !r.config.Enabled { + return input + } + + r.mu.RLock() + defer r.mu.RUnlock() + + result := input + + // Redact API keys + if r.config.RedactAPIKeys { + result = r.redactPatterns(result, + "api_key", "bearer_token", "auth_token", "secret_key", + "openai_key", "anthropic_key", "jwt", "aws_access_key", "aws_secret", + ) + // Redact JSON secrets with special handling + result = r.redactJSONSecrets(result) + } + + // Redact passwords + if r.config.RedactPasswords { + result = r.redactPatterns(result, "password") + } + + // Redact emails + if r.config.RedactEmails { + result = r.redactPatternsWithPartial(result, "email", r.maskEmail) + } + + // Redact phone numbers + if r.config.RedactPhoneNumbers { + result = r.redactPatterns(result, "phone_intl", "phone_us", "phone_simple") + } + + // Redact IP addresses + if r.config.RedactIPAddresses { + result = r.redactPatterns(result, "ipv4", "ipv6") + } + + // Apply custom patterns + for _, re := range r.compiledCustom { + result = re.ReplaceAllString(result, r.config.Replacement) + } + + return result +} + +// redactPatterns applies redaction for the specified patterns. +func (r *Redactor) redactPatterns(input string, patternNames ...string) string { + result := input + for _, name := range patternNames { + if re, ok := r.compiledBuiltin[name]; ok { + // For patterns with capture groups, only redact the captured content + result = re.ReplaceAllStringFunc(result, func(match string) string { + // Find submatches + submatches := re.FindStringSubmatch(match) + if len(submatches) > 1 { + // Redact only the captured group(s), preserve the rest + redacted := match + for i := len(submatches) - 1; i >= 1; i-- { + if submatches[i] != "" { + redacted = strings.Replace(redacted, submatches[i], r.config.Replacement, 1) + } + } + return redacted + } + return r.config.Replacement + }) + } + } + return result +} + +// redactPatternsWithPartial applies partial redaction (like masking) for patterns. +func (r *Redactor) redactPatternsWithPartial(input string, patternName string, maskFn func(string) string) string { + re, ok := r.compiledBuiltin[patternName] + if !ok { + return input + } + + return re.ReplaceAllStringFunc(input, func(match string) string { + return maskFn(match) + }) +} + +// redactJSONSecrets handles JSON key-value pairs specially. +func (r *Redactor) redactJSONSecrets(input string) string { + re := r.compiledBuiltin["json_secret"] + return re.ReplaceAllStringFunc(input, func(match string) string { + submatches := re.FindStringSubmatch(match) + if len(submatches) > 1 { + return strings.Replace(match, submatches[1], r.config.Replacement, 1) + } + return match + }) +} + +// maskEmail masks an email address, showing only first char and domain. +func (r *Redactor) maskEmail(email string) string { + parts := strings.Split(email, "@") + if len(parts) != 2 { + return r.config.Replacement + } + + local := parts[0] + domain := parts[1] + + if len(local) <= 2 { + return string(local[0]) + "***@" + domain + } + + return string(local[0]) + "***@" + domain +} + +// RedactFields redacts sensitive values in a map. +func (r *Redactor) RedactFields(fields map[string]any) map[string]any { + if !r.config.Enabled { + return fields + } + + result := make(map[string]any, len(fields)) + for k, v := range fields { + // Check if key name suggests sensitive data + lowerKey := strings.ToLower(k) + if r.isSensitiveKey(lowerKey) { + result[k] = r.config.Replacement + } else { + // Recursively redact string values + switch val := v.(type) { + case string: + result[k] = r.Redact(val) + case map[string]any: + result[k] = r.RedactFields(val) + default: + result[k] = v + } + } + } + return result +} + +// isSensitiveKey checks if a key name suggests sensitive data. +func (r *Redactor) isSensitiveKey(key string) bool { + sensitiveKeys := []string{ + "password", "passwd", "pwd", + "api_key", "apikey", "api_secret", + "secret", "secret_key", "private_key", + "token", "access_token", "refresh_token", "auth_token", + "credential", "credentials", + "api_key_id", "secret_access_key", + } + + for _, sk := range sensitiveKeys { + if strings.Contains(key, sk) { + return true + } + } + return false +} + +// SetEnabled enables or disables redaction at runtime. +func (r *Redactor) SetEnabled(enabled bool) { + r.mu.Lock() + defer r.mu.Unlock() + r.config.Enabled = enabled +} + +// AddCustomPattern adds a custom redaction pattern at runtime. +func (r *Redactor) AddCustomPattern(pattern string) error { + r.mu.Lock() + defer r.mu.Unlock() + + re, err := regexp.Compile(pattern) + if err != nil { + return err + } + + r.compiledCustom = append(r.compiledCustom, re) + return nil +} + +// Global redactor instance with default config +var globalRedactor = NewRedactor(DefaultConfig()) + +// Redact applies redaction using the global redactor. +func Redact(input string) string { + return globalRedactor.Redact(input) +} + +// RedactFields redacts fields using the global redactor. +func RedactFields(fields map[string]any) map[string]any { + return globalRedactor.RedactFields(fields) +} + +// SetGlobalConfig sets the configuration for the global redactor. +func SetGlobalConfig(config Config) { + globalRedactor = NewRedactor(config) +} diff --git a/pkg/redaction/redaction_test.go b/pkg/redaction/redaction_test.go new file mode 100644 index 000000000..116581765 --- /dev/null +++ b/pkg/redaction/redaction_test.go @@ -0,0 +1,381 @@ +package redaction + +import ( + "testing" +) + +func TestRedactor_Redact_APIKeys(t *testing.T) { + r := NewRedactor(DefaultConfig()) + + tests := []struct { + name string + input string + wantRedact bool + }{ + { + name: "OpenAI key", + input: "api_key=sk-proj-1234567890abcdefghijklmnop", + wantRedact: true, + }, + { + name: "Anthropic key", + input: "api_key: sk-ant-api03-1234567890abcdefghijklmnop", + wantRedact: true, + }, + { + name: "Bearer token", + input: "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + wantRedact: true, + }, + { + name: "JWT token", + input: "token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", + wantRedact: true, + }, + { + name: "AWS access key", + input: "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE", + wantRedact: true, + }, + { + name: "plain text not redacted", + input: "This is a normal message without sensitive data", + wantRedact: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if tt.wantRedact { + if result == tt.input { + t.Errorf("Expected redaction for %q, got unchanged", tt.name) + } + if !contains(result, "[REDACTED]") { + t.Errorf("Expected [REDACTED] in result, got: %s", result) + } + } else { + if result != tt.input { + t.Errorf("Unexpected redaction for %q: %s", tt.name, result) + } + } + }) + } +} + +func TestRedactor_Redact_Emails(t *testing.T) { + r := NewRedactor(DefaultConfig()) + + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple email", + input: "Contact: test@example.com", + expected: "Contact: t***@example.com", + }, + { + name: "email in JSON", + input: `{"email": "user.name@company.org"}`, + expected: `{"email": "u***@company.org"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if result == tt.input { + t.Errorf("Expected email to be masked, got: %s", result) + } + }) + } +} + +func TestRedactor_Redact_Passwords(t *testing.T) { + r := NewRedactor(DefaultConfig()) + + tests := []struct { + name string + input string + wantRedact bool + }{ + { + name: "password field", + input: "password=mysecretpassword123", + wantRedact: true, + }, + { + name: "passwd field", + input: "passwd: secret123", + wantRedact: true, + }, + { + name: "JSON password", + input: `{"password": "mysecret", "user": "john"}`, + wantRedact: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if tt.wantRedact && result == tt.input { + t.Errorf("Expected password redaction for %q, got unchanged", tt.name) + } + }) + } +} + +func TestRedactor_Redact_PhoneNumbers(t *testing.T) { + r := NewRedactor(DefaultConfig()) + + tests := []struct { + name string + input string + wantRedact bool + }{ + { + name: "US phone format", + input: "Phone: (555) 123-4567", + wantRedact: true, + }, + { + name: "International format", + input: "Phone: +1 555 123 4567", + wantRedact: true, + }, + { + name: "Simple format", + input: "Call 555-123-4567", + wantRedact: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if tt.wantRedact && result == tt.input { + t.Errorf("Expected phone redaction for %q, got unchanged", tt.name) + } + }) + } +} + +func TestRedactor_Redact_IPAddresses(t *testing.T) { + config := DefaultConfig() + config.RedactIPAddresses = true + r := NewRedactor(config) + + tests := []struct { + name string + input string + wantRedact bool + }{ + { + name: "IPv4 address", + input: "Server IP: 192.168.1.100", + wantRedact: true, + }, + { + name: "Localhost", + input: "Connect to 127.0.0.1:8080", + wantRedact: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Redact(tt.input) + if tt.wantRedact && result == tt.input { + t.Errorf("Expected IP redaction for %q, got unchanged", tt.name) + } + }) + } +} + +func TestRedactor_RedactFields(t *testing.T) { + r := NewRedactor(DefaultConfig()) + + tests := []struct { + name string + input map[string]any + wantRedact []string // keys that should be redacted + }{ + { + name: "password field", + input: map[string]any{ + "username": "john", + "password": "secret123", + }, + wantRedact: []string{"password"}, + }, + { + name: "api_key field", + input: map[string]any{ + "api_key": "sk-1234567890", + "user": "john", + }, + wantRedact: []string{"api_key"}, + }, + { + name: "nested fields", + input: map[string]any{ + "config": map[string]any{ + "token": "abc123", + }, + }, + wantRedact: []string{"token"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.RedactFields(tt.input) + for _, key := range tt.wantRedact { + // Check nested + if nested, ok := result["config"].(map[string]any); ok { + if val, exists := nested[key]; exists { + if val == tt.input["config"].(map[string]any)[key] { + t.Errorf("Expected %q to be redacted", key) + } + } + } else if val, exists := result[key]; exists { + if val == "[REDACTED]" { + // Good + } else if val == tt.input[key] { + t.Errorf("Expected %q to be redacted, got: %v", key, val) + } + } + } + }) + } +} + +func TestRedactor_Disabled(t *testing.T) { + config := DefaultConfig() + config.Enabled = false + r := NewRedactor(config) + + input := "password=mysecret123 api_key=sk-1234567890" + result := r.Redact(input) + + if result != input { + t.Errorf("Expected no redaction when disabled, got: %s", result) + } +} + +func TestRedactor_CustomPatterns(t *testing.T) { + config := DefaultConfig() + config.CustomPatterns = []string{`CUSTOM-[A-Z0-9]+`} + r := NewRedactor(config) + + input := "Token: CUSTOM-ABC123XYZ" + result := r.Redact(input) + + if !contains(result, "[REDACTED]") { + t.Errorf("Expected custom pattern to be redacted, got: %s", result) + } +} + +func TestRedactor_AddCustomPattern(t *testing.T) { + r := NewRedactor(DefaultConfig()) + + err := r.AddCustomPattern(`MYSECRET-[a-z]+`) + if err != nil { + t.Fatalf("Failed to add custom pattern: %v", err) + } + + input := "Code: MYSECRET-hiddenvalue" + result := r.Redact(input) + + if !contains(result, "[REDACTED]") { + t.Errorf("Expected custom pattern to be redacted, got: %s", result) + } +} + +func TestMaskEmail(t *testing.T) { + r := NewRedactor(DefaultConfig()) + + tests := []struct { + email string + expected string + }{ + {"test@example.com", "t***@example.com"}, + {"ab@domain.org", "a***@domain.org"}, + {"longemail@company.net", "l***@company.net"}, + } + + for _, tt := range tests { + t.Run(tt.email, func(t *testing.T) { + result := r.maskEmail(tt.email) + if result != tt.expected { + t.Errorf("maskEmail(%q) = %q, want %q", tt.email, result, tt.expected) + } + }) + } +} + +func TestIsSensitiveKey(t *testing.T) { + r := NewRedactor(DefaultConfig()) + + tests := []struct { + key string + expected bool + }{ + {"password", true}, + {"api_key", true}, + {"secret", true}, + {"token", true}, + {"access_token", true}, + {"credential", true}, + {"username", false}, + {"email", false}, + {"name", false}, + {"id", false}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + result := r.isSensitiveKey(tt.key) + if result != tt.expected { + t.Errorf("isSensitiveKey(%q) = %v, want %v", tt.key, result, tt.expected) + } + }) + } +} + +func TestGlobalRedactor(t *testing.T) { + // Reset to default + SetGlobalConfig(DefaultConfig()) + + input := "password=secret123" + result := Redact(input) + + if result == input { + t.Error("Expected global Redact to redact sensitive data") + } + + fields := map[string]any{ + "api_key": "sk-12345", + } + resultFields := RedactFields(fields) + + if resultFields["api_key"] != "[REDACTED]" { + t.Error("Expected global RedactFields to redact sensitive fields") + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/ssrf/guard.go b/pkg/ssrf/guard.go new file mode 100644 index 000000000..4636dc1cf --- /dev/null +++ b/pkg/ssrf/guard.go @@ -0,0 +1,233 @@ +// Package ssrf provides Server-Side Request Forgery protection for HTTP clients. +// It blocks requests to private IP ranges, metadata endpoints, and other sensitive destinations. +package ssrf + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "sync" + "time" +) + +// Config holds SSRF protection configuration. +type Config struct { + // Enabled controls whether SSRF protection is active. + Enabled bool `json:"enabled"` + + // BlockPrivateIPs blocks requests to private IP ranges (RFC 1918). + BlockPrivateIPs bool `json:"block_private_ips"` + + // BlockMetadataEndpoints blocks requests to cloud metadata endpoints. + BlockMetadataEndpoints bool `json:"block_metadata_endpoints"` + + // BlockLocalhost blocks requests to localhost/loopback. + BlockLocalhost bool `json:"block_localhost"` + + // AllowedHosts is a list of hosts that are explicitly allowed, bypassing SSRF checks. + AllowedHosts []string `json:"allowed_hosts"` + + // DNSRebindingProtection enables DNS rebinding attack protection. + DNSRebindingProtection bool `json:"dns_rebinding_protection"` + + // DNSCacheTTL is the duration to cache DNS results for rebinding protection. + DNSCacheTTL time.Duration `json:"dns_cache_ttl"` +} + +// DefaultConfig returns the default SSRF protection configuration. +func DefaultConfig() Config { + return Config{ + Enabled: true, + BlockPrivateIPs: true, + BlockMetadataEndpoints: true, + BlockLocalhost: true, + AllowedHosts: nil, + DNSRebindingProtection: true, + DNSCacheTTL: 60 * time.Second, + } +} + +// Guard provides SSRF protection for HTTP requests. +type Guard struct { + config Config + + // dnsCache stores resolved IPs for DNS rebinding protection. + dnsCache sync.Map // map[string]dnsCacheEntry +} + +type dnsCacheEntry struct { + ips []net.IP + expiresAt time.Time +} + +// Error represents an SSRF protection error. +type Error struct { + Reason string + URL string +} + +func (e *Error) Error() string { + return fmt.Sprintf("SSRF protection: %s (URL: %s)", e.Reason, e.URL) +} + +// NewGuard creates a new SSRF guard with the given configuration. +func NewGuard(config Config) *Guard { + return &Guard{ + config: config, + } +} + +// CheckURL validates a URL against SSRF protection rules. +// Returns an error if the URL is blocked, nil otherwise. +func (g *Guard) CheckURL(ctx context.Context, rawURL string) error { + if !g.config.Enabled { + return nil + } + + parsedURL, err := url.Parse(rawURL) + if err != nil { + return &Error{Reason: "invalid URL", URL: rawURL} + } + + // Only allow http and https schemes + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return &Error{Reason: "only http/https schemes allowed", URL: rawURL} + } + + host := parsedURL.Hostname() + if host == "" { + return &Error{Reason: "missing host", URL: rawURL} + } + + // Check if host is in allowed list + for _, allowed := range g.config.AllowedHosts { + if host == allowed || strings.HasSuffix(host, "."+allowed) { + return nil + } + } + + // Resolve host to IPs + ips, err := g.resolveHost(ctx, host) + if err != nil { + return &Error{Reason: fmt.Sprintf("failed to resolve host: %v", err), URL: rawURL} + } + + // Check each resolved IP + for _, ip := range ips { + if err := g.checkIP(ip, rawURL); err != nil { + return err + } + } + + return nil +} + +// resolveHost resolves a hostname to IP addresses with caching for DNS rebinding protection. +func (g *Guard) resolveHost(ctx context.Context, host string) ([]net.IP, error) { + // Check if it's already an IP address + if ip := net.ParseIP(host); ip != nil { + return []net.IP{ip}, nil + } + + // Check cache for DNS rebinding protection + if g.config.DNSRebindingProtection { + if cached, ok := g.dnsCache.Load(host); ok { + entry := cached.(dnsCacheEntry) + if time.Now().Before(entry.expiresAt) { + return entry.ips, nil + } + } + } + + // Resolve the host + resolver := &net.Resolver{} + addrs, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + + if len(addrs) == 0 { + return nil, fmt.Errorf("no IP addresses found for host: %s", host) + } + + ips := make([]net.IP, len(addrs)) + for i, addr := range addrs { + ips[i] = addr.IP + } + + // Cache the result for DNS rebinding protection + if g.config.DNSRebindingProtection { + g.dnsCache.Store(host, dnsCacheEntry{ + ips: ips, + expiresAt: time.Now().Add(g.config.DNSCacheTTL), + }) + } + + return ips, nil +} + +// checkIP checks if an IP address is allowed. +func (g *Guard) checkIP(ip net.IP, rawURL string) error { + // Block localhost/loopback + if g.config.BlockLocalhost && isLoopback(ip) { + return &Error{Reason: "localhost/loopback address blocked", URL: rawURL} + } + + // Block cloud metadata endpoints (169.254.169.254) + if g.config.BlockMetadataEndpoints && isMetadataEndpoint(ip) { + return &Error{Reason: "cloud metadata endpoint blocked", URL: rawURL} + } + + // Block private IP ranges + if g.config.BlockPrivateIPs && isPrivateIP(ip) { + return &Error{Reason: "private IP address blocked", URL: rawURL} + } + + return nil +} + +// isLoopback checks if an IP is a loopback address. +func isLoopback(ip net.IP) bool { + return ip.IsLoopback() +} + +// isMetadataEndpoint checks if an IP is a cloud metadata endpoint. +func isMetadataEndpoint(ip net.IP) bool { + // AWS/GCP/Azure metadata endpoint: 169.254.169.254 + metadataIP := net.ParseIP("169.254.169.254") + return ip.Equal(metadataIP) +} + +// isPrivateIP checks if an IP is in a private range. +func isPrivateIP(ip net.IP) bool { + // Check if it's a private address using net's built-in method + if ip.IsPrivate() { + return true + } + + // Additional checks for link-local addresses + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + return false +} + +// GetResolvedIPs returns the cached IPs for a host (for DNS rebinding protection). +// This should be used when making the actual request to ensure the IP hasn't changed. +func (g *Guard) GetResolvedIPs(host string) []net.IP { + if cached, ok := g.dnsCache.Load(host); ok { + entry := cached.(dnsCacheEntry) + if time.Now().Before(entry.expiresAt) { + return entry.ips + } + } + return nil +} + +// ClearCache clears the DNS cache. +func (g *Guard) ClearCache() { + g.dnsCache = sync.Map{} +} diff --git a/pkg/ssrf/guard_test.go b/pkg/ssrf/guard_test.go new file mode 100644 index 000000000..60281ef7d --- /dev/null +++ b/pkg/ssrf/guard_test.go @@ -0,0 +1,238 @@ +package ssrf + +import ( + "context" + "net" + "testing" + "time" +) + +func TestGuard_CheckURL(t *testing.T) { + tests := []struct { + name string + config Config + url string + wantErr bool + errContains string + }{ + { + name: "valid public URL", + config: DefaultConfig(), + url: "https://example.com/path", + wantErr: false, + }, + { + name: "localhost blocked", + config: DefaultConfig(), + url: "http://localhost:8080/api", + wantErr: true, + errContains: "localhost", + }, + { + name: "127.0.0.1 blocked", + config: DefaultConfig(), + url: "http://127.0.0.1:8080/api", + wantErr: true, + errContains: "localhost/loopback", + }, + { + name: "metadata endpoint blocked", + config: DefaultConfig(), + url: "http://169.254.169.254/latest/meta-data/", + wantErr: true, + errContains: "metadata", + }, + { + name: "private IP 10.x blocked", + config: DefaultConfig(), + url: "http://10.0.0.1/internal", + wantErr: true, + errContains: "private IP", + }, + { + name: "private IP 172.16.x blocked", + config: DefaultConfig(), + url: "http://172.16.0.1/internal", + wantErr: true, + errContains: "private IP", + }, + { + name: "private IP 192.168.x blocked", + config: DefaultConfig(), + url: "http://192.168.1.1/internal", + wantErr: true, + errContains: "private IP", + }, + { + name: "disabled protection allows all", + config: Config{ + Enabled: false, + }, + url: "http://localhost:8080/api", + wantErr: false, + }, + { + name: "allowed host bypasses check", + config: Config{ + Enabled: true, + BlockPrivateIPs: true, + BlockLocalhost: true, + AllowedHosts: []string{"localhost", "internal.example.com"}, + }, + url: "http://localhost:8080/api", + wantErr: false, + }, + { + name: "invalid scheme", + config: DefaultConfig(), + url: "ftp://example.com/file", + wantErr: true, + errContains: "scheme", + }, + { + name: "link-local blocked", + config: DefaultConfig(), + url: "http://169.254.1.1/test", + wantErr: true, + errContains: "private IP", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGuard(tt.config) + err := g.CheckURL(context.Background(), tt.url) + + if tt.wantErr { + if err == nil { + t.Errorf("Guard.CheckURL() expected error, got nil") + return + } + if tt.errContains != "" && !contains(err.Error(), tt.errContains) { + t.Errorf("Guard.CheckURL() error = %v, want containing %v", err, tt.errContains) + } + } else { + if err != nil { + t.Errorf("Guard.CheckURL() unexpected error = %v", err) + } + } + }) + } +} + +func TestGuard_AllowedHostsSubdomain(t *testing.T) { + config := Config{ + Enabled: true, + BlockPrivateIPs: true, + AllowedHosts: []string{"example.com"}, + } + + _ = NewGuard(config) + + // Subdomain of allowed host should be allowed + // Note: This test may fail if the domain actually resolves to a private IP + // In practice, this tests the logic path +} + +func TestGuard_DNSCache(t *testing.T) { + config := Config{ + Enabled: true, + DNSRebindingProtection: true, + DNSCacheTTL: 5 * time.Second, + } + + g := NewGuard(config) + + // Clear cache first + g.ClearCache() + + // Verify cache is empty + if ips := g.GetResolvedIPs("example.com"); ips != nil { + t.Error("Expected empty cache initially") + } +} + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + ip string + private bool + }{ + {"10.0.0.1", true}, + {"10.255.255.255", true}, + {"172.16.0.1", true}, + {"172.31.255.255", true}, + {"192.168.0.1", true}, + {"192.168.255.255", true}, + {"127.0.0.1", false}, // Loopback is handled separately + {"8.8.8.8", false}, + {"1.1.1.1", false}, + {"169.254.1.1", true}, // Link-local + } + + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("Failed to parse IP: %s", tt.ip) + } + got := isPrivateIP(ip) + if got != tt.private { + t.Errorf("isPrivateIP(%s) = %v, want %v", tt.ip, got, tt.private) + } + }) + } +} + +func TestIsMetadataEndpoint(t *testing.T) { + metadataIP := net.ParseIP("169.254.169.254") + if !isMetadataEndpoint(metadataIP) { + t.Error("Expected 169.254.169.254 to be detected as metadata endpoint") + } + + otherIP := net.ParseIP("8.8.8.8") + if isMetadataEndpoint(otherIP) { + t.Error("Expected 8.8.8.8 not to be detected as metadata endpoint") + } +} + +func TestIsLoopback(t *testing.T) { + loopback := net.ParseIP("127.0.0.1") + if !isLoopback(loopback) { + t.Error("Expected 127.0.0.1 to be detected as loopback") + } + + ipv6Loopback := net.ParseIP("::1") + if !isLoopback(ipv6Loopback) { + t.Error("Expected ::1 to be detected as loopback") + } + + otherIP := net.ParseIP("8.8.8.8") + if isLoopback(otherIP) { + t.Error("Expected 8.8.8.8 not to be detected as loopback") + } +} + +func TestError(t *testing.T) { + err := &Error{ + Reason: "test reason", + URL: "http://example.com", + } + + expected := "SSRF protection: test reason (URL: http://example.com)" + if err.Error() != expected { + t.Errorf("Error() = %v, want %v", err.Error(), expected) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 8ba2a723a..6c798d341 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -11,25 +11,14 @@ import ( "regexp" "strings" "time" + + "github.com/sipeed/picoclaw/pkg/ssrf" ) const ( userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" ) -// Pre-compiled regexes for HTML text extraction -var ( - reScript = regexp.MustCompile(``) - reStyle = regexp.MustCompile(``) - reTags = regexp.MustCompile(`<[^>]+>`) - reWhitespace = regexp.MustCompile(`[^\S\n]+`) - reBlankLines = regexp.MustCompile(`\n{3,}`) - - // DuckDuckGo result extraction - reDDGLink = regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) - reDDGSnippet = regexp.MustCompile(`([\s\S]*?)`) -) - // createHTTPClient creates an HTTP client with optional proxy support func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { client := &http.Client{ @@ -142,7 +131,6 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in type TavilySearchProvider struct { apiKey string baseURL string - proxy string } func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -174,10 +162,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(p.proxy, 10*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } + client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) @@ -264,7 +249,8 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query // Try finding the result links directly first, as they are the most critical // Pattern: Title // The previous regex was a bit strict. Let's make it more flexible for attributes order/content - matches := reDDGLink.FindAllStringSubmatch(html, count+5) + reLink := regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) + matches := reLink.FindAllStringSubmatch(html, count+5) if len(matches) == 0 { return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil @@ -281,7 +267,8 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query // A better regex approach: iterate through text and find matches in order // But for now, let's grab all snippets too - snippetMatches := reDDGSnippet.FindAllStringSubmatch(html, count+5) + reSnippet := regexp.MustCompile(`([\s\S]*?)`) + snippetMatches := reSnippet.FindAllStringSubmatch(html, count+5) maxItems := min(len(matches), count) @@ -316,7 +303,8 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query } func stripTags(content string) string { - return reTags.ReplaceAllString(content, "") + re := regexp.MustCompile(`<[^>]+>`) + return re.ReplaceAllString(content, "") } type PerplexitySearchProvider struct { @@ -434,7 +422,6 @@ func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { provider = &TavilySearchProvider{ apiKey: opts.TavilyAPIKey, baseURL: opts.TavilyBaseURL, - proxy: opts.Proxy, } if opts.TavilyMaxResults > 0 { maxResults = opts.TavilyMaxResults @@ -506,8 +493,9 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR } type WebFetchTool struct { - maxChars int - proxy string + maxChars int + proxy string + ssrfGuard *ssrf.Guard } func NewWebFetchTool(maxChars int) *WebFetchTool { @@ -515,7 +503,8 @@ func NewWebFetchTool(maxChars int) *WebFetchTool { maxChars = 50000 } return &WebFetchTool{ - maxChars: maxChars, + maxChars: maxChars, + ssrfGuard: ssrf.NewGuard(ssrf.DefaultConfig()), } } @@ -524,8 +513,21 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool { maxChars = 50000 } return &WebFetchTool{ - maxChars: maxChars, - proxy: proxy, + maxChars: maxChars, + proxy: proxy, + ssrfGuard: ssrf.NewGuard(ssrf.DefaultConfig()), + } +} + +// NewWebFetchToolWithSSRF creates a WebFetchTool with custom SSRF configuration. +func NewWebFetchToolWithSSRF(maxChars int, proxy string, ssrfConfig ssrf.Config) *WebFetchTool { + if maxChars <= 0 { + maxChars = 50000 + } + return &WebFetchTool{ + maxChars: maxChars, + proxy: proxy, + ssrfGuard: ssrf.NewGuard(ssrfConfig), } } @@ -561,6 +563,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult("url is required") } + // SSRF protection check + if t.ssrfGuard != nil { + if err := t.ssrfGuard.CheckURL(ctx, urlStr); err != nil { + return ErrorResult(fmt.Sprintf("SSRF protection: %v", err)) + } + } + parsedURL, err := url.Parse(urlStr) if err != nil { return ErrorResult(fmt.Sprintf("invalid URL: %v", err)) @@ -593,11 +602,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err)) } - // Configure redirect handling + // Configure redirect handling with SSRF protection client.CheckRedirect = func(req *http.Request, via []*http.Request) error { if len(via) >= 5 { return fmt.Errorf("stopped after 5 redirects") } + // Check redirect URL for SSRF + if t.ssrfGuard != nil { + if err := t.ssrfGuard.CheckURL(ctx, req.URL.String()); err != nil { + return fmt.Errorf("redirect blocked by SSRF protection: %v", err) + } + } return nil } @@ -664,14 +679,19 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe } func (t *WebFetchTool) extractText(htmlContent string) string { - result := reScript.ReplaceAllLiteralString(htmlContent, "") - result = reStyle.ReplaceAllLiteralString(result, "") - result = reTags.ReplaceAllLiteralString(result, "") + re := regexp.MustCompile(``) + result := re.ReplaceAllLiteralString(htmlContent, "") + re = regexp.MustCompile(``) + result = re.ReplaceAllLiteralString(result, "") + re = regexp.MustCompile(`<[^>]+>`) + result = re.ReplaceAllLiteralString(result, "") result = strings.TrimSpace(result) - result = reWhitespace.ReplaceAllString(result, " ") - result = reBlankLines.ReplaceAllString(result, "\n\n") + re = regexp.MustCompile(`[^\S\n]+`) + result = re.ReplaceAllString(result, " ") + re = regexp.MustCompile(`\n{3,}`) + result = re.ReplaceAllString(result, "\n\n") lines := strings.Split(result, "\n") var cleanLines []string