diff --git a/whisper/peer_test.go b/whisper/peer_test.go index de67b2463a..9008cdc593 100644 --- a/whisper/peer_test.go +++ b/whisper/peer_test.go @@ -190,3 +190,53 @@ func TestPeerDeliver(t *testing.T) { t.Fatalf("repeating message arrived") } } + +func TestPeerMessageExpiration(t *testing.T) { + // Start a tester and execute the handshake + tester, err := startTestPeerInited() + if err != nil { + t.Fatalf("failed to start initialized peer: %v", err) + } + defer tester.stream.Close() + + // Fetch the peer instance for later inspection + tester.client.peerMu.RLock() + if peers := len(tester.client.peers); peers != 1 { + t.Fatalf("peer pool size mismatch: have %v, want %v", peers, 1) + } + var peer *peer + for peer, _ = range tester.client.peers { + break + } + tester.client.peerMu.RUnlock() + + // Construct a message and pass it through the tester + message := NewMessage([]byte("peer test message")) + envelope, err := message.Wrap(DefaultPoW, Options{ + TTL: time.Second, + }) + if err != nil { + t.Fatalf("failed to wrap message: %v", err) + } + if err := tester.client.Send(envelope); err != nil { + t.Fatalf("failed to send message: %v", err) + } + payload := []interface{}{envelope} + if err := p2p.ExpectMsg(tester.stream, messagesCode, payload); err != nil { + t.Fatalf("message mismatch: %v", err) + } + // Check that the message is inside the cache + if !peer.known.Has(envelope.Hash()) { + t.Fatalf("message not found in cache") + } + // Discard messages until expiration and check cache again + exp := time.Now().Add(time.Second + expirationCycle) + for time.Now().Before(exp) { + if err := p2p.ExpectMsg(tester.stream, messagesCode, []interface{}{}); err != nil { + t.Fatalf("message mismatch: %v", err) + } + } + if peer.known.Has(envelope.Hash()) { + t.Fatalf("message not expired from cache") + } +} diff --git a/whisper/whisper.go b/whisper/whisper.go index f04075e1f4..48efff6229 100644 --- a/whisper/whisper.go +++ b/whisper/whisper.go @@ -46,22 +46,26 @@ type Whisper struct { protocol p2p.Protocol filters *filter.Filters - mmu sync.RWMutex // Message mutex to sync the below pool - messages map[common.Hash]*Envelope // Pool of messages currently tracked by this node - expiry map[uint32]*set.SetNonTS // Message expiration pool (TODO: something lighter) + keys map[string]*ecdsa.PrivateKey - quit chan struct{} + messages map[common.Hash]*Envelope // Pool of messages currently tracked by this node + expirations map[uint32]*set.SetNonTS // Message expiration pool (TODO: something lighter) + poolMu sync.RWMutex // Mutex to sync the message and expiration pools - keys map[string]*ecdsa.PrivateKey + peers map[*peer]struct{} // Set of currently active peers + peerMu sync.RWMutex // Mutex to sync the active peer set + + quit chan struct{} } func New() *Whisper { whisper := &Whisper{ - messages: make(map[common.Hash]*Envelope), - filters: filter.New(), - expiry: make(map[uint32]*set.SetNonTS), - quit: make(chan struct{}), - keys: make(map[string]*ecdsa.PrivateKey), + filters: filter.New(), + keys: make(map[string]*ecdsa.PrivateKey), + messages: make(map[common.Hash]*Envelope), + expirations: make(map[uint32]*set.SetNonTS), + peers: make(map[*peer]struct{}), + quit: make(chan struct{}), } whisper.filters.Start() @@ -179,6 +183,16 @@ func (self *Whisper) handlePeer(peer *p2p.Peer, rw p2p.MsgReadWriter) error { whisperPeer.start() defer whisperPeer.stop() + // Start tracking the active peer + self.peerMu.Lock() + self.peers[whisperPeer] = struct{}{} + self.peerMu.Unlock() + + defer func() { + self.peerMu.Lock() + delete(self.peers, whisperPeer) + self.peerMu.Unlock() + }() // Read and process inbound messages directly to merge into client-global state for { // Fetch the next packet and decode the contained envelopes @@ -206,8 +220,8 @@ func (self *Whisper) handlePeer(peer *p2p.Peer, rw p2p.MsgReadWriter) error { // whisper network. It also inserts the envelope into the expiration pool at the // appropriate time-stamp. func (self *Whisper) add(envelope *Envelope) error { - self.mmu.Lock() - defer self.mmu.Unlock() + self.poolMu.Lock() + defer self.poolMu.Unlock() // Insert the message into the tracked pool hash := envelope.Hash() @@ -218,11 +232,11 @@ func (self *Whisper) add(envelope *Envelope) error { self.messages[hash] = envelope // Insert the message into the expiration pool for later removal - if self.expiry[envelope.Expiry] == nil { - self.expiry[envelope.Expiry] = set.NewNonTS() + if self.expirations[envelope.Expiry] == nil { + self.expirations[envelope.Expiry] = set.NewNonTS() } - if !self.expiry[envelope.Expiry].Has(hash) { - self.expiry[envelope.Expiry].Add(hash) + if !self.expirations[envelope.Expiry].Has(hash) { + self.expirations[envelope.Expiry].Add(hash) // Notify the local node of a message arrival go self.postEvent(envelope) @@ -292,11 +306,11 @@ func (self *Whisper) update() { // expire iterates over all the expiration timestamps, removing all stale // messages from the pools. func (self *Whisper) expire() { - self.mmu.Lock() - defer self.mmu.Unlock() + self.poolMu.Lock() + defer self.poolMu.Unlock() now := uint32(time.Now().Unix()) - for then, hashSet := range self.expiry { + for then, hashSet := range self.expirations { // Short circuit if a future time if then > now { continue @@ -306,14 +320,14 @@ func (self *Whisper) expire() { delete(self.messages, v.(common.Hash)) return true }) - self.expiry[then].Clear() + self.expirations[then].Clear() } } // envelopes retrieves all the messages currently pooled by the node. func (self *Whisper) envelopes() []*Envelope { - self.mmu.RLock() - defer self.mmu.RUnlock() + self.poolMu.RLock() + defer self.poolMu.RUnlock() envelopes := make([]*Envelope, 0, len(self.messages)) for _, envelope := range self.messages {