diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index 8889fa9fe..6ee6b06bb 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -163,6 +163,9 @@ type Tree struct { cache int // Megabytes permitted to use for read caches layers map[common.Hash]snapshot // Collection of all known layers lock sync.RWMutex + + // Test hooks + onFlatten func() // Hook invoked when the bottom most diff layers are flattened } // New attempts to load an already existing snapshot from a persistent key-value @@ -463,14 +466,21 @@ func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer { return nil case *diffLayer: + // Hold the write lock until the flattened parent is linked correctly. + // Otherwise, the stale layer may be accessed by external reads in the + // meantime. + diff.lock.Lock() + defer diff.lock.Unlock() + // Flatten the parent into the grandparent. The flattening internally obtains a // write lock on grandparent. flattened := parent.flatten().(*diffLayer) t.layers[flattened.root] = flattened - diff.lock.Lock() - defer diff.lock.Unlock() - + // Invoke the hook if it's registered. Ugly hack. + if t.onFlatten != nil { + t.onFlatten() + } diff.parent = flattened if flattened.memory < aggregatorMemoryLimit { // Accumulator layer is smaller than the limit, so we can abort, unless diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go index 4b787cfe2..12f2765b3 100644 --- a/core/state/snapshot/snapshot_test.go +++ b/core/state/snapshot/snapshot_test.go @@ -22,6 +22,7 @@ import ( "math/big" "math/rand" "testing" + "time" "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" @@ -324,7 +325,7 @@ func TestPostCapBasicDataAccess(t *testing.T) { } } -// TestSnaphots tests the functionality for retrieveing the snapshot +// TestSnaphots tests the functionality for retrieving the snapshot // with given head root and the desired depth. func TestSnaphots(t *testing.T) { // setAccount is a helper to construct a random account entry and assign it to @@ -423,3 +424,63 @@ func TestSnaphots(t *testing.T) { } } } + +// TestReadStateDuringFlattening tests the scenario that, during the +// bottom diff layers are merging which tags these as stale, the read +// happens via a pre-created top snapshot layer which tries to access +// the state in these stale layers. Ensure this read can retrieve the +// right state back(block until the flattening is finished) instead of +// an unexpected error(snapshot layer is stale). +func TestReadStateDuringFlattening(t *testing.T) { + // setAccount is a helper to construct a random account entry and assign it to + // an account slot in a snapshot + setAccount := func(accKey string) map[common.Hash][]byte { + return map[common.Hash][]byte{ + common.HexToHash(accKey): randomAccount(), + } + } + // Create a starting base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // 4 layers in total, 3 diff layers and 1 disk layers + snaps.Update(common.HexToHash("0xa1"), common.HexToHash("0x01"), nil, setAccount("0xa1"), nil) + snaps.Update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), nil, setAccount("0xa2"), nil) + snaps.Update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), nil, setAccount("0xa3"), nil) + + // Obtain the topmost snapshot handler for state accessing + snap := snaps.Snapshot(common.HexToHash("0xa3")) + + // Register the testing hook to access the state after flattening + var result = make(chan *Account) + snaps.onFlatten = func() { + // Spin up a thread to read the account from the pre-created + // snapshot handler. It's expected to be blocked. + go func() { + account, _ := snap.Account(common.HexToHash("0xa1")) + result <- account + }() + select { + case res := <-result: + t.Fatalf("Unexpected return %v", res) + case <-time.NewTimer(time.Millisecond * 300).C: + } + } + // Cap the snap tree, which will mark the bottom-most layer as stale. + snaps.Cap(common.HexToHash("0xa3"), 1) + select { + case account := <-result: + if account == nil { + t.Fatal("Failed to retrieve account") + } + case <-time.NewTimer(time.Millisecond * 300).C: + t.Fatal("Unexpected blocker") + } +}