Merge branch 'fjl-feature/p2p-protocol-interface' into poc8

pull/206/head
obscuren 10 years ago
commit 195b2d2ebd
  1. 6
      p2p/client_identity.go
  2. 275
      p2p/connection.go
  3. 222
      p2p/connection_test.go
  4. 186
      p2p/message.go
  5. 78
      p2p/message_test.go
  6. 220
      p2p/messenger.go
  7. 147
      p2p/messenger_test.go
  8. 34
      p2p/natpmp.go
  9. 198
      p2p/natupnp.go
  10. 196
      p2p/network.go
  11. 494
      p2p/peer.go
  12. 152
      p2p/peer_error.go
  13. 101
      p2p/peer_error_handler.go
  14. 34
      p2p/peer_error_handler_test.go
  15. 293
      p2p/peer_test.go
  16. 453
      p2p/protocol.go
  17. 723
      p2p/server.go
  18. 309
      p2p/server_test.go
  19. 28
      p2p/testlog_test.go
  20. 40
      p2p/testpoc7.go
  21. 81
      rlp/decode.go
  22. 89
      rlp/decode_test.go

@ -5,10 +5,10 @@ import (
"runtime" "runtime"
) )
// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. // ClientIdentity represents the identity of a peer.
type ClientIdentity interface { type ClientIdentity interface {
String() string String() string // human readable identity
Pubkey() []byte Pubkey() []byte // 512-bit public key
} }
type SimpleClientIdentity struct { type SimpleClientIdentity struct {

@ -1,275 +0,0 @@
package p2p
import (
"bytes"
// "fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
type Connection struct {
conn net.Conn
// conn NetworkConnection
timeout time.Duration
in chan []byte
out chan []byte
err chan *PeerError
closingIn chan chan bool
closingOut chan chan bool
}
// const readBufferLength = 2 //for testing
const readBufferLength = 1440
const partialsQueueSize = 10
const maxPendingQueueSize = 1
const defaultTimeout = 500
var magicToken = []byte{34, 64, 8, 145}
func (self *Connection) Open() {
go self.startRead()
go self.startWrite()
}
func (self *Connection) Close() {
self.closeIn()
self.closeOut()
}
func (self *Connection) closeIn() {
errc := make(chan bool)
self.closingIn <- errc
<-errc
}
func (self *Connection) closeOut() {
errc := make(chan bool)
self.closingOut <- errc
<-errc
}
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection {
return &Connection{
conn: conn,
timeout: defaultTimeout,
in: make(chan []byte),
out: make(chan []byte),
err: errchan,
closingIn: make(chan chan bool, 1),
closingOut: make(chan chan bool, 1),
}
}
func (self *Connection) Read() <-chan []byte {
return self.in
}
func (self *Connection) Write() chan<- []byte {
return self.out
}
func (self *Connection) Error() <-chan *PeerError {
return self.err
}
func (self *Connection) startRead() {
payloads := make(chan []byte)
done := make(chan *PeerError)
pending := [][]byte{}
var head []byte
var wait time.Duration // initally 0 (no delay)
read := time.After(wait * time.Millisecond)
for {
// if pending empty, nil channel blocks
var in chan []byte
if len(pending) > 0 {
in = self.in // enable send case
head = pending[0]
} else {
in = nil
}
select {
case <-read:
go self.read(payloads, done)
case err := <-done:
if err == nil { // no error but nothing to read
if len(pending) < maxPendingQueueSize {
wait = 100
} else if wait == 0 {
wait = 100
} else {
wait = 2 * wait
}
} else {
self.err <- err // report error
wait = 100
}
read = time.After(wait * time.Millisecond)
case payload := <-payloads:
pending = append(pending, payload)
if len(pending) < maxPendingQueueSize {
wait = 0
} else {
wait = 100
}
read = time.After(wait * time.Millisecond)
case in <- head:
pending = pending[1:]
case errc := <-self.closingIn:
errc <- true
close(self.in)
return
}
}
}
func (self *Connection) startWrite() {
pending := [][]byte{}
done := make(chan *PeerError)
writing := false
for {
if len(pending) > 0 && !writing {
writing = true
go self.write(pending[0], done)
}
select {
case payload := <-self.out:
pending = append(pending, payload)
case err := <-done:
if err == nil {
pending = pending[1:]
writing = false
} else {
self.err <- err // report error
}
case errc := <-self.closingOut:
errc <- true
close(self.out)
return
}
}
}
func pack(payload []byte) (packet []byte) {
length := ethutil.NumberToBytes(uint32(len(payload)), 32)
// return error if too long?
// Write magic token and payload length (first 8 bytes)
packet = append(magicToken, length...)
packet = append(packet, payload...)
return
}
func avoidPanic(done chan *PeerError) {
if rec := recover(); rec != nil {
err := NewPeerError(MiscError, " %v", rec)
logger.Debugln(err)
done <- err
}
}
func (self *Connection) write(payload []byte, done chan *PeerError) {
defer avoidPanic(done)
var err *PeerError
_, ok := self.conn.Write(pack(payload))
if ok != nil {
err = NewPeerError(WriteError, " %v", ok)
logger.Debugln(err)
}
done <- err
}
func (self *Connection) read(payloads chan []byte, done chan *PeerError) {
//defer avoidPanic(done)
partials := make(chan []byte, partialsQueueSize)
errc := make(chan *PeerError)
go self.readPartials(partials, errc)
packet := []byte{}
length := 8
start := true
var err *PeerError
out:
for {
// appends partials read via connection until packet is
// - either parseable (>=8bytes)
// - or complete (payload fully consumed)
for len(packet) < length {
partial, ok := <-partials
if !ok { // partials channel is closed
err = <-errc
if err == nil && len(packet) > 0 {
if start {
err = NewPeerError(PacketTooShort, "%v", packet)
} else {
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length)
}
}
break out
}
packet = append(packet, partial...)
}
if start {
// at least 8 bytes read, can validate packet
if bytes.Compare(magicToken, packet[:4]) != 0 {
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4])
break
}
length = int(ethutil.BytesToNumber(packet[4:8]))
packet = packet[8:]
if length > 0 {
start = false // now consuming payload
} else { //penalize peer but read on
self.err <- NewPeerError(EmptyPayload, "")
length = 8
}
} else {
// packet complete (payload fully consumed)
payloads <- packet[:length]
packet = packet[length:] // resclice packet
start = true
length = 8
}
}
// this stops partials read via the connection, should we?
//if err != nil {
// select {
// case errc <- err
// default:
//}
done <- err
}
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) {
defer close(partials)
for {
// Give buffering some time
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond))
buffer := make([]byte, readBufferLength)
// read partial from connection
bytesRead, err := self.conn.Read(buffer)
if err == nil || err.Error() == "EOF" {
if bytesRead > 0 {
partials <- buffer[:bytesRead]
}
if err != nil && err.Error() == "EOF" {
break
}
} else {
// unexpected error, report to errc
err := NewPeerError(ReadError, " %v", err)
logger.Debugln(err)
errc <- err
return // will close partials channel
}
}
close(errc)
}

@ -1,222 +0,0 @@
package p2p
import (
"bytes"
"fmt"
"io"
"net"
"testing"
"time"
)
type TestNetworkConnection struct {
in chan []byte
current []byte
Out [][]byte
addr net.Addr
}
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
return &TestNetworkConnection{
in: make(chan []byte),
current: []byte{},
Out: [][]byte{},
addr: addr,
}
}
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
time.Sleep(latency)
for _, s := range packets {
self.in <- s
}
}
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
if len(self.current) == 0 {
select {
case self.current = <-self.in:
default:
return 0, io.EOF
}
}
length := len(self.current)
if length > len(buff) {
copy(buff[:], self.current[:len(buff)])
self.current = self.current[len(buff):]
return len(buff), nil
} else {
copy(buff[:length], self.current[:])
self.current = []byte{}
return length, io.EOF
}
}
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
self.Out = append(self.Out, buff)
fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
return len(buff), nil
}
func (self *TestNetworkConnection) Close() (err error) {
return
}
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
return
}
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
return self.addr
}
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
return
}
func setupConnection() (*Connection, *TestNetworkConnection) {
addr := &TestAddr{"test:30303"}
net := NewTestNetworkConnection(addr)
conn := NewConnection(net, NewPeerErrorChannel())
conn.Open()
return conn, net
}
func TestReadingNilPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{})
// time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
}
conn.Close()
}
func TestReadingShortPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != PacketTooShort {
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
}
}
conn.Close()
}
func TestReadingInvalidPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != MagicTokenMismatch {
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
}
}
conn.Close()
}
func TestReadingInvalidPayload(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != PayloadTooShort {
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
}
}
conn.Close()
}
func TestReadingEmptyPayload(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
default:
}
select {
case err := <-conn.Error():
code := err.Code
if code != EmptyPayload {
t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
}
default:
t.Errorf("no error, expected EmptyPayload")
}
conn.Close()
}
func TestReadingCompletePacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
if bytes.Compare(packet, []byte{1}) != 0 {
t.Errorf("incorrect payload read")
}
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
t.Errorf("nothing read")
}
conn.Close()
}
func TestReadingTwoCompletePackets(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
for i := 0; i < 2; i++ {
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
if bytes.Compare(packet, []byte{byte(i)}) != 0 {
t.Errorf("incorrect payload read")
}
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
t.Errorf("nothing read")
}
}
conn.Close()
}
func TestWriting(t *testing.T) {
conn, net := setupConnection()
conn.Write() <- []byte{0}
time.Sleep(10 * time.Millisecond)
if len(net.Out) == 0 {
t.Errorf("no output")
} else {
out := net.Out[0]
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
t.Errorf("incorrect packet %v", out)
}
}
conn.Close()
}
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243

@ -1,75 +1,155 @@
package p2p package p2p
import ( import (
// "fmt" "bytes"
"encoding/binary"
"io"
"io/ioutil"
"math/big"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp"
) )
type MsgCode uint8 // Msg defines the structure of a p2p message.
//
// Note that a Msg can only be sent once since the Payload reader is
// consumed during sending. It is not possible to create a Msg and
// send it any number of times. If you want to reuse an encoded
// structure, encode the payload into a byte array and create a
// separate Msg with a bytes.Reader as Payload for each send.
type Msg struct { type Msg struct {
code MsgCode // this is the raw code as per adaptive msg code scheme Code uint64
data *ethutil.Value Size uint32 // size of the paylod
encoded []byte Payload io.Reader
}
// NewMsg creates an RLP-encoded message with the given code.
func NewMsg(code uint64, params ...interface{}) Msg {
buf := new(bytes.Buffer)
for _, p := range params {
buf.Write(ethutil.Encode(p))
}
return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
}
func encodePayload(params ...interface{}) []byte {
buf := new(bytes.Buffer)
for _, p := range params {
buf.Write(ethutil.Encode(p))
}
return buf.Bytes()
} }
func (self *Msg) Code() MsgCode { // Decode parse the RLP content of a message into
return self.code // the given value, which must be a pointer.
//
// For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error {
s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
return s.Decode(val)
} }
func (self *Msg) Data() *ethutil.Value { // Discard reads any remaining payload data into a black hole.
return self.data func (msg Msg) Discard() error {
_, err := io.Copy(ioutil.Discard, msg.Payload)
return err
} }
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { type MsgReader interface {
ReadMsg() (Msg, error)
// // data := [][]interface{}{}
// data := []interface{}{}
// for _, value := range params {
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
// data = append(data, encodable.RlpValue())
// } else if raw, ok := value.([]interface{}); ok {
// data = append(data, raw)
// } else {
// // data = append(data, interface{}(raw))
// err = fmt.Errorf("Unable to encode object of type %T", value)
// return
// }
// }
return &Msg{
code: code,
data: ethutil.NewValue(interface{}(params)),
}, nil
} }
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { type MsgWriter interface {
value := ethutil.NewValueFromBytes(encoded) // WriteMsg sends an existing message.
// Type of message // The Payload reader of the message is consumed.
code := value.Get(0).Uint() // Note that messages can be sent only once.
// Actual data WriteMsg(Msg) error
data := value.SliceFrom(1)
// EncodeMsg writes an RLP-encoded message with the given
msg = &Msg{ // code and data elements.
code: MsgCode(code), EncodeMsg(code uint64, data ...interface{}) error
data: data, }
// data: ethutil.NewValue(data),
encoded: encoded, // MsgReadWriter provides reading and writing of encoded messages.
type MsgReadWriter interface {
MsgReader
MsgWriter
}
var magicToken = []byte{34, 64, 8, 145}
func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code))
listhdr := makeListHeader(msg.Size + uint32(len(code)))
payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size
start := make([]byte, 8)
copy(start, magicToken)
binary.BigEndian.PutUint32(start[4:], payloadLen)
for _, b := range [][]byte{start, listhdr, code} {
if _, err := w.Write(b); err != nil {
return err
}
}
_, err := io.CopyN(w, msg.Payload, int64(msg.Size))
return err
}
func makeListHeader(length uint32) []byte {
if length < 56 {
return []byte{byte(length + 0xc0)}
}
enc := big.NewInt(int64(length)).Bytes()
lenb := byte(len(enc)) + 0xf7
return append([]byte{lenb}, enc...)
}
// readMsg reads a message header from r.
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
// read magic and payload size
start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil {
return msg, newPeerError(errRead, "%v", err)
}
if !bytes.HasPrefix(start, magicToken) {
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
}
size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code
posr := &postrack{r, 0}
s := rlp.NewStream(posr)
if _, err := s.List(); err != nil {
return msg, err
} }
return code, err := s.Uint()
if err != nil {
return msg, err
}
payloadsize := size - posr.p
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
}
// postrack wraps an rlp.ByteReader with a position counter.
type postrack struct {
r rlp.ByteReader
p uint32
} }
func (self *Msg) Decode(offset MsgCode) { func (r *postrack) Read(buf []byte) (int, error) {
self.code = self.code - offset n, err := r.r.Read(buf)
r.p += uint32(n)
return n, err
} }
// encode takes an offset argument to implement adaptive message coding func (r *postrack) ReadByte() (byte, error) {
// the encoded message is memoized to make msgs relayed to several peers more efficient b, err := r.r.ReadByte()
func (self *Msg) Encode(offset MsgCode) (res []byte) { if err == nil {
if len(self.encoded) == 0 { r.p++
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode()
self.encoded = res
} else {
res = self.encoded
} }
return return b, err
} }

