diff --git a/core/state/journal.go b/core/state/journal.go index ad4a654fc6..f180a5dae4 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -17,12 +17,21 @@ package state import ( + "fmt" "maps" + "slices" + "sort" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/holiman/uint256" ) +type revision struct { + id int + journalIndex int +} + // journalEntry is a modification entry in the state change journal that can be // reverted on demand. type journalEntry interface { @@ -42,6 +51,9 @@ type journalEntry interface { type journal struct { entries []journalEntry // Current changes tracked by the journal dirties map[common.Address]int // Dirty accounts and the number of changes + + validRevisions []revision + nextRevisionId int } // newJournal creates a new initialized journal. @@ -51,6 +63,40 @@ func newJournal() *journal { } } +// reset clears the journal, after this operation the journal can be used anew. +// It is semantically similar to calling 'newJournal', but the underlying slices +// can be reused. +func (j *journal) reset() { + j.entries = j.entries[:0] + j.validRevisions = j.validRevisions[:0] + clear(j.dirties) + j.nextRevisionId = 0 +} + +// snapshot returns an identifier for the current revision of the state. +func (j *journal) snapshot() int { + id := j.nextRevisionId + j.nextRevisionId++ + j.validRevisions = append(j.validRevisions, revision{id, j.length()}) + return id +} + +// revertToSnapshot reverts all state changes made since the given revision. +func (j *journal) revertToSnapshot(revid int, s *StateDB) { + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(j.validRevisions), func(i int) bool { + return j.validRevisions[i].id >= revid + }) + if idx == len(j.validRevisions) || j.validRevisions[idx].id != revid { + panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + } + snapshot := j.validRevisions[idx].journalIndex + + // Replay the journal to undo changes and remove invalidated snapshots + j.revert(s, snapshot) + j.validRevisions = j.validRevisions[:idx] +} + // append inserts a new modification entry to the end of the change journal. func (j *journal) append(entry journalEntry) { j.entries = append(j.entries, entry) @@ -95,11 +141,90 @@ func (j *journal) copy() *journal { entries = append(entries, j.entries[i].copy()) } return &journal{ - entries: entries, - dirties: maps.Clone(j.dirties), + entries: entries, + dirties: maps.Clone(j.dirties), + validRevisions: slices.Clone(j.validRevisions), + nextRevisionId: j.nextRevisionId, + } +} + +func (j *journal) logChange(txHash common.Hash) { + j.append(addLogChange{txhash: txHash}) +} + +func (j *journal) createObject(addr common.Address) { + j.append(createObjectChange{account: &addr}) +} + +func (j *journal) createContract(addr common.Address) { + j.append(createContractChange{account: addr}) +} + +func (j *journal) destruct(addr common.Address) { + j.append(selfDestructChange{account: &addr}) +} + +func (j *journal) storageChange(addr common.Address, key, prev, origin common.Hash) { + j.append(storageChange{ + account: &addr, + key: key, + prevvalue: prev, + origvalue: origin, + }) +} + +func (j *journal) transientStateChange(addr common.Address, key, prev common.Hash) { + j.append(transientStorageChange{ + account: &addr, + key: key, + prevalue: prev, + }) +} + +func (j *journal) refundChange(previous uint64) { + j.append(refundChange{prev: previous}) +} + +func (j *journal) balanceChange(addr common.Address, previous *uint256.Int) { + j.append(balanceChange{ + account: &addr, + prev: previous.Clone(), + }) +} + +func (j *journal) setCode(address common.Address) { + j.append(codeChange{account: &address}) +} + +func (j *journal) nonceChange(address common.Address, prev uint64) { + j.append(nonceChange{ + account: &address, + prev: prev, + }) +} + +func (j *journal) touchChange(address common.Address) { + j.append(touchChange{ + account: &address, + }) + if address == ripemd { + // Explicitly put it in the dirty-cache, which is otherwise generated from + // flattened journals. + j.dirty(address) } } +func (j *journal) accessListAddAccount(addr common.Address) { + j.append(accessListAddAccountChange{&addr}) +} + +func (j *journal) accessListAddSlot(addr common.Address, slot common.Hash) { + j.append(accessListAddSlotChange{ + address: &addr, + slot: &slot, + }) +} + type ( // Changes to the account trie. createObjectChange struct { @@ -114,9 +239,7 @@ type ( } selfDestructChange struct { - account *common.Address - prev bool // whether account had already self-destructed - prevbalance *uint256.Int + account *common.Address } // Changes to individual accounts. @@ -135,8 +258,7 @@ type ( origvalue common.Hash } codeChange struct { - account *common.Address - prevcode, prevhash []byte + account *common.Address } // Changes to other state values. @@ -146,9 +268,6 @@ type ( addLogChange struct { txhash common.Hash } - addPreimageChange struct { - hash common.Hash - } touchChange struct { account *common.Address } @@ -200,8 +319,7 @@ func (ch createContractChange) copy() journalEntry { func (ch selfDestructChange) revert(s *StateDB) { obj := s.getStateObject(*ch.account) if obj != nil { - obj.selfDestructed = ch.prev - obj.setBalance(ch.prevbalance) + obj.selfDestructed = false } } @@ -211,9 +329,7 @@ func (ch selfDestructChange) dirtied() *common.Address { func (ch selfDestructChange) copy() journalEntry { return selfDestructChange{ - account: ch.account, - prev: ch.prev, - prevbalance: new(uint256.Int).Set(ch.prevbalance), + account: ch.account, } } @@ -263,7 +379,7 @@ func (ch nonceChange) copy() journalEntry { } func (ch codeChange) revert(s *StateDB) { - s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) + s.getStateObject(*ch.account).setCode(types.EmptyCodeHash, nil) } func (ch codeChange) dirtied() *common.Address { @@ -271,11 +387,7 @@ func (ch codeChange) dirtied() *common.Address { } func (ch codeChange) copy() journalEntry { - return codeChange{ - account: ch.account, - prevhash: common.CopyBytes(ch.prevhash), - prevcode: common.CopyBytes(ch.prevcode), - } + return codeChange{account: ch.account} } func (ch storageChange) revert(s *StateDB) { @@ -344,20 +456,6 @@ func (ch addLogChange) copy() journalEntry { } } -func (ch addPreimageChange) revert(s *StateDB) { - delete(s.preimages, ch.hash) -} - -func (ch addPreimageChange) dirtied() *common.Address { - return nil -} - -func (ch addPreimageChange) copy() journalEntry { - return addPreimageChange{ - hash: ch.hash, - } -} - func (ch accessListAddAccountChange) revert(s *StateDB) { /* One important invariant here, is that whenever a (addr, slot) is added, if the diff --git a/core/state/state_object.go b/core/state/state_object.go index 1e4da105b2..5946683bc0 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -114,14 +114,7 @@ func (s *stateObject) markSelfdestructed() { } func (s *stateObject) touch() { - s.db.journal.append(touchChange{ - account: &s.address, - }) - if s.address == ripemd { - // Explicitly put it in the dirty-cache, which is otherwise generated from - // flattened journals. - s.db.journal.dirty(s.address) - } + s.db.journal.touchChange(s.address) } // getTrie returns the associated storage trie. The trie will be opened if it's @@ -252,16 +245,11 @@ func (s *stateObject) SetState(key, value common.Hash) { return } // New value is different, update and journal the change - s.db.journal.append(storageChange{ - account: &s.address, - key: key, - prevvalue: prev, - origvalue: origin, - }) + s.db.journal.storageChange(s.address, key, prev, origin) + s.setState(key, value, origin) if s.db.logger != nil && s.db.logger.OnStorageChange != nil { s.db.logger.OnStorageChange(s.address, key, prev, value) } - s.setState(key, value, origin) } // setState updates a value in account dirty storage. The dirtiness will be @@ -511,10 +499,7 @@ func (s *stateObject) SubBalance(amount *uint256.Int, reason tracing.BalanceChan } func (s *stateObject) SetBalance(amount *uint256.Int, reason tracing.BalanceChangeReason) { - s.db.journal.append(balanceChange{ - account: &s.address, - prev: new(uint256.Int).Set(s.data.Balance), - }) + s.db.journal.balanceChange(s.address, s.data.Balance) if s.db.logger != nil && s.db.logger.OnBalanceChange != nil { s.db.logger.OnBalanceChange(s.address, s.Balance().ToBig(), amount.ToBig(), reason) } @@ -590,14 +575,10 @@ func (s *stateObject) CodeSize() int { } func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { - prevcode := s.Code() - s.db.journal.append(codeChange{ - account: &s.address, - prevhash: s.CodeHash(), - prevcode: prevcode, - }) + s.db.journal.setCode(s.address) if s.db.logger != nil && s.db.logger.OnCodeChange != nil { - s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), prevcode, codeHash, code) + // TODO remove prevcode from this callback + s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), nil, codeHash, code) } s.setCode(codeHash, code) } @@ -609,10 +590,7 @@ func (s *stateObject) setCode(codeHash common.Hash, code []byte) { } func (s *stateObject) SetNonce(nonce uint64) { - s.db.journal.append(nonceChange{ - account: &s.address, - prev: s.data.Nonce, - }) + s.db.journal.nonceChange(s.address, s.data.Nonce) if s.db.logger != nil && s.db.logger.OnNonceChange != nil { s.db.logger.OnNonceChange(s.address, s.data.Nonce, nonce) } diff --git a/core/state/statedb.go b/core/state/statedb.go index 32c20a3b7b..d93271d76b 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -23,7 +23,6 @@ import ( "maps" "math/big" "slices" - "sort" "sync" "sync/atomic" "time" @@ -48,11 +47,6 @@ import ( // TriesInMemory represents the number of layers that are kept in RAM. const TriesInMemory = 128 -type revision struct { - id int - journalIndex int -} - type mutationType int const ( @@ -143,9 +137,7 @@ type StateDB struct { // Journal of state modifications. This is the backbone of // Snapshot and RevertToSnapshot. - journal *journal - validRevisions []revision - nextRevisionId int + journal *journal // State witness if cross validation is needed witness *stateless.Witness @@ -258,7 +250,7 @@ func (s *StateDB) Error() error { } func (s *StateDB) AddLog(log *types.Log) { - s.journal.append(addLogChange{txhash: s.thash}) + s.journal.logChange(s.thash) log.TxHash = s.thash log.TxIndex = uint(s.txIndex) @@ -292,7 +284,6 @@ func (s *StateDB) Logs() []*types.Log { // AddPreimage records a SHA3 preimage seen by the VM. func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) { if _, ok := s.preimages[hash]; !ok { - s.journal.append(addPreimageChange{hash: hash}) s.preimages[hash] = slices.Clone(preimage) } } @@ -304,14 +295,14 @@ func (s *StateDB) Preimages() map[common.Hash][]byte { // AddRefund adds gas to the refund counter func (s *StateDB) AddRefund(gas uint64) { - s.journal.append(refundChange{prev: s.refund}) + s.journal.refundChange(s.refund) s.refund += gas } // SubRefund removes gas from the refund counter. // This method will panic if the refund counter goes below zero func (s *StateDB) SubRefund(gas uint64) { - s.journal.append(refundChange{prev: s.refund}) + s.journal.refundChange(s.refund) if gas > s.refund { panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund)) } @@ -508,20 +499,17 @@ func (s *StateDB) SelfDestruct(addr common.Address) { if stateObject == nil { return } - var ( - prev = new(uint256.Int).Set(stateObject.Balance()) - n = new(uint256.Int) - ) - s.journal.append(selfDestructChange{ - account: &addr, - prev: stateObject.selfDestructed, - prevbalance: prev, - }) - if s.logger != nil && s.logger.OnBalanceChange != nil && prev.Sign() > 0 { - s.logger.OnBalanceChange(addr, prev.ToBig(), n.ToBig(), tracing.BalanceDecreaseSelfdestruct) + // Regardless of whether it is already destructed or not, we do have to + // journal the balance-change, if we set it to zero here. + if !stateObject.Balance().IsZero() { + stateObject.SetBalance(new(uint256.Int), tracing.BalanceDecreaseSelfdestruct) + } + // If it is already marked as self-destructed, we do not need to add it + // for journalling a second time. + if !stateObject.selfDestructed { + s.journal.destruct(addr) + stateObject.markSelfdestructed() } - stateObject.markSelfdestructed() - stateObject.data.Balance = n } func (s *StateDB) Selfdestruct6780(addr common.Address) { @@ -542,11 +530,7 @@ func (s *StateDB) SetTransientState(addr common.Address, key, value common.Hash) if prev == value { return } - s.journal.append(transientStorageChange{ - account: &addr, - key: key, - prevalue: prev, - }) + s.journal.transientStateChange(addr, key, prev) s.setTransientState(addr, key, value) } @@ -668,7 +652,7 @@ func (s *StateDB) getOrNewStateObject(addr common.Address) *stateObject { // existing account with the given address, otherwise it will be silently overwritten. func (s *StateDB) createObject(addr common.Address) *stateObject { obj := newObject(s, addr, nil) - s.journal.append(createObjectChange{account: &addr}) + s.journal.createObject(addr) s.setStateObject(obj) return obj } @@ -690,7 +674,7 @@ func (s *StateDB) CreateContract(addr common.Address) { obj := s.getStateObject(addr) if !obj.newContract { obj.newContract = true - s.journal.append(createContractChange{account: addr}) + s.journal.createContract(addr) } } @@ -714,8 +698,6 @@ func (s *StateDB) Copy() *StateDB { logSize: s.logSize, preimages: maps.Clone(s.preimages), journal: s.journal.copy(), - validRevisions: slices.Clone(s.validRevisions), - nextRevisionId: s.nextRevisionId, // In order for the block producer to be able to use and make additions // to the snapshot tree, we need to copy that as well. Otherwise, any @@ -761,26 +743,12 @@ func (s *StateDB) Copy() *StateDB { // Snapshot returns an identifier for the current revision of the state. func (s *StateDB) Snapshot() int { - id := s.nextRevisionId - s.nextRevisionId++ - s.validRevisions = append(s.validRevisions, revision{id, s.journal.length()}) - return id + return s.journal.snapshot() } // RevertToSnapshot reverts all state changes made since the given revision. func (s *StateDB) RevertToSnapshot(revid int) { - // Find the snapshot in the stack of valid snapshots. - idx := sort.Search(len(s.validRevisions), func(i int) bool { - return s.validRevisions[i].id >= revid - }) - if idx == len(s.validRevisions) || s.validRevisions[idx].id != revid { - panic(fmt.Errorf("revision id %v cannot be reverted", revid)) - } - snapshot := s.validRevisions[idx].journalIndex - - // Replay the journal to undo changes and remove invalidated snapshots - s.journal.revert(s, snapshot) - s.validRevisions = s.validRevisions[:idx] + s.journal.revertToSnapshot(revid, s) } // GetRefund returns the current value of the refund counter. @@ -996,11 +964,8 @@ func (s *StateDB) SetTxContext(thash common.Hash, ti int) { } func (s *StateDB) clearJournalAndRefund() { - if len(s.journal.entries) > 0 { - s.journal = newJournal() - s.refund = 0 - } - s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries + s.journal.reset() + s.refund = 0 } // fastDeleteStorage is the function that efficiently deletes the storage trie @@ -1431,7 +1396,7 @@ func (s *StateDB) Prepare(rules params.Rules, sender, coinbase common.Address, d // AddAddressToAccessList adds the given address to the access list func (s *StateDB) AddAddressToAccessList(addr common.Address) { if s.accessList.AddAddress(addr) { - s.journal.append(accessListAddAccountChange{&addr}) + s.journal.accessListAddAccount(addr) } } @@ -1443,13 +1408,10 @@ func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { // scope of 'address' without having the 'address' become already added // to the access list (via call-variant, create, etc). // Better safe than sorry, though - s.journal.append(accessListAddAccountChange{&addr}) + s.journal.accessListAddAccount(addr) } if slotMod { - s.journal.append(accessListAddSlotChange{ - address: &addr, - slot: &slot, - }) + s.journal.accessListAddSlot(addr, slot) } } diff --git a/core/state/statedb_fuzz_test.go b/core/state/statedb_fuzz_test.go index 40b079cd8a..153035b9c1 100644 --- a/core/state/statedb_fuzz_test.go +++ b/core/state/statedb_fuzz_test.go @@ -73,7 +73,7 @@ func newStateTestAction(addr common.Address, r *rand.Rand, index int) testAction args: make([]int64, 1), }, { - name: "SetState", + name: "SetStorage", fn: func(a testAction, s *StateDB) { var key, val common.Hash binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 2ce2b868fa..a8ae6eb6d3 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -360,7 +360,7 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { args: make([]int64, 1), }, { - name: "SetState", + name: "SetStorage", fn: func(a testAction, s *StateDB) { var key, val common.Hash binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) @@ -372,6 +372,12 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { { name: "SetCode", fn: func(a testAction, s *StateDB) { + // SetCode can only be performed in case the addr does + // not already hold code + if c := s.GetCode(addr); len(c) > 0 { + // no-op + return + } code := make([]byte, 16) binary.BigEndian.PutUint64(code, uint64(a.args[0])) binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))