diff --git a/cmd/devp2p/internal/ethtest/helpers.go b/cmd/devp2p/internal/ethtest/helpers.go index a9a213f337..6f7365483a 100644 --- a/cmd/devp2p/internal/ethtest/helpers.go +++ b/cmd/devp2p/internal/ethtest/helpers.go @@ -24,6 +24,7 @@ import ( "time" "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/protocols/eth" @@ -649,58 +650,68 @@ func (s *Suite) hashAnnounce(isEth66 bool) error { return fmt.Errorf("peering failed: %v", err) } // create NewBlockHashes announcement - nextBlock := s.fullChain.blocks[s.chain.Len()] - newBlockHash := &NewBlockHashes{ - {Hash: nextBlock.Hash(), Number: nextBlock.Number().Uint64()}, + type anno struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced } - + nextBlock := s.fullChain.blocks[s.chain.Len()] + announcement := anno{Hash: nextBlock.Hash(), Number: nextBlock.Number().Uint64()} + newBlockHash := &NewBlockHashes{announcement} if err := sendConn.Write(newBlockHash); err != nil { return fmt.Errorf("failed to write to connection: %v", err) } + // Announcement sent, now wait for a header request + var ( + id uint64 + msg Message + blockHeaderReq GetBlockHeaders + ) if isEth66 { - // expect GetBlockHeaders request, and respond - id, msg := sendConn.Read66() + id, msg = sendConn.Read66() switch msg := msg.(type) { case GetBlockHeaders: - blockHeaderReq := msg - if blockHeaderReq.Amount != 1 { - return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount) - } - if blockHeaderReq.Origin.Hash != nextBlock.Hash() { - return fmt.Errorf("unexpected block header requested: %v", pretty.Sdump(blockHeaderReq)) - } - resp := ð.BlockHeadersPacket66{ - RequestId: id, - BlockHeadersPacket: eth.BlockHeadersPacket{ - nextBlock.Header(), - }, - } - if err := sendConn.Write66(resp, BlockHeaders{}.Code()); err != nil { - return fmt.Errorf("failed to write to connection: %v", err) - } + blockHeaderReq = msg default: return fmt.Errorf("unexpected %s", pretty.Sdump(msg)) } + if blockHeaderReq.Amount != 1 { + return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount) + } + if blockHeaderReq.Origin.Hash != announcement.Hash { + return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v", + pretty.Sdump(announcement), + pretty.Sdump(blockHeaderReq)) + } + if err := sendConn.Write66(ð.BlockHeadersPacket66{ + RequestId: id, + BlockHeadersPacket: eth.BlockHeadersPacket{ + nextBlock.Header(), + }, + }, BlockHeaders{}.Code()); err != nil { + return fmt.Errorf("failed to write to connection: %v", err) + } } else { - // expect GetBlockHeaders request, and respond - switch msg := sendConn.Read().(type) { + msg = sendConn.Read() + switch msg := msg.(type) { case *GetBlockHeaders: - blockHeaderReq := *msg - if blockHeaderReq.Amount != 1 { - return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount) - } - if blockHeaderReq.Origin.Hash != nextBlock.Hash() { - return fmt.Errorf("unexpected block header requested: %v", pretty.Sdump(blockHeaderReq)) - } - if err := sendConn.Write(&BlockHeaders{nextBlock.Header()}); err != nil { - return fmt.Errorf("failed to write to connection: %v", err) - } + blockHeaderReq = *msg default: return fmt.Errorf("unexpected %s", pretty.Sdump(msg)) } + if blockHeaderReq.Amount != 1 { + return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount) + } + if blockHeaderReq.Origin.Hash != announcement.Hash { + return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v", + pretty.Sdump(announcement), + pretty.Sdump(blockHeaderReq)) + } + if err := sendConn.Write(&BlockHeaders{nextBlock.Header()}); err != nil { + return fmt.Errorf("failed to write to connection: %v", err) + } } // wait for block announcement - msg := recvConn.readAndServe(s.chain, timeout) + msg = recvConn.readAndServe(s.chain, timeout) switch msg := msg.(type) { case *NewBlockHashes: hashes := *msg diff --git a/cmd/devp2p/internal/v4test/discv4tests.go b/cmd/devp2p/internal/v4test/discv4tests.go index 140b96bfa5..1b5e5304ed 100644 --- a/cmd/devp2p/internal/v4test/discv4tests.go +++ b/cmd/devp2p/internal/v4test/discv4tests.go @@ -21,7 +21,6 @@ import ( "crypto/rand" "fmt" "net" - "reflect" "time" "github.com/ethereum/go-ethereum/crypto" @@ -89,16 +88,18 @@ func BasicPing(t *utesting.T) { // checkPong verifies that reply is a valid PONG matching the given ping hash. func (te *testenv) checkPong(reply v4wire.Packet, pingHash []byte) error { - if reply == nil || reply.Kind() != v4wire.PongPacket { - return fmt.Errorf("expected PONG reply, got %v", reply) + if reply == nil { + return fmt.Errorf("expected PONG reply, got nil") + } + if reply.Kind() != v4wire.PongPacket { + return fmt.Errorf("expected PONG reply, got %v %v", reply.Name(), reply) } pong := reply.(*v4wire.Pong) if !bytes.Equal(pong.ReplyTok, pingHash) { return fmt.Errorf("PONG reply token mismatch: got %x, want %x", pong.ReplyTok, pingHash) } - wantEndpoint := te.localEndpoint(te.l1) - if !reflect.DeepEqual(pong.To, wantEndpoint) { - return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, wantEndpoint) + if want := te.localEndpoint(te.l1); !want.IP.Equal(pong.To.IP) || want.UDP != pong.To.UDP { + return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, want) } if v4wire.Expired(pong.Expiration) { return fmt.Errorf("PONG is expired (%v)", pong.Expiration)