@ -1,38 +1,70 @@
package p2p package p2p
import ( import (
"bytes"
"io/ioutil"
"testing" "testing"
"github.com/ethereum/go-ethereum/ethutil"
) )
func TestNewMsg(t *testing.T) { func TestNewMsg(t *testing.T) {
msg, _ := NewMsg(3, 1, "000") msg := NewMsg(3, 1, "000")
if msg.Code() != 3 { if msg.Code != 3 {
t.Errorf("incorrect code %v", msg.Code()) t.Errorf("incorrect code %d, want %d", msg.Code)
} }
data0 := msg.Data().Get(0).Uint() if msg.Size != 5 {
data1 := string(msg.Data().Get(1).Bytes()) t.Errorf("incorrect size %d, want %d", msg.Size, 5)
if data0 != 1 {
t.Errorf("incorrect data %v", data0)
} }
if data1 != "000" { pl, _ := ioutil.ReadAll(msg.Payload)
t.Errorf("incorrect data %v", data1) expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
if !bytes.Equal(pl, expect) {
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
} }
} }
func TestEncodeDecodeMsg(t *testing.T) { func TestEncodeDecodeMsg(t *testing.T) {
msg, _ := NewMsg(3, 1, "000") msg := NewMsg(3, 1, "000")
encoded := msg.Encode(3) buf := new(bytes.Buffer)
msg, _ = NewMsgFromBytes(encoded) if err := writeMsg(buf, msg); err != nil {
msg.Decode(3) t.Fatalf("encodeMsg error: %v", err)
if msg.Code() != 3 { }
t.Errorf("incorrect code %v", msg.Code()) // t.Logf("encoded: %x", buf.Bytes())
}
data0 := msg.Data().Get(0).Uint() decmsg, err := readMsg(buf)
data1 := msg.Data().Get(1).Str() if err != nil {
if data0 != 1 { t.Fatalf("readMsg error: %v", err)
t.Errorf("incorrect data %v", data0) }
} if decmsg.Code != 3 {
if data1 != "000" { t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
t.Errorf("incorrect data %v", data1) }
if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
}
var data struct {
I int
S string
}
if err := decmsg.Decode(&data); err != nil {
t.Fatalf("Decode error: %v", err)
}
if data.I != 1 {
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
}
if data.S != "000" {
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
}
}
func TestDecodeRealMsg(t *testing.T) {
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
msg, err := readMsg(bytes.NewReader(data))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if msg.Code != 0 {
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
} }
} }

@ -1,220 +0,0 @@
package p2p
import (
"fmt"
"sync"
"time"
)
const (
handlerTimeout = 1000
)
type Handlers map[string](func(p *Peer) Protocol)
type Messenger struct {
conn *Connection
peer *Peer
handlers Handlers
protocolLock sync.RWMutex
protocols []Protocol
offsets []MsgCode // offsets for adaptive message idss
protocolTable map[string]int
quit chan chan bool
err chan *PeerError
pulse chan bool
}
func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger {
baseProtocol := NewBaseProtocol(peer)
return &Messenger{
conn: conn,
peer: peer,
offsets: []MsgCode{baseProtocol.Offset()},
handlers: handlers,
protocols: []Protocol{baseProtocol},
protocolTable: make(map[string]int),
err: errchan,
pulse: make(chan bool, 1),
quit: make(chan chan bool, 1),
}
}
func (self *Messenger) Start() {
self.conn.Open()
go self.messenger()
self.protocolLock.RLock()
defer self.protocolLock.RUnlock()
self.protocols[0].Start()
}
func (self *Messenger) Stop() {
// close pulse to stop ping pong monitoring
close(self.pulse)
self.protocolLock.RLock()
defer self.protocolLock.RUnlock()
for _, protocol := range self.protocols {
protocol.Stop() // could be parallel
}
q := make(chan bool)
self.quit <- q
<-q
self.conn.Close()
}
func (self *Messenger) messenger() {
in := self.conn.Read()
for {
select {
case payload, ok := <-in:
//dispatches message to the protocol asynchronously
if ok {
go self.handle(payload)
} else {
return
}
case q := <-self.quit:
q <- true
return
}
}
}
// handles each message by dispatching to the appropriate protocol
// using adaptive message codes
// this function is started as a separate go routine for each message
// it waits for the protocol response
// then encodes and sends outgoing messages to the connection's write channel
func (self *Messenger) handle(payload []byte) {
// send ping to heartbeat channel signalling time of last message
// select {
// case self.pulse <- true:
// default:
// }
self.pulse <- true
// initialise message from payload
msg, err := NewMsgFromBytes(payload)
if err != nil {
self.err <- NewPeerError(MiscError, " %v", err)
return
}
// retrieves protocol based on message Code
protocol, offset, peerErr := self.getProtocol(msg.Code())
if err != nil {
self.err <- peerErr
return
}
// reset message code based on adaptive offset
msg.Decode(offset)
// dispatches
response := make(chan *Msg)
go protocol.HandleIn(msg, response)
// protocol reponse timeout to prevent leaks
timer := time.After(handlerTimeout * time.Millisecond)
for {
select {
case outgoing, ok := <-response:
// we check if response channel is not closed
if ok {
self.conn.Write() <- outgoing.Encode(offset)
} else {
return
}
case <-timer:
return
}
}
}
// negotiated protocols
// stores offsets needed for adaptive message id scheme
// based on offsets set at handshake
// get the right protocol to handle the message
func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) {
self.protocolLock.RLock()
defer self.protocolLock.RUnlock()
base := MsgCode(0)
for index, offset := range self.offsets {
if code < offset {
return self.protocols[index], base, nil
}
base = offset
}
return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code)
}
func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) {
fmt.Printf("pingpong keepalive started at %v", time.Now())
timer := time.After(timeout)
pinged := false
for {
select {
case _, ok := <-self.pulse:
if ok {
pinged = false
timer = time.After(timeout)
} else {
// pulse is closed, stop monitoring
return
}
case <-timer:
if pinged {
fmt.Printf("timeout at %v", time.Now())
timeoutCallback()
return
} else {
fmt.Printf("pinged at %v", time.Now())
pingCallback()
timer = time.After(gracePeriod)
pinged = true
}
}
}
}
func (self *Messenger) AddProtocols(protocols []string) {
self.protocolLock.Lock()
defer self.protocolLock.Unlock()
i := len(self.offsets)
offset := self.offsets[i-1]
for _, name := range protocols {
protocolFunc, ok := self.handlers[name]
if ok {
protocol := protocolFunc(self.peer)
self.protocolTable[name] = i
i++
offset += protocol.Offset()
fmt.Println("offset ", name, offset)
self.offsets = append(self.offsets, offset)
self.protocols = append(self.protocols, protocol)
protocol.Start()
} else {
fmt.Println("no ", name)
// protocol not handled
}
}
}
func (self *Messenger) Write(protocol string, msg *Msg) error {
self.protocolLock.RLock()
defer self.protocolLock.RUnlock()
i := 0
offset := MsgCode(0)
if len(protocol) > 0 {
var ok bool
i, ok = self.protocolTable[protocol]
if !ok {
return fmt.Errorf("protocol %v not handled by peer", protocol)
}
offset = self.offsets[i-1]
}
handler := self.protocols[i]
// checking if protocol status/caps allows the message to be sent out
if handler.HandleOut(msg) {
self.conn.Write() <- msg.Encode(offset)
}
return nil
}

@ -1,147 +0,0 @@
package p2p
import (
// "fmt"
"bytes"
"testing"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) {
errchan := NewPeerErrorChannel()
addr := &TestAddr{"test:30303"}
net := NewTestNetworkConnection(addr)
conn := NewConnection(net, errchan)
mess := NewMessenger(nil, conn, errchan, handlers)
mess.Start()
return net, errchan, mess
}
type TestProtocol struct {
Msgs []*Msg
}
func (self *TestProtocol) Start() {
}
func (self *TestProtocol) Stop() {
}
func (self *TestProtocol) Offset() MsgCode {
return MsgCode(5)
}
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) {
self.Msgs = append(self.Msgs, msg)
close(response)
}
func (self *TestProtocol) HandleOut(msg *Msg) bool {
if msg.Code() > 3 {
return false
} else {
return true
}
}
func (self *TestProtocol) Name() string {
return "a"
}
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte {
msg, _ := NewMsg(code, params...)
encoded := msg.Encode(offset)
packet := []byte{34, 64, 8, 145}
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...)
return append(packet, encoded...)
}
func TestRead(t *testing.T) {
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
net, _, mess := setupMessenger(handlers)
mess.AddProtocols([]string{"a"})
defer mess.Stop()
wait := 1 * time.Millisecond
packet := Packet(16, 1, uint32(1), "000")
go net.In(0, packet)
time.Sleep(wait)
if len(testProtocol.Msgs) != 1 {
t.Errorf("msg not relayed to correct protocol")
} else {
if testProtocol.Msgs[0].Code() != 1 {
t.Errorf("incorrect msg code relayed to protocol")
}
}
}
func TestWrite(t *testing.T) {
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
net, _, mess := setupMessenger(handlers)
mess.AddProtocols([]string{"a"})
defer mess.Stop()
wait := 1 * time.Millisecond
msg, _ := NewMsg(3, uint32(1), "000")
err := mess.Write("b", msg)
if err == nil {
t.Errorf("expect error for unknown protocol")
}
err = mess.Write("a", msg)
if err != nil {
t.Errorf("expect no error for known protocol: %v", err)
} else {
time.Sleep(wait)
if len(net.Out) != 1 {
t.Errorf("msg not written")
} else {
out := net.Out[0]
packet := Packet(16, 3, uint32(1), "000")
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v", out)
}
}
}
}
func TestPulse(t *testing.T) {
net, _, mess := setupMessenger(make(Handlers))
defer mess.Stop()
ping := false
timeout := false
pingTimeout := 10 * time.Millisecond
gracePeriod := 200 * time.Millisecond
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true })
net.In(0, Packet(0, 1))
if ping {
t.Errorf("ping sent too early")
}
time.Sleep(pingTimeout + 100*time.Millisecond)
if !ping {
t.Errorf("no ping sent after timeout")
}
if timeout {
t.Errorf("timeout too early")
}
ping = false
net.In(0, Packet(0, 1))
time.Sleep(pingTimeout + 100*time.Millisecond)
if !ping {
t.Errorf("no ping sent after timeout")
}
if timeout {
t.Errorf("timeout too early")
}
ping = false
time.Sleep(gracePeriod)
if ping {
t.Errorf("ping called twice")
}
if !timeout {
t.Errorf("no timeout after grace period")
}
}

