Skip to content
This repository was archived by the owner on Aug 19, 2022. It is now read-only.

Commit f3ae7f2

Browse files
handle TCP simultaneous open (option 4)
1 parent 8afeaef commit f3ae7f2

File tree

4 files changed

+206
-1
lines changed

4 files changed

+206
-1
lines changed

conn.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
package libp2ptls
22

33
import (
4+
"bytes"
45
"crypto/tls"
6+
"errors"
7+
"fmt"
8+
"io"
9+
"net"
510

611
ci "github.com/libp2p/go-libp2p-core/crypto"
712
"github.com/libp2p/go-libp2p-core/peer"
@@ -35,3 +40,110 @@ func (c *conn) RemotePeer() peer.ID {
3540
func (c *conn) RemotePublicKey() ci.PubKey {
3641
return c.remotePubKey
3742
}
43+
44+
const (
45+
recordTypeHandshake byte = 22
46+
versionTLS13 = 0x0304
47+
maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3
48+
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
49+
)
50+
51+
var errSimultaneousConnect = errors.New("detected TCP simultaneous connect")
52+
53+
type teeConn struct {
54+
net.Conn
55+
buf *bytes.Buffer
56+
}
57+
58+
func newTeeConn(c net.Conn, buf *bytes.Buffer) net.Conn {
59+
return &teeConn{Conn: c, buf: buf}
60+
}
61+
62+
func (c *teeConn) Read(b []byte) (int, error) {
63+
n, err := c.Conn.Read(b)
64+
c.buf.Write(b[:n])
65+
return n, err
66+
}
67+
68+
type wrappedConn struct {
69+
net.Conn
70+
71+
hasReadFirstMessage bool
72+
raw bytes.Buffer // contains a copy of every byte of the first handshake message we read from the wire
73+
74+
hand bytes.Buffer // used to store the first handshake message until we've completely read it
75+
76+
}
77+
78+
func newWrappedConn(c net.Conn) net.Conn {
79+
wc := &wrappedConn{}
80+
wc.Conn = newTeeConn(c, &wc.raw)
81+
return wc
82+
}
83+
84+
func (c *wrappedConn) Read(b []byte) (int, error) {
85+
if c.hasReadFirstMessage {
86+
return c.Conn.Read(b)
87+
}
88+
89+
// We read the first handshake message, and it was not a ClientHello.
90+
// We now need to feed all the bytes we read from the wire into the TLS stack,
91+
// so it can proceed with the handshake.
92+
if c.raw.Len() > 0 {
93+
n, err := c.raw.Read(b)
94+
if err == io.EOF || c.raw.Len() == 0 {
95+
c.hasReadFirstMessage = true
96+
err = nil
97+
}
98+
return n, err
99+
}
100+
101+
mes, err := c.readFirstHandshakeMessage()
102+
if err != nil {
103+
return 0, err
104+
}
105+
106+
switch mes[0] {
107+
case 1: // ClientHello
108+
return 0, errSimultaneousConnect
109+
case 2: // ServerHello
110+
return c.Read(b)
111+
default:
112+
return 0, fmt.Errorf("unexpected message type: %d", mes[0])
113+
}
114+
}
115+
116+
func (c *wrappedConn) readFirstHandshakeMessage() ([]byte, error) {
117+
for c.hand.Len() < 4 {
118+
if err := c.readRecord(); err != nil {
119+
return nil, err
120+
}
121+
}
122+
data := c.hand.Bytes()
123+
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
124+
if n > maxHandshake {
125+
return nil, fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
126+
}
127+
for c.hand.Len() < 4+n {
128+
if err := c.readRecord(); err != nil {
129+
return nil, err
130+
}
131+
}
132+
return c.hand.Next(4 + n), nil
133+
}
134+
135+
func (c *wrappedConn) readRecord() error {
136+
hdr := make([]byte, 5)
137+
if _, err := io.ReadFull(c.Conn, hdr); err != nil {
138+
return err
139+
}
140+
if hdr[0] != recordTypeHandshake {
141+
return errors.New("expected a handshake record")
142+
}
143+
n := int(hdr[3])<<8 | int(hdr[4])
144+
if n > maxCiphertextTLS13 {
145+
return fmt.Errorf("oversized record received with length %d", n)
146+
}
147+
_, err := io.CopyN(&c.hand, c.Conn, int64(n))
148+
return err
149+
}

crypto.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"crypto/ecdsa"
55
"crypto/elliptic"
66
"crypto/rand"
7+
"crypto/sha256"
78
"crypto/tls"
89
"crypto/x509"
910
"crypto/x509/pkix"
@@ -220,3 +221,19 @@ func preferServerCipherSuites() bool {
220221
)
221222
return !hasGCMAsm
222223
}
224+
225+
// Compare two peer IDs by their SHA256 hash.
226+
// The result will be 0 if H(a) == H(b), -1 if H(a) < H(b), and +1 if H(a) > H(b).
227+
func comparePeerIDs(p1, p2 peer.ID) int {
228+
p1Hash := sha256.Sum256([]byte(p1))
229+
p2Hash := sha256.Sum256([]byte(p2))
230+
for i := 0; i < sha256.Size; i++ {
231+
if p1Hash[i] < p2Hash[i] {
232+
return -1
233+
}
234+
if p1Hash[i] > p2Hash[i] {
235+
return 1
236+
}
237+
}
238+
return 0
239+
}

transport.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,26 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S
6666
// notice this after 1 RTT when calling Read.
6767
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
6868
config, keyCh := t.identity.ConfigForPeer(p)
69-
return t.handshake(ctx, tls.Client(insecure, config), keyCh)
69+
conn, err := t.handshake(ctx, tls.Client(newWrappedConn(insecure), config), keyCh)
70+
if err == errSimultaneousConnect {
71+
switch comparePeerIDs(t.localPeer, p) {
72+
case 0:
73+
return nil, errors.New("tried to simultaneous connect to oneself")
74+
case -1:
75+
// SHA256(our peer ID) is smaller than SHA256(their peer ID).
76+
// We're the client in the next connection attempt.
77+
config, keyCh := t.identity.ConfigForPeer(p)
78+
return t.handshake(ctx, tls.Client(insecure, config), keyCh)
79+
case 1:
80+
// SHA256(our peer ID) is larger than SHA256(their peer ID).
81+
// We're the server in the next connection attempt.
82+
config, keyCh := t.identity.ConfigForPeer(p)
83+
return t.handshake(ctx, tls.Server(insecure, config), keyCh)
84+
default:
85+
panic("unexpected peer ID comparison result")
86+
}
87+
}
88+
return conn, err
7089
}
7190

