obscuren 10 years ago
commit 76fa75b394
  1. 93
      cmd/bootnode/main.go
  2. 48
      cmd/ethereum/flags.go
  3. 31
      cmd/ethereum/main.go
  4. 867
      cmd/mist/assets/qml/main.qml
  5. 12
      cmd/mist/assets/qml/views/info.qml
  6. 9
      cmd/mist/bindings.go
  7. 44
      cmd/mist/flags.go
  8. 25
      cmd/mist/gui.go
  9. 29
      cmd/mist/main.go
  10. 8
      cmd/mist/ui_lib.go
  11. 58
      cmd/peerserver/main.go
  12. 8
      cmd/utils/cmd.go
  13. 13
      core/block_processor.go
  14. 8
      core/helper_test.go
  15. 35
      crypto/crypto.go
  16. 2
      crypto/crypto_test.go
  17. 3
      crypto/key.go
  18. 129
      eth/backend.go
  19. 3
      eth/protocol.go
  20. 24
      eth/protocol_test.go
  21. 7
      javascript/javascript_runtime.go
  22. 63
      p2p/client_identity.go
  23. 30
      p2p/client_identity_test.go
  24. 363
      p2p/crypto.go
  25. 167
      p2p/crypto_test.go
  26. 291
      p2p/discover/node.go
  27. 201
      p2p/discover/node_test.go
  28. 280
      p2p/discover/table.go
  29. 311
      p2p/discover/table_test.go
  30. 431
      p2p/discover/udp.go
  31. 211
      p2p/discover/udp_test.go
  32. 143
      p2p/message.go
  33. 140
      p2p/message_test.go
  34. 23
      p2p/nat.go
  35. 235
      p2p/nat/nat.go
  36. 115
      p2p/nat/natpmp.go
  37. 149
      p2p/nat/natupnp.go
  38. 55
      p2p/natpmp.go
  39. 341
      p2p/natupnp.go
  40. 472
      p2p/peer.go
  41. 58
      p2p/peer_error.go
  42. 297
      p2p/peer_test.go
  43. 244
      p2p/protocol.go
  44. 158
      p2p/protocol_test.go
  45. 402
      p2p/server.go
  46. 87
      p2p/server_test.go
  47. 2
      p2p/testlog_test.go
  48. 40
      p2p/testpoc7.go
  49. 22
      rlp/encode.go
  50. 6
      rlp/encode_test.go
  51. 2
      xeth/types.go
  52. 1
      xeth/xeth.go

@ -0,0 +1,93 @@
/*
This file is part of go-ethereum
go-ethereum is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
go-ethereum is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
*/
// Command bootnode runs a bootstrap node for the Discovery Protocol.
package main
import (
"crypto/ecdsa"
"encoding/hex"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/nat"
)
func main() {
var (
listenAddr = flag.String("addr", ":30301", "listen address")
genKey = flag.String("genkey", "", "generate a node key and quit")
nodeKeyFile = flag.String("nodekey", "", "private key filename")
nodeKeyHex = flag.String("nodekeyhex", "", "private key as hex (for testing)")
natdesc = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
nodeKey *ecdsa.PrivateKey
err error
)
flag.Parse()
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
if *genKey != "" {
writeKey(*genKey)
os.Exit(0)
}
natm, err := nat.Parse(*natdesc)
if err != nil {
log.Fatalf("-nat: %v", err)
}
switch {
case *nodeKeyFile == "" && *nodeKeyHex == "":
log.Fatal("Use -nodekey or -nodekeyhex to specify a private key")
case *nodeKeyFile != "" && *nodeKeyHex != "":
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
case *nodeKeyFile != "":
if nodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
log.Fatalf("-nodekey: %v", err)
}
case *nodeKeyHex != "":
if nodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
log.Fatalf("-nodekeyhex: %v", err)
}
}
if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm); err != nil {
log.Fatal(err)
}
select {}
}
func writeKey(target string) {
key, err := crypto.GenerateKey()
if err != nil {
log.Fatal("could not generate key: %v", err)
}
b := crypto.FromECDSA(key)
if target == "-" {
fmt.Println(hex.EncodeToString(b))
} else {
if err := ioutil.WriteFile(target, b, 0600); err != nil {
log.Fatal("write error: ", err)
}
}
}

@ -21,6 +21,7 @@
package main package main
import ( import (
"crypto/ecdsa"
"flag" "flag"
"fmt" "fmt"
"log" "log"
@ -28,7 +29,9 @@ import (
"os/user" "os/user"
"path" "path"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/vm" "github.com/ethereum/go-ethereum/vm"
) )
@ -42,14 +45,14 @@ var (
StartWebSockets bool StartWebSockets bool
RpcPort int RpcPort int
WsPort int WsPort int
NatType string
PMPGateway string
OutboundPort string OutboundPort string
ShowGenesis bool ShowGenesis bool
AddPeer string AddPeer string
MaxPeer int MaxPeer int
GenAddr bool GenAddr bool
SeedNode string BootNodes string
NodeKey *ecdsa.PrivateKey
NAT nat.Interface
SecretFile string SecretFile string
ExportDir string ExportDir string
NonInteractive bool NonInteractive bool
@ -84,6 +87,7 @@ func defaultDataDir() string {
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini") var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
func Init() { func Init() {
// TODO: move common flag processing to cmd/util
flag.Usage = func() { flag.Usage = func() {
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0]) fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
flag.PrintDefaults() flag.PrintDefaults()
@ -93,18 +97,12 @@ func Init() {
flag.StringVar(&Identifier, "id", "", "Custom client identifier") flag.StringVar(&Identifier, "id", "", "Custom client identifier")
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use") flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)") flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
flag.StringVar(&NatType, "nat", "", "NAT support (UPNP|PMP) (none)")
flag.StringVar(&PMPGateway, "pmp", "", "Gateway IP for PMP")
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on") flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on") flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
flag.BoolVar(&StartRpc, "rpc", false, "start rpc server") flag.BoolVar(&StartRpc, "rpc", false, "start rpc server")
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server") flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)") flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
flag.StringVar(&SeedNode, "seednode", "poc-8.ethdev.com:30303", "ip:port of seed node to connect to. Set to blank for skip")
flag.BoolVar(&SHH, "shh", true, "whisper protocol (on)")
flag.BoolVar(&Dial, "dial", true, "dial out connections (on)")
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key") flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)") flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given") flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
@ -127,8 +125,38 @@ func Init() {
flag.BoolVar(&StartJsConsole, "js", false, "launches javascript console") flag.BoolVar(&StartJsConsole, "js", false, "launches javascript console")
flag.BoolVar(&PrintVersion, "version", false, "prints version number") flag.BoolVar(&PrintVersion, "version", false, "prints version number")
// Network stuff
var (
nodeKeyFile = flag.String("nodekey", "", "network private key file")
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)")
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
)
flag.BoolVar(&Dial, "dial", true, "dial out connections (default on)")
flag.BoolVar(&SHH, "shh", true, "run whisper protocol (default on)")
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap")
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
flag.Parse() flag.Parse()
var err error
if NAT, err = nat.Parse(*natstr); err != nil {
log.Fatalf("-nat: %v", err)
}
switch {
case *nodeKeyFile != "" && *nodeKeyHex != "":
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
case *nodeKeyFile != "":
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
log.Fatalf("-nodekey: %v", err)
}
case *nodeKeyHex != "":
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
log.Fatalf("-nodekeyhex: %v", err)
}
}
if VmType >= int(vm.MaxVmTy) { if VmType >= int(vm.MaxVmTy) {
log.Fatal("Invalid VM type ", VmType) log.Fatal("Invalid VM type ", VmType)
} }

@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/state" "github.com/ethereum/go-ethereum/state"
) )
@ -61,21 +62,19 @@ func main() {
utils.InitConfig(VmType, ConfigFile, Datadir, "ETH") utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
ethereum, err := eth.New(&eth.Config{ ethereum, err := eth.New(&eth.Config{
Name: ClientIdentifier, Name: p2p.MakeName(ClientIdentifier, Version),
Version: Version, KeyStore: KeyStore,
KeyStore: KeyStore, DataDir: Datadir,
DataDir: Datadir, LogFile: LogFile,
LogFile: LogFile, LogLevel: LogLevel,
LogLevel: LogLevel, MaxPeers: MaxPeer,
LogFormat: LogFormat, Port: OutboundPort,
Identifier: Identifier, NAT: NAT,
MaxPeers: MaxPeer, KeyRing: KeyRing,
Port: OutboundPort, Shh: SHH,
NATType: PMPGateway, Dial: Dial,
PMPGateway: PMPGateway, BootNodes: BootNodes,
KeyRing: KeyRing, NodeKey: NodeKey,
Shh: SHH,
Dial: Dial,
}) })
if err != nil { if err != nil {
@ -135,7 +134,7 @@ func main() {
utils.StartWebSockets(ethereum, WsPort) utils.StartWebSockets(ethereum, WsPort)
} }
utils.StartEthereum(ethereum, SeedNode) utils.StartEthereum(ethereum)
if StartJsConsole { if StartJsConsole {
InitJsConsole(ethereum) InitJsConsole(ethereum)

@ -11,6 +11,7 @@ import "../ext/http.js" as Http
ApplicationWindow { ApplicationWindow {
<<<<<<< HEAD
id: root id: root
//flags: Qt.FramelessWindowHint //flags: Qt.FramelessWindowHint
@ -1102,3 +1103,869 @@ ApplicationWindow {
} }
} }
} }
=======
id: root
property var ethx : Eth.ethx
width: 1200
height: 820
minimumWidth: 300
title: "Mist"
TextField {
id: copyElementHax
visible: false
}
function copyToClipboard(text) {
copyElementHax.text = text
copyElementHax.selectAll()
copyElementHax.copy()
}
// Takes care of loading all default plugins
Component.onCompleted: {
var wallet = addPlugin("./views/wallet.qml", {noAdd: true, close: false, section: "ethereum", active: true});
addPlugin("./views/miner.qml", {noAdd: true, close: false, section: "ethereum", active: true});
addPlugin("./views/transaction.qml", {noAdd: true, close: false, section: "legacy"});
addPlugin("./views/whisper.qml", {noAdd: true, close: false, section: "legacy"});
addPlugin("./views/chain.qml", {noAdd: true, close: false, section: "legacy"});
addPlugin("./views/pending_tx.qml", {noAdd: true, close: false, section: "legacy"});
addPlugin("./views/info.qml", {noAdd: true, close: false, section: "legacy"});
mainSplit.setView(wallet.view, wallet.menuItem);
newBrowserTab(eth.assetPath("html/home.html"));
// Command setup
gui.sendCommand(0)
}
function activeView(view, menuItem) {
mainSplit.setView(view, menuItem)
if (view.hideUrl) {
urlPane.visible = false;
mainView.anchors.top = rootView.top
} else {
urlPane.visible = true;
mainView.anchors.top = divider.bottom
}
}
function addViews(view, path, options) {
var views = mainSplit.addComponent(view, options)
views.menuItem.path = path
mainSplit.views.push(views);
if(!options.noAdd) {
gui.addPlugin(path)
}
return views
}
function addPlugin(path, options) {
try {
if(typeof(path) === "string" && /^https?/.test(path)) {
console.log('load http')
Http.request(path, function(o) {
if(o.status === 200) {
var view = Qt.createQmlObject(o.responseText, mainView, path)
addViews(view, path, options)
}
})
return
}
var component = Qt.createComponent(path);
if(component.status != Component.Ready) {
if(component.status == Component.Error) {
ethx.note("error: ", component.errorString());
}
return
}
var view = mainView.createView(component, options)
var views = addViews(view, path, options)
return views
} catch(e) {
console.log(e)
}
}
function newBrowserTab(url) {
var window = addPlugin("./views/browser.qml", {noAdd: true, close: true, section: "apps", active: true});
window.view.url = url;
window.menuItem.title = "Mist";
activeView(window.view, window.menuItem);
}
menuBar: MenuBar {
Menu {
title: "File"
MenuItem {
text: "Import App"
shortcut: "Ctrl+o"
onTriggered: {
generalFileDialog.show(true, importApp)
}
}
MenuItem {
text: "Add plugin"
onTriggered: {
generalFileDialog.show(true, function(path) {
addPlugin(path, {close: true, section: "apps"})
})
}
}
MenuItem {
text: "New tab"
shortcut: "Ctrl+t"
onTriggered: {
newBrowserTab("about:blank");
}
}
MenuSeparator {}
MenuItem {
text: "Import key"
shortcut: "Ctrl+i"
onTriggered: {
generalFileDialog.show(true, function(path) {
gui.importKey(path)
})
}
}
MenuItem {
text: "Export keys"
shortcut: "Ctrl+e"
onTriggered: {
generalFileDialog.show(false, function(path) {
})
}
}
}
Menu {
title: "Developer"
MenuItem {
iconSource: "../icecream.png"
text: "Debugger"
shortcut: "Ctrl+d"
onTriggered: eth.startDebugger()
}
MenuItem {
text: "Import Tx"
onTriggered: {
txImportDialog.visible = true
}
}
MenuItem {
text: "Run JS file"
onTriggered: {
generalFileDialog.show(true, function(path) {
eth.evalJavascriptFile(path)
})
}
}
MenuItem {
text: "Dump state"
onTriggered: {
generalFileDialog.show(false, function(path) {
// Empty hash for latest
gui.dumpState("", path)
})
}
}
MenuSeparator {}
}
Menu {
title: "Network"
MenuItem {
text: "Connect to Node"
shortcut: "Ctrl+p"
onTriggered: {
addPeerWin.visible = true
}
}
MenuItem {
text: "Show Peers"
shortcut: "Ctrl+e"
onTriggered: {
peerWindow.visible = true
}
}
}
Menu {
title: "Help"
MenuItem {
text: "About"
onTriggered: {
aboutWin.visible = true
}
}
}
Menu {
title: "GLOBAL SHORTCUTS"
visible: false
MenuItem {
visible: false
shortcut: "Ctrl+l"
onTriggered: {
url.focus = true
}
}
}
}
statusBar: StatusBar {
//height: 32
id: statusBar
Label {
//y: 6
id: walletValueLabel
font.pixelSize: 10
styleColor: "#797979"
}
Label {
//y: 6
objectName: "miningLabel"
visible: true
font.pixelSize: 10
anchors.right: lastBlockLabel.left
anchors.rightMargin: 5
}
Label {
//y: 6
id: lastBlockLabel
objectName: "lastBlockLabel"
visible: true
text: ""
font.pixelSize: 10
anchors.right: peerGroup.left
anchors.rightMargin: 5
}
ProgressBar {
visible: false
id: downloadIndicator
value: 0
objectName: "downloadIndicator"
y: -4
x: statusBar.width / 2 - this.width / 2
width: 160
}
Label {
visible: false
objectName: "downloadLabel"
//y: 7
anchors.left: downloadIndicator.right
anchors.leftMargin: 5
font.pixelSize: 10
text: "0 / 0"
}
RowLayout {
id: peerGroup
//y: 7
anchors.right: parent.right
MouseArea {
onDoubleClicked: peerWindow.visible = true
anchors.fill: parent
}
Label {
id: peerLabel
font.pixelSize: 10
text: "0 / 0"
}
}
}
property var blockModel: ListModel {
id: blockModel
}
SplitView {
property var views: [];
id: mainSplit
anchors.fill: parent
resizing: false
function setView(view, menu) {
for(var i = 0; i < views.length; i++) {
views[i].view.visible = false
views[i].menuItem.setSelection(false)
}
view.visible = true
menu.setSelection(true)
}
function addComponent(view, options) {
view.visible = false
view.anchors.fill = mainView
var menuItem = menu.createMenuItem(view, options);
if( view.hasOwnProperty("menuItem") ) {
view.menuItem = menuItem;
}
if( view.hasOwnProperty("onReady") ) {
view.onReady.call(view)
}
if( options.active ) {
setView(view, menuItem)
}
return {view: view, menuItem: menuItem}
}
/*********************
* Main menu.
********************/
Rectangle {
id: menu
Layout.minimumWidth: 210
Layout.maximumWidth: 210
anchors.top: parent.top
color: "#ececec"
Component {
id: menuItemTemplate
Rectangle {
id: menuItem
property var view;
property var path;
property var closable;
property alias title: label.text
property alias icon: icon.source
property alias secondaryTitle: secondary.text
function setSelection(on) {
sel.visible = on
}
width: 206
height: 28
color: "#00000000"
anchors {
left: parent.left
leftMargin: 4
}
Rectangle {
id: sel
visible: false
anchors.fill: parent
color: "#00000000"
Rectangle {
id: r
anchors.fill: parent
border.color: "#CCCCCC"
border.width: 1
radius: 5
color: "#FFFFFFFF"
}
Rectangle {
anchors {
top: r.top
bottom: r.bottom
right: r.right
}
width: 10
color: "#FFFFFFFF"
Rectangle {
anchors {
left: parent.left
right: parent.right
top: parent.top
}
height: 1
color: "#CCCCCC"
}
Rectangle {
anchors {
left: parent.left
right: parent.right
bottom: parent.bottom
}
height: 1
color: "#CCCCCC"
}
}
}
MouseArea {
anchors.fill: parent
onClicked: {
activeView(view, menuItem);
}
}
Image {
id: icon
height: 20
width: 20
anchors {
left: parent.left
verticalCenter: parent.verticalCenter
leftMargin: 3
}
MouseArea {
anchors.fill: parent
onClicked: {
menuItem.closeApp()
}
}
}
Text {
id: label
anchors {
left: icon.right
verticalCenter: parent.verticalCenter
leftMargin: 3
}
color: "#0D0A01"
font.pixelSize: 12
}
Text {
id: secondary
anchors {
right: parent.right
rightMargin: 8
verticalCenter: parent.verticalCenter
}
color: "#AEADBE"
font.pixelSize: 12
}
function closeApp() {
if(!this.closable) { return; }
if(this.view.hasOwnProperty("onDestroy")) {
this.view.onDestroy.call(this.view)
}
this.view.destroy()
this.destroy()
for (var i = 0; i < mainSplit.views.length; i++) {
var view = mainSplit.views[i];
if (view.menuItem === this) {
mainSplit.views.splice(i, 1);
break;
}
}
gui.removePlugin(this.path)
activeView(mainSplit.views[0].view, mainSplit.views[0].menuItem);
}
}
}
function createMenuItem(view, options) {
if(options === undefined) {
options = {};
}
var section;
switch(options.section) {
case "ethereum":
section = menuDefault;
break;
case "legacy":
section = menuLegacy;
break;
default:
section = menuApps;
break;
}
var comp = menuItemTemplate.createObject(section)
comp.view = view
comp.title = view.title
if(view.hasOwnProperty("iconSource")) {
comp.icon = view.iconSource;
}
comp.closable = options.close;
return comp
}
ColumnLayout {
id: menuColumn
y: 10
width: parent.width
anchors.left: parent.left
anchors.right: parent.right
spacing: 3
Text {
text: "ETHEREUM"
font.bold: true
anchors {
left: parent.left
leftMargin: 5
}
color: "#888888"
}
ColumnLayout {
id: menuDefault
spacing: 3
anchors {
left: parent.left
right: parent.right
}
}
Text {
text: "NET"
font.bold: true
anchors {
left: parent.left
leftMargin: 5
}
color: "#888888"
}
ColumnLayout {
id: menuApps
spacing: 3
anchors {
left: parent.left
right: parent.right
}
}
Text {
text: "DEBUG"
font.bold: true
anchors {
left: parent.left
leftMargin: 5
}
color: "#888888"
}
ColumnLayout {
id: menuLegacy
spacing: 3
anchors {
left: parent.left
right: parent.right
}
}
}
}
/*********************
* Main view
********************/
Rectangle {
id: rootView
anchors.right: parent.right
anchors.left: menu.right
anchors.bottom: parent.bottom
anchors.top: parent.top
color: "#00000000"
Rectangle {
id: urlPane
height: 40
color: "#00000000"
anchors {
left: parent.left
right: parent.right
leftMargin: 5
rightMargin: 5
top: parent.top
topMargin: 5
}
TextField {
id: url
objectName: "url"
placeholderText: "DApp URL"
anchors {
left: parent.left
right: parent.right
top: parent.top
topMargin: 5
rightMargin: 5
leftMargin: 5
}
Keys.onReturnPressed: {
if(/^https?/.test(this.text)) {
newBrowserTab(this.text);
} else {
addPlugin(this.text, {close: true, section: "apps"})
}
}
}
}
// Border
Rectangle {
id: divider
anchors {
left: parent.left
right: parent.right
top: urlPane.bottom
}
z: -1
height: 1
color: "#CCCCCC"
}
Rectangle {
id: mainView
color: "#00000000"
anchors.right: parent.right
anchors.left: parent.left
anchors.bottom: parent.bottom
anchors.top: divider.bottom
function createView(component) {
var view = component.createObject(mainView)
return view;
}
}
}
}
/******************
* Dialogs
*****************/
FileDialog {
id: generalFileDialog
property var callback;
onAccepted: {
var path = this.fileUrl.toString();
callback.call(this, path);
}
function show(selectExisting, callback) {
generalFileDialog.callback = callback;
generalFileDialog.selectExisting = selectExisting;
this.open();
}
}
/******************
* Wallet functions
*****************/
function importApp(path) {
var ext = path.split('.').pop()
if(ext == "html" || ext == "htm") {
eth.openHtml(path)
}else if(ext == "qml"){
addPlugin(path, {close: true, section: "apps"})
}
}
function setWalletValue(value) {
walletValueLabel.text = value
}
function loadPlugin(name) {
console.log("Loading plugin" + name)
var view = mainView.addPlugin(name)
}
function setPeers(text) {
peerLabel.text = text
}
function addPeer(peer) {
// We could just append the whole peer object but it cries if you try to alter them
peerModel.append({ip: peer.ip, port: peer.port, lastResponse:timeAgo(peer.lastSend), latency: peer.latency, version: peer.version, caps: peer.caps})
}
function resetPeers(){
peerModel.clear()
}
function timeAgo(unixTs){
var lapsed = (Date.now() - new Date(unixTs*1000)) / 1000
return (lapsed + " seconds ago")
}
function convertToPretty(unixTs){
var a = new Date(unixTs*1000);
var months = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec'];
var year = a.getFullYear();
var month = months[a.getMonth()];
var date = a.getDate();
var hour = a.getHours();
var min = a.getMinutes();
var sec = a.getSeconds();
var time = date+' '+month+' '+year+' '+hour+':'+min+':'+sec ;
return time;
}
/**********************
* Windows
*********************/
Window {
id: peerWindow
//flags: Qt.CustomizeWindowHint | Qt.Tool | Qt.WindowCloseButtonHint
height: 200
width: 700
Rectangle {
anchors.fill: parent
property var peerModel: ListModel {
id: peerModel
}
TableView {
anchors.fill: parent
id: peerTable
model: peerModel
TableViewColumn{width: 200; role: "ip" ; title: "IP" }
TableViewColumn{width: 260; role: "version" ; title: "Version" }
TableViewColumn{width: 180; role: "caps" ; title: "Capabilities" }
}
}
}
Window {
id: aboutWin
visible: false
title: "About"
minimumWidth: 350
maximumWidth: 350
maximumHeight: 280
minimumHeight: 280
Image {
id: aboutIcon
height: 150
width: 150
fillMode: Image.PreserveAspectFit
smooth: true
source: "../facet.png"
x: 10
y: 30
}
Text {
anchors.left: aboutIcon.right
anchors.leftMargin: 10
anchors.top: parent.top
anchors.topMargin: 30
font.pointSize: 12
text: "<h2>Mist (0.7.10)</h2><br><h3>Development</h3>Jeffrey Wilcke<br>Viktor Trón<br>Felix Lange<br>Taylor Gerring<br>Daniel Nagy<br><h3>UX</h3>Alex van de Sande<br>"
}
}
Window {
id: txImportDialog
minimumWidth: 270
maximumWidth: 270
maximumHeight: 50
minimumHeight: 50
TextField {
id: txImportField
width: 170
anchors.verticalCenter: parent.verticalCenter
anchors.left: parent.left
anchors.leftMargin: 10
onAccepted: {
}
}
Button {
anchors.left: txImportField.right
anchors.verticalCenter: parent.verticalCenter
anchors.leftMargin: 5
text: "Import"
onClicked: {
eth.importTx(txImportField.text)
txImportField.visible = false
}
}
Component.onCompleted: {
addrField.focus = true
}
}
Window {
id: addPeerWin
visible: false
minimumWidth: 400
maximumWidth: 400
maximumHeight: 50
minimumHeight: 50
title: "Connect to Node"
TextField {
id: addrField
placeholderText: "enode://<hex node id>:<IP address>:<port>"
anchors.verticalCenter: parent.verticalCenter
anchors.left: parent.left
anchors.right: addPeerButton.left
anchors.leftMargin: 10
anchors.rightMargin: 10
onAccepted: {
eth.connectToPeer(addrField.text)
addPeerWin.visible = false
}
}
Button {
id: addPeerButton
anchors.right: parent.right
anchors.verticalCenter: parent.verticalCenter
anchors.rightMargin: 10
text: "Connect"
onClicked: {
eth.connectToPeer(addrField.text)
addPeerWin.visible = false
}
}
Component.onCompleted: {
addrField.focus = true
}
}
}
>>>>>>> 32a9c0ca809508c1648b8f44f3e09725af7a80d3

