diff --git a/trie/iterator.go b/trie/iterator.go index ddc674d2b..42149a7d3 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -18,7 +18,7 @@ package trie import ( "bytes" - + "container/heap" "github.com/ethereum/go-ethereum/common" ) @@ -268,6 +268,26 @@ outer: return nil } +func compareNodes(a, b NodeIterator) int { + cmp := bytes.Compare(a.Path(), b.Path()) + if cmp != 0 { + return cmp + } + + if a.Leaf() && !b.Leaf() { + return -1 + } else if b.Leaf() && !a.Leaf() { + return 1 + } + + cmp = bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()) + if cmp != 0 { + return cmp + } + + return bytes.Compare(a.LeafBlob(), b.LeafBlob()) +} + type differenceIterator struct { a, b NodeIterator // Nodes returned are those in b - a. eof bool // Indicates a has run out of elements @@ -321,8 +341,7 @@ func (it *differenceIterator) Next(bool) bool { } for { - apath, bpath := it.a.Path(), it.b.Path() - switch bytes.Compare(apath, bpath) { + switch compareNodes(it.a, it.b) { case -1: // b jumped past a; advance a if !it.a.Next(true) { @@ -334,15 +353,6 @@ func (it *differenceIterator) Next(bool) bool { // b is before a return true case 0: - if it.a.Hash() != it.b.Hash() || it.a.Leaf() != it.b.Leaf() { - // Keys are identical, but hashes or leaf status differs - return true - } - if it.a.Leaf() && it.b.Leaf() && !bytes.Equal(it.a.LeafBlob(), it.b.LeafBlob()) { - // Both are leaf nodes, but with different values - return true - } - // a and b are identical; skip this whole subtree if the nodes have hashes hasHash := it.a.Hash() == common.Hash{} if !it.b.Next(hasHash) { @@ -364,3 +374,107 @@ func (it *differenceIterator) Error() error { } return it.b.Error() } + +type nodeIteratorHeap []NodeIterator + +func (h nodeIteratorHeap) Len() int { return len(h) } +func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 } +func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) } +func (h *nodeIteratorHeap) Pop() interface{} { + n := len(*h) + x := (*h)[n-1] + *h = (*h)[0 : n-1] + return x +} + +type unionIterator struct { + items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators + count int // Number of nodes scanned across all tries + err error // The error, if one has been encountered +} + +// NewUnionIterator constructs a NodeIterator that iterates over elements in the union +// of the provided NodeIterators. Returns the iterator, and a pointer to an integer +// recording the number of nodes visited. +func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) { + h := make(nodeIteratorHeap, len(iters)) + copy(h, iters) + heap.Init(&h) + + ui := &unionIterator{ + items: &h, + } + return ui, &ui.count +} + +func (it *unionIterator) Hash() common.Hash { + return (*it.items)[0].Hash() +} + +func (it *unionIterator) Parent() common.Hash { + return (*it.items)[0].Parent() +} + +func (it *unionIterator) Leaf() bool { + return (*it.items)[0].Leaf() +} + +func (it *unionIterator) LeafBlob() []byte { + return (*it.items)[0].LeafBlob() +} + +func (it *unionIterator) Path() []byte { + return (*it.items)[0].Path() +} + +// Next returns the next node in the union of tries being iterated over. +// +// It does this by maintaining a heap of iterators, sorted by the iteration +// order of their next elements, with one entry for each source trie. Each +// time Next() is called, it takes the least element from the heap to return, +// advancing any other iterators that also point to that same element. These +// iterators are called with descend=false, since we know that any nodes under +// these nodes will also be duplicates, found in the currently selected iterator. +// Whenever an iterator is advanced, it is pushed back into the heap if it still +// has elements remaining. +// +// In the case that descend=false - eg, we're asked to ignore all subnodes of the +// current node - we also advance any iterators in the heap that have the current +// path as a prefix. +func (it *unionIterator) Next(descend bool) bool { + if len(*it.items) == 0 { + return false + } + + // Get the next key from the union + least := heap.Pop(it.items).(NodeIterator) + + // Skip over other nodes as long as they're identical, or, if we're not descending, as + // long as they have the same prefix as the current node. + for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) { + skipped := heap.Pop(it.items).(NodeIterator) + // Skip the whole subtree if the nodes have hashes; otherwise just skip this node + if skipped.Next(skipped.Hash() == common.Hash{}) { + it.count += 1 + // If there are more elements, push the iterator back on the heap + heap.Push(it.items, skipped) + } + } + + if least.Next(descend) { + it.count += 1 + heap.Push(it.items, least) + } + + return len(*it.items) > 0 +} + +func (it *unionIterator) Error() error { + for i := 0; i < len(*it.items); i++ { + if err := (*it.items)[i].Error(); err != nil { + return err + } + } + return nil +} diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 0ad9711ed..c101bb7b0 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -117,36 +117,38 @@ func TestNodeIteratorCoverage(t *testing.T) { } } +var testdata1 = []struct{ k, v string }{ + {"bar", "b"}, + {"barb", "ba"}, + {"bars", "bb"}, + {"bard", "bc"}, + {"fab", "z"}, + {"foo", "a"}, + {"food", "ab"}, + {"foos", "aa"}, +} + +var testdata2 = []struct{ k, v string }{ + {"aardvark", "c"}, + {"bar", "b"}, + {"barb", "bd"}, + {"bars", "be"}, + {"fab", "z"}, + {"foo", "a"}, + {"foos", "aa"}, + {"food", "ab"}, + {"jars", "d"}, +} + func TestDifferenceIterator(t *testing.T) { triea := newEmpty() - valsa := []struct{ k, v string }{ - {"bar", "b"}, - {"barb", "ba"}, - {"bars", "bb"}, - {"bard", "bc"}, - {"fab", "z"}, - {"foo", "a"}, - {"food", "ab"}, - {"foos", "aa"}, - } - for _, val := range valsa { + for _, val := range testdata1 { triea.Update([]byte(val.k), []byte(val.v)) } triea.Commit() trieb := newEmpty() - valsb := []struct{ k, v string }{ - {"aardvark", "c"}, - {"bar", "b"}, - {"barb", "bd"}, - {"bars", "be"}, - {"fab", "z"}, - {"foo", "a"}, - {"foos", "aa"}, - {"food", "ab"}, - {"jars", "d"}, - } - for _, val := range valsb { + for _, val := range testdata2 { trieb.Update([]byte(val.k), []byte(val.v)) } trieb.Commit() @@ -166,10 +168,57 @@ func TestDifferenceIterator(t *testing.T) { } for _, item := range all { if found[item.k] != item.v { - t.Errorf("iterator value mismatch for %s: got %q want %q", item.k, found[item.k], item.v) + t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v) } } if len(found) != len(all) { t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all)) } } + +func TestUnionIterator(t *testing.T) { + triea := newEmpty() + for _, val := range testdata1 { + triea.Update([]byte(val.k), []byte(val.v)) + } + triea.Commit() + + trieb := newEmpty() + for _, val := range testdata2 { + trieb.Update([]byte(val.k), []byte(val.v)) + } + trieb.Commit() + + di, _ := NewUnionIterator([]NodeIterator{NewNodeIterator(triea), NewNodeIterator(trieb)}) + it := NewIteratorFromNodeIterator(di) + + all := []struct{ k, v string }{ + {"aardvark", "c"}, + {"barb", "bd"}, + {"barb", "ba"}, + {"bard", "bc"}, + {"bars", "bb"}, + {"bars", "be"}, + {"bar", "b"}, + {"fab", "z"}, + {"food", "ab"}, + {"foos", "aa"}, + {"foo", "a"}, + {"jars", "d"}, + } + + for i, kv := range all { + if !it.Next() { + t.Errorf("Iterator ends prematurely at element %d", i) + } + if kv.k != string(it.Key) { + t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k) + } + if kv.v != string(it.Value) { + t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v) + } + } + if it.Next() { + t.Errorf("Iterator returned extra values.") + } +}