@ -21,9 +21,11 @@ import (
"encoding/binary"
"encoding/binary"
"errors"
"errors"
"fmt"
"fmt"
"maps"
"math"
"math"
"math/rand"
"math/rand"
"reflect"
"reflect"
"slices"
"strings"
"strings"
"sync"
"sync"
"testing"
"testing"
@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
if err != nil {
if err != nil {
return err
return err
}
}
it := trie . NewIterator ( trieIt )
var (
it = trie . NewIterator ( trieIt )
visited = make ( map [ common . Hash ] bool )
)
for it . Next ( ) {
for it . Next ( ) {
key := common . BytesToHash ( s . trie . GetKey ( it . Key ) )
key := common . BytesToHash ( s . trie . GetKey ( it . Key ) )
visited [ key ] = true
if value , dirty := so . dirtyStorage [ key ] ; dirty {
if value , dirty := so . dirtyStorage [ key ] ; dirty {
if ! cb ( key , value ) {
if ! cb ( key , value ) {
return nil
return nil
@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
checkeq ( "GetCode" , state . GetCode ( addr ) , checkstate . GetCode ( addr ) )
checkeq ( "GetCode" , state . GetCode ( addr ) , checkstate . GetCode ( addr ) )
checkeq ( "GetCodeHash" , state . GetCodeHash ( addr ) , checkstate . GetCodeHash ( addr ) )
checkeq ( "GetCodeHash" , state . GetCodeHash ( addr ) , checkstate . GetCodeHash ( addr ) )
checkeq ( "GetCodeSize" , state . GetCodeSize ( addr ) , checkstate . GetCodeSize ( addr ) )
checkeq ( "GetCodeSize" , state . GetCodeSize ( addr ) , checkstate . GetCodeSize ( addr ) )
// Check newContract-flag
if obj := state . getStateObject ( addr ) ; obj != nil {
checkeq ( "IsNewContract" , obj . newContract , checkstate . getStateObject ( addr ) . newContract )
}
// Check storage.
// Check storage.
if obj := state . getStateObject ( addr ) ; obj != nil {
if obj := state . getStateObject ( addr ) ; obj != nil {
forEachStorage ( state , addr , func ( key , value common . Hash ) bool {
forEachStorage ( state , addr , func ( key , value common . Hash ) bool {
@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
forEachStorage ( checkstate , addr , func ( key , value common . Hash ) bool {
forEachStorage ( checkstate , addr , func ( key , value common . Hash ) bool {
return checkeq ( "GetState(" + key . Hex ( ) + ")" , checkstate . GetState ( addr , key ) , value )
return checkeq ( "GetState(" + key . Hex ( ) + ")" , checkstate . GetState ( addr , key ) , value )
} )
} )
other := checkstate . getStateObject ( addr )
// Check dirty storage which is not in trie
if ! maps . Equal ( obj . dirtyStorage , other . dirtyStorage ) {
print := func ( dirty map [ common . Hash ] common . Hash ) string {
var keys [ ] common . Hash
out := new ( strings . Builder )
for key := range dirty {
keys = append ( keys , key )
}
slices . SortFunc ( keys , common . Hash . Cmp )
for i , key := range keys {
fmt . Fprintf ( out , " %d. %v %v\n" , i , key , dirty [ key ] )
}
return out . String ( )
}
return fmt . Errorf ( "dirty storage err, have\n%v\nwant\n%v" ,
print ( obj . dirtyStorage ) ,
print ( other . dirtyStorage ) )
}
}
// Check transient storage.
{
have := state . transientStorage
want := checkstate . transientStorage
eq := maps . EqualFunc ( have , want ,
func ( a Storage , b Storage ) bool {
return maps . Equal ( a , b )
} )
if ! eq {
return fmt . Errorf ( "transient storage differs ,have\n%v\nwant\n%v" ,
have . PrettyPrint ( ) ,
want . PrettyPrint ( ) )
}
}
}
if err != nil {
if err != nil {
return err
return err
}
}
}
}
if ! checkstate . accessList . Equal ( state . accessList ) { // Check access lists
return fmt . Errorf ( "AccessLists are wrong, have \n%v\nwant\n%v" ,
checkstate . accessList . PrettyPrint ( ) ,
state . accessList . PrettyPrint ( ) )
}
if state . GetRefund ( ) != checkstate . GetRefund ( ) {
if state . GetRefund ( ) != checkstate . GetRefund ( ) {
return fmt . Errorf ( "got GetRefund() == %d, want GetRefund() == %d" ,
return fmt . Errorf ( "got GetRefund() == %d, want GetRefund() == %d" ,
state . GetRefund ( ) , checkstate . GetRefund ( ) )
state . GetRefund ( ) , checkstate . GetRefund ( ) )
@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
return fmt . Errorf ( "got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v" ,
return fmt . Errorf ( "got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v" ,
state . GetLogs ( common . Hash { } , 0 , common . Hash { } ) , checkstate . GetLogs ( common . Hash { } , 0 , common . Hash { } ) )
state . GetLogs ( common . Hash { } , 0 , common . Hash { } ) , checkstate . GetLogs ( common . Hash { } , 0 , common . Hash { } ) )
}
}
if ! maps . Equal ( state . journal . dirties , checkstate . journal . dirties ) {
getKeys := func ( dirty map [ common . Address ] int ) string {
var keys [ ] common . Address
out := new ( strings . Builder )
for key := range dirty {
keys = append ( keys , key )
}
slices . SortFunc ( keys , common . Address . Cmp )
for i , key := range keys {
fmt . Fprintf ( out , " %d. %v\n" , i , key )
}
return out . String ( )
}
have := getKeys ( state . journal . dirties )
want := getKeys ( checkstate . journal . dirties )
return fmt . Errorf ( "dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n" , have , want )
}
return nil
return nil
}
}