core, trie: intermediate mempool between trie and database (#15857)

This commit reduces database I/O by not writing every state trie to disk.
pull/16037/head
Péter Szilágyi 7 years ago committed by Felix Lange
parent 59336283c0
commit 55599ee95d
  1. 14
      accounts/abi/bind/backends/simulated.go
  2. 4
      cmd/evm/runner.go
  3. 4
      cmd/geth/chaincmd.go
  4. 3
      cmd/geth/main.go
  5. 6
      cmd/geth/usage.go
  6. 26
      cmd/utils/cmd.go
  7. 47
      cmd/utils/flags.go
  8. 27
      common/size.go
  9. 4
      consensus/errors.go
  10. 4
      core/bench_test.go
  11. 7
      core/block_validator.go
  12. 8
      core/block_validator_test.go
  13. 318
      core/blockchain.go
  14. 145
      core/blockchain_test.go
  15. 3
      core/chain_indexer.go
  16. 9
      core/chain_makers.go
  17. 2
      core/chain_makers_test.go
  18. 24
      core/dao_test.go
  19. 24
      core/genesis.go
  20. 6
      core/genesis_test.go
  21. 51
      core/state/database.go
  22. 17
      core/state/iterator_test.go
  23. 5
      core/state/state_object.go
  24. 6
      core/state/state_test.go
  25. 38
      core/state/statedb.go
  26. 14
      core/state/statedb_test.go
  27. 44
      core/state/sync_test.go
  28. 4
      core/tx_pool_test.go
  29. 9
      core/types/block.go
  30. 13
      core/types/receipt.go
  31. 2
      core/types/transaction.go
  32. 4
      eth/api.go
  33. 135
      eth/api_tracer.go
  34. 8
      eth/backend.go
  35. 8
      eth/config.go
  36. 229
      eth/downloader/downloader.go
  37. 190
      eth/downloader/downloader_test.go
  38. 153
      eth/downloader/queue.go
  39. 29
      eth/downloader/statesync.go
  40. 6
      eth/handler.go
  41. 16
      eth/handler_test.go
  42. 14
      eth/helper_test.go
  43. 6
      eth/protocol_test.go
  44. 4
      eth/sync_test.go
  45. 2
      internal/ethapi/api.go
  46. 165
      les/handler.go
  47. 2
      les/handler_test.go
  48. 2
      les/helper_test.go
  49. 1
      les/odr_test.go
  50. 7
      light/lightchain.go
  51. 8
      light/nodeset.go
  52. 4
      light/odr_test.go
  53. 58
      light/postprocess.go
  54. 18
      light/trie.go
  55. 2
      light/trie_test.go
  56. 2
      light/txpool_test.go
  57. 2
      miner/worker.go
  58. 2
      tests/block_test_util.go
  59. 6
      tests/state_test_util.go
  60. 355
      trie/database.go
  61. 57
      trie/hasher.go
  62. 125
      trie/iterator_test.go
  63. 47
      trie/proof.go
  64. 62
      trie/secure_trie.go
  65. 20
      trie/secure_trie_test.go
  66. 14
      trie/sync.go
  67. 103
      trie/sync_test.go
  68. 88
      trie/trie.go
  69. 104
      trie/trie_test.go

@ -68,7 +68,7 @@ func NewSimulatedBackend(alloc core.GenesisAlloc) *SimulatedBackend {
database, _ := ethdb.NewMemDatabase() database, _ := ethdb.NewMemDatabase()
genesis := core.Genesis{Config: params.AllEthashProtocolChanges, Alloc: alloc} genesis := core.Genesis{Config: params.AllEthashProtocolChanges, Alloc: alloc}
genesis.MustCommit(database) genesis.MustCommit(database)
blockchain, _ := core.NewBlockChain(database, genesis.Config, ethash.NewFaker(), vm.Config{}) blockchain, _ := core.NewBlockChain(database, nil, genesis.Config, ethash.NewFaker(), vm.Config{})
backend := &SimulatedBackend{ backend := &SimulatedBackend{
database: database, database: database,
@ -102,8 +102,10 @@ func (b *SimulatedBackend) Rollback() {
func (b *SimulatedBackend) rollback() { func (b *SimulatedBackend) rollback() {
blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), ethash.NewFaker(), b.database, 1, func(int, *core.BlockGen) {}) blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), ethash.NewFaker(), b.database, 1, func(int, *core.BlockGen) {})
statedb, _ := b.blockchain.State()
b.pendingBlock = blocks[0] b.pendingBlock = blocks[0]
b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
} }
// CodeAt returns the code associated with a certain account in the blockchain. // CodeAt returns the code associated with a certain account in the blockchain.
@ -309,8 +311,10 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa
} }
block.AddTx(tx) block.AddTx(tx)
}) })
statedb, _ := b.blockchain.State()
b.pendingBlock = blocks[0] b.pendingBlock = blocks[0]
b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
return nil return nil
} }
@ -386,8 +390,10 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error {
} }
block.OffsetTime(int64(adjustment.Seconds())) block.OffsetTime(int64(adjustment.Seconds()))
}) })
statedb, _ := b.blockchain.State()
b.pendingBlock = blocks[0] b.pendingBlock = blocks[0]
b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
return nil return nil
} }

@ -96,7 +96,9 @@ func runCmd(ctx *cli.Context) error {
} }
if ctx.GlobalString(GenesisFlag.Name) != "" { if ctx.GlobalString(GenesisFlag.Name) != "" {
gen := readGenesis(ctx.GlobalString(GenesisFlag.Name)) gen := readGenesis(ctx.GlobalString(GenesisFlag.Name))
_, statedb = gen.ToBlock() db, _ := ethdb.NewMemDatabase()
genesis := gen.ToBlock(db)
statedb, _ = state.New(genesis.Root(), state.NewDatabase(db))
chainConfig = gen.Config chainConfig = gen.Config
} else { } else {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()

@ -202,7 +202,7 @@ func importChain(ctx *cli.Context) error {
if len(ctx.Args()) == 1 { if len(ctx.Args()) == 1 {
if err := utils.ImportChain(chain, ctx.Args().First()); err != nil { if err := utils.ImportChain(chain, ctx.Args().First()); err != nil {
utils.Fatalf("Import error: %v", err) log.Error("Import error", "err", err)
} }
} else { } else {
for _, arg := range ctx.Args() { for _, arg := range ctx.Args() {
@ -211,7 +211,7 @@ func importChain(ctx *cli.Context) error {
} }
} }
} }
chain.Stop()
fmt.Printf("Import done in %v.\n\n", time.Since(start)) fmt.Printf("Import done in %v.\n\n", time.Since(start))
// Output pre-compaction stats mostly to see the import trashing // Output pre-compaction stats mostly to see the import trashing

@ -85,10 +85,13 @@ var (
utils.FastSyncFlag, utils.FastSyncFlag,
utils.LightModeFlag, utils.LightModeFlag,
utils.SyncModeFlag, utils.SyncModeFlag,
utils.GCModeFlag,
utils.LightServFlag, utils.LightServFlag,
utils.LightPeersFlag, utils.LightPeersFlag,
utils.LightKDFFlag, utils.LightKDFFlag,
utils.CacheFlag, utils.CacheFlag,
utils.CacheDatabaseFlag,
utils.CacheGCFlag,
utils.TrieCacheGenFlag, utils.TrieCacheGenFlag,
utils.ListenPortFlag, utils.ListenPortFlag,
utils.MaxPeersFlag, utils.MaxPeersFlag,

@ -22,10 +22,11 @@ import (
"io" "io"
"sort" "sort"
"strings"
"github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/cmd/utils"
"github.com/ethereum/go-ethereum/internal/debug" "github.com/ethereum/go-ethereum/internal/debug"
"gopkg.in/urfave/cli.v1" "gopkg.in/urfave/cli.v1"
"strings"
) )
// AppHelpTemplate is the test template for the default, global app help topic. // AppHelpTemplate is the test template for the default, global app help topic.
@ -74,6 +75,7 @@ var AppHelpFlagGroups = []flagGroup{
utils.TestnetFlag, utils.TestnetFlag,
utils.RinkebyFlag, utils.RinkebyFlag,
utils.SyncModeFlag, utils.SyncModeFlag,
utils.GCModeFlag,
utils.EthStatsURLFlag, utils.EthStatsURLFlag,
utils.IdentityFlag, utils.IdentityFlag,
utils.LightServFlag, utils.LightServFlag,
@ -127,6 +129,8 @@ var AppHelpFlagGroups = []flagGroup{
Name: "PERFORMANCE TUNING", Name: "PERFORMANCE TUNING",
Flags: []cli.Flag{ Flags: []cli.Flag{
utils.CacheFlag, utils.CacheFlag,
utils.CacheDatabaseFlag,
utils.CacheGCFlag,
utils.TrieCacheGenFlag, utils.TrieCacheGenFlag,
}, },
}, },

@ -116,7 +116,6 @@ func ImportChain(chain *core.BlockChain, fn string) error {
return err return err
} }
} }
stream := rlp.NewStream(reader, 0) stream := rlp.NewStream(reader, 0)
// Run actual the import. // Run actual the import.
@ -150,25 +149,34 @@ func ImportChain(chain *core.BlockChain, fn string) error {
if checkInterrupt() { if checkInterrupt() {
return fmt.Errorf("interrupted") return fmt.Errorf("interrupted")
} }
if hasAllBlocks(chain, blocks[:i]) { missing := missingBlocks(chain, blocks[:i])
if len(missing) == 0 {
log.Info("Skipping batch as all blocks present", "batch", batch, "first", blocks[0].Hash(), "last", blocks[i-1].Hash()) log.Info("Skipping batch as all blocks present", "batch", batch, "first", blocks[0].Hash(), "last", blocks[i-1].Hash())
continue continue
} }
if _, err := chain.InsertChain(missing); err != nil {
if _, err := chain.InsertChain(blocks[:i]); err != nil {
return fmt.Errorf("invalid block %d: %v", n, err) return fmt.Errorf("invalid block %d: %v", n, err)
} }
} }
return nil return nil
} }
func hasAllBlocks(chain *core.BlockChain, bs []*types.Block) bool { func missingBlocks(chain *core.BlockChain, blocks []*types.Block) []*types.Block {
for _, b := range bs { head := chain.CurrentBlock()
if !chain.HasBlock(b.Hash(), b.NumberU64()) { for i, block := range blocks {
return false // If we're behind the chain head, only check block, state is available at head
if head.NumberU64() > block.NumberU64() {
if !chain.HasBlock(block.Hash(), block.NumberU64()) {
return blocks[i:]
}
continue
} }
// If we're above the chain head, state availability is a must
if !chain.HasBlockAndState(block.Hash(), block.NumberU64()) {
return blocks[i:]
} }
return true }
return nil
} }
func ExportChain(blockchain *core.BlockChain, fn string) error { func ExportChain(blockchain *core.BlockChain, fn string) error {

@ -170,7 +170,11 @@ var (
Usage: `Blockchain sync mode ("fast", "full", or "light")`, Usage: `Blockchain sync mode ("fast", "full", or "light")`,
Value: &defaultSyncMode, Value: &defaultSyncMode,
} }
GCModeFlag = cli.StringFlag{
Name: "gcmode",
Usage: `Blockchain garbage collection mode ("full", "archive")`,
Value: "full",
}
LightServFlag = cli.IntFlag{ LightServFlag = cli.IntFlag{
Name: "lightserv", Name: "lightserv",
Usage: "Maximum percentage of time allowed for serving LES requests (0-90)", Usage: "Maximum percentage of time allowed for serving LES requests (0-90)",
@ -293,8 +297,18 @@ var (
// Performance tuning settings // Performance tuning settings
CacheFlag = cli.IntFlag{ CacheFlag = cli.IntFlag{
Name: "cache", Name: "cache",
Usage: "Megabytes of memory allocated to internal caching (min 16MB / database forced)", Usage: "Megabytes of memory allocated to internal caching",
Value: 128, Value: 1024,
}
CacheDatabaseFlag = cli.IntFlag{
Name: "cache.database",
Usage: "Percentage of cache memory allowance to use for database io",
Value: 75,
}
CacheGCFlag = cli.IntFlag{
Name: "cache.gc",
Usage: "Percentage of cache memory allowance to use for trie pruning",
Value: 25,
} }
TrieCacheGenFlag = cli.IntFlag{ TrieCacheGenFlag = cli.IntFlag{
Name: "trie-cache-gens", Name: "trie-cache-gens",
@ -1021,11 +1035,19 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
cfg.NetworkId = ctx.GlobalUint64(NetworkIdFlag.Name) cfg.NetworkId = ctx.GlobalUint64(NetworkIdFlag.Name)
} }
if ctx.GlobalIsSet(CacheFlag.Name) { if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheDatabaseFlag.Name) {
cfg.DatabaseCache = ctx.GlobalInt(CacheFlag.Name) cfg.DatabaseCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheDatabaseFlag.Name) / 100
} }
cfg.DatabaseHandles = makeDatabaseHandles() cfg.DatabaseHandles = makeDatabaseHandles()
if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" {
Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name)
}
cfg.NoPruning = ctx.GlobalString(GCModeFlag.Name) == "archive"
if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) {
cfg.TrieCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100
}
if ctx.GlobalIsSet(MinerThreadsFlag.Name) { if ctx.GlobalIsSet(MinerThreadsFlag.Name) {
cfg.MinerThreads = ctx.GlobalInt(MinerThreadsFlag.Name) cfg.MinerThreads = ctx.GlobalInt(MinerThreadsFlag.Name)
} }
@ -1157,7 +1179,7 @@ func SetupNetwork(ctx *cli.Context) {
// MakeChainDatabase open an LevelDB using the flags passed to the client and will hard crash if it fails. // MakeChainDatabase open an LevelDB using the flags passed to the client and will hard crash if it fails.
func MakeChainDatabase(ctx *cli.Context, stack *node.Node) ethdb.Database { func MakeChainDatabase(ctx *cli.Context, stack *node.Node) ethdb.Database {
var ( var (
cache = ctx.GlobalInt(CacheFlag.Name) cache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheDatabaseFlag.Name) / 100
handles = makeDatabaseHandles() handles = makeDatabaseHandles()
) )
name := "chaindata" name := "chaindata"
@ -1209,8 +1231,19 @@ func MakeChain(ctx *cli.Context, stack *node.Node) (chain *core.BlockChain, chai
}) })
} }
} }
if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" {
Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name)
}
cache := &core.CacheConfig{
Disabled: ctx.GlobalString(GCModeFlag.Name) == "archive",
TrieNodeLimit: eth.DefaultConfig.TrieCache,
TrieTimeLimit: eth.DefaultConfig.TrieTimeout,
}
if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) {
cache.TrieNodeLimit = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100
}
vmcfg := vm.Config{EnablePreimageRecording: ctx.GlobalBool(VMEnableDebugFlag.Name)} vmcfg := vm.Config{EnablePreimageRecording: ctx.GlobalBool(VMEnableDebugFlag.Name)}
chain, err = core.NewBlockChain(chainDb, config, engine, vmcfg) chain, err = core.NewBlockChain(chainDb, cache, config, engine, vmcfg)
if err != nil { if err != nil {
Fatalf("Can't create BlockChain: %v", err) Fatalf("Can't create BlockChain: %v", err)
} }

@ -20,18 +20,29 @@ import (
"fmt" "fmt"
) )
// StorageSize is a wrapper around a float value that supports user friendly
// formatting.
type StorageSize float64 type StorageSize float64
func (self StorageSize) String() string { // String implements the stringer interface.
if self > 1000000 { func (s StorageSize) String() string {
return fmt.Sprintf("%.2f mB", self/1000000) if s > 1000000 {
} else if self > 1000 { return fmt.Sprintf("%.2f mB", s/1000000)
return fmt.Sprintf("%.2f kB", self/1000) } else if s > 1000 {
return fmt.Sprintf("%.2f kB", s/1000)
} else { } else {
return fmt.Sprintf("%.2f B", self) return fmt.Sprintf("%.2f B", s)
} }
} }
func (self StorageSize) Int64() int64 { // TerminalString implements log.TerminalStringer, formatting a string for console
return int64(self) // output during logging.
func (s StorageSize) TerminalString() string {
if s > 1000000 {
return fmt.Sprintf("%.2fmB", s/1000000)
} else if s > 1000 {
return fmt.Sprintf("%.2fkB", s/1000)
} else {
return fmt.Sprintf("%.2fB", s)
}
} }

