diff --git a/core/bloombits/generator.go b/core/bloombits/generator.go index 540085450d..ae07481ada 100644 --- a/core/bloombits/generator.go +++ b/core/bloombits/generator.go @@ -22,16 +22,22 @@ import ( "github.com/ethereum/go-ethereum/core/types" ) -// errSectionOutOfBounds is returned if the user tried to add more bloom filters -// to the batch than available space, or if tries to retrieve above the capacity, -var errSectionOutOfBounds = errors.New("section out of bounds") +var ( + // errSectionOutOfBounds is returned if the user tried to add more bloom filters + // to the batch than available space, or if tries to retrieve above the capacity. + errSectionOutOfBounds = errors.New("section out of bounds") + + // errBloomBitOutOfBounds is returned if the user tried to retrieve specified + // bit bloom above the capacity. + errBloomBitOutOfBounds = errors.New("bloom bit out of bounds") +) // Generator takes a number of bloom filters and generates the rotated bloom bits // to be used for batched filtering. type Generator struct { blooms [types.BloomBitLength][]byte // Rotated blooms for per-bit matching sections uint // Number of sections to batch together - nextBit uint // Next bit to set when adding a bloom + nextSec uint // Next section to set when adding a bloom } // NewGenerator creates a rotated bloom generator that can iteratively fill a @@ -51,15 +57,15 @@ func NewGenerator(sections uint) (*Generator, error) { // in memory accordingly. func (b *Generator) AddBloom(index uint, bloom types.Bloom) error { // Make sure we're not adding more bloom filters than our capacity - if b.nextBit >= b.sections { + if b.nextSec >= b.sections { return errSectionOutOfBounds } - if b.nextBit != index { + if b.nextSec != index { return errors.New("bloom filter with unexpected index") } // Rotate the bloom and insert into our collection - byteIndex := b.nextBit / 8 - bitMask := byte(1) << byte(7-b.nextBit%8) + byteIndex := b.nextSec / 8 + bitMask := byte(1) << byte(7-b.nextSec%8) for i := 0; i < types.BloomBitLength; i++ { bloomByteIndex := types.BloomByteLength - 1 - i/8 @@ -69,7 +75,7 @@ func (b *Generator) AddBloom(index uint, bloom types.Bloom) error { b.blooms[i][byteIndex] |= bitMask } } - b.nextBit++ + b.nextSec++ return nil } @@ -77,11 +83,11 @@ func (b *Generator) AddBloom(index uint, bloom types.Bloom) error { // Bitset returns the bit vector belonging to the given bit index after all // blooms have been added. func (b *Generator) Bitset(idx uint) ([]byte, error) { - if b.nextBit != b.sections { + if b.nextSec != b.sections { return nil, errors.New("bloom not fully generated yet") } - if idx >= b.sections { - return nil, errSectionOutOfBounds + if idx >= types.BloomBitLength { + return nil, errBloomBitOutOfBounds } return b.blooms[idx], nil } diff --git a/ethdb/database_test.go b/ethdb/database_test.go index 2deb50988c..74675cbe63 100644 --- a/ethdb/database_test.go +++ b/ethdb/database_test.go @@ -59,6 +59,28 @@ func TestMemoryDB_PutGet(t *testing.T) { func testPutGet(db ethdb.Database, t *testing.T) { t.Parallel() + for _, k := range test_values { + err := db.Put([]byte(k), nil) + if err != nil { + t.Fatalf("put failed: %v", err) + } + } + + for _, k := range test_values { + data, err := db.Get([]byte(k)) + if err != nil { + t.Fatalf("get failed: %v", err) + } + if len(data) != 0 { + t.Fatalf("get returned wrong result, got %q expected nil", string(data)) + } + } + + _, err := db.Get([]byte("non-exist-key")) + if err == nil { + t.Fatalf("expect to return a not found error") + } + for _, v := range test_values { err := db.Put([]byte(v), []byte(v)) if err != nil { diff --git a/ethdb/memory_database.go b/ethdb/memory_database.go index f28ff54818..727f2f7ca3 100644 --- a/ethdb/memory_database.go +++ b/ethdb/memory_database.go @@ -96,7 +96,10 @@ func (db *MemDatabase) NewBatch() Batch { func (db *MemDatabase) Len() int { return len(db.db) } -type kv struct{ k, v []byte } +type kv struct { + k, v []byte + del bool +} type memBatch struct { db *MemDatabase @@ -105,13 +108,14 @@ type memBatch struct { } func (b *memBatch) Put(key, value []byte) error { - b.writes = append(b.writes, kv{common.CopyBytes(key), common.CopyBytes(value)}) + b.writes = append(b.writes, kv{common.CopyBytes(key), common.CopyBytes(value), false}) b.size += len(value) return nil } func (b *memBatch) Delete(key []byte) error { - b.writes = append(b.writes, kv{common.CopyBytes(key), nil}) + b.writes = append(b.writes, kv{common.CopyBytes(key), nil, true}) + b.size += 1 return nil } @@ -120,7 +124,7 @@ func (b *memBatch) Write() error { defer b.db.lock.Unlock() for _, kv := range b.writes { - if kv.v == nil { + if kv.del { delete(b.db.db, string(kv.k)) continue }