From 1959346793bdee469f68841843dd383cf801aba1 Mon Sep 17 00:00:00 2001 From: zelig Date: Fri, 3 Jul 2015 04:56:20 +0100 Subject: [PATCH] account update: migrate or change password * account.Update * KeyStore.Cleanup * fix dir rm for old format deleteKey --- accounts/account_manager.go | 28 ++++++++++++++ cmd/geth/main.go | 68 ++++++++++++++++++++++++++++++---- crypto/key_store_passphrase.go | 33 +++++++++-------- crypto/key_store_plain.go | 45 ++++++++++++++++++++-- 4 files changed, 149 insertions(+), 25 deletions(-) diff --git a/accounts/account_manager.go b/accounts/account_manager.go index eb2672a7de..17b128e9ef 100644 --- a/accounts/account_manager.go +++ b/accounts/account_manager.go @@ -36,6 +36,7 @@ import ( "crypto/ecdsa" crand "crypto/rand" "errors" + "fmt" "os" "sync" "time" @@ -158,6 +159,20 @@ func (am *Manager) NewAccount(auth string) (Account, error) { return Account{Address: key.Address}, nil } +func (am *Manager) AddressByIndex(index int) (addr string, err error) { + var addrs []common.Address + addrs, err = am.keyStore.GetKeyAddresses() + if err != nil { + return + } + if index < 0 || index >= len(addrs) { + err = fmt.Errorf("index out of range: %d (should be 0-%d)", index, len(addrs)-1) + } else { + addr = addrs[index].Hex() + } + return +} + func (am *Manager) Accounts() ([]Account, error) { addresses, err := am.keyStore.GetKeyAddresses() if os.IsNotExist(err) { @@ -204,6 +219,19 @@ func (am *Manager) Import(path string, keyAuth string) (Account, error) { return Account{Address: key.Address}, nil } +func (am *Manager) Update(addr common.Address, authFrom, authTo string) (err error) { + var key *crypto.Key + key, err = am.keyStore.GetKey(addr, authFrom) + + if err == nil { + err = am.keyStore.StoreKey(key, authTo) + if err == nil { + am.keyStore.Cleanup(addr) + } + } + return +} + func (am *Manager) ImportPreSaleKey(keyJSON []byte, password string) (acc Account, err error) { var key *crypto.Key key, err = crypto.ImportPreSaleKey(am.keyStore, keyJSON, password) diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 673a08d452..ffd26a7c20 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -189,6 +189,33 @@ Note, this is meant to be used for testing only, it is a bad idea to save your password to file or expose in any other way. `, }, + { + Action: accountUpdate, + Name: "update", + Usage: "update an existing account", + Description: ` + + ethereum account update
+ +Update an existing account. + +The account is saved in the newest version in encrypted format, you are prompted +for a passphrase to unlock the account and another to save the updated file. + +This same command can therefore be used to migrate an account of a deprecated +format to the newest format or change the password for an account. + +For non-interactive use the passphrase can be specified with the --password flag: + + ethereum --password account new + +Since only one password can be given, only format update can be performed, +changing your password is only possible interactively. + +Note that account update has the a side effect that the order of your accounts +changes. + `, + }, { Action: accountImport, Name: "import", @@ -433,19 +460,30 @@ func execJSFiles(ctx *cli.Context) { ethereum.WaitForShutdown() } -func unlockAccount(ctx *cli.Context, am *accounts.Manager, account string, i int) { +func unlockAccount(ctx *cli.Context, am *accounts.Manager, addr string, i int) (addrHex, auth string) { var err error // Load startup keys. XXX we are going to need a different format - if !((len(account) == 40) || (len(account) == 42)) { // with or without 0x - utils.Fatalf("Invalid account address '%s'", account) + if !((len(addr) == 40) || (len(addr) == 42)) { // with or without 0x + var index int + index, err = strconv.Atoi(addr) + if err != nil { + utils.Fatalf("Invalid account address '%s'", addr) + } + + addrHex, err = am.AddressByIndex(index) + if err != nil { + utils.Fatalf("%v", err) + } + } else { + addrHex = addr } // Attempt to unlock the account 3 times attempts := 3 for tries := 0; tries < attempts; tries++ { - msg := fmt.Sprintf("Unlocking account %s | Attempt %d/%d", account, tries+1, attempts) - passphrase := getPassPhrase(ctx, msg, false, i) - err = am.Unlock(common.HexToAddress(account), passphrase) + msg := fmt.Sprintf("Unlocking account %s | Attempt %d/%d", addr, tries+1, attempts) + auth = getPassPhrase(ctx, msg, false, i) + err = am.Unlock(common.HexToAddress(addrHex), auth) if err == nil { break } @@ -453,7 +491,8 @@ func unlockAccount(ctx *cli.Context, am *accounts.Manager, account string, i int if err != nil { utils.Fatalf("Unlock account failed '%v'", err) } - fmt.Printf("Account '%s' unlocked.\n", account) + fmt.Printf("Account '%s' unlocked.\n", addr) + return } func blockRecovery(ctx *cli.Context) { @@ -578,6 +617,21 @@ func accountCreate(ctx *cli.Context) { fmt.Printf("Address: %x\n", acct) } +func accountUpdate(ctx *cli.Context) { + am := utils.MakeAccountManager(ctx) + arg := ctx.Args().First() + if len(arg) == 0 { + utils.Fatalf("account address or index must be given as argument") + } + + addr, authFrom := unlockAccount(ctx, am, arg, 0) + authTo := getPassPhrase(ctx, "Please give a new password. Do not forget this password.", true, 0) + err := am.Update(common.HexToAddress(addr), authFrom, authTo) + if err != nil { + utils.Fatalf("Could not update the account: %v", err) + } +} + func importWallet(ctx *cli.Context) { keyfile := ctx.Args().First() if len(keyfile) == 0 { diff --git a/crypto/key_store_passphrase.go b/crypto/key_store_passphrase.go index d26e3407f2..47909bc765 100644 --- a/crypto/key_store_passphrase.go +++ b/crypto/key_store_passphrase.go @@ -72,16 +72,19 @@ func (ks keyStorePassphrase) GenerateNewKey(rand io.Reader, auth string) (key *K } func (ks keyStorePassphrase) GetKey(keyAddr common.Address, auth string) (key *Key, err error) { - keyBytes, keyId, err := decryptKeyFromFile(ks, keyAddr, auth) - if err != nil { - return nil, err - } - key = &Key{ - Id: uuid.UUID(keyId), - Address: keyAddr, - PrivateKey: ToECDSA(keyBytes), + keyBytes, keyId, err := decryptKeyFromFile(ks.keysDirPath, keyAddr, auth) + if err == nil { + key = &Key{ + Id: uuid.UUID(keyId), + Address: keyAddr, + PrivateKey: ToECDSA(keyBytes), + } } - return key, err + return +} + +func (ks keyStorePassphrase) Cleanup(keyAddr common.Address) (err error) { + return cleanup(ks.keysDirPath, keyAddr) } func (ks keyStorePassphrase) GetKeyAddresses() (addresses []common.Address, err error) { @@ -142,7 +145,7 @@ func (ks keyStorePassphrase) StoreKey(key *Key, auth string) (err error) { func (ks keyStorePassphrase) DeleteKey(keyAddr common.Address, auth string) (err error) { // only delete if correct passphrase is given - _, _, err = decryptKeyFromFile(ks, keyAddr, auth) + _, _, err = decryptKeyFromFile(ks.keysDirPath, keyAddr, auth) if err != nil { return err } @@ -150,25 +153,25 @@ func (ks keyStorePassphrase) DeleteKey(keyAddr common.Address, auth string) (err return deleteKey(ks.keysDirPath, keyAddr) } -func decryptKeyFromFile(ks keyStorePassphrase, keyAddr common.Address, auth string) (keyBytes []byte, keyId []byte, err error) { +func decryptKeyFromFile(keysDirPath string, keyAddr common.Address, auth string) (keyBytes []byte, keyId []byte, err error) { + fmt.Printf("%v\n", keyAddr.Hex()) m := make(map[string]interface{}) - err = getKey(ks.keysDirPath, keyAddr, &m) + err = getKey(keysDirPath, keyAddr, &m) if err != nil { - fmt.Printf("get key error: %v\n", err) return } v := reflect.ValueOf(m["version"]) if v.Kind() == reflect.String && v.String() == "1" { k := new(encryptedKeyJSONV1) - getKey(ks.keysDirPath, keyAddr, &k) + err = getKey(keysDirPath, keyAddr, &k) if err != nil { return } return decryptKeyV1(k, auth) } else { k := new(encryptedKeyJSONV3) - getKey(ks.keysDirPath, keyAddr, &k) + err = getKey(keysDirPath, keyAddr, &k) if err != nil { return } diff --git a/crypto/key_store_plain.go b/crypto/key_store_plain.go index d785fdf68c..c13c5e7a48 100644 --- a/crypto/key_store_plain.go +++ b/crypto/key_store_plain.go @@ -43,6 +43,7 @@ type KeyStore interface { GetKeyAddresses() ([]common.Address, error) // get all addresses StoreKey(*Key, string) error // store key optionally using auth string DeleteKey(common.Address, string) error // delete key by addr and auth string + Cleanup(keyAddr common.Address) (err error) } type keyStorePlain struct { @@ -86,6 +87,10 @@ func (ks keyStorePlain) GetKeyAddresses() (addresses []common.Address, err error return getKeyAddresses(ks.keysDirPath) } +func (ks keyStorePlain) Cleanup(keyAddr common.Address) (err error) { + return cleanup(ks.keysDirPath, keyAddr) +} + func (ks keyStorePlain) StoreKey(key *Key, auth string) (err error) { keyJSON, err := json.Marshal(key) if err != nil { @@ -100,10 +105,14 @@ func (ks keyStorePlain) DeleteKey(keyAddr common.Address, auth string) (err erro } func deleteKey(keysDirPath string, keyAddr common.Address) (err error) { - var keyFilePath string - keyFilePath, err = getKeyFilePath(keysDirPath, keyAddr) + var path string + path, err = getKeyFilePath(keysDirPath, keyAddr) if err == nil { - err = os.Remove(keyFilePath) + addrHex := hex.EncodeToString(keyAddr[:]) + if path == filepath.Join(keysDirPath, addrHex, addrHex) { + path = filepath.Join(keysDirPath, addrHex) + } + err = os.RemoveAll(path) } return } @@ -122,6 +131,36 @@ func getKeyFilePath(keysDirPath string, keyAddr common.Address) (keyFilePath str return } +func cleanup(keysDirPath string, keyAddr common.Address) (err error) { + fileInfos, err := ioutil.ReadDir(keysDirPath) + if err != nil { + return + } + var paths []string + account := hex.EncodeToString(keyAddr[:]) + for _, fileInfo := range fileInfos { + path := filepath.Join(keysDirPath, fileInfo.Name()) + if len(path) >= 40 { + addr := path[len(path)-40 : len(path)] + if addr == account { + if path == filepath.Join(keysDirPath, addr, addr) { + path = filepath.Join(keysDirPath, addr) + } + paths = append(paths, path) + } + } + } + if len(paths) > 1 { + for i := 0; err == nil && i < len(paths)-1; i++ { + err = os.RemoveAll(paths[i]) + if err != nil { + break + } + } + } + return +} + func getKeyFile(keysDirPath string, keyAddr common.Address) (fileContent []byte, err error) { var keyFilePath string keyFilePath, err = getKeyFilePath(keysDirPath, keyAddr)