@ -3,6 +3,7 @@ package p2p
import ( import (
"fmt" "fmt"
"net" "net"
"time"
natpmp "github.com/jackpal/go-nat-pmp" natpmp "github.com/jackpal/go-nat-pmp"
) )
@ -13,38 +14,37 @@ import (
// + Register for changes to the external address. // + Register for changes to the external address.
// + Re-register port mapping when router reboots. // + Re-register port mapping when router reboots.
// + A mechanism for keeping a port mapping registered. // + A mechanism for keeping a port mapping registered.
// + Discover gateway address automatically.
type natPMPClient struct { type natPMPClient struct {
client *natpmp.Client client *natpmp.Client
} }
func NewNatPMP(gateway net.IP) (nat NAT) { // PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
// address should be the IP of your router.
func PMP(gateway net.IP) (nat NAT) {
return &natPMPClient{natpmp.NewClient(gateway)} return &natPMPClient{natpmp.NewClient(gateway)}
} }
func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { func (*natPMPClient) String() string {
return "NAT-PMP"
}
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
response, err := n.client.GetExternalAddress() response, err := n.client.GetExternalAddress()
if err != nil { if err != nil {
return return nil, err
} }
ip := response.ExternalIPAddress return response.ExternalIPAddress[:], nil
addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
return
} }
func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
description string, timeout int) (mappedExternalPort int, err error) { if lifetime <= 0 {
if timeout <= 0 { return fmt.Errorf("lifetime must not be <= 0")
err = fmt.Errorf("timeout must not be <= 0")
return
} }
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
if err != nil { return err
return
}
mappedExternalPort = int(response.MappedExternalPort)
return
} }
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -15,28 +16,46 @@ import (
"time" "time"
) )
const (
upnpDiscoverAttempts = 3
upnpDiscoverTimeout = 5 * time.Second
)
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPNP() NAT {
return &upnpNAT{}
}
type upnpNAT struct { type upnpNAT struct {
serviceURL string serviceURL string
ourIP string ourIP string
} }
func upnpDiscover(attempts int) (nat NAT, err error) { func (n *upnpNAT) String() string {
return "UPNP"
}
func (n *upnpNAT) discover() error {
if n.serviceURL != "" {
// already discovered
return nil
}
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil { if err != nil {
return return err
} }
// TODO: try on all network interfaces simultaneously.
// Broadcasting on 0.0.0.0 could select a random interface
// to send on (platform specific).
conn, err := net.ListenPacket("udp4", ":0") conn, err := net.ListenPacket("udp4", ":0")
if err != nil { if err != nil {
return return err
}
socket := conn.(*net.UDPConn)
defer socket.Close()
err = socket.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
return
} }
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
buf := bytes.NewBufferString( buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" + "M-SEARCH * HTTP/1.1\r\n" +
@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
"MX: 2\r\n\r\n") "MX: 2\r\n\r\n")
message := buf.Bytes() message := buf.Bytes()
answerBytes := make([]byte, 1024) answerBytes := make([]byte, 1024)
for i := 0; i < attempts; i++ { for i := 0; i < upnpDiscoverAttempts; i++ {
_, err = socket.WriteToUDP(message, ssdp) _, err = conn.WriteTo(message, ssdp)
if err != nil { if err != nil {
return return err
} }
var n int nn, _, err := conn.ReadFrom(answerBytes)
n, _, err = socket.ReadFromUDP(answerBytes)
if err != nil { if err != nil {
continue continue
// socket.Close()
// return
} }
answer := string(answerBytes[0:n]) answer := string(answerBytes[0:nn])
if strings.Index(answer, "\r\n"+st) < 0 { if strings.Index(answer, "\r\n"+st) < 0 {
continue continue
} }
@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
var serviceURL string var serviceURL string
serviceURL, err = getServiceURL(locURL) serviceURL, err = getServiceURL(locURL)
if err != nil { if err != nil {
return return err
} }
var ourIP string var ourIP string
ourIP, err = getOurIP() ourIP, err = getOurIP()
if err != nil { if err != nil {
return return err
}
n.serviceURL = serviceURL
n.ourIP = ourIP
return nil
}
return errors.New("UPnP port discovery failed.")
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
if err := n.discover(); err != nil {
return nil, err
} }
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} info, err := n.getStatusInfo()
return net.ParseIP(info.externalIpAddress), err
}
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
if err := n.discover(); err != nil {
return err
}
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
"</NewLeaseDuration></u:AddPortMapping>"
// TODO: check response to see if the port was forwarded
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
return err
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
if err := n.discover(); err != nil {
return err
}
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
// TODO: check response to see if the port was deleted
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
return err
}
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return return
} }
err = errors.New("UPnP port discovery failed.")
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return return
} }
@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) {
} }
return return
} }
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return
}
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
info, err := n.getStatusInfo()
if err != nil {
return
}
addr = net.ParseIP(info.externalIpAddress)
return
}
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
"</NewLeaseDuration></u:AddPortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was forwarded
// log.Println(message, response)
mappedExternalPort = externalPort
_ = response
return
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was deleted
// log.Println(message, response)
_ = response
return
}

@ -1,196 +0,0 @@
package p2p
import (
"fmt"
"math/rand"
"net"
"strconv"
"time"
)
const (
DialerTimeout = 180 //seconds
KeepAlivePeriod = 60 //minutes
portMappingUpdateInterval = 900 // seconds = 15 mins
upnpDiscoverAttempts = 3
)
// Dialer is not an interface in net, so we define one
// *net.Dialer conforms to this
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
type Network interface {
Start() error
Listener(net.Addr) (net.Listener, error)
Dialer(net.Addr) (Dialer, error)
NewAddr(string, int) (addr net.Addr, err error)
ParseAddr(string) (addr net.Addr, err error)
}
type NAT interface {
GetExternalAddress() (addr net.IP, err error)
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
}
type TCPNetwork struct {
nat NAT
natType NATType
quit chan chan bool
ports chan string
}
type NATType int
const (
NONE = iota
UPNP
PMP
)
const (
portMappingTimeout = 1200 // 20 mins
)
func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
return &TCPNetwork{
natType: natType,
ports: make(chan string),
}
}
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
return &net.Dialer{
Timeout: DialerTimeout * time.Second,
// KeepAlive: KeepAlivePeriod * time.Minute,
LocalAddr: addr,
}, nil
}
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
if self.natType == UPNP {
_, port, _ := net.SplitHostPort(addr.String())
if self.quit == nil {
self.quit = make(chan chan bool)
go self.updatePortMappings()
}
self.ports <- port
}
return net.Listen(addr.Network(), addr.String())
}
func (self *TCPNetwork) Start() (err error) {
switch self.natType {
case NONE:
case UPNP:
nat, uerr := upnpDiscover(upnpDiscoverAttempts)
if uerr != nil {
err = fmt.Errorf("UPNP failed: ", uerr)
} else {
self.nat = nat
}
case PMP:
err = fmt.Errorf("PMP not implemented")
default:
err = fmt.Errorf("Invalid NAT type: %v", self.natType)
}
return
}
func (self *TCPNetwork) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *TCPNetwork) addPortMapping(lport int) (err error) {
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
if err != nil {
logger.Errorf("unable to add port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully added port mapping on %v", lport)
}
return
}
func (self *TCPNetwork) updatePortMappings() {
timer := time.NewTimer(portMappingUpdateInterval * time.Second)
lports := []int{}
out:
for {
select {
case port := <-self.ports:
int64lport, _ := strconv.ParseInt(port, 10, 16)
lport := int(int64lport)
if err := self.addPortMapping(lport); err != nil {
lports = append(lports, lport)
}
case <-timer.C:
for lport := range lports {
if err := self.addPortMapping(lport); err != nil {
}
}
case errc := <-self.quit:
errc <- true
break out
}
}
timer.Stop()
for lport := range lports {
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully removed port mapping on %v", lport)
}
}
}
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
ip, err := self.lookupIP(host)
if err == nil {
return &net.TCPAddr{
IP: ip,
Port: port,
}, nil
}
return nil, err
}
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
host, port, err := net.SplitHostPort(address)
if err == nil {
iport, _ := strconv.Atoi(port)
addr, e := self.NewAddr(host, iport)
return addr, e
}
return nil, err
}
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
if ip = net.ParseIP(host); ip != nil {
return
}
var ips []net.IP
ips, err = net.LookupIP(host)
if err != nil {
logger.Warnln(err)
return
}
if len(ips) == 0 {
err = fmt.Errorf("No IP addresses available for %v", host)
logger.Warnln(err)
return
}
if len(ips) > 1 {
// Pick a random IP address, simulating round-robin DNS.
rand.Seed(time.Now().UTC().UnixNano())
ip = ips[rand.Intn(len(ips))]
} else {
ip = ips[0]
}
return
}

@ -1,83 +1,455 @@
package p2p package p2p
import ( import (
"bufio"
"bytes"
"fmt" "fmt"
"io"
"io/ioutil"
"net" "net"
"strconv" "sort"
"sync"
"time"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger"
) )
// peerAddr is the structure of a peer list element.
// It is also a valid net.Addr.
type peerAddr struct {
IP net.IP
Port uint64
Pubkey []byte // optional
}
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
n := addr.Network()
if n != "tcp" && n != "tcp4" && n != "tcp6" {
// for testing with non-TCP
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
}
ta := addr.(*net.TCPAddr)
return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
}
func (d peerAddr) Network() string {
if d.IP.To4() != nil {
return "tcp4"
} else {
return "tcp6"
}
}
func (d peerAddr) String() string {
return fmt.Sprintf("%v:%d", d.IP, d.Port)
}
func (d peerAddr) RlpData() interface{} {
return []interface{}{d.IP, d.Port, d.Pubkey}
}
// Peer represents a remote peer.
type Peer struct { type Peer struct {
// quit chan chan bool // Peers have all the log methods.
Inbound bool // inbound (via listener) or outbound (via dialout) // Use them to display messages related to the peer.
Address net.Addr *logger.Logger
Host []byte
Port uint16 infolock sync.Mutex
Pubkey []byte identity ClientIdentity
Id string caps []Cap
Caps []string listenAddr *peerAddr // what remote peer is listening on
peerErrorChan chan *PeerError dialAddr *peerAddr // non-nil if dialing
messenger *Messenger
peerErrorHandler *PeerErrorHandler // The mutex protects the connection
server *Server // so only one protocol can write at a time.
} writeMu sync.Mutex
conn net.Conn
func (self *Peer) Messenger() *Messenger { bufconn *bufio.ReadWriter
return self.messenger
} // These fields maintain the running protocols.
protocols []Protocol
func (self *Peer) PeerErrorChan() chan *PeerError { runBaseProtocol bool // for testing
return self.peerErrorChan
} runlock sync.RWMutex // protects running
running map[string]*proto
func (self *Peer) Server() *Server {
return self.server protoWG sync.WaitGroup
} protoErr chan error
closed chan struct{}
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { disc chan DiscReason
peerErrorChan := NewPeerErrorChannel()
host, port, _ := net.SplitHostPort(address.String()) activity event.TypeMux // for activity events
intport, _ := strconv.Atoi(port)
peer := &Peer{ slot int // index into Server peer list
Inbound: inbound,
Address: address, // These fields are kept so base protocol can access them.
Port: uint16(intport), // TODO: this should be one or more interfaces
Host: net.ParseIP(host), ourID ClientIdentity // client id of the Server
peerErrorChan: peerErrorChan, ourListenAddr *peerAddr // listen addr of Server, nil if not listening
server: server, newPeerAddr chan<- *peerAddr // tell server about received peers
} otherPeers func() []*Peer // should return the list of all peers
connection := NewConnection(conn, peerErrorChan) pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers()) }
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist())
// NewPeer returns a peer for testing purposes.
func NewPeer(id ClientIdentity, caps []Cap) *Peer {
conn, _ := net.Pipe()
peer := newPeer(conn, nil, nil)
peer.setHandshakeInfo(id, nil, caps)
close(peer.closed)
return peer return peer
} }
func (self *Peer) String() string { func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
var kind string p := newPeer(conn, server.Protocols, dialAddr)
if self.Inbound { p.ourID = server.Identity
kind = "inbound" p.newPeerAddr = server.peerConnect
} else { p.otherPeers = server.Peers
p.pubkeyHook = server.verifyPeer
p.runBaseProtocol = true
// laddr can be updated concurrently by NAT traversal.
// newServerPeer must be called with the server lock held.
if server.laddr != nil {
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
}
return p
}
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
p := &Peer{
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
conn: conn,
dialAddr: dialAddr,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
}
return p
}
// Identity returns the client identity of the remote peer. The
// identity can be nil if the peer has not yet completed the
// handshake.
func (p *Peer) Identity() ClientIdentity {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.identity
}
// Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.caps
}
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
p.infolock.Lock()
p.identity = id
p.listenAddr = laddr
p.caps = caps
p.infolock.Unlock()
}
// RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr()
}
// LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// Disconnect terminates the peer connection with the given reason.
// It returns immediately and does not wait until the connection is closed.
func (p *Peer) Disconnect(reason DiscReason) {
select {
case p.disc <- reason:
case <-p.closed:
}
}
// String implements fmt.Stringer.
func (p *Peer) String() string {
kind := "inbound"
p.infolock.Lock()
if p.dialAddr != nil {
kind = "outbound" kind = "outbound"
} }
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) p.infolock.Unlock()
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
}
const (
// maximum amount of time allowed for reading a message
msgReadTimeout = 5 * time.Second
// maximum amount of time allowed for writing a message
msgWriteTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
)
var (
inactivityTimeout = 2 * time.Second
disconnectGracePeriod = 2 * time.Second
)
func (p *Peer) loop() (reason DiscReason, err error) {
defer p.activity.Stop()
defer p.closeProtocols()
defer close(p.closed)
defer p.conn.Close()
// read loop
readMsg := make(chan Msg)
readErr := make(chan error)
readNext := make(chan bool, 1)
protoDone := make(chan struct{}, 1)
go p.readLoop(readMsg, readErr, readNext)
readNext <- true
if p.runBaseProtocol {
p.startBaseProtocol()
}
loop:
for {
select {
case msg := <-readMsg:
// a new message has arrived.
var wait bool
if wait, err = p.dispatch(msg, protoDone); err != nil {
p.Errorf("msg dispatch error: %v\n", err)
reason = discReasonForError(err)
break loop
}
if !wait {
// Msg has already been read completely, continue with next message.
readNext <- true
}
p.activity.Post(time.Now())
case <-protoDone:
// protocol has consumed the message payload,
// we can continue reading from the socket.
readNext <- true
case err := <-readErr:
// read failed. there is no need to run the
// polite disconnect sequence because the connection
// is probably dead anyway.
// TODO: handle write errors as well
return DiscNetworkError, err
case err = <-p.protoErr:
reason = discReasonForError(err)
break loop
case reason = <-p.disc:
break loop
}
}
// wait for read loop to return.
close(readNext)
<-readErr
// tell the remote end to disconnect
done := make(chan struct{})
go func() {
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
io.Copy(ioutil.Discard, p.conn)
close(done)
}()
select {
case <-done:
case <-time.After(disconnectGracePeriod):
}
return reason, err
}
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
for _ = range unblock {
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
if msg, err := readMsg(p.bufconn); err != nil {
errc <- err
} else {
msgc <- msg
}
}
close(errc)
}
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
proto, err := p.getProto(msg.Code)
if err != nil {
return false, err
}
if msg.Size <= wholePayloadSize {
// optimization: msg is small enough, read all
// of it and move on to the next message
buf, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return false, err
}
msg.Payload = bytes.NewReader(buf)
proto.in <- msg
} else {
wait = true
pr := &eofSignal{msg.Payload, protoDone}
msg.Payload = pr
proto.in <- msg
}
return wait, nil
}
func (p *Peer) startBaseProtocol() {
p.runlock.Lock()
defer p.runlock.Unlock()
p.running[""] = p.startProto(0, Protocol{
Length: baseProtocolLength,
Run: runBaseProtocol,
})
}
// startProtocols starts matching named subprotocols.
func (p *Peer) startSubprotocols(caps []Cap) {
sort.Sort(capsByName(caps))
p.runlock.Lock()
defer p.runlock.Unlock()
offset := baseProtocolLength
outer:
for _, cap := range caps {
for _, proto := range p.protocols {
if proto.Name == cap.Name &&
proto.Version == cap.Version &&
p.running[cap.Name] == nil {
p.running[cap.Name] = p.startProto(offset, proto)
offset += proto.Length
continue outer
}
}
}
} }
func (self *Peer) Write(protocol string, msg *Msg) error { func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
return self.messenger.Write(protocol, msg) rw := &proto{
in: make(chan Msg),
offset: offset,
maxcode: impl.Length,
peer: p,
}
p.protoWG.Add(1)
go func() {
err := impl.Run(p, rw)
if err == nil {
p.Infof("protocol %q returned", impl.Name)
err = newPeerError(errMisc, "protocol returned")
} else {
p.Errorf("protocol %q error: %v\n", impl.Name, err)
}
select {
case p.protoErr <- err:
case <-p.closed:
}
p.protoWG.Done()
}()
return rw
}
// getProto finds the protocol responsible for handling
// the given message code.
func (p *Peer) getProto(code uint64) (*proto, error) {
p.runlock.RLock()
defer p.runlock.RUnlock()
for _, proto := range p.running {
if code >= proto.offset && code < proto.offset+proto.maxcode {
return proto, nil
}
}
return nil, newPeerError(errInvalidMsgCode, "%d", code)
} }
func (self *Peer) Start() { func (p *Peer) closeProtocols() {
self.peerErrorHandler.Start() p.runlock.RLock()
self.messenger.Start() for _, p := range p.running {
close(p.in)
}
p.runlock.RUnlock()
p.protoWG.Wait()
} }
func (self *Peer) Stop() { // writeProtoMsg sends the given message on behalf of the given named protocol.
self.peerErrorHandler.Stop() func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
self.messenger.Stop() p.runlock.RLock()
// q := make(chan bool) proto, ok := p.running[protoName]
// self.quit <- q p.runlock.RUnlock()
// <-q if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
if msg.Code >= proto.maxcode {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return p.writeMsg(msg, msgWriteTimeout)
}
// writeMsg writes a message to the connection.
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
p.writeMu.Lock()
defer p.writeMu.Unlock()
p.conn.SetWriteDeadline(time.Now().Add(timeout))
if err := writeMsg(p.bufconn, msg); err != nil {
return newPeerError(errWrite, "%v", err)
}
return p.bufconn.Flush()
} }
func (p *Peer) Encode() []interface{} { type proto struct {
return []interface{}{p.Host, p.Port, p.Pubkey} name string
in chan Msg
maxcode, offset uint64
peer *Peer
}
func (rw *proto) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode {
return newPeerError(errInvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.peer.writeMsg(msg, msgWriteTimeout)
}
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return rw.WriteMsg(NewMsg(code, data))
}
func (rw *proto) ReadMsg() (Msg, error) {
msg, ok := <-rw.in
if !ok {
return msg, io.EOF
}
msg.Code -= rw.offset
return msg, nil
}
// eofSignal wraps a reader with eof signaling.
// the eof channel is closed when the wrapped reader
// reaches EOF.
type eofSignal struct {
wrapped io.Reader
eof chan<- struct{}
}
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
if err != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
}
return n, err
} }