@ -23,6 +23,10 @@ var (
// that is unknown. // that is unknown.
ErrUnknownAncestor = errors.New("unknown ancestor") ErrUnknownAncestor = errors.New("unknown ancestor")
// ErrPrunedAncestor is returned when validating a block requires an ancestor
// that is known, but the state of which is not available.
ErrPrunedAncestor = errors.New("pruned ancestor")
// ErrFutureBlock is returned when a block's timestamp is in the future according // ErrFutureBlock is returned when a block's timestamp is in the future according
// to the current node. // to the current node.
ErrFutureBlock = errors.New("block in the future") ErrFutureBlock = errors.New("block in the future")

@ -173,7 +173,7 @@ func benchInsertChain(b *testing.B, disk bool, gen func(int, *BlockGen)) {
// Time the insertion of the new chain. // Time the insertion of the new chain.
// State and blocks are stored in the same DB. // State and blocks are stored in the same DB.
chainman, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) chainman, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer chainman.Stop() defer chainman.Stop()
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
@ -283,7 +283,7 @@ func benchReadChain(b *testing.B, full bool, count uint64) {
if err != nil { if err != nil {
b.Fatalf("error opening database at %v: %v", dir, err) b.Fatalf("error opening database at %v: %v", dir, err)
} }
chain, err := NewBlockChain(db, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) chain, err := NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
if err != nil { if err != nil {
b.Fatalf("error creating chain: %v", err) b.Fatalf("error creating chain: %v", err)
} }

@ -50,12 +50,15 @@ func NewBlockValidator(config *params.ChainConfig, blockchain *BlockChain, engin
// validated at this point. // validated at this point.
func (v *BlockValidator) ValidateBody(block *types.Block) error { func (v *BlockValidator) ValidateBody(block *types.Block) error {
// Check whether the block's known, and if not, that it's linkable // Check whether the block's known, and if not, that it's linkable
if v.bc.HasBlockAndState(block.Hash()) { if v.bc.HasBlockAndState(block.Hash(), block.NumberU64()) {
return ErrKnownBlock return ErrKnownBlock
} }
if !v.bc.HasBlockAndState(block.ParentHash()) { if !v.bc.HasBlockAndState(block.ParentHash(), block.NumberU64()-1) {
if !v.bc.HasBlock(block.ParentHash(), block.NumberU64()-1) {
return consensus.ErrUnknownAncestor return consensus.ErrUnknownAncestor
} }
return consensus.ErrPrunedAncestor
}
// Header validity is known at this point, check the uncles and transactions // Header validity is known at this point, check the uncles and transactions
header := block.Header() header := block.Header()
if err := v.engine.VerifyUncles(v.bc, block); err != nil { if err := v.engine.VerifyUncles(v.bc, block); err != nil {

@ -42,7 +42,7 @@ func TestHeaderVerification(t *testing.T) {
headers[i] = block.Header() headers[i] = block.Header()
} }
// Run the header checker for blocks one-by-one, checking for both valid and invalid nonces // Run the header checker for blocks one-by-one, checking for both valid and invalid nonces
chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
defer chain.Stop() defer chain.Stop()
for i := 0; i < len(blocks); i++ { for i := 0; i < len(blocks); i++ {
@ -106,11 +106,11 @@ func testHeaderConcurrentVerification(t *testing.T, threads int) {
var results <-chan error var results <-chan error
if valid { if valid {
chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
_, results = chain.engine.VerifyHeaders(chain, headers, seals) _, results = chain.engine.VerifyHeaders(chain, headers, seals)
chain.Stop() chain.Stop()
} else { } else {
chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFakeFailer(uint64(len(headers)-1)), vm.Config{}) chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFakeFailer(uint64(len(headers)-1)), vm.Config{})
_, results = chain.engine.VerifyHeaders(chain, headers, seals) _, results = chain.engine.VerifyHeaders(chain, headers, seals)
chain.Stop() chain.Stop()
} }
@ -173,7 +173,7 @@ func testHeaderConcurrentAbortion(t *testing.T, threads int) {
defer runtime.GOMAXPROCS(old) defer runtime.GOMAXPROCS(old)
// Start the verifications and immediately abort // Start the verifications and immediately abort
chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFakeDelayer(time.Millisecond), vm.Config{}) chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFakeDelayer(time.Millisecond), vm.Config{})
defer chain.Stop() defer chain.Stop()
abort, results := chain.engine.VerifyHeaders(chain, headers, seals) abort, results := chain.engine.VerifyHeaders(chain, headers, seals)

@ -42,6 +42,7 @@ import (
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/hashicorp/golang-lru" "github.com/hashicorp/golang-lru"
"gopkg.in/karalabe/cookiejar.v2/collections/prque"
) )
var ( var (
@ -56,11 +57,20 @@ const (
maxFutureBlocks = 256 maxFutureBlocks = 256
maxTimeFutureBlocks = 30 maxTimeFutureBlocks = 30
badBlockLimit = 10 badBlockLimit = 10
triesInMemory = 128
// BlockChainVersion ensures that an incompatible database forces a resync from scratch. // BlockChainVersion ensures that an incompatible database forces a resync from scratch.
BlockChainVersion = 3 BlockChainVersion = 3
) )
// CacheConfig contains the configuration values for the trie caching/pruning
// that's resident in a blockchain.
type CacheConfig struct {
Disabled bool // Whether to disable trie write caching (archive node)
TrieNodeLimit int // Memory limit (MB) at which to flush the current in-memory trie to disk
TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk
}
// BlockChain represents the canonical chain given a database with a genesis // BlockChain represents the canonical chain given a database with a genesis
// block. The Blockchain manages chain imports, reverts, chain reorganisations. // block. The Blockchain manages chain imports, reverts, chain reorganisations.
// //
@ -76,10 +86,14 @@ const (
// included in the canonical one where as GetBlockByNumber always represents the // included in the canonical one where as GetBlockByNumber always represents the
// canonical chain. // canonical chain.
type BlockChain struct { type BlockChain struct {
config *params.ChainConfig // chain & network configuration chainConfig *params.ChainConfig // Chain & network configuration
cacheConfig *CacheConfig // Cache configuration for pruning
db ethdb.Database // Low level persistent database to store final content in
triegc *prque.Prque // Priority queue mapping block numbers to tries to gc
gcproc time.Duration // Accumulates canonical block processing for trie dumping
hc *HeaderChain hc *HeaderChain
chainDb ethdb.Database
rmLogsFeed event.Feed rmLogsFeed event.Feed
chainFeed event.Feed chainFeed event.Feed
chainSideFeed event.Feed chainSideFeed event.Feed
@ -119,7 +133,13 @@ type BlockChain struct {
// NewBlockChain returns a fully initialised block chain using information // NewBlockChain returns a fully initialised block chain using information
// available in the database. It initialises the default Ethereum Validator and // available in the database. It initialises the default Ethereum Validator and
// Processor. // Processor.
func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) { func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) {
if cacheConfig == nil {
cacheConfig = &CacheConfig{
TrieNodeLimit: 256 * 1024 * 1024,
TrieTimeLimit: 5 * time.Minute,
}
}
bodyCache, _ := lru.New(bodyCacheLimit) bodyCache, _ := lru.New(bodyCacheLimit)
bodyRLPCache, _ := lru.New(bodyCacheLimit) bodyRLPCache, _ := lru.New(bodyCacheLimit)
blockCache, _ := lru.New(blockCacheLimit) blockCache, _ := lru.New(blockCacheLimit)
@ -127,9 +147,11 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co
badBlocks, _ := lru.New(badBlockLimit) badBlocks, _ := lru.New(badBlockLimit)
bc := &BlockChain{ bc := &BlockChain{
config: config, chainConfig: chainConfig,
chainDb: chainDb, cacheConfig: cacheConfig,
stateCache: state.NewDatabase(chainDb), db: db,
triegc: prque.New(),
stateCache: state.NewDatabase(db),
quit: make(chan struct{}), quit: make(chan struct{}),
bodyCache: bodyCache, bodyCache: bodyCache,
bodyRLPCache: bodyRLPCache, bodyRLPCache: bodyRLPCache,
@ -139,11 +161,11 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co
vmConfig: vmConfig, vmConfig: vmConfig,
badBlocks: badBlocks, badBlocks: badBlocks,
} }
bc.SetValidator(NewBlockValidator(config, bc, engine)) bc.SetValidator(NewBlockValidator(chainConfig, bc, engine))
bc.SetProcessor(NewStateProcessor(config, bc, engine)) bc.SetProcessor(NewStateProcessor(chainConfig, bc, engine))
var err error var err error
bc.hc, err = NewHeaderChain(chainDb, config, engine, bc.getProcInterrupt) bc.hc, err = NewHeaderChain(db, chainConfig, engine, bc.getProcInterrupt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -180,7 +202,7 @@ func (bc *BlockChain) getProcInterrupt() bool {
// assumes that the chain manager mutex is held. // assumes that the chain manager mutex is held.
func (bc *BlockChain) loadLastState() error { func (bc *BlockChain) loadLastState() error {
// Restore the last known head block // Restore the last known head block
head := GetHeadBlockHash(bc.chainDb) head := GetHeadBlockHash(bc.db)
if head == (common.Hash{}) { if head == (common.Hash{}) {
// Corrupt or empty database, init from scratch // Corrupt or empty database, init from scratch
log.Warn("Empty database, resetting chain") log.Warn("Empty database, resetting chain")
@ -196,15 +218,17 @@ func (bc *BlockChain) loadLastState() error {
// Make sure the state associated with the block is available // Make sure the state associated with the block is available
if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil { if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil {
// Dangling block without a state associated, init from scratch // Dangling block without a state associated, init from scratch
log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash()) log.Warn("Head state missing, repairing chain", "number", currentBlock.Number(), "hash", currentBlock.Hash())
return bc.Reset() if err := bc.repair(&currentBlock); err != nil {
return err
}
} }
// Everything seems to be fine, set as the head block // Everything seems to be fine, set as the head block
bc.currentBlock = currentBlock bc.currentBlock = currentBlock
// Restore the last known head header // Restore the last known head header
currentHeader := bc.currentBlock.Header() currentHeader := bc.currentBlock.Header()
if head := GetHeadHeaderHash(bc.chainDb); head != (common.Hash{}) { if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) {
if header := bc.GetHeaderByHash(head); header != nil { if header := bc.GetHeaderByHash(head); header != nil {
currentHeader = header currentHeader = header
} }
@ -213,7 +237,7 @@ func (bc *BlockChain) loadLastState() error {
// Restore the last known head fast block // Restore the last known head fast block
bc.currentFastBlock = bc.currentBlock bc.currentFastBlock = bc.currentBlock
if head := GetHeadFastBlockHash(bc.chainDb); head != (common.Hash{}) { if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) {
if block := bc.GetBlockByHash(head); block != nil { if block := bc.GetBlockByHash(head); block != nil {
bc.currentFastBlock = block bc.currentFastBlock = block
} }
@ -243,7 +267,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
// Rewind the header chain, deleting all block bodies until then // Rewind the header chain, deleting all block bodies until then
delFn := func(hash common.Hash, num uint64) { delFn := func(hash common.Hash, num uint64) {
DeleteBody(bc.chainDb, hash, num) DeleteBody(bc.db, hash, num)
} }
bc.hc.SetHead(head, delFn) bc.hc.SetHead(head, delFn)
currentHeader := bc.hc.CurrentHeader() currentHeader := bc.hc.CurrentHeader()
@ -275,10 +299,10 @@ func (bc *BlockChain) SetHead(head uint64) error {
if bc.currentFastBlock == nil { if bc.currentFastBlock == nil {
bc.currentFastBlock = bc.genesisBlock bc.currentFastBlock = bc.genesisBlock
} }
if err := WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash()); err != nil { if err := WriteHeadBlockHash(bc.db, bc.currentBlock.Hash()); err != nil {
log.Crit("Failed to reset head full block", "err", err) log.Crit("Failed to reset head full block", "err", err)
} }
if err := WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash()); err != nil { if err := WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash()); err != nil {
log.Crit("Failed to reset head fast block", "err", err) log.Crit("Failed to reset head fast block", "err", err)
} }
return bc.loadLastState() return bc.loadLastState()
@ -292,7 +316,7 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error {
if block == nil { if block == nil {
return fmt.Errorf("non existent block [%x…]", hash[:4]) return fmt.Errorf("non existent block [%x…]", hash[:4])
} }
if _, err := trie.NewSecure(block.Root(), bc.chainDb, 0); err != nil { if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB(), 0); err != nil {
return err return err
} }
// If all checks out, manually set the head block // If all checks out, manually set the head block
@ -387,7 +411,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil { if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil {
log.Crit("Failed to write genesis block TD", "err", err) log.Crit("Failed to write genesis block TD", "err", err)
} }
if err := WriteBlock(bc.chainDb, genesis); err != nil { if err := WriteBlock(bc.db, genesis); err != nil {
log.Crit("Failed to write genesis block", "err", err) log.Crit("Failed to write genesis block", "err", err)
} }
bc.genesisBlock = genesis bc.genesisBlock = genesis
@ -400,6 +424,24 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
return nil return nil
} }
// repair tries to repair the current blockchain by rolling back the current block
// until one with associated state is found. This is needed to fix incomplete db
// writes caused either by crashes/power outages, or simply non-committed tries.
//
// This method only rolls back the current block. The current header and current
// fast block are left intact.
func (bc *BlockChain) repair(head **types.Block) error {
for {
// Abort if we've rewound to a head block that does have associated state
if _, err := state.New((*head).Root(), bc.stateCache); err == nil {
log.Info("Rewound blockchain to past state", "number", (*head).Number(), "hash", (*head).Hash())
return nil
}
// Otherwise rewind one block and recheck state availability there
(*head) = bc.GetBlock((*head).ParentHash(), (*head).NumberU64()-1)
}
}
// Export writes the active chain to the given writer. // Export writes the active chain to the given writer.
func (bc *BlockChain) Export(w io.Writer) error { func (bc *BlockChain) Export(w io.Writer) error {
return bc.ExportN(w, uint64(0), bc.currentBlock.NumberU64()) return bc.ExportN(w, uint64(0), bc.currentBlock.NumberU64())
@ -437,13 +479,13 @@ func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
// Note, this function assumes that the `mu` mutex is held! // Note, this function assumes that the `mu` mutex is held!
func (bc *BlockChain) insert(block *types.Block) { func (bc *BlockChain) insert(block *types.Block) {
// If the block is on a side chain or an unknown one, force other heads onto it too // If the block is on a side chain or an unknown one, force other heads onto it too
updateHeads := GetCanonicalHash(bc.chainDb, block.NumberU64()) != block.Hash() updateHeads := GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash()
// Add the block to the canonical chain number scheme and mark as the head // Add the block to the canonical chain number scheme and mark as the head
if err := WriteCanonicalHash(bc.chainDb, block.Hash(), block.NumberU64()); err != nil { if err := WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil {
log.Crit("Failed to insert block number", "err", err) log.Crit("Failed to insert block number", "err", err)
} }
if err := WriteHeadBlockHash(bc.chainDb, block.Hash()); err != nil { if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil {
log.Crit("Failed to insert head block hash", "err", err) log.Crit("Failed to insert head block hash", "err", err)
} }
bc.currentBlock = block bc.currentBlock = block
@ -452,7 +494,7 @@ func (bc *BlockChain) insert(block *types.Block) {
if updateHeads { if updateHeads {
bc.hc.SetCurrentHeader(block.Header()) bc.hc.SetCurrentHeader(block.Header())
if err := WriteHeadFastBlockHash(bc.chainDb, block.Hash()); err != nil { if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil {
log.Crit("Failed to insert head fast block hash", "err", err) log.Crit("Failed to insert head fast block hash", "err", err)
} }
bc.currentFastBlock = block bc.currentFastBlock = block
@ -472,7 +514,7 @@ func (bc *BlockChain) GetBody(hash common.Hash) *types.Body {
body := cached.(*types.Body) body := cached.(*types.Body)
return body return body
} }
body := GetBody(bc.chainDb, hash, bc.hc.GetBlockNumber(hash)) body := GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash))
if body == nil { if body == nil {
return nil return nil
} }
@ -488,7 +530,7 @@ func (bc *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue {
if cached, ok := bc.bodyRLPCache.Get(hash); ok { if cached, ok := bc.bodyRLPCache.Get(hash); ok {
return cached.(rlp.RawValue) return cached.(rlp.RawValue)
} }
body := GetBodyRLP(bc.chainDb, hash, bc.hc.GetBlockNumber(hash)) body := GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash))
if len(body) == 0 { if len(body) == 0 {
return nil return nil
} }
@ -502,21 +544,25 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool {
if bc.blockCache.Contains(hash) { if bc.blockCache.Contains(hash) {
return true return true
} }
ok, _ := bc.chainDb.Has(blockBodyKey(hash, number)) ok, _ := bc.db.Has(blockBodyKey(hash, number))
return ok return ok
} }
// HasState checks if state trie is fully present in the database or not.
func (bc *BlockChain) HasState(hash common.Hash) bool {
_, err := bc.stateCache.OpenTrie(hash)
return err == nil
}
// HasBlockAndState checks if a block and associated state trie is fully present // HasBlockAndState checks if a block and associated state trie is fully present
// in the database or not, caching it if present. // in the database or not, caching it if present.
func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool { func (bc *BlockChain) HasBlockAndState(hash common.Hash, number uint64) bool {
// Check first that the block itself is known // Check first that the block itself is known
block := bc.GetBlockByHash(hash) block := bc.GetBlock(hash, number)
if block == nil { if block == nil {
return false return false
} }
// Ensure the associated state is also present return bc.HasState(block.Root())
_, err := bc.stateCache.OpenTrie(block.Root())
return err == nil
} }
// GetBlock retrieves a block from the database by hash and number, // GetBlock retrieves a block from the database by hash and number,
@ -526,7 +572,7 @@ func (bc *BlockChain) GetBlock(hash common.Hash, number uint64) *types.Block {
if block, ok := bc.blockCache.Get(hash); ok { if block, ok := bc.blockCache.Get(hash); ok {
return block.(*types.Block) return block.(*types.Block)
} }
block := GetBlock(bc.chainDb, hash, number) block := GetBlock(bc.db, hash, number)
if block == nil { if block == nil {
return nil return nil
} }
@ -543,13 +589,18 @@ func (bc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block {
// GetBlockByNumber retrieves a block from the database by number, caching it // GetBlockByNumber retrieves a block from the database by number, caching it
// (associated with its hash) if found. // (associated with its hash) if found.
func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block { func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block {
hash := GetCanonicalHash(bc.chainDb, number) hash := GetCanonicalHash(bc.db, number)
if hash == (common.Hash{}) { if hash == (common.Hash{}) {
return nil return nil
} }
return bc.GetBlock(hash, number) return bc.GetBlock(hash, number)
} }
// GetReceiptsByHash retrieves the receipts for all transactions in a given block.
func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts {
return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash))
}
// GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors. // GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors.
// [deprecated by eth/62] // [deprecated by eth/62]
func (bc *BlockChain) GetBlocksFromHash(hash common.Hash, n int) (blocks []*types.Block) { func (bc *BlockChain) GetBlocksFromHash(hash common.Hash, n int) (blocks []*types.Block) {
@ -577,6 +628,12 @@ func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types.
return uncles return uncles
} }
// TrieNode retrieves a blob of data associated with a trie node (or code hash)
// either from ephemeral in-memory cache, or from persistent storage.
func (bc *BlockChain) TrieNode(hash common.Hash) ([]byte, error) {
return bc.stateCache.TrieDB().Node(hash)
}
// Stop stops the blockchain service. If any imports are currently in progress // Stop stops the blockchain service. If any imports are currently in progress
// it will abort them using the procInterrupt. // it will abort them using the procInterrupt.
func (bc *BlockChain) Stop() { func (bc *BlockChain) Stop() {
@ -589,6 +646,33 @@ func (bc *BlockChain) Stop() {
atomic.StoreInt32(&bc.procInterrupt, 1) atomic.StoreInt32(&bc.procInterrupt, 1)
bc.wg.Wait() bc.wg.Wait()
// Ensure the state of a recent block is also stored to disk before exiting.
// It is fine if this state does not exist (fast start/stop cycle), but it is
// advisable to leave an N block gap from the head so 1) a restart loads up
// the last N blocks as sync assistance to remote nodes; 2) a restart during
// a (small) reorg doesn't require deep reprocesses; 3) chain "repair" from
// missing states are constantly tested.
//
// This may be tuned a bit on mainnet if its too annoying to reprocess the last
// N blocks.
if !bc.cacheConfig.Disabled {
triedb := bc.stateCache.TrieDB()
if number := bc.CurrentBlock().NumberU64(); number >= triesInMemory {
recent := bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - triesInMemory + 1)
log.Info("Writing cached state to disk", "block", recent.Number(), "hash", recent.Hash(), "root", recent.Root())
if err := triedb.Commit(recent.Root(), true); err != nil {
log.Error("Failed to commit recent state trie", "err", err)
}
}
for !bc.triegc.Empty() {
triedb.Dereference(bc.triegc.PopItem().(common.Hash), common.Hash{})
}
if size := triedb.Size(); size != 0 {
log.Error("Dangling trie nodes after full cleanup")
}
}
log.Info("Blockchain manager stopped") log.Info("Blockchain manager stopped")
} }
@ -633,11 +717,11 @@ func (bc *BlockChain) Rollback(chain []common.Hash) {
} }
if bc.currentFastBlock.Hash() == hash { if bc.currentFastBlock.Hash() == hash {
bc.currentFastBlock = bc.GetBlock(bc.currentFastBlock.ParentHash(), bc.currentFastBlock.NumberU64()-1) bc.currentFastBlock = bc.GetBlock(bc.currentFastBlock.ParentHash(), bc.currentFastBlock.NumberU64()-1)
WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash()) WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash())
} }
if bc.currentBlock.Hash() == hash { if bc.currentBlock.Hash() == hash {
bc.currentBlock = bc.GetBlock(bc.currentBlock.ParentHash(), bc.currentBlock.NumberU64()-1) bc.currentBlock = bc.GetBlock(bc.currentBlock.ParentHash(), bc.currentBlock.NumberU64()-1)
WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash()) WriteHeadBlockHash(bc.db, bc.currentBlock.Hash())
} }
} }
} }
@ -696,7 +780,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
stats = struct{ processed, ignored int32 }{} stats = struct{ processed, ignored int32 }{}
start = time.Now() start = time.Now()
bytes = 0 bytes = 0
batch = bc.chainDb.NewBatch() batch = bc.db.NewBatch()
) )
for i, block := range blockChain { for i, block := range blockChain {
receipts := receiptChain[i] receipts := receiptChain[i]
@ -714,7 +798,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
continue continue
} }
// Compute all the non-consensus fields of the receipts // Compute all the non-consensus fields of the receipts
SetReceiptsData(bc.config, block, receipts) SetReceiptsData(bc.chainConfig, block, receipts)
// Write all the data out into the database // Write all the data out into the database
if err := WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil { if err := WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil {
return i, fmt.Errorf("failed to write block body: %v", err) return i, fmt.Errorf("failed to write block body: %v", err)
@ -747,7 +831,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
head := blockChain[len(blockChain)-1] head := blockChain[len(blockChain)-1]
if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case
if bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64()).Cmp(td) < 0 { if bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64()).Cmp(td) < 0 {
if err := WriteHeadFastBlockHash(bc.chainDb, head.Hash()); err != nil { if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil {
log.Crit("Failed to update head fast block hash", "err", err) log.Crit("Failed to update head fast block hash", "err", err)
} }
bc.currentFastBlock = head bc.currentFastBlock = head
@ -758,15 +842,33 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
log.Info("Imported new block receipts", log.Info("Imported new block receipts",
"count", stats.processed, "count", stats.processed,
"elapsed", common.PrettyDuration(time.Since(start)), "elapsed", common.PrettyDuration(time.Since(start)),
"bytes", bytes,
"number", head.Number(), "number", head.Number(),
"hash", head.Hash(), "hash", head.Hash(),
"size", common.StorageSize(bytes),
"ignored", stats.ignored) "ignored", stats.ignored)
return 0, nil return 0, nil
} }
// WriteBlock writes the block to the chain. var lastWrite uint64
func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
// WriteBlockWithoutState writes only the block and its metadata to the database,
// but does not write any state. This is used to construct competing side forks
// up to the point where they exceed the canonical total difficulty.
func (bc *BlockChain) WriteBlockWithoutState(block *types.Block, td *big.Int) (err error) {
bc.wg.Add(1)
defer bc.wg.Done()
if err := bc.hc.WriteTd(block.Hash(), block.NumberU64(), td); err != nil {
return err
}
if err := WriteBlock(bc.db, block); err != nil {
return err
}
return nil
}
// WriteBlockWithState writes the block and all associated state to the database.
func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
bc.wg.Add(1) bc.wg.Add(1)
defer bc.wg.Done() defer bc.wg.Done()
@ -787,17 +889,73 @@ func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.R
return NonStatTy, err return NonStatTy, err
} }
// Write other block data using a batch. // Write other block data using a batch.
batch := bc.chainDb.NewBatch() batch := bc.db.NewBatch()
if err := WriteBlock(batch, block); err != nil { if err := WriteBlock(batch, block); err != nil {
return NonStatTy, err return NonStatTy, err
} }
if _, err := state.CommitTo(batch, bc.config.IsEIP158(block.Number())); err != nil { root, err := state.Commit(bc.chainConfig.IsEIP158(block.Number()))
if err != nil {
return NonStatTy, err return NonStatTy, err
} }
if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { triedb := bc.stateCache.TrieDB()
// If we're running an archive node, always flush
if bc.cacheConfig.Disabled {
if err := triedb.Commit(root, false); err != nil {
return NonStatTy, err return NonStatTy, err
} }
} else {
// Full but not archive node, do proper garbage collection
triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive
bc.triegc.Push(root, -float32(block.NumberU64()))
if current := block.NumberU64(); current > triesInMemory {
// Find the next state trie we need to commit
header := bc.GetHeaderByNumber(current - triesInMemory)
chosen := header.Number.Uint64()
// Only write to disk if we exceeded our memory allowance *and* also have at
// least a given number of tries gapped.
var (
size = triedb.Size()
limit = common.StorageSize(bc.cacheConfig.TrieNodeLimit) * 1024 * 1024
)
if size > limit || bc.gcproc > bc.cacheConfig.TrieTimeLimit {
// If we're exceeding limits but haven't reached a large enough memory gap,
// warn the user that the system is becoming unstable.
if chosen < lastWrite+triesInMemory {
switch {
case size >= 2*limit:
log.Error("Trie memory critical, forcing to disk", "size", size, "limit", limit, "optimum", float64(chosen-lastWrite)/triesInMemory)
case bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit:
log.Error("Trie timing critical, forcing to disk", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory)
case size > limit:
log.Warn("Trie memory at dangerous levels", "size", size, "limit", limit, "optimum", float64(chosen-lastWrite)/triesInMemory)
case bc.gcproc > bc.cacheConfig.TrieTimeLimit:
log.Warn("Trie timing at dangerous levels", "time", bc.gcproc, "limit", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory)
}
}
// If optimum or critical limits reached, write to disk
if chosen >= lastWrite+triesInMemory || size >= 2*limit || bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit {
triedb.Commit(header.Root, true)
lastWrite = chosen
bc.gcproc = 0
}
}
// Garbage collect anything below our required write retention
for !bc.triegc.Empty() {
root, number := bc.triegc.Pop()
if uint64(-number) > chosen {
bc.triegc.Push(root, number)
break
}
triedb.Dereference(root.(common.Hash), common.Hash{})
}
}
}
if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil {
return NonStatTy, err
}
// If the total difficulty is higher than our known, add it to the canonical chain // If the total difficulty is higher than our known, add it to the canonical chain
// Second clause in the if statement reduces the vulnerability to selfish mining. // Second clause in the if statement reduces the vulnerability to selfish mining.
// Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf // Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf
@ -818,7 +976,7 @@ func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.R
return NonStatTy, err return NonStatTy, err
} }
// Write hash preimages // Write hash preimages
if err := WritePreimages(bc.chainDb, block.NumberU64(), state.Preimages()); err != nil { if err := WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil {
return NonStatTy, err return NonStatTy, err
} }
status = CanonStatTy status = CanonStatTy
@ -910,16 +1068,14 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
if err == nil { if err == nil {
err = bc.Validator().ValidateBody(block) err = bc.Validator().ValidateBody(block)
} }
if err != nil { switch {
if err == ErrKnownBlock { case err == ErrKnownBlock:
stats.ignored++ stats.ignored++
continue continue
}
if err == consensus.ErrFutureBlock { case err == consensus.ErrFutureBlock:
// Allow up to MaxFuture second in the future blocks. If this limit // Allow up to MaxFuture second in the future blocks. If this limit is exceeded
// is exceeded the chain is discarded and processed at a later time // the chain is discarded and processed at a later time if given.
// if given.
max := big.NewInt(time.Now().Unix() + maxTimeFutureBlocks) max := big.NewInt(time.Now().Unix() + maxTimeFutureBlocks)
if block.Time().Cmp(max) > 0 { if block.Time().Cmp(max) > 0 {
return i, events, coalescedLogs, fmt.Errorf("future block: %v > %v", block.Time(), max) return i, events, coalescedLogs, fmt.Errorf("future block: %v > %v", block.Time(), max)
@ -927,14 +1083,45 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
bc.futureBlocks.Add(block.Hash(), block) bc.futureBlocks.Add(block.Hash(), block)
stats.queued++ stats.queued++
continue continue
}
if err == consensus.ErrUnknownAncestor && bc.futureBlocks.Contains(block.ParentHash()) { case err == consensus.ErrUnknownAncestor && bc.futureBlocks.Contains(block.ParentHash()):
bc.futureBlocks.Add(block.Hash(), block) bc.futureBlocks.Add(block.Hash(), block)
stats.queued++ stats.queued++
continue continue
case err == consensus.ErrPrunedAncestor:
// Block competing with the canonical chain, store in the db, but don't process
// until the competitor TD goes above the canonical TD
localTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64())
externTd := new(big.Int).Add(bc.GetTd(block.ParentHash(), block.NumberU64()-1), block.Difficulty())
if localTd.Cmp(externTd) > 0 {
if err = bc.WriteBlockWithoutState(block, externTd); err != nil {
return i, events, coalescedLogs, err
}
continue
} }
// Competitor chain beat canonical, gather all blocks from the common ancestor
var winner []*types.Block
parent := bc.GetBlock(block.ParentHash(), block.NumberU64()-1)
for !bc.HasState(parent.Root()) {
winner = append(winner, parent)
parent = bc.GetBlock(parent.ParentHash(), parent.NumberU64()-1)
}
for j := 0; j < len(winner)/2; j++ {
winner[j], winner[len(winner)-1-j] = winner[len(winner)-1-j], winner[j]
}
// Import all the pruned blocks to make the state available
bc.chainmu.Unlock()
_, evs, logs, err := bc.insertChain(winner)
bc.chainmu.Lock()
events, coalescedLogs = evs, logs
if err != nil {
return i, events, coalescedLogs, err
}
case err != nil:
bc.reportBlock(block, nil, err) bc.reportBlock(block, nil, err)
return i, events, coalescedLogs, err return i, events, coalescedLogs, err
} }
@ -962,8 +1149,10 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
bc.reportBlock(block, receipts, err) bc.reportBlock(block, receipts, err)
return i, events, coalescedLogs, err return i, events, coalescedLogs, err
} }
proctime := time.Since(bstart)
// Write the block to the chain and get the status. // Write the block to the chain and get the status.
status, err := bc.WriteBlockAndState(block, receipts, state) status, err := bc.WriteBlockWithState(block, receipts, state)
if err != nil { if err != nil {
return i, events, coalescedLogs, err return i, events, coalescedLogs, err
} }
@ -977,6 +1166,9 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
events = append(events, ChainEvent{block, block.Hash(), logs}) events = append(events, ChainEvent{block, block.Hash(), logs})
lastCanon = block lastCanon = block
// Only count canonical blocks for GC processing time
bc.gcproc += proctime
case SideStatTy: case SideStatTy:
log.Debug("Inserted forked block", "number", block.Number(), "hash", block.Hash(), "diff", block.Difficulty(), "elapsed", log.Debug("Inserted forked block", "number", block.Number(), "hash", block.Hash(), "diff", block.Difficulty(), "elapsed",
common.PrettyDuration(time.Since(bstart)), "txs", len(block.Transactions()), "gas", block.GasUsed(), "uncles", len(block.Uncles())) common.PrettyDuration(time.Since(bstart)), "txs", len(block.Transactions()), "gas", block.GasUsed(), "uncles", len(block.Uncles()))
@ -986,7 +1178,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
} }
stats.processed++ stats.processed++
stats.usedGas += usedGas stats.usedGas += usedGas
stats.report(chain, i) stats.report(chain, i, bc.stateCache.TrieDB().Size())
} }
// Append a single chain head event if we've progressed the chain // Append a single chain head event if we've progressed the chain
if lastCanon != nil && bc.CurrentBlock().Hash() == lastCanon.Hash() { if lastCanon != nil && bc.CurrentBlock().Hash() == lastCanon.Hash() {
@ -1009,7 +1201,7 @@ const statsReportLimit = 8 * time.Second
// report prints statistics if some number of blocks have been processed // report prints statistics if some number of blocks have been processed
// or more than a few seconds have passed since the last message. // or more than a few seconds have passed since the last message.
func (st *insertStats) report(chain []*types.Block, index int) { func (st *insertStats) report(chain []*types.Block, index int, cache common.StorageSize) {
// Fetch the timings for the batch // Fetch the timings for the batch
var ( var (
now = mclock.Now() now = mclock.Now()
@ -1024,7 +1216,7 @@ func (st *insertStats) report(chain []*types.Block, index int) {
context := []interface{}{ context := []interface{}{
"blocks", st.processed, "txs", txs, "mgas", float64(st.usedGas) / 1000000, "blocks", st.processed, "txs", txs, "mgas", float64(st.usedGas) / 1000000,
"elapsed", common.PrettyDuration(elapsed), "mgasps", float64(st.usedGas) * 1000 / float64(elapsed), "elapsed", common.PrettyDuration(elapsed), "mgasps", float64(st.usedGas) * 1000 / float64(elapsed),
"number", end.Number(), "hash", end.Hash(), "number", end.Number(), "hash", end.Hash(), "cache", cache,
} }
if st.queued > 0 { if st.queued > 0 {
context = append(context, []interface{}{"queued", st.queued}...) context = append(context, []interface{}{"queued", st.queued}...)
@ -1060,7 +1252,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// These logs are later announced as deleted. // These logs are later announced as deleted.
collectLogs = func(h common.Hash) { collectLogs = func(h common.Hash) {
// Coalesce logs and set 'Removed'. // Coalesce logs and set 'Removed'.
receipts := GetBlockReceipts(bc.chainDb, h, bc.hc.GetBlockNumber(h)) receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h))
for _, receipt := range receipts { for _, receipt := range receipts {
for _, log := range receipt.Logs { for _, log := range receipt.Logs {
del := *log del := *log
@ -1129,7 +1321,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// insert the block in the canonical way, re-writing history // insert the block in the canonical way, re-writing history
bc.insert(newChain[i]) bc.insert(newChain[i])
// write lookup entries for hash based transaction/receipt searches // write lookup entries for hash based transaction/receipt searches
if err := WriteTxLookupEntries(bc.chainDb, newChain[i]); err != nil { if err := WriteTxLookupEntries(bc.db, newChain[i]); err != nil {
return err return err
} }
addedTxs = append(addedTxs, newChain[i].Transactions()...) addedTxs = append(addedTxs, newChain[i].Transactions()...)
@ -1139,7 +1331,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
// When transactions get deleted from the database that means the // When transactions get deleted from the database that means the
// receipts that were created in the fork must also be deleted // receipts that were created in the fork must also be deleted
for _, tx := range diff { for _, tx := range diff {
DeleteTxLookupEntry(bc.chainDb, tx.Hash()) DeleteTxLookupEntry(bc.db, tx.Hash())
} }
if len(deletedLogs) > 0 { if len(deletedLogs) > 0 {
go bc.rmLogsFeed.Send(RemovedLogsEvent{deletedLogs}) go bc.rmLogsFeed.Send(RemovedLogsEvent{deletedLogs})
@ -1231,7 +1423,7 @@ Hash: 0x%x
Error: %v Error: %v
############################## ##############################
`, bc.config, block.Number(), block.Hash(), receiptString, err)) `, bc.chainConfig, block.Number(), block.Hash(), receiptString, err))
} }
// InsertHeaderChain attempts to insert the given header chain in to the local // InsertHeaderChain attempts to insert the given header chain in to the local
@ -1338,7 +1530,7 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header {
} }
// Config retrieves the blockchain's chain configuration. // Config retrieves the blockchain's chain configuration.
func (bc *BlockChain) Config() *params.ChainConfig { return bc.config } func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig }
// Engine retrieves the blockchain's consensus engine. // Engine retrieves the blockchain's consensus engine.
func (bc *BlockChain) Engine() consensus.Engine { return bc.engine } func (bc *BlockChain) Engine() consensus.Engine { return bc.engine }

@ -46,7 +46,7 @@ func newTestBlockChain(fake bool) *BlockChain {
if !fake { if !fake {
engine = ethash.NewTester() engine = ethash.NewTester()
} }
blockchain, err := NewBlockChain(db, gspec.Config, engine, vm.Config{}) blockchain, err := NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -148,9 +148,9 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
return err return err
} }
blockchain.mu.Lock() blockchain.mu.Lock()
WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
WriteBlock(blockchain.chainDb, block) WriteBlock(blockchain.db, block)
statedb.CommitTo(blockchain.chainDb, false) statedb.Commit(false)
blockchain.mu.Unlock() blockchain.mu.Unlock()
} }
return nil return nil
@ -166,8 +166,8 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error
} }
// Manually insert the header into the database, but don't reorganise (allows subsequent testing) // Manually insert the header into the database, but don't reorganise (allows subsequent testing)
blockchain.mu.Lock() blockchain.mu.Lock()
WriteTd(blockchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash))) WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
WriteHeader(blockchain.chainDb, header) WriteHeader(blockchain.db, header)
blockchain.mu.Unlock() blockchain.mu.Unlock()
} }
return nil return nil
@ -186,9 +186,9 @@ func TestLastBlock(t *testing.T) {
bchain := newTestBlockChain(false) bchain := newTestBlockChain(false)
defer bchain.Stop() defer bchain.Stop()
block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.chainDb, 0)[0] block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.db, 0)[0]
bchain.insert(block) bchain.insert(block)
if block.Hash() != GetHeadBlockHash(bchain.chainDb) { if block.Hash() != GetHeadBlockHash(bchain.db) {
t.Errorf("Write/Get HeadBlockHash failed") t.Errorf("Write/Get HeadBlockHash failed")
} }
} }
@ -496,7 +496,7 @@ func testReorgBadHashes(t *testing.T, full bool) {
} }
// Create a new BlockChain and check that it rolled back the state. // Create a new BlockChain and check that it rolled back the state.
ncm, err := NewBlockChain(bc.chainDb, bc.config, ethash.NewFaker(), vm.Config{}) ncm, err := NewBlockChain(bc.db, nil, bc.chainConfig, ethash.NewFaker(), vm.Config{})
if err != nil { if err != nil {
t.Fatalf("failed to create new chain manager: %v", err) t.Fatalf("failed to create new chain manager: %v", err)
} }
@ -609,7 +609,7 @@ func TestFastVsFullChains(t *testing.T) {
// Import the chain as an archive node for the comparison baseline // Import the chain as an archive node for the comparison baseline
archiveDb, _ := ethdb.NewMemDatabase() archiveDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(archiveDb) gspec.MustCommit(archiveDb)
archive, _ := NewBlockChain(archiveDb, gspec.Config, ethash.NewFaker(), vm.Config{}) archive, _ := NewBlockChain(archiveDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer archive.Stop() defer archive.Stop()
if n, err := archive.InsertChain(blocks); err != nil { if n, err := archive.InsertChain(blocks); err != nil {
@ -618,7 +618,7 @@ func TestFastVsFullChains(t *testing.T) {
// Fast import the chain as a non-archive node to test // Fast import the chain as a non-archive node to test
fastDb, _ := ethdb.NewMemDatabase() fastDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(fastDb) gspec.MustCommit(fastDb)
fast, _ := NewBlockChain(fastDb, gspec.Config, ethash.NewFaker(), vm.Config{}) fast, _ := NewBlockChain(fastDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer fast.Stop() defer fast.Stop()
headers := make([]*types.Header, len(blocks)) headers := make([]*types.Header, len(blocks))
@ -696,7 +696,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
archiveDb, _ := ethdb.NewMemDatabase() archiveDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(archiveDb) gspec.MustCommit(archiveDb)
archive, _ := NewBlockChain(archiveDb, gspec.Config, ethash.NewFaker(), vm.Config{}) archive, _ := NewBlockChain(archiveDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
if n, err := archive.InsertChain(blocks); err != nil { if n, err := archive.InsertChain(blocks); err != nil {
t.Fatalf("failed to process block %d: %v", n, err) t.Fatalf("failed to process block %d: %v", n, err)
} }
@ -709,7 +709,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
// Import the chain as a non-archive node and ensure all pointers are updated // Import the chain as a non-archive node and ensure all pointers are updated
fastDb, _ := ethdb.NewMemDatabase() fastDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(fastDb) gspec.MustCommit(fastDb)
fast, _ := NewBlockChain(fastDb, gspec.Config, ethash.NewFaker(), vm.Config{}) fast, _ := NewBlockChain(fastDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer fast.Stop() defer fast.Stop()
headers := make([]*types.Header, len(blocks)) headers := make([]*types.Header, len(blocks))
@ -730,7 +730,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
lightDb, _ := ethdb.NewMemDatabase() lightDb, _ := ethdb.NewMemDatabase()
gspec.MustCommit(lightDb) gspec.MustCommit(lightDb)
light, _ := NewBlockChain(lightDb, gspec.Config, ethash.NewFaker(), vm.Config{}) light, _ := NewBlockChain(lightDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
if n, err := light.InsertHeaderChain(headers, 1); err != nil { if n, err := light.InsertHeaderChain(headers, 1); err != nil {
t.Fatalf("failed to insert header %d: %v", n, err) t.Fatalf("failed to insert header %d: %v", n, err)
} }
@ -799,7 +799,7 @@ func TestChainTxReorgs(t *testing.T) {
} }
}) })
// Import the chain. This runs all block validation rules. // Import the chain. This runs all block validation rules.
blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
if i, err := blockchain.InsertChain(chain); err != nil { if i, err := blockchain.InsertChain(chain); err != nil {
t.Fatalf("failed to insert original chain[%d]: %v", i, err) t.Fatalf("failed to insert original chain[%d]: %v", i, err)
} }
@ -870,7 +870,7 @@ func TestLogReorgs(t *testing.T) {
signer = types.NewEIP155Signer(gspec.Config.ChainId) signer = types.NewEIP155Signer(gspec.Config.ChainId)
) )
blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop() defer blockchain.Stop()
rmLogsCh := make(chan RemovedLogsEvent) rmLogsCh := make(chan RemovedLogsEvent)
@ -917,7 +917,7 @@ func TestReorgSideEvent(t *testing.T) {
signer = types.NewEIP155Signer(gspec.Config.ChainId) signer = types.NewEIP155Signer(gspec.Config.ChainId)
) )
blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop() defer blockchain.Stop()
chain, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, gen *BlockGen) {}) chain, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, gen *BlockGen) {})
@ -992,7 +992,7 @@ func TestCanonicalBlockRetrieval(t *testing.T) {
bc := newTestBlockChain(true) bc := newTestBlockChain(true)
defer bc.Stop() defer bc.Stop()
chain, _ := GenerateChain(bc.config, bc.genesisBlock, ethash.NewFaker(), bc.chainDb, 10, func(i int, gen *BlockGen) {}) chain, _ := GenerateChain(bc.chainConfig, bc.genesisBlock, ethash.NewFaker(), bc.db, 10, func(i int, gen *BlockGen) {})
var pend sync.WaitGroup var pend sync.WaitGroup
pend.Add(len(chain)) pend.Add(len(chain))
@ -1003,14 +1003,14 @@ func TestCanonicalBlockRetrieval(t *testing.T) {
// try to retrieve a block by its canonical hash and see if the block data can be retrieved. // try to retrieve a block by its canonical hash and see if the block data can be retrieved.
for { for {
ch := GetCanonicalHash(bc.chainDb, block.NumberU64()) ch := GetCanonicalHash(bc.db, block.NumberU64())
if ch == (common.Hash{}) { if ch == (common.Hash{}) {
continue // busy wait for canonical hash to be written continue // busy wait for canonical hash to be written
} }
if ch != block.Hash() { if ch != block.Hash() {
t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex()) t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex())
} }
fb := GetBlock(bc.chainDb, ch, block.NumberU64()) fb := GetBlock(bc.db, ch, block.NumberU64())
if fb == nil { if fb == nil {
t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex()) t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex())
} }
@ -1043,7 +1043,7 @@ func TestEIP155Transition(t *testing.T) {
genesis = gspec.MustCommit(db) genesis = gspec.MustCommit(db)
) )
blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop() defer blockchain.Stop()
blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 4, func(i int, block *BlockGen) { blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 4, func(i int, block *BlockGen) {
@ -1151,7 +1151,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
} }
genesis = gspec.MustCommit(db) genesis = gspec.MustCommit(db)
) )
blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop() defer blockchain.Stop()
blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, block *BlockGen) { blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, block *BlockGen) {
@ -1226,7 +1226,7 @@ func TestBlockchainHeaderchainReorgConsistency(t *testing.T) {
diskdb, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
new(Genesis).MustCommit(diskdb) new(Genesis).MustCommit(diskdb)
chain, err := NewBlockChain(diskdb, params.TestChainConfig, engine, vm.Config{}) chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
if err != nil { if err != nil {
t.Fatalf("failed to create tester chain: %v", err) t.Fatalf("failed to create tester chain: %v", err)
} }
@ -1245,3 +1245,102 @@ func TestBlockchainHeaderchainReorgConsistency(t *testing.T) {
} }
} }
} }
// Tests that importing small side forks doesn't leave junk in the trie database
// cache (which would eventually cause memory issues).
func TestTrieForkGC(t *testing.T) {
// Generate a canonical chain to act as the main dataset
engine := ethash.NewFaker()
db, _ := ethdb.NewMemDatabase()
genesis := new(Genesis).MustCommit(db)
blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) })
// Generate a bunch of fork blocks, each side forking from the canonical chain
forks := make([]*types.Block, len(blocks))
for i := 0; i < len(forks); i++ {
parent := genesis
if i > 0 {
parent = blocks[i-1]
}
fork, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) })
forks[i] = fork[0]
}
// Import the canonical and fork chain side by side, forcing the trie cache to cache both
diskdb, _ := ethdb.NewMemDatabase()
new(Genesis).MustCommit(diskdb)
chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
if err != nil {
t.Fatalf("failed to create tester chain: %v", err)
}
for i := 0; i < len(blocks); i++ {
if _, err := chain.InsertChain(blocks[i : i+1]); err != nil {
t.Fatalf("block %d: failed to insert into chain: %v", i, err)
}
if _, err := chain.InsertChain(forks[i : i+1]); err != nil {
t.Fatalf("fork %d: failed to insert into chain: %v", i, err)
}
}
// Dereference all the recent tries and ensure no past trie is left in
for i := 0; i < triesInMemory; i++ {
chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root(), common.Hash{})
chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root(), common.Hash{})
}
if len(chain.stateCache.TrieDB().Nodes()) > 0 {
t.Fatalf("stale tries still alive after garbase collection")
}
}
// Tests that doing large reorgs works even if the state associated with the
// forking point is not available any more.
func TestLargeReorgTrieGC(t *testing.T) {
// Generate the original common chain segment and the two competing forks
engine := ethash.NewFaker()
db, _ := ethdb.NewMemDatabase()
genesis := new(Genesis).MustCommit(db)
shared, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 64, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) })
original, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) })
competitor, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory+1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{3}) })
// Import the shared chain and the original canonical one
diskdb, _ := ethdb.NewMemDatabase()
new(Genesis).MustCommit(diskdb)
chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
if err != nil {
t.Fatalf("failed to create tester chain: %v", err)
}
if _, err := chain.InsertChain(shared); err != nil {
t.Fatalf("failed to insert shared chain: %v", err)
}
if _, err := chain.InsertChain(original); err != nil {
t.Fatalf("failed to insert shared chain: %v", err)
}
// Ensure that the state associated with the forking point is pruned away
if node, _ := chain.stateCache.TrieDB().Node(shared[len(shared)-1].Root()); node != nil {
t.Fatalf("common-but-old ancestor still cache")
}
// Import the competitor chain without exceeding the canonical's TD and ensure
// we have not processed any of the blocks (protection against malicious blocks)
if _, err := chain.InsertChain(competitor[:len(competitor)-2]); err != nil {
t.Fatalf("failed to insert competitor chain: %v", err)
}
for i, block := range competitor[:len(competitor)-2] {
if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil {
t.Fatalf("competitor %d: low TD chain became processed", i)
}
}
// Import the head of the competitor chain, triggering the reorg and ensure we
// successfully reprocess all the stashed away blocks.
if _, err := chain.InsertChain(competitor[len(competitor)-2:]); err != nil {
t.Fatalf("failed to finalize competitor chain: %v", err)
}
for i, block := range competitor[:len(competitor)-triesInMemory] {
if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil {
t.Fatalf("competitor %d: competing chain state missing", i)
}
}
}