@ -32,18 +32,6 @@ Rectangle {
width: 500 width: 500
} }
Label {
text: "Client ID"
}
TextField {
text: gui.getCustomIdentifier()
width: 500
placeholderText: "Anonymous"
onTextChanged: {
gui.setCustomIdentifier(text)
}
}
TextArea { TextArea {
objectName: "statsPane" objectName: "statsPane"
width: parent.width width: parent.width

@ -64,15 +64,6 @@ func (gui *Gui) Transact(recipient, value, gas, gasPrice, d string) (string, err
return gui.xeth.Transact(recipient, value, gas, gasPrice, data) return gui.xeth.Transact(recipient, value, gas, gasPrice, data)
} }
func (gui *Gui) SetCustomIdentifier(customIdentifier string) {
gui.clientIdentity.SetCustomIdentifier(customIdentifier)
gui.config.Save("id", customIdentifier)
}
func (gui *Gui) GetCustomIdentifier() string {
return gui.clientIdentity.GetCustomIdentifier()
}
// functions that allow Gui to implement interface guilogger.LogSystem // functions that allow Gui to implement interface guilogger.LogSystem
func (gui *Gui) SetLogLevel(level logger.LogLevel) { func (gui *Gui) SetLogLevel(level logger.LogLevel) {
gui.logLevel = level gui.logLevel = level

@ -21,6 +21,7 @@
package main package main
import ( import (
"crypto/ecdsa"
"flag" "flag"
"fmt" "fmt"
"log" "log"
@ -31,7 +32,9 @@ import (
"runtime" "runtime"
"bitbucket.org/kardianos/osext" "bitbucket.org/kardianos/osext"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/vm" "github.com/ethereum/go-ethereum/vm"
) )
@ -39,19 +42,18 @@ var (
Identifier string Identifier string
KeyRing string KeyRing string
KeyStore string KeyStore string
PMPGateway string
StartRpc bool StartRpc bool
StartWebSockets bool StartWebSockets bool
RpcPort int RpcPort int
WsPort int WsPort int
UseUPnP bool
NatType string
OutboundPort string OutboundPort string
ShowGenesis bool ShowGenesis bool
AddPeer string AddPeer string
MaxPeer int MaxPeer int
GenAddr bool GenAddr bool
SeedNode string BootNodes string
NodeKey *ecdsa.PrivateKey
NAT nat.Interface
SecretFile string SecretFile string
ExportDir string ExportDir string
NonInteractive bool NonInteractive bool
@ -99,6 +101,7 @@ func defaultDataDir() string {
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini") var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
func Init() { func Init() {
// TODO: move common flag processing to cmd/utils
flag.Usage = func() { flag.Usage = func() {
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0]) fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
flag.PrintDefaults() flag.PrintDefaults()
@ -108,30 +111,51 @@ func Init() {
flag.StringVar(&Identifier, "id", "", "Custom client identifier") flag.StringVar(&Identifier, "id", "", "Custom client identifier")
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use") flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)") flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
flag.BoolVar(&UseUPnP, "upnp", true, "enable UPnP support")
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on") flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on") flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
flag.BoolVar(&StartRpc, "rpc", true, "start rpc server") flag.BoolVar(&StartRpc, "rpc", true, "start rpc server")
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server") flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)") flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
flag.StringVar(&SeedNode, "seednode", "poc-8.ethdev.com:30303", "ip:port of seed node to connect to. Set to blank for skip")
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key") flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
flag.StringVar(&NatType, "nat", "", "NAT support (UPNP|PMP) (none)")
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)") flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given") flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
flag.StringVar(&LogFile, "logfile", "", "log file (defaults to standard output)") flag.StringVar(&LogFile, "logfile", "", "log file (defaults to standard output)")
flag.StringVar(&Datadir, "datadir", defaultDataDir(), "specifies the datadir to use") flag.StringVar(&Datadir, "datadir", defaultDataDir(), "specifies the datadir to use")
flag.StringVar(&PMPGateway, "pmp", "", "Gateway IP for PMP")
flag.StringVar(&ConfigFile, "conf", defaultConfigFile, "config file") flag.StringVar(&ConfigFile, "conf", defaultConfigFile, "config file")
flag.StringVar(&DebugFile, "debug", "", "debug file (no debugging if not set)") flag.StringVar(&DebugFile, "debug", "", "debug file (no debugging if not set)")
flag.IntVar(&LogLevel, "loglevel", int(logger.InfoLevel), "loglevel: 0-5: silent,error,warn,info,debug,debug detail)") flag.IntVar(&LogLevel, "loglevel", int(logger.InfoLevel), "loglevel: 0-5: silent,error,warn,info,debug,debug detail)")
flag.StringVar(&AssetPath, "asset_path", defaultAssetPath(), "absolute path to GUI assets directory") flag.StringVar(&AssetPath, "asset_path", defaultAssetPath(), "absolute path to GUI assets directory")
// Network stuff
var (
nodeKeyFile = flag.String("nodekey", "", "network private key file")
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)")
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
)
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap")
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
flag.Parse() flag.Parse()
var err error
if NAT, err = nat.Parse(*natstr); err != nil {
log.Fatalf("-nat: %v", err)
}
switch {
case *nodeKeyFile != "" && *nodeKeyHex != "":
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
case *nodeKeyFile != "":
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
log.Fatalf("-nodekey: %v", err)
}
case *nodeKeyHex != "":
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
log.Fatalf("-nodekeyhex: %v", err)
}
}
if VmType >= int(vm.MaxVmTy) { if VmType >= int(vm.MaxVmTy) {
log.Fatal("Invalid VM type ", VmType) log.Fatal("Invalid VM type ", VmType)
} }

@ -41,7 +41,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/miner" "github.com/ethereum/go-ethereum/miner"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/ui/qt/qwhisper" "github.com/ethereum/go-ethereum/ui/qt/qwhisper"
"github.com/ethereum/go-ethereum/xeth" "github.com/ethereum/go-ethereum/xeth"
"github.com/obscuren/qml" "github.com/obscuren/qml"
@ -77,9 +76,8 @@ type Gui struct {
xeth *xeth.XEth xeth *xeth.XEth
Session string Session string
clientIdentity *p2p.SimpleClientIdentity config *ethutil.ConfigManager
config *ethutil.ConfigManager
plugins map[string]plugin plugins map[string]plugin
@ -87,7 +85,7 @@ type Gui struct {
} }
// Create GUI, but doesn't start it // Create GUI, but doesn't start it
func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, clientIdentity *p2p.SimpleClientIdentity, session string, logLevel int) *Gui { func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, session string, logLevel int) *Gui {
db, err := ethdb.NewLDBDatabase("tx_database") db, err := ethdb.NewLDBDatabase("tx_database")
if err != nil { if err != nil {
panic(err) panic(err)
@ -95,15 +93,14 @@ func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, clientIden
xeth := xeth.New(ethereum) xeth := xeth.New(ethereum)
gui := &Gui{eth: ethereum, gui := &Gui{eth: ethereum,
txDb: db, txDb: db,
xeth: xeth, xeth: xeth,
logLevel: logger.LogLevel(logLevel), logLevel: logger.LogLevel(logLevel),
Session: session, Session: session,
open: false, open: false,
clientIdentity: clientIdentity, config: config,
config: config, plugins: make(map[string]plugin),
plugins: make(map[string]plugin), serviceEvents: make(chan ServEv, 1),
serviceEvents: make(chan ServEv, 1),
} }
data, _ := ethutil.ReadAllFile(path.Join(ethutil.Config.ExecPath, "plugins.json")) data, _ := ethutil.ReadAllFile(path.Join(ethutil.Config.ExecPath, "plugins.json"))
json.Unmarshal([]byte(data), &gui.plugins) json.Unmarshal([]byte(data), &gui.plugins)

@ -52,19 +52,18 @@ func run() error {
config := utils.InitConfig(VmType, ConfigFile, Datadir, "ETH") config := utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
ethereum, err := eth.New(&eth.Config{ ethereum, err := eth.New(&eth.Config{
Name: ClientIdentifier, Name: p2p.MakeName(ClientIdentifier, Version),
Version: Version, KeyStore: KeyStore,
KeyStore: KeyStore, DataDir: Datadir,
DataDir: Datadir, LogFile: LogFile,
LogFile: LogFile, LogLevel: LogLevel,
LogLevel: LogLevel, MaxPeers: MaxPeer,
Identifier: Identifier, Port: OutboundPort,
MaxPeers: MaxPeer, NAT: NAT,
Port: OutboundPort, BootNodes: BootNodes,
NATType: PMPGateway, NodeKey: NodeKey,
PMPGateway: PMPGateway, KeyRing: KeyRing,
KeyRing: KeyRing, Dial: true,
Dial: true,
}) })
if err != nil { if err != nil {
mainlogger.Fatalln(err) mainlogger.Fatalln(err)
@ -79,12 +78,12 @@ func run() error {
utils.StartWebSockets(ethereum, WsPort) utils.StartWebSockets(ethereum, WsPort)
} }
gui := NewWindow(ethereum, config, ethereum.ClientIdentity().(*p2p.SimpleClientIdentity), KeyRing, LogLevel) gui := NewWindow(ethereum, config, KeyRing, LogLevel)
utils.RegisterInterrupt(func(os.Signal) { utils.RegisterInterrupt(func(os.Signal) {
gui.Stop() gui.Stop()
}) })
go utils.StartEthereum(ethereum, SeedNode) go utils.StartEthereum(ethereum)
fmt.Println("ETH stack took", time.Since(tstart)) fmt.Println("ETH stack took", time.Since(tstart))

@ -136,15 +136,15 @@ func (ui *UiLib) Muted(content string) {
func (ui *UiLib) Connect(button qml.Object) { func (ui *UiLib) Connect(button qml.Object) {
if !ui.connected { if !ui.connected {
ui.eth.Start(SeedNode) ui.eth.Start()
ui.connected = true ui.connected = true
button.Set("enabled", false) button.Set("enabled", false)
} }
} }
func (ui *UiLib) ConnectToPeer(addr string) { func (ui *UiLib) ConnectToPeer(nodeURL string) {
if err := ui.eth.SuggestPeer(addr); err != nil { if err := ui.eth.SuggestPeer(nodeURL); err != nil {
guilogger.Infoln(err) guilogger.Infoln("SuggestPeer error: " + err.Error())
} }
} }

@ -1,58 +0,0 @@
/*
This file is part of go-ethereum
go-ethereum is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
go-ethereum is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
*/
package main
import (
"crypto/elliptic"
"flag"
"log"
"os"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
)
var (
natType = flag.String("nat", "", "NAT traversal implementation")
pmpGateway = flag.String("gateway", "", "gateway address for NAT-PMP")
listenAddr = flag.String("addr", ":30301", "listen address")
)
func main() {
flag.Parse()
nat, err := p2p.ParseNAT(*natType, *pmpGateway)
if err != nil {
log.Fatal("invalid nat:", err)
}
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.InfoLevel))
key, _ := crypto.GenerateKey()
marshaled := elliptic.Marshal(crypto.S256(), key.PublicKey.X, key.PublicKey.Y)
srv := p2p.Server{
MaxPeers: 100,
Identity: p2p.NewSimpleClientIdentity("Ethereum(G)", "0.1", "Peer Server Two", marshaled),
ListenAddr: *listenAddr,
NAT: nat,
NoDial: true,
}
if err := srv.Start(); err != nil {
log.Fatal("could not start server:", err)
}
select {}
}

@ -121,13 +121,11 @@ func exit(err error) {
os.Exit(status) os.Exit(status)
} }
func StartEthereum(ethereum *eth.Ethereum, SeedNode string) { func StartEthereum(ethereum *eth.Ethereum) {
clilogger.Infof("Starting %s", ethereum.ClientIdentity()) clilogger.Infoln("Starting ", ethereum.Name())
err := ethereum.Start(SeedNode) if err := ethereum.Start(); err != nil {
if err != nil {
exit(err) exit(err)
} }
RegisterInterrupt(func(sig os.Signal) { RegisterInterrupt(func(sig os.Signal) {
ethereum.Stop() ethereum.Stop()
logger.Flush() logger.Flush()

@ -23,6 +23,19 @@ type PendingBlockEvent struct {
var statelogger = logger.NewLogger("BLOCK") var statelogger = logger.NewLogger("BLOCK")
type EthManager interface {
BlockProcessor() *BlockProcessor
ChainManager() *ChainManager
TxPool() *TxPool
PeerCount() int
IsMining() bool
IsListening() bool
Peers() []*p2p.Peer
KeyManager() *crypto.KeyManager
Db() ethutil.Database
EventMux() *event.TypeMux
}
type BlockProcessor struct { type BlockProcessor struct {
db ethutil.Database db ethutil.Database
// Mutex for locking the block processor. Blocks can only be handled one at a time // Mutex for locking the block processor. Blocks can only be handled one at a time

@ -9,7 +9,6 @@ import (
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/p2p"
) )
// Implement our EthTest Manager // Implement our EthTest Manager
@ -54,13 +53,6 @@ func (tm *TestManager) TxPool() *TxPool {
func (tm *TestManager) EventMux() *event.TypeMux { func (tm *TestManager) EventMux() *event.TypeMux {
return tm.eventMux return tm.eventMux
} }
func (tm *TestManager) Broadcast(msgType p2p.Msg, data []interface{}) {
fmt.Println("Broadcast not implemented")
}
func (tm *TestManager) ClientIdentity() p2p.ClientIdentity {
return nil
}
func (tm *TestManager) KeyManager() *crypto.KeyManager { func (tm *TestManager) KeyManager() *crypto.KeyManager {
return nil return nil
} }

@ -8,6 +8,8 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"io"
"os"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -27,10 +29,11 @@ func init() {
ecies.AddParamsForCurve(S256(), ecies.ECIES_AES128_SHA256) ecies.AddParamsForCurve(S256(), ecies.ECIES_AES128_SHA256)
} }
func Sha3(data []byte) []byte { func Sha3(data ...[]byte) []byte {
d := sha3.NewKeccak256() d := sha3.NewKeccak256()
d.Write(data) for _, b := range data {
d.Write(b)
}
return d.Sum(nil) return d.Sum(nil)
} }
@ -98,6 +101,32 @@ func FromECDSAPub(pub *ecdsa.PublicKey) []byte {
return elliptic.Marshal(S256(), pub.X, pub.Y) return elliptic.Marshal(S256(), pub.X, pub.Y)
} }
// HexToECDSA parses a secp256k1 private key.
func HexToECDSA(hexkey string) (*ecdsa.PrivateKey, error) {
b, err := hex.DecodeString(hexkey)
if err != nil {
return nil, errors.New("invalid hex string")
}
if len(b) != 32 {
return nil, errors.New("invalid length, need 256 bits")
}
return ToECDSA(b), nil
}
// LoadECDSA loads a secp256k1 private key from the given file.
func LoadECDSA(file string) (*ecdsa.PrivateKey, error) {
buf := make([]byte, 32)
fd, err := os.Open(file)
if err != nil {
return nil, err
}
defer fd.Close()
if _, err := io.ReadFull(fd, buf); err != nil {
return nil, err
}
return ToECDSA(buf), nil
}
func GenerateKey() (*ecdsa.PrivateKey, error) { func GenerateKey() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(S256(), rand.Reader) return ecdsa.GenerateKey(S256(), rand.Reader)
} }

@ -18,7 +18,7 @@ import (
func TestSha3(t *testing.T) { func TestSha3(t *testing.T) {
msg := []byte("abc") msg := []byte("abc")
exp, _ := hex.DecodeString("4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45") exp, _ := hex.DecodeString("4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45")
checkhash(t, "Sha3-256", Sha3, msg, exp) checkhash(t, "Sha3-256", func(in []byte) []byte { return Sha3(in) }, msg, exp)
} }
func TestSha256(t *testing.T) { func TestSha256(t *testing.T) {

@ -25,11 +25,12 @@ package crypto
import ( import (
"bytes" "bytes"
"code.google.com/p/go-uuid/uuid"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"encoding/json" "encoding/json"
"io" "io"
"code.google.com/p/go-uuid/uuid"
) )
type Key struct { type Key struct {

@ -1,9 +1,9 @@
package eth package eth
import ( import (
"crypto/ecdsa"
"fmt" "fmt"
"net" "strings"
"sync"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
@ -12,27 +12,35 @@ import (
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
ethlogger "github.com/ethereum/go-ethereum/logger" ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/pow/ezp" "github.com/ethereum/go-ethereum/pow/ezp"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
"github.com/ethereum/go-ethereum/whisper" "github.com/ethereum/go-ethereum/whisper"
) )
var logger = ethlogger.NewLogger("SERV")
type Config struct { type Config struct {
Name string Name string
Version string KeyStore string
Identifier string DataDir string
KeyStore string LogFile string
DataDir string LogLevel int
LogFile string KeyRing string
LogLevel int
LogFormat string MaxPeers int
KeyRing string Port string
MaxPeers int
Port string
NATType string
PMPGateway string
// This should be a space-separated list of
// discovery node URLs.
BootNodes string
// This key is used to identify the node on the network.
// If nil, an ephemeral key is used.
NodeKey *ecdsa.PrivateKey
NAT nat.Interface
Shh bool Shh bool
Dial bool Dial bool
@ -42,6 +50,22 @@ type Config struct {
var logger = ethlogger.NewLogger("SERV") var logger = ethlogger.NewLogger("SERV")
var jsonlogger = ethlogger.NewJsonLogger() var jsonlogger = ethlogger.NewJsonLogger()
func (cfg *Config) parseBootNodes() []*discover.Node {
var ns []*discover.Node
for _, url := range strings.Split(cfg.BootNodes, " ") {
if url == "" {
continue
}
n, err := discover.ParseNode(url)
if err != nil {
logger.Errorf("Bootstrap URL %s: %v\n", url, err)
continue
}
ns = append(ns, n)
}
return ns
}
type Ethereum struct { type Ethereum struct {
// Channel for shutting down the ethereum // Channel for shutting down the ethereum
shutdownChan chan bool shutdownChan chan bool
@ -68,11 +92,7 @@ type Ethereum struct {
WsServer rpc.RpcServer WsServer rpc.RpcServer
keyManager *crypto.KeyManager keyManager *crypto.KeyManager
clientIdentity p2p.ClientIdentity logger ethlogger.LogSystem
logger ethlogger.LogSystem
synclock sync.Mutex
syncGroup sync.WaitGroup
Mining bool Mining bool
} }
@ -105,21 +125,17 @@ func New(config *Config) (*Ethereum, error) {
// Initialise the keyring // Initialise the keyring
keyManager.Init(config.KeyRing, 0, false) keyManager.Init(config.KeyRing, 0, false)
// Create a new client id for this instance. This will help identifying the node on the network
clientId := p2p.NewSimpleClientIdentity(config.Name, config.Version, config.Identifier, keyManager.PublicKey())
saveProtocolVersion(db) saveProtocolVersion(db)
//ethutil.Config.Db = db //ethutil.Config.Db = db
eth := &Ethereum{ eth := &Ethereum{
shutdownChan: make(chan bool), shutdownChan: make(chan bool),
quit: make(chan bool), quit: make(chan bool),
db: db, db: db,
keyManager: keyManager, keyManager: keyManager,
clientIdentity: clientId, blacklist: p2p.NewBlacklist(),
blacklist: p2p.NewBlacklist(), eventMux: &event.TypeMux{},
eventMux: &event.TypeMux{}, logger: logger,
logger: logger,
} }
eth.chainManager = core.NewChainManager(db, eth.EventMux()) eth.chainManager = core.NewChainManager(db, eth.EventMux())
@ -134,21 +150,22 @@ func New(config *Config) (*Ethereum, error) {
ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool) ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool)
protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()} protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()}
netprv := config.NodeKey
nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway) if netprv == nil {
if err != nil { if netprv, err = crypto.GenerateKey(); err != nil {
return nil, err return nil, fmt.Errorf("could not generate server key: %v", err)
}
} }
eth.net = &p2p.Server{ eth.net = &p2p.Server{
Identity: clientId, PrivateKey: netprv,
MaxPeers: config.MaxPeers, Name: config.Name,
Protocols: protocols, MaxPeers: config.MaxPeers,
Blacklist: eth.blacklist, Protocols: protocols,
NAT: nat, Blacklist: eth.blacklist,
NoDial: !config.Dial, NAT: config.NAT,
NoDial: !config.Dial,
BootstrapNodes: config.parseBootNodes(),
} }
if len(config.Port) > 0 { if len(config.Port) > 0 {
eth.net.ListenAddr = ":" + config.Port eth.net.ListenAddr = ":" + config.Port
} }
@ -164,8 +181,8 @@ func (s *Ethereum) Logger() ethlogger.LogSystem {
return s.logger return s.logger
} }
func (s *Ethereum) ClientIdentity() p2p.ClientIdentity { func (s *Ethereum) Name() string {
return s.clientIdentity return s.net.Name
} }
func (s *Ethereum) ChainManager() *core.ChainManager { func (s *Ethereum) ChainManager() *core.ChainManager {
@ -221,7 +238,7 @@ func (s *Ethereum) Coinbase() []byte {
} }
// Start the ethereum // Start the ethereum
func (s *Ethereum) Start(seedNode string) error { func (s *Ethereum) Start() error {
jsonlogger.LogJson(&ethlogger.LogStarting{ jsonlogger.LogJson(&ethlogger.LogStarting{
ClientString: s.ClientIdentity().String(), ClientString: s.ClientIdentity().String(),
Coinbase: ethutil.Bytes2Hex(s.KeyManager().Address()), Coinbase: ethutil.Bytes2Hex(s.KeyManager().Address()),
@ -250,26 +267,16 @@ func (s *Ethereum) Start(seedNode string) error {
s.blockSub = s.eventMux.Subscribe(core.NewMinedBlockEvent{}) s.blockSub = s.eventMux.Subscribe(core.NewMinedBlockEvent{})
go s.blockBroadcastLoop() go s.blockBroadcastLoop()
// TODO: read peers here
if len(seedNode) > 0 {
logger.Infof("Connect to seed node %v", seedNode)
if err := s.SuggestPeer(seedNode); err != nil {
logger.Infoln(err)
}
}
logger.Infoln("Server started") logger.Infoln("Server started")
return nil return nil
} }
func (self *Ethereum) SuggestPeer(addr string) error { func (self *Ethereum) SuggestPeer(nodeURL string) error {
netaddr, err := net.ResolveTCPAddr("tcp", addr) n, err := discover.ParseNode(nodeURL)
if err != nil { if err != nil {
logger.Errorf("couldn't resolve %s:", addr, err) return fmt.Errorf("invalid node URL: %v", err)
return err
} }
self.net.SuggestPeer(n)
self.net.SuggestPeer(netaddr.IP, netaddr.Port, nil)
return nil return nil
} }

@ -92,13 +92,14 @@ func EthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool)
// the main loop that handles incoming messages // the main loop that handles incoming messages
// note RemovePeer in the post-disconnect hook // note RemovePeer in the post-disconnect hook
func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool, peer *p2p.Peer, rw p2p.MsgReadWriter) (err error) { func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool, peer *p2p.Peer, rw p2p.MsgReadWriter) (err error) {
id := peer.ID()
self := &ethProtocol{ self := &ethProtocol{
txPool: txPool, txPool: txPool,
chainManager: chainManager, chainManager: chainManager,
blockPool: blockPool, blockPool: blockPool,
rw: rw, rw: rw,
peer: peer, peer: peer,
id: fmt.Sprintf("%x", peer.Identity().Pubkey()[:8]), id: fmt.Sprintf("%x", id[:8]),
} }
err = self.handleStatus() err = self.handleStatus()
if err == nil { if err == nil {

@ -14,6 +14,7 @@ import (
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
ethlogger "github.com/ethereum/go-ethereum/logger" ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
) )
var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel)) var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel))
@ -128,26 +129,11 @@ func (self *testBlockPool) RemovePeer(peerId string) {
} }
} }
// TODO: refactor this into p2p/client_identity
type peerId struct {
pubkey []byte
}
func (self *peerId) String() string {
return "test peer"
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func testPeer() *p2p.Peer { func testPeer() *p2p.Peer {
return p2p.NewPeer(&peerId{}, []p2p.Cap{}) var id discover.NodeID
pk := crypto.GenerateNewKeyPair().PublicKey
copy(id[:], pk)
return p2p.NewPeer(id, "test peer", []p2p.Cap{})
} }
type ethProtocolTester struct { type ethProtocolTester struct {

@ -197,12 +197,13 @@ func (self *JSRE) watch(call otto.FunctionCall) otto.Value {
} }
func (self *JSRE) addPeer(call otto.FunctionCall) otto.Value { func (self *JSRE) addPeer(call otto.FunctionCall) otto.Value {
host, err := call.Argument(0).ToString() nodeURL, err := call.Argument(0).ToString()
if err != nil { if err != nil {
return otto.FalseValue() return otto.FalseValue()
} }
self.ethereum.SuggestPeer(host) if err := self.ethereum.SuggestPeer(nodeURL); err != nil {
return otto.FalseValue()
}
return otto.TrueValue() return otto.TrueValue()
} }

@ -1,63 +0,0 @@
package p2p
import (
"fmt"
"runtime"
)
// ClientIdentity represents the identity of a peer.
type ClientIdentity interface {
String() string // human readable identity
Pubkey() []byte // 512-bit public key
}
type SimpleClientIdentity struct {
clientIdentifier string
version string
customIdentifier string
os string
implementation string
pubkey []byte
}
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity {
clientIdentity := &SimpleClientIdentity{
clientIdentifier: clientIdentifier,
version: version,
customIdentifier: customIdentifier,
os: runtime.GOOS,
implementation: runtime.Version(),
pubkey: pubkey,
}
return clientIdentity
}
func (c *SimpleClientIdentity) init() {
}
func (c *SimpleClientIdentity) String() string {
var id string
if len(c.customIdentifier) > 0 {
id = "/" + c.customIdentifier
}
return fmt.Sprintf("%s/v%s%s/%s/%s",
c.clientIdentifier,
c.version,
id,
c.os,
c.implementation)
}
func (c *SimpleClientIdentity) Pubkey() []byte {
return []byte(c.pubkey)
}
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
c.customIdentifier = customIdentifier
}
func (c *SimpleClientIdentity) GetCustomIdentifier() string {
return c.customIdentifier
}

@ -1,30 +0,0 @@
package p2p
import (
"fmt"
"runtime"
"testing"
)
func TestClientIdentity(t *testing.T) {
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
clientString := clientIdentity.String()
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
customIdentifier := clientIdentity.GetCustomIdentifier()
if customIdentifier != "test" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
}
clientIdentity.SetCustomIdentifier("test2")
customIdentifier = clientIdentity.GetCustomIdentifier()
if customIdentifier != "test2" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
}
clientString = clientIdentity.String()
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
}

@ -0,0 +1,363 @@
package p2p
import (
// "binary"
"crypto/ecdsa"
"crypto/rand"
"fmt"
"io"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/obscuren/ecies"
)
var clogger = ethlogger.NewLogger("CRYPTOID")
const (
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
sigLen = 65 // elliptic S256
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
shaLen = 32 // hash length (for nonce etc)
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
authRespLen = pubLen + shaLen + 1
eciesBytes = 65 + 16 + 32
iHSLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
)
type hexkey []byte
func (self hexkey) String() string {
return fmt.Sprintf("(%d) %x", len(self), []byte(self))
}
func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) (
remoteID discover.NodeID,
sessionToken []byte,
err error,
) {
if dial == nil {
var remotePubkey []byte
sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
copy(remoteID[:], remotePubkey)
} else {
remoteID = dial.ID
sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
}
return remoteID, sessionToken, err
}
// outboundEncHandshake negotiates a session token on conn.
// it should be called on the dialing side of the connection.
//
// privateKey is the local client's private key
// remotePublicKey is the remote peer's node ID
// sessionToken is the token from a previous session with this node.
func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) (
newSessionToken []byte,
err error,
) {
auth, initNonce, randomPrivKey, err := authMsg(prvKey, remotePublicKey, sessionToken)
if err != nil {
return nil, err
}
if sessionToken != nil {
clogger.Debugf("session-token: %v", hexkey(sessionToken))
}
clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
if _, err = conn.Write(auth); err != nil {
return nil, err
}
clogger.Debugf("initiator handshake: %v", hexkey(auth))
response := make([]byte, rHSLen)
if _, err = io.ReadFull(conn, response); err != nil {
return nil, err
}
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prvKey)
if err != nil {
return nil, err
}
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
}
// authMsg creates the initiator handshake.
func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (
auth, initNonce []byte,
randomPrvKey *ecdsa.PrivateKey,
err error,
) {
// session init, common to both parties
remotePubKey, err := importPublicKey(remotePubKeyS)
if err != nil {
return
}
var tokenFlag byte // = 0x00
if sessionToken == nil {
// no session token found means we need to generate shared secret.
// ecies shared secret is used as initial session token for new peers
// generate shared key from prv and remote pubkey
if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil {
return
}
// tokenFlag = 0x00 // redundant
} else {
// for known peers, we use stored token from the previous session
tokenFlag = 0x01
}
//E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0)
// E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1)
// allocate msgLen long message,
var msg []byte = make([]byte, authMsgLen)
initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
if _, err = rand.Read(initNonce); err != nil {
return
}
// create known message
// ecdh-shared-secret^nonce for new peers
// token^nonce for old peers
var sharedSecret = xor(sessionToken, initNonce)
// generate random keypair to use for signing
if randomPrvKey, err = crypto.GenerateKey(); err != nil {
return
}
// sign shared secret (message known to both parties): shared-secret
var signature []byte
// signature = sign(ecdhe-random, shared-secret)
// uses secp256k1.Sign
if signature, err = crypto.Sign(sharedSecret, randomPrvKey); err != nil {
return
}
// message
// signed-shared-secret || H(ecdhe-random-pubk) || pubk || nonce || 0x0
copy(msg, signature) // copy signed-shared-secret
// H(ecdhe-random-pubk)
var randomPubKey64 []byte
if randomPubKey64, err = exportPublicKey(&randomPrvKey.PublicKey); err != nil {
return
}
var pubKey64 []byte
if pubKey64, err = exportPublicKey(&prvKey.PublicKey); err != nil {
return
}
copy(msg[sigLen:sigLen+shaLen], crypto.Sha3(randomPubKey64))
// pubkey copied to the correct segment.
copy(msg[sigLen+shaLen:sigLen+shaLen+pubLen], pubKey64)
// nonce is already in the slice
// stick tokenFlag byte to the end
msg[authMsgLen-1] = tokenFlag
// encrypt using remote-pubk
// auth = eciesEncrypt(remote-pubk, msg)
if auth, err = crypto.Encrypt(remotePubKey, msg); err != nil {
return
}
return
}
// completeHandshake is called when the initiator receives an
// authentication response (aka receiver handshake). It completes the
// handshake by reading off parameters the remote peer provides needed
// to set up the secure session.
func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) (
respNonce []byte,
remoteRandomPubKey *ecdsa.PublicKey,
tokenFlag bool,
err error,
) {
var msg []byte
// they prove that msg is meant for me,
// I prove I possess private key if i can read it
if msg, err = crypto.Decrypt(prvKey, auth); err != nil {
return
}
respNonce = msg[pubLen : pubLen+shaLen]
var remoteRandomPubKeyS = msg[:pubLen]
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
return
}
if msg[authRespLen-1] == 0x01 {
tokenFlag = true
}
return
}
// inboundEncHandshake negotiates a session token on conn.
// it should be called on the listening side of the connection.
//
// privateKey is the local client's private key
// sessionToken is the token from a previous session with this node.
func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) (
token, remotePubKey []byte,
err error,
) {
// we are listening connection. we are responders in the
// handshake. Extract info from the authentication. The initiator
// starts by sending us a handshake that we need to respond to. so
// we read auth message first, then respond.
auth := make([]byte, iHSLen)
if _, err := io.ReadFull(conn, auth); err != nil {
return nil, nil, err
}
response, recNonce, initNonce, remotePubKey, randomPrivKey, remoteRandomPubKey, err := authResp(auth, sessionToken, prvKey)
if err != nil {
return nil, nil, err
}
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
if _, err = conn.Write(response); err != nil {
return nil, nil, err
}
clogger.Debugf("receiver handshake:\n%v", hexkey(response))
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
return token, remotePubKey, err
}
// authResp is called by peer if it accepted (but not
// initiated) the connection from the remote. It is passed the initiator
// handshake received and the session token belonging to the
// remote initiator.
//
// The first return value is the authentication response (aka receiver
// handshake) that is to be sent to the remote initiator.
func authResp(auth, sessionToken []byte, prvKey *ecdsa.PrivateKey) (
authResp, respNonce, initNonce, remotePubKeyS []byte,
randomPrivKey *ecdsa.PrivateKey,
remoteRandomPubKey *ecdsa.PublicKey,
err error,
) {
// they prove that msg is meant for me,
// I prove I possess private key if i can read it
msg, err := crypto.Decrypt(prvKey, auth)
if err != nil {
return
}
remotePubKeyS = msg[sigLen+shaLen : sigLen+shaLen+pubLen]
remotePubKey, _ := importPublicKey(remotePubKeyS)
var tokenFlag byte
if sessionToken == nil {
// no session token found means we need to generate shared secret.
// ecies shared secret is used as initial session token for new peers
// generate shared key from prv and remote pubkey
if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil {
return
}
// tokenFlag = 0x00 // redundant
} else {
// for known peers, we use stored token from the previous session
tokenFlag = 0x01
}
// the initiator nonce is read off the end of the message
initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
// I prove that i own prv key (to derive shared secret, and read
// nonce off encrypted msg) and that I own shared secret they
// prove they own the private key belonging to ecdhe-random-pubk
// we can now reconstruct the signed message and recover the peers
// pubkey
var signedMsg = xor(sessionToken, initNonce)
var remoteRandomPubKeyS []byte
if remoteRandomPubKeyS, err = secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]); err != nil {
return
}
// convert to ECDSA standard
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
return
}
// now we find ourselves a long task too, fill it random
var resp = make([]byte, authRespLen)
// generate shaLen long nonce
respNonce = resp[pubLen : pubLen+shaLen]
if _, err = rand.Read(respNonce); err != nil {
return
}
// generate random keypair for session
if randomPrivKey, err = crypto.GenerateKey(); err != nil {
return
}
// responder auth message
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
var randomPubKeyS []byte
if randomPubKeyS, err = exportPublicKey(&randomPrivKey.PublicKey); err != nil {
return
}
copy(resp[:pubLen], randomPubKeyS)
// nonce is already in the slice
resp[authRespLen-1] = tokenFlag
// encrypt using remote-pubk
// auth = eciesEncrypt(remote-pubk, msg)
// why not encrypt with ecdhe-random-remote
if authResp, err = crypto.Encrypt(remotePubKey, resp); err != nil {
return
}
return
}
// newSession is called after the handshake is completed. The
// arguments are values negotiated in the handshake. The return value
// is a new session Token to be remembered for the next time we
// connect with this peer.
func newSession(initNonce, respNonce []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) ([]byte, error) {
// 3) Now we can trust ecdhe-random-pubk to derive new keys
//ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk)
pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey)
dhSharedSecret, err := ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen)
if err != nil {
return nil, err
}
sharedSecret := crypto.Sha3(dhSharedSecret, crypto.Sha3(respNonce, initNonce))
sessionToken := crypto.Sha3(sharedSecret)
return sessionToken, nil
}
// importPublicKey unmarshals 512 bit public keys.
func importPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) {
var pubKey65 []byte
switch len(pubKey) {
case 64:
// add 'uncompressed key' flag
pubKey65 = append([]byte{0x04}, pubKey...)
case 65:
pubKey65 = pubKey
default:
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
}
return crypto.ToECDSAPub(pubKey65), nil
}
func exportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) {
if pubKeyEC == nil {
return nil, fmt.Errorf("no ECDSA public key given")
}
return crypto.FromECDSAPub(pubKeyEC)[1:], nil
}
func xor(one, other []byte) (xor []byte) {
xor = make([]byte, len(one))
for i := 0; i < len(one); i++ {
xor[i] = one[i] ^ other[i]
}
return xor
}

