diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 1e91ff255..4eda3bf53 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -17,6 +17,8 @@ package vm import ( + "sync/atomic" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/params" @@ -525,6 +527,9 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b } func opJump(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + if atomic.LoadInt32(&interpreter.evm.abort) != 0 { + return nil, errStopToken + } pos := scope.Stack.pop() if !scope.Contract.validJumpdest(&pos) { return nil, ErrInvalidJump @@ -534,6 +539,9 @@ func opJump(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byt } func opJumpi(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + if atomic.LoadInt32(&interpreter.evm.abort) != 0 { + return nil, errStopToken + } pos, cond := scope.Stack.pop(), scope.Stack.pop() if !cond.IsZero() { if !scope.Contract.validJumpdest(&pos) { diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index a4c54b1fb..673f046a4 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -18,7 +18,6 @@ package vm import ( "hash" - "sync/atomic" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" @@ -178,12 +177,7 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( // explicit STOP, RETURN or SELFDESTRUCT is executed, an error occurred during // the execution of one of the operations or until the done flag is set by the // parent context. - steps := 0 for { - steps++ - if steps%1000 == 0 && atomic.LoadInt32(&in.evm.abort) != 0 { - break - } if in.cfg.Debug { // Capture pre-execution values for tracing. logged, pcCopy, gasCopy = false, pc, contract.Gas diff --git a/core/vm/interpreter_test.go b/core/vm/interpreter_test.go new file mode 100644 index 000000000..dfae0f2e2 --- /dev/null +++ b/core/vm/interpreter_test.go @@ -0,0 +1,77 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package vm + +import ( + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/params" +) + +var loopInterruptTests = []string{ + // infinite loop using JUMP: push(2) jumpdest dup1 jump + "60025b8056", + // infinite loop using JUMPI: push(1) push(4) jumpdest dup2 dup2 jumpi + "600160045b818157", +} + +func TestLoopInterrupt(t *testing.T) { + address := common.BytesToAddress([]byte("contract")) + vmctx := BlockContext{ + Transfer: func(StateDB, common.Address, common.Address, *big.Int) {}, + } + + for i, tt := range loopInterruptTests { + statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + statedb.CreateAccount(address) + statedb.SetCode(address, common.Hex2Bytes(tt)) + statedb.Finalise(true) + + evm := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, Config{}) + + errChannel := make(chan error) + timeout := make(chan bool) + + go func(evm *EVM) { + _, _, err := evm.Call(AccountRef(common.Address{}), address, nil, math.MaxUint64, new(big.Int)) + errChannel <- err + }(evm) + + go func() { + <-time.After(time.Second) + timeout <- true + }() + + evm.Cancel() + + select { + case <-timeout: + t.Errorf("test %d timed out", i) + case err := <-errChannel: + if err != nil { + t.Errorf("test %d failure: %v", i, err) + } + } + } + +}