@ -203,6 +203,9 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainE
if header.ParentHash != prevHash { if header.ParentHash != prevHash {
// Reorg to the common ancestor (might not exist in light sync mode, skip reorg then) // Reorg to the common ancestor (might not exist in light sync mode, skip reorg then)
// TODO(karalabe, zsfelfoldi): This seems a bit brittle, can we detect this case explicitly? // TODO(karalabe, zsfelfoldi): This seems a bit brittle, can we detect this case explicitly?
// TODO(karalabe): This operation is expensive and might block, causing the event system to
// potentially also lock up. We need to do with on a different thread somehow.
if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil { if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil {
c.newHead(h.Number.Uint64(), true) c.newHead(h.Number.Uint64(), true)
} }

@ -166,7 +166,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
genblock := func(i int, parent *types.Block, statedb *state.StateDB) (*types.Block, types.Receipts) { genblock := func(i int, parent *types.Block, statedb *state.StateDB) (*types.Block, types.Receipts) {
// TODO(karalabe): This is needed for clique, which depends on multiple blocks. // TODO(karalabe): This is needed for clique, which depends on multiple blocks.
// It's nonetheless ugly to spin up a blockchain here. Get rid of this somehow. // It's nonetheless ugly to spin up a blockchain here. Get rid of this somehow.
blockchain, _ := NewBlockChain(db, config, engine, vm.Config{}) blockchain, _ := NewBlockChain(db, nil, config, engine, vm.Config{})
defer blockchain.Stop() defer blockchain.Stop()
b := &BlockGen{i: i, parent: parent, chain: blocks, chainReader: blockchain, statedb: statedb, config: config, engine: engine} b := &BlockGen{i: i, parent: parent, chain: blocks, chainReader: blockchain, statedb: statedb, config: config, engine: engine}
@ -192,10 +192,13 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
if b.engine != nil { if b.engine != nil {
block, _ := b.engine.Finalize(b.chainReader, b.header, statedb, b.txs, b.uncles, b.receipts) block, _ := b.engine.Finalize(b.chainReader, b.header, statedb, b.txs, b.uncles, b.receipts)
// Write state changes to db // Write state changes to db
_, err := statedb.CommitTo(db, config.IsEIP158(b.header.Number)) root, err := statedb.Commit(config.IsEIP158(b.header.Number))
if err != nil { if err != nil {
panic(fmt.Sprintf("state write error: %v", err)) panic(fmt.Sprintf("state write error: %v", err))
} }
if err := statedb.Database().TrieDB().Commit(root, false); err != nil {
panic(fmt.Sprintf("trie write error: %v", err))
}
return block, b.receipts return block, b.receipts
} }
return nil, nil return nil, nil
@ -246,7 +249,7 @@ func newCanonical(engine consensus.Engine, n int, full bool) (ethdb.Database, *B
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
genesis := gspec.MustCommit(db) genesis := gspec.MustCommit(db)
blockchain, _ := NewBlockChain(db, params.AllEthashProtocolChanges, engine, vm.Config{}) blockchain, _ := NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{})
// Create and inject the requested chain // Create and inject the requested chain
if n == 0 { if n == 0 {
return db, blockchain, nil return db, blockchain, nil

@ -79,7 +79,7 @@ func ExampleGenerateChain() {
}) })
// Import the chain. This runs all block validation rules. // Import the chain. This runs all block validation rules.
blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
defer blockchain.Stop() defer blockchain.Stop()
if i, err := blockchain.InsertChain(chain); err != nil { if i, err := blockchain.InsertChain(chain); err != nil {

@ -45,7 +45,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
proConf.DAOForkBlock = forkBlock proConf.DAOForkBlock = forkBlock
proConf.DAOForkSupport = true proConf.DAOForkSupport = true
proBc, _ := NewBlockChain(proDb, &proConf, ethash.NewFaker(), vm.Config{}) proBc, _ := NewBlockChain(proDb, nil, &proConf, ethash.NewFaker(), vm.Config{})
defer proBc.Stop() defer proBc.Stop()
conDb, _ := ethdb.NewMemDatabase() conDb, _ := ethdb.NewMemDatabase()
@ -55,7 +55,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
conConf.DAOForkBlock = forkBlock conConf.DAOForkBlock = forkBlock
conConf.DAOForkSupport = false conConf.DAOForkSupport = false
conBc, _ := NewBlockChain(conDb, &conConf, ethash.NewFaker(), vm.Config{}) conBc, _ := NewBlockChain(conDb, nil, &conConf, ethash.NewFaker(), vm.Config{})
defer conBc.Stop() defer conBc.Stop()
if _, err := proBc.InsertChain(prefix); err != nil { if _, err := proBc.InsertChain(prefix); err != nil {
@ -69,7 +69,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
// Create a pro-fork block, and try to feed into the no-fork chain // Create a pro-fork block, and try to feed into the no-fork chain
db, _ = ethdb.NewMemDatabase() db, _ = ethdb.NewMemDatabase()
gspec.MustCommit(db) gspec.MustCommit(db)
bc, _ := NewBlockChain(db, &conConf, ethash.NewFaker(), vm.Config{}) bc, _ := NewBlockChain(db, nil, &conConf, ethash.NewFaker(), vm.Config{})
defer bc.Stop() defer bc.Stop()
blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64())) blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64()))
@ -79,6 +79,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
if _, err := bc.InsertChain(blocks); err != nil { if _, err := bc.InsertChain(blocks); err != nil {
t.Fatalf("failed to import contra-fork chain for expansion: %v", err) t.Fatalf("failed to import contra-fork chain for expansion: %v", err)
} }
if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
t.Fatalf("failed to commit contra-fork head for expansion: %v", err)
}
blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {}) blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
if _, err := conBc.InsertChain(blocks); err == nil { if _, err := conBc.InsertChain(blocks); err == nil {
t.Fatalf("contra-fork chain accepted pro-fork block: %v", blocks[0]) t.Fatalf("contra-fork chain accepted pro-fork block: %v", blocks[0])
@ -91,7 +94,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
// Create a no-fork block, and try to feed into the pro-fork chain // Create a no-fork block, and try to feed into the pro-fork chain
db, _ = ethdb.NewMemDatabase() db, _ = ethdb.NewMemDatabase()
gspec.MustCommit(db) gspec.MustCommit(db)
bc, _ = NewBlockChain(db, &proConf, ethash.NewFaker(), vm.Config{}) bc, _ = NewBlockChain(db, nil, &proConf, ethash.NewFaker(), vm.Config{})
defer bc.Stop() defer bc.Stop()
blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64())) blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64()))
@ -101,6 +104,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
if _, err := bc.InsertChain(blocks); err != nil { if _, err := bc.InsertChain(blocks); err != nil {
t.Fatalf("failed to import pro-fork chain for expansion: %v", err) t.Fatalf("failed to import pro-fork chain for expansion: %v", err)
} }
if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
t.Fatalf("failed to commit pro-fork head for expansion: %v", err)
}
blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {}) blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
if _, err := proBc.InsertChain(blocks); err == nil { if _, err := proBc.InsertChain(blocks); err == nil {
t.Fatalf("pro-fork chain accepted contra-fork block: %v", blocks[0]) t.Fatalf("pro-fork chain accepted contra-fork block: %v", blocks[0])
@ -114,7 +120,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
// Verify that contra-forkers accept pro-fork extra-datas after forking finishes // Verify that contra-forkers accept pro-fork extra-datas after forking finishes
db, _ = ethdb.NewMemDatabase() db, _ = ethdb.NewMemDatabase()
gspec.MustCommit(db) gspec.MustCommit(db)
bc, _ := NewBlockChain(db, &conConf, ethash.NewFaker(), vm.Config{}) bc, _ := NewBlockChain(db, nil, &conConf, ethash.NewFaker(), vm.Config{})
defer bc.Stop() defer bc.Stop()
blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64())) blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64()))
@ -124,6 +130,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
if _, err := bc.InsertChain(blocks); err != nil { if _, err := bc.InsertChain(blocks); err != nil {
t.Fatalf("failed to import contra-fork chain for expansion: %v", err) t.Fatalf("failed to import contra-fork chain for expansion: %v", err)
} }
if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
t.Fatalf("failed to commit contra-fork head for expansion: %v", err)
}
blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {}) blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
if _, err := conBc.InsertChain(blocks); err != nil { if _, err := conBc.InsertChain(blocks); err != nil {
t.Fatalf("contra-fork chain didn't accept pro-fork block post-fork: %v", err) t.Fatalf("contra-fork chain didn't accept pro-fork block post-fork: %v", err)
@ -131,7 +140,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
// Verify that pro-forkers accept contra-fork extra-datas after forking finishes // Verify that pro-forkers accept contra-fork extra-datas after forking finishes
db, _ = ethdb.NewMemDatabase() db, _ = ethdb.NewMemDatabase()
gspec.MustCommit(db) gspec.MustCommit(db)
bc, _ = NewBlockChain(db, &proConf, ethash.NewFaker(), vm.Config{}) bc, _ = NewBlockChain(db, nil, &proConf, ethash.NewFaker(), vm.Config{})
defer bc.Stop() defer bc.Stop()
blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64())) blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64()))
@ -141,6 +150,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
if _, err := bc.InsertChain(blocks); err != nil { if _, err := bc.InsertChain(blocks); err != nil {
t.Fatalf("failed to import pro-fork chain for expansion: %v", err) t.Fatalf("failed to import pro-fork chain for expansion: %v", err)
} }
if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
t.Fatalf("failed to commit pro-fork head for expansion: %v", err)
}
blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {}) blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
if _, err := proBc.InsertChain(blocks); err != nil { if _, err := proBc.InsertChain(blocks); err != nil {
t.Fatalf("pro-fork chain didn't accept contra-fork block post-fork: %v", err) t.Fatalf("pro-fork chain didn't accept contra-fork block post-fork: %v", err)

@ -169,10 +169,9 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig
// Check whether the genesis block is already written. // Check whether the genesis block is already written.
if genesis != nil { if genesis != nil {
block, _ := genesis.ToBlock() hash := genesis.ToBlock(nil).Hash()
hash := block.Hash()
if hash != stored { if hash != stored {
return genesis.Config, block.Hash(), &GenesisMismatchError{stored, hash} return genesis.Config, hash, &GenesisMismatchError{stored, hash}
} }
} }
@ -220,9 +219,12 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig {
} }
} }
// ToBlock creates the block and state of a genesis specification. // ToBlock creates the genesis block and writes state of a genesis specification
func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) { // to the given database (or discards it if nil).
db, _ := ethdb.NewMemDatabase() func (g *Genesis) ToBlock(db ethdb.Database) *types.Block {
if db == nil {
db, _ = ethdb.NewMemDatabase()
}
statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
for addr, account := range g.Alloc { for addr, account := range g.Alloc {
statedb.AddBalance(addr, account.Balance) statedb.AddBalance(addr, account.Balance)
@ -252,19 +254,19 @@ func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) {
if g.Difficulty == nil { if g.Difficulty == nil {
head.Difficulty = params.GenesisDifficulty head.Difficulty = params.GenesisDifficulty
} }
return types.NewBlock(head, nil, nil, nil), statedb statedb.Commit(false)
statedb.Database().TrieDB().Commit(root, true)
return types.NewBlock(head, nil, nil, nil)
} }
// Commit writes the block and state of a genesis specification to the database. // Commit writes the block and state of a genesis specification to the database.
// The block is committed as the canonical head block. // The block is committed as the canonical head block.
func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) {
block, statedb := g.ToBlock() block := g.ToBlock(db)
if block.Number().Sign() != 0 { if block.Number().Sign() != 0 {
return nil, fmt.Errorf("can't commit genesis block with number > 0") return nil, fmt.Errorf("can't commit genesis block with number > 0")
} }
if _, err := statedb.CommitTo(db, false); err != nil {
return nil, fmt.Errorf("cannot write state: %v", err)
}
if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil { if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil {
return nil, err return nil, err
} }

@ -30,11 +30,11 @@ import (
) )
func TestDefaultGenesisBlock(t *testing.T) { func TestDefaultGenesisBlock(t *testing.T) {
block, _ := DefaultGenesisBlock().ToBlock() block := DefaultGenesisBlock().ToBlock(nil)
if block.Hash() != params.MainnetGenesisHash { if block.Hash() != params.MainnetGenesisHash {
t.Errorf("wrong mainnet genesis hash, got %v, want %v", block.Hash(), params.MainnetGenesisHash) t.Errorf("wrong mainnet genesis hash, got %v, want %v", block.Hash(), params.MainnetGenesisHash)
} }
block, _ = DefaultTestnetGenesisBlock().ToBlock() block = DefaultTestnetGenesisBlock().ToBlock(nil)
if block.Hash() != params.TestnetGenesisHash { if block.Hash() != params.TestnetGenesisHash {
t.Errorf("wrong testnet genesis hash, got %v, want %v", block.Hash(), params.TestnetGenesisHash) t.Errorf("wrong testnet genesis hash, got %v, want %v", block.Hash(), params.TestnetGenesisHash)
} }
@ -118,7 +118,7 @@ func TestSetupGenesis(t *testing.T) {
// Commit the 'old' genesis block with Homestead transition at #2. // Commit the 'old' genesis block with Homestead transition at #2.
// Advance to block #4, past the homestead transition block of customg. // Advance to block #4, past the homestead transition block of customg.
genesis := oldcustomg.MustCommit(db) genesis := oldcustomg.MustCommit(db)
bc, _ := NewBlockChain(db, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{}) bc, _ := NewBlockChain(db, nil, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{})
defer bc.Stop() defer bc.Stop()
bc.SetValidator(bproc{}) bc.SetValidator(bproc{})
bc.InsertChain(makeBlockChainWithDiff(genesis, []int{2, 3, 4, 5}, 0)) bc.InsertChain(makeBlockChainWithDiff(genesis, []int{2, 3, 4, 5}, 0))

@ -40,16 +40,23 @@ const (
// Database wraps access to tries and contract code. // Database wraps access to tries and contract code.
type Database interface { type Database interface {
// Accessing tries:
// OpenTrie opens the main account trie. // OpenTrie opens the main account trie.
// OpenStorageTrie opens the storage trie of an account.
OpenTrie(root common.Hash) (Trie, error) OpenTrie(root common.Hash) (Trie, error)
// OpenStorageTrie opens the storage trie of an account.
OpenStorageTrie(addrHash, root common.Hash) (Trie, error) OpenStorageTrie(addrHash, root common.Hash) (Trie, error)
// Accessing contract code:
ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
// CopyTrie returns an independent copy of the given trie. // CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie CopyTrie(Trie) Trie
// ContractCode retrieves a particular contract's code.
ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
// ContractCodeSize retrieves a particular contracts code's size.
ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
// TrieDB retrieves the low level trie database used for data storage.
TrieDB() *trie.Database
} }
// Trie is a Ethereum Merkle Trie. // Trie is a Ethereum Merkle Trie.
@ -57,26 +64,33 @@ type Trie interface {
TryGet(key []byte) ([]byte, error) TryGet(key []byte) ([]byte, error)
TryUpdate(key, value []byte) error TryUpdate(key, value []byte) error
TryDelete(key []byte) error TryDelete(key []byte) error
CommitTo(trie.DatabaseWriter) (common.Hash, error) Commit(onleaf trie.LeafCallback) (common.Hash, error)
Hash() common.Hash Hash() common.Hash
NodeIterator(startKey []byte) trie.NodeIterator NodeIterator(startKey []byte) trie.NodeIterator
GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed
Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error
} }
// NewDatabase creates a backing store for state. The returned database is safe for // NewDatabase creates a backing store for state. The returned database is safe for
// concurrent use and retains cached trie nodes in memory. // concurrent use and retains cached trie nodes in memory. The pool is an optional
// intermediate trie-node memory pool between the low level storage layer and the
// high level trie abstraction.
func NewDatabase(db ethdb.Database) Database { func NewDatabase(db ethdb.Database) Database {
csc, _ := lru.New(codeSizeCacheSize) csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{db: db, codeSizeCache: csc} return &cachingDB{
db: trie.NewDatabase(db),
codeSizeCache: csc,
}
} }
type cachingDB struct { type cachingDB struct {
db ethdb.Database db *trie.Database
mu sync.Mutex mu sync.Mutex
pastTries []*trie.SecureTrie pastTries []*trie.SecureTrie
codeSizeCache *lru.Cache codeSizeCache *lru.Cache
} }
// OpenTrie opens the main account trie.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
@ -105,10 +119,12 @@ func (db *cachingDB) pushTrie(t *trie.SecureTrie) {
} }
} }
// OpenStorageTrie opens the storage trie of an account.
func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) {
return trie.NewSecure(root, db.db, 0) return trie.NewSecure(root, db.db, 0)
} }
// CopyTrie returns an independent copy of the given trie.
func (db *cachingDB) CopyTrie(t Trie) Trie { func (db *cachingDB) CopyTrie(t Trie) Trie {
switch t := t.(type) { switch t := t.(type) {
case cachedTrie: case cachedTrie:
@ -120,14 +136,16 @@ func (db *cachingDB) CopyTrie(t Trie) Trie {
} }
} }
// ContractCode retrieves a particular contract's code.
func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
code, err := db.db.Get(codeHash[:]) code, err := db.db.Node(codeHash)
if err == nil { if err == nil {
db.codeSizeCache.Add(codeHash, len(code)) db.codeSizeCache.Add(codeHash, len(code))
} }
return code, err return code, err
} }
// ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok { if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached.(int), nil return cached.(int), nil
@ -139,16 +157,25 @@ func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, erro
return len(code), err return len(code), err
} }
// TrieDB retrieves any intermediate trie-node caching layer.
func (db *cachingDB) TrieDB() *trie.Database {
return db.db
}
// cachedTrie inserts its trie into a cachingDB on commit. // cachedTrie inserts its trie into a cachingDB on commit.
type cachedTrie struct { type cachedTrie struct {
*trie.SecureTrie *trie.SecureTrie
db *cachingDB db *cachingDB
} }
func (m cachedTrie) CommitTo(dbw trie.DatabaseWriter) (common.Hash, error) { func (m cachedTrie) Commit(onleaf trie.LeafCallback) (common.Hash, error) {
root, err := m.SecureTrie.CommitTo(dbw) root, err := m.SecureTrie.Commit(onleaf)
if err == nil { if err == nil {
m.db.pushTrie(m.SecureTrie) m.db.pushTrie(m.SecureTrie)
} }
return root, err return root, err
} }
func (m cachedTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
return m.SecureTrie.Prove(key, fromLevel, proofDb)
}

@ -21,12 +21,13 @@ import (
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
) )
// Tests that the node iterator indeed walks over the entire database contents. // Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) { func TestNodeIteratorCoverage(t *testing.T) {
// Create some arbitrary test state to iterate // Create some arbitrary test state to iterate
db, mem, root, _ := makeTestState() db, root, _ := makeTestState()
state, err := New(root, db) state, err := New(root, db)
if err != nil { if err != nil {
@ -39,14 +40,18 @@ func TestNodeIteratorCoverage(t *testing.T) {
hashes[it.Hash] = struct{}{} hashes[it.Hash] = struct{}{}
} }
} }
// Cross check the iterated hashes and the database/nodepool content
// Cross check the hashes and the database itself
for hash := range hashes { for hash := range hashes {
if _, err := mem.Get(hash.Bytes()); err != nil { if _, err := db.TrieDB().Node(hash); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", hash, err) t.Errorf("failed to retrieve reported node %x", hash)
}
}
for _, hash := range db.TrieDB().Nodes() {
if _, ok := hashes[hash]; !ok {
t.Errorf("state entry not reported %x", hash)
} }
} }
for _, key := range mem.Keys() { for _, key := range db.TrieDB().DiskDB().(*ethdb.MemDatabase).Keys() {
if bytes.HasPrefix(key, []byte("secure-key-")) { if bytes.HasPrefix(key, []byte("secure-key-")) {
continue continue
} }

@ -25,7 +25,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
) )
var emptyCodeHash = crypto.Keccak256(nil) var emptyCodeHash = crypto.Keccak256(nil)
@ -238,12 +237,12 @@ func (self *stateObject) updateRoot(db Database) {
// CommitTrie the storage trie of the object to dwb. // CommitTrie the storage trie of the object to dwb.
// This updates the trie root. // This updates the trie root.
func (self *stateObject) CommitTrie(db Database, dbw trie.DatabaseWriter) error { func (self *stateObject) CommitTrie(db Database) error {
self.updateTrie(db) self.updateTrie(db)
if self.dbErr != nil { if self.dbErr != nil {
return self.dbErr return self.dbErr
} }
root, err := self.trie.CommitTo(dbw) root, err := self.trie.Commit(nil)
if err == nil { if err == nil {
self.data.Root = root self.data.Root = root
} }

@ -48,7 +48,7 @@ func (s *StateSuite) TestDump(c *checker.C) {
// write some of them to the trie // write some of them to the trie
s.state.updateStateObject(obj1) s.state.updateStateObject(obj1)
s.state.updateStateObject(obj2) s.state.updateStateObject(obj2)
s.state.CommitTo(s.db, false) s.state.Commit(false)
// check that dump contains the state objects that are in trie // check that dump contains the state objects that are in trie
got := string(s.state.Dump()) got := string(s.state.Dump())
@ -97,7 +97,7 @@ func (s *StateSuite) TestNull(c *checker.C) {
//value := common.FromHex("0x823140710bf13990e4500136726d8b55") //value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash var value common.Hash
s.state.SetState(address, common.Hash{}, value) s.state.SetState(address, common.Hash{}, value)
s.state.CommitTo(s.db, false) s.state.Commit(false)
value = s.state.GetState(address, common.Hash{}) value = s.state.GetState(address, common.Hash{})
if !common.EmptyHash(value) { if !common.EmptyHash(value) {
c.Errorf("expected empty hash. got %x", value) c.Errorf("expected empty hash. got %x", value)
@ -155,7 +155,7 @@ func TestSnapshot2(t *testing.T) {
so0.deleted = false so0.deleted = false
state.setStateObject(so0) state.setStateObject(so0)
root, _ := state.CommitTo(db, false) root, _ := state.Commit(false)
state.Reset(root) state.Reset(root)
// and one with deleted == true // and one with deleted == true

@ -36,6 +36,14 @@ type revision struct {
journalIndex int journalIndex int
} }
var (
// emptyState is the known hash of an empty state trie entry.
emptyState = crypto.Keccak256Hash(nil)
// emptyCode is the known hash of the empty EVM bytecode.
emptyCode = crypto.Keccak256Hash(nil)
)
// StateDBs within the ethereum protocol are used to store anything // StateDBs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing // within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve: // nested states. It's the general query interface to retrieve:
@ -235,6 +243,11 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash {
return common.Hash{} return common.Hash{}
} }
// Database retrieves the low level database supporting the lower level trie ops.
func (self *StateDB) Database() Database {
return self.db
}
// StorageTrie returns the storage trie of an account. // StorageTrie returns the storage trie of an account.
// The return value is a copy and is nil for non-existent accounts. // The return value is a copy and is nil for non-existent accounts.
func (self *StateDB) StorageTrie(a common.Address) Trie { func (self *StateDB) StorageTrie(a common.Address) Trie {
@ -568,8 +581,8 @@ func (s *StateDB) clearJournalAndRefund() {
s.refund = 0 s.refund = 0
} }
// CommitTo writes the state to the given database. // Commit writes the state to the underlying in-memory trie database.
func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (root common.Hash, err error) { func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) {
defer s.clearJournalAndRefund() defer s.clearJournalAndRefund()
// Commit objects to the trie. // Commit objects to the trie.
@ -583,13 +596,11 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro
case isDirty: case isDirty:
// Write any contract code associated with the state object // Write any contract code associated with the state object
if stateObject.code != nil && stateObject.dirtyCode { if stateObject.code != nil && stateObject.dirtyCode {
if err := dbw.Put(stateObject.CodeHash(), stateObject.code); err != nil { s.db.TrieDB().Insert(common.BytesToHash(stateObject.CodeHash()), stateObject.code)
return common.Hash{}, err
}
stateObject.dirtyCode = false stateObject.dirtyCode = false
} }
// Write any storage changes in the state object to its storage trie. // Write any storage changes in the state object to its storage trie.
if err := stateObject.CommitTrie(s.db, dbw); err != nil { if err := stateObject.CommitTrie(s.db); err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
// Update the object in the main account trie. // Update the object in the main account trie.
@ -598,7 +609,20 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro
delete(s.stateObjectsDirty, addr) delete(s.stateObjectsDirty, addr)
} }
// Write trie changes. // Write trie changes.
root, err = s.trie.CommitTo(dbw) root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error {
var account Account
if err := rlp.DecodeBytes(leaf, &account); err != nil {
return nil
}
if account.Root != emptyState {
s.db.TrieDB().Reference(account.Root, parent)
}
code := common.BytesToHash(account.CodeHash)
if code != emptyCode {
s.db.TrieDB().Reference(code, parent)
}
return nil
})
log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads()) log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads())
return root, err return root, err
} }

@ -97,10 +97,10 @@ func TestIntermediateLeaks(t *testing.T) {
} }
// Commit and cross check the databases. // Commit and cross check the databases.
if _, err := transState.CommitTo(transDb, false); err != nil { if _, err := transState.Commit(false); err != nil {
t.Fatalf("failed to commit transition state: %v", err) t.Fatalf("failed to commit transition state: %v", err)
} }
if _, err := finalState.CommitTo(finalDb, false); err != nil { if _, err := finalState.Commit(false); err != nil {
t.Fatalf("failed to commit final state: %v", err) t.Fatalf("failed to commit final state: %v", err)
} }
for _, key := range finalDb.Keys() { for _, key := range finalDb.Keys() {
@ -122,8 +122,8 @@ func TestIntermediateLeaks(t *testing.T) {
// https://github.com/ethereum/go-ethereum/pull/15549. // https://github.com/ethereum/go-ethereum/pull/15549.
func TestCopy(t *testing.T) { func TestCopy(t *testing.T) {
// Create a random state test to copy and modify "independently" // Create a random state test to copy and modify "independently"
mem, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
orig, _ := New(common.Hash{}, NewDatabase(mem)) orig, _ := New(common.Hash{}, NewDatabase(db))
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
@ -346,11 +346,10 @@ func (test *snapshotTest) run() bool {
} }
action.fn(action, state) action.fn(action, state)
} }
// Revert all snapshots in reverse order. Each revert must yield a state // Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied. // that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- { for sindex--; sindex >= 0; sindex-- {
checkstate, _ := New(common.Hash{}, NewDatabase(db)) checkstate, _ := New(common.Hash{}, state.Database())
for _, action := range test.actions[:test.snapshots[sindex]] { for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate) action.fn(action, checkstate)
} }
@ -409,7 +408,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
func (s *StateSuite) TestTouchDelete(c *check.C) { func (s *StateSuite) TestTouchDelete(c *check.C) {
s.state.GetOrNewStateObject(common.Address{}) s.state.GetOrNewStateObject(common.Address{})
root, _ := s.state.CommitTo(s.db, false) root, _ := s.state.Commit(false)
s.state.Reset(root) s.state.Reset(root)
snapshot := s.state.Snapshot() snapshot := s.state.Snapshot()
@ -417,7 +416,6 @@ func (s *StateSuite) TestTouchDelete(c *check.C) {
if len(s.state.stateObjectsDirty) != 1 { if len(s.state.stateObjectsDirty) != 1 {
c.Fatal("expected one dirty state object") c.Fatal("expected one dirty state object")
} }
s.state.RevertToSnapshot(snapshot) s.state.RevertToSnapshot(snapshot)
if len(s.state.stateObjectsDirty) != 0 { if len(s.state.stateObjectsDirty) != 0 {
c.Fatal("expected no dirty state object") c.Fatal("expected no dirty state object")

@ -36,10 +36,10 @@ type testAccount struct {
} }
// makeTestState create a sample test state to test node-wise reconstruction. // makeTestState create a sample test state to test node-wise reconstruction.
func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) { func makeTestState() (Database, common.Hash, []*testAccount) {
// Create an empty state // Create an empty state
mem, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
db := NewDatabase(mem) db := NewDatabase(diskdb)
state, _ := New(common.Hash{}, db) state, _ := New(common.Hash{}, db)
// Fill it with some arbitrary data // Fill it with some arbitrary data
@ -61,10 +61,10 @@ func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount)
state.updateStateObject(obj) state.updateStateObject(obj)
accounts = append(accounts, acc) accounts = append(accounts, acc)
} }
root, _ := state.CommitTo(mem, false) root, _ := state.Commit(false)
// Return the generated state // Return the generated state
return db, mem, root, accounts return db, root, accounts
} }
// checkStateAccounts cross references a reconstructed state with an expected // checkStateAccounts cross references a reconstructed state with an expected
@ -96,7 +96,7 @@ func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
if v, _ := db.Get(root[:]); v == nil { if v, _ := db.Get(root[:]); v == nil {
return nil // Consider a non existent state consistent. return nil // Consider a non existent state consistent.
} }
trie, err := trie.New(root, db) trie, err := trie.New(root, trie.NewDatabase(db))
if err != nil { if err != nil {
return err return err
} }
@ -138,7 +138,7 @@ func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t,
func testIterativeStateSync(t *testing.T, batch int) { func testIterativeStateSync(t *testing.T, batch int) {
// Create a random state to copy // Create a random state to copy
_, srcMem, srcRoot, srcAccounts := makeTestState() srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -148,9 +148,9 @@ func testIterativeStateSync(t *testing.T, batch int) {
for len(queue) > 0 { for len(queue) > 0 {
results := make([]trie.SyncResult, len(queue)) results := make([]trie.SyncResult, len(queue))
for i, hash := range queue { for i, hash := range queue {
data, err := srcMem.Get(hash.Bytes()) data, err := srcDb.TrieDB().Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x", hash)
} }
results[i] = trie.SyncResult{Hash: hash, Data: data} results[i] = trie.SyncResult{Hash: hash, Data: data}
} }
@ -170,7 +170,7 @@ func testIterativeStateSync(t *testing.T, batch int) {
// partial results are returned, and the others sent only later. // partial results are returned, and the others sent only later.
func TestIterativeDelayedStateSync(t *testing.T) { func TestIterativeDelayedStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
_, srcMem, srcRoot, srcAccounts := makeTestState() srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -181,9 +181,9 @@ func TestIterativeDelayedStateSync(t *testing.T) {
// Sync only half of the scheduled nodes // Sync only half of the scheduled nodes
results := make([]trie.SyncResult, len(queue)/2+1) results := make([]trie.SyncResult, len(queue)/2+1)
for i, hash := range queue[:len(results)] { for i, hash := range queue[:len(results)] {
data, err := srcMem.Get(hash.Bytes()) data, err := srcDb.TrieDB().Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x", hash)
} }
results[i] = trie.SyncResult{Hash: hash, Data: data} results[i] = trie.SyncResult{Hash: hash, Data: data}
} }
@ -207,7 +207,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS
func testIterativeRandomStateSync(t *testing.T, batch int) { func testIterativeRandomStateSync(t *testing.T, batch int) {
// Create a random state to copy // Create a random state to copy
_, srcMem, srcRoot, srcAccounts := makeTestState() srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -221,9 +221,9 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
// Fetch all the queued nodes in a random order // Fetch all the queued nodes in a random order
results := make([]trie.SyncResult, 0, len(queue)) results := make([]trie.SyncResult, 0, len(queue))
for hash := range queue { for hash := range queue {
data, err := srcMem.Get(hash.Bytes()) data, err := srcDb.TrieDB().Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x", hash)
} }
results = append(results, trie.SyncResult{Hash: hash, Data: data}) results = append(results, trie.SyncResult{Hash: hash, Data: data})
} }
@ -247,7 +247,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
// partial results are returned (Even those randomly), others sent only later. // partial results are returned (Even those randomly), others sent only later.
func TestIterativeRandomDelayedStateSync(t *testing.T) { func TestIterativeRandomDelayedStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
_, srcMem, srcRoot, srcAccounts := makeTestState() srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -263,9 +263,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
for hash := range queue { for hash := range queue {
delete(queue, hash) delete(queue, hash)
data, err := srcMem.Get(hash.Bytes()) data, err := srcDb.TrieDB().Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x", hash)
} }
results = append(results, trie.SyncResult{Hash: hash, Data: data}) results = append(results, trie.SyncResult{Hash: hash, Data: data})
@ -292,9 +292,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
// the database. // the database.
func TestIncompleteStateSync(t *testing.T) { func TestIncompleteStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
_, srcMem, srcRoot, srcAccounts := makeTestState() srcDb, srcRoot, srcAccounts := makeTestState()
checkTrieConsistency(srcMem, srcRoot) checkTrieConsistency(srcDb.TrieDB().DiskDB().(ethdb.Database), srcRoot)
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -306,9 +306,9 @@ func TestIncompleteStateSync(t *testing.T) {
// Fetch a batch of state nodes // Fetch a batch of state nodes
results := make([]trie.SyncResult, len(queue)) results := make([]trie.SyncResult, len(queue))
for i, hash := range queue { for i, hash := range queue {
data, err := srcMem.Get(hash.Bytes()) data, err := srcDb.TrieDB().Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x", hash)
} }
results[i] = trie.SyncResult{Hash: hash, Data: data} results[i] = trie.SyncResult{Hash: hash, Data: data}
} }

@ -78,8 +78,8 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec
} }
func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { func setupTxPool() (*TxPool, *ecdsa.PrivateKey) {
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb))
blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()

@ -25,6 +25,7 @@ import (
"sort" "sort"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
@ -121,6 +122,12 @@ func (h *Header) HashNoNonce() common.Hash {
}) })
} }
// Size returns the approximate memory used by all internal contents. It is used
// to approximate and limit the memory consumption of various caches.
func (h *Header) Size() common.StorageSize {
return common.StorageSize(unsafe.Sizeof(*h)) + common.StorageSize(len(h.Extra)+(h.Difficulty.BitLen()+h.Number.BitLen()+h.Time.BitLen())/8)
}
func rlpHash(x interface{}) (h common.Hash) { func rlpHash(x interface{}) (h common.Hash) {
hw := sha3.NewKeccak256() hw := sha3.NewKeccak256()
rlp.Encode(hw, x) rlp.Encode(hw, x)
@ -322,6 +329,8 @@ func (b *Block) HashNoNonce() common.Hash {
return b.header.HashNoNonce() return b.header.HashNoNonce()
} }
// Size returns the true RLP encoded storage size of the block, either by encoding
// and returning it, or returning a previsouly cached value.
func (b *Block) Size() common.StorageSize { func (b *Block) Size() common.StorageSize {
if size := b.size.Load(); size != nil { if size := b.size.Load(); size != nil {
return size.(common.StorageSize) return size.(common.StorageSize)

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"unsafe"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
@ -136,6 +137,18 @@ func (r *Receipt) statusEncoding() []byte {
return r.PostState return r.PostState
} }
// Size returns the approximate memory used by all internal contents. It is used
// to approximate and limit the memory consumption of various caches.
func (r *Receipt) Size() common.StorageSize {
size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState))
size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{}))
for _, log := range r.Logs {
size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data))
}
return size
}
// String implements the Stringer interface. // String implements the Stringer interface.
func (r *Receipt) String() string { func (r *Receipt) String() string {
if len(r.PostState) == 0 { if len(r.PostState) == 0 {

@ -206,6 +206,8 @@ func (tx *Transaction) Hash() common.Hash {
return v return v
} }
// Size returns the true RLP encoded storage size of the transaction, either by
// encoding and returning it, or returning a previsouly cached value.
func (tx *Transaction) Size() common.StorageSize { func (tx *Transaction) Size() common.StorageSize {
if size := tx.size.Load(); size != nil { if size := tx.size.Load(); size != nil {
return size.(common.StorageSize) return size.(common.StorageSize)

@ -462,11 +462,11 @@ func (api *PrivateDebugAPI) getModifiedAccounts(startBlock, endBlock *types.Bloc
return nil, fmt.Errorf("start block height (%d) must be less than end block height (%d)", startBlock.Number().Uint64(), endBlock.Number().Uint64()) return nil, fmt.Errorf("start block height (%d) must be less than end block height (%d)", startBlock.Number().Uint64(), endBlock.Number().Uint64())
} }
oldTrie, err := trie.NewSecure(startBlock.Root(), api.eth.chainDb, 0) oldTrie, err := trie.NewSecure(startBlock.Root(), trie.NewDatabase(api.eth.chainDb), 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
newTrie, err := trie.NewSecure(endBlock.Root(), api.eth.chainDb, 0) newTrie, err := trie.NewSecure(endBlock.Root(), trie.NewDatabase(api.eth.chainDb), 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -24,7 +24,6 @@ import (
"io/ioutil" "io/ioutil"
"runtime" "runtime"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -34,7 +33,6 @@ import (
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/tracers" "github.com/ethereum/go-ethereum/eth/tracers"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
@ -72,6 +70,7 @@ type txTraceResult struct {
type blockTraceTask struct { type blockTraceTask struct {
statedb *state.StateDB // Intermediate state prepped for tracing statedb *state.StateDB // Intermediate state prepped for tracing
block *types.Block // Block to trace the transactions from block *types.Block // Block to trace the transactions from
rootref common.Hash // Trie root reference held for this task
results []*txTraceResult // Trace results procudes by the task results []*txTraceResult // Trace results procudes by the task
} }
@ -90,59 +89,6 @@ type txTraceTask struct {
index int // Transaction offset in the block index int // Transaction offset in the block
} }
// ephemeralDatabase is a memory wrapper around a proper database, which acts as
// an ephemeral write layer. This construct is used by the chain tracer to write
// state tries for intermediate blocks without serializing to disk, but at the
// same time to allow disk fallback for reads that do no hit the memory layer.
type ephemeralDatabase struct {
diskdb ethdb.Database // Persistent disk database to fall back to with reads
memdb *ethdb.MemDatabase // Ephemeral memory database for primary reads and writes
}
func (db *ephemeralDatabase) Put(key []byte, value []byte) error { return db.memdb.Put(key, value) }
func (db *ephemeralDatabase) Delete(key []byte) error { return errors.New("delete not supported") }
func (db *ephemeralDatabase) Close() { db.memdb.Close() }
func (db *ephemeralDatabase) NewBatch() ethdb.Batch {
return db.memdb.NewBatch()
}
func (db *ephemeralDatabase) Has(key []byte) (bool, error) {
if has, _ := db.memdb.Has(key); has {
return has, nil
}
return db.diskdb.Has(key)
}
func (db *ephemeralDatabase) Get(key []byte) ([]byte, error) {
if blob, _ := db.memdb.Get(key); blob != nil {
return blob, nil
}
return db.diskdb.Get(key)
}
// Prune does a state sync into a new memory write layer and replaces the old one.
// This allows us to discard entries that are no longer referenced from the current
// state.
func (db *ephemeralDatabase) Prune(root common.Hash) {
// Pull the still relevant state data into memory
sync := state.NewStateSync(root, db.diskdb)
for sync.Pending() > 0 {
hash := sync.Missing(1)[0]
// Move the next trie node from the memory layer into a sync struct
node, err := db.memdb.Get(hash[:])
if err != nil {
panic(err) // memdb must have the data
}
if _, _, err := sync.Process([]trie.SyncResult{{Hash: hash, Data: node}}); err != nil {
panic(err) // it's not possible to fail processing a node
}
}
// Discard the old memory layer and write a new one
db.memdb, _ = ethdb.NewMemDatabaseWithCap(db.memdb.Len())
if _, err := sync.Commit(db); err != nil {
panic(err) // writing into a memdb cannot fail
}
}
// TraceChain returns the structured logs created during the execution of EVM // TraceChain returns the structured logs created during the execution of EVM
// between two blocks (excluding start) and returns them as a JSON object. // between two blocks (excluding start) and returns them as a JSON object.
func (api *PrivateDebugAPI) TraceChain(ctx context.Context, start, end rpc.BlockNumber, config *TraceConfig) (*rpc.Subscription, error) { func (api *PrivateDebugAPI) TraceChain(ctx context.Context, start, end rpc.BlockNumber, config *TraceConfig) (*rpc.Subscription, error) {
@ -188,19 +134,15 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
// Ensure we have a valid starting state before doing any work // Ensure we have a valid starting state before doing any work
origin := start.NumberU64() origin := start.NumberU64()
database := state.NewDatabase(api.eth.ChainDb())
memdb, _ := ethdb.NewMemDatabase()
db := &ephemeralDatabase{
diskdb: api.eth.ChainDb(),
memdb: memdb,
}
if number := start.NumberU64(); number > 0 { if number := start.NumberU64(); number > 0 {
start = api.eth.blockchain.GetBlock(start.ParentHash(), start.NumberU64()-1) start = api.eth.blockchain.GetBlock(start.ParentHash(), start.NumberU64()-1)
if start == nil { if start == nil {
return nil, fmt.Errorf("parent block #%d not found", number-1) return nil, fmt.Errorf("parent block #%d not found", number-1)
} }
} }
statedb, err := state.New(start.Root(), state.NewDatabase(db)) statedb, err := state.New(start.Root(), database)
if err != nil { if err != nil {
// If the starting state is missing, allow some number of blocks to be reexecuted // If the starting state is missing, allow some number of blocks to be reexecuted
reexec := defaultTraceReexec reexec := defaultTraceReexec
@ -213,7 +155,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
if start == nil { if start == nil {
break break
} }
if statedb, err = state.New(start.Root(), state.NewDatabase(db)); err == nil { if statedb, err = state.New(start.Root(), database); err == nil {
break break
} }
} }
@ -256,7 +198,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config)
if err != nil { if err != nil {
task.results[i] = &txTraceResult{Error: err.Error()} task.results[i] = &txTraceResult{Error: err.Error()}
log.Warn("Tracing failed", "err", err) log.Warn("Tracing failed", "hash", tx.Hash(), "block", task.block.NumberU64(), "err", err)
break break
} }
task.statedb.DeleteSuicides() task.statedb.DeleteSuicides()
@ -273,7 +215,6 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
} }
// Start a goroutine to feed all the blocks into the tracers // Start a goroutine to feed all the blocks into the tracers
begin := time.Now() begin := time.Now()
complete := start.NumberU64()
go func() { go func() {
var ( var (
@ -281,6 +222,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
number uint64 number uint64
traced uint64 traced uint64
failed error failed error
proot common.Hash
) )
// Ensure everything is properly cleaned up on any exit path // Ensure everything is properly cleaned up on any exit path
defer func() { defer func() {
@ -308,7 +250,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
// Print progress logs if long enough time elapsed // Print progress logs if long enough time elapsed
if time.Since(logged) > 8*time.Second { if time.Since(logged) > 8*time.Second {
if number > origin { if number > origin {
log.Info("Tracing chain segment", "start", origin, "end", end.NumberU64(), "current", number, "transactions", traced, "elapsed", time.Since(begin)) log.Info("Tracing chain segment", "start", origin, "end", end.NumberU64(), "current", number, "transactions", traced, "elapsed", time.Since(begin), "memory", database.TrieDB().Size())
} else { } else {
log.Info("Preparing state for chain trace", "block", number, "start", origin, "elapsed", time.Since(begin)) log.Info("Preparing state for chain trace", "block", number, "start", origin, "elapsed", time.Since(begin))
} }
@ -325,13 +267,11 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
txs := block.Transactions() txs := block.Transactions()
select { select {
case tasks <- &blockTraceTask{statedb: statedb.Copy(), block: block, results: make([]*txTraceResult, len(txs))}: case tasks <- &blockTraceTask{statedb: statedb.Copy(), block: block, rootref: proot, results: make([]*txTraceResult, len(txs))}:
case <-notifier.Closed(): case <-notifier.Closed():
return return
} }
traced += uint64(len(txs)) traced += uint64(len(txs))
} else {
atomic.StoreUint64(&complete, number)
} }
// Generate the next state snapshot fast without tracing // Generate the next state snapshot fast without tracing
_, _, _, err := api.eth.blockchain.Processor().Process(block, statedb, vm.Config{}) _, _, _, err := api.eth.blockchain.Processor().Process(block, statedb, vm.Config{})
@ -340,7 +280,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
break break
} }
// Finalize the state so any modifications are written to the trie // Finalize the state so any modifications are written to the trie
root, err := statedb.CommitTo(db, true) root, err := statedb.Commit(true)
if err != nil { if err != nil {
failed = err failed = err
break break
@ -349,26 +289,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
failed = err failed = err
break break
} }
// After every N blocks, prune the database to only retain relevant data // Reference the trie twice, once for us, once for the trancer
if (number-start.NumberU64())%4096 == 0 { database.TrieDB().Reference(root, common.Hash{})
// Wait until currently pending trace jobs finish if number >= origin {
for atomic.LoadUint64(&complete) != number { database.TrieDB().Reference(root, common.Hash{})
select {
case <-time.After(100 * time.Millisecond):
case <-notifier.Closed():
return
}
}
// No more concurrent access at this point, prune the database
var (
nodes = db.memdb.Len()
start = time.Now()
)
db.Prune(root)
log.Info("Pruned tracer state entries", "deleted", nodes-db.memdb.Len(), "left", db.memdb.Len(), "elapsed", time.Since(start))
statedb, _ = state.New(root, state.NewDatabase(db))
} }
// Dereference all past tries we ourselves are done working with
database.TrieDB().Dereference(proot, common.Hash{})
proot = root
} }
}() }()
@ -387,12 +315,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
} }
done[uint64(result.Block)] = result done[uint64(result.Block)] = result
// Dereference any paret tries held in memory by this task
database.TrieDB().Dereference(res.rootref, common.Hash{})
// Stream completed traces to the user, aborting on the first error // Stream completed traces to the user, aborting on the first error
for result, ok := done[next]; ok; result, ok = done[next] { for result, ok := done[next]; ok; result, ok = done[next] {
if len(result.Traces) > 0 || next == end.NumberU64() { if len(result.Traces) > 0 || next == end.NumberU64() {
notifier.Notify(sub.ID, result) notifier.Notify(sub.ID, result)
} }
atomic.StoreUint64(&complete, next)
delete(done, next) delete(done, next)
next++ next++
} }
@ -544,18 +474,14 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
} }
// Otherwise try to reexec blocks until we find a state or reach our limit // Otherwise try to reexec blocks until we find a state or reach our limit
origin := block.NumberU64() origin := block.NumberU64()
database := state.NewDatabase(api.eth.ChainDb())
memdb, _ := ethdb.NewMemDatabase()
db := &ephemeralDatabase{
diskdb: api.eth.ChainDb(),
memdb: memdb,
}
for i := uint64(0); i < reexec; i++ { for i := uint64(0); i < reexec; i++ {
block = api.eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1) block = api.eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1)
if block == nil { if block == nil {
break break
} }
if statedb, err = state.New(block.Root(), state.NewDatabase(db)); err == nil { if statedb, err = state.New(block.Root(), database); err == nil {
break break
} }
} }
@ -571,6 +497,7 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
var ( var (
start = time.Now() start = time.Now()
logged time.Time logged time.Time
proot common.Hash
) )
for block.NumberU64() < origin { for block.NumberU64() < origin {
// Print progress logs if long enough time elapsed // Print progress logs if long enough time elapsed
@ -587,26 +514,18 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
return nil, err return nil, err
} }
// Finalize the state so any modifications are written to the trie // Finalize the state so any modifications are written to the trie
root, err := statedb.CommitTo(db, true) root, err := statedb.Commit(true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := statedb.Reset(root); err != nil { if err := statedb.Reset(root); err != nil {
return nil, err return nil, err
} }
// After every N blocks, prune the database to only retain relevant data database.TrieDB().Reference(root, common.Hash{})
if block.NumberU64()%4096 == 0 || block.NumberU64() == origin { database.TrieDB().Dereference(proot, common.Hash{})
var ( proot = root
nodes = db.memdb.Len()
begin = time.Now()
)
db.Prune(root)
log.Info("Pruned tracer state entries", "deleted", nodes-db.memdb.Len(), "left", db.memdb.Len(), "elapsed", time.Since(begin))
statedb, _ = state.New(root, state.NewDatabase(db))
}
} }
log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start)) log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start), "size", database.TrieDB().Size())
return statedb, nil return statedb, nil
} }