@ -4,73 +4,121 @@ import (
"fmt" "fmt"
) )
type ErrorCode int
const errorChanCapacity = 10
const ( const (
PacketTooShort = iota errMagicTokenMismatch = iota
PayloadTooShort errRead
MagicTokenMismatch errWrite
EmptyPayload errMisc
ReadError errInvalidMsgCode
WriteError errInvalidMsg
MiscError errP2PVersionMismatch
InvalidMsgCode errPubkeyMissing
InvalidMsg errPubkeyInvalid
P2PVersionMismatch errPubkeyForbidden
PubkeyMissing errProtocolBreach
PubkeyInvalid errPingTimeout
PubkeyForbidden errInvalidNetworkId
ProtocolBreach errInvalidProtocolVersion
PortMismatch
PingTimeout
InvalidGenesis
InvalidNetworkId
InvalidProtocolVersion
) )
var errorToString = map[ErrorCode]string{ var errorToString = map[int]string{
PacketTooShort: "Packet too short", errMagicTokenMismatch: "Magic token mismatch",
PayloadTooShort: "Payload too short", errRead: "Read error",
MagicTokenMismatch: "Magic token mismatch", errWrite: "Write error",
EmptyPayload: "Empty payload", errMisc: "Misc error",
ReadError: "Read error", errInvalidMsgCode: "Invalid message code",
WriteError: "Write error", errInvalidMsg: "Invalid message",
MiscError: "Misc error", errP2PVersionMismatch: "P2P Version Mismatch",
InvalidMsgCode: "Invalid message code", errPubkeyMissing: "Public key missing",
InvalidMsg: "Invalid message", errPubkeyInvalid: "Public key invalid",
P2PVersionMismatch: "P2P Version Mismatch", errPubkeyForbidden: "Public key forbidden",
PubkeyMissing: "Public key missing", errProtocolBreach: "Protocol Breach",
PubkeyInvalid: "Public key invalid", errPingTimeout: "Ping timeout",
PubkeyForbidden: "Public key forbidden", errInvalidNetworkId: "Invalid network id",
ProtocolBreach: "Protocol Breach", errInvalidProtocolVersion: "Invalid protocol version",
PortMismatch: "Port mismatch",
PingTimeout: "Ping timeout",
InvalidGenesis: "Invalid genesis block",
InvalidNetworkId: "Invalid network id",
InvalidProtocolVersion: "Invalid protocol version",
} }
type PeerError struct { type peerError struct {
Code ErrorCode Code int
message string message string
} }
func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { func newPeerError(code int, format string, v ...interface{}) *peerError {
desc, ok := errorToString[code] desc, ok := errorToString[code]
if !ok { if !ok {
panic("invalid error code") panic("invalid error code")
} }
format = desc + ": " + format err := &peerError{code, desc}
message := fmt.Sprintf(format, v...) if format != "" {
return &PeerError{code, message} err.message += ": " + fmt.Sprintf(format, v...)
}
return err
} }
func (self *PeerError) Error() string { func (self *peerError) Error() string {
return self.message return self.message
} }
func NewPeerErrorChannel() chan *PeerError { type DiscReason byte
return make(chan *PeerError, errorChanCapacity)
const (
DiscRequested DiscReason = 0x00
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
)
var discReasonToString = [DiscSubprotocolError + 1]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
DiscUselessPeer: "Useless peer",
DiscTooManyPeers: "Too many peers",
DiscAlreadyConnected: "Already connected",
DiscIncompatibleVersion: "Incompatible P2P protocol version",
DiscInvalidIdentity: "Invalid node identity",
DiscQuitting: "Client quitting",
DiscUnexpectedIdentity: "Unexpected identity",
DiscSelf: "Connected to self",
DiscReadTimeout: "Read timeout",
DiscSubprotocolError: "Subprotocol error",
}
func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return fmt.Sprintf("Unknown Reason(%d)", d)
}
return discReasonToString[d]
}
func discReasonForError(err error) DiscReason {
peerError, ok := err.(*peerError)
if !ok {
return DiscSubprotocolError
}
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
case errRead, errWrite, errMisc:
return DiscNetworkError
default:
return DiscSubprotocolError
}
} }

@ -1,101 +0,0 @@
package p2p
import (
"net"
)
const (
severityThreshold = 10
)
type DisconnectRequest struct {
addr net.Addr
reason DiscReason
}
type PeerErrorHandler struct {
quit chan chan bool
address net.Addr
peerDisconnect chan DisconnectRequest
severity int
peerErrorChan chan *PeerError
blacklist Blacklist
}
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler {
return &PeerErrorHandler{
quit: make(chan chan bool),
address: address,
peerDisconnect: peerDisconnect,
peerErrorChan: peerErrorChan,
blacklist: blacklist,
}
}
func (self *PeerErrorHandler) Start() {
go self.listen()
}
func (self *PeerErrorHandler) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *PeerErrorHandler) listen() {
for {
select {
case peerError, ok := <-self.peerErrorChan:
if ok {
logger.Debugf("error %v\n", peerError)
go self.handle(peerError)
} else {
return
}
case q := <-self.quit:
q <- true
return
}
}
}
func (self *PeerErrorHandler) handle(peerError *PeerError) {
reason := DiscReason(' ')
switch peerError.Code {
case P2PVersionMismatch:
reason = DiscIncompatibleVersion
case PubkeyMissing, PubkeyInvalid:
reason = DiscInvalidIdentity
case PubkeyForbidden:
reason = DiscUselessPeer
case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach:
reason = DiscProtocolError
case PingTimeout:
reason = DiscReadTimeout
case WriteError, MiscError:
reason = DiscNetworkError
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
reason = DiscSubprotocolError
default:
self.severity += self.getSeverity(peerError)
}
if self.severity >= severityThreshold {
reason = DiscSubprotocolError
}
if reason != DiscReason(' ') {
self.peerDisconnect <- DisconnectRequest{
addr: self.address,
reason: reason,
}
}
}
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
switch peerError.Code {
case ReadError:
return 4 //tolerate 3 :)
default:
return 1
}
}

@ -1,34 +0,0 @@
package p2p
import (
// "fmt"
"net"
"testing"
"time"
)
func TestPeerErrorHandler(t *testing.T) {
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
peerDisconnect := make(chan DisconnectRequest)
peerErrorChan := NewPeerErrorChannel()
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist())
peh.Start()
defer peh.Stop()
for i := 0; i < 11; i++ {
select {
case <-peerDisconnect:
t.Errorf("expected no disconnect request")
default:
}
peerErrorChan <- NewPeerError(MiscError, "")
}
time.Sleep(1 * time.Millisecond)
select {
case request := <-peerDisconnect:
if request.addr.String() != address.String() {
t.Errorf("incorrect address %v != %v", request.addr, address)
}
default:
t.Errorf("expected disconnect request")
}
}

@ -1,96 +1,239 @@
package p2p package p2p
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "encoding/hex"
// "net" "io/ioutil"
"net"
"reflect"
"testing" "testing"
"time" "time"
) )
func TestPeer(t *testing.T) { var discard = Protocol{
handlers := make(Handlers) Name: "discard",
testProtocol := &TestProtocol{Msgs: []*Msg{}} Length: 1,
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } Run: func(p *Peer, rw MsgReadWriter) error {
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } for {
addr := &TestAddr{"test:30"} msg, err := rw.ReadMsg()
conn := NewTestNetworkConnection(addr)
_, server := SetupTestServer(handlers)
server.Handshake()
peer := NewPeer(conn, addr, true, server)
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
peer.Start()
defer peer.Stop()
time.Sleep(2 * time.Millisecond)
if len(conn.Out) != 1 {
t.Errorf("handshake not sent")
} else {
out := conn.Out[0]
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect handshake packet %v != %v", out, packet)
}
}
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
conn.In(0, packet)
time.Sleep(10 * time.Millisecond)
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
if pro.state != handshakeReceived {
t.Errorf("handshake not received")
}
if peer.Port != 30 {
t.Errorf("port incorrectly set")
}
if peer.Id != "peer" {
t.Errorf("id incorrectly set")
}
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
t.Errorf("pubkey incorrectly set")
}
fmt.Println(peer.Caps)
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
t.Errorf("protocols incorrectly set")
}
msg, _ := NewMsg(3)
err := peer.Write("aaa", msg)
if err != nil { if err != nil {
t.Errorf("expect no error for known protocol: %v", err) return err
} else { }
time.Sleep(1 * time.Millisecond) if err = msg.Discard(); err != nil {
if len(conn.Out) != 2 { return err
t.Errorf("msg not written") }
} else { }
out := conn.Out[1] },
packet := Packet(16, 3) }
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v != %v", out, packet) func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
peer := newPeer(conn1, protos, nil)
peer.ourID = id
peer.pubkeyHook = func(*peerAddr) error { return nil }
errc := make(chan error, 1)
go func() {
_, err := peer.loop()
errc <- err
}()
return conn2, peer, errc
}
func TestPeerProtoReadMsg(t *testing.T) {
defer testlog(t).detach()
done := make(chan struct{})
proto := Protocol{
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
}
data, err := ioutil.ReadAll(msg.Payload)
if err != nil {
t.Errorf("payload read error: %v", err)
}
expdata, _ := hex.DecodeString("0183303030")
if !bytes.Equal(expdata, data) {
t.Errorf("incorrect msg data %x", data)
}
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, 1, "000"))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestPeerProtoReadLargeMsg(t *testing.T) {
defer testlog(t).detach()
msgsize := uint32(10 * 1024 * 1024)
done := make(chan struct{})
proto := Protocol{
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Size != msgsize+4 {
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
}
msg.Discard()
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
} }
}
func TestPeerProtoEncodeMsg(t *testing.T) {
defer testlog(t).detach()
proto := Protocol{
Name: "a",
Length: 2,
Run: func(peer *Peer, rw MsgReadWriter) error {
if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
} }
if err := rw.EncodeMsg(1); err != nil {
t.Errorf("write error: %v", err)
} }
return nil
},
}
net, peer, _ := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
msg, _ = NewMsg(2) bufr := bufio.NewReader(net)
err = peer.Write("ccc", msg) msg, err := readMsg(bufr)
if err != nil { if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
func TestPeerWrite(t *testing.T) {
defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
}
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
close(read)
}()
// test succcessful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err) t.Errorf("expect no error for known protocol: %v", err)
} else {
time.Sleep(1 * time.Millisecond)
if len(conn.Out) != 3 {
t.Errorf("msg not written")
} else {
out := conn.Out[2]
packet := Packet(21, 2)
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v != %v", out, packet)
} }
select {
case <-read:
case err := <-peerErr:
t.Fatalf("peer stopped: %v", err)
} }
}
func TestPeerActivity(t *testing.T) {
// shorten inactivityTimeout while this test is running
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
sub := peer.activity.Subscribe(time.Time{})
defer sub.Unsubscribe()
for i := 0; i < 6; i++ {
writeMsg(net, NewMsg(16))
select {
case <-sub.Chan():
case <-time.After(inactivityTimeout / 2):
t.Fatal("no event within ", inactivityTimeout/2)
case err := <-peerErr:
t.Fatal("peer error", err)
}
}
select {
case <-time.After(inactivityTimeout * 2):
case <-sub.Chan():
t.Fatal("got activity event while connection was inactive")
case err := <-peerErr:
t.Fatal("peer error", err)
} }
}
err = peer.Write("bbb", msg) func TestNewPeer(t *testing.T) {
time.Sleep(1 * time.Millisecond) id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey")
if err == nil { caps := []Cap{{"foo", 2}, {"bar", 3}}
t.Errorf("expect error for unknown protocol") p := NewPeer(id, caps)
if !reflect.DeepEqual(p.Caps(), caps) {
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
}
if p.Identity() != id {
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
} }
// Should not hang.
p.Disconnect(DiscAlreadyConnected)
} }

