p2p: API cleanup and PoC 7 compatibility

Whoa, one more big commit. I didn't manage to untangle the
changes while working towards compatibility.
poc8
Felix Lange 10 years ago
parent e4a601c644
commit 59b63caf5e
  1. 6
      p2p/client_identity.go
  2. 62
      p2p/message.go
  3. 221
      p2p/messenger.go
  4. 203
      p2p/messenger_test.go
  5. 34
      p2p/natpmp.go
  6. 198
      p2p/natupnp.go
  7. 196
      p2p/network.go
  8. 476
      p2p/peer.go
  9. 150
      p2p/peer_error.go
  10. 98
      p2p/peer_error_handler.go
  11. 34
      p2p/peer_error_handler_test.go
  12. 308
      p2p/peer_test.go
  13. 412
      p2p/protocol.go
  14. 713
      p2p/server.go
  15. 388
      p2p/server_test.go
  16. 28
      p2p/testlog_test.go
  17. 40
      p2p/testpoc7.go

@ -5,10 +5,10 @@ import (
"runtime" "runtime"
) )
// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. // ClientIdentity represents the identity of a peer.
type ClientIdentity interface { type ClientIdentity interface {
String() string String() string // human readable identity
Pubkey() []byte Pubkey() []byte // 512-bit public key
} }
type SimpleClientIdentity struct { type SimpleClientIdentity struct {

@ -11,8 +11,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
) )
type MsgCode uint64
// Msg defines the structure of a p2p message. // Msg defines the structure of a p2p message.
// //
// Note that a Msg can only be sent once since the Payload reader is // Note that a Msg can only be sent once since the Payload reader is
@ -21,13 +19,13 @@ type MsgCode uint64
// structure, encode the payload into a byte array and create a // structure, encode the payload into a byte array and create a
// separate Msg with a bytes.Reader as Payload for each send. // separate Msg with a bytes.Reader as Payload for each send.
type Msg struct { type Msg struct {
Code MsgCode Code uint64
Size uint32 // size of the paylod Size uint32 // size of the paylod
Payload io.Reader Payload io.Reader
} }
// NewMsg creates an RLP-encoded message with the given code. // NewMsg creates an RLP-encoded message with the given code.
func NewMsg(code MsgCode, params ...interface{}) Msg { func NewMsg(code uint64, params ...interface{}) Msg {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
for _, p := range params { for _, p := range params {
buf.Write(ethutil.Encode(p)) buf.Write(ethutil.Encode(p))
@ -63,6 +61,52 @@ func (msg Msg) Discard() error {
return err return err
} }
type MsgReader interface {
ReadMsg() (Msg, error)
}
type MsgWriter interface {
// WriteMsg sends an existing message.
// The Payload reader of the message is consumed.
// Note that messages can be sent only once.
WriteMsg(Msg) error
// EncodeMsg writes an RLP-encoded message with the given
// code and data elements.
EncodeMsg(code uint64, data ...interface{}) error
}
// MsgReadWriter provides reading and writing of encoded messages.
type MsgReadWriter interface {
MsgReader
MsgWriter
}
// MsgLoop reads messages off the given reader and
// calls the handler function for each decoded message until
// it returns an error or the peer connection is closed.
//
// If a message is larger than the given maximum size,
// MsgLoop returns an appropriate error.
func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Value) error) error {
for {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if msg.Size > maxsize {
return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
}
value, err := msg.Data()
if err != nil {
return err
}
if err := f(msg.Code, value); err != nil {
return err
}
}
}
var magicToken = []byte{34, 64, 8, 145} var magicToken = []byte{34, 64, 8, 145}
func writeMsg(w io.Writer, msg Msg) error { func writeMsg(w io.Writer, msg Msg) error {
@ -103,10 +147,10 @@ func readMsg(r byteReader) (msg Msg, err error) {
// read magic and payload size // read magic and payload size
start := make([]byte, 8) start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil { if _, err = io.ReadFull(r, start); err != nil {
return msg, NewPeerError(ReadError, "%v", err) return msg, newPeerError(errRead, "%v", err)
} }
if !bytes.HasPrefix(start, magicToken) { if !bytes.HasPrefix(start, magicToken) {
return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken) return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
} }
size := binary.BigEndian.Uint32(start[4:]) size := binary.BigEndian.Uint32(start[4:])
@ -152,13 +196,13 @@ func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
} }
// readUint reads an RLP-encoded unsigned integer from r. // readUint reads an RLP-encoded unsigned integer from r.
func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) {
b, err := r.ReadByte() b, err := r.ReadByte()
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
if b < 0x80 { if b < 0x80 {
return MsgCode(b), 1, nil return uint64(b), 1, nil
} else if b < 0x89 { // max length for uint64 is 8 bytes } else if b < 0x89 { // max length for uint64 is 8 bytes
codelen = uint32(b - 0x80) codelen = uint32(b - 0x80)
if codelen == 0 { if codelen == 0 {
@ -168,7 +212,7 @@ func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
return 0, 0, err return 0, 0, err
} }
return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil return binary.BigEndian.Uint64(buf), codelen, nil
} }
return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
} }

@ -1,221 +0,0 @@
package p2p
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"time"
)
type Handlers map[string]Protocol
type proto struct {
in chan Msg
maxcode, offset MsgCode
messenger *messenger
}
func (rw *proto) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode {
return NewPeerError(InvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.messenger.writeMsg(msg)
}
func (rw *proto) ReadMsg() (Msg, error) {
msg, ok := <-rw.in
if !ok {
return msg, io.EOF
}
msg.Code -= rw.offset
return msg, nil
}
// eofSignal wraps a reader with eof signaling.
// the eof channel is closed when the wrapped reader
// reaches EOF.
type eofSignal struct {
wrapped io.Reader
eof chan struct{}
}
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
if err != nil {
close(r.eof) // tell messenger that msg has been consumed
}
return n, err
}
// messenger represents a message-oriented peer connection.
// It keeps track of the set of protocols understood
// by the remote peer.
type messenger struct {
peer *Peer
handlers Handlers
// the mutex protects the connection
// so only one protocol can write at a time.
writeMu sync.Mutex
conn net.Conn
bufconn *bufio.ReadWriter
protocolLock sync.RWMutex
protocols map[string]*proto
offsets map[MsgCode]*proto
protoWG sync.WaitGroup
err chan error
pulse chan bool
}
func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
return &messenger{
conn: conn,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
peer: peer,
handlers: handlers,
protocols: make(map[string]*proto),
err: errchan,
pulse: make(chan bool, 1),
}
}
func (m *messenger) Start() {
m.protocols[""] = m.startProto(0, "", &baseProtocol{})
go m.readLoop()
}
func (m *messenger) Stop() {
m.conn.Close()
m.protoWG.Wait()
}
const (
// maximum amount of time allowed for reading a message
msgReadTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
)
func (m *messenger) readLoop() {
defer m.closeProtocols()
for {
m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
msg, err := readMsg(m.bufconn)
if err != nil {
m.err <- err
return
}
// send ping to heartbeat channel signalling time of last message
m.pulse <- true
proto, err := m.getProto(msg.Code)
if err != nil {
m.err <- err
return
}
if msg.Size <= wholePayloadSize {
// optimization: msg is small enough, read all
// of it and move on to the next message
buf, err := ioutil.ReadAll(msg.Payload)
if err != nil {
m.err <- err
return
}
msg.Payload = bytes.NewReader(buf)
proto.in <- msg
} else {
pr := &eofSignal{msg.Payload, make(chan struct{})}
msg.Payload = pr
proto.in <- msg
<-pr.eof
}
}
}
func (m *messenger) closeProtocols() {
m.protocolLock.RLock()
for _, p := range m.protocols {
close(p.in)
}
m.protocolLock.RUnlock()
}
func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
proto := &proto{
in: make(chan Msg),
offset: offset,
maxcode: impl.Offset(),
messenger: m,
}
m.protoWG.Add(1)
go func() {
if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
logger.Errorf("protocol %q error: %v\n", name, err)
m.err <- err
}
m.protoWG.Done()
}()
return proto
}
// getProto finds the protocol responsible for handling
// the given message code.
func (m *messenger) getProto(code MsgCode) (*proto, error) {
m.protocolLock.RLock()
defer m.protocolLock.RUnlock()
for _, proto := range m.protocols {
if code >= proto.offset && code < proto.offset+proto.maxcode {
return proto, nil
}
}
return nil, NewPeerError(InvalidMsgCode, "%d", code)
}
// setProtocols starts all subprotocols shared with the
// remote peer. the protocols must be sorted alphabetically.
func (m *messenger) setRemoteProtocols(protocols []string) {
m.protocolLock.Lock()
defer m.protocolLock.Unlock()
offset := baseProtocolOffset
for _, name := range protocols {
inst, ok := m.handlers[name]
if !ok {
continue // not handled
}
m.protocols[name] = m.startProto(offset, name, inst)
offset += inst.Offset()
}
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
m.protocolLock.RLock()
proto, ok := m.protocols[protoName]
m.protocolLock.RUnlock()
if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
if msg.Code >= proto.maxcode {
return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return m.writeMsg(msg)
}
// writeMsg writes a message to the connection.
func (m *messenger) writeMsg(msg Msg) error {
m.writeMu.Lock()
defer m.writeMu.Unlock()
if err := writeMsg(m.bufconn, msg); err != nil {
return err
}
return m.bufconn.Flush()
}

@ -1,203 +0,0 @@
package p2p
import (
"bufio"
"fmt"
"io"
"log"
"net"
"os"
"reflect"
"testing"
"time"
logpkg "github.com/ethereum/go-ethereum/logger"
)
func init() {
logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
}
func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
return conn2, peer, peer.messenger
}
func performTestHandshake(r *bufio.Reader, w io.Writer) error {
// read remote handshake
msg, err := readMsg(r)
if err != nil {
return fmt.Errorf("read error: %v", err)
}
if msg.Code != handshakeMsg {
return fmt.Errorf("first message should be handshake, got %d", msg.Code)
}
if err := msg.Discard(); err != nil {
return err
}
// send empty handshake
pubkey := make([]byte, 64)
msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
return writeMsg(w, msg)
}
type testProtocol struct {
offset MsgCode
f func(MsgReadWriter)
}
func (p *testProtocol) Offset() MsgCode {
return p.offset
}
func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
p.f(rw)
return nil
}
func TestRead(t *testing.T) {
done := make(chan struct{})
handlers := Handlers{
"a": &testProtocol{5, func(rw MsgReadWriter) {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
}
data, err := msg.Data()
if err != nil {
t.Errorf("data decoding error: %v", err)
}
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
if !reflect.DeepEqual(data.Slice(), expdata) {
t.Errorf("incorrect msg data %#v", data.Slice())
}
close(done)
}},
}
net, peer, m := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
m.setRemoteProtocols([]string{"a"})
writeMsg(net, NewMsg(18, 1, "000"))
select {
case <-done:
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestWriteFromProto(t *testing.T) {
handlers := Handlers{
"a": &testProtocol{2, func(rw MsgReadWriter) {
if err := rw.WriteMsg(NewMsg(2)); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
}
if err := rw.WriteMsg(NewMsg(1)); err != nil {
t.Errorf("write error: %v", err)
}
}},
}
net, peer, mess := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
mess.setRemoteProtocols([]string{"a"})
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v")
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
for {
msg, err := rw.ReadMsg()
if err != nil {
return
}
if err = msg.Discard(); err != nil {
return
}
}
}}
func TestMessengerWriteProtoMsg(t *testing.T) {
handlers := Handlers{"a": discardProto}
net, peer, mess := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
mess.setRemoteProtocols([]string{"a"})
// test write errors
if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v")
}
// test succcessful write
read, readerr := make(chan Msg), make(chan error)
go func() {
if msg, err := readMsg(bufr); err != nil {
readerr <- err
} else {
read <- msg
}
}()
if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
case msg := <-read:
if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
case err := <-readerr:
t.Errorf("read error: %v", err)
}
}
func TestPulse(t *testing.T) {
net, peer, _ := testMessenger(nil)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
before := time.Now()
msg, err := readMsg(bufr)
if err != nil {
t.Fatalf("read error: %v", err)
}
after := time.Now()
if msg.Code != pingMsg {
t.Errorf("expected ping message, got %d", msg.Code)
}
if d := after.Sub(before); d < pingTimeout {
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
}
}