@ -0,0 +1,167 @@
package p2p
import (
"bytes"
"crypto/ecdsa"
"crypto/rand"
"net"
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/obscuren/ecies"
)
func TestPublicKeyEncoding(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
pub0s := crypto.FromECDSAPub(pub0)
pub1, err := importPublicKey(pub0s)
if err != nil {
t.Errorf("%v", err)
}
eciesPub1 := ecies.ImportECDSAPublic(pub1)
if eciesPub1 == nil {
t.Errorf("invalid ecdsa public key")
}
pub1s, err := exportPublicKey(pub1)
if err != nil {
t.Errorf("%v", err)
}
if len(pub1s) != 64 {
t.Errorf("wrong length expect 64, got", len(pub1s))
}
pub2, err := importPublicKey(pub1s)
if err != nil {
t.Errorf("%v", err)
}
pub2s, err := exportPublicKey(pub2)
if err != nil {
t.Errorf("%v", err)
}
if !bytes.Equal(pub1s, pub2s) {
t.Errorf("exports dont match")
}
pub2sEC := crypto.FromECDSAPub(pub2)
if !bytes.Equal(pub0s, pub2sEC) {
t.Errorf("exports dont match")
}
}
func TestSharedSecret(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
if err != nil {
return
}
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
if err != nil {
return
}
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
if !bytes.Equal(ss0, ss1) {
t.Errorf("dont match :(")
}
}
func TestCryptoHandshake(t *testing.T) {
testCryptoHandshake(newkey(), newkey(), nil, t)
}
func TestCryptoHandshakeWithToken(t *testing.T) {
sessionToken := make([]byte, shaLen)
rand.Read(sessionToken)
testCryptoHandshake(newkey(), newkey(), sessionToken, t)
}
func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) {
var err error
// pub0 := &prv0.PublicKey
pub1 := &prv1.PublicKey
// pub0s := crypto.FromECDSAPub(pub0)
pub1s := crypto.FromECDSAPub(pub1)
// simulate handshake by feeding output to input
// initiator sends handshake 'auth'
auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("-> %v", hexkey(auth))
// receiver reads auth and responds with response
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("<- %v\n", hexkey(response))
// initiator reads receiver's response and the key exchange completes
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
if err != nil {
t.Errorf("completeHandshake error: %v", err)
}
// now both parties should have the same session parameters
initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
if err != nil {
t.Errorf("newSession error: %v", err)
}
recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey)
if err != nil {
t.Errorf("newSession error: %v", err)
}
// fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response)
// fmt.Printf("\nauth %x\ninitNonce %x\nresponse%x\nremoteRecNonce %x\nremoteInitNonce %x\nremoteRandomPubKey %x\nrecNonce %x\nremoteInitRandomPubKey %x\ninitSessionToken %x\n\n", auth, initNonce, response, remoteRecNonce, remoteInitNonce, remoteRandomPubKey, recNonce, remoteInitRandomPubKey, initSessionToken)
if !bytes.Equal(initNonce, remoteInitNonce) {
t.Errorf("nonces do not match")
}
if !bytes.Equal(recNonce, remoteRecNonce) {
t.Errorf("receiver nonces do not match")
}
if !bytes.Equal(initSessionToken, recSessionToken) {
t.Errorf("session tokens do not match")
}
}
func TestHandshake(t *testing.T) {
defer testlog(t).detach()
prv0, _ := crypto.GenerateKey()
prv1, _ := crypto.GenerateKey()
pub0s, _ := exportPublicKey(&prv0.PublicKey)
pub1s, _ := exportPublicKey(&prv1.PublicKey)
rw0, rw1 := net.Pipe()
tokens := make(chan []byte)
go func() {
token, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
if err != nil {
t.Errorf("outbound side error: %v", err)
}
tokens <- token
}()
go func() {
token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil)
if err != nil {
t.Errorf("inbound side error: %v", err)
}
if !bytes.Equal(remotePubkey, pub0s) {
t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s)
}
tokens <- token
}()
t1, t2 := <-tokens, <-tokens
if !bytes.Equal(t1, t2) {
t.Error("session token mismatch")
}
}

@ -0,0 +1,291 @@
package discover
import (
"crypto/ecdsa"
"crypto/elliptic"
"encoding/hex"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/rlp"
)
const nodeIDBits = 512
// Node represents a host on the network.
type Node struct {
ID NodeID
IP net.IP
DiscPort int // UDP listening port for discovery protocol
TCPPort int // TCP listening port for RLPx
active time.Time
}
func newNode(id NodeID, addr *net.UDPAddr) *Node {
return &Node{
ID: id,
IP: addr.IP,
DiscPort: addr.Port,
TCPPort: addr.Port,
active: time.Now(),
}
}
func (n *Node) isValid() bool {
// TODO: don't accept localhost, LAN addresses from internet hosts
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
}
// The string representation of a Node is a URL.
// Please see ParseNode for a description of the format.
func (n *Node) String() string {
addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
u := url.URL{
Scheme: "enode",
User: url.User(fmt.Sprintf("%x", n.ID[:])),
Host: addr.String(),
}
if n.DiscPort != n.TCPPort {
u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
}
return u.String()
}
// ParseNode parses a node URL.
//
// A node URL has scheme "enode".
//
// The hexadecimal node ID is encoded in the username portion of the
// URL, separated from the host by an @ sign. The hostname can only be
// given as an IP address, DNS domain names are not allowed. The port
// in the host name section is the TCP listening port. If the TCP and
// UDP (discovery) ports differ, the UDP port is specified as query
// parameter "discport".
//
// In the following example, the node URL describes
// a node with IP address 10.3.58.6, TCP listening port 30303
// and UDP discovery port 30301.
//
// enode://<hex node id>@10.3.58.6:30303?discport=30301
func ParseNode(rawurl string) (*Node, error) {
var n Node
u, err := url.Parse(rawurl)
if u.Scheme != "enode" {
return nil, errors.New("invalid URL scheme, want \"enode\"")
}
if u.User == nil {
return nil, errors.New("does not contain node ID")
}
if n.ID, err = HexID(u.User.String()); err != nil {
return nil, fmt.Errorf("invalid node ID (%v)", err)
}
ip, port, err := net.SplitHostPort(u.Host)
if err != nil {
return nil, fmt.Errorf("invalid host: %v", err)
}
if n.IP = net.ParseIP(ip); n.IP == nil {
return nil, errors.New("invalid IP address")
}
if n.TCPPort, err = strconv.Atoi(port); err != nil {
return nil, errors.New("invalid port")
}
qv := u.Query()
if qv.Get("discport") == "" {
n.DiscPort = n.TCPPort
} else {
if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
return nil, errors.New("invalid discport in query")
}
}
return &n, nil
}
// MustParseNode parses a node URL. It panics if the URL is not valid.
func MustParseNode(rawurl string) *Node {
n, err := ParseNode(rawurl)
if err != nil {
panic("invalid node URL: " + err.Error())
}
return n
}
func (n Node) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
}
func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
var ext rpcNode
if err = s.Decode(&ext); err == nil {
n.TCPPort = int(ext.Port)
n.DiscPort = int(ext.Port)
n.ID = ext.ID
if n.IP = net.ParseIP(ext.IP); n.IP == nil {
return errors.New("invalid IP string")
}
}
return err
}
// NodeID is a unique identifier for each node.
// The node identifier is a marshaled elliptic curve public key.
type NodeID [nodeIDBits / 8]byte
// NodeID prints as a long hexadecimal number.
func (n NodeID) String() string {
return fmt.Sprintf("%#x", n[:])
}
// The Go syntax representation of a NodeID is a call to HexID.
func (n NodeID) GoString() string {
return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
}
// HexID converts a hex string to a NodeID.
// The string may be prefixed with 0x.
func HexID(in string) (NodeID, error) {
if strings.HasPrefix(in, "0x") {
in = in[2:]
}
var id NodeID
b, err := hex.DecodeString(in)
if err != nil {
return id, err
} else if len(b) != len(id) {
return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
}
copy(id[:], b)
return id, nil
}
// MustHexID converts a hex string to a NodeID.
// It panics if the string is not a valid NodeID.
func MustHexID(in string) NodeID {
id, err := HexID(in)
if err != nil {
panic(err)
}
return id
}
// PubkeyID returns a marshaled representation of the given public key.
func PubkeyID(pub *ecdsa.PublicKey) NodeID {
var id NodeID
pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
if len(pbytes)-1 != len(id) {
panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
}
copy(id[:], pbytes[1:])
return id
}
// recoverNodeID computes the public key used to sign the
// given hash from the signature.
func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
pubkey, err := secp256k1.RecoverPubkey(hash, sig)
if err != nil {
return id, err
}
if len(pubkey)-1 != len(id) {
return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
}
for i := range id {
id[i] = pubkey[i+1]
}
return id, nil
}
// distcmp compares the distances a->target and b->target.
// Returns -1 if a is closer to target, 1 if b is closer to target
// and 0 if they are equal.
func distcmp(target, a, b NodeID) int {
for i := range target {
da := a[i] ^ target[i]
db := b[i] ^ target[i]
if da > db {
return 1
} else if da < db {
return -1
}
}
return 0
}
// table of leading zero counts for bytes [0..255]
var lzcount = [256]int{
8, 7, 6, 6, 5, 5, 5, 5,
4, 4, 4, 4, 4, 4, 4, 4,
3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
}
// logdist returns the logarithmic distance between a and b, log2(a ^ b).
func logdist(a, b NodeID) int {
lz := 0
for i := range a {
x := a[i] ^ b[i]
if x == 0 {
lz += 8
} else {
lz += lzcount[x]
break
}
}
return len(a)*8 - lz
}
// randomID returns a random NodeID such that logdist(a, b) == n
func randomID(a NodeID, n int) (b NodeID) {
if n == 0 {
return a
}
// flip bit at position n, fill the rest with random bits
b = a
pos := len(a) - n/8 - 1
bit := byte(0x01) << (byte(n%8) - 1)
if bit == 0 {
pos++
bit = 0x80
}
b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
for i := pos + 1; i < len(a); i++ {
b[i] = byte(rand.Intn(255))
}
return b
}