@ -144,9 +144,11 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
} }
core.WriteBlockChainVersion(chainDb, core.BlockChainVersion) core.WriteBlockChainVersion(chainDb, core.BlockChainVersion)
} }
var (
vmConfig := vm.Config{EnablePreimageRecording: config.EnablePreimageRecording} vmConfig = vm.Config{EnablePreimageRecording: config.EnablePreimageRecording}
eth.blockchain, err = core.NewBlockChain(chainDb, eth.chainConfig, eth.engine, vmConfig) cacheConfig = &core.CacheConfig{Disabled: config.NoPruning, TrieNodeLimit: config.TrieCache, TrieTimeLimit: config.TrieTimeout}
)
eth.blockchain, err = core.NewBlockChain(chainDb, cacheConfig, eth.chainConfig, eth.engine, vmConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -22,6 +22,7 @@ import (
"os/user" "os/user"
"path/filepath" "path/filepath"
"runtime" "runtime"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
@ -44,7 +45,9 @@ var DefaultConfig = Config{
}, },
NetworkId: 1, NetworkId: 1,
LightPeers: 100, LightPeers: 100,
DatabaseCache: 128, DatabaseCache: 768,
TrieCache: 256,
TrieTimeout: 5 * time.Minute,
GasPrice: big.NewInt(18 * params.Shannon), GasPrice: big.NewInt(18 * params.Shannon),
TxPool: core.DefaultTxPoolConfig, TxPool: core.DefaultTxPoolConfig,
@ -78,6 +81,7 @@ type Config struct {
// Protocol options // Protocol options
NetworkId uint64 // Network ID to use for selecting peers to connect to NetworkId uint64 // Network ID to use for selecting peers to connect to
SyncMode downloader.SyncMode SyncMode downloader.SyncMode
NoPruning bool
// Light client options // Light client options
LightServ int `toml:",omitempty"` // Maximum percentage of time allowed for serving LES requests LightServ int `toml:",omitempty"` // Maximum percentage of time allowed for serving LES requests
@ -87,6 +91,8 @@ type Config struct {
SkipBcVersionCheck bool `toml:"-"` SkipBcVersionCheck bool `toml:"-"`
DatabaseHandles int `toml:"-"` DatabaseHandles int `toml:"-"`
DatabaseCache int DatabaseCache int
TrieCache int
TrieTimeout time.Duration
// Mining-related options // Mining-related options
Etherbase common.Address `toml:",omitempty"` Etherbase common.Address `toml:",omitempty"`

@ -18,10 +18,8 @@
package downloader package downloader
import ( import (
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"math"
"math/big" "math/big"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -64,9 +62,8 @@ var (
fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync
fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected
fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it
fsPivotInterval = 256 // Number of headers out of which to randomize the pivot point fsHeaderContCheck = 3 * time.Second // Time interval to check for header continuations during state download
fsMinFullBlocks = 64 // Number of blocks to retrieve fully even in fast sync fsMinFullBlocks = 64 // Number of blocks to retrieve fully even in fast sync
fsCriticalTrials = uint32(32) // Number of times to retry in the cricical section before bailing
) )
var ( var (
@ -102,9 +99,6 @@ type Downloader struct {
peers *peerSet // Set of active peers from which download can proceed peers *peerSet // Set of active peers from which download can proceed
stateDB ethdb.Database stateDB ethdb.Database
fsPivotLock *types.Header // Pivot header on critical section entry (cannot change between retries)
fsPivotFails uint32 // Number of subsequent fast sync failures in the critical section
rttEstimate uint64 // Round trip time to target for download requests rttEstimate uint64 // Round trip time to target for download requests
rttConfidence uint64 // Confidence in the estimated RTT (unit: millionths to allow atomic ops) rttConfidence uint64 // Confidence in the estimated RTT (unit: millionths to allow atomic ops)
@ -124,6 +118,7 @@ type Downloader struct {
synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing
synchronising int32 synchronising int32
notified int32 notified int32
committed int32
// Channels // Channels
headerCh chan dataPack // [eth/62] Channel receiving inbound block headers headerCh chan dataPack // [eth/62] Channel receiving inbound block headers
@ -156,7 +151,7 @@ type Downloader struct {
// LightChain encapsulates functions required to synchronise a light chain. // LightChain encapsulates functions required to synchronise a light chain.
type LightChain interface { type LightChain interface {
// HasHeader verifies a header's presence in the local chain. // HasHeader verifies a header's presence in the local chain.
HasHeader(h common.Hash, number uint64) bool HasHeader(common.Hash, uint64) bool
// GetHeaderByHash retrieves a header from the local chain. // GetHeaderByHash retrieves a header from the local chain.
GetHeaderByHash(common.Hash) *types.Header GetHeaderByHash(common.Hash) *types.Header
@ -179,7 +174,7 @@ type BlockChain interface {
LightChain LightChain
// HasBlockAndState verifies block and associated states' presence in the local chain. // HasBlockAndState verifies block and associated states' presence in the local chain.
HasBlockAndState(common.Hash) bool HasBlockAndState(common.Hash, uint64) bool
// GetBlockByHash retrieves a block from the local chain. // GetBlockByHash retrieves a block from the local chain.
GetBlockByHash(common.Hash) *types.Block GetBlockByHash(common.Hash) *types.Block
@ -391,9 +386,7 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode
// Set the requested sync mode, unless it's forbidden // Set the requested sync mode, unless it's forbidden
d.mode = mode d.mode = mode
if d.mode == FastSync && atomic.LoadUint32(&d.fsPivotFails) >= fsCriticalTrials {
d.mode = FullSync
}
// Retrieve the origin peer and initiate the downloading process // Retrieve the origin peer and initiate the downloading process
p := d.peers.Peer(id) p := d.peers.Peer(id)
if p == nil { if p == nil {
@ -441,57 +434,40 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
d.syncStatsChainHeight = height d.syncStatsChainHeight = height
d.syncStatsLock.Unlock() d.syncStatsLock.Unlock()
// Initiate the sync using a concurrent header and content retrieval algorithm // Ensure our origin point is below any fast sync pivot point
pivot := uint64(0) pivot := uint64(0)
switch d.mode { if d.mode == FastSync {
case LightSync: if height <= uint64(fsMinFullBlocks) {
pivot = height origin = 0
case FastSync:
// Calculate the new fast/slow sync pivot point
if d.fsPivotLock == nil {
pivotOffset, err := rand.Int(rand.Reader, big.NewInt(int64(fsPivotInterval)))
if err != nil {
panic(fmt.Sprintf("Failed to access crypto random source: %v", err))
}
if height > uint64(fsMinFullBlocks)+pivotOffset.Uint64() {
pivot = height - uint64(fsMinFullBlocks) - pivotOffset.Uint64()
}
} else { } else {
// Pivot point locked in, use this and do not pick a new one! pivot = height - uint64(fsMinFullBlocks)
pivot = d.fsPivotLock.Number.Uint64() if pivot <= origin {
}
// If the point is below the origin, move origin back to ensure state download
if pivot < origin {
if pivot > 0 {
origin = pivot - 1 origin = pivot - 1
} else {
origin = 0
} }
} }
log.Debug("Fast syncing until pivot block", "pivot", pivot)
} }
d.queue.Prepare(origin+1, d.mode, pivot, latest) d.committed = 1
if d.mode == FastSync && pivot != 0 {
d.committed = 0
}
// Initiate the sync using a concurrent header and content retrieval algorithm
d.queue.Prepare(origin+1, d.mode)
if d.syncInitHook != nil { if d.syncInitHook != nil {
d.syncInitHook(origin, height) d.syncInitHook(origin, height)
} }
fetchers := []func() error{ fetchers := []func() error{
func() error { return d.fetchHeaders(p, origin+1) }, // Headers are always retrieved func() error { return d.fetchHeaders(p, origin+1, pivot) }, // Headers are always retrieved
func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync
func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync
func() error { return d.processHeaders(origin+1, td) }, func() error { return d.processHeaders(origin+1, pivot, td) },
} }
if d.mode == FastSync { if d.mode == FastSync {
fetchers = append(fetchers, func() error { return d.processFastSyncContent(latest) }) fetchers = append(fetchers, func() error { return d.processFastSyncContent(latest) })
} else if d.mode == FullSync { } else if d.mode == FullSync {
fetchers = append(fetchers, d.processFullSyncContent) fetchers = append(fetchers, d.processFullSyncContent)
} }
err = d.spawnSync(fetchers) return d.spawnSync(fetchers)
if err != nil && d.mode == FastSync && d.fsPivotLock != nil {
// If sync failed in the critical section, bump the fail counter.
atomic.AddUint32(&d.fsPivotFails, 1)
}
return err
} }
// spawnSync runs d.process and all given fetcher functions to completion in // spawnSync runs d.process and all given fetcher functions to completion in
@ -671,7 +647,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
continue continue
} }
// Otherwise check if we already know the header or not // Otherwise check if we already know the header or not
if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash(), headers[i].Number.Uint64())) { if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash(), headers[i].Number.Uint64())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash(), headers[i].Number.Uint64())) {
number, hash = headers[i].Number.Uint64(), headers[i].Hash() number, hash = headers[i].Number.Uint64(), headers[i].Hash()
// If every header is known, even future ones, the peer straight out lied about its head // If every header is known, even future ones, the peer straight out lied about its head
@ -736,7 +712,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
arrived = true arrived = true
// Modify the search interval based on the response // Modify the search interval based on the response
if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash(), headers[0].Number.Uint64())) { if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash(), headers[0].Number.Uint64())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash(), headers[0].Number.Uint64())) {
end = check end = check
break break
} }
@ -774,7 +750,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
// other peers are only accepted if they map cleanly to the skeleton. If no one // other peers are only accepted if they map cleanly to the skeleton. If no one
// can fill in the skeleton - not even the origin peer - it's assumed invalid and // can fill in the skeleton - not even the origin peer - it's assumed invalid and
// the origin is dropped. // the origin is dropped.
func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error { func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, pivot uint64) error {
p.log.Debug("Directing header downloads", "origin", from) p.log.Debug("Directing header downloads", "origin", from)
defer p.log.Debug("Header download terminated") defer p.log.Debug("Header download terminated")
@ -825,6 +801,18 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error {
} }
// If no more headers are inbound, notify the content fetchers and return // If no more headers are inbound, notify the content fetchers and return
if packet.Items() == 0 { if packet.Items() == 0 {
// Don't abort header fetches while the pivot is downloading
if atomic.LoadInt32(&d.committed) == 0 && pivot <= from {
p.log.Debug("No headers, waiting for pivot commit")
select {
case <-time.After(fsHeaderContCheck):
getHeaders(from)
continue
case <-d.cancelCh:
return errCancelHeaderFetch
}
}
// Pivot done (or not in fast sync) and no more headers, terminate the process
p.log.Debug("No more headers available") p.log.Debug("No more headers available")
select { select {
case d.headerProcCh <- nil: case d.headerProcCh <- nil:
@ -1129,10 +1117,8 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
} }
if request.From > 0 { if request.From > 0 {
peer.log.Trace("Requesting new batch of data", "type", kind, "from", request.From) peer.log.Trace("Requesting new batch of data", "type", kind, "from", request.From)
} else if len(request.Headers) > 0 {
peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number)
} else { } else {
peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Hashes)) peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number)
} }
// Fetch the chunk and make sure any errors return the hashes to the queue // Fetch the chunk and make sure any errors return the hashes to the queue
if fetchHook != nil { if fetchHook != nil {
@ -1160,10 +1146,7 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
// processHeaders takes batches of retrieved headers from an input channel and // processHeaders takes batches of retrieved headers from an input channel and
// keeps processing and scheduling them into the header chain and downloader's // keeps processing and scheduling them into the header chain and downloader's
// queue until the stream ends or a failure occurs. // queue until the stream ends or a failure occurs.
func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) error {
// Calculate the pivoting point for switching from fast to slow sync
pivot := d.queue.FastSyncPivot()
// Keep a count of uncertain headers to roll back // Keep a count of uncertain headers to roll back
rollback := []*types.Header{} rollback := []*types.Header{}
defer func() { defer func() {
@ -1188,19 +1171,6 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
"header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number), "header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number),
"fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock), "fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock),
"block", fmt.Sprintf("%d->%d", lastBlock, curBlock)) "block", fmt.Sprintf("%d->%d", lastBlock, curBlock))
// If we're already past the pivot point, this could be an attack, thread carefully
if rollback[len(rollback)-1].Number.Uint64() > pivot {
// If we didn't ever fail, lock in the pivot header (must! not! change!)
if atomic.LoadUint32(&d.fsPivotFails) == 0 {
for _, header := range rollback {
if header.Number.Uint64() == pivot {
log.Warn("Fast-sync pivot locked in", "number", pivot, "hash", header.Hash())
d.fsPivotLock = header
}
}
}
}
} }
}() }()
@ -1302,13 +1272,6 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...) rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...)
} }
} }
// If we're fast syncing and just pulled in the pivot, make sure it's the one locked in
if d.mode == FastSync && d.fsPivotLock != nil && chunk[0].Number.Uint64() <= pivot && chunk[len(chunk)-1].Number.Uint64() >= pivot {
if pivot := chunk[int(pivot-chunk[0].Number.Uint64())]; pivot.Hash() != d.fsPivotLock.Hash() {
log.Warn("Pivot doesn't match locked in one", "remoteNumber", pivot.Number, "remoteHash", pivot.Hash(), "localNumber", d.fsPivotLock.Number, "localHash", d.fsPivotLock.Hash())
return errInvalidChain
}
}
// Unless we're doing light chains, schedule the headers for associated content retrieval // Unless we're doing light chains, schedule the headers for associated content retrieval
if d.mode == FullSync || d.mode == FastSync { if d.mode == FullSync || d.mode == FastSync {
// If we've reached the allowed number of pending headers, stall a bit // If we've reached the allowed number of pending headers, stall a bit
@ -1343,7 +1306,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
// processFullSyncContent takes fetch results from the queue and imports them into the chain. // processFullSyncContent takes fetch results from the queue and imports them into the chain.
func (d *Downloader) processFullSyncContent() error { func (d *Downloader) processFullSyncContent() error {
for { for {
results := d.queue.WaitResults() results := d.queue.Results(true)
if len(results) == 0 { if len(results) == 0 {
return nil return nil
} }
@ -1357,66 +1320,121 @@ func (d *Downloader) processFullSyncContent() error {
} }
func (d *Downloader) importBlockResults(results []*fetchResult) error { func (d *Downloader) importBlockResults(results []*fetchResult) error {
for len(results) != 0 { // Check for any early termination requests
// Check for any termination requests. This makes clean shutdown faster. if len(results) == 0 {
return nil
}
select { select {
case <-d.quitCh: case <-d.quitCh:
return errCancelContentProcessing return errCancelContentProcessing
default: default:
} }
// Retrieve the a batch of results to import // Retrieve the a batch of results to import
items := int(math.Min(float64(len(results)), float64(maxResultsProcess))) first, last := results[0].Header, results[len(results)-1].Header
first, last := results[0].Header, results[items-1].Header
log.Debug("Inserting downloaded chain", "items", len(results), log.Debug("Inserting downloaded chain", "items", len(results),
"firstnum", first.Number, "firsthash", first.Hash(), "firstnum", first.Number, "firsthash", first.Hash(),
"lastnum", last.Number, "lasthash", last.Hash(), "lastnum", last.Number, "lasthash", last.Hash(),
) )
blocks := make([]*types.Block, items) blocks := make([]*types.Block, len(results))
for i, result := range results[:items] { for i, result := range results {
blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
} }
if index, err := d.blockchain.InsertChain(blocks); err != nil { if index, err := d.blockchain.InsertChain(blocks); err != nil {
log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
return errInvalidChain return errInvalidChain
} }
// Shift the results to the next batch
results = results[items:]
}
return nil return nil
} }
// processFastSyncContent takes fetch results from the queue and writes them to the // processFastSyncContent takes fetch results from the queue and writes them to the
// database. It also controls the synchronisation of state nodes of the pivot block. // database. It also controls the synchronisation of state nodes of the pivot block.
func (d *Downloader) processFastSyncContent(latest *types.Header) error { func (d *Downloader) processFastSyncContent(latest *types.Header) error {
// Start syncing state of the reported head block. // Start syncing state of the reported head block. This should get us most of
// This should get us most of the state of the pivot block. // the state of the pivot block.
stateSync := d.syncState(latest.Root) stateSync := d.syncState(latest.Root)
defer stateSync.Cancel() defer stateSync.Cancel()
go func() { go func() {
if err := stateSync.Wait(); err != nil { if err := stateSync.Wait(); err != nil && err != errCancelStateFetch {
d.queue.Close() // wake up WaitResults d.queue.Close() // wake up WaitResults
} }
}() }()
// Figure out the ideal pivot block. Note, that this goalpost may move if the
pivot := d.queue.FastSyncPivot() // sync takes long enough for the chain head to move significantly.
pivot := uint64(0)
if height := latest.Number.Uint64(); height > uint64(fsMinFullBlocks) {
pivot = height - uint64(fsMinFullBlocks)
}
// To cater for moving pivot points, track the pivot block and subsequently
// accumulated download results separatey.
var (
oldPivot *fetchResult // Locked in pivot block, might change eventually
oldTail []*fetchResult // Downloaded content after the pivot
)
for { for {
results := d.queue.WaitResults() // Wait for the next batch of downloaded data to be available, and if the pivot
// block became stale, move the goalpost
results := d.queue.Results(oldPivot == nil) // Block if we're not monitoring pivot staleness
if len(results) == 0 { if len(results) == 0 {
// If pivot sync is done, stop
if oldPivot == nil {
return stateSync.Cancel() return stateSync.Cancel()
} }
// If sync failed, stop
select {
case <-d.cancelCh:
return stateSync.Cancel()
default:
}
}
if d.chainInsertHook != nil { if d.chainInsertHook != nil {
d.chainInsertHook(results) d.chainInsertHook(results)
} }
if oldPivot != nil {
results = append(append([]*fetchResult{oldPivot}, oldTail...), results...)
}
// Split around the pivot block and process the two sides via fast/full sync
if atomic.LoadInt32(&d.committed) == 0 {
latest = results[len(results)-1].Header
if height := latest.Number.Uint64(); height > pivot+2*uint64(fsMinFullBlocks) {
log.Warn("Pivot became stale, moving", "old", pivot, "new", height-uint64(fsMinFullBlocks))
pivot = height - uint64(fsMinFullBlocks)
}
}
P, beforeP, afterP := splitAroundPivot(pivot, results) P, beforeP, afterP := splitAroundPivot(pivot, results)
if err := d.commitFastSyncData(beforeP, stateSync); err != nil { if err := d.commitFastSyncData(beforeP, stateSync); err != nil {
return err return err
} }
if P != nil { if P != nil {
// If new pivot block found, cancel old state retrieval and restart
if oldPivot != P {
stateSync.Cancel() stateSync.Cancel()
stateSync = d.syncState(P.Header.Root)
defer stateSync.Cancel()
go func() {
if err := stateSync.Wait(); err != nil && err != errCancelStateFetch {
d.queue.Close() // wake up WaitResults
}
}()
oldPivot = P
}
// Wait for completion, occasionally checking for pivot staleness
select {
case <-stateSync.done:
if stateSync.err != nil {
return stateSync.err
}
if err := d.commitPivotBlock(P); err != nil { if err := d.commitPivotBlock(P); err != nil {
return err return err
} }
oldPivot = nil
case <-time.After(time.Second):
oldTail = afterP
continue
} }
}
// Fast sync done, pivot commit done, full import
if err := d.importBlockResults(afterP); err != nil { if err := d.importBlockResults(afterP); err != nil {
return err return err
} }
@ -1439,8 +1457,10 @@ func splitAroundPivot(pivot uint64, results []*fetchResult) (p *fetchResult, bef
} }
func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *stateSync) error { func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *stateSync) error {
for len(results) != 0 { // Check for any early termination requests
// Check for any termination requests. if len(results) == 0 {
return nil
}
select { select {
case <-d.quitCh: case <-d.quitCh:
return errCancelContentProcessing return errCancelContentProcessing
@ -1451,15 +1471,14 @@ func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *state
default: default:
} }
// Retrieve the a batch of results to import // Retrieve the a batch of results to import
items := int(math.Min(float64(len(results)), float64(maxResultsProcess))) first, last := results[0].Header, results[len(results)-1].Header
first, last := results[0].Header, results[items-1].Header
log.Debug("Inserting fast-sync blocks", "items", len(results), log.Debug("Inserting fast-sync blocks", "items", len(results),
"firstnum", first.Number, "firsthash", first.Hash(), "firstnum", first.Number, "firsthash", first.Hash(),
"lastnumn", last.Number, "lasthash", last.Hash(), "lastnumn", last.Number, "lasthash", last.Hash(),
) )
blocks := make([]*types.Block, items) blocks := make([]*types.Block, len(results))
receipts := make([]types.Receipts, items) receipts := make([]types.Receipts, len(results))
for i, result := range results[:items] { for i, result := range results {
blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
receipts[i] = result.Receipts receipts[i] = result.Receipts
} }
@ -1467,24 +1486,20 @@ func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *state
log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
return errInvalidChain return errInvalidChain
} }
// Shift the results to the next batch
results = results[items:]
}
return nil return nil
} }
func (d *Downloader) commitPivotBlock(result *fetchResult) error { func (d *Downloader) commitPivotBlock(result *fetchResult) error {
b := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) block := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
// Sync the pivot block state. This should complete reasonably quickly because log.Debug("Committing fast sync pivot as new head", "number", block.Number(), "hash", block.Hash())
// we've already synced up to the reported head block state earlier. if _, err := d.blockchain.InsertReceiptChain([]*types.Block{block}, []types.Receipts{result.Receipts}); err != nil {
if err := d.syncState(b.Root()).Wait(); err != nil {
return err return err
} }
log.Debug("Committing fast sync pivot as new head", "number", b.Number(), "hash", b.Hash()) if err := d.blockchain.FastSyncCommitHead(block.Hash()); err != nil {
if _, err := d.blockchain.InsertReceiptChain([]*types.Block{b}, []types.Receipts{result.Receipts}); err != nil {
return err return err
} }
return d.blockchain.FastSyncCommitHead(b.Hash()) atomic.StoreInt32(&d.committed, 1)
return nil
} }
// DeliverHeaders injects a new batch of block headers received from a remote // DeliverHeaders injects a new batch of block headers received from a remote

