From 86dd005544179818edd78ef6c9396b9574e8a614 Mon Sep 17 00:00:00 2001 From: gary rong Date: Mon, 12 Oct 2020 18:08:04 +0800 Subject: [PATCH] trie: polish commit function (#21692) * trie: polish commit function * trie: fix typo --- trie/stacktrie.go | 23 +++++++++++++++-------- trie/trie_test.go | 5 ++++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/trie/stacktrie.go b/trie/stacktrie.go index fc653101ae..33fa990077 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -17,6 +17,7 @@ package trie import ( + "errors" "fmt" "sync" @@ -26,6 +27,8 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) +var ErrCommitDisabled = errors.New("no database for committing") + var stPool = sync.Pool{ New: func() interface{} { return NewStackTrie(nil) @@ -391,14 +394,18 @@ func (st *StackTrie) Hash() (h common.Hash) { return common.BytesToHash(st.val) } -// Commit will commit the current node to database db -func (st *StackTrie) Commit(db ethdb.KeyValueStore) common.Hash { - oldDb := st.db - st.db = db - defer func() { - st.db = oldDb - }() +// Commit will firstly hash the entrie trie if it's still not hashed +// and then commit all nodes to the associated database. Actually most +// of the trie nodes MAY have been committed already. The main purpose +// here is to commit the root node. +// +// The associated database is expected, otherwise the whole commit +// functionality should be disabled. +func (st *StackTrie) Commit() (common.Hash, error) { + if st.db == nil { + return common.Hash{}, ErrCommitDisabled + } st.hash() h := common.BytesToHash(st.val) - return h + return h, nil } diff --git a/trie/trie_test.go b/trie/trie_test.go index 03ec0cab89..539451fbf4 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -831,7 +831,10 @@ func TestCommitSequenceStackTrie(t *testing.T) { // Flush memdb -> disk (sponge) db.Commit(root, false, nil) // And flush stacktrie -> disk - stRoot := stTrie.Commit(stTrie.db) + stRoot, err := stTrie.Commit() + if err != nil { + t.Fatalf("Failed to commit stack trie %v", err) + } if stRoot != root { t.Fatalf("root wrong, got %x exp %x", stRoot, root) }