diff --git a/core/rawdb/table.go b/core/rawdb/table.go index 4daa6b5349..323ef6293c 100644 --- a/core/rawdb/table.go +++ b/core/rawdb/table.go @@ -176,11 +176,6 @@ func (b *tableBatch) Delete(key []byte) error { return b.batch.Delete(append([]byte(b.prefix), key...)) } -// KeyCount retrieves the number of keys queued up for writing. -func (b *tableBatch) KeyCount() int { - return b.batch.KeyCount() -} - // ValueSize retrieves the amount of data queued up for writing. func (b *tableBatch) ValueSize() int { return b.batch.ValueSize() diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index 78fca45e44..8992d3f91b 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -368,7 +368,7 @@ func (dl *diskLayer) proveRange(stats *generatorStats, root common.Hash, prefix } // Verify the snapshot segment with range prover, ensure that all flat states // in this range correspond to merkle trie. - _, cont, err := trie.VerifyRangeProof(root, origin, last, keys, vals, proof) + cont, err := trie.VerifyRangeProof(root, origin, last, keys, vals, proof) return &proofResult{ keys: keys, vals: vals, diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go index 287ac8d727..d9c0cb9b1b 100644 --- a/eth/protocols/snap/sync.go +++ b/eth/protocols/snap/sync.go @@ -202,9 +202,8 @@ type storageResponse struct { accounts []common.Hash // Account hashes requested, may be only partially filled roots []common.Hash // Storage roots requested, may be only partially filled - hashes [][]common.Hash // Storage slot hashes in the returned range - slots [][][]byte // Storage slot values in the returned range - nodes []ethdb.KeyValueStore // Database containing the reconstructed trie nodes + hashes [][]common.Hash // Storage slot hashes in the returned range + slots [][][]byte // Storage slot values in the returned range cont bool // Whether the last storage range has a continuation } @@ -680,12 +679,22 @@ func (s *Syncer) loadSyncStatus() { } s.tasks = progress.Tasks for _, task := range s.tasks { - task.genBatch = s.db.NewBatch() + task.genBatch = ethdb.HookedBatch{ + Batch: s.db.NewBatch(), + OnPut: func(key []byte, value []byte) { + s.accountBytes += common.StorageSize(len(key) + len(value)) + }, + } task.genTrie = trie.NewStackTrie(task.genBatch) for _, subtasks := range task.SubTasks { for _, subtask := range subtasks { - subtask.genBatch = s.db.NewBatch() + subtask.genBatch = ethdb.HookedBatch{ + Batch: s.db.NewBatch(), + OnPut: func(key []byte, value []byte) { + s.storageBytes += common.StorageSize(len(key) + len(value)) + }, + } subtask.genTrie = trie.NewStackTrie(task.genBatch) } } @@ -729,7 +738,12 @@ func (s *Syncer) loadSyncStatus() { // Make sure we don't overflow if the step is not a proper divisor last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") } - batch := s.db.NewBatch() + batch := ethdb.HookedBatch{ + Batch: s.db.NewBatch(), + OnPut: func(key []byte, value []byte) { + s.accountBytes += common.StorageSize(len(key) + len(value)) + }, + } s.tasks = append(s.tasks, &accountTask{ Next: next, Last: last, @@ -746,19 +760,14 @@ func (s *Syncer) loadSyncStatus() { func (s *Syncer) saveSyncStatus() { // Serialize any partial progress to disk before spinning down for _, task := range s.tasks { - keys, bytes := task.genBatch.KeyCount(), task.genBatch.ValueSize() if err := task.genBatch.Write(); err != nil { log.Error("Failed to persist account slots", "err", err) } - s.accountBytes += common.StorageSize(keys*common.HashLength + bytes) - for _, subtasks := range task.SubTasks { for _, subtask := range subtasks { - keys, bytes := subtask.genBatch.KeyCount(), subtask.genBatch.ValueSize() if err := subtask.genBatch.Write(); err != nil { log.Error("Failed to persist storage slots", "err", err) } - s.accountBytes += common.StorageSize(keys*common.HashLength + bytes) } } } @@ -1763,12 +1772,15 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { if res.subTask != nil { res.subTask.req = nil } - batch := s.db.NewBatch() - + batch := ethdb.HookedBatch{ + Batch: s.db.NewBatch(), + OnPut: func(key []byte, value []byte) { + s.storageBytes += common.StorageSize(len(key) + len(value)) + }, + } var ( - slots int - nodes int - bytes common.StorageSize + slots int + oldStorageBytes = s.storageBytes ) // Iterate over all the accounts and reconstruct their storage tries from the // delivered slots @@ -1829,7 +1841,12 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { r := newHashRange(lastKey, chunks) // Our first task is the one that was just filled by this response. - batch := s.db.NewBatch() + batch := ethdb.HookedBatch{ + Batch: s.db.NewBatch(), + OnPut: func(key []byte, value []byte) { + s.storageBytes += common.StorageSize(len(key) + len(value)) + }, + } tasks = append(tasks, &storageTask{ Next: common.Hash{}, Last: r.End(), @@ -1838,7 +1855,12 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { genTrie: trie.NewStackTrie(batch), }) for r.Next() { - batch := s.db.NewBatch() + batch := ethdb.HookedBatch{ + Batch: s.db.NewBatch(), + OnPut: func(key []byte, value []byte) { + s.storageBytes += common.StorageSize(len(key) + len(value)) + }, + } tasks = append(tasks, &storageTask{ Next: r.Start(), Last: r.End(), @@ -1883,27 +1905,23 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { } } } - // Iterate over all the reconstructed trie nodes and push them to disk - // if the contract is fully delivered. If it's chunked, the trie nodes - // will be reconstructed later. + // Iterate over all the complete contracts, reconstruct the trie nodes and + // push them to disk. If the contract is chunked, the trie nodes will be + // reconstructed later. slots += len(res.hashes[i]) if i < len(res.hashes)-1 || res.subTask == nil { - it := res.nodes[i].NewIterator(nil, nil) - for it.Next() { - batch.Put(it.Key(), it.Value()) - - bytes += common.StorageSize(common.HashLength + len(it.Value())) - nodes++ + tr := trie.NewStackTrie(batch) + for j := 0; j < len(res.hashes[i]); j++ { + tr.Update(res.hashes[i][j][:], res.slots[i][j]) } - it.Release() + tr.Commit() } // Persist the received storage segements. These flat state maybe // outdated during the sync, but it can be fixed later during the // snapshot generation. for j := 0; j < len(res.hashes[i]); j++ { rawdb.WriteStorageSnapshot(batch, account, res.hashes[i][j], res.slots[i][j]) - bytes += common.StorageSize(1 + 2*common.HashLength + len(res.slots[i][j])) // If we're storing large contracts, generate the trie nodes // on the fly to not trash the gluing points @@ -1926,15 +1944,11 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { } } } - if data := res.subTask.genBatch.ValueSize(); data > ethdb.IdealBatchSize || res.subTask.done { - keys := res.subTask.genBatch.KeyCount() + if res.subTask.genBatch.ValueSize() > ethdb.IdealBatchSize || res.subTask.done { if err := res.subTask.genBatch.Write(); err != nil { log.Error("Failed to persist stack slots", "err", err) } res.subTask.genBatch.Reset() - - bytes += common.StorageSize(keys*common.HashLength + data) - nodes += keys } } // Flush anything written just now and update the stats @@ -1942,9 +1956,8 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { log.Crit("Failed to persist storage slots", "err", err) } s.storageSynced += uint64(slots) - s.storageBytes += bytes - log.Debug("Persisted set of storage slots", "accounts", len(res.hashes), "slots", slots, "nodes", nodes, "bytes", bytes) + log.Debug("Persisted set of storage slots", "accounts", len(res.hashes), "slots", slots, "bytes", s.storageBytes-oldStorageBytes) // If this delivery completed the last pending task, forward the account task // to the next chunk @@ -2042,18 +2055,20 @@ func (s *Syncer) forwardAccountTask(task *accountTask) { // Persist the received account segements. These flat state maybe // outdated during the sync, but it can be fixed later during the // snapshot generation. - var ( - nodes int - bytes common.StorageSize - ) - batch := s.db.NewBatch() + oldAccountBytes := s.accountBytes + + batch := ethdb.HookedBatch{ + Batch: s.db.NewBatch(), + OnPut: func(key []byte, value []byte) { + s.accountBytes += common.StorageSize(len(key) + len(value)) + }, + } for i, hash := range res.hashes { if task.needCode[i] || task.needState[i] { break } slim := snapshot.SlimAccountRLP(res.accounts[i].Nonce, res.accounts[i].Balance, res.accounts[i].Root, res.accounts[i].CodeHash) rawdb.WriteAccountSnapshot(batch, hash, slim) - bytes += common.StorageSize(1 + common.HashLength + len(slim)) // If the task is complete, drop it into the stack trie to generate // account trie nodes for it @@ -2069,7 +2084,6 @@ func (s *Syncer) forwardAccountTask(task *accountTask) { if err := batch.Write(); err != nil { log.Crit("Failed to persist accounts", "err", err) } - s.accountBytes += bytes s.accountSynced += uint64(len(res.accounts)) // Task filling persisted, push it the chunk marker forward to the first @@ -2091,17 +2105,13 @@ func (s *Syncer) forwardAccountTask(task *accountTask) { log.Error("Failed to commit stack account", "err", err) } } - if data := task.genBatch.ValueSize(); data > ethdb.IdealBatchSize || task.done { - keys := task.genBatch.KeyCount() + if task.genBatch.ValueSize() > ethdb.IdealBatchSize || task.done { if err := task.genBatch.Write(); err != nil { log.Error("Failed to persist stack account", "err", err) } task.genBatch.Reset() - - nodes += keys - bytes += common.StorageSize(keys*common.HashLength + data) } - log.Debug("Persisted range of accounts", "accounts", len(res.accounts), "nodes", nodes, "bytes", bytes) + log.Debug("Persisted range of accounts", "accounts", len(res.accounts), "bytes", s.accountBytes-oldAccountBytes) } // OnAccounts is a callback method to invoke when a range of accounts are @@ -2176,7 +2186,7 @@ func (s *Syncer) OnAccounts(peer SyncPeer, id uint64, hashes []common.Hash, acco if len(keys) > 0 { end = keys[len(keys)-1] } - _, cont, err := trie.VerifyRangeProof(root, req.origin[:], end, keys, accounts, proofdb) + cont, err := trie.VerifyRangeProof(root, req.origin[:], end, keys, accounts, proofdb) if err != nil { logger.Warn("Account range failed proof", "err", err) // Signal this request as failed, and ready for rescheduling @@ -2393,10 +2403,8 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo s.lock.Unlock() // Reconstruct the partial tries from the response and verify them - var ( - dbs = make([]ethdb.KeyValueStore, len(hashes)) - cont bool - ) + var cont bool + for i := 0; i < len(hashes); i++ { // Convert the keys and proofs into an internal format keys := make([][]byte, len(hashes[i])) @@ -2413,7 +2421,7 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo if len(nodes) == 0 { // No proof has been attached, the response must cover the entire key // space and hash to the origin root. - dbs[i], _, err = trie.VerifyRangeProof(req.roots[i], nil, nil, keys, slots[i], nil) + _, err = trie.VerifyRangeProof(req.roots[i], nil, nil, keys, slots[i], nil) if err != nil { s.scheduleRevertStorageRequest(req) // reschedule request logger.Warn("Storage slots failed proof", "err", err) @@ -2428,7 +2436,7 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo if len(keys) > 0 { end = keys[len(keys)-1] } - dbs[i], cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], end, keys, slots[i], proofdb) + cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], end, keys, slots[i], proofdb) if err != nil { s.scheduleRevertStorageRequest(req) // reschedule request logger.Warn("Storage range failed proof", "err", err) @@ -2444,7 +2452,6 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo roots: req.roots, hashes: hashes, slots: slots, - nodes: dbs, cont: cont, } select { diff --git a/ethdb/batch.go b/ethdb/batch.go index 5f8207fc46..1353693318 100644 --- a/ethdb/batch.go +++ b/ethdb/batch.go @@ -25,9 +25,6 @@ const IdealBatchSize = 100 * 1024 type Batch interface { KeyValueWriter - // KeyCount retrieves the number of keys queued up for writing. - KeyCount() int - // ValueSize retrieves the amount of data queued up for writing. ValueSize() int @@ -47,3 +44,28 @@ type Batcher interface { // until a final write is called. NewBatch() Batch } + +// HookedBatch wraps an arbitrary batch where each operation may be hooked into +// to monitor from black box code. +type HookedBatch struct { + Batch + + OnPut func(key []byte, value []byte) // Callback if a key is inserted + OnDelete func(key []byte) // Callback if a key is deleted +} + +// Put inserts the given value into the key-value data store. +func (b HookedBatch) Put(key []byte, value []byte) error { + if b.OnPut != nil { + b.OnPut(key, value) + } + return b.Batch.Put(key, value) +} + +// Delete removes the key from the key-value data store. +func (b HookedBatch) Delete(key []byte) error { + if b.OnDelete != nil { + b.OnDelete(key) + } + return b.Batch.Delete(key) +} diff --git a/ethdb/leveldb/leveldb.go b/ethdb/leveldb/leveldb.go index da00226e95..5d19cc3577 100644 --- a/ethdb/leveldb/leveldb.go +++ b/ethdb/leveldb/leveldb.go @@ -448,7 +448,6 @@ func (db *Database) meter(refresh time.Duration) { type batch struct { db *leveldb.DB b *leveldb.Batch - keys int size int } @@ -462,16 +461,10 @@ func (b *batch) Put(key, value []byte) error { // Delete inserts the a key removal into the batch for later committing. func (b *batch) Delete(key []byte) error { b.b.Delete(key) - b.keys++ b.size += len(key) return nil } -// KeyCount retrieves the number of keys queued up for writing. -func (b *batch) KeyCount() int { - return b.keys -} - // ValueSize retrieves the amount of data queued up for writing. func (b *batch) ValueSize() int { return b.size @@ -485,7 +478,7 @@ func (b *batch) Write() error { // Reset resets the batch for reuse. func (b *batch) Reset() { b.b.Reset() - b.keys, b.size = 0, 0 + b.size = 0 } // Replay replays the batch contents. diff --git a/ethdb/memorydb/memorydb.go b/ethdb/memorydb/memorydb.go index ded9f5e66c..fedc9e326c 100644 --- a/ethdb/memorydb/memorydb.go +++ b/ethdb/memorydb/memorydb.go @@ -198,7 +198,6 @@ type keyvalue struct { type batch struct { db *Database writes []keyvalue - keys int size int } @@ -212,16 +211,10 @@ func (b *batch) Put(key, value []byte) error { // Delete inserts the a key removal into the batch for later committing. func (b *batch) Delete(key []byte) error { b.writes = append(b.writes, keyvalue{common.CopyBytes(key), nil, true}) - b.keys++ b.size += len(key) return nil } -// KeyCount retrieves the number of keys queued up for writing. -func (b *batch) KeyCount() int { - return b.keys -} - // ValueSize retrieves the amount of data queued up for writing. func (b *batch) ValueSize() int { return b.size @@ -245,7 +238,7 @@ func (b *batch) Write() error { // Reset resets the batch for reuse. func (b *batch) Reset() { b.writes = b.writes[:0] - b.keys, b.size = 0, 0 + b.size = 0 } // Replay replays the batch contents. diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go index 984bb9d0a8..09ee6bb9c7 100644 --- a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go +++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go @@ -170,18 +170,11 @@ func (f *fuzzer) fuzz() int { } ok = 1 //nodes, subtrie - nodes, hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof) + hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof) if err != nil { - if nodes != nil { - panic("err != nil && nodes != nil") - } if hasMore { panic("err != nil && hasMore == true") } - } else { - if nodes == nil { - panic("err == nil && nodes == nil") - } } } return ok diff --git a/tests/fuzzers/stacktrie/trie_fuzzer.go b/tests/fuzzers/stacktrie/trie_fuzzer.go index 0013c919c9..5cea7769c2 100644 --- a/tests/fuzzers/stacktrie/trie_fuzzer.go +++ b/tests/fuzzers/stacktrie/trie_fuzzer.go @@ -90,7 +90,6 @@ func (b *spongeBatch) Put(key, value []byte) error { return nil } func (b *spongeBatch) Delete(key []byte) error { panic("implement me") } -func (b *spongeBatch) KeyCount() int { panic("not implemented") } func (b *spongeBatch) ValueSize() int { return 100 } func (b *spongeBatch) Write() error { return nil } func (b *spongeBatch) Reset() {} diff --git a/trie/notary.go b/trie/notary.go deleted file mode 100644 index 10c7628f55..0000000000 --- a/trie/notary.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2020 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package trie - -import ( - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/ethdb/memorydb" -) - -// keyValueNotary tracks which keys have been accessed through a key-value reader -// with te scope of verifying if certain proof datasets are maliciously bloated. -type keyValueNotary struct { - ethdb.KeyValueReader - reads map[string]struct{} -} - -// newKeyValueNotary wraps a key-value database with an access notary to track -// which items have bene accessed. -func newKeyValueNotary(db ethdb.KeyValueReader) *keyValueNotary { - return &keyValueNotary{ - KeyValueReader: db, - reads: make(map[string]struct{}), - } -} - -// Get retrieves an item from the underlying database, but also tracks it as an -// accessed slot for bloat checks. -func (k *keyValueNotary) Get(key []byte) ([]byte, error) { - k.reads[string(key)] = struct{}{} - return k.KeyValueReader.Get(key) -} - -// Accessed returns s snapshot of the original key-value store containing only the -// data accessed through the notary. -func (k *keyValueNotary) Accessed() ethdb.KeyValueStore { - db := memorydb.New() - for keystr := range k.reads { - key := []byte(keystr) - val, _ := k.KeyValueReader.Get(key) - db.Put(key, val) - } - return db -} diff --git a/trie/proof.go b/trie/proof.go index 2feed24de4..08a9e40422 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -464,108 +464,91 @@ func hasRightElement(node node, key []byte) bool { // // Except returning the error to indicate the proof is valid or not, the function will // also return a flag to indicate whether there exists more accounts/slots in the trie. -func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, bool, error) { +// +// Note: This method does not verify that the proof is of minimal form. If the input +// proofs are 'bloated' with neighbour leaves or random data, aside from the 'useful' +// data, then the proof will still be accepted. +func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (bool, error) { if len(keys) != len(values) { - return nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)) + return false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)) } // Ensure the received batch is monotonic increasing. for i := 0; i < len(keys)-1; i++ { if bytes.Compare(keys[i], keys[i+1]) >= 0 { - return nil, false, errors.New("range is not monotonically increasing") + return false, errors.New("range is not monotonically increasing") } } - // Create a key-value notary to track which items from the given proof the - // range prover actually needed to verify the data - notary := newKeyValueNotary(proof) - // Special case, there is no edge proof at all. The given range is expected // to be the whole leaf-set in the trie. if proof == nil { - var ( - diskdb = memorydb.New() - tr = NewStackTrie(diskdb) - ) + tr := NewStackTrie(nil) for index, key := range keys { tr.TryUpdate(key, values[index]) } if have, want := tr.Hash(), rootHash; have != want { - return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have) - } - // Proof seems valid, serialize remaining nodes into the database - if _, err := tr.Commit(); err != nil { - return nil, false, err + return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have) } - return diskdb, false, nil // No more elements + return false, nil // No more elements } // Special case, there is a provided edge proof but zero key/value // pairs, ensure there are no more accounts / slots in the trie. if len(keys) == 0 { - root, val, err := proofToPath(rootHash, nil, firstKey, notary, true) + root, val, err := proofToPath(rootHash, nil, firstKey, proof, true) if err != nil { - return nil, false, err + return false, err } if val != nil || hasRightElement(root, firstKey) { - return nil, false, errors.New("more entries available") + return false, errors.New("more entries available") } - // Since the entire proof is a single path, we can construct a trie and a - // node database directly out of the inputs, no need to generate them - diskdb := notary.Accessed() - return diskdb, hasRightElement(root, firstKey), nil + return hasRightElement(root, firstKey), nil } // Special case, there is only one element and two edge keys are same. // In this case, we can't construct two edge paths. So handle it here. if len(keys) == 1 && bytes.Equal(firstKey, lastKey) { - root, val, err := proofToPath(rootHash, nil, firstKey, notary, false) + root, val, err := proofToPath(rootHash, nil, firstKey, proof, false) if err != nil { - return nil, false, err + return false, err } if !bytes.Equal(firstKey, keys[0]) { - return nil, false, errors.New("correct proof but invalid key") + return false, errors.New("correct proof but invalid key") } if !bytes.Equal(val, values[0]) { - return nil, false, errors.New("correct proof but invalid data") + return false, errors.New("correct proof but invalid data") } - // Since the entire proof is a single path, we can construct a trie and a - // node database directly out of the inputs, no need to generate them - diskdb := notary.Accessed() - return diskdb, hasRightElement(root, firstKey), nil + return hasRightElement(root, firstKey), nil } // Ok, in all other cases, we require two edge paths available. // First check the validity of edge keys. if bytes.Compare(firstKey, lastKey) >= 0 { - return nil, false, errors.New("invalid edge keys") + return false, errors.New("invalid edge keys") } // todo(rjl493456442) different length edge keys should be supported if len(firstKey) != len(lastKey) { - return nil, false, errors.New("inconsistent edge keys") + return false, errors.New("inconsistent edge keys") } // Convert the edge proofs to edge trie paths. Then we can // have the same tree architecture with the original one. // For the first edge proof, non-existent proof is allowed. - root, _, err := proofToPath(rootHash, nil, firstKey, notary, true) + root, _, err := proofToPath(rootHash, nil, firstKey, proof, true) if err != nil { - return nil, false, err + return false, err } // Pass the root node here, the second path will be merged // with the first one. For the last edge proof, non-existent // proof is also allowed. - root, _, err = proofToPath(rootHash, root, lastKey, notary, true) + root, _, err = proofToPath(rootHash, root, lastKey, proof, true) if err != nil { - return nil, false, err + return false, err } // Remove all internal references. All the removed parts should // be re-filled(or re-constructed) by the given leaves range. empty, err := unsetInternal(root, firstKey, lastKey) if err != nil { - return nil, false, err + return false, err } // Rebuild the trie with the leaf stream, the shape of trie // should be same with the original one. - var ( - diskdb = memorydb.New() - triedb = NewDatabase(diskdb) - ) - tr := &Trie{root: root, db: triedb} + tr := &Trie{root: root, db: NewDatabase(memorydb.New())} if empty { tr.root = nil } @@ -573,16 +556,9 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key tr.TryUpdate(key, values[index]) } if tr.Hash() != rootHash { - return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) - } - // Proof seems valid, serialize all the nodes into the database - if _, err := tr.Commit(nil); err != nil { - return nil, false, err - } - if err := triedb.Commit(rootHash, false, nil); err != nil { - return nil, false, err + return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) } - return diskdb, hasRightElement(root, keys[len(keys)-1]), nil + return hasRightElement(root, keys[len(keys)-1]), nil } // get returns the child of the given node. Return nil if the diff --git a/trie/proof_test.go b/trie/proof_test.go index 7a906e2540..a35b7144c0 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -182,7 +182,7 @@ func TestRangeProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) if err != nil { t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } @@ -233,7 +233,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) + _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) if err != nil { t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } @@ -254,7 +254,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof) + _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof) if err != nil { t.Fatal("Failed to verify whole rang with non-existent edges") } @@ -289,7 +289,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) + _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) if err == nil { t.Fatalf("Expected to detect the error, got nil") } @@ -311,7 +311,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof) + _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof) if err == nil { t.Fatalf("Expected to detect the error, got nil") } @@ -335,7 +335,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - _, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -350,7 +350,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -365,7 +365,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -380,7 +380,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -399,7 +399,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := tinyTrie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) + _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -421,7 +421,7 @@ func TestAllElementsProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) + _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -434,7 +434,7 @@ func TestAllElementsProof(t *testing.T) { if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) + _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -449,7 +449,7 @@ func TestAllElementsProof(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) + _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -482,7 +482,7 @@ func TestSingleSideRangeProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) + _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -518,7 +518,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) + _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -590,7 +590,7 @@ func TestBadRangeProof(t *testing.T) { index = mrand.Intn(end - start) vals[index] = nil } - _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) + _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) if err == nil { t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1) } @@ -624,7 +624,7 @@ func TestGappedRangeProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) if err == nil { t.Fatal("expect error, got nil") } @@ -651,7 +651,7 @@ func TestSameSideProofs(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) + _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) if err == nil { t.Fatalf("Expected error, got nil") } @@ -667,7 +667,7 @@ func TestSameSideProofs(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) + _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) if err == nil { t.Fatalf("Expected error, got nil") } @@ -735,7 +735,7 @@ func TestHasRightElement(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) + hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -768,25 +768,19 @@ func TestEmptyRangeProof(t *testing.T) { if err := trie.Prove(first, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - db, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) + _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) if c.err && err == nil { t.Fatalf("Expected error, got nil") } if !c.err && err != nil { t.Fatalf("Expected no error, got %v", err) } - // If no error was returned, ensure the returned database contains - // the entire proof, since there's no value - if !c.err { - if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() { - t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len()) - } - } } } // TestBloatedProof tests a malicious proof, where the proof is more or less the -// whole trie. +// whole trie. Previously we didn't accept such packets, but the new APIs do, so +// lets leave this test as a bit weird, but present. func TestBloatedProof(t *testing.T) { // Use a small trie trie, kvs := nonRandomTrie(100) @@ -814,10 +808,8 @@ func TestBloatedProof(t *testing.T) { trie.Prove(keys[0], 0, want) trie.Prove(keys[len(keys)-1], 0, want) - db, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) - // The db should not contain anything of the bloated data - if used := db.(*memorydb.Database); used.Len() != want.Len() { - t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len()) + if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil { + t.Fatalf("expected bloated proof to succeed, got %v", err) } } @@ -921,7 +913,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) if err != nil { b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } @@ -948,7 +940,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) { } b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil) + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil) if err != nil { b.Fatalf("Expected no error, got %v", err) } diff --git a/trie/trie_test.go b/trie/trie_test.go index 44fddf87e4..492b423c2f 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -706,7 +706,6 @@ func (b *spongeBatch) Put(key, value []byte) error { return nil } func (b *spongeBatch) Delete(key []byte) error { panic("implement me") } -func (b *spongeBatch) KeyCount() int { return 100 } func (b *spongeBatch) ValueSize() int { return 100 } func (b *spongeBatch) Write() error { return nil } func (b *spongeBatch) Reset() {}