Skip to content

Commit 59173b5

Browse files
authored
Remove writer reference on close (#224)
Furthermore: * Clean up dictionary handling & test. * Fixes dictionary use in levels 1-6. Fixes #219
1 parent b949da4 commit 59173b5

File tree

2 files changed

+66
-64
lines changed

2 files changed

+66
-64
lines changed

flate/deflate.go

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@ type compressionLevel struct {
5959
// See https://blog.klauspost.com/rebalancing-deflate-compression-levels/
6060
var levels = []compressionLevel{
6161
{}, // 0
62-
// Level 1-4 uses specialized algorithm - values not used
62+
// Level 1-6 uses specialized algorithm - values not used
6363
{0, 0, 0, 0, 0, 1},
6464
{0, 0, 0, 0, 0, 2},
6565
{0, 0, 0, 0, 0, 3},
6666
{0, 0, 0, 0, 0, 4},
67-
// For levels 5-6 we don't bother trying with lazy matches.
68-
// Lazy matching is at least 30% slower, with 1.5% increase.
69-
{6, 0, 12, 8, 12, 5},
70-
{8, 0, 24, 16, 16, 6},
67+
{0, 0, 0, 0, 0, 5},
68+
{0, 0, 0, 0, 0, 6},
7169
// Levels 7-9 use increasingly more lazy matching
7270
// and increasingly stringent conditions for "good enough".
7371
{8, 8, 24, 16, skipNever, 7},
@@ -203,9 +201,8 @@ func (d *compressor) writeBlockSkip(tok *tokens, index int, eof bool) error {
203201
// This is much faster than doing a full encode.
204202
// Should only be used after a start/reset.
205203
func (d *compressor) fillWindow(b []byte) {
206-
// Do not fill window if we are in store-only mode,
207-
// use constant or Snappy compression.
208-
if d.level == 0 {
204+
// Do not fill window if we are in store-only or huffman mode.
205+
if d.level <= 0 {
209206
return
210207
}
211208
if d.fast != nil {
@@ -667,6 +664,7 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
667664
default:
668665
return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level)
669666
}
667+
d.level = level
670668
return nil
671669
}
672670

@@ -720,6 +718,7 @@ func (d *compressor) close() error {
720718
return d.w.err
721719
}
722720
d.w.flush()
721+
d.w.reset(nil)
723722
return d.w.err
724723
}
725724

@@ -750,8 +749,7 @@ func NewWriter(w io.Writer, level int) (*Writer, error) {
750749
// can only be decompressed by a Reader initialized with the
751750
// same dictionary.
752751
func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
753-
dw := &dictWriter{w}
754-
zw, err := NewWriter(dw, level)
752+
zw, err := NewWriter(w, level)
755753
if err != nil {
756754
return nil, err
757755
}
@@ -760,14 +758,6 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
760758
return zw, err
761759
}
762760

763-
type dictWriter struct {
764-
w io.Writer
765-
}
766-
767-
func (w *dictWriter) Write(b []byte) (n int, err error) {
768-
return w.w.Write(b)
769-
}
770-
771761
// A Writer takes data written to it and writes the compressed
772762
// form of that data to an underlying writer (see NewWriter).
773763
type Writer struct {
@@ -805,11 +795,12 @@ func (w *Writer) Close() error {
805795
// the result of NewWriter or NewWriterDict called with dst
806796
// and w's level and dictionary.
807797
func (w *Writer) Reset(dst io.Writer) {
808-
if dw, ok := w.d.w.writer.(*dictWriter); ok {
798+
if len(w.dict) > 0 {
809799
// w was created with NewWriterDict
810-
dw.w = dst
811-
w.d.reset(dw)
812-
w.d.fillWindow(w.dict)
800+
w.d.reset(dst)
801+
if dst != nil {
802+
w.d.fillWindow(w.dict)
803+
}
813804
} else {
814805
// w was created with NewWriter
815806
w.d.reset(dst)

flate/deflate_test.go

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -516,54 +516,65 @@ func TestWriterReset(t *testing.T) {
516516
t.Errorf("level %d Writer not reset after Reset", level)
517517
}
518518
}
519-
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, NoCompression) })
520-
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, DefaultCompression) })
521-
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, BestCompression) })
522-
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, ConstantCompression) })
523-
dict := []byte("we are the world")
524-
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, NoCompression, dict) })
525-
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, DefaultCompression, dict) })
526-
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, BestCompression, dict) })
527-
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, ConstantCompression, dict) })
528-
}
529-
530-
func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) {
531-
buf := new(bytes.Buffer)
532-
w, err := newWriter(buf)
533-
if err != nil {
534-
t.Fatalf("NewWriter: %v", err)
535-
}
536-
b := []byte("hello world")
537-
for i := 0; i < 1024; i++ {
538-
w.Write(b)
539-
}
540-
w.Close()
541-
out1 := buf.Bytes()
542519

543-
buf2 := new(bytes.Buffer)
544-
w.Reset(buf2)
545-
for i := 0; i < 1024; i++ {
546-
w.Write(b)
520+
for i := HuffmanOnly; i <= BestCompression; i++ {
521+
testResetOutput(t, fmt.Sprint("level-", i), func(w io.Writer) (*Writer, error) { return NewWriter(w, i) })
547522
}
548-
w.Close()
549-
out2 := buf2.Bytes()
550-
551-
if len(out1) != len(out2) {
552-
t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
523+
dict := []byte(strings.Repeat("we are the world - how are you?", 3))
524+
for i := HuffmanOnly; i <= BestCompression; i++ {
525+
testResetOutput(t, fmt.Sprint("dict-level-", i), func(w io.Writer) (*Writer, error) { return NewWriterDict(w, i, dict) })
553526
}
554-
if bytes.Compare(out1, out2) != 0 {
555-
mm := 0
556-
for i, b := range out1[:len(out2)] {
557-
if b != out2[i] {
558-
t.Errorf("mismatch index %d: %02x, expected %02x", i, out2[i], b)
527+
for i := HuffmanOnly; i <= BestCompression; i++ {
528+
testResetOutput(t, fmt.Sprint("dict-reset-level-", i), func(w io.Writer) (*Writer, error) {
529+
w2, err := NewWriter(nil, i)
530+
if err != nil {
531+
return w2, err
559532
}
560-
mm++
561-
if mm == 10 {
562-
t.Fatal("Stopping")
533+
w2.ResetDict(w, dict)
534+
return w2, nil
535+
})
536+
}
537+
}
538+
539+
func testResetOutput(t *testing.T, name string, newWriter func(w io.Writer) (*Writer, error)) {
540+
t.Run(name, func(t *testing.T) {
541+
buf := new(bytes.Buffer)
542+
w, err := newWriter(buf)
543+
if err != nil {
544+
t.Fatalf("NewWriter: %v", err)
545+
}
546+
b := []byte("hello world - how are you doing?")
547+
for i := 0; i < 1024; i++ {
548+
w.Write(b)
549+
}
550+
w.Close()
551+
out1 := buf.Bytes()
552+
553+
buf2 := new(bytes.Buffer)
554+
w.Reset(buf2)
555+
for i := 0; i < 1024; i++ {
556+
w.Write(b)
557+
}
558+
w.Close()
559+
out2 := buf2.Bytes()
560+
561+
if len(out1) != len(out2) {
562+
t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
563+
}
564+
if bytes.Compare(out1, out2) != 0 {
565+
mm := 0
566+
for i, b := range out1[:len(out2)] {
567+
if b != out2[i] {
568+
t.Errorf("mismatch index %d: %02x, expected %02x", i, out2[i], b)
569+
}
570+
mm++
571+
if mm == 10 {
572+
t.Fatal("Stopping")
573+
}
563574
}
564575
}
565-
}
566-
t.Logf("got %d bytes", len(out1))
576+
t.Logf("got %d bytes", len(out1))
577+
})
567578
}
568579

569580
// TestBestSpeed tests that round-tripping through deflate and then inflate

0 commit comments

Comments
 (0)