@ -24,6 +24,8 @@ var (
errBadHash = errors . New ( "bad hash" )
errBadHash = errors . New ( "bad hash" )
errExpired = errors . New ( "expired" )
errExpired = errors . New ( "expired" )
errBadVersion = errors . New ( "version mismatch" )
errBadVersion = errors . New ( "version mismatch" )
errUnsolicitedReply = errors . New ( "unsolicited reply" )
errUnknownNode = errors . New ( "unknown node" )
errTimeout = errors . New ( "RPC timeout" )
errTimeout = errors . New ( "RPC timeout" )
errClosed = errors . New ( "socket closed" )
errClosed = errors . New ( "socket closed" )
)
)
@ -80,12 +82,25 @@ type rpcNode struct {
ID NodeID
ID NodeID
}
}
type packet interface {
handle ( t * udp , from * net . UDPAddr , fromID NodeID , mac [ ] byte ) error
}
type conn interface {
ReadFromUDP ( b [ ] byte ) ( n int , addr * net . UDPAddr , err error )
WriteToUDP ( b [ ] byte , addr * net . UDPAddr ) ( n int , err error )
Close ( ) error
LocalAddr ( ) net . Addr
}
// udp implements the RPC protocol.
// udp implements the RPC protocol.
type udp struct {
type udp struct {
conn * net . UDPConn
conn c onn
priv * ecdsa . PrivateKey
priv * ecdsa . PrivateKey
addpending chan * pending
addpending chan * pending
replies chan reply
gotreply chan reply
closing chan struct { }
closing chan struct { }
nat nat . Interface
nat nat . Interface
@ -124,6 +139,9 @@ type reply struct {
from NodeID
from NodeID
ptype byte
ptype byte
data interface { }
data interface { }
// loop indicates whether there was
// a matching request by sending on this channel.
matched chan <- bool
}
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
// ListenUDP returns a new table that listens for UDP packets on laddr.
@ -136,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
if err != nil {
if err != nil {
return nil , err
return nil , err
}
}
tab , _ := newUDP ( priv , conn , natm )
log . Infoln ( "Listening," , tab . self )
return tab , nil
}
func newUDP ( priv * ecdsa . PrivateKey , c conn , natm nat . Interface ) ( * Table , * udp ) {
udp := & udp {
udp := & udp {
conn : conn ,
conn : c ,
priv : priv ,
priv : priv ,
closing : make ( chan struct { } ) ,
closing : make ( chan struct { } ) ,
gotreply : make ( chan reply ) ,
addpending : make ( chan * pending ) ,
addpending : make ( chan * pending ) ,
replies : make ( chan reply ) ,
}
}
realaddr := c . LocalAddr ( ) . ( * net . UDPAddr )
realaddr := conn . LocalAddr ( ) . ( * net . UDPAddr )
if natm != nil {
if natm != nil {
if ! realaddr . IP . IsLoopback ( ) {
if ! realaddr . IP . IsLoopback ( ) {
go nat . Map ( natm , udp . closing , "udp" , realaddr . Port , realaddr . Port , "ethereum discovery" )
go nat . Map ( natm , udp . closing , "udp" , realaddr . Port , realaddr . Port , "ethereum discovery" )
@ -155,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
}
}
}
}
udp . Table = newTable ( udp , PubkeyID ( & priv . PublicKey ) , realaddr )
udp . Table = newTable ( udp , PubkeyID ( & priv . PublicKey ) , realaddr )
go udp . loop ( )
go udp . loop ( )
go udp . readLoop ( )
go udp . readLoop ( )
log . Infoln ( "Listening, " , udp . self )
return udp . Table , udp
return udp . Table , nil
}
}
func ( t * udp ) close ( ) {
func ( t * udp ) close ( ) {
@ -169,10 +190,10 @@ func (t *udp) close() {
}
}
// ping sends a ping message to the given node and waits for a reply.
// ping sends a ping message to the given node and waits for a reply.
func ( t * udp ) ping ( e * Node ) error {
func ( t * udp ) ping ( toid NodeID , toaddr * net . UDPAddr ) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT
// TODO: maybe check for ReplyTo field in callback to measure RTT
errc := t . pending ( e . ID , pongPacket , func ( interface { } ) bool { return true } )
errc := t . pending ( toid , pongPacket , func ( interface { } ) bool { return true } )
t . send ( e , pingPacket , ping {
t . send ( toaddr , pingPacket , ping {
Version : Version ,
Version : Version ,
IP : t . self . IP . String ( ) ,
IP : t . self . IP . String ( ) ,
Port : uint16 ( t . self . TCPPort ) ,
Port : uint16 ( t . self . TCPPort ) ,
@ -181,12 +202,16 @@ func (t *udp) ping(e *Node) error {
return <- errc
return <- errc
}
}
func ( t * udp ) waitping ( from NodeID ) error {
return <- t . pending ( from , pingPacket , func ( interface { } ) bool { return true } )
}
// findnode sends a findnode request to the given node and waits until
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
// the node has sent up to k neighbors.
func ( t * udp ) findnode ( to * Node , target NodeID ) ( [ ] * Node , error ) {
func ( t * udp ) findnode ( toid NodeID , toaddr * net . UDPAddr , target NodeID ) ( [ ] * Node , error ) {
nodes := make ( [ ] * Node , 0 , bucketSize )
nodes := make ( [ ] * Node , 0 , bucketSize )
nreceived := 0
nreceived := 0
errc := t . pending ( to . ID , neighborsPacket , func ( r interface { } ) bool {
errc := t . pending ( toid , neighborsPacket , func ( r interface { } ) bool {
reply := r . ( * neighbors )
reply := r . ( * neighbors )
for _ , n := range reply . Nodes {
for _ , n := range reply . Nodes {
nreceived ++
nreceived ++
@ -196,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
}
}
return nreceived >= bucketSize
return nreceived >= bucketSize
} )
} )
t . send ( toaddr , findnodePacket , findnode {
t . send ( to , findnodePacket , findnode {
Target : target ,
Target : target ,
Expiration : uint64 ( time . Now ( ) . Add ( expiration ) . Unix ( ) ) ,
Expiration : uint64 ( time . Now ( ) . Add ( expiration ) . Unix ( ) ) ,
} )
} )
@ -219,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
return ch
return ch
}
}
func ( t * udp ) handleReply ( from NodeID , ptype byte , req packet ) bool {
matched := make ( chan bool )
select {
case t . gotreply <- reply { from , ptype , req , matched } :
// loop will handle it
return <- matched
case <- t . closing :
return false
}
}
// loop runs in its own goroutin. it keeps track of
// loop runs in its own goroutin. it keeps track of
// the refresh timer and the pending reply queue.
// the refresh timer and the pending reply queue.
func ( t * udp ) loop ( ) {
func ( t * udp ) loop ( ) {
@ -249,6 +284,7 @@ func (t *udp) loop() {
for _ , p := range pending {
for _ , p := range pending {
p . errc <- errClosed
p . errc <- errClosed
}
}
pending = nil
return
return
case p := <- t . addpending :
case p := <- t . addpending :
@ -256,18 +292,21 @@ func (t *udp) loop() {
pending = append ( pending , p )
pending = append ( pending , p )
rearmTimeout ( )
rearmTimeout ( )
case reply := <- t . replies :
case r := <- t . gotreply :
// run matching callbacks, remove if they return false.
var matched bool
for i := 0 ; i < len ( pending ) ; i ++ {
for i := 0 ; i < len ( pending ) ; i ++ {
p := pending [ i ]
if p := pending [ i ] ; p . from == r . from && p . ptype == r . ptype {
if reply . from == p . from && reply . ptype == p . ptype && p . callback ( reply . data ) {
matched = true
if p . callback ( r . data ) {
// callback indicates the request is done, remove it.
p . errc <- nil
p . errc <- nil
copy ( pending [ i : ] , pending [ i + 1 : ] )
copy ( pending [ i : ] , pending [ i + 1 : ] )
pending = pending [ : len ( pending ) - 1 ]
pending = pending [ : len ( pending ) - 1 ]
i --
i --
}
}
}
}
rearmTimeout ( )
}
r . matched <- matched
case now := <- timeout . C :
case now := <- timeout . C :
// notify and remove callbacks whose deadline is in the past.
// notify and remove callbacks whose deadline is in the past.
@ -292,33 +331,38 @@ const (
var headSpace = make ( [ ] byte , headSize )
var headSpace = make ( [ ] byte , headSize )
func ( t * udp ) send ( to * Node , ptype byte , req interface { } ) error {
func ( t * udp ) send ( toaddr * net . UDPAddr , ptype byte , req interface { } ) error {
packet , err := encodePacket ( t . priv , ptype , req )
if err != nil {
return err
}
log . DebugDetailf ( ">>> %v %T %v\n" , toaddr , req , req )
if _ , err = t . conn . WriteToUDP ( packet , toaddr ) ; err != nil {
log . DebugDetailln ( "UDP send failed:" , err )
}
return err
}
func encodePacket ( priv * ecdsa . PrivateKey , ptype byte , req interface { } ) ( [ ] byte , error ) {
b := new ( bytes . Buffer )
b := new ( bytes . Buffer )
b . Write ( headSpace )
b . Write ( headSpace )
b . WriteByte ( ptype )
b . WriteByte ( ptype )
if err := rlp . Encode ( b , req ) ; err != nil {
if err := rlp . Encode ( b , req ) ; err != nil {
log . Errorln ( "error encoding packet:" , err )
log . Errorln ( "error encoding packet:" , err )
return err
return nil , err
}
}
packet := b . Bytes ( )
packet := b . Bytes ( )
sig , err := crypto . Sign ( crypto . Sha3 ( packet [ headSize : ] ) , t . priv )
sig , err := crypto . Sign ( crypto . Sha3 ( packet [ headSize : ] ) , priv )
if err != nil {
if err != nil {
log . Errorln ( "could not sign packet:" , err )
log . Errorln ( "could not sign packet:" , err )
return err
return nil , err
}
}
copy ( packet [ macSize : ] , sig )
copy ( packet [ macSize : ] , sig )
// add the hash to the front. Note: this doesn't protect the
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// packet in any way. Our public key will be part of this hash in
// t he future.
// T he future.
copy ( packet , crypto . Sha3 ( packet [ macSize : ] ) )
copy ( packet , crypto . Sha3 ( packet [ macSize : ] ) )
return packet , nil
toaddr := & net . UDPAddr { IP : to . IP , Port : to . DiscPort }
log . DebugDetailf ( ">>> %v %T %v\n" , toaddr , req , req )
if _ , err = t . conn . WriteToUDP ( packet , toaddr ) ; err != nil {
log . DebugDetailln ( "UDP send failed:" , err )
}
return err
}
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
// readLoop runs in its own goroutine. it handles incoming UDP packets.
@ -330,29 +374,34 @@ func (t *udp) readLoop() {
if err != nil {
if err != nil {
return
return
}
}
if err := t . packetIn ( from , buf [ : nbytes ] ) ; err != nil {
packet , fromID , hash , err := decodePacket ( buf [ : nbytes ] )
if err != nil {
log . Debugf ( "Bad packet from %v: %v\n" , from , err )
log . Debugf ( "Bad packet from %v: %v\n" , from , err )
continue
}
log . DebugDetailf ( "<<< %v %T %v\n" , from , packet , packet )
go func ( ) {
if err := packet . handle ( t , from , fromID , hash ) ; err != nil {
log . Debugf ( "error handling %T from %v: %v" , packet , from , err )
}
}
} ( )
}
}
}
}
func ( t * udp ) packetIn ( from * net . UDPAddr , buf [ ] byte ) error {
func decodePacket ( buf [ ] byte ) ( packet , NodeID , [ ] byte , error ) {
if len ( buf ) < headSize + 1 {
if len ( buf ) < headSize + 1 {
return errPacketTooSmall
return nil , NodeID { } , nil , errPacketTooSmall
}
}
hash , sig , sigdata := buf [ : macSize ] , buf [ macSize : headSize ] , buf [ headSize : ]
hash , sig , sigdata := buf [ : macSize ] , buf [ macSize : headSize ] , buf [ headSize : ]
shouldhash := crypto . Sha3 ( buf [ macSize : ] )
shouldhash := crypto . Sha3 ( buf [ macSize : ] )
if ! bytes . Equal ( hash , shouldhash ) {
if ! bytes . Equal ( hash , shouldhash ) {
return errBadHash
return nil , NodeID { } , nil , errBadHash
}
}
fromID , err := recoverNodeID ( crypto . Sha3 ( buf [ headSize : ] ) , sig )
fromID , err := recoverNodeID ( crypto . Sha3 ( buf [ headSize : ] ) , sig )
if err != nil {
if err != nil {
return err
return nil , NodeID { } , hash , err
}
var req interface {
handle ( t * udp , from * net . UDPAddr , fromID NodeID , mac [ ] byte ) error
}
}
var req packet
switch ptype := sigdata [ 0 ] ; ptype {
switch ptype := sigdata [ 0 ] ; ptype {
case pingPacket :
case pingPacket :
req = new ( ping )
req = new ( ping )
@ -363,13 +412,10 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
case neighborsPacket :
case neighborsPacket :
req = new ( neighbors )
req = new ( neighbors )
default :
default :
return fmt . Errorf ( "unknown type: %d" , ptype )
return nil , fromID , hash , fmt . Errorf ( "unknown type: %d" , ptype )
}
if err := rlp . Decode ( bytes . NewReader ( sigdata [ 1 : ] ) , req ) ; err != nil {
return err
}
}
log . DebugDetailf ( "<<< %v %T %v\n" , from , req , req )
err = rlp . Decode ( bytes . NewReader ( sigdata [ 1 : ] ) , req )
return req . handle ( t , from , fromID , hash )
return req , fromID , hash , err
}
}
func ( req * ping ) handle ( t * udp , from * net . UDPAddr , fromID NodeID , mac [ ] byte ) error {
func ( req * ping ) handle ( t * udp , from * net . UDPAddr , fromID NodeID , mac [ ] byte ) error {
@ -379,18 +425,14 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if req . Version != Version {
if req . Version != Version {
return errBadVersion
return errBadVersion
}
}
t . mutex . Lock ( )
t . send ( from , pongPacket , pong {
// Note: we're ignoring the provided IP address right now
n := t . bumpOrAdd ( fromID , from )
if req . Port != 0 {
n . TCPPort = int ( req . Port )
}
t . mutex . Unlock ( )
t . send ( n , pongPacket , pong {
ReplyTok : mac ,
ReplyTok : mac ,
Expiration : uint64 ( time . Now ( ) . Add ( expiration ) . Unix ( ) ) ,
Expiration : uint64 ( time . Now ( ) . Add ( expiration ) . Unix ( ) ) ,
} )
} )
if ! t . handleReply ( fromID , pingPacket , req ) {
// Note: we're ignoring the provided IP address right now
t . bond ( true , fromID , from , req . Port )
}
return nil
return nil
}
}
@ -398,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if expired ( req . Expiration ) {
if expired ( req . Expiration ) {
return errExpired
return errExpired
}
}
t . mutex . Lock ( )
if ! t . handleReply ( fromID , pongPacket , req ) {
t . bump ( fromID )
return errUnsolicitedReply
t . mutex . Unlock ( )
}
t . replies <- reply { fromID , pongPacket , req }
return nil
return nil
}
}
@ -410,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
if expired ( req . Expiration ) {
if expired ( req . Expiration ) {
return errExpired
return errExpired
}
}
if t . db . get ( fromID ) == nil {
// No bond exists, we don't process the packet. This prevents
// an attack vector where the discovery protocol could be used
// to amplify traffic in a DDOS attack. A malicious actor
// would send a findnode request with the IP address and UDP
// port of the target as the source address. The recipient of
// the findnode packet would then send a neighbors packet
// (which is a much bigger packet than findnode) to the victim.
return errUnknownNode
}
t . mutex . Lock ( )
t . mutex . Lock ( )
e := t . bumpOrAdd ( fromID , from )
closest := t . closest ( req . Target , bucketSize ) . entries
closest := t . closest ( req . Target , bucketSize ) . entries
t . mutex . Unlock ( )
t . mutex . Unlock ( )
t . send ( e , neighborsPacket , neighbors {
t . send ( from , neighborsPacket , neighbors {
Nodes : closest ,
Nodes : closest ,
Expiration : uint64 ( time . Now ( ) . Add ( expiration ) . Unix ( ) ) ,
Expiration : uint64 ( time . Now ( ) . Add ( expiration ) . Unix ( ) ) ,
} )
} )
@ -426,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt
if expired ( req . Expiration ) {
if expired ( req . Expiration ) {
return errExpired
return errExpired
}
}
t . mutex . Lock ( )
if ! t . handleReply ( fromID , neighborsPacket , req ) {
t . bump ( fromID )
return errUnsolicitedReply
t . add ( req . Nodes )
}
t . mutex . Unlock ( )
t . replies <- reply { fromID , neighborsPacket , req }
return nil
return nil
}
}