@ -3,6 +3,7 @@ package p2p
import ( import (
"fmt" "fmt"
"net" "net"
"time"
natpmp "github.com/jackpal/go-nat-pmp" natpmp "github.com/jackpal/go-nat-pmp"
) )
@ -13,38 +14,37 @@ import (
// + Register for changes to the external address. // + Register for changes to the external address.
// + Re-register port mapping when router reboots. // + Re-register port mapping when router reboots.
// + A mechanism for keeping a port mapping registered. // + A mechanism for keeping a port mapping registered.
// + Discover gateway address automatically.
type natPMPClient struct { type natPMPClient struct {
client *natpmp.Client client *natpmp.Client
} }
func NewNatPMP(gateway net.IP) (nat NAT) { // PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
// address should be the IP of your router.
func PMP(gateway net.IP) (nat NAT) {
return &natPMPClient{natpmp.NewClient(gateway)} return &natPMPClient{natpmp.NewClient(gateway)}
} }
func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { func (*natPMPClient) String() string {
return "NAT-PMP"
}
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
response, err := n.client.GetExternalAddress() response, err := n.client.GetExternalAddress()
if err != nil { if err != nil {
return return nil, err
} }
ip := response.ExternalIPAddress return response.ExternalIPAddress[:], nil
addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
return
} }
func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
description string, timeout int) (mappedExternalPort int, err error) { if lifetime <= 0 {
if timeout <= 0 { return fmt.Errorf("lifetime must not be <= 0")
err = fmt.Errorf("timeout must not be <= 0")
return
} }
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
if err != nil { return err
return
}
mappedExternalPort = int(response.MappedExternalPort)
return
} }
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -15,28 +16,46 @@ import (
"time" "time"
) )
const (
upnpDiscoverAttempts = 3
upnpDiscoverTimeout = 5 * time.Second
)
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPNP() NAT {
return &upnpNAT{}
}
type upnpNAT struct { type upnpNAT struct {
serviceURL string serviceURL string
ourIP string ourIP string
} }
func upnpDiscover(attempts int) (nat NAT, err error) { func (n *upnpNAT) String() string {
return "UPNP"
}
func (n *upnpNAT) discover() error {
if n.serviceURL != "" {
// already discovered
return nil
}
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil { if err != nil {
return return err
} }
// TODO: try on all network interfaces simultaneously.
// Broadcasting on 0.0.0.0 could select a random interface
// to send on (platform specific).
conn, err := net.ListenPacket("udp4", ":0") conn, err := net.ListenPacket("udp4", ":0")
if err != nil { if err != nil {
return return err
}
socket := conn.(*net.UDPConn)
defer socket.Close()
err = socket.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
return
} }
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
buf := bytes.NewBufferString( buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" + "M-SEARCH * HTTP/1.1\r\n" +
@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
"MX: 2\r\n\r\n") "MX: 2\r\n\r\n")
message := buf.Bytes() message := buf.Bytes()
answerBytes := make([]byte, 1024) answerBytes := make([]byte, 1024)
for i := 0; i < attempts; i++ { for i := 0; i < upnpDiscoverAttempts; i++ {
_, err = socket.WriteToUDP(message, ssdp) _, err = conn.WriteTo(message, ssdp)
if err != nil { if err != nil {
return return err
} }
var n int nn, _, err := conn.ReadFrom(answerBytes)
n, _, err = socket.ReadFromUDP(answerBytes)
if err != nil { if err != nil {
continue continue
// socket.Close()
// return
} }
answer := string(answerBytes[0:n]) answer := string(answerBytes[0:nn])
if strings.Index(answer, "\r\n"+st) < 0 { if strings.Index(answer, "\r\n"+st) < 0 {
continue continue
} }
@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
var serviceURL string var serviceURL string
serviceURL, err = getServiceURL(locURL) serviceURL, err = getServiceURL(locURL)
if err != nil { if err != nil {
return return err
} }
var ourIP string var ourIP string
ourIP, err = getOurIP() ourIP, err = getOurIP()
if err != nil { if err != nil {
return return err
} }
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} n.serviceURL = serviceURL
n.ourIP = ourIP
return nil
}
return errors.New("UPnP port discovery failed.")
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
if err := n.discover(); err != nil {
return nil, err
}
info, err := n.getStatusInfo()
return net.ParseIP(info.externalIpAddress), err
}
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
if err := n.discover(); err != nil {
return err
}
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
"</NewLeaseDuration></u:AddPortMapping>"
// TODO: check response to see if the port was forwarded
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
return err
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
if err := n.discover(); err != nil {
return err
}
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
// TODO: check response to see if the port was deleted
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
return err
}
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return return
} }
err = errors.New("UPnP port discovery failed.")
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return return
} }
@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) {
} }
return return
} }
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return
}
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
info, err := n.getStatusInfo()
if err != nil {
return
}
addr = net.ParseIP(info.externalIpAddress)
return
}
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
"</NewLeaseDuration></u:AddPortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was forwarded
// log.Println(message, response)
mappedExternalPort = externalPort
_ = response
return
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was deleted
// log.Println(message, response)
_ = response
return
}

@ -1,196 +0,0 @@
package p2p
import (
"fmt"
"math/rand"
"net"
"strconv"
"time"
)
const (
DialerTimeout = 180 //seconds
KeepAlivePeriod = 60 //minutes
portMappingUpdateInterval = 900 // seconds = 15 mins
upnpDiscoverAttempts = 3
)
// Dialer is not an interface in net, so we define one
// *net.Dialer conforms to this
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
type Network interface {
Start() error
Listener(net.Addr) (net.Listener, error)
Dialer(net.Addr) (Dialer, error)
NewAddr(string, int) (addr net.Addr, err error)
ParseAddr(string) (addr net.Addr, err error)
}
type NAT interface {
GetExternalAddress() (addr net.IP, err error)
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
}
type TCPNetwork struct {
nat NAT
natType NATType
quit chan chan bool
ports chan string
}
type NATType int
const (
NONE = iota
UPNP
PMP
)
const (
portMappingTimeout = 1200 // 20 mins
)
func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
return &TCPNetwork{
natType: natType,
ports: make(chan string),
}
}
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
return &net.Dialer{
Timeout: DialerTimeout * time.Second,
// KeepAlive: KeepAlivePeriod * time.Minute,
LocalAddr: addr,
}, nil
}
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
if self.natType == UPNP {
_, port, _ := net.SplitHostPort(addr.String())
if self.quit == nil {
self.quit = make(chan chan bool)
go self.updatePortMappings()
}
self.ports <- port
}
return net.Listen(addr.Network(), addr.String())
}
func (self *TCPNetwork) Start() (err error) {
switch self.natType {
case NONE:
case UPNP:
nat, uerr := upnpDiscover(upnpDiscoverAttempts)
if uerr != nil {
err = fmt.Errorf("UPNP failed: ", uerr)
} else {
self.nat = nat
}
case PMP:
err = fmt.Errorf("PMP not implemented")
default:
err = fmt.Errorf("Invalid NAT type: %v", self.natType)
}
return
}
func (self *TCPNetwork) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *TCPNetwork) addPortMapping(lport int) (err error) {
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
if err != nil {
logger.Errorf("unable to add port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully added port mapping on %v", lport)
}
return
}
func (self *TCPNetwork) updatePortMappings() {
timer := time.NewTimer(portMappingUpdateInterval * time.Second)
lports := []int{}
out:
for {
select {
case port := <-self.ports:
int64lport, _ := strconv.ParseInt(port, 10, 16)
lport := int(int64lport)
if err := self.addPortMapping(lport); err != nil {
lports = append(lports, lport)
}
case <-timer.C:
for lport := range lports {
if err := self.addPortMapping(lport); err != nil {
}
}
case errc := <-self.quit:
errc <- true
break out
}
}
timer.Stop()
for lport := range lports {
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully removed port mapping on %v", lport)
}
}
}
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
ip, err := self.lookupIP(host)
if err == nil {
return &net.TCPAddr{
IP: ip,
Port: port,
}, nil
}
return nil, err
}
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
host, port, err := net.SplitHostPort(address)
if err == nil {
iport, _ := strconv.Atoi(port)
addr, e := self.NewAddr(host, iport)
return addr, e
}
return nil, err
}
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
if ip = net.ParseIP(host); ip != nil {
return
}
var ips []net.IP
ips, err = net.LookupIP(host)
if err != nil {
logger.Warnln(err)
return
}
if len(ips) == 0 {
err = fmt.Errorf("No IP addresses available for %v", host)
logger.Warnln(err)
return
}
if len(ips) > 1 {
// Pick a random IP address, simulating round-robin DNS.
rand.Seed(time.Now().UTC().UnixNano())
ip = ips[rand.Intn(len(ips))]
} else {
ip = ips[0]
}
return
}

