diff --git a/common/types.go b/common/types.go index d05c21eecc..e41112a77a 100644 --- a/common/types.go +++ b/common/types.go @@ -62,6 +62,10 @@ func (h Hash) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(h) } +func EmptyHash(h Hash) bool { + return h == Hash{} +} + /////////// Address func BytesToAddress(b []byte) Address { var a Address diff --git a/core/state/state_object.go b/core/state/state_object.go index 6d2455d792..42dac632bf 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -19,11 +19,11 @@ func (self Code) String() string { return string(self) //strings.Join(Disassemble(self), " ") } -type Storage map[string]*common.Value +type Storage map[string]common.Hash func (self Storage) String() (str string) { for key, value := range self { - str += fmt.Sprintf("%X : %X\n", key, value.Bytes()) + str += fmt.Sprintf("%X : %X\n", key, value) } return @@ -32,7 +32,6 @@ func (self Storage) String() (str string) { func (self Storage) Copy() Storage { cpy := make(Storage) for key, value := range self { - // XXX Do we need a 'value' copy or is this sufficient? cpy[key] = value } @@ -112,7 +111,7 @@ func NewStateObjectFromBytes(address common.Address, data []byte, db common.Data object.balance = extobject.Balance object.codeHash = extobject.CodeHash object.State = New(extobject.Root, db) - object.storage = make(map[string]*common.Value) + object.storage = make(map[string]common.Hash) object.gasPool = new(big.Int) object.prepaid = new(big.Int) object.code, _ = db.Get(extobject.CodeHash) @@ -129,35 +128,29 @@ func (self *StateObject) MarkForDeletion() { } } -func (c *StateObject) getAddr(addr common.Hash) *common.Value { - return common.NewValueFromBytes([]byte(c.State.trie.Get(addr[:]))) +func (c *StateObject) getAddr(addr common.Hash) (ret common.Hash) { + return common.BytesToHash(common.NewValueFromBytes([]byte(c.State.trie.Get(addr[:]))).Bytes()) } -func (c *StateObject) setAddr(addr []byte, value interface{}) { - c.State.trie.Update(addr, common.Encode(value)) -} - -func (self *StateObject) GetStorage(key *big.Int) *common.Value { - fmt.Printf("%v: get %v %v", self.address.Hex(), key) - return self.GetState(common.BytesToHash(key.Bytes())) -} - -func (self *StateObject) SetStorage(key *big.Int, value *common.Value) { - fmt.Printf("%v: set %v -> %v", self.address.Hex(), key, value) - self.SetState(common.BytesToHash(key.Bytes()), value) +func (c *StateObject) setAddr(addr []byte, value common.Hash) { + v, err := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) + if err != nil { + // if RLPing failed we better panic and not fail silently. This would be considered a consensus issue + panic(err) + } + c.State.trie.Update(addr, v) } func (self *StateObject) Storage() Storage { return self.storage } -func (self *StateObject) GetState(key common.Hash) *common.Value { +func (self *StateObject) GetState(key common.Hash) common.Hash { strkey := key.Str() - value := self.storage[strkey] - if value == nil { + value, exists := self.storage[strkey] + if !exists { value = self.getAddr(key) - - if !value.IsNil() { + if (value != common.Hash{}) { self.storage[strkey] = value } } @@ -165,14 +158,14 @@ func (self *StateObject) GetState(key common.Hash) *common.Value { return value } -func (self *StateObject) SetState(k common.Hash, value *common.Value) { - self.storage[k.Str()] = value.Copy() +func (self *StateObject) SetState(k, value common.Hash) { + self.storage[k.Str()] = value self.dirty = true } func (self *StateObject) Sync() { for key, value := range self.storage { - if value.Len() == 0 { + if (value == common.Hash{}) { self.State.trie.Delete([]byte(key)) continue } @@ -370,7 +363,7 @@ func (c *StateObject) RlpDecode(data []byte) { c.nonce = decoder.Get(0).Uint() c.balance = decoder.Get(1).BigInt() c.State = New(common.BytesToHash(decoder.Get(2).Bytes()), c.db) //New(trie.New(common.Config.Db, decoder.Get(2).Interface())) - c.storage = make(map[string]*common.Value) + c.storage = make(map[string]common.Hash) c.gasPool = new(big.Int) c.codeHash = decoder.Get(3).Bytes() diff --git a/core/state/statedb.go b/core/state/statedb.go index 895d9fe8ba..1c75ee4db9 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -103,13 +103,13 @@ func (self *StateDB) GetCode(addr common.Address) []byte { return nil } -func (self *StateDB) GetState(a common.Address, b common.Hash) []byte { +func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { stateObject := self.GetStateObject(a) if stateObject != nil { - return stateObject.GetState(b).Bytes() + return stateObject.GetState(b) } - return nil + return common.Hash{} } func (self *StateDB) IsDeleted(addr common.Address) bool { @@ -145,10 +145,10 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) { } } -func (self *StateDB) SetState(addr common.Address, key common.Hash, value interface{}) { +func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(key, common.NewValue(value)) + stateObject.SetState(key, value) } } diff --git a/core/vm/vm.go b/core/vm/vm.go index 0486fbbc7c..336f6cf958 100644 --- a/core/vm/vm.go +++ b/core/vm/vm.go @@ -506,14 +506,14 @@ func (self *Vm) Run(context *Context, input []byte) (ret []byte, err error) { case SLOAD: loc := common.BigToHash(stack.pop()) - val := common.Bytes2Big(statedb.GetState(context.Address(), loc)) + val := statedb.GetState(context.Address(), loc).Big() stack.push(val) case SSTORE: loc := common.BigToHash(stack.pop()) val := stack.pop() - statedb.SetState(context.Address(), loc, val) + statedb.SetState(context.Address(), loc, common.BigToHash(val)) case JUMP: if err := jump(pc, stack.pop()); err != nil { @@ -686,10 +686,10 @@ func (self *Vm) calculateGasAndSize(context *Context, caller ContextRef, op OpCo var g *big.Int y, x := stack.data[stack.len()-2], stack.data[stack.len()-1] val := statedb.GetState(context.Address(), common.BigToHash(x)) - if len(val) == 0 && len(y.Bytes()) > 0 { + if common.EmptyHash(val) && !common.EmptyHash(common.BigToHash(y)) { // 0 => non 0 g = params.SstoreSetGas - } else if len(val) > 0 && len(y.Bytes()) == 0 { + } else if !common.EmptyHash(val) && common.EmptyHash(common.BigToHash(y)) { statedb.Refund(params.SstoreRefundGas) g = params.SstoreClearGas @@ -697,6 +697,13 @@ func (self *Vm) calculateGasAndSize(context *Context, caller ContextRef, op OpCo // non 0 => non 0 (or 0 => 0) g = params.SstoreClearGas } + + /* + if len(val) == 0 && len(y.Bytes()) > 0 { + } else if len(val) > 0 && len(y.Bytes()) == 0 { + } else { + } + */ gas.Set(g) case SUICIDE: if !statedb.IsDeleted(context.Address()) { diff --git a/tests/vm/gh_test.go b/tests/vm/gh_test.go index a95d02576a..be9e89d9c0 100644 --- a/tests/vm/gh_test.go +++ b/tests/vm/gh_test.go @@ -97,7 +97,7 @@ func RunVmTest(p string, t *testing.T) { obj := StateObjectFromAccount(db, addr, account) statedb.SetStateObject(obj) for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.NewValue(helper.FromHex(v))) + obj.SetState(common.HexToHash(a), common.HexToHash(v)) } } @@ -168,11 +168,11 @@ func RunVmTest(p string, t *testing.T) { } for addr, value := range account.Storage { - v := obj.GetState(common.HexToHash(addr)).Bytes() - vexp := helper.FromHex(value) + v := obj.GetState(common.HexToHash(addr)) + vexp := common.HexToHash(value) - if bytes.Compare(v, vexp) != 0 { - t.Errorf("%s's : (%x: %s) storage failed. Expected %x, got %x (%v %v)\n", name, obj.Address().Bytes()[0:4], addr, vexp, v, common.BigD(vexp), common.BigD(v)) + if v != vexp { + t.Errorf("%s's : (%x: %s) storage failed. Expected %x, got %x (%v %v)\n", name, obj.Address().Bytes()[0:4], addr, vexp, v, vexp.Big(), v.Big()) } } }