diff --git a/p2p/message.go b/p2p/message.go index ade39d25a3..845c832f09 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -41,14 +41,22 @@ func encodePayload(params ...interface{}) []byte { return buf.Bytes() } -// Data returns the decoded RLP payload items in a message. -func (msg Msg) Data() (*ethutil.Value, error) { - s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) +// Value returns the decoded RLP payload items in a message. +func (msg Msg) Value() (*ethutil.Value, error) { var v []interface{} - err := s.Decode(&v) + err := msg.Decode(&v) return ethutil.NewValue(v), err } +// Decode parse the RLP content of a message into +// the given value, which must be a pointer. +// +// For the decoding rules, please see package rlp. +func (msg Msg) Decode(val interface{}) error { + s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) + return s.Decode(val) +} + // Discard reads any remaining payload data into a black hole. func (msg Msg) Discard() error { _, err := io.Copy(ioutil.Discard, msg.Payload) @@ -91,7 +99,7 @@ func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Valu if msg.Size > maxsize { return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) } - value, err := msg.Data() + value, err := msg.Value() if err != nil { return err } diff --git a/p2p/message_test.go b/p2p/message_test.go index 02d70a28ba..0f51f759e6 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -42,7 +42,7 @@ func TestEncodeDecodeMsg(t *testing.T) { if decmsg.Size != 5 { t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) } - data, err := decmsg.Data() + data, err := decmsg.Value() if err != nil { t.Fatalf("first payload item decode error: %v", err) } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 56cd4d8902..629475421b 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -53,7 +53,7 @@ func TestPeerProtoReadMsg(t *testing.T) { if msg.Code != 2 { t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) } - data, err := msg.Data() + data, err := msg.Value() if err != nil { t.Errorf("data decoding error: %v", err) } diff --git a/p2p/protocol.go b/p2p/protocol.go index 169dcdb6e2..28eab87cd6 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -2,7 +2,6 @@ package p2p import ( "bytes" - "net" "time" "github.com/ethereum/go-ethereum/ethutil" @@ -90,30 +89,18 @@ type baseProtocol struct { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { bp := &baseProtocol{rw, peer} - - // do handshake - if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { - return err - } - msg, err := rw.ReadMsg() - if err != nil { + if err := bp.doHandshake(rw); err != nil { return err } - if msg.Code != handshakeMsg { - return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) - } - data, err := msg.Data() - if err != nil { - return newPeerError(errInvalidMsg, "%v", err) - } - if err := bp.handleHandshake(data); err != nil { - return err - } - // run main loop quit := make(chan error, 1) go func() { - quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle) + for { + if err := bp.handle(rw); err != nil { + quit <- err + break + } + } }() return bp.loop(quit) } @@ -151,13 +138,27 @@ func (bp *baseProtocol) loop(quit <-chan error) error { return err } -func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error { - switch code { +func (bp *baseProtocol) handle(rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errMisc, "message too big") + } + // make sure that the payload has been fully consumed + defer msg.Discard() + + switch msg.Code { case handshakeMsg: return newPeerError(errProtocolBreach, "extra handshake received") case discMsg: - bp.peer.Disconnect(DiscReason(data.Get(0).Uint())) + var reason DiscReason + if err := msg.Decode(&reason); err != nil { + return err + } + bp.peer.Disconnect(reason) return nil case pingMsg: @@ -178,35 +179,45 @@ func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error { } case peersMsg: - bp.handlePeers(data) + var peers []*peerAddr + if err := msg.Decode(&peers); err != nil { + return err + } + for _, addr := range peers { + bp.peer.Debugf("received peer suggestion: %v", addr) + bp.peer.newPeerAddr <- addr + } default: - return newPeerError(errInvalidMsgCode, "unknown message code %v", code) + return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code) } return nil } -func (bp *baseProtocol) handlePeers(data *ethutil.Value) { - it := data.NewIterator() - for it.Next() { - addr := &peerAddr{ - IP: net.IP(it.Value().Get(0).Bytes()), - Port: it.Value().Get(1).Uint(), - Pubkey: it.Value().Get(2).Bytes(), - } - bp.peer.Debugf("received peer suggestion: %v", addr) - bp.peer.newPeerAddr <- addr +func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { + // send our handshake + if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { + return err + } + + // read and handle remote handshake + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Code != handshakeMsg { + return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errMisc, "message too big") } -} -func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { - hs := handshake{ - Version: c.Get(0).Uint(), - ID: c.Get(1).Str(), - Caps: nil, // decoded below - ListenPort: c.Get(3).Uint(), - NodeID: c.Get(4).Bytes(), + var hs handshake + if err := msg.Decode(&hs); err != nil { + return err } + + // validate handshake info if hs.Version != baseProtocolVersion { return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", baseProtocolVersion, hs.Version) @@ -228,14 +239,8 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { if err := bp.peer.pubkeyHook(pa); err != nil { return newPeerError(errPubkeyForbidden, "%v", err) } - capsIt := c.Get(2).NewIterator() - for capsIt.Next() { - cap := capsIt.Value() - name := cap.Get(0).Str() - if name != "" { - hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())}) - } - } + + // TODO: remove Caps with empty name var addr *peerAddr if hs.ListenPort != 0 {