mirror of https://github.com/ethereum/go-ethereum
parent
8cf9ed0ea5
commit
f38052c499
@ -1,275 +0,0 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
// "fmt"
|
||||
"net" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
) |
||||
|
||||
type Connection struct { |
||||
conn net.Conn |
||||
// conn NetworkConnection
|
||||
timeout time.Duration |
||||
in chan []byte |
||||
out chan []byte |
||||
err chan *PeerError |
||||
closingIn chan chan bool |
||||
closingOut chan chan bool |
||||
} |
||||
|
||||
// const readBufferLength = 2 //for testing
|
||||
|
||||
const readBufferLength = 1440 |
||||
const partialsQueueSize = 10 |
||||
const maxPendingQueueSize = 1 |
||||
const defaultTimeout = 500 |
||||
|
||||
var magicToken = []byte{34, 64, 8, 145} |
||||
|
||||
func (self *Connection) Open() { |
||||
go self.startRead() |
||||
go self.startWrite() |
||||
} |
||||
|
||||
func (self *Connection) Close() { |
||||
self.closeIn() |
||||
self.closeOut() |
||||
} |
||||
|
||||
func (self *Connection) closeIn() { |
||||
errc := make(chan bool) |
||||
self.closingIn <- errc |
||||
<-errc |
||||
} |
||||
|
||||
func (self *Connection) closeOut() { |
||||
errc := make(chan bool) |
||||
self.closingOut <- errc |
||||
<-errc |
||||
} |
||||
|
||||
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection { |
||||
return &Connection{ |
||||
conn: conn, |
||||
timeout: defaultTimeout, |
||||
in: make(chan []byte), |
||||
out: make(chan []byte), |
||||
err: errchan, |
||||
closingIn: make(chan chan bool, 1), |
||||
closingOut: make(chan chan bool, 1), |
||||
} |
||||
} |
||||
|
||||
func (self *Connection) Read() <-chan []byte { |
||||
return self.in |
||||
} |
||||
|
||||
func (self *Connection) Write() chan<- []byte { |
||||
return self.out |
||||
} |
||||
|
||||
func (self *Connection) Error() <-chan *PeerError { |
||||
return self.err |
||||
} |
||||
|
||||
func (self *Connection) startRead() { |
||||
payloads := make(chan []byte) |
||||
done := make(chan *PeerError) |
||||
pending := [][]byte{} |
||||
var head []byte |
||||
var wait time.Duration // initally 0 (no delay)
|
||||
read := time.After(wait * time.Millisecond) |
||||
|
||||
for { |
||||
// if pending empty, nil channel blocks
|
||||
var in chan []byte |
||||
if len(pending) > 0 { |
||||
in = self.in // enable send case
|
||||
head = pending[0] |
||||
} else { |
||||
in = nil |
||||
} |
||||
|
||||
select { |
||||
case <-read: |
||||
go self.read(payloads, done) |
||||
case err := <-done: |
||||
if err == nil { // no error but nothing to read
|
||||
if len(pending) < maxPendingQueueSize { |
||||
wait = 100 |
||||
} else if wait == 0 { |
||||
wait = 100 |
||||
} else { |
||||
wait = 2 * wait |
||||
} |
||||
} else { |
||||
self.err <- err // report error
|
||||
wait = 100 |
||||
} |
||||
read = time.After(wait * time.Millisecond) |
||||
case payload := <-payloads: |
||||
pending = append(pending, payload) |
||||
if len(pending) < maxPendingQueueSize { |
||||
wait = 0 |
||||
} else { |
||||
wait = 100 |
||||
} |
||||
read = time.After(wait * time.Millisecond) |
||||
case in <- head: |
||||
pending = pending[1:] |
||||
case errc := <-self.closingIn: |
||||
errc <- true |
||||
close(self.in) |
||||
return |
||||
} |
||||
|
||||
} |
||||
} |
||||
|
||||
func (self *Connection) startWrite() { |
||||
pending := [][]byte{} |
||||
done := make(chan *PeerError) |
||||
writing := false |
||||
for { |
||||
if len(pending) > 0 && !writing { |
||||
writing = true |
||||
go self.write(pending[0], done) |
||||
} |
||||
select { |
||||
case payload := <-self.out: |
||||
pending = append(pending, payload) |
||||
case err := <-done: |
||||
if err == nil { |
||||
pending = pending[1:] |
||||
writing = false |
||||
} else { |
||||
self.err <- err // report error
|
||||
} |
||||
case errc := <-self.closingOut: |
||||
errc <- true |
||||
close(self.out) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func pack(payload []byte) (packet []byte) { |
||||
length := ethutil.NumberToBytes(uint32(len(payload)), 32) |
||||
// return error if too long?
|
||||
// Write magic token and payload length (first 8 bytes)
|
||||
packet = append(magicToken, length...) |
||||
packet = append(packet, payload...) |
||||
return |
||||
} |
||||
|
||||
func avoidPanic(done chan *PeerError) { |
||||
if rec := recover(); rec != nil { |
||||
err := NewPeerError(MiscError, " %v", rec) |
||||
logger.Debugln(err) |
||||
done <- err |
||||
} |
||||
} |
||||
|
||||
func (self *Connection) write(payload []byte, done chan *PeerError) { |
||||
defer avoidPanic(done) |
||||
var err *PeerError |
||||
_, ok := self.conn.Write(pack(payload)) |
||||
if ok != nil { |
||||
err = NewPeerError(WriteError, " %v", ok) |
||||
logger.Debugln(err) |
||||
} |
||||
done <- err |
||||
} |
||||
|
||||
func (self *Connection) read(payloads chan []byte, done chan *PeerError) { |
||||
//defer avoidPanic(done)
|
||||
|
||||
partials := make(chan []byte, partialsQueueSize) |
||||
errc := make(chan *PeerError) |
||||
go self.readPartials(partials, errc) |
||||
|
||||
packet := []byte{} |
||||
length := 8 |
||||
start := true |
||||
var err *PeerError |
||||
out: |
||||
for { |
||||
// appends partials read via connection until packet is
|
||||
// - either parseable (>=8bytes)
|
||||
// - or complete (payload fully consumed)
|
||||
for len(packet) < length { |
||||
partial, ok := <-partials |
||||
if !ok { // partials channel is closed
|
||||
err = <-errc |
||||
if err == nil && len(packet) > 0 { |
||||
if start { |
||||
err = NewPeerError(PacketTooShort, "%v", packet) |
||||
} else { |
||||
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length) |
||||
} |
||||
} |
||||
break out |
||||
} |
||||
packet = append(packet, partial...) |
||||
} |
||||
if start { |
||||
// at least 8 bytes read, can validate packet
|
||||
if bytes.Compare(magicToken, packet[:4]) != 0 { |
||||
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4]) |
||||
break |
||||
} |
||||
length = int(ethutil.BytesToNumber(packet[4:8])) |
||||
packet = packet[8:] |
||||
|
||||
if length > 0 { |
||||
start = false // now consuming payload
|
||||
} else { //penalize peer but read on
|
||||
self.err <- NewPeerError(EmptyPayload, "") |
||||
length = 8 |
||||
} |
||||
} else { |
||||
// packet complete (payload fully consumed)
|
||||
payloads <- packet[:length] |
||||
packet = packet[length:] // resclice packet
|
||||
start = true |
||||
length = 8 |
||||
} |
||||
} |
||||
|
||||
// this stops partials read via the connection, should we?
|
||||
//if err != nil {
|
||||
// select {
|
||||
// case errc <- err
|
||||
// default:
|
||||
//}
|
||||
done <- err |
||||
} |
||||
|
||||
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) { |
||||
defer close(partials) |
||||
for { |
||||
// Give buffering some time
|
||||
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond)) |
||||
buffer := make([]byte, readBufferLength) |
||||
// read partial from connection
|
||||
bytesRead, err := self.conn.Read(buffer) |
||||
if err == nil || err.Error() == "EOF" { |
||||
if bytesRead > 0 { |
||||
partials <- buffer[:bytesRead] |
||||
} |
||||
if err != nil && err.Error() == "EOF" { |
||||
break |
||||
} |
||||
} else { |
||||
// unexpected error, report to errc
|
||||
err := NewPeerError(ReadError, " %v", err) |
||||
logger.Debugln(err) |
||||
errc <- err |
||||
return // will close partials channel
|
||||
} |
||||
} |
||||
close(errc) |
||||
} |
@ -1,222 +0,0 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
type TestNetworkConnection struct { |
||||
in chan []byte |
||||
current []byte |
||||
Out [][]byte |
||||
addr net.Addr |
||||
} |
||||
|
||||
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { |
||||
return &TestNetworkConnection{ |
||||
in: make(chan []byte), |
||||
current: []byte{}, |
||||
Out: [][]byte{}, |
||||
addr: addr, |
||||
} |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { |
||||
time.Sleep(latency) |
||||
for _, s := range packets { |
||||
self.in <- s |
||||
} |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { |
||||
if len(self.current) == 0 { |
||||
select { |
||||
case self.current = <-self.in: |
||||
default: |
||||
return 0, io.EOF |
||||
} |
||||
} |
||||
length := len(self.current) |
||||
if length > len(buff) { |
||||
copy(buff[:], self.current[:len(buff)]) |
||||
self.current = self.current[len(buff):] |
||||
return len(buff), nil |
||||
} else { |
||||
copy(buff[:length], self.current[:]) |
||||
self.current = []byte{} |
||||
return length, io.EOF |
||||
} |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { |
||||
self.Out = append(self.Out, buff) |
||||
fmt.Printf("net write %v\n%v\n", len(self.Out), buff) |
||||
return len(buff), nil |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) Close() (err error) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { |
||||
return self.addr |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { |
||||
return |
||||
} |
||||
|
||||
func setupConnection() (*Connection, *TestNetworkConnection) { |
||||
addr := &TestAddr{"test:30303"} |
||||
net := NewTestNetworkConnection(addr) |
||||
conn := NewConnection(net, NewPeerErrorChannel()) |
||||
conn.Open() |
||||
return conn, net |
||||
} |
||||
|
||||
func TestReadingNilPacket(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{}) |
||||
// time.Sleep(10 * time.Millisecond)
|
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
case err := <-conn.Error(): |
||||
t.Errorf("incorrect error %v", err) |
||||
default: |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingShortPacket(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{0}) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
case err := <-conn.Error(): |
||||
if err.Code != PacketTooShort { |
||||
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort) |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingInvalidPacket(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0}) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
case err := <-conn.Error(): |
||||
if err.Code != MagicTokenMismatch { |
||||
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch) |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingInvalidPayload(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0}) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
case err := <-conn.Error(): |
||||
if err.Code != PayloadTooShort { |
||||
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort) |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingEmptyPayload(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0}) |
||||
time.Sleep(10 * time.Millisecond) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
default: |
||||
} |
||||
select { |
||||
case err := <-conn.Error(): |
||||
code := err.Code |
||||
if code != EmptyPayload { |
||||
t.Errorf("incorrect error, expected EmptyPayload, got %v", code) |
||||
} |
||||
default: |
||||
t.Errorf("no error, expected EmptyPayload") |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingCompletePacket(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1}) |
||||
time.Sleep(10 * time.Millisecond) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
if bytes.Compare(packet, []byte{1}) != 0 { |
||||
t.Errorf("incorrect payload read") |
||||
} |
||||
case err := <-conn.Error(): |
||||
t.Errorf("incorrect error %v", err) |
||||
default: |
||||
t.Errorf("nothing read") |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingTwoCompletePackets(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1}) |
||||
|
||||
for i := 0; i < 2; i++ { |
||||
time.Sleep(10 * time.Millisecond) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
if bytes.Compare(packet, []byte{byte(i)}) != 0 { |
||||
t.Errorf("incorrect payload read") |
||||
} |
||||
case err := <-conn.Error(): |
||||
t.Errorf("incorrect error %v", err) |
||||
default: |
||||
t.Errorf("nothing read") |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestWriting(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
conn.Write() <- []byte{0} |
||||
time.Sleep(10 * time.Millisecond) |
||||
if len(net.Out) == 0 { |
||||
t.Errorf("no output") |
||||
} else { |
||||
out := net.Out[0] |
||||
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 { |
||||
t.Errorf("incorrect packet %v", out) |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
|
@ -1,75 +1,174 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
// "fmt"
|
||||
"bytes" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"math/big" |
||||
|
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
) |
||||
|
||||
type MsgCode uint8 |
||||
type MsgCode uint64 |
||||
|
||||
// Msg defines the structure of a p2p message.
|
||||
//
|
||||
// Note that a Msg can only be sent once since the Payload reader is
|
||||
// consumed during sending. It is not possible to create a Msg and
|
||||
// send it any number of times. If you want to reuse an encoded
|
||||
// structure, encode the payload into a byte array and create a
|
||||
// separate Msg with a bytes.Reader as Payload for each send.
|
||||
type Msg struct { |
||||
code MsgCode // this is the raw code as per adaptive msg code scheme
|
||||
data *ethutil.Value |
||||
encoded []byte |
||||
Code MsgCode |
||||
Size uint32 // size of the paylod
|
||||
Payload io.Reader |
||||
} |
||||
|
||||
func (self *Msg) Code() MsgCode { |
||||
return self.code |
||||
// NewMsg creates an RLP-encoded message with the given code.
|
||||
func NewMsg(code MsgCode, params ...interface{}) Msg { |
||||
buf := new(bytes.Buffer) |
||||
for _, p := range params { |
||||
buf.Write(ethutil.Encode(p)) |
||||
} |
||||
return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf} |
||||
} |
||||
|
||||
func (self *Msg) Data() *ethutil.Value { |
||||
return self.data |
||||
func encodePayload(params ...interface{}) []byte { |
||||
buf := new(bytes.Buffer) |
||||
for _, p := range params { |
||||
buf.Write(ethutil.Encode(p)) |
||||
} |
||||
return buf.Bytes() |
||||
} |
||||
|
||||
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { |
||||
// Data returns the decoded RLP payload items in a message.
|
||||
func (msg Msg) Data() (*ethutil.Value, error) { |
||||
// TODO: avoid copying when we have a better RLP decoder
|
||||
buf := new(bytes.Buffer) |
||||
var s []interface{} |
||||
if _, err := buf.ReadFrom(msg.Payload); err != nil { |
||||
return nil, err |
||||
} |
||||
for buf.Len() > 0 { |
||||
s = append(s, ethutil.DecodeWithReader(buf)) |
||||
} |
||||
return ethutil.NewValue(s), nil |
||||
} |
||||
|
||||
// // data := [][]interface{}{}
|
||||
// data := []interface{}{}
|
||||
// for _, value := range params {
|
||||
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
|
||||
// data = append(data, encodable.RlpValue())
|
||||
// } else if raw, ok := value.([]interface{}); ok {
|
||||
// data = append(data, raw)
|
||||
// } else {
|
||||
// // data = append(data, interface{}(raw))
|
||||
// err = fmt.Errorf("Unable to encode object of type %T", value)
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
return &Msg{ |
||||
code: code, |
||||
data: ethutil.NewValue(interface{}(params)), |
||||
}, nil |
||||
// Discard reads any remaining payload data into a black hole.
|
||||
func (msg Msg) Discard() error { |
||||
_, err := io.Copy(ioutil.Discard, msg.Payload) |
||||
return err |
||||
} |
||||
|
||||
var magicToken = []byte{34, 64, 8, 145} |
||||
|
||||
func writeMsg(w io.Writer, msg Msg) error { |
||||
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
|
||||
code := ethutil.Encode(uint32(msg.Code)) |
||||
listhdr := makeListHeader(msg.Size + uint32(len(code))) |
||||
payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size |
||||
|
||||
start := make([]byte, 8) |
||||
copy(start, magicToken) |
||||
binary.BigEndian.PutUint32(start[4:], payloadLen) |
||||
|
||||
for _, b := range [][]byte{start, listhdr, code} { |
||||
if _, err := w.Write(b); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
_, err := io.CopyN(w, msg.Payload, int64(msg.Size)) |
||||
return err |
||||
} |
||||
|
||||
func makeListHeader(length uint32) []byte { |
||||
if length < 56 { |
||||
return []byte{byte(length + 0xc0)} |
||||
} |
||||
enc := big.NewInt(int64(length)).Bytes() |
||||
lenb := byte(len(enc)) + 0xf7 |
||||
return append([]byte{lenb}, enc...) |
||||
} |
||||
|
||||
type byteReader interface { |
||||
io.Reader |
||||
io.ByteReader |
||||
} |
||||
|
||||
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { |
||||
value := ethutil.NewValueFromBytes(encoded) |
||||
// Type of message
|
||||
code := value.Get(0).Uint() |
||||
// Actual data
|
||||
data := value.SliceFrom(1) |
||||
// readMsg reads a message header.
|
||||
func readMsg(r byteReader) (msg Msg, err error) { |
||||
// read magic and payload size
|
||||
start := make([]byte, 8) |
||||
if _, err = io.ReadFull(r, start); err != nil { |
||||
return msg, NewPeerError(ReadError, "%v", err) |
||||
} |
||||
if !bytes.HasPrefix(start, magicToken) { |
||||
return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken) |
||||
} |
||||
size := binary.BigEndian.Uint32(start[4:]) |
||||
|
||||
msg = &Msg{ |
||||
code: MsgCode(code), |
||||
data: data, |
||||
// data: ethutil.NewValue(data),
|
||||
encoded: encoded, |
||||
// decode start of RLP message to get the message code
|
||||
_, hdrlen, err := readListHeader(r) |
||||
if err != nil { |
||||
return msg, err |
||||
} |
||||
return |
||||
code, codelen, err := readMsgCode(r) |
||||
if err != nil { |
||||
return msg, err |
||||
} |
||||
|
||||
func (self *Msg) Decode(offset MsgCode) { |
||||
self.code = self.code - offset |
||||
rlpsize := size - hdrlen - codelen |
||||
return Msg{ |
||||
Code: code, |
||||
Size: rlpsize, |
||||
Payload: io.LimitReader(r, int64(rlpsize)), |
||||
}, nil |
||||
} |
||||
|
||||
// encode takes an offset argument to implement adaptive message coding
|
||||
// the encoded message is memoized to make msgs relayed to several peers more efficient
|
||||
func (self *Msg) Encode(offset MsgCode) (res []byte) { |
||||
if len(self.encoded) == 0 { |
||||
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode() |
||||
self.encoded = res |
||||
// readListHeader reads an RLP list header from r.
|
||||
func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) { |
||||
b, err := r.ReadByte() |
||||
if err != nil { |
||||
return 0, 0, err |
||||
} |
||||
if b < 0xC0 { |
||||
return 0, 0, fmt.Errorf("expected list start byte >= 0xC0, got %x", b) |
||||
} else if b < 0xF7 { |
||||
len = uint64(b - 0xc0) |
||||
hdrlen = 1 |
||||
} else { |
||||
res = self.encoded |
||||
lenlen := b - 0xF7 |
||||
lenbuf := make([]byte, 8) |
||||
if _, err := io.ReadFull(r, lenbuf[8-lenlen:]); err != nil { |
||||
return 0, 0, err |
||||
} |
||||
len = binary.BigEndian.Uint64(lenbuf) |
||||
hdrlen = 1 + uint32(lenlen) |
||||
} |
||||
return len, hdrlen, nil |
||||
} |
||||
|
||||
// readUint reads an RLP-encoded unsigned integer from r.
|
||||
func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { |
||||
b, err := r.ReadByte() |
||||
if err != nil { |
||||
return 0, 0, err |
||||
} |
||||
if b < 0x80 { |
||||
return MsgCode(b), 1, nil |
||||
} else if b < 0x89 { // max length for uint64 is 8 bytes
|
||||
codelen = uint32(b - 0x80) |
||||
if codelen == 0 { |
||||
return 0, 1, nil |
||||
} |
||||
buf := make([]byte, 8) |
||||
if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { |
||||
return 0, 0, err |
||||
} |
||||
return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil |
||||
} |
||||
return |
||||
return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) |
||||
} |
||||
|
@ -1,38 +1,67 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
"io/ioutil" |
||||
"testing" |
||||
|
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
) |
||||
|
||||
func TestNewMsg(t *testing.T) { |
||||
msg, _ := NewMsg(3, 1, "000") |
||||
if msg.Code() != 3 { |
||||
t.Errorf("incorrect code %v", msg.Code()) |
||||
msg := NewMsg(3, 1, "000") |
||||
if msg.Code != 3 { |
||||
t.Errorf("incorrect code %d, want %d", msg.Code) |
||||
} |
||||
data0 := msg.Data().Get(0).Uint() |
||||
data1 := string(msg.Data().Get(1).Bytes()) |
||||
if data0 != 1 { |
||||
t.Errorf("incorrect data %v", data0) |
||||
if msg.Size != 5 { |
||||
t.Errorf("incorrect size %d, want %d", msg.Size, 5) |
||||
} |
||||
if data1 != "000" { |
||||
t.Errorf("incorrect data %v", data1) |
||||
pl, _ := ioutil.ReadAll(msg.Payload) |
||||
expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30} |
||||
if !bytes.Equal(pl, expect) { |
||||
t.Errorf("incorrect payload content, got %x, want %x", pl, expect) |
||||
} |
||||
} |
||||
|
||||
func TestEncodeDecodeMsg(t *testing.T) { |
||||
msg, _ := NewMsg(3, 1, "000") |
||||
encoded := msg.Encode(3) |
||||
msg, _ = NewMsgFromBytes(encoded) |
||||
msg.Decode(3) |
||||
if msg.Code() != 3 { |
||||
t.Errorf("incorrect code %v", msg.Code()) |
||||
} |
||||
data0 := msg.Data().Get(0).Uint() |
||||
data1 := msg.Data().Get(1).Str() |
||||
if data0 != 1 { |
||||
t.Errorf("incorrect data %v", data0) |
||||
} |
||||
if data1 != "000" { |
||||
t.Errorf("incorrect data %v", data1) |
||||
msg := NewMsg(3, 1, "000") |
||||
buf := new(bytes.Buffer) |
||||
if err := writeMsg(buf, msg); err != nil { |
||||
t.Fatalf("encodeMsg error: %v", err) |
||||
} |
||||
|
||||
t.Logf("encoded: %x", buf.Bytes()) |
||||
|
||||
decmsg, err := readMsg(buf) |
||||
if err != nil { |
||||
t.Fatalf("readMsg error: %v", err) |
||||
} |
||||
if decmsg.Code != 3 { |
||||
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) |
||||
} |
||||
if decmsg.Size != 5 { |
||||
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) |
||||
} |
||||
data, err := decmsg.Data() |
||||
if err != nil { |
||||
t.Fatalf("first payload item decode error: %v", err) |
||||
} |
||||
if v := data.Get(0).Uint(); v != 1 { |
||||
t.Errorf("incorrect data[0]: got %v, expected %d", v, 1) |
||||
} |
||||
if v := data.Get(1).Str(); v != "000" { |
||||
t.Errorf("incorrect data[1]: got %q, expected %q", v, "000") |
||||
} |
||||
} |
||||
|
||||
func TestDecodeRealMsg(t *testing.T) { |
||||
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") |
||||
msg, err := readMsg(bytes.NewReader(data)) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error: %v", err) |
||||
} |
||||
|
||||
if msg.Code != 0 { |
||||
t.Errorf("incorrect code %d, want %d", msg.Code, 0) |
||||
} |
||||
} |
||||
|
@ -1,220 +1,221 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
const ( |
||||
handlerTimeout = 1000 |
||||
) |
||||
type Handlers map[string]func() Protocol |
||||
|
||||
type Handlers map[string](func(p *Peer) Protocol) |
||||
type proto struct { |
||||
in chan Msg |
||||
maxcode, offset MsgCode |
||||
messenger *messenger |
||||
} |
||||
|
||||
type Messenger struct { |
||||
conn *Connection |
||||
func (rw *proto) WriteMsg(msg Msg) error { |
||||
if msg.Code >= rw.maxcode { |
||||
return NewPeerError(InvalidMsgCode, "not handled") |
||||
} |
||||
return rw.messenger.writeMsg(msg) |
||||
} |
||||
|
||||
func (rw *proto) ReadMsg() (Msg, error) { |
||||
msg, ok := <-rw.in |
||||
if !ok { |
||||
return msg, io.EOF |
||||
} |
||||
return msg, nil |
||||
} |
||||
|
||||
// eofSignal is used to 'lend' the network connection
|
||||
// to a protocol. when the protocol's read loop has read the
|
||||
// whole payload, the done channel is closed.
|
||||
type eofSignal struct { |
||||
wrapped io.Reader |
||||
eof chan struct{} |
||||
} |
||||
|
||||
func (r *eofSignal) Read(buf []byte) (int, error) { |
||||
n, err := r.wrapped.Read(buf) |
||||
if err != nil { |
||||
close(r.eof) // tell messenger that msg has been consumed
|
||||
} |
||||
return n, err |
||||
} |
||||
|
||||
// messenger represents a message-oriented peer connection.
|
||||
// It keeps track of the set of protocols understood
|
||||
// by the remote peer.
|
||||
type messenger struct { |
||||
peer *Peer |
||||
handlers Handlers |
||||
|
||||
// the mutex protects the connection
|
||||
// so only one protocol can write at a time.
|
||||
writeMu sync.Mutex |
||||
conn net.Conn |
||||
bufconn *bufio.ReadWriter |
||||
|
||||
protocolLock sync.RWMutex |
||||
protocols []Protocol |
||||
offsets []MsgCode // offsets for adaptive message idss
|
||||
protocolTable map[string]int |
||||
quit chan chan bool |
||||
err chan *PeerError |
||||
protocols map[string]*proto |
||||
offsets map[MsgCode]*proto |
||||
protoWG sync.WaitGroup |
||||
|
||||
err chan error |
||||
pulse chan bool |
||||
} |
||||
|
||||
func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { |
||||
baseProtocol := NewBaseProtocol(peer) |
||||
return &Messenger{ |
||||
func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger { |
||||
return &messenger{ |
||||
conn: conn, |
||||
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), |
||||
peer: peer, |
||||
offsets: []MsgCode{baseProtocol.Offset()}, |
||||
handlers: handlers, |
||||
protocols: []Protocol{baseProtocol}, |
||||
protocolTable: make(map[string]int), |
||||
protocols: make(map[string]*proto), |
||||
err: errchan, |
||||
pulse: make(chan bool, 1), |
||||
quit: make(chan chan bool, 1), |
||||
} |
||||
} |
||||
|
||||
func (self *Messenger) Start() { |
||||
self.conn.Open() |
||||
go self.messenger() |
||||
self.protocolLock.RLock() |
||||
defer self.protocolLock.RUnlock() |
||||
self.protocols[0].Start() |
||||
func (m *messenger) Start() { |
||||
m.protocols[""] = m.startProto(0, "", &baseProtocol{}) |
||||
go m.readLoop() |
||||
} |
||||
|
||||
func (self *Messenger) Stop() { |
||||
// close pulse to stop ping pong monitoring
|
||||
close(self.pulse) |
||||
self.protocolLock.RLock() |
||||
defer self.protocolLock.RUnlock() |
||||
for _, protocol := range self.protocols { |
||||
protocol.Stop() // could be parallel
|
||||
} |
||||
q := make(chan bool) |
||||
self.quit <- q |
||||
<-q |
||||
self.conn.Close() |
||||
func (m *messenger) Stop() { |
||||
m.conn.Close() |
||||
m.protoWG.Wait() |
||||
} |
||||
|
||||
func (self *Messenger) messenger() { |
||||
in := self.conn.Read() |
||||
const ( |
||||
// maximum amount of time allowed for reading a message
|
||||
msgReadTimeout = 5 * time.Second |
||||
|
||||
// messages smaller than this many bytes will be read at
|
||||
// once before passing them to a protocol.
|
||||
wholePayloadSize = 64 * 1024 |
||||
) |
||||
|
||||
func (m *messenger) readLoop() { |
||||
defer m.closeProtocols() |
||||
for { |
||||
select { |
||||
case payload, ok := <-in: |
||||
//dispatches message to the protocol asynchronously
|
||||
if ok { |
||||
go self.handle(payload) |
||||
} else { |
||||
return |
||||
} |
||||
case q := <-self.quit: |
||||
q <- true |
||||
m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) |
||||
msg, err := readMsg(m.bufconn) |
||||
if err != nil { |
||||
m.err <- err |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// handles each message by dispatching to the appropriate protocol
|
||||
// using adaptive message codes
|
||||
// this function is started as a separate go routine for each message
|
||||
// it waits for the protocol response
|
||||
// then encodes and sends outgoing messages to the connection's write channel
|
||||
func (self *Messenger) handle(payload []byte) { |
||||
// send ping to heartbeat channel signalling time of last message
|
||||
// select {
|
||||
// case self.pulse <- true:
|
||||
// default:
|
||||
// }
|
||||
self.pulse <- true |
||||
// initialise message from payload
|
||||
msg, err := NewMsgFromBytes(payload) |
||||
m.pulse <- true |
||||
proto, err := m.getProto(msg.Code) |
||||
if err != nil { |
||||
self.err <- NewPeerError(MiscError, " %v", err) |
||||
m.err <- err |
||||
return |
||||
} |
||||
// retrieves protocol based on message Code
|
||||
protocol, offset, peerErr := self.getProtocol(msg.Code()) |
||||
msg.Code -= proto.offset |
||||
if msg.Size <= wholePayloadSize { |
||||
// optimization: msg is small enough, read all
|
||||
// of it and move on to the next message
|
||||
buf, err := ioutil.ReadAll(msg.Payload) |
||||
if err != nil { |
||||
self.err <- peerErr |
||||
m.err <- err |
||||
return |
||||
} |
||||
// reset message code based on adaptive offset
|
||||
msg.Decode(offset) |
||||
// dispatches
|
||||
response := make(chan *Msg) |
||||
go protocol.HandleIn(msg, response) |
||||
// protocol reponse timeout to prevent leaks
|
||||
timer := time.After(handlerTimeout * time.Millisecond) |
||||
for { |
||||
select { |
||||
case outgoing, ok := <-response: |
||||
// we check if response channel is not closed
|
||||
if ok { |
||||
self.conn.Write() <- outgoing.Encode(offset) |
||||
msg.Payload = bytes.NewReader(buf) |
||||
proto.in <- msg |
||||
} else { |
||||
return |
||||
} |
||||
case <-timer: |
||||
return |
||||
pr := &eofSignal{msg.Payload, make(chan struct{})} |
||||
msg.Payload = pr |
||||
proto.in <- msg |
||||
<-pr.eof |
||||
} |
||||
} |
||||
} |
||||
|
||||
// negotiated protocols
|
||||
// stores offsets needed for adaptive message id scheme
|
||||
|
||||
// based on offsets set at handshake
|
||||
// get the right protocol to handle the message
|
||||
func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { |
||||
self.protocolLock.RLock() |
||||
defer self.protocolLock.RUnlock() |
||||
base := MsgCode(0) |
||||
for index, offset := range self.offsets { |
||||
if code < offset { |
||||
return self.protocols[index], base, nil |
||||
} |
||||
base = offset |
||||
func (m *messenger) closeProtocols() { |
||||
m.protocolLock.RLock() |
||||
for _, p := range m.protocols { |
||||
close(p.in) |
||||
} |
||||
return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) |
||||
m.protocolLock.RUnlock() |
||||
} |
||||
|
||||
func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { |
||||
fmt.Printf("pingpong keepalive started at %v", time.Now()) |
||||
|
||||
timer := time.After(timeout) |
||||
pinged := false |
||||
for { |
||||
select { |
||||
case _, ok := <-self.pulse: |
||||
if ok { |
||||
pinged = false |
||||
timer = time.After(timeout) |
||||
} else { |
||||
// pulse is closed, stop monitoring
|
||||
return |
||||
func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto { |
||||
proto := &proto{ |
||||
in: make(chan Msg), |
||||
offset: offset, |
||||
maxcode: impl.Offset(), |
||||
messenger: m, |
||||
} |
||||
case <-timer: |
||||
if pinged { |
||||
fmt.Printf("timeout at %v", time.Now()) |
||||
timeoutCallback() |
||||
return |
||||
} else { |
||||
fmt.Printf("pinged at %v", time.Now()) |
||||
pingCallback() |
||||
timer = time.After(gracePeriod) |
||||
pinged = true |
||||
m.protoWG.Add(1) |
||||
go func() { |
||||
if err := impl.Start(m.peer, proto); err != nil && err != io.EOF { |
||||
logger.Errorf("protocol %q error: %v\n", name, err) |
||||
m.err <- err |
||||
} |
||||
m.protoWG.Done() |
||||
}() |
||||
return proto |
||||
} |
||||
|
||||
// getProto finds the protocol responsible for handling
|
||||
// the given message code.
|
||||
func (m *messenger) getProto(code MsgCode) (*proto, error) { |
||||
m.protocolLock.RLock() |
||||
defer m.protocolLock.RUnlock() |
||||
for _, proto := range m.protocols { |
||||
if code >= proto.offset && code < proto.offset+proto.maxcode { |
||||
return proto, nil |
||||
} |
||||
} |
||||
return nil, NewPeerError(InvalidMsgCode, "%d", code) |
||||
} |
||||
|
||||
func (self *Messenger) AddProtocols(protocols []string) { |
||||
self.protocolLock.Lock() |
||||
defer self.protocolLock.Unlock() |
||||
i := len(self.offsets) |
||||
offset := self.offsets[i-1] |
||||
// setProtocols starts all subprotocols shared with the
|
||||
// remote peer. the protocols must be sorted alphabetically.
|
||||
func (m *messenger) setRemoteProtocols(protocols []string) { |
||||
m.protocolLock.Lock() |
||||
defer m.protocolLock.Unlock() |
||||
offset := baseProtocolOffset |
||||
for _, name := range protocols { |
||||
protocolFunc, ok := self.handlers[name] |
||||
if ok { |
||||
protocol := protocolFunc(self.peer) |
||||
self.protocolTable[name] = i |
||||
i++ |
||||
offset += protocol.Offset() |
||||
fmt.Println("offset ", name, offset) |
||||
|
||||
self.offsets = append(self.offsets, offset) |
||||
self.protocols = append(self.protocols, protocol) |
||||
protocol.Start() |
||||
} else { |
||||
fmt.Println("no ", name) |
||||
// protocol not handled
|
||||
protocolFunc, ok := m.handlers[name] |
||||
if !ok { |
||||
continue // not handled
|
||||
} |
||||
inst := protocolFunc() |
||||
m.protocols[name] = m.startProto(offset, name, inst) |
||||
offset += inst.Offset() |
||||
} |
||||
} |
||||
|
||||
func (self *Messenger) Write(protocol string, msg *Msg) error { |
||||
self.protocolLock.RLock() |
||||
defer self.protocolLock.RUnlock() |
||||
i := 0 |
||||
offset := MsgCode(0) |
||||
if len(protocol) > 0 { |
||||
var ok bool |
||||
i, ok = self.protocolTable[protocol] |
||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||
func (m *messenger) writeProtoMsg(protoName string, msg Msg) error { |
||||
m.protocolLock.RLock() |
||||
proto, ok := m.protocols[protoName] |
||||
m.protocolLock.RUnlock() |
||||
if !ok { |
||||
return fmt.Errorf("protocol %v not handled by peer", protocol) |
||||
return fmt.Errorf("protocol %s not handled by peer", protoName) |
||||
} |
||||
if msg.Code >= proto.maxcode { |
||||
return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) |
||||
} |
||||
offset = self.offsets[i-1] |
||||
msg.Code += proto.offset |
||||
return m.writeMsg(msg) |
||||
} |
||||
handler := self.protocols[i] |
||||
// checking if protocol status/caps allows the message to be sent out
|
||||
if handler.HandleOut(msg) { |
||||
self.conn.Write() <- msg.Encode(offset) |
||||
|
||||
// writeMsg writes a message to the connection.
|
||||
func (m *messenger) writeMsg(msg Msg) error { |
||||
m.writeMu.Lock() |
||||
defer m.writeMu.Unlock() |
||||
if err := writeMsg(m.bufconn, msg); err != nil { |
||||
return err |
||||
} |
||||
return nil |
||||
return m.bufconn.Flush() |
||||
} |
||||
|
@ -1,147 +1,157 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
// "fmt"
|
||||
"bytes" |
||||
"bufio" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"net" |
||||
"os" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
) |
||||
|
||||
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { |
||||
errchan := NewPeerErrorChannel() |
||||
addr := &TestAddr{"test:30303"} |
||||
net := NewTestNetworkConnection(addr) |
||||
conn := NewConnection(net, errchan) |
||||
mess := NewMessenger(nil, conn, errchan, handlers) |
||||
mess.Start() |
||||
return net, errchan, mess |
||||
func init() { |
||||
ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel)) |
||||
} |
||||
|
||||
type TestProtocol struct { |
||||
Msgs []*Msg |
||||
func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { |
||||
conn1, conn2 := net.Pipe() |
||||
id := NewSimpleClientIdentity("test", "0", "0", "public key") |
||||
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist()) |
||||
peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0) |
||||
return conn2, peer, peer.messenger |
||||
} |
||||
|
||||
func (self *TestProtocol) Start() { |
||||
func performTestHandshake(r *bufio.Reader, w io.Writer) error { |
||||
// read remote handshake
|
||||
msg, err := readMsg(r) |
||||
if err != nil { |
||||
return fmt.Errorf("read error: %v", err) |
||||
} |
||||
|
||||
func (self *TestProtocol) Stop() { |
||||
if msg.Code != handshakeMsg { |
||||
return fmt.Errorf("first message should be handshake, got %x", msg.Code) |
||||
} |
||||
|
||||
func (self *TestProtocol) Offset() MsgCode { |
||||
return MsgCode(5) |
||||
if err := msg.Discard(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) { |
||||
self.Msgs = append(self.Msgs, msg) |
||||
close(response) |
||||
// send empty handshake
|
||||
pubkey := make([]byte, 64) |
||||
msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey) |
||||
return writeMsg(w, msg) |
||||
} |
||||
|
||||
func (self *TestProtocol) HandleOut(msg *Msg) bool { |
||||
if msg.Code() > 3 { |
||||
return false |
||||
} else { |
||||
return true |
||||
} |
||||
type testMsg struct { |
||||
code MsgCode |
||||
data *ethutil.Value |
||||
} |
||||
|
||||
func (self *TestProtocol) Name() string { |
||||
return "a" |
||||
type testProto struct { |
||||
recv chan testMsg |
||||
} |
||||
|
||||
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte { |
||||
msg, _ := NewMsg(code, params...) |
||||
encoded := msg.Encode(offset) |
||||
packet := []byte{34, 64, 8, 145} |
||||
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...) |
||||
return append(packet, encoded...) |
||||
func (*testProto) Offset() MsgCode { return 5 } |
||||
|
||||
func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error { |
||||
return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error { |
||||
logger.Debugf("testprotocol got msg: %d\n", code) |
||||
tp.recv <- testMsg{code, data} |
||||
return nil |
||||
}) |
||||
} |
||||
|
||||
func TestRead(t *testing.T) { |
||||
handlers := make(Handlers) |
||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol } |
||||
net, _, mess := setupMessenger(handlers) |
||||
mess.AddProtocols([]string{"a"}) |
||||
defer mess.Stop() |
||||
wait := 1 * time.Millisecond |
||||
packet := Packet(16, 1, uint32(1), "000") |
||||
go net.In(0, packet) |
||||
time.Sleep(wait) |
||||
if len(testProtocol.Msgs) != 1 { |
||||
t.Errorf("msg not relayed to correct protocol") |
||||
} else { |
||||
if testProtocol.Msgs[0].Code() != 1 { |
||||
t.Errorf("incorrect msg code relayed to protocol") |
||||
testProtocol := &testProto{make(chan testMsg)} |
||||
handlers := Handlers{"a": func() Protocol { return testProtocol }} |
||||
net, peer, mess := setupMessenger(handlers) |
||||
bufr := bufio.NewReader(net) |
||||
defer peer.Stop() |
||||
if err := performTestHandshake(bufr, net); err != nil { |
||||
t.Fatalf("handshake failed: %v", err) |
||||
} |
||||
|
||||
mess.setRemoteProtocols([]string{"a"}) |
||||
writeMsg(net, NewMsg(17, uint32(1), "000")) |
||||
select { |
||||
case msg := <-testProtocol.recv: |
||||
if msg.code != 1 { |
||||
t.Errorf("incorrect msg code %d relayed to protocol", msg.code) |
||||
} |
||||
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} |
||||
if !reflect.DeepEqual(msg.data.Slice(), expdata) { |
||||
t.Errorf("incorrect msg data %#v", msg.data.Slice()) |
||||
} |
||||
case <-time.After(2 * time.Second): |
||||
t.Errorf("receive timeout") |
||||
} |
||||
} |
||||
|
||||
func TestWrite(t *testing.T) { |
||||
func TestWriteProtoMsg(t *testing.T) { |
||||
handlers := make(Handlers) |
||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol } |
||||
net, _, mess := setupMessenger(handlers) |
||||
mess.AddProtocols([]string{"a"}) |
||||
defer mess.Stop() |
||||
wait := 1 * time.Millisecond |
||||
msg, _ := NewMsg(3, uint32(1), "000") |
||||
err := mess.Write("b", msg) |
||||
if err == nil { |
||||
t.Errorf("expect error for unknown protocol") |
||||
} |
||||
err = mess.Write("a", msg) |
||||
if err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} else { |
||||
time.Sleep(wait) |
||||
if len(net.Out) != 1 { |
||||
t.Errorf("msg not written") |
||||
testProtocol := &testProto{recv: make(chan testMsg, 1)} |
||||
handlers["a"] = func() Protocol { return testProtocol } |
||||
net, peer, mess := setupMessenger(handlers) |
||||
defer peer.Stop() |
||||
bufr := bufio.NewReader(net) |
||||
if err := performTestHandshake(bufr, net); err != nil { |
||||
t.Fatalf("handshake failed: %v", err) |
||||
} |
||||
mess.setRemoteProtocols([]string{"a"}) |
||||
|
||||
// test write errors
|
||||
if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil { |
||||
t.Errorf("expected error for unknown protocol, got nil") |
||||
} |
||||
if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil { |
||||
t.Errorf("expected error for out-of-range msg code, got nil") |
||||
} else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode { |
||||
t.Errorf("wrong error for out-of-range msg code, got %#v") |
||||
} |
||||
|
||||
// test succcessful write
|
||||
read, readerr := make(chan Msg), make(chan error) |
||||
go func() { |
||||
if msg, err := readMsg(bufr); err != nil { |
||||
readerr <- err |
||||
} else { |
||||
out := net.Out[0] |
||||
packet := Packet(16, 3, uint32(1), "000") |
||||
if bytes.Compare(out, packet) != 0 { |
||||
t.Errorf("incorrect packet %v", out) |
||||
read <- msg |
||||
} |
||||
}() |
||||
if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} |
||||
select { |
||||
case msg := <-read: |
||||
if msg.Code != 19 { |
||||
t.Errorf("wrong code, got %d, expected %d", msg.Code, 19) |
||||
} |
||||
msg.Discard() |
||||
case err := <-readerr: |
||||
t.Errorf("read error: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestPulse(t *testing.T) { |
||||
net, _, mess := setupMessenger(make(Handlers)) |
||||
defer mess.Stop() |
||||
ping := false |
||||
timeout := false |
||||
pingTimeout := 10 * time.Millisecond |
||||
gracePeriod := 200 * time.Millisecond |
||||
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true }) |
||||
net.In(0, Packet(0, 1)) |
||||
if ping { |
||||
t.Errorf("ping sent too early") |
||||
} |
||||
time.Sleep(pingTimeout + 100*time.Millisecond) |
||||
if !ping { |
||||
t.Errorf("no ping sent after timeout") |
||||
} |
||||
if timeout { |
||||
t.Errorf("timeout too early") |
||||
} |
||||
ping = false |
||||
net.In(0, Packet(0, 1)) |
||||
time.Sleep(pingTimeout + 100*time.Millisecond) |
||||
if !ping { |
||||
t.Errorf("no ping sent after timeout") |
||||
} |
||||
if timeout { |
||||
t.Errorf("timeout too early") |
||||
} |
||||
ping = false |
||||
time.Sleep(gracePeriod) |
||||
if ping { |
||||
t.Errorf("ping called twice") |
||||
} |
||||
if !timeout { |
||||
t.Errorf("no timeout after grace period") |
||||
net, peer, _ := setupMessenger(nil) |
||||
defer peer.Stop() |
||||
bufr := bufio.NewReader(net) |
||||
if err := performTestHandshake(bufr, net); err != nil { |
||||
t.Fatalf("handshake failed: %v", err) |
||||
} |
||||
|
||||
before := time.Now() |
||||
msg, err := readMsg(bufr) |
||||
if err != nil { |
||||
t.Fatalf("read error: %v", err) |
||||
} |
||||
after := time.Now() |
||||
if msg.Code != pingMsg { |
||||
t.Errorf("expected ping message, got %x", msg.Code) |
||||
} |
||||
if d := after.Sub(before); d < pingTimeout { |
||||
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout) |
||||
} |
||||
} |
||||
|
@ -1,96 +1,90 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
// "net"
|
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestPeer(t *testing.T) { |
||||
handlers := make(Handlers) |
||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
||||
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } |
||||
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } |
||||
addr := &TestAddr{"test:30"} |
||||
conn := NewTestNetworkConnection(addr) |
||||
_, server := SetupTestServer(handlers) |
||||
server.Handshake() |
||||
peer := NewPeer(conn, addr, true, server) |
||||
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
||||
peer.Start() |
||||
defer peer.Stop() |
||||
time.Sleep(2 * time.Millisecond) |
||||
if len(conn.Out) != 1 { |
||||
t.Errorf("handshake not sent") |
||||
} else { |
||||
out := conn.Out[0] |
||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) |
||||
if bytes.Compare(out, packet) != 0 { |
||||
t.Errorf("incorrect handshake packet %v != %v", out, packet) |
||||
} |
||||
} |
||||
// func TestPeer(t *testing.T) {
|
||||
// handlers := make(Handlers)
|
||||
// testProtocol := &TestProtocol{recv: make(chan testMsg)}
|
||||
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
||||
// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
|
||||
// addr := &TestAddr{"test:30"}
|
||||
// conn := NewTestNetworkConnection(addr)
|
||||
// _, server := SetupTestServer(handlers)
|
||||
// server.Handshake()
|
||||
// peer := NewPeer(conn, addr, true, server)
|
||||
// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
||||
// peer.Start()
|
||||
// defer peer.Stop()
|
||||
// time.Sleep(2 * time.Millisecond)
|
||||
// if len(conn.Out) != 1 {
|
||||
// t.Errorf("handshake not sent")
|
||||
// } else {
|
||||
// out := conn.Out[0]
|
||||
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
|
||||
// if bytes.Compare(out, packet) != 0 {
|
||||
// t.Errorf("incorrect handshake packet %v != %v", out, packet)
|
||||
// }
|
||||
// }
|
||||
|
||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) |
||||
conn.In(0, packet) |
||||
time.Sleep(10 * time.Millisecond) |
||||
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
|
||||
// conn.In(0, packet)
|
||||
// time.Sleep(10 * time.Millisecond)
|
||||
|
||||
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) |
||||
if pro.state != handshakeReceived { |
||||
t.Errorf("handshake not received") |
||||
} |
||||
if peer.Port != 30 { |
||||
t.Errorf("port incorrectly set") |
||||
} |
||||
if peer.Id != "peer" { |
||||
t.Errorf("id incorrectly set") |
||||
} |
||||
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { |
||||
t.Errorf("pubkey incorrectly set") |
||||
} |
||||
fmt.Println(peer.Caps) |
||||
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { |
||||
t.Errorf("protocols incorrectly set") |
||||
} |
||||
// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
|
||||
// if pro.state != handshakeReceived {
|
||||
// t.Errorf("handshake not received")
|
||||
// }
|
||||
// if peer.Port != 30 {
|
||||
// t.Errorf("port incorrectly set")
|
||||
// }
|
||||
// if peer.Id != "peer" {
|
||||
// t.Errorf("id incorrectly set")
|
||||
// }
|
||||
// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||
// t.Errorf("pubkey incorrectly set")
|
||||
// }
|
||||
// fmt.Println(peer.Caps)
|
||||
// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
|
||||
// t.Errorf("protocols incorrectly set")
|
||||
// }
|
||||
|
||||
msg, _ := NewMsg(3) |
||||
err := peer.Write("aaa", msg) |
||||
if err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} else { |
||||
time.Sleep(1 * time.Millisecond) |
||||
if len(conn.Out) != 2 { |
||||
t.Errorf("msg not written") |
||||
} else { |
||||
out := conn.Out[1] |
||||
packet := Packet(16, 3) |
||||
if bytes.Compare(out, packet) != 0 { |
||||
t.Errorf("incorrect packet %v != %v", out, packet) |
||||
} |
||||
} |
||||
} |
||||
// msg := NewMsg(3)
|
||||
// err := peer.Write("aaa", msg)
|
||||
// if err != nil {
|
||||
// t.Errorf("expect no error for known protocol: %v", err)
|
||||
// } else {
|
||||
// time.Sleep(1 * time.Millisecond)
|
||||
// if len(conn.Out) != 2 {
|
||||
// t.Errorf("msg not written")
|
||||
// } else {
|
||||
// out := conn.Out[1]
|
||||
// packet := Packet(16, 3)
|
||||
// if bytes.Compare(out, packet) != 0 {
|
||||
// t.Errorf("incorrect packet %v != %v", out, packet)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
msg, _ = NewMsg(2) |
||||
err = peer.Write("ccc", msg) |
||||
if err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} else { |
||||
time.Sleep(1 * time.Millisecond) |
||||
if len(conn.Out) != 3 { |
||||
t.Errorf("msg not written") |
||||
} else { |
||||
out := conn.Out[2] |
||||
packet := Packet(21, 2) |
||||
if bytes.Compare(out, packet) != 0 { |
||||
t.Errorf("incorrect packet %v != %v", out, packet) |
||||
} |
||||
} |
||||
} |
||||
// msg = NewMsg(2)
|
||||
// err = peer.Write("ccc", msg)
|
||||
// if err != nil {
|
||||
// t.Errorf("expect no error for known protocol: %v", err)
|
||||
// } else {
|
||||
// time.Sleep(1 * time.Millisecond)
|
||||
// if len(conn.Out) != 3 {
|
||||
// t.Errorf("msg not written")
|
||||
// } else {
|
||||
// out := conn.Out[2]
|
||||
// packet := Packet(21, 2)
|
||||
// if bytes.Compare(out, packet) != 0 {
|
||||
// t.Errorf("incorrect packet %v != %v", out, packet)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
err = peer.Write("bbb", msg) |
||||
time.Sleep(1 * time.Millisecond) |
||||
if err == nil { |
||||
t.Errorf("expect error for unknown protocol") |
||||
} |
||||
} |
||||
// err = peer.Write("bbb", msg)
|
||||
// time.Sleep(1 * time.Millisecond)
|
||||
// if err == nil {
|
||||
// t.Errorf("expect error for unknown protocol")
|
||||
// }
|
||||
// }
|
||||
|
Loading…
Reference in new issue