@ -1,66 +1,454 @@
package p2p package p2p
import ( import (
"bufio"
"bytes"
"fmt" "fmt"
"io"
"io/ioutil"
"net" "net"
"strconv" "sort"
"sync"
"time"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger"
) )
// peerAddr is the structure of a peer list element.
// It is also a valid net.Addr.
type peerAddr struct {
IP net.IP
Port uint64
Pubkey []byte // optional
}
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
n := addr.Network()
if n != "tcp" && n != "tcp4" && n != "tcp6" {
// for testing with non-TCP
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
}
ta := addr.(*net.TCPAddr)
return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
}
func (d peerAddr) Network() string {
if d.IP.To4() != nil {
return "tcp4"
} else {
return "tcp6"
}
}
func (d peerAddr) String() string {
return fmt.Sprintf("%v:%d", d.IP, d.Port)
}
func (d peerAddr) RlpData() interface{} {
return []interface{}{d.IP, d.Port, d.Pubkey}
}
// Peer represents a remote peer.
type Peer struct { type Peer struct {
Inbound bool // inbound (via listener) or outbound (via dialout) // Peers have all the log methods.
Address net.Addr // Use them to display messages related to the peer.
Host []byte *logger.Logger
Port uint16
Pubkey []byte infolock sync.Mutex
Id string identity ClientIdentity
Caps []string caps []Cap
peerErrorChan chan error listenAddr *peerAddr // what remote peer is listening on
messenger *messenger dialAddr *peerAddr // non-nil if dialing
peerErrorHandler *PeerErrorHandler
server *Server // The mutex protects the connection
} // so only one protocol can write at a time.
writeMu sync.Mutex
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { conn net.Conn
peerErrorChan := NewPeerErrorChannel() bufconn *bufio.ReadWriter
host, port, _ := net.SplitHostPort(address.String())
intport, _ := strconv.Atoi(port) // These fields maintain the running protocols.
peer := &Peer{ protocols []Protocol
Inbound: inbound, runBaseProtocol bool // for testing
Address: address,
Port: uint16(intport), runlock sync.RWMutex // protects running
Host: net.ParseIP(host), running map[string]*proto
peerErrorChan: peerErrorChan,
server: server, protoWG sync.WaitGroup
} protoErr chan error
peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers()) closed chan struct{}
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan) disc chan DiscReason
activity event.TypeMux // for activity events
slot int // index into Server peer list
// These fields are kept so base protocol can access them.
// TODO: this should be one or more interfaces
ourID ClientIdentity // client id of the Server
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
newPeerAddr chan<- *peerAddr // tell server about received peers
otherPeers func() []*Peer // should return the list of all peers
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
}
// NewPeer returns a peer for testing purposes.
func NewPeer(id ClientIdentity, caps []Cap) *Peer {
conn, _ := net.Pipe()
peer := newPeer(conn, nil, nil)
peer.setHandshakeInfo(id, nil, caps)
return peer return peer
} }
func (self *Peer) String() string { func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
var kind string p := newPeer(conn, server.Protocols, dialAddr)
if self.Inbound { p.ourID = server.Identity
kind = "inbound" p.newPeerAddr = server.peerConnect
} else { p.otherPeers = server.Peers
p.pubkeyHook = server.verifyPeer
p.runBaseProtocol = true
// laddr can be updated concurrently by NAT traversal.
// newServerPeer must be called with the server lock held.
if server.laddr != nil {
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
}
return p
}
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
p := &Peer{
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
conn: conn,
dialAddr: dialAddr,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
}
return p
}
// Identity returns the client identity of the remote peer. The
// identity can be nil if the peer has not yet completed the
// handshake.
func (p *Peer) Identity() ClientIdentity {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.identity
}
// Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.caps
}
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
p.infolock.Lock()
p.identity = id
p.listenAddr = laddr
p.caps = caps
p.infolock.Unlock()
}
// RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr()
}
// LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// Disconnect terminates the peer connection with the given reason.
// It returns immediately and does not wait until the connection is closed.
func (p *Peer) Disconnect(reason DiscReason) {
select {
case p.disc <- reason:
case <-p.closed:
}
}
// String implements fmt.Stringer.
func (p *Peer) String() string {
kind := "inbound"
p.infolock.Lock()
if p.dialAddr != nil {
kind = "outbound" kind = "outbound"
} }
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) p.infolock.Unlock()
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
}
const (
// maximum amount of time allowed for reading a message
msgReadTimeout = 5 * time.Second
// maximum amount of time allowed for writing a message
msgWriteTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
)
var (
inactivityTimeout = 2 * time.Second
disconnectGracePeriod = 2 * time.Second
)
func (p *Peer) loop() (reason DiscReason, err error) {
defer p.activity.Stop()
defer p.closeProtocols()
defer close(p.closed)
defer p.conn.Close()
// read loop
readMsg := make(chan Msg)
readErr := make(chan error)
readNext := make(chan bool, 1)
protoDone := make(chan struct{}, 1)
go p.readLoop(readMsg, readErr, readNext)
readNext <- true
if p.runBaseProtocol {
p.startBaseProtocol()
}
loop:
for {
select {
case msg := <-readMsg:
// a new message has arrived.
var wait bool
if wait, err = p.dispatch(msg, protoDone); err != nil {
p.Errorf("msg dispatch error: %v\n", err)
reason = discReasonForError(err)
break loop
}
if !wait {
// Msg has already been read completely, continue with next message.
readNext <- true
}
p.activity.Post(time.Now())
case <-protoDone:
// protocol has consumed the message payload,
// we can continue reading from the socket.
readNext <- true
case err := <-readErr:
// read failed. there is no need to run the
// polite disconnect sequence because the connection
// is probably dead anyway.
// TODO: handle write errors as well
return DiscNetworkError, err
case err = <-p.protoErr:
reason = discReasonForError(err)
break loop
case reason = <-p.disc:
break loop
}
}
// wait for read loop to return.
close(readNext)
<-readErr
// tell the remote end to disconnect
done := make(chan struct{})
go func() {
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
io.Copy(ioutil.Discard, p.conn)
close(done)
}()
select {
case <-done:
case <-time.After(disconnectGracePeriod):
}
return reason, err
}
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
for _ = range unblock {
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
if msg, err := readMsg(p.bufconn); err != nil {
errc <- err
} else {
msgc <- msg
}
}
close(errc)
}
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
proto, err := p.getProto(msg.Code)
if err != nil {
return false, err
}
if msg.Size <= wholePayloadSize {
// optimization: msg is small enough, read all
// of it and move on to the next message
buf, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return false, err
}
msg.Payload = bytes.NewReader(buf)
proto.in <- msg
} else {
wait = true
pr := &eofSignal{msg.Payload, protoDone}
msg.Payload = pr
proto.in <- msg
}
return wait, nil
}
func (p *Peer) startBaseProtocol() {
p.runlock.Lock()
defer p.runlock.Unlock()
p.running[""] = p.startProto(0, Protocol{
Length: baseProtocolLength,
Run: runBaseProtocol,
})
}
// startProtocols starts matching named subprotocols.
func (p *Peer) startSubprotocols(caps []Cap) {
sort.Sort(capsByName(caps))
p.runlock.Lock()
defer p.runlock.Unlock()
offset := baseProtocolLength
outer:
for _, cap := range caps {
for _, proto := range p.protocols {
if proto.Name == cap.Name &&
proto.Version == cap.Version &&
p.running[cap.Name] == nil {
p.running[cap.Name] = p.startProto(offset, proto)
offset += proto.Length
continue outer
}
}
}
}
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
rw := &proto{
in: make(chan Msg),
offset: offset,
maxcode: impl.Length,
peer: p,
}
p.protoWG.Add(1)
go func() {
err := impl.Run(p, rw)
if err == nil {
p.Infof("protocol %q returned", impl.Name)
err = newPeerError(errMisc, "protocol returned")
} else {
p.Errorf("protocol %q error: %v\n", impl.Name, err)
}
select {
case p.protoErr <- err:
case <-p.closed:
}
p.protoWG.Done()
}()
return rw
}
// getProto finds the protocol responsible for handling
// the given message code.
func (p *Peer) getProto(code uint64) (*proto, error) {
p.runlock.RLock()
defer p.runlock.RUnlock()
for _, proto := range p.running {
if code >= proto.offset && code < proto.offset+proto.maxcode {
return proto, nil
}
}
return nil, newPeerError(errInvalidMsgCode, "%d", code)
}
func (p *Peer) closeProtocols() {
p.runlock.RLock()
for _, p := range p.running {
close(p.in)
}
p.runlock.RUnlock()
p.protoWG.Wait()
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
p.runlock.RLock()
proto, ok := p.running[protoName]
p.runlock.RUnlock()
if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
if msg.Code >= proto.maxcode {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return p.writeMsg(msg, msgWriteTimeout)
}
// writeMsg writes a message to the connection.
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
p.writeMu.Lock()
defer p.writeMu.Unlock()
p.conn.SetWriteDeadline(time.Now().Add(timeout))
if err := writeMsg(p.bufconn, msg); err != nil {
return newPeerError(errWrite, "%v", err)
}
return p.bufconn.Flush()
}
type proto struct {
name string
in chan Msg
maxcode, offset uint64
peer *Peer
}
func (rw *proto) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode {
return newPeerError(errInvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.peer.writeMsg(msg, msgWriteTimeout)
} }
func (self *Peer) Write(protocol string, msg Msg) error { func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return self.messenger.writeProtoMsg(protocol, msg) return rw.WriteMsg(NewMsg(code, data))
} }
func (self *Peer) Start() { func (rw *proto) ReadMsg() (Msg, error) {
self.peerErrorHandler.Start() msg, ok := <-rw.in
self.messenger.Start() if !ok {
return msg, io.EOF
}
msg.Code -= rw.offset
return msg, nil
} }
func (self *Peer) Stop() { // eofSignal wraps a reader with eof signaling.
self.peerErrorHandler.Stop() // the eof channel is closed when the wrapped reader
self.messenger.Stop() // reaches EOF.
type eofSignal struct {
wrapped io.Reader
eof chan<- struct{}
} }
func (p *Peer) Encode() []interface{} { func (r *eofSignal) Read(buf []byte) (int, error) {
return []interface{}{p.Host, p.Port, p.Pubkey} n, err := r.wrapped.Read(buf)
if err != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
}
return n, err
} }