@ -0,0 +1,201 @@
package discover
import (
"math/big"
"math/rand"
"net"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/ethereum/go-ethereum/crypto"
)
var (
quickrand = rand.New(rand.NewSource(time.Now().Unix()))
quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
)
var parseNodeTests = []struct {
rawurl string
wantError string
wantResult *Node
}{
{
rawurl: "http://foobar",
wantError: `invalid URL scheme, want "enode"`,
},
{
rawurl: "enode://foobar",
wantError: `does not contain node ID`,
},
{
rawurl: "enode://01010101@123.124.125.126:3",
wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
wantError: `invalid IP address`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
wantError: `invalid port`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
wantError: `invalid discport in query`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("127.0.0.1"),
DiscPort: 52150,
TCPPort: 52150,
},
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("::"),
DiscPort: 52150,
TCPPort: 52150,
},
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("127.0.0.1"),
DiscPort: 223344,
TCPPort: 52150,
},
},
}
func TestParseNode(t *testing.T) {
for i, test := range parseNodeTests {
n, err := ParseNode(test.rawurl)
if err == nil && test.wantError != "" {
t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
continue
}
if err != nil && err.Error() != test.wantError {
t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
continue
}
if !reflect.DeepEqual(n, test.wantResult) {
t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
}
}
}
func TestNodeString(t *testing.T) {
for i, test := range parseNodeTests {
if test.wantError != "" {
continue
}
str := test.wantResult.String()
if str != test.rawurl {
t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
}
}
}
func TestHexID(t *testing.T) {
ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
if id1 != ref {
t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
}
if id2 != ref {
t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
}
}
func TestNodeID_recover(t *testing.T) {
prv := newkey()
hash := make([]byte, 32)
sig, err := crypto.Sign(hash, prv)
if err != nil {
t.Fatalf("signing error: %v", err)
}
pub := PubkeyID(&prv.PublicKey)
recpub, err := recoverNodeID(hash, sig)
if err != nil {
t.Fatalf("recovery error: %v", err)
}
if pub != recpub {
t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
}
}
func TestNodeID_distcmp(t *testing.T) {
distcmpBig := func(target, a, b NodeID) int {
tbig := new(big.Int).SetBytes(target[:])
abig := new(big.Int).SetBytes(a[:])
bbig := new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
}
if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
t.Error(err)
}
}
// the random tests is likely to miss the case where they're equal.
func TestNodeID_distcmpEqual(t *testing.T) {
base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
if distcmp(base, x, x) != 0 {
t.Errorf("distcmp(base, x, x) != 0")
}
}
func TestNodeID_logdist(t *testing.T) {
logdistBig := func(a, b NodeID) int {
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(abig, bbig).BitLen()
}
if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
t.Error(err)
}
}
// the random tests is likely to miss the case where they're equal.
func TestNodeID_logdistEqual(t *testing.T) {
x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
if logdist(x, x) != 0 {
t.Errorf("logdist(x, x) != 0")
}
}
func TestNodeID_randomID(t *testing.T) {
// we don't use quick.Check here because its output isn't
// very helpful when the test fails.
for i := 0; i < quickcfg.MaxCount; i++ {
a := gen(NodeID{}, quickrand).(NodeID)
dist := quickrand.Intn(len(NodeID{}) * 8)
result := randomID(a, dist)
actualdist := logdist(result, a)
if dist != actualdist {
t.Log("a: ", a)
t.Log("result:", result)
t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
}
}
}
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
var id NodeID
m := rand.Intn(len(id))
for i := len(id) - 1; i > m; i-- {
id[i] = byte(rand.Uint32())
}
return reflect.ValueOf(id)
}

@ -0,0 +1,280 @@
// Package discover implements the Node Discovery Protocol.
//
// The Node Discovery protocol provides a way to find RLPx nodes that
// can be connected to. It uses a Kademlia-like protocol to maintain a
// distributed database of the IDs and endpoints of all listening
// nodes.
package discover
import (
"net"
"sort"
"sync"
"time"
)
const (
alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size
nBuckets = nodeIDBits + 1 // Number of buckets
)
type Table struct {
mutex sync.Mutex // protects buckets, their content, and nursery
buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes
net transport
self *Node // metadata of the local node
}
// transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key.
type transport interface {
ping(*Node) error
findnode(e *Node, target NodeID) ([]*Node, error)
close()
}
// bucket contains nodes, ordered by their last activity.
type bucket struct {
lastLookup time.Time
entries []*Node
}
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
tab := &Table{net: t, self: newNode(ourID, ourAddr)}
for i := range tab.buckets {
tab.buckets[i] = new(bucket)
}
return tab
}
// Self returns the local node ID.
func (tab *Table) Self() NodeID {
return tab.self.ID
}
// Close terminates the network listener.
func (tab *Table) Close() {
tab.net.close()
}
// Bootstrap sets the bootstrap nodes. These nodes are used to connect
// to the network if the table is empty. Bootstrap will also attempt to
// fill the table by performing random lookup operations on the
// network.
func (tab *Table) Bootstrap(nodes []*Node) {
tab.mutex.Lock()
// TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
tab.nursery = make([]*Node, 0, len(nodes))
for _, n := range nodes {
cpy := *n
tab.nursery = append(tab.nursery, &cpy)
}
tab.mutex.Unlock()
tab.refresh()
}
// Lookup performs a network search for nodes close
// to the given target. It approaches the target by querying
// nodes that are closer to it on each iteration.
func (tab *Table) Lookup(target NodeID) []*Node {
var (
asked = make(map[NodeID]bool)
seen = make(map[NodeID]bool)
reply = make(chan []*Node, alpha)
pendingQueries = 0
)
// don't query further if we hit the target or ourself.
// unlikely to happen often in practice.
asked[target] = true
asked[tab.self.ID] = true
tab.mutex.Lock()
// update last lookup stamp (for refresh logic)
tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
// generate initial result set
result := tab.closest(target, bucketSize)
tab.mutex.Unlock()
for {
// ask the alpha closest nodes that we haven't asked yet
for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
n := result.entries[i]
if !asked[n.ID] {
asked[n.ID] = true
pendingQueries++
go func() {
result, _ := tab.net.findnode(n, target)
reply <- result
}()
}
}
if pendingQueries == 0 {
// we have asked all closest nodes, stop the search
break
}
// wait for the next reply
for _, n := range <-reply {
cn := n
if !seen[n.ID] {
seen[n.ID] = true
result.push(cn, bucketSize)
}
}
pendingQueries--
}
return result.entries
}
// refresh performs a lookup for a random target to keep buckets full.
func (tab *Table) refresh() {
ld := -1 // logdist of chosen bucket
tab.mutex.Lock()
for i, b := range tab.buckets {
if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
ld = i
break
}
}
tab.mutex.Unlock()
result := tab.Lookup(randomID(tab.self.ID, ld))
if len(result) == 0 {
// bootstrap the table with a self lookup
tab.mutex.Lock()
tab.add(tab.nursery)
tab.mutex.Unlock()
tab.Lookup(tab.self.ID)
// TODO: the Kademlia paper says that we're supposed to perform
// random lookups in all buckets further away than our closest neighbor.
}
}
// closest returns the n nodes in the table that are closest to the
// given id. The caller must hold tab.mutex.
func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
// This is a very wasteful way to find the closest nodes but
// obviously correct. I believe that tree-based buckets would make
// this easier to implement efficiently.
close := &nodesByDistance{target: target}
for _, b := range tab.buckets {
for _, n := range b.entries {
close.push(n, nresults)
}
}
return close
}
func (tab *Table) len() (n int) {
for _, b := range tab.buckets {
n += len(b.entries)
}
return n
}
// bumpOrAdd updates the activity timestamp for the given node and
// attempts to insert the node into a bucket. The returned Node might
// not be part of the table. The caller must hold tab.mutex.
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
b := tab.buckets[logdist(tab.self.ID, node)]
if n = b.bump(node); n == nil {
n = newNode(node, from)
if len(b.entries) == bucketSize {
tab.pingReplace(n, b)
} else {
b.entries = append(b.entries, n)
}
}
return n
}
func (tab *Table) pingReplace(n *Node, b *bucket) {
old := b.entries[bucketSize-1]
go func() {
if err := tab.net.ping(old); err == nil {
// it responded, we don't need to replace it.
return
}
// it didn't respond, replace the node if it is still the oldest node.
tab.mutex.Lock()
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
// slide down other entries and put the new one in front.
// TODO: insert in correct position to keep the order
copy(b.entries[1:], b.entries)
b.entries[0] = n
}
tab.mutex.Unlock()
}()
}
// bump updates the activity timestamp for the given node.
// The caller must hold tab.mutex.
func (tab *Table) bump(node NodeID) {
tab.buckets[logdist(tab.self.ID, node)].bump(node)
}
// add puts the entries into the table if their corresponding
// bucket is not full. The caller must hold tab.mutex.
func (tab *Table) add(entries []*Node) {
outer:
for _, n := range entries {
if n == nil || n.ID == tab.self.ID {
// skip bad entries. The RLP decoder returns nil for empty
// input lists.
continue
}
bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
for i := range bucket.entries {
if bucket.entries[i].ID == n.ID {
// already in bucket
continue outer
}
}
if len(bucket.entries) < bucketSize {
bucket.entries = append(bucket.entries, n)
}
}
}
func (b *bucket) bump(id NodeID) *Node {
for i, n := range b.entries {
if n.ID == id {
n.active = time.Now()
// move it to the front
copy(b.entries[1:], b.entries[:i+1])
b.entries[0] = n
return n
}
}
return nil
}
// nodesByDistance is a list of nodes, ordered by
// distance to target.
type nodesByDistance struct {
entries []*Node
target NodeID
}
// push adds the given node to the list, keeping the total size below maxElems.
func (h *nodesByDistance) push(n *Node, maxElems int) {
ix := sort.Search(len(h.entries), func(i int) bool {
return distcmp(h.target, h.entries[i].ID, n.ID) > 0
})
if len(h.entries) < maxElems {
h.entries = append(h.entries, n)
}
if ix == len(h.entries) {
// farther away than all nodes we already have.
// if there was room for it, the node is now the last element.
} else {
// slide existing entries down to make room
// this will overwrite the entry we just appended.
copy(h.entries[ix+1:], h.entries[ix:])
h.entries[ix] = n
}
}

@ -0,0 +1,311 @@
package discover
import (
"crypto/ecdsa"
"errors"
"fmt"
"math/rand"
"net"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/ethereum/go-ethereum/crypto"
)
func TestTable_bumpOrAddBucketAssign(t *testing.T) {
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
for i := 1; i < len(tab.buckets); i++ {
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
}
for i, b := range tab.buckets {
if i > 0 && len(b.entries) != 1 {
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
}
}
}
func TestTable_bumpOrAddPingReplace(t *testing.T) {
pingC := make(pingC)
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
last := fillBucket(tab, 200)
// this bumpOrAdd should not replace the last node
// because the node replies to ping.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
pinged := <-pingC
if pinged != last.ID {
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[200].entries); l != bucketSize {
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
}
if !contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was removed")
}
if contains(tab.buckets[200].entries, new.ID) {
t.Error("new entry was added")
}
}
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
last := fillBucket(tab, 200)
// this bumpOrAdd should replace the last node
// because the node does not reply to ping.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
// wait for async bucket update. damn. this needs to go away.
time.Sleep(2 * time.Millisecond)
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[200].entries); l != bucketSize {
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
}
if contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was not removed")
}
if !contains(tab.buckets[200].entries, new.ID) {
t.Error("new entry was not added")
}
}
func fillBucket(tab *Table, ld int) (last *Node) {
b := tab.buckets[ld]
for len(b.entries) < bucketSize {
b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)})
}
return b.entries[bucketSize-1]
}
type pingC chan NodeID
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
panic("findnode called on pingRecorder")
}
func (t pingC) close() {
panic("close called on pingRecorder")
}
func (t pingC) ping(n *Node) error {
if t == nil {
return errTimeout
}
t <- n.ID
return nil
}
func TestTable_bump(t *testing.T) {
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
// add an old entry and two recent ones
oldactive := time.Now().Add(-2 * time.Minute)
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
others := []*Node{
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
}
tab.add(append(others, old))
if tab.buckets[200].entries[0] == old {
t.Fatal("old entry is at front of bucket")
}
// bumping the old entry should move it to the front
tab.bump(old.ID)
if old.active == oldactive {
t.Error("activity timestamp not updated")
}
if tab.buckets[200].entries[0] != old {
t.Errorf("bumped entry did not move to the front of bucket")
}
}
func TestTable_closest(t *testing.T) {
t.Parallel()
test := func(test *closeTest) bool {
// for any node table, Target and N
tab := newTable(nil, test.Self, &net.UDPAddr{})
tab.add(test.All)
// check that doClosest(Target, N) returns nodes
result := tab.closest(test.Target, test.N).entries
if hasDuplicates(result) {
t.Errorf("result contains duplicates")
return false
}
if !sortedByDistanceTo(test.Target, result) {
t.Errorf("result is not sorted by distance to target")
return false
}
// check that the number of results is min(N, tablen)
wantN := test.N
if tlen := tab.len(); tlen < test.N {
wantN = tlen
}
if len(result) != wantN {
t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
return false
} else if len(result) == 0 {
return true // no need to check distance
}
// check that the result nodes have minimum distance to target.
for _, b := range tab.buckets {
for _, n := range b.entries {
if contains(result, n.ID) {
continue // don't run the check below for nodes in result
}
farthestResult := result[len(result)-1].ID
if distcmp(test.Target, n.ID, farthestResult) < 0 {
t.Errorf("table contains node that is closer to target but it's not in result")
t.Logf(" Target: %v", test.Target)
t.Logf(" Farthest Result: %v", farthestResult)
t.Logf(" ID: %v", n.ID)
return false
}
}
}
return true
}
if err := quick.Check(test, quickcfg); err != nil {
t.Error(err)
}
}
type closeTest struct {
Self NodeID
Target NodeID
All []*Node
N int
}
func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
t := &closeTest{
Self: gen(NodeID{}, rand).(NodeID),
Target: gen(NodeID{}, rand).(NodeID),
N: rand.Intn(bucketSize),
}
for _, id := range gen([]NodeID{}, rand).([]NodeID) {
t.All = append(t.All, &Node{ID: id})
}
return reflect.ValueOf(t)
}
func TestTable_Lookup(t *testing.T) {
self := gen(NodeID{}, quickrand).(NodeID)
target := randomID(self, 200)
transport := findnodeOracle{t, target}
tab := newTable(transport, self, &net.UDPAddr{})
// lookup on empty table returns no nodes
if results := tab.Lookup(target); len(results) > 0 {
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
}
// seed table with initial node (otherwise lookup will terminate immediately)
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
results := tab.Lookup(target)
t.Logf("results:")
for _, e := range results {
t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID)
}
if len(results) != bucketSize {
t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
}
if hasDuplicates(results) {
t.Errorf("result set contains duplicate entries")
}
if !sortedByDistanceTo(target, results) {
t.Errorf("result set not sorted by distance to target")
}
if !contains(results, target) {
t.Errorf("result set does not contain target")
}
}
// findnode on this transport always returns at least one node
// that is one bucket closer to the target.
type findnodeOracle struct {
t *testing.T
target NodeID
}
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
t.t.Logf("findnode query at dist %d", n.DiscPort)
// current log distance is encoded in port number
var result []*Node
switch n.DiscPort {
case 0:
panic("query to node at distance 0")
default:
// TODO: add more randomness to distances
next := n.DiscPort - 1
for i := 0; i < bucketSize; i++ {
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
}
}
return result, nil
}
func (t findnodeOracle) close() {}
func (t findnodeOracle) ping(n *Node) error {
return errors.New("ping is not supported by this transport")
}
func hasDuplicates(slice []*Node) bool {
seen := make(map[NodeID]bool)
for _, e := range slice {
if seen[e.ID] {
return true
}
seen[e.ID] = true
}
return false
}
func sortedByDistanceTo(distbase NodeID, slice []*Node) bool {
var last NodeID
for i, e := range slice {
if i > 0 && distcmp(distbase, e.ID, last) < 0 {
return false
}
last = e.ID
}
return true
}
func contains(ns []*Node, id NodeID) bool {
for _, n := range ns {
if n.ID == id {
return true
}
}
return false
}
// gen wraps quick.Value so it's easier to use.
// it generates a random value of the given value's type.
func gen(typ interface{}, rand *rand.Rand) interface{} {
v, ok := quick.Value(reflect.TypeOf(typ), rand)
if !ok {
panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
}
return v.Interface()
}
func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey()
if err != nil {
panic("couldn't generate key: " + err.Error())
}
return key
}

@ -0,0 +1,431 @@
package discover
import (
"bytes"
"crypto/ecdsa"
"errors"
"fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/rlp"
)
var log = logger.NewLogger("P2P Discovery")
// Errors
var (
errPacketTooSmall = errors.New("too small")
errBadHash = errors.New("bad hash")
errExpired = errors.New("expired")
errTimeout = errors.New("RPC timeout")
errClosed = errors.New("socket closed")
)
// Timeouts
const (
respTimeout = 300 * time.Millisecond
sendTimeout = 300 * time.Millisecond
expiration = 20 * time.Second
refreshInterval = 1 * time.Hour
)
// RPC packet types
const (
pingPacket = iota + 1 // zero is 'reserved'
pongPacket
findnodePacket
neighborsPacket
)
// RPC request structures
type (
ping struct {
IP string // our IP
Port uint16 // our port
Expiration uint64
}
// reply to Ping
pong struct {
ReplyTok []byte
Expiration uint64
}
findnode struct {
// Id to look up. The responding node will send back nodes
// closest to the target.
Target NodeID
Expiration uint64
}
// reply to findnode
neighbors struct {
Nodes []*Node
Expiration uint64
}
)
type rpcNode struct {
IP string
Port uint16
ID NodeID
}
// udp implements the RPC protocol.
type udp struct {
conn *net.UDPConn
priv *ecdsa.PrivateKey
addpending chan *pending
replies chan reply
closing chan struct{}
nat nat.Interface
*Table
}
// pending represents a pending reply.
//
// some implementations of the protocol wish to send more than one
// reply packet to findnode. in general, any neighbors packet cannot
// be matched up with a specific findnode packet.
//
// our implementation handles this by storing a callback function for
// each pending reply. incoming packets from a node are dispatched
// to all the callback functions for that node.
type pending struct {
// these fields must match in the reply.
from NodeID
ptype byte
// time when the request must complete
deadline time.Time
// callback is called when a matching reply arrives. if it returns
// true, the callback is removed from the pending reply queue.
// if it returns false, the reply is considered incomplete and
// the callback will be invoked again for the next matching reply.
callback func(resp interface{}) (done bool)
// errc receives nil when the callback indicates completion or an
// error if no further reply is received within the timeout.
errc chan<- error
}
type reply struct {
from NodeID
ptype byte
data interface{}
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) {
addr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
udp := &udp{
conn: conn,
priv: priv,
closing: make(chan struct{}),
addpending: make(chan *pending),
replies: make(chan reply),
}
realaddr := conn.LocalAddr().(*net.UDPAddr)
if natm != nil {
if !realaddr.IP.IsLoopback() {
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
}
// TODO: react to external IP changes over time.
if ext, err := natm.ExternalIP(); err == nil {
realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
}
}
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
go udp.loop()
go udp.readLoop()
log.Infoln("Listening, ", udp.self)
return udp.Table, nil
}
func (t *udp) close() {
close(t.closing)
t.conn.Close()
// TODO: wait for the loops to end.
}
// ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(e *Node) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
t.send(e, pingPacket, ping{
IP: t.self.IP.String(),
Port: uint16(t.self.TCPPort),
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return <-errc
}
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
nodes := make([]*Node, 0, bucketSize)
nreceived := 0
errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
reply := r.(*neighbors)
for _, n := range reply.Nodes {
nreceived++
if n.isValid() {
nodes = append(nodes, n)
}
}
return nreceived >= bucketSize
})
t.send(to, findnodePacket, findnode{
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
err := <-errc
return nodes, err
}
// pending adds a reply callback to the pending reply queue.
// see the documentation of type pending for a detailed explanation.
func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
ch := make(chan error, 1)
p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
select {
case t.addpending <- p:
// loop will handle it
case <-t.closing:
ch <- errClosed
}
return ch
}
// loop runs in its own goroutin. it keeps track of
// the refresh timer and the pending reply queue.
func (t *udp) loop() {
var (
pending []*pending
nextDeadline time.Time
timeout = time.NewTimer(0)
refresh = time.NewTicker(refreshInterval)
)
<-timeout.C // ignore first timeout
defer refresh.Stop()
defer timeout.Stop()
rearmTimeout := func() {
if len(pending) == 0 || nextDeadline == pending[0].deadline {
return
}
nextDeadline = pending[0].deadline
timeout.Reset(nextDeadline.Sub(time.Now()))
}
for {
select {
case <-refresh.C:
go t.refresh()
case <-t.closing:
for _, p := range pending {
p.errc <- errClosed
}
return
case p := <-t.addpending:
p.deadline = time.Now().Add(respTimeout)
pending = append(pending, p)
rearmTimeout()
case reply := <-t.replies:
// run matching callbacks, remove if they return false.
for i, p := range pending {
if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
p.errc <- nil
copy(pending[i:], pending[i+1:])
pending = pending[:len(pending)-1]
i--
}
}
rearmTimeout()
case now := <-timeout.C:
// notify and remove callbacks whose deadline is in the past.
i := 0
for ; i < len(pending) && now.After(pending[i].deadline); i++ {
pending[i].errc <- errTimeout
}
if i > 0 {
copy(pending, pending[i:])
pending = pending[:len(pending)-i]
}
rearmTimeout()
}
}
}
const (
macSize = 256 / 8
sigSize = 520 / 8
headSize = macSize + sigSize // space of packet frame data
)
var headSpace = make([]byte, headSize)
func (t *udp) send(to *Node, ptype byte, req interface{}) error {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Errorln("error encoding packet:", err)
return err
}
packet := b.Bytes()
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
if err != nil {
log.Errorln("could not sign packet:", err)
return err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// the future.
copy(packet, crypto.Sha3(packet[macSize:]))
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
log.DebugDetailln("UDP send failed:", err)
}
return err
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop() {
defer t.conn.Close()
buf := make([]byte, 4096) // TODO: good buffer size
for {
nbytes, from, err := t.conn.ReadFromUDP(buf)
if err != nil {
return
}
if err := t.packetIn(from, buf[:nbytes]); err != nil {
log.Debugf("Bad packet from %v: %v\n", from, err)
}
}
}
func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
if len(buf) < headSize+1 {
return errPacketTooSmall
}
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
shouldhash := crypto.Sha3(buf[macSize:])
if !bytes.Equal(hash, shouldhash) {
return errBadHash
}
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
if err != nil {
return err
}
var req interface {
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
}
switch ptype := sigdata[0]; ptype {
case pingPacket:
req = new(ping)
case pongPacket:
req = new(pong)
case findnodePacket:
req = new(findnode)
case neighborsPacket:
req = new(neighbors)
default:
return fmt.Errorf("unknown type: %d", ptype)
}
if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
return err
}
log.DebugDetailf("<<< %v %T %v\n", from, req, req)
return req.handle(t, from, fromID, hash)
}
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
if expired(req.Expiration) {
return errExpired
}
t.mutex.Lock()
// Note: we're ignoring the provided IP address right now
n := t.bumpOrAdd(fromID, from)
if req.Port != 0 {
n.TCPPort = int(req.Port)
}
t.mutex.Unlock()
t.send(n, pongPacket, pong{
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return nil
}
func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
if expired(req.Expiration) {
return errExpired
}
t.mutex.Lock()
t.bump(fromID)
t.mutex.Unlock()
t.replies <- reply{fromID, pongPacket, req}
return nil
}
func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
if expired(req.Expiration) {
return errExpired
}
t.mutex.Lock()
e := t.bumpOrAdd(fromID, from)
closest := t.closest(req.Target, bucketSize).entries
t.mutex.Unlock()
t.send(e, neighborsPacket, neighbors{
Nodes: closest,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return nil
}
func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
if expired(req.Expiration) {
return errExpired
}
t.mutex.Lock()
t.bump(fromID)
t.add(req.Nodes)
t.mutex.Unlock()
t.replies <- reply{fromID, neighborsPacket, req}
return nil
}
func expired(ts uint64) bool {
return time.Unix(int64(ts), 0).Before(time.Now())
}

