mirror of https://github.com/ethereum/go-ethereum
Whoa, one more big commit. I didn't manage to untangle the changes while working towards compatibility.pull/180/head
parent
e4a601c644
commit
59b63caf5e
@ -1,221 +0,0 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
type Handlers map[string]Protocol |
||||
|
||||
type proto struct { |
||||
in chan Msg |
||||
maxcode, offset MsgCode |
||||
messenger *messenger |
||||
} |
||||
|
||||
func (rw *proto) WriteMsg(msg Msg) error { |
||||
if msg.Code >= rw.maxcode { |
||||
return NewPeerError(InvalidMsgCode, "not handled") |
||||
} |
||||
msg.Code += rw.offset |
||||
return rw.messenger.writeMsg(msg) |
||||
} |
||||
|
||||
func (rw *proto) ReadMsg() (Msg, error) { |
||||
msg, ok := <-rw.in |
||||
if !ok { |
||||
return msg, io.EOF |
||||
} |
||||
msg.Code -= rw.offset |
||||
return msg, nil |
||||
} |
||||
|
||||
// eofSignal wraps a reader with eof signaling.
|
||||
// the eof channel is closed when the wrapped reader
|
||||
// reaches EOF.
|
||||
type eofSignal struct { |
||||
wrapped io.Reader |
||||
eof chan struct{} |
||||
} |
||||
|
||||
func (r *eofSignal) Read(buf []byte) (int, error) { |
||||
n, err := r.wrapped.Read(buf) |
||||
if err != nil { |
||||
close(r.eof) // tell messenger that msg has been consumed
|
||||
} |
||||
return n, err |
||||
} |
||||
|
||||
// messenger represents a message-oriented peer connection.
|
||||
// It keeps track of the set of protocols understood
|
||||
// by the remote peer.
|
||||
type messenger struct { |
||||
peer *Peer |
||||
handlers Handlers |
||||
|
||||
// the mutex protects the connection
|
||||
// so only one protocol can write at a time.
|
||||
writeMu sync.Mutex |
||||
conn net.Conn |
||||
bufconn *bufio.ReadWriter |
||||
|
||||
protocolLock sync.RWMutex |
||||
protocols map[string]*proto |
||||
offsets map[MsgCode]*proto |
||||
protoWG sync.WaitGroup |
||||
|
||||
err chan error |
||||
pulse chan bool |
||||
} |
||||
|
||||
func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger { |
||||
return &messenger{ |
||||
conn: conn, |
||||
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), |
||||
peer: peer, |
||||
handlers: handlers, |
||||
protocols: make(map[string]*proto), |
||||
err: errchan, |
||||
pulse: make(chan bool, 1), |
||||
} |
||||
} |
||||
|
||||
func (m *messenger) Start() { |
||||
m.protocols[""] = m.startProto(0, "", &baseProtocol{}) |
||||
go m.readLoop() |
||||
} |
||||
|
||||
func (m *messenger) Stop() { |
||||
m.conn.Close() |
||||
m.protoWG.Wait() |
||||
} |
||||
|
||||
const ( |
||||
// maximum amount of time allowed for reading a message
|
||||
msgReadTimeout = 5 * time.Second |
||||
|
||||
// messages smaller than this many bytes will be read at
|
||||
// once before passing them to a protocol.
|
||||
wholePayloadSize = 64 * 1024 |
||||
) |
||||
|
||||
func (m *messenger) readLoop() { |
||||
defer m.closeProtocols() |
||||
for { |
||||
m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) |
||||
msg, err := readMsg(m.bufconn) |
||||
if err != nil { |
||||
m.err <- err |
||||
return |
||||
} |
||||
// send ping to heartbeat channel signalling time of last message
|
||||
m.pulse <- true |
||||
proto, err := m.getProto(msg.Code) |
||||
if err != nil { |
||||
m.err <- err |
||||
return |
||||
} |
||||
if msg.Size <= wholePayloadSize { |
||||
// optimization: msg is small enough, read all
|
||||
// of it and move on to the next message
|
||||
buf, err := ioutil.ReadAll(msg.Payload) |
||||
if err != nil { |
||||
m.err <- err |
||||
return |
||||
} |
||||
msg.Payload = bytes.NewReader(buf) |
||||
proto.in <- msg |
||||
} else { |
||||
pr := &eofSignal{msg.Payload, make(chan struct{})} |
||||
msg.Payload = pr |
||||
proto.in <- msg |
||||
<-pr.eof |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (m *messenger) closeProtocols() { |
||||
m.protocolLock.RLock() |
||||
for _, p := range m.protocols { |
||||
close(p.in) |
||||
} |
||||
m.protocolLock.RUnlock() |
||||
} |
||||
|
||||
func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto { |
||||
proto := &proto{ |
||||
in: make(chan Msg), |
||||
offset: offset, |
||||
maxcode: impl.Offset(), |
||||
messenger: m, |
||||
} |
||||
m.protoWG.Add(1) |
||||
go func() { |
||||
if err := impl.Start(m.peer, proto); err != nil && err != io.EOF { |
||||
logger.Errorf("protocol %q error: %v\n", name, err) |
||||
m.err <- err |
||||
} |
||||
m.protoWG.Done() |
||||
}() |
||||
return proto |
||||
} |
||||
|
||||
// getProto finds the protocol responsible for handling
|
||||
// the given message code.
|
||||
func (m *messenger) getProto(code MsgCode) (*proto, error) { |
||||
m.protocolLock.RLock() |
||||
defer m.protocolLock.RUnlock() |
||||
for _, proto := range m.protocols { |
||||
if code >= proto.offset && code < proto.offset+proto.maxcode { |
||||
return proto, nil |
||||
} |
||||
} |
||||
return nil, NewPeerError(InvalidMsgCode, "%d", code) |
||||
} |
||||
|
||||
// setProtocols starts all subprotocols shared with the
|
||||
// remote peer. the protocols must be sorted alphabetically.
|
||||
func (m *messenger) setRemoteProtocols(protocols []string) { |
||||
m.protocolLock.Lock() |
||||
defer m.protocolLock.Unlock() |
||||
offset := baseProtocolOffset |
||||
for _, name := range protocols { |
||||
inst, ok := m.handlers[name] |
||||
if !ok { |
||||
continue // not handled
|
||||
} |
||||
m.protocols[name] = m.startProto(offset, name, inst) |
||||
offset += inst.Offset() |
||||
} |
||||
} |
||||
|
||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||
func (m *messenger) writeProtoMsg(protoName string, msg Msg) error { |
||||
m.protocolLock.RLock() |
||||
proto, ok := m.protocols[protoName] |
||||
m.protocolLock.RUnlock() |
||||
if !ok { |
||||
return fmt.Errorf("protocol %s not handled by peer", protoName) |
||||
} |
||||
if msg.Code >= proto.maxcode { |
||||
return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) |
||||
} |
||||
msg.Code += proto.offset |
||||
return m.writeMsg(msg) |
||||
} |
||||
|
||||
// writeMsg writes a message to the connection.
|
||||
func (m *messenger) writeMsg(msg Msg) error { |
||||
m.writeMu.Lock() |
||||
defer m.writeMu.Unlock() |
||||
if err := writeMsg(m.bufconn, msg); err != nil { |
||||
return err |
||||
} |
||||
return m.bufconn.Flush() |
||||
} |
@ -1,203 +0,0 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bufio" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"net" |
||||
"os" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
logpkg "github.com/ethereum/go-ethereum/logger" |
||||
) |
||||
|
||||
func init() { |
||||
logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel)) |
||||
} |
||||
|
||||
func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { |
||||
conn1, conn2 := net.Pipe() |
||||
id := NewSimpleClientIdentity("test", "0", "0", "public key") |
||||
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist()) |
||||
peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0) |
||||
return conn2, peer, peer.messenger |
||||
} |
||||
|
||||
func performTestHandshake(r *bufio.Reader, w io.Writer) error { |
||||
// read remote handshake
|
||||
msg, err := readMsg(r) |
||||
if err != nil { |
||||
return fmt.Errorf("read error: %v", err) |
||||
} |
||||
if msg.Code != handshakeMsg { |
||||
return fmt.Errorf("first message should be handshake, got %d", msg.Code) |
||||
} |
||||
if err := msg.Discard(); err != nil { |
||||
return err |
||||
} |
||||
// send empty handshake
|
||||
pubkey := make([]byte, 64) |
||||
msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey) |
||||
return writeMsg(w, msg) |
||||
} |
||||
|
||||
type testProtocol struct { |
||||
offset MsgCode |
||||
f func(MsgReadWriter) |
||||
} |
||||
|
||||
func (p *testProtocol) Offset() MsgCode { |
||||
return p.offset |
||||
} |
||||
|
||||
func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error { |
||||
p.f(rw) |
||||
return nil |
||||
} |
||||
|
||||
func TestRead(t *testing.T) { |
||||
done := make(chan struct{}) |
||||
handlers := Handlers{ |
||||
"a": &testProtocol{5, func(rw MsgReadWriter) { |
||||
msg, err := rw.ReadMsg() |
||||
if err != nil { |
||||
t.Errorf("read error: %v", err) |
||||
} |
||||
if msg.Code != 2 { |
||||
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) |
||||
} |
||||
data, err := msg.Data() |
||||
if err != nil { |
||||
t.Errorf("data decoding error: %v", err) |
||||
} |
||||
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} |
||||
if !reflect.DeepEqual(data.Slice(), expdata) { |
||||
t.Errorf("incorrect msg data %#v", data.Slice()) |
||||
} |
||||
close(done) |
||||
}}, |
||||
} |
||||
|
||||
net, peer, m := testMessenger(handlers) |
||||
defer peer.Stop() |
||||
bufr := bufio.NewReader(net) |
||||
if err := performTestHandshake(bufr, net); err != nil { |
||||
t.Fatalf("handshake failed: %v", err) |
||||
} |
||||
m.setRemoteProtocols([]string{"a"}) |
||||
|
||||
writeMsg(net, NewMsg(18, 1, "000")) |
||||
select { |
||||
case <-done: |
||||
case <-time.After(2 * time.Second): |
||||
t.Errorf("receive timeout") |
||||
} |
||||
} |
||||
|
||||
func TestWriteFromProto(t *testing.T) { |
||||
handlers := Handlers{ |
||||
"a": &testProtocol{2, func(rw MsgReadWriter) { |
||||
if err := rw.WriteMsg(NewMsg(2)); err == nil { |
||||
t.Error("expected error for out-of-range msg code, got nil") |
||||
} |
||||
if err := rw.WriteMsg(NewMsg(1)); err != nil { |
||||
t.Errorf("write error: %v", err) |
||||
} |
||||
}}, |
||||
} |
||||
net, peer, mess := testMessenger(handlers) |
||||
defer peer.Stop() |
||||
bufr := bufio.NewReader(net) |
||||
if err := performTestHandshake(bufr, net); err != nil { |
||||
t.Fatalf("handshake failed: %v", err) |
||||
} |
||||
mess.setRemoteProtocols([]string{"a"}) |
||||
|
||||
msg, err := readMsg(bufr) |
||||
if err != nil { |
||||
t.Errorf("read error: %v") |
||||
} |
||||
if msg.Code != 17 { |
||||
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) |
||||
} |
||||
} |
||||
|
||||
var discardProto = &testProtocol{1, func(rw MsgReadWriter) { |
||||
for { |
||||
msg, err := rw.ReadMsg() |
||||
if err != nil { |
||||
return |
||||
} |
||||
if err = msg.Discard(); err != nil { |
||||
return |
||||
} |
||||
} |
||||
}} |
||||
|
||||
func TestMessengerWriteProtoMsg(t *testing.T) { |
||||
handlers := Handlers{"a": discardProto} |
||||
net, peer, mess := testMessenger(handlers) |
||||
defer peer.Stop() |
||||
bufr := bufio.NewReader(net) |
||||
if err := performTestHandshake(bufr, net); err != nil { |
||||
t.Fatalf("handshake failed: %v", err) |
||||
} |
||||
mess.setRemoteProtocols([]string{"a"}) |
||||
|
||||
// test write errors
|
||||
if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil { |
||||
t.Errorf("expected error for unknown protocol, got nil") |
||||
} |
||||
if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil { |
||||
t.Errorf("expected error for out-of-range msg code, got nil") |
||||
} else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode { |
||||
t.Errorf("wrong error for out-of-range msg code, got %#v") |
||||
} |
||||
|
||||
// test succcessful write
|
||||
read, readerr := make(chan Msg), make(chan error) |
||||
go func() { |
||||
if msg, err := readMsg(bufr); err != nil { |
||||
readerr <- err |
||||
} else { |
||||
read <- msg |
||||
} |
||||
}() |
||||
if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} |
||||
select { |
||||
case msg := <-read: |
||||
if msg.Code != 16 { |
||||
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) |
||||
} |
||||
msg.Discard() |
||||
case err := <-readerr: |
||||
t.Errorf("read error: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestPulse(t *testing.T) { |
||||
net, peer, _ := testMessenger(nil) |
||||
defer peer.Stop() |
||||
bufr := bufio.NewReader(net) |
||||
if err := performTestHandshake(bufr, net); err != nil { |
||||
t.Fatalf("handshake failed: %v", err) |
||||
} |
||||
|
||||
before := time.Now() |
||||
msg, err := readMsg(bufr) |
||||
if err != nil { |
||||
t.Fatalf("read error: %v", err) |
||||
} |
||||
after := time.Now() |
||||
if msg.Code != pingMsg { |
||||
t.Errorf("expected ping message, got %d", msg.Code) |
||||
} |
||||
if d := after.Sub(before); d < pingTimeout { |
||||
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout) |
||||
} |
||||
} |
@ -1,196 +0,0 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"fmt" |
||||
"math/rand" |
||||
"net" |
||||
"strconv" |
||||
"time" |
||||
) |
||||
|
||||
const ( |
||||
DialerTimeout = 180 //seconds
|
||||
KeepAlivePeriod = 60 //minutes
|
||||
portMappingUpdateInterval = 900 // seconds = 15 mins
|
||||
upnpDiscoverAttempts = 3 |
||||
) |
||||
|
||||
// Dialer is not an interface in net, so we define one
|
||||
// *net.Dialer conforms to this
|
||||
type Dialer interface { |
||||
Dial(network, address string) (net.Conn, error) |
||||
} |
||||
|
||||
type Network interface { |
||||
Start() error |
||||
Listener(net.Addr) (net.Listener, error) |
||||
Dialer(net.Addr) (Dialer, error) |
||||
NewAddr(string, int) (addr net.Addr, err error) |
||||
ParseAddr(string) (addr net.Addr, err error) |
||||
} |
||||
|
||||
type NAT interface { |
||||
GetExternalAddress() (addr net.IP, err error) |
||||
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) |
||||
DeletePortMapping(protocol string, externalPort, internalPort int) (err error) |
||||
} |
||||
|
||||
type TCPNetwork struct { |
||||
nat NAT |
||||
natType NATType |
||||
quit chan chan bool |
||||
ports chan string |
||||
} |
||||
|
||||
type NATType int |
||||
|
||||
const ( |
||||
NONE = iota |
||||
UPNP |
||||
PMP |
||||
) |
||||
|
||||
const ( |
||||
portMappingTimeout = 1200 // 20 mins
|
||||
) |
||||
|
||||
func NewTCPNetwork(natType NATType) (net *TCPNetwork) { |
||||
return &TCPNetwork{ |
||||
natType: natType, |
||||
ports: make(chan string), |
||||
} |
||||
} |
||||
|
||||
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { |
||||
return &net.Dialer{ |
||||
Timeout: DialerTimeout * time.Second, |
||||
// KeepAlive: KeepAlivePeriod * time.Minute,
|
||||
LocalAddr: addr, |
||||
}, nil |
||||
} |
||||
|
||||
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { |
||||
if self.natType == UPNP { |
||||
_, port, _ := net.SplitHostPort(addr.String()) |
||||
if self.quit == nil { |
||||
self.quit = make(chan chan bool) |
||||
go self.updatePortMappings() |
||||
} |
||||
self.ports <- port |
||||
} |
||||
return net.Listen(addr.Network(), addr.String()) |
||||
} |
||||
|
||||
func (self *TCPNetwork) Start() (err error) { |
||||
switch self.natType { |
||||
case NONE: |
||||
case UPNP: |
||||
nat, uerr := upnpDiscover(upnpDiscoverAttempts) |
||||
if uerr != nil { |
||||
err = fmt.Errorf("UPNP failed: ", uerr) |
||||
} else { |
||||
self.nat = nat |
||||
} |
||||
case PMP: |
||||
err = fmt.Errorf("PMP not implemented") |
||||
default: |
||||
err = fmt.Errorf("Invalid NAT type: %v", self.natType) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (self *TCPNetwork) Stop() { |
||||
q := make(chan bool) |
||||
self.quit <- q |
||||
<-q |
||||
} |
||||
|
||||
func (self *TCPNetwork) addPortMapping(lport int) (err error) { |
||||
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) |
||||
if err != nil { |
||||
logger.Errorf("unable to add port mapping on %v: %v", lport, err) |
||||
} else { |
||||
logger.Debugf("succesfully added port mapping on %v", lport) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (self *TCPNetwork) updatePortMappings() { |
||||
timer := time.NewTimer(portMappingUpdateInterval * time.Second) |
||||
lports := []int{} |
||||
out: |
||||
for { |
||||
select { |
||||
case port := <-self.ports: |
||||
int64lport, _ := strconv.ParseInt(port, 10, 16) |
||||
lport := int(int64lport) |
||||
if err := self.addPortMapping(lport); err != nil { |
||||
lports = append(lports, lport) |
||||
} |
||||
case <-timer.C: |
||||
for lport := range lports { |
||||
if err := self.addPortMapping(lport); err != nil { |
||||
} |
||||
} |
||||
case errc := <-self.quit: |
||||
errc <- true |
||||
break out |
||||
} |
||||
} |
||||
|
||||
timer.Stop() |
||||
for lport := range lports { |
||||
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { |
||||
logger.Debugf("unable to remove port mapping on %v: %v", lport, err) |
||||
} else { |
||||
logger.Debugf("succesfully removed port mapping on %v", lport) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { |
||||
ip, err := self.lookupIP(host) |
||||
if err == nil { |
||||
return &net.TCPAddr{ |
||||
IP: ip, |
||||
Port: port, |
||||
}, nil |
||||
} |
||||
return nil, err |
||||
} |
||||
|
||||
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { |
||||
host, port, err := net.SplitHostPort(address) |
||||
if err == nil { |
||||
iport, _ := strconv.Atoi(port) |
||||
addr, e := self.NewAddr(host, iport) |
||||
return addr, e |
||||
} |
||||
return nil, err |
||||
} |
||||
|
||||
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { |
||||
if ip = net.ParseIP(host); ip != nil { |
||||
return |
||||
} |
||||
|
||||
var ips []net.IP |
||||
ips, err = net.LookupIP(host) |
||||
if err != nil { |
||||
logger.Warnln(err) |
||||
return |
||||
} |
||||
if len(ips) == 0 { |
||||
err = fmt.Errorf("No IP addresses available for %v", host) |
||||
logger.Warnln(err) |
||||
return |
||||
} |
||||
if len(ips) > 1 { |
||||
// Pick a random IP address, simulating round-robin DNS.
|
||||
rand.Seed(time.Now().UTC().UnixNano()) |
||||
ip = ips[rand.Intn(len(ips))] |
||||
} else { |
||||
ip = ips[0] |
||||
} |
||||
return |
||||
} |
@ -1,66 +1,454 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"net" |
||||
"strconv" |
||||
"sort" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/event" |
||||
"github.com/ethereum/go-ethereum/logger" |
||||
) |
||||
|
||||
// peerAddr is the structure of a peer list element.
|
||||
// It is also a valid net.Addr.
|
||||
type peerAddr struct { |
||||
IP net.IP |
||||
Port uint64 |
||||
Pubkey []byte // optional
|
||||
} |
||||
|
||||
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { |
||||
n := addr.Network() |
||||
if n != "tcp" && n != "tcp4" && n != "tcp6" { |
||||
// for testing with non-TCP
|
||||
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} |
||||
} |
||||
ta := addr.(*net.TCPAddr) |
||||
return &peerAddr{ta.IP, uint64(ta.Port), pubkey} |
||||
} |
||||
|
||||
func (d peerAddr) Network() string { |
||||
if d.IP.To4() != nil { |
||||
return "tcp4" |
||||
} else { |
||||
return "tcp6" |
||||
} |
||||
} |
||||
|
||||
func (d peerAddr) String() string { |
||||
return fmt.Sprintf("%v:%d", d.IP, d.Port) |
||||
} |
||||
|
||||
func (d peerAddr) RlpData() interface{} { |
||||
return []interface{}{d.IP, d.Port, d.Pubkey} |
||||
} |
||||
|
||||
// Peer represents a remote peer.
|
||||
type Peer struct { |
||||
Inbound bool // inbound (via listener) or outbound (via dialout)
|
||||
Address net.Addr |
||||
Host []byte |
||||
Port uint16 |
||||
Pubkey []byte |
||||
Id string |
||||
Caps []string |
||||
peerErrorChan chan error |
||||
messenger *messenger |
||||
peerErrorHandler *PeerErrorHandler |
||||
server *Server |
||||
} |
||||
|
||||
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { |
||||
peerErrorChan := NewPeerErrorChannel() |
||||
host, port, _ := net.SplitHostPort(address.String()) |
||||
intport, _ := strconv.Atoi(port) |
||||
peer := &Peer{ |
||||
Inbound: inbound, |
||||
Address: address, |
||||
Port: uint16(intport), |
||||
Host: net.ParseIP(host), |
||||
peerErrorChan: peerErrorChan, |
||||
server: server, |
||||
} |
||||
peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers()) |
||||
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan) |
||||
// Peers have all the log methods.
|
||||
// Use them to display messages related to the peer.
|
||||
*logger.Logger |
||||
|
||||
infolock sync.Mutex |
||||
identity ClientIdentity |
||||
caps []Cap |
||||
listenAddr *peerAddr // what remote peer is listening on
|
||||
dialAddr *peerAddr // non-nil if dialing
|
||||
|
||||
// The mutex protects the connection
|
||||
// so only one protocol can write at a time.
|
||||
writeMu sync.Mutex |
||||
conn net.Conn |
||||
bufconn *bufio.ReadWriter |
||||
|
||||
// These fields maintain the running protocols.
|
||||
protocols []Protocol |
||||
runBaseProtocol bool // for testing
|
||||
|
||||
runlock sync.RWMutex // protects running
|
||||
running map[string]*proto |
||||
|
||||
protoWG sync.WaitGroup |
||||
protoErr chan error |
||||
closed chan struct{} |
||||
disc chan DiscReason |
||||
|
||||
activity event.TypeMux // for activity events
|
||||
|
||||
slot int // index into Server peer list
|
||||
|
||||
// These fields are kept so base protocol can access them.
|
||||
// TODO: this should be one or more interfaces
|
||||
ourID ClientIdentity // client id of the Server
|
||||
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
|
||||
newPeerAddr chan<- *peerAddr // tell server about received peers
|
||||
otherPeers func() []*Peer // should return the list of all peers
|
||||
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
|
||||
} |
||||
|
||||
// NewPeer returns a peer for testing purposes.
|
||||
func NewPeer(id ClientIdentity, caps []Cap) *Peer { |
||||
conn, _ := net.Pipe() |
||||
peer := newPeer(conn, nil, nil) |
||||
peer.setHandshakeInfo(id, nil, caps) |
||||
return peer |
||||
} |
||||
|
||||
func (self *Peer) String() string { |
||||
var kind string |
||||
if self.Inbound { |
||||
kind = "inbound" |
||||
} else { |
||||
func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { |
||||
p := newPeer(conn, server.Protocols, dialAddr) |
||||
p.ourID = server.Identity |
||||
p.newPeerAddr = server.peerConnect |
||||
p.otherPeers = server.Peers |
||||
p.pubkeyHook = server.verifyPeer |
||||
p.runBaseProtocol = true |
||||
|
||||
// laddr can be updated concurrently by NAT traversal.
|
||||
// newServerPeer must be called with the server lock held.
|
||||
if server.laddr != nil { |
||||
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) |
||||
} |
||||
return p |
||||
} |
||||
|
||||
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { |
||||
p := &Peer{ |
||||
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), |
||||
conn: conn, |
||||
dialAddr: dialAddr, |
||||
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), |
||||
protocols: protocols, |
||||
running: make(map[string]*proto), |
||||
disc: make(chan DiscReason), |
||||
protoErr: make(chan error), |
||||
closed: make(chan struct{}), |
||||
} |
||||
return p |
||||
} |
||||
|
||||
// Identity returns the client identity of the remote peer. The
|
||||
// identity can be nil if the peer has not yet completed the
|
||||
// handshake.
|
||||
func (p *Peer) Identity() ClientIdentity { |
||||
p.infolock.Lock() |
||||
defer p.infolock.Unlock() |
||||
return p.identity |
||||
} |
||||
|
||||
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||
func (p *Peer) Caps() []Cap { |
||||
p.infolock.Lock() |
||||
defer p.infolock.Unlock() |
||||
return p.caps |
||||
} |
||||
|
||||
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { |
||||
p.infolock.Lock() |
||||
p.identity = id |
||||
p.listenAddr = laddr |
||||
p.caps = caps |
||||
p.infolock.Unlock() |
||||
} |
||||
|
||||
// RemoteAddr returns the remote address of the network connection.
|
||||
func (p *Peer) RemoteAddr() net.Addr { |
||||
return p.conn.RemoteAddr() |
||||
} |
||||
|
||||
// LocalAddr returns the local address of the network connection.
|
||||
func (p *Peer) LocalAddr() net.Addr { |
||||
return p.conn.LocalAddr() |
||||
} |
||||
|
||||
// Disconnect terminates the peer connection with the given reason.
|
||||
// It returns immediately and does not wait until the connection is closed.
|
||||
func (p *Peer) Disconnect(reason DiscReason) { |
||||
select { |
||||
case p.disc <- reason: |
||||
case <-p.closed: |
||||
} |
||||
} |
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (p *Peer) String() string { |
||||
kind := "inbound" |
||||
p.infolock.Lock() |
||||
if p.dialAddr != nil { |
||||
kind = "outbound" |
||||
} |
||||
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) |
||||
p.infolock.Unlock() |
||||
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) |
||||
} |
||||
|
||||
const ( |
||||
// maximum amount of time allowed for reading a message
|
||||
msgReadTimeout = 5 * time.Second |
||||
// maximum amount of time allowed for writing a message
|
||||
msgWriteTimeout = 5 * time.Second |
||||
// messages smaller than this many bytes will be read at
|
||||
// once before passing them to a protocol.
|
||||
wholePayloadSize = 64 * 1024 |
||||
) |
||||
|
||||
var ( |
||||
inactivityTimeout = 2 * time.Second |
||||
disconnectGracePeriod = 2 * time.Second |
||||
) |
||||
|
||||
func (p *Peer) loop() (reason DiscReason, err error) { |
||||
defer p.activity.Stop() |
||||
defer p.closeProtocols() |
||||
defer close(p.closed) |
||||
defer p.conn.Close() |
||||
|
||||
// read loop
|
||||
readMsg := make(chan Msg) |
||||
readErr := make(chan error) |
||||
readNext := make(chan bool, 1) |
||||
protoDone := make(chan struct{}, 1) |
||||
go p.readLoop(readMsg, readErr, readNext) |
||||
readNext <- true |
||||
|
||||
if p.runBaseProtocol { |
||||
p.startBaseProtocol() |
||||
} |
||||
|
||||
loop: |
||||
for { |
||||
select { |
||||
case msg := <-readMsg: |
||||
// a new message has arrived.
|
||||
var wait bool |
||||
if wait, err = p.dispatch(msg, protoDone); err != nil { |
||||
p.Errorf("msg dispatch error: %v\n", err) |
||||
reason = discReasonForError(err) |
||||
break loop |
||||
} |
||||
if !wait { |
||||
// Msg has already been read completely, continue with next message.
|
||||
readNext <- true |
||||
} |
||||
p.activity.Post(time.Now()) |
||||
case <-protoDone: |
||||
// protocol has consumed the message payload,
|
||||
// we can continue reading from the socket.
|
||||
readNext <- true |
||||
|
||||
case err := <-readErr: |
||||
// read failed. there is no need to run the
|
||||
// polite disconnect sequence because the connection
|
||||
// is probably dead anyway.
|
||||
// TODO: handle write errors as well
|
||||
return DiscNetworkError, err |
||||
case err = <-p.protoErr: |
||||
reason = discReasonForError(err) |
||||
break loop |
||||
case reason = <-p.disc: |
||||
break loop |
||||
} |
||||
} |
||||
|
||||
// wait for read loop to return.
|
||||
close(readNext) |
||||
<-readErr |
||||
// tell the remote end to disconnect
|
||||
done := make(chan struct{}) |
||||
go func() { |
||||
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) |
||||
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) |
||||
io.Copy(ioutil.Discard, p.conn) |
||||
close(done) |
||||
}() |
||||
select { |
||||
case <-done: |
||||
case <-time.After(disconnectGracePeriod): |
||||
} |
||||
return reason, err |
||||
} |
||||
|
||||
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { |
||||
for _ = range unblock { |
||||
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) |
||||
if msg, err := readMsg(p.bufconn); err != nil { |
||||
errc <- err |
||||
} else { |
||||
msgc <- msg |
||||
} |
||||
} |
||||
close(errc) |
||||
} |
||||
|
||||
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { |
||||
proto, err := p.getProto(msg.Code) |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
if msg.Size <= wholePayloadSize { |
||||
// optimization: msg is small enough, read all
|
||||
// of it and move on to the next message
|
||||
buf, err := ioutil.ReadAll(msg.Payload) |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
msg.Payload = bytes.NewReader(buf) |
||||
proto.in <- msg |
||||
} else { |
||||
wait = true |
||||
pr := &eofSignal{msg.Payload, protoDone} |
||||
msg.Payload = pr |
||||
proto.in <- msg |
||||
} |
||||
return wait, nil |
||||
} |
||||
|
||||
func (p *Peer) startBaseProtocol() { |
||||
p.runlock.Lock() |
||||
defer p.runlock.Unlock() |
||||
p.running[""] = p.startProto(0, Protocol{ |
||||
Length: baseProtocolLength, |
||||
Run: runBaseProtocol, |
||||
}) |
||||
} |
||||
|
||||
// startProtocols starts matching named subprotocols.
|
||||
func (p *Peer) startSubprotocols(caps []Cap) { |
||||
sort.Sort(capsByName(caps)) |
||||
|
||||
p.runlock.Lock() |
||||
defer p.runlock.Unlock() |
||||
offset := baseProtocolLength |
||||
outer: |
||||
for _, cap := range caps { |
||||
for _, proto := range p.protocols { |
||||
if proto.Name == cap.Name && |
||||
proto.Version == cap.Version && |
||||
p.running[cap.Name] == nil { |
||||
p.running[cap.Name] = p.startProto(offset, proto) |
||||
offset += proto.Length |
||||
continue outer |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (p *Peer) startProto(offset uint64, impl Protocol) *proto { |
||||
rw := &proto{ |
||||
in: make(chan Msg), |
||||
offset: offset, |
||||
maxcode: impl.Length, |
||||
peer: p, |
||||
} |
||||
p.protoWG.Add(1) |
||||
go func() { |
||||
err := impl.Run(p, rw) |
||||
if err == nil { |
||||
p.Infof("protocol %q returned", impl.Name) |
||||
err = newPeerError(errMisc, "protocol returned") |
||||
} else { |
||||
p.Errorf("protocol %q error: %v\n", impl.Name, err) |
||||
} |
||||
select { |
||||
case p.protoErr <- err: |
||||
case <-p.closed: |
||||
} |
||||
p.protoWG.Done() |
||||
}() |
||||
return rw |
||||
} |
||||
|
||||
// getProto finds the protocol responsible for handling
|
||||
// the given message code.
|
||||
func (p *Peer) getProto(code uint64) (*proto, error) { |
||||
p.runlock.RLock() |
||||
defer p.runlock.RUnlock() |
||||
for _, proto := range p.running { |
||||
if code >= proto.offset && code < proto.offset+proto.maxcode { |
||||
return proto, nil |
||||
} |
||||
} |
||||
return nil, newPeerError(errInvalidMsgCode, "%d", code) |
||||
} |
||||
|
||||
func (p *Peer) closeProtocols() { |
||||
p.runlock.RLock() |
||||
for _, p := range p.running { |
||||
close(p.in) |
||||
} |
||||
p.runlock.RUnlock() |
||||
p.protoWG.Wait() |
||||
} |
||||
|
||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { |
||||
p.runlock.RLock() |
||||
proto, ok := p.running[protoName] |
||||
p.runlock.RUnlock() |
||||
if !ok { |
||||
return fmt.Errorf("protocol %s not handled by peer", protoName) |
||||
} |
||||
if msg.Code >= proto.maxcode { |
||||
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) |
||||
} |
||||
msg.Code += proto.offset |
||||
return p.writeMsg(msg, msgWriteTimeout) |
||||
} |
||||
|
||||
// writeMsg writes a message to the connection.
|
||||
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { |
||||
p.writeMu.Lock() |
||||
defer p.writeMu.Unlock() |
||||
p.conn.SetWriteDeadline(time.Now().Add(timeout)) |
||||
if err := writeMsg(p.bufconn, msg); err != nil { |
||||
return newPeerError(errWrite, "%v", err) |
||||
} |
||||
return p.bufconn.Flush() |
||||
} |
||||
|
||||
type proto struct { |
||||
name string |
||||
in chan Msg |
||||
maxcode, offset uint64 |
||||
peer *Peer |
||||
} |
||||
|
||||
func (rw *proto) WriteMsg(msg Msg) error { |
||||
if msg.Code >= rw.maxcode { |
||||
return newPeerError(errInvalidMsgCode, "not handled") |
||||
} |
||||
msg.Code += rw.offset |
||||
return rw.peer.writeMsg(msg, msgWriteTimeout) |
||||
} |
||||
|
||||
func (self *Peer) Write(protocol string, msg Msg) error { |
||||
return self.messenger.writeProtoMsg(protocol, msg) |
||||
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { |
||||
return rw.WriteMsg(NewMsg(code, data)) |
||||
} |
||||
|
||||
func (self *Peer) Start() { |
||||
self.peerErrorHandler.Start() |
||||
self.messenger.Start() |
||||
func (rw *proto) ReadMsg() (Msg, error) { |
||||
msg, ok := <-rw.in |
||||
if !ok { |
||||
return msg, io.EOF |
||||
} |
||||
msg.Code -= rw.offset |
||||
return msg, nil |
||||
} |
||||
|
||||
func (self *Peer) Stop() { |
||||
self.peerErrorHandler.Stop() |
||||
self.messenger.Stop() |
||||
// eofSignal wraps a reader with eof signaling.
|
||||
// the eof channel is closed when the wrapped reader
|
||||
// reaches EOF.
|
||||
type eofSignal struct { |
||||
wrapped io.Reader |
||||
eof chan<- struct{} |
||||
} |
||||
|
||||
func (p *Peer) Encode() []interface{} { |
||||
return []interface{}{p.Host, p.Port, p.Pubkey} |
||||
func (r *eofSignal) Read(buf []byte) (int, error) { |
||||
n, err := r.wrapped.Read(buf) |
||||
if err != nil { |
||||
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
||||
} |
||||
return n, err |
||||
} |
||||
|
@ -1,98 +0,0 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"net" |
||||
) |
||||
|
||||
const ( |
||||
severityThreshold = 10 |
||||
) |
||||
|
||||
type DisconnectRequest struct { |
||||
addr net.Addr |
||||
reason DiscReason |
||||
} |
||||
|
||||
type PeerErrorHandler struct { |
||||
quit chan chan bool |
||||
address net.Addr |
||||
peerDisconnect chan DisconnectRequest |
||||
severity int |
||||
errc chan error |
||||
} |
||||
|
||||
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler { |
||||
return &PeerErrorHandler{ |
||||
quit: make(chan chan bool), |
||||
address: address, |
||||
peerDisconnect: peerDisconnect, |
||||
errc: errc, |
||||
} |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) Start() { |
||||
go self.listen() |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) Stop() { |
||||
q := make(chan bool) |
||||
self.quit <- q |
||||
<-q |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) listen() { |
||||
for { |
||||
select { |
||||
case err, ok := <-self.errc: |
||||
if ok { |
||||
logger.Debugf("error %v\n", err) |
||||
go self.handle(err) |
||||
} else { |
||||
return |
||||
} |
||||
case q := <-self.quit: |
||||
q <- true |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) handle(err error) { |
||||
reason := DiscReason(' ') |
||||
peerError, ok := err.(*PeerError) |
||||
if !ok { |
||||
peerError = NewPeerError(MiscError, " %v", err) |
||||
} |
||||
switch peerError.Code { |
||||
case P2PVersionMismatch: |
||||
reason = DiscIncompatibleVersion |
||||
case PubkeyMissing, PubkeyInvalid: |
||||
reason = DiscInvalidIdentity |
||||
case PubkeyForbidden: |
||||
reason = DiscUselessPeer |
||||
case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach: |
||||
reason = DiscProtocolError |
||||
case PingTimeout: |
||||
reason = DiscReadTimeout |
||||
case ReadError, WriteError, MiscError: |
||||
reason = DiscNetworkError |
||||
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: |
||||
reason = DiscSubprotocolError |
||||
default: |
||||
self.severity += self.getSeverity(peerError) |
||||
} |
||||
|
||||
if self.severity >= severityThreshold { |
||||
reason = DiscSubprotocolError |
||||
} |
||||
if reason != DiscReason(' ') { |
||||
self.peerDisconnect <- DisconnectRequest{ |
||||
addr: self.address, |
||||
reason: reason, |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { |
||||
return 1 |
||||
} |
@ -1,34 +0,0 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
// "fmt"
|
||||
"net" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestPeerErrorHandler(t *testing.T) { |
||||
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} |
||||
peerDisconnect := make(chan DisconnectRequest) |
||||
peerErrorChan := NewPeerErrorChannel() |
||||
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan) |
||||
peh.Start() |
||||
defer peh.Stop() |
||||
for i := 0; i < 11; i++ { |
||||
select { |
||||
case <-peerDisconnect: |
||||
t.Errorf("expected no disconnect request") |
||||
default: |
||||
} |
||||
peerErrorChan <- NewPeerError(MiscError, "") |
||||
} |
||||
time.Sleep(1 * time.Millisecond) |
||||
select { |
||||
case request := <-peerDisconnect: |
||||
if request.addr.String() != address.String() { |
||||
t.Errorf("incorrect address %v != %v", request.addr, address) |
||||
} |
||||
default: |
||||
t.Errorf("expected disconnect request") |
||||
} |
||||
} |
@ -1,90 +1,222 @@ |
||||
package p2p |
||||
|
||||
// "net"
|
||||
|
||||
// func TestPeer(t *testing.T) {
|
||||
// handlers := make(Handlers)
|
||||
// testProtocol := &TestProtocol{recv: make(chan testMsg)}
|
||||
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
||||
// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
|
||||
// addr := &TestAddr{"test:30"}
|
||||
// conn := NewTestNetworkConnection(addr)
|
||||
// _, server := SetupTestServer(handlers)
|
||||
// server.Handshake()
|
||||
// peer := NewPeer(conn, addr, true, server)
|
||||
// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
||||
// peer.Start()
|
||||
// defer peer.Stop()
|
||||
// time.Sleep(2 * time.Millisecond)
|
||||
// if len(conn.Out) != 1 {
|
||||
// t.Errorf("handshake not sent")
|
||||
// } else {
|
||||
// out := conn.Out[0]
|
||||
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
|
||||
// if bytes.Compare(out, packet) != 0 {
|
||||
// t.Errorf("incorrect handshake packet %v != %v", out, packet)
|
||||
// }
|
||||
// }
|
||||
|
||||
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
|
||||
// conn.In(0, packet)
|
||||
// time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
|
||||
// if pro.state != handshakeReceived {
|
||||
// t.Errorf("handshake not received")
|
||||
// }
|
||||
// if peer.Port != 30 {
|
||||
// t.Errorf("port incorrectly set")
|
||||
// }
|
||||
// if peer.Id != "peer" {
|
||||
// t.Errorf("id incorrectly set")
|
||||
// }
|
||||
// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||
// t.Errorf("pubkey incorrectly set")
|
||||
// }
|
||||
// fmt.Println(peer.Caps)
|
||||
// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
|
||||
// t.Errorf("protocols incorrectly set")
|
||||
// }
|
||||
|
||||
// msg := NewMsg(3)
|
||||
// err := peer.Write("aaa", msg)
|
||||
// if err != nil {
|
||||
// t.Errorf("expect no error for known protocol: %v", err)
|
||||
// } else {
|
||||
// time.Sleep(1 * time.Millisecond)
|
||||
// if len(conn.Out) != 2 {
|
||||
// t.Errorf("msg not written")
|
||||
// } else {
|
||||
// out := conn.Out[1]
|
||||
// packet := Packet(16, 3)
|
||||
// if bytes.Compare(out, packet) != 0 {
|
||||
// t.Errorf("incorrect packet %v != %v", out, packet)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// msg = NewMsg(2)
|
||||
// err = peer.Write("ccc", msg)
|
||||
// if err != nil {
|
||||
// t.Errorf("expect no error for known protocol: %v", err)
|
||||
// } else {
|
||||
// time.Sleep(1 * time.Millisecond)
|
||||
// if len(conn.Out) != 3 {
|
||||
// t.Errorf("msg not written")
|
||||
// } else {
|
||||
// out := conn.Out[2]
|
||||
// packet := Packet(21, 2)
|
||||
// if bytes.Compare(out, packet) != 0 {
|
||||
// t.Errorf("incorrect packet %v != %v", out, packet)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// err = peer.Write("bbb", msg)
|
||||
// time.Sleep(1 * time.Millisecond)
|
||||
// if err == nil {
|
||||
// t.Errorf("expect error for unknown protocol")
|
||||
// }
|
||||
// }
|
||||
import ( |
||||
"bufio" |
||||
"net" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
var discard = Protocol{ |
||||
Name: "discard", |
||||
Length: 1, |
||||
Run: func(p *Peer, rw MsgReadWriter) error { |
||||
for { |
||||
msg, err := rw.ReadMsg() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err = msg.Discard(); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
}, |
||||
} |
||||
|
||||
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { |
||||
conn1, conn2 := net.Pipe() |
||||
id := NewSimpleClientIdentity("test", "0", "0", "public key") |
||||
peer := newPeer(conn1, protos, nil) |
||||
peer.ourID = id |
||||
peer.pubkeyHook = func(*peerAddr) error { return nil } |
||||
errc := make(chan error, 1) |
||||
go func() { |
||||
_, err := peer.loop() |
||||
errc <- err |
||||
}() |
||||
return conn2, peer, errc |
||||
} |
||||
|
||||
func TestPeerProtoReadMsg(t *testing.T) { |
||||
defer testlog(t).detach() |
||||
|
||||
done := make(chan struct{}) |
||||
proto := Protocol{ |
||||
Name: "a", |
||||
Length: 5, |
||||
Run: func(peer *Peer, rw MsgReadWriter) error { |
||||
msg, err := rw.ReadMsg() |
||||
if err != nil { |
||||
t.Errorf("read error: %v", err) |
||||
} |
||||
if msg.Code != 2 { |
||||
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) |
||||
} |
||||
data, err := msg.Data() |
||||
if err != nil { |
||||
t.Errorf("data decoding error: %v", err) |
||||
} |
||||
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} |
||||
if !reflect.DeepEqual(data.Slice(), expdata) { |
||||
t.Errorf("incorrect msg data %#v", data.Slice()) |
||||
} |
||||
close(done) |
||||
return nil |
||||
}, |
||||
} |
||||
|
||||
net, peer, errc := testPeer([]Protocol{proto}) |
||||
defer net.Close() |
||||
peer.startSubprotocols([]Cap{proto.cap()}) |
||||
|
||||
writeMsg(net, NewMsg(18, 1, "000")) |
||||
select { |
||||
case <-done: |
||||
case err := <-errc: |
||||
t.Errorf("peer returned: %v", err) |
||||
case <-time.After(2 * time.Second): |
||||
t.Errorf("receive timeout") |
||||
} |
||||
} |
||||
|
||||
func TestPeerProtoReadLargeMsg(t *testing.T) { |
||||
defer testlog(t).detach() |
||||
|
||||
msgsize := uint32(10 * 1024 * 1024) |
||||
done := make(chan struct{}) |
||||
proto := Protocol{ |
||||
Name: "a", |
||||
Length: 5, |
||||
Run: func(peer *Peer, rw MsgReadWriter) error { |
||||
msg, err := rw.ReadMsg() |
||||
if err != nil { |
||||
t.Errorf("read error: %v", err) |
||||
} |
||||
if msg.Size != msgsize+4 { |
||||
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) |
||||
} |
||||
msg.Discard() |
||||
close(done) |
||||
return nil |
||||
}, |
||||
} |
||||
|
||||
net, peer, errc := testPeer([]Protocol{proto}) |
||||
defer net.Close() |
||||
peer.startSubprotocols([]Cap{proto.cap()}) |
||||
|
||||
writeMsg(net, NewMsg(18, make([]byte, msgsize))) |
||||
select { |
||||
case <-done: |
||||
case err := <-errc: |
||||
t.Errorf("peer returned: %v", err) |
||||
case <-time.After(2 * time.Second): |
||||
t.Errorf("receive timeout") |
||||
} |
||||
} |
||||
|
||||
func TestPeerProtoEncodeMsg(t *testing.T) { |
||||
defer testlog(t).detach() |
||||
|
||||
proto := Protocol{ |
||||
Name: "a", |
||||
Length: 2, |
||||
Run: func(peer *Peer, rw MsgReadWriter) error { |
||||
if err := rw.EncodeMsg(2); err == nil { |
||||
t.Error("expected error for out-of-range msg code, got nil") |
||||
} |
||||
if err := rw.EncodeMsg(1); err != nil { |
||||
t.Errorf("write error: %v", err) |
||||
} |
||||
return nil |
||||
}, |
||||
} |
||||
net, peer, _ := testPeer([]Protocol{proto}) |
||||
defer net.Close() |
||||
peer.startSubprotocols([]Cap{proto.cap()}) |
||||
|
||||
bufr := bufio.NewReader(net) |
||||
msg, err := readMsg(bufr) |
||||
if err != nil { |
||||
t.Errorf("read error: %v", err) |
||||
} |
||||
if msg.Code != 17 { |
||||
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) |
||||
} |
||||
} |
||||
|
||||
func TestPeerWrite(t *testing.T) { |
||||
defer testlog(t).detach() |
||||
|
||||
net, peer, peerErr := testPeer([]Protocol{discard}) |
||||
defer net.Close() |
||||
peer.startSubprotocols([]Cap{discard.cap()}) |
||||
|
||||
// test write errors
|
||||
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { |
||||
t.Errorf("expected error for unknown protocol, got nil") |
||||
} |
||||
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { |
||||
t.Errorf("expected error for out-of-range msg code, got nil") |
||||
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { |
||||
t.Errorf("wrong error for out-of-range msg code, got %#v", err) |
||||
} |
||||
|
||||
// setup for reading the message on the other end
|
||||
read := make(chan struct{}) |
||||
go func() { |
||||
bufr := bufio.NewReader(net) |
||||
msg, err := readMsg(bufr) |
||||
if err != nil { |
||||
t.Errorf("read error: %v", err) |
||||
} else if msg.Code != 16 { |
||||
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) |
||||
} |
||||
msg.Discard() |
||||
close(read) |
||||
}() |
||||
|
||||
// test succcessful write
|
||||
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} |
||||
select { |
||||
case <-read: |
||||
case err := <-peerErr: |
||||
t.Fatalf("peer stopped: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestPeerActivity(t *testing.T) { |
||||
// shorten inactivityTimeout while this test is running
|
||||
oldT := inactivityTimeout |
||||
defer func() { inactivityTimeout = oldT }() |
||||
inactivityTimeout = 20 * time.Millisecond |
||||
|
||||
net, peer, peerErr := testPeer([]Protocol{discard}) |
||||
defer net.Close() |
||||
peer.startSubprotocols([]Cap{discard.cap()}) |
||||
|
||||
sub := peer.activity.Subscribe(time.Time{}) |
||||
defer sub.Unsubscribe() |
||||
|
||||
for i := 0; i < 6; i++ { |
||||
writeMsg(net, NewMsg(16)) |
||||
select { |
||||
case <-sub.Chan(): |
||||
case <-time.After(inactivityTimeout / 2): |
||||
t.Fatal("no event within ", inactivityTimeout/2) |
||||
case err := <-peerErr: |
||||
t.Fatal("peer error", err) |
||||
} |
||||
} |
||||
|
||||
select { |
||||
case <-time.After(inactivityTimeout * 2): |
||||
case <-sub.Chan(): |
||||
t.Fatal("got activity event while connection was inactive") |
||||
case err := <-peerErr: |
||||
t.Fatal("peer error", err) |
||||
} |
||||
} |
||||
|
@ -0,0 +1,28 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/ethereum/go-ethereum/logger" |
||||
) |
||||
|
||||
type testLogger struct{ t *testing.T } |
||||
|
||||
func testlog(t *testing.T) testLogger { |
||||
logger.Reset() |
||||
l := testLogger{t} |
||||
logger.AddLogSystem(l) |
||||
return l |
||||
} |
||||
|
||||
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } |
||||
func (testLogger) SetLogLevel(logger.LogLevel) {} |
||||
|
||||
func (l testLogger) LogPrint(level logger.LogLevel, msg string) { |
||||
l.t.Logf("%s", msg) |
||||
} |
||||
|
||||
func (testLogger) detach() { |
||||
logger.Flush() |
||||
logger.Reset() |
||||
} |
@ -0,0 +1,40 @@ |
||||
// +build none
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"fmt" |
||||
"log" |
||||
"net" |
||||
"os" |
||||
|
||||
"github.com/ethereum/go-ethereum/logger" |
||||
"github.com/ethereum/go-ethereum/p2p" |
||||
"github.com/obscuren/secp256k1-go" |
||||
) |
||||
|
||||
func main() { |
||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) |
||||
|
||||
pub, _ := secp256k1.GenerateKeyPair() |
||||
srv := p2p.Server{ |
||||
MaxPeers: 10, |
||||
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), |
||||
ListenAddr: ":30303", |
||||
NAT: p2p.PMP(net.ParseIP("10.0.0.1")), |
||||
} |
||||
if err := srv.Start(); err != nil { |
||||
fmt.Println("could not start server:", err) |
||||
os.Exit(1) |
||||
} |
||||
|
||||
// add seed peers
|
||||
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") |
||||
if err != nil { |
||||
fmt.Println("couldn't resolve:", err) |
||||
os.Exit(1) |
||||
} |
||||
srv.SuggestPeer(seed.IP, seed.Port, nil) |
||||
|
||||
select {} |
||||
} |
Loading…
Reference in new issue