diff --git a/eth/protocol.go b/eth/protocol.go index 663af43fe..b86f33614 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -3,7 +3,6 @@ package eth import ( "bytes" "fmt" - "io" "math/big" "github.com/ethereum/go-ethereum/core/types" @@ -188,33 +187,37 @@ func (self *ethProtocol) handle() error { case BlockHashesMsg: msgStream := rlp.NewStream(msg.Payload) - var err error - var i int + if _, err := msgStream.List(); err != nil { + return err + } + var i int iter := func() (hash []byte, ok bool) { - hash, err = msgStream.Bytes() - if err == nil { - i++ - ok = true - } else { - if err != io.EOF { - self.protoError(ErrDecode, "msg %v: after %v hashes : %v", msg, i, err) - } + hash, err := msgStream.Bytes() + if err == rlp.EOL { + return nil, false + } else if err != nil { + self.protoError(ErrDecode, "msg %v: after %v hashes : %v", msg, i, err) + return nil, false } - return + i++ + return hash, true } - self.blockPool.AddBlockHashes(iter, self.id) case GetBlocksMsg: msgStream := rlp.NewStream(msg.Payload) + if _, err := msgStream.List(); err != nil { + return err + } + var blocks []interface{} var i int for { i++ var hash []byte if err := msgStream.Decode(&hash); err != nil { - if err == io.EOF { + if err == rlp.EOL { break } else { return self.protoError(ErrDecode, "msg %v: %v", msg, err) @@ -232,10 +235,13 @@ func (self *ethProtocol) handle() error { case BlocksMsg: msgStream := rlp.NewStream(msg.Payload) + if _, err := msgStream.List(); err != nil { + return err + } for { var block types.Block if err := msgStream.Decode(&block); err != nil { - if err == io.EOF { + if err == rlp.EOL { break } else { return self.protoError(ErrDecode, "msg %v: %v", msg, err) diff --git a/whisper/peer.go b/whisper/peer.go index 332ddd22a..66cfec88c 100644 --- a/whisper/peer.go +++ b/whisper/peer.go @@ -2,10 +2,10 @@ package whisper import ( "fmt" - "io/ioutil" "time" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" "gopkg.in/fatih/set.v0" ) @@ -77,8 +77,7 @@ func (self *peer) broadcast(envelopes []*Envelope) error { } if i > 0 { - msg := p2p.NewMsg(envelopesMsg, envs[:i]...) - if err := self.ws.WriteMsg(msg); err != nil { + if err := p2p.EncodeMsg(self.ws, envelopesMsg, envs[:i]...); err != nil { return err } self.peer.DebugDetailln("broadcasted", i, "message(s)") @@ -93,34 +92,28 @@ func (self *peer) addKnown(envelope *Envelope) { func (self *peer) handleStatus() error { ws := self.ws - if err := ws.WriteMsg(self.statusMsg()); err != nil { return err } - msg, err := ws.ReadMsg() if err != nil { return err } - if msg.Code != statusMsg { return fmt.Errorf("peer send %x before status msg", msg.Code) } - - data, err := ioutil.ReadAll(msg.Payload) - if err != nil { - return err + s := rlp.NewStream(msg.Payload) + if _, err := s.List(); err != nil { + return fmt.Errorf("bad status message: %v", err) } - - if len(data) == 0 { - return fmt.Errorf("malformed status. data len = 0") + pv, err := s.Uint() + if err != nil { + return fmt.Errorf("bad status message: %v", err) } - - if pv := data[0]; pv != protocolVersion { + if pv != protocolVersion { return fmt.Errorf("protocol version mismatch %d != %d", pv, protocolVersion) } - - return nil + return msg.Discard() // ignore anything after protocol version } func (self *peer) statusMsg() p2p.Msg {