@ -4,71 +4,121 @@ import (
"fmt" "fmt"
) )
type ErrorCode int
const errorChanCapacity = 10
const ( const (
PacketTooLong = iota errMagicTokenMismatch = iota
PayloadTooShort errRead
MagicTokenMismatch errWrite
ReadError errMisc
WriteError errInvalidMsgCode
MiscError errInvalidMsg
InvalidMsgCode errP2PVersionMismatch
InvalidMsg errPubkeyMissing
P2PVersionMismatch errPubkeyInvalid
PubkeyMissing errPubkeyForbidden
PubkeyInvalid errProtocolBreach
PubkeyForbidden errPingTimeout
ProtocolBreach errInvalidNetworkId
PortMismatch errInvalidProtocolVersion
PingTimeout
InvalidGenesis
InvalidNetworkId
InvalidProtocolVersion
) )
var errorToString = map[ErrorCode]string{ var errorToString = map[int]string{
PacketTooLong: "Packet too long", errMagicTokenMismatch: "Magic token mismatch",
PayloadTooShort: "Payload too short", errRead: "Read error",
MagicTokenMismatch: "Magic token mismatch", errWrite: "Write error",
ReadError: "Read error", errMisc: "Misc error",
WriteError: "Write error", errInvalidMsgCode: "Invalid message code",
MiscError: "Misc error", errInvalidMsg: "Invalid message",
InvalidMsgCode: "Invalid message code", errP2PVersionMismatch: "P2P Version Mismatch",
InvalidMsg: "Invalid message", errPubkeyMissing: "Public key missing",
P2PVersionMismatch: "P2P Version Mismatch", errPubkeyInvalid: "Public key invalid",
PubkeyMissing: "Public key missing", errPubkeyForbidden: "Public key forbidden",
PubkeyInvalid: "Public key invalid", errProtocolBreach: "Protocol Breach",
PubkeyForbidden: "Public key forbidden", errPingTimeout: "Ping timeout",
ProtocolBreach: "Protocol Breach", errInvalidNetworkId: "Invalid network id",
PortMismatch: "Port mismatch", errInvalidProtocolVersion: "Invalid protocol version",
PingTimeout: "Ping timeout",
InvalidGenesis: "Invalid genesis block",
InvalidNetworkId: "Invalid network id",
InvalidProtocolVersion: "Invalid protocol version",
} }
type PeerError struct { type peerError struct {
Code ErrorCode Code int
message string message string
} }
func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { func newPeerError(code int, format string, v ...interface{}) *peerError {
desc, ok := errorToString[code] desc, ok := errorToString[code]
if !ok { if !ok {
panic("invalid error code") panic("invalid error code")
} }
format = desc + ": " + format err := &peerError{code, desc}
message := fmt.Sprintf(format, v...) if format != "" {
return &PeerError{code, message} err.message += ": " + fmt.Sprintf(format, v...)
}
return err
} }
func (self *PeerError) Error() string { func (self *peerError) Error() string {
return self.message return self.message
} }
func NewPeerErrorChannel() chan error { type DiscReason byte
return make(chan error, errorChanCapacity)
const (
DiscRequested DiscReason = 0x00
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
)
var discReasonToString = [DiscSubprotocolError + 1]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
DiscUselessPeer: "Useless peer",
DiscTooManyPeers: "Too many peers",
DiscAlreadyConnected: "Already connected",
DiscIncompatibleVersion: "Incompatible P2P protocol version",
DiscInvalidIdentity: "Invalid node identity",
DiscQuitting: "Client quitting",
DiscUnexpectedIdentity: "Unexpected identity",
DiscSelf: "Connected to self",
DiscReadTimeout: "Read timeout",
DiscSubprotocolError: "Subprotocol error",
}
func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return fmt.Sprintf("Unknown Reason(%d)", d)
}
return discReasonToString[d]
}
func discReasonForError(err error) DiscReason {
peerError, ok := err.(*peerError)
if !ok {
return DiscSubprotocolError
}
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
case errRead, errWrite, errMisc:
return DiscNetworkError
default:
return DiscSubprotocolError
}
} }

@ -1,98 +0,0 @@
package p2p
import (
"net"
)
const (
severityThreshold = 10
)
type DisconnectRequest struct {
addr net.Addr
reason DiscReason
}
type PeerErrorHandler struct {
quit chan chan bool
address net.Addr
peerDisconnect chan DisconnectRequest
severity int
errc chan error
}
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler {
return &PeerErrorHandler{
quit: make(chan chan bool),
address: address,
peerDisconnect: peerDisconnect,
errc: errc,
}
}
func (self *PeerErrorHandler) Start() {
go self.listen()
}
func (self *PeerErrorHandler) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *PeerErrorHandler) listen() {
for {
select {
case err, ok := <-self.errc:
if ok {
logger.Debugf("error %v\n", err)
go self.handle(err)
} else {
return
}
case q := <-self.quit:
q <- true
return
}
}
}
func (self *PeerErrorHandler) handle(err error) {
reason := DiscReason(' ')
peerError, ok := err.(*PeerError)
if !ok {
peerError = NewPeerError(MiscError, " %v", err)
}
switch peerError.Code {
case P2PVersionMismatch:
reason = DiscIncompatibleVersion
case PubkeyMissing, PubkeyInvalid:
reason = DiscInvalidIdentity
case PubkeyForbidden:
reason = DiscUselessPeer
case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach:
reason = DiscProtocolError
case PingTimeout:
reason = DiscReadTimeout
case ReadError, WriteError, MiscError:
reason = DiscNetworkError
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
reason = DiscSubprotocolError
default:
self.severity += self.getSeverity(peerError)
}
if self.severity >= severityThreshold {
reason = DiscSubprotocolError
}
if reason != DiscReason(' ') {
self.peerDisconnect <- DisconnectRequest{
addr: self.address,
reason: reason,
}
}
}
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
return 1
}

@ -1,34 +0,0 @@
package p2p
import (
// "fmt"
"net"
"testing"
"time"
)
func TestPeerErrorHandler(t *testing.T) {
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
peerDisconnect := make(chan DisconnectRequest)
peerErrorChan := NewPeerErrorChannel()
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan)
peh.Start()
defer peh.Stop()
for i := 0; i < 11; i++ {
select {
case <-peerDisconnect:
t.Errorf("expected no disconnect request")
default:
}
peerErrorChan <- NewPeerError(MiscError, "")
}
time.Sleep(1 * time.Millisecond)
select {
case request := <-peerDisconnect:
if request.addr.String() != address.String() {
t.Errorf("incorrect address %v != %v", request.addr, address)
}
default:
t.Errorf("expected disconnect request")
}
}

@ -1,90 +1,222 @@
package p2p package p2p
// "net" import (
"bufio"
// func TestPeer(t *testing.T) { "net"
// handlers := make(Handlers) "reflect"
// testProtocol := &TestProtocol{recv: make(chan testMsg)} "testing"
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } "time"
// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } )
// addr := &TestAddr{"test:30"}
// conn := NewTestNetworkConnection(addr) var discard = Protocol{
// _, server := SetupTestServer(handlers) Name: "discard",
// server.Handshake() Length: 1,
// peer := NewPeer(conn, addr, true, server) Run: func(p *Peer, rw MsgReadWriter) error {
// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) for {
// peer.Start() msg, err := rw.ReadMsg()
// defer peer.Stop() if err != nil {
// time.Sleep(2 * time.Millisecond) return err
// if len(conn.Out) != 1 { }
// t.Errorf("handshake not sent") if err = msg.Discard(); err != nil {
// } else { return err
// out := conn.Out[0] }
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) }
// if bytes.Compare(out, packet) != 0 { },
// t.Errorf("incorrect handshake packet %v != %v", out, packet) }
// }
// } func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
conn1, conn2 := net.Pipe()
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) id := NewSimpleClientIdentity("test", "0", "0", "public key")
// conn.In(0, packet) peer := newPeer(conn1, protos, nil)
// time.Sleep(10 * time.Millisecond) peer.ourID = id
peer.pubkeyHook = func(*peerAddr) error { return nil }
// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) errc := make(chan error, 1)
// if pro.state != handshakeReceived { go func() {
// t.Errorf("handshake not received") _, err := peer.loop()
// } errc <- err
// if peer.Port != 30 { }()
// t.Errorf("port incorrectly set") return conn2, peer, errc
// } }
// if peer.Id != "peer" {
// t.Errorf("id incorrectly set") func TestPeerProtoReadMsg(t *testing.T) {
// } defer testlog(t).detach()
// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
// t.Errorf("pubkey incorrectly set") done := make(chan struct{})
// } proto := Protocol{
// fmt.Println(peer.Caps) Name: "a",
// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { Length: 5,
// t.Errorf("protocols incorrectly set") Run: func(peer *Peer, rw MsgReadWriter) error {
// } msg, err := rw.ReadMsg()
if err != nil {
// msg := NewMsg(3) t.Errorf("read error: %v", err)
// err := peer.Write("aaa", msg) }
// if err != nil { if msg.Code != 2 {
// t.Errorf("expect no error for known protocol: %v", err) t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
// } else { }
// time.Sleep(1 * time.Millisecond) data, err := msg.Data()
// if len(conn.Out) != 2 { if err != nil {
// t.Errorf("msg not written") t.Errorf("data decoding error: %v", err)
// } else { }
// out := conn.Out[1] expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
// packet := Packet(16, 3) if !reflect.DeepEqual(data.Slice(), expdata) {
// if bytes.Compare(out, packet) != 0 { t.Errorf("incorrect msg data %#v", data.Slice())
// t.Errorf("incorrect packet %v != %v", out, packet) }
// } close(done)
// } return nil
// } },
}
// msg = NewMsg(2)
// err = peer.Write("ccc", msg) net, peer, errc := testPeer([]Protocol{proto})
// if err != nil { defer net.Close()
// t.Errorf("expect no error for known protocol: %v", err) peer.startSubprotocols([]Cap{proto.cap()})
// } else {
// time.Sleep(1 * time.Millisecond) writeMsg(net, NewMsg(18, 1, "000"))
// if len(conn.Out) != 3 { select {
// t.Errorf("msg not written") case <-done:
// } else { case err := <-errc:
// out := conn.Out[2] t.Errorf("peer returned: %v", err)
// packet := Packet(21, 2) case <-time.After(2 * time.Second):
// if bytes.Compare(out, packet) != 0 { t.Errorf("receive timeout")
// t.Errorf("incorrect packet %v != %v", out, packet) }
// } }
// }
// } func TestPeerProtoReadLargeMsg(t *testing.T) {
defer testlog(t).detach()
// err = peer.Write("bbb", msg)
// time.Sleep(1 * time.Millisecond) msgsize := uint32(10 * 1024 * 1024)
// if err == nil { done := make(chan struct{})
// t.Errorf("expect error for unknown protocol") proto := Protocol{
// } Name: "a",
// } Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Size != msgsize+4 {
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
}
msg.Discard()
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestPeerProtoEncodeMsg(t *testing.T) {
defer testlog(t).detach()
proto := Protocol{
Name: "a",
Length: 2,
Run: func(peer *Peer, rw MsgReadWriter) error {
if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
}
if err := rw.EncodeMsg(1); err != nil {
t.Errorf("write error: %v", err)
}
return nil
},
}
net, peer, _ := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
func TestPeerWrite(t *testing.T) {
defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
}
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
close(read)
}()
// test succcessful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
case <-read:
case err := <-peerErr:
t.Fatalf("peer stopped: %v", err)
}
}
func TestPeerActivity(t *testing.T) {
// shorten inactivityTimeout while this test is running
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
sub := peer.activity.Subscribe(time.Time{})
defer sub.Unsubscribe()
for i := 0; i < 6; i++ {
writeMsg(net, NewMsg(16))
select {
case <-sub.Chan():
case <-time.After(inactivityTimeout / 2):
t.Fatal("no event within ", inactivityTimeout/2)
case err := <-peerErr:
t.Fatal("peer error", err)
}
}
select {
case <-time.After(inactivityTimeout * 2):
case <-sub.Chan():
t.Fatal("got activity event while connection was inactive")
case err := <-peerErr:
t.Fatal("peer error", err)
}
}