@ -28,7 +28,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
@ -45,8 +44,8 @@ var (
// Reduce some of the parameters to make the tester faster. // Reduce some of the parameters to make the tester faster.
func init() { func init() {
MaxForkAncestry = uint64(10000) MaxForkAncestry = uint64(10000)
blockCacheLimit = 1024 blockCacheItems = 1024
fsCriticalTrials = 10 fsHeaderContCheck = 500 * time.Millisecond
} }
// downloadTester is a test simulator for mocking out local block chain. // downloadTester is a test simulator for mocking out local block chain.
@ -223,7 +222,7 @@ func (dl *downloadTester) HasHeader(hash common.Hash, number uint64) bool {
} }
// HasBlockAndState checks if a block and associated state is present in the testers canonical chain. // HasBlockAndState checks if a block and associated state is present in the testers canonical chain.
func (dl *downloadTester) HasBlockAndState(hash common.Hash) bool { func (dl *downloadTester) HasBlockAndState(hash common.Hash, number uint64) bool {
block := dl.GetBlockByHash(hash) block := dl.GetBlockByHash(hash)
if block == nil { if block == nil {
return false return false
@ -293,7 +292,7 @@ func (dl *downloadTester) CurrentFastBlock() *types.Block {
func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error { func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error {
// For now only check that the state trie is correct // For now only check that the state trie is correct
if block := dl.GetBlockByHash(hash); block != nil { if block := dl.GetBlockByHash(hash); block != nil {
_, err := trie.NewSecure(block.Root(), dl.stateDb, 0) _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb), 0)
return err return err
} }
return fmt.Errorf("non existent block: %x", hash[:4]) return fmt.Errorf("non existent block: %x", hash[:4])
@ -619,28 +618,22 @@ func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
// number of items of the various chain components. // number of items of the various chain components.
func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) { func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) {
// Initialize the counters for the first fork // Initialize the counters for the first fork
headers, blocks := lengths[0], lengths[0] headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-fsMinFullBlocks
minReceipts, maxReceipts := lengths[0]-fsMinFullBlocks-fsPivotInterval, lengths[0]-fsMinFullBlocks if receipts < 0 {
if minReceipts < 0 { receipts = 1
minReceipts = 1
}
if maxReceipts < 0 {
maxReceipts = 1
} }
// Update the counters for each subsequent fork // Update the counters for each subsequent fork
for _, length := range lengths[1:] { for _, length := range lengths[1:] {
headers += length - common headers += length - common
blocks += length - common blocks += length - common
receipts += length - common - fsMinFullBlocks
minReceipts += length - common - fsMinFullBlocks - fsPivotInterval
maxReceipts += length - common - fsMinFullBlocks
} }
switch tester.downloader.mode { switch tester.downloader.mode {
case FullSync: case FullSync:
minReceipts, maxReceipts = 1, 1 receipts = 1
case LightSync: case LightSync:
blocks, minReceipts, maxReceipts = 1, 1, 1 blocks, receipts = 1, 1
} }
if hs := len(tester.ownHeaders); hs != headers { if hs := len(tester.ownHeaders); hs != headers {
t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers) t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers)
@ -648,11 +641,12 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
if bs := len(tester.ownBlocks); bs != blocks { if bs := len(tester.ownBlocks); bs != blocks {
t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks) t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks)
} }
if rs := len(tester.ownReceipts); rs < minReceipts || rs > maxReceipts { if rs := len(tester.ownReceipts); rs != receipts {
t.Fatalf("synchronised receipts mismatch: have %v, want between [%v, %v]", rs, minReceipts, maxReceipts) t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts)
} }
// Verify the state trie too for fast syncs // Verify the state trie too for fast syncs
if tester.downloader.mode == FastSync { /*if tester.downloader.mode == FastSync {
pivot := uint64(0)
var index int var index int
if pivot := int(tester.downloader.queue.fastSyncPivot); pivot < common { if pivot := int(tester.downloader.queue.fastSyncPivot); pivot < common {
index = pivot index = pivot
@ -660,11 +654,11 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot) index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot)
} }
if index > 0 { if index > 0 {
if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(tester.stateDb)); statedb == nil || err != nil { if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(trie.NewDatabase(tester.stateDb))); statedb == nil || err != nil {
t.Fatalf("state reconstruction failed: %v", err) t.Fatalf("state reconstruction failed: %v", err)
} }
} }
} }*/
} }
// Tests that simple synchronization against a canonical chain works correctly. // Tests that simple synchronization against a canonical chain works correctly.
@ -684,7 +678,7 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@ -710,7 +704,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a long block chain to download and the tester // Create a long block chain to download and the tester
targetBlocks := 8 * blockCacheLimit targetBlocks := 8 * blockCacheItems
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@ -745,9 +739,9 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
cached = len(tester.downloader.queue.blockDonePool) cached = len(tester.downloader.queue.blockDonePool)
if mode == FastSync { if mode == FastSync {
if receipts := len(tester.downloader.queue.receiptDonePool); receipts < cached { if receipts := len(tester.downloader.queue.receiptDonePool); receipts < cached {
if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot { //if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot {
cached = receipts cached = receipts
} //}
} }
} }
frozen = int(atomic.LoadUint32(&blocked)) frozen = int(atomic.LoadUint32(&blocked))
@ -755,7 +749,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
tester.downloader.queue.lock.Unlock() tester.downloader.queue.lock.Unlock()
tester.lock.Unlock() tester.lock.Unlock()
if cached == blockCacheLimit || retrieved+cached+frozen == targetBlocks+1 { if cached == blockCacheItems || retrieved+cached+frozen == targetBlocks+1 {
break break
} }
} }
@ -765,8 +759,8 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
tester.lock.RLock() tester.lock.RLock()
retrieved = len(tester.ownBlocks) retrieved = len(tester.ownBlocks)
tester.lock.RUnlock() tester.lock.RUnlock()
if cached != blockCacheLimit && retrieved+cached+frozen != targetBlocks+1 { if cached != blockCacheItems && retrieved+cached+frozen != targetBlocks+1 {
t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheLimit, retrieved, frozen, targetBlocks+1) t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheItems, retrieved, frozen, targetBlocks+1)
} }
// Permit the blocked blocks to import // Permit the blocked blocks to import
if atomic.LoadUint32(&blocked) > 0 { if atomic.LoadUint32(&blocked) > 0 {
@ -974,7 +968,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small enough block chain to download and the tester // Create a small enough block chain to download and the tester
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheItems - 15
if targetBlocks >= MaxHashFetch { if targetBlocks >= MaxHashFetch {
targetBlocks = MaxHashFetch - 15 targetBlocks = MaxHashFetch - 15
} }
@ -1016,12 +1010,12 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
// Create various peers with various parts of the chain // Create various peers with various parts of the chain
targetPeers := 8 targetPeers := 8
targetBlocks := targetPeers*blockCacheLimit - 15 targetBlocks := targetPeers*blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
for i := 0; i < targetPeers; i++ { for i := 0; i < targetPeers; i++ {
id := fmt.Sprintf("peer #%d", i) id := fmt.Sprintf("peer #%d", i)
tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts) tester.newPeer(id, protocol, hashes[i*blockCacheItems:], headers, blocks, receipts)
} }
if err := tester.sync("peer #0", nil, mode); err != nil { if err := tester.sync("peer #0", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err) t.Fatalf("failed to synchronise blocks: %v", err)
@ -1045,7 +1039,7 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Create peers of every type // Create peers of every type
@ -1084,7 +1078,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a block chain to download // Create a block chain to download
targetBlocks := 2*blockCacheLimit - 15 targetBlocks := 2*blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@ -1110,8 +1104,8 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
bodiesNeeded++ bodiesNeeded++
} }
} }
for hash, receipt := range receipts { for _, receipt := range receipts {
if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= tester.downloader.queue.fastSyncPivot { if mode == FastSync && len(receipt) > 0 {
receiptsNeeded++ receiptsNeeded++
} }
} }
@ -1139,7 +1133,7 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Attempt a full sync with an attacker feeding gapped headers // Attempt a full sync with an attacker feeding gapped headers
@ -1174,7 +1168,7 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Attempt a full sync with an attacker feeding shifted headers // Attempt a full sync with an attacker feeding shifted headers
@ -1208,7 +1202,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := 3*fsHeaderSafetyNet + fsPivotInterval + fsMinFullBlocks targetBlocks := 3*fsHeaderSafetyNet + 256 + fsMinFullBlocks
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Attempt to sync with an attacker that feeds junk during the fast sync phase. // Attempt to sync with an attacker that feeds junk during the fast sync phase.
@ -1248,7 +1242,6 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts) tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts)
missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1 missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1
tester.downloader.fsPivotFails = 0
tester.downloader.syncInitHook = func(uint64, uint64) { tester.downloader.syncInitHook = func(uint64, uint64) {
for i := missing; i <= len(hashes); i++ { for i := missing; i <= len(hashes); i++ {
delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i]) delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i])
@ -1267,8 +1260,6 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
t.Errorf("fast sync pivot block #%d not rolled back", head) t.Errorf("fast sync pivot block #%d not rolled back", head)
} }
} }
tester.downloader.fsPivotFails = fsCriticalTrials
// Synchronise with the valid peer and make sure sync succeeds. Since the last // Synchronise with the valid peer and make sure sync succeeds. Since the last
// rollback should also disable fast syncing for this process, verify that we // rollback should also disable fast syncing for this process, verify that we
// did a fresh full sync. Note, we can't assert anything about the receipts // did a fresh full sync. Note, we can't assert anything about the receipts
@ -1383,7 +1374,7 @@ func testSyncProgress(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Set a sync init hook to catch progress changes // Set a sync init hook to catch progress changes
@ -1532,7 +1523,7 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small enough block chain to download // Create a small enough block chain to download
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Set a sync init hook to catch progress changes // Set a sync init hook to catch progress changes
@ -1609,7 +1600,7 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
defer tester.terminate() defer tester.terminate()
// Create a small block chain // Create a small block chain
targetBlocks := blockCacheLimit - 15 targetBlocks := blockCacheItems - 15
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks+3, 0, tester.genesis, nil, false) hashes, headers, blocks, receipts := tester.makeChain(targetBlocks+3, 0, tester.genesis, nil, false)
// Set a sync init hook to catch progress changes // Set a sync init hook to catch progress changes
@ -1697,6 +1688,7 @@ func TestDeliverHeadersHang(t *testing.T) {
type floodingTestPeer struct { type floodingTestPeer struct {
peer Peer peer Peer
tester *downloadTester tester *downloadTester
pend sync.WaitGroup
} }
func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() } func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() }
@ -1717,9 +1709,12 @@ func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int
deliveriesDone := make(chan struct{}, 500) deliveriesDone := make(chan struct{}, 500)
for i := 0; i < cap(deliveriesDone); i++ { for i := 0; i < cap(deliveriesDone); i++ {
peer := fmt.Sprintf("fake-peer%d", i) peer := fmt.Sprintf("fake-peer%d", i)
ftp.pend.Add(1)
go func() { go func() {
ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}}) ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}})
deliveriesDone <- struct{}{} deliveriesDone <- struct{}{}
ftp.pend.Done()
}() }()
} }
// Deliver the actual requested headers. // Deliver the actual requested headers.
@ -1751,110 +1746,15 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
// Whenever the downloader requests headers, flood it with // Whenever the downloader requests headers, flood it with
// a lot of unrequested header deliveries. // a lot of unrequested header deliveries.
tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{ tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{
tester.downloader.peers.peers["peer"].peer, peer: tester.downloader.peers.peers["peer"].peer,
tester, tester: tester,
} }
if err := tester.sync("peer", nil, mode); err != nil { if err := tester.sync("peer", nil, mode); err != nil {
t.Errorf("sync failed: %v", err) t.Errorf("test %d: sync failed: %v", i, err)
} }
tester.terminate() tester.terminate()
}
}
// Tests that if fast sync aborts in the critical section, it can restart a few
// times before giving up.
// We use data driven subtests to manage this so that it will be parallel on its own
// and not with the other tests, avoiding intermittent failures.
func TestFastCriticalRestarts(t *testing.T) {
testCases := []struct {
protocol int
progress bool
}{
{63, false},
{64, false},
{63, true},
{64, true},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("protocol %d progress %v", tc.protocol, tc.progress), func(t *testing.T) {
testFastCriticalRestarts(t, tc.protocol, tc.progress)
})
}
}
func testFastCriticalRestarts(t *testing.T, protocol int, progress bool) {
t.Parallel()
tester := newTester()
defer tester.terminate()
// Create a large enough blockchin to actually fast sync on // Flush all goroutines to prevent messing with subsequent tests
targetBlocks := fsMinFullBlocks + 2*fsPivotInterval - 15 tester.downloader.peers.peers["peer"].peer.(*floodingTestPeer).pend.Wait()
hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
// Create a tester peer with a critical section header missing (force failures)
tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
delete(tester.peerHeaders["peer"], hashes[fsMinFullBlocks-1])
tester.downloader.dropPeer = func(id string) {} // We reuse the same "faulty" peer throughout the test
// Remove all possible pivot state roots and slow down replies (test failure resets later)
for i := 0; i < fsPivotInterval; i++ {
tester.peerMissingStates["peer"][headers[hashes[fsMinFullBlocks+i]].Root] = true
}
(tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).setDelay(500 * time.Millisecond) // Enough to reach the critical section
// Synchronise with the peer a few times and make sure they fail until the retry limit
for i := 0; i < int(fsCriticalTrials)-1; i++ {
// Attempt a sync and ensure it fails properly
if err := tester.sync("peer", nil, FastSync); err == nil {
t.Fatalf("failing fast sync succeeded: %v", err)
}
time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
// If it's the first failure, pivot should be locked => reenable all others to detect pivot changes
if i == 0 {
time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
if tester.downloader.fsPivotLock == nil {
time.Sleep(400 * time.Millisecond) // Make sure the first huge timeout expires too
t.Fatalf("pivot block not locked in after critical section failure")
}
tester.lock.Lock()
tester.peerHeaders["peer"][hashes[fsMinFullBlocks-1]] = headers[hashes[fsMinFullBlocks-1]]
tester.peerMissingStates["peer"] = map[common.Hash]bool{tester.downloader.fsPivotLock.Root: true}
(tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).setDelay(0)
tester.lock.Unlock()
}
}
// Return all nodes if we're testing fast sync progression
if progress {
tester.lock.Lock()
tester.peerMissingStates["peer"] = map[common.Hash]bool{}
tester.lock.Unlock()
if err := tester.sync("peer", nil, FastSync); err != nil {
t.Fatalf("failed to synchronise blocks in progressed fast sync: %v", err)
}
time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
if fails := atomic.LoadUint32(&tester.downloader.fsPivotFails); fails != 1 {
t.Fatalf("progressed pivot trial count mismatch: have %v, want %v", fails, 1)
}
assertOwnChain(t, tester, targetBlocks+1)
} else {
if err := tester.sync("peer", nil, FastSync); err == nil {
t.Fatalf("succeeded to synchronise blocks in failed fast sync")
}
time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
if fails := atomic.LoadUint32(&tester.downloader.fsPivotFails); fails != fsCriticalTrials {
t.Fatalf("failed pivot trial count mismatch: have %v, want %v", fails, fsCriticalTrials)
}
}
// Retry limit exhausted, downloader will switch to full sync, should succeed
if err := tester.sync("peer", nil, FastSync); err != nil {
t.Fatalf("failed to synchronise blocks in slow sync: %v", err)
} }
// Note, we can't assert the chain here because the test asserter assumes sync
// completed using a single mode of operation, whereas fast-then-slow can result
// in arbitrary intermediate state that's not cleanly verifiable.
} }

@ -32,7 +32,11 @@ import (
"gopkg.in/karalabe/cookiejar.v2/collections/prque" "gopkg.in/karalabe/cookiejar.v2/collections/prque"
) )
var blockCacheLimit = 8192 // Maximum number of blocks to cache before throttling the download var (
blockCacheItems = 8192 // Maximum number of blocks to cache before throttling the download
blockCacheMemory = 64 * 1024 * 1024 // Maximum amount of memory to use for block caching
blockCacheSizeWeight = 0.1 // Multiplier to approximate the average block size based on past ones
)
var ( var (
errNoFetchesPending = errors.New("no fetches pending") errNoFetchesPending = errors.New("no fetches pending")
@ -43,7 +47,6 @@ var (
type fetchRequest struct { type fetchRequest struct {
Peer *peerConnection // Peer to which the request was sent Peer *peerConnection // Peer to which the request was sent
From uint64 // [eth/62] Requested chain element index (used for skeleton fills only) From uint64 // [eth/62] Requested chain element index (used for skeleton fills only)
Hashes map[common.Hash]int // [eth/61] Requested hashes with their insertion index (priority)
Headers []*types.Header // [eth/62] Requested headers, sorted by request order Headers []*types.Header // [eth/62] Requested headers, sorted by request order
Time time.Time // Time when the request was made Time time.Time // Time when the request was made
} }
@ -52,6 +55,7 @@ type fetchRequest struct {
// all outstanding pieces complete and the result as a whole can be processed. // all outstanding pieces complete and the result as a whole can be processed.
type fetchResult struct { type fetchResult struct {
Pending int // Number of data fetches still pending Pending int // Number of data fetches still pending
Hash common.Hash // Hash of the header to prevent recalculating
Header *types.Header Header *types.Header
Uncles []*types.Header Uncles []*types.Header
@ -62,11 +66,9 @@ type fetchResult struct {
// queue represents hashes that are either need fetching or are being fetched // queue represents hashes that are either need fetching or are being fetched
type queue struct { type queue struct {
mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching
fastSyncPivot uint64 // Block number where the fast sync pivots into archive synchronisation mode
headerHead common.Hash // [eth/62] Hash of the last queued header to verify order
// Headers are "special", they download in batches, supported by a skeleton chain // Headers are "special", they download in batches, supported by a skeleton chain
headerHead common.Hash // [eth/62] Hash of the last queued header to verify order
headerTaskPool map[uint64]*types.Header // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers headerTaskPool map[uint64]*types.Header // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers
headerTaskQueue *prque.Prque // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for headerTaskQueue *prque.Prque // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for
headerPeerMiss map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable headerPeerMiss map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable
@ -89,6 +91,7 @@ type queue struct {
resultCache []*fetchResult // Downloaded but not yet delivered fetch results resultCache []*fetchResult // Downloaded but not yet delivered fetch results
resultOffset uint64 // Offset of the first cached fetch result in the block chain resultOffset uint64 // Offset of the first cached fetch result in the block chain
resultSize common.StorageSize // Approximate size of a block (exponential moving average)
lock *sync.Mutex lock *sync.Mutex
active *sync.Cond active *sync.Cond
@ -109,7 +112,7 @@ func newQueue() *queue {
receiptTaskQueue: prque.New(), receiptTaskQueue: prque.New(),
receiptPendPool: make(map[string]*fetchRequest), receiptPendPool: make(map[string]*fetchRequest),
receiptDonePool: make(map[common.Hash]struct{}), receiptDonePool: make(map[common.Hash]struct{}),
resultCache: make([]*fetchResult, blockCacheLimit), resultCache: make([]*fetchResult, blockCacheItems),
active: sync.NewCond(lock), active: sync.NewCond(lock),
lock: lock, lock: lock,
} }
@ -122,10 +125,8 @@ func (q *queue) Reset() {
q.closed = false q.closed = false
q.mode = FullSync q.mode = FullSync
q.fastSyncPivot = 0
q.headerHead = common.Hash{} q.headerHead = common.Hash{}
q.headerPendPool = make(map[string]*fetchRequest) q.headerPendPool = make(map[string]*fetchRequest)
q.blockTaskPool = make(map[common.Hash]*types.Header) q.blockTaskPool = make(map[common.Hash]*types.Header)
@ -138,7 +139,7 @@ func (q *queue) Reset() {
q.receiptPendPool = make(map[string]*fetchRequest) q.receiptPendPool = make(map[string]*fetchRequest)
q.receiptDonePool = make(map[common.Hash]struct{}) q.receiptDonePool = make(map[common.Hash]struct{})
q.resultCache = make([]*fetchResult, blockCacheLimit) q.resultCache = make([]*fetchResult, blockCacheItems)
q.resultOffset = 0 q.resultOffset = 0
} }
@ -214,27 +215,13 @@ func (q *queue) Idle() bool {
return (queued + pending + cached) == 0 return (queued + pending + cached) == 0
} }
// FastSyncPivot retrieves the currently used fast sync pivot point.
func (q *queue) FastSyncPivot() uint64 {
q.lock.Lock()
defer q.lock.Unlock()
return q.fastSyncPivot
}
// ShouldThrottleBlocks checks if the download should be throttled (active block (body) // ShouldThrottleBlocks checks if the download should be throttled (active block (body)
// fetches exceed block cache). // fetches exceed block cache).
func (q *queue) ShouldThrottleBlocks() bool { func (q *queue) ShouldThrottleBlocks() bool {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
// Calculate the currently in-flight block (body) requests return q.resultSlots(q.blockPendPool, q.blockDonePool) <= 0
pending := 0
for _, request := range q.blockPendPool {
pending += len(request.Hashes) + len(request.Headers)
}
// Throttle if more blocks (bodies) are in-flight than free space in the cache
return pending >= len(q.resultCache)-len(q.blockDonePool)
} }
// ShouldThrottleReceipts checks if the download should be throttled (active receipt // ShouldThrottleReceipts checks if the download should be throttled (active receipt
@ -243,13 +230,39 @@ func (q *queue) ShouldThrottleReceipts() bool {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
// Calculate the currently in-flight receipt requests return q.resultSlots(q.receiptPendPool, q.receiptDonePool) <= 0
}
// resultSlots calculates the number of results slots available for requests
// whilst adhering to both the item and the memory limit too of the results
// cache.
func (q *queue) resultSlots(pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}) int {
// Calculate the maximum length capped by the memory limit
limit := len(q.resultCache)
if common.StorageSize(len(q.resultCache))*q.resultSize > common.StorageSize(blockCacheMemory) {
limit = int((common.StorageSize(blockCacheMemory) + q.resultSize - 1) / q.resultSize)
}
// Calculate the number of slots already finished
finished := 0
for _, result := range q.resultCache[:limit] {
if result == nil {
break
}
if _, ok := donePool[result.Hash]; ok {
finished++
}
}
// Calculate the number of slots currently downloading
pending := 0 pending := 0
for _, request := range q.receiptPendPool { for _, request := range pendPool {
pending += len(request.Headers) for _, header := range request.Headers {
if header.Number.Uint64() < q.resultOffset+uint64(limit) {
pending++
}
} }
// Throttle if more receipts are in-flight than free space in the cache }
return pending >= len(q.resultCache)-len(q.receiptDonePool) // Return the free slots to distribute
return limit - finished - pending
} }
// ScheduleSkeleton adds a batch of header retrieval tasks to the queue to fill // ScheduleSkeleton adds a batch of header retrieval tasks to the queue to fill
@ -323,8 +336,7 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header {
q.blockTaskPool[hash] = header q.blockTaskPool[hash] = header
q.blockTaskQueue.Push(header, -float32(header.Number.Uint64())) q.blockTaskQueue.Push(header, -float32(header.Number.Uint64()))
if q.mode == FastSync && header.Number.Uint64() <= q.fastSyncPivot { if q.mode == FastSync {
// Fast phase of the fast sync, retrieve receipts too
q.receiptTaskPool[hash] = header q.receiptTaskPool[hash] = header
q.receiptTaskQueue.Push(header, -float32(header.Number.Uint64())) q.receiptTaskQueue.Push(header, -float32(header.Number.Uint64()))
} }
@ -335,18 +347,25 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header {
return inserts return inserts
} }
// WaitResults retrieves and permanently removes a batch of fetch // Results retrieves and permanently removes a batch of fetch results from
// results from the cache. the result slice will be empty if the queue // the cache. the result slice will be empty if the queue has been closed.
// has been closed. func (q *queue) Results(block bool) []*fetchResult {
func (q *queue) WaitResults() []*fetchResult {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
// Count the number of items available for processing
nproc := q.countProcessableItems() nproc := q.countProcessableItems()
for nproc == 0 && !q.closed { for nproc == 0 && !q.closed {
if !block {
return nil
}
q.active.Wait() q.active.Wait()
nproc = q.countProcessableItems() nproc = q.countProcessableItems()
} }
// Since we have a batch limit, don't pull more into "dangling" memory
if nproc > maxResultsProcess {
nproc = maxResultsProcess
}
results := make([]*fetchResult, nproc) results := make([]*fetchResult, nproc)
copy(results, q.resultCache[:nproc]) copy(results, q.resultCache[:nproc])
if len(results) > 0 { if len(results) > 0 {
@ -363,6 +382,21 @@ func (q *queue) WaitResults() []*fetchResult {
} }
// Advance the expected block number of the first cache entry. // Advance the expected block number of the first cache entry.
q.resultOffset += uint64(nproc) q.resultOffset += uint64(nproc)
// Recalculate the result item weights to prevent memory exhaustion
for _, result := range results {
size := result.Header.Size()
for _, uncle := range result.Uncles {
size += uncle.Size()
}
for _, receipt := range result.Receipts {
size += receipt.Size()
}
for _, tx := range result.Transactions {
size += tx.Size()
}
q.resultSize = common.StorageSize(blockCacheSizeWeight)*size + (1-common.StorageSize(blockCacheSizeWeight))*q.resultSize
}
} }
return results return results
} }
@ -370,21 +404,9 @@ func (q *queue) WaitResults() []*fetchResult {
// countProcessableItems counts the processable items. // countProcessableItems counts the processable items.
func (q *queue) countProcessableItems() int { func (q *queue) countProcessableItems() int {
for i, result := range q.resultCache { for i, result := range q.resultCache {
// Don't process incomplete or unavailable items.
if result == nil || result.Pending > 0 { if result == nil || result.Pending > 0 {
return i return i
} }
// Stop before processing the pivot block to ensure that
// resultCache has space for fsHeaderForceVerify items. Not
// doing this could leave us unable to download the required
// amount of headers.
if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot {
for j := 0; j < fsHeaderForceVerify; j++ {
if i+j+1 >= len(q.resultCache) || q.resultCache[i+j+1] == nil {
return i
}
}
}
} }
return len(q.resultCache) return len(q.resultCache)
} }
@ -473,10 +495,8 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
return nil, false, nil return nil, false, nil
} }
// Calculate an upper limit on the items we might fetch (i.e. throttling) // Calculate an upper limit on the items we might fetch (i.e. throttling)
space := len(q.resultCache) - len(donePool) space := q.resultSlots(pendPool, donePool)
for _, request := range pendPool {
space -= len(request.Headers)
}
// Retrieve a batch of tasks, skipping previously failed ones // Retrieve a batch of tasks, skipping previously failed ones
send := make([]*types.Header, 0, count) send := make([]*types.Header, 0, count)
skip := make([]*types.Header, 0) skip := make([]*types.Header, 0)
@ -484,6 +504,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
progress := false progress := false
for proc := 0; proc < space && len(send) < count && !taskQueue.Empty(); proc++ { for proc := 0; proc < space && len(send) < count && !taskQueue.Empty(); proc++ {
header := taskQueue.PopItem().(*types.Header) header := taskQueue.PopItem().(*types.Header)
hash := header.Hash()
// If we're the first to request this task, initialise the result container // If we're the first to request this task, initialise the result container
index := int(header.Number.Int64() - int64(q.resultOffset)) index := int(header.Number.Int64() - int64(q.resultOffset))
@ -493,18 +514,19 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
} }
if q.resultCache[index] == nil { if q.resultCache[index] == nil {
components := 1 components := 1
if q.mode == FastSync && header.Number.Uint64() <= q.fastSyncPivot { if q.mode == FastSync {
components = 2 components = 2
} }
q.resultCache[index] = &fetchResult{ q.resultCache[index] = &fetchResult{
Pending: components, Pending: components,
Hash: hash,
Header: header, Header: header,
} }
} }
// If this fetch task is a noop, skip this fetch operation // If this fetch task is a noop, skip this fetch operation
if isNoop(header) { if isNoop(header) {
donePool[header.Hash()] = struct{}{} donePool[hash] = struct{}{}
delete(taskPool, header.Hash()) delete(taskPool, hash)
space, proc = space-1, proc-1 space, proc = space-1, proc-1
q.resultCache[index].Pending-- q.resultCache[index].Pending--
@ -512,7 +534,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
continue continue
} }
// Otherwise unless the peer is known not to have the data, add to the retrieve list // Otherwise unless the peer is known not to have the data, add to the retrieve list
if p.Lacks(header.Hash()) { if p.Lacks(hash) {
skip = append(skip, header) skip = append(skip, header)
} else { } else {
send = append(send, header) send = append(send, header)
@ -565,9 +587,6 @@ func (q *queue) cancel(request *fetchRequest, taskQueue *prque.Prque, pendPool m
if request.From > 0 { if request.From > 0 {
taskQueue.Push(request.From, -float32(request.From)) taskQueue.Push(request.From, -float32(request.From))
} }
for hash, index := range request.Hashes {
taskQueue.Push(hash, float32(index))
}
for _, header := range request.Headers { for _, header := range request.Headers {
taskQueue.Push(header, -float32(header.Number.Uint64())) taskQueue.Push(header, -float32(header.Number.Uint64()))
} }
@ -640,18 +659,11 @@ func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest,
if request.From > 0 { if request.From > 0 {
taskQueue.Push(request.From, -float32(request.From)) taskQueue.Push(request.From, -float32(request.From))
} }
for hash, index := range request.Hashes {
taskQueue.Push(hash, float32(index))
}
for _, header := range request.Headers { for _, header := range request.Headers {
taskQueue.Push(header, -float32(header.Number.Uint64())) taskQueue.Push(header, -float32(header.Number.Uint64()))
} }
// Add the peer to the expiry report along the the number of failed requests // Add the peer to the expiry report along the the number of failed requests
expirations := len(request.Hashes) expiries[id] = len(request.Headers)
if expirations < len(request.Headers) {
expirations = len(request.Headers)
}
expiries[id] = expirations
} }
} }
// Remove the expired requests from the pending pool // Remove the expired requests from the pending pool
@ -828,14 +840,16 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
failure = err failure = err
break break
} }
donePool[header.Hash()] = struct{}{} hash := header.Hash()
donePool[hash] = struct{}{}
q.resultCache[index].Pending-- q.resultCache[index].Pending--
useful = true useful = true
accepted++ accepted++
// Clean up a successful fetch // Clean up a successful fetch
request.Headers[i] = nil request.Headers[i] = nil
delete(taskPool, header.Hash()) delete(taskPool, hash)
} }
// Return all failed or missing fetches to the queue // Return all failed or missing fetches to the queue
for _, header := range request.Headers { for _, header := range request.Headers {
@ -860,7 +874,7 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
// Prepare configures the result cache to allow accepting and caching inbound // Prepare configures the result cache to allow accepting and caching inbound
// fetch results. // fetch results.
func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types.Header) { func (q *queue) Prepare(offset uint64, mode SyncMode) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
@ -868,6 +882,5 @@ func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types.
if q.resultOffset < offset { if q.resultOffset < offset {
q.resultOffset = offset q.resultOffset = offset
} }
q.fastSyncPivot = pivot
q.mode = mode q.mode = mode
} }

