Skip to content

Commit ac85a6f

Browse files
authored
rlp: add back Iterator.Count, with fixes (#33841)
I removed `Iterator.Count` in #33840, because it appeared to be unused and did not provide the documented invariant: the returned count should always be an upper bound on the number of iterations allowed by `Next`. In order to make `Count` work, the semantics of `CountValues` has to change to return the number of items up and including the invalid one. I have reviewed all callsites of `CountValues` to assess if changing this is safe. There aren't that many, and the only call that doesn't check the error and return is in the trie node parser, `trie.decodeNodeUnsafe`. There, we distinguish the node type based on the number of items, and it previously returned an error for item count zero. In order to avoid any potential issue that could result from this change, I'm adding an error check in that function, though it isn't necessary.
1 parent 4f38a76 commit ac85a6f

5 files changed

Lines changed: 85 additions & 7 deletions

File tree

rlp/iterator.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ func (it *Iterator) Value() []byte {
6969
return it.next
7070
}
7171

72+
// Count returns the remaining number of items.
73+
// Note this is O(n) and the result may be incorrect if the list data is invalid.
74+
// The returned count is always an upper bound on the remaining items
75+
// that will be visited by the iterator.
76+
func (it *Iterator) Count() int {
77+
count, _ := CountValues(it.data)
78+
return count
79+
}
80+
7281
// Offset returns the offset of the current value into the list data.
7382
func (it *Iterator) Offset() int {
7483
return it.offset - len(it.next)

rlp/iterator_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package rlp
1818

1919
import (
20+
"io"
2021
"testing"
2122

2223
"github.com/ethereum/go-ethereum/common/hexutil"
@@ -54,6 +55,9 @@ func TestIterator(t *testing.T) {
5455
if err != nil {
5556
t.Fatal(err)
5657
}
58+
if c := txit.Count(); c != 2 {
59+
t.Fatal("wrong Count:", c)
60+
}
5761
var i = 0
5862
for txit.Next() {
5963
if txit.err != nil {
@@ -65,3 +69,65 @@ func TestIterator(t *testing.T) {
6569
t.Errorf("count wrong, expected %d got %d", i, exp)
6670
}
6771
}
72+
73+
func TestIteratorErrors(t *testing.T) {
74+
tests := []struct {
75+
input []byte
76+
wantCount int // expected Count before iterating
77+
wantErr error
78+
}{
79+
// Second item string header claims 3 bytes content, but only 2 remain.
80+
{unhex("C4 01 83AABB"), 2, ErrValueTooLarge},
81+
// Second item truncated: B9 requires 2 size bytes, none available.
82+
{unhex("C2 01 B9"), 2, io.ErrUnexpectedEOF},
83+
// 0x05 should be encoded directly, not as 81 05.
84+
{unhex("C3 01 8105"), 2, ErrCanonSize},
85+
// Long-form string header B8 used for 1-byte content (< 56).
86+
{unhex("C4 01 B801AA"), 2, ErrCanonSize},
87+
// Long-form list header F8 used for 1-byte content (< 56).
88+
{unhex("C4 01 F80101"), 2, ErrCanonSize},
89+
}
90+
for _, tt := range tests {
91+
it, err := NewListIterator(tt.input)
92+
if err != nil {
93+
t.Fatal("NewListIterator error:", err)
94+
}
95+
if c := it.Count(); c != tt.wantCount {
96+
t.Fatalf("%x: Count = %d, want %d", tt.input, c, tt.wantCount)
97+
}
98+
n := 0
99+
for it.Next() {
100+
if it.Err() != nil {
101+
break
102+
}
103+
n++
104+
}
105+
if wantN := tt.wantCount - 1; n != wantN {
106+
t.Fatalf("%x: got %d valid items, want %d", tt.input, n, wantN)
107+
}
108+
if it.Err() != tt.wantErr {
109+
t.Fatalf("%x: got error %v, want %v", tt.input, it.Err(), tt.wantErr)
110+
}
111+
if it.Next() {
112+
t.Fatalf("%x: Next returned true after error", tt.input)
113+
}
114+
}
115+
}
116+
117+
func FuzzIteratorCount(f *testing.F) {
118+
examples := [][]byte{unhex("010203"), unhex("018142"), unhex("01830202")}
119+
for _, e := range examples {
120+
f.Add(e)
121+
}
122+
f.Fuzz(func(t *testing.T, in []byte) {
123+
it := newIterator(in, 0)
124+
count := it.Count()
125+
i := 0
126+
for it.Next() {
127+
i++
128+
}
129+
if i != count {
130+
t.Fatalf("%x: count %d not equal to %d iterations", in, count, i)
131+
}
132+
})
133+
}

rlp/raw.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ func CountValues(b []byte) (int, error) {
285285
for ; len(b) > 0; i++ {
286286
_, tagsize, size, err := readKind(b)
287287
if err != nil {
288-
return 0, err
288+
return i + 1, err
289289
}
290290
b = b[tagsize+size:]
291291
}

rlp/raw_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,9 @@ func TestCountValues(t *testing.T) {
288288
{"820101 820202 8403030303 04", 4, nil},
289289

290290
// size errors
291-
{"8142", 0, ErrCanonSize},
292-
{"01 01 8142", 0, ErrCanonSize},
293-
{"02 84020202", 0, ErrValueTooLarge},
291+
{"8142", 1, ErrCanonSize},
292+
{"01 01 8142", 3, ErrCanonSize},
293+
{"02 84020202", 2, ErrValueTooLarge},
294294

295295
{
296296
input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470",

trie/node.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,14 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) {
161161
if err != nil {
162162
return nil, fmt.Errorf("decode error: %v", err)
163163
}
164-
switch c, _ := rlp.CountValues(elems); c {
165-
case 2:
164+
c, err := rlp.CountValues(elems)
165+
switch {
166+
case err != nil:
167+
return nil, fmt.Errorf("invalid node list: %v", err)
168+
case c == 2:
166169
n, err := decodeShort(hash, elems)
167170
return n, wrapError(err, "short")
168-
case 17:
171+
case c == 17:
169172
n, err := decodeFull(hash, elems)
170173
return n, wrapError(err, "full")
171174
default:

0 commit comments

Comments
 (0)