@ -2,277 +2,294 @@ package p2p
import ( import (
"bytes" "bytes"
"fmt"
"net"
"sort"
"sync"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil"
) )
type Protocol interface { // Protocol represents a P2P subprotocol implementation.
Start() type Protocol struct {
Stop() // Name should contain the official protocol name,
HandleIn(*Msg, chan *Msg) // often a three-letter word.
HandleOut(*Msg) bool Name string
Offset() MsgCode
Name() string // Version should contain the version number of the protocol.
Version uint
// Length should contain the number of message codes used
// by the protocol.
Length uint64
// Run is called in a new groutine when the protocol has been
// negotiated with a peer. It should read and write messages from
// rw. The Payload for each message must be fully consumed.
//
// The peer connection is closed when Start returns. It should return
// any protocol-level error (such as an I/O error) that is
// encountered.
Run func(peer *Peer, rw MsgReadWriter) error
} }
const ( func (p Protocol) cap() Cap {
P2PVersion = 0 return Cap{p.Name, p.Version}
pingTimeout = 2 }
pingGracePeriod = 2
)
const ( const (
HandshakeMsg = iota baseProtocolVersion = 2
DiscMsg baseProtocolLength = uint64(16)
PingMsg baseProtocolMaxMsgSize = 10 * 1024 * 1024
PongMsg
GetPeersMsg
PeersMsg
offset = 16
) )
type ProtocolState uint8
const ( const (
nullState = iota // devp2p message codes
handshakeReceived handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
) )
type DiscReason byte // handshake is the structure of a handshake list.
type handshake struct {
const ( Version uint64
// Values are given explicitly instead of by iota because these values are ID string
// defined by the wire protocol spec; it is easier for humans to ensure Caps []Cap
// correctness when values are explicit. ListenPort uint64
DiscRequested = 0x00 NodeID []byte
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
)
var discReasonToString = map[DiscReason]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
DiscUselessPeer: "Useless peer",
DiscTooManyPeers: "Too many peers",
DiscAlreadyConnected: "Already connected",
DiscIncompatibleVersion: "Incompatible P2P protocol version",
DiscInvalidIdentity: "Invalid node identity",
DiscQuitting: "Client quitting",
DiscUnexpectedIdentity: "Unexpected identity",
DiscSelf: "Connected to self",
DiscReadTimeout: "Read timeout",
DiscSubprotocolError: "Subprotocol error",
}
func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return "Unknown"
}
return discReasonToString[d]
}
type BaseProtocol struct {
peer *Peer
state ProtocolState
stateLock sync.RWMutex
} }
func NewBaseProtocol(peer *Peer) *BaseProtocol { func (h *handshake) String() string {
self := &BaseProtocol{ return h.ID
peer: peer,
}
return self
} }
func (h *handshake) Pubkey() []byte {
func (self *BaseProtocol) Start() { return h.NodeID
if self.peer != nil {
self.peer.Write("", self.peer.Server().Handshake())
go self.peer.Messenger().PingPong(
pingTimeout*time.Second,
pingGracePeriod*time.Second,
self.Ping,
self.Timeout,
)
}
} }
func (self *BaseProtocol) Stop() { // Cap is the structure of a peer capability.
type Cap struct {
Name string
Version uint
} }
func (self *BaseProtocol) Ping() { func (cap Cap) RlpData() interface{} {
msg, _ := NewMsg(PingMsg) return []interface{}{cap.Name, cap.Version}
self.peer.Write("", msg)
} }
func (self *BaseProtocol) Timeout() { type capsByName []Cap
self.peerError(PingTimeout, "")
}
func (self *BaseProtocol) Name() string { func (cs capsByName) Len() int { return len(cs) }
return "" func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
} func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
func (self *BaseProtocol) Offset() MsgCode { type baseProtocol struct {
return offset rw MsgReadWriter
peer *Peer
} }
func (self *BaseProtocol) CheckState(state ProtocolState) bool { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
self.stateLock.RLock() bp := &baseProtocol{rw, peer}
self.stateLock.RUnlock() if err := bp.doHandshake(rw); err != nil {
if self.state != state { return err
return false }
} else { // run main loop
return true quit := make(chan error, 1)
go func() {
for {
if err := bp.handle(rw); err != nil {
quit <- err
break
} }
}
}()
return bp.loop(quit)
} }
func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) { var pingTimeout = 2 * time.Second
if msg.Code() == HandshakeMsg {
self.handleHandshake(msg) func (bp *baseProtocol) loop(quit <-chan error) error {
} else { ping := time.NewTimer(pingTimeout)
if !self.CheckState(handshakeReceived) { activity := bp.peer.activity.Subscribe(time.Time{})
self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code()) lastActive := time.Time{}
close(response) defer ping.Stop()
return defer activity.Unsubscribe()
getPeersTick := time.NewTicker(10 * time.Second)
defer getPeersTick.Stop()
err := bp.rw.EncodeMsg(getPeersMsg)
for err == nil {
select {
case err = <-quit:
return err
case <-getPeersTick.C:
err = bp.rw.EncodeMsg(getPeersMsg)
case event := <-activity.Chan():
ping.Reset(pingTimeout)
lastActive = event.(time.Time)
case t := <-ping.C:
if lastActive.Add(pingTimeout * 2).Before(t) {
err = newPeerError(errPingTimeout, "")
} else if lastActive.Add(pingTimeout).Before(t) {
err = bp.rw.EncodeMsg(pingMsg)
} }
switch msg.Code() {
case DiscMsg:
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint()))
self.peer.Server().PeerDisconnect() <- DisconnectRequest{
addr: self.peer.Address,
reason: DiscRequested,
} }
case PingMsg:
out, _ := NewMsg(PongMsg)
response <- out
case PongMsg:
case GetPeersMsg:
// Peer asked for list of connected peers
if out, err := self.peer.Server().PeersMessage(); err != nil {
response <- out
} }
case PeersMsg: return err
self.handlePeers(msg) }
default:
self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code()) func (bp *baseProtocol) handle(rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
return err
} }
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
} }
close(response) // make sure that the payload has been fully consumed
} defer msg.Discard()
func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) { switch msg.Code {
// somewhat overly paranoid case handshakeMsg:
allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived) return newPeerError(errProtocolBreach, "extra handshake received")
return
}
func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) { case discMsg:
err := NewPeerError(errorCode, format, v...) var reason DiscReason
logger.Warnln(err) if err := msg.Decode(&reason); err != nil {
fmt.Println(self.peer, err) return err
if self.peer != nil {
self.peer.PeerErrorChan() <- err
} }
} bp.peer.Disconnect(reason)
return nil
func (self *BaseProtocol) handlePeers(msg *Msg) {
it := msg.Data().NewIterator() case pingMsg:
for it.Next() { return bp.rw.EncodeMsg(pongMsg)
ip := net.IP(it.Value().Get(0).Bytes())
port := it.Value().Get(1).Uint() case pongMsg:
address := &net.TCPAddr{IP: ip, Port: int(port)}
go self.peer.Server().PeerConnect(address) case getPeersMsg:
peers := bp.peerList()
// this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available.
// this means that would need to send a response later when
// new peers become available.
//
// TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 {
return bp.rw.EncodeMsg(peersMsg, peers)
} }
}
func (self *BaseProtocol) handleHandshake(msg *Msg) { case peersMsg:
self.stateLock.Lock() var peers []*peerAddr
defer self.stateLock.Unlock() if err := msg.Decode(&peers); err != nil {
if self.state != nullState { return err
self.peerError(ProtocolBreach, "extra handshake")
return
} }
for _, addr := range peers {
c := msg.Data() bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
var (
p2pVersion = c.Get(0).Uint()
id = c.Get(1).Str()
caps = c.Get(2)
port = c.Get(3).Uint()
pubkey = c.Get(4).Bytes()
)
fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey)
// Check correctness of p2p protocol version
if p2pVersion != P2PVersion {
self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion)
return
} }
// Handle the pub key (validation, uniqueness) default:
if len(pubkey) == 0 { return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
self.peerError(PubkeyMissing, "not supplied in handshake.")
return
} }
return nil
}
if len(pubkey) != 64 { func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) // send our handshake
return if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
return err
} }
// Self connect detection // read and handle remote handshake
if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 { msg, err := rw.ReadMsg()
self.peerError(PubkeyForbidden, "not allowed to connect to self") if err != nil {
return return err
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
} }
// register pubkey on server. this also sets the pubkey on the peer (need lock) var hs handshake
if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil { if err := msg.Decode(&hs); err != nil {
self.peerError(PubkeyForbidden, err.Error()) return err
return
} }
// check port // validate handshake info
if self.peer.Inbound { if hs.Version != baseProtocolVersion {
uint16port := uint16(port) return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
if self.peer.Port > 0 && self.peer.Port != uint16port { baseProtocolVersion, hs.Version)
self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port)
return
} else {
self.peer.Port = uint16port
} }
if len(hs.NodeID) == 0 {
return newPeerError(errPubkeyMissing, "")
} }
if len(hs.NodeID) != 64 {
capsIt := caps.NewIterator() return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
for capsIt.Next() { }
cap := capsIt.Value().Str() if da := bp.peer.dialAddr; da != nil {
self.peer.Caps = append(self.peer.Caps, cap) // verify that the peer we wanted to connect to
// actually holds the target public key.
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
}
}
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err)
} }
sort.Strings(self.peer.Caps)
self.peer.Messenger().AddProtocols(self.peer.Caps)
self.peer.Id = id // TODO: remove Caps with empty name
self.state = handshakeReceived var addr *peerAddr
if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
addr.Port = hs.ListenPort
}
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
bp.peer.startSubprotocols(hs.Caps)
return nil
}
func (bp *baseProtocol) handshakeMsg() Msg {
var (
port uint64
caps []interface{}
)
if bp.peer.ourListenAddr != nil {
port = bp.peer.ourListenAddr.Port
}
for _, proto := range bp.peer.protocols {
caps = append(caps, proto.cap())
}
return NewMsg(handshakeMsg,
baseProtocolVersion,
bp.peer.ourID.String(),
caps,
port,
bp.peer.ourID.Pubkey()[1:],
)
}
//p.ethereum.PushPeer(p) func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
// p.ethereum.reactor.Post("peerList", p.ethereum.Peers()) peers := bp.peer.otherPeers()
return ds := make([]ethutil.RlpEncodable, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
} }

