|
|
|
@ -21,6 +21,7 @@ const ( |
|
|
|
|
baseProtocolMaxMsgSize = 10 * 1024 * 1024 |
|
|
|
|
|
|
|
|
|
disconnectGracePeriod = 2 * time.Second |
|
|
|
|
pingInterval = 15 * time.Second |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
const ( |
|
|
|
@ -33,37 +34,14 @@ const ( |
|
|
|
|
peersMsg = 0x05 |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// handshake is the RLP structure of the protocol handshake.
|
|
|
|
|
type handshake struct { |
|
|
|
|
Version uint64 |
|
|
|
|
Name string |
|
|
|
|
Caps []Cap |
|
|
|
|
ListenPort uint64 |
|
|
|
|
NodeID discover.NodeID |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Peer represents a connected remote node.
|
|
|
|
|
type Peer struct { |
|
|
|
|
// Peers have all the log methods.
|
|
|
|
|
// Use them to display messages related to the peer.
|
|
|
|
|
*logger.Logger |
|
|
|
|
|
|
|
|
|
infoMu sync.Mutex |
|
|
|
|
name string |
|
|
|
|
caps []Cap |
|
|
|
|
|
|
|
|
|
ourID, remoteID *discover.NodeID |
|
|
|
|
ourName string |
|
|
|
|
|
|
|
|
|
rw *frameRW |
|
|
|
|
|
|
|
|
|
// These fields maintain the running protocols.
|
|
|
|
|
protocols []Protocol |
|
|
|
|
runlock sync.RWMutex // protects running
|
|
|
|
|
running map[string]*proto |
|
|
|
|
|
|
|
|
|
// disables protocol handshake, for testing
|
|
|
|
|
noHandshake bool |
|
|
|
|
rw *conn |
|
|
|
|
running map[string]*protoRW |
|
|
|
|
|
|
|
|
|
protoWG sync.WaitGroup |
|
|
|
|
protoErr chan error |
|
|
|
@ -73,36 +51,27 @@ type Peer struct { |
|
|
|
|
|
|
|
|
|
// NewPeer returns a peer for testing purposes.
|
|
|
|
|
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { |
|
|
|
|
conn, _ := net.Pipe() |
|
|
|
|
peer := newPeer(conn, nil, "", nil, &id) |
|
|
|
|
peer.setHandshakeInfo(name, caps) |
|
|
|
|
pipe, _ := net.Pipe() |
|
|
|
|
conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps}) |
|
|
|
|
peer := newPeer(conn, nil) |
|
|
|
|
close(peer.closed) // ensures Disconnect doesn't block
|
|
|
|
|
return peer |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// ID returns the node's public key.
|
|
|
|
|
func (p *Peer) ID() discover.NodeID { |
|
|
|
|
return *p.remoteID |
|
|
|
|
return p.rw.ID |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Name returns the node name that the remote node advertised.
|
|
|
|
|
func (p *Peer) Name() string { |
|
|
|
|
// this needs a lock because the information is part of the
|
|
|
|
|
// protocol handshake.
|
|
|
|
|
p.infoMu.Lock() |
|
|
|
|
name := p.name |
|
|
|
|
p.infoMu.Unlock() |
|
|
|
|
return name |
|
|
|
|
return p.rw.Name |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
|
|
|
|
func (p *Peer) Caps() []Cap { |
|
|
|
|
// this needs a lock because the information is part of the
|
|
|
|
|
// protocol handshake.
|
|
|
|
|
p.infoMu.Lock() |
|
|
|
|
caps := p.caps |
|
|
|
|
p.infoMu.Unlock() |
|
|
|
|
return caps |
|
|
|
|
// TODO: maybe return copy
|
|
|
|
|
return p.rw.Caps |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// RemoteAddr returns the remote address of the network connection.
|
|
|
|
@ -126,30 +95,20 @@ func (p *Peer) Disconnect(reason DiscReason) { |
|
|
|
|
|
|
|
|
|
// String implements fmt.Stringer.
|
|
|
|
|
func (p *Peer) String() string { |
|
|
|
|
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr()) |
|
|
|
|
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { |
|
|
|
|
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr()) |
|
|
|
|
return &Peer{ |
|
|
|
|
func newPeer(conn *conn, protocols []Protocol) *Peer { |
|
|
|
|
logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr()) |
|
|
|
|
p := &Peer{ |
|
|
|
|
Logger: logger.NewLogger(logtag), |
|
|
|
|
rw: newFrameRW(conn, msgWriteTimeout), |
|
|
|
|
ourID: ourID, |
|
|
|
|
ourName: ourName, |
|
|
|
|
remoteID: remoteID, |
|
|
|
|
protocols: protocols, |
|
|
|
|
running: make(map[string]*proto), |
|
|
|
|
rw: conn, |
|
|
|
|
running: matchProtocols(protocols, conn.Caps, conn), |
|
|
|
|
disc: make(chan DiscReason), |
|
|
|
|
protoErr: make(chan error), |
|
|
|
|
closed: make(chan struct{}), |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (p *Peer) setHandshakeInfo(name string, caps []Cap) { |
|
|
|
|
p.infoMu.Lock() |
|
|
|
|
p.name = name |
|
|
|
|
p.caps = caps |
|
|
|
|
p.infoMu.Unlock() |
|
|
|
|
return p |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (p *Peer) run() DiscReason { |
|
|
|
@ -157,29 +116,36 @@ func (p *Peer) run() DiscReason { |
|
|
|
|
defer p.closeProtocols() |
|
|
|
|
defer close(p.closed) |
|
|
|
|
|
|
|
|
|
p.startProtocols() |
|
|
|
|
go func() { readErr <- p.readLoop() }() |
|
|
|
|
|
|
|
|
|
if !p.noHandshake { |
|
|
|
|
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { |
|
|
|
|
p.DebugDetailf("Protocol handshake error: %v\n", err) |
|
|
|
|
p.rw.Close() |
|
|
|
|
return DiscProtocolError |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
ping := time.NewTicker(pingInterval) |
|
|
|
|
defer ping.Stop() |
|
|
|
|
|
|
|
|
|
// Wait for an error or disconnect.
|
|
|
|
|
var reason DiscReason |
|
|
|
|
loop: |
|
|
|
|
for { |
|
|
|
|
select { |
|
|
|
|
case <-ping.C: |
|
|
|
|
go func() { |
|
|
|
|
if err := EncodeMsg(p.rw, pingMsg, nil); err != nil { |
|
|
|
|
p.protoErr <- err |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
case err := <-readErr: |
|
|
|
|
// We rely on protocols to abort if there is a write error. It
|
|
|
|
|
// might be more robust to handle them here as well.
|
|
|
|
|
p.DebugDetailf("Read error: %v\n", err) |
|
|
|
|
p.rw.Close() |
|
|
|
|
return DiscNetworkError |
|
|
|
|
|
|
|
|
|
case err := <-p.protoErr: |
|
|
|
|
reason = discReasonForError(err) |
|
|
|
|
break loop |
|
|
|
|
case reason = <-p.disc: |
|
|
|
|
break loop |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
p.politeDisconnect(reason) |
|
|
|
|
|
|
|
|
@ -206,11 +172,6 @@ func (p *Peer) politeDisconnect(reason DiscReason) { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (p *Peer) readLoop() error { |
|
|
|
|
if !p.noHandshake { |
|
|
|
|
if err := readProtocolHandshake(p, p.rw); err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
for { |
|
|
|
|
msg, err := p.rw.ReadMsg() |
|
|
|
|
if err != nil { |
|
|
|
@ -249,88 +210,36 @@ func (p *Peer) handle(msg Msg) error { |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error { |
|
|
|
|
// read and handle remote handshake
|
|
|
|
|
msg, err := rw.ReadMsg() |
|
|
|
|
if err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
if msg.Code == discMsg { |
|
|
|
|
// disconnect before protocol handshake is valid according to the
|
|
|
|
|
// spec and we send it ourself if Server.addPeer fails.
|
|
|
|
|
var reason DiscReason |
|
|
|
|
rlp.Decode(msg.Payload, &reason) |
|
|
|
|
return discRequestedError(reason) |
|
|
|
|
} |
|
|
|
|
if msg.Code != handshakeMsg { |
|
|
|
|
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) |
|
|
|
|
} |
|
|
|
|
if msg.Size > baseProtocolMaxMsgSize { |
|
|
|
|
return newPeerError(errInvalidMsg, "message too big") |
|
|
|
|
} |
|
|
|
|
var hs handshake |
|
|
|
|
if err := msg.Decode(&hs); err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
// validate handshake info
|
|
|
|
|
if hs.Version != baseProtocolVersion { |
|
|
|
|
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", |
|
|
|
|
baseProtocolVersion, hs.Version) |
|
|
|
|
} |
|
|
|
|
if hs.NodeID == *p.remoteID { |
|
|
|
|
return newPeerError(errPubkeyForbidden, "node ID mismatch") |
|
|
|
|
} |
|
|
|
|
// TODO: remove Caps with empty name
|
|
|
|
|
p.setHandshakeInfo(hs.Name, hs.Caps) |
|
|
|
|
p.startSubprotocols(hs.Caps) |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error { |
|
|
|
|
var caps []interface{} |
|
|
|
|
for _, proto := range ps { |
|
|
|
|
caps = append(caps, proto.cap()) |
|
|
|
|
} |
|
|
|
|
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// startProtocols starts matching named subprotocols.
|
|
|
|
|
func (p *Peer) startSubprotocols(caps []Cap) { |
|
|
|
|
// matchProtocols creates structures for matching named subprotocols.
|
|
|
|
|
func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW { |
|
|
|
|
sort.Sort(capsByName(caps)) |
|
|
|
|
p.runlock.Lock() |
|
|
|
|
defer p.runlock.Unlock() |
|
|
|
|
offset := baseProtocolLength |
|
|
|
|
result := make(map[string]*protoRW) |
|
|
|
|
outer: |
|
|
|
|
for _, cap := range caps { |
|
|
|
|
for _, proto := range p.protocols { |
|
|
|
|
if proto.Name == cap.Name && |
|
|
|
|
proto.Version == cap.Version && |
|
|
|
|
p.running[cap.Name] == nil { |
|
|
|
|
p.running[cap.Name] = p.startProto(offset, proto) |
|
|
|
|
for _, proto := range protocols { |
|
|
|
|
if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil { |
|
|
|
|
result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw} |
|
|
|
|
offset += proto.Length |
|
|
|
|
continue outer |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return result |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (p *Peer) startProto(offset uint64, impl Protocol) *proto { |
|
|
|
|
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version) |
|
|
|
|
rw := &proto{ |
|
|
|
|
name: impl.Name, |
|
|
|
|
in: make(chan Msg), |
|
|
|
|
offset: offset, |
|
|
|
|
maxcode: impl.Length, |
|
|
|
|
w: p.rw, |
|
|
|
|
} |
|
|
|
|
func (p *Peer) startProtocols() { |
|
|
|
|
for _, proto := range p.running { |
|
|
|
|
proto := proto |
|
|
|
|
p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version) |
|
|
|
|
p.protoWG.Add(1) |
|
|
|
|
go func() { |
|
|
|
|
err := impl.Run(p, rw) |
|
|
|
|
err := proto.Run(p, proto) |
|
|
|
|
if err == nil { |
|
|
|
|
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) |
|
|
|
|
p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version) |
|
|
|
|
err = errors.New("protocol returned") |
|
|
|
|
} else { |
|
|
|
|
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) |
|
|
|
|
p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err) |
|
|
|
|
} |
|
|
|
|
select { |
|
|
|
|
case p.protoErr <- err: |
|
|
|
@ -338,16 +247,14 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto { |
|
|
|
|
} |
|
|
|
|
p.protoWG.Done() |
|
|
|
|
}() |
|
|
|
|
return rw |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// getProto finds the protocol responsible for handling
|
|
|
|
|
// the given message code.
|
|
|
|
|
func (p *Peer) getProto(code uint64) (*proto, error) { |
|
|
|
|
p.runlock.RLock() |
|
|
|
|
defer p.runlock.RUnlock() |
|
|
|
|
func (p *Peer) getProto(code uint64) (*protoRW, error) { |
|
|
|
|
for _, proto := range p.running { |
|
|
|
|
if code >= proto.offset && code < proto.offset+proto.maxcode { |
|
|
|
|
if code >= proto.offset && code < proto.offset+proto.Length { |
|
|
|
|
return proto, nil |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -355,46 +262,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (p *Peer) closeProtocols() { |
|
|
|
|
p.runlock.RLock() |
|
|
|
|
for _, p := range p.running { |
|
|
|
|
close(p.in) |
|
|
|
|
} |
|
|
|
|
p.runlock.RUnlock() |
|
|
|
|
p.protoWG.Wait() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
|
|
|
|
// this exists because of Server.Broadcast.
|
|
|
|
|
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { |
|
|
|
|
p.runlock.RLock() |
|
|
|
|
proto, ok := p.running[protoName] |
|
|
|
|
p.runlock.RUnlock() |
|
|
|
|
if !ok { |
|
|
|
|
return fmt.Errorf("protocol %s not handled by peer", protoName) |
|
|
|
|
} |
|
|
|
|
if msg.Code >= proto.maxcode { |
|
|
|
|
if msg.Code >= proto.Length { |
|
|
|
|
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) |
|
|
|
|
} |
|
|
|
|
msg.Code += proto.offset |
|
|
|
|
return p.rw.WriteMsg(msg) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type proto struct { |
|
|
|
|
name string |
|
|
|
|
type protoRW struct { |
|
|
|
|
Protocol |
|
|
|
|
|
|
|
|
|
in chan Msg |
|
|
|
|
maxcode, offset uint64 |
|
|
|
|
offset uint64 |
|
|
|
|
w MsgWriter |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (rw *proto) WriteMsg(msg Msg) error { |
|
|
|
|
if msg.Code >= rw.maxcode { |
|
|
|
|
func (rw *protoRW) WriteMsg(msg Msg) error { |
|
|
|
|
if msg.Code >= rw.Length { |
|
|
|
|
return newPeerError(errInvalidMsgCode, "not handled") |
|
|
|
|
} |
|
|
|
|
msg.Code += rw.offset |
|
|
|
|
return rw.w.WriteMsg(msg) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (rw *proto) ReadMsg() (Msg, error) { |
|
|
|
|
func (rw *protoRW) ReadMsg() (Msg, error) { |
|
|
|
|
msg, ok := <-rw.in |
|
|
|
|
if !ok { |
|
|
|
|
return msg, io.EOF |
|
|
|
|