@ -17,11 +17,19 @@
package state
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"math/big"
"math/rand"
"reflect"
"strings"
"testing"
"testing/quick"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto "
"github.com/ethereum/go-ethereum/core/vm "
"github.com/ethereum/go-ethereum/ethdb"
)
@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) {
// Update it with some accounts
for i := byte ( 0 ) ; i < 255 ; i ++ {
obj := state . GetOrNewStateObject ( common . BytesToAddress ( [ ] byte { i } ) )
obj . AddBalance ( big . NewInt ( int64 ( 11 * i ) ) )
obj . SetNonce ( uint64 ( 42 * i ) )
addr := common . BytesToAddress ( [ ] byte { i } )
state . AddBalance ( addr , big . NewInt ( int64 ( 11 * i ) ) )
state . SetNonce ( addr , uint64 ( 42 * i ) )
if i % 2 == 0 {
obj . SetState ( common . BytesToHash ( [ ] byte { i , i , i } ) , common . BytesToHash ( [ ] byte { i , i , i , i } ) )
state . SetState ( addr , common . BytesToHash ( [ ] byte { i , i , i } ) , common . BytesToHash ( [ ] byte { i , i , i , i } ) )
}
if i % 3 == 0 {
obj . SetCode ( crypto . Keccak256Hash ( [ ] byte { i , i , i , i , i } ) , [ ] byte { i , i , i , i , i } )
state . SetCode ( addr , [ ] byte { i , i , i , i , i } )
}
state . UpdateStateObject ( obj )
state . IntermediateRoot ( )
}
// Ensure that no data was leaked into the database
for _ , key := range db . Keys ( ) {
@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) {
transState , _ := New ( common . Hash { } , transDb )
finalState , _ := New ( common . Hash { } , finalDb )
// Update the states with some objects
for i := byte ( 0 ) ; i < 255 ; i ++ {
// Create a new state object with some data into the transition database
obj := transState . GetOrNewStateObject ( common . BytesToAddress ( [ ] byte { i } ) )
obj . SetBalance ( big . NewInt ( int64 ( 11 * i ) ) )
obj . SetNonce ( uint64 ( 42 * i ) )
modify := func ( state * StateDB , addr common . Address , i , tweak byte ) {
state . SetBalance ( addr , big . NewInt ( int64 ( 11 * i ) + int64 ( tweak ) ) )
state . SetNonce ( addr , uint64 ( 42 * i + tweak ) )
if i % 2 == 0 {
obj . SetState ( common . BytesToHash ( [ ] byte { i , i , i , 0 } ) , common . BytesToHash ( [ ] byte { i , i , i , i , 0 } ) )
state . SetState ( addr , common . Hash { i , i , i , 0 } , common . Hash { } )
state . SetState ( addr , common . Hash { i , i , i , tweak } , common . Hash { i , i , i , i , tweak } )
}
if i % 3 == 0 {
obj . SetCode ( crypto . Keccak256Hash ( [ ] byte { i , i , i , i , i , 0 } ) , [ ] byte { i , i , i , i , i , 0 } )
state . SetCode ( addr , [ ] byte { i , i , i , i , i , tweak } )
}
transState . UpdateStateObject ( obj )
}
// Overwrite all the data with new values in the transition database
obj . SetBalance ( big . NewInt ( int64 ( 11 * i + 1 ) ) )
obj . SetNonce ( uint64 ( 42 * i + 1 ) )
if i % 2 == 0 {
obj . SetState ( common . BytesToHash ( [ ] byte { i , i , i , 0 } ) , common . Hash { } )
obj . SetState ( common . BytesToHash ( [ ] byte { i , i , i , 1 } ) , common . BytesToHash ( [ ] byte { i , i , i , i , 1 } ) )
}
if i % 3 == 0 {
obj . SetCode ( crypto . Keccak256Hash ( [ ] byte { i , i , i , i , i , 1 } ) , [ ] byte { i , i , i , i , i , 1 } )
}
transState . UpdateStateObject ( obj )
// Modify the transient state.
for i := byte ( 0 ) ; i < 255 ; i ++ {
modify ( transState , common . Address { byte ( i ) } , i , 0 )
}
// Write modifications to trie.
transState . IntermediateRoot ( )
// Create the final state object directly in the final database
obj = finalState . GetOrNewStateObject ( common . BytesToAddress ( [ ] byte { i } ) )
obj . SetBalance ( big . NewInt ( int64 ( 11 * i + 1 ) ) )
obj . SetNonce ( uint64 ( 42 * i + 1 ) )
if i % 2 == 0 {
obj . SetState ( common . BytesToHash ( [ ] byte { i , i , i , 1 } ) , common . BytesToHash ( [ ] byte { i , i , i , i , 1 } ) )
}
if i % 3 == 0 {
obj . SetCode ( crypto . Keccak256Hash ( [ ] byte { i , i , i , i , i , 1 } ) , [ ] byte { i , i , i , i , i , 1 } )
}
finalState . UpdateStateObject ( obj )
// Overwrite all the data with new values in the transient database.
for i := byte ( 0 ) ; i < 255 ; i ++ {
modify ( transState , common . Address { byte ( i ) } , i , 99 )
modify ( finalState , common . Address { byte ( i ) } , i , 99 )
}
// Commit and cross check the databases.
if _ , err := transState . Commit ( ) ; err != nil {
t . Fatalf ( "failed to commit transition state: %v" , err )
}
if _ , err := finalState . Commit ( ) ; err != nil {
t . Fatalf ( "failed to commit final state: %v" , err )
}
// Cross check the databases to ensure they are the same
for _ , key := range finalDb . Keys ( ) {
if _ , err := transDb . Get ( key ) ; err != nil {
val , _ := finalDb . Get ( key )
@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) {
}
}
}
func TestSnapshotRandom ( t * testing . T ) {
config := & quick . Config { MaxCount : 1000 }
err := quick . Check ( ( * snapshotTest ) . run , config )
if cerr , ok := err . ( * quick . CheckError ) ; ok {
test := cerr . In [ 0 ] . ( * snapshotTest )
t . Errorf ( "%v:\n%s" , test . err , test )
} else if err != nil {
t . Error ( err )
}
}
// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
// captured by the snapshot. Instances of this test with pseudorandom content are created
// by Generate.
//
// The test works as follows:
//
// A new state is created and all actions are applied to it. Several snapshots are taken
// in between actions. The test then reverts each snapshot. For each snapshot the actions
// leading up to it are replayed on a fresh, empty state. The behaviour of all public
// accessor methods on the reverted state must match the return value of the equivalent
// methods on the replayed state.
type snapshotTest struct {
addrs [ ] common . Address // all account addresses
actions [ ] testAction // modifications to the state
snapshots [ ] int // actions indexes at which snapshot is taken
err error // failure details are reported through this field
}
type testAction struct {
name string
fn func ( testAction , * StateDB )
args [ ] int64
noAddr bool
}
// newTestAction creates a random action that changes state.
func newTestAction ( addr common . Address , r * rand . Rand ) testAction {
actions := [ ] testAction {
{
name : "SetBalance" ,
fn : func ( a testAction , s * StateDB ) {
s . SetBalance ( addr , big . NewInt ( a . args [ 0 ] ) )
} ,
args : make ( [ ] int64 , 1 ) ,
} ,
{
name : "AddBalance" ,
fn : func ( a testAction , s * StateDB ) {
s . AddBalance ( addr , big . NewInt ( a . args [ 0 ] ) )
} ,
args : make ( [ ] int64 , 1 ) ,
} ,
{
name : "SetNonce" ,
fn : func ( a testAction , s * StateDB ) {
s . SetNonce ( addr , uint64 ( a . args [ 0 ] ) )
} ,
args : make ( [ ] int64 , 1 ) ,
} ,
{
name : "SetState" ,
fn : func ( a testAction , s * StateDB ) {
var key , val common . Hash
binary . BigEndian . PutUint16 ( key [ : ] , uint16 ( a . args [ 0 ] ) )
binary . BigEndian . PutUint16 ( val [ : ] , uint16 ( a . args [ 1 ] ) )
s . SetState ( addr , key , val )
} ,
args : make ( [ ] int64 , 2 ) ,
} ,
{
name : "SetCode" ,
fn : func ( a testAction , s * StateDB ) {
code := make ( [ ] byte , 16 )
binary . BigEndian . PutUint64 ( code , uint64 ( a . args [ 0 ] ) )
binary . BigEndian . PutUint64 ( code [ 8 : ] , uint64 ( a . args [ 1 ] ) )
s . SetCode ( addr , code )
} ,
args : make ( [ ] int64 , 2 ) ,
} ,
{
name : "CreateAccount" ,
fn : func ( a testAction , s * StateDB ) {
s . CreateAccount ( addr )
} ,
} ,
{
name : "Delete" ,
fn : func ( a testAction , s * StateDB ) {
s . Delete ( addr )
} ,
} ,
{
name : "AddRefund" ,
fn : func ( a testAction , s * StateDB ) {
s . AddRefund ( big . NewInt ( a . args [ 0 ] ) )
} ,
args : make ( [ ] int64 , 1 ) ,
noAddr : true ,
} ,
{
name : "AddLog" ,
fn : func ( a testAction , s * StateDB ) {
data := make ( [ ] byte , 2 )
binary . BigEndian . PutUint16 ( data , uint16 ( a . args [ 0 ] ) )
s . AddLog ( & vm . Log { Address : addr , Data : data } )
} ,
args : make ( [ ] int64 , 1 ) ,
} ,
}
action := actions [ r . Intn ( len ( actions ) ) ]
var nameargs [ ] string
if ! action . noAddr {
nameargs = append ( nameargs , addr . Hex ( ) )
}
for _ , i := range action . args {
action . args [ i ] = rand . Int63n ( 100 )
nameargs = append ( nameargs , fmt . Sprint ( action . args [ i ] ) )
}
action . name += strings . Join ( nameargs , ", " )
return action
}
// Generate returns a new snapshot test of the given size. All randomness is
// derived from r.
func ( * snapshotTest ) Generate ( r * rand . Rand , size int ) reflect . Value {
// Generate random actions.
addrs := make ( [ ] common . Address , 50 )
for i := range addrs {
addrs [ i ] [ 0 ] = byte ( i )
}
actions := make ( [ ] testAction , size )
for i := range actions {
addr := addrs [ r . Intn ( len ( addrs ) ) ]
actions [ i ] = newTestAction ( addr , r )
}
// Generate snapshot indexes.
nsnapshots := int ( math . Sqrt ( float64 ( size ) ) )
if size > 0 && nsnapshots == 0 {
nsnapshots = 1
}
snapshots := make ( [ ] int , nsnapshots )
snaplen := len ( actions ) / nsnapshots
for i := range snapshots {
// Try to place the snapshots some number of actions apart from each other.
snapshots [ i ] = ( i * snaplen ) + r . Intn ( snaplen )
}
return reflect . ValueOf ( & snapshotTest { addrs , actions , snapshots , nil } )
}
func ( test * snapshotTest ) String ( ) string {
out := new ( bytes . Buffer )
sindex := 0
for i , action := range test . actions {
if len ( test . snapshots ) > sindex && i == test . snapshots [ sindex ] {
fmt . Fprintf ( out , "---- snapshot %d ----\n" , sindex )
sindex ++
}
fmt . Fprintf ( out , "%4d: %s\n" , i , action . name )
}
return out . String ( )
}
func ( test * snapshotTest ) run ( ) bool {
// Run all actions and create snapshots.
var (
db , _ = ethdb . NewMemDatabase ( )
state , _ = New ( common . Hash { } , db )
snapshotRevs = make ( [ ] int , len ( test . snapshots ) )
sindex = 0
)
for i , action := range test . actions {
if len ( test . snapshots ) > sindex && i == test . snapshots [ sindex ] {
snapshotRevs [ sindex ] = state . Snapshot ( )
sindex ++
}
action . fn ( action , state )
}
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex -- ; sindex >= 0 ; sindex -- {
checkstate , _ := New ( common . Hash { } , db )
for _ , action := range test . actions [ : test . snapshots [ sindex ] ] {
action . fn ( action , checkstate )
}
state . RevertToSnapshot ( snapshotRevs [ sindex ] )
if err := test . checkEqual ( state , checkstate ) ; err != nil {
test . err = fmt . Errorf ( "state mismatch after revert to snapshot %d\n%v" , sindex , err )
return false
}
}
return true
}
// checkEqual checks that methods of state and checkstate return the same values.
func ( test * snapshotTest ) checkEqual ( state , checkstate * StateDB ) error {
for _ , addr := range test . addrs {
var err error
checkeq := func ( op string , a , b interface { } ) bool {
if err == nil && ! reflect . DeepEqual ( a , b ) {
err = fmt . Errorf ( "got %s(%s) == %v, want %v" , op , addr . Hex ( ) , a , b )
return false
}
return true
}
// Check basic accessor methods.
checkeq ( "Exist" , state . Exist ( addr ) , checkstate . Exist ( addr ) )
checkeq ( "IsDeleted" , state . IsDeleted ( addr ) , checkstate . IsDeleted ( addr ) )
checkeq ( "GetBalance" , state . GetBalance ( addr ) , checkstate . GetBalance ( addr ) )
checkeq ( "GetNonce" , state . GetNonce ( addr ) , checkstate . GetNonce ( addr ) )
checkeq ( "GetCode" , state . GetCode ( addr ) , checkstate . GetCode ( addr ) )
checkeq ( "GetCodeHash" , state . GetCodeHash ( addr ) , checkstate . GetCodeHash ( addr ) )
checkeq ( "GetCodeSize" , state . GetCodeSize ( addr ) , checkstate . GetCodeSize ( addr ) )
// Check storage.
if obj := state . GetStateObject ( addr ) ; obj != nil {
obj . ForEachStorage ( func ( key , val common . Hash ) bool {
return checkeq ( "GetState(" + key . Hex ( ) + ")" , val , checkstate . GetState ( addr , key ) )
} )
checkobj := checkstate . GetStateObject ( addr )
checkobj . ForEachStorage ( func ( key , checkval common . Hash ) bool {
return checkeq ( "GetState(" + key . Hex ( ) + ")" , state . GetState ( addr , key ) , checkval )
} )
}
if err != nil {
return err
}
}
if state . GetRefund ( ) . Cmp ( checkstate . GetRefund ( ) ) != 0 {
return fmt . Errorf ( "got GetRefund() == %d, want GetRefund() == %d" ,
state . GetRefund ( ) , checkstate . GetRefund ( ) )
}
if ! reflect . DeepEqual ( state . GetLogs ( common . Hash { } ) , checkstate . GetLogs ( common . Hash { } ) ) {
return fmt . Errorf ( "got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v" ,
state . GetLogs ( common . Hash { } ) , checkstate . GetLogs ( common . Hash { } ) )
}
return nil
}