@ -20,7 +20,6 @@ import (
"fmt" "fmt"
"hash" "hash"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -294,6 +293,9 @@ func (s *stateSync) loop() error {
case <-s.cancel: case <-s.cancel:
return errCancelStateFetch return errCancelStateFetch
case <-s.d.cancelCh:
return errCancelStateFetch
case req := <-s.deliver: case req := <-s.deliver:
// Response, disconnect or timeout triggered, drop the peer if stalling // Response, disconnect or timeout triggered, drop the peer if stalling
log.Trace("Received node data response", "peer", req.peer.id, "count", len(req.response), "dropped", req.dropped, "timeout", !req.dropped && req.timedOut()) log.Trace("Received node data response", "peer", req.peer.id, "count", len(req.response), "dropped", req.dropped, "timeout", !req.dropped && req.timedOut())
@ -304,17 +306,13 @@ func (s *stateSync) loop() error {
s.d.dropPeer(req.peer.id) s.d.dropPeer(req.peer.id)
} }
// Process all the received blobs and check for stale delivery // Process all the received blobs and check for stale delivery
stale, err := s.process(req) if err := s.process(req); err != nil {
if err != nil {
log.Warn("Node data write error", "err", err) log.Warn("Node data write error", "err", err)
return err return err
} }
// The the delivery contains requested data, mark the node idle (otherwise it's a timed out delivery)
if !stale {
req.peer.SetNodeDataIdle(len(req.response)) req.peer.SetNodeDataIdle(len(req.response))
} }
} }
}
return s.commit(true) return s.commit(true)
} }
@ -352,6 +350,7 @@ func (s *stateSync) assignTasks() {
case s.d.trackStateReq <- req: case s.d.trackStateReq <- req:
req.peer.FetchNodeData(req.items) req.peer.FetchNodeData(req.items)
case <-s.cancel: case <-s.cancel:
case <-s.d.cancelCh:
} }
} }
} }
@ -390,7 +389,7 @@ func (s *stateSync) fillTasks(n int, req *stateReq) {
// process iterates over a batch of delivered state data, injecting each item // process iterates over a batch of delivered state data, injecting each item
// into a running state sync, re-queuing any items that were requested but not // into a running state sync, re-queuing any items that were requested but not
// delivered. // delivered.
func (s *stateSync) process(req *stateReq) (bool, error) { func (s *stateSync) process(req *stateReq) error {
// Collect processing stats and update progress if valid data was received // Collect processing stats and update progress if valid data was received
duplicate, unexpected := 0, 0 duplicate, unexpected := 0, 0
@ -401,7 +400,7 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
}(time.Now()) }(time.Now())
// Iterate over all the delivered data and inject one-by-one into the trie // Iterate over all the delivered data and inject one-by-one into the trie
progress, stale := false, len(req.response) > 0 progress := false
for _, blob := range req.response { for _, blob := range req.response {
prog, hash, err := s.processNodeData(blob) prog, hash, err := s.processNodeData(blob)
@ -415,20 +414,12 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
case trie.ErrAlreadyProcessed: case trie.ErrAlreadyProcessed:
duplicate++ duplicate++
default: default:
return stale, fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err) return fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err)
} }
// If the node delivered a requested item, mark the delivery non-stale
if _, ok := req.tasks[hash]; ok { if _, ok := req.tasks[hash]; ok {
delete(req.tasks, hash) delete(req.tasks, hash)
stale = false
} }
} }
// If we're inside the critical section, reset fail counter since we progressed.
if progress && atomic.LoadUint32(&s.d.fsPivotFails) > 1 {
log.Trace("Fast-sync progressed, resetting fail counter", "previous", atomic.LoadUint32(&s.d.fsPivotFails))
atomic.StoreUint32(&s.d.fsPivotFails, 1) // Don't ever reset to 0, as that will unlock the pivot block
}
// Put unfulfilled tasks back into the retry queue // Put unfulfilled tasks back into the retry queue
npeers := s.d.peers.Len() npeers := s.d.peers.Len()
for hash, task := range req.tasks { for hash, task := range req.tasks {
@ -441,12 +432,12 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
// If we've requested the node too many times already, it may be a malicious // If we've requested the node too many times already, it may be a malicious
// sync where nobody has the right data. Abort. // sync where nobody has the right data. Abort.
if len(task.attempts) >= npeers { if len(task.attempts) >= npeers {
return stale, fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers) return fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers)
} }
// Missing item, place into the retry queue. // Missing item, place into the retry queue.
s.tasks[hash] = task s.tasks[hash] = task
} }
return stale, nil return nil
} }
// processNodeData tries to inject a trie node data blob delivered from a remote // processNodeData tries to inject a trie node data blob delivered from a remote

@ -71,7 +71,6 @@ type ProtocolManager struct {
txpool txPool txpool txPool
blockchain *core.BlockChain blockchain *core.BlockChain
chaindb ethdb.Database
chainconfig *params.ChainConfig chainconfig *params.ChainConfig
maxPeers int maxPeers int
@ -106,7 +105,6 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
eventMux: mux, eventMux: mux,
txpool: txpool, txpool: txpool,
blockchain: blockchain, blockchain: blockchain,
chaindb: chaindb,
chainconfig: config, chainconfig: config,
peers: newPeerSet(), peers: newPeerSet(),
newPeerCh: make(chan *peer), newPeerCh: make(chan *peer),
@ -538,7 +536,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrDecode, "msg %v: %v", msg, err) return errResp(ErrDecode, "msg %v: %v", msg, err)
} }
// Retrieve the requested state entry, stopping if enough was found // Retrieve the requested state entry, stopping if enough was found
if entry, err := pm.chaindb.Get(hash.Bytes()); err == nil { if entry, err := pm.blockchain.TrieNode(hash); err == nil {
data = append(data, entry) data = append(data, entry)
bytes += len(entry) bytes += len(entry)
} }
@ -576,7 +574,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrDecode, "msg %v: %v", msg, err) return errResp(ErrDecode, "msg %v: %v", msg, err)
} }
// Retrieve the requested block's receipts, skipping if unknown to us // Retrieve the requested block's receipts, skipping if unknown to us
results := core.GetBlockReceipts(pm.chaindb, hash, core.GetBlockNumber(pm.chaindb, hash)) results := pm.blockchain.GetReceiptsByHash(hash)
if results == nil { if results == nil {
if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
continue continue

@ -56,7 +56,7 @@ func TestProtocolCompatibility(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
ProtocolVersions = []uint{tt.version} ProtocolVersions = []uint{tt.version}
pm, err := newTestProtocolManager(tt.mode, 0, nil, nil) pm, _, err := newTestProtocolManager(tt.mode, 0, nil, nil)
if pm != nil { if pm != nil {
defer pm.Stop() defer pm.Stop()
} }
@ -71,7 +71,7 @@ func TestGetBlockHeaders62(t *testing.T) { testGetBlockHeaders(t, 62) }
func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) } func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) }
func testGetBlockHeaders(t *testing.T, protocol int) { func testGetBlockHeaders(t *testing.T, protocol int) {
pm := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil) pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil)
peer, _ := newTestPeer("peer", protocol, pm, true) peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close() defer peer.close()
@ -230,7 +230,7 @@ func TestGetBlockBodies62(t *testing.T) { testGetBlockBodies(t, 62) }
func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) } func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) }
func testGetBlockBodies(t *testing.T, protocol int) { func testGetBlockBodies(t *testing.T, protocol int) {
pm := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil) pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil)
peer, _ := newTestPeer("peer", protocol, pm, true) peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close() defer peer.close()
@ -337,13 +337,13 @@ func testGetNodeData(t *testing.T, protocol int) {
} }
} }
// Assemble the test environment // Assemble the test environment
pm := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil) pm, db := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
peer, _ := newTestPeer("peer", protocol, pm, true) peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close() defer peer.close()
// Fetch for now the entire chain db // Fetch for now the entire chain db
hashes := []common.Hash{} hashes := []common.Hash{}
for _, key := range pm.chaindb.(*ethdb.MemDatabase).Keys() { for _, key := range db.Keys() {
if len(key) == len(common.Hash{}) { if len(key) == len(common.Hash{}) {
hashes = append(hashes, common.BytesToHash(key)) hashes = append(hashes, common.BytesToHash(key))
} }
@ -429,7 +429,7 @@ func testGetReceipt(t *testing.T, protocol int) {
} }
} }
// Assemble the test environment // Assemble the test environment
pm := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil) pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
peer, _ := newTestPeer("peer", protocol, pm, true) peer, _ := newTestPeer("peer", protocol, pm, true)
defer peer.close() defer peer.close()
@ -439,7 +439,7 @@ func testGetReceipt(t *testing.T, protocol int) {
block := pm.blockchain.GetBlockByNumber(i) block := pm.blockchain.GetBlockByNumber(i)
hashes = append(hashes, block.Hash()) hashes = append(hashes, block.Hash())
receipts = append(receipts, core.GetBlockReceipts(pm.chaindb, block.Hash(), block.NumberU64())) receipts = append(receipts, pm.blockchain.GetReceiptsByHash(block.Hash()))
} }
// Send the hash request and verify the response // Send the hash request and verify the response
p2p.Send(peer.app, 0x0f, hashes) p2p.Send(peer.app, 0x0f, hashes)
@ -472,7 +472,7 @@ func testDAOChallenge(t *testing.T, localForked, remoteForked bool, timeout bool
config = &params.ChainConfig{DAOForkBlock: big.NewInt(1), DAOForkSupport: localForked} config = &params.ChainConfig{DAOForkBlock: big.NewInt(1), DAOForkSupport: localForked}
gspec = &core.Genesis{Config: config} gspec = &core.Genesis{Config: config}
genesis = gspec.MustCommit(db) genesis = gspec.MustCommit(db)
blockchain, _ = core.NewBlockChain(db, config, pow, vm.Config{}) blockchain, _ = core.NewBlockChain(db, nil, config, pow, vm.Config{})
) )
pm, err := NewProtocolManager(config, downloader.FullSync, DefaultConfig.NetworkId, evmux, new(testTxPool), pow, blockchain, db) pm, err := NewProtocolManager(config, downloader.FullSync, DefaultConfig.NetworkId, evmux, new(testTxPool), pow, blockchain, db)
if err != nil { if err != nil {

@ -49,7 +49,7 @@ var (
// newTestProtocolManager creates a new protocol manager for testing purposes, // newTestProtocolManager creates a new protocol manager for testing purposes,
// with the given number of blocks already known, and potential notification // with the given number of blocks already known, and potential notification
// channels for different events. // channels for different events.
func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, error) { func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, *ethdb.MemDatabase, error) {
var ( var (
evmux = new(event.TypeMux) evmux = new(event.TypeMux)
engine = ethash.NewFaker() engine = ethash.NewFaker()
@ -59,7 +59,7 @@ func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func
Alloc: core.GenesisAlloc{testBank: {Balance: big.NewInt(1000000)}}, Alloc: core.GenesisAlloc{testBank: {Balance: big.NewInt(1000000)}},
} }
genesis = gspec.MustCommit(db) genesis = gspec.MustCommit(db)
blockchain, _ = core.NewBlockChain(db, gspec.Config, engine, vm.Config{}) blockchain, _ = core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
) )
chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator) chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator)
if _, err := blockchain.InsertChain(chain); err != nil { if _, err := blockchain.InsertChain(chain); err != nil {
@ -68,22 +68,22 @@ func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func
pm, err := NewProtocolManager(gspec.Config, mode, DefaultConfig.NetworkId, evmux, &testTxPool{added: newtx}, engine, blockchain, db) pm, err := NewProtocolManager(gspec.Config, mode, DefaultConfig.NetworkId, evmux, &testTxPool{added: newtx}, engine, blockchain, db)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
pm.Start(1000) pm.Start(1000)
return pm, nil return pm, db, nil
} }
// newTestProtocolManagerMust creates a new protocol manager for testing purposes, // newTestProtocolManagerMust creates a new protocol manager for testing purposes,
// with the given number of blocks already known, and potential notification // with the given number of blocks already known, and potential notification
// channels for different events. In case of an error, the constructor force- // channels for different events. In case of an error, the constructor force-
// fails the test. // fails the test.
func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) *ProtocolManager { func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, *ethdb.MemDatabase) {
pm, err := newTestProtocolManager(mode, blocks, generator, newtx) pm, db, err := newTestProtocolManager(mode, blocks, generator, newtx)
if err != nil { if err != nil {
t.Fatalf("Failed to create protocol manager: %v", err) t.Fatalf("Failed to create protocol manager: %v", err)
} }
return pm return pm, db
} }
// testTxPool is a fake, helper transaction pool for testing purposes // testTxPool is a fake, helper transaction pool for testing purposes

@ -41,7 +41,7 @@ func TestStatusMsgErrors62(t *testing.T) { testStatusMsgErrors(t, 62) }
func TestStatusMsgErrors63(t *testing.T) { testStatusMsgErrors(t, 63) } func TestStatusMsgErrors63(t *testing.T) { testStatusMsgErrors(t, 63) }
func testStatusMsgErrors(t *testing.T, protocol int) { func testStatusMsgErrors(t *testing.T, protocol int) {
pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
var ( var (
genesis = pm.blockchain.Genesis() genesis = pm.blockchain.Genesis()
head = pm.blockchain.CurrentHeader() head = pm.blockchain.CurrentHeader()
@ -98,7 +98,7 @@ func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) }
func testRecvTransactions(t *testing.T, protocol int) { func testRecvTransactions(t *testing.T, protocol int) {
txAdded := make(chan []*types.Transaction) txAdded := make(chan []*types.Transaction)
pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded) pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded)
pm.acceptTxs = 1 // mark synced to accept transactions pm.acceptTxs = 1 // mark synced to accept transactions
p, _ := newTestPeer("peer", protocol, pm, true) p, _ := newTestPeer("peer", protocol, pm, true)
defer pm.Stop() defer pm.Stop()
@ -125,7 +125,7 @@ func TestSendTransactions62(t *testing.T) { testSendTransactions(t, 62) }
func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) } func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) }
func testSendTransactions(t *testing.T, protocol int) { func testSendTransactions(t *testing.T, protocol int) {
pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
defer pm.Stop() defer pm.Stop()
// Fill the pool with big transactions. // Fill the pool with big transactions.

@ -30,12 +30,12 @@ import (
// imported into the blockchain. // imported into the blockchain.
func TestFastSyncDisabling(t *testing.T) { func TestFastSyncDisabling(t *testing.T) {
// Create a pristine protocol manager, check that fast sync is left enabled // Create a pristine protocol manager, check that fast sync is left enabled
pmEmpty := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil) pmEmpty, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil)
if atomic.LoadUint32(&pmEmpty.fastSync) == 0 { if atomic.LoadUint32(&pmEmpty.fastSync) == 0 {
t.Fatalf("fast sync disabled on pristine blockchain") t.Fatalf("fast sync disabled on pristine blockchain")
} }
// Create a full protocol manager, check that fast sync gets disabled // Create a full protocol manager, check that fast sync gets disabled
pmFull := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil) pmFull, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
if atomic.LoadUint32(&pmFull.fastSync) == 1 { if atomic.LoadUint32(&pmFull.fastSync) == 1 {
t.Fatalf("fast sync not disabled on non-empty blockchain") t.Fatalf("fast sync not disabled on non-empty blockchain")
} }

@ -808,7 +808,7 @@ func (s *PublicBlockChainAPI) rpcOutputBlock(b *types.Block, inclTx bool, fullTx
"difficulty": (*hexutil.Big)(head.Difficulty), "difficulty": (*hexutil.Big)(head.Difficulty),
"totalDifficulty": (*hexutil.Big)(s.b.GetTd(b.Hash())), "totalDifficulty": (*hexutil.Big)(s.b.GetTd(b.Hash())),
"extraData": hexutil.Bytes(head.Extra), "extraData": hexutil.Bytes(head.Extra),
"size": hexutil.Uint64(uint64(b.Size().Int64())), "size": hexutil.Uint64(b.Size()),
"gasLimit": hexutil.Uint64(head.GasLimit), "gasLimit": hexutil.Uint64(head.GasLimit),
"gasUsed": hexutil.Uint64(head.GasUsed), "gasUsed": hexutil.Uint64(head.GasUsed),
"timestamp": (*hexutil.Big)(head.Time), "timestamp": (*hexutil.Big)(head.Time),

@ -18,7 +18,6 @@
package les package les
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -78,6 +77,7 @@ type BlockChain interface {
GetHeaderByHash(hash common.Hash) *types.Header GetHeaderByHash(hash common.Hash) *types.Header
CurrentHeader() *types.Header CurrentHeader() *types.Header
GetTd(hash common.Hash, number uint64) *big.Int GetTd(hash common.Hash, number uint64) *big.Int
State() (*state.StateDB, error)
InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error)
Rollback(chain []common.Hash) Rollback(chain []common.Hash)
GetHeaderByNumber(number uint64) *types.Header GetHeaderByNumber(number uint64) *types.Header
@ -579,17 +579,19 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
for _, req := range req.Reqs { for _, req := range req.Reqs {
// Retrieve the requested state entry, stopping if enough was found // Retrieve the requested state entry, stopping if enough was found
if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
if trie, _ := trie.New(header.Root, pm.chainDb); trie != nil { statedb, err := pm.blockchain.State()
sdata := trie.Get(req.AccKey) if err != nil {
var acc state.Account continue
if err := rlp.DecodeBytes(sdata, &acc); err == nil {
entry, _ := pm.chainDb.Get(acc.CodeHash)
if bytes+len(entry) >= softResponseLimit {
break
} }
data = append(data, entry) account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey))
bytes += len(entry) if err != nil {
continue
} }
code, _ := statedb.Database().TrieDB().Node(common.BytesToHash(account.CodeHash))
data = append(data, code)
if bytes += len(code); bytes >= softResponseLimit {
break
} }
} }
} }
@ -701,25 +703,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrRequestRejected, "") return errResp(ErrRequestRejected, "")
} }
for _, req := range req.Reqs { for _, req := range req.Reqs {
if bytes >= softResponseLimit {
break
}
// Retrieve the requested state entry, stopping if enough was found // Retrieve the requested state entry, stopping if enough was found
if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
if tr, _ := trie.New(header.Root, pm.chainDb); tr != nil { statedb, err := pm.blockchain.State()
if err != nil {
continue
}
var trie state.Trie
if len(req.AccKey) > 0 { if len(req.AccKey) > 0 {
sdata := tr.Get(req.AccKey) account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey))
tr = nil if err != nil {
var acc state.Account continue
if err := rlp.DecodeBytes(sdata, &acc); err == nil {
tr, _ = trie.New(acc.Root, pm.chainDb)
} }
trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root)
} else {
trie, _ = statedb.Database().OpenTrie(header.Root)
} }
if tr != nil { if trie != nil {
var proof light.NodeList var proof light.NodeList
tr.Prove(req.Key, 0, &proof) trie.Prove(req.Key, 0, &proof)
proofs = append(proofs, proof) proofs = append(proofs, proof)
bytes += proof.DataSize() if bytes += proof.DataSize(); bytes >= softResponseLimit {
break
} }
} }
} }
@ -741,8 +747,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
// Gather state data until the fetch or network limits is reached // Gather state data until the fetch or network limits is reached
var ( var (
lastBHash common.Hash lastBHash common.Hash
lastAccKey []byte statedb *state.StateDB
tr, str *trie.Trie root common.Hash
) )
reqCnt := len(req.Reqs) reqCnt := len(req.Reqs)
if reject(uint64(reqCnt), MaxProofsFetch) { if reject(uint64(reqCnt), MaxProofsFetch) {
@ -752,35 +758,36 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
nodes := light.NewNodeSet() nodes := light.NewNodeSet()
for _, req := range req.Reqs { for _, req := range req.Reqs {
if nodes.DataSize() >= softResponseLimit { // Look up the state belonging to the request
break if statedb == nil || req.BHash != lastBHash {
} statedb, root, lastBHash = nil, common.Hash{}, req.BHash
if tr == nil || req.BHash != lastBHash {
if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
tr, _ = trie.New(header.Root, pm.chainDb) statedb, _ = pm.blockchain.State()
} else { root = header.Root
tr = nil
} }
lastBHash = req.BHash
str = nil
} }
if tr != nil { if statedb == nil {
if len(req.AccKey) > 0 { continue
if str == nil || !bytes.Equal(req.AccKey, lastAccKey) {
sdata := tr.Get(req.AccKey)
str = nil
var acc state.Account
if err := rlp.DecodeBytes(sdata, &acc); err == nil {
str, _ = trie.New(acc.Root, pm.chainDb)
}
lastAccKey = common.CopyBytes(req.AccKey)
} }
if str != nil { // Pull the account or storage trie of the request
str.Prove(req.Key, req.FromLevel, nodes) var trie state.Trie
if len(req.AccKey) > 0 {
account, err := pm.getAccount(statedb, root, common.BytesToHash(req.AccKey))
if err != nil {
continue
} }
trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root)
} else { } else {
tr.Prove(req.Key, req.FromLevel, nodes) trie, _ = statedb.Database().OpenTrie(root)
} }
if trie == nil {
continue
}
// Prove the user's request from the account or stroage trie
trie.Prove(req.Key, req.FromLevel, nodes)
if nodes.DataSize() >= softResponseLimit {
break
} }
} }
proofs := nodes.NodeList() proofs := nodes.NodeList()
@ -849,23 +856,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) {
return errResp(ErrRequestRejected, "") return errResp(ErrRequestRejected, "")
} }
trieDb := ethdb.NewTable(pm.chainDb, light.ChtTablePrefix)
for _, req := range req.Reqs { for _, req := range req.Reqs {
if bytes >= softResponseLimit {
break
}
if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil {
sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.ChtV1Frequency-1) sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.ChtV1Frequency-1)
if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) { if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) {
if tr, _ := trie.New(root, trieDb); tr != nil { statedb, err := pm.blockchain.State()
if err != nil {
continue
}
trie, err := statedb.Database().OpenTrie(root)
if err != nil {
continue
}
var encNumber [8]byte var encNumber [8]byte
binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) binary.BigEndian.PutUint64(encNumber[:], req.BlockNum)
var proof light.NodeList var proof light.NodeList
tr.Prove(encNumber[:], 0, &proof) trie.Prove(encNumber[:], 0, &proof)
proofs = append(proofs, ChtResp{Header: header, Proof: proof}) proofs = append(proofs, ChtResp{Header: header, Proof: proof})
bytes += proof.DataSize() + estHeaderRlpSize if bytes += proof.DataSize() + estHeaderRlpSize; bytes >= softResponseLimit {
break
} }
} }
} }
} }
@ -897,25 +910,21 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
lastIdx uint64 lastIdx uint64
lastType uint lastType uint
root common.Hash root common.Hash
tr *trie.Trie statedb *state.StateDB
trie state.Trie
) )
nodes := light.NewNodeSet() nodes := light.NewNodeSet()
for _, req := range req.Reqs { for _, req := range req.Reqs {
if nodes.DataSize()+auxBytes >= softResponseLimit { if trie == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx {
break statedb, trie, lastType, lastIdx = nil, nil, req.HelperTrieType, req.TrieIdx
}
if tr == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx { if root, _ = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx); root != (common.Hash{}) {
var prefix string if statedb, _ = pm.blockchain.State(); statedb != nil {
root, prefix = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx) trie, _ = statedb.Database().OpenTrie(root)
if root != (common.Hash{}) {
if t, err := trie.New(root, ethdb.NewTable(pm.chainDb, prefix)); err == nil {
tr = t
} }
} }
lastType = req.HelperTrieType
lastIdx = req.TrieIdx
} }
if req.AuxReq == auxRoot { if req.AuxReq == auxRoot {
var data []byte var data []byte
@ -925,8 +934,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
auxData = append(auxData, data) auxData = append(auxData, data)
auxBytes += len(data) auxBytes += len(data)
} else { } else {
if tr != nil { if trie != nil {
tr.Prove(req.Key, req.FromLevel, nodes) trie.Prove(req.Key, req.FromLevel, nodes)
} }
if req.AuxReq != 0 { if req.AuxReq != 0 {
data := pm.getHelperTrieAuxData(req) data := pm.getHelperTrieAuxData(req)
@ -934,6 +943,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
auxBytes += len(data) auxBytes += len(data)
} }
} }
if nodes.DataSize()+auxBytes >= softResponseLimit {
break
}
} }
proofs := nodes.NodeList() proofs := nodes.NodeList()
bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
@ -1090,6 +1102,23 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return nil return nil
} }
// getAccount retrieves an account from the state based at root.
func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (state.Account, error) {
trie, err := trie.New(root, statedb.Database().TrieDB())
if err != nil {
return state.Account{}, err
}
blob, err := trie.TryGet(hash[:])
if err != nil {
return state.Account{}, err
}
var account state.Account
if err = rlp.DecodeBytes(blob, &account); err != nil {
return state.Account{}, err
}
return account, nil
}
// getHelperTrie returns the post-processed trie root for the given trie ID and section index // getHelperTrie returns the post-processed trie root for the given trie ID and section index
func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) {
switch id { switch id {

@ -359,7 +359,7 @@ func testGetProofs(t *testing.T, protocol int) {
for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
header := bc.GetHeaderByNumber(i) header := bc.GetHeaderByNumber(i)
root := header.Root root := header.Root
trie, _ := trie.New(root, db) trie, _ := trie.New(root, trie.NewDatabase(db))
for _, acc := range accounts { for _, acc := range accounts {
req := ProofReq{ req := ProofReq{

@ -146,7 +146,7 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor
if lightSync { if lightSync {
chain, _ = light.NewLightChain(odr, gspec.Config, engine) chain, _ = light.NewLightChain(odr, gspec.Config, engine)
} else { } else {
blockchain, _ := core.NewBlockChain(db, gspec.Config, engine, vm.Config{}) blockchain, _ := core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
gchain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator) gchain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator)
if _, err := blockchain.InsertChain(gchain); err != nil { if _, err := blockchain.InsertChain(gchain); err != nil {
panic(err) panic(err)

@ -101,7 +101,6 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
res = append(res, rlp...) res = append(res, rlp...)
} }
} }
return res return res
} }

@ -18,6 +18,7 @@ package light
import ( import (
"context" "context"
"errors"
"math/big" "math/big"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -26,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
@ -212,6 +214,11 @@ func (bc *LightChain) Genesis() *types.Block {
return bc.genesisBlock return bc.genesisBlock
} }
// State returns a new mutable state based on the current HEAD block.
func (bc *LightChain) State() (*state.StateDB, error) {
return nil, errors.New("not implemented, needs client/server interface split")
}
// GetBody retrieves a block body (transactions and uncles) from the database // GetBody retrieves a block body (transactions and uncles) from the database
// or ODR service by hash, caching it if found. // or ODR service by hash, caching it if found.
func (self *LightChain) GetBody(ctx context.Context, hash common.Hash) (*types.Body, error) { func (self *LightChain) GetBody(ctx context.Context, hash common.Hash) (*types.Body, error) {

@ -22,8 +22,8 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
) )
// NodeSet stores a set of trie nodes. It implements trie.Database and can also // NodeSet stores a set of trie nodes. It implements trie.Database and can also
@ -99,7 +99,7 @@ func (db *NodeSet) NodeList() NodeList {
} }
// Store writes the contents of the set to the given database // Store writes the contents of the set to the given database
func (db *NodeSet) Store(target trie.Database) { func (db *NodeSet) Store(target ethdb.Putter) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
@ -108,11 +108,11 @@ func (db *NodeSet) Store(target trie.Database) {
} }
} }
// NodeList stores an ordered list of trie nodes. It implements trie.DatabaseWriter. // NodeList stores an ordered list of trie nodes. It implements ethdb.Putter.
type NodeList []rlp.RawValue type NodeList []rlp.RawValue
// Store writes the contents of the list to the given database // Store writes the contents of the list to the given database
func (n NodeList) Store(db trie.Database) { func (n NodeList) Store(db ethdb.Putter) {
for _, node := range n { for _, node := range n {
db.Put(crypto.Keccak256(node), node) db.Put(crypto.Keccak256(node), node)
} }

@ -74,7 +74,7 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error {
case *ReceiptsRequest: case *ReceiptsRequest:
req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash))
case *TrieRequest: case *TrieRequest:
t, _ := trie.New(req.Id.Root, odr.sdb) t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb))
nodes := NewNodeSet() nodes := NewNodeSet()
t.Prove(req.Key, 0, nodes) t.Prove(req.Key, 0, nodes)
req.Proof = nodes req.Proof = nodes
@ -239,7 +239,7 @@ func testChainOdr(t *testing.T, protocol int, fn odrTestFn) {
) )
gspec.MustCommit(ldb) gspec.MustCommit(ldb)
// Assemble the test environment // Assemble the test environment
blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) blockchain, _ := core.NewBlockChain(sdb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, 4, testChainGen) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, 4, testChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil { if _, err := blockchain.InsertChain(gchain); err != nil {
t.Fatal(err) t.Fatal(err)

@ -113,7 +113,8 @@ func StoreChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common
// ChtIndexerBackend implements core.ChainIndexerBackend // ChtIndexerBackend implements core.ChainIndexerBackend
type ChtIndexerBackend struct { type ChtIndexerBackend struct {
db, cdb ethdb.Database diskdb ethdb.Database
triedb *trie.Database
section, sectionSize uint64 section, sectionSize uint64
lastHash common.Hash lastHash common.Hash
trie *trie.Trie trie *trie.Trie
@ -121,8 +122,6 @@ type ChtIndexerBackend struct {
// NewBloomTrieIndexer creates a BloomTrie chain indexer // NewBloomTrieIndexer creates a BloomTrie chain indexer
func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
cdb := ethdb.NewTable(db, ChtTablePrefix)
idb := ethdb.NewTable(db, "chtIndex-")
var sectionSize, confirmReq uint64 var sectionSize, confirmReq uint64
if clientMode { if clientMode {
sectionSize = ChtFrequency sectionSize = ChtFrequency
@ -131,17 +130,23 @@ func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
sectionSize = ChtV1Frequency sectionSize = ChtV1Frequency
confirmReq = HelperTrieProcessConfirmations confirmReq = HelperTrieProcessConfirmations
} }
return core.NewChainIndexer(db, idb, &ChtIndexerBackend{db: db, cdb: cdb, sectionSize: sectionSize}, sectionSize, confirmReq, time.Millisecond*100, "cht") idb := ethdb.NewTable(db, "chtIndex-")
backend := &ChtIndexerBackend{
diskdb: db,
triedb: trie.NewDatabase(ethdb.NewTable(db, ChtTablePrefix)),
sectionSize: sectionSize,
}
return core.NewChainIndexer(db, idb, backend, sectionSize, confirmReq, time.Millisecond*100, "cht")
} }
// Reset implements core.ChainIndexerBackend // Reset implements core.ChainIndexerBackend
func (c *ChtIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { func (c *ChtIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error {
var root common.Hash var root common.Hash
if section > 0 { if section > 0 {
root = GetChtRoot(c.db, section-1, lastSectionHead) root = GetChtRoot(c.diskdb, section-1, lastSectionHead)
} }
var err error var err error
c.trie, err = trie.New(root, c.cdb) c.trie, err = trie.New(root, c.triedb)
c.section = section c.section = section
return err return err
} }
@ -151,7 +156,7 @@ func (c *ChtIndexerBackend) Process(header *types.Header) {
hash, num := header.Hash(), header.Number.Uint64() hash, num := header.Hash(), header.Number.Uint64()
c.lastHash = hash c.lastHash = hash
td := core.GetTd(c.db, hash, num) td := core.GetTd(c.diskdb, hash, num)
if td == nil { if td == nil {
panic(nil) panic(nil)
} }
@ -163,17 +168,16 @@ func (c *ChtIndexerBackend) Process(header *types.Header) {
// Commit implements core.ChainIndexerBackend // Commit implements core.ChainIndexerBackend
func (c *ChtIndexerBackend) Commit() error { func (c *ChtIndexerBackend) Commit() error {
batch := c.cdb.NewBatch() root, err := c.trie.Commit(nil)
root, err := c.trie.CommitTo(batch)
if err != nil { if err != nil {
return err return err
} else { }
batch.Write() c.triedb.Commit(root, false)
if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 { if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 {
log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root)) log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root))
} }
StoreChtRoot(c.db, c.section, c.lastHash, root) StoreChtRoot(c.diskdb, c.section, c.lastHash, root)
}
return nil return nil
} }
@ -205,7 +209,8 @@ func StoreBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root
// BloomTrieIndexerBackend implements core.ChainIndexerBackend // BloomTrieIndexerBackend implements core.ChainIndexerBackend
type BloomTrieIndexerBackend struct { type BloomTrieIndexerBackend struct {
db, cdb ethdb.Database diskdb ethdb.Database
triedb *trie.Database
section, parentSectionSize, bloomTrieRatio uint64 section, parentSectionSize, bloomTrieRatio uint64
trie *trie.Trie trie *trie.Trie
sectionHeads []common.Hash sectionHeads []common.Hash
@ -213,9 +218,12 @@ type BloomTrieIndexerBackend struct {
// NewBloomTrieIndexer creates a BloomTrie chain indexer // NewBloomTrieIndexer creates a BloomTrie chain indexer
func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
cdb := ethdb.NewTable(db, BloomTrieTablePrefix) backend := &BloomTrieIndexerBackend{
diskdb: db,
triedb: trie.NewDatabase(ethdb.NewTable(db, BloomTrieTablePrefix)),
}
idb := ethdb.NewTable(db, "bltIndex-") idb := ethdb.NewTable(db, "bltIndex-")
backend := &BloomTrieIndexerBackend{db: db, cdb: cdb}
var confirmReq uint64 var confirmReq uint64
if clientMode { if clientMode {
backend.parentSectionSize = BloomTrieFrequency backend.parentSectionSize = BloomTrieFrequency
@ -233,10 +241,10 @@ func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer
func (b *BloomTrieIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { func (b *BloomTrieIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error {
var root common.Hash var root common.Hash
if section > 0 { if section > 0 {
root = GetBloomTrieRoot(b.db, section-1, lastSectionHead) root = GetBloomTrieRoot(b.diskdb, section-1, lastSectionHead)
} }
var err error var err error
b.trie, err = trie.New(root, b.cdb) b.trie, err = trie.New(root, b.triedb)
b.section = section b.section = section
return err return err
} }
@ -259,7 +267,7 @@ func (b *BloomTrieIndexerBackend) Commit() error {
binary.BigEndian.PutUint64(encKey[2:10], b.section) binary.BigEndian.PutUint64(encKey[2:10], b.section)
var decomp []byte var decomp []byte
for j := uint64(0); j < b.bloomTrieRatio; j++ { for j := uint64(0); j < b.bloomTrieRatio; j++ {
data, err := core.GetBloomBits(b.db, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) data, err := core.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j])
if err != nil { if err != nil {
return err return err
} }
@ -279,17 +287,15 @@ func (b *BloomTrieIndexerBackend) Commit() error {
b.trie.Delete(encKey[:]) b.trie.Delete(encKey[:])
} }
} }
root, err := b.trie.Commit(nil)
batch := b.cdb.NewBatch()
root, err := b.trie.CommitTo(batch)
if err != nil { if err != nil {
return err return err
} else { }
batch.Write() b.triedb.Commit(root, false)
sectionHead := b.sectionHeads[b.bloomTrieRatio-1] sectionHead := b.sectionHeads[b.bloomTrieRatio-1]
log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize)) log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize))
StoreBloomTrieRoot(b.db, b.section, sectionHead, root) StoreBloomTrieRoot(b.diskdb, b.section, sectionHead, root)
}
return nil return nil
} }

@ -18,12 +18,14 @@ package light
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
) )
@ -83,6 +85,10 @@ func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, er
return len(code), err return len(code), err
} }
func (db *odrDatabase) TrieDB() *trie.Database {
return nil
}
type odrTrie struct { type odrTrie struct {
db *odrDatabase db *odrDatabase
id *TrieID id *TrieID
@ -113,11 +119,11 @@ func (t *odrTrie) TryDelete(key []byte) error {
}) })
} }
func (t *odrTrie) CommitTo(db trie.DatabaseWriter) (common.Hash, error) { func (t *odrTrie) Commit(onleaf trie.LeafCallback) (common.Hash, error) {
if t.trie == nil { if t.trie == nil {
return t.id.Root, nil return t.id.Root, nil
} }
return t.trie.CommitTo(db) return t.trie.Commit(onleaf)
} }
func (t *odrTrie) Hash() common.Hash { func (t *odrTrie) Hash() common.Hash {
@ -135,13 +141,17 @@ func (t *odrTrie) GetKey(sha []byte) []byte {
return nil return nil
} }
func (t *odrTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
return errors.New("not implemented, needs client/server interface split")
}
// do tries and retries to execute a function until it returns with no error or // do tries and retries to execute a function until it returns with no error or
// an error type other than MissingNodeError // an error type other than MissingNodeError
func (t *odrTrie) do(key []byte, fn func() error) error { func (t *odrTrie) do(key []byte, fn func() error) error {
for { for {
var err error var err error
if t.trie == nil { if t.trie == nil {
t.trie, err = trie.New(t.id.Root, t.db.backend.Database()) t.trie, err = trie.New(t.id.Root, trie.NewDatabase(t.db.backend.Database()))
} }
if err == nil { if err == nil {
err = fn() err = fn()
@ -167,7 +177,7 @@ func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator {
// Open the actual non-ODR trie if that hasn't happened yet. // Open the actual non-ODR trie if that hasn't happened yet.
if t.trie == nil { if t.trie == nil {
it.do(func() error { it.do(func() error {
t, err := trie.New(t.id.Root, t.db.backend.Database()) t, err := trie.New(t.id.Root, trie.NewDatabase(t.db.backend.Database()))
if err == nil { if err == nil {
it.t.trie = t it.t.trie = t
} }

@ -40,7 +40,7 @@ func TestNodeIterator(t *testing.T) {
genesis = gspec.MustCommit(fulldb) genesis = gspec.MustCommit(fulldb)
) )
gspec.MustCommit(lightdb) gspec.MustCommit(lightdb)
blockchain, _ := core.NewBlockChain(fulldb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) blockchain, _ := core.NewBlockChain(fulldb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), fulldb, 4, testChainGen) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), fulldb, 4, testChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil { if _, err := blockchain.InsertChain(gchain); err != nil {
panic(err) panic(err)

@ -88,7 +88,7 @@ func TestTxPool(t *testing.T) {
) )
gspec.MustCommit(ldb) gspec.MustCommit(ldb)
// Assemble the test environment // Assemble the test environment
blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) blockchain, _ := core.NewBlockChain(sdb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, poolTestBlocks, txPoolTestChainGen) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, poolTestBlocks, txPoolTestChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil { if _, err := blockchain.InsertChain(gchain); err != nil {
panic(err) panic(err)

@ -309,7 +309,7 @@ func (self *worker) wait() {
for _, log := range work.state.Logs() { for _, log := range work.state.Logs() {
log.BlockHash = block.Hash() log.BlockHash = block.Hash()
} }
stat, err := self.chain.WriteBlockAndState(block, work.receipts, work.state) stat, err := self.chain.WriteBlockWithState(block, work.receipts, work.state)
if err != nil { if err != nil {
log.Error("Failed writing block to chain", "err", err) log.Error("Failed writing block to chain", "err", err)
continue continue

@ -110,7 +110,7 @@ func (t *BlockTest) Run() error {
return fmt.Errorf("genesis block state root does not match test: computed=%x, test=%x", gblock.Root().Bytes()[:6], t.json.Genesis.StateRoot[:6]) return fmt.Errorf("genesis block state root does not match test: computed=%x, test=%x", gblock.Root().Bytes()[:6], t.json.Genesis.StateRoot[:6])
} }
chain, err := core.NewBlockChain(db, config, ethash.NewShared(), vm.Config{}) chain, err := core.NewBlockChain(db, nil, config, ethash.NewShared(), vm.Config{})
if err != nil { if err != nil {
return err return err
} }

@ -125,7 +125,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD
if !ok { if !ok {
return nil, UnsupportedForkError{subtest.Fork} return nil, UnsupportedForkError{subtest.Fork}
} }
block, _ := t.genesis(config).ToBlock() block := t.genesis(config).ToBlock(nil)
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb := MakePreState(db, t.json.Pre) statedb := MakePreState(db, t.json.Pre)
@ -147,7 +147,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD
if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) {
return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs)
} }
root, _ := statedb.CommitTo(db, config.IsEIP158(block.Number())) root, _ := statedb.Commit(config.IsEIP158(block.Number()))
if root != common.Hash(post.Root) { if root != common.Hash(post.Root) {
return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root) return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root)
} }
@ -170,7 +170,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB
} }
} }
// Commit and re-open to start with a clean state. // Commit and re-open to start with a clean state.
root, _ := statedb.CommitTo(db, false) root, _ := statedb.Commit(false)
statedb, _ = state.New(root, sdb) statedb, _ = state.New(root, sdb)
return statedb return statedb
} }

@ -0,0 +1,355 @@
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
)
// secureKeyPrefix is the database key prefix used to store trie node preimages.
var secureKeyPrefix = []byte("secure-key-")
// secureKeyLength is the length of the above prefix + 32byte hash.
const secureKeyLength = 11 + 32
// DatabaseReader wraps the Get and Has method of a backing store for the trie.
type DatabaseReader interface {
// Get retrieves the value associated with key form the database.
Get(key []byte) (value []byte, err error)
// Has retrieves whether a key is present in the database.
Has(key []byte) (bool, error)
}
// Database is an intermediate write layer between the trie data structures and
// the disk database. The aim is to accumulate trie writes in-memory and only
// periodically flush a couple tries to disk, garbage collecting the remainder.
type Database struct {
diskdb ethdb.Database // Persistent storage for matured trie nodes
nodes map[common.Hash]*cachedNode // Data and references relationships of a node
preimages map[common.Hash][]byte // Preimages of nodes from the secure trie
seckeybuf [secureKeyLength]byte // Ephemeral buffer for calculating preimage keys
gctime time.Duration // Time spent on garbage collection since last commit
gcnodes uint64 // Nodes garbage collected since last commit
gcsize common.StorageSize // Data storage garbage collected since last commit
nodesSize common.StorageSize // Storage size of the nodes cache
preimagesSize common.StorageSize // Storage size of the preimages cache
lock sync.RWMutex
}
// cachedNode is all the information we know about a single cached node in the
// memory database write layer.
type cachedNode struct {
blob []byte // Cached data block of the trie node
parents int // Number of live nodes referencing this one
children map[common.Hash]int // Children referenced by this nodes
}
// NewDatabase creates a new trie database to store ephemeral trie content before
// its written out to disk or garbage collected.
func NewDatabase(diskdb ethdb.Database) *Database {
return &Database{
diskdb: diskdb,
nodes: map[common.Hash]*cachedNode{
{}: {children: make(map[common.Hash]int)},
},
preimages: make(map[common.Hash][]byte),
}
}
// DiskDB retrieves the persistent storage backing the trie database.
func (db *Database) DiskDB() DatabaseReader {
return db.diskdb
}
// Insert writes a new trie node to the memory database if it's yet unknown. The
// method will make a copy of the slice.
func (db *Database) Insert(hash common.Hash, blob []byte) {
db.lock.Lock()
defer db.lock.Unlock()
db.insert(hash, blob)
}
// insert is the private locked version of Insert.
func (db *Database) insert(hash common.Hash, blob []byte) {
if _, ok := db.nodes[hash]; ok {
return
}
db.nodes[hash] = &cachedNode{
blob: common.CopyBytes(blob),
children: make(map[common.Hash]int),
}
db.nodesSize += common.StorageSize(common.HashLength + len(blob))
}
// insertPreimage writes a new trie node pre-image to the memory database if it's
// yet unknown. The method will make a copy of the slice.
//
// Note, this method assumes that the database's lock is held!
func (db *Database) insertPreimage(hash common.Hash, preimage []byte) {
if _, ok := db.preimages[hash]; ok {
return
}
db.preimages[hash] = common.CopyBytes(preimage)
db.preimagesSize += common.StorageSize(common.HashLength + len(preimage))
}
// Node retrieves a cached trie node from memory. If it cannot be found cached,
// the method queries the persistent database for the content.
func (db *Database) Node(hash common.Hash) ([]byte, error) {
// Retrieve the node from cache if available
db.lock.RLock()
node := db.nodes[hash]
db.lock.RUnlock()
if node != nil {
return node.blob, nil
}
// Content unavailable in memory, attempt to retrieve from disk
return db.diskdb.Get(hash[:])
}
// preimage retrieves a cached trie node pre-image from memory. If it cannot be
// found cached, the method queries the persistent database for the content.
func (db *Database) preimage(hash common.Hash) ([]byte, error) {
// Retrieve the node from cache if available
db.lock.RLock()
preimage := db.preimages[hash]
db.lock.RUnlock()
if preimage != nil {
return preimage, nil
}
// Content unavailable in memory, attempt to retrieve from disk
return db.diskdb.Get(db.secureKey(hash[:]))
}
// secureKey returns the database key for the preimage of key, as an ephemeral
// buffer. The caller must not hold onto the return value because it will become
// invalid on the next call.
func (db *Database) secureKey(key []byte) []byte {
buf := append(db.seckeybuf[:0], secureKeyPrefix...)
buf = append(buf, key...)
return buf
}
// Nodes retrieves the hashes of all the nodes cached within the memory database.
// This method is extremely expensive and should only be used to validate internal
// states in test code.
func (db *Database) Nodes() []common.Hash {
db.lock.RLock()
defer db.lock.RUnlock()
var hashes = make([]common.Hash, 0, len(db.nodes))
for hash := range db.nodes {
if hash != (common.Hash{}) { // Special case for "root" references/nodes
hashes = append(hashes, hash)
}
}
return hashes
}
// Reference adds a new reference from a parent node to a child node.
func (db *Database) Reference(child common.Hash, parent common.Hash) {
db.lock.RLock()
defer db.lock.RUnlock()
db.reference(child, parent)
}
// reference is the private locked version of Reference.
func (db *Database) reference(child common.Hash, parent common.Hash) {
// If the node does not exist, it's a node pulled from disk, skip
node, ok := db.nodes[child]
if !ok {
return
}
// If the reference already exists, only duplicate for roots
if _, ok = db.nodes[parent].children[child]; ok && parent != (common.Hash{}) {
return
}
node.parents++
db.nodes[parent].children[child]++
}
// Dereference removes an existing reference from a parent node to a child node.
func (db *Database) Dereference(child common.Hash, parent common.Hash) {
db.lock.Lock()
defer db.lock.Unlock()
nodes, storage, start := len(db.nodes), db.nodesSize, time.Now()
db.dereference(child, parent)
db.gcnodes += uint64(nodes - len(db.nodes))
db.gcsize += storage - db.nodesSize
db.gctime += time.Since(start)
log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start),
"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize)
}
// dereference is the private locked version of Dereference.
func (db *Database) dereference(child common.Hash, parent common.Hash) {
// Dereference the parent-child
node := db.nodes[parent]
node.children[child]--
if node.children[child] == 0 {
delete(node.children, child)
}
// If the node does not exist, it's a previously committed node.
node, ok := db.nodes[child]
if !ok {
return
}
// If there are no more references to the child, delete it and cascade
node.parents--
if node.parents == 0 {
for hash := range node.children {
db.dereference(hash, child)
}
delete(db.nodes, child)
db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob))
}
}
// Commit iterates over all the children of a particular node, writes them out
// to disk, forcefully tearing down all references in both directions.
//
// As a side effect, all pre-images accumulated up to this point are also written.
func (db *Database) Commit(node common.Hash, report bool) error {
// Create a database batch to flush persistent data out. It is important that
// outside code doesn't see an inconsistent state (referenced data removed from
// memory cache during commit but not yet in persistent storage). This is ensured
// by only uncaching existing data when the database write finalizes.
db.lock.RLock()
start := time.Now()
batch := db.diskdb.NewBatch()
// Move all of the accumulated preimages into a write batch
for hash, preimage := range db.preimages {
if err := batch.Put(db.secureKey(hash[:]), preimage); err != nil {
log.Error("Failed to commit preimage from trie database", "err", err)
db.lock.RUnlock()
return err
}
if batch.ValueSize() > ethdb.IdealBatchSize {
if err := batch.Write(); err != nil {
return err
}
batch.Reset()
}
}
// Move the trie itself into the batch, flushing if enough data is accumulated
nodes, storage := len(db.nodes), db.nodesSize+db.preimagesSize
if err := db.commit(node, batch); err != nil {
log.Error("Failed to commit trie from trie database", "err", err)
db.lock.RUnlock()
return err
}
// Write batch ready, unlock for readers during persistence
if err := batch.Write(); err != nil {
log.Error("Failed to write trie to disk", "err", err)
db.lock.RUnlock()
return err
}
db.lock.RUnlock()
// Write successful, clear out the flushed data
db.lock.Lock()
defer db.lock.Unlock()
db.preimages = make(map[common.Hash][]byte)
db.preimagesSize = 0
db.uncache(node)
logger := log.Info
if !report {
logger = log.Debug
}
logger("Persisted trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start),
"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize)
// Reset the garbage collection statistics
db.gcnodes, db.gcsize, db.gctime = 0, 0, 0
return nil
}
// commit is the private locked version of Commit.
func (db *Database) commit(hash common.Hash, batch ethdb.Batch) error {
// If the node does not exist, it's a previously committed node
node, ok := db.nodes[hash]
if !ok {
return nil
}
for child := range node.children {
if err := db.commit(child, batch); err != nil {
return err
}
}
if err := batch.Put(hash[:], node.blob); err != nil {
return err
}
// If we've reached an optimal match size, commit and start over
if batch.ValueSize() >= ethdb.IdealBatchSize {
if err := batch.Write(); err != nil {
return err
}
batch.Reset()
}
return nil
}
// uncache is the post-processing step of a commit operation where the already
// persisted trie is removed from the cache. The reason behind the two-phase
// commit is to ensure consistent data availability while moving from memory
// to disk.
func (db *Database) uncache(hash common.Hash) {
// If the node does not exist, we're done on this path
node, ok := db.nodes[hash]
if !ok {
return
}
// Otherwise uncache the node's subtries and remove the node itself too
for child := range node.children {
db.uncache(child)
}
delete(db.nodes, hash)
db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob))
}
// Size returns the current storage size of the memory cache in front of the
// persistent database layer.
func (db *Database) Size() common.StorageSize {
db.lock.RLock()
defer db.lock.RUnlock()
return db.nodesSize + db.preimagesSize
}

