Skip to content

Commit 308a751

Browse files
authored
Refactor zstd decoder (#498)
# TLDR; * Streams can now be decoded without goroutines using `WithDecoderConcurrency(1)`. * `WithDecoderConcurrency(4)` is now default. If you need more concurrent `DecodeAll` operations, use `WithDecoderConcurrency(0)`. Goroutines exit when streams have finished reading (either error or EOF). Designed and tested to be compatible, but test before committing upgrade. # Changes Goroutines will now only be created on demand, and `WithDecoderConcurrency(1)` is now strictly synchronized. Decompression will typically be about 2x faster when using multiple goroutines, and will prepare input for the upstream reader async to reads. This can lead to ~3x faster input in total than using no goroutines. New default is now `WithDecoderConcurrency(4)` (or less, if GOMAXPROCS is less). Beyond 4, there is little benefit for streaming decompression. * No goroutines created, unless streaming, and auto-closed at error/EOF. * Synchronous stream decoding with `WithDecoderConcurrency(1)`. * Split sequence decoding/execution for streams up to 50% faster. * Simplified error flow. * Speedup on streams. * More consistent error reporting. * Improved error detection/compliance with reference decoder. * Improved test coverage. Fixes #477
1 parent 0f44aaa commit 308a751

File tree

18 files changed

+1447
-890
lines changed

18 files changed

+1447
-890
lines changed

huff0/bitreader.go

Lines changed: 15 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -8,115 +8,10 @@ package huff0
88
import (
99
"encoding/binary"
1010
"errors"
11+
"fmt"
1112
"io"
1213
)
1314

14-
// bitReader reads a bitstream in reverse.
15-
// The last set bit indicates the start of the stream and is used
16-
// for aligning the input.
17-
type bitReader struct {
18-
in []byte
19-
off uint // next byte to read is at in[off - 1]
20-
value uint64
21-
bitsRead uint8
22-
}
23-
24-
// init initializes and resets the bit reader.
25-
func (b *bitReader) init(in []byte) error {
26-
if len(in) < 1 {
27-
return errors.New("corrupt stream: too short")
28-
}
29-
b.in = in
30-
b.off = uint(len(in))
31-
// The highest bit of the last byte indicates where to start
32-
v := in[len(in)-1]
33-
if v == 0 {
34-
return errors.New("corrupt stream, did not find end of stream")
35-
}
36-
b.bitsRead = 64
37-
b.value = 0
38-
if len(in) >= 8 {
39-
b.fillFastStart()
40-
} else {
41-
b.fill()
42-
b.fill()
43-
}
44-
b.bitsRead += 8 - uint8(highBit32(uint32(v)))
45-
return nil
46-
}
47-
48-
// peekBitsFast requires that at least one bit is requested every time.
49-
// There are no checks if the buffer is filled.
50-
func (b *bitReader) peekBitsFast(n uint8) uint16 {
51-
const regMask = 64 - 1
52-
v := uint16((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
53-
return v
54-
}
55-
56-
// fillFast() will make sure at least 32 bits are available.
57-
// There must be at least 4 bytes available.
58-
func (b *bitReader) fillFast() {
59-
if b.bitsRead < 32 {
60-
return
61-
}
62-
63-
// 2 bounds checks.
64-
v := b.in[b.off-4 : b.off]
65-
v = v[:4]
66-
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
67-
b.value = (b.value << 32) | uint64(low)
68-
b.bitsRead -= 32
69-
b.off -= 4
70-
}
71-
72-
func (b *bitReader) advance(n uint8) {
73-
b.bitsRead += n
74-
}
75-
76-
// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
77-
func (b *bitReader) fillFastStart() {
78-
// Do single re-slice to avoid bounds checks.
79-
b.value = binary.LittleEndian.Uint64(b.in[b.off-8:])
80-
b.bitsRead = 0
81-
b.off -= 8
82-
}
83-
84-
// fill() will make sure at least 32 bits are available.
85-
func (b *bitReader) fill() {
86-
if b.bitsRead < 32 {
87-
return
88-
}
89-
if b.off > 4 {
90-
v := b.in[b.off-4:]
91-
v = v[:4]
92-
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
93-
b.value = (b.value << 32) | uint64(low)
94-
b.bitsRead -= 32
95-
b.off -= 4
96-
return
97-
}
98-
for b.off > 0 {
99-
b.value = (b.value << 8) | uint64(b.in[b.off-1])
100-
b.bitsRead -= 8
101-
b.off--
102-
}
103-
}
104-
105-
// finished returns true if all bits have been read from the bit stream.
106-
func (b *bitReader) finished() bool {
107-
return b.off == 0 && b.bitsRead >= 64
108-
}
109-
110-
// close the bitstream and returns an error if out-of-buffer reads occurred.
111-
func (b *bitReader) close() error {
112-
// Release reference.
113-
b.in = nil
114-
if b.bitsRead > 64 {
115-
return io.ErrUnexpectedEOF
116-
}
117-
return nil
118-
}
119-
12015
// bitReader reads a bitstream in reverse.
12116
// The last set bit indicates the start of the stream and is used
12217
// for aligning the input.
@@ -213,10 +108,17 @@ func (b *bitReaderBytes) finished() bool {
213108
return b.off == 0 && b.bitsRead >= 64
214109
}
215110

111+
func (b *bitReaderBytes) remaining() uint {
112+
return b.off*8 + uint(64-b.bitsRead)
113+
}
114+
216115
// close the bitstream and returns an error if out-of-buffer reads occurred.
217116
func (b *bitReaderBytes) close() error {
218117
// Release reference.
219118
b.in = nil
119+
if b.remaining() > 0 {
120+
return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining())
121+
}
220122
if b.bitsRead > 64 {
221123
return io.ErrUnexpectedEOF
222124
}
@@ -318,10 +220,17 @@ func (b *bitReaderShifted) finished() bool {
318220
return b.off == 0 && b.bitsRead >= 64
319221
}
320222

223+
func (b *bitReaderShifted) remaining() uint {
224+
return b.off*8 + uint(64-b.bitsRead)
225+
}
226+
321227
// close the bitstream and returns an error if out-of-buffer reads occurred.
322228
func (b *bitReaderShifted) close() error {
323229
// Release reference.
324230
b.in = nil
231+
if b.remaining() > 0 {
232+
return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining())
233+
}
325234
if b.bitsRead > 64 {
326235
return io.ErrUnexpectedEOF
327236
}

huff0/decompress.go

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,7 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
741741
}
742742

743743
var br [4]bitReaderShifted
744+
// Decode "jump table"
744745
start := 6
745746
for i := 0; i < 3; i++ {
746747
length := int(src[i*2]) | (int(src[i*2+1]) << 8)
@@ -865,30 +866,18 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
865866
}
866867

867868
// Decode remaining.
869+
remainBytes := dstEvery - (decoded / 4)
868870
for i := range br {
869871
offset := dstEvery * i
872+
endsAt := offset + remainBytes
873+
if endsAt > len(out) {
874+
endsAt = len(out)
875+
}
870876
br := &br[i]
871-
bitsLeft := br.off*8 + uint(64-br.bitsRead)
877+
bitsLeft := br.remaining()
872878
for bitsLeft > 0 {
873879
br.fill()
874-
if false && br.bitsRead >= 32 {
875-
if br.off >= 4 {
876-
v := br.in[br.off-4:]
877-
v = v[:4]
878-
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
879-
br.value = (br.value << 32) | uint64(low)
880-
br.bitsRead -= 32
881-
br.off -= 4
882-
} else {
883-
for br.off > 0 {
884-
br.value = (br.value << 8) | uint64(br.in[br.off-1])
885-
br.bitsRead -= 8
886-
br.off--
887-
}
888-
}
889-
}
890-
// end inline...
891-
if offset >= len(out) {
880+
if offset >= endsAt {
892881
d.bufs.Put(buf)
893882
return nil, errors.New("corruption detected: stream overrun 4")
894883
}
@@ -902,6 +891,10 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
902891
out[offset] = uint8(v >> 8)
903892
offset++
904893
}
894+
if offset != endsAt {
895+
d.bufs.Put(buf)
896+
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
897+
}
905898
decoded += offset - dstEvery*i
906899
err = br.close()
907900
if err != nil {
@@ -1091,10 +1084,16 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
10911084
}
10921085

10931086
// Decode remaining.
1087+
// Decode remaining.
1088+
remainBytes := dstEvery - (decoded / 4)
10941089
for i := range br {
10951090
offset := dstEvery * i
1091+
endsAt := offset + remainBytes
1092+
if endsAt > len(out) {
1093+
endsAt = len(out)
1094+
}
10961095
br := &br[i]
1097-
bitsLeft := int(br.off*8) + int(64-br.bitsRead)
1096+
bitsLeft := br.remaining()
10981097
for bitsLeft > 0 {
10991098
if br.finished() {
11001099
d.bufs.Put(buf)
@@ -1117,7 +1116,7 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
11171116
}
11181117
}
11191118
// end inline...
1120-
if offset >= len(out) {
1119+
if offset >= endsAt {
11211120
d.bufs.Put(buf)
11221121
return nil, errors.New("corruption detected: stream overrun 4")
11231122
}
@@ -1126,10 +1125,14 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
11261125
v := single[uint8(br.value>>shift)].entry
11271126
nBits := uint8(v)
11281127
br.advance(nBits)
1129-
bitsLeft -= int(nBits)
1128+
bitsLeft -= uint(nBits)
11301129
out[offset] = uint8(v >> 8)
11311130
offset++
11321131
}
1132+
if offset != endsAt {
1133+
d.bufs.Put(buf)
1134+
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
1135+
}
11331136
decoded += offset - dstEvery*i
11341137
err = br.close()
11351138
if err != nil {
@@ -1315,10 +1318,15 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
13151318
}
13161319

13171320
// Decode remaining.
1321+
remainBytes := dstEvery - (decoded / 4)
13181322
for i := range br {
13191323
offset := dstEvery * i
1324+
endsAt := offset + remainBytes
1325+
if endsAt > len(out) {
1326+
endsAt = len(out)
1327+
}
13201328
br := &br[i]
1321-
bitsLeft := int(br.off*8) + int(64-br.bitsRead)
1329+
bitsLeft := br.remaining()
13221330
for bitsLeft > 0 {
13231331
if br.finished() {
13241332
d.bufs.Put(buf)
@@ -1341,7 +1349,7 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
13411349
}
13421350
}
13431351
// end inline...
1344-
if offset >= len(out) {
1352+
if offset >= endsAt {
13451353
d.bufs.Put(buf)
13461354
return nil, errors.New("corruption detected: stream overrun 4")
13471355
}
@@ -1350,10 +1358,15 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
13501358
v := single[br.peekByteFast()].entry
13511359
nBits := uint8(v)
13521360
br.advance(nBits)
1353-
bitsLeft -= int(nBits)
1361+
bitsLeft -= uint(nBits)
13541362
out[offset] = uint8(v >> 8)
13551363
offset++
13561364
}
1365+
if offset != endsAt {
1366+
d.bufs.Put(buf)
1367+
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
1368+
}
1369+
13571370
decoded += offset - dstEvery*i
13581371
err = br.close()
13591372
if err != nil {

zstd/bitreader.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package zstd
77
import (
88
"encoding/binary"
99
"errors"
10+
"fmt"
1011
"io"
1112
"math/bits"
1213
)
@@ -132,6 +133,9 @@ func (b *bitReader) remain() uint {
132133
func (b *bitReader) close() error {
133134
// Release reference.
134135
b.in = nil
136+
if !b.finished() {
137+
return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
138+
}
135139
if b.bitsRead > 64 {
136140
return io.ErrUnexpectedEOF
137141
}

0 commit comments

Comments
 (0)