mirror of https://github.com/ethereum/go-ethereum
commit
195b2d2ebd
@ -1,275 +0,0 @@ |
|||||||
package p2p |
|
||||||
|
|
||||||
import ( |
|
||||||
"bytes" |
|
||||||
// "fmt"
|
|
||||||
"net" |
|
||||||
"time" |
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil" |
|
||||||
) |
|
||||||
|
|
||||||
type Connection struct { |
|
||||||
conn net.Conn |
|
||||||
// conn NetworkConnection
|
|
||||||
timeout time.Duration |
|
||||||
in chan []byte |
|
||||||
out chan []byte |
|
||||||
err chan *PeerError |
|
||||||
closingIn chan chan bool |
|
||||||
closingOut chan chan bool |
|
||||||
} |
|
||||||
|
|
||||||
// const readBufferLength = 2 //for testing
|
|
||||||
|
|
||||||
const readBufferLength = 1440 |
|
||||||
const partialsQueueSize = 10 |
|
||||||
const maxPendingQueueSize = 1 |
|
||||||
const defaultTimeout = 500 |
|
||||||
|
|
||||||
var magicToken = []byte{34, 64, 8, 145} |
|
||||||
|
|
||||||
func (self *Connection) Open() { |
|
||||||
go self.startRead() |
|
||||||
go self.startWrite() |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) Close() { |
|
||||||
self.closeIn() |
|
||||||
self.closeOut() |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) closeIn() { |
|
||||||
errc := make(chan bool) |
|
||||||
self.closingIn <- errc |
|
||||||
<-errc |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) closeOut() { |
|
||||||
errc := make(chan bool) |
|
||||||
self.closingOut <- errc |
|
||||||
<-errc |
|
||||||
} |
|
||||||
|
|
||||||
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection { |
|
||||||
return &Connection{ |
|
||||||
conn: conn, |
|
||||||
timeout: defaultTimeout, |
|
||||||
in: make(chan []byte), |
|
||||||
out: make(chan []byte), |
|
||||||
err: errchan, |
|
||||||
closingIn: make(chan chan bool, 1), |
|
||||||
closingOut: make(chan chan bool, 1), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) Read() <-chan []byte { |
|
||||||
return self.in |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) Write() chan<- []byte { |
|
||||||
return self.out |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) Error() <-chan *PeerError { |
|
||||||
return self.err |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) startRead() { |
|
||||||
payloads := make(chan []byte) |
|
||||||
done := make(chan *PeerError) |
|
||||||
pending := [][]byte{} |
|
||||||
var head []byte |
|
||||||
var wait time.Duration // initally 0 (no delay)
|
|
||||||
read := time.After(wait * time.Millisecond) |
|
||||||
|
|
||||||
for { |
|
||||||
// if pending empty, nil channel blocks
|
|
||||||
var in chan []byte |
|
||||||
if len(pending) > 0 { |
|
||||||
in = self.in // enable send case
|
|
||||||
head = pending[0] |
|
||||||
} else { |
|
||||||
in = nil |
|
||||||
} |
|
||||||
|
|
||||||
select { |
|
||||||
case <-read: |
|
||||||
go self.read(payloads, done) |
|
||||||
case err := <-done: |
|
||||||
if err == nil { // no error but nothing to read
|
|
||||||
if len(pending) < maxPendingQueueSize { |
|
||||||
wait = 100 |
|
||||||
} else if wait == 0 { |
|
||||||
wait = 100 |
|
||||||
} else { |
|
||||||
wait = 2 * wait |
|
||||||
} |
|
||||||
} else { |
|
||||||
self.err <- err // report error
|
|
||||||
wait = 100 |
|
||||||
} |
|
||||||
read = time.After(wait * time.Millisecond) |
|
||||||
case payload := <-payloads: |
|
||||||
pending = append(pending, payload) |
|
||||||
if len(pending) < maxPendingQueueSize { |
|
||||||
wait = 0 |
|
||||||
} else { |
|
||||||
wait = 100 |
|
||||||
} |
|
||||||
read = time.After(wait * time.Millisecond) |
|
||||||
case in <- head: |
|
||||||
pending = pending[1:] |
|
||||||
case errc := <-self.closingIn: |
|
||||||
errc <- true |
|
||||||
close(self.in) |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) startWrite() { |
|
||||||
pending := [][]byte{} |
|
||||||
done := make(chan *PeerError) |
|
||||||
writing := false |
|
||||||
for { |
|
||||||
if len(pending) > 0 && !writing { |
|
||||||
writing = true |
|
||||||
go self.write(pending[0], done) |
|
||||||
} |
|
||||||
select { |
|
||||||
case payload := <-self.out: |
|
||||||
pending = append(pending, payload) |
|
||||||
case err := <-done: |
|
||||||
if err == nil { |
|
||||||
pending = pending[1:] |
|
||||||
writing = false |
|
||||||
} else { |
|
||||||
self.err <- err // report error
|
|
||||||
} |
|
||||||
case errc := <-self.closingOut: |
|
||||||
errc <- true |
|
||||||
close(self.out) |
|
||||||
return |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func pack(payload []byte) (packet []byte) { |
|
||||||
length := ethutil.NumberToBytes(uint32(len(payload)), 32) |
|
||||||
// return error if too long?
|
|
||||||
// Write magic token and payload length (first 8 bytes)
|
|
||||||
packet = append(magicToken, length...) |
|
||||||
packet = append(packet, payload...) |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
func avoidPanic(done chan *PeerError) { |
|
||||||
if rec := recover(); rec != nil { |
|
||||||
err := NewPeerError(MiscError, " %v", rec) |
|
||||||
logger.Debugln(err) |
|
||||||
done <- err |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) write(payload []byte, done chan *PeerError) { |
|
||||||
defer avoidPanic(done) |
|
||||||
var err *PeerError |
|
||||||
_, ok := self.conn.Write(pack(payload)) |
|
||||||
if ok != nil { |
|
||||||
err = NewPeerError(WriteError, " %v", ok) |
|
||||||
logger.Debugln(err) |
|
||||||
} |
|
||||||
done <- err |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) read(payloads chan []byte, done chan *PeerError) { |
|
||||||
//defer avoidPanic(done)
|
|
||||||
|
|
||||||
partials := make(chan []byte, partialsQueueSize) |
|
||||||
errc := make(chan *PeerError) |
|
||||||
go self.readPartials(partials, errc) |
|
||||||
|
|
||||||
packet := []byte{} |
|
||||||
length := 8 |
|
||||||
start := true |
|
||||||
var err *PeerError |
|
||||||
out: |
|
||||||
for { |
|
||||||
// appends partials read via connection until packet is
|
|
||||||
// - either parseable (>=8bytes)
|
|
||||||
// - or complete (payload fully consumed)
|
|
||||||
for len(packet) < length { |
|
||||||
partial, ok := <-partials |
|
||||||
if !ok { // partials channel is closed
|
|
||||||
err = <-errc |
|
||||||
if err == nil && len(packet) > 0 { |
|
||||||
if start { |
|
||||||
err = NewPeerError(PacketTooShort, "%v", packet) |
|
||||||
} else { |
|
||||||
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length) |
|
||||||
} |
|
||||||
} |
|
||||||
break out |
|
||||||
} |
|
||||||
packet = append(packet, partial...) |
|
||||||
} |
|
||||||
if start { |
|
||||||
// at least 8 bytes read, can validate packet
|
|
||||||
if bytes.Compare(magicToken, packet[:4]) != 0 { |
|
||||||
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4]) |
|
||||||
break |
|
||||||
} |
|
||||||
length = int(ethutil.BytesToNumber(packet[4:8])) |
|
||||||
packet = packet[8:] |
|
||||||
|
|
||||||
if length > 0 { |
|
||||||
start = false // now consuming payload
|
|
||||||
} else { //penalize peer but read on
|
|
||||||
self.err <- NewPeerError(EmptyPayload, "") |
|
||||||
length = 8 |
|
||||||
} |
|
||||||
} else { |
|
||||||
// packet complete (payload fully consumed)
|
|
||||||
payloads <- packet[:length] |
|
||||||
packet = packet[length:] // resclice packet
|
|
||||||
start = true |
|
||||||
length = 8 |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// this stops partials read via the connection, should we?
|
|
||||||
//if err != nil {
|
|
||||||
// select {
|
|
||||||
// case errc <- err
|
|
||||||
// default:
|
|
||||||
//}
|
|
||||||
done <- err |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) { |
|
||||||
defer close(partials) |
|
||||||
for { |
|
||||||
// Give buffering some time
|
|
||||||
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond)) |
|
||||||
buffer := make([]byte, readBufferLength) |
|
||||||
// read partial from connection
|
|
||||||
bytesRead, err := self.conn.Read(buffer) |
|
||||||
if err == nil || err.Error() == "EOF" { |
|
||||||
if bytesRead > 0 { |
|
||||||
partials <- buffer[:bytesRead] |
|
||||||
} |
|
||||||
if err != nil && err.Error() == "EOF" { |
|
||||||
break |
|
||||||
} |
|
||||||
} else { |
|
||||||
// unexpected error, report to errc
|
|
||||||
err := NewPeerError(ReadError, " %v", err) |
|
||||||
logger.Debugln(err) |
|
||||||
errc <- err |
|
||||||
return // will close partials channel
|
|
||||||
} |
|
||||||
} |
|
||||||
close(errc) |
|
||||||
} |
|
@ -1,222 +0,0 @@ |
|||||||
package p2p |
|
||||||
|
|
||||||
import ( |
|
||||||
"bytes" |
|
||||||
"fmt" |
|
||||||
"io" |
|
||||||
"net" |
|
||||||
"testing" |
|
||||||
"time" |
|
||||||
) |
|
||||||
|
|
||||||
type TestNetworkConnection struct { |
|
||||||
in chan []byte |
|
||||||
current []byte |
|
||||||
Out [][]byte |
|
||||||
addr net.Addr |
|
||||||
} |
|
||||||
|
|
||||||
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { |
|
||||||
return &TestNetworkConnection{ |
|
||||||
in: make(chan []byte), |
|
||||||
current: []byte{}, |
|
||||||
Out: [][]byte{}, |
|
||||||
addr: addr, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { |
|
||||||
time.Sleep(latency) |
|
||||||
for _, s := range packets { |
|
||||||
self.in <- s |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { |
|
||||||
if len(self.current) == 0 { |
|
||||||
select { |
|
||||||
case self.current = <-self.in: |
|
||||||
default: |
|
||||||
return 0, io.EOF |
|
||||||
} |
|
||||||
} |
|
||||||
length := len(self.current) |
|
||||||
if length > len(buff) { |
|
||||||
copy(buff[:], self.current[:len(buff)]) |
|
||||||
self.current = self.current[len(buff):] |
|
||||||
return len(buff), nil |
|
||||||
} else { |
|
||||||
copy(buff[:length], self.current[:]) |
|
||||||
self.current = []byte{} |
|
||||||
return length, io.EOF |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { |
|
||||||
self.Out = append(self.Out, buff) |
|
||||||
fmt.Printf("net write %v\n%v\n", len(self.Out), buff) |
|
||||||
return len(buff), nil |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Close() (err error) { |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { |
|
||||||
return self.addr |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
func setupConnection() (*Connection, *TestNetworkConnection) { |
|
||||||
addr := &TestAddr{"test:30303"} |
|
||||||
net := NewTestNetworkConnection(addr) |
|
||||||
conn := NewConnection(net, NewPeerErrorChannel()) |
|
||||||
conn.Open() |
|
||||||
return conn, net |
|
||||||
} |
|
||||||
|
|
||||||
func TestReadingNilPacket(t *testing.T) { |
|
||||||
conn, net := setupConnection() |
|
||||||
go net.In(0, []byte{}) |
|
||||||
// time.Sleep(10 * time.Millisecond)
|
|
||||||
select { |
|
||||||
case packet := <-conn.Read(): |
|
||||||
t.Errorf("read %v", packet) |
|
||||||
case err := <-conn.Error(): |
|
||||||
t.Errorf("incorrect error %v", err) |
|
||||||
default: |
|
||||||
} |
|
||||||
conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
func TestReadingShortPacket(t *testing.T) { |
|
||||||
conn, net := setupConnection() |
|
||||||
go net.In(0, []byte{0}) |
|
||||||
select { |
|
||||||
case packet := <-conn.Read(): |
|
||||||
t.Errorf("read %v", packet) |
|
||||||
case err := <-conn.Error(): |
|
||||||
if err.Code != PacketTooShort { |
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort) |
|
||||||
} |
|
||||||
} |
|
||||||
conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
func TestReadingInvalidPacket(t *testing.T) { |
|
||||||
conn, net := setupConnection() |
|
||||||
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0}) |
|
||||||
select { |
|
||||||
case packet := <-conn.Read(): |
|
||||||
t.Errorf("read %v", packet) |
|
||||||
case err := <-conn.Error(): |
|
||||||
if err.Code != MagicTokenMismatch { |
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch) |
|
||||||
} |
|
||||||
} |
|
||||||
conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
func TestReadingInvalidPayload(t *testing.T) { |
|
||||||
conn, net := setupConnection() |
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0}) |
|
||||||
select { |
|
||||||
case packet := <-conn.Read(): |
|
||||||
t.Errorf("read %v", packet) |
|
||||||
case err := <-conn.Error(): |
|
||||||
if err.Code != PayloadTooShort { |
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort) |
|
||||||
} |
|
||||||
} |
|
||||||
conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
func TestReadingEmptyPayload(t *testing.T) { |
|
||||||
conn, net := setupConnection() |
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0}) |
|
||||||
time.Sleep(10 * time.Millisecond) |
|
||||||
select { |
|
||||||
case packet := <-conn.Read(): |
|
||||||
t.Errorf("read %v", packet) |
|
||||||
default: |
|
||||||
} |
|
||||||
select { |
|
||||||
case err := <-conn.Error(): |
|
||||||
code := err.Code |
|
||||||
if code != EmptyPayload { |
|
||||||
t.Errorf("incorrect error, expected EmptyPayload, got %v", code) |
|
||||||
} |
|
||||||
default: |
|
||||||
t.Errorf("no error, expected EmptyPayload") |
|
||||||
} |
|
||||||
conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
func TestReadingCompletePacket(t *testing.T) { |
|
||||||
conn, net := setupConnection() |
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1}) |
|
||||||
time.Sleep(10 * time.Millisecond) |
|
||||||
select { |
|
||||||
case packet := <-conn.Read(): |
|
||||||
if bytes.Compare(packet, []byte{1}) != 0 { |
|
||||||
t.Errorf("incorrect payload read") |
|
||||||
} |
|
||||||
case err := <-conn.Error(): |
|
||||||
t.Errorf("incorrect error %v", err) |
|
||||||
default: |
|
||||||
t.Errorf("nothing read") |
|
||||||
} |
|
||||||
conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
func TestReadingTwoCompletePackets(t *testing.T) { |
|
||||||
conn, net := setupConnection() |
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1}) |
|
||||||
|
|
||||||
for i := 0; i < 2; i++ { |
|
||||||
time.Sleep(10 * time.Millisecond) |
|
||||||
select { |
|
||||||
case packet := <-conn.Read(): |
|
||||||
if bytes.Compare(packet, []byte{byte(i)}) != 0 { |
|
||||||
t.Errorf("incorrect payload read") |
|
||||||
} |
|
||||||
case err := <-conn.Error(): |
|
||||||
t.Errorf("incorrect error %v", err) |
|
||||||
default: |
|
||||||
t.Errorf("nothing read") |
|
||||||
} |
|
||||||
} |
|
||||||
conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
func TestWriting(t *testing.T) { |
|
||||||
conn, net := setupConnection() |
|
||||||
conn.Write() <- []byte{0} |
|
||||||
time.Sleep(10 * time.Millisecond) |
|
||||||
if len(net.Out) == 0 { |
|
||||||
t.Errorf("no output") |
|
||||||
} else { |
|
||||||
out := net.Out[0] |
|
||||||
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 { |
|
||||||
t.Errorf("incorrect packet %v", out) |
|
||||||
} |
|
||||||
} |
|
||||||
conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
|
|
@ -1,75 +1,155 @@ |
|||||||
package p2p |
package p2p |
||||||
|
|
||||||
import ( |
import ( |
||||||
// "fmt"
|
"bytes" |
||||||
|
"encoding/binary" |
||||||
|
"io" |
||||||
|
"io/ioutil" |
||||||
|
"math/big" |
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil" |
"github.com/ethereum/go-ethereum/ethutil" |
||||||
|
"github.com/ethereum/go-ethereum/rlp" |
||||||
) |
) |
||||||
|
|
||||||
type MsgCode uint8 |
// Msg defines the structure of a p2p message.
|
||||||
|
//
|
||||||
|
// Note that a Msg can only be sent once since the Payload reader is
|
||||||
|
// consumed during sending. It is not possible to create a Msg and
|
||||||
|
// send it any number of times. If you want to reuse an encoded
|
||||||
|
// structure, encode the payload into a byte array and create a
|
||||||
|
// separate Msg with a bytes.Reader as Payload for each send.
|
||||||
type Msg struct { |
type Msg struct { |
||||||
code MsgCode // this is the raw code as per adaptive msg code scheme
|
Code uint64 |
||||||
data *ethutil.Value |
Size uint32 // size of the paylod
|
||||||
encoded []byte |
Payload io.Reader |
||||||
|
} |
||||||
|
|
||||||
|
// NewMsg creates an RLP-encoded message with the given code.
|
||||||
|
func NewMsg(code uint64, params ...interface{}) Msg { |
||||||
|
buf := new(bytes.Buffer) |
||||||
|
for _, p := range params { |
||||||
|
buf.Write(ethutil.Encode(p)) |
||||||
|
} |
||||||
|
return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf} |
||||||
|
} |
||||||
|
|
||||||
|
func encodePayload(params ...interface{}) []byte { |
||||||
|
buf := new(bytes.Buffer) |
||||||
|
for _, p := range params { |
||||||
|
buf.Write(ethutil.Encode(p)) |
||||||
|
} |
||||||
|
return buf.Bytes() |
||||||
|
} |
||||||
|
|
||||||
|
// Decode parse the RLP content of a message into
|
||||||
|
// the given value, which must be a pointer.
|
||||||
|
//
|
||||||
|
// For the decoding rules, please see package rlp.
|
||||||
|
func (msg Msg) Decode(val interface{}) error { |
||||||
|
s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) |
||||||
|
return s.Decode(val) |
||||||
} |
} |
||||||
|
|
||||||
func (self *Msg) Code() MsgCode { |
// Discard reads any remaining payload data into a black hole.
|
||||||
return self.code |
func (msg Msg) Discard() error { |
||||||
|
_, err := io.Copy(ioutil.Discard, msg.Payload) |
||||||
|
return err |
||||||
} |
} |
||||||
|
|
||||||
func (self *Msg) Data() *ethutil.Value { |
type MsgReader interface { |
||||||
return self.data |
ReadMsg() (Msg, error) |
||||||
} |
} |
||||||
|
|
||||||
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { |
type MsgWriter interface { |
||||||
|
// WriteMsg sends an existing message.
|
||||||
|
// The Payload reader of the message is consumed.
|
||||||
|
// Note that messages can be sent only once.
|
||||||
|
WriteMsg(Msg) error |
||||||
|
|
||||||
// // data := [][]interface{}{}
|
// EncodeMsg writes an RLP-encoded message with the given
|
||||||
// data := []interface{}{}
|
// code and data elements.
|
||||||
// for _, value := range params {
|
EncodeMsg(code uint64, data ...interface{}) error |
||||||
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
|
|
||||||
// data = append(data, encodable.RlpValue())
|
|
||||||
// } else if raw, ok := value.([]interface{}); ok {
|
|
||||||
// data = append(data, raw)
|
|
||||||
// } else {
|
|
||||||
// // data = append(data, interface{}(raw))
|
|
||||||
// err = fmt.Errorf("Unable to encode object of type %T", value)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
return &Msg{ |
|
||||||
code: code, |
|
||||||
data: ethutil.NewValue(interface{}(params)), |
|
||||||
}, nil |
|
||||||
} |
} |
||||||
|
|
||||||
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { |
// MsgReadWriter provides reading and writing of encoded messages.
|
||||||
value := ethutil.NewValueFromBytes(encoded) |
type MsgReadWriter interface { |
||||||
// Type of message
|
MsgReader |
||||||
code := value.Get(0).Uint() |
MsgWriter |
||||||
// Actual data
|
} |
||||||
data := value.SliceFrom(1) |
|
||||||
|
var magicToken = []byte{34, 64, 8, 145} |
||||||
|
|
||||||
|
func writeMsg(w io.Writer, msg Msg) error { |
||||||
|
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
|
||||||
|
code := ethutil.Encode(uint32(msg.Code)) |
||||||
|
listhdr := makeListHeader(msg.Size + uint32(len(code))) |
||||||
|
payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size |
||||||
|
|
||||||
|
start := make([]byte, 8) |
||||||
|
copy(start, magicToken) |
||||||
|
binary.BigEndian.PutUint32(start[4:], payloadLen) |
||||||
|
|
||||||
msg = &Msg{ |
for _, b := range [][]byte{start, listhdr, code} { |
||||||
code: MsgCode(code), |
if _, err := w.Write(b); err != nil { |
||||||
data: data, |
return err |
||||||
// data: ethutil.NewValue(data),
|
|
||||||
encoded: encoded, |
|
||||||
} |
} |
||||||
return |
} |
||||||
|
_, err := io.CopyN(w, msg.Payload, int64(msg.Size)) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func makeListHeader(length uint32) []byte { |
||||||
|
if length < 56 { |
||||||
|
return []byte{byte(length + 0xc0)} |
||||||
|
} |
||||||
|
enc := big.NewInt(int64(length)).Bytes() |
||||||
|
lenb := byte(len(enc)) + 0xf7 |
||||||
|
return append([]byte{lenb}, enc...) |
||||||
|
} |
||||||
|
|
||||||
|
// readMsg reads a message header from r.
|
||||||
|
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
|
||||||
|
func readMsg(r rlp.ByteReader) (msg Msg, err error) { |
||||||
|
// read magic and payload size
|
||||||
|
start := make([]byte, 8) |
||||||
|
if _, err = io.ReadFull(r, start); err != nil { |
||||||
|
return msg, newPeerError(errRead, "%v", err) |
||||||
|
} |
||||||
|
if !bytes.HasPrefix(start, magicToken) { |
||||||
|
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) |
||||||
|
} |
||||||
|
size := binary.BigEndian.Uint32(start[4:]) |
||||||
|
|
||||||
|
// decode start of RLP message to get the message code
|
||||||
|
posr := &postrack{r, 0} |
||||||
|
s := rlp.NewStream(posr) |
||||||
|
if _, err := s.List(); err != nil { |
||||||
|
return msg, err |
||||||
|
} |
||||||
|
code, err := s.Uint() |
||||||
|
if err != nil { |
||||||
|
return msg, err |
||||||
|
} |
||||||
|
payloadsize := size - posr.p |
||||||
|
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// postrack wraps an rlp.ByteReader with a position counter.
|
||||||
|
type postrack struct { |
||||||
|
r rlp.ByteReader |
||||||
|
p uint32 |
||||||
} |
} |
||||||
|
|
||||||
func (self *Msg) Decode(offset MsgCode) { |
func (r *postrack) Read(buf []byte) (int, error) { |
||||||
self.code = self.code - offset |
n, err := r.r.Read(buf) |
||||||
|
r.p += uint32(n) |
||||||
|
return n, err |
||||||
} |
} |
||||||
|
|
||||||
// encode takes an offset argument to implement adaptive message coding
|
func (r *postrack) ReadByte() (byte, error) { |
||||||
// the encoded message is memoized to make msgs relayed to several peers more efficient
|
b, err := r.r.ReadByte() |
||||||
func (self *Msg) Encode(offset MsgCode) (res []byte) { |
if err == nil { |
||||||
if len(self.encoded) == 0 { |
r.p++ |
||||||
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode() |
|
||||||
self.encoded = res |
|
||||||
} else { |
|
||||||
res = self.encoded |
|
||||||
} |
} |
||||||
return |
return b, err |
||||||
} |
} |
||||||
|
@ -1,38 +1,70 @@ |
|||||||
package p2p |
package p2p |
||||||
|
|
||||||
import ( |
import ( |
||||||
|
"bytes" |
||||||
|
"io/ioutil" |
||||||
"testing" |
"testing" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/ethutil" |
||||||
) |
) |
||||||
|
|
||||||
func TestNewMsg(t *testing.T) { |
func TestNewMsg(t *testing.T) { |
||||||
msg, _ := NewMsg(3, 1, "000") |
msg := NewMsg(3, 1, "000") |
||||||
if msg.Code() != 3 { |
if msg.Code != 3 { |
||||||
t.Errorf("incorrect code %v", msg.Code()) |
t.Errorf("incorrect code %d, want %d", msg.Code) |
||||||
} |
} |
||||||
data0 := msg.Data().Get(0).Uint() |
if msg.Size != 5 { |
||||||
data1 := string(msg.Data().Get(1).Bytes()) |
t.Errorf("incorrect size %d, want %d", msg.Size, 5) |
||||||
if data0 != 1 { |
|
||||||
t.Errorf("incorrect data %v", data0) |
|
||||||
} |
} |
||||||
if data1 != "000" { |
pl, _ := ioutil.ReadAll(msg.Payload) |
||||||
t.Errorf("incorrect data %v", data1) |
expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30} |
||||||
|
if !bytes.Equal(pl, expect) { |
||||||
|
t.Errorf("incorrect payload content, got %x, want %x", pl, expect) |
||||||
} |
} |
||||||
} |
} |
||||||
|
|
||||||
func TestEncodeDecodeMsg(t *testing.T) { |
func TestEncodeDecodeMsg(t *testing.T) { |
||||||
msg, _ := NewMsg(3, 1, "000") |
msg := NewMsg(3, 1, "000") |
||||||
encoded := msg.Encode(3) |
buf := new(bytes.Buffer) |
||||||
msg, _ = NewMsgFromBytes(encoded) |
if err := writeMsg(buf, msg); err != nil { |
||||||
msg.Decode(3) |
t.Fatalf("encodeMsg error: %v", err) |
||||||
if msg.Code() != 3 { |
} |
||||||
t.Errorf("incorrect code %v", msg.Code()) |
// t.Logf("encoded: %x", buf.Bytes())
|
||||||
} |
|
||||||
data0 := msg.Data().Get(0).Uint() |
decmsg, err := readMsg(buf) |
||||||
data1 := msg.Data().Get(1).Str() |
if err != nil { |
||||||
if data0 != 1 { |
t.Fatalf("readMsg error: %v", err) |
||||||
t.Errorf("incorrect data %v", data0) |
} |
||||||
} |
if decmsg.Code != 3 { |
||||||
if data1 != "000" { |
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) |
||||||
t.Errorf("incorrect data %v", data1) |
} |
||||||
|
if decmsg.Size != 5 { |
||||||
|
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) |
||||||
|
} |
||||||
|
|
||||||
|
var data struct { |
||||||
|
I int |
||||||
|
S string |
||||||
|
} |
||||||
|
if err := decmsg.Decode(&data); err != nil { |
||||||
|
t.Fatalf("Decode error: %v", err) |
||||||
|
} |
||||||
|
if data.I != 1 { |
||||||
|
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) |
||||||
|
} |
||||||
|
if data.S != "000" { |
||||||
|
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestDecodeRealMsg(t *testing.T) { |
||||||
|
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") |
||||||
|
msg, err := readMsg(bytes.NewReader(data)) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unexpected error: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if msg.Code != 0 { |
||||||
|
t.Errorf("incorrect code %d, want %d", msg.Code, 0) |
||||||
} |
} |
||||||
} |
} |
||||||
|
@ -1,220 +0,0 @@ |
|||||||
package p2p |
|
||||||
|
|
||||||
import ( |
|
||||||
"fmt" |
|
||||||
"sync" |
|
||||||
"time" |
|
||||||
) |
|
||||||
|
|
||||||
const ( |
|
||||||
handlerTimeout = 1000 |
|
||||||
) |
|
||||||
|
|
||||||
type Handlers map[string](func(p *Peer) Protocol) |
|
||||||
|
|
||||||
type Messenger struct { |
|
||||||
conn *Connection |
|
||||||
peer *Peer |
|
||||||
handlers Handlers |
|
||||||
protocolLock sync.RWMutex |
|
||||||
protocols []Protocol |
|
||||||
offsets []MsgCode // offsets for adaptive message idss
|
|
||||||
protocolTable map[string]int |
|
||||||
quit chan chan bool |
|
||||||
err chan *PeerError |
|
||||||
pulse chan bool |
|
||||||
} |
|
||||||
|
|
||||||
func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { |
|
||||||
baseProtocol := NewBaseProtocol(peer) |
|
||||||
return &Messenger{ |
|
||||||
conn: conn, |
|
||||||
peer: peer, |
|
||||||
offsets: []MsgCode{baseProtocol.Offset()}, |
|
||||||
handlers: handlers, |
|
||||||
protocols: []Protocol{baseProtocol}, |
|
||||||
protocolTable: make(map[string]int), |
|
||||||
err: errchan, |
|
||||||
pulse: make(chan bool, 1), |
|
||||||
quit: make(chan chan bool, 1), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Messenger) Start() { |
|
||||||
self.conn.Open() |
|
||||||
go self.messenger() |
|
||||||
self.protocolLock.RLock() |
|
||||||
defer self.protocolLock.RUnlock() |
|
||||||
self.protocols[0].Start() |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Messenger) Stop() { |
|
||||||
// close pulse to stop ping pong monitoring
|
|
||||||
close(self.pulse) |
|
||||||
self.protocolLock.RLock() |
|
||||||
defer self.protocolLock.RUnlock() |
|
||||||
for _, protocol := range self.protocols { |
|
||||||
protocol.Stop() // could be parallel
|
|
||||||
} |
|
||||||
q := make(chan bool) |
|
||||||
self.quit <- q |
|
||||||
<-q |
|
||||||
self.conn.Close() |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Messenger) messenger() { |
|
||||||
in := self.conn.Read() |
|
||||||
for { |
|
||||||
select { |
|
||||||
case payload, ok := <-in: |
|
||||||
//dispatches message to the protocol asynchronously
|
|
||||||
if ok { |
|
||||||
go self.handle(payload) |
|
||||||
} else { |
|
||||||
return |
|
||||||
} |
|
||||||
case q := <-self.quit: |
|
||||||
q <- true |
|
||||||
return |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// handles each message by dispatching to the appropriate protocol
|
|
||||||
// using adaptive message codes
|
|
||||||
// this function is started as a separate go routine for each message
|
|
||||||
// it waits for the protocol response
|
|
||||||
// then encodes and sends outgoing messages to the connection's write channel
|
|
||||||
func (self *Messenger) handle(payload []byte) { |
|
||||||
// send ping to heartbeat channel signalling time of last message
|
|
||||||
// select {
|
|
||||||
// case self.pulse <- true:
|
|
||||||
// default:
|
|
||||||
// }
|
|
||||||
self.pulse <- true |
|
||||||
// initialise message from payload
|
|
||||||
msg, err := NewMsgFromBytes(payload) |
|
||||||
if err != nil { |
|
||||||
self.err <- NewPeerError(MiscError, " %v", err) |
|
||||||
return |
|
||||||
} |
|
||||||
// retrieves protocol based on message Code
|
|
||||||
protocol, offset, peerErr := self.getProtocol(msg.Code()) |
|
||||||
if err != nil { |
|
||||||
self.err <- peerErr |
|
||||||
return |
|
||||||
} |
|
||||||
// reset message code based on adaptive offset
|
|
||||||
msg.Decode(offset) |
|
||||||
// dispatches
|
|
||||||
response := make(chan *Msg) |
|
||||||
go protocol.HandleIn(msg, response) |
|
||||||
// protocol reponse timeout to prevent leaks
|
|
||||||
timer := time.After(handlerTimeout * time.Millisecond) |
|
||||||
for { |
|
||||||
select { |
|
||||||
case outgoing, ok := <-response: |
|
||||||
// we check if response channel is not closed
|
|
||||||
if ok { |
|
||||||
self.conn.Write() <- outgoing.Encode(offset) |
|
||||||
} else { |
|
||||||
return |
|
||||||
} |
|
||||||
case <-timer: |
|
||||||
return |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// negotiated protocols
|
|
||||||
// stores offsets needed for adaptive message id scheme
|
|
||||||
|
|
||||||
// based on offsets set at handshake
|
|
||||||
// get the right protocol to handle the message
|
|
||||||
func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { |
|
||||||
self.protocolLock.RLock() |
|
||||||
defer self.protocolLock.RUnlock() |
|
||||||
base := MsgCode(0) |
|
||||||
for index, offset := range self.offsets { |
|
||||||
if code < offset { |
|
||||||
return self.protocols[index], base, nil |
|
||||||
} |
|
||||||
base = offset |
|
||||||
} |
|
||||||
return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { |
|
||||||
fmt.Printf("pingpong keepalive started at %v", time.Now()) |
|
||||||
|
|
||||||
timer := time.After(timeout) |
|
||||||
pinged := false |
|
||||||
for { |
|
||||||
select { |
|
||||||
case _, ok := <-self.pulse: |
|
||||||
if ok { |
|
||||||
pinged = false |
|
||||||
timer = time.After(timeout) |
|
||||||
} else { |
|
||||||
// pulse is closed, stop monitoring
|
|
||||||
return |
|
||||||
} |
|
||||||
case <-timer: |
|
||||||
if pinged { |
|
||||||
fmt.Printf("timeout at %v", time.Now()) |
|
||||||
timeoutCallback() |
|
||||||
return |
|
||||||
} else { |
|
||||||
fmt.Printf("pinged at %v", time.Now()) |
|
||||||
pingCallback() |
|
||||||
timer = time.After(gracePeriod) |
|
||||||
pinged = true |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Messenger) AddProtocols(protocols []string) { |
|
||||||
self.protocolLock.Lock() |
|
||||||
defer self.protocolLock.Unlock() |
|
||||||
i := len(self.offsets) |
|
||||||
offset := self.offsets[i-1] |
|
||||||
for _, name := range protocols { |
|
||||||
protocolFunc, ok := self.handlers[name] |
|
||||||
if ok { |
|
||||||
protocol := protocolFunc(self.peer) |
|
||||||
self.protocolTable[name] = i |
|
||||||
i++ |
|
||||||
offset += protocol.Offset() |
|
||||||
fmt.Println("offset ", name, offset) |
|
||||||
|
|
||||||
self.offsets = append(self.offsets, offset) |
|
||||||
self.protocols = append(self.protocols, protocol) |
|
||||||
protocol.Start() |
|
||||||
} else { |
|
||||||
fmt.Println("no ", name) |
|
||||||
// protocol not handled
|
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *Messenger) Write(protocol string, msg *Msg) error { |
|
||||||
self.protocolLock.RLock() |
|
||||||
defer self.protocolLock.RUnlock() |
|
||||||
i := 0 |
|
||||||
offset := MsgCode(0) |
|
||||||
if len(protocol) > 0 { |
|
||||||
var ok bool |
|
||||||
i, ok = self.protocolTable[protocol] |
|
||||||
if !ok { |
|
||||||
return fmt.Errorf("protocol %v not handled by peer", protocol) |
|
||||||
} |
|
||||||
offset = self.offsets[i-1] |
|
||||||
} |
|
||||||
handler := self.protocols[i] |
|
||||||
// checking if protocol status/caps allows the message to be sent out
|
|
||||||
if handler.HandleOut(msg) { |
|
||||||
self.conn.Write() <- msg.Encode(offset) |
|
||||||
} |
|
||||||
return nil |
|
||||||
} |
|
@ -1,147 +0,0 @@ |
|||||||
package p2p |
|
||||||
|
|
||||||
import ( |
|
||||||
// "fmt"
|
|
||||||
"bytes" |
|
||||||
"testing" |
|
||||||
"time" |
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil" |
|
||||||
) |
|
||||||
|
|
||||||
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { |
|
||||||
errchan := NewPeerErrorChannel() |
|
||||||
addr := &TestAddr{"test:30303"} |
|
||||||
net := NewTestNetworkConnection(addr) |
|
||||||
conn := NewConnection(net, errchan) |
|
||||||
mess := NewMessenger(nil, conn, errchan, handlers) |
|
||||||
mess.Start() |
|
||||||
return net, errchan, mess |
|
||||||
} |
|
||||||
|
|
||||||
type TestProtocol struct { |
|
||||||
Msgs []*Msg |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestProtocol) Start() { |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestProtocol) Stop() { |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestProtocol) Offset() MsgCode { |
|
||||||
return MsgCode(5) |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) { |
|
||||||
self.Msgs = append(self.Msgs, msg) |
|
||||||
close(response) |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestProtocol) HandleOut(msg *Msg) bool { |
|
||||||
if msg.Code() > 3 { |
|
||||||
return false |
|
||||||
} else { |
|
||||||
return true |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TestProtocol) Name() string { |
|
||||||
return "a" |
|
||||||
} |
|
||||||
|
|
||||||
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte { |
|
||||||
msg, _ := NewMsg(code, params...) |
|
||||||
encoded := msg.Encode(offset) |
|
||||||
packet := []byte{34, 64, 8, 145} |
|
||||||
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...) |
|
||||||
return append(packet, encoded...) |
|
||||||
} |
|
||||||
|
|
||||||
func TestRead(t *testing.T) { |
|
||||||
handlers := make(Handlers) |
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
|
||||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol } |
|
||||||
net, _, mess := setupMessenger(handlers) |
|
||||||
mess.AddProtocols([]string{"a"}) |
|
||||||
defer mess.Stop() |
|
||||||
wait := 1 * time.Millisecond |
|
||||||
packet := Packet(16, 1, uint32(1), "000") |
|
||||||
go net.In(0, packet) |
|
||||||
time.Sleep(wait) |
|
||||||
if len(testProtocol.Msgs) != 1 { |
|
||||||
t.Errorf("msg not relayed to correct protocol") |
|
||||||
} else { |
|
||||||
if testProtocol.Msgs[0].Code() != 1 { |
|
||||||
t.Errorf("incorrect msg code relayed to protocol") |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func TestWrite(t *testing.T) { |
|
||||||
handlers := make(Handlers) |
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
|
||||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol } |
|
||||||
net, _, mess := setupMessenger(handlers) |
|
||||||
mess.AddProtocols([]string{"a"}) |
|
||||||
defer mess.Stop() |
|
||||||
wait := 1 * time.Millisecond |
|
||||||
msg, _ := NewMsg(3, uint32(1), "000") |
|
||||||
err := mess.Write("b", msg) |
|
||||||
if err == nil { |
|
||||||
t.Errorf("expect error for unknown protocol") |
|
||||||
} |
|
||||||
err = mess.Write("a", msg) |
|
||||||
if err != nil { |
|
||||||
t.Errorf("expect no error for known protocol: %v", err) |
|
||||||
} else { |
|
||||||
time.Sleep(wait) |
|
||||||
if len(net.Out) != 1 { |
|
||||||
t.Errorf("msg not written") |
|
||||||
} else { |
|
||||||
out := net.Out[0] |
|
||||||
packet := Packet(16, 3, uint32(1), "000") |
|
||||||
if bytes.Compare(out, packet) != 0 { |
|
||||||
t.Errorf("incorrect packet %v", out) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func TestPulse(t *testing.T) { |
|
||||||
net, _, mess := setupMessenger(make(Handlers)) |
|
||||||
defer mess.Stop() |
|
||||||
ping := false |
|
||||||
timeout := false |
|
||||||
pingTimeout := 10 * time.Millisecond |
|
||||||
gracePeriod := 200 * time.Millisecond |
|
||||||
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true }) |
|
||||||
net.In(0, Packet(0, 1)) |
|
||||||
if ping { |
|
||||||
t.Errorf("ping sent too early") |
|
||||||
} |
|
||||||
time.Sleep(pingTimeout + 100*time.Millisecond) |
|
||||||
if !ping { |
|
||||||
t.Errorf("no ping sent after timeout") |
|
||||||
} |
|
||||||
if timeout { |
|
||||||
t.Errorf("timeout too early") |
|
||||||
} |
|
||||||
ping = false |
|
||||||
net.In(0, Packet(0, 1)) |
|
||||||
time.Sleep(pingTimeout + 100*time.Millisecond) |
|
||||||
if !ping { |
|
||||||
t.Errorf("no ping sent after timeout") |
|
||||||
} |
|
||||||
if timeout { |
|
||||||
t.Errorf("timeout too early") |
|
||||||
} |
|
||||||
ping = false |
|
||||||
time.Sleep(gracePeriod) |
|
||||||
if ping { |
|
||||||
t.Errorf("ping called twice") |
|
||||||
} |
|
||||||
if !timeout { |
|
||||||
t.Errorf("no timeout after grace period") |
|
||||||
} |
|
||||||
} |
|
@ -1,196 +0,0 @@ |
|||||||
package p2p |
|
||||||
|
|
||||||
import ( |
|
||||||
"fmt" |
|
||||||
"math/rand" |
|
||||||
"net" |
|
||||||
"strconv" |
|
||||||
"time" |
|
||||||
) |
|
||||||
|
|
||||||
const ( |
|
||||||
DialerTimeout = 180 //seconds
|
|
||||||
KeepAlivePeriod = 60 //minutes
|
|
||||||
portMappingUpdateInterval = 900 // seconds = 15 mins
|
|
||||||
upnpDiscoverAttempts = 3 |
|
||||||
) |
|
||||||
|
|
||||||
// Dialer is not an interface in net, so we define one
|
|
||||||
// *net.Dialer conforms to this
|
|
||||||
type Dialer interface { |
|
||||||
Dial(network, address string) (net.Conn, error) |
|
||||||
} |
|
||||||
|
|
||||||
type Network interface { |
|
||||||
Start() error |
|
||||||
Listener(net.Addr) (net.Listener, error) |
|
||||||
Dialer(net.Addr) (Dialer, error) |
|
||||||
NewAddr(string, int) (addr net.Addr, err error) |
|
||||||
ParseAddr(string) (addr net.Addr, err error) |
|
||||||
} |
|
||||||
|
|
||||||
type NAT interface { |
|
||||||
GetExternalAddress() (addr net.IP, err error) |
|
||||||
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) |
|
||||||
DeletePortMapping(protocol string, externalPort, internalPort int) (err error) |
|
||||||
} |
|
||||||
|
|
||||||
type TCPNetwork struct { |
|
||||||
nat NAT |
|
||||||
natType NATType |
|
||||||
quit chan chan bool |
|
||||||
ports chan string |
|
||||||
} |
|
||||||
|
|
||||||
type NATType int |
|
||||||
|
|
||||||
const ( |
|
||||||
NONE = iota |
|
||||||
UPNP |
|
||||||
PMP |
|
||||||
) |
|
||||||
|
|
||||||
const ( |
|
||||||
portMappingTimeout = 1200 // 20 mins
|
|
||||||
) |
|
||||||
|
|
||||||
func NewTCPNetwork(natType NATType) (net *TCPNetwork) { |
|
||||||
return &TCPNetwork{ |
|
||||||
natType: natType, |
|
||||||
ports: make(chan string), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { |
|
||||||
return &net.Dialer{ |
|
||||||
Timeout: DialerTimeout * time.Second, |
|
||||||
// KeepAlive: KeepAlivePeriod * time.Minute,
|
|
||||||
LocalAddr: addr, |
|
||||||
}, nil |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { |
|
||||||
if self.natType == UPNP { |
|
||||||
_, port, _ := net.SplitHostPort(addr.String()) |
|
||||||
if self.quit == nil { |
|
||||||
self.quit = make(chan chan bool) |
|
||||||
go self.updatePortMappings() |
|
||||||
} |
|
||||||
self.ports <- port |
|
||||||
} |
|
||||||
return net.Listen(addr.Network(), addr.String()) |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TCPNetwork) Start() (err error) { |
|
||||||
switch self.natType { |
|
||||||
case NONE: |
|
||||||
case UPNP: |
|
||||||
nat, uerr := upnpDiscover(upnpDiscoverAttempts) |
|
||||||
if uerr != nil { |
|
||||||
err = fmt.Errorf("UPNP failed: ", uerr) |
|
||||||
} else { |
|
||||||
self.nat = nat |
|
||||||
} |
|
||||||
case PMP: |
|
||||||
err = fmt.Errorf("PMP not implemented") |
|
||||||
default: |
|
||||||
err = fmt.Errorf("Invalid NAT type: %v", self.natType) |
|
||||||
} |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TCPNetwork) Stop() { |
|
||||||
q := make(chan bool) |
|
||||||
self.quit <- q |
|
||||||
<-q |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TCPNetwork) addPortMapping(lport int) (err error) { |
|
||||||
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) |
|
||||||
if err != nil { |
|
||||||
logger.Errorf("unable to add port mapping on %v: %v", lport, err) |
|
||||||
} else { |
|
||||||
logger.Debugf("succesfully added port mapping on %v", lport) |
|
||||||
} |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TCPNetwork) updatePortMappings() { |
|
||||||
timer := time.NewTimer(portMappingUpdateInterval * time.Second) |
|
||||||
lports := []int{} |
|
||||||
out: |
|
||||||
for { |
|
||||||
select { |
|
||||||
case port := <-self.ports: |
|
||||||
int64lport, _ := strconv.ParseInt(port, 10, 16) |
|
||||||
lport := int(int64lport) |
|
||||||
if err := self.addPortMapping(lport); err != nil { |
|
||||||
lports = append(lports, lport) |
|
||||||
} |
|
||||||
case <-timer.C: |
|
||||||
for lport := range lports { |
|
||||||
if err := self.addPortMapping(lport); err != nil { |
|
||||||
} |
|
||||||
} |
|
||||||
case errc := <-self.quit: |
|
||||||
errc <- true |
|
||||||
break out |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
timer.Stop() |
|
||||||
for lport := range lports { |
|
||||||
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { |
|
||||||
logger.Debugf("unable to remove port mapping on %v: %v", lport, err) |
|
||||||
} else { |
|
||||||
logger.Debugf("succesfully removed port mapping on %v", lport) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { |
|
||||||
ip, err := self.lookupIP(host) |
|
||||||
if err == nil { |
|
||||||
return &net.TCPAddr{ |
|
||||||
IP: ip, |
|
||||||
Port: port, |
|
||||||
}, nil |
|
||||||
} |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { |
|
||||||
host, port, err := net.SplitHostPort(address) |
|
||||||
if err == nil { |
|
||||||
iport, _ := strconv.Atoi(port) |
|
||||||
addr, e := self.NewAddr(host, iport) |
|
||||||
return addr, e |
|
||||||
} |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { |
|
||||||
if ip = net.ParseIP(host); ip != nil { |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
var ips []net.IP |
|
||||||
ips, err = net.LookupIP(host) |
|
||||||
if err != nil { |
|
||||||
logger.Warnln(err) |
|
||||||
return |
|
||||||
} |
|
||||||
if len(ips) == 0 { |
|
||||||
err = fmt.Errorf("No IP addresses available for %v", host) |
|
||||||
logger.Warnln(err) |
|
||||||
return |
|
||||||
} |
|
||||||
if len(ips) > 1 { |
|
||||||
// Pick a random IP address, simulating round-robin DNS.
|
|
||||||
rand.Seed(time.Now().UTC().UnixNano()) |
|
||||||
ip = ips[rand.Intn(len(ips))] |
|
||||||
} else { |
|
||||||
ip = ips[0] |
|
||||||
} |
|
||||||
return |
|
||||||
} |
|
@ -1,83 +1,455 @@ |
|||||||
package p2p |
package p2p |
||||||
|
|
||||||
import ( |
import ( |
||||||
|
"bufio" |
||||||
|
"bytes" |
||||||
"fmt" |
"fmt" |
||||||
|
"io" |
||||||
|
"io/ioutil" |
||||||
"net" |
"net" |
||||||
"strconv" |
"sort" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/event" |
||||||
|
"github.com/ethereum/go-ethereum/logger" |
||||||
) |
) |
||||||
|
|
||||||
|
// peerAddr is the structure of a peer list element.
|
||||||
|
// It is also a valid net.Addr.
|
||||||
|
type peerAddr struct { |
||||||
|
IP net.IP |
||||||
|
Port uint64 |
||||||
|
Pubkey []byte // optional
|
||||||
|
} |
||||||
|
|
||||||
|
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { |
||||||
|
n := addr.Network() |
||||||
|
if n != "tcp" && n != "tcp4" && n != "tcp6" { |
||||||
|
// for testing with non-TCP
|
||||||
|
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} |
||||||
|
} |
||||||
|
ta := addr.(*net.TCPAddr) |
||||||
|
return &peerAddr{ta.IP, uint64(ta.Port), pubkey} |
||||||
|
} |
||||||
|
|
||||||
|
func (d peerAddr) Network() string { |
||||||
|
if d.IP.To4() != nil { |
||||||
|
return "tcp4" |
||||||
|
} else { |
||||||
|
return "tcp6" |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (d peerAddr) String() string { |
||||||
|
return fmt.Sprintf("%v:%d", d.IP, d.Port) |
||||||
|
} |
||||||
|
|
||||||
|
func (d peerAddr) RlpData() interface{} { |
||||||
|
return []interface{}{d.IP, d.Port, d.Pubkey} |
||||||
|
} |
||||||
|
|
||||||
|
// Peer represents a remote peer.
|
||||||
type Peer struct { |
type Peer struct { |
||||||
// quit chan chan bool
|
// Peers have all the log methods.
|
||||||
Inbound bool // inbound (via listener) or outbound (via dialout)
|
// Use them to display messages related to the peer.
|
||||||
Address net.Addr |
*logger.Logger |
||||||
Host []byte |
|
||||||
Port uint16 |
infolock sync.Mutex |
||||||
Pubkey []byte |
identity ClientIdentity |
||||||
Id string |
caps []Cap |
||||||
Caps []string |
listenAddr *peerAddr // what remote peer is listening on
|
||||||
peerErrorChan chan *PeerError |
dialAddr *peerAddr // non-nil if dialing
|
||||||
messenger *Messenger |
|
||||||
peerErrorHandler *PeerErrorHandler |
// The mutex protects the connection
|
||||||
server *Server |
// so only one protocol can write at a time.
|
||||||
} |
writeMu sync.Mutex |
||||||
|
conn net.Conn |
||||||
func (self *Peer) Messenger() *Messenger { |
bufconn *bufio.ReadWriter |
||||||
return self.messenger |
|
||||||
} |
// These fields maintain the running protocols.
|
||||||
|
protocols []Protocol |
||||||
func (self *Peer) PeerErrorChan() chan *PeerError { |
runBaseProtocol bool // for testing
|
||||||
return self.peerErrorChan |
|
||||||
} |
runlock sync.RWMutex // protects running
|
||||||
|
running map[string]*proto |
||||||
func (self *Peer) Server() *Server { |
|
||||||
return self.server |
protoWG sync.WaitGroup |
||||||
} |
protoErr chan error |
||||||
|
closed chan struct{} |
||||||
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { |
disc chan DiscReason |
||||||
peerErrorChan := NewPeerErrorChannel() |
|
||||||
host, port, _ := net.SplitHostPort(address.String()) |
activity event.TypeMux // for activity events
|
||||||
intport, _ := strconv.Atoi(port) |
|
||||||
peer := &Peer{ |
slot int // index into Server peer list
|
||||||
Inbound: inbound, |
|
||||||
Address: address, |
// These fields are kept so base protocol can access them.
|
||||||
Port: uint16(intport), |
// TODO: this should be one or more interfaces
|
||||||
Host: net.ParseIP(host), |
ourID ClientIdentity // client id of the Server
|
||||||
peerErrorChan: peerErrorChan, |
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
|
||||||
server: server, |
newPeerAddr chan<- *peerAddr // tell server about received peers
|
||||||
} |
otherPeers func() []*Peer // should return the list of all peers
|
||||||
connection := NewConnection(conn, peerErrorChan) |
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
|
||||||
peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers()) |
} |
||||||
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist()) |
|
||||||
|
// NewPeer returns a peer for testing purposes.
|
||||||
|
func NewPeer(id ClientIdentity, caps []Cap) *Peer { |
||||||
|
conn, _ := net.Pipe() |
||||||
|
peer := newPeer(conn, nil, nil) |
||||||
|
peer.setHandshakeInfo(id, nil, caps) |
||||||
|
close(peer.closed) |
||||||
return peer |
return peer |
||||||
} |
} |
||||||
|
|
||||||
func (self *Peer) String() string { |
func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { |
||||||
var kind string |
p := newPeer(conn, server.Protocols, dialAddr) |
||||||
if self.Inbound { |
p.ourID = server.Identity |
||||||
kind = "inbound" |
p.newPeerAddr = server.peerConnect |
||||||
} else { |
p.otherPeers = server.Peers |
||||||
|
p.pubkeyHook = server.verifyPeer |
||||||
|
p.runBaseProtocol = true |
||||||
|
|
||||||
|
// laddr can be updated concurrently by NAT traversal.
|
||||||
|
// newServerPeer must be called with the server lock held.
|
||||||
|
if server.laddr != nil { |
||||||
|
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) |
||||||
|
} |
||||||
|
return p |
||||||
|
} |
||||||
|
|
||||||
|
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { |
||||||
|
p := &Peer{ |
||||||
|
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), |
||||||
|
conn: conn, |
||||||
|
dialAddr: dialAddr, |
||||||
|
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), |
||||||
|
protocols: protocols, |
||||||
|
running: make(map[string]*proto), |
||||||
|
disc: make(chan DiscReason), |
||||||
|
protoErr: make(chan error), |
||||||
|
closed: make(chan struct{}), |
||||||
|
} |
||||||
|
return p |
||||||
|
} |
||||||
|
|
||||||
|
// Identity returns the client identity of the remote peer. The
|
||||||
|
// identity can be nil if the peer has not yet completed the
|
||||||
|
// handshake.
|
||||||
|
func (p *Peer) Identity() ClientIdentity { |
||||||
|
p.infolock.Lock() |
||||||
|
defer p.infolock.Unlock() |
||||||
|
return p.identity |
||||||
|
} |
||||||
|
|
||||||
|
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||||
|
func (p *Peer) Caps() []Cap { |
||||||
|
p.infolock.Lock() |
||||||
|
defer p.infolock.Unlock() |
||||||
|
return p.caps |
||||||
|
} |
||||||
|
|
||||||
|
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { |
||||||
|
p.infolock.Lock() |
||||||
|
p.identity = id |
||||||
|
p.listenAddr = laddr |
||||||
|
p.caps = caps |
||||||
|
p.infolock.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// RemoteAddr returns the remote address of the network connection.
|
||||||
|
func (p *Peer) RemoteAddr() net.Addr { |
||||||
|
return p.conn.RemoteAddr() |
||||||
|
} |
||||||
|
|
||||||
|
// LocalAddr returns the local address of the network connection.
|
||||||
|
func (p *Peer) LocalAddr() net.Addr { |
||||||
|
return p.conn.LocalAddr() |
||||||
|
} |
||||||
|
|
||||||
|
// Disconnect terminates the peer connection with the given reason.
|
||||||
|
// It returns immediately and does not wait until the connection is closed.
|
||||||
|
func (p *Peer) Disconnect(reason DiscReason) { |
||||||
|
select { |
||||||
|
case p.disc <- reason: |
||||||
|
case <-p.closed: |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// String implements fmt.Stringer.
|
||||||
|
func (p *Peer) String() string { |
||||||
|
kind := "inbound" |
||||||
|
p.infolock.Lock() |
||||||
|
if p.dialAddr != nil { |
||||||
kind = "outbound" |
kind = "outbound" |
||||||
} |
} |
||||||
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) |
p.infolock.Unlock() |
||||||
|
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) |
||||||
|
} |
||||||
|
|
||||||
|
const ( |
||||||
|
// maximum amount of time allowed for reading a message
|
||||||
|
msgReadTimeout = 5 * time.Second |
||||||
|
// maximum amount of time allowed for writing a message
|
||||||
|
msgWriteTimeout = 5 * time.Second |
||||||
|
// messages smaller than this many bytes will be read at
|
||||||
|
// once before passing them to a protocol.
|
||||||
|
wholePayloadSize = 64 * 1024 |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
inactivityTimeout = 2 * time.Second |
||||||
|
disconnectGracePeriod = 2 * time.Second |
||||||
|
) |
||||||
|
|
||||||
|
func (p *Peer) loop() (reason DiscReason, err error) { |
||||||
|
defer p.activity.Stop() |
||||||
|
defer p.closeProtocols() |
||||||
|
defer close(p.closed) |
||||||
|
defer p.conn.Close() |
||||||
|
|
||||||
|
// read loop
|
||||||
|
readMsg := make(chan Msg) |
||||||
|
readErr := make(chan error) |
||||||
|
readNext := make(chan bool, 1) |
||||||
|
protoDone := make(chan struct{}, 1) |
||||||
|
go p.readLoop(readMsg, readErr, readNext) |
||||||
|
readNext <- true |
||||||
|
|
||||||
|
if p.runBaseProtocol { |
||||||
|
p.startBaseProtocol() |
||||||
|
} |
||||||
|
|
||||||
|
loop: |
||||||
|
for { |
||||||
|
select { |
||||||
|
case msg := <-readMsg: |
||||||
|
// a new message has arrived.
|
||||||
|
var wait bool |
||||||
|
if wait, err = p.dispatch(msg, protoDone); err != nil { |
||||||
|
p.Errorf("msg dispatch error: %v\n", err) |
||||||
|
reason = discReasonForError(err) |
||||||
|
break loop |
||||||
|
} |
||||||
|
if !wait { |
||||||
|
// Msg has already been read completely, continue with next message.
|
||||||
|
readNext <- true |
||||||
|
} |
||||||
|
p.activity.Post(time.Now()) |
||||||
|
case <-protoDone: |
||||||
|
// protocol has consumed the message payload,
|
||||||
|
// we can continue reading from the socket.
|
||||||
|
readNext <- true |
||||||
|
|
||||||
|
case err := <-readErr: |
||||||
|
// read failed. there is no need to run the
|
||||||
|
// polite disconnect sequence because the connection
|
||||||
|
// is probably dead anyway.
|
||||||
|
// TODO: handle write errors as well
|
||||||
|
return DiscNetworkError, err |
||||||
|
case err = <-p.protoErr: |
||||||
|
reason = discReasonForError(err) |
||||||
|
break loop |
||||||
|
case reason = <-p.disc: |
||||||
|
break loop |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// wait for read loop to return.
|
||||||
|
close(readNext) |
||||||
|
<-readErr |
||||||
|
// tell the remote end to disconnect
|
||||||
|
done := make(chan struct{}) |
||||||
|
go func() { |
||||||
|
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) |
||||||
|
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) |
||||||
|
io.Copy(ioutil.Discard, p.conn) |
||||||
|
close(done) |
||||||
|
}() |
||||||
|
select { |
||||||
|
case <-done: |
||||||
|
case <-time.After(disconnectGracePeriod): |
||||||
|
} |
||||||
|
return reason, err |
||||||
|
} |
||||||
|
|
||||||
|
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { |
||||||
|
for _ = range unblock { |
||||||
|
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) |
||||||
|
if msg, err := readMsg(p.bufconn); err != nil { |
||||||
|
errc <- err |
||||||
|
} else { |
||||||
|
msgc <- msg |
||||||
|
} |
||||||
|
} |
||||||
|
close(errc) |
||||||
|
} |
||||||
|
|
||||||
|
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { |
||||||
|
proto, err := p.getProto(msg.Code) |
||||||
|
if err != nil { |
||||||
|
return false, err |
||||||
|
} |
||||||
|
if msg.Size <= wholePayloadSize { |
||||||
|
// optimization: msg is small enough, read all
|
||||||
|
// of it and move on to the next message
|
||||||
|
buf, err := ioutil.ReadAll(msg.Payload) |
||||||
|
if err != nil { |
||||||
|
return false, err |
||||||
|
} |
||||||
|
msg.Payload = bytes.NewReader(buf) |
||||||
|
proto.in <- msg |
||||||
|
} else { |
||||||
|
wait = true |
||||||
|
pr := &eofSignal{msg.Payload, protoDone} |
||||||
|
msg.Payload = pr |
||||||
|
proto.in <- msg |
||||||
|
} |
||||||
|
return wait, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (p *Peer) startBaseProtocol() { |
||||||
|
p.runlock.Lock() |
||||||
|
defer p.runlock.Unlock() |
||||||
|
p.running[""] = p.startProto(0, Protocol{ |
||||||
|
Length: baseProtocolLength, |
||||||
|
Run: runBaseProtocol, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// startProtocols starts matching named subprotocols.
|
||||||
|
func (p *Peer) startSubprotocols(caps []Cap) { |
||||||
|
sort.Sort(capsByName(caps)) |
||||||
|
|
||||||
|
p.runlock.Lock() |
||||||
|
defer p.runlock.Unlock() |
||||||
|
offset := baseProtocolLength |
||||||
|
outer: |
||||||
|
for _, cap := range caps { |
||||||
|
for _, proto := range p.protocols { |
||||||
|
if proto.Name == cap.Name && |
||||||
|
proto.Version == cap.Version && |
||||||
|
p.running[cap.Name] == nil { |
||||||
|
p.running[cap.Name] = p.startProto(offset, proto) |
||||||
|
offset += proto.Length |
||||||
|
continue outer |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
} |
} |
||||||
|
|
||||||
func (self *Peer) Write(protocol string, msg *Msg) error { |
func (p *Peer) startProto(offset uint64, impl Protocol) *proto { |
||||||
return self.messenger.Write(protocol, msg) |
rw := &proto{ |
||||||
|
in: make(chan Msg), |
||||||
|
offset: offset, |
||||||
|
maxcode: impl.Length, |
||||||
|
peer: p, |
||||||
|
} |
||||||
|
p.protoWG.Add(1) |
||||||
|
go func() { |
||||||
|
err := impl.Run(p, rw) |
||||||
|
if err == nil { |
||||||
|
p.Infof("protocol %q returned", impl.Name) |
||||||
|
err = newPeerError(errMisc, "protocol returned") |
||||||
|
} else { |
||||||
|
p.Errorf("protocol %q error: %v\n", impl.Name, err) |
||||||
|
} |
||||||
|
select { |
||||||
|
case p.protoErr <- err: |
||||||
|
case <-p.closed: |
||||||
|
} |
||||||
|
p.protoWG.Done() |
||||||
|
}() |
||||||
|
return rw |
||||||
|
} |
||||||
|
|
||||||
|
// getProto finds the protocol responsible for handling
|
||||||
|
// the given message code.
|
||||||
|
func (p *Peer) getProto(code uint64) (*proto, error) { |
||||||
|
p.runlock.RLock() |
||||||
|
defer p.runlock.RUnlock() |
||||||
|
for _, proto := range p.running { |
||||||
|
if code >= proto.offset && code < proto.offset+proto.maxcode { |
||||||
|
return proto, nil |
||||||
|
} |
||||||
|
} |
||||||
|
return nil, newPeerError(errInvalidMsgCode, "%d", code) |
||||||
} |
} |
||||||
|
|
||||||
func (self *Peer) Start() { |
func (p *Peer) closeProtocols() { |
||||||
self.peerErrorHandler.Start() |
p.runlock.RLock() |
||||||
self.messenger.Start() |
for _, p := range p.running { |
||||||
|
close(p.in) |
||||||
|
} |
||||||
|
p.runlock.RUnlock() |
||||||
|
p.protoWG.Wait() |
||||||
} |
} |
||||||
|
|
||||||
func (self *Peer) Stop() { |
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||||
self.peerErrorHandler.Stop() |
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { |
||||||
self.messenger.Stop() |
p.runlock.RLock() |
||||||
// q := make(chan bool)
|
proto, ok := p.running[protoName] |
||||||
// self.quit <- q
|
p.runlock.RUnlock() |
||||||
// <-q
|
if !ok { |
||||||
|
return fmt.Errorf("protocol %s not handled by peer", protoName) |
||||||
|
} |
||||||
|
if msg.Code >= proto.maxcode { |
||||||
|
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) |
||||||
|
} |
||||||
|
msg.Code += proto.offset |
||||||
|
return p.writeMsg(msg, msgWriteTimeout) |
||||||
|
} |
||||||
|
|
||||||
|
// writeMsg writes a message to the connection.
|
||||||
|
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { |
||||||
|
p.writeMu.Lock() |
||||||
|
defer p.writeMu.Unlock() |
||||||
|
p.conn.SetWriteDeadline(time.Now().Add(timeout)) |
||||||
|
if err := writeMsg(p.bufconn, msg); err != nil { |
||||||
|
return newPeerError(errWrite, "%v", err) |
||||||
|
} |
||||||
|
return p.bufconn.Flush() |
||||||
} |
} |
||||||
|
|
||||||
func (p *Peer) Encode() []interface{} { |
type proto struct { |
||||||
return []interface{}{p.Host, p.Port, p.Pubkey} |
name string |
||||||
|
in chan Msg |
||||||
|
maxcode, offset uint64 |
||||||
|
peer *Peer |
||||||
|
} |
||||||
|
|
||||||
|
func (rw *proto) WriteMsg(msg Msg) error { |
||||||
|
if msg.Code >= rw.maxcode { |
||||||
|
return newPeerError(errInvalidMsgCode, "not handled") |
||||||
|
} |
||||||
|
msg.Code += rw.offset |
||||||
|
return rw.peer.writeMsg(msg, msgWriteTimeout) |
||||||
|
} |
||||||
|
|
||||||
|
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { |
||||||
|
return rw.WriteMsg(NewMsg(code, data)) |
||||||
|
} |
||||||
|
|
||||||
|
func (rw *proto) ReadMsg() (Msg, error) { |
||||||
|
msg, ok := <-rw.in |
||||||
|
if !ok { |
||||||
|
return msg, io.EOF |
||||||
|
} |
||||||
|
msg.Code -= rw.offset |
||||||
|
return msg, nil |
||||||
|
} |
||||||
|
|
||||||
|
// eofSignal wraps a reader with eof signaling.
|
||||||
|
// the eof channel is closed when the wrapped reader
|
||||||
|
// reaches EOF.
|
||||||
|
type eofSignal struct { |
||||||
|
wrapped io.Reader |
||||||
|
eof chan<- struct{} |
||||||
|
} |
||||||
|
|
||||||
|
func (r *eofSignal) Read(buf []byte) (int, error) { |
||||||
|
n, err := r.wrapped.Read(buf) |
||||||
|
if err != nil { |
||||||
|
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
||||||
|
} |
||||||
|
return n, err |
||||||
} |
} |
||||||
|
@ -1,101 +0,0 @@ |
|||||||
package p2p |
|
||||||
|
|
||||||
import ( |
|
||||||
"net" |
|
||||||
) |
|
||||||
|
|
||||||
const ( |
|
||||||
severityThreshold = 10 |
|
||||||
) |
|
||||||
|
|
||||||
type DisconnectRequest struct { |
|
||||||
addr net.Addr |
|
||||||
reason DiscReason |
|
||||||
} |
|
||||||
|
|
||||||
type PeerErrorHandler struct { |
|
||||||
quit chan chan bool |
|
||||||
address net.Addr |
|
||||||
peerDisconnect chan DisconnectRequest |
|
||||||
severity int |
|
||||||
peerErrorChan chan *PeerError |
|
||||||
blacklist Blacklist |
|
||||||
} |
|
||||||
|
|
||||||
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler { |
|
||||||
return &PeerErrorHandler{ |
|
||||||
quit: make(chan chan bool), |
|
||||||
address: address, |
|
||||||
peerDisconnect: peerDisconnect, |
|
||||||
peerErrorChan: peerErrorChan, |
|
||||||
blacklist: blacklist, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *PeerErrorHandler) Start() { |
|
||||||
go self.listen() |
|
||||||
} |
|
||||||
|
|
||||||
func (self *PeerErrorHandler) Stop() { |
|
||||||
q := make(chan bool) |
|
||||||
self.quit <- q |
|
||||||
<-q |
|
||||||
} |
|
||||||
|
|
||||||
func (self *PeerErrorHandler) listen() { |
|
||||||
for { |
|
||||||
select { |
|
||||||
case peerError, ok := <-self.peerErrorChan: |
|
||||||
if ok { |
|
||||||
logger.Debugf("error %v\n", peerError) |
|
||||||
go self.handle(peerError) |
|
||||||
} else { |
|
||||||
return |
|
||||||
} |
|
||||||
case q := <-self.quit: |
|
||||||
q <- true |
|
||||||
return |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *PeerErrorHandler) handle(peerError *PeerError) { |
|
||||||
reason := DiscReason(' ') |
|
||||||
switch peerError.Code { |
|
||||||
case P2PVersionMismatch: |
|
||||||
reason = DiscIncompatibleVersion |
|
||||||
case PubkeyMissing, PubkeyInvalid: |
|
||||||
reason = DiscInvalidIdentity |
|
||||||
case PubkeyForbidden: |
|
||||||
reason = DiscUselessPeer |
|
||||||
case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach: |
|
||||||
reason = DiscProtocolError |
|
||||||
case PingTimeout: |
|
||||||
reason = DiscReadTimeout |
|
||||||
case WriteError, MiscError: |
|
||||||
reason = DiscNetworkError |
|
||||||
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: |
|
||||||
reason = DiscSubprotocolError |
|
||||||
default: |
|
||||||
self.severity += self.getSeverity(peerError) |
|
||||||
} |
|
||||||
|
|
||||||
if self.severity >= severityThreshold { |
|
||||||
reason = DiscSubprotocolError |
|
||||||
} |
|
||||||
if reason != DiscReason(' ') { |
|
||||||
self.peerDisconnect <- DisconnectRequest{ |
|
||||||
addr: self.address, |
|
||||||
reason: reason, |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { |
|
||||||
switch peerError.Code { |
|
||||||
case ReadError: |
|
||||||
return 4 //tolerate 3 :)
|
|
||||||
default: |
|
||||||
return 1 |
|
||||||
} |
|
||||||
} |
|
@ -1,34 +0,0 @@ |
|||||||
package p2p |
|
||||||
|
|
||||||
import ( |
|
||||||
// "fmt"
|
|
||||||
"net" |
|
||||||
"testing" |
|
||||||
"time" |
|
||||||
) |
|
||||||
|
|
||||||
func TestPeerErrorHandler(t *testing.T) { |
|
||||||
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} |
|
||||||
peerDisconnect := make(chan DisconnectRequest) |
|
||||||
peerErrorChan := NewPeerErrorChannel() |
|
||||||
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist()) |
|
||||||
peh.Start() |
|
||||||
defer peh.Stop() |
|
||||||
for i := 0; i < 11; i++ { |
|
||||||
select { |
|
||||||
case <-peerDisconnect: |
|
||||||
t.Errorf("expected no disconnect request") |
|
||||||
default: |
|
||||||
} |
|
||||||
peerErrorChan <- NewPeerError(MiscError, "") |
|
||||||
} |
|
||||||
time.Sleep(1 * time.Millisecond) |
|
||||||
select { |
|
||||||
case request := <-peerDisconnect: |
|
||||||
if request.addr.String() != address.String() { |
|
||||||
t.Errorf("incorrect address %v != %v", request.addr, address) |
|
||||||
} |
|
||||||
default: |
|
||||||
t.Errorf("expected disconnect request") |
|
||||||
} |
|
||||||
} |
|
@ -1,96 +1,239 @@ |
|||||||
package p2p |
package p2p |
||||||
|
|
||||||
import ( |
import ( |
||||||
|
"bufio" |
||||||
"bytes" |
"bytes" |
||||||
"fmt" |
"encoding/hex" |
||||||
// "net"
|
"io/ioutil" |
||||||
|
"net" |
||||||
|
"reflect" |
||||||
"testing" |
"testing" |
||||||
"time" |
"time" |
||||||
) |
) |
||||||
|
|
||||||
func TestPeer(t *testing.T) { |
var discard = Protocol{ |
||||||
handlers := make(Handlers) |
Name: "discard", |
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
Length: 1, |
||||||
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } |
Run: func(p *Peer, rw MsgReadWriter) error { |
||||||
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } |
for { |
||||||
addr := &TestAddr{"test:30"} |
msg, err := rw.ReadMsg() |
||||||
conn := NewTestNetworkConnection(addr) |
|
||||||
_, server := SetupTestServer(handlers) |
|
||||||
server.Handshake() |
|
||||||
peer := NewPeer(conn, addr, true, server) |
|
||||||
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
|
||||||
peer.Start() |
|
||||||
defer peer.Stop() |
|
||||||
time.Sleep(2 * time.Millisecond) |
|
||||||
if len(conn.Out) != 1 { |
|
||||||
t.Errorf("handshake not sent") |
|
||||||
} else { |
|
||||||
out := conn.Out[0] |
|
||||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) |
|
||||||
if bytes.Compare(out, packet) != 0 { |
|
||||||
t.Errorf("incorrect handshake packet %v != %v", out, packet) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) |
|
||||||
conn.In(0, packet) |
|
||||||
time.Sleep(10 * time.Millisecond) |
|
||||||
|
|
||||||
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) |
|
||||||
if pro.state != handshakeReceived { |
|
||||||
t.Errorf("handshake not received") |
|
||||||
} |
|
||||||
if peer.Port != 30 { |
|
||||||
t.Errorf("port incorrectly set") |
|
||||||
} |
|
||||||
if peer.Id != "peer" { |
|
||||||
t.Errorf("id incorrectly set") |
|
||||||
} |
|
||||||
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { |
|
||||||
t.Errorf("pubkey incorrectly set") |
|
||||||
} |
|
||||||
fmt.Println(peer.Caps) |
|
||||||
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { |
|
||||||
t.Errorf("protocols incorrectly set") |
|
||||||
} |
|
||||||
|
|
||||||
msg, _ := NewMsg(3) |
|
||||||
err := peer.Write("aaa", msg) |
|
||||||
if err != nil { |
if err != nil { |
||||||
t.Errorf("expect no error for known protocol: %v", err) |
return err |
||||||
} else { |
} |
||||||
time.Sleep(1 * time.Millisecond) |
if err = msg.Discard(); err != nil { |
||||||
if len(conn.Out) != 2 { |
return err |
||||||
t.Errorf("msg not written") |
} |
||||||
} else { |
} |
||||||
out := conn.Out[1] |
}, |
||||||
packet := Packet(16, 3) |
} |
||||||
if bytes.Compare(out, packet) != 0 { |
|
||||||
t.Errorf("incorrect packet %v != %v", out, packet) |
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { |
||||||
|
conn1, conn2 := net.Pipe() |
||||||
|
id := NewSimpleClientIdentity("test", "0", "0", "public key") |
||||||
|
peer := newPeer(conn1, protos, nil) |
||||||
|
peer.ourID = id |
||||||
|
peer.pubkeyHook = func(*peerAddr) error { return nil } |
||||||
|
errc := make(chan error, 1) |
||||||
|
go func() { |
||||||
|
_, err := peer.loop() |
||||||
|
errc <- err |
||||||
|
}() |
||||||
|
return conn2, peer, errc |
||||||
|
} |
||||||
|
|
||||||
|
func TestPeerProtoReadMsg(t *testing.T) { |
||||||
|
defer testlog(t).detach() |
||||||
|
|
||||||
|
done := make(chan struct{}) |
||||||
|
proto := Protocol{ |
||||||
|
Name: "a", |
||||||
|
Length: 5, |
||||||
|
Run: func(peer *Peer, rw MsgReadWriter) error { |
||||||
|
msg, err := rw.ReadMsg() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("read error: %v", err) |
||||||
|
} |
||||||
|
if msg.Code != 2 { |
||||||
|
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) |
||||||
|
} |
||||||
|
data, err := ioutil.ReadAll(msg.Payload) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("payload read error: %v", err) |
||||||
|
} |
||||||
|
expdata, _ := hex.DecodeString("0183303030") |
||||||
|
if !bytes.Equal(expdata, data) { |
||||||
|
t.Errorf("incorrect msg data %x", data) |
||||||
|
} |
||||||
|
close(done) |
||||||
|
return nil |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
net, peer, errc := testPeer([]Protocol{proto}) |
||||||
|
defer net.Close() |
||||||
|
peer.startSubprotocols([]Cap{proto.cap()}) |
||||||
|
|
||||||
|
writeMsg(net, NewMsg(18, 1, "000")) |
||||||
|
select { |
||||||
|
case <-done: |
||||||
|
case err := <-errc: |
||||||
|
t.Errorf("peer returned: %v", err) |
||||||
|
case <-time.After(2 * time.Second): |
||||||
|
t.Errorf("receive timeout") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestPeerProtoReadLargeMsg(t *testing.T) { |
||||||
|
defer testlog(t).detach() |
||||||
|
|
||||||
|
msgsize := uint32(10 * 1024 * 1024) |
||||||
|
done := make(chan struct{}) |
||||||
|
proto := Protocol{ |
||||||
|
Name: "a", |
||||||
|
Length: 5, |
||||||
|
Run: func(peer *Peer, rw MsgReadWriter) error { |
||||||
|
msg, err := rw.ReadMsg() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("read error: %v", err) |
||||||
|
} |
||||||
|
if msg.Size != msgsize+4 { |
||||||
|
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) |
||||||
|
} |
||||||
|
msg.Discard() |
||||||
|
close(done) |
||||||
|
return nil |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
net, peer, errc := testPeer([]Protocol{proto}) |
||||||
|
defer net.Close() |
||||||
|
peer.startSubprotocols([]Cap{proto.cap()}) |
||||||
|
|
||||||
|
writeMsg(net, NewMsg(18, make([]byte, msgsize))) |
||||||
|
select { |
||||||
|
case <-done: |
||||||
|
case err := <-errc: |
||||||
|
t.Errorf("peer returned: %v", err) |
||||||
|
case <-time.After(2 * time.Second): |
||||||
|
t.Errorf("receive timeout") |
||||||
|
} |
||||||
} |
} |
||||||
|
|
||||||
|
func TestPeerProtoEncodeMsg(t *testing.T) { |
||||||
|
defer testlog(t).detach() |
||||||
|
|
||||||
|
proto := Protocol{ |
||||||
|
Name: "a", |
||||||
|
Length: 2, |
||||||
|
Run: func(peer *Peer, rw MsgReadWriter) error { |
||||||
|
if err := rw.EncodeMsg(2); err == nil { |
||||||
|
t.Error("expected error for out-of-range msg code, got nil") |
||||||
} |
} |
||||||
|
if err := rw.EncodeMsg(1); err != nil { |
||||||
|
t.Errorf("write error: %v", err) |
||||||
} |
} |
||||||
|
return nil |
||||||
|
}, |
||||||
|
} |
||||||
|
net, peer, _ := testPeer([]Protocol{proto}) |
||||||
|
defer net.Close() |
||||||
|
peer.startSubprotocols([]Cap{proto.cap()}) |
||||||
|
|
||||||
msg, _ = NewMsg(2) |
bufr := bufio.NewReader(net) |
||||||
err = peer.Write("ccc", msg) |
msg, err := readMsg(bufr) |
||||||
if err != nil { |
if err != nil { |
||||||
|
t.Errorf("read error: %v", err) |
||||||
|
} |
||||||
|
if msg.Code != 17 { |
||||||
|
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestPeerWrite(t *testing.T) { |
||||||
|
defer testlog(t).detach() |
||||||
|
|
||||||
|
net, peer, peerErr := testPeer([]Protocol{discard}) |
||||||
|
defer net.Close() |
||||||
|
peer.startSubprotocols([]Cap{discard.cap()}) |
||||||
|
|
||||||
|
// test write errors
|
||||||
|
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { |
||||||
|
t.Errorf("expected error for unknown protocol, got nil") |
||||||
|
} |
||||||
|
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { |
||||||
|
t.Errorf("expected error for out-of-range msg code, got nil") |
||||||
|
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { |
||||||
|
t.Errorf("wrong error for out-of-range msg code, got %#v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// setup for reading the message on the other end
|
||||||
|
read := make(chan struct{}) |
||||||
|
go func() { |
||||||
|
bufr := bufio.NewReader(net) |
||||||
|
msg, err := readMsg(bufr) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("read error: %v", err) |
||||||
|
} else if msg.Code != 16 { |
||||||
|
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) |
||||||
|
} |
||||||
|
msg.Discard() |
||||||
|
close(read) |
||||||
|
}() |
||||||
|
|
||||||
|
// test succcessful write
|
||||||
|
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { |
||||||
t.Errorf("expect no error for known protocol: %v", err) |
t.Errorf("expect no error for known protocol: %v", err) |
||||||
} else { |
|
||||||
time.Sleep(1 * time.Millisecond) |
|
||||||
if len(conn.Out) != 3 { |
|
||||||
t.Errorf("msg not written") |
|
||||||
} else { |
|
||||||
out := conn.Out[2] |
|
||||||
packet := Packet(21, 2) |
|
||||||
if bytes.Compare(out, packet) != 0 { |
|
||||||
t.Errorf("incorrect packet %v != %v", out, packet) |
|
||||||
} |
} |
||||||
|
select { |
||||||
|
case <-read: |
||||||
|
case err := <-peerErr: |
||||||
|
t.Fatalf("peer stopped: %v", err) |
||||||
} |
} |
||||||
} |
} |
||||||
|
|
||||||
err = peer.Write("bbb", msg) |
func TestPeerActivity(t *testing.T) { |
||||||
time.Sleep(1 * time.Millisecond) |
// shorten inactivityTimeout while this test is running
|
||||||
if err == nil { |
oldT := inactivityTimeout |
||||||
t.Errorf("expect error for unknown protocol") |
defer func() { inactivityTimeout = oldT }() |
||||||
|
inactivityTimeout = 20 * time.Millisecond |
||||||
|
|
||||||
|
net, peer, peerErr := testPeer([]Protocol{discard}) |
||||||
|
defer net.Close() |
||||||
|
peer.startSubprotocols([]Cap{discard.cap()}) |
||||||
|
|
||||||
|
sub := peer.activity.Subscribe(time.Time{}) |
||||||
|
defer sub.Unsubscribe() |
||||||
|
|
||||||
|
for i := 0; i < 6; i++ { |
||||||
|
writeMsg(net, NewMsg(16)) |
||||||
|
select { |
||||||
|
case <-sub.Chan(): |
||||||
|
case <-time.After(inactivityTimeout / 2): |
||||||
|
t.Fatal("no event within ", inactivityTimeout/2) |
||||||
|
case err := <-peerErr: |
||||||
|
t.Fatal("peer error", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
select { |
||||||
|
case <-time.After(inactivityTimeout * 2): |
||||||
|
case <-sub.Chan(): |
||||||
|
t.Fatal("got activity event while connection was inactive") |
||||||
|
case err := <-peerErr: |
||||||
|
t.Fatal("peer error", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestNewPeer(t *testing.T) { |
||||||
|
id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey") |
||||||
|
caps := []Cap{{"foo", 2}, {"bar", 3}} |
||||||
|
p := NewPeer(id, caps) |
||||||
|
if !reflect.DeepEqual(p.Caps(), caps) { |
||||||
|
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) |
||||||
|
} |
||||||
|
if p.Identity() != id { |
||||||
|
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) |
||||||
} |
} |
||||||
|
// Should not hang.
|
||||||
|
p.Disconnect(DiscAlreadyConnected) |
||||||
} |
} |
||||||
|
@ -0,0 +1,28 @@ |
|||||||
|
package p2p |
||||||
|
|
||||||
|
import ( |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/logger" |
||||||
|
) |
||||||
|
|
||||||
|
type testLogger struct{ t *testing.T } |
||||||
|
|
||||||
|
func testlog(t *testing.T) testLogger { |
||||||
|
logger.Reset() |
||||||
|
l := testLogger{t} |
||||||
|
logger.AddLogSystem(l) |
||||||
|
return l |
||||||
|
} |
||||||
|
|
||||||
|
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } |
||||||
|
func (testLogger) SetLogLevel(logger.LogLevel) {} |
||||||
|
|
||||||
|
func (l testLogger) LogPrint(level logger.LogLevel, msg string) { |
||||||
|
l.t.Logf("%s", msg) |
||||||
|
} |
||||||
|
|
||||||
|
func (testLogger) detach() { |
||||||
|
logger.Flush() |
||||||
|
logger.Reset() |
||||||
|
} |
@ -0,0 +1,40 @@ |
|||||||
|
// +build none
|
||||||
|
|
||||||
|
package main |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"log" |
||||||
|
"net" |
||||||
|
"os" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/logger" |
||||||
|
"github.com/ethereum/go-ethereum/p2p" |
||||||
|
"github.com/obscuren/secp256k1-go" |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) |
||||||
|
|
||||||
|
pub, _ := secp256k1.GenerateKeyPair() |
||||||
|
srv := p2p.Server{ |
||||||
|
MaxPeers: 10, |
||||||
|
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), |
||||||
|
ListenAddr: ":30303", |
||||||
|
NAT: p2p.PMP(net.ParseIP("10.0.0.1")), |
||||||
|
} |
||||||
|
if err := srv.Start(); err != nil { |
||||||
|
fmt.Println("could not start server:", err) |
||||||
|
os.Exit(1) |
||||||
|
} |
||||||
|
|
||||||
|
// add seed peers
|
||||||
|
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") |
||||||
|
if err != nil { |
||||||
|
fmt.Println("couldn't resolve:", err) |
||||||
|
os.Exit(1) |
||||||
|
} |
||||||
|
srv.SuggestPeer(seed.IP, seed.Port, nil) |
||||||
|
|
||||||
|
select {} |
||||||
|
} |
Loading…
Reference in new issue