diff --git a/p2p/handshake.go b/p2p/handshake.go index 614711eafc..17f572deab 100644 --- a/p2p/handshake.go +++ b/p2p/handshake.go @@ -5,12 +5,14 @@ import ( "crypto/rand" "errors" "fmt" + "hash" "io" "net" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/secp256k1" + "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) @@ -38,13 +40,23 @@ func newConn(fd net.Conn, hs *protoHandshake) *conn { return &conn{newFrameRW(fd, msgWriteTimeout), hs} } -// encHandshake represents information about the remote end -// of a connection that is negotiated during the encryption handshake. +// encHandshake contains the state of the encryption handshake. type encHandshake struct { - ID discover.NodeID - IngressMAC []byte - EgressMAC []byte - Token []byte + remoteID discover.NodeID + initiator bool + initNonce, respNonce []byte + dhSharedSecret []byte + randomPrivKey *ecdsa.PrivateKey + remoteRandomPub *ecdsa.PublicKey +} + +// secrets represents the connection secrets +// which are negotiated during the encryption handshake. +type secrets struct { + RemoteID discover.NodeID + AES, MAC []byte + EgressMAC, IngressMAC hash.Hash + Token []byte } // protoHandshake is the RLP structure of the protocol handshake. @@ -56,6 +68,34 @@ type protoHandshake struct { ID discover.NodeID } +// secrets is called after the handshake is completed. +// It extracts the connection secrets from the handshake values. +func (h *encHandshake) secrets(auth, authResp []byte) secrets { + sharedSecret := crypto.Sha3(h.dhSharedSecret, crypto.Sha3(h.respNonce, h.initNonce)) + aesSecret := crypto.Sha3(h.dhSharedSecret, sharedSecret) + s := secrets{ + RemoteID: h.remoteID, + AES: aesSecret, + MAC: crypto.Sha3(h.dhSharedSecret, aesSecret), + Token: crypto.Sha3(sharedSecret), + } + + // setup sha3 instances for the MACs + mac1 := sha3.NewKeccak256() + mac1.Write(xor(s.MAC, h.respNonce)) + mac1.Write(auth) + mac2 := sha3.NewKeccak256() + mac2.Write(xor(s.MAC, h.initNonce)) + mac2.Write(authResp) + if h.initiator { + s.EgressMAC, s.IngressMAC = mac1, mac2 + } else { + s.EgressMAC, s.IngressMAC = mac2, mac1 + } + + return s +} + // setupConn starts a protocol session on the given connection. // It runs the encryption handshake and the protocol handshake. // If dial is non-nil, the connection the local node is the initiator. @@ -68,36 +108,47 @@ func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *di } func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) { - // var remotePubkey []byte - // sessionToken, remotePubkey, err = inboundEncHandshake(fd, prv, nil) - // copy(remoteID[:], remotePubkey) + secrets, err := inboundEncHandshake(fd, prv, nil) + if err != nil { + return nil, fmt.Errorf("encryption handshake failed: %v", err) + } - rw := newFrameRW(fd, msgWriteTimeout) - rhs, err := readProtocolHandshake(rw, our) + // Run the protocol handshake using authenticated messages. + // TODO: move buffering setup here (out of newFrameRW) + phsrw := newRlpxFrameRW(fd, secrets) + rhs, err := readProtocolHandshake(phsrw, our) if err != nil { return nil, err } - if err := writeProtocolHandshake(rw, our); err != nil { + if err := writeProtocolHandshake(phsrw, our); err != nil { return nil, fmt.Errorf("protocol write error: %v", err) } + + rw := newFrameRW(fd, msgWriteTimeout) return &conn{rw, rhs}, nil } func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { - // remoteID = dial.ID - // sessionToken, err = outboundEncHandshake(fd, prv, remoteID[:], nil) + secrets, err := outboundEncHandshake(fd, prv, dial.ID[:], nil) + if err != nil { + return nil, fmt.Errorf("encryption handshake failed: %v", err) + } - rw := newFrameRW(fd, msgWriteTimeout) - if err := writeProtocolHandshake(rw, our); err != nil { + // Run the protocol handshake using authenticated messages. + // TODO: move buffering setup here (out of newFrameRW) + phsrw := newRlpxFrameRW(fd, secrets) + if err := writeProtocolHandshake(phsrw, our); err != nil { return nil, fmt.Errorf("protocol write error: %v", err) } - rhs, err := readProtocolHandshake(rw, our) + rhs, err := readProtocolHandshake(phsrw, our) if err != nil { return nil, fmt.Errorf("protocol handshake read error: %v", err) } if rhs.ID != dial.ID { return nil, errors.New("dialed node id mismatch") } + + rw := newFrameRW(fd, msgWriteTimeout) return &conn{rw, rhs}, nil } @@ -107,43 +158,48 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, // privateKey is the local client's private key // remotePublicKey is the remote peer's node ID // sessionToken is the token from a previous session with this node. -func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) ( - newSessionToken []byte, - err error, -) { +func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) (s secrets, err error) { auth, initNonce, randomPrivKey, err := authMsg(prvKey, remotePublicKey, sessionToken) if err != nil { - return nil, err + return s, err } if _, err = conn.Write(auth); err != nil { - return nil, err + return s, err } response := make([]byte, rHSLen) if _, err = io.ReadFull(conn, response); err != nil { - return nil, err + return s, err } recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prvKey) if err != nil { - return nil, err + return s, err } - return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey) + h := &encHandshake{ + initiator: true, + initNonce: initNonce, + respNonce: recNonce, + randomPrivKey: randomPrivKey, + remoteRandomPub: remoteRandomPubKey, + } + copy(h.remoteID[:], remotePublicKey) + return h.secrets(auth, response), nil } // authMsg creates the initiator handshake. +// TODO: change all the names func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) ( auth, initNonce []byte, randomPrvKey *ecdsa.PrivateKey, err error, ) { - // session init, common to both parties remotePubKey, err := importPublicKey(remotePubKeyS) if err != nil { return } - var tokenFlag byte // = 0x00 + var tokenFlag byte if sessionToken == nil { // no session token found means we need to generate shared secret. // ecies shared secret is used as initial session token for new peers @@ -151,14 +207,13 @@ func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) ( if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil { return } - // tokenFlag = 0x00 // redundant } else { // for known peers, we use stored token from the previous session tokenFlag = 0x01 } - //E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0) - // E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1) + //E(remote-pubk, S(ecdhe-random, sha3(ecdh-shared-secret^nonce)) || H(ecdhe-random-pubk) || pubk || nonce || 0x0) + // E(remote-pubk, S(ecdhe-random, sha3(token^nonce)) || H(ecdhe-random-pubk) || pubk || nonce || 0x1) // allocate msgLen long message, var msg []byte = make([]byte, authMsgLen) initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1] @@ -242,27 +297,32 @@ func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) ( // // privateKey is the local client's private key // sessionToken is the token from a previous session with this node. -func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) ( - token, remotePubKey []byte, - err error, -) { +func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) (s secrets, err error) { // we are listening connection. we are responders in the // handshake. Extract info from the authentication. The initiator // starts by sending us a handshake that we need to respond to. so // we read auth message first, then respond. auth := make([]byte, iHSLen) if _, err := io.ReadFull(conn, auth); err != nil { - return nil, nil, err + return s, err } response, recNonce, initNonce, remotePubKey, randomPrivKey, remoteRandomPubKey, err := authResp(auth, sessionToken, prvKey) if err != nil { - return nil, nil, err + return s, err } if _, err = conn.Write(response); err != nil { - return nil, nil, err + return s, err } - token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey) - return token, remotePubKey, err + + h := &encHandshake{ + initiator: false, + initNonce: initNonce, + respNonce: recNonce, + randomPrivKey: randomPrivKey, + remoteRandomPub: remoteRandomPubKey, + } + copy(h.remoteID[:], remotePubKey) + return h.secrets(auth, response), err } // authResp is called by peer if it accepted (but not @@ -349,23 +409,6 @@ func authResp(auth, sessionToken []byte, prvKey *ecdsa.PrivateKey) ( return } -// newSession is called after the handshake is completed. The -// arguments are values negotiated in the handshake. The return value -// is a new session Token to be remembered for the next time we -// connect with this peer. -func newSession(initNonce, respNonce []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) ([]byte, error) { - // 3) Now we can trust ecdhe-random-pubk to derive new keys - //ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk) - pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey) - dhSharedSecret, err := ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen) - if err != nil { - return nil, err - } - sharedSecret := crypto.Sha3(dhSharedSecret, crypto.Sha3(respNonce, initNonce)) - sessionToken := crypto.Sha3(sharedSecret) - return sessionToken, nil -} - // importPublicKey unmarshals 512 bit public keys. func importPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) { var pubKey65 []byte diff --git a/p2p/handshake_test.go b/p2p/handshake_test.go index 06c6a69324..66e610d171 100644 --- a/p2p/handshake_test.go +++ b/p2p/handshake_test.go @@ -2,8 +2,6 @@ package p2p import ( "bytes" - "crypto/ecdsa" - "crypto/rand" "net" "reflect" "testing" @@ -69,102 +67,46 @@ func TestSharedSecret(t *testing.T) { } } -func TestCryptoHandshake(t *testing.T) { - testCryptoHandshake(newkey(), newkey(), nil, t) -} - -func TestCryptoHandshakeWithToken(t *testing.T) { - sessionToken := make([]byte, shaLen) - rand.Read(sessionToken) - testCryptoHandshake(newkey(), newkey(), sessionToken, t) -} - -func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) { - var err error - // pub0 := &prv0.PublicKey - pub1 := &prv1.PublicKey - - // pub0s := crypto.FromECDSAPub(pub0) - pub1s := crypto.FromECDSAPub(pub1) - - // simulate handshake by feeding output to input - // initiator sends handshake 'auth' - auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken) - if err != nil { - t.Errorf("%v", err) - } - // t.Logf("-> %v", hexkey(auth)) - - // receiver reads auth and responds with response - response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1) - if err != nil { - t.Errorf("%v", err) - } - // t.Logf("<- %v\n", hexkey(response)) - - // initiator reads receiver's response and the key exchange completes - recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0) - if err != nil { - t.Errorf("completeHandshake error: %v", err) - } - - // now both parties should have the same session parameters - initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey) - if err != nil { - t.Errorf("newSession error: %v", err) - } - - recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey) - if err != nil { - t.Errorf("newSession error: %v", err) - } - - // fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response) - - // fmt.Printf("\nauth %x\ninitNonce %x\nresponse%x\nremoteRecNonce %x\nremoteInitNonce %x\nremoteRandomPubKey %x\nrecNonce %x\nremoteInitRandomPubKey %x\ninitSessionToken %x\n\n", auth, initNonce, response, remoteRecNonce, remoteInitNonce, remoteRandomPubKey, recNonce, remoteInitRandomPubKey, initSessionToken) - - if !bytes.Equal(initNonce, remoteInitNonce) { - t.Errorf("nonces do not match") - } - if !bytes.Equal(recNonce, remoteRecNonce) { - t.Errorf("receiver nonces do not match") - } - if !bytes.Equal(initSessionToken, recSessionToken) { - t.Errorf("session tokens do not match") - } -} - func TestEncHandshake(t *testing.T) { defer testlog(t).detach() prv0, _ := crypto.GenerateKey() prv1, _ := crypto.GenerateKey() - pub0s, _ := exportPublicKey(&prv0.PublicKey) - pub1s, _ := exportPublicKey(&prv1.PublicKey) rw0, rw1 := net.Pipe() - tokens := make(chan []byte) + secrets := make(chan secrets) go func() { - token, err := outboundEncHandshake(rw0, prv0, pub1s, nil) + pub1s, _ := exportPublicKey(&prv1.PublicKey) + s, err := outboundEncHandshake(rw0, prv0, pub1s, nil) if err != nil { t.Errorf("outbound side error: %v", err) } - tokens <- token + id1 := discover.PubkeyID(&prv1.PublicKey) + if s.RemoteID != id1 { + t.Errorf("outbound side remote ID mismatch") + } + secrets <- s }() go func() { - token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil) + s, err := inboundEncHandshake(rw1, prv1, nil) if err != nil { t.Errorf("inbound side error: %v", err) } - if !bytes.Equal(remotePubkey, pub0s) { - t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s) + id0 := discover.PubkeyID(&prv0.PublicKey) + if s.RemoteID != id0 { + t.Errorf("inbound side remote ID mismatch") } - tokens <- token + secrets <- s }() - t1, t2 := <-tokens, <-tokens - if !bytes.Equal(t1, t2) { - t.Error("session token mismatch") + // get computed secrets from both sides + t1, t2 := <-secrets, <-secrets + // don't compare remote node IDs + t1.RemoteID, t2.RemoteID = discover.NodeID{}, discover.NodeID{} + // flip MACs on one of them so they compare equal + t1.EgressMAC, t1.IngressMAC = t1.IngressMAC, t1.EgressMAC + if !reflect.DeepEqual(t1, t2) { + t.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", t1, t2) } } diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 9fd1aed1f6..761dc2ed96 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -13,24 +13,44 @@ import ( ) var ( + // this is used in place of actual frame header data. + // TODO: replace this when Msg contains the protocol type code. zeroHeader = []byte{0xC2, 0x80, 0x80} - zero16 = make([]byte, 16) + + // sixteen zero bytes + zero16 = make([]byte, 16) ) type rlpxFrameRW struct { conn io.ReadWriter + enc cipher.Stream + dec cipher.Stream macCipher cipher.Block egressMAC hash.Hash ingressMAC hash.Hash } -func newRlpxFrameRW(conn io.ReadWriter, macSecret []byte, egressMAC, ingressMAC hash.Hash) *rlpxFrameRW { - cipher, err := aes.NewCipher(macSecret) +func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { + macc, err := aes.NewCipher(s.MAC) + if err != nil { + panic("invalid MAC secret: " + err.Error()) + } + encc, err := aes.NewCipher(s.AES) if err != nil { - panic("invalid macSecret: " + err.Error()) + panic("invalid AES secret: " + err.Error()) + } + // we use an all-zeroes IV for AES because the key used + // for encryption is ephemeral. + iv := make([]byte, encc.BlockSize()) + return &rlpxFrameRW{ + conn: conn, + enc: cipher.NewCTR(encc, iv), + dec: cipher.NewCTR(encc, iv), + macCipher: macc, + egressMAC: s.EgressMAC, + ingressMAC: s.IngressMAC, } - return &rlpxFrameRW{conn: conn, macCipher: cipher, egressMAC: egressMAC, ingressMAC: ingressMAC} } func (rw *rlpxFrameRW) WriteMsg(msg Msg) error { @@ -41,13 +61,14 @@ func (rw *rlpxFrameRW) WriteMsg(msg Msg) error { fsize := uint32(len(ptype)) + msg.Size putInt24(fsize, headbuf) // TODO: check overflow copy(headbuf[3:], zeroHeader) + rw.enc.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now encrypted copy(headbuf[16:], updateHeaderMAC(rw.egressMAC, rw.macCipher, headbuf[:16])) if _, err := rw.conn.Write(headbuf); err != nil { return err } - // write frame, updating the egress MAC while writing to conn. - tee := io.MultiWriter(rw.conn, rw.egressMAC) + // write encrypted frame, updating the egress MAC while writing to conn. + tee := cipher.StreamWriter{S: rw.enc, W: io.MultiWriter(rw.conn, rw.egressMAC)} if _, err := tee.Write(ptype); err != nil { return err } @@ -62,7 +83,8 @@ func (rw *rlpxFrameRW) WriteMsg(msg Msg) error { // write packet-mac. egress MAC is up to date because // frame content was written to it as well. - _, err := rw.conn.Write(rw.egressMAC.Sum(nil)) + mac := updateHeaderMAC(rw.egressMAC, rw.macCipher, rw.egressMAC.Sum(nil)) + _, err := rw.conn.Write(mac) return err } @@ -72,34 +94,40 @@ func (rw *rlpxFrameRW) ReadMsg() (msg Msg, err error) { if _, err := io.ReadFull(rw.conn, headbuf); err != nil { return msg, err } - fsize := readInt24(headbuf) - // ignore protocol type for now + // verify header mac shouldMAC := updateHeaderMAC(rw.ingressMAC, rw.macCipher, headbuf[:16]) if !hmac.Equal(shouldMAC[:16], headbuf[16:]) { return msg, errors.New("bad header MAC") } + rw.dec.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now decrypted + fsize := readInt24(headbuf) + // ignore protocol type for now // read the frame content - framebuf := make([]byte, fsize) + var rsize = fsize // frame size rounded up to 16 byte boundary + if padding := fsize % 16; padding > 0 { + rsize += 16 - padding + } + framebuf := make([]byte, rsize) if _, err := io.ReadFull(rw.conn, framebuf); err != nil { return msg, err } - rw.ingressMAC.Write(framebuf) - if padding := fsize % 16; padding > 0 { - if _, err := io.CopyN(rw.ingressMAC, rw.conn, int64(16-padding)); err != nil { - return msg, err - } - } + // read and validate frame MAC. we can re-use headbuf for that. + rw.ingressMAC.Write(framebuf) if _, err := io.ReadFull(rw.conn, headbuf); err != nil { return msg, err } - if !hmac.Equal(rw.ingressMAC.Sum(nil), headbuf) { + shouldMAC = updateHeaderMAC(rw.ingressMAC, rw.macCipher, rw.ingressMAC.Sum(nil)) + if !hmac.Equal(shouldMAC, headbuf) { return msg, errors.New("bad frame MAC") } + // decrypt frame content + rw.dec.XORKeyStream(framebuf, framebuf) + // decode message code - content := bytes.NewReader(framebuf) + content := bytes.NewReader(framebuf[:fsize]) if err := rlp.Decode(content, &msg.Code); err != nil { return msg, err } diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index 380d9aba68..b3c2adf8d2 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -16,14 +16,18 @@ import ( func TestRlpxFrameFake(t *testing.T) { buf := new(bytes.Buffer) - secret := crypto.Sha3() hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) - rw := newRlpxFrameRW(buf, secret, hash, hash) + rw := newRlpxFrameRW(buf, secrets{ + AES: crypto.Sha3(), + MAC: crypto.Sha3(), + IngressMAC: hash, + EgressMAC: hash, + }) golden := unhex(` -000006C2808000000000000000000000 +00828ddae471818bb0bfa6b551d1cb42 01010101010101010101010101010101 -08C40102030400000000000000000000 +ba628a4ba590cb43f7848f41c4382885 01010101010101010101010101010101 01010101010101010101010101010101 `) @@ -75,27 +79,35 @@ func unhex(str string) []byte { func TestRlpxFrameRW(t *testing.T) { var ( + aesSecret = make([]byte, 16) macSecret = make([]byte, 16) egressMACinit = make([]byte, 32) ingressMACinit = make([]byte, 32) ) - for _, s := range [][]byte{macSecret, egressMACinit, ingressMACinit} { + for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} { rand.Read(s) } - conn := new(bytes.Buffer) - em1 := sha3.NewKeccak256() - em1.Write(egressMACinit) - im1 := sha3.NewKeccak256() - im1.Write(ingressMACinit) - rw1 := newRlpxFrameRW(conn, macSecret, em1, im1) - - em2 := sha3.NewKeccak256() - em2.Write(ingressMACinit) - im2 := sha3.NewKeccak256() - im2.Write(egressMACinit) - rw2 := newRlpxFrameRW(conn, macSecret, em2, im2) + s1 := secrets{ + AES: aesSecret, + MAC: macSecret, + EgressMAC: sha3.NewKeccak256(), + IngressMAC: sha3.NewKeccak256(), + } + s1.EgressMAC.Write(egressMACinit) + s1.IngressMAC.Write(ingressMACinit) + rw1 := newRlpxFrameRW(conn, s1) + + s2 := secrets{ + AES: aesSecret, + MAC: macSecret, + EgressMAC: sha3.NewKeccak256(), + IngressMAC: sha3.NewKeccak256(), + } + s2.EgressMAC.Write(ingressMACinit) + s2.IngressMAC.Write(egressMACinit) + rw2 := newRlpxFrameRW(conn, s2) // send some messages for i := 0; i < 10; i++ {