@ -2,483 +2,466 @@ package p2p
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"net" "net"
"sort"
"strconv"
"sync" "sync"
"time" "time"
logpkg "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
) )
const ( const (
outboundAddressPoolSize = 10 outboundAddressPoolSize = 500
disconnectGracePeriod = 2 defaultDialTimeout = 10 * time.Second
portMappingUpdateInterval = 15 * time.Minute
portMappingTimeout = 20 * time.Minute
) )
type Blacklist interface { var srvlog = logger.NewLogger("P2P Server")
Get([]byte) (bool, error)
Put([]byte) error
Delete([]byte) error
Exists(pubkey []byte) (ok bool)
}
type BlacklistMap struct {
blacklist map[string]bool
lock sync.RWMutex
}
func NewBlacklist() *BlacklistMap {
return &BlacklistMap{
blacklist: make(map[string]bool),
}
}
func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
self.lock.RLock()
defer self.lock.RUnlock()
v, ok := self.blacklist[string(pubkey)]
var err error
if !ok {
err = fmt.Errorf("not found")
}
return v, err
}
func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
self.lock.RLock()
defer self.lock.RUnlock()
_, ok = self.blacklist[string(pubkey)]
return
}
func (self *BlacklistMap) Put(pubkey []byte) error {
self.lock.RLock()
defer self.lock.RUnlock()
self.blacklist[string(pubkey)] = true
return nil
}
func (self *BlacklistMap) Delete(pubkey []byte) error {
self.lock.RLock()
defer self.lock.RUnlock()
delete(self.blacklist, string(pubkey))
return nil
}
// Server manages all peer connections.
//
// The fields of Server are used as configuration parameters.
// You should set them before starting the Server. Fields may not be
// modified while the server is running.
type Server struct { type Server struct {
network Network // This field must be set to a valid client identity.
listening bool //needed? Identity ClientIdentity
dialing bool //needed?
closed bool // MaxPeers is the maximum number of peers that can be
identity ClientIdentity // connected. It must be greater than zero.
addr net.Addr MaxPeers int
port uint16
protocols []string // Protocols should contain the protocols supported
// by the server. Matching protocols are launched for
quit chan chan bool // each peer.
peersLock sync.RWMutex Protocols []Protocol
maxPeers int // If Blacklist is set to a non-nil value, the given Blacklist
// is used to verify peer connections.
Blacklist Blacklist
// If ListenAddr is set to a non-nil address, the server
// will listen for incoming connections.
//
// If the port is zero, the operating system will pick a port. The
// ListenAddr field will be updated with the actual address when
// the server is started.
ListenAddr string
// If set to a non-nil value, the given NAT port mapper
// is used to make the listening port available to the
// Internet.
NAT NAT
// If Dialer is set to a non-nil value, the given Dialer
// is used to dial outbound peer connections.
Dialer *net.Dialer
// If NoDial is true, the server will not dial any peers.
NoDial bool
// Hook for testing. This is useful because we can inhibit
// the whole protocol stack.
newPeerFunc peerFunc
lock sync.RWMutex
running bool
listener net.Listener
laddr *net.TCPAddr // real listen addr
peers []*Peer peers []*Peer
peerSlots chan int peerSlots chan int
peersTable map[string]int
peersMsg *Msg
peerCount int peerCount int
peerConnect chan net.Addr quit chan struct{}
peerDisconnect chan DisconnectRequest wg sync.WaitGroup
blacklist Blacklist peerConnect chan *peerAddr
handlers Handlers peerDisconnect chan *Peer
} }
var logger = logpkg.NewLogger("P2P") // NAT is implemented by NAT traversal methods.
type NAT interface {
func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { GetExternalAddress() (net.IP, error)
// get alphabetical list of protocol names from handlers map AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
protocols := []string{} DeletePortMapping(protocol string, extport, intport int) error
for protocol := range handlers {
protocols = append(protocols, protocol)
}
sort.Strings(protocols)
_, port, _ := net.SplitHostPort(addr.String())
intport, _ := strconv.Atoi(port)
self := &Server{
// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
network: network,
identity: identity,
addr: addr,
port: uint16(intport),
protocols: protocols,
quit: make(chan chan bool),
maxPeers: maxPeers, // Should return name of the method.
peers: make([]*Peer, maxPeers), String() string
peerSlots: make(chan int, maxPeers), }
peersTable: make(map[string]int),
peerConnect: make(chan net.Addr, outboundAddressPoolSize), type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
peerDisconnect: make(chan DisconnectRequest),
blacklist: blacklist,
handlers: handlers, // Peers returns all connected peers.
func (srv *Server) Peers() (peers []*Peer) {
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil {
peers = append(peers, peer)
} }
for i := 0; i < maxPeers; i++ {
self.peerSlots <- i // fill up with indexes
} }
return self
}
func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) {
addr, err = self.network.NewAddr(host, port)
return
}
func (self *Server) ParseAddr(address string) (addr net.Addr, err error) {
addr, err = self.network.ParseAddr(address)
return return
} }
func (self *Server) ClientIdentity() ClientIdentity { // PeerCount returns the number of connected peers.
return self.identity func (srv *Server) PeerCount() int {
srv.lock.RLock()
defer srv.lock.RUnlock()
return srv.peerCount
} }
func (self *Server) PeersMessage() (msg *Msg, err error) { // SuggestPeer injects an address into the outbound address pool.
// TODO: memoize and reset when peers change func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
self.peersLock.RLock() select {
defer self.peersLock.RUnlock() case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}:
msg = self.peersMsg default: // don't block
if msg == nil {
var peerData []interface{}
for _, i := range self.peersTable {
peer := self.peers[i]
peerData = append(peerData, peer.Encode())
}
if len(peerData) == 0 {
err = fmt.Errorf("no peers")
} else {
msg, err = NewMsg(PeersMsg, peerData...)
self.peersMsg = msg //memoize
}
} }
return
} }
func (self *Server) Peers() (peers []*Peer) { // Broadcast sends an RLP-encoded message to all connected peers.
self.peersLock.RLock() // This method is deprecated and will be removed later.
defer self.peersLock.RUnlock() func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
for _, peer := range self.peers { var payload []byte
if data != nil {
payload = encodePayload(data...)
}
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil { if peer != nil {
peers = append(peers, peer) var msg = Msg{Code: code}
if data != nil {
msg.Payload = bytes.NewReader(payload)
msg.Size = uint32(len(payload))
}
peer.writeProtoMsg(protocol, msg)
} }
} }
return
}
func (self *Server) PeerCount() int {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
return self.peerCount
} }
var getPeersMsg, _ = NewMsg(GetPeersMsg) // Start starts running the server.
// Servers can be re-used and started again after stopping.
func (self *Server) PeerConnect(addr net.Addr) { func (srv *Server) Start() (err error) {
// TODO: should buffer, filter and uniq srv.lock.Lock()
// send GetPeersMsg if not blocking defer srv.lock.Unlock()
select { if srv.running {
case self.peerConnect <- addr: // not enough peers return errors.New("server already running")
self.Broadcast("", getPeersMsg)
default: // we dont care
} }
} srvlog.Infoln("Starting Server")
func (self *Server) PeerDisconnect() chan DisconnectRequest { // initialize fields
return self.peerDisconnect if srv.Identity == nil {
} return fmt.Errorf("Server.Identity must be set to a non-nil identity")
}
if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0")
}
srv.quit = make(chan struct{})
srv.peers = make([]*Peer, srv.MaxPeers)
srv.peerSlots = make(chan int, srv.MaxPeers)
srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
srv.peerDisconnect = make(chan *Peer)
if srv.newPeerFunc == nil {
srv.newPeerFunc = newServerPeer
}
if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist()
}
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
func (self *Server) Blacklist() Blacklist { if srv.ListenAddr != "" {
return self.blacklist if err := srv.startListening(); err != nil {
} return err
}
}
if !srv.NoDial {
srv.wg.Add(1)
go srv.dialLoop()
}
if srv.NoDial && srv.ListenAddr == "" {
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
}
func (self *Server) Handlers() Handlers { // make all slots available
return self.handlers for i := range srv.peers {
srv.peerSlots <- i
}
// note: discLoop is not part of WaitGroup
go srv.discLoop()
srv.running = true
return nil
} }
func (self *Server) Broadcast(protocol string, msg *Msg) { func (srv *Server) startListening() error {
self.peersLock.RLock() listener, err := net.Listen("tcp", srv.ListenAddr)
defer self.peersLock.RUnlock() if err != nil {
for _, peer := range self.peers { return err
if peer != nil {
peer.Write(protocol, msg)
} }
srv.ListenAddr = listener.Addr().String()
srv.laddr = listener.Addr().(*net.TCPAddr)
srv.listener = listener
srv.wg.Add(1)
go srv.listenLoop()
if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
srv.wg.Add(1)
go srv.natLoop(srv.laddr.Port)
} }
return nil
} }
// Start the server // Stop terminates the server and all active peer connections.
func (self *Server) Start(listen bool, dial bool) { // It blocks until all active connections have been closed.
self.network.Start() func (srv *Server) Stop() {
if listen { srv.lock.Lock()
listener, err := self.network.Listener(self.addr) if !srv.running {
if err != nil { srv.lock.Unlock()
logger.Warnf("Error initializing listener: %v", err) return
logger.Warnf("Connection listening disabled")
self.listening = false
} else {
self.listening = true
logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr())
go self.inboundPeerHandler(listener)
} }
srv.running = false
srv.lock.Unlock()
srvlog.Infoln("Stopping server")
if srv.listener != nil {
// this unblocks listener Accept
srv.listener.Close()
} }
if dial { close(srv.quit)
dialer, err := self.network.Dialer(self.addr) for _, peer := range srv.Peers() {
if err != nil { peer.Disconnect(DiscQuitting)
logger.Warnf("Error initializing dialer: %v", err)
logger.Warnf("Connection dialout disabled")
self.dialing = false
} else {
self.dialing = true
logger.Infoln("Dial peers watching outbound address pool")
go self.outboundPeerHandler(dialer)
} }
srv.wg.Wait()
// wait till they actually disconnect
// this is checked by claiming all peerSlots.
// slots become available as the peers disconnect.
for i := 0; i < cap(srv.peerSlots); i++ {
<-srv.peerSlots
} }
logger.Infoln("server started") // terminate discLoop
close(srv.peerDisconnect)
} }
func (self *Server) Stop() { func (srv *Server) discLoop() {
logger.Infoln("server stopping...") for peer := range srv.peerDisconnect {
// // quit one loop if dialing // peer has just disconnected. free up its slot.
if self.dialing { srvlog.Infof("%v is gone", peer)
logger.Infoln("stop dialout...") srv.peerSlots <- peer.slot
dialq := make(chan bool) srv.lock.Lock()
self.quit <- dialq srv.peers[peer.slot] = nil
<-dialq srv.lock.Unlock()
fmt.Println("quit another")
}
// quit the other loop if listening
if self.listening {
logger.Infoln("stop listening...")
listenq := make(chan bool)
self.quit <- listenq
<-listenq
fmt.Println("quit one")
}
fmt.Println("quit waited")
logger.Infoln("stopping peers...")
peers := []net.Addr{}
self.peersLock.RLock()
self.closed = true
for _, peer := range self.peers {
if peer != nil {
peers = append(peers, peer.Address)
} }
} }
self.peersLock.RUnlock()
for _, address := range peers { // main loop for adding connections via listening
go self.removePeer(DisconnectRequest{ func (srv *Server) listenLoop() {
addr: address, defer srv.wg.Done()
reason: DiscQuitting,
})
}
// wait till they actually disconnect
// this is checked by draining the peerSlots (slots are released back if a peer is removed)
i := 0
fmt.Println("draining peers")
FOR: srvlog.Infoln("Listening on", srv.listener.Addr())
for { for {
select { select {
case slot := <-self.peerSlots: case slot := <-srv.peerSlots:
i++ conn, err := srv.listener.Accept()
fmt.Printf("%v: found slot %v", i, slot) if err != nil {
if i == self.maxPeers { srv.peerSlots <- slot
break FOR return
} }
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
srv.addPeer(conn, nil, slot)
case <-srv.quit:
return
} }
} }
logger.Infoln("server stopped")
} }
// main loop for adding connections via listening func (srv *Server) natLoop(port int) {
func (self *Server) inboundPeerHandler(listener net.Listener) { defer srv.wg.Done()
for { for {
srv.updatePortMapping(port)
select { select {
case slot := <-self.peerSlots: case <-time.After(portMappingUpdateInterval):
go self.connectInboundPeer(listener, slot) // one more round
case errc := <-self.quit: case <-srv.quit:
listener.Close() srv.removePortMapping(port)
fmt.Println("quit listenloop") return
errc <- true }
}
}
func (srv *Server) updatePortMapping(port int) {
srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
if err != nil {
srvlog.Errorln("Port mapping error:", err)
return return
} }
extip, err := srv.NAT.GetExternalAddress()
if err != nil {
srvlog.Errorln("Error getting external IP:", err)
return
} }
srv.lock.Lock()
extaddr := *(srv.listener.Addr().(*net.TCPAddr))
extaddr.IP = extip
srvlog.Infoln("Mapped port, external addr is", &extaddr)
srv.laddr = &extaddr
srv.lock.Unlock()
} }
// main loop for adding outbound peers based on peerConnect address pool func (srv *Server) removePortMapping(port int) {
// this same loop handles peer disconnect requests as well srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
func (self *Server) outboundPeerHandler(dialer Dialer) { srv.NAT.DeletePortMapping("tcp", port, port)
// addressChan initially set to nil (only watches peerConnect if we need more peers) }
var addressChan chan net.Addr
slots := self.peerSlots func (srv *Server) dialLoop() {
var slot *int defer srv.wg.Done()
var (
suggest chan *peerAddr
slot *int
slots = srv.peerSlots
)
for { for {
select { select {
case i := <-slots: case i := <-slots:
// we need a peer in slot i, slot reserved // we need a peer in slot i, slot reserved
slot = &i slot = &i
// now we can watch for candidate peers in the next loop // now we can watch for candidate peers in the next loop
addressChan = self.peerConnect suggest = srv.peerConnect
// do not consume more until candidate peer is found // do not consume more until candidate peer is found
slots = nil slots = nil
case address := <-addressChan:
case desc := <-suggest:
// candidate peer found, will dial out asyncronously // candidate peer found, will dial out asyncronously
// if connection fails slot will be released // if connection fails slot will be released
go self.connectOutboundPeer(dialer, address, *slot) go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop // we can watch if more peers needed in the next loop
slots = self.peerSlots slots = srv.peerSlots
// until then we dont care about candidate peers // until then we dont care about candidate peers
addressChan = nil suggest = nil
case request := <-self.peerDisconnect:
go self.removePeer(request)
case errc := <-self.quit:
if addressChan != nil && slot != nil {
self.peerSlots <- *slot
}
fmt.Println("quit dialloop")
errc <- true
return
}
}
}
// check if peer address already connected case <-srv.quit:
func (self *Server) connected(address net.Addr) (err error) { // give back the currently reserved slot
self.peersLock.RLock() if slot != nil {
defer self.peersLock.RUnlock() srv.peerSlots <- *slot
// fmt.Printf("address: %v\n", address)
slot, found := self.peersTable[address.String()]
if found {
err = fmt.Errorf("already connected as peer %v (%v)", slot, address)
} }
return return
}
// connect to peer via listener.Accept()
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
var address net.Addr
conn, err := listener.Accept()
if err == nil {
address = conn.RemoteAddr()
err = self.connected(address)
if err != nil {
conn.Close()
} }
} }
if err != nil {
logger.Debugln(err)
self.peerSlots <- slot
} else {
fmt.Printf("adding %v\n", address)
go self.addPeer(conn, address, true, slot)
}
} }
// connect to peer via dial out // connect to peer via dial out
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { func (srv *Server) dialPeer(desc *peerAddr, slot int) {
var conn net.Conn srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
err := self.connected(address) conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
if err == nil {
conn, err = dialer.Dial(address.Network(), address.String())
}
if err != nil { if err != nil {
logger.Debugln(err) srvlog.Errorf("Dial error: %v", err)
self.peerSlots <- slot srv.peerSlots <- slot
} else { return
go self.addPeer(conn, address, false, slot)
} }
go srv.addPeer(conn, desc, slot)
} }
// creates the new peer object and inserts it into its slot // creates the new peer object and inserts it into its slot
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) { func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
self.peersLock.Lock() srv.lock.Lock()
defer self.peersLock.Unlock() defer srv.lock.Unlock()
if self.closed { if !srv.running {
fmt.Println("oopsy, not no longer need peer") conn.Close()
conn.Close() //oopsy our bad srv.peerSlots <- slot // release slot
self.peerSlots <- slot // release slot return nil
} else {
peer := NewPeer(conn, address, inbound, self)
self.peers[slot] = peer
self.peersTable[address.String()] = slot
self.peerCount++
// reset peersmsg
self.peersMsg = nil
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
peer.Start()
} }
peer := srv.newPeerFunc(srv, conn, desc)
peer.slot = slot
srv.peers[slot] = peer
srv.peerCount++
go func() { peer.loop(); srv.peerDisconnect <- peer }()
return peer
} }
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
func (self *Server) removePeer(request DisconnectRequest) { func (srv *Server) removePeer(peer *Peer) {
self.peersLock.Lock() srv.lock.Lock()
defer srv.lock.Unlock()
address := request.addr srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot)
slot := self.peersTable[address.String()] if srv.peers[peer.slot] != peer {
peer := self.peers[slot] srvlog.Warnln("Invalid peer to remove:", peer)
fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot)
if peer == nil {
logger.Debugf("already removed peer on %v", address)
self.peersLock.Unlock()
return return
} }
// remove from list and index // remove from list and index
self.peerCount-- srv.peerCount--
self.peers[slot] = nil srv.peers[peer.slot] = nil
delete(self.peersTable, address.String())
// reset peersmsg
self.peersMsg = nil
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
self.peersLock.Unlock()
// sending disconnect message
disconnectMsg, _ := NewMsg(DiscMsg, request.reason)
peer.Write("", disconnectMsg)
// be nice and wait
time.Sleep(disconnectGracePeriod * time.Second)
// switch off peer and close connections etc.
fmt.Println("stopping peer")
peer.Stop()
fmt.Println("stopped peer")
// release slot to signal need for a new peer, last! // release slot to signal need for a new peer, last!
self.peerSlots <- slot srv.peerSlots <- peer.slot
} }
// fix handshake message to push to peers func (srv *Server) verifyPeer(addr *peerAddr) error {
func (self *Server) Handshake() *Msg { if srv.Blacklist.Exists(addr.Pubkey) {
fmt.Println(self.identity.Pubkey()[1:]) return errors.New("blacklisted")
msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:]) }
return msg if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
}
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil {
id := peer.Identity()
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
return errors.New("already connected")
}
}
}
return nil
} }
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { type Blacklist interface {
// Check for blacklisting Get([]byte) (bool, error)
if self.blacklist.Exists(pubkey) { Put([]byte) error
return fmt.Errorf("blacklisted") Delete([]byte) error
} Exists(pubkey []byte) (ok bool)
}
self.peersLock.RLock() type BlacklistMap struct {
defer self.peersLock.RUnlock() blacklist map[string]bool
for _, peer := range self.peers { lock sync.RWMutex
if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { }
return fmt.Errorf("already connected")
func NewBlacklist() *BlacklistMap {
return &BlacklistMap{
blacklist: make(map[string]bool),
} }
}
func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
self.lock.RLock()
defer self.lock.RUnlock()
v, ok := self.blacklist[string(pubkey)]
var err error
if !ok {
err = fmt.Errorf("not found")
} }
candidate.Pubkey = pubkey return v, err
}
func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
self.lock.RLock()
defer self.lock.RUnlock()
_, ok = self.blacklist[string(pubkey)]
return
}
func (self *BlacklistMap) Put(pubkey []byte) error {
self.lock.RLock()
defer self.lock.RUnlock()
self.blacklist[string(pubkey)] = true
return nil
}
func (self *BlacklistMap) Delete(pubkey []byte) error {
self.lock.RLock()
defer self.lock.RUnlock()
delete(self.blacklist, string(pubkey))
return nil return nil
} }