@ -3,249 +3,185 @@ package p2p
import ( import (
"bytes" "bytes"
"net" "net"
"sort"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
) )
// Protocol is implemented by P2P subprotocols. // Protocol represents a P2P subprotocol implementation.
type Protocol interface { type Protocol struct {
// Start is called when the protocol becomes active. // Name should contain the official protocol name,
// It should read and write messages from rw. // often a three-letter word.
// Messages must be fully consumed. Name string
//
// The connection is closed when Start returns. It should return
// any protocol-level error (such as an I/O error) that is
// encountered.
Start(peer *Peer, rw MsgReadWriter) error
// Offset should return the number of message codes // Version should contain the version number of the protocol.
// used by the protocol. Version uint
Offset() MsgCode
}
type MsgReader interface { // Length should contain the number of message codes used
ReadMsg() (Msg, error) // by the protocol.
} Length uint64
type MsgWriter interface {
WriteMsg(Msg) error
}
// MsgReadWriter is passed to protocols. Protocol implementations can
// use it to write messages back to a connected peer.
type MsgReadWriter interface {
MsgReader
MsgWriter
}
type MsgHandler func(code MsgCode, data *ethutil.Value) error // Run is called in a new groutine when the protocol has been
// negotiated with a peer. It should read and write messages from
// MsgLoop reads messages off the given reader and // rw. The Payload for each message must be fully consumed.
// calls the handler function for each decoded message until //
// it returns an error or the peer connection is closed. // The peer connection is closed when Start returns. It should return
// // any protocol-level error (such as an I/O error) that is
// If a message is larger than the given maximum size, RunProtocol // encountered.
// returns an appropriate error.n Run func(peer *Peer, rw MsgReadWriter) error
func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error {
for {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if msg.Size > maxsize {
return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
}
value, err := msg.Data()
if err != nil {
return err
}
if err := handler(msg.Code, value); err != nil {
return err
}
}
}
// the ÐΞVp2p base protocol
type baseProtocol struct {
rw MsgReadWriter
peer *Peer
} }
type bpMsg struct { func (p Protocol) cap() Cap {
code MsgCode return Cap{p.Name, p.Version}
data *ethutil.Value
} }
const ( const (
p2pVersion = 0 baseProtocolVersion = 2
pingTimeout = 2 * time.Second baseProtocolLength = uint64(16)
pingGracePeriod = 2 * time.Second baseProtocolMaxMsgSize = 10 * 1024 * 1024
) )
const ( const (
// message codes // devp2p message codes
handshakeMsg = iota handshakeMsg = 0x00
discMsg discMsg = 0x01
pingMsg pingMsg = 0x02
pongMsg pongMsg = 0x03
getPeersMsg getPeersMsg = 0x04
peersMsg peersMsg = 0x05
) )
const ( // handshake is the structure of a handshake list.
baseProtocolOffset MsgCode = 16 type handshake struct {
baseProtocolMaxMsgSize = 500 * 1024 Version uint64
) ID string
Caps []Cap
type DiscReason byte ListenPort uint64
NodeID []byte
}
const ( func (h *handshake) String() string {
// Values are given explicitly instead of by iota because these values are return h.ID
// defined by the wire protocol spec; it is easier for humans to ensure }
// correctness when values are explicit. func (h *handshake) Pubkey() []byte {
DiscRequested = 0x00 return h.NodeID
DiscNetworkError = 0x01 }
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
)
var discReasonToString = [DiscSubprotocolError + 1]string{ // Cap is the structure of a peer capability.
DiscRequested: "Disconnect requested", type Cap struct {
DiscNetworkError: "Network error", Name string
DiscProtocolError: "Breach of protocol", Version uint
DiscUselessPeer: "Useless peer",
DiscTooManyPeers: "Too many peers",
DiscAlreadyConnected: "Already connected",
DiscIncompatibleVersion: "Incompatible P2P protocol version",
DiscInvalidIdentity: "Invalid node identity",
DiscQuitting: "Client quitting",
DiscUnexpectedIdentity: "Unexpected identity",
DiscSelf: "Connected to self",
DiscReadTimeout: "Read timeout",
DiscSubprotocolError: "Subprotocol error",
} }
func (d DiscReason) String() string { func (cap Cap) RlpData() interface{} {
if len(discReasonToString) < int(d) { return []interface{}{cap.Name, cap.Version}
return "Unknown"
}
return discReasonToString[d]
} }
func (bp *baseProtocol) Offset() MsgCode { type capsByName []Cap
return baseProtocolOffset
func (cs capsByName) Len() int { return len(cs) }
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
type baseProtocol struct {
rw MsgReadWriter
peer *Peer
} }
func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp.peer, bp.rw = peer, rw bp := &baseProtocol{rw, peer}
// Do the handshake. // do handshake
// TODO: disconnect is valid before handshake, too. if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
rw.WriteMsg(bp.peer.server.handshakeMsg()) return err
}
msg, err := rw.ReadMsg() msg, err := rw.ReadMsg()
if err != nil { if err != nil {
return err return err
} }
if msg.Code != handshakeMsg { if msg.Code != handshakeMsg {
return NewPeerError(ProtocolBreach, " first message must be handshake") return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
} }
data, err := msg.Data() data, err := msg.Data()
if err != nil { if err != nil {
return NewPeerError(InvalidMsg, "%v", err) return newPeerError(errInvalidMsg, "%v", err)
} }
if err := bp.handleHandshake(data); err != nil { if err := bp.handleHandshake(data); err != nil {
return err return err
} }
msgin := make(chan bpMsg) // run main loop
done := make(chan error, 1) quit := make(chan error, 1)
go func() { go func() {
done <- MsgLoop(rw, baseProtocolMaxMsgSize, quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle)
func(code MsgCode, data *ethutil.Value) error {
msgin <- bpMsg{code, data}
return nil
})
}() }()
return bp.loop(msgin, done) return bp.loop(quit)
} }
func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error { var pingTimeout = 2 * time.Second
logger.Debugf("pingpong keepalive started at %v\n", time.Now())
messenger := bp.rw.(*proto).messenger func (bp *baseProtocol) loop(quit <-chan error) error {
pingTimer := time.NewTimer(pingTimeout) ping := time.NewTimer(pingTimeout)
pinged := true activity := bp.peer.activity.Subscribe(time.Time{})
lastActive := time.Time{}
defer ping.Stop()
defer activity.Unsubscribe()
for { getPeersTick := time.NewTicker(10 * time.Second)
defer getPeersTick.Stop()
err := bp.rw.EncodeMsg(getPeersMsg)
for err == nil {
select { select {
case msg := <-msgin: case err = <-quit:
if err := bp.handle(msg.code, msg.data); err != nil {
return err
}
case err := <-quit:
return err return err
case <-messenger.pulse: case <-getPeersTick.C:
pingTimer.Reset(pingTimeout) err = bp.rw.EncodeMsg(getPeersMsg)
pinged = false case event := <-activity.Chan():
case <-pingTimer.C: ping.Reset(pingTimeout)
if pinged { lastActive = event.(time.Time)
return NewPeerError(PingTimeout, "") case t := <-ping.C:
if lastActive.Add(pingTimeout * 2).Before(t) {
err = newPeerError(errPingTimeout, "")
} else if lastActive.Add(pingTimeout).Before(t) {
err = bp.rw.EncodeMsg(pingMsg)
} }
logger.Debugf("pinging at %v\n", time.Now())
if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil {
return NewPeerError(WriteError, "%v", err)
}
pinged = true
pingTimer.Reset(pingTimeout)
} }
} }
return err
} }
func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error { func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
switch code { switch code {
case handshakeMsg: case handshakeMsg:
return NewPeerError(ProtocolBreach, " extra handshake received") return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg: case discMsg:
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint())) bp.peer.Disconnect(DiscReason(data.Get(0).Uint()))
bp.peer.server.PeerDisconnect() <- DisconnectRequest{ return nil
addr: bp.peer.Address,
reason: DiscRequested,
}
case pingMsg: case pingMsg:
return bp.rw.WriteMsg(NewMsg(pongMsg)) return bp.rw.EncodeMsg(pongMsg)
case pongMsg: case pongMsg:
// reply for ping
case getPeersMsg: case getPeersMsg:
// Peer asked for list of connected peers. peers := bp.peerList()
peersRLP := bp.peer.server.encodedPeerList() // this is dangerous. the spec says that we should _delay_
if peersRLP != nil { // sending the response if no new information is available.
msg := Msg{ // this means that would need to send a response later when
Code: peersMsg, // new peers become available.
Size: uint32(len(peersRLP)), //
Payload: bytes.NewReader(peersRLP), // TODO: add event mechanism to notify baseProtocol for new peers
} if len(peers) > 0 {
return bp.rw.WriteMsg(msg) return bp.rw.EncodeMsg(peersMsg, peers)
} }
case peersMsg: case peersMsg:
bp.handlePeers(data) bp.handlePeers(data)
default: default:
return NewPeerError(InvalidMsgCode, "unknown message code %v", code) return newPeerError(errInvalidMsgCode, "unknown message code %v", code)
} }
return nil return nil
} }
@ -253,62 +189,102 @@ func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
func (bp *baseProtocol) handlePeers(data *ethutil.Value) { func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
it := data.NewIterator() it := data.NewIterator()
for it.Next() { for it.Next() {
ip := net.IP(it.Value().Get(0).Bytes()) addr := &peerAddr{
port := it.Value().Get(1).Uint() IP: net.IP(it.Value().Get(0).Bytes()),
address := &net.TCPAddr{IP: ip, Port: int(port)} Port: it.Value().Get(1).Uint(),
go bp.peer.server.PeerConnect(address) Pubkey: it.Value().Get(2).Bytes(),
}
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
} }
} }
func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
var ( hs := handshake{
remoteVersion = c.Get(0).Uint() Version: c.Get(0).Uint(),
id = c.Get(1).Str() ID: c.Get(1).Str(),
caps = c.Get(2) Caps: nil, // decoded below
port = c.Get(3).Uint() ListenPort: c.Get(3).Uint(),
pubkey = c.Get(4).Bytes() NodeID: c.Get(4).Bytes(),
)
// Check correctness of p2p protocol version
if remoteVersion != p2pVersion {
return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion)
} }
if hs.Version != baseProtocolVersion {
// Handle the pub key (validation, uniqueness) return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
if len(pubkey) == 0 { baseProtocolVersion, hs.Version)
return NewPeerError(PubkeyMissing, "not supplied in handshake.")
} }
if len(hs.NodeID) == 0 {
if len(pubkey) != 64 { return newPeerError(errPubkeyMissing, "")
return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) }
if len(hs.NodeID) != 64 {
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
}
if da := bp.peer.dialAddr; da != nil {
// verify that the peer we wanted to connect to
// actually holds the target public key.
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
}
}
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err)
}
capsIt := c.Get(2).NewIterator()
for capsIt.Next() {
cap := capsIt.Value()
name := cap.Get(0).Str()
if name != "" {
hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())})
}
} }
// self connect detection var addr *peerAddr
if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 { if hs.ListenPort != 0 {
return NewPeerError(PubkeyForbidden, "not allowed to connect to self") addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
addr.Port = hs.ListenPort
} }
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
bp.peer.startSubprotocols(hs.Caps)
return nil
}
// register pubkey on server. this also sets the pubkey on the peer (need lock) func (bp *baseProtocol) handshakeMsg() Msg {
if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil { var (
return NewPeerError(PubkeyForbidden, err.Error()) port uint64
caps []interface{}
)
if bp.peer.ourListenAddr != nil {
port = bp.peer.ourListenAddr.Port
} }
for _, proto := range bp.peer.protocols {
caps = append(caps, proto.cap())
}
return NewMsg(handshakeMsg,
baseProtocolVersion,
bp.peer.ourID.String(),
caps,
port,
bp.peer.ourID.Pubkey()[1:],
)
}
// check port func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
if bp.peer.Inbound { peers := bp.peer.otherPeers()
uint16port := uint16(port) ds := make([]ethutil.RlpEncodable, 0, len(peers))
if bp.peer.Port > 0 && bp.peer.Port != uint16port { for _, p := range peers {
return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port) p.infolock.Lock()
} else { addr := p.listenAddr
bp.peer.Port = uint16port p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
} }
ds = append(ds, addr)
} }
ourAddr := bp.peer.ourListenAddr
capsIt := caps.NewIterator() if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
for capsIt.Next() { ds = append(ds, ourAddr)
cap := capsIt.Value().Str()
bp.peer.Caps = append(bp.peer.Caps, cap)
} }
sort.Strings(bp.peer.Caps) return ds
bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps)
bp.peer.Id = id
return nil
} }

