Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion cmd/server/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ var (
registerVetSQLQueryTool bool
vetSQLQueryToolDBPath string
registerPackageRegistryTool bool
sseServerAllowedOrigins []string
sseServerAllowedHosts []string
)

func newMcpServerCommand() *cobra.Command {
Expand All @@ -42,6 +44,19 @@ func newMcpServerCommand() *cobra.Command {
cmd.Flags().StringVar(&mcpServerSseServerAddr, "sse-server-addr", "localhost:9988", "The address to listen for SSE connections")
cmd.Flags().StringVar(&mcpServerServerType, "server-type", "stdio", "The type of server to start (stdio, sse)")

cmd.Flags().StringSliceVar(
&sseServerAllowedOrigins,
"sse-allowed-origins",
nil,
"List of allowed origin prefixes for SSE connections. By default, we allow http://localhost:, http://127.0.0.1: and https://localhost:.",
)
cmd.Flags().StringSliceVar(
&sseServerAllowedHosts,
"sse-allowed-hosts",
nil,
"List of allowed hosts for SSE connections. By default, we allow localhost:9988, 127.0.0.1:9988 and [::1]:9988.",
)

// We allow skipping default tools to allow for custom tools to be registered when the server starts.
// This is useful for agents to avoid unnecessary tool registration.
cmd.Flags().BoolVar(&skipDefaultTools, "skip-default-tools", false, "Skip registering default tools")
Expand Down Expand Up @@ -75,7 +90,18 @@ func startMcpServer() error {
case "stdio":
mcpSrv, err = server.NewMcpServerWithStdioTransport(server.DefaultMcpServerConfig())
case "sse":
mcpSrv, err = server.NewMcpServerWithSseTransport(server.DefaultMcpServerConfig())
config := server.DefaultMcpServerConfig()

// Override with user supplied config
config.SseServerAddr = mcpServerSseServerAddr
if len(sseServerAllowedOrigins) > 0 {
config.SseServerAllowedOrigins = sseServerAllowedOrigins
}
if len(sseServerAllowedHosts) > 0 {
config.SseServerAllowedHosts = sseServerAllowedHosts
}

mcpSrv, err = server.NewMcpServerWithSseTransport(config)
default:
return fmt.Errorf("invalid server type: %s", mcpServerServerType)
}
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ require (
4d63.com/gochecknoglobals v0.2.2 // indirect
ariga.io/atlas v0.34.0 // indirect
buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.8-20240508200655-46a4cf4ba109.1 // indirect
buf.build/gen/go/safedep/api/connectrpc/go v1.18.1-20250822112533-a008e1948f1d.1 // indirect
cel.dev/expr v0.24.0 // indirect
cloud.google.com/go v0.121.2 // indirect
cloud.google.com/go/auth v0.16.1 // indirect
Expand Down
12 changes: 0 additions & 12 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,10 @@
4d63.com/gochecknoglobals v0.2.2/go.mod h1:lLxwTQjL5eIesRbvnzIP3jZtG140FnTdz+AlMa+ogt0=
ariga.io/atlas v0.34.0 h1:4hdy+2x+xNs6Lx2anuJ/4Q7lCaqddbEj5CtRDVOBu0M=
ariga.io/atlas v0.34.0/go.mod h1:WJesu2UCpGQvgUh3oVP94EiRT61nNy1W/VN5g+vqP1I=
buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.6-20250425153114-8976f5be98c1.1 h1:YhMSc48s25kr7kv31Z8vf7sPUIq5YJva9z1mn/hAt0M=
buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.6-20250425153114-8976f5be98c1.1/go.mod h1:avRlCjnFzl98VPaeCtJ24RrV/wwHFzB8sWXhj26+n/U=
buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.8-20240508200655-46a4cf4ba109.1 h1:7JbSS7TE2PJR4d/qRtynipwLl/CBFoTB69pX7xlhcJM=
buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.8-20240508200655-46a4cf4ba109.1/go.mod h1:8EQ5GzyGJQ5tEIwMSxCl8RKJYsjCpAwkdcENoioXT6g=
buf.build/gen/go/safedep/api/connectrpc/go v1.18.1-20250822112533-a008e1948f1d.1 h1:l2Fuy7PMz0wR8sQVQlhMnm6fxr6ZLdWeAR7NzZ5w2jI=
buf.build/gen/go/safedep/api/connectrpc/go v1.18.1-20250822112533-a008e1948f1d.1/go.mod h1:W2eqH9M5zldL2cDL9xqE/HsWe2FoWw+Bwzaul95RQko=
buf.build/gen/go/safedep/api/grpc/go v1.5.1-20250610075857-7cfdb61a0bfa.2 h1:ENbt9SmU2gh4YhjcFqzceJRlg80hsD28M+Oon9l752A=
buf.build/gen/go/safedep/api/grpc/go v1.5.1-20250610075857-7cfdb61a0bfa.2/go.mod h1:WDOWZglnweQ4njVEJpLYYpLMx9fD+e94KbKdt8oJrxY=
buf.build/gen/go/safedep/api/grpc/go v1.5.1-20250819072717-b69aa2c62a0d.2 h1:A4enKVmVf69uVSG88POR59z5YE6dhATNLpL8+DmZtsg=
buf.build/gen/go/safedep/api/grpc/go v1.5.1-20250819072717-b69aa2c62a0d.2/go.mod h1:Raps9oq+lWS0tdif5yUy8MS6UGc2pr6NMSrv3Jz4avM=
buf.build/gen/go/safedep/api/protocolbuffers/go v1.36.6-20250705071048-7ad8e6be7c05.1 h1:4sM5O5dx0yUucJ1trjZ8Cm9IGX2loEc4cUyh3Xy+5eU=
buf.build/gen/go/safedep/api/protocolbuffers/go v1.36.6-20250705071048-7ad8e6be7c05.1/go.mod h1:uR95GqsnNCRn6cTyRBte6uMJMm0rEBRxTGpakKCNL9I=
buf.build/gen/go/safedep/api/protocolbuffers/go v1.36.8-20250819072717-b69aa2c62a0d.1 h1:fRdyfm5aiolcZmJuWPzbbI4cSYJlssvBZXi/BQUfMWc=
buf.build/gen/go/safedep/api/protocolbuffers/go v1.36.8-20250819072717-b69aa2c62a0d.1/go.mod h1:Q5oZou54kSUyZHl4RSPY93qr3b1ssj3ZvdBAhRAdlJA=
buf.build/gen/go/safedep/api/protocolbuffers/go v1.36.8-20250822112533-a008e1948f1d.1 h1:XqV9omaTxxXaI9VvS87PX4Uw6h927UycRR7SfwENSHU=
buf.build/gen/go/safedep/api/protocolbuffers/go v1.36.8-20250822112533-a008e1948f1d.1/go.mod h1:Q5oZou54kSUyZHl4RSPY93qr3b1ssj3ZvdBAhRAdlJA=
cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY=
Expand Down Expand Up @@ -2066,8 +2056,6 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
Expand Down
70 changes: 70 additions & 0 deletions mcp/server/guard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package server

import (
"net/http"
"slices"
"strings"
)

var (
defaultAllowedOriginsPrefix = []string{
"http://localhost:",
"http://127.0.0.1:",
"https://localhost:",
}
defaultAllowedHosts = []string{"localhost:9988", "127.0.0.1:9988", "[::1]:9988"}
)

// hostGuard is a middleware that allows only the allowed hosts to access the
// MCP server. nil allowedHosts will use the default allowed hosts. Empty
// allowedHosts will block all hosts.
func hostGuard(allowedHosts []string, next http.Handler) http.Handler {
if allowedHosts == nil {
allowedHosts = defaultAllowedHosts
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// contains is faster than a map lookup for small lists
if !slices.Contains(allowedHosts, r.Host) {
// 421 (misdirected request) is ideal; 403 (forbidden) is fine too.
w.WriteHeader(http.StatusMisdirectedRequest)
return
}
next.ServeHTTP(w, r)
})
}

// originGuard is a middleware that allows only the allowed origins to access
// the MCP server. nil allowedOriginsPrefix will use the default allowed origins
// prefix. Empty allowedOriginsPrefix will block all origins.
func originGuard(allowedOriginsPrefix []string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
o := r.Header.Get("Origin")
if o == "" {
// Non-browser/same-origin fetches may omit Origin. Don't block
// solely on this.
next.ServeHTTP(w, r)
return
}

if allowedOriginsPrefix == nil {
allowedOriginsPrefix = defaultAllowedOriginsPrefix
}
if !isAllowedOrigin(o, allowedOriginsPrefix) {
http.Error(w, "forbidden origin", http.StatusForbidden)
return
}

next.ServeHTTP(w, r)
})
}

// isAllowedOrigin checks if the origin is in the allowed origins prefix list.
func isAllowedOrigin(origin string, allowedOriginsPrefix []string) bool {
for _, allowedOriginPrefix := range allowedOriginsPrefix {
if strings.HasPrefix(origin, allowedOriginPrefix) {
return true
}
}
return false
}
Loading
Loading