@ -2,207 +2,160 @@ package p2p
import ( import (
"bytes" "bytes"
"fmt" "io"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
) )
type TestNetwork struct { func startTestServer(t *testing.T, pf peerFunc) *Server {
connections map[string]*TestNetworkConnection server := &Server{
dialer Dialer Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"),
maxinbound int MaxPeers: 10,
} ListenAddr: "127.0.0.1:0",
newPeerFunc: pf,
func NewTestNetwork(maxinbound int) *TestNetwork {
connections := make(map[string]*TestNetworkConnection)
return &TestNetwork{
connections: connections,
dialer: &TestDialer{connections},
maxinbound: maxinbound,
} }
if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
}
return server
} }
func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { func TestServerListen(t *testing.T) {
return self.dialer, nil defer testlog(t).detach()
}
func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
return &TestListener{
connections: self.connections,
addr: addr,
max: self.maxinbound,
}, nil
}
func (self *TestNetwork) Start() error {
return nil
}
func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
return
}
func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
return
}
type TestAddr struct {
name string
}
func (self *TestAddr) String() string {
return self.name
}
func (*TestAddr) Network() string {
return "test"
}
type TestDialer struct {
connections map[string]*TestNetworkConnection
}
func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
address := &TestAddr{addr}
tconn := NewTestNetworkConnection(address)
self.connections[addr] = tconn
conn = net.Conn(tconn)
return
}
type TestListener struct {
connections map[string]*TestNetworkConnection
addr net.Addr
max int
i int
}
func (self *TestListener) Accept() (conn net.Conn, err error) {
self.i++
if self.i > self.max {
err = fmt.Errorf("no more")
} else {
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
tconn := NewTestNetworkConnection(addr)
key := tconn.RemoteAddr().String()
self.connections[key] = tconn
conn = net.Conn(tconn)
fmt.Printf("accepted connection from: %v \n", addr)
}
return
}
func (self *TestListener) Close() error { // start the test server
return nil connected := make(chan *Peer)
} srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
if conn == nil {
t.Error("peer func called with nil conn")
}
if dialAddr != nil {
t.Error("peer func called with non-nil dialAddr")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected)
defer srv.Stop()
// dial the test server
conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
defer conn.Close()
func (self *TestListener) Addr() net.Addr { select {
return self.addr case peer := <-connected:
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.LocalAddr(), conn.RemoteAddr())
}
case <-time.After(1 * time.Second):
t.Error("server did not accept within one second")
}
} }
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { func TestServerDial(t *testing.T) {
network = NewTestNetwork(1) defer testlog(t).detach()
addr := &TestAddr{"test:30303"}
identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
maxPeers := 2
if handlers == nil {
handlers = make(Handlers)
}
blackist := NewBlacklist()
server = New(network, addr, identity, handlers, maxPeers, blackist)
fmt.Println(server.identity.Pubkey())
return
}
func TestServerListener(t *testing.T) { // run a fake TCP server to handle the connection.
network, server := SetupTestServer(nil) listener, err := net.Listen("tcp", "127.0.0.1:0")
server.Start(true, false) if err != nil {
time.Sleep(10 * time.Millisecond) t.Fatalf("could not setup listener: %v")
server.Stop()
peer1, ok := network.connections["inboundpeer-1"]
if !ok {
t.Error("not found inbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 2 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
} }
defer listener.Close()
accepted := make(chan net.Conn)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Error("acccept error:", err)
} }
conn.Close()
} accepted <- conn
}()
func TestServerDialer(t *testing.T) {
network, server := SetupTestServer(nil) // start the test server
server.Start(false, true) connected := make(chan *Peer)
server.peerConnect <- &TestAddr{"outboundpeer-1"} srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
time.Sleep(10 * time.Millisecond) if conn == nil {
server.Stop() t.Error("peer func called with nil conn")
peer1, ok := network.connections["outboundpeer-1"]
if !ok {
t.Error("not found outbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 2 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
} }
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected)
defer srv.Stop()
// tell the server to connect.
connAddr := newPeerAddr(listener.Addr(), nil)
srv.peerConnect <- connAddr
select {
case conn := <-accepted:
select {
case peer := <-connected:
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.RemoteAddr(), conn.LocalAddr())
} }
} if peer.dialAddr != connAddr {
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
func TestServerBroadcast(t *testing.T) { peer.dialAddr, connAddr)
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
network, server := SetupTestServer(handlers)
server.Start(true, true)
server.peerConnect <- &TestAddr{"outboundpeer-1"}
time.Sleep(10 * time.Millisecond)
msg, _ := NewMsg(0)
server.Broadcast("", msg)
packet := Packet(0, 0)
time.Sleep(10 * time.Millisecond)
server.Stop()
peer1, ok := network.connections["outboundpeer-1"]
if !ok {
t.Error("not found outbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 3 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
} else {
if bytes.Compare(peer1.Out[1], packet) != 0 {
t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
}
}
}
peer2, ok := network.connections["inboundpeer-1"]
if !ok {
t.Error("not found inbound peer 2")
} else {
fmt.Printf("out: %v\n", peer2.Out)
if len(peer1.Out) != 3 {
t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
} else {
if bytes.Compare(peer2.Out[1], packet) != 0 {
t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
} }
case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second")
} }
case <-time.After(1 * time.Second):
t.Error("server did not connect within one second")
} }
} }
func TestServerPeersMessage(t *testing.T) { func TestServerBroadcast(t *testing.T) {
handlers := make(Handlers) defer testlog(t).detach()
_, server := SetupTestServer(handlers) var connected sync.WaitGroup
server.Start(true, true) srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
defer server.Stop() peer := newPeer(c, []Protocol{discard}, dialAddr)
server.peerConnect <- &TestAddr{"outboundpeer-1"} peer.startSubprotocols([]Cap{discard.cap()})
time.Sleep(10 * time.Millisecond) connected.Done()
peersMsg, err := server.PeersMessage() return peer
fmt.Println(peersMsg) })
defer srv.Stop()
// dial a bunch of conns
var conns = make([]net.Conn, 8)
connected.Add(len(conns))
deadline := time.Now().Add(3 * time.Second)
dialer := &net.Dialer{Deadline: deadline}
for i := range conns {
conn, err := dialer.Dial("tcp", srv.ListenAddr)
if err != nil { if err != nil {
t.Errorf("expect no error, got %v", err) t.Fatalf("conn %d: dial error: %v", i, err)
}
defer conn.Close()
conn.SetDeadline(deadline)
conns[i] = conn
}
connected.Wait()
// broadcast one message
srv.Broadcast("discard", 0, "foo")
goldbuf := new(bytes.Buffer)
writeMsg(goldbuf, NewMsg(16, "foo"))
golden := goldbuf.Bytes()
// check that the message has been written everywhere
for i, conn := range conns {
buf := make([]byte, len(golden))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Errorf("conn %d: read error: %v", i, err)
} else if !bytes.Equal(buf, golden) {
t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
} }
if c := server.PeerCount(); c != 2 {
t.Errorf("expect 2 peers, got %v", c)
} }
} }

@ -0,0 +1,28 @@
package p2p
import (
"testing"
"github.com/ethereum/go-ethereum/logger"
)
type testLogger struct{ t *testing.T }
func testlog(t *testing.T) testLogger {
logger.Reset()
l := testLogger{t}
logger.AddLogSystem(l)
return l
}
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
l.t.Logf("%s", msg)
}
func (testLogger) detach() {
logger.Flush()
logger.Reset()
}

@ -0,0 +1,40 @@
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/obscuren/secp256k1-go"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}