@ -0,0 +1,211 @@
package discover
import (
"fmt"
logpkg "log"
"net"
"os"
"testing"
"time"
"github.com/ethereum/go-ethereum/logger"
)
func init() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
}
func TestUDP_ping(t *testing.T) {
t.Parallel()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
defer n1.Close()
defer n2.Close()
if err := n1.net.ping(n2.self); err != nil {
t.Fatalf("ping error: %v", err)
}
if find(n2, n1.self.ID) == nil {
t.Errorf("node 2 does not contain id of node 1")
}
if e := find(n1, n2.self.ID); e != nil {
t.Errorf("node 1 does contains id of node 2: %v", e)
}
}
func find(tab *Table, id NodeID) *Node {
for _, b := range tab.buckets {
for _, e := range b.entries {
if e.ID == id {
return e
}
}
}
return nil
}
func TestUDP_findnode(t *testing.T) {
t.Parallel()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
defer n1.Close()
defer n2.Close()
// put a few nodes into n2. the exact distribution shouldn't
// matter much, altough we need to take care not to overflow
// any bucket.
target := randomID(n1.self.ID, 100)
nodes := &nodesByDistance{target: target}
for i := 0; i < bucketSize; i++ {
n2.add([]*Node{&Node{
IP: net.IP{1, 2, 3, byte(i)},
DiscPort: i + 2,
TCPPort: i + 2,
ID: randomID(n2.self.ID, i+2),
}})
}
n2.add(nodes.entries)
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
expected := n2.closest(target, bucketSize)
err := runUDP(10, func() error {
result, _ := n1.net.findnode(n2.self, target)
if len(result) != bucketSize {
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
}
for i := range result {
if result[i].ID != expected.entries[i].ID {
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
}
}
return nil
})
if err != nil {
t.Error(err)
}
}
func TestUDP_replytimeout(t *testing.T) {
t.Parallel()
// reserve a port so we don't talk to an existing service by accident
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
fd, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatal(err)
}
defer fd.Close()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
defer n1.Close()
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
if err := n1.net.ping(n2); err != errTimeout {
t.Error("expected timeout error, got", err)
}
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
t.Error("expected timeout error, got", err)
} else if len(result) > 0 {
t.Error("expected empty result, got", result)
}
}
func TestUDP_findnodeMultiReply(t *testing.T) {
t.Parallel()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
udp2 := n2.net.(*udp)
defer n1.Close()
defer n2.Close()
err := runUDP(10, func() error {
nodes := make([]*Node, bucketSize)
for i := range nodes {
nodes[i] = &Node{
IP: net.IP{1, 2, 3, 4},
DiscPort: i + 1,
TCPPort: i + 1,
ID: randomID(n2.self.ID, i+1),
}
}
// ask N2 for neighbors. it will send an empty reply back.
// the request will wait for up to bucketSize replies.
resultc := make(chan []*Node)
errc := make(chan error)
go func() {
ns, err := n1.net.findnode(n2.self, n1.self.ID)
if err != nil {
errc <- err
} else {
resultc <- ns
}
}()
// send a few more neighbors packets to N1.
// it should collect those.
for end := 0; end < len(nodes); {
off := end
if end = end + 5; end > len(nodes) {
end = len(nodes)
}
udp2.send(n1.self, neighborsPacket, neighbors{
Nodes: nodes[off:end],
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
})
}
// check that they are all returned. we cannot just check for
// equality because they might not be returned in the order they
// were sent.
var result []*Node
select {
case result = <-resultc:
case err := <-errc:
return err
}
if hasDuplicates(result) {
return fmt.Errorf("result slice contains duplicates")
}
if len(result) != len(nodes) {
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
}
matched := make(map[NodeID]bool)
for _, n := range result {
for _, expn := range nodes {
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
matched[n.ID] = true
}
}
}
if len(matched) != len(nodes) {
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
}
return nil
})
if err != nil {
t.Error(err)
}
}
// runUDP runs a test n times and returns an error if the test failed
// in all n runs. This is necessary because UDP is unreliable even for
// connections on the local machine, causing test failures.
func runUDP(n int, test func() error) error {
errcount := 0
errors := ""
for i := 0; i < n; i++ {
if err := test(); err != nil {
errors += fmt.Sprintf("\n#%d: %v", i, err)
errcount++
}
}
if errcount == n {
return fmt.Errorf("failed on all %d iterations:%s", n, errors)
}
return nil
}

@ -1,6 +1,7 @@
package p2p package p2p
import ( import (
"bufio"
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -8,12 +9,37 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"net"
"sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
// parameters for frameRW
const (
// maximum time allowed for reading a message header.
// this is effectively the amount of time a connection can be idle.
frameReadTimeout = 1 * time.Minute
// maximum time allowed for reading the payload data of a message.
// this is shorter than (and distinct from) frameReadTimeout because
// the connection is not considered idle while a message is transferred.
// this also limits the payload size of messages to how much the connection
// can transfer within the timeout.
payloadReadTimeout = 5 * time.Second
// maximum amount of time allowed for writing a complete message.
msgWriteTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol. this increases
// concurrency in the processing.
wholePayloadSize = 64 * 1024
)
// Msg defines the structure of a p2p message. // Msg defines the structure of a p2p message.
// //
// Note that a Msg can only be sent once since the Payload reader is // Note that a Msg can only be sent once since the Payload reader is
@ -74,11 +100,14 @@ type MsgWriter interface {
// WriteMsg sends a message. It will block until the message's // WriteMsg sends a message. It will block until the message's
// Payload has been consumed by the other end. // Payload has been consumed by the other end.
// //
// Note that messages can be sent only once. // Note that messages can be sent only once because their
// payload reader is drained.
WriteMsg(Msg) error WriteMsg(Msg) error
} }
// MsgReadWriter provides reading and writing of encoded messages. // MsgReadWriter provides reading and writing of encoded messages.
// Implementations should ensure that ReadMsg and WriteMsg can be
// called simultaneously from multiple goroutines.
type MsgReadWriter interface { type MsgReadWriter interface {
MsgReader MsgReader
MsgWriter MsgWriter
@ -90,8 +119,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
return w.WriteMsg(NewMsg(code, data...)) return w.WriteMsg(NewMsg(code, data...))
} }
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
// As required by the interface, ReadMsg and WriteMsg can be called from
// multiple goroutines.
type frameRW struct {
net.Conn // make Conn methods available. be careful.
bufconn *bufio.ReadWriter
// this channel is used to 'lend' bufconn to a caller of ReadMsg
// until the message payload has been consumed. the channel
// receives a value when EOF is reached on the payload, unblocking
// a pending call to ReadMsg.
rsync chan struct{}
// this mutex guards writes to bufconn.
writeMu sync.Mutex
}
func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW {
rsync := make(chan struct{}, 1)
rsync <- struct{}{}
return &frameRW{
Conn: conn,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
rsync: rsync,
}
}
var magicToken = []byte{34, 64, 8, 145} var magicToken = []byte{34, 64, 8, 145}
func (rw *frameRW) WriteMsg(msg Msg) error {
rw.writeMu.Lock()
defer rw.writeMu.Unlock()
rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout))
if err := writeMsg(rw.bufconn, msg); err != nil {
return err
}
return rw.bufconn.Flush()
}
func writeMsg(w io.Writer, msg Msg) error { func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32 // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code)) code := ethutil.Encode(uint32(msg.Code))
@ -120,31 +186,51 @@ func makeListHeader(length uint32) []byte {
return append([]byte{lenb}, enc...) return append([]byte{lenb}, enc...)
} }
// readMsg reads a message header from r. func (rw *frameRW) ReadMsg() (msg Msg, err error) {
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. <-rw.rsync // wait until bufconn is ours
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
rw.SetReadDeadline(time.Now().Add(frameReadTimeout))
// read magic and payload size // read magic and payload size
start := make([]byte, 8) start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil { if _, err = io.ReadFull(rw.bufconn, start); err != nil {
return msg, newPeerError(errRead, "%v", err) return msg, err
} }
if !bytes.HasPrefix(start, magicToken) { if !bytes.HasPrefix(start, magicToken) {
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
} }
size := binary.BigEndian.Uint32(start[4:]) size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code // decode start of RLP message to get the message code
posr := &postrack{r, 0} posr := &postrack{rw.bufconn, 0}
s := rlp.NewStream(posr) s := rlp.NewStream(posr)
if _, err := s.List(); err != nil { if _, err := s.List(); err != nil {
return msg, err return msg, err
} }
code, err := s.Uint() msg.Code, err = s.Uint()
if err != nil { if err != nil {
return msg, err return msg, err
} }
payloadsize := size - posr.p msg.Size = size - posr.p
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
rw.SetReadDeadline(time.Now().Add(payloadReadTimeout))
if msg.Size <= wholePayloadSize {
// msg is small, read all of it and move on to the next message.
pbuf := make([]byte, msg.Size)
if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil {
return msg, err
}
rw.rsync <- struct{}{} // bufconn is available again
msg.Payload = bytes.NewReader(pbuf)
} else {
// lend bufconn to the caller until it has
// consumed the payload. eofSignal will send a value
// on rw.rsync when EOF is reached.
pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync}
msg.Payload = pr
}
return msg, nil
} }
// postrack wraps an rlp.ByteReader with a position counter. // postrack wraps an rlp.ByteReader with a position counter.
@ -167,6 +253,39 @@ func (r *postrack) ReadByte() (byte, error) {
return b, err return b, err
} }
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
type eofSignal struct {
wrapped io.Reader
count uint32 // number of bytes left
eof chan<- struct{}
}
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) {
if r.count == 0 {
if r.eof != nil {
r.eof <- struct{}{}
r.eof = nil
}
return 0, io.EOF
}
max := len(buf)
if int(r.count) < len(buf) {
max = int(r.count)
}
n, err := r.wrapped.Read(buf[:max])
r.count -= uint32(n)
if (err != nil || r.count == 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}
// MsgPipe creates a message pipe. Reads on one end are matched // MsgPipe creates a message pipe. Reads on one end are matched
// with writes on the other. The pipe is full-duplex, both ends // with writes on the other. The pipe is full-duplex, both ends
// implement MsgReadWriter. // implement MsgReadWriter.
@ -198,7 +317,7 @@ type MsgPipeRW struct {
func (p *MsgPipeRW) WriteMsg(msg Msg) error { func (p *MsgPipeRW) WriteMsg(msg Msg) error {
if atomic.LoadInt32(p.closed) == 0 { if atomic.LoadInt32(p.closed) == 0 {
consumed := make(chan struct{}, 1) consumed := make(chan struct{}, 1)
msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed} msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
select { select {
case p.w <- msg: case p.w <- msg:
if msg.Size > 0 { if msg.Size > 0 {

@ -3,12 +3,11 @@ package p2p
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"runtime" "runtime"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil"
) )
func TestNewMsg(t *testing.T) { func TestNewMsg(t *testing.T) {
@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) {
} }
} }
func TestEncodeDecodeMsg(t *testing.T) { // func TestEncodeDecodeMsg(t *testing.T) {
msg := NewMsg(3, 1, "000") // msg := NewMsg(3, 1, "000")
buf := new(bytes.Buffer) // buf := new(bytes.Buffer)
if err := writeMsg(buf, msg); err != nil { // if err := writeMsg(buf, msg); err != nil {
t.Fatalf("encodeMsg error: %v", err) // t.Fatalf("encodeMsg error: %v", err)
} // }
// t.Logf("encoded: %x", buf.Bytes()) // // t.Logf("encoded: %x", buf.Bytes())
decmsg, err := readMsg(buf) // decmsg, err := readMsg(buf)
if err != nil { // if err != nil {
t.Fatalf("readMsg error: %v", err) // t.Fatalf("readMsg error: %v", err)
} // }
if decmsg.Code != 3 { // if decmsg.Code != 3 {
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) // t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
} // }
if decmsg.Size != 5 { // if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) // t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
} // }
var data struct { // var data struct {
I uint // I uint
S string // S string
} // }
if err := decmsg.Decode(&data); err != nil { // if err := decmsg.Decode(&data); err != nil {
t.Fatalf("Decode error: %v", err) // t.Fatalf("Decode error: %v", err)
} // }
if data.I != 1 { // if data.I != 1 {
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) // t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
} // }
if data.S != "000" { // if data.S != "000" {
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") // t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
} // }
} // }
func TestDecodeRealMsg(t *testing.T) { // func TestDecodeRealMsg(t *testing.T) {
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") // data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
msg, err := readMsg(bytes.NewReader(data)) // msg, err := readMsg(bytes.NewReader(data))
if err != nil { // if err != nil {
t.Fatalf("unexpected error: %v", err) // t.Fatalf("unexpected error: %v", err)
} // }
if msg.Code != 0 { // if msg.Code != 0 {
t.Errorf("incorrect code %d, want %d", msg.Code, 0) // t.Errorf("incorrect code %d, want %d", msg.Code, 0)
} // }
} // }
func ExampleMsgPipe() { func ExampleMsgPipe() {
rw1, rw2 := MsgPipe() rw1, rw2 := MsgPipe()
@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) {
go rw1.Close() go rw1.Close()
} }
} }
func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)
// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
}
}

@ -1,23 +0,0 @@
package p2p
import (
"fmt"
"net"
)
func ParseNAT(natType string, gateway string) (nat NAT, err error) {
switch natType {
case "UPNP":
nat = UPNP()
case "PMP":
ip := net.ParseIP(gateway)
if ip == nil {
return nil, fmt.Errorf("cannot resolve PMP gateway IP %s", gateway)
}
nat = PMP(ip)
case "":
default:
return nil, fmt.Errorf("unrecognised NAT type '%s'", natType)
}
return
}

@ -0,0 +1,235 @@
// Package nat provides access to common port mapping protocols.
package nat
import (
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/ethereum/go-ethereum/logger"
"github.com/jackpal/go-nat-pmp"
)
var log = logger.NewLogger("P2P NAT")
// An implementation of nat.Interface can map local ports to ports
// accessible from the Internet.
type Interface interface {
// These methods manage a mapping between a port on the local
// machine to a port that can be connected to from the internet.
//
// protocol is "UDP" or "TCP". Some implementations allow setting
// a display name for the mapping. The mapping may be removed by
// the gateway when its lifetime ends.
AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
DeleteMapping(protocol string, extport, intport int) error
// This method should return the external (Internet-facing)
// address of the gateway device.
ExternalIP() (net.IP, error)
// Should return name of the method. This is used for logging.
String() string
}
// Parse parses a NAT interface description.
// The following formats are currently accepted.
// Note that mechanism names are not case-sensitive.
//
// "" or "none" return nil
// "extip:77.12.33.4" will assume the local machine is reachable on the given IP
// "any" uses the first auto-detected mechanism
// "upnp" uses the Universal Plug and Play protocol
// "pmp" uses NAT-PMP with an auto-detected gateway address
// "pmp:192.168.0.1" uses NAT-PMP with the given gateway address
func Parse(spec string) (Interface, error) {
var (
parts = strings.SplitN(spec, ":", 2)
mech = strings.ToLower(parts[0])
ip net.IP
)
if len(parts) > 1 {
ip = net.ParseIP(parts[1])
if ip == nil {
return nil, errors.New("invalid IP address")
}
}
switch mech {
case "", "none", "off":
return nil, nil
case "any", "auto", "on":
return Any(), nil
case "extip", "ip":
if ip == nil {
return nil, errors.New("missing IP address")
}
return ExtIP(ip), nil
case "upnp":
return UPnP(), nil
case "pmp", "natpmp", "nat-pmp":
return PMP(ip), nil
default:
return nil, fmt.Errorf("unknown mechanism %q", parts[0])
}
}
const (
mapTimeout = 20 * time.Minute
mapUpdateInterval = 15 * time.Minute
)
// Map adds a port mapping on m and keeps it alive until c is closed.
// This function is typically invoked in its own goroutine.
func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) {
refresh := time.NewTimer(mapUpdateInterval)
defer func() {
refresh.Stop()
log.Debugf("Deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
m.DeleteMapping(protocol, extport, intport)
}()
log.Debugf("add mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
log.Errorf("mapping error: %v\n", err)
}
for {
select {
case _, ok := <-c:
if !ok {
return
}
case <-refresh.C:
log.DebugDetailf("refresh mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
log.Errorf("mapping error: %v\n", err)
}
refresh.Reset(mapUpdateInterval)
}
}
}
// ExtIP assumes that the local machine is reachable on the given
// external IP address, and that any required ports were mapped manually.
// Mapping operations will not return an error but won't actually do anything.
func ExtIP(ip net.IP) Interface {
if ip == nil {
panic("IP must not be nil")
}
return extIP(ip)
}
type extIP net.IP
func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
// These do nothing.
func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
func (extIP) DeleteMapping(string, int, int) error { return nil }
// Any returns a port mapper that tries to discover any supported
// mechanism on the local network.
func Any() Interface {
// TODO: attempt to discover whether the local machine has an
// Internet-class address. Return ExtIP in this case.
return startautodisc("UPnP or NAT-PMP", func() Interface {
found := make(chan Interface, 2)
go func() { found <- discoverUPnP() }()
go func() { found <- discoverPMP() }()
for i := 0; i < cap(found); i++ {
if c := <-found; c != nil {
return c
}
}
return nil
})
}
// UPnP returns a port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPnP() Interface {
return startautodisc("UPnP", discoverUPnP)
}
// PMP returns a port mapper that uses NAT-PMP. The provided gateway
// address should be the IP of your router. If the given gateway
// address is nil, PMP will attempt to auto-discover the router.
func PMP(gateway net.IP) Interface {
if gateway != nil {
return &pmp{gw: gateway, c: natpmp.NewClient(gateway)}
}
return startautodisc("NAT-PMP", discoverPMP)
}
// autodisc represents a port mapping mechanism that is still being
// auto-discovered. Calls to the Interface methods on this type will
// wait until the discovery is done and then call the method on the
// discovered mechanism.
//
// This type is useful because discovery can take a while but we
// want return an Interface value from UPnP, PMP and Auto immediately.
type autodisc struct {
what string
done <-chan Interface
mu sync.Mutex
found Interface
}
func startautodisc(what string, doit func() Interface) Interface {
// TODO: monitor network configuration and rerun doit when it changes.
done := make(chan Interface)
ad := &autodisc{what: what, done: done}
go func() { done <- doit(); close(done) }()
return ad
}
func (n *autodisc) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
if err := n.wait(); err != nil {
return err
}
return n.found.AddMapping(protocol, extport, intport, name, lifetime)
}
func (n *autodisc) DeleteMapping(protocol string, extport, intport int) error {
if err := n.wait(); err != nil {
return err
}
return n.found.DeleteMapping(protocol, extport, intport)
}
func (n *autodisc) ExternalIP() (net.IP, error) {
if err := n.wait(); err != nil {
return nil, err
}
return n.found.ExternalIP()
}
func (n *autodisc) String() string {
n.mu.Lock()
defer n.mu.Unlock()
if n.found == nil {
return n.what
} else {
return n.found.String()
}
}
func (n *autodisc) wait() error {
n.mu.Lock()
found := n.found
n.mu.Unlock()
if found != nil {
// already discovered
return nil
}
if found = <-n.done; found == nil {
return errors.New("no devices discovered")
}
n.mu.Lock()
n.found = found
n.mu.Unlock()
return nil
}

