mirror of https://github.com/ethereum/go-ethereum
parent
a19d2c2278
commit
f7417d3552
@ -0,0 +1,59 @@ |
||||
package ptrie |
||||
|
||||
type FullNode struct { |
||||
trie *Trie |
||||
nodes [17]Node |
||||
} |
||||
|
||||
func NewFullNode(t *Trie) *FullNode { |
||||
return &FullNode{trie: t} |
||||
} |
||||
|
||||
func (self *FullNode) Dirty() bool { return true } |
||||
func (self *FullNode) Value() Node { |
||||
self.nodes[16] = self.trie.trans(self.nodes[16]) |
||||
return self.nodes[16] |
||||
} |
||||
|
||||
func (self *FullNode) Copy() Node { return self } |
||||
|
||||
// Returns the length of non-nil nodes
|
||||
func (self *FullNode) Len() (amount int) { |
||||
for _, node := range self.nodes { |
||||
if node != nil { |
||||
amount++ |
||||
} |
||||
} |
||||
|
||||
return |
||||
} |
||||
|
||||
func (self *FullNode) Hash() interface{} { |
||||
return self.trie.store(self) |
||||
} |
||||
|
||||
func (self *FullNode) RlpData() interface{} { |
||||
t := make([]interface{}, 17) |
||||
for i, node := range self.nodes { |
||||
if node != nil { |
||||
t[i] = node.Hash() |
||||
} else { |
||||
t[i] = "" |
||||
} |
||||
} |
||||
|
||||
return t |
||||
} |
||||
|
||||
func (self *FullNode) set(k byte, value Node) { |
||||
self.nodes[int(k)] = value |
||||
} |
||||
|
||||
func (self *FullNode) get(i byte) Node { |
||||
if self.nodes[int(i)] != nil { |
||||
self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)]) |
||||
|
||||
return self.nodes[int(i)] |
||||
} |
||||
return nil |
||||
} |
@ -0,0 +1,22 @@ |
||||
package ptrie |
||||
|
||||
type HashNode struct { |
||||
key []byte |
||||
} |
||||
|
||||
func NewHash(key []byte) *HashNode { |
||||
return &HashNode{key} |
||||
} |
||||
|
||||
func (self *HashNode) RlpData() interface{} { |
||||
return self.key |
||||
} |
||||
|
||||
func (self *HashNode) Hash() interface{} { |
||||
return self.key |
||||
} |
||||
|
||||
// These methods will never be called but we have to satisfy Node interface
|
||||
func (self *HashNode) Value() Node { return nil } |
||||
func (self *HashNode) Dirty() bool { return true } |
||||
func (self *HashNode) Copy() Node { return self } |
@ -0,0 +1,40 @@ |
||||
package ptrie |
||||
|
||||
import "fmt" |
||||
|
||||
var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} |
||||
|
||||
type Node interface { |
||||
Value() Node |
||||
Copy() Node // All nodes, for now, return them self
|
||||
Dirty() bool |
||||
fstring(string) string |
||||
Hash() interface{} |
||||
RlpData() interface{} |
||||
} |
||||
|
||||
// Value node
|
||||
func (self *ValueNode) String() string { return self.fstring("") } |
||||
func (self *FullNode) String() string { return self.fstring("") } |
||||
func (self *ShortNode) String() string { return self.fstring("") } |
||||
func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%s ", self.data) } |
||||
func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.key) } |
||||
|
||||
// Full node
|
||||
func (self *FullNode) fstring(ind string) string { |
||||
resp := fmt.Sprintf("[\n%s ", ind) |
||||
for i, node := range self.nodes { |
||||
if node == nil { |
||||
resp += fmt.Sprintf("%s: <nil> ", indices[i]) |
||||
} else { |
||||
resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" ")) |
||||
} |
||||
} |
||||
|
||||
return resp + fmt.Sprintf("\n%s] ", ind) |
||||
} |
||||
|
||||
// Short node
|
||||
func (self *ShortNode) fstring(ind string) string { |
||||
return fmt.Sprintf("[ %s: %v ] ", self.key, self.value.fstring(ind+" ")) |
||||
} |
@ -0,0 +1,31 @@ |
||||
package ptrie |
||||
|
||||
import "github.com/ethereum/go-ethereum/trie" |
||||
|
||||
type ShortNode struct { |
||||
trie *Trie |
||||
key []byte |
||||
value Node |
||||
} |
||||
|
||||
func NewShortNode(t *Trie, key []byte, value Node) *ShortNode { |
||||
return &ShortNode{t, []byte(trie.CompactEncode(key)), value} |
||||
} |
||||
func (self *ShortNode) Value() Node { |
||||
self.value = self.trie.trans(self.value) |
||||
|
||||
return self.value |
||||
} |
||||
func (self *ShortNode) Dirty() bool { return true } |
||||
func (self *ShortNode) Copy() Node { return self } |
||||
|
||||
func (self *ShortNode) RlpData() interface{} { |
||||
return []interface{}{self.key, self.value.Hash()} |
||||
} |
||||
func (self *ShortNode) Hash() interface{} { |
||||
return self.trie.store(self) |
||||
} |
||||
|
||||
func (self *ShortNode) Key() []byte { |
||||
return trie.CompactDecode(string(self.key)) |
||||
} |
@ -0,0 +1,286 @@ |
||||
package ptrie |
||||
|
||||
import ( |
||||
"bytes" |
||||
"sync" |
||||
|
||||
"github.com/ethereum/go-ethereum/crypto" |
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
"github.com/ethereum/go-ethereum/trie" |
||||
) |
||||
|
||||
type Backend interface { |
||||
Get([]byte) []byte |
||||
Set([]byte, []byte) |
||||
} |
||||
|
||||
type Cache map[string][]byte |
||||
|
||||
func (self Cache) Get(key []byte) []byte { |
||||
return self[string(key)] |
||||
} |
||||
func (self Cache) Set(key []byte, data []byte) { |
||||
self[string(key)] = data |
||||
} |
||||
|
||||
type Trie struct { |
||||
mu sync.Mutex |
||||
root Node |
||||
roothash []byte |
||||
backend Backend |
||||
} |
||||
|
||||
func NewEmpty() *Trie { |
||||
return &Trie{sync.Mutex{}, nil, nil, make(Cache)} |
||||
} |
||||
|
||||
func New(root []byte, backend Backend) *Trie { |
||||
trie := &Trie{} |
||||
trie.roothash = root |
||||
trie.backend = backend |
||||
|
||||
value := ethutil.NewValueFromBytes(trie.backend.Get(root)) |
||||
trie.root = trie.mknode(value) |
||||
|
||||
return trie |
||||
} |
||||
|
||||
func (self *Trie) Hash() []byte { |
||||
var hash []byte |
||||
if self.root != nil { |
||||
t := self.root.Hash() |
||||
if byts, ok := t.([]byte); ok { |
||||
hash = byts |
||||
} else { |
||||
hash = crypto.Sha3(ethutil.Encode(self.root.RlpData())) |
||||
} |
||||
} else { |
||||
hash = crypto.Sha3(ethutil.Encode(self.root)) |
||||
} |
||||
|
||||
self.roothash = hash |
||||
|
||||
return hash |
||||
} |
||||
|
||||
func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } |
||||
func (self *Trie) Update(key, value []byte) Node { |
||||
self.mu.Lock() |
||||
defer self.mu.Unlock() |
||||
|
||||
k := trie.CompactHexDecode(string(key)) |
||||
|
||||
if len(value) != 0 { |
||||
self.root = self.insert(self.root, k, &ValueNode{self, value}) |
||||
} else { |
||||
self.root = self.delete(self.root, k) |
||||
} |
||||
|
||||
return self.root |
||||
} |
||||
|
||||
func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } |
||||
func (self *Trie) Get(key []byte) []byte { |
||||
self.mu.Lock() |
||||
defer self.mu.Unlock() |
||||
|
||||
k := trie.CompactHexDecode(string(key)) |
||||
|
||||
n := self.get(self.root, k) |
||||
if n != nil { |
||||
return n.(*ValueNode).Val() |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } |
||||
func (self *Trie) Delete(key []byte) Node { |
||||
self.mu.Lock() |
||||
defer self.mu.Unlock() |
||||
|
||||
k := trie.CompactHexDecode(string(key)) |
||||
self.root = self.delete(self.root, k) |
||||
|
||||
return self.root |
||||
} |
||||
|
||||
func (self *Trie) insert(node Node, key []byte, value Node) Node { |
||||
if len(key) == 0 { |
||||
return value |
||||
} |
||||
|
||||
if node == nil { |
||||
return NewShortNode(self, key, value) |
||||
} |
||||
|
||||
switch node := node.(type) { |
||||
case *ShortNode: |
||||
k := node.Key() |
||||
cnode := node.Value() |
||||
if bytes.Equal(k, key) { |
||||
return NewShortNode(self, key, value) |
||||
} |
||||
|
||||
var n Node |
||||
matchlength := trie.MatchingNibbleLength(key, k) |
||||
if matchlength == len(k) { |
||||
n = self.insert(cnode, key[matchlength:], value) |
||||
} else { |
||||
pnode := self.insert(nil, k[matchlength+1:], cnode) |
||||
nnode := self.insert(nil, key[matchlength+1:], value) |
||||
fulln := NewFullNode(self) |
||||
fulln.set(k[matchlength], pnode) |
||||
fulln.set(key[matchlength], nnode) |
||||
n = fulln |
||||
} |
||||
if matchlength == 0 { |
||||
return n |
||||
} |
||||
|
||||
return NewShortNode(self, key[:matchlength], n) |
||||
|
||||
case *FullNode: |
||||
cpy := node.Copy().(*FullNode) |
||||
cpy.set(key[0], self.insert(node.get(key[0]), key[1:], value)) |
||||
|
||||
return cpy |
||||
|
||||
default: |
||||
panic("Invalid node") |
||||
} |
||||
} |
||||
|
||||
func (self *Trie) get(node Node, key []byte) Node { |
||||
if len(key) == 0 { |
||||
return node |
||||
} |
||||
|
||||
if node == nil { |
||||
return nil |
||||
} |
||||
|
||||
switch node := node.(type) { |
||||
case *ShortNode: |
||||
k := node.Key() |
||||
cnode := node.Value() |
||||
|
||||
if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { |
||||
return self.get(cnode, key[len(k):]) |
||||
} |
||||
|
||||
return nil |
||||
case *FullNode: |
||||
return self.get(node.get(key[0]), key[1:]) |
||||
default: |
||||
panic("Invalid node") |
||||
} |
||||
} |
||||
|
||||
func (self *Trie) delete(node Node, key []byte) Node { |
||||
if len(key) == 0 { |
||||
return nil |
||||
} |
||||
|
||||
switch node := node.(type) { |
||||
case *ShortNode: |
||||
k := node.Key() |
||||
cnode := node.Value() |
||||
if bytes.Equal(key, k) { |
||||
return nil |
||||
} else if bytes.Equal(key[:len(k)], k) { |
||||
child := self.delete(cnode, key[len(k):]) |
||||
|
||||
var n Node |
||||
switch child := child.(type) { |
||||
case *ShortNode: |
||||
nkey := append(k, child.Key()...) |
||||
n = NewShortNode(self, nkey, child.Value()) |
||||
case *FullNode: |
||||
n = NewShortNode(self, node.key, child) |
||||
} |
||||
|
||||
return n |
||||
} else { |
||||
return node |
||||
} |
||||
|
||||
case *FullNode: |
||||
n := node.Copy().(*FullNode) |
||||
n.set(key[0], self.delete(n.get(key[0]), key[1:])) |
||||
|
||||
pos := -1 |
||||
for i := 0; i < 17; i++ { |
||||
if n.get(byte(i)) != nil { |
||||
if pos == -1 { |
||||
pos = i |
||||
} else { |
||||
pos = -2 |
||||
} |
||||
} |
||||
} |
||||
|
||||
var nnode Node |
||||
if pos == 16 { |
||||
nnode = NewShortNode(self, []byte{16}, n.get(byte(pos))) |
||||
} else if pos >= 0 { |
||||
cnode := n.get(byte(pos)) |
||||
switch cnode := cnode.(type) { |
||||
case *ShortNode: |
||||
// Stitch keys
|
||||
k := append([]byte{byte(pos)}, cnode.Key()...) |
||||
nnode = NewShortNode(self, k, cnode.Value()) |
||||
case *FullNode: |
||||
nnode = NewShortNode(self, []byte{byte(pos)}, n.get(byte(pos))) |
||||
} |
||||
} else { |
||||
nnode = n |
||||
} |
||||
|
||||
return nnode |
||||
|
||||
default: |
||||
panic("Invalid node") |
||||
} |
||||
} |
||||
|
||||
// casting functions and cache storing
|
||||
func (self *Trie) mknode(value *ethutil.Value) Node { |
||||
l := value.Len() |
||||
switch l { |
||||
case 2: |
||||
return NewShortNode(self, trie.CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1))) |
||||
case 17: |
||||
fnode := NewFullNode(self) |
||||
for i := 0; i < l; i++ { |
||||
fnode.set(byte(i), self.mknode(value.Get(i))) |
||||
} |
||||
return fnode |
||||
case 32: |
||||
return &HashNode{value.Bytes()} |
||||
default: |
||||
return &ValueNode{self, value.Bytes()} |
||||
} |
||||
} |
||||
|
||||
func (self *Trie) trans(node Node) Node { |
||||
switch node := node.(type) { |
||||
case *HashNode: |
||||
value := ethutil.NewValueFromBytes(self.backend.Get(node.key)) |
||||
return self.mknode(value) |
||||
default: |
||||
return node |
||||
} |
||||
} |
||||
|
||||
func (self *Trie) store(node Node) interface{} { |
||||
data := ethutil.Encode(node) |
||||
if len(data) >= 32 { |
||||
key := crypto.Sha3(data) |
||||
self.backend.Set(key, data) |
||||
|
||||
return key |
||||
} |
||||
|
||||
return node.RlpData() |
||||
} |
@ -0,0 +1,138 @@ |
||||
package ptrie |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"testing" |
||||
|
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
) |
||||
|
||||
func TestInsert(t *testing.T) { |
||||
trie := NewEmpty() |
||||
|
||||
trie.UpdateString("doe", "reindeer") |
||||
trie.UpdateString("dog", "puppy") |
||||
trie.UpdateString("dogglesworth", "cat") |
||||
|
||||
exp := ethutil.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") |
||||
root := trie.Hash() |
||||
if !bytes.Equal(root, exp) { |
||||
t.Errorf("exp %x got %x", exp, root) |
||||
} |
||||
|
||||
trie = NewEmpty() |
||||
trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") |
||||
|
||||
exp = ethutil.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") |
||||
root = trie.Hash() |
||||
if !bytes.Equal(root, exp) { |
||||
t.Errorf("exp %x got %x", exp, root) |
||||
} |
||||
} |
||||
|
||||
func TestGet(t *testing.T) { |
||||
trie := NewEmpty() |
||||
|
||||
trie.UpdateString("doe", "reindeer") |
||||
trie.UpdateString("dog", "puppy") |
||||
trie.UpdateString("dogglesworth", "cat") |
||||
|
||||
res := trie.GetString("dog") |
||||
if !bytes.Equal(res, []byte("puppy")) { |
||||
t.Errorf("expected puppy got %x", res) |
||||
} |
||||
|
||||
unknown := trie.GetString("unknown") |
||||
if unknown != nil { |
||||
t.Errorf("expected nil got %x", unknown) |
||||
} |
||||
} |
||||
|
||||
func TestDelete(t *testing.T) { |
||||
trie := NewEmpty() |
||||
|
||||
vals := []struct{ k, v string }{ |
||||
{"do", "verb"}, |
||||
{"ether", "wookiedoo"}, |
||||
{"horse", "stallion"}, |
||||
{"shaman", "horse"}, |
||||
{"doge", "coin"}, |
||||
{"ether", ""}, |
||||
{"dog", "puppy"}, |
||||
{"shaman", ""}, |
||||
} |
||||
for _, val := range vals { |
||||
trie.UpdateString(val.k, val.v) |
||||
} |
||||
|
||||
hash := trie.Hash() |
||||
exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") |
||||
if !bytes.Equal(hash, exp) { |
||||
t.Errorf("expected %x got %x", exp, hash) |
||||
} |
||||
} |
||||
|
||||
func TestReplication(t *testing.T) { |
||||
trie := NewEmpty() |
||||
vals := []struct{ k, v string }{ |
||||
{"do", "verb"}, |
||||
{"ether", "wookiedoo"}, |
||||
{"horse", "stallion"}, |
||||
{"shaman", "horse"}, |
||||
{"doge", "coin"}, |
||||
{"ether", ""}, |
||||
{"dog", "puppy"}, |
||||
{"shaman", ""}, |
||||
{"somethingveryoddindeedthis is", "myothernodedata"}, |
||||
} |
||||
for _, val := range vals { |
||||
trie.UpdateString(val.k, val.v) |
||||
} |
||||
trie.Hash() |
||||
|
||||
trie2 := New(trie.roothash, trie.backend) |
||||
if string(trie2.GetString("horse")) != "stallion" { |
||||
t.Error("expected to have harse => stallion") |
||||
} |
||||
|
||||
hash := trie2.Hash() |
||||
exp := trie.Hash() |
||||
if !bytes.Equal(hash, exp) { |
||||
t.Errorf("root failure. expected %x got %x", exp, hash) |
||||
} |
||||
|
||||
} |
||||
|
||||
func BenchmarkGets(b *testing.B) { |
||||
trie := NewEmpty() |
||||
vals := []struct{ k, v string }{ |
||||
{"do", "verb"}, |
||||
{"ether", "wookiedoo"}, |
||||
{"horse", "stallion"}, |
||||
{"shaman", "horse"}, |
||||
{"doge", "coin"}, |
||||
{"ether", ""}, |
||||
{"dog", "puppy"}, |
||||
{"shaman", ""}, |
||||
{"somethingveryoddindeedthis is", "myothernodedata"}, |
||||
} |
||||
for _, val := range vals { |
||||
trie.UpdateString(val.k, val.v) |
||||
} |
||||
|
||||
b.ResetTimer() |
||||
for i := 0; i < b.N; i++ { |
||||
trie.Get([]byte("horse")) |
||||
} |
||||
} |
||||
|
||||
func BenchmarkUpdate(b *testing.B) { |
||||
trie := NewEmpty() |
||||
|
||||
b.ResetTimer() |
||||
for i := 0; i < b.N; i++ { |
||||
trie.UpdateString(fmt.Sprintf("aaaaaaaaaaaaaaa%d", j), "value") |
||||
} |
||||
trie.Hash() |
||||
} |
@ -0,0 +1,13 @@ |
||||
package ptrie |
||||
|
||||
type ValueNode struct { |
||||
trie *Trie |
||||
data []byte |
||||
} |
||||
|
||||
func (self *ValueNode) Value() Node { return self } // Best not to call :-)
|
||||
func (self *ValueNode) Val() []byte { return self.data } |
||||
func (self *ValueNode) Dirty() bool { return true } |
||||
func (self *ValueNode) Copy() Node { return self } |
||||
func (self *ValueNode) RlpData() interface{} { return self.data } |
||||
func (self *ValueNode) Hash() interface{} { return self.data } |
Loading…
Reference in new issue