mirror of https://github.com/ethereum/go-ethereum
commit
797b93c98c
@ -0,0 +1,63 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"fmt" |
||||
"runtime" |
||||
) |
||||
|
||||
// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc.
|
||||
type ClientIdentity interface { |
||||
String() string |
||||
Pubkey() []byte |
||||
} |
||||
|
||||
type SimpleClientIdentity struct { |
||||
clientIdentifier string |
||||
version string |
||||
customIdentifier string |
||||
os string |
||||
implementation string |
||||
pubkey string |
||||
} |
||||
|
||||
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey string) *SimpleClientIdentity { |
||||
clientIdentity := &SimpleClientIdentity{ |
||||
clientIdentifier: clientIdentifier, |
||||
version: version, |
||||
customIdentifier: customIdentifier, |
||||
os: runtime.GOOS, |
||||
implementation: runtime.Version(), |
||||
pubkey: pubkey, |
||||
} |
||||
|
||||
return clientIdentity |
||||
} |
||||
|
||||
func (c *SimpleClientIdentity) init() { |
||||
} |
||||
|
||||
func (c *SimpleClientIdentity) String() string { |
||||
var id string |
||||
if len(c.customIdentifier) > 0 { |
||||
id = "/" + c.customIdentifier |
||||
} |
||||
|
||||
return fmt.Sprintf("%s/v%s%s/%s/%s", |
||||
c.clientIdentifier, |
||||
c.version, |
||||
id, |
||||
c.os, |
||||
c.implementation) |
||||
} |
||||
|
||||
func (c *SimpleClientIdentity) Pubkey() []byte { |
||||
return []byte(c.pubkey) |
||||
} |
||||
|
||||
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) { |
||||
c.customIdentifier = customIdentifier |
||||
} |
||||
|
||||
func (c *SimpleClientIdentity) GetCustomIdentifier() string { |
||||
return c.customIdentifier |
||||
} |
@ -0,0 +1,30 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"fmt" |
||||
"runtime" |
||||
"testing" |
||||
) |
||||
|
||||
func TestClientIdentity(t *testing.T) { |
||||
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey") |
||||
clientString := clientIdentity.String() |
||||
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version()) |
||||
if clientString != expected { |
||||
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString) |
||||
} |
||||
customIdentifier := clientIdentity.GetCustomIdentifier() |
||||
if customIdentifier != "test" { |
||||
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier) |
||||
} |
||||
clientIdentity.SetCustomIdentifier("test2") |
||||
customIdentifier = clientIdentity.GetCustomIdentifier() |
||||
if customIdentifier != "test2" { |
||||
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier) |
||||
} |
||||
clientString = clientIdentity.String() |
||||
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version()) |
||||
if clientString != expected { |
||||
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString) |
||||
} |
||||
} |
@ -0,0 +1,275 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
// "fmt"
|
||||
"net" |
||||
"time" |
||||
|
||||
"github.com/ethereum/eth-go/ethutil" |
||||
) |
||||
|
||||
type Connection struct { |
||||
conn net.Conn |
||||
// conn NetworkConnection
|
||||
timeout time.Duration |
||||
in chan []byte |
||||
out chan []byte |
||||
err chan *PeerError |
||||
closingIn chan chan bool |
||||
closingOut chan chan bool |
||||
} |
||||
|
||||
// const readBufferLength = 2 //for testing
|
||||
|
||||
const readBufferLength = 1440 |
||||
const partialsQueueSize = 10 |
||||
const maxPendingQueueSize = 1 |
||||
const defaultTimeout = 500 |
||||
|
||||
var magicToken = []byte{34, 64, 8, 145} |
||||
|
||||
func (self *Connection) Open() { |
||||
go self.startRead() |
||||
go self.startWrite() |
||||
} |
||||
|
||||
func (self *Connection) Close() { |
||||
self.closeIn() |
||||
self.closeOut() |
||||
} |
||||
|
||||
func (self *Connection) closeIn() { |
||||
errc := make(chan bool) |
||||
self.closingIn <- errc |
||||
<-errc |
||||
} |
||||
|
||||
func (self *Connection) closeOut() { |
||||
errc := make(chan bool) |
||||
self.closingOut <- errc |
||||
<-errc |
||||
} |
||||
|
||||
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection { |
||||
return &Connection{ |
||||
conn: conn, |
||||
timeout: defaultTimeout, |
||||
in: make(chan []byte), |
||||
out: make(chan []byte), |
||||
err: errchan, |
||||
closingIn: make(chan chan bool, 1), |
||||
closingOut: make(chan chan bool, 1), |
||||
} |
||||
} |
||||
|
||||
func (self *Connection) Read() <-chan []byte { |
||||
return self.in |
||||
} |
||||
|
||||
func (self *Connection) Write() chan<- []byte { |
||||
return self.out |
||||
} |
||||
|
||||
func (self *Connection) Error() <-chan *PeerError { |
||||
return self.err |
||||
} |
||||
|
||||
func (self *Connection) startRead() { |
||||
payloads := make(chan []byte) |
||||
done := make(chan *PeerError) |
||||
pending := [][]byte{} |
||||
var head []byte |
||||
var wait time.Duration // initally 0 (no delay)
|
||||
read := time.After(wait * time.Millisecond) |
||||
|
||||
for { |
||||
// if pending empty, nil channel blocks
|
||||
var in chan []byte |
||||
if len(pending) > 0 { |
||||
in = self.in // enable send case
|
||||
head = pending[0] |
||||
} else { |
||||
in = nil |
||||
} |
||||
|
||||
select { |
||||
case <-read: |
||||
go self.read(payloads, done) |
||||
case err := <-done: |
||||
if err == nil { // no error but nothing to read
|
||||
if len(pending) < maxPendingQueueSize { |
||||
wait = 100 |
||||
} else if wait == 0 { |
||||
wait = 100 |
||||
} else { |
||||
wait = 2 * wait |
||||
} |
||||
} else { |
||||
self.err <- err // report error
|
||||
wait = 100 |
||||
} |
||||
read = time.After(wait * time.Millisecond) |
||||
case payload := <-payloads: |
||||
pending = append(pending, payload) |
||||
if len(pending) < maxPendingQueueSize { |
||||
wait = 0 |
||||
} else { |
||||
wait = 100 |
||||
} |
||||
read = time.After(wait * time.Millisecond) |
||||
case in <- head: |
||||
pending = pending[1:] |
||||
case errc := <-self.closingIn: |
||||
errc <- true |
||||
close(self.in) |
||||
return |
||||
} |
||||
|
||||
} |
||||
} |
||||
|
||||
func (self *Connection) startWrite() { |
||||
pending := [][]byte{} |
||||
done := make(chan *PeerError) |
||||
writing := false |
||||
for { |
||||
if len(pending) > 0 && !writing { |
||||
writing = true |
||||
go self.write(pending[0], done) |
||||
} |
||||
select { |
||||
case payload := <-self.out: |
||||
pending = append(pending, payload) |
||||
case err := <-done: |
||||
if err == nil { |
||||
pending = pending[1:] |
||||
writing = false |
||||
} else { |
||||
self.err <- err // report error
|
||||
} |
||||
case errc := <-self.closingOut: |
||||
errc <- true |
||||
close(self.out) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func pack(payload []byte) (packet []byte) { |
||||
length := ethutil.NumberToBytes(uint32(len(payload)), 32) |
||||
// return error if too long?
|
||||
// Write magic token and payload length (first 8 bytes)
|
||||
packet = append(magicToken, length...) |
||||
packet = append(packet, payload...) |
||||
return |
||||
} |
||||
|
||||
func avoidPanic(done chan *PeerError) { |
||||
if rec := recover(); rec != nil { |
||||
err := NewPeerError(MiscError, " %v", rec) |
||||
logger.Debugln(err) |
||||
done <- err |
||||
} |
||||
} |
||||
|
||||
func (self *Connection) write(payload []byte, done chan *PeerError) { |
||||
defer avoidPanic(done) |
||||
var err *PeerError |
||||
_, ok := self.conn.Write(pack(payload)) |
||||
if ok != nil { |
||||
err = NewPeerError(WriteError, " %v", ok) |
||||
logger.Debugln(err) |
||||
} |
||||
done <- err |
||||
} |
||||
|
||||
func (self *Connection) read(payloads chan []byte, done chan *PeerError) { |
||||
//defer avoidPanic(done)
|
||||
|
||||
partials := make(chan []byte, partialsQueueSize) |
||||
errc := make(chan *PeerError) |
||||
go self.readPartials(partials, errc) |
||||
|
||||
packet := []byte{} |
||||
length := 8 |
||||
start := true |
||||
var err *PeerError |
||||
out: |
||||
for { |
||||
// appends partials read via connection until packet is
|
||||
// - either parseable (>=8bytes)
|
||||
// - or complete (payload fully consumed)
|
||||
for len(packet) < length { |
||||
partial, ok := <-partials |
||||
if !ok { // partials channel is closed
|
||||
err = <-errc |
||||
if err == nil && len(packet) > 0 { |
||||
if start { |
||||
err = NewPeerError(PacketTooShort, "%v", packet) |
||||
} else { |
||||
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length) |
||||
} |
||||
} |
||||
break out |
||||
} |
||||
packet = append(packet, partial...) |
||||
} |
||||
if start { |
||||
// at least 8 bytes read, can validate packet
|
||||
if bytes.Compare(magicToken, packet[:4]) != 0 { |
||||
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4]) |
||||
break |
||||
} |
||||
length = int(ethutil.BytesToNumber(packet[4:8])) |
||||
packet = packet[8:] |
||||
|
||||
if length > 0 { |
||||
start = false // now consuming payload
|
||||
} else { //penalize peer but read on
|
||||
self.err <- NewPeerError(EmptyPayload, "") |
||||
length = 8 |
||||
} |
||||
} else { |
||||
// packet complete (payload fully consumed)
|
||||
payloads <- packet[:length] |
||||
packet = packet[length:] // resclice packet
|
||||
start = true |
||||
length = 8 |
||||
} |
||||
} |
||||
|
||||
// this stops partials read via the connection, should we?
|
||||
//if err != nil {
|
||||
// select {
|
||||
// case errc <- err
|
||||
// default:
|
||||
//}
|
||||
done <- err |
||||
} |
||||
|
||||
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) { |
||||
defer close(partials) |
||||
for { |
||||
// Give buffering some time
|
||||
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond)) |
||||
buffer := make([]byte, readBufferLength) |
||||
// read partial from connection
|
||||
bytesRead, err := self.conn.Read(buffer) |
||||
if err == nil || err.Error() == "EOF" { |
||||
if bytesRead > 0 { |
||||
partials <- buffer[:bytesRead] |
||||
} |
||||
if err != nil && err.Error() == "EOF" { |
||||
break |
||||
} |
||||
} else { |
||||
// unexpected error, report to errc
|
||||
err := NewPeerError(ReadError, " %v", err) |
||||
logger.Debugln(err) |
||||
errc <- err |
||||
return // will close partials channel
|
||||
} |
||||
} |
||||
close(errc) |
||||
} |
@ -0,0 +1,222 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
type TestNetworkConnection struct { |
||||
in chan []byte |
||||
current []byte |
||||
Out [][]byte |
||||
addr net.Addr |
||||
} |
||||
|
||||
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { |
||||
return &TestNetworkConnection{ |
||||
in: make(chan []byte), |
||||
current: []byte{}, |
||||
Out: [][]byte{}, |
||||
addr: addr, |
||||
} |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { |
||||
time.Sleep(latency) |
||||
for _, s := range packets { |
||||
self.in <- s |
||||
} |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { |
||||
if len(self.current) == 0 { |
||||
select { |
||||
case self.current = <-self.in: |
||||
default: |
||||
return 0, io.EOF |
||||
} |
||||
} |
||||
length := len(self.current) |
||||
if length > len(buff) { |
||||
copy(buff[:], self.current[:len(buff)]) |
||||
self.current = self.current[len(buff):] |
||||
return len(buff), nil |
||||
} else { |
||||
copy(buff[:length], self.current[:]) |
||||
self.current = []byte{} |
||||
return length, io.EOF |
||||
} |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { |
||||
self.Out = append(self.Out, buff) |
||||
fmt.Printf("net write %v\n%v\n", len(self.Out), buff) |
||||
return len(buff), nil |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) Close() (err error) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { |
||||
return self.addr |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { |
||||
return |
||||
} |
||||
|
||||
func setupConnection() (*Connection, *TestNetworkConnection) { |
||||
addr := &TestAddr{"test:30303"} |
||||
net := NewTestNetworkConnection(addr) |
||||
conn := NewConnection(net, NewPeerErrorChannel()) |
||||
conn.Open() |
||||
return conn, net |
||||
} |
||||
|
||||
func TestReadingNilPacket(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{}) |
||||
// time.Sleep(10 * time.Millisecond)
|
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
case err := <-conn.Error(): |
||||
t.Errorf("incorrect error %v", err) |
||||
default: |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingShortPacket(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{0}) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
case err := <-conn.Error(): |
||||
if err.Code != PacketTooShort { |
||||
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort) |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingInvalidPacket(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0}) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
case err := <-conn.Error(): |
||||
if err.Code != MagicTokenMismatch { |
||||
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch) |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingInvalidPayload(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0}) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
case err := <-conn.Error(): |
||||
if err.Code != PayloadTooShort { |
||||
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort) |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingEmptyPayload(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0}) |
||||
time.Sleep(10 * time.Millisecond) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
t.Errorf("read %v", packet) |
||||
default: |
||||
} |
||||
select { |
||||
case err := <-conn.Error(): |
||||
code := err.Code |
||||
if code != EmptyPayload { |
||||
t.Errorf("incorrect error, expected EmptyPayload, got %v", code) |
||||
} |
||||
default: |
||||
t.Errorf("no error, expected EmptyPayload") |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingCompletePacket(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1}) |
||||
time.Sleep(10 * time.Millisecond) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
if bytes.Compare(packet, []byte{1}) != 0 { |
||||
t.Errorf("incorrect payload read") |
||||
} |
||||
case err := <-conn.Error(): |
||||
t.Errorf("incorrect error %v", err) |
||||
default: |
||||
t.Errorf("nothing read") |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestReadingTwoCompletePackets(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1}) |
||||
|
||||
for i := 0; i < 2; i++ { |
||||
time.Sleep(10 * time.Millisecond) |
||||
select { |
||||
case packet := <-conn.Read(): |
||||
if bytes.Compare(packet, []byte{byte(i)}) != 0 { |
||||
t.Errorf("incorrect payload read") |
||||
} |
||||
case err := <-conn.Error(): |
||||
t.Errorf("incorrect error %v", err) |
||||
default: |
||||
t.Errorf("nothing read") |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestWriting(t *testing.T) { |
||||
conn, net := setupConnection() |
||||
conn.Write() <- []byte{0} |
||||
time.Sleep(10 * time.Millisecond) |
||||
if len(net.Out) == 0 { |
||||
t.Errorf("no output") |
||||
} else { |
||||
out := net.Out[0] |
||||
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 { |
||||
t.Errorf("incorrect packet %v", out) |
||||
} |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
|
@ -0,0 +1,75 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
// "fmt"
|
||||
"github.com/ethereum/eth-go/ethutil" |
||||
) |
||||
|
||||
type MsgCode uint8 |
||||
|
||||
type Msg struct { |
||||
code MsgCode // this is the raw code as per adaptive msg code scheme
|
||||
data *ethutil.Value |
||||
encoded []byte |
||||
} |
||||
|
||||
func (self *Msg) Code() MsgCode { |
||||
return self.code |
||||
} |
||||
|
||||
func (self *Msg) Data() *ethutil.Value { |
||||
return self.data |
||||
} |
||||
|
||||
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { |
||||
|
||||
// // data := [][]interface{}{}
|
||||
// data := []interface{}{}
|
||||
// for _, value := range params {
|
||||
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
|
||||
// data = append(data, encodable.RlpValue())
|
||||
// } else if raw, ok := value.([]interface{}); ok {
|
||||
// data = append(data, raw)
|
||||
// } else {
|
||||
// // data = append(data, interface{}(raw))
|
||||
// err = fmt.Errorf("Unable to encode object of type %T", value)
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
return &Msg{ |
||||
code: code, |
||||
data: ethutil.NewValue(interface{}(params)), |
||||
}, nil |
||||
} |
||||
|
||||
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { |
||||
value := ethutil.NewValueFromBytes(encoded) |
||||
// Type of message
|
||||
code := value.Get(0).Uint() |
||||
// Actual data
|
||||
data := value.SliceFrom(1) |
||||
|
||||
msg = &Msg{ |
||||
code: MsgCode(code), |
||||
data: data, |
||||
// data: ethutil.NewValue(data),
|
||||
encoded: encoded, |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (self *Msg) Decode(offset MsgCode) { |
||||
self.code = self.code - offset |
||||
} |
||||
|
||||
// encode takes an offset argument to implement adaptive message coding
|
||||
// the encoded message is memoized to make msgs relayed to several peers more efficient
|
||||
func (self *Msg) Encode(offset MsgCode) (res []byte) { |
||||
if len(self.encoded) == 0 { |
||||
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode() |
||||
self.encoded = res |
||||
} else { |
||||
res = self.encoded |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,38 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"testing" |
||||
) |
||||
|
||||
func TestNewMsg(t *testing.T) { |
||||
msg, _ := NewMsg(3, 1, "000") |
||||
if msg.Code() != 3 { |
||||
t.Errorf("incorrect code %v", msg.Code()) |
||||
} |
||||
data0 := msg.Data().Get(0).Uint() |
||||
data1 := string(msg.Data().Get(1).Bytes()) |
||||
if data0 != 1 { |
||||
t.Errorf("incorrect data %v", data0) |
||||
} |
||||
if data1 != "000" { |
||||
t.Errorf("incorrect data %v", data1) |
||||
} |
||||
} |
||||
|
||||
func TestEncodeDecodeMsg(t *testing.T) { |
||||
msg, _ := NewMsg(3, 1, "000") |
||||
encoded := msg.Encode(3) |
||||
msg, _ = NewMsgFromBytes(encoded) |
||||
msg.Decode(3) |
||||
if msg.Code() != 3 { |
||||
t.Errorf("incorrect code %v", msg.Code()) |
||||
} |
||||
data0 := msg.Data().Get(0).Uint() |
||||
data1 := msg.Data().Get(1).Str() |
||||
if data0 != 1 { |
||||
t.Errorf("incorrect data %v", data0) |
||||
} |
||||
if data1 != "000" { |
||||
t.Errorf("incorrect data %v", data1) |
||||
} |
||||
} |
@ -0,0 +1,220 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
const ( |
||||
handlerTimeout = 1000 |
||||
) |
||||
|
||||
type Handlers map[string](func(p *Peer) Protocol) |
||||
|
||||
type Messenger struct { |
||||
conn *Connection |
||||
peer *Peer |
||||
handlers Handlers |
||||
protocolLock sync.RWMutex |
||||
protocols []Protocol |
||||
offsets []MsgCode // offsets for adaptive message idss
|
||||
protocolTable map[string]int |
||||
quit chan chan bool |
||||
err chan *PeerError |
||||
pulse chan bool |
||||
} |
||||
|
||||
func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { |
||||
baseProtocol := NewBaseProtocol(peer) |
||||
return &Messenger{ |
||||
conn: conn, |
||||
peer: peer, |
||||
offsets: []MsgCode{baseProtocol.Offset()}, |
||||
handlers: handlers, |
||||
protocols: []Protocol{baseProtocol}, |
||||
protocolTable: make(map[string]int), |
||||
err: errchan, |
||||
pulse: make(chan bool, 1), |
||||
quit: make(chan chan bool, 1), |
||||
} |
||||
} |
||||
|
||||
func (self *Messenger) Start() { |
||||
self.conn.Open() |
||||
go self.messenger() |
||||
self.protocolLock.RLock() |
||||
defer self.protocolLock.RUnlock() |
||||
self.protocols[0].Start() |
||||
} |
||||
|
||||
func (self *Messenger) Stop() { |
||||
// close pulse to stop ping pong monitoring
|
||||
close(self.pulse) |
||||
self.protocolLock.RLock() |
||||
defer self.protocolLock.RUnlock() |
||||
for _, protocol := range self.protocols { |
||||
protocol.Stop() // could be parallel
|
||||
} |
||||
q := make(chan bool) |
||||
self.quit <- q |
||||
<-q |
||||
self.conn.Close() |
||||
} |
||||
|
||||
func (self *Messenger) messenger() { |
||||
in := self.conn.Read() |
||||
for { |
||||
select { |
||||
case payload, ok := <-in: |
||||
//dispatches message to the protocol asynchronously
|
||||
if ok { |
||||
go self.handle(payload) |
||||
} else { |
||||
return |
||||
} |
||||
case q := <-self.quit: |
||||
q <- true |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// handles each message by dispatching to the appropriate protocol
|
||||
// using adaptive message codes
|
||||
// this function is started as a separate go routine for each message
|
||||
// it waits for the protocol response
|
||||
// then encodes and sends outgoing messages to the connection's write channel
|
||||
func (self *Messenger) handle(payload []byte) { |
||||
// send ping to heartbeat channel signalling time of last message
|
||||
// select {
|
||||
// case self.pulse <- true:
|
||||
// default:
|
||||
// }
|
||||
self.pulse <- true |
||||
// initialise message from payload
|
||||
msg, err := NewMsgFromBytes(payload) |
||||
if err != nil { |
||||
self.err <- NewPeerError(MiscError, " %v", err) |
||||
return |
||||
} |
||||
// retrieves protocol based on message Code
|
||||
protocol, offset, peerErr := self.getProtocol(msg.Code()) |
||||
if err != nil { |
||||
self.err <- peerErr |
||||
return |
||||
} |
||||
// reset message code based on adaptive offset
|
||||
msg.Decode(offset) |
||||
// dispatches
|
||||
response := make(chan *Msg) |
||||
go protocol.HandleIn(msg, response) |
||||
// protocol reponse timeout to prevent leaks
|
||||
timer := time.After(handlerTimeout * time.Millisecond) |
||||
for { |
||||
select { |
||||
case outgoing, ok := <-response: |
||||
// we check if response channel is not closed
|
||||
if ok { |
||||
self.conn.Write() <- outgoing.Encode(offset) |
||||
} else { |
||||
return |
||||
} |
||||
case <-timer: |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// negotiated protocols
|
||||
// stores offsets needed for adaptive message id scheme
|
||||
|
||||
// based on offsets set at handshake
|
||||
// get the right protocol to handle the message
|
||||
func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { |
||||
self.protocolLock.RLock() |
||||
defer self.protocolLock.RUnlock() |
||||
base := MsgCode(0) |
||||
for index, offset := range self.offsets { |
||||
if code < offset { |
||||
return self.protocols[index], base, nil |
||||
} |
||||
base = offset |
||||
} |
||||
return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) |
||||
} |
||||
|
||||
func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { |
||||
fmt.Printf("pingpong keepalive started at %v", time.Now()) |
||||
|
||||
timer := time.After(timeout) |
||||
pinged := false |
||||
for { |
||||
select { |
||||
case _, ok := <-self.pulse: |
||||
if ok { |
||||
pinged = false |
||||
timer = time.After(timeout) |
||||
} else { |
||||
// pulse is closed, stop monitoring
|
||||
return |
||||
} |
||||
case <-timer: |
||||
if pinged { |
||||
fmt.Printf("timeout at %v", time.Now()) |
||||
timeoutCallback() |
||||
return |
||||
} else { |
||||
fmt.Printf("pinged at %v", time.Now()) |
||||
pingCallback() |
||||
timer = time.After(gracePeriod) |
||||
pinged = true |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *Messenger) AddProtocols(protocols []string) { |
||||
self.protocolLock.Lock() |
||||
defer self.protocolLock.Unlock() |
||||
i := len(self.offsets) |
||||
offset := self.offsets[i-1] |
||||
for _, name := range protocols { |
||||
protocolFunc, ok := self.handlers[name] |
||||
if ok { |
||||
protocol := protocolFunc(self.peer) |
||||
self.protocolTable[name] = i |
||||
i++ |
||||
offset += protocol.Offset() |
||||
fmt.Println("offset ", name, offset) |
||||
|
||||
self.offsets = append(self.offsets, offset) |
||||
self.protocols = append(self.protocols, protocol) |
||||
protocol.Start() |
||||
} else { |
||||
fmt.Println("no ", name) |
||||
// protocol not handled
|
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *Messenger) Write(protocol string, msg *Msg) error { |
||||
self.protocolLock.RLock() |
||||
defer self.protocolLock.RUnlock() |
||||
i := 0 |
||||
offset := MsgCode(0) |
||||
if len(protocol) > 0 { |
||||
var ok bool |
||||
i, ok = self.protocolTable[protocol] |
||||
if !ok { |
||||
return fmt.Errorf("protocol %v not handled by peer", protocol) |
||||
} |
||||
offset = self.offsets[i-1] |
||||
} |
||||
handler := self.protocols[i] |
||||
// checking if protocol status/caps allows the message to be sent out
|
||||
if handler.HandleOut(msg) { |
||||
self.conn.Write() <- msg.Encode(offset) |
||||
} |
||||
return nil |
||||
} |
@ -0,0 +1,146 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
// "fmt"
|
||||
"bytes" |
||||
"github.com/ethereum/eth-go/ethutil" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { |
||||
errchan := NewPeerErrorChannel() |
||||
addr := &TestAddr{"test:30303"} |
||||
net := NewTestNetworkConnection(addr) |
||||
conn := NewConnection(net, errchan) |
||||
mess := NewMessenger(nil, conn, errchan, handlers) |
||||
mess.Start() |
||||
return net, errchan, mess |
||||
} |
||||
|
||||
type TestProtocol struct { |
||||
Msgs []*Msg |
||||
} |
||||
|
||||
func (self *TestProtocol) Start() { |
||||
} |
||||
|
||||
func (self *TestProtocol) Stop() { |
||||
} |
||||
|
||||
func (self *TestProtocol) Offset() MsgCode { |
||||
return MsgCode(5) |
||||
} |
||||
|
||||
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) { |
||||
self.Msgs = append(self.Msgs, msg) |
||||
close(response) |
||||
} |
||||
|
||||
func (self *TestProtocol) HandleOut(msg *Msg) bool { |
||||
if msg.Code() > 3 { |
||||
return false |
||||
} else { |
||||
return true |
||||
} |
||||
} |
||||
|
||||
func (self *TestProtocol) Name() string { |
||||
return "a" |
||||
} |
||||
|
||||
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte { |
||||
msg, _ := NewMsg(code, params...) |
||||
encoded := msg.Encode(offset) |
||||
packet := []byte{34, 64, 8, 145} |
||||
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...) |
||||
return append(packet, encoded...) |
||||
} |
||||
|
||||
func TestRead(t *testing.T) { |
||||
handlers := make(Handlers) |
||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol } |
||||
net, _, mess := setupMessenger(handlers) |
||||
mess.AddProtocols([]string{"a"}) |
||||
defer mess.Stop() |
||||
wait := 1 * time.Millisecond |
||||
packet := Packet(16, 1, uint32(1), "000") |
||||
go net.In(0, packet) |
||||
time.Sleep(wait) |
||||
if len(testProtocol.Msgs) != 1 { |
||||
t.Errorf("msg not relayed to correct protocol") |
||||
} else { |
||||
if testProtocol.Msgs[0].Code() != 1 { |
||||
t.Errorf("incorrect msg code relayed to protocol") |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestWrite(t *testing.T) { |
||||
handlers := make(Handlers) |
||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol } |
||||
net, _, mess := setupMessenger(handlers) |
||||
mess.AddProtocols([]string{"a"}) |
||||
defer mess.Stop() |
||||
wait := 1 * time.Millisecond |
||||
msg, _ := NewMsg(3, uint32(1), "000") |
||||
err := mess.Write("b", msg) |
||||
if err == nil { |
||||
t.Errorf("expect error for unknown protocol") |
||||
} |
||||
err = mess.Write("a", msg) |
||||
if err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} else { |
||||
time.Sleep(wait) |
||||
if len(net.Out) != 1 { |
||||
t.Errorf("msg not written") |
||||
} else { |
||||
out := net.Out[0] |
||||
packet := Packet(16, 3, uint32(1), "000") |
||||
if bytes.Compare(out, packet) != 0 { |
||||
t.Errorf("incorrect packet %v", out) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestPulse(t *testing.T) { |
||||
net, _, mess := setupMessenger(make(Handlers)) |
||||
defer mess.Stop() |
||||
ping := false |
||||
timeout := false |
||||
pingTimeout := 10 * time.Millisecond |
||||
gracePeriod := 200 * time.Millisecond |
||||
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true }) |
||||
net.In(0, Packet(0, 1)) |
||||
if ping { |
||||
t.Errorf("ping sent too early") |
||||
} |
||||
time.Sleep(pingTimeout + 100*time.Millisecond) |
||||
if !ping { |
||||
t.Errorf("no ping sent after timeout") |
||||
} |
||||
if timeout { |
||||
t.Errorf("timeout too early") |
||||
} |
||||
ping = false |
||||
net.In(0, Packet(0, 1)) |
||||
time.Sleep(pingTimeout + 100*time.Millisecond) |
||||
if !ping { |
||||
t.Errorf("no ping sent after timeout") |
||||
} |
||||
if timeout { |
||||
t.Errorf("timeout too early") |
||||
} |
||||
ping = false |
||||
time.Sleep(gracePeriod) |
||||
if ping { |
||||
t.Errorf("ping called twice") |
||||
} |
||||
if !timeout { |
||||
t.Errorf("no timeout after grace period") |
||||
} |
||||
} |
@ -0,0 +1,55 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
|
||||
natpmp "github.com/jackpal/go-nat-pmp" |
||||
) |
||||
|
||||
// Adapt the NAT-PMP protocol to the NAT interface
|
||||
|
||||
// TODO:
|
||||
// + Register for changes to the external address.
|
||||
// + Re-register port mapping when router reboots.
|
||||
// + A mechanism for keeping a port mapping registered.
|
||||
|
||||
type natPMPClient struct { |
||||
client *natpmp.Client |
||||
} |
||||
|
||||
func NewNatPMP(gateway net.IP) (nat NAT) { |
||||
return &natPMPClient{natpmp.NewClient(gateway)} |
||||
} |
||||
|
||||
func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { |
||||
response, err := n.client.GetExternalAddress() |
||||
if err != nil { |
||||
return |
||||
} |
||||
ip := response.ExternalIPAddress |
||||
addr = net.IPv4(ip[0], ip[1], ip[2], ip[3]) |
||||
return |
||||
} |
||||
|
||||
func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, |
||||
description string, timeout int) (mappedExternalPort int, err error) { |
||||
if timeout <= 0 { |
||||
err = fmt.Errorf("timeout must not be <= 0") |
||||
return |
||||
} |
||||
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
|
||||
response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) |
||||
if err != nil { |
||||
return |
||||
} |
||||
mappedExternalPort = int(response.MappedExternalPort) |
||||
return |
||||
} |
||||
|
||||
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { |
||||
// To destroy a mapping, send an add-port with
|
||||
// an internalPort of the internal port to destroy, an external port of zero and a time of zero.
|
||||
_, err = n.client.AddPortMapping(protocol, internalPort, 0, 0) |
||||
return |
||||
} |
@ -0,0 +1,335 @@ |
||||
package p2p |
||||
|
||||
// Just enough UPnP to be able to forward ports
|
||||
//
|
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/xml" |
||||
"errors" |
||||
"net" |
||||
"net/http" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
) |
||||
|
||||
type upnpNAT struct { |
||||
serviceURL string |
||||
ourIP string |
||||
} |
||||
|
||||
func upnpDiscover(attempts int) (nat NAT, err error) { |
||||
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") |
||||
if err != nil { |
||||
return |
||||
} |
||||
conn, err := net.ListenPacket("udp4", ":0") |
||||
if err != nil { |
||||
return |
||||
} |
||||
socket := conn.(*net.UDPConn) |
||||
defer socket.Close() |
||||
|
||||
err = socket.SetDeadline(time.Now().Add(10 * time.Second)) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" |
||||
buf := bytes.NewBufferString( |
||||
"M-SEARCH * HTTP/1.1\r\n" + |
||||
"HOST: 239.255.255.250:1900\r\n" + |
||||
st + |
||||
"MAN: \"ssdp:discover\"\r\n" + |
||||
"MX: 2\r\n\r\n") |
||||
message := buf.Bytes() |
||||
answerBytes := make([]byte, 1024) |
||||
for i := 0; i < attempts; i++ { |
||||
_, err = socket.WriteToUDP(message, ssdp) |
||||
if err != nil { |
||||
return |
||||
} |
||||
var n int |
||||
n, _, err = socket.ReadFromUDP(answerBytes) |
||||
if err != nil { |
||||
continue |
||||
// socket.Close()
|
||||
// return
|
||||
} |
||||
answer := string(answerBytes[0:n]) |
||||
if strings.Index(answer, "\r\n"+st) < 0 { |
||||
continue |
||||
} |
||||
// HTTP header field names are case-insensitive.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
|
||||
locString := "\r\nlocation: " |
||||
answer = strings.ToLower(answer) |
||||
locIndex := strings.Index(answer, locString) |
||||
if locIndex < 0 { |
||||
continue |
||||
} |
||||
loc := answer[locIndex+len(locString):] |
||||
endIndex := strings.Index(loc, "\r\n") |
||||
if endIndex < 0 { |
||||
continue |
||||
} |
||||
locURL := loc[0:endIndex] |
||||
var serviceURL string |
||||
serviceURL, err = getServiceURL(locURL) |
||||
if err != nil { |
||||
return |
||||
} |
||||
var ourIP string |
||||
ourIP, err = getOurIP() |
||||
if err != nil { |
||||
return |
||||
} |
||||
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} |
||||
return |
||||
} |
||||
err = errors.New("UPnP port discovery failed.") |
||||
return |
||||
} |
||||
|
||||
// service represents the Service type in an UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type service struct { |
||||
ServiceType string `xml:"serviceType"` |
||||
ControlURL string `xml:"controlURL"` |
||||
} |
||||
|
||||
// deviceList represents the deviceList type in an UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type deviceList struct { |
||||
XMLName xml.Name `xml:"deviceList"` |
||||
Device []device `xml:"device"` |
||||
} |
||||
|
||||
// serviceList represents the serviceList type in an UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type serviceList struct { |
||||
XMLName xml.Name `xml:"serviceList"` |
||||
Service []service `xml:"service"` |
||||
} |
||||
|
||||
// device represents the device type in an UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type device struct { |
||||
XMLName xml.Name `xml:"device"` |
||||
DeviceType string `xml:"deviceType"` |
||||
DeviceList deviceList `xml:"deviceList"` |
||||
ServiceList serviceList `xml:"serviceList"` |
||||
} |
||||
|
||||
// specVersion represents the specVersion in a UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type specVersion struct { |
||||
XMLName xml.Name `xml:"specVersion"` |
||||
Major int `xml:"major"` |
||||
Minor int `xml:"minor"` |
||||
} |
||||
|
||||
// root represents the Root document for a UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type root struct { |
||||
XMLName xml.Name `xml:"root"` |
||||
SpecVersion specVersion |
||||
Device device |
||||
} |
||||
|
||||
func getChildDevice(d *device, deviceType string) *device { |
||||
dl := d.DeviceList.Device |
||||
for i := 0; i < len(dl); i++ { |
||||
if dl[i].DeviceType == deviceType { |
||||
return &dl[i] |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func getChildService(d *device, serviceType string) *service { |
||||
sl := d.ServiceList.Service |
||||
for i := 0; i < len(sl); i++ { |
||||
if sl[i].ServiceType == serviceType { |
||||
return &sl[i] |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func getOurIP() (ip string, err error) { |
||||
hostname, err := os.Hostname() |
||||
if err != nil { |
||||
return |
||||
} |
||||
p, err := net.LookupIP(hostname) |
||||
if err != nil && len(p) > 0 { |
||||
return |
||||
} |
||||
return p[0].String(), nil |
||||
} |
||||
|
||||
func getServiceURL(rootURL string) (url string, err error) { |
||||
r, err := http.Get(rootURL) |
||||
if err != nil { |
||||
return |
||||
} |
||||
defer r.Body.Close() |
||||
if r.StatusCode >= 400 { |
||||
err = errors.New(string(r.StatusCode)) |
||||
return |
||||
} |
||||
var root root |
||||
err = xml.NewDecoder(r.Body).Decode(&root) |
||||
|
||||
if err != nil { |
||||
return |
||||
} |
||||
a := &root.Device |
||||
if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" { |
||||
err = errors.New("No InternetGatewayDevice") |
||||
return |
||||
} |
||||
b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1") |
||||
if b == nil { |
||||
err = errors.New("No WANDevice") |
||||
return |
||||
} |
||||
c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1") |
||||
if c == nil { |
||||
err = errors.New("No WANConnectionDevice") |
||||
return |
||||
} |
||||
d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1") |
||||
if d == nil { |
||||
err = errors.New("No WANIPConnection") |
||||
return |
||||
} |
||||
url = combineURL(rootURL, d.ControlURL) |
||||
return |
||||
} |
||||
|
||||
func combineURL(rootURL, subURL string) string { |
||||
protocolEnd := "://" |
||||
protoEndIndex := strings.Index(rootURL, protocolEnd) |
||||
a := rootURL[protoEndIndex+len(protocolEnd):] |
||||
rootIndex := strings.Index(a, "/") |
||||
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL |
||||
} |
||||
|
||||
func soapRequest(url, function, message string) (r *http.Response, err error) { |
||||
fullMessage := "<?xml version=\"1.0\" ?>" + |
||||
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" + |
||||
"<s:Body>" + message + "</s:Body></s:Envelope>" |
||||
|
||||
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) |
||||
if err != nil { |
||||
return |
||||
} |
||||
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") |
||||
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") |
||||
//req.Header.Set("Transfer-Encoding", "chunked")
|
||||
req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"") |
||||
req.Header.Set("Connection", "Close") |
||||
req.Header.Set("Cache-Control", "no-cache") |
||||
req.Header.Set("Pragma", "no-cache") |
||||
|
||||
r, err = http.DefaultClient.Do(req) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
if r.Body != nil { |
||||
defer r.Body.Close() |
||||
} |
||||
|
||||
if r.StatusCode >= 400 { |
||||
// log.Stderr(function, r.StatusCode)
|
||||
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) |
||||
r = nil |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
type statusInfo struct { |
||||
externalIpAddress string |
||||
} |
||||
|
||||
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { |
||||
|
||||
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + |
||||
"</u:GetStatusInfo>" |
||||
|
||||
var response *http.Response |
||||
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
|
||||
|
||||
response.Body.Close() |
||||
return |
||||
} |
||||
|
||||
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { |
||||
info, err := n.getStatusInfo() |
||||
if err != nil { |
||||
return |
||||
} |
||||
addr = net.ParseIP(info.externalIpAddress) |
||||
return |
||||
} |
||||
|
||||
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { |
||||
// A single concatenation would break ARM compilation.
|
||||
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + |
||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) |
||||
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" |
||||
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" + |
||||
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + |
||||
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>" |
||||
message += description + |
||||
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) + |
||||
"</NewLeaseDuration></u:AddPortMapping>" |
||||
|
||||
var response *http.Response |
||||
response, err = soapRequest(n.serviceURL, "AddPortMapping", message) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
// TODO: check response to see if the port was forwarded
|
||||
// log.Println(message, response)
|
||||
mappedExternalPort = externalPort |
||||
_ = response |
||||
return |
||||
} |
||||
|
||||
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { |
||||
|
||||
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + |
||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + |
||||
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + |
||||
"</u:DeletePortMapping>" |
||||
|
||||
var response *http.Response |
||||
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
// TODO: check response to see if the port was deleted
|
||||
// log.Println(message, response)
|
||||
_ = response |
||||
return |
||||
} |
@ -0,0 +1,196 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"fmt" |
||||
"math/rand" |
||||
"net" |
||||
"strconv" |
||||
"time" |
||||
) |
||||
|
||||
const ( |
||||
DialerTimeout = 180 //seconds
|
||||
KeepAlivePeriod = 60 //minutes
|
||||
portMappingUpdateInterval = 900 // seconds = 15 mins
|
||||
upnpDiscoverAttempts = 3 |
||||
) |
||||
|
||||
// Dialer is not an interface in net, so we define one
|
||||
// *net.Dialer conforms to this
|
||||
type Dialer interface { |
||||
Dial(network, address string) (net.Conn, error) |
||||
} |
||||
|
||||
type Network interface { |
||||
Start() error |
||||
Listener(net.Addr) (net.Listener, error) |
||||
Dialer(net.Addr) (Dialer, error) |
||||
NewAddr(string, int) (addr net.Addr, err error) |
||||
ParseAddr(string) (addr net.Addr, err error) |
||||
} |
||||
|
||||
type NAT interface { |
||||
GetExternalAddress() (addr net.IP, err error) |
||||
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) |
||||
DeletePortMapping(protocol string, externalPort, internalPort int) (err error) |
||||
} |
||||
|
||||
type TCPNetwork struct { |
||||
nat NAT |
||||
natType NATType |
||||
quit chan chan bool |
||||
ports chan string |
||||
} |
||||
|
||||
type NATType int |
||||
|
||||
const ( |
||||
NONE = iota |
||||
UPNP |
||||
PMP |
||||
) |
||||
|
||||
const ( |
||||
portMappingTimeout = 1200 // 20 mins
|
||||
) |
||||
|
||||
func NewTCPNetwork(natType NATType) (net *TCPNetwork) { |
||||
return &TCPNetwork{ |
||||
natType: natType, |
||||
ports: make(chan string), |
||||
} |
||||
} |
||||
|
||||
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { |
||||
return &net.Dialer{ |
||||
Timeout: DialerTimeout * time.Second, |
||||
// KeepAlive: KeepAlivePeriod * time.Minute,
|
||||
LocalAddr: addr, |
||||
}, nil |
||||
} |
||||
|
||||
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { |
||||
if self.natType == UPNP { |
||||
_, port, _ := net.SplitHostPort(addr.String()) |
||||
if self.quit == nil { |
||||
self.quit = make(chan chan bool) |
||||
go self.updatePortMappings() |
||||
} |
||||
self.ports <- port |
||||
} |
||||
return net.Listen(addr.Network(), addr.String()) |
||||
} |
||||
|
||||
func (self *TCPNetwork) Start() (err error) { |
||||
switch self.natType { |
||||
case NONE: |
||||
case UPNP: |
||||
nat, uerr := upnpDiscover(upnpDiscoverAttempts) |
||||
if uerr != nil { |
||||
err = fmt.Errorf("UPNP failed: ", uerr) |
||||
} else { |
||||
self.nat = nat |
||||
} |
||||
case PMP: |
||||
err = fmt.Errorf("PMP not implemented") |
||||
default: |
||||
err = fmt.Errorf("Invalid NAT type: %v", self.natType) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (self *TCPNetwork) Stop() { |
||||
q := make(chan bool) |
||||
self.quit <- q |
||||
<-q |
||||
} |
||||
|
||||
func (self *TCPNetwork) addPortMapping(lport int) (err error) { |
||||
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) |
||||
if err != nil { |
||||
logger.Errorf("unable to add port mapping on %v: %v", lport, err) |
||||
} else { |
||||
logger.Debugf("succesfully added port mapping on %v", lport) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (self *TCPNetwork) updatePortMappings() { |
||||
timer := time.NewTimer(portMappingUpdateInterval * time.Second) |
||||
lports := []int{} |
||||
out: |
||||
for { |
||||
select { |
||||
case port := <-self.ports: |
||||
int64lport, _ := strconv.ParseInt(port, 10, 16) |
||||
lport := int(int64lport) |
||||
if err := self.addPortMapping(lport); err != nil { |
||||
lports = append(lports, lport) |
||||
} |
||||
case <-timer.C: |
||||
for lport := range lports { |
||||
if err := self.addPortMapping(lport); err != nil { |
||||
} |
||||
} |
||||
case errc := <-self.quit: |
||||
errc <- true |
||||
break out |
||||
} |
||||
} |
||||
|
||||
timer.Stop() |
||||
for lport := range lports { |
||||
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { |
||||
logger.Debugf("unable to remove port mapping on %v: %v", lport, err) |
||||
} else { |
||||
logger.Debugf("succesfully removed port mapping on %v", lport) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { |
||||
ip, err := self.lookupIP(host) |
||||
if err == nil { |
||||
return &net.TCPAddr{ |
||||
IP: ip, |
||||
Port: port, |
||||
}, nil |
||||
} |
||||
return nil, err |
||||
} |
||||
|
||||
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { |
||||
host, port, err := net.SplitHostPort(address) |
||||
if err == nil { |
||||
iport, _ := strconv.Atoi(port) |
||||
addr, e := self.NewAddr(host, iport) |
||||
return addr, e |
||||
} |
||||
return nil, err |
||||
} |
||||
|
||||
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { |
||||
if ip = net.ParseIP(host); ip != nil { |
||||
return |
||||
} |
||||
|
||||
var ips []net.IP |
||||
ips, err = net.LookupIP(host) |
||||
if err != nil { |
||||
logger.Warnln(err) |
||||
return |
||||
} |
||||
if len(ips) == 0 { |
||||
err = fmt.Errorf("No IP addresses available for %v", host) |
||||
logger.Warnln(err) |
||||
return |
||||
} |
||||
if len(ips) > 1 { |
||||
// Pick a random IP address, simulating round-robin DNS.
|
||||
rand.Seed(time.Now().UTC().UnixNano()) |
||||
ip = ips[rand.Intn(len(ips))] |
||||
} else { |
||||
ip = ips[0] |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,83 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
"strconv" |
||||
) |
||||
|
||||
type Peer struct { |
||||
// quit chan chan bool
|
||||
Inbound bool // inbound (via listener) or outbound (via dialout)
|
||||
Address net.Addr |
||||
Host []byte |
||||
Port uint16 |
||||
Pubkey []byte |
||||
Id string |
||||
Caps []string |
||||
peerErrorChan chan *PeerError |
||||
messenger *Messenger |
||||
peerErrorHandler *PeerErrorHandler |
||||
server *Server |
||||
} |
||||
|
||||
func (self *Peer) Messenger() *Messenger { |
||||
return self.messenger |
||||
} |
||||
|
||||
func (self *Peer) PeerErrorChan() chan *PeerError { |
||||
return self.peerErrorChan |
||||
} |
||||
|
||||
func (self *Peer) Server() *Server { |
||||
return self.server |
||||
} |
||||
|
||||
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { |
||||
peerErrorChan := NewPeerErrorChannel() |
||||
host, port, _ := net.SplitHostPort(address.String()) |
||||
intport, _ := strconv.Atoi(port) |
||||
peer := &Peer{ |
||||
Inbound: inbound, |
||||
Address: address, |
||||
Port: uint16(intport), |
||||
Host: net.ParseIP(host), |
||||
peerErrorChan: peerErrorChan, |
||||
server: server, |
||||
} |
||||
connection := NewConnection(conn, peerErrorChan) |
||||
peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers()) |
||||
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist()) |
||||
return peer |
||||
} |
||||
|
||||
func (self *Peer) String() string { |
||||
var kind string |
||||
if self.Inbound { |
||||
kind = "inbound" |
||||
} else { |
||||
kind = "outbound" |
||||
} |
||||
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) |
||||
} |
||||
|
||||
func (self *Peer) Write(protocol string, msg *Msg) error { |
||||
return self.messenger.Write(protocol, msg) |
||||
} |
||||
|
||||
func (self *Peer) Start() { |
||||
self.peerErrorHandler.Start() |
||||
self.messenger.Start() |
||||
} |
||||
|
||||
func (self *Peer) Stop() { |
||||
self.peerErrorHandler.Stop() |
||||
self.messenger.Stop() |
||||
// q := make(chan bool)
|
||||
// self.quit <- q
|
||||
// <-q
|
||||
} |
||||
|
||||
func (p *Peer) Encode() []interface{} { |
||||
return []interface{}{p.Host, p.Port, p.Pubkey} |
||||
} |
@ -0,0 +1,76 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"fmt" |
||||
) |
||||
|
||||
type ErrorCode int |
||||
|
||||
const errorChanCapacity = 10 |
||||
|
||||
const ( |
||||
PacketTooShort = iota |
||||
PayloadTooShort |
||||
MagicTokenMismatch |
||||
EmptyPayload |
||||
ReadError |
||||
WriteError |
||||
MiscError |
||||
InvalidMsgCode |
||||
InvalidMsg |
||||
P2PVersionMismatch |
||||
PubkeyMissing |
||||
PubkeyInvalid |
||||
PubkeyForbidden |
||||
ProtocolBreach |
||||
PortMismatch |
||||
PingTimeout |
||||
InvalidGenesis |
||||
InvalidNetworkId |
||||
InvalidProtocolVersion |
||||
) |
||||
|
||||
var errorToString = map[ErrorCode]string{ |
||||
PacketTooShort: "Packet too short", |
||||
PayloadTooShort: "Payload too short", |
||||
MagicTokenMismatch: "Magic token mismatch", |
||||
EmptyPayload: "Empty payload", |
||||
ReadError: "Read error", |
||||
WriteError: "Write error", |
||||
MiscError: "Misc error", |
||||
InvalidMsgCode: "Invalid message code", |
||||
InvalidMsg: "Invalid message", |
||||
P2PVersionMismatch: "P2P Version Mismatch", |
||||
PubkeyMissing: "Public key missing", |
||||
PubkeyInvalid: "Public key invalid", |
||||
PubkeyForbidden: "Public key forbidden", |
||||
ProtocolBreach: "Protocol Breach", |
||||
PortMismatch: "Port mismatch", |
||||
PingTimeout: "Ping timeout", |
||||
InvalidGenesis: "Invalid genesis block", |
||||
InvalidNetworkId: "Invalid network id", |
||||
InvalidProtocolVersion: "Invalid protocol version", |
||||
} |
||||
|
||||
type PeerError struct { |
||||
Code ErrorCode |
||||
message string |
||||
} |
||||
|
||||
func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { |
||||
desc, ok := errorToString[code] |
||||
if !ok { |
||||
panic("invalid error code") |
||||
} |
||||
format = desc + ": " + format |
||||
message := fmt.Sprintf(format, v...) |
||||
return &PeerError{code, message} |
||||
} |
||||
|
||||
func (self *PeerError) Error() string { |
||||
return self.message |
||||
} |
||||
|
||||
func NewPeerErrorChannel() chan *PeerError { |
||||
return make(chan *PeerError, errorChanCapacity) |
||||
} |
@ -0,0 +1,101 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"net" |
||||
) |
||||
|
||||
const ( |
||||
severityThreshold = 10 |
||||
) |
||||
|
||||
type DisconnectRequest struct { |
||||
addr net.Addr |
||||
reason DiscReason |
||||
} |
||||
|
||||
type PeerErrorHandler struct { |
||||
quit chan chan bool |
||||
address net.Addr |
||||
peerDisconnect chan DisconnectRequest |
||||
severity int |
||||
peerErrorChan chan *PeerError |
||||
blacklist Blacklist |
||||
} |
||||
|
||||
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler { |
||||
return &PeerErrorHandler{ |
||||
quit: make(chan chan bool), |
||||
address: address, |
||||
peerDisconnect: peerDisconnect, |
||||
peerErrorChan: peerErrorChan, |
||||
blacklist: blacklist, |
||||
} |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) Start() { |
||||
go self.listen() |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) Stop() { |
||||
q := make(chan bool) |
||||
self.quit <- q |
||||
<-q |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) listen() { |
||||
for { |
||||
select { |
||||
case peerError, ok := <-self.peerErrorChan: |
||||
if ok { |
||||
logger.Debugf("error %v\n", peerError) |
||||
go self.handle(peerError) |
||||
} else { |
||||
return |
||||
} |
||||
case q := <-self.quit: |
||||
q <- true |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) handle(peerError *PeerError) { |
||||
reason := DiscReason(' ') |
||||
switch peerError.Code { |
||||
case P2PVersionMismatch: |
||||
reason = DiscIncompatibleVersion |
||||
case PubkeyMissing, PubkeyInvalid: |
||||
reason = DiscInvalidIdentity |
||||
case PubkeyForbidden: |
||||
reason = DiscUselessPeer |
||||
case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach: |
||||
reason = DiscProtocolError |
||||
case PingTimeout: |
||||
reason = DiscReadTimeout |
||||
case WriteError, MiscError: |
||||
reason = DiscNetworkError |
||||
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: |
||||
reason = DiscSubprotocolError |
||||
default: |
||||
self.severity += self.getSeverity(peerError) |
||||
} |
||||
|
||||
if self.severity >= severityThreshold { |
||||
reason = DiscSubprotocolError |
||||
} |
||||
if reason != DiscReason(' ') { |
||||
self.peerDisconnect <- DisconnectRequest{ |
||||
addr: self.address, |
||||
reason: reason, |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { |
||||
switch peerError.Code { |
||||
case ReadError: |
||||
return 4 //tolerate 3 :)
|
||||
default: |
||||
return 1 |
||||
} |
||||
} |
@ -0,0 +1,34 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
// "fmt"
|
||||
"net" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestPeerErrorHandler(t *testing.T) { |
||||
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} |
||||
peerDisconnect := make(chan DisconnectRequest) |
||||
peerErrorChan := NewPeerErrorChannel() |
||||
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist()) |
||||
peh.Start() |
||||
defer peh.Stop() |
||||
for i := 0; i < 11; i++ { |
||||
select { |
||||
case <-peerDisconnect: |
||||
t.Errorf("expected no disconnect request") |
||||
default: |
||||
} |
||||
peerErrorChan <- NewPeerError(MiscError, "") |
||||
} |
||||
time.Sleep(1 * time.Millisecond) |
||||
select { |
||||
case request := <-peerDisconnect: |
||||
if request.addr.String() != address.String() { |
||||
t.Errorf("incorrect address %v != %v", request.addr, address) |
||||
} |
||||
default: |
||||
t.Errorf("expected disconnect request") |
||||
} |
||||
} |
@ -0,0 +1,96 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
// "net"
|
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestPeer(t *testing.T) { |
||||
handlers := make(Handlers) |
||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
||||
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } |
||||
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } |
||||
addr := &TestAddr{"test:30"} |
||||
conn := NewTestNetworkConnection(addr) |
||||
_, server := SetupTestServer(handlers) |
||||
server.Handshake() |
||||
peer := NewPeer(conn, addr, true, server) |
||||
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
||||
peer.Start() |
||||
defer peer.Stop() |
||||
time.Sleep(2 * time.Millisecond) |
||||
if len(conn.Out) != 1 { |
||||
t.Errorf("handshake not sent") |
||||
} else { |
||||
out := conn.Out[0] |
||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) |
||||
if bytes.Compare(out, packet) != 0 { |
||||
t.Errorf("incorrect handshake packet %v != %v", out, packet) |
||||
} |
||||
} |
||||
|
||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) |
||||
conn.In(0, packet) |
||||
time.Sleep(10 * time.Millisecond) |
||||
|
||||
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) |
||||
if pro.state != handshakeReceived { |
||||
t.Errorf("handshake not received") |
||||
} |
||||
if peer.Port != 30 { |
||||
t.Errorf("port incorrectly set") |
||||
} |
||||
if peer.Id != "peer" { |
||||
t.Errorf("id incorrectly set") |
||||
} |
||||
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { |
||||
t.Errorf("pubkey incorrectly set") |
||||
} |
||||
fmt.Println(peer.Caps) |
||||
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { |
||||
t.Errorf("protocols incorrectly set") |
||||
} |
||||
|
||||
msg, _ := NewMsg(3) |
||||
err := peer.Write("aaa", msg) |
||||
if err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} else { |
||||
time.Sleep(1 * time.Millisecond) |
||||
if len(conn.Out) != 2 { |
||||
t.Errorf("msg not written") |
||||
} else { |
||||
out := conn.Out[1] |
||||
packet := Packet(16, 3) |
||||
if bytes.Compare(out, packet) != 0 { |
||||
t.Errorf("incorrect packet %v != %v", out, packet) |
||||
} |
||||
} |
||||
} |
||||
|
||||
msg, _ = NewMsg(2) |
||||
err = peer.Write("ccc", msg) |
||||
if err != nil { |
||||
t.Errorf("expect no error for known protocol: %v", err) |
||||
} else { |
||||
time.Sleep(1 * time.Millisecond) |
||||
if len(conn.Out) != 3 { |
||||
t.Errorf("msg not written") |
||||
} else { |
||||
out := conn.Out[2] |
||||
packet := Packet(21, 2) |
||||
if bytes.Compare(out, packet) != 0 { |
||||
t.Errorf("incorrect packet %v != %v", out, packet) |
||||
} |
||||
} |
||||
} |
||||
|
||||
err = peer.Write("bbb", msg) |
||||
time.Sleep(1 * time.Millisecond) |
||||
if err == nil { |
||||
t.Errorf("expect error for unknown protocol") |
||||
} |
||||
} |
@ -0,0 +1,278 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"net" |
||||
"sort" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
type Protocol interface { |
||||
Start() |
||||
Stop() |
||||
HandleIn(*Msg, chan *Msg) |
||||
HandleOut(*Msg) bool |
||||
Offset() MsgCode |
||||
Name() string |
||||
} |
||||
|
||||
const ( |
||||
P2PVersion = 0 |
||||
pingTimeout = 2 |
||||
pingGracePeriod = 2 |
||||
) |
||||
|
||||
const ( |
||||
HandshakeMsg = iota |
||||
DiscMsg |
||||
PingMsg |
||||
PongMsg |
||||
GetPeersMsg |
||||
PeersMsg |
||||
offset = 16 |
||||
) |
||||
|
||||
type ProtocolState uint8 |
||||
|
||||
const ( |
||||
nullState = iota |
||||
handshakeReceived |
||||
) |
||||
|
||||
type DiscReason byte |
||||
|
||||
const ( |
||||
// Values are given explicitly instead of by iota because these values are
|
||||
// defined by the wire protocol spec; it is easier for humans to ensure
|
||||
// correctness when values are explicit.
|
||||
DiscRequested = 0x00 |
||||
DiscNetworkError = 0x01 |
||||
DiscProtocolError = 0x02 |
||||
DiscUselessPeer = 0x03 |
||||
DiscTooManyPeers = 0x04 |
||||
DiscAlreadyConnected = 0x05 |
||||
DiscIncompatibleVersion = 0x06 |
||||
DiscInvalidIdentity = 0x07 |
||||
DiscQuitting = 0x08 |
||||
DiscUnexpectedIdentity = 0x09 |
||||
DiscSelf = 0x0a |
||||
DiscReadTimeout = 0x0b |
||||
DiscSubprotocolError = 0x10 |
||||
) |
||||
|
||||
var discReasonToString = map[DiscReason]string{ |
||||
DiscRequested: "Disconnect requested", |
||||
DiscNetworkError: "Network error", |
||||
DiscProtocolError: "Breach of protocol", |
||||
DiscUselessPeer: "Useless peer", |
||||
DiscTooManyPeers: "Too many peers", |
||||
DiscAlreadyConnected: "Already connected", |
||||
DiscIncompatibleVersion: "Incompatible P2P protocol version", |
||||
DiscInvalidIdentity: "Invalid node identity", |
||||
DiscQuitting: "Client quitting", |
||||
DiscUnexpectedIdentity: "Unexpected identity", |
||||
DiscSelf: "Connected to self", |
||||
DiscReadTimeout: "Read timeout", |
||||
DiscSubprotocolError: "Subprotocol error", |
||||
} |
||||
|
||||
func (d DiscReason) String() string { |
||||
if len(discReasonToString) < int(d) { |
||||
return "Unknown" |
||||
} |
||||
|
||||
return discReasonToString[d] |
||||
} |
||||
|
||||
type BaseProtocol struct { |
||||
peer *Peer |
||||
state ProtocolState |
||||
stateLock sync.RWMutex |
||||
} |
||||
|
||||
func NewBaseProtocol(peer *Peer) *BaseProtocol { |
||||
self := &BaseProtocol{ |
||||
peer: peer, |
||||
} |
||||
|
||||
return self |
||||
} |
||||
|
||||
func (self *BaseProtocol) Start() { |
||||
if self.peer != nil { |
||||
self.peer.Write("", self.peer.Server().Handshake()) |
||||
go self.peer.Messenger().PingPong( |
||||
pingTimeout*time.Second, |
||||
pingGracePeriod*time.Second, |
||||
self.Ping, |
||||
self.Timeout, |
||||
) |
||||
} |
||||
} |
||||
|
||||
func (self *BaseProtocol) Stop() { |
||||
} |
||||
|
||||
func (self *BaseProtocol) Ping() { |
||||
msg, _ := NewMsg(PingMsg) |
||||
self.peer.Write("", msg) |
||||
} |
||||
|
||||
func (self *BaseProtocol) Timeout() { |
||||
self.peerError(PingTimeout, "") |
||||
} |
||||
|
||||
func (self *BaseProtocol) Name() string { |
||||
return "" |
||||
} |
||||
|
||||
func (self *BaseProtocol) Offset() MsgCode { |
||||
return offset |
||||
} |
||||
|
||||
func (self *BaseProtocol) CheckState(state ProtocolState) bool { |
||||
self.stateLock.RLock() |
||||
self.stateLock.RUnlock() |
||||
if self.state != state { |
||||
return false |
||||
} else { |
||||
return true |
||||
} |
||||
} |
||||
|
||||
func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) { |
||||
if msg.Code() == HandshakeMsg { |
||||
self.handleHandshake(msg) |
||||
} else { |
||||
if !self.CheckState(handshakeReceived) { |
||||
self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code()) |
||||
close(response) |
||||
return |
||||
} |
||||
switch msg.Code() { |
||||
case DiscMsg: |
||||
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint())) |
||||
self.peer.Server().PeerDisconnect() <- DisconnectRequest{ |
||||
addr: self.peer.Address, |
||||
reason: DiscRequested, |
||||
} |
||||
case PingMsg: |
||||
out, _ := NewMsg(PongMsg) |
||||
response <- out |
||||
case PongMsg: |
||||
case GetPeersMsg: |
||||
// Peer asked for list of connected peers
|
||||
if out, err := self.peer.Server().PeersMessage(); err != nil { |
||||
response <- out |
||||
} |
||||
case PeersMsg: |
||||
self.handlePeers(msg) |
||||
default: |
||||
self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code()) |
||||
} |
||||
} |
||||
close(response) |
||||
} |
||||
|
||||
func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) { |
||||
// somewhat overly paranoid
|
||||
allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived) |
||||
return |
||||
} |
||||
|
||||
func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) { |
||||
err := NewPeerError(errorCode, format, v...) |
||||
logger.Warnln(err) |
||||
fmt.Println(self.peer, err) |
||||
if self.peer != nil { |
||||
self.peer.PeerErrorChan() <- err |
||||
} |
||||
} |
||||
|
||||
func (self *BaseProtocol) handlePeers(msg *Msg) { |
||||
it := msg.Data().NewIterator() |
||||
for it.Next() { |
||||
ip := net.IP(it.Value().Get(0).Bytes()) |
||||
port := it.Value().Get(1).Uint() |
||||
address := &net.TCPAddr{IP: ip, Port: int(port)} |
||||
go self.peer.Server().PeerConnect(address) |
||||
} |
||||
} |
||||
|
||||
func (self *BaseProtocol) handleHandshake(msg *Msg) { |
||||
self.stateLock.Lock() |
||||
defer self.stateLock.Unlock() |
||||
if self.state != nullState { |
||||
self.peerError(ProtocolBreach, "extra handshake") |
||||
return |
||||
} |
||||
|
||||
c := msg.Data() |
||||
|
||||
var ( |
||||
p2pVersion = c.Get(0).Uint() |
||||
id = c.Get(1).Str() |
||||
caps = c.Get(2) |
||||
port = c.Get(3).Uint() |
||||
pubkey = c.Get(4).Bytes() |
||||
) |
||||
fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey) |
||||
|
||||
// Check correctness of p2p protocol version
|
||||
if p2pVersion != P2PVersion { |
||||
self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion) |
||||
return |
||||
} |
||||
|
||||
// Handle the pub key (validation, uniqueness)
|
||||
if len(pubkey) == 0 { |
||||
self.peerError(PubkeyMissing, "not supplied in handshake.") |
||||
return |
||||
} |
||||
|
||||
if len(pubkey) != 64 { |
||||
self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) |
||||
return |
||||
} |
||||
|
||||
// Self connect detection
|
||||
if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 { |
||||
self.peerError(PubkeyForbidden, "not allowed to connect to self") |
||||
return |
||||
} |
||||
|
||||
// register pubkey on server. this also sets the pubkey on the peer (need lock)
|
||||
if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil { |
||||
self.peerError(PubkeyForbidden, err.Error()) |
||||
return |
||||
} |
||||
|
||||
// check port
|
||||
if self.peer.Inbound { |
||||
uint16port := uint16(port) |
||||
if self.peer.Port > 0 && self.peer.Port != uint16port { |
||||
self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port) |
||||
return |
||||
} else { |
||||
self.peer.Port = uint16port |
||||
} |
||||
} |
||||
|
||||
capsIt := caps.NewIterator() |
||||
for capsIt.Next() { |
||||
cap := capsIt.Value().Str() |
||||
self.peer.Caps = append(self.peer.Caps, cap) |
||||
} |
||||
sort.Strings(self.peer.Caps) |
||||
self.peer.Messenger().AddProtocols(self.peer.Caps) |
||||
|
||||
self.peer.Id = id |
||||
|
||||
self.state = handshakeReceived |
||||
|
||||
//p.ethereum.PushPeer(p)
|
||||
// p.ethereum.reactor.Post("peerList", p.ethereum.Peers())
|
||||
return |
||||
} |
@ -0,0 +1,484 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"net" |
||||
"sort" |
||||
"strconv" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/ethereum/eth-go/ethlog" |
||||
) |
||||
|
||||
const ( |
||||
outboundAddressPoolSize = 10 |
||||
disconnectGracePeriod = 2 |
||||
) |
||||
|
||||
type Blacklist interface { |
||||
Get([]byte) (bool, error) |
||||
Put([]byte) error |
||||
Delete([]byte) error |
||||
Exists(pubkey []byte) (ok bool) |
||||
} |
||||
|
||||
type BlacklistMap struct { |
||||
blacklist map[string]bool |
||||
lock sync.RWMutex |
||||
} |
||||
|
||||
func NewBlacklist() *BlacklistMap { |
||||
return &BlacklistMap{ |
||||
blacklist: make(map[string]bool), |
||||
} |
||||
} |
||||
|
||||
func (self *BlacklistMap) Get(pubkey []byte) (bool, error) { |
||||
self.lock.RLock() |
||||
defer self.lock.RUnlock() |
||||
v, ok := self.blacklist[string(pubkey)] |
||||
var err error |
||||
if !ok { |
||||
err = fmt.Errorf("not found") |
||||
} |
||||
return v, err |
||||
} |
||||
|
||||
func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) { |
||||
self.lock.RLock() |
||||
defer self.lock.RUnlock() |
||||
_, ok = self.blacklist[string(pubkey)] |
||||
return |
||||
} |
||||
|
||||
func (self *BlacklistMap) Put(pubkey []byte) error { |
||||
self.lock.RLock() |
||||
defer self.lock.RUnlock() |
||||
self.blacklist[string(pubkey)] = true |
||||
return nil |
||||
} |
||||
|
||||
func (self *BlacklistMap) Delete(pubkey []byte) error { |
||||
self.lock.RLock() |
||||
defer self.lock.RUnlock() |
||||
delete(self.blacklist, string(pubkey)) |
||||
return nil |
||||
} |
||||
|
||||
type Server struct { |
||||
network Network |
||||
listening bool //needed?
|
||||
dialing bool //needed?
|
||||
closed bool |
||||
identity ClientIdentity |
||||
addr net.Addr |
||||
port uint16 |
||||
protocols []string |
||||
|
||||
quit chan chan bool |
||||
peersLock sync.RWMutex |
||||
|
||||
maxPeers int |
||||
peers []*Peer |
||||
peerSlots chan int |
||||
peersTable map[string]int |
||||
peersMsg *Msg |
||||
peerCount int |
||||
|
||||
peerConnect chan net.Addr |
||||
peerDisconnect chan DisconnectRequest |
||||
blacklist Blacklist |
||||
handlers Handlers |
||||
} |
||||
|
||||
var logger = ethlog.NewLogger("P2P") |
||||
|
||||
func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { |
||||
// get alphabetical list of protocol names from handlers map
|
||||
protocols := []string{} |
||||
for protocol := range handlers { |
||||
protocols = append(protocols, protocol) |
||||
} |
||||
sort.Strings(protocols) |
||||
|
||||
_, port, _ := net.SplitHostPort(addr.String()) |
||||
intport, _ := strconv.Atoi(port) |
||||
|
||||
self := &Server{ |
||||
// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
|
||||
network: network, |
||||
identity: identity, |
||||
addr: addr, |
||||
port: uint16(intport), |
||||
protocols: protocols, |
||||
|
||||
quit: make(chan chan bool), |
||||
|
||||
maxPeers: maxPeers, |
||||
peers: make([]*Peer, maxPeers), |
||||
peerSlots: make(chan int, maxPeers), |
||||
peersTable: make(map[string]int), |
||||
|
||||
peerConnect: make(chan net.Addr, outboundAddressPoolSize), |
||||
peerDisconnect: make(chan DisconnectRequest), |
||||
blacklist: blacklist, |
||||
|
||||
handlers: handlers, |
||||
} |
||||
for i := 0; i < maxPeers; i++ { |
||||
self.peerSlots <- i // fill up with indexes
|
||||
} |
||||
return self |
||||
} |
||||
|
||||
func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) { |
||||
addr, err = self.network.NewAddr(host, port) |
||||
return |
||||
} |
||||
|
||||
func (self *Server) ParseAddr(address string) (addr net.Addr, err error) { |
||||
addr, err = self.network.ParseAddr(address) |
||||
return |
||||
} |
||||
|
||||
func (self *Server) ClientIdentity() ClientIdentity { |
||||
return self.identity |
||||
} |
||||
|
||||
func (self *Server) PeersMessage() (msg *Msg, err error) { |
||||
// TODO: memoize and reset when peers change
|
||||
self.peersLock.RLock() |
||||
defer self.peersLock.RUnlock() |
||||
msg = self.peersMsg |
||||
if msg == nil { |
||||
var peerData []interface{} |
||||
for _, i := range self.peersTable { |
||||
peer := self.peers[i] |
||||
peerData = append(peerData, peer.Encode()) |
||||
} |
||||
if len(peerData) == 0 { |
||||
err = fmt.Errorf("no peers") |
||||
} else { |
||||
msg, err = NewMsg(PeersMsg, peerData...) |
||||
self.peersMsg = msg //memoize
|
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (self *Server) Peers() (peers []*Peer) { |
||||
self.peersLock.RLock() |
||||
defer self.peersLock.RUnlock() |
||||
for _, peer := range self.peers { |
||||
if peer != nil { |
||||
peers = append(peers, peer) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (self *Server) PeerCount() int { |
||||
self.peersLock.RLock() |
||||
defer self.peersLock.RUnlock() |
||||
return self.peerCount |
||||
} |
||||
|
||||
var getPeersMsg, _ = NewMsg(GetPeersMsg) |
||||
|
||||
func (self *Server) PeerConnect(addr net.Addr) { |
||||
// TODO: should buffer, filter and uniq
|
||||
// send GetPeersMsg if not blocking
|
||||
select { |
||||
case self.peerConnect <- addr: // not enough peers
|
||||
self.Broadcast("", getPeersMsg) |
||||
default: // we dont care
|
||||
} |
||||
} |
||||
|
||||
func (self *Server) PeerDisconnect() chan DisconnectRequest { |
||||
return self.peerDisconnect |
||||
} |
||||
|
||||
func (self *Server) Blacklist() Blacklist { |
||||
return self.blacklist |
||||
} |
||||
|
||||
func (self *Server) Handlers() Handlers { |
||||
return self.handlers |
||||
} |
||||
|
||||
func (self *Server) Broadcast(protocol string, msg *Msg) { |
||||
self.peersLock.RLock() |
||||
defer self.peersLock.RUnlock() |
||||
for _, peer := range self.peers { |
||||
if peer != nil { |
||||
peer.Write(protocol, msg) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Start the server
|
||||
func (self *Server) Start(listen bool, dial bool) { |
||||
self.network.Start() |
||||
if listen { |
||||
listener, err := self.network.Listener(self.addr) |
||||
if err != nil { |
||||
logger.Warnf("Error initializing listener: %v", err) |
||||
logger.Warnf("Connection listening disabled") |
||||
self.listening = false |
||||
} else { |
||||
self.listening = true |
||||
logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr()) |
||||
go self.inboundPeerHandler(listener) |
||||
} |
||||
} |
||||
if dial { |
||||
dialer, err := self.network.Dialer(self.addr) |
||||
if err != nil { |
||||
logger.Warnf("Error initializing dialer: %v", err) |
||||
logger.Warnf("Connection dialout disabled") |
||||
self.dialing = false |
||||
} else { |
||||
self.dialing = true |
||||
logger.Infoln("Dial peers watching outbound address pool") |
||||
go self.outboundPeerHandler(dialer) |
||||
} |
||||
} |
||||
logger.Infoln("server started") |
||||
} |
||||
|
||||
func (self *Server) Stop() { |
||||
logger.Infoln("server stopping...") |
||||
// // quit one loop if dialing
|
||||
if self.dialing { |
||||
logger.Infoln("stop dialout...") |
||||
dialq := make(chan bool) |
||||
self.quit <- dialq |
||||
<-dialq |
||||
fmt.Println("quit another") |
||||
} |
||||
// quit the other loop if listening
|
||||
if self.listening { |
||||
logger.Infoln("stop listening...") |
||||
listenq := make(chan bool) |
||||
self.quit <- listenq |
||||
<-listenq |
||||
fmt.Println("quit one") |
||||
} |
||||
|
||||
fmt.Println("quit waited") |
||||
|
||||
logger.Infoln("stopping peers...") |
||||
peers := []net.Addr{} |
||||
self.peersLock.RLock() |
||||
self.closed = true |
||||
for _, peer := range self.peers { |
||||
if peer != nil { |
||||
peers = append(peers, peer.Address) |
||||
} |
||||
} |
||||
self.peersLock.RUnlock() |
||||
for _, address := range peers { |
||||
go self.removePeer(DisconnectRequest{ |
||||
addr: address, |
||||
reason: DiscQuitting, |
||||
}) |
||||
} |
||||
// wait till they actually disconnect
|
||||
// this is checked by draining the peerSlots (slots are released back if a peer is removed)
|
||||
i := 0 |
||||
fmt.Println("draining peers") |
||||
|
||||
FOR: |
||||
for { |
||||
select { |
||||
case slot := <-self.peerSlots: |
||||
i++ |
||||
fmt.Printf("%v: found slot %v", i, slot) |
||||
if i == self.maxPeers { |
||||
break FOR |
||||
} |
||||
} |
||||
} |
||||
logger.Infoln("server stopped") |
||||
} |
||||
|
||||
// main loop for adding connections via listening
|
||||
func (self *Server) inboundPeerHandler(listener net.Listener) { |
||||
for { |
||||
select { |
||||
case slot := <-self.peerSlots: |
||||
go self.connectInboundPeer(listener, slot) |
||||
case errc := <-self.quit: |
||||
listener.Close() |
||||
fmt.Println("quit listenloop") |
||||
errc <- true |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// main loop for adding outbound peers based on peerConnect address pool
|
||||
// this same loop handles peer disconnect requests as well
|
||||
func (self *Server) outboundPeerHandler(dialer Dialer) { |
||||
// addressChan initially set to nil (only watches peerConnect if we need more peers)
|
||||
var addressChan chan net.Addr |
||||
slots := self.peerSlots |
||||
var slot *int |
||||
for { |
||||
select { |
||||
case i := <-slots: |
||||
// we need a peer in slot i, slot reserved
|
||||
slot = &i |
||||
// now we can watch for candidate peers in the next loop
|
||||
addressChan = self.peerConnect |
||||
// do not consume more until candidate peer is found
|
||||
slots = nil |
||||
case address := <-addressChan: |
||||
// candidate peer found, will dial out asyncronously
|
||||
// if connection fails slot will be released
|
||||
go self.connectOutboundPeer(dialer, address, *slot) |
||||
// we can watch if more peers needed in the next loop
|
||||
slots = self.peerSlots |
||||
// until then we dont care about candidate peers
|
||||
addressChan = nil |
||||
case request := <-self.peerDisconnect: |
||||
go self.removePeer(request) |
||||
case errc := <-self.quit: |
||||
if addressChan != nil && slot != nil { |
||||
self.peerSlots <- *slot |
||||
} |
||||
fmt.Println("quit dialloop") |
||||
errc <- true |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// check if peer address already connected
|
||||
func (self *Server) connected(address net.Addr) (err error) { |
||||
self.peersLock.RLock() |
||||
defer self.peersLock.RUnlock() |
||||
// fmt.Printf("address: %v\n", address)
|
||||
slot, found := self.peersTable[address.String()] |
||||
if found { |
||||
err = fmt.Errorf("already connected as peer %v (%v)", slot, address) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// connect to peer via listener.Accept()
|
||||
func (self *Server) connectInboundPeer(listener net.Listener, slot int) { |
||||
var address net.Addr |
||||
conn, err := listener.Accept() |
||||
if err == nil { |
||||
address = conn.RemoteAddr() |
||||
err = self.connected(address) |
||||
if err != nil { |
||||
conn.Close() |
||||
} |
||||
} |
||||
if err != nil { |
||||
logger.Debugln(err) |
||||
self.peerSlots <- slot |
||||
} else { |
||||
fmt.Printf("adding %v\n", address) |
||||
go self.addPeer(conn, address, true, slot) |
||||
} |
||||
} |
||||
|
||||
// connect to peer via dial out
|
||||
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { |
||||
var conn net.Conn |
||||
err := self.connected(address) |
||||
if err == nil { |
||||
conn, err = dialer.Dial(address.Network(), address.String()) |
||||
} |
||||
if err != nil { |
||||
logger.Debugln(err) |
||||
self.peerSlots <- slot |
||||
} else { |
||||
go self.addPeer(conn, address, false, slot) |
||||
} |
||||
} |
||||
|
||||
// creates the new peer object and inserts it into its slot
|
||||
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) { |
||||
self.peersLock.Lock() |
||||
defer self.peersLock.Unlock() |
||||
if self.closed { |
||||
fmt.Println("oopsy, not no longer need peer") |
||||
conn.Close() //oopsy our bad
|
||||
self.peerSlots <- slot // release slot
|
||||
} else { |
||||
peer := NewPeer(conn, address, inbound, self) |
||||
self.peers[slot] = peer |
||||
self.peersTable[address.String()] = slot |
||||
self.peerCount++ |
||||
// reset peersmsg
|
||||
self.peersMsg = nil |
||||
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) |
||||
peer.Start() |
||||
} |
||||
} |
||||
|
||||
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
|
||||
func (self *Server) removePeer(request DisconnectRequest) { |
||||
self.peersLock.Lock() |
||||
|
||||
address := request.addr |
||||
slot := self.peersTable[address.String()] |
||||
peer := self.peers[slot] |
||||
fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot) |
||||
if peer == nil { |
||||
logger.Debugf("already removed peer on %v", address) |
||||
self.peersLock.Unlock() |
||||
return |
||||
} |
||||
// remove from list and index
|
||||
self.peerCount-- |
||||
self.peers[slot] = nil |
||||
delete(self.peersTable, address.String()) |
||||
// reset peersmsg
|
||||
self.peersMsg = nil |
||||
fmt.Printf("removed peer %v (slot %v)\n", peer, slot) |
||||
self.peersLock.Unlock() |
||||
|
||||
// sending disconnect message
|
||||
disconnectMsg, _ := NewMsg(DiscMsg, request.reason) |
||||
peer.Write("", disconnectMsg) |
||||
// be nice and wait
|
||||
time.Sleep(disconnectGracePeriod * time.Second) |
||||
// switch off peer and close connections etc.
|
||||
fmt.Println("stopping peer") |
||||
peer.Stop() |
||||
fmt.Println("stopped peer") |
||||
// release slot to signal need for a new peer, last!
|
||||
self.peerSlots <- slot |
||||
} |
||||
|
||||
// fix handshake message to push to peers
|
||||
func (self *Server) Handshake() *Msg { |
||||
fmt.Println(self.identity.Pubkey()[1:]) |
||||
msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:]) |
||||
return msg |
||||
} |
||||
|
||||
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { |
||||
// Check for blacklisting
|
||||
if self.blacklist.Exists(pubkey) { |
||||
return fmt.Errorf("blacklisted") |
||||
} |
||||
|
||||
self.peersLock.RLock() |
||||
defer self.peersLock.RUnlock() |
||||
for _, peer := range self.peers { |
||||
if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { |
||||
return fmt.Errorf("already connected") |
||||
} |
||||
} |
||||
candidate.Pubkey = pubkey |
||||
return nil |
||||
} |
@ -0,0 +1,208 @@ |
||||
package p2p |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"net" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
type TestNetwork struct { |
||||
connections map[string]*TestNetworkConnection |
||||
dialer Dialer |
||||
maxinbound int |
||||
} |
||||
|
||||
func NewTestNetwork(maxinbound int) *TestNetwork { |
||||
connections := make(map[string]*TestNetworkConnection) |
||||
return &TestNetwork{ |
||||
connections: connections, |
||||
dialer: &TestDialer{connections}, |
||||
maxinbound: maxinbound, |
||||
} |
||||
} |
||||
|
||||
func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { |
||||
return self.dialer, nil |
||||
} |
||||
|
||||
func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { |
||||
return &TestListener{ |
||||
connections: self.connections, |
||||
addr: addr, |
||||
max: self.maxinbound, |
||||
}, nil |
||||
} |
||||
|
||||
func (self *TestNetwork) Start() error { |
||||
return nil |
||||
} |
||||
|
||||
func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) { |
||||
return |
||||
} |
||||
|
||||
func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) { |
||||
return |
||||
} |
||||
|
||||
type TestAddr struct { |
||||
name string |
||||
} |
||||
|
||||
func (self *TestAddr) String() string { |
||||
return self.name |
||||
} |
||||
|
||||
func (*TestAddr) Network() string { |
||||
return "test" |
||||
} |
||||
|
||||
type TestDialer struct { |
||||
connections map[string]*TestNetworkConnection |
||||
} |
||||
|
||||
func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) { |
||||
address := &TestAddr{addr} |
||||
tconn := NewTestNetworkConnection(address) |
||||
self.connections[addr] = tconn |
||||
conn = net.Conn(tconn) |
||||
return |
||||
} |
||||
|
||||
type TestListener struct { |
||||
connections map[string]*TestNetworkConnection |
||||
addr net.Addr |
||||
max int |
||||
i int |
||||
} |
||||
|
||||
func (self *TestListener) Accept() (conn net.Conn, err error) { |
||||
self.i++ |
||||
if self.i > self.max { |
||||
err = fmt.Errorf("no more") |
||||
} else { |
||||
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} |
||||
tconn := NewTestNetworkConnection(addr) |
||||
key := tconn.RemoteAddr().String() |
||||
self.connections[key] = tconn |
||||
conn = net.Conn(tconn) |
||||
fmt.Printf("accepted connection from: %v \n", addr) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (self *TestListener) Close() error { |
||||
return nil |
||||
} |
||||
|
||||
func (self *TestListener) Addr() net.Addr { |
||||
return self.addr |
||||
} |
||||
|
||||
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { |
||||
network = NewTestNetwork(1) |
||||
addr := &TestAddr{"test:30303"} |
||||
identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey") |
||||
maxPeers := 2 |
||||
if handlers == nil { |
||||
handlers = make(Handlers) |
||||
} |
||||
blackist := NewBlacklist() |
||||
server = New(network, addr, identity, handlers, maxPeers, blackist) |
||||
fmt.Println(server.identity.Pubkey()) |
||||
return |
||||
} |
||||
|
||||
func TestServerListener(t *testing.T) { |
||||
network, server := SetupTestServer(nil) |
||||
server.Start(true, false) |
||||
time.Sleep(10 * time.Millisecond) |
||||
server.Stop() |
||||
peer1, ok := network.connections["inboundpeer-1"] |
||||
if !ok { |
||||
t.Error("not found inbound peer 1") |
||||
} else { |
||||
fmt.Printf("out: %v\n", peer1.Out) |
||||
if len(peer1.Out) != 2 { |
||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
func TestServerDialer(t *testing.T) { |
||||
network, server := SetupTestServer(nil) |
||||
server.Start(false, true) |
||||
server.peerConnect <- &TestAddr{"outboundpeer-1"} |
||||
time.Sleep(10 * time.Millisecond) |
||||
server.Stop() |
||||
peer1, ok := network.connections["outboundpeer-1"] |
||||
if !ok { |
||||
t.Error("not found outbound peer 1") |
||||
} else { |
||||
fmt.Printf("out: %v\n", peer1.Out) |
||||
if len(peer1.Out) != 2 { |
||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestServerBroadcast(t *testing.T) { |
||||
handlers := make(Handlers) |
||||
testProtocol := &TestProtocol{Msgs: []*Msg{}} |
||||
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } |
||||
network, server := SetupTestServer(handlers) |
||||
server.Start(true, true) |
||||
server.peerConnect <- &TestAddr{"outboundpeer-1"} |
||||
time.Sleep(10 * time.Millisecond) |
||||
msg, _ := NewMsg(0) |
||||
server.Broadcast("", msg) |
||||
packet := Packet(0, 0) |
||||
time.Sleep(10 * time.Millisecond) |
||||
server.Stop() |
||||
peer1, ok := network.connections["outboundpeer-1"] |
||||
if !ok { |
||||
t.Error("not found outbound peer 1") |
||||
} else { |
||||
fmt.Printf("out: %v\n", peer1.Out) |
||||
if len(peer1.Out) != 3 { |
||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) |
||||
} else { |
||||
if bytes.Compare(peer1.Out[1], packet) != 0 { |
||||
t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) |
||||
} |
||||
} |
||||
} |
||||
peer2, ok := network.connections["inboundpeer-1"] |
||||
if !ok { |
||||
t.Error("not found inbound peer 2") |
||||
} else { |
||||
fmt.Printf("out: %v\n", peer2.Out) |
||||
if len(peer1.Out) != 3 { |
||||
t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) |
||||
} else { |
||||
if bytes.Compare(peer2.Out[1], packet) != 0 { |
||||
t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestServerPeersMessage(t *testing.T) { |
||||
handlers := make(Handlers) |
||||
_, server := SetupTestServer(handlers) |
||||
server.Start(true, true) |
||||
defer server.Stop() |
||||
server.peerConnect <- &TestAddr{"outboundpeer-1"} |
||||
time.Sleep(10 * time.Millisecond) |
||||
peersMsg, err := server.PeersMessage() |
||||
fmt.Println(peersMsg) |
||||
if err != nil { |
||||
t.Errorf("expect no error, got %v", err) |
||||
} |
||||
if c := server.PeerCount(); c != 2 { |
||||
t.Errorf("expect 2 peers, got %v", c) |
||||
} |
||||
} |
Loading…
Reference in new issue