diff --git a/crates/common/rlp/encode.rs b/crates/common/rlp/encode.rs index 1de929a8b9d..7835955dc49 100644 --- a/crates/common/rlp/encode.rs +++ b/crates/common/rlp/encode.rs @@ -139,6 +139,25 @@ impl RLPEncode for [u8] { buf.put_slice(self); } } + + fn length(&self) -> usize { + const U8_MAX_PLUS_ONE: usize = u8::MAX as usize + 1; + const U16_MAX_PLUS_ONE: usize = u16::MAX as usize + 1; + + match self.len() { + 0 => 1, // encodes to RLP_NULL + 1 if self[0] < 0x80 => 1, // `self` is its own encoding + 1..56 => 1 + self.len(), // single byte prefix + 56..U8_MAX_PLUS_ONE => 1 + 1 + self.len(), // single byte prefix + payload len bytes + U8_MAX_PLUS_ONE..U16_MAX_PLUS_ONE => 1 + 2 + self.len(), // single byte prefix + payload len bytes + _ => { + // fallback if `self` is longer than 2^16 - 1 bytes + let payload_len_bytes = + ((usize::BITS - self.len().leading_zeros()) as usize).div_ceil(8); + 1 + payload_len_bytes + self.len() + } + } + } } impl RLPEncode for [u8; N] { diff --git a/crates/common/trie/node.rs b/crates/common/trie/node.rs index 3bcb65d1bbd..ee670652f26 100644 --- a/crates/common/trie/node.rs +++ b/crates/common/trie/node.rs @@ -135,9 +135,22 @@ impl NodeRef { } pub fn compute_hash(&self) -> NodeHash { + *self.compute_hash_ref() + } + + pub fn compute_hash_ref(&self) -> &NodeHash { match self { - NodeRef::Node(node, hash) => *hash.get_or_init(|| node.compute_hash()), - NodeRef::Hash(hash) => *hash, + NodeRef::Node(node, hash) => hash.get_or_init(|| node.compute_hash()), + NodeRef::Hash(hash) => hash, + } + } + + pub fn memoize_hashes(&self) { + if let NodeRef::Node(node, hash) = &self + && hash.get().is_none() + { + node.memoize_hashes(); + let _ = hash.set(node.compute_hash()); } } @@ -294,12 +307,27 @@ impl Node { /// Computes the node's hash pub fn compute_hash(&self) -> NodeHash { + self.memoize_hashes(); match self { Node::Branch(n) => n.compute_hash(), Node::Extension(n) => n.compute_hash(), Node::Leaf(n) => n.compute_hash(), } } + + /// Recursively memoizes the hashes of all nodes of the subtrie that has + /// `self` as root (post-order traversal) + pub fn memoize_hashes(&self) { + match self { + Node::Branch(n) => { + for child in &n.choices { + child.memoize_hashes(); + } + } + Node::Extension(n) => n.child.memoize_hashes(), + _ => {} + } + } } /// Used as return type for `Node` remove operations that may resolve into either: diff --git a/crates/common/trie/node_hash.rs b/crates/common/trie/node_hash.rs index f7743bdf205..5a4f7480281 100644 --- a/crates/common/trie/node_hash.rs +++ b/crates/common/trie/node_hash.rs @@ -120,6 +120,14 @@ impl RLPEncode for NodeHash { fn encode(&self, buf: &mut dyn bytes::BufMut) { RLPEncode::encode(&Into::>::into(self), buf) } + + fn length(&self) -> usize { + match self { + NodeHash::Hashed(_) => 33, // 1 byte prefix + 32 bytes + NodeHash::Inline((_, 0)) => 1, // if empty then it's encoded to RLP_NULL + NodeHash::Inline((_, len)) => *len as usize, // already encoded + } + } } impl RLPDecode for NodeHash { diff --git a/crates/common/trie/rlp.rs b/crates/common/trie/rlp.rs index d09dc2853b8..1b4950572c2 100644 --- a/crates/common/trie/rlp.rs +++ b/crates/common/trie/rlp.rs @@ -3,8 +3,9 @@ use std::array; // Contains RLP encoding and decoding implementations for Trie Nodes // This encoding is only used to store the nodes in the DB, it is not the encoding used for hash computation use ethrex_rlp::{ + constants::RLP_NULL, decode::{RLPDecode, decode_bytes}, - encode::RLPEncode, + encode::{RLPEncode, encode_length}, error::RLPDecodeError, structs::{Decoder, Encoder}, }; @@ -14,18 +15,45 @@ use crate::{Nibbles, NodeHash}; impl RLPEncode for BranchNode { fn encode(&self, buf: &mut dyn bytes::BufMut) { - let mut encoder = Encoder::new(buf); + let value_len = <[u8] as RLPEncode>::length(&self.value); + let payload_len = self.choices.iter().fold(value_len, |acc, child| { + acc + RLPEncode::length(child.compute_hash_ref()) + }); + + encode_length(payload_len, buf); + for child in self.choices.iter() { + match child.compute_hash_ref() { + NodeHash::Hashed(hash) => hash.0.encode(buf), + NodeHash::Inline((_, 0)) => buf.put_u8(RLP_NULL), + NodeHash::Inline((encoded, len)) => buf.put_slice(&encoded[..*len as usize]), + } + } + <[u8] as RLPEncode>::encode(&self.value, buf); + } + + // Duplicated to prealloc the buffer and avoid calculating the payload length twice + fn encode_to_vec(&self) -> Vec { + let value_len = <[u8] as RLPEncode>::length(&self.value); + let choices_len = self.choices.iter().fold(0, |acc, child| { + acc + RLPEncode::length(child.compute_hash_ref()) + }); + let payload_len = choices_len + value_len; + + let mut buf: Vec = Vec::with_capacity(payload_len + 3); // 3 byte prefix headroom + + encode_length(payload_len, &mut buf); for child in self.choices.iter() { - match child.compute_hash() { - NodeHash::Hashed(hash) => encoder = encoder.encode_bytes(&hash.0), - child @ NodeHash::Inline(raw) if raw.1 != 0 => { - encoder = encoder.encode_raw(child.as_ref()) + match child.compute_hash_ref() { + NodeHash::Hashed(hash) => hash.0.encode(&mut buf), + NodeHash::Inline((_, 0)) => buf.push(RLP_NULL), + NodeHash::Inline((encoded, len)) => { + buf.extend_from_slice(&encoded[..*len as usize]) } - _ => encoder = encoder.encode_bytes(&[]), } } - encoder = encoder.encode_bytes(&self.value); - encoder.finish(); + <[u8] as RLPEncode>::encode(&self.value, &mut buf); + + buf } }