7291
func (t *Transport) handshake(

transport_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"math/big"
1616
mrand "math/rand"
1717
"net"
18+
"reflect"
1819
"time"
1920

2021
"github.com/onsi/gomega/gbytes"
@@ -188,6 +189,62 @@ var _ = Describe("Transport", func() {
188189
Eventually(done).Should(BeClosed())
189190
})
190191

192+
It("handles simultaneous open", func() {
193+
// Avoid confusion regarding the naming.
194+
p1, p1Key := serverID, serverKey
195+
p2, p2Key := clientID, clientKey
196+
197+
// We use a normal dial / listen to establish the TCP connection,
198+
// but we then start two clients.
199+
c1raw, c2raw := connect()
200+
201+
c1Transport, err := New(p1Key)
202+
Expect(err).ToNot(HaveOccurred())
203+
c2Transport, err := New(p2Key)
204+
Expect(err).ToNot(HaveOccurred())
205+
206+
c1ConnChan := make(chan sec.SecureConn, 1)
207+
go func() {
208+
defer GinkgoRecover()
209+
conn, err := c1Transport.SecureOutbound(context.Background(), c1raw, p2)
210+
Expect(err).ToNot(HaveOccurred())
211+
c1ConnChan <- conn
212+
}()
213+
214+
c2, err := c2Transport.SecureOutbound(context.Background(), c2raw, p1)
215+
Expect(err).ToNot(HaveOccurred())
216+
defer c2.Close()
217+
var c1 sec.SecureConn
218+
Eventually(c1ConnChan).Should(Receive(&c1))
219+
defer c1.Close()
220+
221+
// check that the peers are in the correct roles
222+
isClient := func(c sec.SecureConn) bool {
223+
// the isClient field of the tls.Conn will tell us who is client and server
224+
return reflect.ValueOf(c.(*conn).Conn).Elem().FieldByName("isClient").Bool()
225+
}
226+
switch comparePeerIDs(p1, p2) {
227+
case -1:
228+
// H(p1) < H(p2) => p1 acts as a client, p2 as a server
229+
Expect(isClient(c1)).To(BeTrue())
230+
Expect(isClient(c2)).To(BeFalse())
231+
case 1:
232+
// H(p1) > H(p2) => p1 acts as a server, p2 as a client
233+
Expect(isClient(c1)).To(BeFalse())
234+
Expect(isClient(c2)).To(BeTrue())
235+
default:
236+
Fail("unexpected peer comparison result")
237+
}
238+
239+
// exchange some data
240+
_, err = c1.Write([]byte("foobar"))
241+
Expect(err).ToNot(HaveOccurred())
242+
b := make([]byte, 6)
243+
_, err = c2.Read(b)
244+
Expect(err).ToNot(HaveOccurred())
245+
Expect(string(b)).To(Equal("foobar"))
246+
})
247+
191248
Context("invalid certificates", func() {
192249
invalidateCertChain := func(identity *Identity) {
193250
switch identity.config.Certificates[0].PrivateKey.(type) {

0 commit comments

Comments
 (0)