From c79d8c270fcd7637ddee740f712213c347647996 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Sun, 8 Mar 2026 10:52:01 -0600 Subject: [PATCH 01/13] feat: MITM TLS interception for HTTP content inspection Add TLS man-in-the-middle capability so greyproxy can decrypt and log HTTPS request/response content flowing through the proxy. - Add `greyproxy cert generate/install` CLI for CA certificate management - Auto-inject MITM cert paths into HTTP/SOCKS5 handler metadata on startup - Enable sniffing by default in greyproxy.yml for both proxy services - Add OnHTTPRoundTrip callback to Sniffer for decrypted traffic hooks - Wire [MITM] log output in both HTTP and SOCKS5 handlers - Fix GenerateCertificate to auto-detect key type (ECDSA/Ed25519/RSA) instead of hardcoding SHA256WithRSA --- cmd/greyproxy/cert.go | 176 ++++++++++++++++++ cmd/greyproxy/main.go | 4 + cmd/greyproxy/program.go | 23 +++ greyproxy.yml | 4 + internal/gostx/handler/http/handler.go | 19 ++ internal/gostx/handler/socks/v5/connect.go | 19 ++ .../gostx/internal/util/sniffing/sniffer.go | 90 +++++++-- internal/gostx/internal/util/tls/tls.go | 18 +- 8 files changed, 336 insertions(+), 17 deletions(-) create mode 100644 cmd/greyproxy/cert.go diff --git a/cmd/greyproxy/cert.go b/cmd/greyproxy/cert.go new file mode 100644 index 0000000..e516b43 --- /dev/null +++ b/cmd/greyproxy/cert.go @@ -0,0 +1,176 @@ +package main + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "os/exec" + "path/filepath" + "runtime" + "time" +) + +func handleCert(args []string) { + if len(args) == 0 { + fmt.Fprintf(os.Stderr, `Usage: greyproxy cert + +Commands: + generate Generate CA certificate and key pair + install Trust the CA certificate on the OS + +Options: + -f Force overwrite existing files (generate only) +`) + os.Exit(1) + } + + switch args[0] { + case "generate": + force := len(args) > 1 && args[1] == "-f" + handleCertGenerate(force) + case "install": + handleCertInstall() + default: + fmt.Fprintf(os.Stderr, "unknown cert command: %s\n", args[0]) + os.Exit(1) + } +} + +func handleCertGenerate(force bool) { + dataDir := greyproxyDataHome() + certFile := filepath.Join(dataDir, "ca-cert.pem") + keyFile := filepath.Join(dataDir, "ca-key.pem") + + if !force { + if _, err := os.Stat(certFile); err == nil { + fmt.Fprintf(os.Stderr, "CA certificate already exists: %s\nUse -f to overwrite.\n", certFile) + os.Exit(1) + } + if _, err := os.Stat(keyFile); err == nil { + fmt.Fprintf(os.Stderr, "CA key already exists: %s\nUse -f to overwrite.\n", keyFile) + os.Exit(1) + } + } + + // Generate ECDSA P-256 key + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to generate private key: %v\n", err) + os.Exit(1) + } + + // Create self-signed CA certificate + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to generate serial number: %v\n", err) + os.Exit(1) + } + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "Greyproxy CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 0, + MaxPathLenZero: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to create certificate: %v\n", err) + os.Exit(1) + } + + // Ensure data directory exists + if err := os.MkdirAll(dataDir, 0700); err != nil { + fmt.Fprintf(os.Stderr, "failed to create data directory: %v\n", err) + os.Exit(1) + } + + // Write certificate + certOut, err := os.OpenFile(certFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to write certificate: %v\n", err) + os.Exit(1) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + certOut.Close() + fmt.Fprintf(os.Stderr, "failed to encode certificate: %v\n", err) + os.Exit(1) + } + certOut.Close() + + // Write private key + keyBytes, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to marshal private key: %v\n", err) + os.Exit(1) + } + + keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to write key: %v\n", err) + os.Exit(1) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}); err != nil { + keyOut.Close() + fmt.Fprintf(os.Stderr, "failed to encode key: %v\n", err) + os.Exit(1) + } + keyOut.Close() + + fmt.Printf("CA certificate: %s\n", certFile) + fmt.Printf("CA private key: %s\n", keyFile) + fmt.Println("\nRun 'greyproxy cert install' to trust this CA on your system.") +} + +func handleCertInstall() { + certFile := filepath.Join(greyproxyDataHome(), "ca-cert.pem") + + if _, err := os.Stat(certFile); os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "CA certificate not found: %s\nRun 'greyproxy cert generate' first.\n", certFile) + os.Exit(1) + } + + switch runtime.GOOS { + case "darwin": + // Remove any existing Greyproxy CA cert to avoid errSecDuplicateItem (-25294) + exec.Command("security", "delete-certificate", "-c", "Greyproxy CA").Run() + + fmt.Println("Installing CA certificate into system trust store (requires sudo)...") + cmd := exec.Command("sudo", "security", "add-trusted-cert", + "-d", "-r", "trustRoot", + "-k", "/Library/Keychains/System.keychain", + certFile, + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + if err := cmd.Run(); err != nil { + fmt.Fprintf(os.Stderr, "\nAutomatic install failed. Please run manually:\n\n") + fmt.Fprintf(os.Stderr, " sudo security add-trusted-cert -d -r trustRoot -k /Library/Keychains/System.keychain \"%s\"\n\n", certFile) + os.Exit(1) + } + fmt.Println("CA certificate installed and trusted.") + + case "linux": + fmt.Printf("To trust the CA certificate on Linux, run:\n\n") + fmt.Printf(" sudo cp %s /usr/local/share/ca-certificates/greyproxy-ca.crt\n", certFile) + fmt.Printf(" sudo update-ca-certificates\n") + + default: + fmt.Printf("CA certificate is at: %s\n", certFile) + fmt.Printf("Please install it manually in your OS trust store.\n") + } +} diff --git a/cmd/greyproxy/main.go b/cmd/greyproxy/main.go index 75b83d3..37c256c 100644 --- a/cmd/greyproxy/main.go +++ b/cmd/greyproxy/main.go @@ -142,6 +142,9 @@ func main() { case "uninstall": handleUninstall(os.Args[2:]) + case "cert": + handleCert(os.Args[2:]) + case "-V", "--version": fmt.Fprintf(os.Stdout, "greyproxy %s (%s %s/%s)\n built: %s\n commit: %s\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH, buildTime, gitCommit) @@ -159,6 +162,7 @@ Usage: greyproxy Commands: serve Run the proxy server in foreground + cert Manage MITM CA certificate (generate/install) install Install binary and register as a background service [-f] uninstall Stop service, remove registration and binary [-f] service Manage the OS service (start/stop/restart/status/...) diff --git a/cmd/greyproxy/program.go b/cmd/greyproxy/program.go index 4e89f76..d53024a 100644 --- a/cmd/greyproxy/program.go +++ b/cmd/greyproxy/program.go @@ -64,6 +64,29 @@ func (p *program) Start(s service.Service) error { os.Exit(0) } + // Auto-inject MITM cert paths if CA files exist + certFile := filepath.Join(greyproxyDataHome(), "ca-cert.pem") + keyFile := filepath.Join(greyproxyDataHome(), "ca-key.pem") + if _, err := os.Stat(certFile); err == nil { + if _, err := os.Stat(keyFile); err == nil { + for _, svc := range cfg.Services { + if svc.Handler == nil { + continue + } + if svc.Handler.Type != "http" && svc.Handler.Type != "socks5" { + continue + } + if svc.Handler.Metadata == nil { + svc.Handler.Metadata = make(map[string]any) + } + if _, ok := svc.Handler.Metadata["mitm.certFile"]; !ok { + svc.Handler.Metadata["mitm.certFile"] = certFile + svc.Handler.Metadata["mitm.keyFile"] = keyFile + } + } + } + } + config.Set(cfg) // Override DNS handler to capture responses for DNS cache population. diff --git a/greyproxy.yml b/greyproxy.yml index 0d741c8..cb0e2e8 100644 --- a/greyproxy.yml +++ b/greyproxy.yml @@ -31,6 +31,8 @@ services: handler: type: http auther: auther-0 + metadata: + sniffing: true listener: type: tcp admission: admission-0 @@ -44,6 +46,8 @@ services: handler: type: socks5 auther: auther-0 + metadata: + sniffing: true listener: type: tcp admission: admission-0 diff --git a/internal/gostx/handler/http/handler.go b/internal/gostx/handler/http/handler.go index 409211e..1deac3f 100644 --- a/internal/gostx/handler/http/handler.go +++ b/internal/gostx/handler/http/handler.go @@ -417,6 +417,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt CertPool: h.certPool, MitmBypass: h.md.mitmBypass, ReadTimeout: h.md.readTimeout, + OnHTTPRoundTrip: mitmLogHook(log), } conn = xnet.NewReadWriteConn(br, conn, conn) @@ -1002,3 +1003,21 @@ func (h *httpHandler) observeStats(ctx context.Context) { } } } + +func mitmLogHook(log logger.Logger) func(info sniffing.HTTPRoundTripInfo) { + return func(info sniffing.HTTPRoundTripInfo) { + log.Infof("[MITM] %s %s%s → %d", info.Method, info.Host, info.URI, info.StatusCode) + log.Debugf("[MITM] Request Headers: %v", info.RequestHeaders) + if len(info.RequestBody) > 0 { + log.Debugf("[MITM] Request Body: %s", info.RequestBody) + } + log.Debugf("[MITM] Response Headers: %v", info.ResponseHeaders) + if len(info.ResponseBody) > 0 { + bodyPreview := info.ResponseBody + if len(bodyPreview) > 512 { + bodyPreview = bodyPreview[:512] + } + log.Debugf("[MITM] Response Body (%d bytes): %s", len(info.ResponseBody), bodyPreview) + } + } +} diff --git a/internal/gostx/handler/socks/v5/connect.go b/internal/gostx/handler/socks/v5/connect.go index 727eb6e..51a2938 100644 --- a/internal/gostx/handler/socks/v5/connect.go +++ b/internal/gostx/handler/socks/v5/connect.go @@ -168,6 +168,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ CertPool: h.certPool, MitmBypass: h.md.mitmBypass, ReadTimeout: h.md.readTimeout, + OnHTTPRoundTrip: mitmLogHook(log), } conn = xnet.NewReadWriteConn(br, conn, conn) @@ -201,3 +202,21 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ return nil } + +func mitmLogHook(log logger.Logger) func(info sniffing.HTTPRoundTripInfo) { + return func(info sniffing.HTTPRoundTripInfo) { + log.Infof("[MITM] %s %s%s → %d", info.Method, info.Host, info.URI, info.StatusCode) + log.Debugf("[MITM] Request Headers: %v", info.RequestHeaders) + if len(info.RequestBody) > 0 { + log.Debugf("[MITM] Request Body: %s", info.RequestBody) + } + log.Debugf("[MITM] Response Headers: %v", info.ResponseHeaders) + if len(info.ResponseBody) > 0 { + bodyPreview := info.ResponseBody + if len(bodyPreview) > 512 { + bodyPreview = bodyPreview[:512] + } + log.Debugf("[MITM] Response Body (%d bytes): %s", len(info.ResponseBody), bodyPreview) + } + } +} diff --git a/internal/gostx/internal/util/sniffing/sniffer.go b/internal/gostx/internal/util/sniffing/sniffer.go index ab78e5c..371a6d7 100644 --- a/internal/gostx/internal/util/sniffing/sniffer.go +++ b/internal/gostx/internal/util/sniffing/sniffer.go @@ -101,6 +101,19 @@ func WithLog(log logger.Logger) HandleOption { } } +// HTTPRoundTripInfo contains decrypted HTTP request/response data from a MITM round-trip. +type HTTPRoundTripInfo struct { + Host string + Method string + URI string + Proto string + StatusCode int + RequestHeaders http.Header + RequestBody []byte + ResponseHeaders http.Header + ResponseBody []byte +} + type Sniffer struct { Websocket bool WebsocketSampleRate float64 @@ -116,6 +129,9 @@ type Sniffer struct { MitmBypass bypass.Bypass ReadTimeout time.Duration + + // OnHTTPRoundTrip is called after each decrypted HTTP round-trip with request/response details. + OnHTTPRoundTrip func(info HTTPRoundTripInfo) } func (h *Sniffer) HandleHTTP(ctx context.Context, network string, conn net.Conn, opts ...HandleOption) error { @@ -272,6 +288,7 @@ func (h *Sniffer) serveH2(ctx context.Context, network string, conn net.Conn, ho recorderOptions: h.RecorderOptions, recorderObject: ro, log: log, + onHTTPRoundTrip: h.OnHTTPRoundTrip, }, }) return nil @@ -326,11 +343,12 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, } var reqBody *xhttp.Body - if opts := h.RecorderOptions; opts != nil && opts.HTTPBody { + captureBody := (h.RecorderOptions != nil && h.RecorderOptions.HTTPBody) || h.OnHTTPRoundTrip != nil + if captureBody { if req.Body != nil { - bodySize := opts.MaxBodySize - if bodySize <= 0 { - bodySize = DefaultBodySize + bodySize := DefaultBodySize + if opts := h.RecorderOptions; opts != nil && opts.MaxBodySize > 0 { + bodySize = opts.MaxBodySize } if bodySize > MaxBodySize { bodySize = MaxBodySize @@ -395,10 +413,10 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, } var respBody *xhttp.Body - if opts := h.RecorderOptions; opts != nil && opts.HTTPBody { - bodySize := opts.MaxBodySize - if bodySize <= 0 { - bodySize = DefaultBodySize + if captureBody { + bodySize := DefaultBodySize + if opts := h.RecorderOptions; opts != nil && opts.MaxBodySize > 0 { + bodySize = opts.MaxBodySize } if bodySize > MaxBodySize { bodySize = MaxBodySize @@ -419,6 +437,25 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, return } + if h.OnHTTPRoundTrip != nil { + info := HTTPRoundTripInfo{ + Host: req.Host, + Method: req.Method, + URI: req.RequestURI, + Proto: req.Proto, + StatusCode: resp.StatusCode, + RequestHeaders: ro.HTTP.Request.Header, + ResponseHeaders: ro.HTTP.Response.Header, + } + if reqBody != nil { + info.RequestBody = reqBody.Content() + } + if respBody != nil { + info.ResponseBody = respBody.Content() + } + h.OnHTTPRoundTrip(info) + } + if resp.ContentLength >= 0 { close = resp.Close } @@ -789,6 +826,7 @@ type h2Handler struct { recorderOptions *recorder.Options recorderObject *xrecorder.HandlerRecorderObject log logger.Logger + onHTTPRoundTrip func(info HTTPRoundTripInfo) } func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -848,11 +886,12 @@ func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } var reqBody *xhttp.Body - if opts := h.recorderOptions; opts != nil && opts.HTTPBody { + h2CaptureBody := (h.recorderOptions != nil && h.recorderOptions.HTTPBody) || h.onHTTPRoundTrip != nil + if h2CaptureBody { if req.Body != nil { - bodySize := opts.MaxBodySize - if bodySize <= 0 { - bodySize = DefaultBodySize + bodySize := DefaultBodySize + if opts := h.recorderOptions; opts != nil && opts.MaxBodySize > 0 { + bodySize = opts.MaxBodySize } if bodySize > MaxBodySize { bodySize = MaxBodySize @@ -888,10 +927,10 @@ func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(resp.StatusCode) var respBody *xhttp.Body - if opts := h.recorderOptions; opts != nil && opts.HTTPBody { - bodySize := opts.MaxBodySize - if bodySize <= 0 { - bodySize = DefaultBodySize + if h2CaptureBody { + bodySize := DefaultBodySize + if opts := h.recorderOptions; opts != nil && opts.MaxBodySize > 0 { + bodySize = opts.MaxBodySize } if bodySize > MaxBodySize { bodySize = MaxBodySize @@ -906,6 +945,25 @@ func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ro.HTTP.Response.Body = respBody.Content() ro.HTTP.Response.ContentLength = respBody.Length() } + + if h.onHTTPRoundTrip != nil { + info := HTTPRoundTripInfo{ + Host: r.Host, + Method: r.Method, + URI: r.RequestURI, + Proto: r.Proto, + StatusCode: resp.StatusCode, + RequestHeaders: ro.HTTP.Request.Header, + ResponseHeaders: ro.HTTP.Response.Header, + } + if reqBody != nil { + info.RequestBody = reqBody.Content() + } + if respBody != nil { + info.ResponseBody = respBody.Content() + } + h.onHTTPRoundTrip(info) + } } func (h *h2Handler) setHeader(w http.ResponseWriter, header http.Header) { diff --git a/internal/gostx/internal/util/tls/tls.go b/internal/gostx/internal/util/tls/tls.go index dbf3d53..21f9fc6 100644 --- a/internal/gostx/internal/util/tls/tls.go +++ b/internal/gostx/internal/util/tls/tls.go @@ -2,6 +2,8 @@ package tls import ( "crypto" + "crypto/ecdsa" + "crypto/ed25519" "crypto/rand" "crypto/tls" "crypto/x509" @@ -423,6 +425,8 @@ func GenerateCertificate(serverName string, validity time.Duration, caCert *x509 serverName = host } + sigAlg := sigAlgorithm(caKey) + tmpl := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().UnixNano() / 100000), Subject: pkix.Name{ @@ -430,7 +434,7 @@ func GenerateCertificate(serverName string, validity time.Duration, caCert *x509 }, NotBefore: time.Now().Add(-validity), NotAfter: time.Now().Add(validity), - SignatureAlgorithm: x509.SHA256WithRSA, + SignatureAlgorithm: sigAlg, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, } @@ -454,6 +458,18 @@ func GenerateCertificate(serverName string, validity time.Duration, caCert *x509 return x509.ParseCertificate(raw) } +// sigAlgorithm returns the appropriate x509.SignatureAlgorithm for the given private key. +func sigAlgorithm(key crypto.PrivateKey) x509.SignatureAlgorithm { + switch key.(type) { + case *ecdsa.PrivateKey: + return x509.ECDSAWithSHA256 + case ed25519.PrivateKey: + return x509.PureEd25519 + default: + return x509.SHA256WithRSA + } +} + // https://pkg.go.dev/crypto#PrivateKey type privateKey interface { Public() crypto.PublicKey From 2b96f42d25f6b8cbb92812a1c0cf8844b1226f47 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Sun, 8 Mar 2026 14:26:22 -0600 Subject: [PATCH 02/13] =?UTF-8?q?feat:=20Phase=201=20observability=20?= =?UTF-8?q?=E2=80=94=20capture=20and=20display=20MITM=20HTTP=20transaction?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add end-to-end HTTP transaction recording for MITM-intercepted requests: - DB: migration for http_transactions table with indexes - Models: HttpTransaction, HttpTransactionJSON, HttpTransactionCreateInput - CRUD: create, get, query with filtering (container, destination, method, date range) - API: GET /api/transactions (list) and GET /api/transactions/:id (detail with body) - UI: Traffic page with HTMX table, filters, pagination, expandable row details - Hook: GlobalHTTPRoundTripHook in sniffer + bridge via gostx.SetGlobalMitmHook - Wire: program.go connects MITM hook to DB storage and EventBus - Fix: auther srcAddrKey type mismatch — use xctx.SrcAddrFromContext for correct client IP resolution (was causing unknown-unknown container names) - Fix: remove custom lt/gt template funcs that shadowed Go builtins and broke int64 comparisons in traffic table (only first row rendered) - Tests: API handler tests (9), HTMX route tests (10), CRUD tests (5), plugin tests --- cmd/greyproxy/program.go | 40 +++ .../gostx/internal/util/sniffing/sniffer.go | 36 +- internal/gostx/mitm_hook.go | 46 +++ internal/greyproxy/api/router.go | 3 + internal/greyproxy/api/transactions.go | 90 +++++ internal/greyproxy/api/transactions_test.go | 281 +++++++++++++++ internal/greyproxy/crud.go | 163 +++++++++ internal/greyproxy/crud_test.go | 235 ++++++++++++- internal/greyproxy/events.go | 1 + internal/greyproxy/migrations.go | 28 ++ internal/greyproxy/models.go | 123 +++++++ internal/greyproxy/plugins/auther.go | 9 +- internal/greyproxy/plugins/bypass.go | 5 +- internal/greyproxy/plugins/plugins_test.go | 2 +- internal/greyproxy/ui/pages.go | 69 +++- internal/greyproxy/ui/pages_test.go | 322 ++++++++++++++++++ internal/greyproxy/ui/templates/base.html | 14 + .../ui/templates/partials/traffic_table.html | 102 ++++++ internal/greyproxy/ui/templates/traffic.html | 152 +++++++++ 19 files changed, 1699 insertions(+), 22 deletions(-) create mode 100644 internal/gostx/mitm_hook.go create mode 100644 internal/greyproxy/api/transactions.go create mode 100644 internal/greyproxy/api/transactions_test.go create mode 100644 internal/greyproxy/ui/pages_test.go create mode 100644 internal/greyproxy/ui/templates/partials/traffic_table.html create mode 100644 internal/greyproxy/ui/templates/traffic.html diff --git a/cmd/greyproxy/program.go b/cmd/greyproxy/program.go index d53024a..eaae8bc 100644 --- a/cmd/greyproxy/program.go +++ b/cmd/greyproxy/program.go @@ -20,6 +20,7 @@ import ( greyproxy_api "github.com/greyhavenhq/greyproxy/internal/greyproxy/api" greyproxy_plugins "github.com/greyhavenhq/greyproxy/internal/greyproxy/plugins" greyproxy_ui "github.com/greyhavenhq/greyproxy/internal/greyproxy/ui" + "github.com/greyhavenhq/greyproxy/internal/gostx" "github.com/greyhavenhq/greyproxy/internal/gostx/config" "github.com/greyhavenhq/greyproxy/internal/gostx/config/loader" auth_parser "github.com/greyhavenhq/greyproxy/internal/gostx/config/parsing/auth" @@ -327,6 +328,45 @@ func (p *program) buildGreyproxyService() error { // Set the shared DNS cache so the DNS handler wrapper can populate it greyproxy_plugins.SetSharedDNSCache(shared.Cache) + // Wire MITM HTTP round-trip hook to store transactions in the database + gostx.SetGlobalMitmHook(func(info gostx.MitmRoundTripInfo) { + host, portStr, _ := net.SplitHostPort(info.Host) + if host == "" { + host = info.Host + } + port, _ := strconv.Atoi(portStr) + if port == 0 { + port = 443 + } + containerName, _ := greyproxy_plugins.ResolveIdentity(info.ContainerName) + go func() { + txn, err := greyproxy.CreateHttpTransaction(shared.DB, greyproxy.HttpTransactionCreateInput{ + ContainerName: containerName, + DestinationHost: host, + DestinationPort: port, + Method: info.Method, + URL: "https://" + info.Host + info.URI, + RequestHeaders: info.RequestHeaders, + RequestBody: info.RequestBody, + RequestContentType: info.RequestHeaders.Get("Content-Type"), + StatusCode: info.StatusCode, + ResponseHeaders: info.ResponseHeaders, + ResponseBody: info.ResponseBody, + ResponseContentType: info.ResponseHeaders.Get("Content-Type"), + DurationMs: info.DurationMs, + Result: "auto", + }) + if err != nil { + log.Warnf("failed to store HTTP transaction: %v", err) + return + } + shared.Bus.Publish(greyproxy.Event{ + Type: greyproxy.EventTransactionNew, + Data: txn.ToJSON(false), + }) + }() + }) + // Create and register gost plugins autherPlugin := greyproxy_plugins.NewAuther() admissionPlugin := greyproxy_plugins.NewAdmission() diff --git a/internal/gostx/internal/util/sniffing/sniffer.go b/internal/gostx/internal/util/sniffing/sniffer.go index 371a6d7..f7686d9 100644 --- a/internal/gostx/internal/util/sniffing/sniffer.go +++ b/internal/gostx/internal/util/sniffing/sniffer.go @@ -112,8 +112,14 @@ type HTTPRoundTripInfo struct { RequestBody []byte ResponseHeaders http.Header ResponseBody []byte + ContainerName string + DurationMs int64 } +// GlobalHTTPRoundTripHook is called (if set) after each MITM-intercepted HTTP round-trip. +// Set this from program initialization to record transactions to the database. +var GlobalHTTPRoundTripHook func(info HTTPRoundTripInfo) + type Sniffer struct { Websocket bool WebsocketSampleRate float64 @@ -437,7 +443,11 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, return } - if h.OnHTTPRoundTrip != nil { + if h.OnHTTPRoundTrip != nil || GlobalHTTPRoundTripHook != nil { + containerName := string(xctx.ClientIDFromContext(ctx)) + if containerName == "" { + containerName = ro.ClientID + } info := HTTPRoundTripInfo{ Host: req.Host, Method: req.Method, @@ -446,6 +456,8 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, StatusCode: resp.StatusCode, RequestHeaders: ro.HTTP.Request.Header, ResponseHeaders: ro.HTTP.Response.Header, + ContainerName: containerName, + DurationMs: time.Since(ro.Time).Milliseconds(), } if reqBody != nil { info.RequestBody = reqBody.Content() @@ -453,7 +465,12 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, if respBody != nil { info.ResponseBody = respBody.Content() } - h.OnHTTPRoundTrip(info) + if h.OnHTTPRoundTrip != nil { + h.OnHTTPRoundTrip(info) + } + if GlobalHTTPRoundTripHook != nil { + GlobalHTTPRoundTripHook(info) + } } if resp.ContentLength >= 0 { @@ -946,7 +963,11 @@ func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ro.HTTP.Response.ContentLength = respBody.Length() } - if h.onHTTPRoundTrip != nil { + if h.onHTTPRoundTrip != nil || GlobalHTTPRoundTripHook != nil { + containerName := string(xctx.ClientIDFromContext(r.Context())) + if containerName == "" { + containerName = ro.ClientID + } info := HTTPRoundTripInfo{ Host: r.Host, Method: r.Method, @@ -955,6 +976,8 @@ func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { StatusCode: resp.StatusCode, RequestHeaders: ro.HTTP.Request.Header, ResponseHeaders: ro.HTTP.Response.Header, + ContainerName: containerName, + DurationMs: time.Since(ro.Time).Milliseconds(), } if reqBody != nil { info.RequestBody = reqBody.Content() @@ -962,7 +985,12 @@ func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if respBody != nil { info.ResponseBody = respBody.Content() } - h.onHTTPRoundTrip(info) + if h.onHTTPRoundTrip != nil { + h.onHTTPRoundTrip(info) + } + if GlobalHTTPRoundTripHook != nil { + GlobalHTTPRoundTripHook(info) + } } } diff --git a/internal/gostx/mitm_hook.go b/internal/gostx/mitm_hook.go new file mode 100644 index 0000000..7fab4c6 --- /dev/null +++ b/internal/gostx/mitm_hook.go @@ -0,0 +1,46 @@ +package gostx + +import ( + "net/http" + + "github.com/greyhavenhq/greyproxy/internal/gostx/internal/util/sniffing" +) + +// MitmRoundTripInfo contains decrypted HTTP request/response data from a MITM round-trip. +// This re-exports the internal sniffing type for use outside the gostx/internal package. +type MitmRoundTripInfo struct { + Host string + Method string + URI string + Proto string + StatusCode int + RequestHeaders http.Header + RequestBody []byte + ResponseHeaders http.Header + ResponseBody []byte + ContainerName string + DurationMs int64 +} + +// SetGlobalMitmHook sets a global callback that fires after every MITM-intercepted HTTP round-trip. +func SetGlobalMitmHook(hook func(info MitmRoundTripInfo)) { + if hook == nil { + sniffing.GlobalHTTPRoundTripHook = nil + return + } + sniffing.GlobalHTTPRoundTripHook = func(info sniffing.HTTPRoundTripInfo) { + hook(MitmRoundTripInfo{ + Host: info.Host, + Method: info.Method, + URI: info.URI, + Proto: info.Proto, + StatusCode: info.StatusCode, + RequestHeaders: info.RequestHeaders, + RequestBody: info.RequestBody, + ResponseHeaders: info.ResponseHeaders, + ResponseBody: info.ResponseBody, + ContainerName: info.ContainerName, + DurationMs: info.DurationMs, + }) + } +} diff --git a/internal/greyproxy/api/router.go b/internal/greyproxy/api/router.go index 64a0ae4..5b25387 100644 --- a/internal/greyproxy/api/router.go +++ b/internal/greyproxy/api/router.go @@ -74,6 +74,9 @@ func NewRouter(s *Shared, pathPrefix string) (*gin.Engine, *gin.RouterGroup) { api.GET("/settings", SettingsGetHandler(s)) api.PUT("/settings", SettingsUpdateHandler(s)) + + api.GET("/transactions", TransactionsListHandler(s)) + api.GET("/transactions/:id", TransactionsDetailHandler(s)) } // WebSocket diff --git a/internal/greyproxy/api/transactions.go b/internal/greyproxy/api/transactions.go new file mode 100644 index 0000000..6ef121a --- /dev/null +++ b/internal/greyproxy/api/transactions.go @@ -0,0 +1,90 @@ +package api + +import ( + "math" + "net/http" + "strconv" + "time" + + "github.com/gin-gonic/gin" + greyproxy "github.com/greyhavenhq/greyproxy/internal/greyproxy" +) + +func TransactionsListHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50")) + offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0")) + + f := greyproxy.TransactionFilter{ + Container: c.Query("container"), + Destination: c.Query("destination"), + Method: c.Query("method"), + Limit: limit, + Offset: offset, + } + + if v := c.Query("from_date"); v != "" { + if t, err := time.Parse(time.RFC3339, v); err == nil { + f.FromDate = &t + } else if t, err := time.Parse("2006-01-02T15:04", v); err == nil { + f.FromDate = &t + } + } + if v := c.Query("to_date"); v != "" { + if t, err := time.Parse(time.RFC3339, v); err == nil { + f.ToDate = &t + } else if t, err := time.Parse("2006-01-02T15:04", v); err == nil { + f.ToDate = &t + } + } + + items, total, err := greyproxy.QueryHttpTransactions(s.DB, f) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + jsonItems := make([]greyproxy.HttpTransactionJSON, len(items)) + for i, item := range items { + jsonItems[i] = item.ToJSON(false) + } + + page := 1 + if limit > 0 && offset > 0 { + page = offset/limit + 1 + } + pages := 1 + if limit > 0 && total > 0 { + pages = int(math.Ceil(float64(total) / float64(limit))) + } + + c.JSON(http.StatusOK, gin.H{ + "items": jsonItems, + "total": total, + "page": page, + "pages": pages, + }) + } +} + +func TransactionsDetailHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) + return + } + + txn, err := greyproxy.GetHttpTransaction(s.DB, id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if txn == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "transaction not found"}) + return + } + + c.JSON(http.StatusOK, txn.ToJSON(true)) + } +} diff --git a/internal/greyproxy/api/transactions_test.go b/internal/greyproxy/api/transactions_test.go new file mode 100644 index 0000000..baedbb9 --- /dev/null +++ b/internal/greyproxy/api/transactions_test.go @@ -0,0 +1,281 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/gin-gonic/gin" + greyproxy "github.com/greyhavenhq/greyproxy/internal/greyproxy" + _ "modernc.org/sqlite" +) + +func setupTestShared(t *testing.T) *Shared { + t.Helper() + + tmpFile, err := os.CreateTemp("", "greyproxy_api_test_*.db") + if err != nil { + t.Fatal(err) + } + tmpFile.Close() + t.Cleanup(func() { os.Remove(tmpFile.Name()) }) + + db, err := greyproxy.OpenDB(tmpFile.Name()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) + + if err := db.Migrate(); err != nil { + t.Fatal(err) + } + + return &Shared{ + DB: db, + Bus: greyproxy.NewEventBus(), + } +} + +func seedTransactions(t *testing.T, s *Shared) { + t.Helper() + txns := []greyproxy.HttpTransactionCreateInput{ + { + ContainerName: "webapp", + DestinationHost: "api.example.com", + DestinationPort: 443, + Method: "GET", + URL: "https://api.example.com/users", + RequestHeaders: http.Header{"Accept": {"application/json"}}, + StatusCode: 200, + ResponseHeaders: http.Header{"Content-Type": {"application/json"}}, + ResponseBody: []byte(`{"users":[]}`), + ResponseContentType: "application/json", + DurationMs: 42, + Result: "auto", + }, + { + ContainerName: "webapp", + DestinationHost: "api.example.com", + DestinationPort: 443, + Method: "POST", + URL: "https://api.example.com/users", + RequestHeaders: http.Header{"Content-Type": {"application/json"}}, + RequestBody: []byte(`{"name":"alice"}`), + RequestContentType: "application/json", + StatusCode: 201, + ResponseHeaders: http.Header{"Content-Type": {"application/json"}}, + ResponseBody: []byte(`{"id":1,"name":"alice"}`), + ResponseContentType: "application/json", + DurationMs: 85, + Result: "auto", + }, + { + ContainerName: "worker", + DestinationHost: "storage.example.com", + DestinationPort: 443, + Method: "PUT", + URL: "https://storage.example.com/files/report.pdf", + StatusCode: 500, + DurationMs: 300, + Result: "auto", + }, + } + for _, input := range txns { + if _, err := greyproxy.CreateHttpTransaction(s.DB, input); err != nil { + t.Fatal(err) + } + } +} + +func TestTransactionsListAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + s := setupTestShared(t) + seedTransactions(t, s) + + r := gin.New() + r.GET("/api/transactions", TransactionsListHandler(s)) + + t.Run("returns all transactions", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions", nil) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status: got %d, want 200", w.Code) + } + + var resp struct { + Items []greyproxy.HttpTransactionJSON `json:"items"` + Total int `json:"total"` + Page int `json:"page"` + Pages int `json:"pages"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp.Total != 3 { + t.Errorf("total: got %d, want 3", resp.Total) + } + if len(resp.Items) != 3 { + t.Fatalf("items: got %d, want 3", len(resp.Items)) + } + // Most recent first + if resp.Items[0].Method != "PUT" { + t.Errorf("first item method: got %q, want PUT", resp.Items[0].Method) + } + // List view should NOT include bodies + if resp.Items[0].RequestBody != nil { + t.Error("list view should not include request_body") + } + if resp.Items[0].ResponseBody != nil { + t.Error("list view should not include response_body") + } + }) + + t.Run("filter by method", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions?method=GET", nil) + r.ServeHTTP(w, req) + + var resp struct { + Items []greyproxy.HttpTransactionJSON `json:"items"` + Total int `json:"total"` + } + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Total != 1 { + t.Errorf("total: got %d, want 1", resp.Total) + } + if len(resp.Items) != 1 || resp.Items[0].Method != "GET" { + t.Error("expected single GET transaction") + } + }) + + t.Run("filter by container", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions?container=worker", nil) + r.ServeHTTP(w, req) + + var resp struct { + Items []greyproxy.HttpTransactionJSON `json:"items"` + Total int `json:"total"` + } + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Total != 1 { + t.Errorf("total: got %d, want 1", resp.Total) + } + }) + + t.Run("filter by destination", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions?destination=storage", nil) + r.ServeHTTP(w, req) + + var resp struct { + Items []greyproxy.HttpTransactionJSON `json:"items"` + Total int `json:"total"` + } + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Total != 1 { + t.Errorf("total: got %d, want 1", resp.Total) + } + }) + + t.Run("pagination", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions?limit=2", nil) + r.ServeHTTP(w, req) + + var resp struct { + Items []greyproxy.HttpTransactionJSON `json:"items"` + Total int `json:"total"` + Page int `json:"page"` + Pages int `json:"pages"` + } + json.Unmarshal(w.Body.Bytes(), &resp) + if len(resp.Items) != 2 { + t.Errorf("items: got %d, want 2", len(resp.Items)) + } + if resp.Total != 3 { + t.Errorf("total: got %d, want 3", resp.Total) + } + if resp.Pages != 2 { + t.Errorf("pages: got %d, want 2", resp.Pages) + } + }) + + t.Run("empty result", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions?method=DELETE", nil) + r.ServeHTTP(w, req) + + var resp struct { + Items []greyproxy.HttpTransactionJSON `json:"items"` + Total int `json:"total"` + } + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Total != 0 { + t.Errorf("total: got %d, want 0", resp.Total) + } + if len(resp.Items) != 0 { + t.Errorf("items: got %d, want 0", len(resp.Items)) + } + }) +} + +func TestTransactionsDetailAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + s := setupTestShared(t) + seedTransactions(t, s) + + r := gin.New() + r.GET("/api/transactions/:id", TransactionsDetailHandler(s)) + + t.Run("returns full transaction with body", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions/2", nil) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status: got %d, want 200", w.Code) + } + + var txn greyproxy.HttpTransactionJSON + if err := json.Unmarshal(w.Body.Bytes(), &txn); err != nil { + t.Fatal(err) + } + if txn.Method != "POST" { + t.Errorf("method: got %q, want POST", txn.Method) + } + if txn.RequestBody == nil || *txn.RequestBody != `{"name":"alice"}` { + t.Errorf("request_body missing or wrong: %v", txn.RequestBody) + } + if txn.ResponseBody == nil || *txn.ResponseBody != `{"id":1,"name":"alice"}` { + t.Errorf("response_body missing or wrong: %v", txn.ResponseBody) + } + if txn.StatusCode == nil || *txn.StatusCode != 201 { + t.Errorf("status_code: got %v, want 201", txn.StatusCode) + } + }) + + t.Run("not found", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions/999", nil) + r.ServeHTTP(w, req) + + if w.Code != 404 { + t.Errorf("status: got %d, want 404", w.Code) + } + }) + + t.Run("invalid id", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/transactions/abc", nil) + r.ServeHTTP(w, req) + + if w.Code != 400 { + t.Errorf("status: got %d, want 400", w.Code) + } + }) +} diff --git a/internal/greyproxy/crud.go b/internal/greyproxy/crud.go index 816438c..7d08836 100644 --- a/internal/greyproxy/crud.go +++ b/internal/greyproxy/crud.go @@ -2,6 +2,7 @@ package greyproxy import ( "database/sql" + "encoding/json" "fmt" "sort" "strings" @@ -1054,3 +1055,165 @@ func GetDashboardStats(db *DB, fromDate, toDate time.Time, groupBy string, recen return stats, nil } + +// --- HTTP Transactions --- + +// MaxBodyCapture is the default max bytes to store per request/response body. +const MaxBodyCapture = 1048576 // 1MB + +func CreateHttpTransaction(db *DB, input HttpTransactionCreateInput) (*HttpTransaction, error) { + db.Lock() + defer db.Unlock() + + if input.Result == "" { + input.Result = "auto" + } + + var reqHeadersJSON sql.NullString + if input.RequestHeaders != nil { + b, _ := json.Marshal(input.RequestHeaders) + reqHeadersJSON = sql.NullString{String: string(b), Valid: true} + } + + var respHeadersJSON sql.NullString + if input.ResponseHeaders != nil { + b, _ := json.Marshal(input.ResponseHeaders) + respHeadersJSON = sql.NullString{String: string(b), Valid: true} + } + + reqBody := input.RequestBody + reqBodySize := int64(len(reqBody)) + if len(reqBody) > MaxBodyCapture { + reqBody = reqBody[:MaxBodyCapture] + } + + respBody := input.ResponseBody + respBodySize := int64(len(respBody)) + if len(respBody) > MaxBodyCapture { + respBody = respBody[:MaxBodyCapture] + } + + result, err := db.WriteDB().Exec( + `INSERT INTO http_transactions (container_name, destination_host, destination_port, + method, url, request_headers, request_body, request_body_size, request_content_type, + status_code, response_headers, response_body, response_body_size, response_content_type, + duration_ms, rule_id, result) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + input.ContainerName, input.DestinationHost, input.DestinationPort, + input.Method, input.URL, + reqHeadersJSON, reqBody, reqBodySize, + sql.NullString{String: input.RequestContentType, Valid: input.RequestContentType != ""}, + sql.NullInt64{Int64: int64(input.StatusCode), Valid: input.StatusCode != 0}, + respHeadersJSON, respBody, respBodySize, + sql.NullString{String: input.ResponseContentType, Valid: input.ResponseContentType != ""}, + sql.NullInt64{Int64: input.DurationMs, Valid: input.DurationMs > 0}, + sql.NullInt64{Int64: ptrInt64OrZero(input.RuleID), Valid: input.RuleID != nil}, + input.Result, + ) + if err != nil { + return nil, fmt.Errorf("insert http_transaction: %w", err) + } + + id, _ := result.LastInsertId() + return getHttpTransactionByID(db.WriteDB(), id) +} + +func getHttpTransactionByID(conn *sql.DB, id int64) (*HttpTransaction, error) { + var t HttpTransaction + err := conn.QueryRow( + `SELECT id, timestamp, container_name, destination_host, destination_port, + method, url, request_headers, request_body, request_body_size, request_content_type, + status_code, response_headers, response_body, response_body_size, response_content_type, + duration_ms, rule_id, result + FROM http_transactions WHERE id = ?`, id, + ).Scan(&t.ID, &t.Timestamp, &t.ContainerName, &t.DestinationHost, &t.DestinationPort, + &t.Method, &t.URL, &t.RequestHeaders, &t.RequestBody, &t.RequestBodySize, &t.RequestContentType, + &t.StatusCode, &t.ResponseHeaders, &t.ResponseBody, &t.ResponseBodySize, &t.ResponseContentType, + &t.DurationMs, &t.RuleID, &t.Result) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &t, nil +} + +func GetHttpTransaction(db *DB, id int64) (*HttpTransaction, error) { + return getHttpTransactionByID(db.ReadDB(), id) +} + +type TransactionFilter struct { + Container string + Destination string + Method string + FromDate *time.Time + ToDate *time.Time + Limit int + Offset int +} + +func QueryHttpTransactions(db *DB, f TransactionFilter) ([]HttpTransaction, int, error) { + if f.Limit <= 0 { + f.Limit = 50 + } + + where := []string{"1=1"} + args := []any{} + + if f.Container != "" { + where = append(where, "container_name LIKE ?") + args = append(args, "%"+f.Container+"%") + } + if f.Destination != "" { + where = append(where, "destination_host LIKE ?") + args = append(args, "%"+f.Destination+"%") + } + if f.Method != "" { + where = append(where, "method = ?") + args = append(args, f.Method) + } + if f.FromDate != nil { + where = append(where, "timestamp >= ?") + args = append(args, f.FromDate.UTC().Format("2006-01-02 15:04:05")) + } + if f.ToDate != nil { + where = append(where, "timestamp <= ?") + args = append(args, f.ToDate.UTC().Format("2006-01-02 15:04:05")) + } + + whereClause := strings.Join(where, " AND ") + + var total int + err := db.ReadDB().QueryRow("SELECT COUNT(*) FROM http_transactions WHERE "+whereClause, args...).Scan(&total) + if err != nil { + return nil, 0, err + } + + // List query excludes body blobs for performance + rows, err := db.ReadDB().Query( + `SELECT id, timestamp, container_name, destination_host, destination_port, + method, url, request_headers, NULL, request_body_size, request_content_type, + status_code, response_headers, NULL, response_body_size, response_content_type, + duration_ms, rule_id, result + FROM http_transactions WHERE `+whereClause+` ORDER BY timestamp DESC LIMIT ? OFFSET ?`, + append(args, f.Limit, f.Offset)..., + ) + if err != nil { + return nil, 0, err + } + defer rows.Close() + + var txns []HttpTransaction + for rows.Next() { + var t HttpTransaction + if err := rows.Scan(&t.ID, &t.Timestamp, &t.ContainerName, &t.DestinationHost, &t.DestinationPort, + &t.Method, &t.URL, &t.RequestHeaders, &t.RequestBody, &t.RequestBodySize, &t.RequestContentType, + &t.StatusCode, &t.ResponseHeaders, &t.ResponseBody, &t.ResponseBodySize, &t.ResponseContentType, + &t.DurationMs, &t.RuleID, &t.Result); err != nil { + return nil, 0, err + } + txns = append(txns, t) + } + return txns, total, nil +} diff --git a/internal/greyproxy/crud_test.go b/internal/greyproxy/crud_test.go index 71a03ce..b09d21b 100644 --- a/internal/greyproxy/crud_test.go +++ b/internal/greyproxy/crud_test.go @@ -2,6 +2,7 @@ package greyproxy import ( "database/sql" + "net/http" "os" "testing" "time" @@ -617,7 +618,7 @@ func TestMigrations(t *testing.T) { db := setupTestDB(t) // Verify tables exist - tables := []string{"rules", "pending_requests", "request_logs", "schema_migrations"} + tables := []string{"rules", "pending_requests", "request_logs", "http_transactions", "schema_migrations"} for _, table := range tables { var name string err := db.ReadDB().QueryRow( @@ -636,8 +637,8 @@ func TestMigrations(t *testing.T) { // Verify migration versions were recorded var count int db.ReadDB().QueryRow("SELECT COUNT(*) FROM schema_migrations").Scan(&count) - if count != 3 { - t.Errorf("expected 3 migration versions, got %d", count) + if count != 4 { + t.Errorf("expected 4 migration versions, got %d", count) } } @@ -742,3 +743,231 @@ func TestRuleToJSON(t *testing.T) { t.Error("expected IsActive to be true for rule with future expiration") } } + +func TestCreateHttpTransaction(t *testing.T) { + db := setupTestDB(t) + + txn, err := CreateHttpTransaction(db, HttpTransactionCreateInput{ + ContainerName: "claude-code", + DestinationHost: "api.anthropic.com", + DestinationPort: 443, + Method: "POST", + URL: "https://api.anthropic.com/v1/messages", + RequestHeaders: http.Header{"Content-Type": {"application/json"}, "Authorization": {"Bearer sk-ant-xxx"}}, + RequestBody: []byte(`{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"hello"}]}`), + RequestContentType: "application/json", + StatusCode: 200, + ResponseHeaders: http.Header{"Content-Type": {"application/json"}}, + ResponseBody: []byte(`{"content":[{"text":"Hello!"}]}`), + ResponseContentType: "application/json", + DurationMs: 150, + Result: "auto", + }) + if err != nil { + t.Fatalf("CreateHttpTransaction: %v", err) + } + if txn == nil { + t.Fatal("expected non-nil transaction") + } + if txn.ContainerName != "claude-code" { + t.Errorf("container_name = %q, want %q", txn.ContainerName, "claude-code") + } + if txn.Method != "POST" { + t.Errorf("method = %q, want %q", txn.Method, "POST") + } + if txn.DestinationHost != "api.anthropic.com" { + t.Errorf("destination_host = %q, want %q", txn.DestinationHost, "api.anthropic.com") + } + if !txn.StatusCode.Valid || txn.StatusCode.Int64 != 200 { + t.Errorf("status_code = %v, want 200", txn.StatusCode) + } + if !txn.DurationMs.Valid || txn.DurationMs.Int64 != 150 { + t.Errorf("duration_ms = %v, want 150", txn.DurationMs) + } +} + +func TestGetHttpTransaction(t *testing.T) { + db := setupTestDB(t) + + created, _ := CreateHttpTransaction(db, HttpTransactionCreateInput{ + ContainerName: "test-app", + DestinationHost: "example.com", + DestinationPort: 443, + Method: "GET", + URL: "https://example.com/api/data", + RequestBody: nil, + StatusCode: 200, + ResponseBody: []byte("response body content"), + DurationMs: 50, + Result: "auto", + }) + + got, err := GetHttpTransaction(db, created.ID) + if err != nil { + t.Fatalf("GetHttpTransaction: %v", err) + } + if got == nil { + t.Fatal("expected non-nil transaction") + } + if got.Method != "GET" { + t.Errorf("method = %q, want %q", got.Method, "GET") + } + if string(got.ResponseBody) != "response body content" { + t.Errorf("response_body = %q, want %q", string(got.ResponseBody), "response body content") + } + + // Not found + missing, err := GetHttpTransaction(db, 9999) + if err != nil { + t.Fatalf("GetHttpTransaction for missing: %v", err) + } + if missing != nil { + t.Error("expected nil for non-existent transaction") + } +} + +func TestQueryHttpTransactions(t *testing.T) { + db := setupTestDB(t) + + // Create several transactions + for _, m := range []string{"GET", "POST", "DELETE"} { + CreateHttpTransaction(db, HttpTransactionCreateInput{ + ContainerName: "app1", + DestinationHost: "api.example.com", + DestinationPort: 443, + Method: m, + URL: "https://api.example.com/test", + StatusCode: 200, + Result: "auto", + }) + } + CreateHttpTransaction(db, HttpTransactionCreateInput{ + ContainerName: "app2", + DestinationHost: "other.example.com", + DestinationPort: 443, + Method: "GET", + URL: "https://other.example.com/", + StatusCode: 200, + Result: "auto", + }) + + // List all + txns, total, err := QueryHttpTransactions(db, TransactionFilter{}) + if err != nil { + t.Fatalf("QueryHttpTransactions: %v", err) + } + if total != 4 { + t.Errorf("total = %d, want 4", total) + } + if len(txns) != 4 { + t.Errorf("len(txns) = %d, want 4", len(txns)) + } + + // Filter by method + txns, total, _ = QueryHttpTransactions(db, TransactionFilter{Method: "POST"}) + if total != 1 { + t.Errorf("total with method=POST = %d, want 1", total) + } + + // Filter by destination + txns, total, _ = QueryHttpTransactions(db, TransactionFilter{Destination: "other"}) + if total != 1 { + t.Errorf("total with destination=other = %d, want 1", total) + } + + // Filter by container + txns, total, _ = QueryHttpTransactions(db, TransactionFilter{Container: "app2"}) + if total != 1 { + t.Errorf("total with container=app2 = %d, want 1", total) + } + if txns[0].ContainerName != "app2" { + t.Errorf("container_name = %q, want %q", txns[0].ContainerName, "app2") + } + + // List query should NOT include body blobs + txns, _, _ = QueryHttpTransactions(db, TransactionFilter{}) + for _, tx := range txns { + if tx.RequestBody != nil { + t.Error("list query should not include request_body") + } + if tx.ResponseBody != nil { + t.Error("list query should not include response_body") + } + } +} + +func TestHttpTransactionBodyTruncation(t *testing.T) { + db := setupTestDB(t) + + // Create a transaction with body larger than MaxBodyCapture + largeBody := make([]byte, MaxBodyCapture+1000) + for i := range largeBody { + largeBody[i] = 'A' + } + + txn, err := CreateHttpTransaction(db, HttpTransactionCreateInput{ + ContainerName: "test", + DestinationHost: "example.com", + DestinationPort: 443, + Method: "POST", + URL: "https://example.com/upload", + RequestBody: largeBody, + StatusCode: 200, + Result: "auto", + }) + if err != nil { + t.Fatalf("CreateHttpTransaction: %v", err) + } + + got, _ := GetHttpTransaction(db, txn.ID) + + // Stored body should be truncated to MaxBodyCapture + if len(got.RequestBody) != MaxBodyCapture { + t.Errorf("stored body length = %d, want %d", len(got.RequestBody), MaxBodyCapture) + } + + // But request_body_size should reflect the original size + if !got.RequestBodySize.Valid || got.RequestBodySize.Int64 != int64(MaxBodyCapture+1000) { + t.Errorf("request_body_size = %v, want %d", got.RequestBodySize, MaxBodyCapture+1000) + } +} + +func TestHttpTransactionToJSON(t *testing.T) { + txn := HttpTransaction{ + ID: 1, + Timestamp: time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC), + ContainerName: "claude-code", + DestinationHost: "api.anthropic.com", + DestinationPort: 443, + Method: "POST", + URL: "https://api.anthropic.com/v1/messages", + RequestHeaders: sql.NullString{String: `{"Content-Type":["application/json"]}`, Valid: true}, + RequestBody: []byte(`{"model":"claude-sonnet-4-20250514"}`), + RequestBodySize: sql.NullInt64{Int64: 30, Valid: true}, + StatusCode: sql.NullInt64{Int64: 200, Valid: true}, + ResponseBody: []byte(`{"content":"hello"}`), + DurationMs: sql.NullInt64{Int64: 150, Valid: true}, + Result: "auto", + } + + // Without body + j := txn.ToJSON(false) + if j.RequestBody != nil { + t.Error("ToJSON(false) should not include request_body") + } + if j.ResponseBody != nil { + t.Error("ToJSON(false) should not include response_body") + } + if j.ContainerName != "claude-code" { + t.Errorf("container_name = %q, want %q", j.ContainerName, "claude-code") + } + + // With body + j = txn.ToJSON(true) + if j.RequestBody == nil || *j.RequestBody != `{"model":"claude-sonnet-4-20250514"}` { + t.Errorf("ToJSON(true) request_body = %v, want the body content", j.RequestBody) + } + if j.ResponseBody == nil || *j.ResponseBody != `{"content":"hello"}` { + t.Errorf("ToJSON(true) response_body = %v, want the body content", j.ResponseBody) + } +} diff --git a/internal/greyproxy/events.go b/internal/greyproxy/events.go index 16568bb..876959e 100644 --- a/internal/greyproxy/events.go +++ b/internal/greyproxy/events.go @@ -12,6 +12,7 @@ const ( EventPendingAllowed = "pending_request.allowed" EventPendingDismissed = "pending_request.dismissed" EventWaitersChanged = "waiters.changed" + EventTransactionNew = "transaction.new" ) // Event represents a broadcast event. diff --git a/internal/greyproxy/migrations.go b/internal/greyproxy/migrations.go index ec0a8dc..8a66b38 100644 --- a/internal/greyproxy/migrations.go +++ b/internal/greyproxy/migrations.go @@ -59,6 +59,34 @@ var migrations = []string{ CREATE INDEX IF NOT EXISTS idx_logs_container ON request_logs(container_name); CREATE INDEX IF NOT EXISTS idx_logs_destination ON request_logs(destination_host); CREATE INDEX IF NOT EXISTS idx_logs_result ON request_logs(result);`, + + // Migration 4: Create http_transactions table for MITM-captured HTTP request/response data + `CREATE TABLE IF NOT EXISTS http_transactions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp DATETIME NOT NULL DEFAULT (datetime('now')), + container_name TEXT NOT NULL, + destination_host TEXT NOT NULL, + destination_port INTEGER NOT NULL, + + method TEXT NOT NULL, + url TEXT NOT NULL, + request_headers TEXT, + request_body BLOB, + request_body_size INTEGER, + request_content_type TEXT, + + status_code INTEGER, + response_headers TEXT, + response_body BLOB, + response_body_size INTEGER, + response_content_type TEXT, + + duration_ms INTEGER, + rule_id INTEGER, + result TEXT NOT NULL DEFAULT 'auto' + ); + CREATE INDEX IF NOT EXISTS idx_http_transactions_ts ON http_transactions(timestamp); + CREATE INDEX IF NOT EXISTS idx_http_transactions_dest ON http_transactions(destination_host, destination_port);`, } func runMigrations(db *sql.DB) error { diff --git a/internal/greyproxy/models.go b/internal/greyproxy/models.go index f2d27ce..07e86bf 100644 --- a/internal/greyproxy/models.go +++ b/internal/greyproxy/models.go @@ -2,6 +2,8 @@ package greyproxy import ( "database/sql" + "encoding/json" + "net/http" "time" ) @@ -180,6 +182,127 @@ func (l *RequestLog) DisplayHost() string { return l.DestinationHost } +// HttpTransaction represents a MITM-captured HTTP request/response pair. +type HttpTransaction struct { + ID int64 `json:"id"` + Timestamp time.Time `json:"timestamp"` + ContainerName string `json:"container_name"` + DestinationHost string `json:"destination_host"` + DestinationPort int `json:"destination_port"` + Method string `json:"method"` + URL string `json:"url"` + RequestHeaders sql.NullString `json:"-"` + RequestBody []byte `json:"-"` + RequestBodySize sql.NullInt64 `json:"-"` + RequestContentType sql.NullString `json:"-"` + StatusCode sql.NullInt64 `json:"status_code"` + ResponseHeaders sql.NullString `json:"-"` + ResponseBody []byte `json:"-"` + ResponseBodySize sql.NullInt64 `json:"-"` + ResponseContentType sql.NullString `json:"-"` + DurationMs sql.NullInt64 `json:"duration_ms"` + RuleID sql.NullInt64 `json:"rule_id"` + Result string `json:"result"` +} + +type HttpTransactionJSON struct { + ID int64 `json:"id"` + Timestamp string `json:"timestamp"` + ContainerName string `json:"container_name"` + DestinationHost string `json:"destination_host"` + DestinationPort int `json:"destination_port"` + Method string `json:"method"` + URL string `json:"url"` + RequestHeaders any `json:"request_headers,omitempty"` + RequestBody *string `json:"request_body,omitempty"` + RequestBodySize *int64 `json:"request_body_size,omitempty"` + RequestContentType *string `json:"request_content_type,omitempty"` + StatusCode *int64 `json:"status_code,omitempty"` + ResponseHeaders any `json:"response_headers,omitempty"` + ResponseBody *string `json:"response_body,omitempty"` + ResponseBodySize *int64 `json:"response_body_size,omitempty"` + ResponseContentType *string `json:"response_content_type,omitempty"` + DurationMs *int64 `json:"duration_ms,omitempty"` + RuleID *int64 `json:"rule_id,omitempty"` + Result string `json:"result"` +} + +func (t *HttpTransaction) ToJSON(includeBody bool) HttpTransactionJSON { + j := HttpTransactionJSON{ + ID: t.ID, + Timestamp: t.Timestamp.UTC().Format(time.RFC3339), + ContainerName: t.ContainerName, + DestinationHost: t.DestinationHost, + DestinationPort: t.DestinationPort, + Method: t.Method, + URL: t.URL, + Result: t.Result, + } + if t.RequestHeaders.Valid { + var h map[string]any + if json.Unmarshal([]byte(t.RequestHeaders.String), &h) == nil { + j.RequestHeaders = h + } + } + if t.RequestBodySize.Valid { + j.RequestBodySize = &t.RequestBodySize.Int64 + } + if t.RequestContentType.Valid { + j.RequestContentType = &t.RequestContentType.String + } + if t.StatusCode.Valid { + j.StatusCode = &t.StatusCode.Int64 + } + if t.ResponseHeaders.Valid { + var h map[string]any + if json.Unmarshal([]byte(t.ResponseHeaders.String), &h) == nil { + j.ResponseHeaders = h + } + } + if t.ResponseBodySize.Valid { + j.ResponseBodySize = &t.ResponseBodySize.Int64 + } + if t.ResponseContentType.Valid { + j.ResponseContentType = &t.ResponseContentType.String + } + if t.DurationMs.Valid { + j.DurationMs = &t.DurationMs.Int64 + } + if t.RuleID.Valid { + j.RuleID = &t.RuleID.Int64 + } + if includeBody { + if len(t.RequestBody) > 0 { + s := string(t.RequestBody) + j.RequestBody = &s + } + if len(t.ResponseBody) > 0 { + s := string(t.ResponseBody) + j.ResponseBody = &s + } + } + return j +} + +// HttpTransactionCreateInput holds the data needed to create a transaction record. +type HttpTransactionCreateInput struct { + ContainerName string + DestinationHost string + DestinationPort int + Method string + URL string + RequestHeaders http.Header + RequestBody []byte + RequestContentType string + StatusCode int + ResponseHeaders http.Header + ResponseBody []byte + ResponseContentType string + DurationMs int64 + RuleID *int64 + Result string +} + // DashboardStats holds aggregated data for the dashboard. type DashboardStats struct { Period Period `json:"period"` diff --git a/internal/greyproxy/plugins/auther.go b/internal/greyproxy/plugins/auther.go index 64d9e17..011103a 100644 --- a/internal/greyproxy/plugins/auther.go +++ b/internal/greyproxy/plugins/auther.go @@ -8,6 +8,7 @@ import ( "github.com/greyhavenhq/greyproxy/internal/gostcore/auth" "github.com/greyhavenhq/greyproxy/internal/gostcore/logger" + xctx "github.com/greyhavenhq/greyproxy/internal/gostx/ctx" ) // Auther implements auth.Authenticator. @@ -46,9 +47,8 @@ func (a *Auther) Authenticate(ctx context.Context, user, password string, opts . } func extractClientIP(ctx context.Context) string { - // Try to get source address from context - // The SOCKS5 handler sets this in the context - if addr, ok := ctx.Value(srcAddrKey{}).(net.Addr); ok && addr != nil { + // Get source address from context using the canonical key from gostx/ctx + if addr := xctx.SrcAddrFromContext(ctx); addr != nil { host, _, err := net.SplitHostPort(addr.String()) if err == nil { return host @@ -58,9 +58,6 @@ func extractClientIP(ctx context.Context) string { return "unknown" } -// srcAddrKey matches the key used in github.com/go-gost/x/ctx -type srcAddrKey = struct{} - // ParseClientID splits a composite client ID "ip|username" into its components. func ParseClientID(clientID string) (ip, username string) { parts := strings.SplitN(clientID, "|", 2) diff --git a/internal/greyproxy/plugins/bypass.go b/internal/greyproxy/plugins/bypass.go index 7a9a574..e3e528c 100644 --- a/internal/greyproxy/plugins/bypass.go +++ b/internal/greyproxy/plugins/bypass.go @@ -65,7 +65,7 @@ func (b *Bypass) Contains(ctx context.Context, network, addr string, opts ...byp // Get client identity from context (set by auther) clientID := string(ctxvalue.ClientIDFromContext(ctx)) - containerName, containerID := resolveIdentity(clientID) + containerName, containerID := ResolveIdentity(clientID) // Resolve hostname resolvedHostname := b.resolveHostname(host) @@ -191,7 +191,8 @@ func (b *Bypass) resolveHostname(host string) string { return b.cache.ResolveIP(host) } -func resolveIdentity(clientID string) (containerName, containerID string) { +// ResolveIdentity maps a composite client ID ("ip|username") to a container name and ID. +func ResolveIdentity(clientID string) (containerName, containerID string) { ip, username := ParseClientID(clientID) if username != "" && username != "proxy" { diff --git a/internal/greyproxy/plugins/plugins_test.go b/internal/greyproxy/plugins/plugins_test.go index 1bbc2c8..5d28dfe 100644 --- a/internal/greyproxy/plugins/plugins_test.go +++ b/internal/greyproxy/plugins/plugins_test.go @@ -45,7 +45,7 @@ func TestResolveIdentity(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - container, _ := resolveIdentity(tt.clientID) + container, _ := ResolveIdentity(tt.clientID) if container != tt.wantContainer { t.Errorf("got %q, want %q", container, tt.wantContainer) } diff --git a/internal/greyproxy/ui/pages.go b/internal/greyproxy/ui/pages.go index 9fb67b9..4523cf2 100644 --- a/internal/greyproxy/ui/pages.go +++ b/internal/greyproxy/ui/pages.go @@ -66,12 +66,6 @@ var funcMap = template.FuncMap{ "sub": func(a, b int) int { return a - b }, - "gt": func(a, b int) bool { - return a > b - }, - "lt": func(a, b int) bool { - return a < b - }, "formatFloat": func(f float64) string { return fmt.Sprintf("%.1f", f) }, @@ -176,10 +170,13 @@ var ( logsTmpl = parseTemplate("base.html", "base.html", "logs.html") settingsTmpl = parseTemplate("base.html", "base.html", "settings.html") + trafficTmpl = parseTemplate("base.html", "base.html", "traffic.html") + dashboardStatsTmpl = parseTemplate("dashboard_stats.html", "partials/dashboard_stats.html") pendingListTmpl = parseTemplate("pending_list.html", "partials/pending_list.html") rulesListTmpl = parseTemplate("rules_list.html", "partials/rules_list.html") logsTableTmpl = parseTemplate("logs_table.html", "partials/logs_table.html") + trafficTableTmpl = parseTemplate("traffic_table.html", "partials/traffic_table.html") ) // cacheBuster is set once at startup for static asset cache busting. @@ -198,6 +195,7 @@ func getContainers(db *greyproxy.DB) []string { rows, err := db.ReadDB().Query( `SELECT DISTINCT container_name FROM pending_requests UNION SELECT DISTINCT container_name FROM request_logs + UNION SELECT DISTINCT container_name FROM http_transactions ORDER BY container_name`) if err != nil { return nil @@ -268,6 +266,17 @@ func RegisterPageRoutes(r *gin.RouterGroup, db *greyproxy.DB, bus *greyproxy.Eve Prefix: prefix, CacheBuster: cacheBuster, Title: "Settings - Greyproxy", + Containers: getContainers(db), + }) + }) + + r.GET("/traffic", func(c *gin.Context) { + trafficTmpl.Execute(c.Writer, PageData{ + CurrentPath: c.Request.URL.Path, + Prefix: prefix, + CacheBuster: cacheBuster, + Title: "HTTP Traffic - Greyproxy", + Containers: getContainers(db), }) }) } @@ -577,6 +586,54 @@ func RegisterHTMXRoutes(r *gin.RouterGroup, db *greyproxy.DB, bus *greyproxy.Eve "HasFilters": hasFilters, }) }) + + htmx.GET("/traffic-table", func(c *gin.Context) { + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50")) + offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0")) + + if page, err := strconv.Atoi(c.Query("page")); err == nil && page > 1 { + offset = (page - 1) * limit + } + + container := c.Query("container") + destination := c.Query("destination") + method := c.Query("method") + + f := greyproxy.TransactionFilter{ + Container: container, + Destination: destination, + Method: method, + Limit: limit, + Offset: offset, + } + + items, total, err := greyproxy.QueryHttpTransactions(db, f) + if err != nil { + c.String(http.StatusInternalServerError, "Error: %v", err) + return + } + + page := 1 + if limit > 0 && offset > 0 { + page = offset/limit + 1 + } + pages := 1 + if limit > 0 && total > 0 { + pages = int(math.Ceil(float64(total) / float64(limit))) + } + + hasFilters := container != "" || destination != "" || method != "" + + c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") + trafficTableTmpl.Execute(c.Writer, gin.H{ + "Prefix": prefix, + "Items": items, + "Total": total, + "Page": page, + "Pages": pages, + "HasFilters": hasFilters, + }) + }) } func enrichWaitingCounts(items []greyproxy.PendingRequest, waiters *greyproxy.WaiterTracker) { diff --git a/internal/greyproxy/ui/pages_test.go b/internal/greyproxy/ui/pages_test.go new file mode 100644 index 0000000..b281742 --- /dev/null +++ b/internal/greyproxy/ui/pages_test.go @@ -0,0 +1,322 @@ +package ui + +import ( + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/gin-gonic/gin" + greyproxy "github.com/greyhavenhq/greyproxy/internal/greyproxy" + _ "modernc.org/sqlite" +) + +func setupTestDB(t *testing.T) *greyproxy.DB { + t.Helper() + + tmpFile, err := os.CreateTemp("", "greyproxy_ui_test_*.db") + if err != nil { + t.Fatal(err) + } + tmpFile.Close() + t.Cleanup(func() { os.Remove(tmpFile.Name()) }) + + db, err := greyproxy.OpenDB(tmpFile.Name()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) + + if err := db.Migrate(); err != nil { + t.Fatal(err) + } + return db +} + +func seedTransactions(t *testing.T, db *greyproxy.DB) { + t.Helper() + txns := []greyproxy.HttpTransactionCreateInput{ + { + ContainerName: "webapp", + DestinationHost: "api.example.com", + DestinationPort: 443, + Method: "GET", + URL: "https://api.example.com/users", + RequestHeaders: http.Header{"Accept": {"application/json"}}, + StatusCode: 200, + ResponseContentType: "application/json", + DurationMs: 42, + Result: "auto", + }, + { + ContainerName: "webapp", + DestinationHost: "api.example.com", + DestinationPort: 443, + Method: "POST", + URL: "https://api.example.com/users", + RequestHeaders: http.Header{"Content-Type": {"application/json"}}, + RequestBody: []byte(`{"name":"alice"}`), + RequestContentType: "application/json", + StatusCode: 201, + ResponseContentType: "application/json", + DurationMs: 85, + Result: "auto", + }, + { + ContainerName: "worker", + DestinationHost: "storage.example.com", + DestinationPort: 443, + Method: "PUT", + URL: "https://storage.example.com/files/report.pdf", + StatusCode: 500, + DurationMs: 300, + Result: "auto", + }, + } + for _, input := range txns { + if _, err := greyproxy.CreateHttpTransaction(db, input); err != nil { + t.Fatal(err) + } + } +} + +func setupRouter(t *testing.T, db *greyproxy.DB) *gin.Engine { + t.Helper() + gin.SetMode(gin.TestMode) + r := gin.New() + g := r.Group("") + bus := greyproxy.NewEventBus() + RegisterPageRoutes(g, db, bus) + RegisterHTMXRoutes(g, db, bus, nil, nil) + return r +} + +func TestTrafficPageRoute(t *testing.T) { + db := setupTestDB(t) + r := setupRouter(t, db) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/traffic", nil) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status: got %d, want 200", w.Code) + } + + body := w.Body.String() + // Page should contain the traffic page structure + if !strings.Contains(body, "HTTP Traffic") { + t.Error("page missing title 'HTTP Traffic'") + } + if !strings.Contains(body, "traffic-table") { + t.Error("page missing traffic-table container") + } + if !strings.Contains(body, "traffic-filter-form") { + t.Error("page missing traffic filter form") + } + // Navigation should have active Traffic link + if !strings.Contains(body, `href="/traffic"`) { + t.Error("page missing traffic nav link") + } +} + +func TestTrafficTableHTMXRoute(t *testing.T) { + db := setupTestDB(t) + seedTransactions(t, db) + r := setupRouter(t, db) + + t.Run("renders all transactions", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table", nil) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status: got %d, want 200", w.Code) + } + + body := w.Body.String() + // Should render all 3 rows (each has toggleTxnDetails) + count := strings.Count(body, "toggleTxnDetails") + if count != 3 { + t.Errorf("rendered rows: got %d, want 3", count) + } + // Should show all methods + if !strings.Contains(body, ">GET") { + t.Error("missing GET method badge") + } + if !strings.Contains(body, ">POST") { + t.Error("missing POST method badge") + } + if !strings.Contains(body, ">PUT") { + t.Error("missing PUT method badge") + } + // Should show status codes + if !strings.Contains(body, ">200") { + t.Error("missing 200 status code") + } + if !strings.Contains(body, ">201") { + t.Error("missing 201 status code") + } + if !strings.Contains(body, ">500") { + t.Error("missing 500 status code") + } + // Should show container names + if !strings.Contains(body, "webapp") { + t.Error("missing container name 'webapp'") + } + if !strings.Contains(body, "worker") { + t.Error("missing container name 'worker'") + } + // Should show URLs + if !strings.Contains(body, "api.example.com/users") { + t.Error("missing URL") + } + // Should show transaction count + if !strings.Contains(body, "Showing 3 of 3 transactions") { + t.Error("missing or wrong transaction count text") + } + // Status code colors: 200 should be green, 500 should be red + if !strings.Contains(body, "text-green-600\">200") { + t.Error("200 status should have green color class") + } + if !strings.Contains(body, "text-red-600\">500") { + t.Error("500 status should have red color class") + } + }) + + t.Run("filter by method", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table?method=POST", nil) + r.ServeHTTP(w, req) + + body := w.Body.String() + count := strings.Count(body, "toggleTxnDetails") + if count != 1 { + t.Errorf("rendered rows: got %d, want 1", count) + } + if !strings.Contains(body, ">POST") { + t.Error("missing POST method") + } + if !strings.Contains(body, "Showing 1 of 1 transactions") { + t.Error("wrong count text") + } + }) + + t.Run("filter by container", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table?container=worker", nil) + r.ServeHTTP(w, req) + + body := w.Body.String() + count := strings.Count(body, "toggleTxnDetails") + if count != 1 { + t.Errorf("rendered rows: got %d, want 1", count) + } + if !strings.Contains(body, "worker") { + t.Error("missing container name 'worker'") + } + }) + + t.Run("filter by destination", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table?destination=storage", nil) + r.ServeHTTP(w, req) + + body := w.Body.String() + count := strings.Count(body, "toggleTxnDetails") + if count != 1 { + t.Errorf("rendered rows: got %d, want 1", count) + } + }) + + t.Run("pagination", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table?limit=2", nil) + r.ServeHTTP(w, req) + + body := w.Body.String() + count := strings.Count(body, "toggleTxnDetails") + if count != 2 { + t.Errorf("rendered rows: got %d, want 2", count) + } + if !strings.Contains(body, "Showing 2 of 3 transactions") { + t.Error("wrong count text for paginated view") + } + if !strings.Contains(body, "Page 1 of 2") { + t.Error("missing pagination info") + } + }) + + t.Run("page 2", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table?limit=2&page=2", nil) + r.ServeHTTP(w, req) + + body := w.Body.String() + count := strings.Count(body, "toggleTxnDetails") + if count != 1 { + t.Errorf("rendered rows on page 2: got %d, want 1", count) + } + }) + + t.Run("empty result shows message", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table?method=DELETE", nil) + r.ServeHTTP(w, req) + + body := w.Body.String() + if !strings.Contains(body, "No transactions match your filters") { + t.Error("missing empty state message for filtered view") + } + }) + + t.Run("no data shows empty state", func(t *testing.T) { + emptyDB := setupTestDB(t) + emptyRouter := setupRouter(t, emptyDB) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table", nil) + emptyRouter.ServeHTTP(w, req) + + body := w.Body.String() + if !strings.Contains(body, "No HTTP transactions") { + t.Error("missing empty state message") + } + }) + + t.Run("method badges use primary color", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table", nil) + r.ServeHTTP(w, req) + + body := w.Body.String() + // All method badges should use the same primary/orange color + if !strings.Contains(body, "bg-primary/10 text-primary\">GET") { + t.Error("GET badge missing primary color classes") + } + if !strings.Contains(body, "bg-primary/10 text-primary\">POST") { + t.Error("POST badge missing primary color classes") + } + if !strings.Contains(body, "bg-primary/10 text-primary\">PUT") { + t.Error("PUT badge missing primary color classes") + } + }) + + t.Run("expandable details section present", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/htmx/traffic-table", nil) + r.ServeHTTP(w, req) + + body := w.Body.String() + // Each transaction should have a hidden details row + detailCount := strings.Count(body, "txn-details-") + // 3 transactions × 2 occurrences each (id attr + onclick ref) = but details rows have id="txn-details-N" + if detailCount < 3 { + t.Errorf("detail rows: got %d, want at least 3", detailCount) + } + if !strings.Contains(body, "Destination:") { + t.Error("missing destination info in detail section") + } + }) +} diff --git a/internal/greyproxy/ui/templates/base.html b/internal/greyproxy/ui/templates/base.html index 1e1dc85..56e4736 100644 --- a/internal/greyproxy/ui/templates/base.html +++ b/internal/greyproxy/ui/templates/base.html @@ -23,6 +23,9 @@ +{{end}} From 32b57685e63fa30a12545c61ffb423737751b1a6 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Sun, 8 Mar 2026 14:27:22 -0600 Subject: [PATCH 03/13] docs: add vibedocs with MITM proposal as 001-mitm.md --- vibedocs/001-mitm.md | 557 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 557 insertions(+) create mode 100644 vibedocs/001-mitm.md diff --git a/vibedocs/001-mitm.md b/vibedocs/001-mitm.md new file mode 100644 index 0000000..fc7c07f --- /dev/null +++ b/vibedocs/001-mitm.md @@ -0,0 +1,557 @@ +# Greyproxy v2: Human-in-the-Middle for AI Agent Loops + +## The Problem + +Today greyproxy operates at the **TCP connection level**: you see "claude-code wants to connect to api.anthropic.com:443" and you either allow or deny. That tells you nothing about **what** the agent is actually doing. You're approving a domain blindly — is it a normal LLM request? Posting credentials? A DELETE to production? + +## The Vision + +Greyproxy becomes a **human-in-the-middle** proxy. When a new destination appears, instead of just showing "api.anthropic.com:443", you see the actual HTTP request: method, URL, headers, body. You approve with full context. Every transaction is logged. Dangerous patterns are auto-held. + +--- + +## Core Idea: Deferred Connect + +Today's flow: +``` +1. SOCKS5 CONNECT host:port +2. Bypass check → no rule → HOLD (show "host:port" pending) +3. User approves → dial upstream → MITM → HTTP flows +``` + +The problem: at step 2, we don't have HTTP details yet. The TLS handshake hasn't happened. + +**New flow** (when MITM is available): +``` +1. SOCKS5 CONNECT host:port +2. Quick check → explicit deny rule? → BLOCK immediately +3. No rule or allow rule → send SOCKS5 "Succeeded" to client +4. Client starts TLS → we do MITM handshake (client side only, no upstream yet) +5. Client sends HTTP request → we read and buffer it +6. NOW we have full context: method, URL, headers, body +7. Evaluate request rules: + a. Auto-allow → connect upstream, forward request + b. Hold → show rich pending with full HTTP details, wait for user + c. Deny → send HTTP 403 back, never connect upstream +8. On approval → connect upstream, TLS handshake with server, forward buffered request +``` + +**One stone, two birds**: the pending request now shows everything. The user approves with full context. The same mechanism works for subsequent requests on the same connection (each request is evaluated). + +**Fallback** (no MITM — non-HTTP, cert not installed, MITM bypass): +``` +Same as today: pending shows host:port only, approval is connection-level. +``` + +No degradation for non-MITM scenarios. + +--- + +## What the User Sees + +### Pending Request (today) +``` +claude-code → api.anthropic.com:443 (3 attempts) [Allow] [Deny] +``` + +### Pending Request (with MITM) +``` +claude-code → api.anthropic.com:443 +POST /v1/messages HTTP/1.1 +Content-Type: application/json +Authorization: Bearer sk-ant-***REDACTED*** + +{ + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Read src/main.go..."}], + "tools": [...] +} + +[Allow Once] [Allow Destination] [Allow Pattern] [Deny] +``` + +### Approval Actions + +| Action | What it does | +|--------|-------------| +| **Allow Once** | Forward this request only. Next request re-evaluates. | +| **Allow Destination** | Create a destination rule (same as today's Allow). All future requests to this host:port auto-forward + log. | +| **Allow Pattern** | Create a request-level rule: e.g. "allow POST /v1/messages to api.anthropic.com". Other methods/paths still held. | +| **Deny** | Return HTTP 403. Optionally create deny rule. | + +### Safe Methods: Auto-Allow After Destination Approval + +When a user clicks "Allow Destination", a sensible default: **GET requests auto-forward** (logged), **mutating requests (POST/PUT/DELETE/PATCH) are held** for review. This balances speed with safety. Configurable per-rule. + +### Global Kill Switch + +A prominent button in the dashboard header: **KILL ALL** +- Immediately denies all pending requests +- Closes all active connections (via ConnTracker) +- Pauses the proxy (new connections get instant deny) +- Requires explicit "Resume" to re-enable + +--- + +## Request Rules + +Extend the existing rules with optional HTTP-level fields: + +### Schema Changes + +```sql +ALTER TABLE rules ADD COLUMN method_pattern TEXT NOT NULL DEFAULT '*'; +ALTER TABLE rules ADD COLUMN path_pattern TEXT NOT NULL DEFAULT '*'; +ALTER TABLE rules ADD COLUMN content_action TEXT NOT NULL DEFAULT 'allow'; +-- content_action: 'allow' (forward+log), 'hold' (pause for approval), 'deny' (block) +``` + +### Rule Matching (extended) + +Existing specificity system stays. New dimensions add specificity: +- Method exact match: +4 points +- Method wildcard: +0 +- Path exact match: +3 points +- Path with glob: +2 points +- Path wildcard: +0 + +### Rule Examples + +``` +# After "Allow Destination" for anthropic: all requests auto-forward, logged +container=claude-code dest=api.anthropic.com port=443 method=* path=* action=allow + +# After "Allow Pattern" for chat completions: only this endpoint auto-forwards +container=claude-code dest=api.anthropic.com port=443 method=POST path=/v1/messages action=allow + +# Global: hold any DELETE request anywhere +container=* dest=* port=* method=DELETE path=* action=hold + +# Global: hold PUT requests to non-API destinations +container=* dest=* port=* method=PUT path=* action=hold +``` + +### Backward Compatibility + +- Existing rules get `method_pattern='*'`, `path_pattern='*'`, `content_action='allow'` +- Behavior is identical to today: destination allowed = everything flows through +- New fields only matter when MITM is active and request-level evaluation runs + +--- + +## HTTP Transaction Logging + +Every MITM-intercepted HTTP request/response is stored: + +### New Table: `http_transactions` + +```sql +CREATE TABLE http_transactions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp DATETIME NOT NULL, + container_name TEXT NOT NULL, + destination_host TEXT NOT NULL, + destination_port INTEGER NOT NULL, + + -- Request + method TEXT NOT NULL, + url TEXT NOT NULL, + request_headers TEXT, -- JSON object + request_body BLOB, -- up to max_body_capture bytes + request_body_size INTEGER, -- actual size before truncation + request_content_type TEXT, + + -- Response + status_code INTEGER, + response_headers TEXT, -- JSON object + response_body BLOB, -- up to max_body_capture bytes + response_body_size INTEGER, + response_content_type TEXT, + + -- Metadata + duration_ms INTEGER, + rule_id INTEGER, + result TEXT NOT NULL DEFAULT 'auto' -- 'auto', 'allowed', 'held', 'denied' +); + +CREATE INDEX idx_http_transactions_ts ON http_transactions(timestamp); +CREATE INDEX idx_http_transactions_dest ON http_transactions(destination_host, destination_port); +``` + +### Body Capture Config + +```yaml +greyproxy: + max_body_capture: 1048576 # 1MB default, configurable +``` + +- Bodies larger than max are truncated; `request_body_size` stores the real size +- Binary content stored as-is; UI shows "binary, N bytes" or hex preview +- JSON bodies: stored raw, UI renders with syntax highlighting and collapsible sections + +### Data Retention + +Logs and transactions are retained for **2 weeks** by default (configurable). + +```yaml +greyproxy: + log_retention_days: 14 # delete logs + transactions older than this +``` + +A cleanup routine runs on startup and then every hour: +```sql +DELETE FROM http_transactions WHERE timestamp < datetime('now', '-14 days'); +DELETE FROM request_logs WHERE timestamp < datetime('now', '-14 days'); +``` + +The dashboard shows a storage indicator: "Logs: 12,430 transactions, 847 MB, oldest 13 days". + +### SSE / Streaming Responses + +LLM APIs stream via Server-Sent Events. Approach: **tee the stream**. + +- Forward SSE chunks to the client in real-time (don't buffer/block the agent) +- Simultaneously capture chunks into a buffer +- When the stream ends, store the assembled response body in `http_transactions` +- The agent experiences no latency impact from logging + +For now, response-side inspection is **post-hoc** (visible in logs after it happened). Request-side inspection happens **before forwarding** (hold/deny). + +**Future: response-side pre-hoc inspection.** The architecture is designed to support this. Step 5's deferred connect establishes the pattern: "buffer first, decide, then forward." The same applies to responses — buffer the full response before writing to the client, run content filters, then forward or block. When we add response-side rules, we swap the tee approach for full response buffering on filtered destinations. This is how we'd catch dangerous tool calls in LLM responses before they reach the agent. The SSE tee is the interim optimization for unfiltered traffic; filtered traffic will use buffer-then-decide. + +--- + +## Content Filters (Phase 3) + +Regex-based rules that auto-trigger on request content: + +```yaml +greyproxy: + content_filters: + - name: "Credential leak" + field: body + regex: "private_key|BEGIN RSA|BEGIN EC" + action: hold + exclude_destinations: + - "api.anthropic.com" + - "api.openai.com" + + - name: "Dangerous methods" + method: "DELETE" + action: hold +``` + +When a filter matches: +- `hold` → request pending appears with filter name highlighted and matched content shown +- `deny` → HTTP 403 returned immediately +- `flag` → forwarded but log entry gets a warning badge + +--- + +## Schema Migrations + +The project already uses a versioned migration system (`schema_migrations` table in `migrations.go`). All new tables and column additions are added as new numbered migrations, so existing databases upgrade safely on restart. + +| Migration | What | +|-----------|------| +| 4 | Create `http_transactions` table + indexes | +| 5 | Add `method_pattern`, `path_pattern`, `content_action` columns to `rules` (with defaults for existing rows) | +| 6 | Create `pending_http_requests` table | + +Existing data is never modified or deleted by migrations. New columns use `DEFAULT` values so existing rules keep working identically. The migration runner skips already-applied versions via `schema_migrations`. + +--- + +## Dashboard Additions + +The dashboard gets new sections reflecting the HTTP transaction data: + +### Stats Panel (existing, extended) +- **HTTP Transactions today**: count of MITM-captured requests +- **Top endpoints**: most-hit method+path combinations (e.g., "POST /v1/messages — 342 calls") +- **Storage usage**: "12,430 transactions, 847 MB, oldest 13 days" with retention indicator + +### New: Live Activity Feed +A real-time stream of HTTP transactions as they happen (via WebSocket): +``` +10:32:05 POST api.anthropic.com/v1/messages → 200 (1.2s) +10:32:04 GET api.github.com/repos/... → 200 (0.3s) +10:32:01 POST api.openai.com/v1/chat/completions → 200 (2.1s) +``` +Click any entry to expand full request/response details. + +### New: Request Pending Section +When request-level holds are active, the dashboard shows them prominently — same data as the Pending page but inline for quick action without navigating away. + +--- + +## Implementation Plan: Tiny Verifiable Steps + +Each step produces a working, testable increment. Every step has: +- **Unit tests** for the new logic +- **Live test** you can run against the running proxy with curl + +### Step 1: `http_transactions` table + model + CRUD + +**What**: Create the new table, Go model, and basic CRUD operations. + +**Test**: +- Unit: TestCreateHttpTransaction, TestGetHttpTransaction, TestListHttpTransactions +- Unit: TestHttpTransactionBodyTruncation (verify max_body_capture works) + +--- + +### Step 2: Wire MITM callback to store transactions + +**What**: Replace the current `mitmLogHook` (which only logs to console) with a hook that writes to `http_transactions`. Pass DB + config through to the sniffer setup. + +**Test**: +- Unit: TestMitmCallbackCreatesTransaction (mock DB, verify fields) +- Live: + ```bash + # Terminal 1: proxy is running with MITM enabled, destination already allowed + # Terminal 2: + curl --proxy socks5h://localhost:43052 --insecure https://httpbin.org/get + # Terminal 3: check it was captured + curl http://localhost:43080/api/transactions | jq '.[] | {method, url, status_code}' + # Expected: GET https://httpbin.org/get → 200 + ``` + +--- + +### Step 3: API endpoint for transactions + +**What**: `GET /api/transactions` — list transactions with filters (destination, method, time range). `GET /api/transactions/:id` — full detail including body. + +**Test**: +- Unit: TestTransactionsAPIList, TestTransactionsAPIDetail +- Live: + ```bash + # After making a few requests through the proxy: + curl http://localhost:43080/api/transactions?destination=httpbin.org | jq + curl http://localhost:43080/api/transactions/1 | jq '.request_body' + ``` + +--- + +### Step 4: Transaction detail in Logs UI + +**What**: Update the Logs tab to show HTTP transaction info. Each log entry expands to show method, URL, status, headers, body (collapsible). + +**Test**: +- Live: Open dashboard, make requests through proxy, verify Logs tab shows HTTP details with expandable body sections. + +--- + +### Step 5: Deferred connect — refactor sniffer for split handshake + +**What**: This is the core architectural change. Refactor `terminateTLS` so it can: +1. Do TLS handshake with client (MITM) WITHOUT connecting to upstream +2. Return the decrypted client connection for reading +3. Later, when approved, connect to upstream and forward + +Currently `terminateTLS` dials upstream first, then does client handshake. We reverse this. + +**Test**: +- Unit: TestDeferredTlsHandshake — verify we can MITM-handshake with client, read HTTP request bytes, then separately connect upstream and forward +- Live: + ```bash + # With a test destination not yet in rules: + curl --proxy socks5h://localhost:43052 --insecure https://httpbin.org/post \ + -X POST -d '{"test": "hello"}' + # In dashboard: pending should show POST /post with body {"test": "hello"} + ``` + +--- + +### Step 6: Request-level pending model + API + +**What**: New `pending_http_requests` table. API endpoints to list/approve/deny request-level pendings. WebSocket events for real-time updates. + +```sql +CREATE TABLE pending_http_requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + container_name TEXT NOT NULL, + destination_host TEXT NOT NULL, + destination_port INTEGER NOT NULL, + method TEXT NOT NULL, + url TEXT NOT NULL, + request_headers TEXT, + request_body BLOB, + request_body_size INTEGER, + created_at DATETIME NOT NULL, + UNIQUE(container_name, destination_host, destination_port, method, url) +); +``` + +**Test**: +- Unit: TestCreatePendingHttpRequest, TestApprovePendingHttpRequest, TestDenyPendingHttpRequest +- Live: + ```bash + # API shows request pendings + curl http://localhost:43080/api/pending/requests | jq + # Approve via API + curl -X POST http://localhost:43080/api/pending/requests/1/allow + ``` + +--- + +### Step 7: Request-level hold in sniffer + +**What**: In `httpRoundTrip`, before forwarding the request upstream, evaluate request-level rules. If no matching allow rule and MITM is active → buffer request, create pending, wait for approval via EventBus. + +**Test**: +- Unit: TestRequestHoldAndApprove, TestRequestHoldAndDeny, TestRequestAutoAllow +- Live: + ```bash + # Destination allowed but no request-level rule: + curl --proxy socks5h://localhost:43052 --insecure https://httpbin.org/delete -X DELETE + # Dashboard shows request pending with DELETE method + # Approve → curl completes + # Deny → curl gets 403 + ``` + +--- + +### Step 8: Pending UI for request-level holds + +**What**: Update the Pending page to show two sections: +1. Connection pendings (existing, top) +2. Request pendings (new, below, with HTTP detail) + +Rich display: method badge, URL, headers, body preview with syntax highlighting. + +**Test**: +- Live: Open dashboard, trigger a request-level hold, verify the UI shows full HTTP details with approve/deny buttons. + +--- + +### Step 9: Extended rules — method + path patterns + +**What**: Add `method_pattern` and `path_pattern` columns to rules. Update rule matching to include these dimensions. Update Rules UI to show/edit these fields. + +**Test**: +- Unit: TestRuleMatchesMethod, TestRuleMatchesPath, TestRuleSpecificityWithHttpFields +- Live: + ```bash + # Create a rule that allows GET but holds POST: + curl -X POST http://localhost:43080/api/rules -d '{ + "container_pattern": "*", + "destination_pattern": "httpbin.org", + "port_pattern": "443", + "method_pattern": "GET", + "path_pattern": "*", + "action": "allow" + }' + # GET flows through: + curl --proxy socks5h://localhost:43052 --insecure https://httpbin.org/get + # POST gets held: + curl --proxy socks5h://localhost:43052 --insecure https://httpbin.org/post -X POST -d 'test' + ``` + +--- + +### Step 10: Global kill switch + +**What**: API endpoint `POST /api/killswitch` that: +1. Denies all pending requests (connection + request level) +2. Cancels all active connections via ConnTracker +3. Sets a "paused" flag that makes Bypass.Contains() deny everything +4. `POST /api/killswitch/resume` to re-enable + +Dashboard header gets a red KILL button and green RESUME button. + +**Test**: +- Unit: TestKillSwitchDeniesAll, TestKillSwitchClosesConnections, TestResume +- Live: + ```bash + # Start a long-running request: + curl --proxy socks5h://localhost:43052 --insecure https://httpbin.org/delay/30 & + # Hit kill switch: + curl -X POST http://localhost:43080/api/killswitch + # Background curl should fail immediately + # New requests should fail: + curl --proxy socks5h://localhost:43052 --insecure https://httpbin.org/get + # Resume: + curl -X POST http://localhost:43080/api/killswitch/resume + ``` + +--- + +### Step 11: Content filters + +**What**: Config-based regex filters that run on request body/headers/URL. When matched, override the content_action to hold/deny/flag. + +**Test**: +- Unit: TestContentFilterMatchesBody, TestContentFilterExcludesDestination +- Live: + ```bash + # With filter configured for "private_key": + curl --proxy socks5h://localhost:43052 --insecure https://httpbin.org/post \ + -X POST -d '{"data": "my private_key is abc123"}' + # Dashboard shows held request with "Credential leak" filter highlighted + ``` + +--- + +### Step 12: UI polish — redaction, syntax highlighting, body preview + +**What**: Auto-redact sensitive patterns (API keys, bearer tokens) in UI display. JSON syntax highlighting for bodies. Collapsible sections for headers. + +--- + +## Priority Order + +Steps 1-4 are **Phase 1 (Observability)**: see everything, no behavior changes. +Steps 5-9 are **Phase 2 (Control)**: hold and approve individual HTTP requests. +Step 10 is **Emergency Control**: kill switch. +Steps 11-12 are **Phase 3 (Automation)**: content filters, polish. + +I recommend implementing in this order. Phase 1 alone is immediately valuable — you can see every HTTP request your agent makes. Phase 2 gives granular control. Phase 3 automates the tedious parts. + +--- + +## Open Questions + +1. **Hold timeout for request pendings**: Proposal: **60 seconds** (longer than connection-level 30s because the TCP connection is alive and the client is more patient at HTTP level). Configurable. + +2. **HTTP/2 multiplexing**: For the POC (Step 5-7), we can start with HTTP/1.1 only. HTTP/2 streams are already handled per-request in `h2Handler.ServeHTTP`, so adding the hold logic there should be straightforward in a follow-up. + +3. **What if MITM cert is not installed?**: Fall back to connection-level only. No request details shown. The pending says "MITM unavailable — install CA cert for request details" with a link to `greyproxy cert install`. + +--- + +## Config Changes Summary + +```yaml +greyproxy: + addr: ":43080" + db: "greyproxy.db" + + # New + max_body_capture: 1048576 # 1MB, max bytes per request/response body + request_hold_timeout: 60 # seconds to wait for request-level approval + log_retention_days: 14 # auto-delete logs + transactions older than this + + content_filters: # Phase 3 + - name: "Credential leak" + field: body + regex: "private_key|BEGIN RSA" + action: hold + exclude_destinations: ["api.anthropic.com"] +``` + +--- + +## Summary + +| What | Before | After | +|------|--------|-------| +| **Pending shows** | host:port | Full HTTP request with body | +| **Approval means** | "Allow this domain" | "Allow this request" or "Allow this domain" | +| **Visibility** | Connection events | Full HTTP transactions | +| **Control** | Domain allow/deny | Method/path/content rules | +| **Emergency** | Close browser tab | Kill switch | +| **Safety net** | None | Content filters | + +**Core principle**: When you approve a request, you see exactly what you're approving. From f60d30bd29e78dd3fb2f7c020cb5828f3b5359a6 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 10 Mar 2026 12:06:42 -0600 Subject: [PATCH 04/13] feat: Phase 2 MITM - pending HTTP request approval, URL pattern matching, and UI --- cmd/greyproxy/program.go | 130 ++++ .../gostx/internal/util/sniffing/sniffer.go | 185 +++++- internal/gostx/mitm_hook.go | 34 ++ internal/greyproxy/api/pending_http.go | 114 ++++ internal/greyproxy/api/router.go | 7 + internal/greyproxy/crud.go | 357 ++++++++++- internal/greyproxy/crud_test.go | 6 +- internal/greyproxy/events.go | 5 + internal/greyproxy/migrations.go | 23 + internal/greyproxy/models.go | 65 ++ internal/greyproxy/patterns.go | 56 ++ internal/greyproxy/ui/pages.go | 106 +++- internal/greyproxy/ui/templates/base.html | 15 +- .../templates/partials/pending_http_list.html | 84 +++ .../ui/templates/partials/rules_list.html | 25 +- internal/greyproxy/ui/templates/pending.html | 14 + internal/greyproxy/ui/templates/rules.html | 43 +- latest | 1 + plan.md | 33 ++ proposal.md | 557 ++++++++++++++++++ research-001-mitmproxy.md | 200 +++++++ research-002-greyproxy.md | 410 +++++++++++++ 22 files changed, 2431 insertions(+), 39 deletions(-) create mode 100644 internal/greyproxy/api/pending_http.go create mode 100644 internal/greyproxy/ui/templates/partials/pending_http_list.html create mode 120000 latest create mode 100644 plan.md create mode 100644 proposal.md create mode 100644 research-001-mitmproxy.md create mode 100644 research-002-greyproxy.md diff --git a/cmd/greyproxy/program.go b/cmd/greyproxy/program.go index eaae8bc..28c546d 100644 --- a/cmd/greyproxy/program.go +++ b/cmd/greyproxy/program.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "syscall" + "time" defaults "github.com/greyhavenhq/greyproxy" "github.com/greyhavenhq/greyproxy/internal/gostcore/logger" @@ -367,6 +368,94 @@ func (p *program) buildGreyproxyService() error { }() }) + // Wire MITM request-level hold hook for request approval before forwarding + gostx.SetGlobalMitmHoldHook(func(ctx context.Context, info gostx.MitmRequestHoldInfo) error { + host, portStr, _ := net.SplitHostPort(info.Host) + if host == "" { + host = info.Host + } + port, _ := strconv.Atoi(portStr) + if port == 0 { + port = 443 + } + containerName, _ := greyproxy_plugins.ResolveIdentity(info.ContainerName) + + // Resolve hostname from cache + resolvedHostname := shared.Cache.ResolveIP(host) + if resolvedHostname == "" { + resolvedHostname = host + } + + // Two-pass rule evaluation: + // 1. Check for request-specific rules (method/path non-wildcard) first + // 2. Fall back to destination-level rules + + // Pass 1: Find a rule with specific method or path patterns + requestRule := greyproxy.FindRequestSpecificRule(shared.DB, containerName, host, port, resolvedHostname, info.Method, info.URI) + if requestRule != nil { + if requestRule.Action == "deny" { + return gostx.ErrRequestDenied + } + switch requestRule.ContentAction { + case "allow": + return nil + case "deny": + return gostx.ErrRequestDenied + case "hold": + // Fall through to hold logic below + goto hold + } + } + + // Pass 2: Destination-level rule (backward compatible) + { + destRule := greyproxy.FindMatchingRule(shared.DB, containerName, host, port, resolvedHostname) + if destRule != nil { + if destRule.Action == "deny" { + return gostx.ErrRequestDenied + } + // Existing allow rules with default content_action auto-forward everything + if destRule.ContentAction == "allow" || destRule.ContentAction == "" { + return nil + } + if destRule.ContentAction == "deny" { + return gostx.ErrRequestDenied + } + // content_action == "hold" — fall through + } else { + // No rule at all — this shouldn't happen since connection was already allowed, + // but allow to avoid blocking + return nil + } + } + + hold: + // Hold: create pending HTTP request and wait for approval + pending, isNew, err := greyproxy.CreatePendingHttpRequest(shared.DB, greyproxy.PendingHttpRequestCreateInput{ + ContainerName: containerName, + DestinationHost: host, + DestinationPort: port, + Method: info.Method, + URL: "https://" + info.Host + info.URI, + RequestHeaders: info.RequestHeaders, + RequestBody: info.RequestBody, + }) + if err != nil { + log.Warnf("failed to create pending http request: %v", err) + return nil // Allow on error to avoid blocking the agent + } + + if isNew { + shared.Bus.Publish(greyproxy.Event{ + Type: greyproxy.EventHttpPendingCreated, + Data: pending.ToJSON(false), + }) + } + + // Wait for approval + return waitForHttpApproval(ctx, shared, pending.ID, log) + }) + // Create and register gost plugins autherPlugin := greyproxy_plugins.NewAuther() admissionPlugin := greyproxy_plugins.NewAdmission() @@ -404,6 +493,47 @@ func (p *program) buildGreyproxyService() error { return nil } +// requestHoldTimeout is how long to wait for user approval on a held HTTP request. +const requestHoldTimeout = 60 * time.Second + +func waitForHttpApproval(ctx context.Context, shared *greyproxy_api.Shared, pendingID int64, log logger.Logger) error { + ch := shared.Bus.Subscribe(16) + defer shared.Bus.Unsubscribe(ch) + + timer := time.NewTimer(requestHoldTimeout) + defer timer.Stop() + + log.Debugf("HOLD HTTP request %d, waiting up to %s for approval", pendingID, requestHoldTimeout) + + for { + select { + case <-ctx.Done(): + return gostx.ErrRequestDenied + case <-timer.C: + // Timeout — deny + greyproxy.ResolvePendingHttpRequest(shared.DB, pendingID, "denied") + return gostx.ErrRequestDenied + case evt := <-ch: + switch evt.Type { + case greyproxy.EventHttpPendingAllowed: + if data, ok := evt.Data.(map[string]any); ok { + if id, ok := data["pending_id"].(int64); ok && id == pendingID { + log.Debugf("APPROVED HTTP request %d", pendingID) + return nil + } + } + case greyproxy.EventHttpPendingDenied: + if data, ok := evt.Data.(map[string]any); ok { + if id, ok := data["pending_id"].(int64); ok && id == pendingID { + log.Debugf("DENIED HTTP request %d", pendingID) + return gostx.ErrRequestDenied + } + } + } + } + } +} + func buildMetricsService(cfg *config.MetricsConfig) (svccore.Service, error) { auther := auth_parser.ParseAutherFromAuth(cfg.Auth) if cfg.Auther != "" { diff --git a/internal/gostx/internal/util/sniffing/sniffer.go b/internal/gostx/internal/util/sniffing/sniffer.go index f7686d9..e328909 100644 --- a/internal/gostx/internal/util/sniffing/sniffer.go +++ b/internal/gostx/internal/util/sniffing/sniffer.go @@ -13,6 +13,7 @@ import ( "io" "math" "net" + "sync" "net/http" "net/http/httputil" "strings" @@ -120,6 +121,23 @@ type HTTPRoundTripInfo struct { // Set this from program initialization to record transactions to the database. var GlobalHTTPRoundTripHook func(info HTTPRoundTripInfo) +// ErrRequestDenied is returned by the hold hook to indicate the request should be denied. +var ErrRequestDenied = errors.New("request denied") + +// HTTPRequestHoldInfo contains request details for the hold hook to evaluate. +type HTTPRequestHoldInfo struct { + Host string + Method string + URI string + RequestHeaders http.Header + RequestBody []byte + ContainerName string +} + +// GlobalHTTPRequestHoldHook is called (if set) before forwarding a MITM-intercepted HTTP request upstream. +// Return nil to allow, ErrRequestDenied to send 403, or block until approval. +var GlobalHTTPRequestHoldHook func(ctx context.Context, info HTTPRequestHoldInfo) error + type Sniffer struct { Websocket bool WebsocketSampleRate float64 @@ -349,7 +367,7 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, } var reqBody *xhttp.Body - captureBody := (h.RecorderOptions != nil && h.RecorderOptions.HTTPBody) || h.OnHTTPRoundTrip != nil + captureBody := (h.RecorderOptions != nil && h.RecorderOptions.HTTPBody) || h.OnHTTPRoundTrip != nil || GlobalHTTPRequestHoldHook != nil if captureBody { if req.Body != nil { bodySize := DefaultBodySize @@ -364,6 +382,51 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, } } + // Request-level hold: evaluate before forwarding upstream + if GlobalHTTPRequestHoldHook != nil { + containerName := string(xctx.ClientIDFromContext(ctx)) + if containerName == "" { + containerName = ro.ClientID + } + // Read the body first so it's captured for the hook + var holdBody []byte + if reqBody != nil { + // Force body to be read by reading through the tee + bodyBuf := new(bytes.Buffer) + if req.Body != nil { + bodyBuf.ReadFrom(req.Body) + // Reconstruct body for forwarding + req.Body = io.NopCloser(bodyBuf) + req.ContentLength = int64(bodyBuf.Len()) + } + holdBody = reqBody.Content() + } + + holdInfo := HTTPRequestHoldInfo{ + Host: req.Host, + Method: req.Method, + URI: req.RequestURI, + RequestHeaders: req.Header.Clone(), + RequestBody: holdBody, + ContainerName: containerName, + } + if holdErr := GlobalHTTPRequestHoldHook(ctx, holdInfo); holdErr != nil { + // Request denied — send 403 to client + denyResp := &http.Response{ + StatusCode: http.StatusForbidden, + Proto: req.Proto, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Header: http.Header{"Content-Type": {"text/plain"}}, + Body: io.NopCloser(strings.NewReader("Request denied by proxy")), + } + denyResp.ContentLength = 22 + denyResp.Write(rw) + close = true + return + } + } + err = req.Write(cc) if reqBody != nil { @@ -731,6 +794,13 @@ func (h *Sniffer) terminateTLS(ctx context.Context, network string, conn, cc net ro := ho.recorderObject log := ho.log + // Deferred connect mode: if a hold hook is set, we do client-side TLS first + // (without upstream) so we can read and evaluate HTTP requests before connecting. + if GlobalHTTPRequestHoldHook != nil { + return h.terminateTLSDeferred(ctx, network, conn, cc, clientHello, ho) + } + + // Original flow: connect upstream first, then client nextProtos := clientHello.SupportedProtos if h.NegotiatedProtocol != "" { nextProtos = []string{h.NegotiatedProtocol} @@ -837,6 +907,119 @@ func (h *Sniffer) terminateTLS(ctx context.Context, network string, conn, cc net return h.HandleHTTP(ctx, network, serverConn, opts...) } +// terminateTLSDeferred performs the client-side TLS handshake FIRST (without upstream), +// allowing us to read HTTP requests before deciding whether to connect upstream. +// This enables request-level hold/approval: the user sees the full HTTP request +// before any data reaches the destination. +func (h *Sniffer) terminateTLSDeferred(ctx context.Context, network string, conn, cc net.Conn, clientHello *dissector.ClientHelloInfo, ho *HandleOptions) error { + ro := ho.recorderObject + log := ho.log + + host := clientHello.ServerName + if host == "" { + host = ro.Host + } + if hostPart, _, _ := net.SplitHostPort(host); hostPart != "" { + host = hostPart + } + + // For deferred mode, negotiate http/1.1 with the client + // (HTTP/2 deferred connect is a future enhancement) + nextProtos := []string{"http/1.1"} + + ro.TLS.Proto = "http/1.1" + + // Step 1: TLS handshake with client (MITM) — no upstream connection yet + wb := &bytes.Buffer{} + conn = xnet.NewReadWriteConn(conn, io.MultiWriter(wb, conn), conn) + + serverConn := tls.Server(conn, &tls.Config{ + NextProtos: nextProtos, + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + certPool := h.CertPool + if certPool == nil { + certPool = DefaultCertPool + } + serverName := chi.ServerName + if serverName == "" { + serverName = host + } + cert, err := certPool.Get(serverName) + if cert != nil { + pool := x509.NewCertPool() + pool.AddCert(h.Certificate) + if _, err = cert.Verify(x509.VerifyOptions{ + DNSName: serverName, + Roots: pool, + }); err != nil { + log.Warnf("verify cached certificate for %s: %v", serverName, err) + cert = nil + } + } + if cert == nil { + cert, err = tls_util.GenerateCertificate(serverName, 7*24*time.Hour, h.Certificate, h.PrivateKey) + certPool.Put(serverName, cert) + } + if err != nil { + return nil, err + } + return &tls.Certificate{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: h.PrivateKey, + }, nil + }, + }) + if err := serverConn.HandshakeContext(ctx); err != nil { + return err + } + if record, _ := dissector.ReadRecord(wb); record != nil { + wb.Reset() + record.WriteTo(wb) + ro.TLS.ServerHello = hex.EncodeToString(wb.Bytes()) + } + + // Step 2: Lazy upstream connection — established on first dial + var upstreamOnce sync.Once + var upstreamConn net.Conn + var upstreamErr error + + lazyDial := func(ctx context.Context, network, address string) (net.Conn, error) { + upstreamOnce.Do(func() { + // TLS handshake with upstream — force HTTP/1.1 to match our client-side negotiation + upstreamCfg := &tls.Config{ + ServerName: clientHello.ServerName, + NextProtos: []string{"http/1.1"}, + CipherSuites: clientHello.CipherSuites, + } + if upstreamCfg.ServerName == "" { + upstreamCfg.InsecureSkipVerify = true + } + tlsConn := tls.Client(cc, upstreamCfg) + if err := tlsConn.HandshakeContext(ctx); err != nil { + upstreamErr = err + return + } + cs := tlsConn.ConnectionState() + ro.TLS.CipherSuite = tls_util.CipherSuite(cs.CipherSuite).String() + ro.TLS.Version = tls_util.Version(cs.Version).String() + upstreamConn = tlsConn + }) + return upstreamConn, upstreamErr + } + + // Step 3: HandleHTTP reads requests from the decrypted client connection + // and forwards them via the lazy dialer + opts := []HandleOption{ + WithDial(lazyDial), + WithDialTLS(func(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error) { + return lazyDial(ctx, network, address) + }), + WithRecorderObject(ro), + WithLog(log), + } + return h.HandleHTTP(ctx, network, serverConn, opts...) +} + type h2Handler struct { transport http.RoundTripper recorder recorder.Recorder diff --git a/internal/gostx/mitm_hook.go b/internal/gostx/mitm_hook.go index 7fab4c6..da11635 100644 --- a/internal/gostx/mitm_hook.go +++ b/internal/gostx/mitm_hook.go @@ -1,6 +1,7 @@ package gostx import ( + "context" "net/http" "github.com/greyhavenhq/greyproxy/internal/gostx/internal/util/sniffing" @@ -22,6 +23,19 @@ type MitmRoundTripInfo struct { DurationMs int64 } +// MitmRequestHoldInfo contains request details for the hold hook to evaluate. +type MitmRequestHoldInfo struct { + Host string + Method string + URI string + RequestHeaders http.Header + RequestBody []byte + ContainerName string +} + +// ErrRequestDenied is returned by the hold hook to deny a request. +var ErrRequestDenied = sniffing.ErrRequestDenied + // SetGlobalMitmHook sets a global callback that fires after every MITM-intercepted HTTP round-trip. func SetGlobalMitmHook(hook func(info MitmRoundTripInfo)) { if hook == nil { @@ -44,3 +58,23 @@ func SetGlobalMitmHook(hook func(info MitmRoundTripInfo)) { }) } } + +// SetGlobalMitmHoldHook sets a global callback that fires BEFORE forwarding a MITM-intercepted +// HTTP request upstream. Return nil to allow, ErrRequestDenied to deny with 403. +// The hook may block (e.g., waiting for user approval). +func SetGlobalMitmHoldHook(hook func(ctx context.Context, info MitmRequestHoldInfo) error) { + if hook == nil { + sniffing.GlobalHTTPRequestHoldHook = nil + return + } + sniffing.GlobalHTTPRequestHoldHook = func(ctx context.Context, info sniffing.HTTPRequestHoldInfo) error { + return hook(ctx, MitmRequestHoldInfo{ + Host: info.Host, + Method: info.Method, + URI: info.URI, + RequestHeaders: info.RequestHeaders, + RequestBody: info.RequestBody, + ContainerName: info.ContainerName, + }) + } +} diff --git a/internal/greyproxy/api/pending_http.go b/internal/greyproxy/api/pending_http.go new file mode 100644 index 0000000..8213563 --- /dev/null +++ b/internal/greyproxy/api/pending_http.go @@ -0,0 +1,114 @@ +package api + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + greyproxy "github.com/greyhavenhq/greyproxy/internal/greyproxy" +) + +func PendingHttpCountHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + count, err := greyproxy.GetPendingHttpCount(s.DB) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"count": count}) + } +} + +func PendingHttpListHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) + offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0")) + + items, total, err := greyproxy.GetPendingHttpRequests(s.DB, greyproxy.PendingHttpFilter{ + Container: c.Query("container"), + Destination: c.Query("destination"), + Method: c.Query("method"), + Status: c.DefaultQuery("status", "pending"), + Limit: limit, + Offset: offset, + }) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + jsonItems := make([]greyproxy.PendingHttpRequestJSON, len(items)) + for i, item := range items { + jsonItems[i] = item.ToJSON(false) + } + + c.JSON(http.StatusOK, gin.H{ + "items": jsonItems, + "total": total, + }) + } +} + +func PendingHttpDetailHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) + return + } + + p, err := greyproxy.GetPendingHttpRequest(s.DB, id) + if err != nil || p == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) + return + } + + c.JSON(http.StatusOK, p.ToJSON(true)) + } +} + +func PendingHttpAllowHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) + return + } + + p, err := greyproxy.ResolvePendingHttpRequest(s.DB, id, "allowed") + if err != nil || p == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found or already resolved"}) + return + } + + s.Bus.Publish(greyproxy.Event{ + Type: greyproxy.EventHttpPendingAllowed, + Data: map[string]any{"pending_id": id}, + }) + + c.JSON(http.StatusOK, gin.H{"status": "allowed", "pending": p.ToJSON(false)}) + } +} + +func PendingHttpDenyHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) + return + } + + p, err := greyproxy.ResolvePendingHttpRequest(s.DB, id, "denied") + if err != nil || p == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found or already resolved"}) + return + } + + s.Bus.Publish(greyproxy.Event{ + Type: greyproxy.EventHttpPendingDenied, + Data: map[string]any{"pending_id": id}, + }) + + c.JSON(http.StatusOK, gin.H{"status": "denied", "pending": p.ToJSON(false)}) + } +} diff --git a/internal/greyproxy/api/router.go b/internal/greyproxy/api/router.go index 5b25387..56405f1 100644 --- a/internal/greyproxy/api/router.go +++ b/internal/greyproxy/api/router.go @@ -77,6 +77,13 @@ func NewRouter(s *Shared, pathPrefix string) (*gin.Engine, *gin.RouterGroup) { api.GET("/transactions", TransactionsListHandler(s)) api.GET("/transactions/:id", TransactionsDetailHandler(s)) + + // Request-level pending (MITM HTTP requests held for approval) + api.GET("/pending/requests/count", PendingHttpCountHandler(s)) + api.GET("/pending/requests", PendingHttpListHandler(s)) + api.GET("/pending/requests/:id", PendingHttpDetailHandler(s)) + api.POST("/pending/requests/:id/allow", PendingHttpAllowHandler(s)) + api.POST("/pending/requests/:id/deny", PendingHttpDenyHandler(s)) } // WebSocket diff --git a/internal/greyproxy/crud.go b/internal/greyproxy/crud.go index 7d08836..d0e3129 100644 --- a/internal/greyproxy/crud.go +++ b/internal/greyproxy/crud.go @@ -15,6 +15,9 @@ type RuleCreateInput struct { ContainerPattern string `json:"container_pattern"` DestinationPattern string `json:"destination_pattern"` PortPattern string `json:"port_pattern"` + MethodPattern string `json:"method_pattern"` + PathPattern string `json:"path_pattern"` + ContentAction string `json:"content_action"` RuleType string `json:"rule_type"` Action string `json:"action"` ExpiresInSeconds *int64 `json:"expires_in_seconds"` @@ -26,6 +29,9 @@ type RuleUpdateInput struct { ContainerPattern *string `json:"container_pattern"` DestinationPattern *string `json:"destination_pattern"` PortPattern *string `json:"port_pattern"` + MethodPattern *string `json:"method_pattern"` + PathPattern *string `json:"path_pattern"` + ContentAction *string `json:"content_action"` Action *string `json:"action"` Notes *string `json:"notes"` ExpiresAt *string `json:"expires_at"` @@ -47,6 +53,15 @@ func CreateRule(db *DB, input RuleCreateInput) (*Rule, error) { if input.CreatedBy == "" { input.CreatedBy = "admin" } + if input.MethodPattern == "" { + input.MethodPattern = "*" + } + if input.PathPattern == "" { + input.PathPattern = "*" + } + if input.ContentAction == "" { + input.ContentAction = "allow" + } var expiresAt sql.NullString if input.ExpiresInSeconds != nil && *input.ExpiresInSeconds > 0 { @@ -68,9 +83,10 @@ func CreateRule(db *DB, input RuleCreateInput) (*Rule, error) { ) result, err := db.WriteDB().Exec( - `INSERT INTO rules (container_pattern, destination_pattern, port_pattern, rule_type, action, expires_at, created_by, notes) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + `INSERT INTO rules (container_pattern, destination_pattern, port_pattern, method_pattern, path_pattern, content_action, rule_type, action, expires_at, created_by, notes) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, input.ContainerPattern, input.DestinationPattern, input.PortPattern, + input.MethodPattern, input.PathPattern, input.ContentAction, input.RuleType, input.Action, expiresAt, input.CreatedBy, notes, ) if err != nil { @@ -88,11 +104,14 @@ func CreateRule(db *DB, input RuleCreateInput) (*Rule, error) { return GetRule(db, id) } +// ruleColumns is the SELECT list for all rule queries. +const ruleColumns = `id, container_pattern, destination_pattern, port_pattern, + method_pattern, path_pattern, content_action, + rule_type, action, created_at, expires_at, last_used_at, created_by, notes` + func GetRule(db *DB, id int64) (*Rule, error) { row := db.ReadDB().QueryRow( - `SELECT id, container_pattern, destination_pattern, port_pattern, rule_type, action, - created_at, expires_at, last_used_at, created_by, notes - FROM rules WHERE id = ?`, id, + `SELECT `+ruleColumns+` FROM rules WHERE id = ?`, id, ) return scanRule(row) } @@ -100,6 +119,7 @@ func GetRule(db *DB, id int64) (*Rule, error) { func scanRule(row interface{ Scan(...any) error }) (*Rule, error) { var r Rule err := row.Scan(&r.ID, &r.ContainerPattern, &r.DestinationPattern, &r.PortPattern, + &r.MethodPattern, &r.PathPattern, &r.ContentAction, &r.RuleType, &r.Action, &r.CreatedAt, &r.ExpiresAt, &r.LastUsedAt, &r.CreatedBy, &r.Notes) if err != nil { if err == sql.ErrNoRows { @@ -152,7 +172,7 @@ func GetRules(db *DB, f RuleFilter) ([]Rule, int, error) { } rows, err := db.ReadDB().Query( - "SELECT id, container_pattern, destination_pattern, port_pattern, rule_type, action, created_at, expires_at, last_used_at, created_by, notes FROM rules WHERE "+whereClause+" ORDER BY created_at DESC LIMIT ? OFFSET ?", + "SELECT "+ruleColumns+" FROM rules WHERE "+whereClause+" ORDER BY created_at DESC LIMIT ? OFFSET ?", append(args, f.Limit, f.Offset)..., ) if err != nil { @@ -164,6 +184,7 @@ func GetRules(db *DB, f RuleFilter) ([]Rule, int, error) { for rows.Next() { var r Rule if err := rows.Scan(&r.ID, &r.ContainerPattern, &r.DestinationPattern, &r.PortPattern, + &r.MethodPattern, &r.PathPattern, &r.ContentAction, &r.RuleType, &r.Action, &r.CreatedAt, &r.ExpiresAt, &r.LastUsedAt, &r.CreatedBy, &r.Notes); err != nil { return nil, 0, err } @@ -191,6 +212,18 @@ func UpdateRule(db *DB, id int64, input RuleUpdateInput) (*Rule, error) { sets = append(sets, "port_pattern = ?") args = append(args, *input.PortPattern) } + if input.MethodPattern != nil { + sets = append(sets, "method_pattern = ?") + args = append(args, *input.MethodPattern) + } + if input.PathPattern != nil { + sets = append(sets, "path_pattern = ?") + args = append(args, *input.PathPattern) + } + if input.ContentAction != nil { + sets = append(sets, "content_action = ?") + args = append(args, *input.ContentAction) + } if input.Action != nil { sets = append(sets, "action = ?") args = append(args, *input.Action) @@ -306,12 +339,98 @@ func IngestRules(db *DB, rules []IngestRuleInput) (*IngestResult, error) { // FindMatchingRule finds the most specific matching rule for the given request. // Returns nil if no rule matches (default-deny). +// Only considers destination-level matching (container, host, port). func FindMatchingRule(db *DB, containerName, destHost string, destPort int, resolvedHostname string) *Rule { + return findMatchingRuleInternal(db, containerName, destHost, destPort, resolvedHostname, "", "") +} + +// FindMatchingRequestRule finds the most specific matching rule including method/path dimensions. +// Used for request-level evaluation in MITM mode. +func FindMatchingRequestRule(db *DB, containerName, destHost string, destPort int, resolvedHostname, method, path string) *Rule { + return findMatchingRuleInternal(db, containerName, destHost, destPort, resolvedHostname, method, path) +} + +// FindRequestSpecificRule finds the most specific rule that has non-wildcard method or path patterns. +// Used as Pass 1 in two-pass evaluation: request-specific rules take precedence over destination-level rules. +// Returns nil if no request-specific rule matches. +func FindRequestSpecificRule(db *DB, containerName, destHost string, destPort int, resolvedHostname, method, path string) *Rule { + // Get all non-expired rules that have specific method or path patterns + rows, err := db.ReadDB().Query( + `SELECT `+ruleColumns+` FROM rules + WHERE (expires_at IS NULL OR expires_at > datetime('now')) + AND (method_pattern != '*' OR path_pattern != '*')`, + ) + if err != nil { + return nil + } + defer rows.Close() + + type scored struct { + rule Rule + specificity int + } + + var matches []scored + for rows.Next() { + var r Rule + if err := rows.Scan(&r.ID, &r.ContainerPattern, &r.DestinationPattern, &r.PortPattern, + &r.MethodPattern, &r.PathPattern, &r.ContentAction, + &r.RuleType, &r.Action, &r.CreatedAt, &r.ExpiresAt, &r.LastUsedAt, &r.CreatedBy, &r.Notes); err != nil { + continue + } + + // Check destination-level match + matched := MatchesRule(containerName, destHost, destPort, r.ContainerPattern, r.DestinationPattern, r.PortPattern) + if !matched && resolvedHostname != "" { + matched = MatchesRule(containerName, resolvedHostname, destPort, r.ContainerPattern, r.DestinationPattern, r.PortPattern) + } + if !matched { + continue + } + + // Check method/path match + if method != "" && r.MethodPattern != "*" && !MatchesMethod(method, r.MethodPattern) { + continue + } + if path != "" && r.PathPattern != "*" && !MatchesPath(path, r.PathPattern) { + continue + } + + matches = append(matches, scored{ + rule: r, + specificity: CalculateSpecificity(r.ContainerPattern, r.DestinationPattern, r.PortPattern) + CalculateHTTPSpecificity(r.MethodPattern, r.PathPattern), + }) + } + + if len(matches) == 0 { + return nil + } + + sort.Slice(matches, func(i, j int) bool { + if matches[i].specificity != matches[j].specificity { + return matches[i].specificity > matches[j].specificity + } + if matches[i].rule.Action != matches[j].rule.Action { + return matches[i].rule.Action == "deny" + } + return false + }) + + winner := &matches[0].rule + + go func() { + db.Lock() + defer db.Unlock() + db.WriteDB().Exec("UPDATE rules SET last_used_at = datetime('now') WHERE id = ?", winner.ID) + }() + + return winner +} + +func findMatchingRuleInternal(db *DB, containerName, destHost string, destPort int, resolvedHostname, method, path string) *Rule { // Get all non-expired rules rows, err := db.ReadDB().Query( - `SELECT id, container_pattern, destination_pattern, port_pattern, rule_type, action, - created_at, expires_at, last_used_at, created_by, notes - FROM rules + `SELECT `+ruleColumns+` FROM rules WHERE expires_at IS NULL OR expires_at > datetime('now')`, ) if err != nil { @@ -328,6 +447,7 @@ func FindMatchingRule(db *DB, containerName, destHost string, destPort int, reso for rows.Next() { var r Rule if err := rows.Scan(&r.ID, &r.ContainerPattern, &r.DestinationPattern, &r.PortPattern, + &r.MethodPattern, &r.PathPattern, &r.ContentAction, &r.RuleType, &r.Action, &r.CreatedAt, &r.ExpiresAt, &r.LastUsedAt, &r.CreatedBy, &r.Notes); err != nil { continue } @@ -340,12 +460,26 @@ func FindMatchingRule(db *DB, containerName, destHost string, destPort int, reso matched = MatchesRule(containerName, resolvedHostname, destPort, r.ContainerPattern, r.DestinationPattern, r.PortPattern) } - if matched { - matches = append(matches, scored{ - rule: r, - specificity: CalculateSpecificity(r.ContainerPattern, r.DestinationPattern, r.PortPattern), - }) + if !matched { + continue + } + + // If method/path are provided, check those dimensions too + if method != "" && r.MethodPattern != "*" { + if !MatchesMethod(method, r.MethodPattern) { + continue + } + } + if path != "" && r.PathPattern != "*" { + if !MatchesPath(path, r.PathPattern) { + continue + } } + + matches = append(matches, scored{ + rule: r, + specificity: CalculateSpecificity(r.ContainerPattern, r.DestinationPattern, r.PortPattern) + CalculateHTTPSpecificity(r.MethodPattern, r.PathPattern), + }) } if len(matches) == 0 { @@ -778,13 +912,13 @@ func parseDuration(duration string) (ruleType string, expiresIn *int64) { func findExistingRule(db *DB, containerPattern, destPattern, portPattern, action string) *Rule { var r Rule err := db.ReadDB().QueryRow( - `SELECT id, container_pattern, destination_pattern, port_pattern, rule_type, action, - created_at, expires_at, last_used_at, created_by, notes + `SELECT `+ruleColumns+` FROM rules WHERE container_pattern = ? AND destination_pattern = ? AND port_pattern = ? AND action = ? AND (expires_at IS NULL OR expires_at > datetime('now'))`, containerPattern, destPattern, portPattern, action, ).Scan(&r.ID, &r.ContainerPattern, &r.DestinationPattern, &r.PortPattern, + &r.MethodPattern, &r.PathPattern, &r.ContentAction, &r.RuleType, &r.Action, &r.CreatedAt, &r.ExpiresAt, &r.LastUsedAt, &r.CreatedBy, &r.Notes) if err != nil { return nil @@ -1056,6 +1190,197 @@ func GetDashboardStats(db *DB, fromDate, toDate time.Time, groupBy string, recen return stats, nil } +// --- Pending HTTP Requests --- + +type PendingHttpRequestCreateInput struct { + ContainerName string + DestinationHost string + DestinationPort int + Method string + URL string + RequestHeaders map[string][]string + RequestBody []byte +} + +func CreatePendingHttpRequest(db *DB, input PendingHttpRequestCreateInput) (*PendingHttpRequest, bool, error) { + db.Lock() + defer db.Unlock() + + var headersJSON sql.NullString + if input.RequestHeaders != nil { + b, _ := json.Marshal(input.RequestHeaders) + headersJSON = sql.NullString{String: string(b), Valid: true} + } + + body := input.RequestBody + bodySize := int64(len(body)) + if len(body) > MaxBodyCapture { + body = body[:MaxBodyCapture] + } + + // Check for existing pending with same key + var existingID int64 + err := db.WriteDB().QueryRow( + `SELECT id FROM pending_http_requests + WHERE container_name = ? AND destination_host = ? AND destination_port = ? AND method = ? AND url = ? AND status = 'pending'`, + input.ContainerName, input.DestinationHost, input.DestinationPort, input.Method, input.URL, + ).Scan(&existingID) + if err == nil { + p, err := getPendingHttpRequestByID(db.WriteDB(), existingID) + return p, false, err + } + + // Delete any resolved entries with the same key to avoid UNIQUE constraint violation + db.WriteDB().Exec( + `DELETE FROM pending_http_requests + WHERE container_name = ? AND destination_host = ? AND destination_port = ? AND method = ? AND url = ? AND status != 'pending'`, + input.ContainerName, input.DestinationHost, input.DestinationPort, input.Method, input.URL, + ) + + result, err := db.WriteDB().Exec( + `INSERT INTO pending_http_requests (container_name, destination_host, destination_port, method, url, request_headers, request_body, request_body_size) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + input.ContainerName, input.DestinationHost, input.DestinationPort, + input.Method, input.URL, headersJSON, body, bodySize, + ) + if err != nil { + // UNIQUE constraint — return existing + if strings.Contains(err.Error(), "UNIQUE") { + var id int64 + db.WriteDB().QueryRow( + `SELECT id FROM pending_http_requests WHERE container_name = ? AND destination_host = ? AND destination_port = ? AND method = ? AND url = ?`, + input.ContainerName, input.DestinationHost, input.DestinationPort, input.Method, input.URL, + ).Scan(&id) + p, err := getPendingHttpRequestByID(db.WriteDB(), id) + return p, false, err + } + return nil, false, fmt.Errorf("insert pending_http_request: %w", err) + } + + id, _ := result.LastInsertId() + p, err := getPendingHttpRequestByID(db.WriteDB(), id) + return p, true, err +} + +func getPendingHttpRequestByID(conn *sql.DB, id int64) (*PendingHttpRequest, error) { + var p PendingHttpRequest + err := conn.QueryRow( + `SELECT id, container_name, destination_host, destination_port, + method, url, request_headers, request_body, request_body_size, + created_at, status + FROM pending_http_requests WHERE id = ?`, id, + ).Scan(&p.ID, &p.ContainerName, &p.DestinationHost, &p.DestinationPort, + &p.Method, &p.URL, &p.RequestHeaders, &p.RequestBody, &p.RequestBodySize, + &p.CreatedAt, &p.Status) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &p, nil +} + +func GetPendingHttpRequest(db *DB, id int64) (*PendingHttpRequest, error) { + return getPendingHttpRequestByID(db.ReadDB(), id) +} + +type PendingHttpFilter struct { + Container string + Destination string + Method string + Status string + Limit int + Offset int +} + +func GetPendingHttpRequests(db *DB, f PendingHttpFilter) ([]PendingHttpRequest, int, error) { + if f.Limit <= 0 { + f.Limit = 100 + } + if f.Status == "" { + f.Status = "pending" + } + + where := []string{"status = ?"} + args := []any{f.Status} + + if f.Container != "" { + where = append(where, "container_name LIKE ?") + args = append(args, "%"+f.Container+"%") + } + if f.Destination != "" { + where = append(where, "destination_host LIKE ?") + args = append(args, "%"+f.Destination+"%") + } + if f.Method != "" { + where = append(where, "method = ?") + args = append(args, f.Method) + } + + whereClause := strings.Join(where, " AND ") + + var total int + err := db.ReadDB().QueryRow("SELECT COUNT(*) FROM pending_http_requests WHERE "+whereClause, args...).Scan(&total) + if err != nil { + return nil, 0, err + } + + rows, err := db.ReadDB().Query( + `SELECT id, container_name, destination_host, destination_port, + method, url, request_headers, NULL, request_body_size, + created_at, status + FROM pending_http_requests WHERE `+whereClause+` ORDER BY created_at DESC LIMIT ? OFFSET ?`, + append(args, f.Limit, f.Offset)..., + ) + if err != nil { + return nil, 0, err + } + defer rows.Close() + + var items []PendingHttpRequest + for rows.Next() { + var p PendingHttpRequest + if err := rows.Scan(&p.ID, &p.ContainerName, &p.DestinationHost, &p.DestinationPort, + &p.Method, &p.URL, &p.RequestHeaders, &p.RequestBody, &p.RequestBodySize, + &p.CreatedAt, &p.Status); err != nil { + return nil, 0, err + } + items = append(items, p) + } + return items, total, nil +} + +func GetPendingHttpCount(db *DB) (int, error) { + var count int + err := db.ReadDB().QueryRow( + `SELECT COUNT(*) FROM pending_http_requests WHERE status = 'pending'`, + ).Scan(&count) + return count, err +} + +// ResolvePendingHttpRequest sets the status of a pending HTTP request to "allowed" or "denied". +func ResolvePendingHttpRequest(db *DB, id int64, status string) (*PendingHttpRequest, error) { + db.Lock() + defer db.Unlock() + + _, err := db.WriteDB().Exec( + `UPDATE pending_http_requests SET status = ? WHERE id = ? AND status = 'pending'`, + status, id, + ) + if err != nil { + return nil, fmt.Errorf("resolve pending http request: %w", err) + } + return getPendingHttpRequestByID(db.WriteDB(), id) +} + +// CleanupResolvedHttpPending deletes non-pending (resolved) records older than 5 minutes. +func CleanupResolvedHttpPending(db *DB) { + db.Lock() + defer db.Unlock() + db.WriteDB().Exec(`DELETE FROM pending_http_requests WHERE status != 'pending' AND created_at < datetime('now', '-5 minutes')`) +} + // --- HTTP Transactions --- // MaxBodyCapture is the default max bytes to store per request/response body. diff --git a/internal/greyproxy/crud_test.go b/internal/greyproxy/crud_test.go index b09d21b..788963e 100644 --- a/internal/greyproxy/crud_test.go +++ b/internal/greyproxy/crud_test.go @@ -618,7 +618,7 @@ func TestMigrations(t *testing.T) { db := setupTestDB(t) // Verify tables exist - tables := []string{"rules", "pending_requests", "request_logs", "http_transactions", "schema_migrations"} + tables := []string{"rules", "pending_requests", "request_logs", "http_transactions", "pending_http_requests", "schema_migrations"} for _, table := range tables { var name string err := db.ReadDB().QueryRow( @@ -637,8 +637,8 @@ func TestMigrations(t *testing.T) { // Verify migration versions were recorded var count int db.ReadDB().QueryRow("SELECT COUNT(*) FROM schema_migrations").Scan(&count) - if count != 4 { - t.Errorf("expected 4 migration versions, got %d", count) + if count != 6 { + t.Errorf("expected 6 migration versions, got %d", count) } } diff --git a/internal/greyproxy/events.go b/internal/greyproxy/events.go index 876959e..76ca7a3 100644 --- a/internal/greyproxy/events.go +++ b/internal/greyproxy/events.go @@ -13,6 +13,11 @@ const ( EventPendingDismissed = "pending_request.dismissed" EventWaitersChanged = "waiters.changed" EventTransactionNew = "transaction.new" + + // Request-level pending events (MITM HTTP requests held for approval) + EventHttpPendingCreated = "http_pending.created" + EventHttpPendingAllowed = "http_pending.allowed" + EventHttpPendingDenied = "http_pending.denied" ) // Event represents a broadcast event. diff --git a/internal/greyproxy/migrations.go b/internal/greyproxy/migrations.go index 8a66b38..03476ba 100644 --- a/internal/greyproxy/migrations.go +++ b/internal/greyproxy/migrations.go @@ -87,6 +87,29 @@ var migrations = []string{ ); CREATE INDEX IF NOT EXISTS idx_http_transactions_ts ON http_transactions(timestamp); CREATE INDEX IF NOT EXISTS idx_http_transactions_dest ON http_transactions(destination_host, destination_port);`, + + // Migration 5: Add method_pattern, path_pattern, content_action to rules for request-level control + `ALTER TABLE rules ADD COLUMN method_pattern TEXT NOT NULL DEFAULT '*'; + ALTER TABLE rules ADD COLUMN path_pattern TEXT NOT NULL DEFAULT '*'; + ALTER TABLE rules ADD COLUMN content_action TEXT NOT NULL DEFAULT 'allow';`, + + // Migration 6: Create pending_http_requests table for request-level holds + `CREATE TABLE IF NOT EXISTS pending_http_requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + container_name TEXT NOT NULL, + destination_host TEXT NOT NULL, + destination_port INTEGER NOT NULL, + method TEXT NOT NULL, + url TEXT NOT NULL, + request_headers TEXT, + request_body BLOB, + request_body_size INTEGER, + created_at DATETIME NOT NULL DEFAULT (datetime('now')), + status TEXT NOT NULL DEFAULT 'pending' CHECK (status IN ('pending', 'allowed', 'denied')), + UNIQUE(container_name, destination_host, destination_port, method, url) + ); + CREATE INDEX IF NOT EXISTS idx_pending_http_container ON pending_http_requests(container_name); + CREATE INDEX IF NOT EXISTS idx_pending_http_status ON pending_http_requests(status);`, } func runMigrations(db *sql.DB) error { diff --git a/internal/greyproxy/models.go b/internal/greyproxy/models.go index 07e86bf..cb24457 100644 --- a/internal/greyproxy/models.go +++ b/internal/greyproxy/models.go @@ -12,6 +12,9 @@ type Rule struct { ContainerPattern string `json:"container_pattern"` DestinationPattern string `json:"destination_pattern"` PortPattern string `json:"port_pattern"` + MethodPattern string `json:"method_pattern"` + PathPattern string `json:"path_pattern"` + ContentAction string `json:"content_action"` RuleType string `json:"rule_type"` Action string `json:"action"` CreatedAt time.Time `json:"created_at"` @@ -26,6 +29,9 @@ type RuleJSON struct { ContainerPattern string `json:"container_pattern"` DestinationPattern string `json:"destination_pattern"` PortPattern string `json:"port_pattern"` + MethodPattern string `json:"method_pattern"` + PathPattern string `json:"path_pattern"` + ContentAction string `json:"content_action"` RuleType string `json:"rule_type"` Action string `json:"action"` CreatedAt string `json:"created_at"` @@ -42,6 +48,9 @@ func (r *Rule) ToJSON() RuleJSON { ContainerPattern: r.ContainerPattern, DestinationPattern: r.DestinationPattern, PortPattern: r.PortPattern, + MethodPattern: r.MethodPattern, + PathPattern: r.PathPattern, + ContentAction: r.ContentAction, RuleType: r.RuleType, Action: r.Action, CreatedAt: r.CreatedAt.UTC().Format(time.RFC3339), @@ -284,6 +293,62 @@ func (t *HttpTransaction) ToJSON(includeBody bool) HttpTransactionJSON { return j } +// PendingHttpRequest represents an HTTP request held for user approval. +type PendingHttpRequest struct { + ID int64 `json:"id"` + ContainerName string `json:"container_name"` + DestinationHost string `json:"destination_host"` + DestinationPort int `json:"destination_port"` + Method string `json:"method"` + URL string `json:"url"` + RequestHeaders sql.NullString `json:"-"` + RequestBody []byte `json:"-"` + RequestBodySize sql.NullInt64 `json:"-"` + CreatedAt time.Time `json:"created_at"` + Status string `json:"status"` +} + +type PendingHttpRequestJSON struct { + ID int64 `json:"id"` + ContainerName string `json:"container_name"` + DestinationHost string `json:"destination_host"` + DestinationPort int `json:"destination_port"` + Method string `json:"method"` + URL string `json:"url"` + RequestHeaders any `json:"request_headers,omitempty"` + RequestBody *string `json:"request_body,omitempty"` + RequestBodySize *int64 `json:"request_body_size,omitempty"` + CreatedAt string `json:"created_at"` + Status string `json:"status"` +} + +func (p *PendingHttpRequest) ToJSON(includeBody bool) PendingHttpRequestJSON { + j := PendingHttpRequestJSON{ + ID: p.ID, + ContainerName: p.ContainerName, + DestinationHost: p.DestinationHost, + DestinationPort: p.DestinationPort, + Method: p.Method, + URL: p.URL, + CreatedAt: p.CreatedAt.UTC().Format(time.RFC3339), + Status: p.Status, + } + if p.RequestHeaders.Valid { + var h map[string]any + if json.Unmarshal([]byte(p.RequestHeaders.String), &h) == nil { + j.RequestHeaders = h + } + } + if p.RequestBodySize.Valid { + j.RequestBodySize = &p.RequestBodySize.Int64 + } + if includeBody && len(p.RequestBody) > 0 { + s := string(p.RequestBody) + j.RequestBody = &s + } + return j +} + // HttpTransactionCreateInput holds the data needed to create a transaction record. type HttpTransactionCreateInput struct { ContainerName string diff --git a/internal/greyproxy/patterns.go b/internal/greyproxy/patterns.go index 42db15d..489f6a1 100644 --- a/internal/greyproxy/patterns.go +++ b/internal/greyproxy/patterns.go @@ -138,3 +138,59 @@ func CalculateSpecificity(containerPattern, destinationPattern, portPattern stri return score } + +// CalculateHTTPSpecificity returns additional specificity points for method/path matching. +func CalculateHTTPSpecificity(methodPattern, pathPattern string) int { + score := 0 + if methodPattern != "*" { + score += 4 // Method exact match + } + if pathPattern != "*" { + if !strings.ContainsAny(pathPattern, "*?[") { + score += 3 // Path exact match + } else { + score += 2 // Path with glob + } + } + return score +} + +// MatchesMethod checks if an HTTP method matches the given pattern. +// Supports: exact match (case-insensitive), wildcard "*", comma-separated list. +func MatchesMethod(method, pattern string) bool { + if pattern == "*" { + return true + } + method = strings.ToUpper(method) + if strings.Contains(pattern, ",") { + for _, p := range strings.Split(pattern, ",") { + if strings.ToUpper(strings.TrimSpace(p)) == method { + return true + } + } + return false + } + return strings.ToUpper(pattern) == method +} + +// MatchesPath checks if a URL path matches the given pattern. +// Supports: exact match, glob patterns via filepath.Match, prefix with trailing *. +func MatchesPath(path, pattern string) bool { + if pattern == "*" { + return true + } + // Prefix match: /api/* matches /api/anything + if strings.HasSuffix(pattern, "/*") { + prefix := pattern[:len(pattern)-1] // "/api/" + return strings.HasPrefix(path, prefix) || path == pattern[:len(pattern)-2] + } + // Exact match + if path == pattern { + return true + } + // Glob match + if matched, err := filepath.Match(pattern, path); err == nil && matched { + return true + } + return false +} diff --git a/internal/greyproxy/ui/pages.go b/internal/greyproxy/ui/pages.go index 4523cf2..f70daef 100644 --- a/internal/greyproxy/ui/pages.go +++ b/internal/greyproxy/ui/pages.go @@ -172,11 +172,12 @@ var ( trafficTmpl = parseTemplate("base.html", "base.html", "traffic.html") - dashboardStatsTmpl = parseTemplate("dashboard_stats.html", "partials/dashboard_stats.html") - pendingListTmpl = parseTemplate("pending_list.html", "partials/pending_list.html") - rulesListTmpl = parseTemplate("rules_list.html", "partials/rules_list.html") - logsTableTmpl = parseTemplate("logs_table.html", "partials/logs_table.html") - trafficTableTmpl = parseTemplate("traffic_table.html", "partials/traffic_table.html") + dashboardStatsTmpl = parseTemplate("dashboard_stats.html", "partials/dashboard_stats.html") + pendingListTmpl = parseTemplate("pending_list.html", "partials/pending_list.html") + pendingHttpListTmpl = parseTemplate("pending_http_list.html", "partials/pending_http_list.html") + rulesListTmpl = parseTemplate("rules_list.html", "partials/rules_list.html") + logsTableTmpl = parseTemplate("logs_table.html", "partials/logs_table.html") + trafficTableTmpl = parseTemplate("traffic_table.html", "partials/traffic_table.html") ) // cacheBuster is set once at startup for static asset cache busting. @@ -444,6 +445,54 @@ func RegisterHTMXRoutes(r *gin.RouterGroup, db *greyproxy.DB, bus *greyproxy.Eve renderPendingList(c, db, prefix, waiters) }) + // Request-level pending HTMX + htmx.GET("/pending-http-list", func(c *gin.Context) { + container := c.Query("container") + destination := c.Query("destination") + + items, total, err := greyproxy.GetPendingHttpRequests(db, greyproxy.PendingHttpFilter{ + Container: container, + Destination: destination, + Status: "pending", + Limit: 100, + }) + if err != nil { + c.String(http.StatusInternalServerError, "Error: %v", err) + return + } + + c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") + pendingHttpListTmpl.Execute(c.Writer, gin.H{ + "Prefix": prefix, + "Items": items, + "Total": total, + }) + }) + + htmx.POST("/pending-http/:id/allow", func(c *gin.Context) { + id, _ := strconv.ParseInt(c.Param("id"), 10, 64) + p, err := greyproxy.ResolvePendingHttpRequest(db, id, "allowed") + if err == nil && p != nil { + bus.Publish(greyproxy.Event{ + Type: greyproxy.EventHttpPendingAllowed, + Data: map[string]any{"pending_id": id}, + }) + } + renderPendingHttpList(c, db, prefix) + }) + + htmx.POST("/pending-http/:id/deny", func(c *gin.Context) { + id, _ := strconv.ParseInt(c.Param("id"), 10, 64) + p, err := greyproxy.ResolvePendingHttpRequest(db, id, "denied") + if err == nil && p != nil { + bus.Publish(greyproxy.Event{ + Type: greyproxy.EventHttpPendingDenied, + Data: map[string]any{"pending_id": id}, + }) + } + renderPendingHttpList(c, db, prefix) + }) + // Rules HTMX htmx.GET("/rules-list", func(c *gin.Context) { renderRulesList(c, db, prefix) @@ -465,10 +514,26 @@ func RegisterHTMXRoutes(r *gin.RouterGroup, db *greyproxy.DB, bus *greyproxy.Eve notesPtr = ¬es } + methodPattern := c.PostForm("method_pattern") + if methodPattern == "" { + methodPattern = "*" + } + pathPattern := c.PostForm("path_pattern") + if pathPattern == "" { + pathPattern = "*" + } + contentAction := c.PostForm("content_action") + if contentAction == "" { + contentAction = "allow" + } + _, err := greyproxy.CreateRule(db, greyproxy.RuleCreateInput{ ContainerPattern: c.PostForm("container_pattern"), DestinationPattern: c.PostForm("destination_pattern"), PortPattern: portPattern, + MethodPattern: methodPattern, + PathPattern: pathPattern, + ContentAction: contentAction, RuleType: ruleType, Action: action, ExpiresInSeconds: expiresIn, @@ -487,6 +552,9 @@ func RegisterHTMXRoutes(r *gin.RouterGroup, db *greyproxy.DB, bus *greyproxy.Eve cp := c.PostForm("container_pattern") dp := c.PostForm("destination_pattern") pp := c.PostForm("port_pattern") + mp := c.PostForm("method_pattern") + pathP := c.PostForm("path_pattern") + ca := c.PostForm("content_action") action := c.PostForm("action") notes := c.PostForm("notes") @@ -500,6 +568,15 @@ func RegisterHTMXRoutes(r *gin.RouterGroup, db *greyproxy.DB, bus *greyproxy.Eve if pp != "" { input.PortPattern = &pp } + if mp != "" { + input.MethodPattern = &mp + } + if pathP != "" { + input.PathPattern = &pathP + } + if ca != "" { + input.ContentAction = &ca + } if action != "" { input.Action = &action } @@ -668,6 +745,25 @@ func renderPendingList(c *gin.Context, db *greyproxy.DB, prefix string, waiters }) } +func renderPendingHttpList(c *gin.Context, db *greyproxy.DB, prefix string) { + container := c.Query("container") + destination := c.Query("destination") + + items, total, _ := greyproxy.GetPendingHttpRequests(db, greyproxy.PendingHttpFilter{ + Container: container, + Destination: destination, + Status: "pending", + Limit: 100, + }) + + c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") + pendingHttpListTmpl.Execute(c.Writer, gin.H{ + "Prefix": prefix, + "Items": items, + "Total": total, + }) +} + func renderRulesList(c *gin.Context, db *greyproxy.DB, prefix string) { container := c.Query("container") destination := c.Query("destination") diff --git a/internal/greyproxy/ui/templates/base.html b/internal/greyproxy/ui/templates/base.html index 56e4736..1dc96a8 100644 --- a/internal/greyproxy/ui/templates/base.html +++ b/internal/greyproxy/ui/templates/base.html @@ -278,10 +278,13 @@ } function fetchPendingCount() { - fetch('{{.Prefix}}/api/pending/count') - .then(function(r) { return r.json(); }) - .then(function(data) { updateBadge(data.count); }) - .catch(function(err) { console.warn('Failed to fetch pending count:', err); }); + Promise.all([ + fetch('{{.Prefix}}/api/pending/count').then(function(r) { return r.json(); }), + fetch('{{.Prefix}}/api/pending/requests/count').then(function(r) { return r.json(); }) + ]).then(function(results) { + var total = (results[0].count || 0) + (results[1].count || 0); + updateBadge(total); + }).catch(function(err) { console.warn('Failed to fetch pending count:', err); }); } function connect() { @@ -300,6 +303,10 @@ fetchPendingCount(); window.dispatchEvent(new CustomEvent('proxy:pending-event', { detail: msg })); } + if (msg.type && msg.type.indexOf('http_pending.') === 0) { + fetchPendingCount(); + window.dispatchEvent(new CustomEvent('proxy:http-pending-event', { detail: msg })); + } if (msg.type === 'transaction.new') { window.dispatchEvent(new CustomEvent('proxy:transaction-event', { detail: msg })); } diff --git a/internal/greyproxy/ui/templates/partials/pending_http_list.html b/internal/greyproxy/ui/templates/partials/pending_http_list.html new file mode 100644 index 0000000..3aa3a7d --- /dev/null +++ b/internal/greyproxy/ui/templates/partials/pending_http_list.html @@ -0,0 +1,84 @@ +{{$items := .Items}} +{{$total := .Total}} +{{if $items}} +
+
+ + {{$total}} + HTTP Request{{if ne $total 1}}s{{end}} Held for Approval + +
+ +
+ {{range $items}} +
+
+
+
+ + {{.Method}} + + {{.URL}} +
+
+ {{.ContainerName}} + + {{.DestinationHost}}:{{.DestinationPort}} + · + {{formatTime .CreatedAt}} +
+ + + {{if .RequestHeaders.Valid}} +
+ + +
+ {{end}} +
+ +
+ + +
+
+
+ {{end}} +
+
+ + +{{else}} +
+

No HTTP requests currently held for approval.

+
+{{end}} diff --git a/internal/greyproxy/ui/templates/partials/rules_list.html b/internal/greyproxy/ui/templates/partials/rules_list.html index 21d203a..d6a1413 100644 --- a/internal/greyproxy/ui/templates/partials/rules_list.html +++ b/internal/greyproxy/ui/templates/partials/rules_list.html @@ -42,11 +42,19 @@ {{.DestinationPattern}}:{{.PortPattern}} + {{if or (ne .MethodPattern "*") (ne .PathPattern "*")}} +
{{.MethodPattern}} {{.PathPattern}} + {{end}} + {{if ne .ContentAction "allow"}} + {{.ContentAction}} + {{end}} - @@ -86,6 +94,14 @@ {{.ContainerPattern}} {{.DestinationPattern}}:{{.PortPattern}} + {{if or (ne .MethodPattern "*") (ne .PathPattern "*")}} +
{{.MethodPattern}} {{.PathPattern}} + {{end}} + {{if ne .ContentAction "allow"}} + {{.ContentAction}} + {{end}}
@@ -93,7 +109,7 @@
-
@@ -120,11 +136,14 @@ var form = document.getElementById('rules-filter-form'); if (form) form.dispatchEvent(new Event('change', { bubbles: true })); } - function editRule(id, container, destination, port, type, action, notes) { + function editRule(id, container, destination, port, method, path, contentAction, type, action, notes) { document.getElementById('modal-title').textContent = 'Edit Rule'; document.getElementById('container_pattern').value = container; document.getElementById('destination_pattern').value = destination; document.getElementById('port_pattern').value = port; + document.getElementById('method_pattern').value = method; + document.getElementById('path_pattern').value = path; + document.getElementById('content_action').value = contentAction; document.getElementById('rule_type').value = type; document.getElementById('action').value = action; document.getElementById('notes').value = notes; diff --git a/internal/greyproxy/ui/templates/pending.html b/internal/greyproxy/ui/templates/pending.html index f542c6b..54ba364 100644 --- a/internal/greyproxy/ui/templates/pending.html +++ b/internal/greyproxy/ui/templates/pending.html @@ -57,6 +57,13 @@

Pending Requests

+ +
+
+ +
@@ -126,5 +133,12 @@

Pending Requests

highlightPending(); } }); + var httpPendingRefreshTimer = null; + window.addEventListener('proxy:http-pending-event', function(e) { + clearTimeout(httpPendingRefreshTimer); + httpPendingRefreshTimer = setTimeout(function() { + document.body.dispatchEvent(new CustomEvent('pending-http-refresh')); + }, 500); + }); {{end}} diff --git a/internal/greyproxy/ui/templates/rules.html b/internal/greyproxy/ui/templates/rules.html index c62a6fc..797f3bf 100644 --- a/internal/greyproxy/ui/templates/rules.html +++ b/internal/greyproxy/ui/templates/rules.html @@ -82,13 +82,39 @@
-
- - +
+
+ + +

* for any, or comma-separated

+
+
+ + +

* for any, supports glob

+
+
+
+
+ + +
+
+ + +

Controls MITM request inspection

+
@@ -140,6 +166,9 @@