core, eth, les, tests, trie: abstract node scheme (#25532)

This PR introduces a node scheme abstraction. The interface is only implemented by `hashScheme` at the moment, but will be extended by `pathScheme` very soon.

Apart from that, a few changes are also included which is worth mentioning:

-  port the changes in the stacktrie, tracking the path prefix of nodes during commit
-  use ethdb.Database for constructing trie.Database. This is not necessary right now, but it is required for path-based used to open reverse diff freezer
pull/26272/head
rjl493456442 2 years ago committed by GitHub
parent 0e06735201
commit 743e404906
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      cmd/geth/chaincmd.go
  2. 52
      core/blockchain.go
  3. 6
      core/blockchain_reader.go
  4. 3
      core/chain_makers.go
  5. 37
      core/genesis.go
  6. 25
      core/genesis_test.go
  7. 3
      core/headerchain_test.go
  8. 20
      core/state/database.go
  9. 14
      core/state/iterator_test.go
  10. 25
      core/state/snapshot/conversion.go
  11. 5
      core/state/snapshot/generate.go
  12. 4
      core/state/snapshot/generate_test.go
  13. 4
      core/state/snapshot/snapshot.go
  14. 4
      core/state/sync.go
  15. 64
      core/state/sync_test.go
  16. 7
      eth/downloader/downloader.go
  17. 33
      eth/protocols/snap/sync.go
  18. 149
      eth/protocols/snap/sync_test.go
  19. 3
      les/client.go
  20. 2
      les/downloader/downloader.go
  21. 6
      les/downloader/statesync.go
  22. 6
      miner/miner_test.go
  23. 5
      tests/block_test_util.go
  24. 53
      tests/fuzzers/stacktrie/trie_fuzzer.go
  25. 4
      tests/fuzzers/trie/trie-fuzzer.go
  26. 11
      trie/database.go
  27. 4
      trie/database_test.go
  28. 8
      trie/iterator_test.go
  29. 96
      trie/schema.go
  30. 6
      trie/secure_trie_test.go
  31. 113
      trie/stacktrie.go
  32. 14
      trie/stacktrie_test.go
  33. 48
      trie/sync.go
  34. 37
      trie/sync_test.go
  35. 16
      trie/trie.go
  36. 23
      trie/trie_test.go

@ -39,6 +39,7 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/trie"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
@ -48,7 +49,7 @@ var (
Name: "init", Name: "init",
Usage: "Bootstrap and initialize a new genesis block", Usage: "Bootstrap and initialize a new genesis block",
ArgsUsage: "<genesisPath>", ArgsUsage: "<genesisPath>",
Flags: utils.DatabasePathFlags, Flags: flags.Merge([]cli.Flag{utils.CachePreimagesFlag}, utils.DatabasePathFlags),
Description: ` Description: `
The init command initializes a new genesis block and definition for the network. The init command initializes a new genesis block and definition for the network.
This is a destructive action and changes the network in which you will be This is a destructive action and changes the network in which you will be
@ -188,12 +189,16 @@ func initGenesis(ctx *cli.Context) error {
// Open and initialise both full and light databases // Open and initialise both full and light databases
stack, _ := makeConfigNode(ctx) stack, _ := makeConfigNode(ctx)
defer stack.Close() defer stack.Close()
for _, name := range []string{"chaindata", "lightchaindata"} { for _, name := range []string{"chaindata", "lightchaindata"} {
chaindb, err := stack.OpenDatabaseWithFreezer(name, 0, 0, ctx.String(utils.AncientFlag.Name), "", false) chaindb, err := stack.OpenDatabaseWithFreezer(name, 0, 0, ctx.String(utils.AncientFlag.Name), "", false)
if err != nil { if err != nil {
utils.Fatalf("Failed to open database: %v", err) utils.Fatalf("Failed to open database: %v", err)
} }
_, hash, err := core.SetupGenesisBlock(chaindb, genesis) triedb := trie.NewDatabaseWithConfig(chaindb, &trie.Config{
Preimages: ctx.Bool(utils.CachePreimagesFlag.Name),
})
_, hash, err := core.SetupGenesisBlock(chaindb, triedb, genesis)
if err != nil { if err != nil {
utils.Fatalf("Failed to write genesis block: %v", err) utils.Fatalf("Failed to write genesis block: %v", err)
} }
@ -460,7 +465,10 @@ func dump(ctx *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
state, err := state.New(root, state.NewDatabase(db), nil) config := &trie.Config{
Preimages: true, // always enable preimage lookup
}
state, err := state.New(root, state.NewDatabaseWithConfig(db, config), nil)
if err != nil { if err != nil {
return err return err
} }

@ -173,6 +173,8 @@ type BlockChain struct {
snaps *snapshot.Tree // Snapshot tree for fast trie leaf access snaps *snapshot.Tree // Snapshot tree for fast trie leaf access
triegc *prque.Prque // Priority queue mapping block numbers to tries to gc triegc *prque.Prque // Priority queue mapping block numbers to tries to gc
gcproc time.Duration // Accumulates canonical block processing for trie dumping gcproc time.Duration // Accumulates canonical block processing for trie dumping
triedb *trie.Database // The database handler for maintaining trie nodes.
stateCache state.Database // State database to reuse between imports (contains state cache)
// txLookupLimit is the maximum number of blocks from head whose tx indices // txLookupLimit is the maximum number of blocks from head whose tx indices
// are reserved: // are reserved:
@ -200,7 +202,6 @@ type BlockChain struct {
currentFinalizedBlock atomic.Value // Current finalized head currentFinalizedBlock atomic.Value // Current finalized head
currentSafeBlock atomic.Value // Current safe head currentSafeBlock atomic.Value // Current safe head
stateCache state.Database // State database to reuse between imports (contains state cache)
bodyCache *lru.Cache[common.Hash, *types.Body] bodyCache *lru.Cache[common.Hash, *types.Body]
bodyRLPCache *lru.Cache[common.Hash, rlp.RawValue] bodyRLPCache *lru.Cache[common.Hash, rlp.RawValue]
receiptsCache *lru.Cache[common.Hash, []*types.Receipt] receiptsCache *lru.Cache[common.Hash, []*types.Receipt]
@ -231,10 +232,16 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis
cacheConfig = defaultCacheConfig cacheConfig = defaultCacheConfig
} }
// Open trie database with provided config
triedb := trie.NewDatabaseWithConfig(db, &trie.Config{
Cache: cacheConfig.TrieCleanLimit,
Journal: cacheConfig.TrieCleanJournal,
Preimages: cacheConfig.Preimages,
})
// Setup the genesis block, commit the provided genesis specification // Setup the genesis block, commit the provided genesis specification
// to database if the genesis block is not present yet, or load the // to database if the genesis block is not present yet, or load the
// stored one from database. // stored one from database.
chainConfig, genesisHash, genesisErr := SetupGenesisBlockWithOverride(db, genesis, overrides) chainConfig, genesisHash, genesisErr := SetupGenesisBlockWithOverride(db, triedb, genesis, overrides)
if _, ok := genesisErr.(*params.ConfigCompatError); genesisErr != nil && !ok { if _, ok := genesisErr.(*params.ConfigCompatError); genesisErr != nil && !ok {
return nil, genesisErr return nil, genesisErr
} }
@ -250,12 +257,8 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis
chainConfig: chainConfig, chainConfig: chainConfig,
cacheConfig: cacheConfig, cacheConfig: cacheConfig,
db: db, db: db,
triedb: triedb,
triegc: prque.New(nil), triegc: prque.New(nil),
stateCache: state.NewDatabaseWithConfig(db, &trie.Config{
Cache: cacheConfig.TrieCleanLimit,
Journal: cacheConfig.TrieCleanJournal,
Preimages: cacheConfig.Preimages,
}),
quit: make(chan struct{}), quit: make(chan struct{}),
chainmu: syncx.NewClosableMutex(), chainmu: syncx.NewClosableMutex(),
bodyCache: lru.NewCache[common.Hash, *types.Body](bodyCacheLimit), bodyCache: lru.NewCache[common.Hash, *types.Body](bodyCacheLimit),
@ -268,6 +271,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis
vmConfig: vmConfig, vmConfig: vmConfig,
} }
bc.forker = NewForkChoice(bc, shouldPreserve) bc.forker = NewForkChoice(bc, shouldPreserve)
bc.stateCache = state.NewDatabaseWithNodeDB(bc.db, bc.triedb)
bc.validator = NewBlockValidator(chainConfig, bc, engine) bc.validator = NewBlockValidator(chainConfig, bc, engine)
bc.prefetcher = newStatePrefetcher(chainConfig, bc, engine) bc.prefetcher = newStatePrefetcher(chainConfig, bc, engine)
bc.processor = NewStateProcessor(chainConfig, bc, engine) bc.processor = NewStateProcessor(chainConfig, bc, engine)
@ -300,7 +304,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis
} }
// Make sure the state associated with the block is available // Make sure the state associated with the block is available
head := bc.CurrentBlock() head := bc.CurrentBlock()
if _, err := state.New(head.Root(), bc.stateCache, bc.snaps); err != nil { if !bc.HasState(head.Root()) {
// Head state is missing, before the state recovery, find out the // Head state is missing, before the state recovery, find out the
// disk layer point of snapshot(if it's enabled). Make sure the // disk layer point of snapshot(if it's enabled). Make sure the
// rewound point is lower than disk layer. // rewound point is lower than disk layer.
@ -388,7 +392,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis
var recover bool var recover bool
head := bc.CurrentBlock() head := bc.CurrentBlock()
if layer := rawdb.ReadSnapshotRecoveryNumber(bc.db); layer != nil && *layer > head.NumberU64() { if layer := rawdb.ReadSnapshotRecoveryNumber(bc.db); layer != nil && *layer >= head.NumberU64() {
log.Warn("Enabling snapshot recovery", "chainhead", head.NumberU64(), "diskbase", *layer) log.Warn("Enabling snapshot recovery", "chainhead", head.NumberU64(), "diskbase", *layer)
recover = true recover = true
} }
@ -398,7 +402,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis
NoBuild: bc.cacheConfig.SnapshotNoBuild, NoBuild: bc.cacheConfig.SnapshotNoBuild,
AsyncBuild: !bc.cacheConfig.SnapshotWait, AsyncBuild: !bc.cacheConfig.SnapshotWait,
} }
bc.snaps, _ = snapshot.New(snapconfig, bc.db, bc.stateCache.TrieDB(), head.Root()) bc.snaps, _ = snapshot.New(snapconfig, bc.db, bc.triedb, head.Root())
} }
// Start future block processor. // Start future block processor.
@ -411,11 +415,10 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis
log.Warn("Sanitizing invalid trie cache journal time", "provided", bc.cacheConfig.TrieCleanRejournal, "updated", time.Minute) log.Warn("Sanitizing invalid trie cache journal time", "provided", bc.cacheConfig.TrieCleanRejournal, "updated", time.Minute)
bc.cacheConfig.TrieCleanRejournal = time.Minute bc.cacheConfig.TrieCleanRejournal = time.Minute
} }
triedb := bc.stateCache.TrieDB()
bc.wg.Add(1) bc.wg.Add(1)
go func() { go func() {
defer bc.wg.Done() defer bc.wg.Done()
triedb.SaveCachePeriodically(bc.cacheConfig.TrieCleanJournal, bc.cacheConfig.TrieCleanRejournal, bc.quit) bc.triedb.SaveCachePeriodically(bc.cacheConfig.TrieCleanJournal, bc.cacheConfig.TrieCleanRejournal, bc.quit)
}() }()
} }
// Rewind the chain in case of an incompatible config upgrade. // Rewind the chain in case of an incompatible config upgrade.
@ -594,7 +597,7 @@ func (bc *BlockChain) setHeadBeyondRoot(head uint64, root common.Hash, repair bo
if root != (common.Hash{}) && !beyondRoot && newHeadBlock.Root() == root { if root != (common.Hash{}) && !beyondRoot && newHeadBlock.Root() == root {
beyondRoot, rootNumber = true, newHeadBlock.NumberU64() beyondRoot, rootNumber = true, newHeadBlock.NumberU64()
} }
if _, err := state.New(newHeadBlock.Root(), bc.stateCache, bc.snaps); err != nil { if !bc.HasState(newHeadBlock.Root()) {
log.Trace("Block state missing, rewinding further", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) log.Trace("Block state missing, rewinding further", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash())
if pivot == nil || newHeadBlock.NumberU64() > *pivot { if pivot == nil || newHeadBlock.NumberU64() > *pivot {
parent := bc.GetBlock(newHeadBlock.ParentHash(), newHeadBlock.NumberU64()-1) parent := bc.GetBlock(newHeadBlock.ParentHash(), newHeadBlock.NumberU64()-1)
@ -617,7 +620,7 @@ func (bc *BlockChain) setHeadBeyondRoot(head uint64, root common.Hash, repair bo
// if the historical chain pruning is enabled. In that case the logic // if the historical chain pruning is enabled. In that case the logic
// needs to be improved here. // needs to be improved here.
if !bc.HasState(bc.genesisBlock.Root()) { if !bc.HasState(bc.genesisBlock.Root()) {
if err := CommitGenesisState(bc.db, bc.genesisBlock.Hash()); err != nil { if err := CommitGenesisState(bc.db, bc.triedb, bc.genesisBlock.Hash()); err != nil {
log.Crit("Failed to commit genesis state", "err", err) log.Crit("Failed to commit genesis state", "err", err)
} }
log.Debug("Recommitted genesis state to disk") log.Debug("Recommitted genesis state to disk")
@ -900,7 +903,7 @@ func (bc *BlockChain) Stop() {
// - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle // - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle
// - HEAD-127: So we have a hard limit on the number of blocks reexecuted // - HEAD-127: So we have a hard limit on the number of blocks reexecuted
if !bc.cacheConfig.TrieDirtyDisabled { if !bc.cacheConfig.TrieDirtyDisabled {
triedb := bc.stateCache.TrieDB() triedb := bc.triedb
for _, offset := range []uint64{0, 1, TriesInMemory - 1} { for _, offset := range []uint64{0, 1, TriesInMemory - 1} {
if number := bc.CurrentBlock().NumberU64(); number > offset { if number := bc.CurrentBlock().NumberU64(); number > offset {
@ -932,8 +935,7 @@ func (bc *BlockChain) Stop() {
// Ensure all live cached entries be saved into disk, so that we can skip // Ensure all live cached entries be saved into disk, so that we can skip
// cache warmup when node restarts. // cache warmup when node restarts.
if bc.cacheConfig.TrieCleanJournal != "" { if bc.cacheConfig.TrieCleanJournal != "" {
triedb := bc.stateCache.TrieDB() bc.triedb.SaveCache(bc.cacheConfig.TrieCleanJournal)
triedb.SaveCache(bc.cacheConfig.TrieCleanJournal)
} }
log.Info("Blockchain stopped") log.Info("Blockchain stopped")
} }
@ -1306,24 +1308,22 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
if err != nil { if err != nil {
return err return err
} }
triedb := bc.stateCache.TrieDB()
// If we're running an archive node, always flush // If we're running an archive node, always flush
if bc.cacheConfig.TrieDirtyDisabled { if bc.cacheConfig.TrieDirtyDisabled {
return triedb.Commit(root, false, nil) return bc.triedb.Commit(root, false, nil)
} else { } else {
// Full but not archive node, do proper garbage collection // Full but not archive node, do proper garbage collection
triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive bc.triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive
bc.triegc.Push(root, -int64(block.NumberU64())) bc.triegc.Push(root, -int64(block.NumberU64()))
if current := block.NumberU64(); current > TriesInMemory { if current := block.NumberU64(); current > TriesInMemory {
// If we exceeded our memory allowance, flush matured singleton nodes to disk // If we exceeded our memory allowance, flush matured singleton nodes to disk
var ( var (
nodes, imgs = triedb.Size() nodes, imgs = bc.triedb.Size()
limit = common.StorageSize(bc.cacheConfig.TrieDirtyLimit) * 1024 * 1024 limit = common.StorageSize(bc.cacheConfig.TrieDirtyLimit) * 1024 * 1024
) )
if nodes > limit || imgs > 4*1024*1024 { if nodes > limit || imgs > 4*1024*1024 {
triedb.Cap(limit - ethdb.IdealBatchSize) bc.triedb.Cap(limit - ethdb.IdealBatchSize)
} }
// Find the next state trie we need to commit // Find the next state trie we need to commit
chosen := current - TriesInMemory chosen := current - TriesInMemory
@ -1342,7 +1342,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
log.Info("State in memory for too long, committing", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/TriesInMemory) log.Info("State in memory for too long, committing", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/TriesInMemory)
} }
// Flush an entire trie and restart the counters // Flush an entire trie and restart the counters
triedb.Commit(header.Root, true, nil) bc.triedb.Commit(header.Root, true, nil)
lastWrite = chosen lastWrite = chosen
bc.gcproc = 0 bc.gcproc = 0
} }
@ -1354,7 +1354,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
bc.triegc.Push(root, number) bc.triegc.Push(root, number)
break break
} }
triedb.Dereference(root.(common.Hash)) bc.triedb.Dereference(root.(common.Hash))
} }
} }
} }
@ -1760,7 +1760,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool)
stats.processed++ stats.processed++
stats.usedGas += usedGas stats.usedGas += usedGas
dirty, _ := bc.stateCache.TrieDB().Size() dirty, _ := bc.triedb.Size()
stats.report(chain, it.index, dirty, setHead) stats.report(chain, it.index, dirty, setHead)
if !setHead { if !setHead {

@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
) )
// CurrentHeader retrieves the current head header of the canonical chain. The // CurrentHeader retrieves the current head header of the canonical chain. The
@ -375,6 +376,11 @@ func (bc *BlockChain) TxLookupLimit() uint64 {
return bc.txLookupLimit return bc.txLookupLimit
} }
// TrieDB retrieves the low level trie database used for data storage.
func (bc *BlockChain) TrieDB() *trie.Database {
return bc.triedb
}
// SubscribeRemovedLogsEvent registers a subscription of RemovedLogsEvent. // SubscribeRemovedLogsEvent registers a subscription of RemovedLogsEvent.
func (bc *BlockChain) SubscribeRemovedLogsEvent(ch chan<- RemovedLogsEvent) event.Subscription { func (bc *BlockChain) SubscribeRemovedLogsEvent(ch chan<- RemovedLogsEvent) event.Subscription {
return bc.scope.Track(bc.rmLogsFeed.Subscribe(ch)) return bc.scope.Track(bc.rmLogsFeed.Subscribe(ch))

@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/trie"
) )
// BlockGen creates blocks for testing. // BlockGen creates blocks for testing.
@ -308,7 +309,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
// then generate chain on top. // then generate chain on top.
func GenerateChainWithGenesis(genesis *Genesis, engine consensus.Engine, n int, gen func(int, *BlockGen)) (ethdb.Database, []*types.Block, []types.Receipts) { func GenerateChainWithGenesis(genesis *Genesis, engine consensus.Engine, n int, gen func(int, *BlockGen)) (ethdb.Database, []*types.Block, []types.Receipts) {
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
_, err := genesis.Commit(db) _, err := genesis.Commit(db, trie.NewDatabase(db))
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -138,8 +138,8 @@ func (ga *GenesisAlloc) deriveHash() (common.Hash, error) {
// flush is very similar with deriveHash, but the main difference is // flush is very similar with deriveHash, but the main difference is
// all the generated states will be persisted into the given database. // all the generated states will be persisted into the given database.
// Also, the genesis state specification will be flushed as well. // Also, the genesis state specification will be flushed as well.
func (ga *GenesisAlloc) flush(db ethdb.Database) error { func (ga *GenesisAlloc) flush(db ethdb.Database, triedb *trie.Database) error {
statedb, err := state.New(common.Hash{}, state.NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), nil) statedb, err := state.New(common.Hash{}, state.NewDatabaseWithNodeDB(db, triedb), nil)
if err != nil { if err != nil {
return err return err
} }
@ -155,10 +155,12 @@ func (ga *GenesisAlloc) flush(db ethdb.Database) error {
if err != nil { if err != nil {
return err return err
} }
err = statedb.Database().TrieDB().Commit(root, true, nil) // Commit newly generated states into disk if it's not empty.
if err != nil { if root != types.EmptyRootHash {
if err := triedb.Commit(root, true, nil); err != nil {
return err return err
} }
}
// Marshal the genesis state specification and persist. // Marshal the genesis state specification and persist.
blob, err := json.Marshal(ga) blob, err := json.Marshal(ga)
if err != nil { if err != nil {
@ -169,8 +171,8 @@ func (ga *GenesisAlloc) flush(db ethdb.Database) error {
} }
// CommitGenesisState loads the stored genesis state with the given block // CommitGenesisState loads the stored genesis state with the given block
// hash and commits them into the given database handler. // hash and commits it into the provided trie database.
func CommitGenesisState(db ethdb.Database, hash common.Hash) error { func CommitGenesisState(db ethdb.Database, triedb *trie.Database, hash common.Hash) error {
var alloc GenesisAlloc var alloc GenesisAlloc
blob := rawdb.ReadGenesisStateSpec(db, hash) blob := rawdb.ReadGenesisStateSpec(db, hash)
if len(blob) != 0 { if len(blob) != 0 {
@ -202,7 +204,7 @@ func CommitGenesisState(db ethdb.Database, hash common.Hash) error {
return errors.New("not found") return errors.New("not found")
} }
} }
return alloc.flush(db) return alloc.flush(db, triedb)
} }
// GenesisAccount is an account in the state of the genesis block. // GenesisAccount is an account in the state of the genesis block.
@ -284,15 +286,14 @@ type ChainOverrides struct {
// error is a *params.ConfigCompatError and the new, unwritten config is returned. // error is a *params.ConfigCompatError and the new, unwritten config is returned.
// //
// The returned chain configuration is never nil. // The returned chain configuration is never nil.
func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig, common.Hash, error) { func SetupGenesisBlock(db ethdb.Database, triedb *trie.Database, genesis *Genesis) (*params.ChainConfig, common.Hash, error) {
return SetupGenesisBlockWithOverride(db, genesis, nil) return SetupGenesisBlockWithOverride(db, triedb, genesis, nil)
} }
func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, overrides *ChainOverrides) (*params.ChainConfig, common.Hash, error) { func SetupGenesisBlockWithOverride(db ethdb.Database, triedb *trie.Database, genesis *Genesis, overrides *ChainOverrides) (*params.ChainConfig, common.Hash, error) {
if genesis != nil && genesis.Config == nil { if genesis != nil && genesis.Config == nil {
return params.AllEthashProtocolChanges, common.Hash{}, errGenesisNoConfig return params.AllEthashProtocolChanges, common.Hash{}, errGenesisNoConfig
} }
applyOverrides := func(config *params.ChainConfig) { applyOverrides := func(config *params.ChainConfig) {
if config != nil { if config != nil {
if overrides != nil && overrides.OverrideTerminalTotalDifficulty != nil { if overrides != nil && overrides.OverrideTerminalTotalDifficulty != nil {
@ -313,7 +314,7 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override
} else { } else {
log.Info("Writing custom genesis block") log.Info("Writing custom genesis block")
} }
block, err := genesis.Commit(db) block, err := genesis.Commit(db, triedb)
if err != nil { if err != nil {
return genesis.Config, common.Hash{}, err return genesis.Config, common.Hash{}, err
} }
@ -323,7 +324,7 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override
// We have the genesis block in database(perhaps in ancient database) // We have the genesis block in database(perhaps in ancient database)
// but the corresponding state is missing. // but the corresponding state is missing.
header := rawdb.ReadHeader(db, stored, 0) header := rawdb.ReadHeader(db, stored, 0)
if _, err := state.New(header.Root, state.NewDatabaseWithConfig(db, nil), nil); err != nil { if _, err := state.New(header.Root, state.NewDatabaseWithNodeDB(db, triedb), nil); err != nil {
if genesis == nil { if genesis == nil {
genesis = DefaultGenesisBlock() genesis = DefaultGenesisBlock()
} }
@ -332,7 +333,7 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override
if hash != stored { if hash != stored {
return genesis.Config, hash, &GenesisMismatchError{stored, hash} return genesis.Config, hash, &GenesisMismatchError{stored, hash}
} }
block, err := genesis.Commit(db) block, err := genesis.Commit(db, triedb)
if err != nil { if err != nil {
return genesis.Config, hash, err return genesis.Config, hash, err
} }
@ -480,7 +481,7 @@ func (g *Genesis) ToBlock() *types.Block {
// Commit writes the block and state of a genesis specification to the database. // Commit writes the block and state of a genesis specification to the database.
// The block is committed as the canonical head block. // The block is committed as the canonical head block.
func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { func (g *Genesis) Commit(db ethdb.Database, triedb *trie.Database) (*types.Block, error) {
block := g.ToBlock() block := g.ToBlock()
if block.Number().Sign() != 0 { if block.Number().Sign() != 0 {
return nil, errors.New("can't commit genesis block with number > 0") return nil, errors.New("can't commit genesis block with number > 0")
@ -498,7 +499,7 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) {
// All the checks has passed, flush the states derived from the genesis // All the checks has passed, flush the states derived from the genesis
// specification as well as the specification itself into the provided // specification as well as the specification itself into the provided
// database. // database.
if err := g.Alloc.flush(db); err != nil { if err := g.Alloc.flush(db, triedb); err != nil {
return nil, err return nil, err
} }
rawdb.WriteTd(db, block.Hash(), block.NumberU64(), block.Difficulty()) rawdb.WriteTd(db, block.Hash(), block.NumberU64(), block.Difficulty())
@ -514,8 +515,10 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) {
// MustCommit writes the genesis block and state to db, panicking on error. // MustCommit writes the genesis block and state to db, panicking on error.
// The block is committed as the canonical head block. // The block is committed as the canonical head block.
// Note the state changes will be committed in hash-based scheme, use Commit
// if path-scheme is preferred.
func (g *Genesis) MustCommit(db ethdb.Database) *types.Block { func (g *Genesis) MustCommit(db ethdb.Database) *types.Block {
block, err := g.Commit(db) block, err := g.Commit(db, trie.NewDatabase(db))
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -17,6 +17,7 @@
package core package core
import ( import (
"encoding/json"
"math/big" "math/big"
"reflect" "reflect"
"testing" "testing"
@ -28,12 +29,14 @@ import (
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/trie"
) )
func TestInvalidCliqueConfig(t *testing.T) { func TestInvalidCliqueConfig(t *testing.T) {
block := DefaultGoerliGenesisBlock() block := DefaultGoerliGenesisBlock()
block.ExtraData = []byte{} block.ExtraData = []byte{}
if _, err := block.Commit(nil); err == nil { db := rawdb.NewMemoryDatabase()
if _, err := block.Commit(db, trie.NewDatabase(db)); err == nil {
t.Fatal("Expected error on invalid clique config") t.Fatal("Expected error on invalid clique config")
} }
} }
@ -60,7 +63,7 @@ func TestSetupGenesis(t *testing.T) {
{ {
name: "genesis without ChainConfig", name: "genesis without ChainConfig",
fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) {
return SetupGenesisBlock(db, new(Genesis)) return SetupGenesisBlock(db, trie.NewDatabase(db), new(Genesis))
}, },
wantErr: errGenesisNoConfig, wantErr: errGenesisNoConfig,
wantConfig: params.AllEthashProtocolChanges, wantConfig: params.AllEthashProtocolChanges,
@ -68,7 +71,7 @@ func TestSetupGenesis(t *testing.T) {
{ {
name: "no block in DB, genesis == nil", name: "no block in DB, genesis == nil",
fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) {
return SetupGenesisBlock(db, nil) return SetupGenesisBlock(db, trie.NewDatabase(db), nil)
}, },
wantHash: params.MainnetGenesisHash, wantHash: params.MainnetGenesisHash,
wantConfig: params.MainnetChainConfig, wantConfig: params.MainnetChainConfig,
@ -77,7 +80,7 @@ func TestSetupGenesis(t *testing.T) {
name: "mainnet block in DB, genesis == nil", name: "mainnet block in DB, genesis == nil",
fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) {
DefaultGenesisBlock().MustCommit(db) DefaultGenesisBlock().MustCommit(db)
return SetupGenesisBlock(db, nil) return SetupGenesisBlock(db, trie.NewDatabase(db), nil)
}, },
wantHash: params.MainnetGenesisHash, wantHash: params.MainnetGenesisHash,
wantConfig: params.MainnetChainConfig, wantConfig: params.MainnetChainConfig,
@ -86,7 +89,7 @@ func TestSetupGenesis(t *testing.T) {
name: "custom block in DB, genesis == nil", name: "custom block in DB, genesis == nil",
fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) {
customg.MustCommit(db) customg.MustCommit(db)
return SetupGenesisBlock(db, nil) return SetupGenesisBlock(db, trie.NewDatabase(db), nil)
}, },
wantHash: customghash, wantHash: customghash,
wantConfig: customg.Config, wantConfig: customg.Config,
@ -95,7 +98,7 @@ func TestSetupGenesis(t *testing.T) {
name: "custom block in DB, genesis == ropsten", name: "custom block in DB, genesis == ropsten",
fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) {
customg.MustCommit(db) customg.MustCommit(db)
return SetupGenesisBlock(db, DefaultRopstenGenesisBlock()) return SetupGenesisBlock(db, trie.NewDatabase(db), DefaultRopstenGenesisBlock())
}, },
wantErr: &GenesisMismatchError{Stored: customghash, New: params.RopstenGenesisHash}, wantErr: &GenesisMismatchError{Stored: customghash, New: params.RopstenGenesisHash},
wantHash: params.RopstenGenesisHash, wantHash: params.RopstenGenesisHash,
@ -105,7 +108,7 @@ func TestSetupGenesis(t *testing.T) {
name: "compatible config in DB", name: "compatible config in DB",
fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) { fn: func(db ethdb.Database) (*params.ChainConfig, common.Hash, error) {
oldcustomg.MustCommit(db) oldcustomg.MustCommit(db)
return SetupGenesisBlock(db, &customg) return SetupGenesisBlock(db, trie.NewDatabase(db), &customg)
}, },
wantHash: customghash, wantHash: customghash,
wantConfig: customg.Config, wantConfig: customg.Config,
@ -122,9 +125,9 @@ func TestSetupGenesis(t *testing.T) {
blocks, _ := GenerateChain(oldcustomg.Config, genesis, ethash.NewFaker(), db, 4, nil) blocks, _ := GenerateChain(oldcustomg.Config, genesis, ethash.NewFaker(), db, 4, nil)
bc.InsertChain(blocks) bc.InsertChain(blocks)
bc.CurrentBlock()
// This should return a compatibility error. // This should return a compatibility error.
return SetupGenesisBlock(db, &customg) return SetupGenesisBlock(db, trie.NewDatabase(db), &customg)
}, },
wantHash: customghash, wantHash: customghash,
wantConfig: customg.Config, wantConfig: customg.Config,
@ -193,6 +196,7 @@ func TestGenesis_Commit(t *testing.T) {
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
genesisBlock := genesis.MustCommit(db) genesisBlock := genesis.MustCommit(db)
if genesis.Difficulty != nil { if genesis.Difficulty != nil {
t.Fatalf("assumption wrong") t.Fatalf("assumption wrong")
} }
@ -219,7 +223,8 @@ func TestReadWriteGenesisAlloc(t *testing.T) {
} }
hash, _ = alloc.deriveHash() hash, _ = alloc.deriveHash()
) )
alloc.flush(db) blob, _ := json.Marshal(alloc)
rawdb.WriteGenesisStateSpec(db, hash, blob)
var reload GenesisAlloc var reload GenesisAlloc
err := reload.UnmarshalJSON(rawdb.ReadGenesisStateSpec(db, hash)) err := reload.UnmarshalJSON(rawdb.ReadGenesisStateSpec(db, hash))

@ -28,6 +28,7 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/trie"
) )
func verifyUnbrokenCanonchain(hc *HeaderChain) error { func verifyUnbrokenCanonchain(hc *HeaderChain) error {
@ -72,7 +73,7 @@ func TestHeaderInsertion(t *testing.T) {
db = rawdb.NewMemoryDatabase() db = rawdb.NewMemoryDatabase()
gspec = &Genesis{BaseFee: big.NewInt(params.InitialBaseFee), Config: params.AllEthashProtocolChanges} gspec = &Genesis{BaseFee: big.NewInt(params.InitialBaseFee), Config: params.AllEthashProtocolChanges}
) )
gspec.Commit(db) gspec.Commit(db, trie.NewDatabase(db))
hc, err := NewHeaderChain(db, gspec.Config, ethash.NewFaker(), func() bool { return false }) hc, err := NewHeaderChain(db, gspec.Config, ethash.NewFaker(), func() bool { return false })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

@ -130,23 +130,33 @@ func NewDatabase(db ethdb.Database) Database {
// large memory cache. // large memory cache.
func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database {
return &cachingDB{ return &cachingDB{
db: trie.NewDatabaseWithConfig(db, config),
disk: db, disk: db,
codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize), codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
triedb: trie.NewDatabaseWithConfig(db, config),
}
}
// NewDatabaseWithNodeDB creates a state database with an already initialized node database.
func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database {
return &cachingDB{
disk: db,
codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
triedb: triedb,
} }
} }
type cachingDB struct { type cachingDB struct {
db *trie.Database
disk ethdb.KeyValueStore disk ethdb.KeyValueStore
codeSizeCache *lru.Cache[common.Hash, int] codeSizeCache *lru.Cache[common.Hash, int]
codeCache *lru.SizeConstrainedCache[common.Hash, []byte] codeCache *lru.SizeConstrainedCache[common.Hash, []byte]
triedb *trie.Database
} }
// OpenTrie opens the main account trie at a specific root hash. // OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.db) tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -155,7 +165,7 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) { func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) {
tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.db) tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -220,5 +230,5 @@ func (db *cachingDB) DiskDB() ethdb.KeyValueStore {
// TrieDB retrieves any intermediate trie-node caching layer. // TrieDB retrieves any intermediate trie-node caching layer.
func (db *cachingDB) TrieDB() *trie.Database { func (db *cachingDB) TrieDB() *trie.Database {
return db.db return db.triedb
} }

@ -26,10 +26,10 @@ import (
// Tests that the node iterator indeed walks over the entire database contents. // Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) { func TestNodeIteratorCoverage(t *testing.T) {
// Create some arbitrary test state to iterate // Create some arbitrary test state to iterate
db, root, _ := makeTestState() db, sdb, root, _ := makeTestState()
db.TrieDB().Commit(root, false, nil) sdb.TrieDB().Commit(root, false, nil)
state, err := New(root, db, nil) state, err := New(root, sdb, nil)
if err != nil { if err != nil {
t.Fatalf("failed to create state trie at %x: %v", root, err) t.Fatalf("failed to create state trie at %x: %v", root, err)
} }
@ -42,19 +42,19 @@ func TestNodeIteratorCoverage(t *testing.T) {
} }
// Cross check the iterated hashes and the database/nodepool content // Cross check the iterated hashes and the database/nodepool content
for hash := range hashes { for hash := range hashes {
if _, err = db.TrieDB().Node(hash); err != nil { if _, err = sdb.TrieDB().Node(hash); err != nil {
_, err = db.ContractCode(common.Hash{}, hash) _, err = sdb.ContractCode(common.Hash{}, hash)
} }
if err != nil { if err != nil {
t.Errorf("failed to retrieve reported node %x", hash) t.Errorf("failed to retrieve reported node %x", hash)
} }
} }
for _, hash := range db.TrieDB().Nodes() { for _, hash := range sdb.TrieDB().Nodes() {
if _, ok := hashes[hash]; !ok { if _, ok := hashes[hash]; !ok {
t.Errorf("state entry not reported %x", hash) t.Errorf("state entry not reported %x", hash)
} }
} }
it := db.DiskDB().NewIterator(nil, nil) it := db.NewIterator(nil, nil)
for it.Next() { for it.Next() {
key := it.Key() key := it.Key()
if bytes.HasPrefix(key, []byte("secure-key-")) { if bytes.HasPrefix(key, []byte("secure-key-")) {

@ -43,7 +43,7 @@ type trieKV struct {
type ( type (
// trieGeneratorFn is the interface of trie generation which can // trieGeneratorFn is the interface of trie generation which can
// be implemented by different trie algorithm. // be implemented by different trie algorithm.
trieGeneratorFn func(db ethdb.KeyValueWriter, owner common.Hash, in chan (trieKV), out chan (common.Hash)) trieGeneratorFn func(db ethdb.KeyValueWriter, scheme trie.NodeScheme, owner common.Hash, in chan (trieKV), out chan (common.Hash))
// leafCallbackFn is the callback invoked at the leaves of the trie, // leafCallbackFn is the callback invoked at the leaves of the trie,
// returns the subtrie root with the specified subtrie identifier. // returns the subtrie root with the specified subtrie identifier.
@ -52,12 +52,12 @@ type (
// GenerateAccountTrieRoot takes an account iterator and reproduces the root hash. // GenerateAccountTrieRoot takes an account iterator and reproduces the root hash.
func GenerateAccountTrieRoot(it AccountIterator) (common.Hash, error) { func GenerateAccountTrieRoot(it AccountIterator) (common.Hash, error) {
return generateTrieRoot(nil, it, common.Hash{}, stackTrieGenerate, nil, newGenerateStats(), true) return generateTrieRoot(nil, nil, it, common.Hash{}, stackTrieGenerate, nil, newGenerateStats(), true)
} }
// GenerateStorageTrieRoot takes a storage iterator and reproduces the root hash. // GenerateStorageTrieRoot takes a storage iterator and reproduces the root hash.
func GenerateStorageTrieRoot(account common.Hash, it StorageIterator) (common.Hash, error) { func GenerateStorageTrieRoot(account common.Hash, it StorageIterator) (common.Hash, error) {
return generateTrieRoot(nil, it, account, stackTrieGenerate, nil, newGenerateStats(), true) return generateTrieRoot(nil, nil, it, account, stackTrieGenerate, nil, newGenerateStats(), true)
} }
// GenerateTrie takes the whole snapshot tree as the input, traverses all the // GenerateTrie takes the whole snapshot tree as the input, traverses all the
@ -71,7 +71,8 @@ func GenerateTrie(snaptree *Tree, root common.Hash, src ethdb.Database, dst ethd
} }
defer acctIt.Release() defer acctIt.Release()
got, err := generateTrieRoot(dst, acctIt, common.Hash{}, stackTrieGenerate, func(dst ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { scheme := snaptree.triedb.Scheme()
got, err := generateTrieRoot(dst, scheme, acctIt, common.Hash{}, stackTrieGenerate, func(dst ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) {
// Migrate the code first, commit the contract code into the tmp db. // Migrate the code first, commit the contract code into the tmp db.
if codeHash != emptyCode { if codeHash != emptyCode {
code := rawdb.ReadCode(src, codeHash) code := rawdb.ReadCode(src, codeHash)
@ -87,7 +88,7 @@ func GenerateTrie(snaptree *Tree, root common.Hash, src ethdb.Database, dst ethd
} }
defer storageIt.Release() defer storageIt.Release()
hash, err := generateTrieRoot(dst, storageIt, accountHash, stackTrieGenerate, nil, stat, false) hash, err := generateTrieRoot(dst, scheme, storageIt, accountHash, stackTrieGenerate, nil, stat, false)
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
@ -242,7 +243,7 @@ func runReport(stats *generateStats, stop chan bool) {
// generateTrieRoot generates the trie hash based on the snapshot iterator. // generateTrieRoot generates the trie hash based on the snapshot iterator.
// It can be used for generating account trie, storage trie or even the // It can be used for generating account trie, storage trie or even the
// whole state which connects the accounts and the corresponding storages. // whole state which connects the accounts and the corresponding storages.
func generateTrieRoot(db ethdb.KeyValueWriter, it Iterator, account common.Hash, generatorFn trieGeneratorFn, leafCallback leafCallbackFn, stats *generateStats, report bool) (common.Hash, error) { func generateTrieRoot(db ethdb.KeyValueWriter, scheme trie.NodeScheme, it Iterator, account common.Hash, generatorFn trieGeneratorFn, leafCallback leafCallbackFn, stats *generateStats, report bool) (common.Hash, error) {
var ( var (
in = make(chan trieKV) // chan to pass leaves in = make(chan trieKV) // chan to pass leaves
out = make(chan common.Hash, 1) // chan to collect result out = make(chan common.Hash, 1) // chan to collect result
@ -253,7 +254,7 @@ func generateTrieRoot(db ethdb.KeyValueWriter, it Iterator, account common.Hash,
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
generatorFn(db, account, in, out) generatorFn(db, scheme, account, in, out)
}() }()
// Spin up a go-routine for progress logging // Spin up a go-routine for progress logging
if report && stats != nil { if report && stats != nil {
@ -360,8 +361,14 @@ func generateTrieRoot(db ethdb.KeyValueWriter, it Iterator, account common.Hash,
return stop(nil) return stop(nil)
} }
func stackTrieGenerate(db ethdb.KeyValueWriter, owner common.Hash, in chan trieKV, out chan common.Hash) { func stackTrieGenerate(db ethdb.KeyValueWriter, scheme trie.NodeScheme, owner common.Hash, in chan trieKV, out chan common.Hash) {
t := trie.NewStackTrieWithOwner(db, owner) var nodeWriter trie.NodeWriteFunc
if db != nil {
nodeWriter = func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
scheme.WriteTrieNode(db, owner, path, hash, blob)
}
}
t := trie.NewStackTrieWithOwner(nodeWriter, owner)
for leaf := range in { for leaf := range in {
t.TryUpdate(leaf.key[:], leaf.value) t.TryUpdate(leaf.key[:], leaf.value)
} }

@ -29,7 +29,6 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
@ -360,9 +359,9 @@ func (dl *diskLayer) generateRange(ctx *generatorContext, trieId *trie.ID, prefi
} }
// We use the snap data to build up a cache which can be used by the // We use the snap data to build up a cache which can be used by the
// main account trie as a primary lookup when resolving hashes // main account trie as a primary lookup when resolving hashes
var snapNodeCache ethdb.KeyValueStore var snapNodeCache ethdb.Database
if len(result.keys) > 0 { if len(result.keys) > 0 {
snapNodeCache = memorydb.New() snapNodeCache = rawdb.NewMemoryDatabase()
snapTrieDb := trie.NewDatabase(snapNodeCache) snapTrieDb := trie.NewDatabase(snapNodeCache)
snapTrie := trie.NewEmpty(snapTrieDb) snapTrie := trie.NewEmpty(snapTrieDb)
for i, key := range result.keys { for i, key := range result.keys {

@ -117,12 +117,12 @@ func checkSnapRoot(t *testing.T, snap *diskLayer, trieRoot common.Hash) {
accIt := snap.AccountIterator(common.Hash{}) accIt := snap.AccountIterator(common.Hash{})
defer accIt.Release() defer accIt.Release()
snapRoot, err := generateTrieRoot(nil, accIt, common.Hash{}, stackTrieGenerate, snapRoot, err := generateTrieRoot(nil, nil, accIt, common.Hash{}, stackTrieGenerate,
func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) {
storageIt, _ := snap.StorageIterator(accountHash, common.Hash{}) storageIt, _ := snap.StorageIterator(accountHash, common.Hash{})
defer storageIt.Release() defer storageIt.Release()
hash, err := generateTrieRoot(nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false) hash, err := generateTrieRoot(nil, nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false)
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }

@ -776,14 +776,14 @@ func (t *Tree) Verify(root common.Hash) error {
} }
defer acctIt.Release() defer acctIt.Release()
got, err := generateTrieRoot(nil, acctIt, common.Hash{}, stackTrieGenerate, func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { got, err := generateTrieRoot(nil, nil, acctIt, common.Hash{}, stackTrieGenerate, func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) {
storageIt, err := t.StorageIterator(root, accountHash, common.Hash{}) storageIt, err := t.StorageIterator(root, accountHash, common.Hash{})
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
defer storageIt.Release() defer storageIt.Release()
hash, err := generateTrieRoot(nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false) hash, err := generateTrieRoot(nil, nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false)
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }

@ -27,7 +27,7 @@ import (
) )
// NewStateSync create a new state trie download scheduler. // NewStateSync create a new state trie download scheduler.
func NewStateSync(root common.Hash, database ethdb.KeyValueReader, onLeaf func(keys [][]byte, leaf []byte) error) *trie.Sync { func NewStateSync(root common.Hash, database ethdb.KeyValueReader, onLeaf func(keys [][]byte, leaf []byte) error, scheme trie.NodeScheme) *trie.Sync {
// Register the storage slot callback if the external callback is specified. // Register the storage slot callback if the external callback is specified.
var onSlot func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error var onSlot func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error
if onLeaf != nil { if onLeaf != nil {
@ -52,6 +52,6 @@ func NewStateSync(root common.Hash, database ethdb.KeyValueReader, onLeaf func(k
syncer.AddCodeEntry(common.BytesToHash(obj.CodeHash), path, parent, parentPath) syncer.AddCodeEntry(common.BytesToHash(obj.CodeHash), path, parent, parentPath)
return nil return nil
} }
syncer = trie.NewSync(root, database, onAccount) syncer = trie.NewSync(root, database, onAccount, scheme)
return syncer return syncer
} }

@ -39,10 +39,11 @@ type testAccount struct {
} }
// makeTestState create a sample test state to test node-wise reconstruction. // makeTestState create a sample test state to test node-wise reconstruction.
func makeTestState() (Database, common.Hash, []*testAccount) { func makeTestState() (ethdb.Database, Database, common.Hash, []*testAccount) {
// Create an empty state // Create an empty state
db := NewDatabase(rawdb.NewMemoryDatabase()) db := rawdb.NewMemoryDatabase()
state, _ := New(common.Hash{}, db, nil) sdb := NewDatabase(db)
state, _ := New(common.Hash{}, sdb, nil)
// Fill it with some arbitrary data // Fill it with some arbitrary data
var accounts []*testAccount var accounts []*testAccount
@ -63,7 +64,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) {
if i%5 == 0 { if i%5 == 0 {
for j := byte(0); j < 5; j++ { for j := byte(0); j < 5; j++ {
hash := crypto.Keccak256Hash([]byte{i, i, i, i, i, j, j}) hash := crypto.Keccak256Hash([]byte{i, i, i, i, i, j, j})
obj.SetState(db, hash, hash) obj.SetState(sdb, hash, hash)
} }
} }
state.updateStateObject(obj) state.updateStateObject(obj)
@ -72,7 +73,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) {
root, _ := state.Commit(false) root, _ := state.Commit(false)
// Return the generated state // Return the generated state
return db, root, accounts return db, sdb, root, accounts
} }
// checkStateAccounts cross references a reconstructed state with an expected // checkStateAccounts cross references a reconstructed state with an expected
@ -100,7 +101,7 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou
} }
// checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present. // checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present.
func checkTrieConsistency(db ethdb.KeyValueStore, root common.Hash) error { func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
if v, _ := db.Get(root[:]); v == nil { if v, _ := db.Get(root[:]); v == nil {
return nil // Consider a non existent state consistent. return nil // Consider a non existent state consistent.
} }
@ -132,8 +133,9 @@ func checkStateConsistency(db ethdb.Database, root common.Hash) error {
// Tests that an empty state is not scheduled for syncing. // Tests that an empty state is not scheduled for syncing.
func TestEmptyStateSync(t *testing.T) { func TestEmptyStateSync(t *testing.T) {
db := trie.NewDatabase(rawdb.NewMemoryDatabase())
empty := common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") empty := common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
sync := NewStateSync(empty, rawdb.NewMemoryDatabase(), nil) sync := NewStateSync(empty, rawdb.NewMemoryDatabase(), nil, db.Scheme())
if paths, nodes, codes := sync.Missing(1); len(paths) != 0 || len(nodes) != 0 || len(codes) != 0 { if paths, nodes, codes := sync.Missing(1); len(paths) != 0 || len(nodes) != 0 || len(codes) != 0 {
t.Errorf("content requested for empty state: %v, %v, %v", nodes, paths, codes) t.Errorf("content requested for empty state: %v, %v, %v", nodes, paths, codes)
} }
@ -170,7 +172,7 @@ type stateElement struct {
func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcDb, srcRoot, srcAccounts := makeTestState()
if commit { if commit {
srcDb.TrieDB().Commit(srcRoot, false, nil) srcDb.TrieDB().Commit(srcRoot, false, nil)
} }
@ -178,7 +180,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
sched := NewStateSync(srcRoot, dstDb, nil) sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme())
var ( var (
nodeElements []stateElement nodeElements []stateElement
@ -281,11 +283,11 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
// partial results are returned, and the others sent only later. // partial results are returned, and the others sent only later.
func TestIterativeDelayedStateSync(t *testing.T) { func TestIterativeDelayedStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
sched := NewStateSync(srcRoot, dstDb, nil) sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme())
var ( var (
nodeElements []stateElement nodeElements []stateElement
@ -374,11 +376,11 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS
func testIterativeRandomStateSync(t *testing.T, count int) { func testIterativeRandomStateSync(t *testing.T, count int) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
sched := NewStateSync(srcRoot, dstDb, nil) sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme())
nodeQueue := make(map[string]stateElement) nodeQueue := make(map[string]stateElement)
codeQueue := make(map[common.Hash]struct{}) codeQueue := make(map[common.Hash]struct{})
@ -454,11 +456,11 @@ func testIterativeRandomStateSync(t *testing.T, count int) {
// partial results are returned (Even those randomly), others sent only later. // partial results are returned (Even those randomly), others sent only later.
func TestIterativeRandomDelayedStateSync(t *testing.T) { func TestIterativeRandomDelayedStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
sched := NewStateSync(srcRoot, dstDb, nil) sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme())
nodeQueue := make(map[string]stateElement) nodeQueue := make(map[string]stateElement)
codeQueue := make(map[common.Hash]struct{}) codeQueue := make(map[common.Hash]struct{})
@ -544,7 +546,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
// the database. // the database.
func TestIncompleteStateSync(t *testing.T) { func TestIncompleteStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() db, srcDb, srcRoot, srcAccounts := makeTestState()
// isCodeLookup to save some hashing // isCodeLookup to save some hashing
var isCode = make(map[common.Hash]struct{}) var isCode = make(map[common.Hash]struct{})
@ -554,15 +556,16 @@ func TestIncompleteStateSync(t *testing.T) {
} }
} }
isCode[common.BytesToHash(emptyCodeHash)] = struct{}{} isCode[common.BytesToHash(emptyCodeHash)] = struct{}{}
checkTrieConsistency(srcDb.DiskDB(), srcRoot) checkTrieConsistency(db, srcRoot)
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
sched := NewStateSync(srcRoot, dstDb, nil) sched := NewStateSync(srcRoot, dstDb, nil, srcDb.TrieDB().Scheme())
var ( var (
addedCodes []common.Hash addedCodes []common.Hash
addedNodes []common.Hash addedPaths []string
addedHashes []common.Hash
) )
nodeQueue := make(map[string]stateElement) nodeQueue := make(map[string]stateElement)
codeQueue := make(map[common.Hash]struct{}) codeQueue := make(map[common.Hash]struct{})
@ -599,15 +602,16 @@ func TestIncompleteStateSync(t *testing.T) {
var nodehashes []common.Hash var nodehashes []common.Hash
if len(nodeQueue) > 0 { if len(nodeQueue) > 0 {
results := make([]trie.NodeSyncResult, 0, len(nodeQueue)) results := make([]trie.NodeSyncResult, 0, len(nodeQueue))
for key, element := range nodeQueue { for path, element := range nodeQueue {
data, err := srcDb.TrieDB().Node(element.hash) data, err := srcDb.TrieDB().Node(element.hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x", element.hash) t.Fatalf("failed to retrieve node data for %x", element.hash)
} }
results = append(results, trie.NodeSyncResult{Path: key, Data: data}) results = append(results, trie.NodeSyncResult{Path: path, Data: data})
if element.hash != srcRoot { if element.hash != srcRoot {
addedNodes = append(addedNodes, element.hash) addedPaths = append(addedPaths, element.path)
addedHashes = append(addedHashes, element.hash)
} }
nodehashes = append(nodehashes, element.hash) nodehashes = append(nodehashes, element.hash)
} }
@ -655,12 +659,18 @@ func TestIncompleteStateSync(t *testing.T) {
} }
rawdb.WriteCode(dstDb, node, val) rawdb.WriteCode(dstDb, node, val)
} }
for _, node := range addedNodes { scheme := srcDb.TrieDB().Scheme()
val := rawdb.ReadTrieNode(dstDb, node) for i, path := range addedPaths {
rawdb.DeleteTrieNode(dstDb, node) owner, inner := trie.ResolvePath([]byte(path))
hash := addedHashes[i]
val := scheme.ReadTrieNode(dstDb, owner, inner, hash)
if val == nil {
t.Error("missing trie node")
}
scheme.DeleteTrieNode(dstDb, owner, inner, hash)
if err := checkStateConsistency(dstDb, srcRoot); err == nil { if err := checkStateConsistency(dstDb, srcRoot); err == nil {
t.Errorf("trie inconsistency not caught, missing: %v", node.Hex()) t.Errorf("trie inconsistency not caught, missing: %v", path)
} }
rawdb.WriteTrieNode(dstDb, node, val) scheme.WriteTrieNode(dstDb, owner, inner, hash, val)
} }
} }

@ -35,6 +35,7 @@ import (
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/trie"
) )
var ( var (
@ -206,6 +207,10 @@ type BlockChain interface {
// Snapshots returns the blockchain snapshot tree to paused it during sync. // Snapshots returns the blockchain snapshot tree to paused it during sync.
Snapshots() *snapshot.Tree Snapshots() *snapshot.Tree
// TrieDB retrieves the low level trie database used for interacting
// with trie nodes.
TrieDB() *trie.Database
} }
// New creates a new downloader to fetch hashes and blocks from remote peers. // New creates a new downloader to fetch hashes and blocks from remote peers.
@ -224,7 +229,7 @@ func New(checkpoint uint64, stateDb ethdb.Database, mux *event.TypeMux, chain Bl
dropPeer: dropPeer, dropPeer: dropPeer,
headerProcCh: make(chan *headerTask, 1), headerProcCh: make(chan *headerTask, 1),
quitCh: make(chan struct{}), quitCh: make(chan struct{}),
SnapSyncer: snap.NewSyncer(stateDb), SnapSyncer: snap.NewSyncer(stateDb, chain.TrieDB().Scheme()),
stateSyncStart: make(chan *stateSync), stateSyncStart: make(chan *stateSync),
} }
dl.skeleton = newSkeleton(stateDb, dl.peers, dropPeer, newBeaconBackfiller(dl, success)) dl.skeleton = newSkeleton(stateDb, dl.peers, dropPeer, newBeaconBackfiller(dl, success))

@ -418,6 +418,7 @@ type SyncPeer interface {
// - The peer delivers a refusal to serve the requested state // - The peer delivers a refusal to serve the requested state
type Syncer struct { type Syncer struct {
db ethdb.KeyValueStore // Database to store the trie nodes into (and dedup) db ethdb.KeyValueStore // Database to store the trie nodes into (and dedup)
scheme trie.NodeScheme // Node scheme used in node database
root common.Hash // Current state trie root being synced root common.Hash // Current state trie root being synced
tasks []*accountTask // Current account task set being synced tasks []*accountTask // Current account task set being synced
@ -485,9 +486,10 @@ type Syncer struct {
// NewSyncer creates a new snapshot syncer to download the Ethereum state over the // NewSyncer creates a new snapshot syncer to download the Ethereum state over the
// snap protocol. // snap protocol.
func NewSyncer(db ethdb.KeyValueStore) *Syncer { func NewSyncer(db ethdb.KeyValueStore, scheme trie.NodeScheme) *Syncer {
return &Syncer{ return &Syncer{
db: db, db: db,
scheme: scheme,
peers: make(map[string]SyncPeer), peers: make(map[string]SyncPeer),
peerJoin: new(event.Feed), peerJoin: new(event.Feed),
@ -581,7 +583,7 @@ func (s *Syncer) Sync(root common.Hash, cancel chan struct{}) error {
s.lock.Lock() s.lock.Lock()
s.root = root s.root = root
s.healer = &healTask{ s.healer = &healTask{
scheduler: state.NewStateSync(root, s.db, s.onHealState), scheduler: state.NewStateSync(root, s.db, s.onHealState, s.scheme),
trieTasks: make(map[string]common.Hash), trieTasks: make(map[string]common.Hash),
codeTasks: make(map[common.Hash]struct{}), codeTasks: make(map[common.Hash]struct{}),
} }
@ -743,8 +745,9 @@ func (s *Syncer) loadSyncStatus() {
s.accountBytes += common.StorageSize(len(key) + len(value)) s.accountBytes += common.StorageSize(len(key) + len(value))
}, },
} }
task.genTrie = trie.NewStackTrie(task.genBatch) task.genTrie = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, val []byte) {
s.scheme.WriteTrieNode(task.genBatch, owner, path, hash, val)
})
for accountHash, subtasks := range task.SubTasks { for accountHash, subtasks := range task.SubTasks {
for _, subtask := range subtasks { for _, subtask := range subtasks {
subtask.genBatch = ethdb.HookedBatch{ subtask.genBatch = ethdb.HookedBatch{
@ -753,7 +756,9 @@ func (s *Syncer) loadSyncStatus() {
s.storageBytes += common.StorageSize(len(key) + len(value)) s.storageBytes += common.StorageSize(len(key) + len(value))
}, },
} }
subtask.genTrie = trie.NewStackTrieWithOwner(subtask.genBatch, accountHash) subtask.genTrie = trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) {
s.scheme.WriteTrieNode(subtask.genBatch, owner, path, hash, val)
}, accountHash)
} }
} }
} }
@ -810,7 +815,9 @@ func (s *Syncer) loadSyncStatus() {
Last: last, Last: last,
SubTasks: make(map[common.Hash][]*storageTask), SubTasks: make(map[common.Hash][]*storageTask),
genBatch: batch, genBatch: batch,
genTrie: trie.NewStackTrie(batch), genTrie: trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, val []byte) {
s.scheme.WriteTrieNode(batch, owner, path, hash, val)
}),
}) })
log.Debug("Created account sync task", "from", next, "last", last) log.Debug("Created account sync task", "from", next, "last", last)
next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1)) next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
@ -1835,7 +1842,7 @@ func (s *Syncer) processAccountResponse(res *accountResponse) {
} }
// Check if the account is a contract with an unknown storage trie // Check if the account is a contract with an unknown storage trie
if account.Root != emptyRoot { if account.Root != emptyRoot {
if ok, err := s.db.Has(account.Root[:]); err != nil || !ok { if !s.scheme.HasTrieNode(s.db, res.hashes[i], nil, account.Root) {
// If there was a previous large state retrieval in progress, // If there was a previous large state retrieval in progress,
// don't restart it from scratch. This happens if a sync cycle // don't restart it from scratch. This happens if a sync cycle
// is interrupted and resumed later. However, *do* update the // is interrupted and resumed later. However, *do* update the
@ -2007,7 +2014,9 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
Last: r.End(), Last: r.End(),
root: acc.Root, root: acc.Root,
genBatch: batch, genBatch: batch,
genTrie: trie.NewStackTrieWithOwner(batch, account), genTrie: trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) {
s.scheme.WriteTrieNode(batch, owner, path, hash, val)
}, account),
}) })
for r.Next() { for r.Next() {
batch := ethdb.HookedBatch{ batch := ethdb.HookedBatch{
@ -2021,7 +2030,9 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
Last: r.End(), Last: r.End(),
root: acc.Root, root: acc.Root,
genBatch: batch, genBatch: batch,
genTrie: trie.NewStackTrieWithOwner(batch, account), genTrie: trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) {
s.scheme.WriteTrieNode(batch, owner, path, hash, val)
}, account),
}) })
} }
for _, task := range tasks { for _, task := range tasks {
@ -2066,7 +2077,9 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
slots += len(res.hashes[i]) slots += len(res.hashes[i])
if i < len(res.hashes)-1 || res.subTask == nil { if i < len(res.hashes)-1 || res.subTask == nil {
tr := trie.NewStackTrieWithOwner(batch, account) tr := trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) {
s.scheme.WriteTrieNode(batch, owner, path, hash, val)
}, account)
for j := 0; j < len(res.hashes[i]); j++ { for j := 0; j < len(res.hashes[i]); j++ {
tr.Update(res.hashes[i][j][:], res.slots[i][j]) tr.Update(res.hashes[i][j][:], res.slots[i][j])
} }

@ -159,6 +159,13 @@ func newTestPeer(id string, t *testing.T, term func()) *testPeer {
return peer return peer
} }
func (t *testPeer) setStorageTries(tries map[common.Hash]*trie.Trie) {
t.storageTries = make(map[common.Hash]*trie.Trie)
for root, trie := range tries {
t.storageTries[root] = trie.Copy()
}
}
func (t *testPeer) ID() string { return t.id } func (t *testPeer) ID() string { return t.id }
func (t *testPeer) Log() log.Logger { return t.logger } func (t *testPeer) Log() log.Logger { return t.logger }
@ -562,9 +569,9 @@ func TestSyncBloatedProof(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(100) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100)
source := newTestPeer("source", t, term) source := newTestPeer("source", t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
@ -610,15 +617,15 @@ func TestSyncBloatedProof(t *testing.T) {
} }
return nil return nil
} }
syncer := setupSyncer(source) syncer := setupSyncer(nodeScheme, source)
if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err == nil { if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err == nil {
t.Fatal("No error returned from incomplete/cancelled sync") t.Fatal("No error returned from incomplete/cancelled sync")
} }
} }
func setupSyncer(peers ...*testPeer) *Syncer { func setupSyncer(scheme trie.NodeScheme, peers ...*testPeer) *Syncer {
stateDb := rawdb.NewMemoryDatabase() stateDb := rawdb.NewMemoryDatabase()
syncer := NewSyncer(stateDb) syncer := NewSyncer(stateDb, scheme)
for _, peer := range peers { for _, peer := range peers {
syncer.Register(peer) syncer.Register(peer)
peer.remote = syncer peer.remote = syncer
@ -639,15 +646,15 @@ func TestSync(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(100) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
return source return source
} }
syncer := setupSyncer(mkSource("source")) syncer := setupSyncer(nodeScheme, mkSource("source"))
if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
t.Fatalf("sync failed: %v", err) t.Fatalf("sync failed: %v", err)
} }
@ -668,15 +675,15 @@ func TestSyncTinyTriePanic(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(1) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(1)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
return source return source
} }
syncer := setupSyncer(mkSource("source")) syncer := setupSyncer(nodeScheme, mkSource("source"))
done := checkStall(t, term) done := checkStall(t, term)
if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
t.Fatalf("sync failed: %v", err) t.Fatalf("sync failed: %v", err)
@ -698,15 +705,15 @@ func TestMultiSync(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(100) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
return source return source
} }
syncer := setupSyncer(mkSource("sourceA"), mkSource("sourceB")) syncer := setupSyncer(nodeScheme, mkSource("sourceA"), mkSource("sourceB"))
done := checkStall(t, term) done := checkStall(t, term)
if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
t.Fatalf("sync failed: %v", err) t.Fatalf("sync failed: %v", err)
@ -728,17 +735,17 @@ func TestSyncWithStorage(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(3, 3000, true, false) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(3, 3000, true, false)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
return source return source
} }
syncer := setupSyncer(mkSource("sourceA")) syncer := setupSyncer(nodeScheme, mkSource("sourceA"))
done := checkStall(t, term) done := checkStall(t, term)
if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
t.Fatalf("sync failed: %v", err) t.Fatalf("sync failed: %v", err)
@ -760,13 +767,13 @@ func TestMultiSyncManyUseless(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false)
mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
if !noAccount { if !noAccount {
@ -782,6 +789,7 @@ func TestMultiSyncManyUseless(t *testing.T) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("full", true, true, true), mkSource("full", true, true, true),
mkSource("noAccounts", false, true, true), mkSource("noAccounts", false, true, true),
mkSource("noStorage", true, false, true), mkSource("noStorage", true, false, true),
@ -806,13 +814,13 @@ func TestMultiSyncManyUselessWithLowTimeout(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false)
mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
if !noAccount { if !noAccount {
@ -828,6 +836,7 @@ func TestMultiSyncManyUselessWithLowTimeout(t *testing.T) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("full", true, true, true), mkSource("full", true, true, true),
mkSource("noAccounts", false, true, true), mkSource("noAccounts", false, true, true),
mkSource("noStorage", true, false, true), mkSource("noStorage", true, false, true),
@ -857,13 +866,13 @@ func TestMultiSyncManyUnresponsive(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false)
mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
if !noAccount { if !noAccount {
@ -879,6 +888,7 @@ func TestMultiSyncManyUnresponsive(t *testing.T) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("full", true, true, true), mkSource("full", true, true, true),
mkSource("noAccounts", false, true, true), mkSource("noAccounts", false, true, true),
mkSource("noStorage", true, false, true), mkSource("noStorage", true, false, true),
@ -923,15 +933,16 @@ func TestSyncBoundaryAccountTrie(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeBoundaryAccountTrie(3000) nodeScheme, sourceAccountTrie, elems := makeBoundaryAccountTrie(3000)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
return source return source
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("peer-a"), mkSource("peer-a"),
mkSource("peer-b"), mkSource("peer-b"),
) )
@ -957,11 +968,11 @@ func TestSyncNoStorageAndOneCappedPeer(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000)
mkSource := func(name string, slow bool) *testPeer { mkSource := func(name string, slow bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
if slow { if slow {
@ -971,6 +982,7 @@ func TestSyncNoStorageAndOneCappedPeer(t *testing.T) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("nice-a", false), mkSource("nice-a", false),
mkSource("nice-b", false), mkSource("nice-b", false),
mkSource("nice-c", false), mkSource("nice-c", false),
@ -998,11 +1010,11 @@ func TestSyncNoStorageAndOneCodeCorruptPeer(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000)
mkSource := func(name string, codeFn codeHandlerFunc) *testPeer { mkSource := func(name string, codeFn codeHandlerFunc) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.codeRequestHandler = codeFn source.codeRequestHandler = codeFn
return source return source
@ -1012,6 +1024,7 @@ func TestSyncNoStorageAndOneCodeCorruptPeer(t *testing.T) {
// non-corrupt peer, which delivers everything in one go, and makes the // non-corrupt peer, which delivers everything in one go, and makes the
// test moot // test moot
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("capped", cappedCodeRequestHandler), mkSource("capped", cappedCodeRequestHandler),
mkSource("corrupt", corruptCodeRequestHandler), mkSource("corrupt", corruptCodeRequestHandler),
) )
@ -1035,11 +1048,11 @@ func TestSyncNoStorageAndOneAccountCorruptPeer(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000)
mkSource := func(name string, accFn accountHandlerFunc) *testPeer { mkSource := func(name string, accFn accountHandlerFunc) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.accountRequestHandler = accFn source.accountRequestHandler = accFn
return source return source
@ -1049,6 +1062,7 @@ func TestSyncNoStorageAndOneAccountCorruptPeer(t *testing.T) {
// non-corrupt peer, which delivers everything in one go, and makes the // non-corrupt peer, which delivers everything in one go, and makes the
// test moot // test moot
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("capped", defaultAccountRequestHandler), mkSource("capped", defaultAccountRequestHandler),
mkSource("corrupt", corruptAccountRequestHandler), mkSource("corrupt", corruptAccountRequestHandler),
) )
@ -1074,11 +1088,11 @@ func TestSyncNoStorageAndOneCodeCappedPeer(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(3000) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(3000)
mkSource := func(name string, codeFn codeHandlerFunc) *testPeer { mkSource := func(name string, codeFn codeHandlerFunc) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.codeRequestHandler = codeFn source.codeRequestHandler = codeFn
return source return source
@ -1087,6 +1101,7 @@ func TestSyncNoStorageAndOneCodeCappedPeer(t *testing.T) {
// so it shouldn't be more than that // so it shouldn't be more than that
var counter int var counter int
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
counter++ counter++
return cappedCodeRequestHandler(t, id, hashes, max) return cappedCodeRequestHandler(t, id, hashes, max)
@ -1124,17 +1139,18 @@ func TestSyncBoundaryStorageTrie(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(10, 1000, false, true) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(10, 1000, false, true)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
return source return source
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("peer-a"), mkSource("peer-a"),
mkSource("peer-b"), mkSource("peer-b"),
) )
@ -1160,13 +1176,13 @@ func TestSyncWithStorageAndOneCappedPeer(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(300, 1000, false, false) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(300, 1000, false, false)
mkSource := func(name string, slow bool) *testPeer { mkSource := func(name string, slow bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
if slow { if slow {
@ -1176,6 +1192,7 @@ func TestSyncWithStorageAndOneCappedPeer(t *testing.T) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("nice-a", false), mkSource("nice-a", false),
mkSource("slow", true), mkSource("slow", true),
) )
@ -1201,19 +1218,20 @@ func TestSyncWithStorageAndCorruptPeer(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false)
mkSource := func(name string, handler storageHandlerFunc) *testPeer { mkSource := func(name string, handler storageHandlerFunc) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
source.storageRequestHandler = handler source.storageRequestHandler = handler
return source return source
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("nice-a", defaultStorageRequestHandler), mkSource("nice-a", defaultStorageRequestHandler),
mkSource("nice-b", defaultStorageRequestHandler), mkSource("nice-b", defaultStorageRequestHandler),
mkSource("nice-c", defaultStorageRequestHandler), mkSource("nice-c", defaultStorageRequestHandler),
@ -1239,18 +1257,19 @@ func TestSyncWithStorageAndNonProvingPeer(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(100, 3000, true, false)
mkSource := func(name string, handler storageHandlerFunc) *testPeer { mkSource := func(name string, handler storageHandlerFunc) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
source.storageRequestHandler = handler source.storageRequestHandler = handler
return source return source
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme,
mkSource("nice-a", defaultStorageRequestHandler), mkSource("nice-a", defaultStorageRequestHandler),
mkSource("nice-b", defaultStorageRequestHandler), mkSource("nice-b", defaultStorageRequestHandler),
mkSource("nice-c", defaultStorageRequestHandler), mkSource("nice-c", defaultStorageRequestHandler),
@ -1279,18 +1298,18 @@ func TestSyncWithStorageMisbehavingProve(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorageWithUniqueStorage(10, 30, false) nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorageWithUniqueStorage(10, 30, false)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
source.storageTries = storageTries source.setStorageTries(storageTries)
source.storageValues = storageElems source.storageValues = storageElems
source.storageRequestHandler = proofHappyStorageRequestHandler source.storageRequestHandler = proofHappyStorageRequestHandler
return source return source
} }
syncer := setupSyncer(mkSource("sourceA")) syncer := setupSyncer(nodeScheme, mkSource("sourceA"))
if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
t.Fatalf("sync failed: %v", err) t.Fatalf("sync failed: %v", err)
} }
@ -1347,7 +1366,7 @@ func getCodeByHash(hash common.Hash) []byte {
} }
// makeAccountTrieNoStorage spits out a trie, along with the leafs // makeAccountTrieNoStorage spits out a trie, along with the leafs
func makeAccountTrieNoStorage(n int) (*trie.Trie, entrySlice) { func makeAccountTrieNoStorage(n int) (trie.NodeScheme, *trie.Trie, entrySlice) {
var ( var (
db = trie.NewDatabase(rawdb.NewMemoryDatabase()) db = trie.NewDatabase(rawdb.NewMemoryDatabase())
accTrie = trie.NewEmpty(db) accTrie = trie.NewEmpty(db)
@ -1373,13 +1392,13 @@ func makeAccountTrieNoStorage(n int) (*trie.Trie, entrySlice) {
db.Update(trie.NewWithNodeSet(nodes)) db.Update(trie.NewWithNodeSet(nodes))
accTrie, _ = trie.New(trie.StateTrieID(root), db) accTrie, _ = trie.New(trie.StateTrieID(root), db)
return accTrie, entries return db.Scheme(), accTrie, entries
} }
// makeBoundaryAccountTrie constructs an account trie. Instead of filling // makeBoundaryAccountTrie constructs an account trie. Instead of filling
// accounts normally, this function will fill a few accounts which have // accounts normally, this function will fill a few accounts which have
// boundary hash. // boundary hash.
func makeBoundaryAccountTrie(n int) (*trie.Trie, entrySlice) { func makeBoundaryAccountTrie(n int) (trie.NodeScheme, *trie.Trie, entrySlice) {
var ( var (
entries entrySlice entries entrySlice
boundaries []common.Hash boundaries []common.Hash
@ -1435,12 +1454,12 @@ func makeBoundaryAccountTrie(n int) (*trie.Trie, entrySlice) {
db.Update(trie.NewWithNodeSet(nodes)) db.Update(trie.NewWithNodeSet(nodes))
accTrie, _ = trie.New(trie.StateTrieID(root), db) accTrie, _ = trie.New(trie.StateTrieID(root), db)
return accTrie, entries return db.Scheme(), accTrie, entries
} }
// makeAccountTrieWithStorageWithUniqueStorage creates an account trie where each accounts // makeAccountTrieWithStorageWithUniqueStorage creates an account trie where each accounts
// has a unique storage set. // has a unique storage set.
func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool) (*trie.Trie, entrySlice, map[common.Hash]*trie.Trie, map[common.Hash]entrySlice) { func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool) (trie.NodeScheme, *trie.Trie, entrySlice, map[common.Hash]*trie.Trie, map[common.Hash]entrySlice) {
var ( var (
db = trie.NewDatabase(rawdb.NewMemoryDatabase()) db = trie.NewDatabase(rawdb.NewMemoryDatabase())
accTrie = trie.NewEmpty(db) accTrie = trie.NewEmpty(db)
@ -1491,11 +1510,11 @@ func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool)
trie, _ := trie.New(id, db) trie, _ := trie.New(id, db)
storageTries[common.BytesToHash(key)] = trie storageTries[common.BytesToHash(key)] = trie
} }
return accTrie, entries, storageTries, storageEntries return db.Scheme(), accTrie, entries, storageTries, storageEntries
} }
// makeAccountTrieWithStorage spits out a trie, along with the leafs // makeAccountTrieWithStorage spits out a trie, along with the leafs
func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (*trie.Trie, entrySlice, map[common.Hash]*trie.Trie, map[common.Hash]entrySlice) { func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (trie.NodeScheme, *trie.Trie, entrySlice, map[common.Hash]*trie.Trie, map[common.Hash]entrySlice) {
var ( var (
db = trie.NewDatabase(rawdb.NewMemoryDatabase()) db = trie.NewDatabase(rawdb.NewMemoryDatabase())
accTrie = trie.NewEmpty(db) accTrie = trie.NewEmpty(db)
@ -1562,7 +1581,7 @@ func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (*trie
} }
storageTries[common.BytesToHash(key)] = trie storageTries[common.BytesToHash(key)] = trie
} }
return accTrie, entries, storageTries, storageEntries return db.Scheme(), accTrie, entries, storageTries, storageEntries
} }
// makeStorageTrieWithSeed fills a storage trie with n items, returning the // makeStorageTrieWithSeed fills a storage trie with n items, returning the
@ -1641,7 +1660,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo
func verifyTrie(db ethdb.KeyValueStore, root common.Hash, t *testing.T) { func verifyTrie(db ethdb.KeyValueStore, root common.Hash, t *testing.T) {
t.Helper() t.Helper()
triedb := trie.NewDatabase(db) triedb := trie.NewDatabase(rawdb.NewDatabase(db))
accTrie, err := trie.New(trie.StateTrieID(root), triedb) accTrie, err := trie.New(trie.StateTrieID(root), triedb)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1697,16 +1716,16 @@ func TestSyncAccountPerformance(t *testing.T) {
}) })
} }
) )
sourceAccountTrie, elems := makeAccountTrieNoStorage(100) nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems source.accountValues = elems
return source return source
} }
src := mkSource("source") src := mkSource("source")
syncer := setupSyncer(src) syncer := setupSyncer(nodeScheme, src)
if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
t.Fatalf("sync failed: %v", err) t.Fatalf("sync failed: %v", err)
} }

@ -48,6 +48,7 @@ import (
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
"github.com/ethereum/go-ethereum/trie"
) )
type LightEthereum struct { type LightEthereum struct {
@ -99,7 +100,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*LightEthereum, error) {
if config.OverrideTerminalTotalDifficultyPassed != nil { if config.OverrideTerminalTotalDifficultyPassed != nil {
overrides.OverrideTerminalTotalDifficultyPassed = config.OverrideTerminalTotalDifficultyPassed overrides.OverrideTerminalTotalDifficultyPassed = config.OverrideTerminalTotalDifficultyPassed
} }
chainConfig, genesisHash, genesisErr := core.SetupGenesisBlockWithOverride(chainDb, config.Genesis, &overrides) chainConfig, genesisHash, genesisErr := core.SetupGenesisBlockWithOverride(chainDb, trie.NewDatabase(chainDb), config.Genesis, &overrides)
if _, isCompat := genesisErr.(*params.ConfigCompatError); genesisErr != nil && !isCompat { if _, isCompat := genesisErr.(*params.ConfigCompatError); genesisErr != nil && !isCompat {
return nil, genesisErr return nil, genesisErr
} }

@ -226,7 +226,7 @@ func New(checkpoint uint64, stateDb ethdb.Database, mux *event.TypeMux, chain Bl
headerProcCh: make(chan []*types.Header, 1), headerProcCh: make(chan []*types.Header, 1),
quitCh: make(chan struct{}), quitCh: make(chan struct{}),
stateCh: make(chan dataPack), stateCh: make(chan dataPack),
SnapSyncer: snap.NewSyncer(stateDb), SnapSyncer: snap.NewSyncer(stateDb, nil),
stateSyncStart: make(chan *stateSync), stateSyncStart: make(chan *stateSync),
//syncStatsState: stateSyncStats{ //syncStatsState: stateSyncStats{
// processed: rawdb.ReadFastTrieProgress(stateDb), // processed: rawdb.ReadFastTrieProgress(stateDb),

@ -22,6 +22,7 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
@ -295,10 +296,13 @@ type codeTask struct {
// newStateSync creates a new state trie download scheduler. This method does not // newStateSync creates a new state trie download scheduler. This method does not
// yet start the sync. The user needs to call run to initiate. // yet start the sync. The user needs to call run to initiate.
func newStateSync(d *Downloader, root common.Hash) *stateSync { func newStateSync(d *Downloader, root common.Hash) *stateSync {
// Hack the node scheme here. It's a dead code is not used
// by light client at all. Just aim for passing tests.
scheme := trie.NewDatabase(rawdb.NewMemoryDatabase()).Scheme()
return &stateSync{ return &stateSync{
d: d, d: d,
root: root, root: root,
sched: state.NewStateSync(root, d.stateDB, nil), sched: state.NewStateSync(root, d.stateDB, nil, scheme),
keccak: sha3.NewLegacyKeccak256().(crypto.KeccakState), keccak: sha3.NewLegacyKeccak256().(crypto.KeccakState),
trieTasks: make(map[string]*trieTask), trieTasks: make(map[string]*trieTask),
codeTasks: make(map[common.Hash]*codeTask), codeTasks: make(map[common.Hash]*codeTask),

@ -31,7 +31,6 @@ import (
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
) )
@ -247,10 +246,9 @@ func createMiner(t *testing.T) (*Miner, *event.TypeMux, func(skipMiner bool)) {
Etherbase: common.HexToAddress("123456789"), Etherbase: common.HexToAddress("123456789"),
} }
// Create chainConfig // Create chainConfig
memdb := memorydb.New() chainDB := rawdb.NewMemoryDatabase()
chainDB := rawdb.NewDatabase(memdb)
genesis := core.DeveloperGenesisBlock(15, 11_500_000, common.HexToAddress("12345")) genesis := core.DeveloperGenesisBlock(15, 11_500_000, common.HexToAddress("12345"))
chainConfig, _, err := core.SetupGenesisBlock(chainDB, genesis) chainConfig, _, err := core.SetupGenesisBlock(chainDB, trie.NewDatabase(chainDB), genesis)
if err != nil { if err != nil {
t.Fatalf("can't create new chain config: %v", err) t.Fatalf("can't create new chain config: %v", err)
} }

@ -107,10 +107,7 @@ func (t *BlockTest) Run(snapshotter bool) error {
// import pre accounts & construct test genesis block & state root // import pre accounts & construct test genesis block & state root
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
gspec := t.genesis(config) gspec := t.genesis(config)
gblock, err := gspec.Commit(db) gblock := gspec.MustCommit(db)
if err != nil {
return err
}
if gblock.Hash() != t.json.Genesis.Hash { if gblock.Hash() != t.json.Genesis.Hash {
return fmt.Errorf("genesis block hash doesn't match test: computed=%x, test=%x", gblock.Hash().Bytes()[:6], t.json.Genesis.Hash[:6]) return fmt.Errorf("genesis block hash doesn't match test: computed=%x, test=%x", gblock.Hash().Bytes()[:6], t.json.Genesis.Hash[:6])
} }

@ -25,6 +25,9 @@ import (
"io" "io"
"sort" "sort"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
@ -144,10 +147,13 @@ func (f *fuzzer) fuzz() int {
// This spongeDb is used to check the sequence of disk-db-writes // This spongeDb is used to check the sequence of disk-db-writes
var ( var (
spongeA = &spongeDb{sponge: sha3.NewLegacyKeccak256()} spongeA = &spongeDb{sponge: sha3.NewLegacyKeccak256()}
dbA = trie.NewDatabase(spongeA) dbA = trie.NewDatabase(rawdb.NewDatabase(spongeA))
trieA = trie.NewEmpty(dbA) trieA = trie.NewEmpty(dbA)
spongeB = &spongeDb{sponge: sha3.NewLegacyKeccak256()} spongeB = &spongeDb{sponge: sha3.NewLegacyKeccak256()}
trieB = trie.NewStackTrie(spongeB) dbB = trie.NewDatabase(rawdb.NewDatabase(spongeB))
trieB = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
dbB.Scheme().WriteTrieNode(spongeB, owner, path, hash, blob)
})
vals kvs vals kvs
useful bool useful bool
maxElements = 10000 maxElements = 10000
@ -206,5 +212,48 @@ func (f *fuzzer) fuzz() int {
if !bytes.Equal(sumA, sumB) { if !bytes.Equal(sumA, sumB) {
panic(fmt.Sprintf("sequence differ: (trie) %x != %x (stacktrie)", sumA, sumB)) panic(fmt.Sprintf("sequence differ: (trie) %x != %x (stacktrie)", sumA, sumB))
} }
// Ensure all the nodes are persisted correctly
var (
nodeset = make(map[string][]byte) // path -> blob
trieC = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
if crypto.Keccak256Hash(blob) != hash {
panic("invalid node blob")
}
if owner != (common.Hash{}) {
panic("invalid node owner")
}
nodeset[string(path)] = common.CopyBytes(blob)
})
checked int
)
for _, kv := range vals {
trieC.Update(kv.k, kv.v)
}
rootC, _ := trieC.Commit()
if rootA != rootC {
panic(fmt.Sprintf("roots differ: (trie) %x != %x (stacktrie)", rootA, rootC))
}
trieA, _ = trie.New(trie.TrieID(rootA), dbA)
iterA := trieA.NodeIterator(nil)
for iterA.Next(true) {
if iterA.Hash() == (common.Hash{}) {
if _, present := nodeset[string(iterA.Path())]; present {
panic("unexpected tiny node")
}
continue
}
nodeBlob, present := nodeset[string(iterA.Path())]
if !present {
panic("missing node")
}
if !bytes.Equal(nodeBlob, iterA.NodeBlob()) {
panic("node blob is not matched")
}
checked += 1
}
if checked != len(nodeset) {
panic("node number is not matched")
}
return 1 return 1
} }

@ -21,7 +21,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
) )
@ -139,7 +139,7 @@ func Fuzz(input []byte) int {
} }
func runRandTest(rt randTest) error { func runRandTest(rt randTest) error {
triedb := trie.NewDatabase(memorydb.New()) triedb := trie.NewDatabase(rawdb.NewMemoryDatabase())
tr := trie.NewEmpty(triedb) tr := trie.NewEmpty(triedb)
values := make(map[string]string) // tracks content of the trie values := make(map[string]string) // tracks content of the trie

@ -68,7 +68,7 @@ var (
// behind this split design is to provide read access to RPC handlers and sync // behind this split design is to provide read access to RPC handlers and sync
// servers even while the trie is executing expensive garbage collection. // servers even while the trie is executing expensive garbage collection.
type Database struct { type Database struct {
diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes diskdb ethdb.Database // Persistent storage for matured trie nodes
cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes
@ -273,14 +273,14 @@ type Config struct {
// NewDatabase creates a new trie database to store ephemeral trie content before // NewDatabase creates a new trie database to store ephemeral trie content before
// its written out to disk or garbage collected. No read cache is created, so all // its written out to disk or garbage collected. No read cache is created, so all
// data retrievals will hit the underlying disk database. // data retrievals will hit the underlying disk database.
func NewDatabase(diskdb ethdb.KeyValueStore) *Database { func NewDatabase(diskdb ethdb.Database) *Database {
return NewDatabaseWithConfig(diskdb, nil) return NewDatabaseWithConfig(diskdb, nil)
} }
// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content // NewDatabaseWithConfig creates a new trie database to store ephemeral trie content
// before its written out to disk or garbage collected. It also acts as a read cache // before its written out to disk or garbage collected. It also acts as a read cache
// for nodes loaded from disk. // for nodes loaded from disk.
func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database { func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database {
var cleans *fastcache.Cache var cleans *fastcache.Cache
if config != nil && config.Cache > 0 { if config != nil && config.Cache > 0 {
if config.Journal == "" { if config.Journal == "" {
@ -917,3 +917,8 @@ func (db *Database) CommitPreimages() error {
} }
return db.preimages.commit(true) return db.preimages.commit(true)
} }
// Scheme returns the node scheme used in the database.
func (db *Database) Scheme() NodeScheme {
return &hashScheme{}
}

@ -20,13 +20,13 @@ import (
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/core/rawdb"
) )
// Tests that the trie database returns a missing trie node error if attempting // Tests that the trie database returns a missing trie node error if attempting
// to retrieve the meta root. // to retrieve the meta root.
func TestDatabaseMetarootFetch(t *testing.T) { func TestDatabaseMetarootFetch(t *testing.T) {
db := NewDatabase(memorydb.New()) db := NewDatabase(rawdb.NewMemoryDatabase())
if _, err := db.Node(common.Hash{}); err == nil { if _, err := db.Node(common.Hash{}); err == nil {
t.Fatalf("metaroot retrieval succeeded") t.Fatalf("metaroot retrieval succeeded")
} }

@ -327,7 +327,7 @@ func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueA
func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
func testIteratorContinueAfterError(t *testing.T, memonly bool) { func testIteratorContinueAfterError(t *testing.T, memonly bool) {
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
tr := NewEmpty(triedb) tr := NewEmpty(triedb)
@ -419,7 +419,7 @@ func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
// Commit test trie to db, then remove the node containing "bars". // Commit test trie to db, then remove the node containing "bars".
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
ctr := NewEmpty(triedb) ctr := NewEmpty(triedb)
@ -532,7 +532,7 @@ func (l *loggingDb) Close() error {
func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) { func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) {
// Create an empty trie // Create an empty trie
logDb := &loggingDb{0, memorydb.New()} logDb := &loggingDb{0, memorydb.New()}
triedb := NewDatabase(logDb) triedb := NewDatabase(rawdb.NewDatabase(logDb))
trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb) trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb)
// Fill it with some arbitrary data // Fill it with some arbitrary data
@ -567,7 +567,7 @@ func TestNodeIteratorLargeTrie(t *testing.T) {
func TestIteratorNodeBlob(t *testing.T) { func TestIteratorNodeBlob(t *testing.T) {
var ( var (
db = memorydb.New() db = rawdb.NewMemoryDatabase()
triedb = NewDatabase(db) triedb = NewDatabase(db)
trie = NewEmpty(triedb) trie = NewEmpty(triedb)
) )

@ -0,0 +1,96 @@
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb"
)
const (
HashScheme = "hashScheme" // Identifier of hash based node scheme
// Path-based scheme will be introduced in the following PRs.
// PathScheme = "pathScheme" // Identifier of path based node scheme
)
// NodeScheme describes the scheme for interacting nodes in disk.
type NodeScheme interface {
// Name returns the identifier of node scheme.
Name() string
// HasTrieNode checks the trie node presence with the provided node info and
// the associated node hash.
HasTrieNode(db ethdb.KeyValueReader, owner common.Hash, path []byte, hash common.Hash) bool
// ReadTrieNode retrieves the trie node from database with the provided node
// info and the associated node hash.
ReadTrieNode(db ethdb.KeyValueReader, owner common.Hash, path []byte, hash common.Hash) []byte
// WriteTrieNode writes the trie node into database with the provided node
// info and associated node hash.
WriteTrieNode(db ethdb.KeyValueWriter, owner common.Hash, path []byte, hash common.Hash, node []byte)
// DeleteTrieNode deletes the trie node from database with the provided node
// info and associated node hash.
DeleteTrieNode(db ethdb.KeyValueWriter, owner common.Hash, path []byte, hash common.Hash)
// IsTrieNode returns an indicator if the given database key is the key of
// trie node according to the scheme.
IsTrieNode(key []byte) (bool, []byte)
}
type hashScheme struct{}
// Name returns the identifier of hash based scheme.
func (scheme *hashScheme) Name() string {
return HashScheme
}
// HasTrieNode checks the trie node presence with the provided node info and
// the associated node hash.
func (scheme *hashScheme) HasTrieNode(db ethdb.KeyValueReader, owner common.Hash, path []byte, hash common.Hash) bool {
return rawdb.HasTrieNode(db, hash)
}
// ReadTrieNode retrieves the trie node from database with the provided node info
// and associated node hash.
func (scheme *hashScheme) ReadTrieNode(db ethdb.KeyValueReader, owner common.Hash, path []byte, hash common.Hash) []byte {
return rawdb.ReadTrieNode(db, hash)
}
// WriteTrieNode writes the trie node into database with the provided node info
// and associated node hash.
func (scheme *hashScheme) WriteTrieNode(db ethdb.KeyValueWriter, owner common.Hash, path []byte, hash common.Hash, node []byte) {
rawdb.WriteTrieNode(db, hash, node)
}
// DeleteTrieNode deletes the trie node from database with the provided node info
// and associated node hash.
func (scheme *hashScheme) DeleteTrieNode(db ethdb.KeyValueWriter, owner common.Hash, path []byte, hash common.Hash) {
rawdb.DeleteTrieNode(db, hash)
}
// IsTrieNode returns an indicator if the given database key is the key of trie
// node according to the scheme.
func (scheme *hashScheme) IsTrieNode(key []byte) (bool, []byte) {
if len(key) == common.HashLength {
return true, key
}
return false, nil
}

@ -24,19 +24,19 @@ import (
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
) )
func newEmptySecure() *StateTrie { func newEmptySecure() *StateTrie {
trie, _ := NewStateTrie(TrieID(common.Hash{}), NewDatabase(memorydb.New())) trie, _ := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()))
return trie return trie
} }
// makeTestStateTrie creates a large enough secure trie for testing. // makeTestStateTrie creates a large enough secure trie for testing.
func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) { func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie // Create an empty trie
triedb := NewDatabase(memorydb.New()) triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb) trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb)
// Fill it with some arbitrary data // Fill it with some arbitrary data

@ -25,7 +25,6 @@ import (
"sync" "sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
) )
@ -37,10 +36,14 @@ var stPool = sync.Pool{
}, },
} }
func stackTrieFromPool(db ethdb.KeyValueWriter, owner common.Hash) *StackTrie { // NodeWriteFunc is used to provide all information of a dirty node for committing
// so that callers can flush nodes into database with desired scheme.
type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob []byte)
func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie {
st := stPool.Get().(*StackTrie) st := stPool.Get().(*StackTrie)
st.db = db
st.owner = owner st.owner = owner
st.writeFn = writeFn
return st return st
} }
@ -58,36 +61,36 @@ type StackTrie struct {
val []byte // value contained by this node if it's a leaf val []byte // value contained by this node if it's a leaf
key []byte // key chunk covered by this (leaf|ext) node key []byte // key chunk covered by this (leaf|ext) node
children [16]*StackTrie // list of children (for branch and exts) children [16]*StackTrie // list of children (for branch and exts)
db ethdb.KeyValueWriter // Pointer to the commit db, can be nil writeFn NodeWriteFunc // function for committing nodes, can be nil
} }
// NewStackTrie allocates and initializes an empty trie. // NewStackTrie allocates and initializes an empty trie.
func NewStackTrie(db ethdb.KeyValueWriter) *StackTrie { func NewStackTrie(writeFn NodeWriteFunc) *StackTrie {
return &StackTrie{ return &StackTrie{
nodeType: emptyNode, nodeType: emptyNode,
db: db, writeFn: writeFn,
} }
} }
// NewStackTrieWithOwner allocates and initializes an empty trie, but with // NewStackTrieWithOwner allocates and initializes an empty trie, but with
// the additional owner field. // the additional owner field.
func NewStackTrieWithOwner(db ethdb.KeyValueWriter, owner common.Hash) *StackTrie { func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie {
return &StackTrie{ return &StackTrie{
owner: owner, owner: owner,
nodeType: emptyNode, nodeType: emptyNode,
db: db, writeFn: writeFn,
} }
} }
// NewFromBinary initialises a serialized stacktrie with the given db. // NewFromBinary initialises a serialized stacktrie with the given db.
func NewFromBinary(data []byte, db ethdb.KeyValueWriter) (*StackTrie, error) { func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) {
var st StackTrie var st StackTrie
if err := st.UnmarshalBinary(data); err != nil { if err := st.UnmarshalBinary(data); err != nil {
return nil, err return nil, err
} }
// If a database is used, we need to recursively add it to every child // If a database is used, we need to recursively add it to every child
if db != nil { if writeFn != nil {
st.setDb(db) st.setWriter(writeFn)
} }
return &st, nil return &st, nil
} }
@ -160,25 +163,25 @@ func (st *StackTrie) unmarshalBinary(r io.Reader) error {
return nil return nil
} }
func (st *StackTrie) setDb(db ethdb.KeyValueWriter) { func (st *StackTrie) setWriter(writeFn NodeWriteFunc) {
st.db = db st.writeFn = writeFn
for _, child := range st.children { for _, child := range st.children {
if child != nil { if child != nil {
child.setDb(db) child.setWriter(writeFn)
} }
} }
} }
func newLeaf(owner common.Hash, key, val []byte, db ethdb.KeyValueWriter) *StackTrie { func newLeaf(owner common.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie {
st := stackTrieFromPool(db, owner) st := stackTrieFromPool(writeFn, owner)
st.nodeType = leafNode st.nodeType = leafNode
st.key = append(st.key, key...) st.key = append(st.key, key...)
st.val = val st.val = val
return st return st
} }
func newExt(owner common.Hash, key []byte, child *StackTrie, db ethdb.KeyValueWriter) *StackTrie { func newExt(owner common.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie {
st := stackTrieFromPool(db, owner) st := stackTrieFromPool(writeFn, owner)
st.nodeType = extNode st.nodeType = extNode
st.key = append(st.key, key...) st.key = append(st.key, key...)
st.children[0] = child st.children[0] = child
@ -200,7 +203,7 @@ func (st *StackTrie) TryUpdate(key, value []byte) error {
if len(value) == 0 { if len(value) == 0 {
panic("deletion not supported") panic("deletion not supported")
} }
st.insert(k[:len(k)-1], value) st.insert(k[:len(k)-1], value, nil)
return nil return nil
} }
@ -212,7 +215,7 @@ func (st *StackTrie) Update(key, value []byte) {
func (st *StackTrie) Reset() { func (st *StackTrie) Reset() {
st.owner = common.Hash{} st.owner = common.Hash{}
st.db = nil st.writeFn = nil
st.key = st.key[:0] st.key = st.key[:0]
st.val = nil st.val = nil
for i := range st.children { for i := range st.children {
@ -235,7 +238,7 @@ func (st *StackTrie) getDiffIndex(key []byte) int {
// Helper function to that inserts a (key, value) pair into // Helper function to that inserts a (key, value) pair into
// the trie. // the trie.
func (st *StackTrie) insert(key, value []byte) { func (st *StackTrie) insert(key, value []byte, prefix []byte) {
switch st.nodeType { switch st.nodeType {
case branchNode: /* Branch */ case branchNode: /* Branch */
idx := int(key[0]) idx := int(key[0])
@ -244,7 +247,7 @@ func (st *StackTrie) insert(key, value []byte) {
for i := idx - 1; i >= 0; i-- { for i := idx - 1; i >= 0; i-- {
if st.children[i] != nil { if st.children[i] != nil {
if st.children[i].nodeType != hashedNode { if st.children[i].nodeType != hashedNode {
st.children[i].hash() st.children[i].hash(append(prefix, byte(i)))
} }
break break
} }
@ -252,9 +255,9 @@ func (st *StackTrie) insert(key, value []byte) {
// Add new child // Add new child
if st.children[idx] == nil { if st.children[idx] == nil {
st.children[idx] = newLeaf(st.owner, key[1:], value, st.db) st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn)
} else { } else {
st.children[idx].insert(key[1:], value) st.children[idx].insert(key[1:], value, append(prefix, key[0]))
} }
case extNode: /* Ext */ case extNode: /* Ext */
@ -269,7 +272,7 @@ func (st *StackTrie) insert(key, value []byte) {
if diffidx == len(st.key) { if diffidx == len(st.key) {
// Ext key and key segment are identical, recurse into // Ext key and key segment are identical, recurse into
// the child node. // the child node.
st.children[0].insert(key[diffidx:], value) st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...))
return return
} }
// Save the original part. Depending if the break is // Save the original part. Depending if the break is
@ -278,14 +281,19 @@ func (st *StackTrie) insert(key, value []byte) {
// node directly. // node directly.
var n *StackTrie var n *StackTrie
if diffidx < len(st.key)-1 { if diffidx < len(st.key)-1 {
n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.db) // Break on the non-last byte, insert an intermediate
// extension. The path prefix of the newly-inserted
// extension should also contain the different byte.
n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn)
n.hash(append(prefix, st.key[:diffidx+1]...))
} else { } else {
// Break on the last byte, no need to insert // Break on the last byte, no need to insert
// an extension node: reuse the current node // an extension node: reuse the current node.
// The path prefix of the original part should
// still be same.
n = st.children[0] n = st.children[0]
n.hash(append(prefix, st.key...))
} }
// Convert to hash
n.hash()
var p *StackTrie var p *StackTrie
if diffidx == 0 { if diffidx == 0 {
// the break is on the first byte, so // the break is on the first byte, so
@ -298,12 +306,12 @@ func (st *StackTrie) insert(key, value []byte) {
// the common prefix is at least one byte // the common prefix is at least one byte
// long, insert a new intermediate branch // long, insert a new intermediate branch
// node. // node.
st.children[0] = stackTrieFromPool(st.db, st.owner) st.children[0] = stackTrieFromPool(st.writeFn, st.owner)
st.children[0].nodeType = branchNode st.children[0].nodeType = branchNode
p = st.children[0] p = st.children[0]
} }
// Create a leaf for the inserted part // Create a leaf for the inserted part
o := newLeaf(st.owner, key[diffidx+1:], value, st.db) o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn)
// Insert both child leaves where they belong: // Insert both child leaves where they belong:
origIdx := st.key[diffidx] origIdx := st.key[diffidx]
@ -339,7 +347,7 @@ func (st *StackTrie) insert(key, value []byte) {
// Convert current node into an ext, // Convert current node into an ext,
// and insert a child branch node. // and insert a child branch node.
st.nodeType = extNode st.nodeType = extNode
st.children[0] = NewStackTrieWithOwner(st.db, st.owner) st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner)
st.children[0].nodeType = branchNode st.children[0].nodeType = branchNode
p = st.children[0] p = st.children[0]
} }
@ -348,11 +356,11 @@ func (st *StackTrie) insert(key, value []byte) {
// value and another containing the new value. The child leaf // value and another containing the new value. The child leaf
// is hashed directly in order to free up some memory. // is hashed directly in order to free up some memory.
origIdx := st.key[diffidx] origIdx := st.key[diffidx]
p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.db) p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn)
p.children[origIdx].hash() p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...))
newIdx := key[diffidx] newIdx := key[diffidx]
p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.db) p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn)
// Finally, cut off the key part that has been passed // Finally, cut off the key part that has been passed
// over to the children. // over to the children.
@ -383,14 +391,14 @@ func (st *StackTrie) insert(key, value []byte) {
// - And the 'st.type' will be 'hashedNode' AGAIN // - And the 'st.type' will be 'hashedNode' AGAIN
// //
// This method also sets 'st.type' to hashedNode, and clears 'st.key'. // This method also sets 'st.type' to hashedNode, and clears 'st.key'.
func (st *StackTrie) hash() { func (st *StackTrie) hash(path []byte) {
h := newHasher(false) h := newHasher(false)
defer returnHasherToPool(h) defer returnHasherToPool(h)
st.hashRec(h) st.hashRec(h, path)
} }
func (st *StackTrie) hashRec(hasher *hasher) { func (st *StackTrie) hashRec(hasher *hasher, path []byte) {
// The switch below sets this to the RLP-encoding of this node. // The switch below sets this to the RLP-encoding of this node.
var encodedNode []byte var encodedNode []byte
@ -411,8 +419,7 @@ func (st *StackTrie) hashRec(hasher *hasher) {
nodes[i] = nilValueNode nodes[i] = nilValueNode
continue continue
} }
child.hashRec(hasher, append(path, byte(i)))
child.hashRec(hasher)
if len(child.val) < 32 { if len(child.val) < 32 {
nodes[i] = rawNode(child.val) nodes[i] = rawNode(child.val)
} else { } else {
@ -428,10 +435,9 @@ func (st *StackTrie) hashRec(hasher *hasher) {
encodedNode = hasher.encodedBytes() encodedNode = hasher.encodedBytes()
case extNode: case extNode:
st.children[0].hashRec(hasher) st.children[0].hashRec(hasher, append(path, st.key...))
sz := hexToCompactInPlace(st.key) n := rawShortNode{Key: hexToCompact(st.key)}
n := rawShortNode{Key: st.key[:sz]}
if len(st.children[0].val) < 32 { if len(st.children[0].val) < 32 {
n.Val = rawNode(st.children[0].val) n.Val = rawNode(st.children[0].val)
} else { } else {
@ -447,8 +453,7 @@ func (st *StackTrie) hashRec(hasher *hasher) {
case leafNode: case leafNode:
st.key = append(st.key, byte(16)) st.key = append(st.key, byte(16))
sz := hexToCompactInPlace(st.key) n := rawShortNode{Key: hexToCompact(st.key), Val: valueNode(st.val)}
n := rawShortNode{Key: st.key[:sz], Val: valueNode(st.val)}
n.encode(hasher.encbuf) n.encode(hasher.encbuf)
encodedNode = hasher.encodedBytes() encodedNode = hasher.encodedBytes()
@ -467,10 +472,8 @@ func (st *StackTrie) hashRec(hasher *hasher) {
// Write the hash to the 'val'. We allocate a new val here to not mutate // Write the hash to the 'val'. We allocate a new val here to not mutate
// input values // input values
st.val = hasher.hashData(encodedNode) st.val = hasher.hashData(encodedNode)
if st.db != nil { if st.writeFn != nil {
// TODO! Is it safe to Put the slice here? st.writeFn(st.owner, path, common.BytesToHash(st.val), encodedNode)
// Do all db implementations copy the value provided?
st.db.Put(st.val, encodedNode)
} }
} }
@ -479,12 +482,11 @@ func (st *StackTrie) Hash() (h common.Hash) {
hasher := newHasher(false) hasher := newHasher(false)
defer returnHasherToPool(hasher) defer returnHasherToPool(hasher)
st.hashRec(hasher) st.hashRec(hasher, nil)
if len(st.val) == 32 { if len(st.val) == 32 {
copy(h[:], st.val) copy(h[:], st.val)
return h return h
} }
// If the node's RLP isn't 32 bytes long, the node will not // If the node's RLP isn't 32 bytes long, the node will not
// be hashed, and instead contain the rlp-encoding of the // be hashed, and instead contain the rlp-encoding of the
// node. For the top level node, we need to force the hashing. // node. For the top level node, we need to force the hashing.
@ -502,25 +504,24 @@ func (st *StackTrie) Hash() (h common.Hash) {
// The associated database is expected, otherwise the whole commit // The associated database is expected, otherwise the whole commit
// functionality should be disabled. // functionality should be disabled.
func (st *StackTrie) Commit() (h common.Hash, err error) { func (st *StackTrie) Commit() (h common.Hash, err error) {
if st.db == nil { if st.writeFn == nil {
return common.Hash{}, ErrCommitDisabled return common.Hash{}, ErrCommitDisabled
} }
hasher := newHasher(false) hasher := newHasher(false)
defer returnHasherToPool(hasher) defer returnHasherToPool(hasher)
st.hashRec(hasher) st.hashRec(hasher, nil)
if len(st.val) == 32 { if len(st.val) == 32 {
copy(h[:], st.val) copy(h[:], st.val)
return h, nil return h, nil
} }
// If the node's RLP isn't 32 bytes long, the node will not // If the node's RLP isn't 32 bytes long, the node will not
// be hashed (and committed), and instead contain the rlp-encoding of the // be hashed (and committed), and instead contain the rlp-encoding of the
// node. For the top level node, we need to force the hashing+commit. // node. For the top level node, we need to force the hashing+commit.
hasher.sha.Reset() hasher.sha.Reset()
hasher.sha.Write(st.val) hasher.sha.Write(st.val)
hasher.sha.Read(h[:]) hasher.sha.Read(h[:])
st.db.Put(h[:], st.val)
st.writeFn(st.owner, nil, h, st.val)
return h, nil return h, nil
} }

@ -22,8 +22,8 @@ import (
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
) )
func TestStackTrieInsertAndHash(t *testing.T) { func TestStackTrieInsertAndHash(t *testing.T) {
@ -188,7 +188,7 @@ func TestStackTrieInsertAndHash(t *testing.T) {
func TestSizeBug(t *testing.T) { func TestSizeBug(t *testing.T) {
st := NewStackTrie(nil) st := NewStackTrie(nil)
nt := NewEmpty(NewDatabase(memorydb.New())) nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
@ -203,7 +203,7 @@ func TestSizeBug(t *testing.T) {
func TestEmptyBug(t *testing.T) { func TestEmptyBug(t *testing.T) {
st := NewStackTrie(nil) st := NewStackTrie(nil)
nt := NewEmpty(NewDatabase(memorydb.New())) nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
//leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
//value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
@ -229,7 +229,7 @@ func TestEmptyBug(t *testing.T) {
func TestValLength56(t *testing.T) { func TestValLength56(t *testing.T) {
st := NewStackTrie(nil) st := NewStackTrie(nil)
nt := NewEmpty(NewDatabase(memorydb.New())) nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
//leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
//value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
@ -254,7 +254,7 @@ func TestValLength56(t *testing.T) {
// which causes a lot of node-within-node. This case was found via fuzzing. // which causes a lot of node-within-node. This case was found via fuzzing.
func TestUpdateSmallNodes(t *testing.T) { func TestUpdateSmallNodes(t *testing.T) {
st := NewStackTrie(nil) st := NewStackTrie(nil)
nt := NewEmpty(NewDatabase(memorydb.New())) nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
kvs := []struct { kvs := []struct {
K string K string
@ -283,7 +283,7 @@ func TestUpdateSmallNodes(t *testing.T) {
func TestUpdateVariableKeys(t *testing.T) { func TestUpdateVariableKeys(t *testing.T) {
t.SkipNow() t.SkipNow()
st := NewStackTrie(nil) st := NewStackTrie(nil)
nt := NewEmpty(NewDatabase(memorydb.New())) nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
kvs := []struct { kvs := []struct {
K string K string
@ -353,7 +353,7 @@ func TestStacktrieNotModifyValues(t *testing.T) {
func TestStacktrieSerialization(t *testing.T) { func TestStacktrieSerialization(t *testing.T) {
var ( var (
st = NewStackTrie(nil) st = NewStackTrie(nil)
nt = NewEmpty(NewDatabase(memorydb.New())) nt = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
keyB = big.NewInt(1) keyB = big.NewInt(1)
keyDelta = big.NewInt(1) keyDelta = big.NewInt(1)
vals [][]byte vals [][]byte

@ -64,7 +64,7 @@ type SyncPath [][]byte
// version that can be sent over the network. // version that can be sent over the network.
func NewSyncPath(path []byte) SyncPath { func NewSyncPath(path []byte) SyncPath {
// If the hash is from the account trie, append a single item, if it // If the hash is from the account trie, append a single item, if it
// is from the a storage trie, append a tuple. Note, the length 64 is // is from a storage trie, append a tuple. Note, the length 64 is
// clashing between account leaf and storage root. It's fine though // clashing between account leaf and storage root. It's fine though
// because having a trie node at 64 depth means a hash collision was // because having a trie node at 64 depth means a hash collision was
// found and we're long dead. // found and we're long dead.
@ -74,6 +74,22 @@ func NewSyncPath(path []byte) SyncPath {
return SyncPath{hexToKeybytes(path[:64]), hexToCompact(path[64:])} return SyncPath{hexToKeybytes(path[:64]), hexToCompact(path[64:])}
} }
// LeafCallback is a callback type invoked when a trie operation reaches a leaf
// node.
//
// The keys is a path tuple identifying a particular trie node either in a single
// trie (account) or a layered trie (account -> storage). Each key in the tuple
// is in the raw format(32 bytes).
//
// The path is a composite hexary path identifying the trie node. All the key
// bytes are converted to the hexary nibbles and composited with the parent path
// if the trie node is in a layered trie.
//
// It's used by state sync and commit to allow handling external references
// between account and storage tries. And also it's used in the state healing
// for extracting the raw states(leaf nodes) with corresponding paths.
type LeafCallback func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error
// nodeRequest represents a scheduled or already in-flight trie node retrieval request. // nodeRequest represents a scheduled or already in-flight trie node retrieval request.
type nodeRequest struct { type nodeRequest struct {
hash common.Hash // Hash of the trie node to retrieve hash common.Hash // Hash of the trie node to retrieve
@ -139,6 +155,7 @@ func (batch *syncMemBatch) hasCode(hash common.Hash) bool {
// unknown trie hashes to retrieve, accepts node data associated with said hashes // unknown trie hashes to retrieve, accepts node data associated with said hashes
// and reconstructs the trie step by step until all is done. // and reconstructs the trie step by step until all is done.
type Sync struct { type Sync struct {
scheme NodeScheme // Node scheme descriptor used in database.
database ethdb.KeyValueReader // Persistent database to check for existing entries database ethdb.KeyValueReader // Persistent database to check for existing entries
membatch *syncMemBatch // Memory buffer to avoid frequent database writes membatch *syncMemBatch // Memory buffer to avoid frequent database writes
nodeReqs map[string]*nodeRequest // Pending requests pertaining to a trie node path nodeReqs map[string]*nodeRequest // Pending requests pertaining to a trie node path
@ -148,8 +165,9 @@ type Sync struct {
} }
// NewSync creates a new trie data download scheduler. // NewSync creates a new trie data download scheduler.
func NewSync(root common.Hash, database ethdb.KeyValueReader, callback LeafCallback) *Sync { func NewSync(root common.Hash, database ethdb.KeyValueReader, callback LeafCallback, scheme NodeScheme) *Sync {
ts := &Sync{ ts := &Sync{
scheme: scheme,
database: database, database: database,
membatch: newSyncMemBatch(), membatch: newSyncMemBatch(),
nodeReqs: make(map[string]*nodeRequest), nodeReqs: make(map[string]*nodeRequest),
@ -172,7 +190,8 @@ func (s *Sync) AddSubTrie(root common.Hash, path []byte, parent common.Hash, par
if s.membatch.hasNode(path) { if s.membatch.hasNode(path) {
return return
} }
if rawdb.HasTrieNode(s.database, root) { owner, inner := ResolvePath(path)
if s.scheme.HasTrieNode(s.database, owner, inner, root) {
return return
} }
// Assemble the new sub-trie sync request // Assemble the new sub-trie sync request
@ -205,7 +224,7 @@ func (s *Sync) AddCodeEntry(hash common.Hash, path []byte, parent common.Hash, p
return return
} }
// If database says duplicate, the blob is present for sure. // If database says duplicate, the blob is present for sure.
// Note we only check the existence with new code scheme, fast // Note we only check the existence with new code scheme, snap
// sync is expected to run with a fresh new node. Even there // sync is expected to run with a fresh new node. Even there
// exists the code with legacy format, fetch and store with // exists the code with legacy format, fetch and store with
// new scheme anyway. // new scheme anyway.
@ -329,7 +348,8 @@ func (s *Sync) ProcessNode(result NodeSyncResult) error {
func (s *Sync) Commit(dbw ethdb.Batch) error { func (s *Sync) Commit(dbw ethdb.Batch) error {
// Dump the membatch into a database dbw // Dump the membatch into a database dbw
for path, value := range s.membatch.nodes { for path, value := range s.membatch.nodes {
rawdb.WriteTrieNode(dbw, s.membatch.hashes[path], value) owner, inner := ResolvePath([]byte(path))
s.scheme.WriteTrieNode(dbw, owner, inner, s.membatch.hashes[path], value)
} }
for hash, value := range s.membatch.codes { for hash, value := range s.membatch.codes {
rawdb.WriteCode(dbw, hash, value) rawdb.WriteCode(dbw, hash, value)
@ -450,8 +470,11 @@ func (s *Sync) children(req *nodeRequest, object node) ([]*nodeRequest, error) {
// If database says duplicate, then at least the trie node is present // If database says duplicate, then at least the trie node is present
// and we hold the assumption that it's NOT legacy contract code. // and we hold the assumption that it's NOT legacy contract code.
chash := common.BytesToHash(node) var (
if rawdb.HasTrieNode(s.database, chash) { chash = common.BytesToHash(node)
owner, inner = ResolvePath(child.path)
)
if s.scheme.HasTrieNode(s.database, owner, inner, chash) {
return return
} }
// Locally unknown node, schedule for retrieval // Locally unknown node, schedule for retrieval
@ -525,3 +548,14 @@ func (s *Sync) commitCodeRequest(req *codeRequest) error {
} }
return nil return nil
} }
// ResolvePath resolves the provided composite node path by separating the
// path in account trie if it's existent.
func ResolvePath(path []byte) (common.Hash, []byte) {
var owner common.Hash
if len(path) >= 2*common.HashLength {
owner = common.BytesToHash(hexToKeybytes(path[:2*common.HashLength]))
path = path[2*common.HashLength:]
}
return owner, path
}

@ -22,6 +22,7 @@ import (
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/ethdb/memorydb"
) )
@ -29,7 +30,7 @@ import (
// makeTestTrie create a sample test trie to test node-wise reconstruction. // makeTestTrie create a sample test trie to test node-wise reconstruction.
func makeTestTrie() (*Database, *StateTrie, map[string][]byte) { func makeTestTrie() (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie // Create an empty trie
triedb := NewDatabase(memorydb.New()) triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb) trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb)
// Fill it with some arbitrary data // Fill it with some arbitrary data
@ -103,13 +104,13 @@ type trieElement struct {
// Tests that an empty trie is not scheduled for syncing. // Tests that an empty trie is not scheduled for syncing.
func TestEmptySync(t *testing.T) { func TestEmptySync(t *testing.T) {
dbA := NewDatabase(memorydb.New()) dbA := NewDatabase(rawdb.NewMemoryDatabase())
dbB := NewDatabase(memorydb.New()) dbB := NewDatabase(rawdb.NewMemoryDatabase())
emptyA, _ := New(TrieID(common.Hash{}), dbA) emptyA, _ := New(TrieID(common.Hash{}), dbA)
emptyB, _ := New(TrieID(emptyRoot), dbB) emptyB, _ := New(TrieID(emptyRoot), dbB)
for i, trie := range []*Trie{emptyA, emptyB} { for i, trie := range []*Trie{emptyA, emptyB} {
sync := NewSync(trie.Hash(), memorydb.New(), nil) sync := NewSync(trie.Hash(), memorydb.New(), nil, []*Database{dbA, dbB}[i].Scheme())
if paths, nodes, codes := sync.Missing(1); len(paths) != 0 || len(nodes) != 0 || len(codes) != 0 { if paths, nodes, codes := sync.Missing(1); len(paths) != 0 || len(nodes) != 0 || len(codes) != 0 {
t.Errorf("test %d: content requested for empty trie: %v, %v, %v", i, paths, nodes, codes) t.Errorf("test %d: content requested for empty trie: %v, %v, %v", i, paths, nodes, codes)
} }
@ -128,9 +129,9 @@ func testIterativeSync(t *testing.T, count int, bypath bool) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil) sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme())
// The code requests are ignored here since there is no code // The code requests are ignored here since there is no code
// at the testing trie. // at the testing trie.
@ -194,9 +195,9 @@ func TestIterativeDelayedSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil) sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme())
// The code requests are ignored here since there is no code // The code requests are ignored here since there is no code
// at the testing trie. // at the testing trie.
@ -255,9 +256,9 @@ func testIterativeRandomSync(t *testing.T, count int) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil) sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme())
// The code requests are ignored here since there is no code // The code requests are ignored here since there is no code
// at the testing trie. // at the testing trie.
@ -313,9 +314,9 @@ func TestIterativeRandomDelayedSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil) sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme())
// The code requests are ignored here since there is no code // The code requests are ignored here since there is no code
// at the testing trie. // at the testing trie.
@ -376,9 +377,9 @@ func TestDuplicateAvoidanceSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil) sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme())
// The code requests are ignored here since there is no code // The code requests are ignored here since there is no code
// at the testing trie. // at the testing trie.
@ -439,9 +440,9 @@ func TestIncompleteSync(t *testing.T) {
srcDb, srcTrie, _ := makeTestTrie() srcDb, srcTrie, _ := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil) sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme())
// The code requests are ignored here since there is no code // The code requests are ignored here since there is no code
// at the testing trie. // at the testing trie.
@ -519,9 +520,9 @@ func TestSyncOrdering(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler, tracking the requests // Create a destination trie and sync with the scheduler, tracking the requests
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
sched := NewSync(srcTrie.Hash(), diskdb, nil) sched := NewSync(srcTrie.Hash(), diskdb, nil, srcDb.Scheme())
// The code requests are ignored here since there is no code // The code requests are ignored here since there is no code
// at the testing trie. // at the testing trie.

@ -35,22 +35,6 @@ var (
emptyState = crypto.Keccak256Hash(nil) emptyState = crypto.Keccak256Hash(nil)
) )
// LeafCallback is a callback type invoked when a trie operation reaches a leaf
// node.
//
// The keys is a path tuple identifying a particular trie node either in a single
// trie (account) or a layered trie (account -> storage). Each key in the tuple
// is in the raw format(32 bytes).
//
// The path is a composite hexary path identifying the trie node. All the key
// bytes are converted to the hexary nibbles and composited with the parent path
// if the trie node is in a layered trie.
//
// It's used by state sync and commit to allow handling external references
// between account and storage tries. And also it's used in the state healing
// for extracting the raw states(leaf nodes) with corresponding paths.
type LeafCallback func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error
// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on // Trie is a Merkle Patricia Trie. Use New to create a trie that sits on
// top of a database. Whenever trie performs a commit operation, the generated // top of a database. Whenever trie performs a commit operation, the generated
// nodes will be gathered and returned in a set. Once the trie is committed, // nodes will be gathered and returned in a set. Once the trie is committed,

@ -34,7 +34,6 @@ import (
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
) )
@ -65,7 +64,7 @@ func TestNull(t *testing.T) {
func TestMissingRoot(t *testing.T) { func TestMissingRoot(t *testing.T) {
root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33") root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33")
trie, err := New(TrieID(root), NewDatabase(memorydb.New())) trie, err := New(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase()))
if trie != nil { if trie != nil {
t.Error("New returned non-nil trie for invalid root") t.Error("New returned non-nil trie for invalid root")
} }
@ -78,7 +77,7 @@ func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) }
func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) } func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
func testMissingNode(t *testing.T, memonly bool) { func testMissingNode(t *testing.T, memonly bool) {
diskdb := memorydb.New() diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb) triedb := NewDatabase(diskdb)
trie := NewEmpty(triedb) trie := NewEmpty(triedb)
@ -414,7 +413,7 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
func runRandTest(rt randTest) bool { func runRandTest(rt randTest) bool {
var ( var (
triedb = NewDatabase(memorydb.New()) triedb = NewDatabase(rawdb.NewMemoryDatabase())
tr = NewEmpty(triedb) tr = NewEmpty(triedb)
values = make(map[string]string) // tracks content of the trie values = make(map[string]string) // tracks content of the trie
origTrie = NewEmpty(triedb) origTrie = NewEmpty(triedb)
@ -811,7 +810,7 @@ func TestCommitSequence(t *testing.T) {
addresses, accounts := makeAccounts(tc.count) addresses, accounts := makeAccounts(tc.count)
// This spongeDb is used to check the sequence of disk-db-writes // This spongeDb is used to check the sequence of disk-db-writes
s := &spongeDb{sponge: sha3.NewLegacyKeccak256()} s := &spongeDb{sponge: sha3.NewLegacyKeccak256()}
db := NewDatabase(s) db := NewDatabase(rawdb.NewDatabase(s))
trie := NewEmpty(db) trie := NewEmpty(db)
// Another sponge is used to check the callback-sequence // Another sponge is used to check the callback-sequence
callbackSponge := sha3.NewLegacyKeccak256() callbackSponge := sha3.NewLegacyKeccak256()
@ -854,7 +853,7 @@ func TestCommitSequenceRandomBlobs(t *testing.T) {
prng := rand.New(rand.NewSource(int64(i))) prng := rand.New(rand.NewSource(int64(i)))
// This spongeDb is used to check the sequence of disk-db-writes // This spongeDb is used to check the sequence of disk-db-writes
s := &spongeDb{sponge: sha3.NewLegacyKeccak256()} s := &spongeDb{sponge: sha3.NewLegacyKeccak256()}
db := NewDatabase(s) db := NewDatabase(rawdb.NewDatabase(s))
trie := NewEmpty(db) trie := NewEmpty(db)
// Another sponge is used to check the callback-sequence // Another sponge is used to check the callback-sequence
callbackSponge := sha3.NewLegacyKeccak256() callbackSponge := sha3.NewLegacyKeccak256()
@ -894,11 +893,13 @@ func TestCommitSequenceStackTrie(t *testing.T) {
prng := rand.New(rand.NewSource(int64(count))) prng := rand.New(rand.NewSource(int64(count)))
// This spongeDb is used to check the sequence of disk-db-writes // This spongeDb is used to check the sequence of disk-db-writes
s := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "a"} s := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "a"}
db := NewDatabase(s) db := NewDatabase(rawdb.NewDatabase(s))
trie := NewEmpty(db) trie := NewEmpty(db)
// Another sponge is used for the stacktrie commits // Another sponge is used for the stacktrie commits
stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"} stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"}
stTrie := NewStackTrie(stackTrieSponge) stTrie := NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
db.Scheme().WriteTrieNode(stackTrieSponge, owner, path, hash, blob)
})
// Fill the trie with elements // Fill the trie with elements
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
// For the stack trie, we need to do inserts in proper order // For the stack trie, we need to do inserts in proper order
@ -951,11 +952,13 @@ func TestCommitSequenceStackTrie(t *testing.T) {
// not fit into 32 bytes, rlp-encoded. However, it's still the correct thing to do. // not fit into 32 bytes, rlp-encoded. However, it's still the correct thing to do.
func TestCommitSequenceSmallRoot(t *testing.T) { func TestCommitSequenceSmallRoot(t *testing.T) {
s := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "a"} s := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "a"}
db := NewDatabase(s) db := NewDatabase(rawdb.NewDatabase(s))
trie := NewEmpty(db) trie := NewEmpty(db)
// Another sponge is used for the stacktrie commits // Another sponge is used for the stacktrie commits
stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"} stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"}
stTrie := NewStackTrie(stackTrieSponge) stTrie := NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
db.Scheme().WriteTrieNode(stackTrieSponge, owner, path, hash, blob)
})
// Add a single small-element to the trie(s) // Add a single small-element to the trie(s)
key := make([]byte, 5) key := make([]byte, 5)
key[0] = 1 key[0] = 1

Loading…
Cancel
Save