@ -29,19 +29,21 @@ import (
type hasher struct { type hasher struct {
tmp *bytes.Buffer tmp *bytes.Buffer
sha hash.Hash sha hash.Hash
cachegen, cachelimit uint16 cachegen uint16
cachelimit uint16
onleaf LeafCallback
} }
// hashers live in a global pool. // hashers live in a global db.
var hasherPool = sync.Pool{ var hasherPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()}
}, },
} }
func newHasher(cachegen, cachelimit uint16) *hasher { func newHasher(cachegen, cachelimit uint16, onleaf LeafCallback) *hasher {
h := hasherPool.Get().(*hasher) h := hasherPool.Get().(*hasher)
h.cachegen, h.cachelimit = cachegen, cachelimit h.cachegen, h.cachelimit, h.onleaf = cachegen, cachelimit, onleaf
return h return h
} }
@ -51,7 +53,7 @@ func returnHasherToPool(h *hasher) {
// hash collapses a node down into a hash node, also returning a copy of the // hash collapses a node down into a hash node, also returning a copy of the
// original node initialized with the computed hash to replace the original one. // original node initialized with the computed hash to replace the original one.
func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) {
// If we're not storing the node, just hashing, use available cached data // If we're not storing the node, just hashing, use available cached data
if hash, dirty := n.cache(); hash != nil { if hash, dirty := n.cache(); hash != nil {
if db == nil { if db == nil {
@ -98,7 +100,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
// hashChildren replaces the children of a node with their hashes if the encoded // hashChildren replaces the children of a node with their hashes if the encoded
// size of the child is larger than a hash, returning the collapsed node as well // size of the child is larger than a hash, returning the collapsed node as well
// as a replacement for the original node with the child hashes cached in. // as a replacement for the original node with the child hashes cached in.
func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { func (h *hasher) hashChildren(original node, db *Database) (node, node, error) {
var err error var err error
switch n := original.(type) { switch n := original.(type) {
@ -145,7 +147,10 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err
} }
} }
func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { // store hashes the node n and if we have a storage layer specified, it writes
// the key/value pair to it and tracks any node->child references as well as any
// node->external trie references.
func (h *hasher) store(n node, db *Database, force bool) (node, error) {
// Don't store hashes or empty nodes. // Don't store hashes or empty nodes.
if _, isHash := n.(hashNode); n == nil || isHash { if _, isHash := n.(hashNode); n == nil || isHash {
return n, nil return n, nil
@ -155,7 +160,6 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
if err := rlp.Encode(h.tmp, n); err != nil { if err := rlp.Encode(h.tmp, n); err != nil {
panic("encode error: " + err.Error()) panic("encode error: " + err.Error())
} }
if h.tmp.Len() < 32 && !force { if h.tmp.Len() < 32 && !force {
return n, nil // Nodes smaller than 32 bytes are stored inside their parent return n, nil // Nodes smaller than 32 bytes are stored inside their parent
} }
@ -167,7 +171,42 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
hash = hashNode(h.sha.Sum(nil)) hash = hashNode(h.sha.Sum(nil))
} }
if db != nil { if db != nil {
return hash, db.Put(hash, h.tmp.Bytes()) // We are pooling the trie nodes into an intermediate memory cache
db.lock.Lock()
hash := common.BytesToHash(hash)
db.insert(hash, h.tmp.Bytes())
// Track all direct parent->child node references
switch n := n.(type) {
case *shortNode:
if child, ok := n.Val.(hashNode); ok {
db.reference(common.BytesToHash(child), hash)
}
case *fullNode:
for i := 0; i < 16; i++ {
if child, ok := n.Children[i].(hashNode); ok {
db.reference(common.BytesToHash(child), hash)
}
}
}
db.lock.Unlock()
// Track external references from account->storage trie
if h.onleaf != nil {
switch n := n.(type) {
case *shortNode:
if child, ok := n.Val.(valueNode); ok {
h.onleaf(child, hash)
}
case *fullNode:
for i := 0; i < 16; i++ {
if child, ok := n.Children[i].(valueNode); ok {
h.onleaf(child, hash)
}
}
}
}
} }
return hash, nil return hash, nil
} }

@ -42,7 +42,7 @@ func TestIterator(t *testing.T) {
all[val.k] = val.v all[val.k] = val.v
trie.Update([]byte(val.k), []byte(val.v)) trie.Update([]byte(val.k), []byte(val.v))
} }
trie.Commit() trie.Commit(nil)
found := make(map[string]string) found := make(map[string]string)
it := NewIterator(trie.NodeIterator(nil)) it := NewIterator(trie.NodeIterator(nil))
@ -109,11 +109,18 @@ func TestNodeIteratorCoverage(t *testing.T) {
} }
// Cross check the hashes and the database itself // Cross check the hashes and the database itself
for hash := range hashes { for hash := range hashes {
if _, err := db.Get(hash.Bytes()); err != nil { if _, err := db.Node(hash); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", hash, err) t.Errorf("failed to retrieve reported node %x: %v", hash, err)
} }
} }
for _, key := range db.(*ethdb.MemDatabase).Keys() { for hash, obj := range db.nodes {
if obj != nil && hash != (common.Hash{}) {
if _, ok := hashes[hash]; !ok {
t.Errorf("state entry not reported %x", hash)
}
}
}
for _, key := range db.diskdb.(*ethdb.MemDatabase).Keys() {
if _, ok := hashes[common.BytesToHash(key)]; !ok { if _, ok := hashes[common.BytesToHash(key)]; !ok {
t.Errorf("state entry not reported %x", key) t.Errorf("state entry not reported %x", key)
} }
@ -191,13 +198,13 @@ func TestDifferenceIterator(t *testing.T) {
for _, val := range testdata1 { for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v)) triea.Update([]byte(val.k), []byte(val.v))
} }
triea.Commit() triea.Commit(nil)
trieb := newEmpty() trieb := newEmpty()
for _, val := range testdata2 { for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v)) trieb.Update([]byte(val.k), []byte(val.v))
} }
trieb.Commit() trieb.Commit(nil)
found := make(map[string]string) found := make(map[string]string)
di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
@ -227,13 +234,13 @@ func TestUnionIterator(t *testing.T) {
for _, val := range testdata1 { for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v)) triea.Update([]byte(val.k), []byte(val.v))
} }
triea.Commit() triea.Commit(nil)
trieb := newEmpty() trieb := newEmpty()
for _, val := range testdata2 { for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v)) trieb.Update([]byte(val.k), []byte(val.v))
} }
trieb.Commit() trieb.Commit(nil)
di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
it := NewIterator(di) it := NewIterator(di)
@ -278,43 +285,75 @@ func TestIteratorNoDups(t *testing.T) {
} }
// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes. // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
func TestIteratorContinueAfterError(t *testing.T) { func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueAfterError(t, false) }
db, _ := ethdb.NewMemDatabase() func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
tr, _ := New(common.Hash{}, db)
func testIteratorContinueAfterError(t *testing.T, memonly bool) {
diskdb, _ := ethdb.NewMemDatabase()
triedb := NewDatabase(diskdb)
tr, _ := New(common.Hash{}, triedb)
for _, val := range testdata1 { for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v)) tr.Update([]byte(val.k), []byte(val.v))
} }
tr.Commit() tr.Commit(nil)
if !memonly {
triedb.Commit(tr.Hash(), true)
}
wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil) wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
keys := db.Keys()
t.Log("node count", wantNodeCount)
var (
diskKeys [][]byte
memKeys []common.Hash
)
if memonly {
memKeys = triedb.Nodes()
} else {
diskKeys = diskdb.Keys()
}
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
// Create trie that will load all nodes from DB. // Create trie that will load all nodes from DB.
tr, _ := New(tr.Hash(), db) tr, _ := New(tr.Hash(), triedb)
// Remove a random node from the database. It can't be the root node // Remove a random node from the database. It can't be the root node
// because that one is already loaded. // because that one is already loaded.
var rkey []byte var (
rkey common.Hash
rval []byte
robj *cachedNode
)
for { for {
if rkey = keys[rand.Intn(len(keys))]; !bytes.Equal(rkey, tr.Hash().Bytes()) { if memonly {
rkey = memKeys[rand.Intn(len(memKeys))]
} else {
copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
}
if rkey != tr.Hash() {
break break
} }
} }
rval, _ := db.Get(rkey) if memonly {
db.Delete(rkey) robj = triedb.nodes[rkey]
delete(triedb.nodes, rkey)
} else {
rval, _ = diskdb.Get(rkey[:])
diskdb.Delete(rkey[:])
}
// Iterate until the error is hit. // Iterate until the error is hit.
seen := make(map[string]bool) seen := make(map[string]bool)
it := tr.NodeIterator(nil) it := tr.NodeIterator(nil)
checkIteratorNoDups(t, it, seen) checkIteratorNoDups(t, it, seen)
missing, ok := it.Error().(*MissingNodeError) missing, ok := it.Error().(*MissingNodeError)
if !ok || !bytes.Equal(missing.NodeHash[:], rkey) { if !ok || missing.NodeHash != rkey {
t.Fatal("didn't hit missing node, got", it.Error()) t.Fatal("didn't hit missing node, got", it.Error())
} }
// Add the node back and continue iteration. // Add the node back and continue iteration.
db.Put(rkey, rval) if memonly {
triedb.nodes[rkey] = robj
} else {
diskdb.Put(rkey[:], rval)
}
checkIteratorNoDups(t, it, seen) checkIteratorNoDups(t, it, seen)
if it.Error() != nil { if it.Error() != nil {
t.Fatal("unexpected error", it.Error()) t.Fatal("unexpected error", it.Error())
@ -328,21 +367,41 @@ func TestIteratorContinueAfterError(t *testing.T) {
// Similar to the test above, this one checks that failure to create nodeIterator at a // Similar to the test above, this one checks that failure to create nodeIterator at a
// certain key prefix behaves correctly when Next is called. The expectation is that Next // certain key prefix behaves correctly when Next is called. The expectation is that Next
// should retry seeking before returning true for the first time. // should retry seeking before returning true for the first time.
func TestIteratorContinueAfterSeekError(t *testing.T) { func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
testIteratorContinueAfterSeekError(t, false)
}
func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
testIteratorContinueAfterSeekError(t, true)
}
func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
// Commit test trie to db, then remove the node containing "bars". // Commit test trie to db, then remove the node containing "bars".
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
ctr, _ := New(common.Hash{}, db) triedb := NewDatabase(diskdb)
ctr, _ := New(common.Hash{}, triedb)
for _, val := range testdata1 { for _, val := range testdata1 {
ctr.Update([]byte(val.k), []byte(val.v)) ctr.Update([]byte(val.k), []byte(val.v))
} }
root, _ := ctr.Commit() root, _ := ctr.Commit(nil)
if !memonly {
triedb.Commit(root, true)
}
barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e") barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
barNode, _ := db.Get(barNodeHash[:]) var (
db.Delete(barNodeHash[:]) barNodeBlob []byte
barNodeObj *cachedNode
)
if memonly {
barNodeObj = triedb.nodes[barNodeHash]
delete(triedb.nodes, barNodeHash)
} else {
barNodeBlob, _ = diskdb.Get(barNodeHash[:])
diskdb.Delete(barNodeHash[:])
}
// Create a new iterator that seeks to "bars". Seeking can't proceed because // Create a new iterator that seeks to "bars". Seeking can't proceed because
// the node is missing. // the node is missing.
tr, _ := New(root, db) tr, _ := New(root, triedb)
it := tr.NodeIterator([]byte("bars")) it := tr.NodeIterator([]byte("bars"))
missing, ok := it.Error().(*MissingNodeError) missing, ok := it.Error().(*MissingNodeError)
if !ok { if !ok {
@ -350,10 +409,12 @@ func TestIteratorContinueAfterSeekError(t *testing.T) {
} else if missing.NodeHash != barNodeHash { } else if missing.NodeHash != barNodeHash {
t.Fatal("wrong node missing") t.Fatal("wrong node missing")
} }
// Reinsert the missing node. // Reinsert the missing node.
db.Put(barNodeHash[:], barNode[:]) if memonly {
triedb.nodes[barNodeHash] = barNodeObj
} else {
diskdb.Put(barNodeHash[:], barNodeBlob)
}
// Check that iteration produces the right set of values. // Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil { if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
t.Fatal(err) t.Fatal(err)

@ -22,20 +22,19 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
// Prove constructs a merkle proof for key. The result contains all // Prove constructs a merkle proof for key. The result contains all encoded nodes
// encoded nodes on the path to the value at key. The value itself is // on the path to the value at key. The value itself is also included in the last
// also included in the last node and can be retrieved by verifying // node and can be retrieved by verifying the proof.
// the proof.
// //
// If the trie does not contain a value for key, the returned proof // If the trie does not contain a value for key, the returned proof contains all
// contains all nodes of the longest existing prefix of the key // nodes of the longest existing prefix of the key (at least the root node), ending
// (at least the root node), ending with the node that proves the // with the node that proves the absence of the key.
// absence of the key. func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
// Collect all nodes on the path to key. // Collect all nodes on the path to key.
key = keybytesToHex(key) key = keybytesToHex(key)
nodes := []node{} nodes := []node{}
@ -66,7 +65,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
} }
} }
hasher := newHasher(0, 0) hasher := newHasher(0, 0, nil)
for i, n := range nodes { for i, n := range nodes {
// Don't bother checking for errors here since hasher panics // Don't bother checking for errors here since hasher panics
// if encoding doesn't work and we're not writing to any database. // if encoding doesn't work and we're not writing to any database.
@ -89,19 +88,29 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
return nil return nil
} }
// VerifyProof checks merkle proofs. The given proof must contain the // Prove constructs a merkle proof for key. The result contains all encoded nodes
// value for key in a trie with the given root hash. VerifyProof // on the path to the value at key. The value itself is also included in the last
// returns an error if the proof contains invalid trie nodes or the // node and can be retrieved by verifying the proof.
// wrong value. //
// If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key.
func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
return t.trie.Prove(key, fromLevel, proofDb)
}
// VerifyProof checks merkle proofs. The given proof must contain the value for
// key in a trie with the given root hash. VerifyProof returns an error if the
// proof contains invalid trie nodes or the wrong value.
func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) { func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) {
key = keybytesToHex(key) key = keybytesToHex(key)
wantHash := rootHash[:] wantHash := rootHash
for i := 0; ; i++ { for i := 0; ; i++ {
buf, _ := proofDb.Get(wantHash) buf, _ := proofDb.Get(wantHash[:])
if buf == nil { if buf == nil {
return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash), i
} }
n, err := decodeNode(wantHash, buf, 0) n, err := decodeNode(wantHash[:], buf, 0)
if err != nil { if err != nil {
return nil, fmt.Errorf("bad proof node %d: %v", i, err), i return nil, fmt.Errorf("bad proof node %d: %v", i, err), i
} }
@ -112,7 +121,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (valu
return nil, nil, i return nil, nil, i
case hashNode: case hashNode:
key = keyrest key = keyrest
wantHash = cld copy(wantHash[:], cld)
case valueNode: case valueNode:
return cld, nil, i + 1 return cld, nil, i + 1
} }

@ -23,10 +23,6 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
) )
var secureKeyPrefix = []byte("secure-key-")
const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash
// SecureTrie wraps a trie with key hashing. In a secure trie, all // SecureTrie wraps a trie with key hashing. In a secure trie, all
// access operations hash the key using keccak256. This prevents // access operations hash the key using keccak256. This prevents
// calling code from creating long chains of nodes that // calling code from creating long chains of nodes that
@ -39,25 +35,25 @@ const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash
// SecureTrie is not safe for concurrent use. // SecureTrie is not safe for concurrent use.
type SecureTrie struct { type SecureTrie struct {
trie Trie trie Trie
hashKeyBuf [secureKeyLength]byte hashKeyBuf [common.HashLength]byte
secKeyBuf [200]byte
secKeyCache map[string][]byte secKeyCache map[string][]byte
secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch
} }
// NewSecure creates a trie with an existing root node from db. // NewSecure creates a trie with an existing root node from a backing database
// and optional intermediate in-memory node pool.
// //
// If root is the zero hash or the sha3 hash of an empty string, the // If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil // trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found. // and returns MissingNodeError if the root node cannot be found.
// //
// Accessing the trie loads nodes from db on demand. // Accessing the trie loads nodes from the database or node pool on demand.
// Loaded nodes are kept around until their 'cache generation' expires. // Loaded nodes are kept around until their 'cache generation' expires.
// A new cache generation is created by each call to Commit. // A new cache generation is created by each call to Commit.
// cachelimit sets the number of past cache generations to keep. // cachelimit sets the number of past cache generations to keep.
func NewSecure(root common.Hash, db Database, cachelimit uint16) (*SecureTrie, error) { func NewSecure(root common.Hash, db *Database, cachelimit uint16) (*SecureTrie, error) {
if db == nil { if db == nil {
panic("NewSecure called with nil database") panic("trie.NewSecure called without a database")
} }
trie, err := New(root, db) trie, err := New(root, db)
if err != nil { if err != nil {
@ -135,7 +131,7 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
return key return key
} }
key, _ := t.trie.db.Get(t.secKey(shaKey)) key, _ := t.trie.db.preimage(common.BytesToHash(shaKey))
return key return key
} }
@ -144,8 +140,19 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
// //
// Committing flushes nodes from memory. Subsequent Get calls will load nodes // Committing flushes nodes from memory. Subsequent Get calls will load nodes
// from the database. // from the database.
func (t *SecureTrie) Commit() (root common.Hash, err error) { func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
return t.CommitTo(t.trie.db) // Write all the pre-images to the actual disk database
if len(t.getSecKeyCache()) > 0 {
t.trie.db.lock.Lock()
for hk, key := range t.secKeyCache {
t.trie.db.insertPreimage(common.BytesToHash([]byte(hk)), key)
}
t.trie.db.lock.Unlock()
t.secKeyCache = make(map[string][]byte)
}
// Commit the trie to its intermediate node database
return t.trie.Commit(onleaf)
} }
func (t *SecureTrie) Hash() common.Hash { func (t *SecureTrie) Hash() common.Hash {
@ -167,38 +174,11 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator {
return t.trie.NodeIterator(start) return t.trie.NodeIterator(start)
} }
// CommitTo writes all nodes and the secure hash pre-images to the given database.
// Nodes are stored with their sha3 hash as the key.
//
// Committing flushes nodes from memory. Subsequent Get calls will load nodes from
// the trie's database. Calling code must ensure that the changes made to db are
// written back to the trie's attached database before using the trie.
func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
if len(t.getSecKeyCache()) > 0 {
for hk, key := range t.secKeyCache {
if err := db.Put(t.secKey([]byte(hk)), key); err != nil {
return common.Hash{}, err
}
}
t.secKeyCache = make(map[string][]byte)
}
return t.trie.CommitTo(db)
}
// secKey returns the database key for the preimage of key, as an ephemeral buffer.
// The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey.
func (t *SecureTrie) secKey(key []byte) []byte {
buf := append(t.secKeyBuf[:0], secureKeyPrefix...)
buf = append(buf, key...)
return buf
}
// hashKey returns the hash of key as an ephemeral buffer. // hashKey returns the hash of key as an ephemeral buffer.
// The caller must not hold onto the return value because it will become // The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey. // invalid on the next call to hashKey or secKey.
func (t *SecureTrie) hashKey(key []byte) []byte { func (t *SecureTrie) hashKey(key []byte) []byte {
h := newHasher(0, 0) h := newHasher(0, 0, nil)
h.sha.Reset() h.sha.Reset()
h.sha.Write(key) h.sha.Write(key)
buf := h.sha.Sum(t.hashKeyBuf[:0]) buf := h.sha.Sum(t.hashKeyBuf[:0])

@ -28,16 +28,20 @@ import (
) )
func newEmptySecure() *SecureTrie { func newEmptySecure() *SecureTrie {
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
trie, _ := NewSecure(common.Hash{}, db, 0) triedb := NewDatabase(diskdb)
trie, _ := NewSecure(common.Hash{}, triedb, 0)
return trie return trie
} }
// makeTestSecureTrie creates a large enough secure trie for testing. // makeTestSecureTrie creates a large enough secure trie for testing.
func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) { func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) {
// Create an empty trie // Create an empty trie
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
trie, _ := NewSecure(common.Hash{}, db, 0) triedb := NewDatabase(diskdb)
trie, _ := NewSecure(common.Hash{}, triedb, 0)
// Fill it with some arbitrary data // Fill it with some arbitrary data
content := make(map[string][]byte) content := make(map[string][]byte)
@ -58,10 +62,10 @@ func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) {
trie.Update(key, val) trie.Update(key, val)
} }
} }
trie.Commit() trie.Commit(nil)
// Return the generated trie // Return the generated trie
return db, trie, content return triedb, trie, content
} }
func TestSecureDelete(t *testing.T) { func TestSecureDelete(t *testing.T) {
@ -137,7 +141,7 @@ func TestSecureTrieConcurrency(t *testing.T) {
tries[index].Update(key, val) tries[index].Update(key, val)
} }
} }
tries[index].Commit() tries[index].Commit(nil)
}(i) }(i)
} }
// Wait for all threads to finish // Wait for all threads to finish

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
"gopkg.in/karalabe/cookiejar.v2/collections/prque" "gopkg.in/karalabe/cookiejar.v2/collections/prque"
) )
@ -42,7 +43,7 @@ type request struct {
depth int // Depth level within the trie the node is located to prioritise DFS depth int // Depth level within the trie the node is located to prioritise DFS
deps int // Number of dependencies before allowed to commit this node deps int // Number of dependencies before allowed to commit this node
callback TrieSyncLeafCallback // Callback to invoke if a leaf node it reached on this branch callback LeafCallback // Callback to invoke if a leaf node it reached on this branch
} }
// SyncResult is a simple list to return missing nodes along with their request // SyncResult is a simple list to return missing nodes along with their request
@ -67,11 +68,6 @@ func newSyncMemBatch() *syncMemBatch {
} }
} }
// TrieSyncLeafCallback is a callback type invoked when a trie sync reaches a
// leaf node. It's used by state syncing to check if the leaf node requires some
// further data syncing.
type TrieSyncLeafCallback func(leaf []byte, parent common.Hash) error
// TrieSync is the main state trie synchronisation scheduler, which provides yet // TrieSync is the main state trie synchronisation scheduler, which provides yet
// unknown trie hashes to retrieve, accepts node data associated with said hashes // unknown trie hashes to retrieve, accepts node data associated with said hashes
// and reconstructs the trie step by step until all is done. // and reconstructs the trie step by step until all is done.
@ -83,7 +79,7 @@ type TrieSync struct {
} }
// NewTrieSync creates a new trie data download scheduler. // NewTrieSync creates a new trie data download scheduler.
func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLeafCallback) *TrieSync { func NewTrieSync(root common.Hash, database DatabaseReader, callback LeafCallback) *TrieSync {
ts := &TrieSync{ ts := &TrieSync{
database: database, database: database,
membatch: newSyncMemBatch(), membatch: newSyncMemBatch(),
@ -95,7 +91,7 @@ func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLea
} }
// AddSubTrie registers a new trie to the sync code, rooted at the designated parent. // AddSubTrie registers a new trie to the sync code, rooted at the designated parent.
func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback TrieSyncLeafCallback) { func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback LeafCallback) {
// Short circuit if the trie is empty or already known // Short circuit if the trie is empty or already known
if root == emptyRoot { if root == emptyRoot {
return return
@ -217,7 +213,7 @@ func (s *TrieSync) Process(results []SyncResult) (bool, int, error) {
// Commit flushes the data stored in the internal membatch out to persistent // Commit flushes the data stored in the internal membatch out to persistent
// storage, returning th enumber of items written and any occurred error. // storage, returning th enumber of items written and any occurred error.
func (s *TrieSync) Commit(dbw DatabaseWriter) (int, error) { func (s *TrieSync) Commit(dbw ethdb.Putter) (int, error) {
// Dump the membatch into a database dbw // Dump the membatch into a database dbw
for i, key := range s.membatch.order { for i, key := range s.membatch.order {
if err := dbw.Put(key[:], s.membatch.batch[key]); err != nil { if err := dbw.Put(key[:], s.membatch.batch[key]); err != nil {

@ -25,10 +25,11 @@ import (
) )
// makeTestTrie create a sample test trie to test node-wise reconstruction. // makeTestTrie create a sample test trie to test node-wise reconstruction.
func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { func makeTestTrie() (*Database, *Trie, map[string][]byte) {
// Create an empty trie // Create an empty trie
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
trie, _ := New(common.Hash{}, db) triedb := NewDatabase(diskdb)
trie, _ := New(common.Hash{}, triedb)
// Fill it with some arbitrary data // Fill it with some arbitrary data
content := make(map[string][]byte) content := make(map[string][]byte)
@ -49,15 +50,15 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) {
trie.Update(key, val) trie.Update(key, val)
} }
} }
trie.Commit() trie.Commit(nil)
// Return the generated trie // Return the generated trie
return db, trie, content return triedb, trie, content
} }
// checkTrieContents cross references a reconstructed trie with an expected data // checkTrieContents cross references a reconstructed trie with an expected data
// content map. // content map.
func checkTrieContents(t *testing.T, db Database, root []byte, content map[string][]byte) { func checkTrieContents(t *testing.T, db *Database, root []byte, content map[string][]byte) {
// Check root availability and trie contents // Check root availability and trie contents
trie, err := New(common.BytesToHash(root), db) trie, err := New(common.BytesToHash(root), db)
if err != nil { if err != nil {
@ -74,7 +75,7 @@ func checkTrieContents(t *testing.T, db Database, root []byte, content map[strin
} }
// checkTrieConsistency checks that all nodes in a trie are indeed present. // checkTrieConsistency checks that all nodes in a trie are indeed present.
func checkTrieConsistency(db Database, root common.Hash) error { func checkTrieConsistency(db *Database, root common.Hash) error {
// Create and iterate a trie rooted in a subnode // Create and iterate a trie rooted in a subnode
trie, err := New(root, db) trie, err := New(root, db)
if err != nil { if err != nil {
@ -88,12 +89,18 @@ func checkTrieConsistency(db Database, root common.Hash) error {
// Tests that an empty trie is not scheduled for syncing. // Tests that an empty trie is not scheduled for syncing.
func TestEmptyTrieSync(t *testing.T) { func TestEmptyTrieSync(t *testing.T) {
emptyA, _ := New(common.Hash{}, nil) diskdbA, _ := ethdb.NewMemDatabase()
emptyB, _ := New(emptyRoot, nil) triedbA := NewDatabase(diskdbA)
diskdbB, _ := ethdb.NewMemDatabase()
triedbB := NewDatabase(diskdbB)
emptyA, _ := New(common.Hash{}, triedbA)
emptyB, _ := New(emptyRoot, triedbB)
for i, trie := range []*Trie{emptyA, emptyB} { for i, trie := range []*Trie{emptyA, emptyB} {
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
if req := NewTrieSync(common.BytesToHash(trie.Root()), db, nil).Missing(1); len(req) != 0 { if req := NewTrieSync(trie.Hash(), diskdb, nil).Missing(1); len(req) != 0 {
t.Errorf("test %d: content requested for empty trie: %v", i, req) t.Errorf("test %d: content requested for empty trie: %v", i, req)
} }
} }
@ -109,14 +116,15 @@ func testIterativeTrieSync(t *testing.T, batch int) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) triedb := NewDatabase(diskdb)
sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := append([]common.Hash{}, sched.Missing(batch)...) queue := append([]common.Hash{}, sched.Missing(batch)...)
for len(queue) > 0 { for len(queue) > 0 {
results := make([]SyncResult, len(queue)) results := make([]SyncResult, len(queue))
for i, hash := range queue { for i, hash := range queue {
data, err := srcDb.Get(hash.Bytes()) data, err := srcDb.Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -125,13 +133,13 @@ func testIterativeTrieSync(t *testing.T, batch int) {
if _, index, err := sched.Process(results); err != nil { if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err) t.Fatalf("failed to process result #%d: %v", index, err)
} }
if index, err := sched.Commit(dstDb); err != nil { if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err) t.Fatalf("failed to commit data #%d: %v", index, err)
} }
queue = append(queue[:0], sched.Missing(batch)...) queue = append(queue[:0], sched.Missing(batch)...)
} }
// Cross check that the two tries are in sync // Cross check that the two tries are in sync
checkTrieContents(t, dstDb, srcTrie.Root(), srcData) checkTrieContents(t, triedb, srcTrie.Root(), srcData)
} }
// Tests that the trie scheduler can correctly reconstruct the state even if only // Tests that the trie scheduler can correctly reconstruct the state even if only
@ -141,15 +149,16 @@ func TestIterativeDelayedTrieSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) triedb := NewDatabase(diskdb)
sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := append([]common.Hash{}, sched.Missing(10000)...) queue := append([]common.Hash{}, sched.Missing(10000)...)
for len(queue) > 0 { for len(queue) > 0 {
// Sync only half of the scheduled nodes // Sync only half of the scheduled nodes
results := make([]SyncResult, len(queue)/2+1) results := make([]SyncResult, len(queue)/2+1)
for i, hash := range queue[:len(results)] { for i, hash := range queue[:len(results)] {
data, err := srcDb.Get(hash.Bytes()) data, err := srcDb.Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -158,13 +167,13 @@ func TestIterativeDelayedTrieSync(t *testing.T) {
if _, index, err := sched.Process(results); err != nil { if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err) t.Fatalf("failed to process result #%d: %v", index, err)
} }
if index, err := sched.Commit(dstDb); err != nil { if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err) t.Fatalf("failed to commit data #%d: %v", index, err)
} }
queue = append(queue[len(results):], sched.Missing(10000)...) queue = append(queue[len(results):], sched.Missing(10000)...)
} }
// Cross check that the two tries are in sync // Cross check that the two tries are in sync
checkTrieContents(t, dstDb, srcTrie.Root(), srcData) checkTrieContents(t, triedb, srcTrie.Root(), srcData)
} }
// Tests that given a root hash, a trie can sync iteratively on a single thread, // Tests that given a root hash, a trie can sync iteratively on a single thread,
@ -178,8 +187,9 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) triedb := NewDatabase(diskdb)
sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := make(map[common.Hash]struct{}) queue := make(map[common.Hash]struct{})
for _, hash := range sched.Missing(batch) { for _, hash := range sched.Missing(batch) {
@ -189,7 +199,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
// Fetch all the queued nodes in a random order // Fetch all the queued nodes in a random order
results := make([]SyncResult, 0, len(queue)) results := make([]SyncResult, 0, len(queue))
for hash := range queue { for hash := range queue {
data, err := srcDb.Get(hash.Bytes()) data, err := srcDb.Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -199,7 +209,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
if _, index, err := sched.Process(results); err != nil { if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err) t.Fatalf("failed to process result #%d: %v", index, err)
} }
if index, err := sched.Commit(dstDb); err != nil { if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err) t.Fatalf("failed to commit data #%d: %v", index, err)
} }
queue = make(map[common.Hash]struct{}) queue = make(map[common.Hash]struct{})
@ -208,7 +218,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
} }
} }
// Cross check that the two tries are in sync // Cross check that the two tries are in sync
checkTrieContents(t, dstDb, srcTrie.Root(), srcData) checkTrieContents(t, triedb, srcTrie.Root(), srcData)
} }
// Tests that the trie scheduler can correctly reconstruct the state even if only // Tests that the trie scheduler can correctly reconstruct the state even if only
@ -218,8 +228,9 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) triedb := NewDatabase(diskdb)
sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := make(map[common.Hash]struct{}) queue := make(map[common.Hash]struct{})
for _, hash := range sched.Missing(10000) { for _, hash := range sched.Missing(10000) {
@ -229,7 +240,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
// Sync only half of the scheduled nodes, even those in random order // Sync only half of the scheduled nodes, even those in random order
results := make([]SyncResult, 0, len(queue)/2+1) results := make([]SyncResult, 0, len(queue)/2+1)
for hash := range queue { for hash := range queue {
data, err := srcDb.Get(hash.Bytes()) data, err := srcDb.Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -243,7 +254,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
if _, index, err := sched.Process(results); err != nil { if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err) t.Fatalf("failed to process result #%d: %v", index, err)
} }
if index, err := sched.Commit(dstDb); err != nil { if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err) t.Fatalf("failed to commit data #%d: %v", index, err)
} }
for _, result := range results { for _, result := range results {
@ -254,7 +265,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
} }
} }
// Cross check that the two tries are in sync // Cross check that the two tries are in sync
checkTrieContents(t, dstDb, srcTrie.Root(), srcData) checkTrieContents(t, triedb, srcTrie.Root(), srcData)
} }
// Tests that a trie sync will not request nodes multiple times, even if they // Tests that a trie sync will not request nodes multiple times, even if they
@ -264,8 +275,9 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
srcDb, srcTrie, srcData := makeTestTrie() srcDb, srcTrie, srcData := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) triedb := NewDatabase(diskdb)
sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
queue := append([]common.Hash{}, sched.Missing(0)...) queue := append([]common.Hash{}, sched.Missing(0)...)
requested := make(map[common.Hash]struct{}) requested := make(map[common.Hash]struct{})
@ -273,7 +285,7 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
for len(queue) > 0 { for len(queue) > 0 {
results := make([]SyncResult, len(queue)) results := make([]SyncResult, len(queue))
for i, hash := range queue { for i, hash := range queue {
data, err := srcDb.Get(hash.Bytes()) data, err := srcDb.Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -287,13 +299,13 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
if _, index, err := sched.Process(results); err != nil { if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err) t.Fatalf("failed to process result #%d: %v", index, err)
} }
if index, err := sched.Commit(dstDb); err != nil { if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err) t.Fatalf("failed to commit data #%d: %v", index, err)
} }
queue = append(queue[:0], sched.Missing(0)...) queue = append(queue[:0], sched.Missing(0)...)
} }
// Cross check that the two tries are in sync // Cross check that the two tries are in sync
checkTrieContents(t, dstDb, srcTrie.Root(), srcData) checkTrieContents(t, triedb, srcTrie.Root(), srcData)
} }
// Tests that at any point in time during a sync, only complete sub-tries are in // Tests that at any point in time during a sync, only complete sub-tries are in
@ -303,8 +315,9 @@ func TestIncompleteTrieSync(t *testing.T) {
srcDb, srcTrie, _ := makeTestTrie() srcDb, srcTrie, _ := makeTestTrie()
// Create a destination trie and sync with the scheduler // Create a destination trie and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) triedb := NewDatabase(diskdb)
sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
added := []common.Hash{} added := []common.Hash{}
queue := append([]common.Hash{}, sched.Missing(1)...) queue := append([]common.Hash{}, sched.Missing(1)...)
@ -312,7 +325,7 @@ func TestIncompleteTrieSync(t *testing.T) {
// Fetch a batch of trie nodes // Fetch a batch of trie nodes
results := make([]SyncResult, len(queue)) results := make([]SyncResult, len(queue))
for i, hash := range queue { for i, hash := range queue {
data, err := srcDb.Get(hash.Bytes()) data, err := srcDb.Node(hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -322,7 +335,7 @@ func TestIncompleteTrieSync(t *testing.T) {
if _, index, err := sched.Process(results); err != nil { if _, index, err := sched.Process(results); err != nil {
t.Fatalf("failed to process result #%d: %v", index, err) t.Fatalf("failed to process result #%d: %v", index, err)
} }
if index, err := sched.Commit(dstDb); err != nil { if index, err := sched.Commit(diskdb); err != nil {
t.Fatalf("failed to commit data #%d: %v", index, err) t.Fatalf("failed to commit data #%d: %v", index, err)
} }
for _, result := range results { for _, result := range results {
@ -330,7 +343,7 @@ func TestIncompleteTrieSync(t *testing.T) {
} }
// Check that all known sub-tries in the synced trie are complete // Check that all known sub-tries in the synced trie are complete
for _, root := range added { for _, root := range added {
if err := checkTrieConsistency(dstDb, root); err != nil { if err := checkTrieConsistency(triedb, root); err != nil {
t.Fatalf("trie inconsistent: %v", err) t.Fatalf("trie inconsistent: %v", err)
} }
} }
@ -340,12 +353,12 @@ func TestIncompleteTrieSync(t *testing.T) {
// Sanity check that removing any node from the database is detected // Sanity check that removing any node from the database is detected
for _, node := range added[1:] { for _, node := range added[1:] {
key := node.Bytes() key := node.Bytes()
value, _ := dstDb.Get(key) value, _ := diskdb.Get(key)
dstDb.Delete(key) diskdb.Delete(key)
if err := checkTrieConsistency(dstDb, added[0]); err == nil { if err := checkTrieConsistency(triedb, added[0]); err == nil {
t.Fatalf("trie inconsistency not caught, missing: %x", key) t.Fatalf("trie inconsistency not caught, missing: %x", key)
} }
dstDb.Put(key, value) diskdb.Put(key, value)
} }
} }

@ -22,16 +22,17 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
) )
var ( var (
// This is the known root hash of an empty trie. // emptyRoot is the known root hash of an empty trie.
emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
// This is the known hash of an empty state trie entry.
emptyState common.Hash // emptyState is the known hash of an empty state trie entry.
emptyState = crypto.Keccak256Hash(nil)
) )
var ( var (
@ -53,29 +54,10 @@ func CacheUnloads() int64 {
return cacheUnloadCounter.Count() return cacheUnloadCounter.Count()
} }
func init() { // LeafCallback is a callback type invoked when a trie operation reaches a leaf
sha3.NewKeccak256().Sum(emptyState[:0]) // node. It's used by state sync and commit to allow handling external references
} // between account and storage tries.
type LeafCallback func(leaf []byte, parent common.Hash) error
// Database must be implemented by backing stores for the trie.
type Database interface {
DatabaseReader
DatabaseWriter
}
// DatabaseReader wraps the Get method of a backing store for the trie.
type DatabaseReader interface {
Get(key []byte) (value []byte, err error)
Has(key []byte) (bool, error)
}
// DatabaseWriter wraps the Put method of a backing store for the trie.
type DatabaseWriter interface {
// Put stores the mapping key->value in the database.
// Implementations must not hold onto the value bytes, the trie
// will reuse the slice across calls to Put.
Put(key, value []byte) error
}
// Trie is a Merkle Patricia Trie. // Trie is a Merkle Patricia Trie.
// The zero value is an empty trie with no database. // The zero value is an empty trie with no database.
@ -83,8 +65,8 @@ type DatabaseWriter interface {
// //
// Trie is not safe for concurrent use. // Trie is not safe for concurrent use.
type Trie struct { type Trie struct {
db *Database
root node root node
db Database
originalRoot common.Hash originalRoot common.Hash
// Cache generation values. // Cache generation values.
@ -111,12 +93,15 @@ func (t *Trie) newFlag() nodeFlag {
// trie is initially empty and does not require a database. Otherwise, // trie is initially empty and does not require a database. Otherwise,
// New will panic if db is nil and returns a MissingNodeError if root does // New will panic if db is nil and returns a MissingNodeError if root does
// not exist in the database. Accessing the trie loads nodes from db on demand. // not exist in the database. Accessing the trie loads nodes from db on demand.
func New(root common.Hash, db Database) (*Trie, error) { func New(root common.Hash, db *Database) (*Trie, error) {
trie := &Trie{db: db, originalRoot: root}
if (root != common.Hash{}) && root != emptyRoot {
if db == nil { if db == nil {
panic("trie.New: cannot use existing root without a database") panic("trie.New called without a database")
}
trie := &Trie{
db: db,
originalRoot: root,
} }
if (root != common.Hash{}) && root != emptyRoot {
rootnode, err := trie.resolveHash(root[:], nil) rootnode, err := trie.resolveHash(root[:], nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -447,12 +432,13 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) {
func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
cacheMissCounter.Inc(1) cacheMissCounter.Inc(1)
enc, err := t.db.Get(n) hash := common.BytesToHash(n)
enc, err := t.db.Node(hash)
if err != nil || enc == nil { if err != nil || enc == nil {
return nil, &MissingNodeError{NodeHash: common.BytesToHash(n), Path: prefix} return nil, &MissingNodeError{NodeHash: hash, Path: prefix}
} }
dec := mustDecodeNode(n, enc, t.cachegen) return mustDecodeNode(n, enc, t.cachegen), nil
return dec, nil
} }
// Root returns the root hash of the trie. // Root returns the root hash of the trie.
@ -462,32 +448,18 @@ func (t *Trie) Root() []byte { return t.Hash().Bytes() }
// Hash returns the root hash of the trie. It does not write to the // Hash returns the root hash of the trie. It does not write to the
// database and can be used even if the trie doesn't have one. // database and can be used even if the trie doesn't have one.
func (t *Trie) Hash() common.Hash { func (t *Trie) Hash() common.Hash {
hash, cached, _ := t.hashRoot(nil) hash, cached, _ := t.hashRoot(nil, nil)
t.root = cached t.root = cached
return common.BytesToHash(hash.(hashNode)) return common.BytesToHash(hash.(hashNode))
} }
// Commit writes all nodes to the trie's database. // Commit writes all nodes to the trie's memory database, tracking the internal
// Nodes are stored with their sha3 hash as the key. // and external (for account tries) references.
// func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
// Committing flushes nodes from memory.
// Subsequent Get calls will load nodes from the database.
func (t *Trie) Commit() (root common.Hash, err error) {
if t.db == nil { if t.db == nil {
panic("Commit called on trie with nil database") panic("commit called on trie with nil database")
}
return t.CommitTo(t.db)
} }
hash, cached, err := t.hashRoot(t.db, onleaf)
// CommitTo writes all nodes to the given database.
// Nodes are stored with their sha3 hash as the key.
//
// Committing flushes nodes from memory. Subsequent Get calls will
// load nodes from the trie's database. Calling code must ensure that
// the changes made to db are written back to the trie's attached
// database before using the trie.
func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
hash, cached, err := t.hashRoot(db)
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
@ -496,11 +468,11 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
return common.BytesToHash(hash.(hashNode)), nil return common.BytesToHash(hash.(hashNode)), nil
} }
func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) {
if t.root == nil { if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil, nil return hashNode(emptyRoot.Bytes()), nil, nil
} }
h := newHasher(t.cachegen, t.cachelimit) h := newHasher(t.cachegen, t.cachelimit, onleaf)
defer returnHasherToPool(h) defer returnHasherToPool(h)
return h.hash(t.root, db, true) return h.hash(t.root, db, true)
} }

@ -43,8 +43,8 @@ func init() {
// Used for testing // Used for testing
func newEmpty() *Trie { func newEmpty() *Trie {
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
trie, _ := New(common.Hash{}, db) trie, _ := New(common.Hash{}, NewDatabase(diskdb))
return trie return trie
} }
@ -68,8 +68,8 @@ func TestNull(t *testing.T) {
} }
func TestMissingRoot(t *testing.T) { func TestMissingRoot(t *testing.T) {
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db) trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(diskdb))
if trie != nil { if trie != nil {
t.Error("New returned non-nil trie for invalid root") t.Error("New returned non-nil trie for invalid root")
} }
@ -78,70 +78,75 @@ func TestMissingRoot(t *testing.T) {
} }
} }
func TestMissingNode(t *testing.T) { func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) }
db, _ := ethdb.NewMemDatabase() func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
trie, _ := New(common.Hash{}, db)
func testMissingNode(t *testing.T, memonly bool) {
diskdb, _ := ethdb.NewMemDatabase()
triedb := NewDatabase(diskdb)
trie, _ := New(common.Hash{}, triedb)
updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer") updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
root, _ := trie.Commit() root, _ := trie.Commit(nil)
if !memonly {
triedb.Commit(root, true)
}
trie, _ = New(root, db) trie, _ = New(root, triedb)
_, err := trie.TryGet([]byte("120000")) _, err := trie.TryGet([]byte("120000"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(root, triedb)
trie, _ = New(root, db)
_, err = trie.TryGet([]byte("120099")) _, err = trie.TryGet([]byte("120099"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(root, triedb)
trie, _ = New(root, db)
_, err = trie.TryGet([]byte("123456")) _, err = trie.TryGet([]byte("123456"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(root, triedb)
trie, _ = New(root, db)
err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(root, triedb)
trie, _ = New(root, db)
err = trie.TryDelete([]byte("123456")) err = trie.TryDelete([]byte("123456"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")) hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
if memonly {
delete(triedb.nodes, hash)
} else {
diskdb.Delete(hash[:])
}
trie, _ = New(root, db) trie, _ = New(root, triedb)
_, err = trie.TryGet([]byte("120000")) _, err = trie.TryGet([]byte("120000"))
if _, ok := err.(*MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err) t.Errorf("Wrong error: %v", err)
} }
trie, _ = New(root, triedb)
trie, _ = New(root, db)
_, err = trie.TryGet([]byte("120099")) _, err = trie.TryGet([]byte("120099"))
if _, ok := err.(*MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err) t.Errorf("Wrong error: %v", err)
} }
trie, _ = New(root, triedb)
trie, _ = New(root, db)
_, err = trie.TryGet([]byte("123456")) _, err = trie.TryGet([]byte("123456"))
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
trie, _ = New(root, triedb)
trie, _ = New(root, db)
err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
if _, ok := err.(*MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err) t.Errorf("Wrong error: %v", err)
} }
trie, _ = New(root, triedb)
trie, _ = New(root, db)
err = trie.TryDelete([]byte("123456")) err = trie.TryDelete([]byte("123456"))
if _, ok := err.(*MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err) t.Errorf("Wrong error: %v", err)
@ -165,7 +170,7 @@ func TestInsert(t *testing.T) {
updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
root, err := trie.Commit() root, err := trie.Commit(nil)
if err != nil { if err != nil {
t.Fatalf("commit error: %v", err) t.Fatalf("commit error: %v", err)
} }
@ -194,7 +199,7 @@ func TestGet(t *testing.T) {
if i == 1 { if i == 1 {
return return
} }
trie.Commit() trie.Commit(nil)
} }
} }
@ -263,7 +268,7 @@ func TestReplication(t *testing.T) {
for _, val := range vals { for _, val := range vals {
updateString(trie, val.k, val.v) updateString(trie, val.k, val.v)
} }
exp, err := trie.Commit() exp, err := trie.Commit(nil)
if err != nil { if err != nil {
t.Fatalf("commit error: %v", err) t.Fatalf("commit error: %v", err)
} }
@ -278,7 +283,7 @@ func TestReplication(t *testing.T) {
t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v) t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
} }
} }
hash, err := trie2.Commit() hash, err := trie2.Commit(nil)
if err != nil { if err != nil {
t.Fatalf("commit error: %v", err) t.Fatalf("commit error: %v", err)
} }
@ -314,7 +319,7 @@ func TestLargeValue(t *testing.T) {
} }
type countingDB struct { type countingDB struct {
Database ethdb.Database
gets map[string]int gets map[string]int
} }
@ -332,19 +337,20 @@ func TestCacheUnload(t *testing.T) {
key2 := "---some other branch" key2 := "---some other branch"
updateString(trie, key1, "this is the branch of key1.") updateString(trie, key1, "this is the branch of key1.")
updateString(trie, key2, "this is the branch of key2.") updateString(trie, key2, "this is the branch of key2.")
root, _ := trie.Commit()
root, _ := trie.Commit(nil)
trie.db.Commit(root, true)
// Commit the trie repeatedly and access key1. // Commit the trie repeatedly and access key1.
// The branch containing it is loaded from DB exactly two times: // The branch containing it is loaded from DB exactly two times:
// in the 0th and 6th iteration. // in the 0th and 6th iteration.
db := &countingDB{Database: trie.db, gets: make(map[string]int)} db := &countingDB{Database: trie.db.diskdb, gets: make(map[string]int)}
trie, _ = New(root, db) trie, _ = New(root, NewDatabase(db))
trie.SetCacheLimit(5) trie.SetCacheLimit(5)
for i := 0; i < 12; i++ { for i := 0; i < 12; i++ {
getString(trie, key1) getString(trie, key1)
trie.Commit() trie.Commit(nil)
} }
// Check that it got loaded two times. // Check that it got loaded two times.
for dbkey, count := range db.gets { for dbkey, count := range db.gets {
if count != 2 { if count != 2 {
@ -407,8 +413,10 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
} }
func runRandTest(rt randTest) bool { func runRandTest(rt randTest) bool {
db, _ := ethdb.NewMemDatabase() diskdb, _ := ethdb.NewMemDatabase()
tr, _ := New(common.Hash{}, db) triedb := NewDatabase(diskdb)
tr, _ := New(common.Hash{}, triedb)
values := make(map[string]string) // tracks content of the trie values := make(map[string]string) // tracks content of the trie
for i, step := range rt { for i, step := range rt {
@ -426,23 +434,23 @@ func runRandTest(rt randTest) bool {
rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want) rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want)
} }
case opCommit: case opCommit:
_, rt[i].err = tr.Commit() _, rt[i].err = tr.Commit(nil)
case opHash: case opHash:
tr.Hash() tr.Hash()
case opReset: case opReset:
hash, err := tr.Commit() hash, err := tr.Commit(nil)
if err != nil { if err != nil {
rt[i].err = err rt[i].err = err
return false return false
} }
newtr, err := New(hash, db) newtr, err := New(hash, triedb)
if err != nil { if err != nil {
rt[i].err = err rt[i].err = err
return false return false
} }
tr = newtr tr = newtr
case opItercheckhash: case opItercheckhash:
checktr, _ := New(common.Hash{}, nil) checktr, _ := New(common.Hash{}, triedb)
it := NewIterator(tr.NodeIterator(nil)) it := NewIterator(tr.NodeIterator(nil))
for it.Next() { for it.Next() {
checktr.Update(it.Key, it.Value) checktr.Update(it.Key, it.Value)
@ -524,7 +532,7 @@ func benchGet(b *testing.B, commit bool) {
} }
binary.LittleEndian.PutUint64(k, benchElemCount/2) binary.LittleEndian.PutUint64(k, benchElemCount/2)
if commit { if commit {
trie.Commit() trie.Commit(nil)
} }
b.ResetTimer() b.ResetTimer()
@ -534,7 +542,7 @@ func benchGet(b *testing.B, commit bool) {
b.StopTimer() b.StopTimer()
if commit { if commit {
ldb := trie.db.(*ethdb.LDBDatabase) ldb := trie.db.diskdb.(*ethdb.LDBDatabase)
ldb.Close() ldb.Close()
os.RemoveAll(ldb.Path()) os.RemoveAll(ldb.Path())
} }
@ -585,16 +593,16 @@ func BenchmarkHash(b *testing.B) {
trie.Hash() trie.Hash()
} }
func tempDB() (string, Database) { func tempDB() (string, *Database) {
dir, err := ioutil.TempDir("", "trie-bench") dir, err := ioutil.TempDir("", "trie-bench")
if err != nil { if err != nil {
panic(fmt.Sprintf("can't create temporary directory: %v", err)) panic(fmt.Sprintf("can't create temporary directory: %v", err))
} }
db, err := ethdb.NewLDBDatabase(dir, 256, 0) diskdb, err := ethdb.NewLDBDatabase(dir, 256, 0)
if err != nil { if err != nil {
panic(fmt.Sprintf("can't create temporary database: %v", err)) panic(fmt.Sprintf("can't create temporary database: %v", err))
} }
return dir, db return dir, NewDatabase(diskdb)
} }
func getString(trie *Trie, k string) []byte { func getString(trie *Trie, k string) []byte {

Loading…
Cancel
Save