cmd/devp2p/internal/ethtest: update tests for eth/67 (#25306)

pull/25468/head
Felix Lange 2 years ago committed by GitHub
parent 6fdc619413
commit d804a59ee1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      cmd/devp2p/internal/ethtest/chain.go
  2. 41
      cmd/devp2p/internal/ethtest/chain_test.go
  3. 322
      cmd/devp2p/internal/ethtest/helpers.go
  4. 24
      cmd/devp2p/internal/ethtest/snapTypes.go
  5. 476
      cmd/devp2p/internal/ethtest/suite.go
  6. 2
      cmd/devp2p/internal/ethtest/suite_test.go
  7. 50
      cmd/devp2p/internal/ethtest/transaction.go
  8. 153
      cmd/devp2p/internal/ethtest/types.go
  9. 8
      cmd/devp2p/rlpxcmd.go

@ -96,12 +96,12 @@ func (c *Chain) Head() *types.Block {
return c.blocks[c.Len()-1] return c.blocks[c.Len()-1]
} }
func (c *Chain) GetHeaders(req GetBlockHeaders) (BlockHeaders, error) { func (c *Chain) GetHeaders(req *GetBlockHeaders) ([]*types.Header, error) {
if req.Amount < 1 { if req.Amount < 1 {
return nil, fmt.Errorf("no block headers requested") return nil, fmt.Errorf("no block headers requested")
} }
headers := make(BlockHeaders, req.Amount) headers := make([]*types.Header, req.Amount)
var blockNumber uint64 var blockNumber uint64
// range over blocks to check if our chain has the requested header // range over blocks to check if our chain has the requested header

@ -21,6 +21,7 @@ import (
"strconv" "strconv"
"testing" "testing"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -140,18 +141,18 @@ func TestChain_GetHeaders(t *testing.T) {
var tests = []struct { var tests = []struct {
req GetBlockHeaders req GetBlockHeaders
expected BlockHeaders expected []*types.Header
}{ }{
{ {
req: GetBlockHeaders{ req: GetBlockHeaders{
Origin: eth.HashOrNumber{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Number: uint64(2), Origin: eth.HashOrNumber{Number: uint64(2)},
Amount: uint64(5),
Skip: 1,
Reverse: false,
}, },
Amount: uint64(5),
Skip: 1,
Reverse: false,
}, },
expected: BlockHeaders{ expected: []*types.Header{
chain.blocks[2].Header(), chain.blocks[2].Header(),
chain.blocks[4].Header(), chain.blocks[4].Header(),
chain.blocks[6].Header(), chain.blocks[6].Header(),
@ -161,14 +162,14 @@ func TestChain_GetHeaders(t *testing.T) {
}, },
{ {
req: GetBlockHeaders{ req: GetBlockHeaders{
Origin: eth.HashOrNumber{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Number: uint64(chain.Len() - 1), Origin: eth.HashOrNumber{Number: uint64(chain.Len() - 1)},
Amount: uint64(3),
Skip: 0,
Reverse: true,
}, },
Amount: uint64(3),
Skip: 0,
Reverse: true,
}, },
expected: BlockHeaders{ expected: []*types.Header{
chain.blocks[chain.Len()-1].Header(), chain.blocks[chain.Len()-1].Header(),
chain.blocks[chain.Len()-2].Header(), chain.blocks[chain.Len()-2].Header(),
chain.blocks[chain.Len()-3].Header(), chain.blocks[chain.Len()-3].Header(),
@ -176,14 +177,14 @@ func TestChain_GetHeaders(t *testing.T) {
}, },
{ {
req: GetBlockHeaders{ req: GetBlockHeaders{
Origin: eth.HashOrNumber{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Hash: chain.Head().Hash(), Origin: eth.HashOrNumber{Hash: chain.Head().Hash()},
Amount: uint64(1),
Skip: 0,
Reverse: false,
}, },
Amount: uint64(1),
Skip: 0,
Reverse: false,
}, },
expected: BlockHeaders{ expected: []*types.Header{
chain.Head().Header(), chain.Head().Header(),
}, },
}, },
@ -191,7 +192,7 @@ func TestChain_GetHeaders(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run(strconv.Itoa(i), func(t *testing.T) {
headers, err := chain.GetHeaders(tt.req) headers, err := chain.GetHeaders(&tt.req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -43,21 +43,6 @@ var (
timeout = 20 * time.Second timeout = 20 * time.Second
) )
// Is_66 checks if the node supports the eth66 protocol version,
// and if not, exists the test suite
func (s *Suite) Is_66(t *utesting.T) {
conn, err := s.dial66()
if err != nil {
t.Fatalf("dial failed: %v", err)
}
if err := conn.handshake(); err != nil {
t.Fatalf("handshake failed: %v", err)
}
if conn.negotiatedProtoVersion < 66 {
t.Fail()
}
}
// dial attempts to dial the given node and perform a handshake, // dial attempts to dial the given node and perform a handshake,
// returning the created Conn if successful. // returning the created Conn if successful.
func (s *Suite) dial() (*Conn, error) { func (s *Suite) dial() (*Conn, error) {
@ -76,31 +61,16 @@ func (s *Suite) dial() (*Conn, error) {
} }
// set default p2p capabilities // set default p2p capabilities
conn.caps = []p2p.Cap{ conn.caps = []p2p.Cap{
{Name: "eth", Version: 64}, {Name: "eth", Version: 66},
{Name: "eth", Version: 65}, {Name: "eth", Version: 67},
} }
conn.ourHighestProtoVersion = 65 conn.ourHighestProtoVersion = 67
return &conn, nil return &conn, nil
} }
// dial66 attempts to dial the given node and perform a handshake, // dialSnap creates a connection with snap/1 capability.
// returning the created Conn with additional eth66 capabilities if
// successful
func (s *Suite) dial66() (*Conn, error) {
conn, err := s.dial()
if err != nil {
return nil, fmt.Errorf("dial failed: %v", err)
}
conn.caps = append(conn.caps, p2p.Cap{Name: "eth", Version: 66})
conn.ourHighestProtoVersion = 66
return conn, nil
}
// dial66 attempts to dial the given node and perform a handshake,
// returning the created Conn with additional snap/1 capabilities if
// successful.
func (s *Suite) dialSnap() (*Conn, error) { func (s *Suite) dialSnap() (*Conn, error) {
conn, err := s.dial66() conn, err := s.dial()
if err != nil { if err != nil {
return nil, fmt.Errorf("dial failed: %v", err) return nil, fmt.Errorf("dial failed: %v", err)
} }
@ -235,117 +205,68 @@ loop:
// createSendAndRecvConns creates two connections, one for sending messages to the // createSendAndRecvConns creates two connections, one for sending messages to the
// node, and one for receiving messages from the node. // node, and one for receiving messages from the node.
func (s *Suite) createSendAndRecvConns(isEth66 bool) (*Conn, *Conn, error) { func (s *Suite) createSendAndRecvConns() (*Conn, *Conn, error) {
var ( sendConn, err := s.dial()
sendConn *Conn if err != nil {
recvConn *Conn return nil, nil, fmt.Errorf("dial failed: %v", err)
err error
)
if isEth66 {
sendConn, err = s.dial66()
if err != nil {
return nil, nil, fmt.Errorf("dial failed: %v", err)
}
recvConn, err = s.dial66()
if err != nil {
sendConn.Close()
return nil, nil, fmt.Errorf("dial failed: %v", err)
}
} else {
sendConn, err = s.dial()
if err != nil {
return nil, nil, fmt.Errorf("dial failed: %v", err)
}
recvConn, err = s.dial()
if err != nil {
sendConn.Close()
return nil, nil, fmt.Errorf("dial failed: %v", err)
}
} }
return sendConn, recvConn, nil recvConn, err := s.dial()
} if err != nil {
sendConn.Close()
func (c *Conn) readAndServe(chain *Chain, timeout time.Duration) Message { return nil, nil, fmt.Errorf("dial failed: %v", err)
if c.negotiatedProtoVersion == 66 {
_, msg := c.readAndServe66(chain, timeout)
return msg
} }
return c.readAndServe65(chain, timeout) return sendConn, recvConn, nil
} }
// readAndServe serves GetBlockHeaders requests while waiting // readAndServe serves GetBlockHeaders requests while waiting
// on another message from the node. // on another message from the node.
func (c *Conn) readAndServe65(chain *Chain, timeout time.Duration) Message { func (c *Conn) readAndServe(chain *Chain, timeout time.Duration) Message {
start := time.Now()
for time.Since(start) < timeout {
c.SetReadDeadline(time.Now().Add(5 * time.Second))
switch msg := c.Read().(type) {
case *Ping:
c.Write(&Pong{})
case *GetBlockHeaders:
req := *msg
headers, err := chain.GetHeaders(req)
if err != nil {
return errorf("could not get headers for inbound header request: %v", err)
}
if err := c.Write(headers); err != nil {
return errorf("could not write to connection: %v", err)
}
default:
return msg
}
}
return errorf("no message received within %v", timeout)
}
// readAndServe66 serves eth66 GetBlockHeaders requests while waiting
// on another message from the node.
func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) {
start := time.Now() start := time.Now()
for time.Since(start) < timeout { for time.Since(start) < timeout {
c.SetReadDeadline(time.Now().Add(10 * time.Second)) c.SetReadDeadline(time.Now().Add(10 * time.Second))
reqID, msg := c.Read66() msg := c.Read()
switch msg := msg.(type) { switch msg := msg.(type) {
case *Ping: case *Ping:
c.Write(&Pong{}) c.Write(&Pong{})
case GetBlockHeaders: case *GetBlockHeaders:
headers, err := chain.GetHeaders(msg) headers, err := chain.GetHeaders(msg)
if err != nil { if err != nil {
return 0, errorf("could not get headers for inbound header request: %v", err) return errorf("could not get headers for inbound header request: %v", err)
} }
resp := &eth.BlockHeadersPacket66{ resp := &BlockHeaders{
RequestId: reqID, RequestId: msg.ReqID(),
BlockHeadersPacket: eth.BlockHeadersPacket(headers), BlockHeadersPacket: eth.BlockHeadersPacket(headers),
} }
if err := c.Write66(resp, BlockHeaders{}.Code()); err != nil { if err := c.Write(resp); err != nil {
return 0, errorf("could not write to connection: %v", err) return errorf("could not write to connection: %v", err)
} }
default: default:
return reqID, msg return msg
} }
} }
return 0, errorf("no message received within %v", timeout) return errorf("no message received within %v", timeout)
} }
// headersRequest executes the given `GetBlockHeaders` request. // headersRequest executes the given `GetBlockHeaders` request.
func (c *Conn) headersRequest(request *GetBlockHeaders, chain *Chain, isEth66 bool, reqID uint64) (BlockHeaders, error) { func (c *Conn) headersRequest(request *GetBlockHeaders, chain *Chain, reqID uint64) ([]*types.Header, error) {
defer c.SetReadDeadline(time.Time{}) defer c.SetReadDeadline(time.Time{})
c.SetReadDeadline(time.Now().Add(20 * time.Second)) c.SetReadDeadline(time.Now().Add(20 * time.Second))
// if on eth66 connection, perform eth66 GetBlockHeaders request
if isEth66 { // write request
return getBlockHeaders66(chain, c, request, reqID) request.RequestId = reqID
}
if err := c.Write(request); err != nil { if err := c.Write(request); err != nil {
return nil, err return nil, fmt.Errorf("could not write to connection: %v", err)
} }
switch msg := c.readAndServe(chain, timeout).(type) {
case *BlockHeaders: // wait for response
return *msg, nil msg := c.waitForResponse(chain, timeout, request.RequestId)
default: resp, ok := msg.(*BlockHeaders)
return nil, fmt.Errorf("invalid message: %s", pretty.Sdump(msg)) if !ok {
return nil, fmt.Errorf("unexpected message received: %s", pretty.Sdump(msg))
} }
headers := []*types.Header(resp.BlockHeadersPacket)
return headers, nil
} }
func (c *Conn) snapRequest(msg Message, id uint64, chain *Chain) (Message, error) { func (c *Conn) snapRequest(msg Message, id uint64, chain *Chain) (Message, error) {
@ -357,28 +278,8 @@ func (c *Conn) snapRequest(msg Message, id uint64, chain *Chain) (Message, error
return c.ReadSnap(id) return c.ReadSnap(id)
} }
// getBlockHeaders66 executes the given `GetBlockHeaders` request over the eth66 protocol.
func getBlockHeaders66(chain *Chain, conn *Conn, request *GetBlockHeaders, id uint64) (BlockHeaders, error) {
// write request
packet := eth.GetBlockHeadersPacket(*request)
req := &eth.GetBlockHeadersPacket66{
RequestId: id,
GetBlockHeadersPacket: &packet,
}
if err := conn.Write66(req, GetBlockHeaders{}.Code()); err != nil {
return nil, fmt.Errorf("could not write to connection: %v", err)
}
// wait for response
msg := conn.waitForResponse(chain, timeout, req.RequestId)
headers, ok := msg.(BlockHeaders)
if !ok {
return nil, fmt.Errorf("unexpected message received: %s", pretty.Sdump(msg))
}
return headers, nil
}
// headersMatch returns whether the received headers match the given request // headersMatch returns whether the received headers match the given request
func headersMatch(expected BlockHeaders, headers BlockHeaders) bool { func headersMatch(expected []*types.Header, headers []*types.Header) bool {
return reflect.DeepEqual(expected, headers) return reflect.DeepEqual(expected, headers)
} }
@ -386,8 +287,8 @@ func headersMatch(expected BlockHeaders, headers BlockHeaders) bool {
// request ID is received. // request ID is received.
func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message { func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message {
for { for {
id, msg := c.readAndServe66(chain, timeout) msg := c.readAndServe(chain, timeout)
if id == requestID { if msg.ReqID() == requestID {
return msg return msg
} }
} }
@ -395,9 +296,9 @@ func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID ui
// sendNextBlock broadcasts the next block in the chain and waits // sendNextBlock broadcasts the next block in the chain and waits
// for the node to propagate the block and import it into its chain. // for the node to propagate the block and import it into its chain.
func (s *Suite) sendNextBlock(isEth66 bool) error { func (s *Suite) sendNextBlock() error {
// set up sending and receiving connections // set up sending and receiving connections
sendConn, recvConn, err := s.createSendAndRecvConns(isEth66) sendConn, recvConn, err := s.createSendAndRecvConns()
if err != nil { if err != nil {
return err return err
} }
@ -420,7 +321,7 @@ func (s *Suite) sendNextBlock(isEth66 bool) error {
return fmt.Errorf("failed to announce block: %v", err) return fmt.Errorf("failed to announce block: %v", err)
} }
// wait for client to update its chain // wait for client to update its chain
if err = s.waitForBlockImport(recvConn, nextBlock, isEth66); err != nil { if err = s.waitForBlockImport(recvConn, nextBlock); err != nil {
return fmt.Errorf("failed to receive confirmation of block import: %v", err) return fmt.Errorf("failed to receive confirmation of block import: %v", err)
} }
// update test suite chain // update test suite chain
@ -465,29 +366,22 @@ func (s *Suite) waitAnnounce(conn *Conn, blockAnnouncement *NewBlock) error {
} }
} }
func (s *Suite) waitForBlockImport(conn *Conn, block *types.Block, isEth66 bool) error { func (s *Suite) waitForBlockImport(conn *Conn, block *types.Block) error {
defer conn.SetReadDeadline(time.Time{}) defer conn.SetReadDeadline(time.Time{})
conn.SetReadDeadline(time.Now().Add(20 * time.Second)) conn.SetReadDeadline(time.Now().Add(20 * time.Second))
// create request // create request
req := &GetBlockHeaders{ req := &GetBlockHeaders{
Origin: eth.HashOrNumber{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Hash: block.Hash(), Origin: eth.HashOrNumber{Hash: block.Hash()},
Amount: 1,
}, },
Amount: 1,
} }
// loop until BlockHeaders response contains desired block, confirming the // loop until BlockHeaders response contains desired block, confirming the
// node imported the block // node imported the block
for { for {
var ( requestID := uint64(54)
headers BlockHeaders headers, err := conn.headersRequest(req, s.chain, requestID)
err error
)
if isEth66 {
requestID := uint64(54)
headers, err = conn.headersRequest(req, s.chain, eth66, requestID)
} else {
headers, err = conn.headersRequest(req, s.chain, eth65, 0)
}
if err != nil { if err != nil {
return fmt.Errorf("GetBlockHeader request failed: %v", err) return fmt.Errorf("GetBlockHeader request failed: %v", err)
} }
@ -503,8 +397,8 @@ func (s *Suite) waitForBlockImport(conn *Conn, block *types.Block, isEth66 bool)
} }
} }
func (s *Suite) oldAnnounce(isEth66 bool) error { func (s *Suite) oldAnnounce() error {
sendConn, receiveConn, err := s.createSendAndRecvConns(isEth66) sendConn, receiveConn, err := s.createSendAndRecvConns()
if err != nil { if err != nil {
return err return err
} }
@ -550,23 +444,13 @@ func (s *Suite) oldAnnounce(isEth66 bool) error {
return nil return nil
} }
func (s *Suite) maliciousHandshakes(t *utesting.T, isEth66 bool) error { func (s *Suite) maliciousHandshakes(t *utesting.T) error {
var ( conn, err := s.dial()
conn *Conn if err != nil {
err error return fmt.Errorf("dial failed: %v", err)
)
if isEth66 {
conn, err = s.dial66()
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
} else {
conn, err = s.dial()
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
} }
defer conn.Close() defer conn.Close()
// write hello to client // write hello to client
pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:] pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:]
handshakes := []*Hello{ handshakes := []*Hello{
@ -627,16 +511,9 @@ func (s *Suite) maliciousHandshakes(t *utesting.T, isEth66 bool) error {
} }
} }
// dial for the next round // dial for the next round
if isEth66 { conn, err = s.dial()
conn, err = s.dial66() if err != nil {
if err != nil { return fmt.Errorf("dial failed: %v", err)
return fmt.Errorf("dial failed: %v", err)
}
} else {
conn, err = s.dial()
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
} }
} }
return nil return nil
@ -654,6 +531,7 @@ func (s *Suite) maliciousStatus(conn *Conn) error {
Genesis: s.chain.blocks[0].Hash(), Genesis: s.chain.blocks[0].Hash(),
ForkID: s.chain.ForkID(), ForkID: s.chain.ForkID(),
} }
// get status // get status
msg, err := conn.statusExchange(s.chain, status) msg, err := conn.statusExchange(s.chain, status)
if err != nil { if err != nil {
@ -664,6 +542,7 @@ func (s *Suite) maliciousStatus(conn *Conn) error {
default: default:
return fmt.Errorf("expected status, got: %#v ", msg) return fmt.Errorf("expected status, got: %#v ", msg)
} }
// wait for disconnect // wait for disconnect
switch msg := conn.readAndServe(s.chain, timeout).(type) { switch msg := conn.readAndServe(s.chain, timeout).(type) {
case *Disconnect: case *Disconnect:
@ -675,9 +554,9 @@ func (s *Suite) maliciousStatus(conn *Conn) error {
} }
} }
func (s *Suite) hashAnnounce(isEth66 bool) error { func (s *Suite) hashAnnounce() error {
// create connections // create connections
sendConn, recvConn, err := s.createSendAndRecvConns(isEth66) sendConn, recvConn, err := s.createSendAndRecvConns()
if err != nil { if err != nil {
return fmt.Errorf("failed to create connections: %v", err) return fmt.Errorf("failed to create connections: %v", err)
} }
@ -689,6 +568,7 @@ func (s *Suite) hashAnnounce(isEth66 bool) error {
if err := recvConn.peer(s.chain, nil); err != nil { if err := recvConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err) return fmt.Errorf("peering failed: %v", err)
} }
// create NewBlockHashes announcement // create NewBlockHashes announcement
type anno struct { type anno struct {
Hash common.Hash // Hash of one particular block being announced Hash common.Hash // Hash of one particular block being announced
@ -700,56 +580,29 @@ func (s *Suite) hashAnnounce(isEth66 bool) error {
if err := sendConn.Write(newBlockHash); err != nil { if err := sendConn.Write(newBlockHash); err != nil {
return fmt.Errorf("failed to write to connection: %v", err) return fmt.Errorf("failed to write to connection: %v", err)
} }
// Announcement sent, now wait for a header request // Announcement sent, now wait for a header request
var ( msg := sendConn.Read()
id uint64 blockHeaderReq, ok := msg.(*GetBlockHeaders)
msg Message if !ok {
blockHeaderReq GetBlockHeaders return fmt.Errorf("unexpected %s", pretty.Sdump(msg))
)
if isEth66 {
id, msg = sendConn.Read66()
switch msg := msg.(type) {
case GetBlockHeaders:
blockHeaderReq = msg
default:
return fmt.Errorf("unexpected %s", pretty.Sdump(msg))
}
if blockHeaderReq.Amount != 1 {
return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
}
if blockHeaderReq.Origin.Hash != announcement.Hash {
return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v",
pretty.Sdump(announcement),
pretty.Sdump(blockHeaderReq))
}
if err := sendConn.Write66(&eth.BlockHeadersPacket66{
RequestId: id,
BlockHeadersPacket: eth.BlockHeadersPacket{
nextBlock.Header(),
},
}, BlockHeaders{}.Code()); err != nil {
return fmt.Errorf("failed to write to connection: %v", err)
}
} else {
msg = sendConn.Read()
switch msg := msg.(type) {
case *GetBlockHeaders:
blockHeaderReq = *msg
default:
return fmt.Errorf("unexpected %s", pretty.Sdump(msg))
}
if blockHeaderReq.Amount != 1 {
return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
}
if blockHeaderReq.Origin.Hash != announcement.Hash {
return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v",
pretty.Sdump(announcement),
pretty.Sdump(blockHeaderReq))
}
if err := sendConn.Write(&BlockHeaders{nextBlock.Header()}); err != nil {
return fmt.Errorf("failed to write to connection: %v", err)
}
} }
if blockHeaderReq.Amount != 1 {
return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
}
if blockHeaderReq.Origin.Hash != announcement.Hash {
return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v",
pretty.Sdump(announcement),
pretty.Sdump(blockHeaderReq))
}
err = sendConn.Write(&BlockHeaders{
RequestId: blockHeaderReq.ReqID(),
BlockHeadersPacket: eth.BlockHeadersPacket{nextBlock.Header()},
})
if err != nil {
return fmt.Errorf("failed to write to connection: %v", err)
}
// wait for block announcement // wait for block announcement
msg = recvConn.readAndServe(s.chain, timeout) msg = recvConn.readAndServe(s.chain, timeout)
switch msg := msg.(type) { switch msg := msg.(type) {
@ -762,6 +615,7 @@ func (s *Suite) hashAnnounce(isEth66 bool) error {
return fmt.Errorf("unexpected block hash announcement, wanted %v, got %v", nextBlock.Hash(), return fmt.Errorf("unexpected block hash announcement, wanted %v, got %v", nextBlock.Hash(),
hashes[0].Hash) hashes[0].Hash)
} }
case *NewBlock: case *NewBlock:
// node should only propagate NewBlock without having requested the body if the body is empty // node should only propagate NewBlock without having requested the body if the body is empty
nextBlockBody := nextBlock.Body() nextBlockBody := nextBlock.Body()
@ -780,7 +634,7 @@ func (s *Suite) hashAnnounce(isEth66 bool) error {
return fmt.Errorf("unexpected: %s", pretty.Sdump(msg)) return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
} }
// confirm node imported block // confirm node imported block
if err := s.waitForBlockImport(recvConn, nextBlock, isEth66); err != nil { if err := s.waitForBlockImport(recvConn, nextBlock); err != nil {
return fmt.Errorf("error waiting for node to import new block: %v", err) return fmt.Errorf("error waiting for node to import new block: %v", err)
} }
// update the chain // update the chain

@ -21,32 +21,40 @@ import "github.com/ethereum/go-ethereum/eth/protocols/snap"
// GetAccountRange represents an account range query. // GetAccountRange represents an account range query.
type GetAccountRange snap.GetAccountRangePacket type GetAccountRange snap.GetAccountRangePacket
func (g GetAccountRange) Code() int { return 33 } func (msg GetAccountRange) Code() int { return 33 }
func (msg GetAccountRange) ReqID() uint64 { return msg.ID }
type AccountRange snap.AccountRangePacket type AccountRange snap.AccountRangePacket
func (g AccountRange) Code() int { return 34 } func (msg AccountRange) Code() int { return 34 }
func (msg AccountRange) ReqID() uint64 { return msg.ID }
type GetStorageRanges snap.GetStorageRangesPacket type GetStorageRanges snap.GetStorageRangesPacket
func (g GetStorageRanges) Code() int { return 35 } func (msg GetStorageRanges) Code() int { return 35 }
func (msg GetStorageRanges) ReqID() uint64 { return msg.ID }
type StorageRanges snap.StorageRangesPacket type StorageRanges snap.StorageRangesPacket
func (g StorageRanges) Code() int { return 36 } func (msg StorageRanges) Code() int { return 36 }
func (msg StorageRanges) ReqID() uint64 { return msg.ID }
type GetByteCodes snap.GetByteCodesPacket type GetByteCodes snap.GetByteCodesPacket
func (g GetByteCodes) Code() int { return 37 } func (msg GetByteCodes) Code() int { return 37 }
func (msg GetByteCodes) ReqID() uint64 { return msg.ID }
type ByteCodes snap.ByteCodesPacket type ByteCodes snap.ByteCodesPacket
func (g ByteCodes) Code() int { return 38 } func (msg ByteCodes) Code() int { return 38 }
func (msg ByteCodes) ReqID() uint64 { return msg.ID }
type GetTrieNodes snap.GetTrieNodesPacket type GetTrieNodes snap.GetTrieNodesPacket
func (g GetTrieNodes) Code() int { return 39 } func (msg GetTrieNodes) Code() int { return 39 }
func (msg GetTrieNodes) ReqID() uint64 { return msg.ID }
type TrieNodes snap.TrieNodesPacket type TrieNodes snap.TrieNodesPacket
func (g TrieNodes) Code() int { return 40 } func (msg TrieNodes) Code() int { return 40 }
func (msg TrieNodes) ReqID() uint64 { return msg.ID }

@ -49,79 +49,30 @@ func NewSuite(dest *enode.Node, chainfile string, genesisfile string) (*Suite, e
}, nil }, nil
} }
func (s *Suite) AllEthTests() []utesting.Test { func (s *Suite) EthTests() []utesting.Test {
return []utesting.Test{ return []utesting.Test{
// status // status
{Name: "TestStatus65", Fn: s.TestStatus65}, {Name: "TestStatus", Fn: s.TestStatus},
{Name: "TestStatus66", Fn: s.TestStatus66},
// get block headers // get block headers
{Name: "TestGetBlockHeaders65", Fn: s.TestGetBlockHeaders65}, {Name: "TestGetBlockHeaders", Fn: s.TestGetBlockHeaders},
{Name: "TestGetBlockHeaders66", Fn: s.TestGetBlockHeaders66}, {Name: "TestSimultaneousRequests", Fn: s.TestSimultaneousRequests},
{Name: "TestSimultaneousRequests66", Fn: s.TestSimultaneousRequests66}, {Name: "TestSameRequestID", Fn: s.TestSameRequestID},
{Name: "TestSameRequestID66", Fn: s.TestSameRequestID66}, {Name: "TestZeroRequestID", Fn: s.TestZeroRequestID},
{Name: "TestZeroRequestID66", Fn: s.TestZeroRequestID66},
// get block bodies // get block bodies
{Name: "TestGetBlockBodies65", Fn: s.TestGetBlockBodies65}, {Name: "TestGetBlockBodies", Fn: s.TestGetBlockBodies},
{Name: "TestGetBlockBodies66", Fn: s.TestGetBlockBodies66},
// broadcast // broadcast
{Name: "TestBroadcast65", Fn: s.TestBroadcast65}, {Name: "TestBroadcast", Fn: s.TestBroadcast},
{Name: "TestBroadcast66", Fn: s.TestBroadcast66}, {Name: "TestLargeAnnounce", Fn: s.TestLargeAnnounce},
{Name: "TestLargeAnnounce65", Fn: s.TestLargeAnnounce65}, {Name: "TestOldAnnounce", Fn: s.TestOldAnnounce},
{Name: "TestLargeAnnounce66", Fn: s.TestLargeAnnounce66}, {Name: "TestBlockHashAnnounce", Fn: s.TestBlockHashAnnounce},
{Name: "TestOldAnnounce65", Fn: s.TestOldAnnounce65},
{Name: "TestOldAnnounce66", Fn: s.TestOldAnnounce66},
{Name: "TestBlockHashAnnounce65", Fn: s.TestBlockHashAnnounce65},
{Name: "TestBlockHashAnnounce66", Fn: s.TestBlockHashAnnounce66},
// malicious handshakes + status // malicious handshakes + status
{Name: "TestMaliciousHandshake65", Fn: s.TestMaliciousHandshake65}, {Name: "TestMaliciousHandshake", Fn: s.TestMaliciousHandshake},
{Name: "TestMaliciousStatus65", Fn: s.TestMaliciousStatus65}, {Name: "TestMaliciousStatus", Fn: s.TestMaliciousStatus},
{Name: "TestMaliciousHandshake66", Fn: s.TestMaliciousHandshake66},
{Name: "TestMaliciousStatus66", Fn: s.TestMaliciousStatus66},
// test transactions // test transactions
{Name: "TestTransaction65", Fn: s.TestTransaction65}, {Name: "TestTransaction", Fn: s.TestTransaction},
{Name: "TestTransaction66", Fn: s.TestTransaction66}, {Name: "TestMaliciousTx", Fn: s.TestMaliciousTx},
{Name: "TestMaliciousTx65", Fn: s.TestMaliciousTx65}, {Name: "TestLargeTxRequest", Fn: s.TestLargeTxRequest},
{Name: "TestMaliciousTx66", Fn: s.TestMaliciousTx66}, {Name: "TestNewPooledTxs", Fn: s.TestNewPooledTxs},
{Name: "TestLargeTxRequest66", Fn: s.TestLargeTxRequest66},
{Name: "TestNewPooledTxs66", Fn: s.TestNewPooledTxs66},
}
}
func (s *Suite) EthTests() []utesting.Test {
return []utesting.Test{
{Name: "TestStatus65", Fn: s.TestStatus65},
{Name: "TestGetBlockHeaders65", Fn: s.TestGetBlockHeaders65},
{Name: "TestGetBlockBodies65", Fn: s.TestGetBlockBodies65},
{Name: "TestBroadcast65", Fn: s.TestBroadcast65},
{Name: "TestLargeAnnounce65", Fn: s.TestLargeAnnounce65},
{Name: "TestOldAnnounce65", Fn: s.TestOldAnnounce65},
{Name: "TestBlockHashAnnounce65", Fn: s.TestBlockHashAnnounce65},
{Name: "TestMaliciousHandshake65", Fn: s.TestMaliciousHandshake65},
{Name: "TestMaliciousStatus65", Fn: s.TestMaliciousStatus65},
{Name: "TestTransaction65", Fn: s.TestTransaction65},
{Name: "TestMaliciousTx65", Fn: s.TestMaliciousTx65},
}
}
func (s *Suite) Eth66Tests() []utesting.Test {
return []utesting.Test{
// only proceed with eth66 test suite if node supports eth 66 protocol
{Name: "TestStatus66", Fn: s.TestStatus66},
{Name: "TestGetBlockHeaders66", Fn: s.TestGetBlockHeaders66},
{Name: "TestSimultaneousRequests66", Fn: s.TestSimultaneousRequests66},
{Name: "TestSameRequestID66", Fn: s.TestSameRequestID66},
{Name: "TestZeroRequestID66", Fn: s.TestZeroRequestID66},
{Name: "TestGetBlockBodies66", Fn: s.TestGetBlockBodies66},
{Name: "TestBroadcast66", Fn: s.TestBroadcast66},
{Name: "TestLargeAnnounce66", Fn: s.TestLargeAnnounce66},
{Name: "TestOldAnnounce66", Fn: s.TestOldAnnounce66},
{Name: "TestBlockHashAnnounce66", Fn: s.TestBlockHashAnnounce66},
{Name: "TestMaliciousHandshake66", Fn: s.TestMaliciousHandshake66},
{Name: "TestMaliciousStatus66", Fn: s.TestMaliciousStatus66},
{Name: "TestTransaction66", Fn: s.TestTransaction66},
{Name: "TestMaliciousTx66", Fn: s.TestMaliciousTx66},
{Name: "TestLargeTxRequest66", Fn: s.TestLargeTxRequest66},
{Name: "TestNewPooledTxs66", Fn: s.TestNewPooledTxs66},
} }
} }
@ -135,14 +86,9 @@ func (s *Suite) SnapTests() []utesting.Test {
} }
} }
var ( // TestStatus attempts to connect to the given node and exchange
eth66 = true // indicates whether suite should negotiate eth66 connection // a status message with it on the eth protocol.
eth65 = false // indicates whether suite should negotiate eth65 connection or below. func (s *Suite) TestStatus(t *utesting.T) {
)
// TestStatus65 attempts to connect to the given node and exchange
// a status message with it.
func (s *Suite) TestStatus65(t *utesting.T) {
conn, err := s.dial() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
@ -153,79 +99,32 @@ func (s *Suite) TestStatus65(t *utesting.T) {
} }
} }
// TestStatus66 attempts to connect to the given node and exchange // TestGetBlockHeaders tests whether the given node can respond to
// a status message with it on the eth66 protocol. // an eth `GetBlockHeaders` request and that the response is accurate.
func (s *Suite) TestStatus66(t *utesting.T) { func (s *Suite) TestGetBlockHeaders(t *utesting.T) {
conn, err := s.dial66()
if err != nil {
t.Fatalf("dial failed: %v", err)
}
defer conn.Close()
if err := conn.peer(s.chain, nil); err != nil {
t.Fatalf("peering failed: %v", err)
}
}
// TestGetBlockHeaders65 tests whether the given node can respond to
// a `GetBlockHeaders` request accurately.
func (s *Suite) TestGetBlockHeaders65(t *utesting.T) {
conn, err := s.dial() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
} }
defer conn.Close() defer conn.Close()
if err := conn.peer(s.chain, nil); err != nil {
t.Fatalf("handshake(s) failed: %v", err)
}
// write request
req := &GetBlockHeaders{
Origin: eth.HashOrNumber{
Hash: s.chain.blocks[1].Hash(),
},
Amount: 2,
Skip: 1,
Reverse: false,
}
headers, err := conn.headersRequest(req, s.chain, eth65, 0)
if err != nil {
t.Fatalf("GetBlockHeaders request failed: %v", err)
}
// check for correct headers
expected, err := s.chain.GetHeaders(*req)
if err != nil {
t.Fatalf("failed to get headers for given request: %v", err)
}
if !headersMatch(expected, headers) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers)
}
}
// TestGetBlockHeaders66 tests whether the given node can respond to
// an eth66 `GetBlockHeaders` request and that the response is accurate.
func (s *Suite) TestGetBlockHeaders66(t *utesting.T) {
conn, err := s.dial66()
if err != nil {
t.Fatalf("dial failed: %v", err)
}
defer conn.Close()
if err = conn.peer(s.chain, nil); err != nil { if err = conn.peer(s.chain, nil); err != nil {
t.Fatalf("peering failed: %v", err) t.Fatalf("peering failed: %v", err)
} }
// write request // write request
req := &GetBlockHeaders{ req := &GetBlockHeaders{
Origin: eth.HashOrNumber{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Hash: s.chain.blocks[1].Hash(), Origin: eth.HashOrNumber{Hash: s.chain.blocks[1].Hash()},
Amount: 2,
Skip: 1,
Reverse: false,
}, },
Amount: 2,
Skip: 1,
Reverse: false,
} }
headers, err := conn.headersRequest(req, s.chain, eth66, 33) headers, err := conn.headersRequest(req, s.chain, 33)
if err != nil { if err != nil {
t.Fatalf("could not get block headers: %v", err) t.Fatalf("could not get block headers: %v", err)
} }
// check for correct headers // check for correct headers
expected, err := s.chain.GetHeaders(*req) expected, err := s.chain.GetHeaders(req)
if err != nil { if err != nil {
t.Fatalf("failed to get headers for given request: %v", err) t.Fatalf("failed to get headers for given request: %v", err)
} }
@ -234,12 +133,12 @@ func (s *Suite) TestGetBlockHeaders66(t *utesting.T) {
} }
} }
// TestSimultaneousRequests66 sends two simultaneous `GetBlockHeader` requests from // TestSimultaneousRequests sends two simultaneous `GetBlockHeader` requests from
// the same connection with different request IDs and checks to make sure the node // the same connection with different request IDs and checks to make sure the node
// responds with the correct headers per request. // responds with the correct headers per request.
func (s *Suite) TestSimultaneousRequests66(t *utesting.T) { func (s *Suite) TestSimultaneousRequests(t *utesting.T) {
// create a connection // create a connection
conn, err := s.dial66() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
} }
@ -247,8 +146,9 @@ func (s *Suite) TestSimultaneousRequests66(t *utesting.T) {
if err := conn.peer(s.chain, nil); err != nil { if err := conn.peer(s.chain, nil); err != nil {
t.Fatalf("peering failed: %v", err) t.Fatalf("peering failed: %v", err)
} }
// create two requests // create two requests
req1 := &eth.GetBlockHeadersPacket66{ req1 := &GetBlockHeaders{
RequestId: uint64(111), RequestId: uint64(111),
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{ Origin: eth.HashOrNumber{
@ -259,7 +159,7 @@ func (s *Suite) TestSimultaneousRequests66(t *utesting.T) {
Reverse: false, Reverse: false,
}, },
} }
req2 := &eth.GetBlockHeadersPacket66{ req2 := &GetBlockHeaders{
RequestId: uint64(222), RequestId: uint64(222),
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{ Origin: eth.HashOrNumber{
@ -270,46 +170,49 @@ func (s *Suite) TestSimultaneousRequests66(t *utesting.T) {
Reverse: false, Reverse: false,
}, },
} }
// write the first request // write the first request
if err := conn.Write66(req1, GetBlockHeaders{}.Code()); err != nil { if err := conn.Write(req1); err != nil {
t.Fatalf("failed to write to connection: %v", err) t.Fatalf("failed to write to connection: %v", err)
} }
// write the second request // write the second request
if err := conn.Write66(req2, GetBlockHeaders{}.Code()); err != nil { if err := conn.Write(req2); err != nil {
t.Fatalf("failed to write to connection: %v", err) t.Fatalf("failed to write to connection: %v", err)
} }
// wait for responses // wait for responses
msg := conn.waitForResponse(s.chain, timeout, req1.RequestId) msg := conn.waitForResponse(s.chain, timeout, req1.RequestId)
headers1, ok := msg.(BlockHeaders) headers1, ok := msg.(*BlockHeaders)
if !ok { if !ok {
t.Fatalf("unexpected %s", pretty.Sdump(msg)) t.Fatalf("unexpected %s", pretty.Sdump(msg))
} }
msg = conn.waitForResponse(s.chain, timeout, req2.RequestId) msg = conn.waitForResponse(s.chain, timeout, req2.RequestId)
headers2, ok := msg.(BlockHeaders) headers2, ok := msg.(*BlockHeaders)
if !ok { if !ok {
t.Fatalf("unexpected %s", pretty.Sdump(msg)) t.Fatalf("unexpected %s", pretty.Sdump(msg))
} }
// check received headers for accuracy // check received headers for accuracy
expected1, err := s.chain.GetHeaders(GetBlockHeaders(*req1.GetBlockHeadersPacket)) expected1, err := s.chain.GetHeaders(req1)
if err != nil { if err != nil {
t.Fatalf("failed to get expected headers for request 1: %v", err) t.Fatalf("failed to get expected headers for request 1: %v", err)
} }
expected2, err := s.chain.GetHeaders(GetBlockHeaders(*req2.GetBlockHeadersPacket)) expected2, err := s.chain.GetHeaders(req2)
if err != nil { if err != nil {
t.Fatalf("failed to get expected headers for request 2: %v", err) t.Fatalf("failed to get expected headers for request 2: %v", err)
} }
if !headersMatch(expected1, headers1) { if !headersMatch(expected1, headers1.BlockHeadersPacket) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected1, headers1) t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected1, headers1)
} }
if !headersMatch(expected2, headers2) { if !headersMatch(expected2, headers2.BlockHeadersPacket) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected2, headers2) t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected2, headers2)
} }
} }
// TestSameRequestID66 sends two requests with the same request ID to a // TestSameRequestID sends two requests with the same request ID to a
// single node. // single node.
func (s *Suite) TestSameRequestID66(t *utesting.T) { func (s *Suite) TestSameRequestID(t *utesting.T) {
conn, err := s.dial66() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
} }
@ -319,7 +222,7 @@ func (s *Suite) TestSameRequestID66(t *utesting.T) {
} }
// create requests // create requests
reqID := uint64(1234) reqID := uint64(1234)
request1 := &eth.GetBlockHeadersPacket66{ request1 := &GetBlockHeaders{
RequestId: reqID, RequestId: reqID,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{ Origin: eth.HashOrNumber{
@ -328,7 +231,7 @@ func (s *Suite) TestSameRequestID66(t *utesting.T) {
Amount: 2, Amount: 2,
}, },
} }
request2 := &eth.GetBlockHeadersPacket66{ request2 := &GetBlockHeaders{
RequestId: reqID, RequestId: reqID,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{ Origin: eth.HashOrNumber{
@ -337,45 +240,48 @@ func (s *Suite) TestSameRequestID66(t *utesting.T) {
Amount: 2, Amount: 2,
}, },
} }
// write the requests // write the requests
if err = conn.Write66(request1, GetBlockHeaders{}.Code()); err != nil { if err = conn.Write(request1); err != nil {
t.Fatalf("failed to write to connection: %v", err) t.Fatalf("failed to write to connection: %v", err)
} }
if err = conn.Write66(request2, GetBlockHeaders{}.Code()); err != nil { if err = conn.Write(request2); err != nil {
t.Fatalf("failed to write to connection: %v", err) t.Fatalf("failed to write to connection: %v", err)
} }
// wait for responses // wait for responses
msg := conn.waitForResponse(s.chain, timeout, reqID) msg := conn.waitForResponse(s.chain, timeout, reqID)
headers1, ok := msg.(BlockHeaders) headers1, ok := msg.(*BlockHeaders)
if !ok { if !ok {
t.Fatalf("unexpected %s", pretty.Sdump(msg)) t.Fatalf("unexpected %s", pretty.Sdump(msg))
} }
msg = conn.waitForResponse(s.chain, timeout, reqID) msg = conn.waitForResponse(s.chain, timeout, reqID)
headers2, ok := msg.(BlockHeaders) headers2, ok := msg.(*BlockHeaders)
if !ok { if !ok {
t.Fatalf("unexpected %s", pretty.Sdump(msg)) t.Fatalf("unexpected %s", pretty.Sdump(msg))
} }
// check if headers match // check if headers match
expected1, err := s.chain.GetHeaders(GetBlockHeaders(*request1.GetBlockHeadersPacket)) expected1, err := s.chain.GetHeaders(request1)
if err != nil { if err != nil {
t.Fatalf("failed to get expected block headers: %v", err) t.Fatalf("failed to get expected block headers: %v", err)
} }
expected2, err := s.chain.GetHeaders(GetBlockHeaders(*request2.GetBlockHeadersPacket)) expected2, err := s.chain.GetHeaders(request2)
if err != nil { if err != nil {
t.Fatalf("failed to get expected block headers: %v", err) t.Fatalf("failed to get expected block headers: %v", err)
} }
if !headersMatch(expected1, headers1) { if !headersMatch(expected1, headers1.BlockHeadersPacket) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected1, headers1) t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected1, headers1)
} }
if !headersMatch(expected2, headers2) { if !headersMatch(expected2, headers2.BlockHeadersPacket) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected2, headers2) t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected2, headers2)
} }
} }
// TestZeroRequestID_66 checks that a message with a request ID of zero is still handled // TestZeroRequestID checks that a message with a request ID of zero is still handled
// by the node. // by the node.
func (s *Suite) TestZeroRequestID66(t *utesting.T) { func (s *Suite) TestZeroRequestID(t *utesting.T) {
conn, err := s.dial66() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
} }
@ -384,16 +290,16 @@ func (s *Suite) TestZeroRequestID66(t *utesting.T) {
t.Fatalf("peering failed: %v", err) t.Fatalf("peering failed: %v", err)
} }
req := &GetBlockHeaders{ req := &GetBlockHeaders{
Origin: eth.HashOrNumber{ GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Number: 0, Origin: eth.HashOrNumber{Number: 0},
Amount: 2,
}, },
Amount: 2,
} }
headers, err := conn.headersRequest(req, s.chain, eth66, 0) headers, err := conn.headersRequest(req, s.chain, 0)
if err != nil { if err != nil {
t.Fatalf("failed to get block headers: %v", err) t.Fatalf("failed to get block headers: %v", err)
} }
expected, err := s.chain.GetHeaders(*req) expected, err := s.chain.GetHeaders(req)
if err != nil { if err != nil {
t.Fatalf("failed to get expected block headers: %v", err) t.Fatalf("failed to get expected block headers: %v", err)
} }
@ -402,9 +308,9 @@ func (s *Suite) TestZeroRequestID66(t *utesting.T) {
} }
} }
// TestGetBlockBodies65 tests whether the given node can respond to // TestGetBlockBodies tests whether the given node can respond to
// a `GetBlockBodies` request and that the response is accurate. // a `GetBlockBodies` request and that the response is accurate.
func (s *Suite) TestGetBlockBodies65(t *utesting.T) { func (s *Suite) TestGetBlockBodies(t *utesting.T) {
conn, err := s.dial() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
@ -415,126 +321,39 @@ func (s *Suite) TestGetBlockBodies65(t *utesting.T) {
} }
// create block bodies request // create block bodies request
req := &GetBlockBodies{ req := &GetBlockBodies{
s.chain.blocks[54].Hash(),
s.chain.blocks[75].Hash(),
}
if err := conn.Write(req); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
// wait for response
switch msg := conn.readAndServe(s.chain, timeout).(type) {
case *BlockBodies:
t.Logf("received %d block bodies", len(*msg))
if len(*msg) != len(*req) {
t.Fatalf("wrong bodies in response: expected %d bodies, "+
"got %d", len(*req), len(*msg))
}
default:
t.Fatalf("unexpected: %s", pretty.Sdump(msg))
}
}
// TestGetBlockBodies66 tests whether the given node can respond to
// a `GetBlockBodies` request and that the response is accurate over
// the eth66 protocol.
func (s *Suite) TestGetBlockBodies66(t *utesting.T) {
conn, err := s.dial66()
if err != nil {
t.Fatalf("dial failed: %v", err)
}
defer conn.Close()
if err := conn.peer(s.chain, nil); err != nil {
t.Fatalf("peering failed: %v", err)
}
// create block bodies request
req := &eth.GetBlockBodiesPacket66{
RequestId: uint64(55), RequestId: uint64(55),
GetBlockBodiesPacket: eth.GetBlockBodiesPacket{ GetBlockBodiesPacket: eth.GetBlockBodiesPacket{
s.chain.blocks[54].Hash(), s.chain.blocks[54].Hash(),
s.chain.blocks[75].Hash(), s.chain.blocks[75].Hash(),
}, },
} }
if err := conn.Write66(req, GetBlockBodies{}.Code()); err != nil { if err := conn.Write(req); err != nil {
t.Fatalf("could not write to connection: %v", err) t.Fatalf("could not write to connection: %v", err)
} }
// wait for block bodies response // wait for block bodies response
msg := conn.waitForResponse(s.chain, timeout, req.RequestId) msg := conn.waitForResponse(s.chain, timeout, req.RequestId)
blockBodies, ok := msg.(BlockBodies) resp, ok := msg.(*BlockBodies)
if !ok { if !ok {
t.Fatalf("unexpected: %s", pretty.Sdump(msg)) t.Fatalf("unexpected: %s", pretty.Sdump(msg))
} }
t.Logf("received %d block bodies", len(blockBodies)) bodies := resp.BlockBodiesPacket
if len(blockBodies) != len(req.GetBlockBodiesPacket) { t.Logf("received %d block bodies", len(bodies))
if len(bodies) != len(req.GetBlockBodiesPacket) {
t.Fatalf("wrong bodies in response: expected %d bodies, "+ t.Fatalf("wrong bodies in response: expected %d bodies, "+
"got %d", len(req.GetBlockBodiesPacket), len(blockBodies)) "got %d", len(req.GetBlockBodiesPacket), len(bodies))
}
}
// TestBroadcast65 tests whether a block announcement is correctly
// propagated to the given node's peer(s).
func (s *Suite) TestBroadcast65(t *utesting.T) {
if err := s.sendNextBlock(eth65); err != nil {
t.Fatalf("block broadcast failed: %v", err)
} }
} }
// TestBroadcast66 tests whether a block announcement is correctly // TestBroadcast tests whether a block announcement is correctly
// propagated to the given node's peer(s) on the eth66 protocol. // propagated to the node's peers.
func (s *Suite) TestBroadcast66(t *utesting.T) { func (s *Suite) TestBroadcast(t *utesting.T) {
if err := s.sendNextBlock(eth66); err != nil { if err := s.sendNextBlock(); err != nil {
t.Fatalf("block broadcast failed: %v", err) t.Fatalf("block broadcast failed: %v", err)
} }
} }
// TestLargeAnnounce65 tests the announcement mechanism with a large block. // TestLargeAnnounce tests the announcement mechanism with a large block.
func (s *Suite) TestLargeAnnounce65(t *utesting.T) { func (s *Suite) TestLargeAnnounce(t *utesting.T) {
nextBlock := len(s.chain.blocks)
blocks := []*NewBlock{
{
Block: largeBlock(),
TD: s.fullChain.TotalDifficultyAt(nextBlock),
},
{
Block: s.fullChain.blocks[nextBlock],
TD: largeNumber(2),
},
{
Block: largeBlock(),
TD: largeNumber(2),
},
}
for i, blockAnnouncement := range blocks {
t.Logf("Testing malicious announcement: %v\n", i)
conn, err := s.dial()
if err != nil {
t.Fatalf("dial failed: %v", err)
}
if err = conn.peer(s.chain, nil); err != nil {
t.Fatalf("peering failed: %v", err)
}
if err = conn.Write(blockAnnouncement); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
// Invalid announcement, check that peer disconnected
switch msg := conn.readAndServe(s.chain, time.Second*8).(type) {
case *Disconnect:
case *Error:
break
default:
t.Fatalf("unexpected: %s wanted disconnect", pretty.Sdump(msg))
}
conn.Close()
}
// Test the last block as a valid block
if err := s.sendNextBlock(eth65); err != nil {
t.Fatalf("failed to broadcast next block: %v", err)
}
}
// TestLargeAnnounce66 tests the announcement mechanism with a large
// block over the eth66 protocol.
func (s *Suite) TestLargeAnnounce66(t *utesting.T) {
nextBlock := len(s.chain.blocks) nextBlock := len(s.chain.blocks)
blocks := []*NewBlock{ blocks := []*NewBlock{
{ {
@ -553,7 +372,7 @@ func (s *Suite) TestLargeAnnounce66(t *utesting.T) {
for i, blockAnnouncement := range blocks[0:3] { for i, blockAnnouncement := range blocks[0:3] {
t.Logf("Testing malicious announcement: %v\n", i) t.Logf("Testing malicious announcement: %v\n", i)
conn, err := s.dial66() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
} }
@ -564,7 +383,7 @@ func (s *Suite) TestLargeAnnounce66(t *utesting.T) {
t.Fatalf("could not write to connection: %v", err) t.Fatalf("could not write to connection: %v", err)
} }
// Invalid announcement, check that peer disconnected // Invalid announcement, check that peer disconnected
switch msg := conn.readAndServe(s.chain, time.Second*8).(type) { switch msg := conn.readAndServe(s.chain, 8*time.Second).(type) {
case *Disconnect: case *Disconnect:
case *Error: case *Error:
break break
@ -574,58 +393,35 @@ func (s *Suite) TestLargeAnnounce66(t *utesting.T) {
conn.Close() conn.Close()
} }
// Test the last block as a valid block // Test the last block as a valid block
if err := s.sendNextBlock(eth66); err != nil { if err := s.sendNextBlock(); err != nil {
t.Fatalf("failed to broadcast next block: %v", err) t.Fatalf("failed to broadcast next block: %v", err)
} }
} }
// TestOldAnnounce65 tests the announcement mechanism with an old block. // TestOldAnnounce tests the announcement mechanism with an old block.
func (s *Suite) TestOldAnnounce65(t *utesting.T) { func (s *Suite) TestOldAnnounce(t *utesting.T) {
if err := s.oldAnnounce(eth65); err != nil { if err := s.oldAnnounce(); err != nil {
t.Fatal(err)
}
}
// TestOldAnnounce66 tests the announcement mechanism with an old block,
// over the eth66 protocol.
func (s *Suite) TestOldAnnounce66(t *utesting.T) {
if err := s.oldAnnounce(eth66); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
// TestBlockHashAnnounce65 sends a new block hash announcement and expects // TestBlockHashAnnounce sends a new block hash announcement and expects
// the node to perform a `GetBlockHeaders` request.
func (s *Suite) TestBlockHashAnnounce65(t *utesting.T) {
if err := s.hashAnnounce(eth65); err != nil {
t.Fatalf("block hash announcement failed: %v", err)
}
}
// TestBlockHashAnnounce66 sends a new block hash announcement and expects
// the node to perform a `GetBlockHeaders` request. // the node to perform a `GetBlockHeaders` request.
func (s *Suite) TestBlockHashAnnounce66(t *utesting.T) { func (s *Suite) TestBlockHashAnnounce(t *utesting.T) {
if err := s.hashAnnounce(eth66); err != nil { if err := s.hashAnnounce(); err != nil {
t.Fatalf("block hash announcement failed: %v", err) t.Fatalf("block hash announcement failed: %v", err)
} }
} }
// TestMaliciousHandshake65 tries to send malicious data during the handshake. // TestMaliciousHandshake tries to send malicious data during the handshake.
func (s *Suite) TestMaliciousHandshake65(t *utesting.T) { func (s *Suite) TestMaliciousHandshake(t *utesting.T) {
if err := s.maliciousHandshakes(t, eth65); err != nil { if err := s.maliciousHandshakes(t); err != nil {
t.Fatal(err)
}
}
// TestMaliciousHandshake66 tries to send malicious data during the handshake.
func (s *Suite) TestMaliciousHandshake66(t *utesting.T) {
if err := s.maliciousHandshakes(t, eth66); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
// TestMaliciousStatus65 sends a status package with a large total difficulty. // TestMaliciousStatus sends a status package with a large total difficulty.
func (s *Suite) TestMaliciousStatus65(t *utesting.T) { func (s *Suite) TestMaliciousStatus(t *utesting.T) {
conn, err := s.dial() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
@ -637,58 +433,28 @@ func (s *Suite) TestMaliciousStatus65(t *utesting.T) {
} }
} }
// TestMaliciousStatus66 sends a status package with a large total // TestTransaction sends a valid transaction to the node and
// difficulty over the eth66 protocol.
func (s *Suite) TestMaliciousStatus66(t *utesting.T) {
conn, err := s.dial66()
if err != nil {
t.Fatalf("dial failed: %v", err)
}
defer conn.Close()
if err := s.maliciousStatus(conn); err != nil {
t.Fatal(err)
}
}
// TestTransaction65 sends a valid transaction to the node and
// checks if the transaction gets propagated. // checks if the transaction gets propagated.
func (s *Suite) TestTransaction65(t *utesting.T) { func (s *Suite) TestTransaction(t *utesting.T) {
if err := s.sendSuccessfulTxs(t, eth65); err != nil { if err := s.sendSuccessfulTxs(t); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
// TestTransaction66 sends a valid transaction to the node and // TestMaliciousTx sends several invalid transactions and tests whether
// checks if the transaction gets propagated.
func (s *Suite) TestTransaction66(t *utesting.T) {
if err := s.sendSuccessfulTxs(t, eth66); err != nil {
t.Fatal(err)
}
}
// TestMaliciousTx65 sends several invalid transactions and tests whether
// the node will propagate them. // the node will propagate them.
func (s *Suite) TestMaliciousTx65(t *utesting.T) { func (s *Suite) TestMaliciousTx(t *utesting.T) {
if err := s.sendMaliciousTxs(t, eth65); err != nil { if err := s.sendMaliciousTxs(t); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
// TestMaliciousTx66 sends several invalid transactions and tests whether // TestLargeTxRequest tests whether a node can fulfill a large GetPooledTransactions
// the node will propagate them.
func (s *Suite) TestMaliciousTx66(t *utesting.T) {
if err := s.sendMaliciousTxs(t, eth66); err != nil {
t.Fatal(err)
}
}
// TestLargeTxRequest66 tests whether a node can fulfill a large GetPooledTransactions
// request. // request.
func (s *Suite) TestLargeTxRequest66(t *utesting.T) { func (s *Suite) TestLargeTxRequest(t *utesting.T) {
// send the next block to ensure the node is no longer syncing and // send the next block to ensure the node is no longer syncing and
// is able to accept txs // is able to accept txs
if err := s.sendNextBlock(eth66); err != nil { if err := s.sendNextBlock(); err != nil {
t.Fatalf("failed to send next block: %v", err) t.Fatalf("failed to send next block: %v", err)
} }
// send 2000 transactions to the node // send 2000 transactions to the node
@ -701,7 +467,7 @@ func (s *Suite) TestLargeTxRequest66(t *utesting.T) {
} }
// set up connection to receive to ensure node is peered with the receiving connection // set up connection to receive to ensure node is peered with the receiving connection
// before tx request is sent // before tx request is sent
conn, err := s.dial66() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
} }
@ -714,17 +480,17 @@ func (s *Suite) TestLargeTxRequest66(t *utesting.T) {
for _, hash := range hashMap { for _, hash := range hashMap {
hashes = append(hashes, hash) hashes = append(hashes, hash)
} }
getTxReq := &eth.GetPooledTransactionsPacket66{ getTxReq := &GetPooledTransactions{
RequestId: 1234, RequestId: 1234,
GetPooledTransactionsPacket: hashes, GetPooledTransactionsPacket: hashes,
} }
if err = conn.Write66(getTxReq, GetPooledTransactions{}.Code()); err != nil { if err = conn.Write(getTxReq); err != nil {
t.Fatalf("could not write to conn: %v", err) t.Fatalf("could not write to conn: %v", err)
} }
// check that all received transactions match those that were sent to node // check that all received transactions match those that were sent to node
switch msg := conn.waitForResponse(s.chain, timeout, getTxReq.RequestId).(type) { switch msg := conn.waitForResponse(s.chain, timeout, getTxReq.RequestId).(type) {
case PooledTransactions: case *PooledTransactions:
for _, gotTx := range msg { for _, gotTx := range msg.PooledTransactionsPacket {
if _, exists := hashMap[gotTx.Hash()]; !exists { if _, exists := hashMap[gotTx.Hash()]; !exists {
t.Fatalf("unexpected tx received: %v", gotTx.Hash()) t.Fatalf("unexpected tx received: %v", gotTx.Hash())
} }
@ -734,12 +500,12 @@ func (s *Suite) TestLargeTxRequest66(t *utesting.T) {
} }
} }
// TestNewPooledTxs_66 tests whether a node will do a GetPooledTransactions // TestNewPooledTxs tests whether a node will do a GetPooledTransactions
// request upon receiving a NewPooledTransactionHashes announcement. // request upon receiving a NewPooledTransactionHashes announcement.
func (s *Suite) TestNewPooledTxs66(t *utesting.T) { func (s *Suite) TestNewPooledTxs(t *utesting.T) {
// send the next block to ensure the node is no longer syncing and // send the next block to ensure the node is no longer syncing and
// is able to accept txs // is able to accept txs
if err := s.sendNextBlock(eth66); err != nil { if err := s.sendNextBlock(); err != nil {
t.Fatalf("failed to send next block: %v", err) t.Fatalf("failed to send next block: %v", err)
} }
@ -757,7 +523,7 @@ func (s *Suite) TestNewPooledTxs66(t *utesting.T) {
announce := NewPooledTransactionHashes(hashes) announce := NewPooledTransactionHashes(hashes)
// send announcement // send announcement
conn, err := s.dial66() conn, err := s.dial()
if err != nil { if err != nil {
t.Fatalf("dial failed: %v", err) t.Fatalf("dial failed: %v", err)
} }
@ -771,11 +537,11 @@ func (s *Suite) TestNewPooledTxs66(t *utesting.T) {
// wait for GetPooledTxs request // wait for GetPooledTxs request
for { for {
_, msg := conn.readAndServe66(s.chain, timeout) msg := conn.readAndServe(s.chain, timeout)
switch msg := msg.(type) { switch msg := msg.(type) {
case GetPooledTransactions: case *GetPooledTransactions:
if len(msg) != len(hashes) { if len(msg.GetPooledTransactionsPacket) != len(hashes) {
t.Fatalf("unexpected number of txs requested: wanted %d, got %d", len(hashes), len(msg)) t.Fatalf("unexpected number of txs requested: wanted %d, got %d", len(hashes), len(msg.GetPooledTransactionsPacket))
} }
return return
// ignore propagated txs from previous tests // ignore propagated txs from previous tests

@ -45,7 +45,7 @@ func TestEthSuite(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("could not create new test suite: %v", err) t.Fatalf("could not create new test suite: %v", err)
} }
for _, test := range suite.Eth66Tests() { for _, test := range suite.EthTests() {
t.Run(test.Name, func(t *testing.T) { t.Run(test.Name, func(t *testing.T) {
result := utesting.RunTAP([]utesting.Test{{Name: test.Name, Fn: test.Fn}}, os.Stdout) result := utesting.RunTAP([]utesting.Test{{Name: test.Name, Fn: test.Fn}}, os.Stdout)
if result[0].Failed { if result[0].Failed {

@ -32,7 +32,7 @@ import (
//var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7") //var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7")
var faucetKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") var faucetKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
func (s *Suite) sendSuccessfulTxs(t *utesting.T, isEth66 bool) error { func (s *Suite) sendSuccessfulTxs(t *utesting.T) error {
tests := []*types.Transaction{ tests := []*types.Transaction{
getNextTxFromChain(s), getNextTxFromChain(s),
unknownTx(s), unknownTx(s),
@ -48,15 +48,15 @@ func (s *Suite) sendSuccessfulTxs(t *utesting.T, isEth66 bool) error {
prevTx = tests[i-1] prevTx = tests[i-1]
} }
// write tx to connection // write tx to connection
if err := sendSuccessfulTx(s, tx, prevTx, isEth66); err != nil { if err := sendSuccessfulTx(s, tx, prevTx); err != nil {
return fmt.Errorf("send successful tx test failed: %v", err) return fmt.Errorf("send successful tx test failed: %v", err)
} }
} }
return nil return nil
} }
func sendSuccessfulTx(s *Suite, tx *types.Transaction, prevTx *types.Transaction, isEth66 bool) error { func sendSuccessfulTx(s *Suite, tx *types.Transaction, prevTx *types.Transaction) error {
sendConn, recvConn, err := s.createSendAndRecvConns(isEth66) sendConn, recvConn, err := s.createSendAndRecvConns()
if err != nil { if err != nil {
return err return err
} }
@ -73,8 +73,10 @@ func sendSuccessfulTx(s *Suite, tx *types.Transaction, prevTx *types.Transaction
if err = recvConn.peer(s.chain, nil); err != nil { if err = recvConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err) return fmt.Errorf("peering failed: %v", err)
} }
// update last nonce seen // update last nonce seen
nonce = tx.Nonce() nonce = tx.Nonce()
// Wait for the transaction announcement // Wait for the transaction announcement
for { for {
switch msg := recvConn.readAndServe(s.chain, timeout).(type) { switch msg := recvConn.readAndServe(s.chain, timeout).(type) {
@ -114,7 +116,7 @@ func sendSuccessfulTx(s *Suite, tx *types.Transaction, prevTx *types.Transaction
} }
} }
func (s *Suite) sendMaliciousTxs(t *utesting.T, isEth66 bool) error { func (s *Suite) sendMaliciousTxs(t *utesting.T) error {
badTxs := []*types.Transaction{ badTxs := []*types.Transaction{
getOldTxFromChain(s), getOldTxFromChain(s),
invalidNonceTx(s), invalidNonceTx(s),
@ -122,16 +124,9 @@ func (s *Suite) sendMaliciousTxs(t *utesting.T, isEth66 bool) error {
hugeGasPrice(s), hugeGasPrice(s),
hugeData(s), hugeData(s),
} }
// setup receiving connection before sending malicious txs // setup receiving connection before sending malicious txs
var ( recvConn, err := s.dial()
recvConn *Conn
err error
)
if isEth66 {
recvConn, err = s.dial66()
} else {
recvConn, err = s.dial()
}
if err != nil { if err != nil {
return fmt.Errorf("dial failed: %v", err) return fmt.Errorf("dial failed: %v", err)
} }
@ -139,9 +134,10 @@ func (s *Suite) sendMaliciousTxs(t *utesting.T, isEth66 bool) error {
if err = recvConn.peer(s.chain, nil); err != nil { if err = recvConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err) return fmt.Errorf("peering failed: %v", err)
} }
for i, tx := range badTxs { for i, tx := range badTxs {
t.Logf("Testing malicious tx propagation: %v\n", i) t.Logf("Testing malicious tx propagation: %v\n", i)
if err = sendMaliciousTx(s, tx, isEth66); err != nil { if err = sendMaliciousTx(s, tx); err != nil {
return fmt.Errorf("malicious tx test failed:\ntx: %v\nerror: %v", tx, err) return fmt.Errorf("malicious tx test failed:\ntx: %v\nerror: %v", tx, err)
} }
} }
@ -149,17 +145,8 @@ func (s *Suite) sendMaliciousTxs(t *utesting.T, isEth66 bool) error {
return checkMaliciousTxPropagation(s, badTxs, recvConn) return checkMaliciousTxPropagation(s, badTxs, recvConn)
} }
func sendMaliciousTx(s *Suite, tx *types.Transaction, isEth66 bool) error { func sendMaliciousTx(s *Suite, tx *types.Transaction) error {
// setup connection conn, err := s.dial()
var (
conn *Conn
err error
)
if isEth66 {
conn, err = s.dial66()
} else {
conn, err = s.dial()
}
if err != nil { if err != nil {
return fmt.Errorf("dial failed: %v", err) return fmt.Errorf("dial failed: %v", err)
} }
@ -167,6 +154,7 @@ func sendMaliciousTx(s *Suite, tx *types.Transaction, isEth66 bool) error {
if err = conn.peer(s.chain, nil); err != nil { if err = conn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err) return fmt.Errorf("peering failed: %v", err)
} }
// write malicious tx // write malicious tx
if err = conn.Write(&Transactions{tx}); err != nil { if err = conn.Write(&Transactions{tx}); err != nil {
return fmt.Errorf("failed to write to connection: %v", err) return fmt.Errorf("failed to write to connection: %v", err)
@ -182,7 +170,7 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, txs []*types.Transaction
txMsg := Transactions(txs) txMsg := Transactions(txs)
t.Logf("sending %d txs\n", len(txs)) t.Logf("sending %d txs\n", len(txs))
sendConn, recvConn, err := s.createSendAndRecvConns(true) sendConn, recvConn, err := s.createSendAndRecvConns()
if err != nil { if err != nil {
return err return err
} }
@ -194,15 +182,19 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, txs []*types.Transaction
if err = recvConn.peer(s.chain, nil); err != nil { if err = recvConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err) return fmt.Errorf("peering failed: %v", err)
} }
// Send the transactions // Send the transactions
if err = sendConn.Write(&txMsg); err != nil { if err = sendConn.Write(&txMsg); err != nil {
return fmt.Errorf("failed to write message to connection: %v", err) return fmt.Errorf("failed to write message to connection: %v", err)
} }
// update nonce // update nonce
nonce = txs[len(txs)-1].Nonce() nonce = txs[len(txs)-1].Nonce()
// Wait for the transaction announcement(s) and make sure all sent txs are being propagated
// Wait for the transaction announcement(s) and make sure all sent txs are being propagated.
// all txs should be announced within 3 announcements.
recvHashes := make([]common.Hash, 0) recvHashes := make([]common.Hash, 0)
// all txs should be announced within 3 announcements
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
switch msg := recvConn.readAndServe(s.chain, timeout).(type) { switch msg := recvConn.readAndServe(s.chain, timeout).(type) {
case *Transactions: case *Transactions:

@ -29,6 +29,7 @@ import (
type Message interface { type Message interface {
Code() int Code() int
ReqID() uint64
} }
type Error struct { type Error struct {
@ -37,9 +38,11 @@ type Error struct {
func (e *Error) Unwrap() error { return e.err } func (e *Error) Unwrap() error { return e.err }
func (e *Error) Error() string { return e.err.Error() } func (e *Error) Error() string { return e.err.Error() }
func (e *Error) Code() int { return -1 }
func (e *Error) String() string { return e.Error() } func (e *Error) String() string { return e.Error() }
func (e *Error) Code() int { return -1 }
func (e *Error) ReqID() uint64 { return 0 }
func errorf(format string, args ...interface{}) *Error { func errorf(format string, args ...interface{}) *Error {
return &Error{fmt.Errorf(format, args...)} return &Error{fmt.Errorf(format, args...)}
} }
@ -56,73 +59,88 @@ type Hello struct {
Rest []rlp.RawValue `rlp:"tail"` Rest []rlp.RawValue `rlp:"tail"`
} }
func (h Hello) Code() int { return 0x00 } func (msg Hello) Code() int { return 0x00 }
func (msg Hello) ReqID() uint64 { return 0 }
// Disconnect is the RLP structure for a disconnect message. // Disconnect is the RLP structure for a disconnect message.
type Disconnect struct { type Disconnect struct {
Reason p2p.DiscReason Reason p2p.DiscReason
} }
func (d Disconnect) Code() int { return 0x01 } func (msg Disconnect) Code() int { return 0x01 }
func (msg Disconnect) ReqID() uint64 { return 0 }
type Ping struct{} type Ping struct{}
func (p Ping) Code() int { return 0x02 } func (msg Ping) Code() int { return 0x02 }
func (msg Ping) ReqID() uint64 { return 0 }
type Pong struct{} type Pong struct{}
func (p Pong) Code() int { return 0x03 } func (msg Pong) Code() int { return 0x03 }
func (msg Pong) ReqID() uint64 { return 0 }
// Status is the network packet for the status message for eth/64 and later. // Status is the network packet for the status message for eth/64 and later.
type Status eth.StatusPacket type Status eth.StatusPacket
func (s Status) Code() int { return 16 } func (msg Status) Code() int { return 16 }
func (msg Status) ReqID() uint64 { return 0 }
// NewBlockHashes is the network packet for the block announcements. // NewBlockHashes is the network packet for the block announcements.
type NewBlockHashes eth.NewBlockHashesPacket type NewBlockHashes eth.NewBlockHashesPacket
func (nbh NewBlockHashes) Code() int { return 17 } func (msg NewBlockHashes) Code() int { return 17 }
func (msg NewBlockHashes) ReqID() uint64 { return 0 }
type Transactions eth.TransactionsPacket type Transactions eth.TransactionsPacket
func (t Transactions) Code() int { return 18 } func (msg Transactions) Code() int { return 18 }
func (msg Transactions) ReqID() uint64 { return 18 }
// GetBlockHeaders represents a block header query. // GetBlockHeaders represents a block header query.
type GetBlockHeaders eth.GetBlockHeadersPacket type GetBlockHeaders eth.GetBlockHeadersPacket66
func (g GetBlockHeaders) Code() int { return 19 } func (msg GetBlockHeaders) Code() int { return 19 }
func (msg GetBlockHeaders) ReqID() uint64 { return msg.RequestId }
type BlockHeaders eth.BlockHeadersPacket type BlockHeaders eth.BlockHeadersPacket66
func (bh BlockHeaders) Code() int { return 20 } func (msg BlockHeaders) Code() int { return 20 }
func (msg BlockHeaders) ReqID() uint64 { return msg.RequestId }
// GetBlockBodies represents a GetBlockBodies request // GetBlockBodies represents a GetBlockBodies request
type GetBlockBodies eth.GetBlockBodiesPacket type GetBlockBodies eth.GetBlockBodiesPacket66
func (gbb GetBlockBodies) Code() int { return 21 } func (msg GetBlockBodies) Code() int { return 21 }
func (msg GetBlockBodies) ReqID() uint64 { return msg.RequestId }
// BlockBodies is the network packet for block content distribution. // BlockBodies is the network packet for block content distribution.
type BlockBodies eth.BlockBodiesPacket type BlockBodies eth.BlockBodiesPacket66
func (bb BlockBodies) Code() int { return 22 } func (msg BlockBodies) Code() int { return 22 }
func (msg BlockBodies) ReqID() uint64 { return msg.RequestId }
// NewBlock is the network packet for the block propagation message. // NewBlock is the network packet for the block propagation message.
type NewBlock eth.NewBlockPacket type NewBlock eth.NewBlockPacket
func (nb NewBlock) Code() int { return 23 } func (msg NewBlock) Code() int { return 23 }
func (msg NewBlock) ReqID() uint64 { return 0 }
// NewPooledTransactionHashes is the network packet for the tx hash propagation message. // NewPooledTransactionHashes is the network packet for the tx hash propagation message.
type NewPooledTransactionHashes eth.NewPooledTransactionHashesPacket type NewPooledTransactionHashes eth.NewPooledTransactionHashesPacket
func (nb NewPooledTransactionHashes) Code() int { return 24 } func (msg NewPooledTransactionHashes) Code() int { return 24 }
func (msg NewPooledTransactionHashes) ReqID() uint64 { return 0 }
type GetPooledTransactions eth.GetPooledTransactionsPacket type GetPooledTransactions eth.GetPooledTransactionsPacket66
func (gpt GetPooledTransactions) Code() int { return 25 } func (msg GetPooledTransactions) Code() int { return 25 }
func (msg GetPooledTransactions) ReqID() uint64 { return msg.RequestId }
type PooledTransactions eth.PooledTransactionsPacket type PooledTransactions eth.PooledTransactionsPacket66
func (pt PooledTransactions) Code() int { return 26 } func (msg PooledTransactions) Code() int { return 26 }
func (msg PooledTransactions) ReqID() uint64 { return msg.RequestId }
// Conn represents an individual connection with a peer // Conn represents an individual connection with a peer
type Conn struct { type Conn struct {
@ -135,62 +153,13 @@ type Conn struct {
caps []p2p.Cap caps []p2p.Cap
} }
// Read reads an eth packet from the connection. // Read reads an eth66 packet from the connection.
func (c *Conn) Read() Message { func (c *Conn) Read() Message {
code, rawData, _, err := c.Conn.Read() code, rawData, _, err := c.Conn.Read()
if err != nil { if err != nil {
return errorf("could not read from connection: %v", err) return errorf("could not read from connection: %v", err)
} }
var msg Message
switch int(code) {
case (Hello{}).Code():
msg = new(Hello)
case (Ping{}).Code():
msg = new(Ping)
case (Pong{}).Code():
msg = new(Pong)
case (Disconnect{}).Code():
msg = new(Disconnect)
case (Status{}).Code():
msg = new(Status)
case (GetBlockHeaders{}).Code():
msg = new(GetBlockHeaders)
case (BlockHeaders{}).Code():
msg = new(BlockHeaders)
case (GetBlockBodies{}).Code():
msg = new(GetBlockBodies)
case (BlockBodies{}).Code():
msg = new(BlockBodies)
case (NewBlock{}).Code():
msg = new(NewBlock)
case (NewBlockHashes{}).Code():
msg = new(NewBlockHashes)
case (Transactions{}).Code():
msg = new(Transactions)
case (NewPooledTransactionHashes{}).Code():
msg = new(NewPooledTransactionHashes)
case (GetPooledTransactions{}.Code()):
msg = new(GetPooledTransactions)
case (PooledTransactions{}.Code()):
msg = new(PooledTransactions)
default:
return errorf("invalid message code: %d", code)
}
// if message is devp2p, decode here
if err := rlp.DecodeBytes(rawData, msg); err != nil {
return errorf("could not rlp decode message: %v", err)
}
return msg
}
// Read66 reads an eth66 packet from the connection.
func (c *Conn) Read66() (uint64, Message) {
code, rawData, _, err := c.Conn.Read()
if err != nil {
return 0, errorf("could not read from connection: %v", err)
}
var msg Message var msg Message
switch int(code) { switch int(code) {
case (Hello{}).Code(): case (Hello{}).Code():
@ -206,27 +175,27 @@ func (c *Conn) Read66() (uint64, Message) {
case (GetBlockHeaders{}).Code(): case (GetBlockHeaders{}).Code():
ethMsg := new(eth.GetBlockHeadersPacket66) ethMsg := new(eth.GetBlockHeadersPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err) return errorf("could not rlp decode message: %v", err)
} }
return ethMsg.RequestId, GetBlockHeaders(*ethMsg.GetBlockHeadersPacket) return (*GetBlockHeaders)(ethMsg)
case (BlockHeaders{}).Code(): case (BlockHeaders{}).Code():
ethMsg := new(eth.BlockHeadersPacket66) ethMsg := new(eth.BlockHeadersPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err) return errorf("could not rlp decode message: %v", err)
} }
return ethMsg.RequestId, BlockHeaders(ethMsg.BlockHeadersPacket) return (*BlockHeaders)(ethMsg)
case (GetBlockBodies{}).Code(): case (GetBlockBodies{}).Code():
ethMsg := new(eth.GetBlockBodiesPacket66) ethMsg := new(eth.GetBlockBodiesPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err) return errorf("could not rlp decode message: %v", err)
} }
return ethMsg.RequestId, GetBlockBodies(ethMsg.GetBlockBodiesPacket) return (*GetBlockBodies)(ethMsg)
case (BlockBodies{}).Code(): case (BlockBodies{}).Code():
ethMsg := new(eth.BlockBodiesPacket66) ethMsg := new(eth.BlockBodiesPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err) return errorf("could not rlp decode message: %v", err)
} }
return ethMsg.RequestId, BlockBodies(ethMsg.BlockBodiesPacket) return (*BlockBodies)(ethMsg)
case (NewBlock{}).Code(): case (NewBlock{}).Code():
msg = new(NewBlock) msg = new(NewBlock)
case (NewBlockHashes{}).Code(): case (NewBlockHashes{}).Code():
@ -238,26 +207,26 @@ func (c *Conn) Read66() (uint64, Message) {
case (GetPooledTransactions{}.Code()): case (GetPooledTransactions{}.Code()):
ethMsg := new(eth.GetPooledTransactionsPacket66) ethMsg := new(eth.GetPooledTransactionsPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err) return errorf("could not rlp decode message: %v", err)
} }
return ethMsg.RequestId, GetPooledTransactions(ethMsg.GetPooledTransactionsPacket) return (*GetPooledTransactions)(ethMsg)
case (PooledTransactions{}.Code()): case (PooledTransactions{}.Code()):
ethMsg := new(eth.PooledTransactionsPacket66) ethMsg := new(eth.PooledTransactionsPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err) return errorf("could not rlp decode message: %v", err)
} }
return ethMsg.RequestId, PooledTransactions(ethMsg.PooledTransactionsPacket) return (*PooledTransactions)(ethMsg)
default: default:
msg = errorf("invalid message code: %d", code) msg = errorf("invalid message code: %d", code)
} }
if msg != nil { if msg != nil {
if err := rlp.DecodeBytes(rawData, msg); err != nil { if err := rlp.DecodeBytes(rawData, msg); err != nil {
return 0, errorf("could not rlp decode message: %v", err) return errorf("could not rlp decode message: %v", err)
} }
return 0, msg return msg
} }
return 0, errorf("invalid message: %s", string(rawData)) return errorf("invalid message: %s", string(rawData))
} }
// Write writes a eth packet to the connection. // Write writes a eth packet to the connection.
@ -270,16 +239,6 @@ func (c *Conn) Write(msg Message) error {
return err return err
} }
// Write66 writes an eth66 packet to the connection.
func (c *Conn) Write66(req eth.Packet, code int) error {
payload, err := rlp.EncodeToBytes(req)
if err != nil {
return err
}
_, err = c.Conn.Write(uint64(code), payload)
return err
}
// ReadSnap reads a snap/1 response with the given id from the connection. // ReadSnap reads a snap/1 response with the given id from the connection.
func (c *Conn) ReadSnap(id uint64) (Message, error) { func (c *Conn) ReadSnap(id uint64) (Message, error) {
respId := id + 1 respId := id + 1

@ -22,7 +22,6 @@ import (
"github.com/ethereum/go-ethereum/cmd/devp2p/internal/ethtest" "github.com/ethereum/go-ethereum/cmd/devp2p/internal/ethtest"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/rlpx" "github.com/ethereum/go-ethereum/p2p/rlpx"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
@ -110,12 +109,7 @@ func rlpxEthTest(ctx *cli.Context) error {
if err != nil { if err != nil {
exit(err) exit(err)
} }
// check if given node supports eth66, and if so, run eth66 protocol tests as well return runTests(ctx, suite.EthTests())
is66Failed, _ := utesting.Run(utesting.Test{Name: "Is_66", Fn: suite.Is_66})
if is66Failed {
return runTests(ctx, suite.EthTests())
}
return runTests(ctx, suite.AllEthTests())
} }
// rlpxSnapTest runs the snap protocol test suite. // rlpxSnapTest runs the snap protocol test suite.

Loading…
Cancel
Save