@ -1,6 +1,7 @@
package rlp package rlp
import ( import (
"bufio"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -24,8 +25,9 @@ type Decoder interface {
DecodeRLP(*Stream) error DecodeRLP(*Stream) error
} }
// Decode parses RLP-encoded data from r and stores the result // Decode parses RLP-encoded data from r and stores the result in the
// in the value pointed to by val. Val must be a non-nil pointer. // value pointed to by val. Val must be a non-nil pointer. If r does
// not implement ByteReader, Decode will do its own buffering.
// //
// Decode uses the following type-dependent decoding rules: // Decode uses the following type-dependent decoding rules:
// //
@ -66,10 +68,19 @@ type Decoder interface {
// //
// Non-empty interface types are not supported, nor are bool, float32, // Non-empty interface types are not supported, nor are bool, float32,
// float64, maps, channel types and functions. // float64, maps, channel types and functions.
func Decode(r ByteReader, val interface{}) error { func Decode(r io.Reader, val interface{}) error {
return NewStream(r).Decode(val) return NewStream(r).Decode(val)
} }
type decodeError struct {
msg string
typ reflect.Type
}
func (err decodeError) Error() string {
return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ)
}
func makeNumDecoder(typ reflect.Type) decoder { func makeNumDecoder(typ reflect.Type) decoder {
kind := typ.Kind() kind := typ.Kind()
switch { switch {
@ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder {
} }
func decodeInt(s *Stream, val reflect.Value) error { func decodeInt(s *Stream, val reflect.Value) error {
num, err := s.uint(val.Type().Bits()) typ := val.Type()
if err != nil { num, err := s.uint(typ.Bits())
if err == errUintOverflow {
return decodeError{"input string too long", typ}
} else if err != nil {
return err return err
} }
val.SetInt(int64(num)) val.SetInt(int64(num))
@ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error {
} }
func decodeUint(s *Stream, val reflect.Value) error { func decodeUint(s *Stream, val reflect.Value) error {
num, err := s.uint(val.Type().Bits()) typ := val.Type()
if err != nil { num, err := s.uint(typ.Bits())
if err == errUintOverflow {
return decodeError{"input string too big", typ}
} else if err != nil {
return err return err
} }
val.SetUint(num) val.SetUint(num)
@ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro
i := 0 i := 0
for { for {
if i > maxelem { if i > maxelem {
return fmt.Errorf("rlp: input List has more than %d elements", maxelem) return decodeError{"input list has too many elements", val.Type()}
} }
if val.Kind() == reflect.Slice { if val.Kind() == reflect.Slice {
// grow slice if necessary // grow slice if necessary
@ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error {
return err return err
} }
var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array")
func decodeByteArray(s *Stream, val reflect.Value) error { func decodeByteArray(s *Stream, val reflect.Value) error {
kind, size, err := s.Kind() kind, size, err := s.Kind()
if err != nil { if err != nil {
@ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error {
switch kind { switch kind {
case Byte: case Byte:
if val.Len() == 0 { if val.Len() == 0 {
return errStringDoesntFitArray return decodeError{"input string too big", val.Type()}
} }
bv, _ := s.Uint() bv, _ := s.Uint()
val.Index(0).SetUint(bv) val.Index(0).SetUint(bv)
zero(val, 1) zero(val, 1)
case String: case String:
if uint64(val.Len()) < size { if uint64(val.Len()) < size {
return errStringDoesntFitArray return decodeError{"input string too big", val.Type()}
} }
slice := val.Slice(0, int(size)).Interface().([]byte) slice := val.Slice(0, int(size)).Interface().([]byte)
if err := s.readFull(slice); err != nil { if err := s.readFull(slice); err != nil {
@ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
} }
} }
if err = s.ListEnd(); err == errNotAtEOL { if err = s.ListEnd(); err == errNotAtEOL {
err = errors.New("rlp: input List has too many elements") err = decodeError{"input list has too many elements", typ}
} }
return err return err
} }
@ -432,8 +447,23 @@ type Stream struct {
type listpos struct{ pos, size uint64 } type listpos struct{ pos, size uint64 }
func NewStream(r ByteReader) *Stream { // NewStream creates a new stream reading from r.
return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} // If r does not implement ByteReader, the Stream will
// introduce its own buffering.
func NewStream(r io.Reader) *Stream {
s := new(Stream)
s.Reset(r)
return s
}
// NewListStream creates a new stream that pretends to be positioned
// at an encoded list of the given length.
func NewListStream(r io.Reader, len uint64) *Stream {
s := new(Stream)
s.Reset(r)
s.kind = List
s.size = len
return s
} }
// Bytes reads an RLP string and returns its contents as a byte slice. // Bytes reads an RLP string and returns its contents as a byte slice.
@ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) {
} }
} }
var errUintOverflow = errors.New("rlp: uint overflow")
// Uint reads an RLP string of up to 8 bytes and returns its contents // Uint reads an RLP string of up to 8 bytes and returns its contents
// as an unsigned integer. If the input does not contain an RLP string, the // as an unsigned integer. If the input does not contain an RLP string, the
// returned error will be ErrExpectedString. // returned error will be ErrExpectedString.
@ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) {
return uint64(s.byteval), nil return uint64(s.byteval), nil
case String: case String:
if size > uint64(maxbits/8) { if size > uint64(maxbits/8) {
return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits) return 0, errUintOverflow
} }
return s.readUint(byte(size)) return s.readUint(byte(size))
default: default:
@ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error {
return info.decoder(s, rval.Elem()) return info.decoder(s, rval.Elem())
} }
// Reset discards any information about the current decoding context
// and starts reading from r. If r does not also implement ByteReader,
// Stream will do its own buffering.
func (s *Stream) Reset(r io.Reader) {
bufr, ok := r.(ByteReader)
if !ok {
bufr = bufio.NewReader(r)
}
s.r = bufr
s.stack = s.stack[:0]
s.size = 0
s.kind = -1
if s.uintbuf == nil {
s.uintbuf = make([]byte, 8)
}
}
// Kind returns the kind and size of the next value in the // Kind returns the kind and size of the next value in the
// input stream. // input stream.
// //

@ -3,7 +3,6 @@ package rlp
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
@ -54,6 +53,24 @@ func TestStreamKind(t *testing.T) {
} }
} }
func TestNewListStream(t *testing.T) {
ls := NewListStream(bytes.NewReader(unhex("0101010101")), 3)
if k, size, err := ls.Kind(); k != List || size != 3 || err != nil {
t.Errorf("Kind() returned (%v, %d, %v), expected (List, 3, nil)", k, size, err)
}
if size, err := ls.List(); size != 3 || err != nil {
t.Errorf("List() returned (%d, %v), expected (3, nil)", size, err)
}
for i := 0; i < 3; i++ {
if val, err := ls.Uint(); val != 1 || err != nil {
t.Errorf("Uint() returned (%d, %v), expected (1, nil)", val, err)
}
}
if err := ls.ListEnd(); err != nil {
t.Errorf("ListEnd() returned %v, expected (3, nil)", err)
}
}
func TestStreamErrors(t *testing.T) { func TestStreamErrors(t *testing.T) {
type calls []string type calls []string
tests := []struct { tests := []struct {
@ -69,7 +86,7 @@ func TestStreamErrors(t *testing.T) {
{"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, {"81", calls{"Bytes"}, io.ErrUnexpectedEOF},
{"81", calls{"Uint"}, io.ErrUnexpectedEOF}, {"81", calls{"Uint"}, io.ErrUnexpectedEOF},
{"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF},
{"89000000000000000001", calls{"Uint"}, errors.New("rlp: string is larger than 64 bits")}, {"89000000000000000001", calls{"Uint"}, errUintOverflow},
{"00", calls{"List"}, ErrExpectedList}, {"00", calls{"List"}, ErrExpectedList},
{"80", calls{"List"}, ErrExpectedList}, {"80", calls{"List"}, ErrExpectedList},
{"C0", calls{"List", "Uint"}, EOL}, {"C0", calls{"List", "Uint"}, EOL},
@ -163,7 +180,7 @@ type decodeTest struct {
input string input string
ptr interface{} ptr interface{}
value interface{} value interface{}
error error error string
} }
type simplestruct struct { type simplestruct struct {
@ -196,8 +213,8 @@ var decodeTests = []decodeTest{
{input: "820505", ptr: new(uint32), value: uint32(0x0505)}, {input: "820505", ptr: new(uint32), value: uint32(0x0505)},
{input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, {input: "83050505", ptr: new(uint32), value: uint32(0x050505)},
{input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, {input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)},
{input: "850505050505", ptr: new(uint32), error: errors.New("rlp: string is larger than 32 bits")}, {input: "850505050505", ptr: new(uint32), error: "rlp: input string too big for uint32"},
{input: "C0", ptr: new(uint32), error: ErrExpectedString}, {input: "C0", ptr: new(uint32), error: ErrExpectedString.Error()},
// slices // slices
{input: "C0", ptr: new([]int), value: []int{}}, {input: "C0", ptr: new([]int), value: []int{}},
@ -206,7 +223,7 @@ var decodeTests = []decodeTest{
// arrays // arrays
{input: "C0", ptr: new([5]int), value: [5]int{}}, {input: "C0", ptr: new([5]int), value: [5]int{}},
{input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}}, {input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}},
{input: "C6010203040506", ptr: new([5]int), error: errors.New("rlp: input List has more than 5 elements")}, {input: "C6010203040506", ptr: new([5]int), error: "rlp: input list has too many elements for [5]int"},
// byte slices // byte slices
{input: "01", ptr: new([]byte), value: []byte{1}}, {input: "01", ptr: new([]byte), value: []byte{1}},
@ -214,7 +231,7 @@ var decodeTests = []decodeTest{
{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
{input: "C0", ptr: new([]byte), value: []byte{}}, {input: "C0", ptr: new([]byte), value: []byte{}},
{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}},
{input: "C3820102", ptr: new([]byte), error: errors.New("rlp: string is larger than 8 bits")}, {input: "C3820102", ptr: new([]byte), error: "rlp: input string too big for uint8"},
// byte arrays // byte arrays
{input: "01", ptr: new([5]byte), value: [5]byte{1}}, {input: "01", ptr: new([5]byte), value: [5]byte{1}},
@ -222,9 +239,9 @@ var decodeTests = []decodeTest{
{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
{input: "C0", ptr: new([5]byte), value: [5]byte{}}, {input: "C0", ptr: new([5]byte), value: [5]byte{}},
{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}},
{input: "C3820102", ptr: new([5]byte), error: errors.New("rlp: string is larger than 8 bits")}, {input: "C3820102", ptr: new([5]byte), error: "rlp: input string too big for uint8"},
{input: "86010203040506", ptr: new([5]byte), error: errStringDoesntFitArray}, {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too big for [5]uint8"},
{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF}, {input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()},
// byte array reuse (should be zeroed) // byte array reuse (should be zeroed)
{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
@ -237,25 +254,25 @@ var decodeTests = []decodeTest{
// zero sized byte arrays // zero sized byte arrays
{input: "80", ptr: new([0]byte), value: [0]byte{}}, {input: "80", ptr: new([0]byte), value: [0]byte{}},
{input: "C0", ptr: new([0]byte), value: [0]byte{}}, {input: "C0", ptr: new([0]byte), value: [0]byte{}},
{input: "01", ptr: new([0]byte), error: errStringDoesntFitArray}, {input: "01", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"},
{input: "8101", ptr: new([0]byte), error: errStringDoesntFitArray}, {input: "8101", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"},
// strings // strings
{input: "00", ptr: new(string), value: "\000"}, {input: "00", ptr: new(string), value: "\000"},
{input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, {input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"},
{input: "C0", ptr: new(string), error: ErrExpectedString}, {input: "C0", ptr: new(string), error: ErrExpectedString.Error()},
// big ints // big ints
{input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "01", ptr: new(*big.Int), value: big.NewInt(1)},
{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt},
{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
{input: "C0", ptr: new(*big.Int), error: ErrExpectedString}, {input: "C0", ptr: new(*big.Int), error: ErrExpectedString.Error()},
// structs // structs
{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
{input: "C3010101", ptr: new(simplestruct), error: errors.New("rlp: input List has too many elements")}, {input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"},
{ {
input: "C501C302C103", input: "C501C302C103",
ptr: new(recstruct), ptr: new(recstruct),
@ -286,20 +303,20 @@ var decodeTests = []decodeTest{
func intp(i int) *int { return &i } func intp(i int) *int { return &i }
func TestDecode(t *testing.T) { func runTests(t *testing.T, decode func([]byte, interface{}) error) {
for i, test := range decodeTests { for i, test := range decodeTests {
input, err := hex.DecodeString(test.input) input, err := hex.DecodeString(test.input)
if err != nil { if err != nil {
t.Errorf("test %d: invalid hex input %q", i, test.input) t.Errorf("test %d: invalid hex input %q", i, test.input)
continue continue
} }
err = Decode(bytes.NewReader(input), test.ptr) err = decode(input, test.ptr)
if err != nil && test.error == nil { if err != nil && test.error == "" {
t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q",
i, err, test.ptr, test.input) i, err, test.ptr, test.input)
continue continue
} }
if test.error != nil && fmt.Sprint(err) != fmt.Sprint(test.error) { if test.error != "" && fmt.Sprint(err) != test.error {
t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q", t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q",
i, err, test.error, test.ptr, test.input) i, err, test.error, test.ptr, test.input)
continue continue
@ -312,6 +329,40 @@ func TestDecode(t *testing.T) {
} }
} }
func TestDecodeWithByteReader(t *testing.T) {
runTests(t, func(input []byte, into interface{}) error {
return Decode(bytes.NewReader(input), into)
})
}
// dumbReader reads from a byte slice but does not
// implement ReadByte.
type dumbReader []byte
func (r *dumbReader) Read(buf []byte) (n int, err error) {
if len(*r) == 0 {
return 0, io.EOF
}
n = copy(buf, *r)
*r = (*r)[n:]
return n, nil
}
func TestDecodeWithNonByteReader(t *testing.T) {
runTests(t, func(input []byte, into interface{}) error {
r := dumbReader(input)
return Decode(&r, into)
})
}
func TestDecodeStreamReset(t *testing.T) {
s := NewStream(nil)
runTests(t, func(input []byte, into interface{}) error {
s.Reset(bytes.NewReader(input))
return s.Decode(into)
})
}
type testDecoder struct{ called bool } type testDecoder struct{ called bool }
func (t *testDecoder) DecodeRLP(s *Stream) error { func (t *testDecoder) DecodeRLP(s *Stream) error {

Loading…
Cancel
Save