Skip to content

Commit f4ddf59

Browse files
holepunch: add multiaddress filter (#1839)
* feat: add holepunch address filter option * fix: exit early if all addresses were filtered out * incorporate PR feedback Co-authored-by: Marten Seemann <[email protected]> * remove: holepunch default filter * fix: hole punch failing test * holepunch: fix race condition in test when adding holepunch service * improve holepunch filter interface comments Co-authored-by: Marten Seemann <[email protected]>
1 parent 84ded7d commit f4ddf59

File tree

4 files changed

+123
-19
lines changed

4 files changed

+123
-19
lines changed

p2p/protocol/holepunch/filter.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package holepunch
2+
3+
import (
4+
"github.com/libp2p/go-libp2p/core/peer"
5+
ma "github.com/multiformats/go-multiaddr"
6+
)
7+
8+
// WithAddrFilter is a Service option that enables multiaddress filtering.
9+
// It allows to only send a subset of observed addresses to the remote
10+
// peer. E.g., only announce TCP or QUIC multi addresses instead of both.
11+
// It also allows to only consider a subset of received multi addresses
12+
// that remote peers announced to us.
13+
// Theoretically, this API also allows to add multi addresses in both cases.
14+
func WithAddrFilter(f AddrFilter) Option {
15+
return func(hps *Service) error {
16+
hps.filter = f
17+
return nil
18+
}
19+
}
20+
21+
// AddrFilter defines the interface for the multi address filtering.
22+
type AddrFilter interface {
23+
// FilterLocal filters the multi addresses that are sent to the remote peer.
24+
FilterLocal(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
25+
// FilterRemote filters the multi addresses received from the remote peer.
26+
FilterRemote(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
27+
}

p2p/protocol/holepunch/holepunch_test.go

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,21 @@ func (m *mockEventTracer) getEvents() []*holepunch.Event {
4545

4646
var _ holepunch.EventTracer = &mockEventTracer{}
4747

48+
type mockMaddrFilter struct {
49+
filterLocal func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
50+
filterRemote func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
51+
}
52+
53+
func (m mockMaddrFilter) FilterLocal(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
54+
return m.filterLocal(remoteID, maddrs)
55+
}
56+
57+
func (m mockMaddrFilter) FilterRemote(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
58+
return m.filterRemote(remoteID, maddrs)
59+
}
60+
61+
var _ holepunch.AddrFilter = &mockMaddrFilter{}
62+
4863
type mockIDService struct {
4964
identify.IDService
5065
}
@@ -110,7 +125,7 @@ func TestDirectDialWorks(t *testing.T) {
110125
func TestEndToEndSimConnect(t *testing.T) {
111126
h1tr := &mockEventTracer{}
112127
h2tr := &mockEventTracer{}
113-
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(h1tr), holepunch.WithTracer(h2tr), true)
128+
h1, h2, relay, _ := makeRelayedHosts(t, []holepunch.Option{holepunch.WithTracer(h1tr)}, []holepunch.Option{holepunch.WithTracer(h2tr)}, true)
114129
defer h1.Close()
115130
defer h2.Close()
116131
defer relay.Close()
@@ -151,6 +166,7 @@ func TestFailuresOnInitiator(t *testing.T) {
151166
rhandler func(s network.Stream)
152167
errMsg string
153168
holePunchTimeout time.Duration
169+
filter func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
154170
}{
155171
"responder does NOT send a CONNECT message": {
156172
rhandler: func(s network.Stream) {
@@ -175,6 +191,12 @@ func TestFailuresOnInitiator(t *testing.T) {
175191
},
176192
errMsg: "i/o deadline reached",
177193
},
194+
"no addrs after filtering": {
195+
errMsg: "aborting hole punch initiation as we have no public address",
196+
filter: func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
197+
return []ma.Multiaddr{}
198+
},
199+
},
178200
}
179201

180202
for name, tc := range tcs {
@@ -190,7 +212,22 @@ func TestFailuresOnInitiator(t *testing.T) {
190212
defer h1.Close()
191213
defer h2.Close()
192214
defer relay.Close()
193-
hps := addHolePunchService(t, h2, holepunch.WithTracer(tr))
215+
216+
opts := []holepunch.Option{holepunch.WithTracer(tr)}
217+
if tc.filter != nil {
218+
f := mockMaddrFilter{
219+
filterLocal: tc.filter,
220+
filterRemote: tc.filter,
221+
}
222+
opts = append(opts, holepunch.WithAddrFilter(f))
223+
}
224+
225+
hps := addHolePunchService(t, h2, opts...)
226+
// wait until the hole punching protocol has actually started
227+
require.Eventually(t, func() bool {
228+
protos, _ := h2.Peerstore().SupportsProtocols(h1.ID(), string(holepunch.Protocol))
229+
return len(protos) > 0
230+
}, 200*time.Millisecond, 10*time.Millisecond)
194231

195232
if tc.rhandler != nil {
196233
h1.SetStreamHandler(holepunch.Protocol, tc.rhandler)
@@ -221,6 +258,7 @@ func TestFailuresOnResponder(t *testing.T) {
221258
initiator func(s network.Stream)
222259
errMsg string
223260
holePunchTimeout time.Duration
261+
filter func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
224262
}{
225263
"initiator does NOT send a CONNECT message": {
226264
initiator: func(s network.Stream) {
@@ -258,6 +296,19 @@ func TestFailuresOnResponder(t *testing.T) {
258296
},
259297
errMsg: "expected CONNECT message to contain at least one address",
260298
},
299+
"no addrs after filtering": {
300+
errMsg: "rejecting hole punch request, as we don't have any public addresses",
301+
initiator: func(s network.Stream) {
302+
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{
303+
Type: holepunch_pb.HolePunch_CONNECT.Enum(),
304+
ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}),
305+
})
306+
time.Sleep(10 * time.Second)
307+
},
308+
filter: func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
309+
return []ma.Multiaddr{}
310+
},
311+
},
261312
}
262313

263314
for name, tc := range tcs {
@@ -267,9 +318,18 @@ func TestFailuresOnResponder(t *testing.T) {
267318
holepunch.StreamTimeout = tc.holePunchTimeout
268319
defer func() { holepunch.StreamTimeout = cpy }()
269320
}
270-
271321
tr := &mockEventTracer{}
272-
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), nil, false)
322+
323+
opts := []holepunch.Option{holepunch.WithTracer(tr)}
324+
if tc.filter != nil {
325+
f := mockMaddrFilter{
326+
filterLocal: tc.filter,
327+
filterRemote: tc.filter,
328+
}
329+
opts = append(opts, holepunch.WithAddrFilter(f))
330+
}
331+
332+
h1, h2, relay, _ := makeRelayedHosts(t, opts, nil, false)
273333
defer h1.Close()
274334
defer h2.Close()
275335
defer relay.Close()
@@ -379,13 +439,9 @@ func mkHostWithStaticAutoRelay(t *testing.T, relay host.Host) host.Host {
379439
return h
380440
}
381441

382-
func makeRelayedHosts(t *testing.T, h1opt, h2opt holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
442+
func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
383443
t.Helper()
384-
var h1opts []holepunch.Option
385-
if h1opt != nil {
386-
h1opts = append(h1opts, h1opt)
387-
}
388-
h1, _ = mkHostWithHolePunchSvc(t, h1opts...)
444+
h1, _ = mkHostWithHolePunchSvc(t, h1opt...)
389445
var err error
390446
relay, err = libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")), libp2p.DisableRelay())
391447
require.NoError(t, err)
@@ -395,7 +451,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt holepunch.Option, addHolePunche
395451

396452
h2 = mkHostWithStaticAutoRelay(t, relay)
397453
if addHolePuncher {
398-
hps = addHolePunchService(t, h2, h2opt)
454+
hps = addHolePunchService(t, h2, h2opt...)
399455
}
400456

401457
// h1 has a relay addr
@@ -415,12 +471,8 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt holepunch.Option, addHolePunche
415471
return
416472
}
417473

418-
func addHolePunchService(t *testing.T, h host.Host, opt holepunch.Option) *holepunch.Service {
474+
func addHolePunchService(t *testing.T, h host.Host, opts ...holepunch.Option) *holepunch.Service {
419475
t.Helper()
420-
var opts []holepunch.Option
421-
if opt != nil {
422-
opts = append(opts, opt)
423-
}
424476
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
425477
require.NoError(t, err)
426478
return hps

p2p/protocol/holepunch/holepuncher.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,16 @@ type holePuncher struct {
4949
closed bool
5050

5151
tracer *tracer
52+
filter AddrFilter
5253
}
5354

54-
func newHolePuncher(h host.Host, ids identify.IDService, tracer *tracer) *holePuncher {
55+
func newHolePuncher(h host.Host, ids identify.IDService, tracer *tracer, filter AddrFilter) *holePuncher {
5556
hp := &holePuncher{
5657
host: h,
5758
ids: ids,
5859
active: make(map[peer.ID]struct{}),
5960
tracer: tracer,
61+
filter: filter,
6062
}
6163
hp.ctx, hp.ctxCancel = context.WithCancel(context.Background())
6264
h.Network().Notify((*netNotifiee)(hp))
@@ -204,10 +206,18 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr
204206
str.SetDeadline(time.Now().Add(StreamTimeout))
205207

206208
// send a CONNECT and start RTT measurement.
209+
obsAddrs := removeRelayAddrs(hp.ids.OwnObservedAddrs())
210+
if hp.filter != nil {
211+
obsAddrs = hp.filter.FilterLocal(str.Conn().RemotePeer(), obsAddrs)
212+
}
213+
if len(obsAddrs) == 0 {
214+
return nil, 0, errors.New("aborting hole punch initiation as we have no public address")
215+
}
216+
207217
start := time.Now()
208218
if err := w.WriteMsg(&pb.HolePunch{
209219
Type: pb.HolePunch_CONNECT.Enum(),
210-
ObsAddrs: addrsToBytes(removeRelayAddrs(hp.ids.OwnObservedAddrs())),
220+
ObsAddrs: addrsToBytes(obsAddrs),
211221
}); err != nil {
212222
str.Reset()
213223
return nil, 0, err
@@ -222,7 +232,12 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr
222232
if t := msg.GetType(); t != pb.HolePunch_CONNECT {
223233
return nil, 0, fmt.Errorf("expect CONNECT message, got %s", t)
224234
}
235+
225236
addrs := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs))
237+
if hp.filter != nil {
238+
addrs = hp.filter.FilterRemote(str.Conn().RemotePeer(), addrs)
239+
}
240+
226241
if len(addrs) == 0 {
227242
return nil, 0, errors.New("didn't receive any public addresses in CONNECT")
228243
}

p2p/protocol/holepunch/svc.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ type Service struct {
5353
hasPublicAddrsChan chan struct{}
5454

5555
tracer *tracer
56+
filter AddrFilter
5657

5758
refCount sync.WaitGroup
5859
}
@@ -140,7 +141,7 @@ func (s *Service) watchForPublicAddr() {
140141
continue
141142
}
142143
s.holePuncherMx.Lock()
143-
s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer)
144+
s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer, s.filter)
144145
s.holePuncherMx.Unlock()
145146
close(s.hasPublicAddrsChan)
146147
return
@@ -169,6 +170,10 @@ func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, addr
169170
return 0, nil, fmt.Errorf("received hole punch stream: %s", str.Conn().RemoteMultiaddr())
170171
}
171172
ownAddrs := removeRelayAddrs(s.ids.OwnObservedAddrs())
173+
if s.filter != nil {
174+
ownAddrs = s.filter.FilterLocal(str.Conn().RemotePeer(), ownAddrs)
175+
}
176+
172177
// If we can't tell the peer where to dial us, there's no point in starting the hole punching.
173178
if len(ownAddrs) == 0 {
174179
return 0, nil, errors.New("rejecting hole punch request, as we don't have any public addresses")
@@ -194,7 +199,12 @@ func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, addr
194199
if t := msg.GetType(); t != pb.HolePunch_CONNECT {
195200
return 0, nil, fmt.Errorf("expected CONNECT message from initiator but got %d", t)
196201
}
202+
197203
obsDial := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs))
204+
if s.filter != nil {
205+
obsDial = s.filter.FilterRemote(str.Conn().RemotePeer(), obsDial)
206+
}
207+
198208
log.Debugw("received hole punch request", "peer", str.Conn().RemotePeer(), "addrs", obsDial)
199209
if len(obsDial) == 0 {
200210
return 0, nil, errors.New("expected CONNECT message to contain at least one address")

0 commit comments

Comments
 (0)