Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 84 additions & 35 deletions huff0/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ type dTable struct {

// single-symbols decoding
type dEntrySingle struct {
byte uint8
nBits uint8
entry uint16
}

// double-symbols decoding
Expand Down Expand Up @@ -76,14 +75,15 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
}

// collect weight stats
var rankStats [tableLogMax + 1]uint32
var rankStats [16]uint32
weightTotal := uint32(0)
for _, v := range s.huffWeight[:s.symbolLen] {
if v > tableLogMax {
return s, nil, errors.New("corrupt input: weight too large")
}
rankStats[v]++
weightTotal += (1 << (v & 15)) >> 1
v2 := v & 15
rankStats[v2]++
weightTotal += (1 << v2) >> 1
}
if weightTotal == 0 {
return s, nil, errors.New("corrupt input: weights zero")
Expand Down Expand Up @@ -134,15 +134,17 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
if len(s.dt.single) != tSize {
s.dt.single = make([]dEntrySingle, tSize)
}

for n, w := range s.huffWeight[:s.symbolLen] {
if w == 0 {
continue
}
length := (uint32(1) << w) >> 1
d := dEntrySingle{
byte: uint8(n),
nBits: s.actualTableLog + 1 - w,
entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
}
for u := rankStats[w]; u < rankStats[w]+length; u++ {
s.dt.single[u] = d
single := s.dt.single[rankStats[w] : rankStats[w]+length]
for i := range single {
single[i] = d
}
rankStats[w] += length
}
Expand All @@ -167,12 +169,12 @@ func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
decode := func() byte {
val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
v := s.dt.single[val]
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}
hasDec := func(v dEntrySingle) byte {
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}

// Avoid bounds check by always having full sized table.
Expand Down Expand Up @@ -269,8 +271,8 @@ func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
decode := func(br *bitReader) byte {
val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
v := single[val&tlMask]
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}

// Use temp table to avoid bound checks/append penalty.
Expand All @@ -283,20 +285,67 @@ func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
bigloop:
for {
for i := range br {
if br[i].off < 4 {
br := &br[i]
if br.off < 4 {
break bigloop
}
br[i].fillFast()
br.fillFast()
}

{
const stream = 0
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

{
const stream = 1
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

{
const stream = 2
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}
tmp[off] = decode(&br[0])
tmp[off+bufoff] = decode(&br[1])
tmp[off+bufoff*2] = decode(&br[2])
tmp[off+bufoff*3] = decode(&br[3])
tmp[off+1] = decode(&br[0])
tmp[off+1+bufoff] = decode(&br[1])
tmp[off+1+bufoff*2] = decode(&br[2])
tmp[off+1+bufoff*3] = decode(&br[3])

{
const stream = 3
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

off += 2

if off == bufoff {
if bufoff > dstEvery {
return nil, errors.New("corruption detected: stream overrun 1")
Expand Down Expand Up @@ -367,7 +416,7 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
broken++
if enc.nBits == 0 {
for _, dec := range dt {
if dec.byte == byte(sym) {
if uint8(dec.entry>>8) == byte(sym) {
fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
errs++
break
Expand All @@ -383,12 +432,12 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
top := enc.val << ub
// decoder looks at top bits.
dec := dt[top]
if dec.nBits != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, dec.nBits)
if uint8(dec.entry) != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
errs++
}
if dec.byte != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, dec.byte)
if uint8(dec.entry>>8) != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
errs++
}
if errs > 0 {
Expand All @@ -399,12 +448,12 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
for i := uint16(0); i < (1 << ub); i++ {
vval := top | i
dec := dt[vval]
if dec.nBits != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, dec.nBits)
if uint8(dec.entry) != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
errs++
}
if dec.byte != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, dec.byte)
if uint8(dec.entry>>8) != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
errs++
}
if errs > 20 {
Expand Down
38 changes: 38 additions & 0 deletions huff0/decompress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,41 @@ func BenchmarkDecompress4XNoTable(b *testing.B) {
})
}
}

func BenchmarkDecompress4XTable(b *testing.B) {
for _, tt := range testfiles {
test := tt
if test.err4X != nil {
continue
}
b.Run(test.name, func(b *testing.B) {
var s = &Scratch{}
s.Reuse = ReusePolicyNone
buf0, err := test.fn()
if err != nil {
b.Fatal(err)
}
if len(buf0) > BlockSizeMax {
buf0 = buf0[:BlockSizeMax]
}
compressed, _, err := Compress4X(buf0, s)
if err != test.err1X {
b.Fatal("unexpected error:", err)
}
s.Out = nil
b.ResetTimer()
b.ReportAllocs()
b.SetBytes(int64(len(buf0)))
for i := 0; i < b.N; i++ {
s, remain, err := ReadTable(compressed, s)
if err != nil {
b.Fatal(err)
}
_, err = s.Decompress4X(remain, len(buf0))
if err != nil {
b.Fatal(err)
}
}
})
}
}
Binary file modified zstd/testdata/benchdecoder.zip
Binary file not shown.