From 748d1c171d74fbf6b6051fd629d3c2204dd930e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= Date: Thu, 19 May 2016 13:24:14 +0300 Subject: [PATCH] core, core/state, trie: enterprise hand-tuned multi-level caching --- core/blockchain.go | 7 +- core/state/statedb.go | 22 ++++ trie/iterator.go | 24 +++-- trie/node.go | 60 +++++++---- trie/proof.go | 8 +- trie/secure_trie.go | 4 +- trie/sync.go | 13 +-- trie/trie.go | 237 ++++++++++++++++++++++++++---------------- trie/trie_test.go | 2 +- 9 files changed, 244 insertions(+), 133 deletions(-) diff --git a/core/blockchain.go b/core/blockchain.go index 171a49e53d..bd84adfe9a 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -819,6 +819,7 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { tstart = time.Now() nonceChecked = make([]bool, len(chain)) + statedb *state.StateDB ) // Start the parallel nonce verifier. @@ -885,7 +886,11 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { // Create a new statedb using the parent block and report an // error if it fails. - statedb, err := state.New(self.GetBlock(block.ParentHash()).Root(), self.chainDb) + if statedb == nil { + statedb, err = state.New(self.GetBlock(block.ParentHash()).Root(), self.chainDb) + } else { + err = statedb.Reset(chain[i-1].Root()) + } if err != nil { reportBlock(block, err) return i, err diff --git a/core/state/statedb.go b/core/state/statedb.go index 22ffa36a06..cfcb82d974 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -68,6 +68,28 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) { }, nil } +// Reset clears out all emphemeral state objects from the state db, but keeps +// the underlying state trie to avoid reloading data for the next operations. +func (self *StateDB) Reset(root common.Hash) error { + var ( + err error + tr = self.trie + ) + if self.trie.Hash() != root { + if tr, err = trie.NewSecure(root, self.db); err != nil { + return err + } + } + *self = StateDB{ + db: self.db, + trie: tr, + stateObjects: make(map[string]*StateObject), + refund: new(big.Int), + logs: make(map[common.Hash]vm.Logs), + } + return nil +} + func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { self.thash = thash self.bhash = bhash diff --git a/trie/iterator.go b/trie/iterator.go index ceef52ec8d..88c4cee7fa 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -62,7 +62,7 @@ func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byt switch node := node.(type) { case fullNode: if len(key) > 0 { - k := self.next(node[key[0]], key[1:], isIterStart) + k := self.next(node.Children[key[0]], key[1:], isIterStart) if k != nil { return append([]byte{key[0]}, k...) } @@ -74,7 +74,7 @@ func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byt } for i := r; i < 16; i++ { - k := self.key(node[i]) + k := self.key(node.Children[i]) if k != nil { return append([]byte{i}, k...) } @@ -130,12 +130,12 @@ func (self *Iterator) key(node interface{}) []byte { } return append(k, self.key(node.Val)...) case fullNode: - if node[16] != nil { - self.Value = node[16].(valueNode) + if node.Children[16] != nil { + self.Value = node.Children[16].(valueNode) return []byte{16} } for i := 0; i < 16; i++ { - k := self.key(node[i]) + k := self.key(node.Children[i]) if k != nil { return append([]byte{byte(i)}, k...) } @@ -175,7 +175,7 @@ type NodeIterator struct { // NewNodeIterator creates an post-order trie iterator. func NewNodeIterator(trie *Trie) *NodeIterator { - if bytes.Compare(trie.Root(), emptyRoot.Bytes()) == 0 { + if trie.Hash() == emptyState { return new(NodeIterator) } return &NodeIterator{trie: trie} @@ -205,9 +205,11 @@ func (it *NodeIterator) step() error { } // Initialize the iterator if we've just started, or pop off the old node otherwise if len(it.stack) == 0 { - it.stack = append(it.stack, &nodeIteratorState{node: it.trie.root, child: -1}) + // Always start with a collapsed root + root := it.trie.Hash() + it.stack = append(it.stack, &nodeIteratorState{node: hashNode(root[:]), child: -1}) if it.stack[0].node == nil { - return fmt.Errorf("root node missing: %x", it.trie.Root()) + return fmt.Errorf("root node missing: %x", it.trie.Hash()) } } else { it.stack = it.stack[:len(it.stack)-1] @@ -225,11 +227,11 @@ func (it *NodeIterator) step() error { } if node, ok := parent.node.(fullNode); ok { // Full node, traverse all children, then the node itself - if parent.child >= len(node) { + if parent.child >= len(node.Children) { break } - for parent.child++; parent.child < len(node); parent.child++ { - if current := node[parent.child]; current != nil { + for parent.child++; parent.child < len(node.Children); parent.child++ { + if current := node.Children[parent.child]; current != nil { it.stack = append(it.stack, &nodeIteratorState{node: current, parent: ancestor, child: -1}) break } diff --git a/trie/node.go b/trie/node.go index 0bfa21dc43..b97d370be4 100644 --- a/trie/node.go +++ b/trie/node.go @@ -29,18 +29,36 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b type node interface { fstring(string) string + cache() (hashNode, bool) } type ( - fullNode [17]node + fullNode struct { + Children [17]node // Actual trie node data to encode/decode (needs custom encoder) + hash hashNode // Cached hash of the node to prevent rehashing (may be nil) + dirty bool // Cached flag whether the node's new or already stored + } shortNode struct { - Key []byte - Val node + Key []byte + Val node + hash hashNode // Cached hash of the node to prevent rehashing (may be nil) + dirty bool // Cached flag whether the node's new or already stored } hashNode []byte valueNode []byte ) +// EncodeRLP encodes a full node into the consensus RLP format. +func (n fullNode) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, n.Children) +} + +// Cache accessors to retrieve precalculated values (avoid lengthy type switches). +func (n fullNode) cache() (hashNode, bool) { return n.hash, n.dirty } +func (n shortNode) cache() (hashNode, bool) { return n.hash, n.dirty } +func (n hashNode) cache() (hashNode, bool) { return nil, true } +func (n valueNode) cache() (hashNode, bool) { return nil, true } + // Pretty printing. func (n fullNode) String() string { return n.fstring("") } func (n shortNode) String() string { return n.fstring("") } @@ -49,7 +67,7 @@ func (n valueNode) String() string { return n.fstring("") } func (n fullNode) fstring(ind string) string { resp := fmt.Sprintf("[\n%s ", ind) - for i, node := range n { + for i, node := range n.Children { if node == nil { resp += fmt.Sprintf("%s: ", indices[i]) } else { @@ -68,16 +86,16 @@ func (n valueNode) fstring(ind string) string { return fmt.Sprintf("%x ", []byte(n)) } -func mustDecodeNode(dbkey, buf []byte) node { - n, err := decodeNode(buf) +func mustDecodeNode(hash, buf []byte) node { + n, err := decodeNode(hash, buf) if err != nil { - panic(fmt.Sprintf("node %x: %v", dbkey, err)) + panic(fmt.Sprintf("node %x: %v", hash, err)) } return n } // decodeNode parses the RLP encoding of a trie node. -func decodeNode(buf []byte) (node, error) { +func decodeNode(hash, buf []byte) (node, error) { if len(buf) == 0 { return nil, io.ErrUnexpectedEOF } @@ -87,18 +105,18 @@ func decodeNode(buf []byte) (node, error) { } switch c, _ := rlp.CountValues(elems); c { case 2: - n, err := decodeShort(elems) + n, err := decodeShort(hash, buf, elems) return n, wrapError(err, "short") case 17: - n, err := decodeFull(elems) + n, err := decodeFull(hash, buf, elems) return n, wrapError(err, "full") default: return nil, fmt.Errorf("invalid number of list elements: %v", c) } } -func decodeShort(buf []byte) (node, error) { - kbuf, rest, err := rlp.SplitString(buf) +func decodeShort(hash, buf, elems []byte) (node, error) { + kbuf, rest, err := rlp.SplitString(elems) if err != nil { return nil, err } @@ -109,30 +127,30 @@ func decodeShort(buf []byte) (node, error) { if err != nil { return nil, fmt.Errorf("invalid value node: %v", err) } - return shortNode{key, valueNode(val)}, nil + return shortNode{key, valueNode(val), hash, false}, nil } r, _, err := decodeRef(rest) if err != nil { return nil, wrapError(err, "val") } - return shortNode{key, r}, nil + return shortNode{key, r, hash, false}, nil } -func decodeFull(buf []byte) (fullNode, error) { - var n fullNode +func decodeFull(hash, buf, elems []byte) (fullNode, error) { + n := fullNode{hash: hash} for i := 0; i < 16; i++ { - cld, rest, err := decodeRef(buf) + cld, rest, err := decodeRef(elems) if err != nil { return n, wrapError(err, fmt.Sprintf("[%d]", i)) } - n[i], buf = cld, rest + n.Children[i], elems = cld, rest } - val, _, err := rlp.SplitString(buf) + val, _, err := rlp.SplitString(elems) if err != nil { return n, err } if len(val) > 0 { - n[16] = valueNode(val) + n.Children[16] = valueNode(val) } return n, nil } @@ -152,7 +170,7 @@ func decodeRef(buf []byte) (node, []byte, error) { err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) return nil, buf, err } - n, err := decodeNode(buf) + n, err := decodeNode(nil, buf) return n, rest, err case kind == rlp.String && len(val) == 0: // empty node diff --git a/trie/proof.go b/trie/proof.go index 37a70fb34d..5135de0473 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -54,7 +54,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { } nodes = append(nodes, n) case fullNode: - tn = n[key[0]] + tn = n.Children[key[0]] key = key[1:] nodes = append(nodes, n) case hashNode: @@ -77,7 +77,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. - n, _ = t.hasher.replaceChildren(n, nil) + n, _, _ = t.hasher.hashChildren(n, nil) hn, _ := t.hasher.store(n, nil, false) if _, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the @@ -103,7 +103,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value if !bytes.Equal(sha.Sum(nil), wantHash) { return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) } - n, err := decodeNode(buf) + n, err := decodeNode(wantHash, buf) if err != nil { return nil, fmt.Errorf("bad proof node %d: %v", i, err) } @@ -139,7 +139,7 @@ func get(tn node, key []byte) ([]byte, node) { tn = n.Val key = key[len(n.Key):] case fullNode: - tn = n[key[0]] + tn = n.Children[key[0]] key = key[1:] case hashNode: return key, n diff --git a/trie/secure_trie.go b/trie/secure_trie.go index be7defe83b..1d027c1027 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -162,11 +162,11 @@ func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { } t.secKeyCache = make(map[string][]byte) } - n, err := t.hashRoot(db) + n, clean, err := t.hashRoot(db) if err != nil { return (common.Hash{}), err } - t.root = n + t.root = clean return common.BytesToHash(n.(hashNode)), nil } diff --git a/trie/sync.go b/trie/sync.go index d55399d06b..a35478f837 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -75,8 +75,9 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c if root == emptyRoot { return } - blob, _ := s.database.Get(root.Bytes()) - if local, err := decodeNode(blob); local != nil && err == nil { + key := root.Bytes() + blob, _ := s.database.Get(key) + if local, err := decodeNode(key, blob); local != nil && err == nil { return } // Assemble the new sub-trie sync request @@ -152,7 +153,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) { continue } // Decode the node data content and update the request - node, err := decodeNode(item.Data) + node, err := decodeNode(item.Hash[:], item.Data) if err != nil { return i, err } @@ -213,9 +214,9 @@ func (s *TrieSync) children(req *request) ([]*request, error) { }} case fullNode: for i := 0; i < 17; i++ { - if node[i] != nil { + if node.Children[i] != nil { children = append(children, child{ - node: &node[i], + node: &node.Children[i], depth: req.depth + 1, }) } @@ -238,7 +239,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) { if node, ok := (*child.node).(hashNode); ok { // Try to resolve the node from the local database blob, _ := s.database.Get(node) - if local, err := decodeNode(blob); local != nil && err == nil { + if local, err := decodeNode(node[:], blob); local != nil && err == nil { *child.node = local continue } diff --git a/trie/trie.go b/trie/trie.go index cc5dcf2a65..a530e7b2a3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -129,7 +129,7 @@ func (t *Trie) TryGet(key []byte) ([]byte, error) { tn = n.Val pos += len(n.Key) case fullNode: - tn = n[key[pos]] + tn = n.Children[key[pos]] pos++ case nil: return nil, nil @@ -169,13 +169,13 @@ func (t *Trie) Update(key, value []byte) { func (t *Trie) TryUpdate(key, value []byte) error { k := compactHexDecode(key) if len(value) != 0 { - n, err := t.insert(t.root, nil, k, valueNode(value)) + _, n, err := t.insert(t.root, nil, k, valueNode(value)) if err != nil { return err } t.root = n } else { - n, err := t.delete(t.root, nil, k) + _, n, err := t.delete(t.root, nil, k) if err != nil { return err } @@ -184,9 +184,12 @@ func (t *Trie) TryUpdate(key, value []byte) error { return nil } -func (t *Trie) insert(n node, prefix, key []byte, value node) (node, error) { +func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) { if len(key) == 0 { - return value, nil + if v, ok := n.(valueNode); ok { + return !bytes.Equal(v, value.(valueNode)), value, nil + } + return true, value, nil } switch n := n.(type) { case shortNode: @@ -194,53 +197,63 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (node, error) { // If the whole key matches, keep this short node as is // and only update the value. if matchlen == len(n.Key) { - nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) + dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) if err != nil { - return nil, err + return false, nil, err + } + if !dirty { + return false, n, nil } - return shortNode{n.Key, nn}, nil + return true, shortNode{n.Key, nn, nil, true}, nil } // Otherwise branch out at the index where they differ. - var branch fullNode + branch := fullNode{dirty: true} var err error - branch[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) + _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) if err != nil { - return nil, err + return false, nil, err } - branch[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) + _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) if err != nil { - return nil, err + return false, nil, err } // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { - return branch, nil + return true, branch, nil } // Otherwise, replace it with a short node leading up to the branch. - return shortNode{key[:matchlen], branch}, nil + return true, shortNode{key[:matchlen], branch, nil, true}, nil case fullNode: - nn, err := t.insert(n[key[0]], append(prefix, key[0]), key[1:], value) + dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) if err != nil { - return nil, err + return false, nil, err } - n[key[0]] = nn - return n, nil + if !dirty { + return false, n, nil + } + n.Children[key[0]], n.hash, n.dirty = nn, nil, true + return true, n, nil case nil: - return shortNode{key, value}, nil + return true, shortNode{key, value, nil, true}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load // the node and insert into it. This leaves all child nodes on // the path to the value in the trie. - // - // TODO: track whether insertion changed the value and keep - // n as a hash node if it didn't. rn, err := t.resolveHash(n, prefix, key) if err != nil { - return nil, err + return false, nil, err + } + dirty, nn, err := t.insert(rn, prefix, key, value) + if err != nil { + return false, nil, err } - return t.insert(rn, prefix, key, value) + if !dirty { + return false, rn, nil + } + return true, nn, nil default: panic(fmt.Sprintf("%T: invalid node: %v", n, n)) @@ -258,7 +271,7 @@ func (t *Trie) Delete(key []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryDelete(key []byte) error { k := compactHexDecode(key) - n, err := t.delete(t.root, nil, k) + _, n, err := t.delete(t.root, nil, k) if err != nil { return err } @@ -269,23 +282,26 @@ func (t *Trie) TryDelete(key []byte) error { // delete returns the new root of the trie with key deleted. // It reduces the trie to minimal form by simplifying // nodes on the way up after deleting recursively. -func (t *Trie) delete(n node, prefix, key []byte) (node, error) { +func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { switch n := n.(type) { case shortNode: matchlen := prefixLen(key, n.Key) if matchlen < len(n.Key) { - return n, nil // don't replace n on mismatch + return false, n, nil // don't replace n on mismatch } if matchlen == len(key) { - return nil, nil // remove n entirely for whole matches + return true, nil, nil // remove n entirely for whole matches } // The key is longer than n.Key. Remove the remaining suffix // from the subtrie. Child can never be nil here since the // subtrie must contain at least two other values with keys // longer than n.Key. - child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) + dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) if err != nil { - return nil, err + return false, nil, err + } + if !dirty { + return false, n, nil } switch child := child.(type) { case shortNode: @@ -295,17 +311,21 @@ func (t *Trie) delete(n node, prefix, key []byte) (node, error) { // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return shortNode{concat(n.Key, child.Key...), child.Val}, nil + return true, shortNode{concat(n.Key, child.Key...), child.Val, nil, true}, nil default: - return shortNode{n.Key, child}, nil + return true, shortNode{n.Key, child, nil, true}, nil } case fullNode: - nn, err := t.delete(n[key[0]], append(prefix, key[0]), key[1:]) + dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) if err != nil { - return nil, err + return false, nil, err } - n[key[0]] = nn + if !dirty { + return false, n, nil + } + n.Children[key[0]], n.hash, n.dirty = nn, nil, true + // Check how many non-nil entries are left after deleting and // reduce the full node to a short node if only one entry is // left. Since n must've contained at least two children @@ -316,7 +336,7 @@ func (t *Trie) delete(n node, prefix, key []byte) (node, error) { // value that is left in n or -2 if n contains at least two // values. pos := -1 - for i, cld := range n { + for i, cld := range n.Children { if cld != nil { if pos == -1 { pos = i @@ -334,37 +354,41 @@ func (t *Trie) delete(n node, prefix, key []byte) (node, error) { // shortNode{..., shortNode{...}}. Since the entry // might not be loaded yet, resolve it just for this // check. - cnode, err := t.resolve(n[pos], prefix, []byte{byte(pos)}) + cnode, err := t.resolve(n.Children[pos], prefix, []byte{byte(pos)}) if err != nil { - return nil, err + return false, nil, err } if cnode, ok := cnode.(shortNode); ok { k := append([]byte{byte(pos)}, cnode.Key...) - return shortNode{k, cnode.Val}, nil + return true, shortNode{k, cnode.Val, nil, true}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return shortNode{[]byte{byte(pos)}, n[pos]}, nil + return true, shortNode{[]byte{byte(pos)}, n.Children[pos], nil, true}, nil } // n still contains at least two values and cannot be reduced. - return n, nil + return true, n, nil case nil: - return nil, nil + return false, nil, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load // the node and delete from it. This leaves all child nodes on // the path to the value in the trie. - // - // TODO: track whether deletion actually hit a key and keep - // n as a hash node if it didn't. rn, err := t.resolveHash(n, prefix, key) if err != nil { - return nil, err + return false, nil, err } - return t.delete(rn, prefix, key) + dirty, nn, err := t.delete(rn, prefix, key) + if err != nil { + return false, nil, err + } + if !dirty { + return false, rn, nil + } + return true, nn, nil default: panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) @@ -413,8 +437,9 @@ func (t *Trie) Root() []byte { return t.Hash().Bytes() } // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - root, _ := t.hashRoot(nil) - return common.BytesToHash(root.(hashNode)) + hash, cached, _ := t.hashRoot(nil) + t.root = cached + return common.BytesToHash(hash.(hashNode)) } // Commit writes all nodes to the trie's database. @@ -437,17 +462,17 @@ func (t *Trie) Commit() (root common.Hash, err error) { // the changes made to db are written back to the trie's attached // database before using the trie. func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { - n, err := t.hashRoot(db) + hash, cached, err := t.hashRoot(db) if err != nil { return (common.Hash{}), err } - t.root = n - return common.BytesToHash(n.(hashNode)), nil + t.root = cached + return common.BytesToHash(hash.(hashNode)), nil } -func (t *Trie) hashRoot(db DatabaseWriter) (node, error) { +func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { if t.root == nil { - return hashNode(emptyRoot.Bytes()), nil + return hashNode(emptyRoot.Bytes()), nil, nil } if t.hasher == nil { t.hasher = newHasher() @@ -464,51 +489,87 @@ func newHasher() *hasher { return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} } -func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, error) { - hashed, err := h.replaceChildren(n, db) +// hash collapses a node down into a hash node, also returning a copy of the +// original node initialzied with the computed hash to replace the original one. +func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { + // If we're not storing the node, just hashing, use avaialble cached data + if hash, dirty := n.cache(); hash != nil && (db == nil || !dirty) { + return hash, n, nil + } + // Trie not processed yet or needs storage, walk the children + collapsed, cached, err := h.hashChildren(n, db) if err != nil { - return hashNode{}, err + return hashNode{}, n, err } - if n, err = h.store(hashed, db, force); err != nil { - return hashNode{}, err + hashed, err := h.store(collapsed, db, force) + if err != nil { + return hashNode{}, n, err } - return n, nil + // Cache the hash and RLP blob of the ndoe for later reuse + if hash, ok := hashed.(hashNode); ok && !force { + switch cached := cached.(type) { + case shortNode: + cached.hash = hash + if db != nil { + cached.dirty = false + } + return hashed, cached, nil + case fullNode: + cached.hash = hash + if db != nil { + cached.dirty = false + } + return hashed, cached, nil + } + } + return hashed, cached, nil } -// hashChildren replaces child nodes of n with their hashes if the encoded -// size of the child is larger than a hash. -func (h *hasher) replaceChildren(n node, db DatabaseWriter) (node, error) { +// hashChildren replaces the children of a node with their hashes if the encoded +// size of the child is larger than a hash, returning the collapsed node as well +// as a replacement for the original node with the child hashes cached in. +func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { var err error - switch n := n.(type) { + + switch n := original.(type) { case shortNode: + // Hash the short node's child, caching the newly hashed subtree + cached := n + cached.Key = common.CopyBytes(cached.Key) + n.Key = compactEncode(n.Key) if _, ok := n.Val.(valueNode); !ok { - if n.Val, err = h.hash(n.Val, db, false); err != nil { - return n, err + if n.Val, cached.Val, err = h.hash(n.Val, db, false); err != nil { + return n, original, err } } if n.Val == nil { - // Ensure that nil children are encoded as empty strings. - n.Val = valueNode(nil) + n.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings. } - return n, nil + return n, cached, nil + case fullNode: + // Hash the full node's children, caching the newly hashed subtrees + cached := fullNode{dirty: n.dirty} + for i := 0; i < 16; i++ { - if n[i] != nil { - if n[i], err = h.hash(n[i], db, false); err != nil { - return n, err + if n.Children[i] != nil { + if n.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false); err != nil { + return n, original, err } } else { - // Ensure that nil children are encoded as empty strings. - n[i] = valueNode(nil) + n.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings. } } - if n[16] == nil { - n[16] = valueNode(nil) + cached.Children[16] = n.Children[16] + if n.Children[16] == nil { + n.Children[16] = valueNode(nil) } - return n, nil + return n, cached, nil + default: - return n, nil + // Value and hash nodes don't have children so they're left as were + return n, original, nil } } @@ -517,21 +578,23 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { if _, isHash := n.(hashNode); n == nil || isHash { return n, nil } + // Generate the RLP encoding of the node h.tmp.Reset() if err := rlp.Encode(h.tmp, n); err != nil { panic("encode error: " + err.Error()) } if h.tmp.Len() < 32 && !force { - // Nodes smaller than 32 bytes are stored inside their parent. - return n, nil + return n, nil // Nodes smaller than 32 bytes are stored inside their parent } // Larger nodes are replaced by their hash and stored in the database. - h.sha.Reset() - h.sha.Write(h.tmp.Bytes()) - key := hashNode(h.sha.Sum(nil)) + hash, _ := n.cache() + if hash == nil { + h.sha.Reset() + h.sha.Write(h.tmp.Bytes()) + hash = hashNode(h.sha.Sum(nil)) + } if db != nil { - err := db.Put(key, h.tmp.Bytes()) - return key, err + return hash, db.Put(hash, h.tmp.Bytes()) } - return key, nil + return hash, nil } diff --git a/trie/trie_test.go b/trie/trie_test.go index bb761b5551..121ba24c1e 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -295,7 +295,7 @@ func TestReplication(t *testing.T) { for _, val := range vals2 { updateString(trie2, val.k, val.v) } - if trie2.Hash() != exp { + if hash := trie2.Hash(); hash != exp { t.Errorf("root failure. expected %x got %x", exp, hash) } }