p2p: fixes for actual connections

The unit test hooks were turned on 'in production'.
pull/292/head
Felix Lange 10 years ago
parent 8564eb9f7e
commit e34d134102
  1. 4
      p2p/message.go
  2. 37
      p2p/peer.go
  3. 2
      p2p/peer_error.go
  4. 19
      p2p/peer_test.go
  5. 4
      p2p/server.go
  6. 1
      p2p/server_test.go

@ -174,10 +174,10 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) {
// read magic and payload size // read magic and payload size
start := make([]byte, 8) start := make([]byte, 8)
if _, err = io.ReadFull(rw.bufconn, start); err != nil { if _, err = io.ReadFull(rw.bufconn, start); err != nil {
return msg, newPeerError(errRead, "%v", err) return msg, err
} }
if !bytes.HasPrefix(start, magicToken) { if !bytes.HasPrefix(start, magicToken) {
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
} }
size := binary.BigEndian.Uint32(start[4:]) size := binary.BigEndian.Uint32(start[4:])

@ -1,6 +1,7 @@
package p2p package p2p
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -71,7 +72,8 @@ type Peer struct {
runlock sync.RWMutex // protects running runlock sync.RWMutex // protects running
running map[string]*proto running map[string]*proto
protocolHandshakeEnabled bool // disables protocol handshake, for testing
noHandshake bool
protoWG sync.WaitGroup protoWG sync.WaitGroup
protoErr chan error protoErr chan error
@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer. // String implements fmt.Stringer.
func (p *Peer) String() string { func (p *Peer) String() string {
return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr()) return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
} }
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr()) logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
return &Peer{ return &Peer{
Logger: logger.NewLogger(logtag), Logger: logger.NewLogger(logtag),
rw: newFrameRW(conn, msgWriteTimeout), rw: newFrameRW(conn, msgWriteTimeout),
@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason {
var readErr = make(chan error, 1) var readErr = make(chan error, 1)
defer p.closeProtocols() defer p.closeProtocols()
defer close(p.closed) defer close(p.closed)
defer p.rw.Close()
// start the read loop
go func() { readErr <- p.readLoop() }() go func() { readErr <- p.readLoop() }()
if p.protocolHandshakeEnabled { if !p.noHandshake {
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
p.DebugDetailf("Protocol handshake error: %v\n", err) p.DebugDetailf("Protocol handshake error: %v\n", err)
p.rw.Close()
return DiscProtocolError return DiscProtocolError
} }
} }
// wait for an error or disconnect // Wait for an error or disconnect.
var reason DiscReason var reason DiscReason
select { select {
case err := <-readErr: case err := <-readErr:
// We rely on protocols to abort if there is a write error. It // We rely on protocols to abort if there is a write error. It
// might be more robust to handle them here as well. // might be more robust to handle them here as well.
p.DebugDetailf("Read error: %v\n", err) p.DebugDetailf("Read error: %v\n", err)
reason = DiscNetworkError p.rw.Close()
return DiscNetworkError
case err := <-p.protoErr: case err := <-p.protoErr:
reason = discReasonForError(err) reason = discReasonForError(err)
case reason = <-p.disc: case reason = <-p.disc:
} }
if reason != DiscNetworkError { p.politeDisconnect(reason)
p.politeDisconnect(reason)
} // Wait for readLoop. It will end because conn is now closed.
<-readErr
p.Debugf("Disconnected: %v\n", reason) p.Debugf("Disconnected: %v\n", reason)
return reason return reason
} }
@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason {
func (p *Peer) politeDisconnect(reason DiscReason) { func (p *Peer) politeDisconnect(reason DiscReason) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
// send reason
EncodeMsg(p.rw, discMsg, uint(reason)) EncodeMsg(p.rw, discMsg, uint(reason))
// discard any data that might arrive // Wait for the other side to close the connection.
// Discard any data that they send until then.
io.Copy(ioutil.Discard, p.rw) io.Copy(ioutil.Discard, p.rw)
close(done) close(done)
}() }()
@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
case <-done: case <-done:
case <-time.After(disconnectGracePeriod): case <-time.After(disconnectGracePeriod):
} }
p.rw.Close()
} }
func (p *Peer) readLoop() error { func (p *Peer) readLoop() error {
if p.protocolHandshakeEnabled { if !p.noHandshake {
if err := readProtocolHandshake(p, p.rw); err != nil { if err := readProtocolHandshake(p, p.rw); err != nil {
return err return err
} }
@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
} }
if msg.Size > baseProtocolMaxMsgSize { if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big") return newPeerError(errInvalidMsg, "message too big")
} }
var hs handshake var hs handshake
if err := msg.Decode(&hs); err != nil { if err := msg.Decode(&hs); err != nil {
@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
err := impl.Run(p, rw) err := impl.Run(p, rw)
if err == nil { if err == nil {
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
err = newPeerError(errMisc, "protocol returned") err = errors.New("protocol returned")
} else { } else {
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
} }

@ -123,7 +123,7 @@ func discReasonForError(err error) DiscReason {
return DiscProtocolError return DiscProtocolError
case errPingTimeout: case errPingTimeout:
return DiscReadTimeout return DiscReadTimeout
case errRead, errWrite, errMisc: case errRead, errWrite:
return DiscNetworkError return DiscNetworkError
default: default:
return DiscSubprotocolError return DiscSubprotocolError

@ -30,10 +30,10 @@ var discard = Protocol{
}, },
} }
func testPeer(handshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{}) peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
peer.protocolHandshakeEnabled = handshake peer.noHandshake = noHandshake
errc := make(chan DiscReason, 1) errc := make(chan DiscReason, 1)
go func() { errc <- peer.run() }() go func() { errc <- peer.run() }()
return newFrameRW(conn2, msgWriteTimeout), peer, errc return newFrameRW(conn2, msgWriteTimeout), peer, errc
@ -61,7 +61,7 @@ func TestPeerProtoReadMsg(t *testing.T) {
}, },
} }
rw, peer, errc := testPeer(false, []Protocol{proto}) rw, peer, errc := testPeer(true, []Protocol{proto})
defer rw.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
@ -100,7 +100,7 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
}, },
} }
rw, peer, errc := testPeer(false, []Protocol{proto}) rw, peer, errc := testPeer(true, []Protocol{proto})
defer rw.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
@ -130,7 +130,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
return nil return nil
}, },
} }
rw, peer, _ := testPeer(false, []Protocol{proto}) rw, peer, _ := testPeer(true, []Protocol{proto})
defer rw.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
@ -142,7 +142,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
func TestPeerWriteForBroadcast(t *testing.T) { func TestPeerWriteForBroadcast(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, peer, peerErr := testPeer(false, []Protocol{discard}) rw, peer, peerErr := testPeer(true, []Protocol{discard})
defer rw.Close() defer rw.Close()
peer.startSubprotocols([]Cap{discard.cap()}) peer.startSubprotocols([]Cap{discard.cap()})
@ -179,7 +179,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
func TestPeerPing(t *testing.T) { func TestPeerPing(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, _, _ := testPeer(false, nil) rw, _, _ := testPeer(true, nil)
defer rw.Close() defer rw.Close()
if err := EncodeMsg(rw, pingMsg); err != nil { if err := EncodeMsg(rw, pingMsg); err != nil {
t.Fatal(err) t.Fatal(err)
@ -192,7 +192,7 @@ func TestPeerPing(t *testing.T) {
func TestPeerDisconnect(t *testing.T) { func TestPeerDisconnect(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, _, disc := testPeer(false, nil) rw, _, disc := testPeer(true, nil)
defer rw.Close() defer rw.Close()
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
t.Fatal(err) t.Fatal(err)
@ -233,7 +233,7 @@ func TestPeerHandshake(t *testing.T) {
{Name: "c", Version: 3, Length: 1, Run: run}, {Name: "c", Version: 3, Length: 1, Run: run},
{Name: "d", Version: 4, Length: 1, Run: run}, {Name: "d", Version: 4, Length: 1, Run: run},
} }
rw, p, disc := testPeer(true, protocols) rw, p, disc := testPeer(false, protocols)
p.remoteID = remote.ourID p.remoteID = remote.ourID
defer rw.Close() defer rw.Close()
@ -269,6 +269,7 @@ func TestPeerHandshake(t *testing.T) {
} }
close(stop) close(stop)
expectMsg(rw, discMsg, nil)
t.Logf("disc reason: %v", <-disc) t.Logf("disc reason: %v", <-disc)
} }

@ -408,7 +408,9 @@ func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
return return
} }
srv.newPeerHook(p) if srv.newPeerHook != nil {
srv.newPeerHook(p)
}
p.run() p.run()
srv.removePeer(p) srv.removePeer(p)
} }

@ -118,6 +118,7 @@ func TestServerBroadcast(t *testing.T) {
srv := startTestServer(t, func(p *Peer) { srv := startTestServer(t, func(p *Peer) {
p.protocols = []Protocol{discard} p.protocols = []Protocol{discard}
p.startSubprotocols([]Cap{discard.cap()}) p.startSubprotocols([]Cap{discard.cap()})
p.noHandshake = true
connected.Done() connected.Done()
}) })
defer srv.Stop() defer srv.Stop()

Loading…
Cancel
Save