@ -2,155 +2,101 @@ package p2p
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"net" "net"
"sort"
"strconv"
"sync" "sync"
"time" "time"
logpkg "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
) )
const ( const (
outboundAddressPoolSize = 10 outboundAddressPoolSize = 500
disconnectGracePeriod = 2 defaultDialTimeout = 10 * time.Second
portMappingUpdateInterval = 15 * time.Minute
portMappingTimeout = 20 * time.Minute
) )
type Blacklist interface { var srvlog = logger.NewLogger("P2P Server")
Get([]byte) (bool, error)
Put([]byte) error
Delete([]byte) error
Exists(pubkey []byte) (ok bool)
}
type BlacklistMap struct {
blacklist map[string]bool
lock sync.RWMutex
}
func NewBlacklist() *BlacklistMap {
return &BlacklistMap{
blacklist: make(map[string]bool),
}
}
func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
self.lock.RLock()
defer self.lock.RUnlock()
v, ok := self.blacklist[string(pubkey)]
var err error
if !ok {
err = fmt.Errorf("not found")
}
return v, err
}
func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
self.lock.RLock()
defer self.lock.RUnlock()
_, ok = self.blacklist[string(pubkey)]
return
}
func (self *BlacklistMap) Put(pubkey []byte) error {
self.lock.RLock()
defer self.lock.RUnlock()
self.blacklist[string(pubkey)] = true
return nil
}
func (self *BlacklistMap) Delete(pubkey []byte) error {
self.lock.RLock()
defer self.lock.RUnlock()
delete(self.blacklist, string(pubkey))
return nil
}
// Server manages all peer connections.
//
// The fields of Server are used as configuration parameters.
// You should set them before starting the Server. Fields may not be
// modified while the server is running.
type Server struct { type Server struct {
network Network // This field must be set to a valid client identity.
listening bool //needed? Identity ClientIdentity
dialing bool //needed?
closed bool // MaxPeers is the maximum number of peers that can be
identity ClientIdentity // connected. It must be greater than zero.
addr net.Addr MaxPeers int
port uint16
protocols []string // Protocols should contain the protocols supported
// by the server. Matching protocols are launched for
quit chan chan bool // each peer.
peersLock sync.RWMutex Protocols []Protocol
maxPeers int // If Blacklist is set to a non-nil value, the given Blacklist
peers []*Peer // is used to verify peer connections.
peerSlots chan int Blacklist Blacklist
peersTable map[string]int
peerCount int // If ListenAddr is set to a non-nil address, the server
cachedEncodedPeers []byte // will listen for incoming connections.
//
peerConnect chan net.Addr // If the port is zero, the operating system will pick a port. The
peerDisconnect chan DisconnectRequest // ListenAddr field will be updated with the actual address when
blacklist Blacklist // the server is started.
handlers Handlers ListenAddr string
}
// If set to a non-nil value, the given NAT port mapper
var logger = logpkg.NewLogger("P2P") // is used to make the listening port available to the
// Internet.
func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { NAT NAT
// get alphabetical list of protocol names from handlers map
protocols := []string{} // If Dialer is set to a non-nil value, the given Dialer
for protocol := range handlers { // is used to dial outbound peer connections.
protocols = append(protocols, protocol) Dialer *net.Dialer
}
sort.Strings(protocols) // If NoDial is true, the server will not dial any peers.
NoDial bool
_, port, _ := net.SplitHostPort(addr.String())
intport, _ := strconv.Atoi(port) // Hook for testing. This is useful because we can inhibit
// the whole protocol stack.
self := &Server{ newPeerFunc peerFunc
// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
network: network,
identity: identity,
addr: addr,
port: uint16(intport),
protocols: protocols,
quit: make(chan chan bool),
maxPeers: maxPeers,
peers: make([]*Peer, maxPeers),
peerSlots: make(chan int, maxPeers),
peersTable: make(map[string]int),
peerConnect: make(chan net.Addr, outboundAddressPoolSize), lock sync.RWMutex
peerDisconnect: make(chan DisconnectRequest), running bool
blacklist: blacklist, listener net.Listener
laddr *net.TCPAddr // real listen addr
handlers: handlers, peers []*Peer
} peerSlots chan int
for i := 0; i < maxPeers; i++ { peerCount int
self.peerSlots <- i // fill up with indexes
} quit chan struct{}
return self wg sync.WaitGroup
peerConnect chan *peerAddr
peerDisconnect chan *Peer
} }
func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) { // NAT is implemented by NAT traversal methods.
addr, err = self.network.NewAddr(host, port) type NAT interface {
return GetExternalAddress() (net.IP, error)
} AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
DeletePortMapping(protocol string, extport, intport int) error
func (self *Server) ParseAddr(address string) (addr net.Addr, err error) { // Should return name of the method.
addr, err = self.network.ParseAddr(address) String() string
return
} }
func (self *Server) ClientIdentity() ClientIdentity { type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
return self.identity
}
func (self *Server) Peers() (peers []*Peer) { // Peers returns all connected peers.
self.peersLock.RLock() func (srv *Server) Peers() (peers []*Peer) {
defer self.peersLock.RUnlock() srv.lock.RLock()
for _, peer := range self.peers { defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil { if peer != nil {
peers = append(peers, peer) peers = append(peers, peer)
} }
@ -158,331 +104,364 @@ func (self *Server) Peers() (peers []*Peer) {
return return
} }
func (self *Server) PeerCount() int { // PeerCount returns the number of connected peers.
self.peersLock.RLock() func (srv *Server) PeerCount() int {
defer self.peersLock.RUnlock() srv.lock.RLock()
return self.peerCount defer srv.lock.RUnlock()
return srv.peerCount
} }
func (self *Server) PeerConnect(addr net.Addr) { // SuggestPeer injects an address into the outbound address pool.
// TODO: should buffer, filter and uniq func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
// send GetPeersMsg if not blocking
select { select {
case self.peerConnect <- addr: // not enough peers case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}:
self.Broadcast("", getPeersMsg) default: // don't block
default: // we dont care
} }
} }
func (self *Server) PeerDisconnect() chan DisconnectRequest { // Broadcast sends an RLP-encoded message to all connected peers.
return self.peerDisconnect // This method is deprecated and will be removed later.
} func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
func (self *Server) Blacklist() Blacklist {
return self.blacklist
}
func (self *Server) Handlers() Handlers {
return self.handlers
}
func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) {
var payload []byte var payload []byte
if data != nil { if data != nil {
payload = encodePayload(data...) payload = encodePayload(data...)
} }
self.peersLock.RLock() srv.lock.RLock()
defer self.peersLock.RUnlock() defer srv.lock.RUnlock()
for _, peer := range self.peers { for _, peer := range srv.peers {
if peer != nil { if peer != nil {
var msg = Msg{Code: code} var msg = Msg{Code: code}
if data != nil { if data != nil {
msg.Payload = bytes.NewReader(payload) msg.Payload = bytes.NewReader(payload)
msg.Size = uint32(len(payload)) msg.Size = uint32(len(payload))
} }
peer.messenger.writeProtoMsg(protocol, msg) peer.writeProtoMsg(protocol, msg)
} }
} }
} }
// Start the server // Start starts running the server.
func (self *Server) Start(listen bool, dial bool) { // Servers can be re-used and started again after stopping.
self.network.Start() func (srv *Server) Start() (err error) {
if listen { srv.lock.Lock()
listener, err := self.network.Listener(self.addr) defer srv.lock.Unlock()
if err != nil { if srv.running {
logger.Warnf("Error initializing listener: %v", err) return errors.New("server already running")
logger.Warnf("Connection listening disabled") }
self.listening = false srvlog.Infoln("Starting Server")
} else {
self.listening = true // initialize fields
logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr()) if srv.Identity == nil {
go self.inboundPeerHandler(listener) return fmt.Errorf("Server.Identity must be set to a non-nil identity")
}
} }
if dial { if srv.MaxPeers <= 0 {
dialer, err := self.network.Dialer(self.addr) return fmt.Errorf("Server.MaxPeers must be > 0")
if err != nil { }
logger.Warnf("Error initializing dialer: %v", err) srv.quit = make(chan struct{})
logger.Warnf("Connection dialout disabled") srv.peers = make([]*Peer, srv.MaxPeers)
self.dialing = false srv.peerSlots = make(chan int, srv.MaxPeers)
} else { srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
self.dialing = true srv.peerDisconnect = make(chan *Peer)
logger.Infoln("Dial peers watching outbound address pool") if srv.newPeerFunc == nil {
go self.outboundPeerHandler(dialer) srv.newPeerFunc = newServerPeer
}
if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist()
}
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil {
return err
} }
} }
logger.Infoln("server started") if !srv.NoDial {
srv.wg.Add(1)
go srv.dialLoop()
}
if srv.NoDial && srv.ListenAddr == "" {
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
}
// make all slots available
for i := range srv.peers {
srv.peerSlots <- i
}
// note: discLoop is not part of WaitGroup
go srv.discLoop()
srv.running = true
return nil
} }
func (self *Server) Stop() { func (srv *Server) startListening() error {
logger.Infoln("server stopping...") listener, err := net.Listen("tcp", srv.ListenAddr)
// // quit one loop if dialing if err != nil {
if self.dialing { return err
logger.Infoln("stop dialout...") }
dialq := make(chan bool) srv.ListenAddr = listener.Addr().String()
self.quit <- dialq srv.laddr = listener.Addr().(*net.TCPAddr)
<-dialq srv.listener = listener
fmt.Println("quit another") srv.wg.Add(1)
} go srv.listenLoop()
// quit the other loop if listening if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
if self.listening { srv.wg.Add(1)
logger.Infoln("stop listening...") go srv.natLoop(srv.laddr.Port)
listenq := make(chan bool) }
self.quit <- listenq return nil
<-listenq }
fmt.Println("quit one")
} // Stop terminates the server and all active peer connections.
// It blocks until all active connections have been closed.
fmt.Println("quit waited") func (srv *Server) Stop() {
srv.lock.Lock()
logger.Infoln("stopping peers...") if !srv.running {
peers := []net.Addr{} srv.lock.Unlock()
self.peersLock.RLock() return
self.closed = true
for _, peer := range self.peers {
if peer != nil {
peers = append(peers, peer.Address)
}
} }
self.peersLock.RUnlock() srv.running = false
for _, address := range peers { srv.lock.Unlock()
go self.removePeer(DisconnectRequest{
addr: address, srvlog.Infoln("Stopping server")
reason: DiscQuitting, if srv.listener != nil {
}) // this unblocks listener Accept
srv.listener.Close()
}
close(srv.quit)
for _, peer := range srv.Peers() {
peer.Disconnect(DiscQuitting)
} }
srv.wg.Wait()
// wait till they actually disconnect // wait till they actually disconnect
// this is checked by draining the peerSlots (slots are released back if a peer is removed) // this is checked by claiming all peerSlots.
i := 0 // slots become available as the peers disconnect.
fmt.Println("draining peers") for i := 0; i < cap(srv.peerSlots); i++ {
<-srv.peerSlots
}
// terminate discLoop
close(srv.peerDisconnect)
}
func (srv *Server) discLoop() {
for peer := range srv.peerDisconnect {
// peer has just disconnected. free up its slot.
srvlog.Infof("%v is gone", peer)
srv.peerSlots <- peer.slot
srv.lock.Lock()
srv.peers[peer.slot] = nil
srv.lock.Unlock()
}
}
FOR: // main loop for adding connections via listening
func (srv *Server) listenLoop() {
defer srv.wg.Done()
srvlog.Infoln("Listening on", srv.listener.Addr())
for { for {
select { select {
case slot := <-self.peerSlots: case slot := <-srv.peerSlots:
i++ conn, err := srv.listener.Accept()
fmt.Printf("%v: found slot %v\n", i, slot) if err != nil {
if i == self.maxPeers { srv.peerSlots <- slot
break FOR return
} }
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
srv.addPeer(conn, nil, slot)
case <-srv.quit:
return
} }
} }
logger.Infoln("server stopped")
} }
// main loop for adding connections via listening func (srv *Server) natLoop(port int) {
func (self *Server) inboundPeerHandler(listener net.Listener) { defer srv.wg.Done()
for { for {
srv.updatePortMapping(port)
select { select {
case slot := <-self.peerSlots: case <-time.After(portMappingUpdateInterval):
go self.connectInboundPeer(listener, slot) // one more round
case errc := <-self.quit: case <-srv.quit:
listener.Close() srv.removePortMapping(port)
fmt.Println("quit listenloop")
errc <- true
return return
} }
} }
} }
// main loop for adding outbound peers based on peerConnect address pool func (srv *Server) updatePortMapping(port int) {
// this same loop handles peer disconnect requests as well srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
func (self *Server) outboundPeerHandler(dialer Dialer) { err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
// addressChan initially set to nil (only watches peerConnect if we need more peers) if err != nil {
var addressChan chan net.Addr srvlog.Errorln("Port mapping error:", err)
slots := self.peerSlots return
var slot *int }
extip, err := srv.NAT.GetExternalAddress()
if err != nil {
srvlog.Errorln("Error getting external IP:", err)
return
}
srv.lock.Lock()
extaddr := *(srv.listener.Addr().(*net.TCPAddr))
extaddr.IP = extip
srvlog.Infoln("Mapped port, external addr is", &extaddr)
srv.laddr = &extaddr
srv.lock.Unlock()
}
func (srv *Server) removePortMapping(port int) {
srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
srv.NAT.DeletePortMapping("tcp", port, port)
}
func (srv *Server) dialLoop() {
defer srv.wg.Done()
var (
suggest chan *peerAddr
slot *int
slots = srv.peerSlots
)
for { for {
select { select {
case i := <-slots: case i := <-slots:
// we need a peer in slot i, slot reserved // we need a peer in slot i, slot reserved
slot = &i slot = &i
// now we can watch for candidate peers in the next loop // now we can watch for candidate peers in the next loop
addressChan = self.peerConnect suggest = srv.peerConnect
// do not consume more until candidate peer is found // do not consume more until candidate peer is found
slots = nil slots = nil
case address := <-addressChan:
case desc := <-suggest:
// candidate peer found, will dial out asyncronously // candidate peer found, will dial out asyncronously
// if connection fails slot will be released // if connection fails slot will be released
go self.connectOutboundPeer(dialer, address, *slot) go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop // we can watch if more peers needed in the next loop
slots = self.peerSlots slots = srv.peerSlots
// until then we dont care about candidate peers // until then we dont care about candidate peers
addressChan = nil suggest = nil
case request := <-self.peerDisconnect:
go self.removePeer(request) case <-srv.quit:
case errc := <-self.quit: // give back the currently reserved slot
if addressChan != nil && slot != nil { if slot != nil {
self.peerSlots <- *slot srv.peerSlots <- *slot
} }
fmt.Println("quit dialloop")
errc <- true
return return
} }
} }
} }
// check if peer address already connected
func (self *Server) isConnected(address net.Addr) bool {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
_, found := self.peersTable[address.String()]
return found
}
// connect to peer via listener.Accept()
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
var address net.Addr
conn, err := listener.Accept()
if err != nil {
logger.Debugln(err)
self.peerSlots <- slot
return
}
address = conn.RemoteAddr()
// XXX: this won't work because the remote socket
// address does not identify the peer. we should
// probably get rid of this check and rely on public
// key detection in the base protocol.
if self.isConnected(address) {
conn.Close()
self.peerSlots <- slot
return
}
fmt.Printf("adding %v\n", address)
go self.addPeer(conn, address, true, slot)
}
// connect to peer via dial out // connect to peer via dial out
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { func (srv *Server) dialPeer(desc *peerAddr, slot int) {
if self.isConnected(address) { srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
return conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
}
conn, err := dialer.Dial(address.Network(), address.String())
if err != nil { if err != nil {
self.peerSlots <- slot srvlog.Errorf("Dial error: %v", err)
srv.peerSlots <- slot
return return
} }
go self.addPeer(conn, address, false, slot) go srv.addPeer(conn, desc, slot)
} }
// creates the new peer object and inserts it into its slot // creates the new peer object and inserts it into its slot
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer { func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
self.peersLock.Lock() srv.lock.Lock()
defer self.peersLock.Unlock() defer srv.lock.Unlock()
if self.closed { if !srv.running {
fmt.Println("oopsy, not no longer need peer") conn.Close()
conn.Close() //oopsy our bad srv.peerSlots <- slot // release slot
self.peerSlots <- slot // release slot
return nil return nil
} }
logger.Infoln("adding new peer", address) peer := srv.newPeerFunc(srv, conn, desc)
peer := NewPeer(conn, address, inbound, self) peer.slot = slot
self.peers[slot] = peer srv.peers[slot] = peer
self.peersTable[address.String()] = slot srv.peerCount++
self.peerCount++ go func() { peer.loop(); srv.peerDisconnect <- peer }()
self.cachedEncodedPeers = nil
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
peer.Start()
return peer return peer
} }
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
func (self *Server) removePeer(request DisconnectRequest) { func (srv *Server) removePeer(peer *Peer) {
self.peersLock.Lock() srv.lock.Lock()
defer srv.lock.Unlock()
address := request.addr srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot)
slot := self.peersTable[address.String()] if srv.peers[peer.slot] != peer {
peer := self.peers[slot] srvlog.Warnln("Invalid peer to remove:", peer)
fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot)
if peer == nil {
logger.Debugf("already removed peer on %v", address)
self.peersLock.Unlock()
return return
} }
// remove from list and index // remove from list and index
self.peerCount-- srv.peerCount--
self.peers[slot] = nil srv.peers[peer.slot] = nil
delete(self.peersTable, address.String())
self.cachedEncodedPeers = nil
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
self.peersLock.Unlock()
// sending disconnect message
disconnectMsg := NewMsg(discMsg, request.reason)
peer.Write("", disconnectMsg)
// be nice and wait
time.Sleep(disconnectGracePeriod * time.Second)
// switch off peer and close connections etc.
fmt.Println("stopping peer")
peer.Stop()
fmt.Println("stopped peer")
// release slot to signal need for a new peer, last! // release slot to signal need for a new peer, last!
self.peerSlots <- slot srv.peerSlots <- peer.slot
} }
// encodedPeerList returns an RLP-encoded list of peers. func (srv *Server) verifyPeer(addr *peerAddr) error {
// the returned slice will be nil if there are no peers. if srv.Blacklist.Exists(addr.Pubkey) {
func (self *Server) encodedPeerList() []byte { return errors.New("blacklisted")
// TODO: memoize and reset when peers change }
self.peersLock.RLock() if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
defer self.peersLock.RUnlock() return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
if self.cachedEncodedPeers == nil && self.peerCount > 0 { }
var peerData []interface{} srv.lock.RLock()
for _, i := range self.peersTable { defer srv.lock.RUnlock()
peer := self.peers[i] for _, peer := range srv.peers {
peerData = append(peerData, peer.Encode()) if peer != nil {
id := peer.Identity()
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
return errors.New("already connected")
}
} }
self.cachedEncodedPeers = encodePayload(peerData)
} }
return self.cachedEncodedPeers return nil
} }
// fix handshake message to push to peers type Blacklist interface {
func (self *Server) handshakeMsg() Msg { Get([]byte) (bool, error)
return NewMsg(handshakeMsg, Put([]byte) error
p2pVersion, Delete([]byte) error
[]byte(self.identity.String()), Exists(pubkey []byte) (ok bool)
[]interface{}{self.protocols}, }
self.port,
self.identity.Pubkey()[1:], type BlacklistMap struct {
) blacklist map[string]bool
lock sync.RWMutex
} }
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { func NewBlacklist() *BlacklistMap {
// Check for blacklisting return &BlacklistMap{
if self.blacklist.Exists(pubkey) { blacklist: make(map[string]bool),
return fmt.Errorf("blacklisted")
} }
}
self.peersLock.RLock() func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
defer self.peersLock.RUnlock() self.lock.RLock()
for _, peer := range self.peers { defer self.lock.RUnlock()
if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { v, ok := self.blacklist[string(pubkey)]
return fmt.Errorf("already connected") var err error
} if !ok {
err = fmt.Errorf("not found")
} }
candidate.Pubkey = pubkey return v, err
}
func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
self.lock.RLock()
defer self.lock.RUnlock()
_, ok = self.blacklist[string(pubkey)]
return
}
func (self *BlacklistMap) Put(pubkey []byte) error {
self.lock.RLock()
defer self.lock.RUnlock()
self.blacklist[string(pubkey)] = true
return nil
}
func (self *BlacklistMap) Delete(pubkey []byte) error {
self.lock.RLock()
defer self.lock.RUnlock()
delete(self.blacklist, string(pubkey))
return nil return nil
} }