@ -0,0 +1,115 @@
package nat
import (
"fmt"
"net"
"strings"
"time"
"github.com/jackpal/go-nat-pmp"
)
// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to
// the common interface.
type pmp struct {
gw net.IP
c *natpmp.Client
}
func (n *pmp) String() string {
return fmt.Sprintf("NAT-PMP(%v)", n.gw)
}
func (n *pmp) ExternalIP() (net.IP, error) {
response, err := n.c.GetExternalAddress()
if err != nil {
return nil, err
}
return response.ExternalIPAddress[:], nil
}
func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
if lifetime <= 0 {
return fmt.Errorf("lifetime must not be <= 0")
}
// Note order of port arguments is switched between our
// AddMapping and the client's AddPortMapping.
_, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second))
return err
}
func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) {
// To destroy a mapping, send an add-port with an internalPort of
// the internal port to destroy, an external port of zero and a
// time of zero.
_, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0)
return err
}
func discoverPMP() Interface {
// run external address lookups on all potential gateways
gws := potentialGateways()
found := make(chan *pmp, len(gws))
for i := range gws {
gw := gws[i]
go func() {
c := natpmp.NewClient(gw)
if _, err := c.GetExternalAddress(); err != nil {
found <- nil
} else {
found <- &pmp{gw, c}
}
}()
}
// return the one that responds first.
// discovery needs to be quick, so we stop caring about
// any responses after a very short timeout.
timeout := time.NewTimer(1 * time.Second)
defer timeout.Stop()
for _ = range gws {
select {
case c := <-found:
if c != nil {
return c
}
case <-timeout.C:
return nil
}
}
return nil
}
var (
// LAN IP ranges
_, lan10, _ = net.ParseCIDR("10.0.0.0/8")
_, lan176, _ = net.ParseCIDR("172.16.0.0/12")
_, lan192, _ = net.ParseCIDR("192.168.0.0/16")
)
// TODO: improve this. We currently assume that (on most networks)
// the router is X.X.X.1 in a local LAN range.
func potentialGateways() (gws []net.IP) {
ifaces, err := net.Interfaces()
if err != nil {
return nil
}
for _, iface := range ifaces {
ifaddrs, err := iface.Addrs()
if err != nil {
return gws
}
for _, addr := range ifaddrs {
switch x := addr.(type) {
case *net.IPNet:
if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) {
ip := x.IP.Mask(x.Mask).To4()
if ip != nil {
ip[3] = ip[3] | 0x01
gws = append(gws, ip)
}
}
}
}
}
return gws
}

@ -0,0 +1,149 @@
package nat
import (
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/fjl/goupnp"
"github.com/fjl/goupnp/dcps/internetgateway1"
"github.com/fjl/goupnp/dcps/internetgateway2"
)
type upnp struct {
dev *goupnp.RootDevice
service string
client upnpClient
}
type upnpClient interface {
GetExternalIPAddress() (string, error)
AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error
DeletePortMapping(string, uint16, string) error
GetNATRSIPStatus() (sip bool, nat bool, err error)
}
func (n *upnp) ExternalIP() (addr net.IP, err error) {
ipString, err := n.client.GetExternalIPAddress()
if err != nil {
return nil, err
}
ip := net.ParseIP(ipString)
if ip == nil {
return nil, errors.New("bad IP in response")
}
return ip, nil
}
func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, lifetime time.Duration) error {
ip, err := n.internalAddress()
if err != nil {
return nil
}
protocol = strings.ToUpper(protocol)
lifetimeS := uint32(lifetime / time.Second)
return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
}
func (n *upnp) internalAddress() (net.IP, error) {
devaddr, err := net.ResolveUDPAddr("udp4", n.dev.URLBase.Host)
if err != nil {
return nil, err
}
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
switch x := addr.(type) {
case *net.IPNet:
if x.Contains(devaddr.IP) {
return x.IP, nil
}
}
}
}
return nil, fmt.Errorf("could not find local address in same net as %v", devaddr)
}
func (n *upnp) DeleteMapping(protocol string, extport, intport int) error {
return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
}
func (n *upnp) String() string {
return "UPNP " + n.service
}
// discoverUPnP searches for Internet Gateway Devices
// and returns the first one it can find on the local network.
func discoverUPnP() Interface {
found := make(chan *upnp, 2)
// IGDv1
go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
switch sc.Service.ServiceType {
case internetgateway1.URN_WANIPConnection_1:
return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{sc}}
case internetgateway1.URN_WANPPPConnection_1:
return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{sc}}
}
return nil
})
// IGDv2
go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
switch sc.Service.ServiceType {
case internetgateway2.URN_WANIPConnection_1:
return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{sc}}
case internetgateway2.URN_WANIPConnection_2:
return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{sc}}
case internetgateway2.URN_WANPPPConnection_1:
return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{sc}}
}
return nil
})
for i := 0; i < cap(found); i++ {
if c := <-found; c != nil {
return c
}
}
return nil
}
func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) {
devs, err := goupnp.DiscoverDevices(target)
if err != nil {
return
}
found := false
for i := 0; i < len(devs) && !found; i++ {
if devs[i].Root == nil {
continue
}
devs[i].Root.Device.VisitServices(func(service *goupnp.Service) {
if found {
return
}
// check for a matching IGD service
sc := goupnp.ServiceClient{service.NewSOAPClient(), devs[i].Root, service}
upnp := matcher(devs[i].Root, sc)
if upnp == nil {
return
}
// check whether port mapping is enabled
if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat {
return
}
out <- upnp
found = true
})
}
if !found {
out <- nil
}
}

@ -1,55 +0,0 @@
package p2p
import (
"fmt"
"net"
"time"
natpmp "github.com/jackpal/go-nat-pmp"
)
// Adapt the NAT-PMP protocol to the NAT interface
// TODO:
// + Register for changes to the external address.
// + Re-register port mapping when router reboots.
// + A mechanism for keeping a port mapping registered.
// + Discover gateway address automatically.
type natPMPClient struct {
client *natpmp.Client
}
// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
// address should be the IP of your router.
func PMP(gateway net.IP) (nat NAT) {
return &natPMPClient{natpmp.NewClient(gateway)}
}
func (*natPMPClient) String() string {
return "NAT-PMP"
}
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
response, err := n.client.GetExternalAddress()
if err != nil {
return nil, err
}
return response.ExternalIPAddress[:], nil
}
func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
if lifetime <= 0 {
return fmt.Errorf("lifetime must not be <= 0")
}
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
_, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
return err
}
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
// To destroy a mapping, send an add-port with
// an internalPort of the internal port to destroy, an external port of zero and a time of zero.
_, err = n.client.AddPortMapping(protocol, internalPort, 0, 0)
return
}

@ -1,341 +0,0 @@
package p2p
// Just enough UPnP to be able to forward ports
//
import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
)
const (
upnpDiscoverAttempts = 3
upnpDiscoverTimeout = 5 * time.Second
)
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPNP() NAT {
return &upnpNAT{}
}
type upnpNAT struct {
serviceURL string
ourIP string
}
func (n *upnpNAT) String() string {
return "UPNP"
}
func (n *upnpNAT) discover() error {
if n.serviceURL != "" {
// already discovered
return nil
}
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil {
return err
}
// TODO: try on all network interfaces simultaneously.
// Broadcasting on 0.0.0.0 could select a random interface
// to send on (platform specific).
conn, err := net.ListenPacket("udp4", ":0")
if err != nil {
return err
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" +
"HOST: 239.255.255.250:1900\r\n" +
st +
"MAN: \"ssdp:discover\"\r\n" +
"MX: 2\r\n\r\n")
message := buf.Bytes()
answerBytes := make([]byte, 1024)
for i := 0; i < upnpDiscoverAttempts; i++ {
_, err = conn.WriteTo(message, ssdp)
if err != nil {
return err
}
nn, _, err := conn.ReadFrom(answerBytes)
if err != nil {
continue
}
answer := string(answerBytes[0:nn])
if strings.Index(answer, "\r\n"+st) < 0 {
continue
}
// HTTP header field names are case-insensitive.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
locString := "\r\nlocation: "
answer = strings.ToLower(answer)
locIndex := strings.Index(answer, locString)
if locIndex < 0 {
continue
}
loc := answer[locIndex+len(locString):]
endIndex := strings.Index(loc, "\r\n")
if endIndex < 0 {
continue
}
locURL := loc[0:endIndex]
var serviceURL string
serviceURL, err = getServiceURL(locURL)
if err != nil {
return err
}
var ourIP string
ourIP, err = getOurIP()
if err != nil {
return err
}
n.serviceURL = serviceURL
n.ourIP = ourIP
return nil
}
return errors.New("UPnP port discovery failed.")
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
if err := n.discover(); err != nil {
return nil, err
}
info, err := n.getStatusInfo()
return net.ParseIP(info.externalIpAddress), err
}
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
if err := n.discover(); err != nil {
return err
}
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
"</NewLeaseDuration></u:AddPortMapping>"
// TODO: check response to see if the port was forwarded
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
return err
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
if err := n.discover(); err != nil {
return err
}
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
// TODO: check response to see if the port was deleted
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
return err
}
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return
}
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return
}
// service represents the Service type in an UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type service struct {
ServiceType string `xml:"serviceType"`
ControlURL string `xml:"controlURL"`
}
// deviceList represents the deviceList type in an UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type deviceList struct {
XMLName xml.Name `xml:"deviceList"`
Device []device `xml:"device"`
}
// serviceList represents the serviceList type in an UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type serviceList struct {
XMLName xml.Name `xml:"serviceList"`
Service []service `xml:"service"`
}
// device represents the device type in an UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type device struct {
XMLName xml.Name `xml:"device"`
DeviceType string `xml:"deviceType"`
DeviceList deviceList `xml:"deviceList"`
ServiceList serviceList `xml:"serviceList"`
}
// specVersion represents the specVersion in a UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type specVersion struct {
XMLName xml.Name `xml:"specVersion"`
Major int `xml:"major"`
Minor int `xml:"minor"`
}
// root represents the Root document for a UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type root struct {
XMLName xml.Name `xml:"root"`
SpecVersion specVersion
Device device
}
func getChildDevice(d *device, deviceType string) *device {
dl := d.DeviceList.Device
for i := 0; i < len(dl); i++ {
if dl[i].DeviceType == deviceType {
return &dl[i]
}
}
return nil
}
func getChildService(d *device, serviceType string) *service {
sl := d.ServiceList.Service
for i := 0; i < len(sl); i++ {
if sl[i].ServiceType == serviceType {
return &sl[i]
}
}
return nil
}
func getOurIP() (ip string, err error) {
hostname, err := os.Hostname()
if err != nil {
return
}
p, err := net.LookupIP(hostname)
if err != nil && len(p) > 0 {
return
}
return p[0].String(), nil
}
func getServiceURL(rootURL string) (url string, err error) {
r, err := http.Get(rootURL)
if err != nil {
return
}
defer r.Body.Close()
if r.StatusCode >= 400 {
err = errors.New(string(r.StatusCode))
return
}
var root root
err = xml.NewDecoder(r.Body).Decode(&root)
if err != nil {
return
}
a := &root.Device
if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" {
err = errors.New("No InternetGatewayDevice")
return
}
b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1")
if b == nil {
err = errors.New("No WANDevice")
return
}
c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1")
if c == nil {
err = errors.New("No WANConnectionDevice")
return
}
d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1")
if d == nil {
err = errors.New("No WANIPConnection")
return
}
url = combineURL(rootURL, d.ControlURL)
return
}
func combineURL(rootURL, subURL string) string {
protocolEnd := "://"
protoEndIndex := strings.Index(rootURL, protocolEnd)
a := rootURL[protoEndIndex+len(protocolEnd):]
rootIndex := strings.Index(a, "/")
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
}
func soapRequest(url, function, message string) (r *http.Response, err error) {
fullMessage := "<?xml version=\"1.0\" ?>" +
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
"<s:Body>" + message + "</s:Body></s:Envelope>"
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
if err != nil {
return
}
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
//req.Header.Set("Transfer-Encoding", "chunked")
req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"")
req.Header.Set("Connection", "Close")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Pragma", "no-cache")
r, err = http.DefaultClient.Do(req)
if err != nil {
return
}
if r.Body != nil {
defer r.Body.Close()
}
if r.StatusCode >= 400 {
// log.Stderr(function, r.StatusCode)
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
r = nil
return
}
return
}

