diff --git a/CHANGELOG.md b/CHANGELOG.md index 42ab6c89324..c14395bfab4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ BUG FIXES: * serviceregistration: Fix a regression for Consul service registration that ignored using the listener address as the redirect address unless api_addr was provided. It now properly uses the same redirect address as the one used by Vault's Core object. [[GH-8976](https://github.com/hashicorp/vault/pull/8976)] +* storage/raft: Advertise the configured cluster address to the rest of the nodes in the raft cluster. This fixes + an issue where a node advertising 0.0.0.0 is not using a unique hostname. [[GH-9008](https://github.com/hashicorp/vault/pull/9008)] +* storage/raft: Fix panic when multiple nodes attempt to join the cluster at once. [[GH-9008](https://github.com/hashicorp/vault/pull/9008)] * sys: The path provided in `sys/internal/ui/mounts/:path` is now namespace-aware. This fixes an issue with `vault kv` subcommands that had namespaces provided in the path returning permission denied all the time. [[GH-8962](https://github.com/hashicorp/vault/pull/8962)] diff --git a/physical/raft/streamlayer.go b/physical/raft/streamlayer.go index fcf0a0be57f..71ce991eb08 100644 --- a/physical/raft/streamlayer.go +++ b/physical/raft/streamlayer.go @@ -15,6 +15,7 @@ import ( "math/big" mathrand "math/rand" "net" + "net/url" "sync" "time" @@ -135,13 +136,15 @@ func GenerateTLSKey(reader io.Reader) (*TLSKey, error) { }, nil } -// Make sure raftLayer satisfies the raft.StreamLayer interface -var _ raft.StreamLayer = (*raftLayer)(nil) +var ( + // Make sure raftLayer satisfies the raft.StreamLayer interface + _ raft.StreamLayer = (*raftLayer)(nil) -// Make sure raftLayer satisfies the cluster.Handler and cluster.Client -// interfaces -var _ cluster.Handler = (*raftLayer)(nil) -var _ cluster.Client = (*raftLayer)(nil) + // Make sure raftLayer satisfies the cluster.Handler and cluster.Client + // interfaces + _ cluster.Handler = (*raftLayer)(nil) + _ cluster.Client = (*raftLayer)(nil) +) // RaftLayer implements the raft.StreamLayer interface, // so that we can use a single RPC layer for Raft and Vault @@ -170,12 +173,21 @@ type raftLayer struct { // from the network config. func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) { clusterAddr := clusterListener.Addr() - switch { - case clusterAddr == nil: - // Clustering disabled on the server, don't try to look for params + if clusterAddr == nil { return nil, errors.New("no raft addr found") } + { + // Test the advertised address to make sure it's not an unspecified IP + u := url.URL{ + Host: clusterAddr.String(), + } + ip := net.ParseIP(u.Hostname()) + if ip != nil && ip.IsUnspecified() { + return nil, fmt.Errorf("cannot use unspecified IP with raft storage: %s", clusterAddr.String()) + } + } + layer := &raftLayer{ addr: clusterAddr, connCh: make(chan net.Conn), diff --git a/physical/raft/streamlayer_test.go b/physical/raft/streamlayer_test.go new file mode 100644 index 00000000000..51a26f83226 --- /dev/null +++ b/physical/raft/streamlayer_test.go @@ -0,0 +1,70 @@ +package raft + +import ( + "context" + "crypto/rand" + "crypto/tls" + "net" + "testing" + "time" + + "github.com/hashicorp/vault/vault/cluster" +) + +type mockClusterHook struct { + address net.Addr +} + +func (*mockClusterHook) AddClient(alpn string, client cluster.Client) {} +func (*mockClusterHook) RemoveClient(alpn string) {} +func (*mockClusterHook) AddHandler(alpn string, handler cluster.Handler) {} +func (*mockClusterHook) StopHandler(alpn string) {} +func (*mockClusterHook) TLSConfig(ctx context.Context) (*tls.Config, error) { return nil, nil } +func (m *mockClusterHook) Addr() net.Addr { return m.address } +func (*mockClusterHook) GetDialerFunc(ctx context.Context, alpnProto string) func(string, time.Duration) (net.Conn, error) { + return func(string, time.Duration) (net.Conn, error) { + return nil, nil + } +} + +func TestStreamLayer_UnspecifiedIP(t *testing.T) { + m := &mockClusterHook{ + address: &cluster.NetAddr{ + Host: "0.0.0.0:8200", + }, + } + + raftTLSKey, err := GenerateTLSKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + raftTLS := &TLSKeyring{ + Keys: []*TLSKey{raftTLSKey}, + ActiveKeyID: raftTLSKey.ID, + } + + layer, err := NewRaftLayer(nil, raftTLS, m) + if err == nil { + t.Fatal("expected error") + } + + if err.Error() != "cannot use unspecified IP with raft storage: 0.0.0.0:8200" { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if layer != nil { + t.Fatal("expected nil layer") + } + + m.address.(*cluster.NetAddr).Host = "10.0.0.1:8200" + + layer, err = NewRaftLayer(nil, raftTLS, m) + if err != nil { + t.Fatal(err) + } + + if layer == nil { + t.Fatal("nil layer") + } +} diff --git a/vault/cluster.go b/vault/cluster.go index b189bc8ca8c..b48355541c8 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -321,6 +321,13 @@ func (c *Core) startClusterListener(ctx context.Context) error { // If we listened on port 0, record the port the OS gave us. c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addr())) } + + if len(c.ClusterAddr()) != 0 { + if err := c.getClusterListener().SetAdvertiseAddr(c.ClusterAddr()); err != nil { + return err + } + } + return nil } diff --git a/vault/cluster/cluster.go b/vault/cluster/cluster.go index b7b74857cd1..5c60b268a3e 100644 --- a/vault/cluster/cluster.go +++ b/vault/cluster/cluster.go @@ -7,10 +7,12 @@ import ( "errors" "fmt" "net" + "net/url" "sync" "sync/atomic" "time" + "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/sdk/helper/consts" "golang.org/x/net/http2" @@ -66,6 +68,7 @@ type Listener struct { networkLayer NetworkLayer cipherSuites []uint16 + advertise net.Addr logger log.Logger l sync.RWMutex } @@ -94,7 +97,23 @@ func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Lo } } +func (cl *Listener) SetAdvertiseAddr(addr string) error { + u, err := url.ParseRequestURI(addr) + if err != nil { + return errwrap.Wrapf("failed to parse advertise address: {{err}}", err) + } + cl.advertise = &NetAddr{ + Host: u.Host, + } + + return nil +} + func (cl *Listener) Addr() net.Addr { + if cl.advertise != nil { + return cl.advertise + } + addrs := cl.Addrs() if len(addrs) == 0 { return nil @@ -422,3 +441,15 @@ type NetworkLayer interface { type NetworkLayerSet interface { Layers() []NetworkLayer } + +type NetAddr struct { + Host string +} + +func (c *NetAddr) String() string { + return c.Host +} + +func (*NetAddr) Network() string { + return "tcp" +} diff --git a/vault/core.go b/vault/core.go index f6503d983a9..a64d33591a3 100644 --- a/vault/core.go +++ b/vault/core.go @@ -502,7 +502,7 @@ type Core struct { // Stop channel for raft TLS rotations raftTLSRotationStopCh chan struct{} // Stores the pending peers we are waiting to give answers - pendingRaftPeers map[string][]byte + pendingRaftPeers *sync.Map // rawConfig stores the config as-is from the provided server configuration. rawConfig *atomic.Value diff --git a/vault/logical_system_raft.go b/vault/logical_system_raft.go index 5a589ffe6c7..ea4ae023b30 100644 --- a/vault/logical_system_raft.go +++ b/vault/logical_system_raft.go @@ -181,14 +181,17 @@ func (b *SystemBackend) handleRaftBootstrapChallengeWrite() framework.OperationF return logical.ErrorResponse("no server id provided"), logical.ErrInvalidRequest } - answer, ok := b.Core.pendingRaftPeers[serverID] + var answer []byte + answerRaw, ok := b.Core.pendingRaftPeers.Load(serverID) if !ok { var err error answer, err = uuid.GenerateRandomBytes(16) if err != nil { return nil, err } - b.Core.pendingRaftPeers[serverID] = answer + b.Core.pendingRaftPeers.Store(serverID, answer) + } else { + answer = answerRaw.([]byte) } sealAccess := b.Core.seal.GetAccess() @@ -243,14 +246,14 @@ func (b *SystemBackend) handleRaftBootstrapAnswerWrite() framework.OperationFunc return logical.ErrorResponse("could not base64 decode answer"), logical.ErrInvalidRequest } - expectedAnswer, ok := b.Core.pendingRaftPeers[serverID] + expectedAnswerRaw, ok := b.Core.pendingRaftPeers.Load(serverID) if !ok { return logical.ErrorResponse("no expected answer for the server id provided"), logical.ErrInvalidRequest } - delete(b.Core.pendingRaftPeers, serverID) + b.Core.pendingRaftPeers.Delete(serverID) - if subtle.ConstantTimeCompare(answer, expectedAnswer) == 0 { + if subtle.ConstantTimeCompare(answer, expectedAnswerRaw.([]byte)) == 0 { return logical.ErrorResponse("invalid answer given"), logical.ErrInvalidRequest } diff --git a/vault/raft.go b/vault/raft.go index 950d510af41..1f42ecc4b32 100644 --- a/vault/raft.go +++ b/vault/raft.go @@ -159,7 +159,7 @@ func (c *Core) startRaftStorage(ctx context.Context) (retErr error) { } func (c *Core) setupRaftActiveNode(ctx context.Context) error { - c.pendingRaftPeers = make(map[string][]byte) + c.pendingRaftPeers = &sync.Map{} return c.startPeriodicRaftTLSRotate(ctx) }