@ -1,289 +1,161 @@
package p2p package p2p
import ( import (
"fmt" "bytes"
"io" "io"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
) )
type TestNetwork struct { func startTestServer(t *testing.T, pf peerFunc) *Server {
connections map[string]*TestNetworkConnection server := &Server{
dialer Dialer Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"),
maxinbound int MaxPeers: 10,
} ListenAddr: "127.0.0.1:0",
newPeerFunc: pf,
func NewTestNetwork(maxinbound int) *TestNetwork {
connections := make(map[string]*TestNetworkConnection)
return &TestNetwork{
connections: connections,
dialer: &TestDialer{connections},
maxinbound: maxinbound,
} }
} if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) {
return self.dialer, nil
}
func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
return &TestListener{
connections: self.connections,
addr: addr,
max: self.maxinbound,
close: make(chan struct{}),
}, nil
}
func (self *TestNetwork) Start() error {
return nil
}
func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
return
}
func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
return
}
type TestAddr struct {
name string
}
func (self *TestAddr) String() string {
return self.name
}
func (*TestAddr) Network() string {
return "test"
}
type TestDialer struct {
connections map[string]*TestNetworkConnection
}
func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
address := &TestAddr{addr}
tconn := NewTestNetworkConnection(address)
self.connections[addr] = tconn
conn = net.Conn(tconn)
return
}
type TestListener struct {
connections map[string]*TestNetworkConnection
addr net.Addr
max int
i int
close chan struct{}
}
func (self *TestListener) Accept() (net.Conn, error) {
self.i++
if self.i > self.max {
<-self.close
return nil, io.EOF
} }
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} return server
tconn := NewTestNetworkConnection(addr)
key := tconn.RemoteAddr().String()
self.connections[key] = tconn
fmt.Printf("accepted connection from: %v \n", addr)
return tconn, nil
}
func (self *TestListener) Close() error {
close(self.close)
return nil
}
func (self *TestListener) Addr() net.Addr {
return self.addr
} }
type TestNetworkConnection struct { func TestServerListen(t *testing.T) {
in chan []byte defer testlog(t).detach()
close chan struct{}
current []byte
Out [][]byte
addr net.Addr
}
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { // start the test server
return &TestNetworkConnection{ connected := make(chan *Peer)
in: make(chan []byte), srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
close: make(chan struct{}), if conn == nil {
current: []byte{}, t.Error("peer func called with nil conn")
Out: [][]byte{}, }
addr: addr, if dialAddr != nil {
t.Error("peer func called with non-nil dialAddr")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected)
defer srv.Stop()
// dial the test server
conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
if err != nil {
t.Fatalf("could not dial: %v", err)
} }
} defer conn.Close()
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { select {
time.Sleep(latency) case peer := <-connected:
for _, s := range packets { if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
self.in <- s t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.LocalAddr(), conn.RemoteAddr())
}
case <-time.After(1 * time.Second):
t.Error("server did not accept within one second")
} }
} }
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { func TestServerDial(t *testing.T) {
if len(self.current) == 0 { defer testlog(t).detach()
var ok bool
// run a fake TCP server to handle the connection.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("could not setup listener: %v")
}
defer listener.Close()
accepted := make(chan net.Conn)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Error("acccept error:", err)
}
conn.Close()
accepted <- conn
}()
// start the test server
connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
if conn == nil {
t.Error("peer func called with nil conn")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected)
defer srv.Stop()
// tell the server to connect.
connAddr := newPeerAddr(listener.Addr(), nil)
srv.peerConnect <- connAddr
select {
case conn := <-accepted:
select { select {
case self.current, ok = <-self.in: case peer := <-connected:
if !ok { if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
return 0, io.EOF t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.RemoteAddr(), conn.LocalAddr())
}
if peer.dialAddr != connAddr {
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
peer.dialAddr, connAddr)
} }
case <-self.close: case <-time.After(1 * time.Second):
return 0, io.EOF t.Error("server did not launch peer within one second")
} }
}
length := len(self.current)
if length > len(buff) {
copy(buff[:], self.current[:len(buff)])
self.current = self.current[len(buff):]
return len(buff), nil
} else {
copy(buff[:length], self.current[:])
self.current = []byte{}
return length, io.EOF
}
}
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
self.Out = append(self.Out, buff)
fmt.Printf("net write(%d): %x\n", len(self.Out), buff)
return len(buff), nil
}
func (self *TestNetworkConnection) Close() error {
close(self.close)
return nil
}
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
return
}
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { case <-time.After(1 * time.Second):
return self.addr t.Error("server did not connect within one second")
}
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
return
}
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
network = NewTestNetwork(1)
addr := &TestAddr{"test:30303"}
identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
maxPeers := 2
if handlers == nil {
handlers = make(Handlers)
} }
blackist := NewBlacklist()
server = New(network, addr, identity, handlers, maxPeers, blackist)
fmt.Println(server.identity.Pubkey())
return
} }
func TestServerListener(t *testing.T) { func TestServerBroadcast(t *testing.T) {
t.SkipNow() defer testlog(t).detach()
var connected sync.WaitGroup
network, server := SetupTestServer(nil) srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
server.Start(true, false) peer := newPeer(c, []Protocol{discard}, dialAddr)
time.Sleep(10 * time.Millisecond) peer.startSubprotocols([]Cap{discard.cap()})
server.Stop() connected.Done()
peer1, ok := network.connections["inboundpeer-1"] return peer
if !ok { })
t.Error("not found inbound peer 1") defer srv.Stop()
} else {
if len(peer1.Out) != 2 { // dial a bunch of conns
t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2) var conns = make([]net.Conn, 8)
connected.Add(len(conns))
deadline := time.Now().Add(3 * time.Second)
dialer := &net.Dialer{Deadline: deadline}
for i := range conns {
conn, err := dialer.Dial("tcp", srv.ListenAddr)
if err != nil {
t.Fatalf("conn %d: dial error: %v", i, err)
} }
defer conn.Close()
conn.SetDeadline(deadline)
conns[i] = conn
} }
} connected.Wait()
func TestServerDialer(t *testing.T) { // broadcast one message
network, server := SetupTestServer(nil) srv.Broadcast("discard", 0, "foo")
server.Start(false, true) goldbuf := new(bytes.Buffer)
server.peerConnect <- &TestAddr{"outboundpeer-1"} writeMsg(goldbuf, NewMsg(16, "foo"))
time.Sleep(10 * time.Millisecond) golden := goldbuf.Bytes()
server.Stop()
peer1, ok := network.connections["outboundpeer-1"] // check that the message has been written everywhere
if !ok { for i, conn := range conns {
t.Error("not found outbound peer 1") buf := make([]byte, len(golden))
} else { if _, err := io.ReadFull(conn, buf); err != nil {
if len(peer1.Out) != 2 { t.Errorf("conn %d: read error: %v", i, err)
t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2) } else if !bytes.Equal(buf, golden) {
t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
} }
} }
} }
// func TestServerBroadcast(t *testing.T) {
// handlers := make(Handlers)
// testProtocol := &TestProtocol{Msgs: []*Msg{}}
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
// network, server := SetupTestServer(handlers)
// server.Start(true, true)
// server.peerConnect <- &TestAddr{"outboundpeer-1"}
// time.Sleep(10 * time.Millisecond)
// msg := NewMsg(0)
// server.Broadcast("", msg)
// packet := Packet(0, 0)
// time.Sleep(10 * time.Millisecond)
// server.Stop()
// peer1, ok := network.connections["outboundpeer-1"]
// if !ok {
// t.Error("not found outbound peer 1")
// } else {
// fmt.Printf("out: %v\n", peer1.Out)
// if len(peer1.Out) != 3 {
// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
// } else {
// if bytes.Compare(peer1.Out[1], packet) != 0 {
// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
// }
// }
// }
// peer2, ok := network.connections["inboundpeer-1"]
// if !ok {
// t.Error("not found inbound peer 2")
// } else {
// fmt.Printf("out: %v\n", peer2.Out)
// if len(peer1.Out) != 3 {
// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
// } else {
// if bytes.Compare(peer2.Out[1], packet) != 0 {
// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
// }
// }
// }
// }
func TestServerPeersMessage(t *testing.T) {
t.SkipNow()
_, server := SetupTestServer(nil)
server.Start(true, true)
defer server.Stop()
server.peerConnect <- &TestAddr{"outboundpeer-1"}
time.Sleep(2000 * time.Millisecond)
pl := server.encodedPeerList()
if pl == nil {
t.Errorf("expect non-nil peer list")
}
if c := server.PeerCount(); c != 2 {
t.Errorf("expect 2 peers, got %v", c)
}
}

@ -0,0 +1,28 @@
package p2p
import (
"testing"
"github.com/ethereum/go-ethereum/logger"
)
type testLogger struct{ t *testing.T }
func testlog(t *testing.T) testLogger {
logger.Reset()
l := testLogger{t}
logger.AddLogSystem(l)
return l
}
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
l.t.Logf("%s", msg)
}
func (testLogger) detach() {
logger.Flush()
logger.Reset()
}

@ -0,0 +1,40 @@
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/obscuren/secp256k1-go"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}
Loading…
Cancel
Save