Skip to content

Commit e0b126b

Browse files
authored
fix(dot/network): fix memory allocations with sizedBufferPool (ChainSafe#1963)
1 parent 173f04e commit e0b126b

4 files changed

Lines changed: 131 additions & 22 deletions

File tree

dot/network/notifications.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDe
389389
go func() {
390390
msgBytes := s.bufPool.get()
391391
defer func() {
392-
s.bufPool.put(&msgBytes)
392+
s.bufPool.put(msgBytes)
393393
close(hsC)
394394
}()
395395

dot/network/pool.go

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ package network
55

66
// sizedBufferPool is a pool of buffers used for reading from streams
77
type sizedBufferPool struct {
8-
c chan *[maxMessageSize]byte
8+
c chan []byte
99
}
1010

11-
func newSizedBufferPool(min, max int) (bp *sizedBufferPool) {
12-
bufferCh := make(chan *[maxMessageSize]byte, max)
11+
func newSizedBufferPool(preAllocate, size int) (bp *sizedBufferPool) {
12+
bufferCh := make(chan []byte, size)
1313

14-
for i := 0; i < min; i++ {
15-
buf := [maxMessageSize]byte{}
16-
bufferCh <- &buf
14+
for i := 0; i < preAllocate; i++ {
15+
buf := make([]byte, maxMessageSize)
16+
bufferCh <- buf
1717
}
1818

1919
return &sizedBufferPool{
@@ -23,20 +23,19 @@ func newSizedBufferPool(min, max int) (bp *sizedBufferPool) {
2323

2424
// get gets a buffer from the sizedBufferPool, or creates a new one if none are
2525
// available in the pool. Buffers have a pre-allocated capacity.
26-
func (bp *sizedBufferPool) get() [maxMessageSize]byte {
27-
var buff *[maxMessageSize]byte
26+
func (bp *sizedBufferPool) get() (b []byte) {
2827
select {
29-
case buff = <-bp.c:
30-
// reuse existing buffer
28+
case b = <-bp.c:
29+
// reuse existing buffer
30+
return b
3131
default:
3232
// create new buffer
33-
buff = &[maxMessageSize]byte{}
33+
return make([]byte, maxMessageSize)
3434
}
35-
return *buff
3635
}
3736

3837
// put returns the given buffer to the sizedBufferPool.
39-
func (bp *sizedBufferPool) put(b *[maxMessageSize]byte) {
38+
func (bp *sizedBufferPool) put(b []byte) {
4039
select {
4140
case bp.c <- b:
4241
default: // Discard the buffer if the pool is full.

dot/network/pool_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package network
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func Benchmark_sizedBufferPool(b *testing.B) {
13+
const preAllocate = 100
14+
const poolSize = 200
15+
sbp := newSizedBufferPool(preAllocate, poolSize)
16+
17+
b.RunParallel(func(p *testing.PB) {
18+
for p.Next() {
19+
buffer := sbp.get()
20+
buffer[0] = 1
21+
buffer[len(buffer)-1] = 1
22+
sbp.put(buffer)
23+
}
24+
})
25+
}
26+
27+
// Before: 104853 11119 ns/op 65598 B/op 1 allocs/op
28+
// Array ptr: 2742781 438.3 ns/op 2 B/op 0 allocs/op
29+
// Slices: 2560960 463.8 ns/op 2 B/op 0 allocs/op
30+
// Slice pointer: 2683528 460.8 ns/op 2 B/op 0 allocs/op
31+
32+
func Test_sizedBufferPool(t *testing.T) {
33+
t.Parallel()
34+
35+
const preAlloc = 1
36+
const poolSize = 2
37+
const maxIndex = maxMessageSize - 1
38+
39+
pool := newSizedBufferPool(preAlloc, poolSize)
40+
41+
first := pool.get() // pre-allocated one
42+
first[maxIndex] = 1
43+
44+
second := pool.get() // new one
45+
second[maxIndex] = 2
46+
47+
third := pool.get() // new one
48+
third[maxIndex] = 3
49+
50+
fourth := pool.get() // new one
51+
fourth[maxIndex] = 4
52+
53+
pool.put(fourth)
54+
pool.put(third)
55+
pool.put(second) // discarded
56+
pool.put(first) // discarded
57+
58+
b := pool.get() // fourth
59+
assert.Equal(t, byte(4), b[maxIndex])
60+
61+
b = pool.get() // third
62+
assert.Equal(t, byte(3), b[maxIndex])
63+
}
64+
65+
func Test_sizedBufferPool_race(t *testing.T) {
66+
t.Parallel()
67+
68+
const preAlloc = 1
69+
const poolSize = 2
70+
71+
pool := newSizedBufferPool(preAlloc, poolSize)
72+
73+
const parallelism = 4
74+
75+
readyWait := new(sync.WaitGroup)
76+
readyWait.Add(parallelism)
77+
78+
doneWait := new(sync.WaitGroup)
79+
doneWait.Add(parallelism)
80+
81+
// run for 50ms
82+
ctxTimerStarted := make(chan struct{})
83+
ctx := context.Background()
84+
ctx, cancel := context.WithCancel(ctx)
85+
go func() {
86+
const timeout = 50 * time.Millisecond
87+
readyWait.Wait()
88+
ctx, cancel = context.WithTimeout(ctx, timeout)
89+
close(ctxTimerStarted)
90+
}()
91+
defer cancel()
92+
93+
for i := 0; i < parallelism; i++ {
94+
go func() {
95+
defer doneWait.Done()
96+
readyWait.Done()
97+
readyWait.Wait()
98+
<-ctxTimerStarted
99+
100+
for ctx.Err() != nil {
101+
// test relies on the -race detector
102+
// to detect concurrent writes to the buffer.
103+
b := pool.get()
104+
b[0] = 1
105+
pool.put(b)
106+
}
107+
}()
108+
}
109+
110+
doneWait.Wait()
111+
}

dot/network/service.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,13 @@ func NewService(cfg *Config) (*Service, error) {
135135
// pre-allocate pool of buffers used to read from streams.
136136
// initially allocate as many buffers as liekly necessary which is the number inbound streams we will have,
137137
// which should equal average number of peers times the number of notifications protocols, which is currently 3.
138-
var bufPool *sizedBufferPool
139-
if cfg.noPreAllocate {
140-
bufPool = &sizedBufferPool{
141-
c: make(chan *[maxMessageSize]byte, cfg.MinPeers*3),
142-
}
143-
} else {
144-
bufPool = newSizedBufferPool(cfg.MinPeers*3, cfg.MaxPeers*3)
138+
preAllocateInPool := cfg.MinPeers * 3
139+
poolSize := cfg.MaxPeers * 3
140+
if cfg.noPreAllocate { // testing
141+
preAllocateInPool = 0
142+
poolSize = cfg.MinPeers * 3
145143
}
144+
bufPool := newSizedBufferPool(preAllocateInPool, poolSize)
146145

147146
network := &Service{
148147
ctx: ctx,
@@ -550,7 +549,7 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder
550549

551550
peer := stream.Conn().RemotePeer()
552551
msgBytes := s.bufPool.get()
553-
defer s.bufPool.put(&msgBytes)
552+
defer s.bufPool.put(msgBytes)
554553

555554
for {
556555
tot, err := readStream(stream, msgBytes[:])

0 commit comments

Comments
 (0)