@ -68,50 +68,61 @@ type protoHandshake struct {
// setupConn starts a protocol session on the given connection.
// setupConn starts a protocol session on the given connection.
// It runs the encryption handshake and the protocol handshake.
// It runs the encryption handshake and the protocol handshake.
// If dial is non-nil, the connection the local node is the initiator.
// If dial is non-nil, the connection the local node is the initiator.
func setupConn ( fd net . Conn , prv * ecdsa . PrivateKey , our * protoHandshake , dial * discover . Node ) ( * conn , error ) {
// If atcap is true, the connection will be disconnected with DiscTooManyPeers
// after the key exchange.
func setupConn ( fd net . Conn , prv * ecdsa . PrivateKey , our * protoHandshake , dial * discover . Node , atcap bool ) ( * conn , error ) {
if dial == nil {
if dial == nil {
return setupInboundConn ( fd , prv , our )
return setupInboundConn ( fd , prv , our , atcap )
} else {
} else {
return setupOutboundConn ( fd , prv , our , dial )
return setupOutboundConn ( fd , prv , our , dial , atcap )
}
}
}
}
func setupInboundConn ( fd net . Conn , prv * ecdsa . PrivateKey , our * protoHandshake ) ( * conn , error ) {
func setupInboundConn ( fd net . Conn , prv * ecdsa . PrivateKey , our * protoHandshake , atcap bool ) ( * conn , error ) {
secrets , err := receiverEncHandshake ( fd , prv , nil )
secrets , err := receiverEncHandshake ( fd , prv , nil )
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "encryption handshake failed: %v" , err )
return nil , fmt . Errorf ( "encryption handshake failed: %v" , err )
}
}
// Run the protocol handshake using authenticated messages.
rw := newRlpxFrameRW ( fd , secrets )
rw := newRlpxFrameRW ( fd , secrets )
rhs , err := readProtocolHandshake ( rw , our )
if atcap {
SendItems ( rw , discMsg , DiscTooManyPeers )
return nil , errors . New ( "we have too many peers" )
}
// Run the protocol handshake using authenticated messages.
rhs , err := readProtocolHandshake ( rw , secrets . RemoteID , our )
if err != nil {
if err != nil {
return nil , err
return nil , err
}
}
if rhs . ID != secrets . RemoteID {
return nil , errors . New ( "node ID in protocol handshake does not match encryption handshake" )
}
// TODO: validate that handshake node ID matches
if err := Send ( rw , handshakeMsg , our ) ; err != nil {
if err := Send ( rw , handshakeMsg , our ) ; err != nil {
return nil , fmt . Errorf ( "protocol write error: %v" , err )
return nil , fmt . Errorf ( "protocol handshake write error: %v" , err )
}
}
return & conn { rw , rhs } , nil
return & conn { rw , rhs } , nil
}
}
func setupOutboundConn ( fd net . Conn , prv * ecdsa . PrivateKey , our * protoHandshake , dial * discover . Node ) ( * conn , error ) {
func setupOutboundConn ( fd net . Conn , prv * ecdsa . PrivateKey , our * protoHandshake , dial * discover . Node , atcap bool ) ( * conn , error ) {
secrets , err := initiatorEncHandshake ( fd , prv , dial . ID , nil )
secrets , err := initiatorEncHandshake ( fd , prv , dial . ID , nil )
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "encryption handshake failed: %v" , err )
return nil , fmt . Errorf ( "encryption handshake failed: %v" , err )
}
}
// Run the protocol handshake using authenticated messages.
rw := newRlpxFrameRW ( fd , secrets )
rw := newRlpxFrameRW ( fd , secrets )
if err := Send ( rw , handshakeMsg , our ) ; err != nil {
if atcap {
return nil , fmt . Errorf ( "protocol write error: %v" , err )
SendItems ( rw , discMsg , DiscTooManyPeers )
return nil , errors . New ( "we have too many peers" )
}
}
rhs , err := readProtocolHandshake ( rw , our )
// Run the protocol handshake using authenticated messages.
//
// Note that even though writing the handshake is first, we prefer
// returning the handshake read error. If the remote side
// disconnects us early with a valid reason, we should return it
// as the error so it can be tracked elsewhere.
werr := make ( chan error )
go func ( ) { werr <- Send ( rw , handshakeMsg , our ) } ( )
rhs , err := readProtocolHandshake ( rw , secrets . RemoteID , our )
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "protocol handshake read error: %v" , err )
return nil , err
}
if err := <- werr ; err != nil {
return nil , fmt . Errorf ( "protocol handshake write error: %v" , err )
}
}
if rhs . ID != dial . ID {
if rhs . ID != dial . ID {
return nil , errors . New ( "dialed node id mismatch" )
return nil , errors . New ( "dialed node id mismatch" )
@ -398,18 +409,17 @@ func xor(one, other []byte) (xor []byte) {
return xor
return xor
}
}
func readProtocolHandshake ( r MsgReader , our * protoHandshake ) ( * protoHandshake , error ) {
func readProtocolHandshake ( rw MsgReadWriter , wantID discover . NodeID , our * protoHandshake ) ( * protoHandshake , error ) {
// read and handle remote handshake
msg , err := rw . ReadMsg ( )
msg , err := r . ReadMsg ( )
if err != nil {
if err != nil {
return nil , err
return nil , err
}
}
if msg . Code == discMsg {
if msg . Code == discMsg {
// disconnect before protocol handshake is valid according to the
// disconnect before protocol handshake is valid according to the
// spec and we send it ourself if Server.addPeer fails.
// spec and we send it ourself if Server.addPeer fails.
var reason DiscReason
var reason [ 1 ] DiscReason
rlp . Decode ( msg . Payload , & reason )
rlp . Decode ( msg . Payload , & reason )
return nil , reason
return nil , reason [ 0 ]
}
}
if msg . Code != handshakeMsg {
if msg . Code != handshakeMsg {
return nil , fmt . Errorf ( "expected handshake, got %x" , msg . Code )
return nil , fmt . Errorf ( "expected handshake, got %x" , msg . Code )
@ -423,10 +433,16 @@ func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, e
}
}
// validate handshake info
// validate handshake info
if hs . Version != our . Version {
if hs . Version != our . Version {
return nil , newPeerError ( errP2PVersionMismatch , "required version %d, received %d\n" , baseProtocolVersion , hs . Version )
SendItems ( rw , discMsg , DiscIncompatibleVersion )
return nil , fmt . Errorf ( "required version %d, received %d\n" , baseProtocolVersion , hs . Version )
}
}
if ( hs . ID == discover . NodeID { } ) {
if ( hs . ID == discover . NodeID { } ) {
return nil , newPeerError ( errPubkeyInvalid , "missing" )
SendItems ( rw , discMsg , DiscInvalidIdentity )
return nil , errors . New ( "invalid public key in handshake" )
}
if hs . ID != wantID {
SendItems ( rw , discMsg , DiscUnexpectedIdentity )
return nil , errors . New ( "handshake node ID does not match encryption handshake" )
}
}
return & hs , nil
return & hs , nil
}
}