|
|
|
@ -45,33 +45,44 @@ type sshClient struct { |
|
|
|
|
|
|
|
|
|
// dial establishes an SSH connection to a remote node using the current user and
|
|
|
|
|
// the user's configured private RSA key. If that fails, password authentication
|
|
|
|
|
// is fallen back to. The caller may override the login user via user@server:port.
|
|
|
|
|
// is fallen back to. server can be a string like user:identity@server:port.
|
|
|
|
|
func dial(server string, pubkey []byte) (*sshClient, error) { |
|
|
|
|
// Figure out a label for the server and a logger
|
|
|
|
|
label := server |
|
|
|
|
if strings.Contains(label, ":") { |
|
|
|
|
label = label[:strings.Index(label, ":")] |
|
|
|
|
} |
|
|
|
|
login := "" |
|
|
|
|
// Figure out username, identity, hostname and port
|
|
|
|
|
hostname := "" |
|
|
|
|
hostport := server |
|
|
|
|
username := "" |
|
|
|
|
identity := "id_rsa" // default
|
|
|
|
|
|
|
|
|
|
if strings.Contains(server, "@") { |
|
|
|
|
login = label[:strings.Index(label, "@")] |
|
|
|
|
label = label[strings.Index(label, "@")+1:] |
|
|
|
|
server = server[strings.Index(server, "@")+1:] |
|
|
|
|
prefix := server[:strings.Index(server, "@")] |
|
|
|
|
if strings.Contains(prefix, ":") { |
|
|
|
|
username = prefix[:strings.Index(prefix, ":")] |
|
|
|
|
identity = prefix[strings.Index(prefix, ":")+1:] |
|
|
|
|
} else { |
|
|
|
|
username = prefix |
|
|
|
|
} |
|
|
|
|
hostport = server[strings.Index(server, "@")+1:] |
|
|
|
|
} |
|
|
|
|
logger := log.New("server", label) |
|
|
|
|
if strings.Contains(hostport, ":") { |
|
|
|
|
hostname = hostport[:strings.Index(hostport, ":")] |
|
|
|
|
} else { |
|
|
|
|
hostname = hostport |
|
|
|
|
hostport += ":22" |
|
|
|
|
} |
|
|
|
|
logger := log.New("server", server) |
|
|
|
|
logger.Debug("Attempting to establish SSH connection") |
|
|
|
|
|
|
|
|
|
user, err := user.Current() |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
if login == "" { |
|
|
|
|
login = user.Username |
|
|
|
|
if username == "" { |
|
|
|
|
username = user.Username |
|
|
|
|
} |
|
|
|
|
// Configure the supported authentication methods (private key and password)
|
|
|
|
|
var auths []ssh.AuthMethod |
|
|
|
|
|
|
|
|
|
path := filepath.Join(user.HomeDir, ".ssh", "id_rsa") |
|
|
|
|
path := filepath.Join(user.HomeDir, ".ssh", identity) |
|
|
|
|
if buf, err := ioutil.ReadFile(path); err != nil { |
|
|
|
|
log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) |
|
|
|
|
} else { |
|
|
|
@ -94,14 +105,14 @@ func dial(server string, pubkey []byte) (*sshClient, error) { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
auths = append(auths, ssh.PasswordCallback(func() (string, error) { |
|
|
|
|
fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server) |
|
|
|
|
fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server) |
|
|
|
|
blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) |
|
|
|
|
|
|
|
|
|
fmt.Println() |
|
|
|
|
return string(blob), err |
|
|
|
|
})) |
|
|
|
|
// Resolve the IP address of the remote server
|
|
|
|
|
addr, err := net.LookupHost(label) |
|
|
|
|
addr, err := net.LookupHost(hostname) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
@ -109,10 +120,7 @@ func dial(server string, pubkey []byte) (*sshClient, error) { |
|
|
|
|
return nil, errors.New("no IPs associated with domain") |
|
|
|
|
} |
|
|
|
|
// Try to dial in to the remote server
|
|
|
|
|
logger.Trace("Dialing remote SSH server", "user", login) |
|
|
|
|
if !strings.Contains(server, ":") { |
|
|
|
|
server += ":22" |
|
|
|
|
} |
|
|
|
|
logger.Trace("Dialing remote SSH server", "user", username) |
|
|
|
|
keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error { |
|
|
|
|
// If no public key is known for SSH, ask the user to confirm
|
|
|
|
|
if pubkey == nil { |
|
|
|
@ -139,13 +147,13 @@ func dial(server string, pubkey []byte) (*sshClient, error) { |
|
|
|
|
// We have a mismatch, forbid connecting
|
|
|
|
|
return errors.New("ssh key mismatch, readd the machine to update") |
|
|
|
|
} |
|
|
|
|
client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck}) |
|
|
|
|
client, err := ssh.Dial("tcp", hostport, &ssh.ClientConfig{User: username, Auth: auths, HostKeyCallback: keycheck}) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
// Connection established, return our utility wrapper
|
|
|
|
|
c := &sshClient{ |
|
|
|
|
server: label, |
|
|
|
|
server: hostname, |
|
|
|
|
address: addr[0], |
|
|
|
|
pubkey: pubkey, |
|
|
|
|
client: client, |
|
|
|
|