@ -1,8 +1,7 @@
package p2p package p2p
import ( import (
"bufio" "errors"
"bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -11,159 +10,109 @@ import (
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
) )
// peerAddr is the structure of a peer list element. const (
// It is also a valid net.Addr. baseProtocolVersion = 3
type peerAddr struct { baseProtocolLength = uint64(16)
IP net.IP baseProtocolMaxMsgSize = 10 * 1024 * 1024
Port uint64
Pubkey []byte // optional
}
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
n := addr.Network()
if n != "tcp" && n != "tcp4" && n != "tcp6" {
// for testing with non-TCP
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
}
ta := addr.(*net.TCPAddr)
return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
}
func (d peerAddr) Network() string { disconnectGracePeriod = 2 * time.Second
if d.IP.To4() != nil { )
return "tcp4"
} else {
return "tcp6"
}
}
func (d peerAddr) String() string { const (
return fmt.Sprintf("%v:%d", d.IP, d.Port) // devp2p message codes
} handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
func (d *peerAddr) RlpData() interface{} { // handshake is the RLP structure of the protocol handshake.
return []interface{}{string(d.IP), d.Port, d.Pubkey} type handshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
NodeID discover.NodeID
} }
// Peer represents a remote peer. // Peer represents a connected remote node.
type Peer struct { type Peer struct {
// Peers have all the log methods. // Peers have all the log methods.
// Use them to display messages related to the peer. // Use them to display messages related to the peer.
*logger.Logger *logger.Logger
infolock sync.Mutex infoMu sync.Mutex
identity ClientIdentity name string
caps []Cap caps []Cap
listenAddr *peerAddr // what remote peer is listening on
dialAddr *peerAddr // non-nil if dialing ourID, remoteID *discover.NodeID
ourName string
// The mutex protects the connection rw *frameRW
// so only one protocol can write at a time.
writeMu sync.Mutex
conn net.Conn
bufconn *bufio.ReadWriter
// These fields maintain the running protocols. // These fields maintain the running protocols.
protocols []Protocol protocols []Protocol
runBaseProtocol bool // for testing runlock sync.RWMutex // protects running
running map[string]*proto
runlock sync.RWMutex // protects running // disables protocol handshake, for testing
running map[string]*proto noHandshake bool
protoWG sync.WaitGroup protoWG sync.WaitGroup
protoErr chan error protoErr chan error
closed chan struct{} closed chan struct{}
disc chan DiscReason disc chan DiscReason
activity event.TypeMux // for activity events
slot int // index into Server peer list
// These fields are kept so base protocol can access them.
// TODO: this should be one or more interfaces
ourID ClientIdentity // client id of the Server
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
newPeerAddr chan<- *peerAddr // tell server about received peers
otherPeers func() []*Peer // should return the list of all peers
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
} }
// NewPeer returns a peer for testing purposes. // NewPeer returns a peer for testing purposes.
func NewPeer(id ClientIdentity, caps []Cap) *Peer { func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
conn, _ := net.Pipe() conn, _ := net.Pipe()
peer := newPeer(conn, nil, nil) peer := newPeer(conn, nil, "", nil, &id)
peer.setHandshakeInfo(id, nil, caps) peer.setHandshakeInfo(name, caps)
close(peer.closed) close(peer.closed) // ensures Disconnect doesn't block
return peer return peer
} }
func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { // ID returns the node's public key.
p := newPeer(conn, server.Protocols, dialAddr) func (p *Peer) ID() discover.NodeID {
p.ourID = server.Identity return *p.remoteID
p.newPeerAddr = server.peerConnect
p.otherPeers = server.Peers
p.pubkeyHook = server.verifyPeer
p.runBaseProtocol = true
// laddr can be updated concurrently by NAT traversal.
// newServerPeer must be called with the server lock held.
if server.laddr != nil {
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
}
return p
}
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
p := &Peer{
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
conn: conn,
dialAddr: dialAddr,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
}
return p
} }
// Identity returns the client identity of the remote peer. The // Name returns the node name that the remote node advertised.
// identity can be nil if the peer has not yet completed the func (p *Peer) Name() string {
// handshake. // this needs a lock because the information is part of the
func (p *Peer) Identity() ClientIdentity { // protocol handshake.
p.infolock.Lock() p.infoMu.Lock()
defer p.infolock.Unlock() name := p.name
return p.identity p.infoMu.Unlock()
return name
} }
// Caps returns the capabilities (supported subprotocols) of the remote peer. // Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap { func (p *Peer) Caps() []Cap {
p.infolock.Lock() // this needs a lock because the information is part of the
defer p.infolock.Unlock() // protocol handshake.
return p.caps p.infoMu.Lock()
} caps := p.caps
p.infoMu.Unlock()
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { return caps
p.infolock.Lock()
p.identity = id
p.listenAddr = laddr
p.caps = caps
p.infolock.Unlock()
} }
// RemoteAddr returns the remote address of the network connection. // RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr { func (p *Peer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr() return p.rw.RemoteAddr()
} }
// LocalAddr returns the local address of the network connection. // LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr { func (p *Peer) LocalAddr() net.Addr {
return p.conn.LocalAddr() return p.rw.LocalAddr()
} }
// Disconnect terminates the peer connection with the given reason. // Disconnect terminates the peer connection with the given reason.
@ -177,149 +126,177 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer. // String implements fmt.Stringer.
func (p *Peer) String() string { func (p *Peer) String() string {
kind := "inbound" return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
p.infolock.Lock()
if p.dialAddr != nil {
kind = "outbound"
}
p.infolock.Unlock()
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
} }
const ( func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
// maximum amount of time allowed for reading a message logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
msgReadTimeout = 5 * time.Second return &Peer{
// maximum amount of time allowed for writing a message Logger: logger.NewLogger(logtag),
msgWriteTimeout = 5 * time.Second rw: newFrameRW(conn, msgWriteTimeout),
// messages smaller than this many bytes will be read at ourID: ourID,
// once before passing them to a protocol. ourName: ourName,
wholePayloadSize = 64 * 1024 remoteID: remoteID,
) protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
}
}
var ( func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
inactivityTimeout = 2 * time.Second p.infoMu.Lock()
disconnectGracePeriod = 2 * time.Second p.name = name
) p.caps = caps
p.infoMu.Unlock()
}
func (p *Peer) loop() (reason DiscReason, err error) { func (p *Peer) run() DiscReason {
defer p.activity.Stop() var readErr = make(chan error, 1)
defer p.closeProtocols() defer p.closeProtocols()
defer close(p.closed) defer close(p.closed)
defer p.conn.Close()
// read loop
readMsg := make(chan Msg)
readErr := make(chan error)
readNext := make(chan bool, 1)
protoDone := make(chan struct{}, 1)
go p.readLoop(readMsg, readErr, readNext)
readNext <- true
if p.runBaseProtocol {
p.startBaseProtocol()
}
loop: go func() { readErr <- p.readLoop() }()
for {
select { if !p.noHandshake {
case msg := <-readMsg: if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
// a new message has arrived. p.DebugDetailf("Protocol handshake error: %v\n", err)
var wait bool p.rw.Close()
if wait, err = p.dispatch(msg, protoDone); err != nil { return DiscProtocolError
p.Errorf("msg dispatch error: %v\n", err)
reason = discReasonForError(err)
break loop
}
if !wait {
// Msg has already been read completely, continue with next message.
readNext <- true
}
p.activity.Post(time.Now())
case <-protoDone:
// protocol has consumed the message payload,
// we can continue reading from the socket.
readNext <- true
case err := <-readErr:
// read failed. there is no need to run the
// polite disconnect sequence because the connection
// is probably dead anyway.
// TODO: handle write errors as well
return DiscNetworkError, err
case err = <-p.protoErr:
reason = discReasonForError(err)
break loop
case reason = <-p.disc:
break loop
} }
} }
// wait for read loop to return. // Wait for an error or disconnect.
close(readNext) var reason DiscReason
select {
case err := <-readErr:
// We rely on protocols to abort if there is a write error. It
// might be more robust to handle them here as well.
p.DebugDetailf("Read error: %v\n", err)
p.rw.Close()
return DiscNetworkError
case err := <-p.protoErr:
reason = discReasonForError(err)
case reason = <-p.disc:
}
p.politeDisconnect(reason)
// Wait for readLoop. It will end because conn is now closed.
<-readErr <-readErr
// tell the remote end to disconnect p.Debugf("Disconnected: %v\n", reason)
return reason
}
func (p *Peer) politeDisconnect(reason DiscReason) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) EncodeMsg(p.rw, discMsg, uint(reason))
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) // Wait for the other side to close the connection.
io.Copy(ioutil.Discard, p.conn) // Discard any data that they send until then.
io.Copy(ioutil.Discard, p.rw)
close(done) close(done)
}() }()
select { select {
case <-done: case <-done:
case <-time.After(disconnectGracePeriod): case <-time.After(disconnectGracePeriod):
} }
return reason, err p.rw.Close()
} }
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { func (p *Peer) readLoop() error {
for _ = range unblock { if !p.noHandshake {
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) if err := readProtocolHandshake(p, p.rw); err != nil {
if msg, err := readMsg(p.bufconn); err != nil { return err
errc <- err
} else {
msgc <- msg
} }
} }
close(errc) for {
} msg, err := p.rw.ReadMsg()
if err != nil {
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { return err
proto, err := p.getProto(msg.Code) }
if err != nil { if err = p.handle(msg); err != nil {
return false, err return err
}
} }
if msg.Size <= wholePayloadSize { return nil
// optimization: msg is small enough, read all }
// of it and move on to the next message
buf, err := ioutil.ReadAll(msg.Payload) func (p *Peer) handle(msg Msg) error {
switch {
case msg.Code == pingMsg:
msg.Discard()
go EncodeMsg(p.rw, pongMsg)
case msg.Code == discMsg:
var reason DiscReason
// no need to discard or for error checking, we'll close the
// connection after this.
rlp.Decode(msg.Payload, &reason)
p.Disconnect(DiscRequested)
return discRequestedError(reason)
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
return msg.Discard()
default:
// it's a subprotocol message
proto, err := p.getProto(msg.Code)
if err != nil { if err != nil {
return false, err return fmt.Errorf("msg code out of range: %v", msg.Code)
} }
msg.Payload = bytes.NewReader(buf)
proto.in <- msg
} else {
wait = true
pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
msg.Payload = pr
proto.in <- msg proto.in <- msg
} }
return wait, nil return nil
} }
func (p *Peer) startBaseProtocol() { func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
p.runlock.Lock() // read and handle remote handshake
defer p.runlock.Unlock() msg, err := rw.ReadMsg()
p.running[""] = p.startProto(0, Protocol{ if err != nil {
Length: baseProtocolLength, return err
Run: runBaseProtocol, }
}) if msg.Code == discMsg {
// disconnect before protocol handshake is valid according to the
// spec and we send it ourself if Server.addPeer fails.
var reason DiscReason
rlp.Decode(msg.Payload, &reason)
return discRequestedError(reason)
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errInvalidMsg, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if hs.NodeID == *p.remoteID {
return newPeerError(errPubkeyForbidden, "node ID mismatch")
}
// TODO: remove Caps with empty name
p.setHandshakeInfo(hs.Name, hs.Caps)
p.startSubprotocols(hs.Caps)
return nil
}
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
var caps []interface{}
for _, proto := range ps {
caps = append(caps, proto.cap())
}
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
} }
// startProtocols starts matching named subprotocols. // startProtocols starts matching named subprotocols.
func (p *Peer) startSubprotocols(caps []Cap) { func (p *Peer) startSubprotocols(caps []Cap) {
sort.Sort(capsByName(caps)) sort.Sort(capsByName(caps))
p.runlock.Lock() p.runlock.Lock()
defer p.runlock.Unlock() defer p.runlock.Unlock()
offset := baseProtocolLength offset := baseProtocolLength
@ -338,20 +315,22 @@ outer:
} }
func (p *Peer) startProto(offset uint64, impl Protocol) *proto { func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
rw := &proto{ rw := &proto{
name: impl.Name,
in: make(chan Msg), in: make(chan Msg),
offset: offset, offset: offset,
maxcode: impl.Length, maxcode: impl.Length,
peer: p, w: p.rw,
} }
p.protoWG.Add(1) p.protoWG.Add(1)
go func() { go func() {
err := impl.Run(p, rw) err := impl.Run(p, rw)
if err == nil { if err == nil {
p.Infof("protocol %q returned", impl.Name) p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
err = newPeerError(errMisc, "protocol returned") err = errors.New("protocol returned")
} else { } else {
p.Errorf("protocol %q error: %v\n", impl.Name, err) p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
} }
select { select {
case p.protoErr <- err: case p.protoErr <- err:
@ -385,6 +364,7 @@ func (p *Peer) closeProtocols() {
} }
// writeProtoMsg sends the given message on behalf of the given named protocol. // writeProtoMsg sends the given message on behalf of the given named protocol.
// this exists because of Server.Broadcast.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
p.runlock.RLock() p.runlock.RLock()
proto, ok := p.running[protoName] proto, ok := p.running[protoName]
@ -396,25 +376,14 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
} }
msg.Code += proto.offset msg.Code += proto.offset
return p.writeMsg(msg, msgWriteTimeout) return p.rw.WriteMsg(msg)
}
// writeMsg writes a message to the connection.
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
p.writeMu.Lock()
defer p.writeMu.Unlock()
p.conn.SetWriteDeadline(time.Now().Add(timeout))
if err := writeMsg(p.bufconn, msg); err != nil {
return newPeerError(errWrite, "%v", err)
}
return p.bufconn.Flush()
} }
type proto struct { type proto struct {
name string name string
in chan Msg in chan Msg
maxcode, offset uint64 maxcode, offset uint64
peer *Peer w MsgWriter
} }
func (rw *proto) WriteMsg(msg Msg) error { func (rw *proto) WriteMsg(msg Msg) error {
@ -422,11 +391,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
return newPeerError(errInvalidMsgCode, "not handled") return newPeerError(errInvalidMsgCode, "not handled")
} }
msg.Code += rw.offset msg.Code += rw.offset
return rw.peer.writeMsg(msg, msgWriteTimeout) return rw.w.WriteMsg(msg)
}
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return rw.WriteMsg(NewMsg(code, data...))
} }
func (rw *proto) ReadMsg() (Msg, error) { func (rw *proto) ReadMsg() (Msg, error) {
@ -437,26 +402,3 @@ func (rw *proto) ReadMsg() (Msg, error) {
msg.Code -= rw.offset msg.Code -= rw.offset
return msg, nil return msg, nil
} }
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
//
type eofSignal struct {
wrapped io.Reader
count int64
eof chan<- struct{}
}
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
r.count -= int64(n)
if (err != nil || r.count <= 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}

@ -12,7 +12,6 @@ const (
errInvalidMsgCode errInvalidMsgCode
errInvalidMsg errInvalidMsg
errP2PVersionMismatch errP2PVersionMismatch
errPubkeyMissing
errPubkeyInvalid errPubkeyInvalid
errPubkeyForbidden errPubkeyForbidden
errProtocolBreach errProtocolBreach
@ -22,20 +21,19 @@ const (
) )
var errorToString = map[int]string{ var errorToString = map[int]string{
errMagicTokenMismatch: "Magic token mismatch", errMagicTokenMismatch: "magic token mismatch",
errRead: "Read error", errRead: "read error",
errWrite: "Write error", errWrite: "write error",
errMisc: "Misc error", errMisc: "misc error",
errInvalidMsgCode: "Invalid message code", errInvalidMsgCode: "invalid message code",
errInvalidMsg: "Invalid message", errInvalidMsg: "invalid message",
errP2PVersionMismatch: "P2P Version Mismatch", errP2PVersionMismatch: "P2P Version Mismatch",
errPubkeyMissing: "Public key missing", errPubkeyInvalid: "public key invalid",
errPubkeyInvalid: "Public key invalid", errPubkeyForbidden: "public key forbidden",
errPubkeyForbidden: "Public key forbidden", errProtocolBreach: "protocol Breach",
errProtocolBreach: "Protocol Breach", errPingTimeout: "ping timeout",
errPingTimeout: "Ping timeout", errInvalidNetworkId: "invalid network id",
errInvalidNetworkId: "Invalid network id", errInvalidProtocolVersion: "invalid protocol version",
errInvalidProtocolVersion: "Invalid protocol version",
} }
type peerError struct { type peerError struct {
@ -62,22 +60,22 @@ func (self *peerError) Error() string {
type DiscReason byte type DiscReason byte
const ( const (
DiscRequested DiscReason = 0x00 DiscRequested DiscReason = iota
DiscNetworkError = 0x01 DiscNetworkError
DiscProtocolError = 0x02 DiscProtocolError
DiscUselessPeer = 0x03 DiscUselessPeer
DiscTooManyPeers = 0x04 DiscTooManyPeers
DiscAlreadyConnected = 0x05 DiscAlreadyConnected
DiscIncompatibleVersion = 0x06 DiscIncompatibleVersion
DiscInvalidIdentity = 0x07 DiscInvalidIdentity
DiscQuitting = 0x08 DiscQuitting
DiscUnexpectedIdentity = 0x09 DiscUnexpectedIdentity
DiscSelf = 0x0a DiscSelf
DiscReadTimeout = 0x0b DiscReadTimeout
DiscSubprotocolError = 0x10 DiscSubprotocolError
) )
var discReasonToString = [DiscSubprotocolError + 1]string{ var discReasonToString = [...]string{
DiscRequested: "Disconnect requested", DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error", DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol", DiscProtocolError: "Breach of protocol",
@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason {
switch peerError.Code { switch peerError.Code {
case errP2PVersionMismatch: case errP2PVersionMismatch:
return DiscIncompatibleVersion return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid: case errPubkeyInvalid:
return DiscInvalidIdentity return DiscInvalidIdentity
case errPubkeyForbidden: case errPubkeyForbidden:
return DiscUselessPeer return DiscUselessPeer
@ -125,7 +123,7 @@ func discReasonForError(err error) DiscReason {
return DiscProtocolError return DiscProtocolError
case errPingTimeout: case errPingTimeout:
return DiscReadTimeout return DiscReadTimeout
case errRead, errWrite, errMisc: case errRead, errWrite:
return DiscNetworkError return DiscNetworkError
default: default:
return DiscSubprotocolError return DiscSubprotocolError

@ -1,15 +1,17 @@
package p2p package p2p
import ( import (
"bufio"
"bytes" "bytes"
"encoding/hex" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net" "net"
"reflect" "reflect"
"sort"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
) )
var discard = Protocol{ var discard = Protocol{
@ -28,17 +30,13 @@ var discard = Protocol{
}, },
} }
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
peer := newPeer(conn1, protos, nil) peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
peer.ourID = &peerId{} peer.noHandshake = noHandshake
peer.pubkeyHook = func(*peerAddr) error { return nil } errc := make(chan DiscReason, 1)
errc := make(chan error, 1) go func() { errc <- peer.run() }()
go func() { return newFrameRW(conn2, msgWriteTimeout), peer, errc
_, err := peer.loop()
errc <- err
}()
return conn2, peer, errc
} }
func TestPeerProtoReadMsg(t *testing.T) { func TestPeerProtoReadMsg(t *testing.T) {
@ -49,31 +47,28 @@ func TestPeerProtoReadMsg(t *testing.T) {
Name: "a", Name: "a",
Length: 5, Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error { Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg() if err := expectMsg(rw, 2, []uint{1}); err != nil {
if err != nil { t.Error(err)
t.Errorf("read error: %v", err)
} }
if msg.Code != 2 { if err := expectMsg(rw, 3, []uint{2}); err != nil {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) t.Error(err)
}
data, err := ioutil.ReadAll(msg.Payload)
if err != nil {
t.Errorf("payload read error: %v", err)
} }
expdata, _ := hex.DecodeString("0183303030") if err := expectMsg(rw, 4, []uint{3}); err != nil {
if !bytes.Equal(expdata, data) { t.Error(err)
t.Errorf("incorrect msg data %x", data)
} }
close(done) close(done)
return nil return nil
}, },
} }
net, peer, errc := testPeer([]Protocol{proto}) rw, peer, errc := testPeer(true, []Protocol{proto})
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, 1, "000")) EncodeMsg(rw, baseProtocolLength+2, 1)
EncodeMsg(rw, baseProtocolLength+3, 2)
EncodeMsg(rw, baseProtocolLength+4, 3)
select { select {
case <-done: case <-done:
case err := <-errc: case err := <-errc:
@ -105,11 +100,11 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
}, },
} }
net, peer, errc := testPeer([]Protocol{proto}) rw, peer, errc := testPeer(true, []Protocol{proto})
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize))) EncodeMsg(rw, 18, make([]byte, msgsize))
select { select {
case <-done: case <-done:
case err := <-errc: case err := <-errc:
@ -135,32 +130,20 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
return nil return nil
}, },
} }
net, peer, _ := testPeer([]Protocol{proto}) rw, peer, _ := testPeer(true, []Protocol{proto})
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
bufr := bufio.NewReader(net) if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
msg, err := readMsg(bufr) t.Error(err)
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
var data []string
if err := msg.Decode(&data); err != nil {
t.Errorf("payload decode error: %v", err)
}
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
} }
} }
func TestPeerWrite(t *testing.T) { func TestPeerWriteForBroadcast(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard}) rw, peer, peerErr := testPeer(true, []Protocol{discard})
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{discard.cap()}) peer.startSubprotocols([]Cap{discard.cap()})
// test write errors // test write errors
@ -176,18 +159,13 @@ func TestPeerWrite(t *testing.T) {
// setup for reading the message on the other end // setup for reading the message on the other end
read := make(chan struct{}) read := make(chan struct{})
go func() { go func() {
bufr := bufio.NewReader(net) if err := expectMsg(rw, 16, nil); err != nil {
msg, err := readMsg(bufr) t.Error()
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
} }
msg.Discard()
close(read) close(read)
}() }()
// test succcessful write // test successful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err) t.Errorf("expect no error for known protocol: %v", err)
} }
@ -198,104 +176,153 @@ func TestPeerWrite(t *testing.T) {
} }
} }
func TestPeerActivity(t *testing.T) { func TestPeerPing(t *testing.T) {
// shorten inactivityTimeout while this test is running defer testlog(t).detach()
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
net, peer, peerErr := testPeer([]Protocol{discard}) rw, _, _ := testPeer(true, nil)
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{discard.cap()}) if err := EncodeMsg(rw, pingMsg); err != nil {
t.Fatal(err)
}
if err := expectMsg(rw, pongMsg, nil); err != nil {
t.Error(err)
}
}
sub := peer.activity.Subscribe(time.Time{}) func TestPeerDisconnect(t *testing.T) {
defer sub.Unsubscribe() defer testlog(t).detach()
for i := 0; i < 6; i++ { rw, _, disc := testPeer(true, nil)
writeMsg(net, NewMsg(16)) defer rw.Close()
select { if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
case <-sub.Chan(): t.Fatal(err)
case <-time.After(inactivityTimeout / 2):
t.Fatal("no event within ", inactivityTimeout/2)
case err := <-peerErr:
t.Fatal("peer error", err)
}
} }
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
select { t.Error(err)
case <-time.After(inactivityTimeout * 2): }
case <-sub.Chan(): rw.Close() // make test end faster
t.Fatal("got activity event while connection was inactive") if reason := <-disc; reason != DiscRequested {
case err := <-peerErr: t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
t.Fatal("peer error", err)
} }
} }
func TestNewPeer(t *testing.T) { func TestPeerHandshake(t *testing.T) {
caps := []Cap{{"foo", 2}, {"bar", 3}} defer testlog(t).detach()
id := &peerId{}
p := NewPeer(id, caps) // remote has two matching protocols: a and c
if !reflect.DeepEqual(p.Caps(), caps) { remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) remoteID := randomID()
remote.ourID = &remoteID
remote.ourName = "remote peer"
start := make(chan string)
stop := make(chan struct{})
run := func(p *Peer, rw MsgReadWriter) error {
name := rw.(*proto).name
if name != "a" && name != "c" {
t.Errorf("protocol %q should not be started", name)
} else {
start <- name
}
<-stop
return nil
} }
if p.Identity() != id { protocols := []Protocol{
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) {Name: "a", Version: 1, Length: 1, Run: run},
{Name: "b", Version: 2, Length: 1, Run: run},
{Name: "c", Version: 3, Length: 1, Run: run},
{Name: "d", Version: 4, Length: 1, Run: run},
} }
// Should not hang. rw, p, disc := testPeer(false, protocols)
p.Disconnect(DiscAlreadyConnected) p.remoteID = remote.ourID
} defer rw.Close()
func TestEOFSignal(t *testing.T) { // run the handshake
rb := make([]byte, 10) remoteProtocols := []Protocol{protocols[0], protocols[2]}
if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
t.Fatalf("handshake write error: %v", err)
}
if err := readProtocolHandshake(remote, rw); err != nil {
t.Fatalf("handshake read error: %v", err)
}
// empty reader // check that all protocols have been started
eof := make(chan struct{}, 1) var started []string
sig := &eofSignal{new(bytes.Buffer), 0, eof} for i := 0; i < 2; i++ {
if n, err := sig.Read(rb); n != 0 || err != io.EOF { select {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err) case name := <-start:
started = append(started, name)
case <-time.After(100 * time.Millisecond):
}
} }
select { sort.Strings(started)
case <-eof: if !reflect.DeepEqual(started, []string{"a", "c"}) {
default: t.Errorf("wrong protocols started: %v", started)
t.Error("EOF chan not signaled")
} }
// count before error // check that metadata has been set
eof = make(chan struct{}, 1) if p.ID() != remoteID {
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
if n, err := sig.Read(rb); n != 8 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
} }
select { if p.Name() != remote.ourName {
case <-eof: t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
default:
t.Error("EOF chan not signaled")
} }
// error before count close(stop)
eof = make(chan struct{}, 1) expectMsg(rw, discMsg, nil)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} t.Logf("disc reason: %v", <-disc)
if n, err := sig.Read(rb); n != 4 || err != nil { }
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
func TestNewPeer(t *testing.T) {
name := "nodename"
caps := []Cap{{"foo", 2}, {"bar", 3}}
id := randomID()
p := NewPeer(id, name, caps)
if p.ID() != id {
t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
} }
if n, err := sig.Read(rb); n != 0 || err != io.EOF { if p.Name() != name {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err) t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
} }
select { if !reflect.DeepEqual(p.Caps(), caps) {
case <-eof: t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
default:
t.Error("EOF chan not signaled")
} }
// no signal if neither occurs p.Disconnect(DiscAlreadyConnected) // Should not hang
eof = make(chan struct{}, 1) }
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil { // expectMsg reads a message from r and verifies that its
t.Errorf("Read returned unexpected values: (%v, %v)", n, err) // code and encoded RLP content match the provided values.
// If content is nil, the payload is discarded and not verified.
func expectMsg(r MsgReader, code uint64, content interface{}) error {
msg, err := r.ReadMsg()
if err != nil {
return err
} }
select { if msg.Code != code {
case <-eof: return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
t.Error("unexpected EOF signal") }
default: if content == nil {
return msg.Discard()
} else {
contentEnc, err := rlp.EncodeToBytes(content)
if err != nil {
panic("content encode error: " + err.Error())
}
// skip over list header in encoded value. this is temporary.
contentEncR := bytes.NewReader(contentEnc)
if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
panic("content must encode as RLP list")
}
contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
actualContent, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return err
}
if !bytes.Equal(actualContent, contentEnc) {
return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc)
}
} }
return nil
} }

@ -1,10 +1,5 @@
package p2p package p2p
import (
"bytes"
"time"
)
// Protocol represents a P2P subprotocol implementation. // Protocol represents a P2P subprotocol implementation.
type Protocol struct { type Protocol struct {
// Name should contain the official protocol name, // Name should contain the official protocol name,
@ -32,38 +27,6 @@ func (p Protocol) cap() Cap {
return Cap{p.Name, p.Version} return Cap{p.Name, p.Version}
} }
const (
baseProtocolVersion = 2
baseProtocolLength = uint64(16)
baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
// devp2p message codes
handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
// handshake is the structure of a handshake list.
type handshake struct {
Version uint64
ID string
Caps []Cap
ListenPort uint64
NodeID []byte
}
func (h *handshake) String() string {
return h.ID
}
func (h *handshake) Pubkey() []byte {
return h.NodeID
}
// Cap is the structure of a peer capability. // Cap is the structure of a peer capability.
type Cap struct { type Cap struct {
Name string Name string
@ -79,210 +42,3 @@ type capsByName []Cap
func (cs capsByName) Len() int { return len(cs) } func (cs capsByName) Len() int { return len(cs) }
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
type baseProtocol struct {
rw MsgReadWriter
peer *Peer
}
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer}
errc := make(chan error, 1)
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
if err := bp.readHandshake(); err != nil {
return err
}
// handle write error
if err := <-errc; err != nil {
return err
}
// run main loop
go func() {
for {
if err := bp.handle(rw); err != nil {
errc <- err
break
}
}
}()
return bp.loop(errc)
}
var pingTimeout = 2 * time.Second
func (bp *baseProtocol) loop(quit <-chan error) error {
ping := time.NewTimer(pingTimeout)
activity := bp.peer.activity.Subscribe(time.Time{})
lastActive := time.Time{}
defer ping.Stop()
defer activity.Unsubscribe()
getPeersTick := time.NewTicker(10 * time.Second)
defer getPeersTick.Stop()
err := EncodeMsg(bp.rw, getPeersMsg)
for err == nil {
select {
case err = <-quit:
return err
case <-getPeersTick.C:
err = EncodeMsg(bp.rw, getPeersMsg)
case event := <-activity.Chan():
ping.Reset(pingTimeout)
lastActive = event.(time.Time)
case t := <-ping.C:
if lastActive.Add(pingTimeout * 2).Before(t) {
err = newPeerError(errPingTimeout, "")
} else if lastActive.Add(pingTimeout).Before(t) {
err = EncodeMsg(bp.rw, pingMsg)
}
}
}
return err
}
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
// make sure that the payload has been fully consumed
defer msg.Discard()
switch msg.Code {
case handshakeMsg:
return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg:
var reason [1]DiscReason
if err := msg.Decode(&reason); err != nil {
return err
}
return discRequestedError(reason[0])
case pingMsg:
return EncodeMsg(bp.rw, pongMsg)
case pongMsg:
case getPeersMsg:
peers := bp.peerList()
// this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available.
// this means that would need to send a response later when
// new peers become available.
//
// TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 {
return EncodeMsg(bp.rw, peersMsg, peers...)
}
case peersMsg:
var peers []*peerAddr
if err := msg.Decode(&peers); err != nil {
return err
}
for _, addr := range peers {
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
}
default:
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
}
return nil
}
func (bp *baseProtocol) readHandshake() error {
// read and handle remote handshake
msg, err := bp.rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if len(hs.NodeID) == 0 {
return newPeerError(errPubkeyMissing, "")
}
if len(hs.NodeID) != 64 {
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
}
if da := bp.peer.dialAddr; da != nil {
// verify that the peer we wanted to connect to
// actually holds the target public key.
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
}
}
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err)
}
// TODO: remove Caps with empty name
var addr *peerAddr
if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
addr.Port = hs.ListenPort
}
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
bp.peer.startSubprotocols(hs.Caps)
return nil
}
func (bp *baseProtocol) handshakeMsg() Msg {
var (
port uint64
caps []interface{}
)
if bp.peer.ourListenAddr != nil {
port = bp.peer.ourListenAddr.Port
}
for _, proto := range bp.peer.protocols {
caps = append(caps, proto.cap())
}
return NewMsg(handshakeMsg,
baseProtocolVersion,
bp.peer.ourID.String(),
caps,
port,
bp.peer.ourID.Pubkey()[1:],
)
}
func (bp *baseProtocol) peerList() []interface{} {
peers := bp.peer.otherPeers()
ds := make([]interface{}, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}

@ -1,158 +0,0 @@
package p2p
import (
"fmt"
"net"
"reflect"
"sync"
"testing"
"github.com/ethereum/go-ethereum/crypto"
)
type peerId struct {
pubkey []byte
}
func (self *peerId) String() string {
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func newTestPeer() (peer *Peer) {
peer = NewPeer(&peerId{}, []Cap{})
peer.pubkeyHook = func(*peerAddr) error { return nil }
peer.ourID = &peerId{}
peer.listenAddr = &peerAddr{}
peer.otherPeers = func() []*Peer { return nil }
return
}
func TestBaseProtocolPeers(t *testing.T) {
peerList := []*peerAddr{
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
}
listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
rw1, rw2 := MsgPipe()
defer rw1.Close()
wg := new(sync.WaitGroup)
// run matcher, close pipe when addresses have arrived
numPeers := len(peerList) + 1
addrChan := make(chan *peerAddr)
wg.Add(1)
go func() {
i := 0
for got := range addrChan {
var want *peerAddr
switch {
case i < len(peerList):
want = peerList[i]
case i == len(peerList):
want = listenAddr // listenAddr should be the last thing sent
}
t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
if !reflect.DeepEqual(want, got) {
t.Errorf("mismatch: got %+v, want %+v", got, want)
}
i++
if i == numPeers {
break
}
}
if i != numPeers {
t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
}
rw1.Close()
wg.Done()
}()
// run first peer (in background)
peer1 := newTestPeer()
peer1.ourListenAddr = listenAddr
peer1.otherPeers = func() []*Peer {
pl := make([]*Peer, len(peerList))
for i, addr := range peerList {
pl[i] = &Peer{listenAddr: addr}
}
return pl
}
wg.Add(1)
go func() {
runBaseProtocol(peer1, rw1)
wg.Done()
}()
// run second peer
peer2 := newTestPeer()
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
t.Errorf("peer2 terminated with unexpected error: %v", err)
}
// terminate matcher
close(addrChan)
wg.Wait()
}
func TestBaseProtocolDisconnect(t *testing.T) {
peer := NewPeer(&peerId{}, nil)
peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil }
rw1, rw2 := MsgPipe()
done := make(chan struct{})
go func() {
if err := expectMsg(rw2, handshakeMsg); err != nil {
t.Error(err)
}
err := EncodeMsg(rw2, handshakeMsg,
baseProtocolVersion,
"",
[]interface{}{},
0,
make([]byte, 64),
)
if err != nil {
t.Error(err)
}
if err := expectMsg(rw2, getPeersMsg); err != nil {
t.Error(err)
}
if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
t.Error(err)
}
close(done)
}()
if err := runBaseProtocol(peer, rw1); err == nil {
t.Errorf("base protocol returned without error")
} else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
t.Errorf("base protocol returned wrong error: %v", err)
}
<-done
}
func expectMsg(r MsgReader, code uint64) error {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if err := msg.Discard(); err != nil {
return err
}
if msg.Code != code {
return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
}
return nil
}

