diff --git a/dot/state/offline_pruner.go b/dot/state/offline_pruner.go index 5a69307f7b..62ceb7b86c 100644 --- a/dot/state/offline_pruner.go +++ b/dot/state/offline_pruner.go @@ -121,7 +121,7 @@ func (p *OfflinePruner) SetBloomFilter() (err error) { return err } - tr.PopulateMerkleValues(tr.RootNode(), merkleValues) + trie.PopulateNodeHashes(tr.RootNode(), merkleValues) // get parent header of current block header, err = p.blockState.GetHeader(header.ParentHash) diff --git a/lib/trie/database.go b/lib/trie/database.go index fe908df6ec..c2ca4bde70 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -185,22 +185,35 @@ func (t *Trie) loadNode(db Database, n *Node) error { return nil } -// PopulateMerkleValues writes the Merkle value of each children of the node given -// as keys to the map merkleValues. -func (t *Trie) PopulateMerkleValues(n *Node, merkleValues map[string]struct{}) { - if n.Kind() != node.Branch { +// PopulateNodeHashes writes the node hash values of the node given and of +// all its descendant nodes as keys to the nodeHashes map. +// It is assumed the node and its descendant nodes have their Merkle value already +// computed. +func PopulateNodeHashes(n *Node, nodeHashes map[string]struct{}) { + if n == nil { return } - branch := n - for _, child := range branch.Children { - if child == nil { - continue - } + switch { + case len(n.MerkleValue) == 0: + // TODO remove once lazy loading of nodes is implemented + // https://github.com/ChainSafe/gossamer/issues/2838 + panic(fmt.Sprintf("node with key 0x%x has no Merkle value computed", n.Key)) + case len(n.MerkleValue) < 32: + // Inlined node where its Merkle value is its + // encoding and not the encoding hash digest. + return + } - merkleValues[string(child.MerkleValue)] = struct{}{} + nodeHashes[string(n.MerkleValue)] = struct{}{} - t.PopulateMerkleValues(child, merkleValues) + if n.Kind() == node.Leaf { + return + } + + branch := n + for _, child := range branch.Children { + PopulateNodeHashes(child, nodeHashes) } } diff --git a/lib/trie/database_test.go b/lib/trie/database_test.go index 41fb537953..341eaa2e88 100644 --- a/lib/trie/database_test.go +++ b/lib/trie/database_test.go @@ -158,7 +158,103 @@ func Test_Trie_WriteDirty_ClearPrefix(t *testing.T) { assert.Equal(t, trie.String(), trieFromDB.String()) } -func Test_Trie_GetFromDB(t *testing.T) { +func Test_PopulateNodeHashes(t *testing.T) { + t.Parallel() + + const ( + merkleValue32Zeroes = "00000000000000000000000000000000" + merkleValue32Ones = "11111111111111111111111111111111" + merkleValue32Twos = "22222222222222222222222222222222" + merkleValue32Threes = "33333333333333333333333333333333" + ) + + testCases := map[string]struct { + node *Node + nodeHashes map[string]struct{} + panicValue interface{} + }{ + "nil node": { + nodeHashes: map[string]struct{}{}, + }, + "inlined leaf node": { + node: &Node{MerkleValue: []byte("a")}, + nodeHashes: map[string]struct{}{}, + }, + "leaf node": { + node: &Node{MerkleValue: []byte(merkleValue32Zeroes)}, + nodeHashes: map[string]struct{}{ + merkleValue32Zeroes: {}, + }, + }, + "leaf node without Merkle value": { + node: &Node{Key: []byte{1}, SubValue: []byte{2}}, + panicValue: "node with key 0x01 has no Merkle value computed", + }, + "inlined branch node": { + node: &Node{ + MerkleValue: []byte("a"), + Children: padRightChildren([]*Node{ + {MerkleValue: []byte("b")}, + }), + }, + nodeHashes: map[string]struct{}{}, + }, + "branch node": { + node: &Node{ + MerkleValue: []byte(merkleValue32Zeroes), + Children: padRightChildren([]*Node{ + {MerkleValue: []byte(merkleValue32Ones)}, + }), + }, + nodeHashes: map[string]struct{}{ + merkleValue32Zeroes: {}, + merkleValue32Ones: {}, + }, + }, + "nested branch node": { + node: &Node{ + MerkleValue: []byte(merkleValue32Zeroes), + Children: padRightChildren([]*Node{ + {MerkleValue: []byte(merkleValue32Ones)}, + { + MerkleValue: []byte(merkleValue32Twos), + Children: padRightChildren([]*Node{ + {MerkleValue: []byte(merkleValue32Threes)}, + }), + }, + }), + }, + nodeHashes: map[string]struct{}{ + merkleValue32Zeroes: {}, + merkleValue32Ones: {}, + merkleValue32Twos: {}, + merkleValue32Threes: {}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodeHashes := make(map[string]struct{}) + + if testCase.panicValue != nil { + assert.PanicsWithValue(t, testCase.panicValue, func() { + PopulateNodeHashes(testCase.node, nodeHashes) + }) + return + } + + PopulateNodeHashes(testCase.node, nodeHashes) + + assert.Equal(t, testCase.nodeHashes, nodeHashes) + }) + } +} + +func Test_GetFromDB(t *testing.T) { t.Parallel() const size = 1000