|
|
@ -1,6 +1,7 @@ |
|
|
|
package p2p |
|
|
|
package p2p |
|
|
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
import ( |
|
|
|
|
|
|
|
"errors" |
|
|
|
"fmt" |
|
|
|
"fmt" |
|
|
|
"io" |
|
|
|
"io" |
|
|
|
"io/ioutil" |
|
|
|
"io/ioutil" |
|
|
@ -71,7 +72,8 @@ type Peer struct { |
|
|
|
runlock sync.RWMutex // protects running
|
|
|
|
runlock sync.RWMutex // protects running
|
|
|
|
running map[string]*proto |
|
|
|
running map[string]*proto |
|
|
|
|
|
|
|
|
|
|
|
protocolHandshakeEnabled bool |
|
|
|
// disables protocol handshake, for testing
|
|
|
|
|
|
|
|
noHandshake bool |
|
|
|
|
|
|
|
|
|
|
|
protoWG sync.WaitGroup |
|
|
|
protoWG sync.WaitGroup |
|
|
|
protoErr chan error |
|
|
|
protoErr chan error |
|
|
@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) { |
|
|
|
|
|
|
|
|
|
|
|
// String implements fmt.Stringer.
|
|
|
|
// String implements fmt.Stringer.
|
|
|
|
func (p *Peer) String() string { |
|
|
|
func (p *Peer) String() string { |
|
|
|
return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr()) |
|
|
|
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr()) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { |
|
|
|
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { |
|
|
|
logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr()) |
|
|
|
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr()) |
|
|
|
return &Peer{ |
|
|
|
return &Peer{ |
|
|
|
Logger: logger.NewLogger(logtag), |
|
|
|
Logger: logger.NewLogger(logtag), |
|
|
|
rw: newFrameRW(conn, msgWriteTimeout), |
|
|
|
rw: newFrameRW(conn, msgWriteTimeout), |
|
|
@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason { |
|
|
|
var readErr = make(chan error, 1) |
|
|
|
var readErr = make(chan error, 1) |
|
|
|
defer p.closeProtocols() |
|
|
|
defer p.closeProtocols() |
|
|
|
defer close(p.closed) |
|
|
|
defer close(p.closed) |
|
|
|
defer p.rw.Close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// start the read loop
|
|
|
|
|
|
|
|
go func() { readErr <- p.readLoop() }() |
|
|
|
go func() { readErr <- p.readLoop() }() |
|
|
|
|
|
|
|
|
|
|
|
if p.protocolHandshakeEnabled { |
|
|
|
if !p.noHandshake { |
|
|
|
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { |
|
|
|
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { |
|
|
|
p.DebugDetailf("Protocol handshake error: %v\n", err) |
|
|
|
p.DebugDetailf("Protocol handshake error: %v\n", err) |
|
|
|
|
|
|
|
p.rw.Close() |
|
|
|
return DiscProtocolError |
|
|
|
return DiscProtocolError |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// wait for an error or disconnect
|
|
|
|
// Wait for an error or disconnect.
|
|
|
|
var reason DiscReason |
|
|
|
var reason DiscReason |
|
|
|
select { |
|
|
|
select { |
|
|
|
case err := <-readErr: |
|
|
|
case err := <-readErr: |
|
|
|
// We rely on protocols to abort if there is a write error. It
|
|
|
|
// We rely on protocols to abort if there is a write error. It
|
|
|
|
// might be more robust to handle them here as well.
|
|
|
|
// might be more robust to handle them here as well.
|
|
|
|
p.DebugDetailf("Read error: %v\n", err) |
|
|
|
p.DebugDetailf("Read error: %v\n", err) |
|
|
|
reason = DiscNetworkError |
|
|
|
p.rw.Close() |
|
|
|
|
|
|
|
return DiscNetworkError |
|
|
|
|
|
|
|
|
|
|
|
case err := <-p.protoErr: |
|
|
|
case err := <-p.protoErr: |
|
|
|
reason = discReasonForError(err) |
|
|
|
reason = discReasonForError(err) |
|
|
|
case reason = <-p.disc: |
|
|
|
case reason = <-p.disc: |
|
|
|
} |
|
|
|
} |
|
|
|
if reason != DiscNetworkError { |
|
|
|
p.politeDisconnect(reason) |
|
|
|
p.politeDisconnect(reason) |
|
|
|
|
|
|
|
} |
|
|
|
// Wait for readLoop. It will end because conn is now closed.
|
|
|
|
|
|
|
|
<-readErr |
|
|
|
p.Debugf("Disconnected: %v\n", reason) |
|
|
|
p.Debugf("Disconnected: %v\n", reason) |
|
|
|
return reason |
|
|
|
return reason |
|
|
|
} |
|
|
|
} |
|
|
@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason { |
|
|
|
func (p *Peer) politeDisconnect(reason DiscReason) { |
|
|
|
func (p *Peer) politeDisconnect(reason DiscReason) { |
|
|
|
done := make(chan struct{}) |
|
|
|
done := make(chan struct{}) |
|
|
|
go func() { |
|
|
|
go func() { |
|
|
|
// send reason
|
|
|
|
|
|
|
|
EncodeMsg(p.rw, discMsg, uint(reason)) |
|
|
|
EncodeMsg(p.rw, discMsg, uint(reason)) |
|
|
|
// discard any data that might arrive
|
|
|
|
// Wait for the other side to close the connection.
|
|
|
|
|
|
|
|
// Discard any data that they send until then.
|
|
|
|
io.Copy(ioutil.Discard, p.rw) |
|
|
|
io.Copy(ioutil.Discard, p.rw) |
|
|
|
close(done) |
|
|
|
close(done) |
|
|
|
}() |
|
|
|
}() |
|
|
@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) { |
|
|
|
case <-done: |
|
|
|
case <-done: |
|
|
|
case <-time.After(disconnectGracePeriod): |
|
|
|
case <-time.After(disconnectGracePeriod): |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
p.rw.Close() |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (p *Peer) readLoop() error { |
|
|
|
func (p *Peer) readLoop() error { |
|
|
|
if p.protocolHandshakeEnabled { |
|
|
|
if !p.noHandshake { |
|
|
|
if err := readProtocolHandshake(p, p.rw); err != nil { |
|
|
|
if err := readProtocolHandshake(p, p.rw); err != nil { |
|
|
|
return err |
|
|
|
return err |
|
|
|
} |
|
|
|
} |
|
|
@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error { |
|
|
|
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) |
|
|
|
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) |
|
|
|
} |
|
|
|
} |
|
|
|
if msg.Size > baseProtocolMaxMsgSize { |
|
|
|
if msg.Size > baseProtocolMaxMsgSize { |
|
|
|
return newPeerError(errMisc, "message too big") |
|
|
|
return newPeerError(errInvalidMsg, "message too big") |
|
|
|
} |
|
|
|
} |
|
|
|
var hs handshake |
|
|
|
var hs handshake |
|
|
|
if err := msg.Decode(&hs); err != nil { |
|
|
|
if err := msg.Decode(&hs); err != nil { |
|
|
@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto { |
|
|
|
err := impl.Run(p, rw) |
|
|
|
err := impl.Run(p, rw) |
|
|
|
if err == nil { |
|
|
|
if err == nil { |
|
|
|
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) |
|
|
|
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) |
|
|
|
err = newPeerError(errMisc, "protocol returned") |
|
|
|
err = errors.New("protocol returned") |
|
|
|
} else { |
|
|
|
} else { |
|
|
|
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) |
|
|
|
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) |
|
|
|
} |
|
|
|
} |
|
|
|