@ -2,37 +2,56 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/ecdsa"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"runtime"
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/nat"
) )
const ( const (
outboundAddressPoolSize = 500 handshakeTimeout = 5 * time.Second
defaultDialTimeout = 10 * time.Second defaultDialTimeout = 10 * time.Second
portMappingUpdateInterval = 15 * time.Minute refreshPeersInterval = 30 * time.Second
portMappingTimeout = 20 * time.Minute
) )
var srvlog = logger.NewLogger("P2P Server") var srvlog = logger.NewLogger("P2P Server")
// MakeName creates a node name that follows the ethereum convention
// for such names. It adds the operation system name and Go runtime version
// the name.
func MakeName(name, version string) string {
return fmt.Sprintf("%s/v%s/%s/%s", name, version, runtime.GOOS, runtime.Version())
}
// Server manages all peer connections. // Server manages all peer connections.
// //
// The fields of Server are used as configuration parameters. // The fields of Server are used as configuration parameters.
// You should set them before starting the Server. Fields may not be // You should set them before starting the Server. Fields may not be
// modified while the server is running. // modified while the server is running.
type Server struct { type Server struct {
// This field must be set to a valid client identity. // This field must be set to a valid secp256k1 private key.
Identity ClientIdentity PrivateKey *ecdsa.PrivateKey
// MaxPeers is the maximum number of peers that can be // MaxPeers is the maximum number of peers that can be
// connected. It must be greater than zero. // connected. It must be greater than zero.
MaxPeers int MaxPeers int
// Name sets the node name of this server.
// Use MakeName to create a name that follows existing conventions.
Name string
// Bootstrap nodes are used to establish connectivity
// with the rest of the network.
BootstrapNodes []*discover.Node
// Protocols should contain the protocols supported // Protocols should contain the protocols supported
// by the server. Matching protocols are launched for // by the server. Matching protocols are launched for
// each peer. // each peer.
@ -53,7 +72,7 @@ type Server struct {
// If set to a non-nil value, the given NAT port mapper // If set to a non-nil value, the given NAT port mapper
// is used to make the listening port available to the // is used to make the listening port available to the
// Internet. // Internet.
NAT NAT NAT nat.Interface
// If Dialer is set to a non-nil value, the given Dialer // If Dialer is set to a non-nil value, the given Dialer
// is used to dial outbound peer connections. // is used to dial outbound peer connections.
@ -62,35 +81,26 @@ type Server struct {
// If NoDial is true, the server will not dial any peers. // If NoDial is true, the server will not dial any peers.
NoDial bool NoDial bool
// Hook for testing. This is useful because we can inhibit // Hooks for testing. These are useful because we can inhibit
// the whole protocol stack. // the whole protocol stack.
newPeerFunc peerFunc handshakeFunc
newPeerHook
lock sync.RWMutex lock sync.RWMutex
running bool running bool
listener net.Listener listener net.Listener
laddr *net.TCPAddr // real listen addr peers map[discover.NodeID]*Peer
peers []*Peer
peerSlots chan int
peerCount int
quit chan struct{}
wg sync.WaitGroup
peerConnect chan *peerAddr
peerDisconnect chan *Peer
}
// NAT is implemented by NAT traversal methods. ntab *discover.Table
type NAT interface {
GetExternalAddress() (net.IP, error)
AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
DeletePortMapping(protocol string, extport, intport int) error
// Should return name of the method. quit chan struct{}
String() string loopWG sync.WaitGroup // {dial,listen,nat}Loop
peerWG sync.WaitGroup // active peer goroutines
peerConnect chan *discover.Node
} }
type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
type newPeerHook func(*Peer)
// Peers returns all connected peers. // Peers returns all connected peers.
func (srv *Server) Peers() (peers []*Peer) { func (srv *Server) Peers() (peers []*Peer) {
@ -107,18 +117,15 @@ func (srv *Server) Peers() (peers []*Peer) {
// PeerCount returns the number of connected peers. // PeerCount returns the number of connected peers.
func (srv *Server) PeerCount() int { func (srv *Server) PeerCount() int {
srv.lock.RLock() srv.lock.RLock()
defer srv.lock.RUnlock() n := len(srv.peers)
return srv.peerCount srv.lock.RUnlock()
return n
} }
// SuggestPeer injects an address into the outbound address pool. // SuggestPeer creates a connection to the given Node if it
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { // is not already connected.
addr := &peerAddr{ip, uint64(port), nodeID} func (srv *Server) SuggestPeer(n *discover.Node) {
select { srv.peerConnect <- n
case srv.peerConnect <- addr:
default: // don't block
srvlog.Warnf("peer suggestion %v ignored", addr)
}
} }
// Broadcast sends an RLP-encoded message to all connected peers. // Broadcast sends an RLP-encoded message to all connected peers.
@ -152,47 +159,46 @@ func (srv *Server) Start() (err error) {
} }
srvlog.Infoln("Starting Server") srvlog.Infoln("Starting Server")
// initialize fields // initialize all the fields
if srv.Identity == nil { if srv.PrivateKey == nil {
return fmt.Errorf("Server.Identity must be set to a non-nil identity") return fmt.Errorf("Server.PrivateKey must be set to a non-nil key")
} }
if srv.MaxPeers <= 0 { if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0") return fmt.Errorf("Server.MaxPeers must be > 0")
} }
srv.quit = make(chan struct{}) srv.quit = make(chan struct{})
srv.peers = make([]*Peer, srv.MaxPeers) srv.peers = make(map[discover.NodeID]*Peer)
srv.peerSlots = make(chan int, srv.MaxPeers) srv.peerConnect = make(chan *discover.Node)
srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
srv.peerDisconnect = make(chan *Peer) if srv.handshakeFunc == nil {
if srv.newPeerFunc == nil { srv.handshakeFunc = encHandshake
srv.newPeerFunc = newServerPeer
} }
if srv.Blacklist == nil { if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist() srv.Blacklist = NewBlacklist()
} }
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if srv.ListenAddr != "" { if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil { if err := srv.startListening(); err != nil {
return err return err
} }
} }
// dial stuff
dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
if err != nil {
return err
}
srv.ntab = dt
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if !srv.NoDial { if !srv.NoDial {
srv.wg.Add(1) srv.loopWG.Add(1)
go srv.dialLoop() go srv.dialLoop()
} }
if srv.NoDial && srv.ListenAddr == "" { if srv.NoDial && srv.ListenAddr == "" {
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
} }
// make all slots available
for i := range srv.peers {
srv.peerSlots <- i
}
// note: discLoop is not part of WaitGroup
go srv.discLoop()
srv.running = true srv.running = true
return nil return nil
} }
@ -202,14 +208,17 @@ func (srv *Server) startListening() error {
if err != nil { if err != nil {
return err return err
} }
srv.ListenAddr = listener.Addr().String() laddr := listener.Addr().(*net.TCPAddr)
srv.laddr = listener.Addr().(*net.TCPAddr) srv.ListenAddr = laddr.String()
srv.listener = listener srv.listener = listener
srv.wg.Add(1) srv.loopWG.Add(1)
go srv.listenLoop() go srv.listenLoop()
if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { if !laddr.IP.IsLoopback() && srv.NAT != nil {
srv.wg.Add(1) srv.loopWG.Add(1)
go srv.natLoop(srv.laddr.Port) go func() {
nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
srv.loopWG.Done()
}()
} }
return nil return nil
} }
@ -225,200 +234,171 @@ func (srv *Server) Stop() {
srv.running = false srv.running = false
srv.lock.Unlock() srv.lock.Unlock()
srvlog.Infoln("Stopping server") srvlog.Infoln("Stopping Server")
srv.ntab.Close()
if srv.listener != nil { if srv.listener != nil {
// this unblocks listener Accept // this unblocks listener Accept
srv.listener.Close() srv.listener.Close()
} }
close(srv.quit) close(srv.quit)
for _, peer := range srv.Peers() { srv.loopWG.Wait()
peer.Disconnect(DiscQuitting)
}
srv.wg.Wait()
// wait till they actually disconnect
// this is checked by claiming all peerSlots.
// slots become available as the peers disconnect.
for i := 0; i < cap(srv.peerSlots); i++ {
<-srv.peerSlots
}
// terminate discLoop
close(srv.peerDisconnect)
}
func (srv *Server) discLoop() { // No new peers can be added at this point because dialLoop and
for peer := range srv.peerDisconnect { // listenLoop are down. It is safe to call peerWG.Wait because
srv.removePeer(peer) // peerWG.Add is not called outside of those loops.
for _, peer := range srv.peers {
peer.Disconnect(DiscQuitting)
} }
srv.peerWG.Wait()
} }
// main loop for adding connections via listening // main loop for adding connections via listening
func (srv *Server) listenLoop() { func (srv *Server) listenLoop() {
defer srv.wg.Done() defer srv.loopWG.Done()
srvlog.Infoln("Listening on", srv.listener.Addr()) srvlog.Infoln("Listening on", srv.listener.Addr())
for { for {
select { conn, err := srv.listener.Accept()
case slot := <-srv.peerSlots: if err != nil {
srvlog.Debugf("grabbed slot %v for listening", slot)
conn, err := srv.listener.Accept()
if err != nil {
srv.peerSlots <- slot
return
}
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
srv.addPeer(conn, nil, slot)
case <-srv.quit:
return return
} }
srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr())
srv.peerWG.Add(1)
go srv.startPeer(conn, nil)
} }
} }
func (srv *Server) natLoop(port int) { func (srv *Server) dialLoop() {
defer srv.wg.Done() defer srv.loopWG.Done()
refresh := time.NewTicker(refreshPeersInterval)
defer refresh.Stop()
srv.ntab.Bootstrap(srv.BootstrapNodes)
go srv.findPeers()
dialed := make(chan *discover.Node)
dialing := make(map[discover.NodeID]bool)
// TODO: limit number of active dials
// TODO: ensure only one findPeers goroutine is running
// TODO: pause findPeers when we're at capacity
for { for {
srv.updatePortMapping(port)
select { select {
case <-time.After(portMappingUpdateInterval): case <-refresh.C:
// one more round
go srv.findPeers()
case dest := <-srv.peerConnect:
// avoid dialing nodes that are already connected.
// there is another check for this in addPeer,
// which runs after the handshake.
srv.lock.Lock()
_, isconnected := srv.peers[dest.ID]
srv.lock.Unlock()
if isconnected || dialing[dest.ID] || dest.ID == srv.ntab.Self() {
continue
}
dialing[dest.ID] = true
srv.peerWG.Add(1)
go func() {
srv.dialNode(dest)
// at this point, the peer has been added
// or discarded. either way, we're not dialing it anymore.
dialed <- dest
}()
case dest := <-dialed:
delete(dialing, dest.ID)
case <-srv.quit: case <-srv.quit:
srv.removePortMapping(port) // TODO: maybe wait for active dials
return return
} }
} }
} }
func (srv *Server) updatePortMapping(port int) { func (srv *Server) dialNode(dest *discover.Node) {
srvlog.Infoln("Attempting to map port", port, "with", srv.NAT) addr := &net.TCPAddr{IP: dest.IP, Port: dest.TCPPort}
err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout) srvlog.Debugf("Dialing %v\n", dest)
conn, err := srv.Dialer.Dial("tcp", addr.String())
if err != nil { if err != nil {
srvlog.Errorln("Port mapping error:", err) srvlog.DebugDetailf("dial error: %v", err)
return
}
extip, err := srv.NAT.GetExternalAddress()
if err != nil {
srvlog.Errorln("Error getting external IP:", err)
return return
} }
srv.lock.Lock() srv.startPeer(conn, dest)
extaddr := *(srv.listener.Addr().(*net.TCPAddr))
extaddr.IP = extip
srvlog.Infoln("Mapped port, external addr is", &extaddr)
srv.laddr = &extaddr
srv.lock.Unlock()
} }
func (srv *Server) removePortMapping(port int) { func (srv *Server) findPeers() {
srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT) far := srv.ntab.Self()
srv.NAT.DeletePortMapping("tcp", port, port) for i := range far {
} far[i] = ^far[i]
}
func (srv *Server) dialLoop() { closeToSelf := srv.ntab.Lookup(srv.ntab.Self())
defer srv.wg.Done() farFromSelf := srv.ntab.Lookup(far)
var (
suggest chan *peerAddr
slot *int
slots = srv.peerSlots
)
for {
select {
case i := <-slots:
// we need a peer in slot i, slot reserved
slot = &i
// now we can watch for candidate peers in the next loop
suggest = srv.peerConnect
// do not consume more until candidate peer is found
slots = nil
case desc := <-suggest:
// candidate peer found, will dial out asyncronously
// if connection fails slot will be released
srvlog.DebugDetailf("dial %v (%v)", desc, *slot)
go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop
slots = srv.peerSlots
// until then we dont care about candidate peers
suggest = nil
case <-srv.quit: for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ {
// give back the currently reserved slot if i < len(closeToSelf) {
if slot != nil { srv.peerConnect <- closeToSelf[i]
srv.peerSlots <- *slot }
} if i < len(farFromSelf) {
return srv.peerConnect <- farFromSelf[i]
} }
} }
} }
// connect to peer via dial out func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
func (srv *Server) dialPeer(desc *peerAddr, slot int) { // TODO: handle/store session token
srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) conn.SetDeadline(time.Now().Add(handshakeTimeout))
conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
if err != nil { if err != nil {
srvlog.DebugDetailf("dial error: %v", err) conn.Close()
srv.peerSlots <- slot srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
return return
} }
go srv.addPeer(conn, desc, slot) ourID := srv.ntab.Self()
} p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
if ok, reason := srv.addPeer(remoteID, p); !ok {
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
p.politeDisconnect(reason)
return
}
srvlog.Debugf("Added %v\n", p)
// creates the new peer object and inserts it into its slot if srv.newPeerHook != nil {
func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { srv.newPeerHook(p)
srv.lock.Lock()
defer srv.lock.Unlock()
if !srv.running {
conn.Close()
srv.peerSlots <- slot // release slot
return nil
} }
peer := srv.newPeerFunc(srv, conn, desc) discreason := p.run()
peer.slot = slot srv.removePeer(p)
srv.peers[slot] = peer srvlog.Debugf("Removed %v (%v)\n", p, discreason)
srv.peerCount++
go func() {
peer.loop()
srv.peerDisconnect <- peer
}()
return peer
} }
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
func (srv *Server) removePeer(peer *Peer) {
srv.lock.Lock() srv.lock.Lock()
defer srv.lock.Unlock() defer srv.lock.Unlock()
srvlog.Debugf("Removing %v (slot %v)\n", peer, peer.slot) switch {
if srv.peers[peer.slot] != peer { case !srv.running:
srvlog.Warnln("Invalid peer to remove:", peer) return false, DiscQuitting
return case len(srv.peers) >= srv.MaxPeers:
} return false, DiscTooManyPeers
// remove from list and index case srv.peers[id] != nil:
srv.peerCount-- return false, DiscAlreadyConnected
srv.peers[peer.slot] = nil case srv.Blacklist.Exists(id[:]):
// release slot to signal need for a new peer, last! return false, DiscUselessPeer
srv.peerSlots <- peer.slot case id == srv.ntab.Self():
return false, DiscSelf
}
srv.peers[id] = p
return true, 0
} }
func (srv *Server) verifyPeer(addr *peerAddr) error { func (srv *Server) removePeer(p *Peer) {
if srv.Blacklist.Exists(addr.Pubkey) { srv.lock.Lock()
return errors.New("blacklisted") delete(srv.peers, *p.remoteID)
} srv.lock.Unlock()
if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) { srv.peerWG.Done()
return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
}
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil {
id := peer.Identity()
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
return errors.New("already connected")
}
}
}
return nil
} }
// TODO replace with "Set"
type Blacklist interface { type Blacklist interface {
Get([]byte) (bool, error) Get([]byte) (bool, error)
Put([]byte) error Put([]byte) error

@ -2,19 +2,28 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/ecdsa"
"io" "io"
"math/rand"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/discover"
) )
func startTestServer(t *testing.T, pf peerFunc) *Server { func startTestServer(t *testing.T, pf newPeerHook) *Server {
server := &Server{ server := &Server{
Identity: &peerId{}, Name: "test",
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
newPeerFunc: pf, PrivateKey: newkey(),
newPeerHook: pf,
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
return randomID(), nil, err
},
} }
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err) t.Fatalf("Could not start server: %v", err)
@ -27,16 +36,11 @@ func TestServerListen(t *testing.T) {
// start the test server // start the test server
connected := make(chan *Peer) connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) {
if conn == nil { if p == nil {
t.Error("peer func called with nil conn") t.Error("peer func called with nil conn")
} }
if dialAddr != nil { connected <- p
t.Error("peer func called with non-nil dialAddr")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
}) })
defer close(connected) defer close(connected)
defer srv.Stop() defer srv.Stop()
@ -50,9 +54,9 @@ func TestServerListen(t *testing.T) {
select { select {
case peer := <-connected: case peer := <-connected:
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { if peer.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v", t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.LocalAddr(), conn.RemoteAddr()) peer.LocalAddr(), conn.RemoteAddr())
} }
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Error("server did not accept within one second") t.Error("server did not accept within one second")
@ -62,7 +66,7 @@ func TestServerListen(t *testing.T) {
func TestServerDial(t *testing.T) { func TestServerDial(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
// run a fake TCP server to handle the connection. // run a one-shot TCP server to handle the connection.
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("could not setup listener: %v") t.Fatalf("could not setup listener: %v")
@ -72,41 +76,32 @@ func TestServerDial(t *testing.T) {
go func() { go func() {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
t.Error("acccept error:", err) t.Error("accept error:", err)
return
} }
conn.Close() conn.Close()
accepted <- conn accepted <- conn
}() }()
// start the test server // start the server
connected := make(chan *Peer) connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) { connected <- p })
if conn == nil {
t.Error("peer func called with nil conn")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected) defer close(connected)
defer srv.Stop() defer srv.Stop()
// tell the server to connect. // tell the server to connect
connAddr := newPeerAddr(listener.Addr(), nil) tcpAddr := listener.Addr().(*net.TCPAddr)
srv.peerConnect <- connAddr srv.SuggestPeer(&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port})
select { select {
case conn := <-accepted: case conn := <-accepted:
select { select {
case peer := <-connected: case peer := <-connected:
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { if peer.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v", t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.RemoteAddr(), conn.LocalAddr()) peer.RemoteAddr(), conn.LocalAddr())
}
if peer.dialAddr != connAddr {
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
peer.dialAddr, connAddr)
} }
// TODO: validate more fields
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second") t.Error("server did not launch peer within one second")
} }
@ -118,16 +113,17 @@ func TestServerDial(t *testing.T) {
func TestServerBroadcast(t *testing.T) { func TestServerBroadcast(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
var connected sync.WaitGroup var connected sync.WaitGroup
srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) {
peer := newPeer(c, []Protocol{discard}, dialAddr) p.protocols = []Protocol{discard}
peer.startSubprotocols([]Cap{discard.cap()}) p.startSubprotocols([]Cap{discard.cap()})
p.noHandshake = true
connected.Done() connected.Done()
return peer
}) })
defer srv.Stop() defer srv.Stop()
// dial a bunch of conns // create a few peers
var conns = make([]net.Conn, 8) var conns = make([]net.Conn, 8)
connected.Add(len(conns)) connected.Add(len(conns))
deadline := time.Now().Add(3 * time.Second) deadline := time.Now().Add(3 * time.Second)
@ -159,3 +155,18 @@ func TestServerBroadcast(t *testing.T) {
} }
} }
} }
func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey()
if err != nil {
panic("couldn't generate key: " + err.Error())
}
return key
}
func randomID() (id discover.NodeID) {
for i := range id {
id[i] = byte(rand.Intn(255))
}
return id
}

@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger {
return l return l
} }
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {} func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) { func (l testLogger) LogPrint(level logger.LogLevel, msg string) {

@ -1,40 +0,0 @@
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}

@ -350,8 +350,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
return writeUint, nil return writeUint, nil
case kind == reflect.String: case kind == reflect.String:
return writeString, nil return writeString, nil
case kind == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 && !typ.Elem().Implements(encoderInterface): case kind == reflect.Slice && isByte(typ.Elem()):
return writeBytes, nil return writeBytes, nil
case kind == reflect.Array && isByte(typ.Elem()):
return writeByteArray, nil
case kind == reflect.Slice || kind == reflect.Array: case kind == reflect.Slice || kind == reflect.Array:
return makeSliceWriter(typ) return makeSliceWriter(typ)
case kind == reflect.Struct: case kind == reflect.Struct:
@ -363,6 +365,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
} }
} }
func isByte(typ reflect.Type) bool {
return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
}
func writeUint(val reflect.Value, w *encbuf) error { func writeUint(val reflect.Value, w *encbuf) error {
i := val.Uint() i := val.Uint()
if i == 0 { if i == 0 {
@ -407,6 +413,20 @@ func writeBytes(val reflect.Value, w *encbuf) error {
return nil return nil
} }
func writeByteArray(val reflect.Value, w *encbuf) error {
if !val.CanAddr() {
// Slice requires the value to be addressable.
// Make it addressable by copying.
copy := reflect.New(val.Type()).Elem()
copy.Set(val)
val = copy
}
size := val.Len()
slice := val.Slice(0, size).Bytes()
w.encodeString(slice)
return nil
}
func writeString(val reflect.Value, w *encbuf) error { func writeString(val reflect.Value, w *encbuf) error {
s := val.String() s := val.String()
w.encodeStringHeader(len(s)) w.encodeStringHeader(len(s))

@ -40,6 +40,8 @@ func (e *encodableReader) Read(b []byte) (int, error) {
panic("called") panic("called")
} }
type namedByteType byte
var ( var (
_ = Encoder(&testEncoder{}) _ = Encoder(&testEncoder{})
_ = Encoder(byteEncoder(0)) _ = Encoder(byteEncoder(0))
@ -102,6 +104,10 @@ var encTests = []encTest{
// byte slices, strings // byte slices, strings
{val: []byte{}, output: "80"}, {val: []byte{}, output: "80"},
{val: []byte{1, 2, 3}, output: "83010203"}, {val: []byte{1, 2, 3}, output: "83010203"},
{val: []namedByteType{1, 2, 3}, output: "83010203"},
{val: [...]namedByteType{1, 2, 3}, output: "83010203"},
{val: "", output: "80"}, {val: "", output: "80"},
{val: "dog", output: "83646F67"}, {val: "dog", output: "83646F67"},
{ {

@ -215,7 +215,7 @@ func NewPeer(peer *p2p.Peer) *Peer {
return &Peer{ return &Peer{
ref: peer, ref: peer,
Ip: fmt.Sprintf("%v", peer.RemoteAddr()), Ip: fmt.Sprintf("%v", peer.RemoteAddr()),
Version: fmt.Sprintf("%v", peer.Identity()), Version: fmt.Sprintf("%v", peer.ID()),
Caps: fmt.Sprintf("%v", caps), Caps: fmt.Sprintf("%v", caps),
} }
} }

@ -31,7 +31,6 @@ type Backend interface {
IsListening() bool IsListening() bool
Peers() []*p2p.Peer Peers() []*p2p.Peer
KeyManager() *crypto.KeyManager KeyManager() *crypto.KeyManager
ClientIdentity() p2p.ClientIdentity
Db() ethutil.Database Db() ethutil.Database
EventMux() *event.TypeMux EventMux() *event.TypeMux
Whisper() *whisper.Whisper Whisper() *whisper.Whisper

Loading…
Cancel
Save