Skip to content

Commit 3f5f3f4

Browse files
fbivillerobsdedude
andauthored
Close socket if Bolt handshake fails
Co-authored-by: Rouven Bauer <[email protected]>
1 parent 62bae2d commit 3f5f3f4

File tree

2 files changed

+198
-19
lines changed

2 files changed

+198
-19
lines changed

neo4j/internal/connector/connector.go

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,17 @@ import (
3434
)
3535

3636
type Connector struct {
37-
SkipEncryption bool
38-
SkipVerify bool
39-
RootCAs *x509.CertPool
40-
DialTimeout time.Duration
41-
SocketKeepAlive bool
42-
Auth map[string]interface{}
43-
Log log.Logger
44-
UserAgent string
45-
RoutingContext map[string]string
46-
Network string
37+
SkipEncryption bool
38+
SkipVerify bool
39+
RootCAs *x509.CertPool
40+
DialTimeout time.Duration
41+
SocketKeepAlive bool
42+
Auth map[string]interface{}
43+
Log log.Logger
44+
UserAgent string
45+
RoutingContext map[string]string
46+
Network string
47+
SupplyConnection func(address string) (net.Conn, error)
4748
}
4849

4950
type ConnectError struct {
@@ -63,19 +64,24 @@ func (e *TlsError) Error() string {
6364
}
6465

6566
func (c Connector) Connect(address string, boltLogger log.BoltLogger) (db.Connection, error) {
66-
dialer := net.Dialer{Timeout: c.DialTimeout}
67-
if !c.SocketKeepAlive {
68-
dialer.KeepAlive = -1 * time.Second // Turns keep-alive off
67+
if c.SupplyConnection == nil {
68+
c.SupplyConnection = c.createConnection
6969
}
70-
71-
conn, err := dialer.Dial(c.Network, address)
70+
conn, err := c.SupplyConnection(address)
7271
if err != nil {
7372
return nil, &ConnectError{inner: err}
7473
}
7574

76-
// TLS not requested, perform Bolt handshake
75+
// TLS not requested
7776
if c.SkipEncryption {
78-
return bolt.Connect(address, conn, c.Auth, c.UserAgent, c.RoutingContext, c.Log, boltLogger)
77+
connection, err := bolt.Connect(address, conn, c.Auth, c.UserAgent, c.RoutingContext, c.Log, boltLogger)
78+
if err != nil {
79+
if connErr := conn.Close(); connErr != nil {
80+
c.Log.Warnf(log.Driver, "", "Could not close underlying socket after Bolt handshake error")
81+
}
82+
return nil, err
83+
}
84+
return connection, err
7985
}
8086

8187
// TLS requested, continue with handshake
@@ -100,6 +106,20 @@ func (c Connector) Connect(address string, boltLogger log.BoltLogger) (db.Connec
100106
conn.Close()
101107
return nil, &TlsError{inner: err}
102108
}
103-
// Perform Bolt handshake
104-
return bolt.Connect(address, tlsconn, c.Auth, c.UserAgent, c.RoutingContext, c.Log, boltLogger)
109+
connection, err := bolt.Connect(address, tlsconn, c.Auth, c.UserAgent, c.RoutingContext, c.Log, boltLogger)
110+
if err != nil {
111+
if connErr := conn.Close(); connErr != nil {
112+
c.Log.Warnf(log.Driver, "", "Could not close underlying socket after Bolt handshake error")
113+
}
114+
return nil, err
115+
}
116+
return connection, nil
117+
}
118+
119+
func (c Connector) createConnection(address string) (net.Conn, error) {
120+
dialer := net.Dialer{Timeout: c.DialTimeout}
121+
if !c.SocketKeepAlive {
122+
dialer.KeepAlive = -1 * time.Second // Turns keep-alive off
123+
}
124+
return dialer.Dial(c.Network, address)
105125
}
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [https://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* https://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
*/
19+
20+
package connector_test
21+
22+
import (
23+
"github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/connector"
24+
. "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/testutil"
25+
"io"
26+
"net"
27+
"testing"
28+
"time"
29+
)
30+
31+
func TestConnect(outer *testing.T) {
32+
outer.Parallel()
33+
34+
outer.Run("closes connection if Bolt handshake does not reach agreement", func(t *testing.T) {
35+
clientConnection, server := setUp(t)
36+
go func() {
37+
server.acceptVersion(1, 0)
38+
}()
39+
connectionDelegate := &ConnDelegate{Delegate: clientConnection}
40+
connector := &connector.Connector{SupplyConnection: supplyThis(connectionDelegate), SkipEncryption: true}
41+
42+
connection, err := connector.Connect("irrelevant", nil)
43+
44+
AssertNil(t, connection)
45+
AssertErrorMessageContains(t, err, "unsupported version 1.0")
46+
AssertTrue(t, connectionDelegate.Closed)
47+
})
48+
49+
outer.Run("closes connection if Bolt handshake errors", func(t *testing.T) {
50+
clientConnection, server := setUp(t)
51+
go func() {
52+
server.failAcceptingVersion()
53+
}()
54+
connectionDelegate := &ConnDelegate{Delegate: clientConnection}
55+
connector := &connector.Connector{SupplyConnection: supplyThis(connectionDelegate), SkipEncryption: true}
56+
57+
connection, err := connector.Connect("irrelevant", nil)
58+
59+
AssertNil(t, connection)
60+
AssertError(t, err)
61+
AssertTrue(t, connectionDelegate.Closed)
62+
})
63+
}
64+
65+
func setUp(t *testing.T) (net.Conn, *boltHandshakeServer) {
66+
listener, err := net.Listen("tcp", ":0")
67+
if err != nil {
68+
t.Fatalf("Unable to listen: %s", err)
69+
}
70+
t.Cleanup(func() {
71+
_ = listener.Close()
72+
})
73+
74+
address := listener.Addr()
75+
clientConnection, err := net.Dial(address.Network(), address.String())
76+
if err != nil {
77+
t.Fatalf("Dial error: %s", err)
78+
}
79+
t.Cleanup(func() {
80+
_ = clientConnection.Close()
81+
})
82+
serverConnection, err := listener.Accept()
83+
if err != nil {
84+
t.Fatalf("Accept error: %s", err)
85+
}
86+
t.Cleanup(func() {
87+
_ = serverConnection.Close()
88+
})
89+
handshakeServer := &boltHandshakeServer{t, serverConnection}
90+
return clientConnection, handshakeServer
91+
}
92+
93+
func supplyThis(connection net.Conn) func(address string) (net.Conn, error) {
94+
return func(address string) (net.Conn, error) {
95+
return connection, nil
96+
}
97+
}
98+
99+
type boltHandshakeServer struct {
100+
t *testing.T
101+
conn net.Conn
102+
}
103+
104+
func (server *boltHandshakeServer) waitForHandshake() []byte {
105+
handshake := make([]byte, 4*5)
106+
if _, err := io.ReadFull(server.conn, handshake); err != nil {
107+
server.t.Fatalf("Unable to read client versions: %s", err)
108+
}
109+
return handshake
110+
}
111+
112+
func (server *boltHandshakeServer) acceptVersion(major, minor byte) {
113+
server.waitForHandshake()
114+
if _, err := server.conn.Write([]byte{0x00, 0x00, minor, major}); err != nil {
115+
panic(err)
116+
}
117+
}
118+
119+
func (server *boltHandshakeServer) failAcceptingVersion() {
120+
_ = server.conn.Close()
121+
}
122+
123+
type ConnDelegate struct {
124+
Closed bool
125+
Delegate net.Conn
126+
}
127+
128+
func (cd *ConnDelegate) Read(b []byte) (n int, err error) {
129+
return cd.Delegate.Read(b)
130+
}
131+
132+
func (cd *ConnDelegate) Write(b []byte) (n int, err error) {
133+
return cd.Delegate.Write(b)
134+
}
135+
136+
func (cd *ConnDelegate) Close() error {
137+
cd.Closed = true
138+
return cd.Delegate.Close()
139+
}
140+
141+
func (cd *ConnDelegate) LocalAddr() net.Addr {
142+
return cd.Delegate.LocalAddr()
143+
}
144+
145+
func (cd *ConnDelegate) RemoteAddr() net.Addr {
146+
return cd.Delegate.RemoteAddr()
147+
}
148+
149+
func (cd *ConnDelegate) SetDeadline(t time.Time) error {
150+
return cd.Delegate.SetDeadline(t)
151+
}
152+
153+
func (cd *ConnDelegate) SetReadDeadline(t time.Time) error {
154+
return cd.Delegate.SetReadDeadline(t)
155+
}
156+
157+
func (cd *ConnDelegate) SetWriteDeadline(t time.Time) error {
158+
return cd.Delegate.SetWriteDeadline(t)
159+
}

0 commit comments

Comments
 (0)