p2p, p2p/discover, p2p/discv5: add IP network restriction feature

The p2p packages can now be configured to restrict all communication to
a certain subset of IP networks. This feature is meant to be used for
private networks.
pull/3325/head
Felix Lange 8 years ago
parent e46bda5093
commit a47341cf96
  1. 45
      p2p/dial.go
  2. 41
      p2p/dial_test.go
  3. 25
      p2p/discover/udp.go
  4. 2
      p2p/discover/udp_test.go
  5. 12
      p2p/discv5/net.go
  6. 2
      p2p/discv5/net_test.go
  7. 2
      p2p/discv5/sim_test.go
  8. 4
      p2p/discv5/udp.go
  9. 25
      p2p/server.go

@ -19,6 +19,7 @@ package p2p
import ( import (
"container/heap" "container/heap"
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
@ -26,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
const ( const (
@ -48,6 +50,7 @@ const (
type dialstate struct { type dialstate struct {
maxDynDials int maxDynDials int
ntab discoverTable ntab discoverTable
netrestrict *netutil.Netlist
lookupRunning bool lookupRunning bool
dialing map[discover.NodeID]connFlag dialing map[discover.NodeID]connFlag
@ -100,10 +103,11 @@ type waitExpireTask struct {
time.Duration time.Duration
} }
func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate { func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
s := &dialstate{ s := &dialstate{
maxDynDials: maxdyn, maxDynDials: maxdyn,
ntab: ntab, ntab: ntab,
netrestrict: netrestrict,
static: make(map[discover.NodeID]*dialTask), static: make(map[discover.NodeID]*dialTask),
dialing: make(map[discover.NodeID]connFlag), dialing: make(map[discover.NodeID]connFlag),
randomNodes: make([]*discover.Node, maxdyn/2), randomNodes: make([]*discover.Node, maxdyn/2),
@ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) {
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
var newtasks []task var newtasks []task
isDialing := func(id discover.NodeID) bool {
_, found := s.dialing[id]
return found || peers[id] != nil || s.hist.contains(id)
}
addDial := func(flag connFlag, n *discover.Node) bool { addDial := func(flag connFlag, n *discover.Node) bool {
if isDialing(n.ID) { if err := s.checkDial(n, peers); err != nil {
glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err)
return false return false
} }
s.dialing[n.ID] = flag s.dialing[n.ID] = flag
@ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
// Create dials for static nodes if they are not connected. // Create dials for static nodes if they are not connected.
for id, t := range s.static { for id, t := range s.static {
if !isDialing(id) { err := s.checkDial(t.dest, peers)
switch err {
case errNotWhitelisted, errSelf:
glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err)
delete(s.static, t.dest.ID)
case nil:
s.dialing[id] = t.flags s.dialing[id] = t.flags
newtasks = append(newtasks, t) newtasks = append(newtasks, t)
} }
@ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
return newtasks return newtasks
} }
var (
errSelf = errors.New("is self")
errAlreadyDialing = errors.New("already dialing")
errAlreadyConnected = errors.New("already connected")
errRecentlyDialed = errors.New("recently dialed")
errNotWhitelisted = errors.New("not contained in netrestrict whitelist")
)
func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error {
_, dialing := s.dialing[n.ID]
switch {
case dialing:
return errAlreadyDialing
case peers[n.ID] != nil:
return errAlreadyConnected
case s.ntab != nil && n.ID == s.ntab.Self().ID:
return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP):
return errNotWhitelisted
case s.hist.contains(n.ID):
return errRecentlyDialed
}
return nil
}
func (s *dialstate) taskDone(t task, now time.Time) { func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) { switch t := t.(type) {
case *dialTask: case *dialTask:

@ -25,6 +25,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
func init() { func init() {
@ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf,
// This test checks that dynamic dials are launched from discovery results. // This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) { func TestDialStateDynDial(t *testing.T) {
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(nil, fakeTable{}, 5), init: newDialState(nil, fakeTable{}, 5, nil),
rounds: []round{ rounds: []round{
// A discovery query is launched. // A discovery query is launched.
{ {
@ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(nil, table, 10), init: newDialState(nil, table, 10, nil),
rounds: []round{ rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{ {
@ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) {
}) })
} }
// This test checks that candidates that do not match the netrestrict list are not dialed.
func TestDialStateNetRestrict(t *testing.T) {
// This table always returns the same random nodes
// in the order given below.
table := fakeTable{
{ID: uintID(1), IP: net.ParseIP("127.0.0.1")},
{ID: uintID(2), IP: net.ParseIP("127.0.0.2")},
{ID: uintID(3), IP: net.ParseIP("127.0.0.3")},
{ID: uintID(4), IP: net.ParseIP("127.0.0.4")},
{ID: uintID(5), IP: net.ParseIP("127.0.2.5")},
{ID: uintID(6), IP: net.ParseIP("127.0.2.6")},
{ID: uintID(7), IP: net.ParseIP("127.0.2.7")},
{ID: uintID(8), IP: net.ParseIP("127.0.2.8")},
}
restrict := new(netutil.Netlist)
restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{
init: newDialState(nil, table, 10, restrict),
rounds: []round{
{
new: []task{
&dialTask{flags: dynDialedConn, dest: table[4]},
&discoverTask{},
},
},
},
})
}
// This test checks that static dials are launched. // This test checks that static dials are launched.
func TestDialStateStaticDial(t *testing.T) { func TestDialStateStaticDial(t *testing.T) {
wantStatic := []*discover.Node{ wantStatic := []*discover.Node{
@ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(wantStatic, fakeTable{}, 0), init: newDialState(wantStatic, fakeTable{}, 0, nil),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(wantStatic, fakeTable{}, 0), init: newDialState(wantStatic, fakeTable{}, 0, nil),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) {
func TestDialResolve(t *testing.T) { func TestDialResolve(t *testing.T) {
resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444)
table := &resolveMock{answer: resolved} table := &resolveMock{answer: resolved}
state := newDialState(nil, table, 0) state := newDialState(nil, table, 0, nil)
// Check that the task is generated with an incomplete ID. // Check that the task is generated with an incomplete ID.
dest := discover.NewNode(uintID(1), nil, 0, 0) dest := discover.NewNode(uintID(1), nil, 0, 0)

@ -127,13 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
} }
func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
if rn.UDP <= 1024 { if rn.UDP <= 1024 {
return nil, errors.New("low port") return nil, errors.New("low port")
} }
if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
return nil, err return nil, err
} }
if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
return nil, errors.New("not contained in netrestrict whitelist")
}
n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
err := n.validateComplete() err := n.validateComplete()
return n, err return n, err
@ -157,6 +160,7 @@ type conn interface {
// udp implements the RPC protocol. // udp implements the RPC protocol.
type udp struct { type udp struct {
conn conn conn conn
netrestrict *netutil.Netlist
priv *ecdsa.PrivateKey priv *ecdsa.PrivateKey
ourEndpoint rpcEndpoint ourEndpoint rpcEndpoint
@ -207,7 +211,7 @@ type reply struct {
} }
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) { func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
addr, err := net.ResolveUDPAddr("udp", laddr) addr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -216,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
if err != nil { if err != nil {
return nil, err return nil, err
} }
tab, _, err := newUDP(priv, conn, natm, nodeDBPath) tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -224,13 +228,14 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
return tab, nil return tab, nil
} }
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) { func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
udp := &udp{ udp := &udp{
conn: c, conn: c,
priv: priv, priv: priv,
closing: make(chan struct{}), netrestrict: netrestrict,
gotreply: make(chan reply), closing: make(chan struct{}),
addpending: make(chan *pending), gotreply: make(chan reply),
addpending: make(chan *pending),
} }
realaddr := c.LocalAddr().(*net.UDPAddr) realaddr := c.LocalAddr().(*net.UDPAddr)
if natm != nil { if natm != nil {
@ -287,7 +292,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
reply := r.(*neighbors) reply := r.(*neighbors)
for _, rn := range reply.Nodes { for _, rn := range reply.Nodes {
nreceived++ nreceived++
n, err := nodeFromRPC(toaddr, rn) n, err := t.nodeFromRPC(toaddr, rn)
if err != nil { if err != nil {
glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err) glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err)
continue continue

@ -70,7 +70,7 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(), remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
} }
test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "") test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil)
return test return test
} }

@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -63,8 +64,9 @@ func debugLog(s string) {
// Network manages the table and all protocol interaction. // Network manages the table and all protocol interaction.
type Network struct { type Network struct {
db *nodeDB // database of known nodes db *nodeDB // database of known nodes
conn transport conn transport
netrestrict *netutil.Netlist
closed chan struct{} // closed when loop is done closed chan struct{} // closed when loop is done
closeReq chan struct{} // 'request to close' closeReq chan struct{} // 'request to close'
@ -133,7 +135,7 @@ type timeoutEvent struct {
node *Node node *Node
} }
func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) { func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
ourID := PubkeyID(&ourPubkey) ourID := PubkeyID(&ourPubkey)
var db *nodeDB var db *nodeDB
@ -148,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d
net := &Network{ net := &Network{
db: db, db: db,
conn: conn, conn: conn,
netrestrict: netrestrict,
tab: tab, tab: tab,
topictab: newTopicTable(db, tab.self), topictab: newTopicTable(db, tab.self),
ticketStore: newTicketStore(), ticketStore: newTicketStore(),
@ -696,6 +699,9 @@ func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n
if n == nil { if n == nil {
// We haven't seen this node before. // We haven't seen this node before.
n, err = nodeFromRPC(sender, rn) n, err = nodeFromRPC(sender, rn)
if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
return n, errors.New("not contained in netrestrict whitelist")
}
if err == nil { if err == nil {
n.state = unknown n.state = unknown
net.nodes[n.ID] = n net.nodes[n.ID] = n

@ -28,7 +28,7 @@ import (
func TestNetwork_Lookup(t *testing.T) { func TestNetwork_Lookup(t *testing.T) {
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()
network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "") network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network {
addr := &net.UDPAddr{IP: ip, Port: 30303} addr := &net.UDPAddr{IP: ip, Port: 30303}
transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key}
net, err := newNetwork(transport, key.PublicKey, nil, "<no database>") net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil)
if err != nil { if err != nil {
panic("cannot launch new node: " + err.Error()) panic("cannot launch new node: " + err.Error())
} }

@ -238,12 +238,12 @@ type udp struct {
} }
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) { func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
transport, err := listenUDP(priv, laddr) transport, err := listenUDP(priv, laddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath) net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
const ( const (
@ -101,6 +102,11 @@ type Config struct {
// allowed to connect, even above the peer limit. // allowed to connect, even above the peer limit.
TrustedNodes []*discover.Node TrustedNodes []*discover.Node
// Connectivity can be restricted to certain IP networks.
// If this option is set to a non-nil value, only hosts which match one of the
// IP networks contained in the list are considered.
NetRestrict *netutil.Netlist
// NodeDatabase is the path to the database containing the previously seen // NodeDatabase is the path to the database containing the previously seen
// live nodes in the network. // live nodes in the network.
NodeDatabase string NodeDatabase string
@ -356,7 +362,7 @@ func (srv *Server) Start() (err error) {
// node table // node table
if srv.Discovery { if srv.Discovery {
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict)
if err != nil { if err != nil {
return err return err
} }
@ -367,7 +373,7 @@ func (srv *Server) Start() (err error) {
} }
if srv.DiscoveryV5 { if srv.DiscoveryV5 {
ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase) ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase)
if err != nil { if err != nil {
return err return err
} }
@ -381,7 +387,7 @@ func (srv *Server) Start() (err error) {
if !srv.Discovery { if !srv.Discovery {
dynPeers = 0 dynPeers = 0
} }
dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers) dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake // handshake
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)}
@ -634,8 +640,19 @@ func (srv *Server) listenLoop() {
} }
break break
} }
// Reject connections that do not match NetRestrict.
if srv.NetRestrict != nil {
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr())
fd.Close()
slots <- struct{}{}
continue
}
}
fd = newMeteredConn(fd, true) fd = newMeteredConn(fd, true)
glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr()) glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr())
// Spawn the handler. It will give the slot back when the connection // Spawn the handler. It will give the slot back when the connection
// has been established. // has been established.

Loading…
Cancel
Save