mirror of https://github.com/ethereum/go-ethereum
cmd/puppeth: your Ethereum private network manager (#13854)
parent
18bbe12425
commit
706a1e552c
@ -0,0 +1,456 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
// faucet is a Ether faucet backed by a light client.
|
||||
package main |
||||
|
||||
//go:generate go-bindata -nometadata -o website.go faucet.html
|
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"encoding/json" |
||||
"flag" |
||||
"fmt" |
||||
"html/template" |
||||
"io/ioutil" |
||||
"math/big" |
||||
"net/http" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/accounts" |
||||
"github.com/ethereum/go-ethereum/accounts/keystore" |
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/ethereum/go-ethereum/core" |
||||
"github.com/ethereum/go-ethereum/core/types" |
||||
"github.com/ethereum/go-ethereum/eth" |
||||
"github.com/ethereum/go-ethereum/ethclient" |
||||
"github.com/ethereum/go-ethereum/ethstats" |
||||
"github.com/ethereum/go-ethereum/les" |
||||
"github.com/ethereum/go-ethereum/log" |
||||
"github.com/ethereum/go-ethereum/node" |
||||
"github.com/ethereum/go-ethereum/p2p/discover" |
||||
"github.com/ethereum/go-ethereum/p2p/discv5" |
||||
"github.com/ethereum/go-ethereum/p2p/nat" |
||||
"github.com/ethereum/go-ethereum/params" |
||||
"golang.org/x/net/websocket" |
||||
) |
||||
|
||||
var ( |
||||
genesisFlag = flag.String("genesis", "", "Genesis json file to seed the chain with") |
||||
apiPortFlag = flag.Int("apiport", 8080, "Listener port for the HTTP API connection") |
||||
ethPortFlag = flag.Int("ethport", 30303, "Listener port for the devp2p connection") |
||||
bootFlag = flag.String("bootnodes", "", "Comma separated bootnode enode URLs to seed with") |
||||
netFlag = flag.Int("network", 0, "Network ID to use for the Ethereum protocol") |
||||
statsFlag = flag.String("ethstats", "", "Ethstats network monitoring auth string") |
||||
|
||||
netnameFlag = flag.String("faucet.name", "", "Network name to assign to the faucet") |
||||
payoutFlag = flag.Int("faucet.amount", 1, "Number of Ethers to pay out per user request") |
||||
minutesFlag = flag.Int("faucet.minutes", 1440, "Number of minutes to wait between funding rounds") |
||||
|
||||
accJSONFlag = flag.String("account.json", "", "Key json file to fund user requests with") |
||||
accPassFlag = flag.String("account.pass", "", "Decryption password to access faucet funds") |
||||
|
||||
githubUser = flag.String("github.user", "", "GitHub user to authenticate with for Gist access") |
||||
githubToken = flag.String("github.token", "", "GitHub personal token to access Gists with") |
||||
|
||||
logFlag = flag.Int("loglevel", 3, "Log level to use for Ethereum and the faucet") |
||||
) |
||||
|
||||
var ( |
||||
ether = new(big.Int).Exp(big.NewInt(10), big.NewInt(18), nil) |
||||
) |
||||
|
||||
func main() { |
||||
// Parse the flags and set up the logger to print everything requested
|
||||
flag.Parse() |
||||
log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*logFlag), log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) |
||||
|
||||
// Load up and render the faucet website
|
||||
tmpl, err := Asset("faucet.html") |
||||
if err != nil { |
||||
log.Crit("Failed to load the faucet template", "err", err) |
||||
} |
||||
period := fmt.Sprintf("%d minute(s)", *minutesFlag) |
||||
if *minutesFlag%60 == 0 { |
||||
period = fmt.Sprintf("%d hour(s)", *minutesFlag/60) |
||||
} |
||||
website := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(string(tmpl))).Execute(website, map[string]interface{}{ |
||||
"Network": *netnameFlag, |
||||
"Amount": *payoutFlag, |
||||
"Period": period, |
||||
}) |
||||
// Load and parse the genesis block requested by the user
|
||||
blob, err := ioutil.ReadFile(*genesisFlag) |
||||
if err != nil { |
||||
log.Crit("Failed to read genesis block contents", "genesis", *genesisFlag, "err", err) |
||||
} |
||||
genesis := new(core.Genesis) |
||||
if err = json.Unmarshal(blob, genesis); err != nil { |
||||
log.Crit("Failed to parse genesis block json", "err", err) |
||||
} |
||||
// Convert the bootnodes to internal enode representations
|
||||
var enodes []*discv5.Node |
||||
for _, boot := range strings.Split(*bootFlag, ",") { |
||||
if url, err := discv5.ParseNode(boot); err == nil { |
||||
enodes = append(enodes, url) |
||||
} else { |
||||
log.Error("Failed to parse bootnode URL", "url", boot, "err", err) |
||||
} |
||||
} |
||||
// Load up the account key and decrypt its password
|
||||
if blob, err = ioutil.ReadFile(*accPassFlag); err != nil { |
||||
log.Crit("Failed to read account password contents", "file", *accPassFlag, "err", err) |
||||
} |
||||
pass := string(blob) |
||||
|
||||
ks := keystore.NewKeyStore(filepath.Join(os.Getenv("HOME"), ".faucet", "keys"), keystore.StandardScryptN, keystore.StandardScryptP) |
||||
if blob, err = ioutil.ReadFile(*accJSONFlag); err != nil { |
||||
log.Crit("Failed to read account key contents", "file", *accJSONFlag, "err", err) |
||||
} |
||||
acc, err := ks.Import(blob, pass, pass) |
||||
if err != nil { |
||||
log.Crit("Failed to import faucet signer account", "err", err) |
||||
} |
||||
ks.Unlock(acc, pass) |
||||
|
||||
// Assemble and start the faucet light service
|
||||
faucet, err := newFaucet(genesis, *ethPortFlag, enodes, *netFlag, *statsFlag, ks, website.Bytes()) |
||||
if err != nil { |
||||
log.Crit("Failed to start faucet", "err", err) |
||||
} |
||||
defer faucet.close() |
||||
|
||||
if err := faucet.listenAndServe(*apiPortFlag); err != nil { |
||||
log.Crit("Failed to launch faucet API", "err", err) |
||||
} |
||||
} |
||||
|
||||
// request represents an accepted funding request.
|
||||
type request struct { |
||||
Username string `json:"username"` // GitHub user for displaying an avatar
|
||||
Account common.Address `json:"account"` // Ethereum address being funded
|
||||
Time time.Time `json:"time"` // Timestamp when te request was accepted
|
||||
Tx *types.Transaction `json:"tx"` // Transaction funding the account
|
||||
} |
||||
|
||||
// faucet represents a crypto faucet backed by an Ethereum light client.
|
||||
type faucet struct { |
||||
config *params.ChainConfig // Chain configurations for signing
|
||||
stack *node.Node // Ethereum protocol stack
|
||||
client *ethclient.Client // Client connection to the Ethereum chain
|
||||
index []byte // Index page to serve up on the web
|
||||
|
||||
keystore *keystore.KeyStore // Keystore containing the single signer
|
||||
account accounts.Account // Account funding user faucet requests
|
||||
nonce uint64 // Current pending nonce of the faucet
|
||||
price *big.Int // Current gas price to issue funds with
|
||||
|
||||
conns []*websocket.Conn // Currently live websocket connections
|
||||
history map[string]time.Time // History of users and their funding requests
|
||||
reqs []*request // Currently pending funding requests
|
||||
update chan struct{} // Channel to signal request updates
|
||||
|
||||
lock sync.RWMutex // Lock protecting the faucet's internals
|
||||
} |
||||
|
||||
func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network int, stats string, ks *keystore.KeyStore, index []byte) (*faucet, error) { |
||||
// Assemble the raw devp2p protocol stack
|
||||
stack, err := node.New(&node.Config{ |
||||
Name: "geth", |
||||
Version: params.Version, |
||||
DataDir: filepath.Join(os.Getenv("HOME"), ".faucet"), |
||||
NAT: nat.Any(), |
||||
DiscoveryV5: true, |
||||
ListenAddr: fmt.Sprintf(":%d", port), |
||||
DiscoveryV5Addr: fmt.Sprintf(":%d", port+1), |
||||
MaxPeers: 25, |
||||
BootstrapNodesV5: enodes, |
||||
}) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
// Assemble the Ethereum light client protocol
|
||||
if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { |
||||
return les.New(ctx, ð.Config{ |
||||
LightMode: true, |
||||
NetworkId: network, |
||||
Genesis: genesis, |
||||
GasPrice: big.NewInt(20 * params.Shannon), |
||||
GpoBlocks: 10, |
||||
GpoPercentile: 50, |
||||
EthashCacheDir: "ethash", |
||||
EthashCachesInMem: 2, |
||||
EthashCachesOnDisk: 3, |
||||
}) |
||||
}); err != nil { |
||||
return nil, err |
||||
} |
||||
// Assemble the ethstats monitoring and reporting service'
|
||||
if stats != "" { |
||||
if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { |
||||
var serv *les.LightEthereum |
||||
ctx.Service(&serv) |
||||
return ethstats.New(stats, nil, serv) |
||||
}); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
// Boot up the client and ensure it connects to bootnodes
|
||||
if err := stack.Start(); err != nil { |
||||
return nil, err |
||||
} |
||||
for _, boot := range enodes { |
||||
old, _ := discover.ParseNode(boot.String()) |
||||
stack.Server().AddPeer(old) |
||||
} |
||||
// Attach to the client and retrieve and interesting metadatas
|
||||
api, err := stack.Attach() |
||||
if err != nil { |
||||
stack.Stop() |
||||
return nil, err |
||||
} |
||||
client := ethclient.NewClient(api) |
||||
|
||||
return &faucet{ |
||||
config: genesis.Config, |
||||
stack: stack, |
||||
client: client, |
||||
index: index, |
||||
keystore: ks, |
||||
account: ks.Accounts()[0], |
||||
history: make(map[string]time.Time), |
||||
update: make(chan struct{}, 1), |
||||
}, nil |
||||
} |
||||
|
||||
// close terminates the Ethereum connection and tears down the faucet.
|
||||
func (f *faucet) close() error { |
||||
return f.stack.Stop() |
||||
} |
||||
|
||||
// listenAndServe registers the HTTP handlers for the faucet and boots it up
|
||||
// for service user funding requests.
|
||||
func (f *faucet) listenAndServe(port int) error { |
||||
go f.loop() |
||||
|
||||
http.HandleFunc("/", f.webHandler) |
||||
http.Handle("/api", websocket.Handler(f.apiHandler)) |
||||
|
||||
return http.ListenAndServe(fmt.Sprintf(":%d", port), nil) |
||||
} |
||||
|
||||
// webHandler handles all non-api requests, simply flattening and returning the
|
||||
// faucet website.
|
||||
func (f *faucet) webHandler(w http.ResponseWriter, r *http.Request) { |
||||
w.Write(f.index) |
||||
} |
||||
|
||||
// apiHandler handles requests for Ether grants and transaction statuses.
|
||||
func (f *faucet) apiHandler(conn *websocket.Conn) { |
||||
// Start tracking the connection and drop at the end
|
||||
f.lock.Lock() |
||||
f.conns = append(f.conns, conn) |
||||
f.lock.Unlock() |
||||
|
||||
defer func() { |
||||
f.lock.Lock() |
||||
for i, c := range f.conns { |
||||
if c == conn { |
||||
f.conns = append(f.conns[:i], f.conns[i+1:]...) |
||||
break |
||||
} |
||||
} |
||||
f.lock.Unlock() |
||||
}() |
||||
// Send a few initial stats to the client
|
||||
balance, _ := f.client.BalanceAt(context.Background(), f.account.Address, nil) |
||||
nonce, _ := f.client.NonceAt(context.Background(), f.account.Address, nil) |
||||
|
||||
websocket.JSON.Send(conn, map[string]interface{}{ |
||||
"funds": balance.Div(balance, ether), |
||||
"funded": nonce, |
||||
"peers": f.stack.Server().PeerCount(), |
||||
"requests": f.reqs, |
||||
}) |
||||
header, _ := f.client.HeaderByNumber(context.Background(), nil) |
||||
websocket.JSON.Send(conn, header) |
||||
|
||||
// Keep reading requests from the websocket until the connection breaks
|
||||
for { |
||||
// Fetch the next funding request and validate against github
|
||||
var msg struct { |
||||
URL string `json:"url"` |
||||
} |
||||
if err := websocket.JSON.Receive(conn, &msg); err != nil { |
||||
return |
||||
} |
||||
if !strings.HasPrefix(msg.URL, "https://gist.github.com/") { |
||||
websocket.JSON.Send(conn, map[string]string{"error": "URL doesn't link to GitHub Gists"}) |
||||
continue |
||||
} |
||||
log.Info("Faucet funds requested", "gist", msg.URL) |
||||
|
||||
// Retrieve the gist from the GitHub Gist APIs
|
||||
parts := strings.Split(msg.URL, "/") |
||||
req, _ := http.NewRequest("GET", "https://api.github.com/gists/"+parts[len(parts)-1], nil) |
||||
if *githubUser != "" { |
||||
req.SetBasicAuth(*githubUser, *githubToken) |
||||
} |
||||
res, err := http.DefaultClient.Do(req) |
||||
if err != nil { |
||||
websocket.JSON.Send(conn, map[string]string{"error": err.Error()}) |
||||
continue |
||||
} |
||||
var gist struct { |
||||
Owner struct { |
||||
Login string `json:"login"` |
||||
} `json:"owner"` |
||||
Files map[string]struct { |
||||
Content string `json:"content"` |
||||
} `json:"files"` |
||||
} |
||||
err = json.NewDecoder(res.Body).Decode(&gist) |
||||
res.Body.Close() |
||||
if err != nil { |
||||
websocket.JSON.Send(conn, map[string]string{"error": err.Error()}) |
||||
continue |
||||
} |
||||
if gist.Owner.Login == "" { |
||||
websocket.JSON.Send(conn, map[string]string{"error": "Nice try ;)"}) |
||||
continue |
||||
} |
||||
// Iterate over all the files and look for Ethereum addresses
|
||||
var address common.Address |
||||
for _, file := range gist.Files { |
||||
if len(file.Content) == 2+common.AddressLength*2 { |
||||
address = common.HexToAddress(file.Content) |
||||
} |
||||
} |
||||
if address == (common.Address{}) { |
||||
websocket.JSON.Send(conn, map[string]string{"error": "No Ethereum address found to fund"}) |
||||
continue |
||||
} |
||||
// Ensure the user didn't request funds too recently
|
||||
f.lock.Lock() |
||||
var ( |
||||
fund bool |
||||
elapsed time.Duration |
||||
) |
||||
if elapsed = time.Since(f.history[gist.Owner.Login]); elapsed > time.Duration(*minutesFlag)*time.Minute { |
||||
// User wasn't funded recently, create the funding transaction
|
||||
tx := types.NewTransaction(f.nonce+uint64(len(f.reqs)), address, new(big.Int).Mul(big.NewInt(int64(*payoutFlag)), ether), big.NewInt(21000), f.price, nil) |
||||
signed, err := f.keystore.SignTx(f.account, tx, f.config.ChainId) |
||||
if err != nil { |
||||
websocket.JSON.Send(conn, map[string]string{"error": err.Error()}) |
||||
f.lock.Unlock() |
||||
continue |
||||
} |
||||
// Submit the transaction and mark as funded if successful
|
||||
if err := f.client.SendTransaction(context.Background(), signed); err != nil { |
||||
websocket.JSON.Send(conn, map[string]string{"error": err.Error()}) |
||||
f.lock.Unlock() |
||||
continue |
||||
} |
||||
f.reqs = append(f.reqs, &request{ |
||||
Username: gist.Owner.Login, |
||||
Account: address, |
||||
Time: time.Now(), |
||||
Tx: signed, |
||||
}) |
||||
f.history[gist.Owner.Login] = time.Now() |
||||
fund = true |
||||
} |
||||
f.lock.Unlock() |
||||
|
||||
// Send an error if too frequent funding, othewise a success
|
||||
if !fund { |
||||
websocket.JSON.Send(conn, map[string]string{"error": fmt.Sprintf("User already funded %s ago", common.PrettyDuration(elapsed))}) |
||||
continue |
||||
} |
||||
websocket.JSON.Send(conn, map[string]string{"success": fmt.Sprintf("Funding request accepted for %s into %s", gist.Owner.Login, address.Hex())}) |
||||
select { |
||||
case f.update <- struct{}{}: |
||||
default: |
||||
} |
||||
} |
||||
} |
||||
|
||||
// loop keeps waiting for interesting events and pushes them out to connected
|
||||
// websockets.
|
||||
func (f *faucet) loop() { |
||||
// Wait for chain events and push them to clients
|
||||
heads := make(chan *types.Header, 16) |
||||
sub, err := f.client.SubscribeNewHead(context.Background(), heads) |
||||
if err != nil { |
||||
log.Crit("Failed to subscribe to head events", "err", err) |
||||
} |
||||
defer sub.Unsubscribe() |
||||
|
||||
for { |
||||
select { |
||||
case head := <-heads: |
||||
// New chain head arrived, query the current stats and stream to clients
|
||||
balance, _ := f.client.BalanceAt(context.Background(), f.account.Address, nil) |
||||
balance = new(big.Int).Div(balance, ether) |
||||
|
||||
price, _ := f.client.SuggestGasPrice(context.Background()) |
||||
nonce, _ := f.client.NonceAt(context.Background(), f.account.Address, nil) |
||||
|
||||
f.lock.Lock() |
||||
f.price, f.nonce = price, nonce |
||||
for len(f.reqs) > 0 && f.reqs[0].Tx.Nonce() < f.nonce { |
||||
f.reqs = f.reqs[1:] |
||||
} |
||||
f.lock.Unlock() |
||||
|
||||
f.lock.RLock() |
||||
for _, conn := range f.conns { |
||||
if err := websocket.JSON.Send(conn, map[string]interface{}{ |
||||
"funds": balance, |
||||
"funded": f.nonce, |
||||
"peers": f.stack.Server().PeerCount(), |
||||
"requests": f.reqs, |
||||
}); err != nil { |
||||
log.Warn("Failed to send stats to client", "err", err) |
||||
conn.Close() |
||||
continue |
||||
} |
||||
if err := websocket.JSON.Send(conn, head); err != nil { |
||||
log.Warn("Failed to send header to client", "err", err) |
||||
conn.Close() |
||||
} |
||||
} |
||||
f.lock.RUnlock() |
||||
|
||||
case <-f.update: |
||||
// Pending requests updated, stream to clients
|
||||
f.lock.RLock() |
||||
for _, conn := range f.conns { |
||||
if err := websocket.JSON.Send(conn, map[string]interface{}{"requests": f.reqs}); err != nil { |
||||
log.Warn("Failed to send requests to client", "err", err) |
||||
conn.Close() |
||||
} |
||||
} |
||||
f.lock.RUnlock() |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,143 @@ |
||||
<!DOCTYPE html> |
||||
<html lang="en"> |
||||
<head> |
||||
<meta charset="utf-8"> |
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge"> |
||||
<meta name="viewport" content="width=device-width, initial-scale=1"> |
||||
|
||||
<title>{{.Network}}: GitHub Faucet</title> |
||||
|
||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/css/bootstrap.min.css" rel="stylesheet" /> |
||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css" rel="stylesheet" /> |
||||
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.1.1/jquery.min.js"></script> |
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery-noty/2.4.1/packaged/jquery.noty.packaged.min.js"></script> |
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/js/bootstrap.min.js"></script> |
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/moment.js/2.18.0/moment.min.js"></script> |
||||
|
||||
<style> |
||||
.vertical-center { |
||||
min-height: 100%; |
||||
min-height: 100vh; |
||||
display: flex; |
||||
align-items: center; |
||||
} |
||||
.progress { |
||||
position: relative; |
||||
} |
||||
.progress span { |
||||
position: absolute; |
||||
display: block; |
||||
width: 100%; |
||||
color: white; |
||||
} |
||||
pre { |
||||
padding: 6px; |
||||
margin: 0; |
||||
} |
||||
</style> |
||||
</head> |
||||
|
||||
<body> |
||||
<div class="vertical-center"> |
||||
<div class="container"> |
||||
<div class="row" style="margin-bottom: 16px;"> |
||||
<div class="col-lg-12"> |
||||
<h1 style="text-align: center;"><i class="fa fa-bath" aria-hidden="true"></i> {{.Network}} GitHub Authenticated Faucet <i class="fa fa-github-alt" aria-hidden="true"></i></h1> |
||||
</div> |
||||
</div> |
||||
<div class="row"> |
||||
<div class="col-lg-8 col-lg-offset-2"> |
||||
<div class="input-group"> |
||||
<input id="gist" type="text" class="form-control" placeholder="GitHub Gist URL containing your Ethereum address..."> |
||||
<span class="input-group-btn"> |
||||
<button class="btn btn-default" type="button" onclick="submit()">Give me Ether!</button> |
||||
</span> |
||||
</div> |
||||
</div> |
||||
</div> |
||||
<div class="row" style="margin-top: 32px;"> |
||||
<div class="col-lg-6 col-lg-offset-3"> |
||||
<div class="panel panel-small panel-default"> |
||||
<div class="panel-body" style="padding: 0; overflow: auto; max-height: 300px;"> |
||||
<table id="requests" class="table table-condensed" style="margin: 0;"></table> |
||||
</div> |
||||
<div class="panel-footer"> |
||||
<table style="width: 100%"><tr> |
||||
<td style="text-align: center;"><i class="fa fa-rss" aria-hidden="true"></i> <span id="peers"></span> peers</td> |
||||
<td style="text-align: center;"><i class="fa fa-database" aria-hidden="true"></i> <span id="block"></span> blocks</td> |
||||
<td style="text-align: center;"><i class="fa fa-heartbeat" aria-hidden="true"></i> <span id="funds"></span> Ethers</td> |
||||
<td style="text-align: center;"><i class="fa fa-university" aria-hidden="true"></i> <span id="funded"></span> funded</td> |
||||
</tr></table> |
||||
</div> |
||||
</div> |
||||
</div> |
||||
</div> |
||||
<div class="row" style="margin-top: 32px;"> |
||||
<div class="col-lg-12"> |
||||
<h3>How does this work?</h3> |
||||
<p>This Ether faucet is running on the {{.Network}} network. To prevent malicious actors from exhausting all available funds or accumulating enough Ether to mount long running spam attacks, requests are tied to GitHub accounts. Anyone having a GitHub account may request funds within the permitted limit of <strong>{{.Amount}} Ether(s) / {{.Period}}</strong>.</p> |
||||
<p>To request funds, simply create a <a href="https://gist.github.com/" target="_about:blank">GitHub Gist</a> with your Ethereum address pasted into the contents (the file name doesn't matter), copy paste the gists URL into the above input box and fire away! You can track the current pending requests below the input field to see how much you have to wait until your turn comes.</p> |
||||
</div> |
||||
</div> |
||||
</div> |
||||
</div> |
||||
|
||||
<script> |
||||
// Global variables to hold the current status of the faucet |
||||
var attempt = 0; |
||||
var server; |
||||
|
||||
// Define the function that submits a gist url to the server |
||||
var submit = function() { |
||||
server.send(JSON.stringify({url: $("#gist")[0].value})); |
||||
}; |
||||
// Define a method to reconnect upon server loss |
||||
var reconnect = function() { |
||||
if (attempt % 2 == 0) { |
||||
server = new WebSocket("wss://" + location.host + "/api"); |
||||
} else { |
||||
server = new WebSocket("ws://" + location.host + "/api"); |
||||
} |
||||
attempt++; |
||||
|
||||
server.onmessage = function(event) { |
||||
var msg = JSON.parse(event.data); |
||||
if (msg === null) { |
||||
return; |
||||
} |
||||
|
||||
if (msg.funds !== undefined) { |
||||
$("#funds").text(msg.funds); |
||||
} |
||||
if (msg.funded !== undefined) { |
||||
$("#funded").text(msg.funded); |
||||
} |
||||
if (msg.peers !== undefined) { |
||||
$("#peers").text(msg.peers); |
||||
} |
||||
if (msg.number !== undefined) { |
||||
$("#block").text(parseInt(msg.number, 16)); |
||||
} |
||||
if (msg.error !== undefined) { |
||||
noty({layout: 'topCenter', text: msg.error, type: 'error'}); |
||||
} |
||||
if (msg.success !== undefined) { |
||||
noty({layout: 'topCenter', text: msg.success, type: 'success'}); |
||||
} |
||||
if (msg.requests !== undefined && msg.requests !== null) { |
||||
var content = ""; |
||||
for (var i=0; i<msg.requests.length; i++) { |
||||
content += "<tr><td><div style=\"background: url('https://github.com/" + msg.requests[i].username + ".png?size=64'); background-size: cover; width:32px; height: 32px; border-radius: 4px;\"></div></td><td><pre>" + msg.requests[i].account + "</pre></td><td style=\"width: 100%; text-align: center; vertical-align: middle;\">" + moment.duration(moment(msg.requests[i].time).unix()-moment().unix(), 'seconds').humanize(true) + "</td></tr>"; |
||||
} |
||||
$("#requests").html("<tbody>" + content + "</tbody>"); |
||||
} |
||||
} |
||||
server.onclose = function() { setTimeout(reconnect, 3000); }; |
||||
server.onerror = function() { setTimeout(reconnect, 3000); }; |
||||
} |
||||
// Establish a websocket connection to the API server |
||||
reconnect(); |
||||
</script> |
||||
</body> |
||||
</html> |
File diff suppressed because one or more lines are too long
@ -0,0 +1,152 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
var ( |
||||
// ErrServiceUnknown is returned when a service container doesn't exist.
|
||||
ErrServiceUnknown = errors.New("service unknown") |
||||
|
||||
// ErrServiceOffline is returned when a service container exists, but it is not
|
||||
// running.
|
||||
ErrServiceOffline = errors.New("service offline") |
||||
|
||||
// ErrServiceUnreachable is returned when a service container is running, but
|
||||
// seems to not respond to communication attempts.
|
||||
ErrServiceUnreachable = errors.New("service unreachable") |
||||
|
||||
// ErrNotExposed is returned if a web-service doesn't have an exposed port, nor
|
||||
// a reverse-proxy in front of it to forward requests.
|
||||
ErrNotExposed = errors.New("service not exposed, nor proxied") |
||||
) |
||||
|
||||
// containerInfos is a heavily reduced version of the huge inspection dataset
|
||||
// returned from docker inspect, parsed into a form easily usable by puppeth.
|
||||
type containerInfos struct { |
||||
running bool // Flag whether the container is running currently
|
||||
envvars map[string]string // Collection of environmental variables set on the container
|
||||
portmap map[string]int // Port mapping from internal port/proto combos to host binds
|
||||
volumes map[string]string // Volume mount points from container to host directories
|
||||
} |
||||
|
||||
// inspectContainer runs docker inspect against a running container
|
||||
func inspectContainer(client *sshClient, container string) (*containerInfos, error) { |
||||
// Check whether there's a container running for the service
|
||||
out, err := client.Run(fmt.Sprintf("docker inspect %s", container)) |
||||
if err != nil { |
||||
return nil, ErrServiceUnknown |
||||
} |
||||
// If yes, extract various configuration options
|
||||
type inspection struct { |
||||
State struct { |
||||
Running bool |
||||
} |
||||
Mounts []struct { |
||||
Source string |
||||
Destination string |
||||
} |
||||
Config struct { |
||||
Env []string |
||||
} |
||||
HostConfig struct { |
||||
PortBindings map[string][]map[string]string |
||||
} |
||||
} |
||||
var inspects []inspection |
||||
if err = json.Unmarshal(out, &inspects); err != nil { |
||||
return nil, err |
||||
} |
||||
inspect := inspects[0] |
||||
|
||||
// Infos retrieved, parse the above into something meaningful
|
||||
infos := &containerInfos{ |
||||
running: inspect.State.Running, |
||||
envvars: make(map[string]string), |
||||
portmap: make(map[string]int), |
||||
volumes: make(map[string]string), |
||||
} |
||||
for _, envvar := range inspect.Config.Env { |
||||
if parts := strings.Split(envvar, "="); len(parts) == 2 { |
||||
infos.envvars[parts[0]] = parts[1] |
||||
} |
||||
} |
||||
for portname, details := range inspect.HostConfig.PortBindings { |
||||
if len(details) > 0 { |
||||
port, _ := strconv.Atoi(details[0]["HostPort"]) |
||||
infos.portmap[portname] = port |
||||
} |
||||
} |
||||
for _, mount := range inspect.Mounts { |
||||
infos.volumes[mount.Destination] = mount.Source |
||||
} |
||||
return infos, err |
||||
} |
||||
|
||||
// tearDown connects to a remote machine via SSH and terminates docker containers
|
||||
// running with the specified name in the specified network.
|
||||
func tearDown(client *sshClient, network string, service string, purge bool) ([]byte, error) { |
||||
// Tear down the running (or paused) container
|
||||
out, err := client.Run(fmt.Sprintf("docker rm -f %s_%s_1", network, service)) |
||||
if err != nil { |
||||
return out, err |
||||
} |
||||
// If requested, purge the associated docker image too
|
||||
if purge { |
||||
return client.Run(fmt.Sprintf("docker rmi %s/%s", network, service)) |
||||
} |
||||
return nil, nil |
||||
} |
||||
|
||||
// resolve retrieves the hostname a service is running on either by returning the
|
||||
// actual server name and port, or preferably an nginx virtual host if available.
|
||||
func resolve(client *sshClient, network string, service string, port int) (string, error) { |
||||
// Inspect the service to get various configurations from it
|
||||
infos, err := inspectContainer(client, fmt.Sprintf("%s_%s_1", network, service)) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
if !infos.running { |
||||
return "", ErrServiceOffline |
||||
} |
||||
// Container online, extract any environmental variables
|
||||
if vhost := infos.envvars["VIRTUAL_HOST"]; vhost != "" { |
||||
return vhost, nil |
||||
} |
||||
return fmt.Sprintf("%s:%d", client.server, port), nil |
||||
} |
||||
|
||||
// checkPort tries to connect to a remote host on a given
|
||||
func checkPort(host string, port int) error { |
||||
log.Trace("Verifying remote TCP connectivity", "server", host, "port", port) |
||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Second) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
conn.Close() |
||||
return nil |
||||
} |
File diff suppressed because one or more lines are too long
@ -0,0 +1,159 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"math/rand" |
||||
"path/filepath" |
||||
"strings" |
||||
"text/template" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// ethstatsDockerfile is the Dockerfile required to build an ethstats backend
|
||||
// and associated monitoring site.
|
||||
var ethstatsDockerfile = ` |
||||
FROM mhart/alpine-node:latest |
||||
|
||||
RUN \
|
||||
apk add --update git && \
|
||||
git clone --depth=1 https://github.com/karalabe/eth-netstats && \
|
||||
apk del git && rm -rf /var/cache/apk/* && \
|
||||
\
|
||||
cd /eth-netstats && npm install && npm install -g grunt-cli && grunt |
||||
|
||||
WORKDIR /eth-netstats |
||||
EXPOSE 3000 |
||||
|
||||
RUN echo 'module.exports = {trusted: [{{.Trusted}}], banned: []};' > lib/utils/config.js |
||||
|
||||
CMD ["npm", "start"] |
||||
` |
||||
|
||||
// ethstatsComposefile is the docker-compose.yml file required to deploy and
|
||||
// maintain an ethstats monitoring site.
|
||||
var ethstatsComposefile = ` |
||||
version: '2' |
||||
services: |
||||
ethstats: |
||||
build: . |
||||
image: {{.Network}}/ethstats{{if not .VHost}} |
||||
ports: |
||||
- "{{.Port}}:3000"{{end}} |
||||
environment: |
||||
- WS_SECRET={{.Secret}}{{if .VHost}} |
||||
- VIRTUAL_HOST={{.VHost}}{{end}} |
||||
restart: always |
||||
` |
||||
|
||||
// deployEthstats deploys a new ethstats container to a remote machine via SSH,
|
||||
// docker and docker-compose. If an instance with the specified network name
|
||||
// already exists there, it will be overwritten!
|
||||
func deployEthstats(client *sshClient, network string, port int, secret string, vhost string, trusted []string) ([]byte, error) { |
||||
// Generate the content to upload to the server
|
||||
workdir := fmt.Sprintf("%d", rand.Int63()) |
||||
files := make(map[string][]byte) |
||||
|
||||
for i, address := range trusted { |
||||
trusted[i] = fmt.Sprintf("\"%s\"", address) |
||||
} |
||||
|
||||
dockerfile := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(ethstatsDockerfile)).Execute(dockerfile, map[string]interface{}{ |
||||
"Trusted": strings.Join(trusted, ", "), |
||||
}) |
||||
files[filepath.Join(workdir, "Dockerfile")] = dockerfile.Bytes() |
||||
|
||||
composefile := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(ethstatsComposefile)).Execute(composefile, map[string]interface{}{ |
||||
"Network": network, |
||||
"Port": port, |
||||
"Secret": secret, |
||||
"VHost": vhost, |
||||
}) |
||||
files[filepath.Join(workdir, "docker-compose.yaml")] = composefile.Bytes() |
||||
|
||||
// Upload the deployment files to the remote server (and clean up afterwards)
|
||||
if out, err := client.Upload(files); err != nil { |
||||
return out, err |
||||
} |
||||
defer client.Run("rm -rf " + workdir) |
||||
|
||||
// Build and deploy the ethstats service
|
||||
return nil, client.Stream(fmt.Sprintf("cd %s && docker-compose -p %s up -d --build", workdir, network)) |
||||
} |
||||
|
||||
// ethstatsInfos is returned from an ethstats status check to allow reporting
|
||||
// various configuration parameters.
|
||||
type ethstatsInfos struct { |
||||
host string |
||||
port int |
||||
secret string |
||||
config string |
||||
} |
||||
|
||||
// String implements the stringer interface.
|
||||
func (info *ethstatsInfos) String() string { |
||||
return fmt.Sprintf("host=%s, port=%d, secret=%s", info.host, info.port, info.secret) |
||||
} |
||||
|
||||
// checkEthstats does a health-check against an ethstats server to verify whether
|
||||
// it's running, and if yes, gathering a collection of useful infos about it.
|
||||
func checkEthstats(client *sshClient, network string) (*ethstatsInfos, error) { |
||||
// Inspect a possible ethstats container on the host
|
||||
infos, err := inspectContainer(client, fmt.Sprintf("%s_ethstats_1", network)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if !infos.running { |
||||
return nil, ErrServiceOffline |
||||
} |
||||
// Resolve the port from the host, or the reverse proxy
|
||||
port := infos.portmap["3000/tcp"] |
||||
if port == 0 { |
||||
if proxy, _ := checkNginx(client, network); proxy != nil { |
||||
port = proxy.port |
||||
} |
||||
} |
||||
if port == 0 { |
||||
return nil, ErrNotExposed |
||||
} |
||||
// Resolve the host from the reverse-proxy and configure the connection string
|
||||
host := infos.envvars["VIRTUAL_HOST"] |
||||
if host == "" { |
||||
host = client.server |
||||
} |
||||
secret := infos.envvars["WS_SECRET"] |
||||
config := fmt.Sprintf("%s@%s", secret, host) |
||||
if port != 80 && port != 443 { |
||||
config += fmt.Sprintf(":%d", port) |
||||
} |
||||
// Run a sanity check to see if the port is reachable
|
||||
if err = checkPort(host, port); err != nil { |
||||
log.Warn("Ethstats service seems unreachable", "server", host, "port", port, "err", err) |
||||
} |
||||
// Container available, assemble and return the useful infos
|
||||
return ðstatsInfos{ |
||||
host: host, |
||||
port: port, |
||||
secret: secret, |
||||
config: config, |
||||
}, nil |
||||
} |
@ -0,0 +1,210 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"html/template" |
||||
"math/rand" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// faucetDockerfile is the Dockerfile required to build an faucet container to
|
||||
// grant crypto tokens based on GitHub authentications.
|
||||
var faucetDockerfile = ` |
||||
FROM alpine:latest |
||||
|
||||
RUN mkdir /go |
||||
ENV GOPATH /go |
||||
|
||||
RUN \
|
||||
apk add --update git go make gcc musl-dev ca-certificates linux-headers && \
|
||||
mkdir -p $GOPATH/src/github.com/ethereum && \
|
||||
(cd $GOPATH/src/github.com/ethereum && git clone --depth=1 https://github.com/ethereum/go-ethereum) && \
|
||||
go build -v github.com/ethereum/go-ethereum/cmd/faucet && \
|
||||
apk del git go make gcc musl-dev linux-headers && \
|
||||
rm -rf $GOPATH && rm -rf /var/cache/apk/* |
||||
|
||||
ADD genesis.json /genesis.json |
||||
ADD account.json /account.json |
||||
ADD account.pass /account.pass |
||||
|
||||
EXPOSE 8080 |
||||
|
||||
CMD [ \
|
||||
"/faucet", "--genesis", "/genesis.json", "--network", "{{.NetworkID}}", "--bootnodes", "{{.Bootnodes}}", "--ethstats", "{{.Ethstats}}", \
|
||||
"--ethport", "{{.EthPort}}", "--faucet.name", "{{.FaucetName}}", "--faucet.amount", "{{.FaucetAmount}}", "--faucet.minutes", "{{.FaucetMinutes}}", \
|
||||
"--github.user", "{{.GitHubUser}}", "--github.token", "{{.GitHubToken}}", "--account.json", "/account.json", "--account.pass", "/account.pass" \
|
||||
]` |
||||
|
||||
// faucetComposefile is the docker-compose.yml file required to deploy and maintain
|
||||
// a crypto faucet.
|
||||
var faucetComposefile = ` |
||||
version: '2' |
||||
services: |
||||
faucet: |
||||
build: . |
||||
image: {{.Network}}/faucet |
||||
ports: |
||||
- "{{.EthPort}}:{{.EthPort}}"{{if not .VHost}} |
||||
- "{{.ApiPort}}:8080"{{end}} |
||||
volumes: |
||||
- {{.Datadir}}:/root/.faucet |
||||
environment: |
||||
- ETH_PORT={{.EthPort}} |
||||
- ETH_NAME={{.EthName}} |
||||
- FAUCET_AMOUNT={{.FaucetAmount}} |
||||
- FAUCET_MINUTES={{.FaucetMinutes}} |
||||
- GITHUB_USER={{.GitHubUser}} |
||||
- GITHUB_TOKEN={{.GitHubToken}}{{if .VHost}} |
||||
- VIRTUAL_HOST={{.VHost}} |
||||
- VIRTUAL_PORT=8080{{end}} |
||||
restart: always |
||||
` |
||||
|
||||
// deployFaucet deploys a new faucet container to a remote machine via SSH,
|
||||
// docker and docker-compose. If an instance with the specified network name
|
||||
// already exists there, it will be overwritten!
|
||||
func deployFaucet(client *sshClient, network string, bootnodes []string, config *faucetInfos) ([]byte, error) { |
||||
// Generate the content to upload to the server
|
||||
workdir := fmt.Sprintf("%d", rand.Int63()) |
||||
files := make(map[string][]byte) |
||||
|
||||
dockerfile := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(faucetDockerfile)).Execute(dockerfile, map[string]interface{}{ |
||||
"NetworkID": config.node.network, |
||||
"Bootnodes": strings.Join(bootnodes, ","), |
||||
"Ethstats": config.node.ethstats, |
||||
"EthPort": config.node.portFull, |
||||
"GitHubUser": config.githubUser, |
||||
"GitHubToken": config.githubToken, |
||||
"FaucetName": strings.Title(network), |
||||
"FaucetAmount": config.amount, |
||||
"FaucetMinutes": config.minutes, |
||||
}) |
||||
files[filepath.Join(workdir, "Dockerfile")] = dockerfile.Bytes() |
||||
|
||||
composefile := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(faucetComposefile)).Execute(composefile, map[string]interface{}{ |
||||
"Network": network, |
||||
"Datadir": config.node.datadir, |
||||
"VHost": config.host, |
||||
"ApiPort": config.port, |
||||
"EthPort": config.node.portFull, |
||||
"EthName": config.node.ethstats[:strings.Index(config.node.ethstats, ":")], |
||||
"GitHubUser": config.githubUser, |
||||
"GitHubToken": config.githubToken, |
||||
"FaucetAmount": config.amount, |
||||
"FaucetMinutes": config.minutes, |
||||
}) |
||||
files[filepath.Join(workdir, "docker-compose.yaml")] = composefile.Bytes() |
||||
|
||||
files[filepath.Join(workdir, "genesis.json")] = []byte(config.node.genesis) |
||||
files[filepath.Join(workdir, "account.json")] = []byte(config.node.keyJSON) |
||||
files[filepath.Join(workdir, "account.pass")] = []byte(config.node.keyPass) |
||||
|
||||
// Upload the deployment files to the remote server (and clean up afterwards)
|
||||
if out, err := client.Upload(files); err != nil { |
||||
return out, err |
||||
} |
||||
defer client.Run("rm -rf " + workdir) |
||||
|
||||
// Build and deploy the faucet service
|
||||
return nil, client.Stream(fmt.Sprintf("cd %s && docker-compose -p %s up -d --build", workdir, network)) |
||||
} |
||||
|
||||
// faucetInfos is returned from an faucet status check to allow reporting various
|
||||
// configuration parameters.
|
||||
type faucetInfos struct { |
||||
node *nodeInfos |
||||
host string |
||||
port int |
||||
amount int |
||||
minutes int |
||||
githubUser string |
||||
githubToken string |
||||
} |
||||
|
||||
// String implements the stringer interface.
|
||||
func (info *faucetInfos) String() string { |
||||
return fmt.Sprintf("host=%s, api=%d, eth=%d, amount=%d, minutes=%d, github=%s, ethstats=%s", info.host, info.port, info.node.portFull, info.amount, info.minutes, info.githubUser, info.node.ethstats) |
||||
} |
||||
|
||||
// checkFaucet does a health-check against an faucet server to verify whether
|
||||
// it's running, and if yes, gathering a collection of useful infos about it.
|
||||
func checkFaucet(client *sshClient, network string) (*faucetInfos, error) { |
||||
// Inspect a possible faucet container on the host
|
||||
infos, err := inspectContainer(client, fmt.Sprintf("%s_faucet_1", network)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if !infos.running { |
||||
return nil, ErrServiceOffline |
||||
} |
||||
// Resolve the port from the host, or the reverse proxy
|
||||
port := infos.portmap["8080/tcp"] |
||||
if port == 0 { |
||||
if proxy, _ := checkNginx(client, network); proxy != nil { |
||||
port = proxy.port |
||||
} |
||||
} |
||||
if port == 0 { |
||||
return nil, ErrNotExposed |
||||
} |
||||
// Resolve the host from the reverse-proxy and the config values
|
||||
host := infos.envvars["VIRTUAL_HOST"] |
||||
if host == "" { |
||||
host = client.server |
||||
} |
||||
amount, _ := strconv.Atoi(infos.envvars["FAUCET_AMOUNT"]) |
||||
minutes, _ := strconv.Atoi(infos.envvars["FAUCET_MINUTES"]) |
||||
|
||||
// Retrieve the funding account informations
|
||||
var out []byte |
||||
keyJSON, keyPass := "", "" |
||||
if out, err = client.Run(fmt.Sprintf("docker exec %s_faucet_1 cat /account.json", network)); err == nil { |
||||
keyJSON = string(bytes.TrimSpace(out)) |
||||
} |
||||
if out, err = client.Run(fmt.Sprintf("docker exec %s_faucet_1 cat /account.pass", network)); err == nil { |
||||
keyPass = string(bytes.TrimSpace(out)) |
||||
} |
||||
// Run a sanity check to see if the port is reachable
|
||||
if err = checkPort(host, port); err != nil { |
||||
log.Warn("Faucet service seems unreachable", "server", host, "port", port, "err", err) |
||||
} |
||||
// Container available, assemble and return the useful infos
|
||||
return &faucetInfos{ |
||||
node: &nodeInfos{ |
||||
datadir: infos.volumes["/root/.faucet"], |
||||
portFull: infos.portmap[infos.envvars["ETH_PORT"]+"/tcp"], |
||||
ethstats: infos.envvars["ETH_NAME"], |
||||
keyJSON: keyJSON, |
||||
keyPass: keyPass, |
||||
}, |
||||
host: host, |
||||
port: port, |
||||
amount: amount, |
||||
minutes: minutes, |
||||
githubUser: infos.envvars["GITHUB_USER"], |
||||
githubToken: infos.envvars["GITHUB_TOKEN"], |
||||
}, nil |
||||
} |
@ -0,0 +1,106 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"html/template" |
||||
"math/rand" |
||||
"path/filepath" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// nginxDockerfile is theis the Dockerfile required to build an nginx reverse-
|
||||
// proxy.
|
||||
var nginxDockerfile = `FROM jwilder/nginx-proxy` |
||||
|
||||
// nginxComposefile is the docker-compose.yml file required to deploy and maintain
|
||||
// an nginx reverse-proxy. The proxy is responsible for exposing one or more HTTP
|
||||
// services running on a single host.
|
||||
var nginxComposefile = ` |
||||
version: '2' |
||||
services: |
||||
nginx: |
||||
build: . |
||||
image: {{.Network}}/nginx |
||||
ports: |
||||
- "{{.Port}}:80" |
||||
volumes: |
||||
- /var/run/docker.sock:/tmp/docker.sock:ro |
||||
restart: always |
||||
` |
||||
|
||||
// deployNginx deploys a new nginx reverse-proxy container to expose one or more
|
||||
// HTTP services running on a single host. If an instance with the specified
|
||||
// network name already exists there, it will be overwritten!
|
||||
func deployNginx(client *sshClient, network string, port int) ([]byte, error) { |
||||
log.Info("Deploying nginx reverse-proxy", "server", client.server, "port", port) |
||||
|
||||
// Generate the content to upload to the server
|
||||
workdir := fmt.Sprintf("%d", rand.Int63()) |
||||
files := make(map[string][]byte) |
||||
|
||||
dockerfile := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(nginxDockerfile)).Execute(dockerfile, nil) |
||||
files[filepath.Join(workdir, "Dockerfile")] = dockerfile.Bytes() |
||||
|
||||
composefile := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(nginxComposefile)).Execute(composefile, map[string]interface{}{ |
||||
"Network": network, |
||||
"Port": port, |
||||
}) |
||||
files[filepath.Join(workdir, "docker-compose.yaml")] = composefile.Bytes() |
||||
|
||||
// Upload the deployment files to the remote server (and clean up afterwards)
|
||||
if out, err := client.Upload(files); err != nil { |
||||
return out, err |
||||
} |
||||
defer client.Run("rm -rf " + workdir) |
||||
|
||||
// Build and deploy the ethstats service
|
||||
return nil, client.Stream(fmt.Sprintf("cd %s && docker-compose -p %s up -d --build", workdir, network)) |
||||
} |
||||
|
||||
// nginxInfos is returned from an nginx reverse-proxy status check to allow
|
||||
// reporting various configuration parameters.
|
||||
type nginxInfos struct { |
||||
port int |
||||
} |
||||
|
||||
// String implements the stringer interface.
|
||||
func (info *nginxInfos) String() string { |
||||
return fmt.Sprintf("port=%d", info.port) |
||||
} |
||||
|
||||
// checkNginx does a health-check against an nginx reverse-proxy to verify whether
|
||||
// it's running, and if yes, gathering a collection of useful infos about it.
|
||||
func checkNginx(client *sshClient, network string) (*nginxInfos, error) { |
||||
// Inspect a possible nginx container on the host
|
||||
infos, err := inspectContainer(client, fmt.Sprintf("%s_nginx_1", network)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if !infos.running { |
||||
return nil, ErrServiceOffline |
||||
} |
||||
// Container available, assemble and return the useful infos
|
||||
return &nginxInfos{ |
||||
port: infos.portmap["80/tcp"], |
||||
}, nil |
||||
} |
@ -0,0 +1,222 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"math/rand" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
"text/template" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// nodeDockerfile is the Dockerfile required to run an Ethereum node.
|
||||
var nodeDockerfile = ` |
||||
FROM ethereum/client-go:alpine-develop |
||||
|
||||
ADD genesis.json /genesis.json |
||||
{{if .Unlock}} |
||||
ADD signer.json /signer.json |
||||
ADD signer.pass /signer.pass |
||||
{{end}} |
||||
RUN \
|
||||
echo '/geth init /genesis.json' > geth.sh && \{{if .Unlock}} |
||||
echo 'mkdir -p /root/.ethereum/keystore/ && cp /signer.json /root/.ethereum/keystore/' >> geth.sh && \{{end}} |
||||
echo $'/geth --networkid {{.NetworkID}} --cache 512 --port {{.Port}} --maxpeers {{.Peers}} {{.LightFlag}} --ethstats \'{{.Ethstats}}\' {{if .Bootnodes}}--bootnodes {{.Bootnodes}}{{end}} {{if .Etherbase}}--etherbase {{.Etherbase}} --mine{{end}}{{if .Unlock}}--unlock 0 --password /signer.pass --mine{{end}}' >> geth.sh |
||||
|
||||
ENTRYPOINT ["/bin/sh", "geth.sh"] |
||||
` |
||||
|
||||
// nodeComposefile is the docker-compose.yml file required to deploy and maintain
|
||||
// an Ethereum node (bootnode or miner for now).
|
||||
var nodeComposefile = ` |
||||
version: '2' |
||||
services: |
||||
{{.Type}}: |
||||
build: . |
||||
image: {{.Network}}/{{.Type}} |
||||
ports: |
||||
- "{{.FullPort}}:{{.FullPort}}" |
||||
- "{{.FullPort}}:{{.FullPort}}/udp"{{if .Light}} |
||||
- "{{.LightPort}}:{{.LightPort}}/udp"{{end}} |
||||
volumes: |
||||
- {{.Datadir}}:/root/.ethereum |
||||
environment: |
||||
- FULL_PORT={{.FullPort}}/tcp |
||||
- LIGHT_PORT={{.LightPort}}/udp |
||||
- TOTAL_PEERS={{.TotalPeers}} |
||||
- LIGHT_PEERS={{.LightPeers}} |
||||
- STATS_NAME={{.Ethstats}} |
||||
- MINER_NAME={{.Etherbase}} |
||||
restart: always |
||||
` |
||||
|
||||
// deployNode deploys a new Ethereum node container to a remote machine via SSH,
|
||||
// docker and docker-compose. If an instance with the specified network name
|
||||
// already exists there, it will be overwritten!
|
||||
func deployNode(client *sshClient, network string, bootnodes []string, config *nodeInfos) ([]byte, error) { |
||||
kind := "sealnode" |
||||
if config.keyJSON == "" && config.etherbase == "" { |
||||
kind = "bootnode" |
||||
bootnodes = make([]string, 0) |
||||
} |
||||
// Generate the content to upload to the server
|
||||
workdir := fmt.Sprintf("%d", rand.Int63()) |
||||
files := make(map[string][]byte) |
||||
|
||||
lightFlag := "" |
||||
if config.peersLight > 0 { |
||||
lightFlag = fmt.Sprintf("--lightpeers=%d --lightserv=50", config.peersLight) |
||||
} |
||||
dockerfile := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(nodeDockerfile)).Execute(dockerfile, map[string]interface{}{ |
||||
"NetworkID": config.network, |
||||
"Port": config.portFull, |
||||
"Peers": config.peersTotal, |
||||
"LightFlag": lightFlag, |
||||
"Bootnodes": strings.Join(bootnodes, ","), |
||||
"Ethstats": config.ethstats, |
||||
"Etherbase": config.etherbase, |
||||
"Unlock": config.keyJSON != "", |
||||
}) |
||||
files[filepath.Join(workdir, "Dockerfile")] = dockerfile.Bytes() |
||||
|
||||
composefile := new(bytes.Buffer) |
||||
template.Must(template.New("").Parse(nodeComposefile)).Execute(composefile, map[string]interface{}{ |
||||
"Type": kind, |
||||
"Datadir": config.datadir, |
||||
"Network": network, |
||||
"FullPort": config.portFull, |
||||
"TotalPeers": config.peersTotal, |
||||
"Light": config.peersLight > 0, |
||||
"LightPort": config.portFull + 1, |
||||
"LightPeers": config.peersLight, |
||||
"Ethstats": config.ethstats[:strings.Index(config.ethstats, ":")], |
||||
"Etherbase": config.etherbase, |
||||
}) |
||||
files[filepath.Join(workdir, "docker-compose.yaml")] = composefile.Bytes() |
||||
|
||||
//genesisfile, _ := json.MarshalIndent(config.genesis, "", " ")
|
||||
files[filepath.Join(workdir, "genesis.json")] = []byte(config.genesis) |
||||
|
||||
if config.keyJSON != "" { |
||||
files[filepath.Join(workdir, "signer.json")] = []byte(config.keyJSON) |
||||
files[filepath.Join(workdir, "signer.pass")] = []byte(config.keyPass) |
||||
} |
||||
// Upload the deployment files to the remote server (and clean up afterwards)
|
||||
if out, err := client.Upload(files); err != nil { |
||||
return out, err |
||||
} |
||||
defer client.Run("rm -rf " + workdir) |
||||
|
||||
// Build and deploy the bootnode service
|
||||
return nil, client.Stream(fmt.Sprintf("cd %s && docker-compose -p %s up -d --build", workdir, network)) |
||||
} |
||||
|
||||
// nodeInfos is returned from a boot or seal node status check to allow reporting
|
||||
// various configuration parameters.
|
||||
type nodeInfos struct { |
||||
genesis []byte |
||||
network int64 |
||||
datadir string |
||||
ethstats string |
||||
portFull int |
||||
portLight int |
||||
enodeFull string |
||||
enodeLight string |
||||
peersTotal int |
||||
peersLight int |
||||
etherbase string |
||||
keyJSON string |
||||
keyPass string |
||||
} |
||||
|
||||
// String implements the stringer interface.
|
||||
func (info *nodeInfos) String() string { |
||||
discv5 := "" |
||||
if info.peersLight > 0 { |
||||
discv5 = fmt.Sprintf(", portv5=%d", info.portLight) |
||||
} |
||||
return fmt.Sprintf("port=%d%s, datadir=%s, peers=%d, lights=%d, ethstats=%s", info.portFull, discv5, info.datadir, info.peersTotal, info.peersLight, info.ethstats) |
||||
} |
||||
|
||||
// checkNode does a health-check against an boot or seal node server to verify
|
||||
// whether it's running, and if yes, whether it's responsive.
|
||||
func checkNode(client *sshClient, network string, boot bool) (*nodeInfos, error) { |
||||
kind := "bootnode" |
||||
if !boot { |
||||
kind = "sealnode" |
||||
} |
||||
// Inspect a possible bootnode container on the host
|
||||
infos, err := inspectContainer(client, fmt.Sprintf("%s_%s_1", network, kind)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if !infos.running { |
||||
return nil, ErrServiceOffline |
||||
} |
||||
// Resolve a few types from the environmental variables
|
||||
totalPeers, _ := strconv.Atoi(infos.envvars["TOTAL_PEERS"]) |
||||
lightPeers, _ := strconv.Atoi(infos.envvars["LIGHT_PEERS"]) |
||||
|
||||
// Container available, retrieve its node ID and its genesis json
|
||||
var out []byte |
||||
if out, err = client.Run(fmt.Sprintf("docker exec %s_%s_1 /geth --exec admin.nodeInfo.id attach", network, kind)); err != nil { |
||||
return nil, ErrServiceUnreachable |
||||
} |
||||
id := bytes.Trim(bytes.TrimSpace(out), "\"") |
||||
|
||||
if out, err = client.Run(fmt.Sprintf("docker exec %s_%s_1 cat /genesis.json", network, kind)); err != nil { |
||||
return nil, ErrServiceUnreachable |
||||
} |
||||
genesis := bytes.TrimSpace(out) |
||||
|
||||
keyJSON, keyPass := "", "" |
||||
if out, err = client.Run(fmt.Sprintf("docker exec %s_%s_1 cat /signer.json", network, kind)); err == nil { |
||||
keyJSON = string(bytes.TrimSpace(out)) |
||||
} |
||||
if out, err = client.Run(fmt.Sprintf("docker exec %s_%s_1 cat /signer.pass", network, kind)); err == nil { |
||||
keyPass = string(bytes.TrimSpace(out)) |
||||
} |
||||
// Run a sanity check to see if the devp2p is reachable
|
||||
port := infos.portmap[infos.envvars["FULL_PORT"]] |
||||
if err = checkPort(client.server, port); err != nil { |
||||
log.Warn(fmt.Sprintf("%s devp2p port seems unreachable", strings.Title(kind)), "server", client.server, "port", port, "err", err) |
||||
} |
||||
// Assemble and return the useful infos
|
||||
stats := &nodeInfos{ |
||||
genesis: genesis, |
||||
datadir: infos.volumes["/root/.ethereum"], |
||||
portFull: infos.portmap[infos.envvars["FULL_PORT"]], |
||||
portLight: infos.portmap[infos.envvars["LIGHT_PORT"]], |
||||
peersTotal: totalPeers, |
||||
peersLight: lightPeers, |
||||
ethstats: infos.envvars["STATS_NAME"], |
||||
etherbase: infos.envvars["MINER_NAME"], |
||||
keyJSON: keyJSON, |
||||
keyPass: keyPass, |
||||
} |
||||
stats.enodeFull = fmt.Sprintf("enode://%s@%s:%d", id, client.address, stats.portFull) |
||||
if stats.portLight != 0 { |
||||
stats.enodeLight = fmt.Sprintf("enode://%s@%s:%d?discport=%d", id, client.address, stats.portFull, stats.portLight) |
||||
} |
||||
return stats, nil |
||||
} |
@ -0,0 +1,55 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
// puppeth is a command to assemble and maintain private networks.
|
||||
package main |
||||
|
||||
import ( |
||||
"math/rand" |
||||
"os" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
"gopkg.in/urfave/cli.v1" |
||||
) |
||||
|
||||
// main is just a boring entry point to set up the CLI app.
|
||||
func main() { |
||||
app := cli.NewApp() |
||||
app.Name = "puppeth" |
||||
app.Usage = "assemble and maintain private Ethereum networks" |
||||
app.Flags = []cli.Flag{ |
||||
cli.StringFlag{ |
||||
Name: "network", |
||||
Usage: "name of the network to administer", |
||||
}, |
||||
cli.IntFlag{ |
||||
Name: "loglevel", |
||||
Value: 4, |
||||
Usage: "log level to emit to the screen", |
||||
}, |
||||
} |
||||
app.Action = func(c *cli.Context) error { |
||||
// Set up the logger to print everything and the random generator
|
||||
log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(c.Int("loglevel")), log.StreamHandler(os.Stdout, log.TerminalFormat(true)))) |
||||
rand.Seed(time.Now().UnixNano()) |
||||
|
||||
// Start the wizard and relinquish control
|
||||
makeWizard(c.String("network")).run() |
||||
return nil |
||||
} |
||||
app.Run(os.Args) |
||||
} |
@ -0,0 +1,195 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"net" |
||||
"os" |
||||
"os/user" |
||||
"path/filepath" |
||||
"strings" |
||||
"syscall" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
"golang.org/x/crypto/ssh" |
||||
"golang.org/x/crypto/ssh/terminal" |
||||
) |
||||
|
||||
// sshClient is a small wrapper around Go's SSH client with a few utility methods
|
||||
// implemented on top.
|
||||
type sshClient struct { |
||||
server string // Server name or IP without port number
|
||||
address string // IP address of the remote server
|
||||
client *ssh.Client |
||||
logger log.Logger |
||||
} |
||||
|
||||
// dial establishes an SSH connection to a remote node using the current user and
|
||||
// the user's configured private RSA key.
|
||||
func dial(server string) (*sshClient, error) { |
||||
// Figure out a label for the server and a logger
|
||||
label := server |
||||
if strings.Contains(label, ":") { |
||||
label = label[:strings.Index(label, ":")] |
||||
} |
||||
logger := log.New("server", label) |
||||
logger.Debug("Attempting to establish SSH connection") |
||||
|
||||
user, err := user.Current() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
// Configure the supported authentication methods (private key and password)
|
||||
var auths []ssh.AuthMethod |
||||
|
||||
path := filepath.Join(user.HomeDir, ".ssh", "id_rsa") |
||||
if buf, err := ioutil.ReadFile(path); err != nil { |
||||
log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) |
||||
} else { |
||||
key, err := ssh.ParsePrivateKey(buf) |
||||
if err != nil { |
||||
log.Warn("Bad SSH key, falling back to passwords", "path", path, "err", err) |
||||
} else { |
||||
auths = append(auths, ssh.PublicKeys(key)) |
||||
} |
||||
} |
||||
auths = append(auths, ssh.PasswordCallback(func() (string, error) { |
||||
fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", user.Username, server) |
||||
blob, err := terminal.ReadPassword(int(syscall.Stdin)) |
||||
|
||||
fmt.Println() |
||||
return string(blob), err |
||||
})) |
||||
// Resolve the IP address of the remote server
|
||||
addr, err := net.LookupHost(label) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(addr) == 0 { |
||||
return nil, errors.New("no IPs associated with domain") |
||||
} |
||||
// Try to dial in to the remote server
|
||||
logger.Trace("Dialing remote SSH server", "user", user.Username, "key", path) |
||||
if !strings.Contains(server, ":") { |
||||
server += ":22" |
||||
} |
||||
client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: user.Username, Auth: auths}) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
// Connection established, return our utility wrapper
|
||||
c := &sshClient{ |
||||
server: label, |
||||
address: addr[0], |
||||
client: client, |
||||
logger: logger, |
||||
} |
||||
if err := c.init(); err != nil { |
||||
client.Close() |
||||
return nil, err |
||||
} |
||||
return c, nil |
||||
} |
||||
|
||||
// init runs some initialization commands on the remote server to ensure it's
|
||||
// capable of acting as puppeth target.
|
||||
func (client *sshClient) init() error { |
||||
client.logger.Debug("Verifying if docker is available") |
||||
if out, err := client.Run("docker version"); err != nil { |
||||
if len(out) == 0 { |
||||
return err |
||||
} |
||||
return fmt.Errorf("docker configured incorrectly: %s", out) |
||||
} |
||||
client.logger.Debug("Verifying if docker-compose is available") |
||||
if out, err := client.Run("docker-compose version"); err != nil { |
||||
if len(out) == 0 { |
||||
return err |
||||
} |
||||
return fmt.Errorf("docker-compose configured incorrectly: %s", out) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Close terminates the connection to an SSH server.
|
||||
func (client *sshClient) Close() error { |
||||
return client.client.Close() |
||||
} |
||||
|
||||
// Run executes a command on the remote server and returns the combined output
|
||||
// along with any error status.
|
||||
func (client *sshClient) Run(cmd string) ([]byte, error) { |
||||
// Establish a single command session
|
||||
session, err := client.client.NewSession() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer session.Close() |
||||
|
||||
// Execute the command and return any output
|
||||
client.logger.Trace("Running command on remote server", "cmd", cmd) |
||||
return session.CombinedOutput(cmd) |
||||
} |
||||
|
||||
// Stream executes a command on the remote server and streams all outputs into
|
||||
// the local stdout and stderr streams.
|
||||
func (client *sshClient) Stream(cmd string) error { |
||||
// Establish a single command session
|
||||
session, err := client.client.NewSession() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer session.Close() |
||||
|
||||
session.Stdout = os.Stdout |
||||
session.Stderr = os.Stderr |
||||
|
||||
// Execute the command and return any output
|
||||
client.logger.Trace("Streaming command on remote server", "cmd", cmd) |
||||
return session.Run(cmd) |
||||
} |
||||
|
||||
// Upload copied the set of files to a remote server via SCP, creating any non-
|
||||
// existing folder in te mean time.
|
||||
func (client *sshClient) Upload(files map[string][]byte) ([]byte, error) { |
||||
// Establish a single command session
|
||||
session, err := client.client.NewSession() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer session.Close() |
||||
|
||||
// Create a goroutine that streams the SCP content
|
||||
go func() { |
||||
out, _ := session.StdinPipe() |
||||
defer out.Close() |
||||
|
||||
for file, content := range files { |
||||
client.logger.Trace("Uploading file to server", "file", file, "bytes", len(content)) |
||||
|
||||
fmt.Fprintln(out, "D0755", 0, filepath.Dir(file)) // Ensure the folder exists
|
||||
fmt.Fprintln(out, "C0644", len(content), filepath.Base(file)) // Create the actual file
|
||||
out.Write(content) // Stream the data content
|
||||
fmt.Fprint(out, "\x00") // Transfer end with \x00
|
||||
fmt.Fprintln(out, "E") // Leave directory (simpler)
|
||||
} |
||||
}() |
||||
return session.CombinedOutput("/usr/bin/scp -v -tr ./") |
||||
} |
@ -0,0 +1,229 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"bufio" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"math/big" |
||||
"os" |
||||
"path/filepath" |
||||
"sort" |
||||
"strconv" |
||||
"strings" |
||||
"syscall" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/ethereum/go-ethereum/core" |
||||
"github.com/ethereum/go-ethereum/log" |
||||
"golang.org/x/crypto/ssh/terminal" |
||||
) |
||||
|
||||
// config contains all the configurations needed by puppeth that should be saved
|
||||
// between sessions.
|
||||
type config struct { |
||||
path string // File containing the configuration values
|
||||
genesis *core.Genesis // Genesis block to cache for node deploys
|
||||
bootFull []string // Bootnodes to always connect to by full nodes
|
||||
bootLight []string // Bootnodes to always connect to by light nodes
|
||||
ethstats string // Ethstats settings to cache for node deploys
|
||||
|
||||
Servers []string `json:"servers,omitempty"` |
||||
} |
||||
|
||||
// flush dumps the contents of config to disk.
|
||||
func (c config) flush() { |
||||
os.MkdirAll(filepath.Dir(c.path), 0755) |
||||
|
||||
sort.Strings(c.Servers) |
||||
out, _ := json.MarshalIndent(c, "", " ") |
||||
if err := ioutil.WriteFile(c.path, out, 0644); err != nil { |
||||
log.Warn("Failed to save puppeth configs", "file", c.path, "err", err) |
||||
} |
||||
} |
||||
|
||||
type wizard struct { |
||||
network string // Network name to manage
|
||||
conf config // Configurations from previous runs
|
||||
|
||||
servers map[string]*sshClient // SSH connections to servers to administer
|
||||
services map[string][]string // Ethereum services known to be running on servers
|
||||
|
||||
in *bufio.Reader // Wrapper around stdin to allow reading user input
|
||||
} |
||||
|
||||
// read reads a single line from stdin, trimming if from spaces.
|
||||
func (w *wizard) read() string { |
||||
fmt.Printf("> ") |
||||
text, err := w.in.ReadString('\n') |
||||
if err != nil { |
||||
log.Crit("Failed to read user input", "err", err) |
||||
} |
||||
return strings.TrimSpace(text) |
||||
} |
||||
|
||||
// readString reads a single line from stdin, trimming if from spaces, enforcing
|
||||
// non-emptyness.
|
||||
func (w *wizard) readString() string { |
||||
for { |
||||
fmt.Printf("> ") |
||||
text, err := w.in.ReadString('\n') |
||||
if err != nil { |
||||
log.Crit("Failed to read user input", "err", err) |
||||
} |
||||
if text = strings.TrimSpace(text); text != "" { |
||||
return text |
||||
} |
||||
} |
||||
} |
||||
|
||||
// readDefaultString reads a single line from stdin, trimming if from spaces. If
|
||||
// an empty line is entered, the default value is returned.
|
||||
func (w *wizard) readDefaultString(def string) string { |
||||
for { |
||||
fmt.Printf("> ") |
||||
text, err := w.in.ReadString('\n') |
||||
if err != nil { |
||||
log.Crit("Failed to read user input", "err", err) |
||||
} |
||||
if text = strings.TrimSpace(text); text != "" { |
||||
return text |
||||
} |
||||
return def |
||||
} |
||||
} |
||||
|
||||
// readInt reads a single line from stdin, trimming if from spaces, enforcing it
|
||||
// to parse into an integer.
|
||||
func (w *wizard) readInt() int { |
||||
for { |
||||
fmt.Printf("> ") |
||||
text, err := w.in.ReadString('\n') |
||||
if err != nil { |
||||
log.Crit("Failed to read user input", "err", err) |
||||
} |
||||
if text = strings.TrimSpace(text); text == "" { |
||||
continue |
||||
} |
||||
val, err := strconv.Atoi(strings.TrimSpace(text)) |
||||
if err != nil { |
||||
log.Error("Invalid input, expected integer", "err", err) |
||||
continue |
||||
} |
||||
return val |
||||
} |
||||
} |
||||
|
||||
// readDefaultInt reads a single line from stdin, trimming if from spaces, enforcing
|
||||
// it to parse into an integer. If an empty line is entered, the default value is
|
||||
// returned.
|
||||
func (w *wizard) readDefaultInt(def int) int { |
||||
for { |
||||
fmt.Printf("> ") |
||||
text, err := w.in.ReadString('\n') |
||||
if err != nil { |
||||
log.Crit("Failed to read user input", "err", err) |
||||
} |
||||
if text = strings.TrimSpace(text); text == "" { |
||||
return def |
||||
} |
||||
val, err := strconv.Atoi(strings.TrimSpace(text)) |
||||
if err != nil { |
||||
log.Error("Invalid input, expected integer", "err", err) |
||||
continue |
||||
} |
||||
return val |
||||
} |
||||
} |
||||
|
||||
// readPassword reads a single line from stdin, trimming it from the trailing new
|
||||
// line and returns it. The input will not be echoed.
|
||||
func (w *wizard) readPassword() string { |
||||
for { |
||||
fmt.Printf("> ") |
||||
text, err := terminal.ReadPassword(int(syscall.Stdin)) |
||||
if err != nil { |
||||
log.Crit("Failed to read password", "err", err) |
||||
} |
||||
fmt.Println() |
||||
return string(text) |
||||
} |
||||
} |
||||
|
||||
// readAddress reads a single line from stdin, trimming if from spaces and converts
|
||||
// it to an Ethereum address.
|
||||
func (w *wizard) readAddress() *common.Address { |
||||
for { |
||||
// Read the address from the user
|
||||
fmt.Printf("> 0x") |
||||
text, err := w.in.ReadString('\n') |
||||
if err != nil { |
||||
log.Crit("Failed to read user input", "err", err) |
||||
} |
||||
if text = strings.TrimSpace(text); text == "" { |
||||
return nil |
||||
} |
||||
// Make sure it looks ok and return it if so
|
||||
if len(text) != 40 { |
||||
log.Error("Invalid address length, please retry") |
||||
continue |
||||
} |
||||
bigaddr, _ := new(big.Int).SetString(text, 16) |
||||
address := common.BigToAddress(bigaddr) |
||||
return &address |
||||
} |
||||
} |
||||
|
||||
// readDefaultAddress reads a single line from stdin, trimming if from spaces and
|
||||
// converts it to an Ethereum address. If an empty line is entered, the default
|
||||
// value is returned.
|
||||
func (w *wizard) readDefaultAddress(def common.Address) common.Address { |
||||
for { |
||||
// Read the address from the user
|
||||
fmt.Printf("> 0x") |
||||
text, err := w.in.ReadString('\n') |
||||
if err != nil { |
||||
log.Crit("Failed to read user input", "err", err) |
||||
} |
||||
if text = strings.TrimSpace(text); text == "" { |
||||
return def |
||||
} |
||||
// Make sure it looks ok and return it if so
|
||||
if len(text) != 40 { |
||||
log.Error("Invalid address length, please retry") |
||||
continue |
||||
} |
||||
bigaddr, _ := new(big.Int).SetString(text, 16) |
||||
return common.BigToAddress(bigaddr) |
||||
} |
||||
} |
||||
|
||||
// readJSON reads a raw JSON message and returns it.
|
||||
func (w *wizard) readJSON() string { |
||||
var blob json.RawMessage |
||||
|
||||
for { |
||||
fmt.Printf("> ") |
||||
if err := json.NewDecoder(w.in).Decode(&blob); err != nil { |
||||
log.Error("Invalid JSON, please try again", "err", err) |
||||
continue |
||||
} |
||||
return string(blob) |
||||
} |
||||
} |
@ -0,0 +1,132 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// deployDashboard queries the user for various input on deploying a web-service
|
||||
// dashboard, after which is pushes the container.
|
||||
func (w *wizard) deployDashboard() { |
||||
// Select the server to interact with
|
||||
server := w.selectServer() |
||||
if server == "" { |
||||
return |
||||
} |
||||
client := w.servers[server] |
||||
|
||||
// Retrieve any active dashboard configurations from the server
|
||||
infos, err := checkDashboard(client, w.network) |
||||
if err != nil { |
||||
infos = &dashboardInfos{ |
||||
port: 80, |
||||
host: client.server, |
||||
} |
||||
} |
||||
// Figure out which port to listen on
|
||||
fmt.Println() |
||||
fmt.Printf("Which port should the dashboard listen on? (default = %d)\n", infos.port) |
||||
infos.port = w.readDefaultInt(infos.port) |
||||
|
||||
// Figure which virtual-host to deploy the dashboard on
|
||||
infos.host, err = w.ensureVirtualHost(client, infos.port, infos.host) |
||||
if err != nil { |
||||
log.Error("Failed to decide on dashboard host", "err", err) |
||||
return |
||||
} |
||||
// Port and proxy settings retrieved, figure out which services are available
|
||||
available := make(map[string][]string) |
||||
for server, services := range w.services { |
||||
for _, service := range services { |
||||
available[service] = append(available[service], server) |
||||
} |
||||
} |
||||
listing := make(map[string]string) |
||||
for _, service := range []string{"ethstats", "explorer", "wallet", "faucet"} { |
||||
// Gather all the locally hosted pages of this type
|
||||
var pages []string |
||||
for _, server := range available[service] { |
||||
client := w.servers[server] |
||||
if client == nil { |
||||
continue |
||||
} |
||||
// If there's a service running on the machine, retrieve it's port number
|
||||
var port int |
||||
switch service { |
||||
case "ethstats": |
||||
if infos, err := checkEthstats(client, w.network); err == nil { |
||||
port = infos.port |
||||
} |
||||
case "faucet": |
||||
if infos, err := checkFaucet(client, w.network); err == nil { |
||||
port = infos.port |
||||
} |
||||
} |
||||
if page, err := resolve(client, w.network, service, port); err == nil && page != "" { |
||||
pages = append(pages, page) |
||||
} |
||||
} |
||||
// Promt the user to chose one, enter manually or simply not list this service
|
||||
defLabel, defChoice := "don't list", len(pages)+2 |
||||
if len(pages) > 0 { |
||||
defLabel, defChoice = pages[0], 1 |
||||
} |
||||
fmt.Println() |
||||
fmt.Printf("Which %s service to list? (default = %s)\n", service, defLabel) |
||||
for i, page := range pages { |
||||
fmt.Printf(" %d. %s\n", i+1, page) |
||||
} |
||||
fmt.Printf(" %d. List external %s service\n", len(pages)+1, service) |
||||
fmt.Printf(" %d. Don't list any %s service\n", len(pages)+2, service) |
||||
|
||||
choice := w.readDefaultInt(defChoice) |
||||
if choice < 0 || choice > len(pages)+2 { |
||||
log.Error("Invalid listing choice, aborting") |
||||
return |
||||
} |
||||
switch { |
||||
case choice <= len(pages): |
||||
listing[service] = pages[choice-1] |
||||
case choice == len(pages)+1: |
||||
fmt.Println() |
||||
fmt.Printf("Which address is the external %s service at?\n", service) |
||||
listing[service] = w.readString() |
||||
default: |
||||
// No service hosting for this
|
||||
} |
||||
} |
||||
// If we have ethstats running, ask whether to make the secret public or not
|
||||
var ethstats bool |
||||
if w.conf.ethstats != "" { |
||||
fmt.Println() |
||||
fmt.Println("Include ethstats secret on dashboard (y/n)? (default = yes)") |
||||
ethstats = w.readDefaultString("y") == "y" |
||||
} |
||||
// Try to deploy the dashboard container on the host
|
||||
if out, err := deployDashboard(client, w.network, infos.port, infos.host, listing, &w.conf, ethstats); err != nil { |
||||
log.Error("Failed to deploy dashboard container", "err", err) |
||||
if len(out) > 0 { |
||||
fmt.Printf("%s\n", out) |
||||
} |
||||
return |
||||
} |
||||
// All ok, run a network scan to pick any changes up
|
||||
w.networkStats(false) |
||||
} |
@ -0,0 +1,79 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// deployEthstats queries the user for various input on deploying an ethstats
|
||||
// monitoring server, after which it executes it.
|
||||
func (w *wizard) deployEthstats() { |
||||
// Select the server to interact with
|
||||
server := w.selectServer() |
||||
if server == "" { |
||||
return |
||||
} |
||||
client := w.servers[server] |
||||
|
||||
// Retrieve any active ethstats configurations from the server
|
||||
infos, err := checkEthstats(client, w.network) |
||||
if err != nil { |
||||
infos = ðstatsInfos{ |
||||
port: 80, |
||||
host: client.server, |
||||
secret: "", |
||||
} |
||||
} |
||||
// Figure out which port to listen on
|
||||
fmt.Println() |
||||
fmt.Printf("Which port should ethstats listen on? (default = %d)\n", infos.port) |
||||
infos.port = w.readDefaultInt(infos.port) |
||||
|
||||
// Figure which virtual-host to deploy ethstats on
|
||||
if infos.host, err = w.ensureVirtualHost(client, infos.port, infos.host); err != nil { |
||||
log.Error("Failed to decide on ethstats host", "err", err) |
||||
return |
||||
} |
||||
// Port and proxy settings retrieved, figure out the secret and boot ethstats
|
||||
fmt.Println() |
||||
if infos.secret == "" { |
||||
fmt.Printf("What should be the secret password for the API? (must not be empty)\n") |
||||
infos.secret = w.readString() |
||||
} else { |
||||
fmt.Printf("What should be the secret password for the API? (default = %s)\n", infos.secret) |
||||
infos.secret = w.readDefaultString(infos.secret) |
||||
} |
||||
// Try to deploy the ethstats server on the host
|
||||
trusted := make([]string, 0, len(w.servers)) |
||||
for _, client := range w.servers { |
||||
if client != nil { |
||||
trusted = append(trusted, client.address) |
||||
} |
||||
} |
||||
if out, err := deployEthstats(client, w.network, infos.port, infos.secret, infos.host, trusted); err != nil { |
||||
log.Error("Failed to deploy ethstats container", "err", err) |
||||
if len(out) > 0 { |
||||
fmt.Printf("%s\n", out) |
||||
} |
||||
return |
||||
} |
||||
// All ok, run a network scan to pick any changes up
|
||||
w.networkStats(false) |
||||
} |
@ -0,0 +1,172 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"net/http" |
||||
|
||||
"github.com/ethereum/go-ethereum/accounts/keystore" |
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// deployFaucet queries the user for various input on deploying a faucet, after
|
||||
// which it executes it.
|
||||
func (w *wizard) deployFaucet() { |
||||
// Select the server to interact with
|
||||
server := w.selectServer() |
||||
if server == "" { |
||||
return |
||||
} |
||||
client := w.servers[server] |
||||
|
||||
// Retrieve any active faucet configurations from the server
|
||||
infos, err := checkFaucet(client, w.network) |
||||
if err != nil { |
||||
infos = &faucetInfos{ |
||||
node: &nodeInfos{portFull: 30303, peersTotal: 25}, |
||||
port: 80, |
||||
host: client.server, |
||||
amount: 1, |
||||
minutes: 1440, |
||||
} |
||||
} |
||||
infos.node.genesis, _ = json.MarshalIndent(w.conf.genesis, "", " ") |
||||
infos.node.network = w.conf.genesis.Config.ChainId.Int64() |
||||
|
||||
// Figure out which port to listen on
|
||||
fmt.Println() |
||||
fmt.Printf("Which port should the faucet listen on? (default = %d)\n", infos.port) |
||||
infos.port = w.readDefaultInt(infos.port) |
||||
|
||||
// Figure which virtual-host to deploy ethstats on
|
||||
if infos.host, err = w.ensureVirtualHost(client, infos.port, infos.host); err != nil { |
||||
log.Error("Failed to decide on faucet host", "err", err) |
||||
return |
||||
} |
||||
// Port and proxy settings retrieved, figure out the funcing amount per perdion configurations
|
||||
fmt.Println() |
||||
fmt.Printf("How many Ethers to release per request? (default = %d)\n", infos.amount) |
||||
infos.amount = w.readDefaultInt(infos.amount) |
||||
|
||||
fmt.Println() |
||||
fmt.Printf("How many minutes to enforce between requests? (default = %d)\n", infos.minutes) |
||||
infos.minutes = w.readDefaultInt(infos.minutes) |
||||
|
||||
// Accessing GitHub gists requires API authorization, retrieve it
|
||||
if infos.githubUser != "" { |
||||
fmt.Println() |
||||
fmt.Printf("Reused previous (%s) GitHub API authorization (y/n)? (default = yes)\n", infos.githubUser) |
||||
if w.readDefaultString("y") != "y" { |
||||
infos.githubUser, infos.githubToken = "", "" |
||||
} |
||||
} |
||||
if infos.githubUser == "" { |
||||
// No previous authorization (or new one requested)
|
||||
fmt.Println() |
||||
fmt.Println("Which GitHub user to verify Gists through?") |
||||
infos.githubUser = w.readString() |
||||
|
||||
fmt.Println() |
||||
fmt.Println("What is the GitHub personal access token of the user? (won't be echoed)") |
||||
infos.githubToken = w.readPassword() |
||||
|
||||
// Do a sanity check query against github to ensure it's valid
|
||||
req, _ := http.NewRequest("GET", "https://api.github.com/user", nil) |
||||
req.SetBasicAuth(infos.githubUser, infos.githubToken) |
||||
res, err := http.DefaultClient.Do(req) |
||||
if err != nil { |
||||
log.Error("Failed to verify GitHub authentication", "err", err) |
||||
return |
||||
} |
||||
defer res.Body.Close() |
||||
|
||||
var msg struct { |
||||
Login string `json:"login"` |
||||
Message string `json:"message"` |
||||
} |
||||
if err = json.NewDecoder(res.Body).Decode(&msg); err != nil { |
||||
log.Error("Failed to decode authorization response", "err", err) |
||||
return |
||||
} |
||||
if msg.Login != infos.githubUser { |
||||
log.Error("GitHub authorization failed", "user", infos.githubUser, "message", msg.Message) |
||||
return |
||||
} |
||||
} |
||||
// Figure out where the user wants to store the persistent data
|
||||
fmt.Println() |
||||
if infos.node.datadir == "" { |
||||
fmt.Printf("Where should data be stored on the remote machine?\n") |
||||
infos.node.datadir = w.readString() |
||||
} else { |
||||
fmt.Printf("Where should data be stored on the remote machine? (default = %s)\n", infos.node.datadir) |
||||
infos.node.datadir = w.readDefaultString(infos.node.datadir) |
||||
} |
||||
// Figure out which port to listen on
|
||||
fmt.Println() |
||||
fmt.Printf("Which TCP/UDP port should the light client listen on? (default = %d)\n", infos.node.portFull) |
||||
infos.node.portFull = w.readDefaultInt(infos.node.portFull) |
||||
|
||||
// Set a proper name to report on the stats page
|
||||
fmt.Println() |
||||
if infos.node.ethstats == "" { |
||||
fmt.Printf("What should the node be called on the stats page?\n") |
||||
infos.node.ethstats = w.readString() + ":" + w.conf.ethstats |
||||
} else { |
||||
fmt.Printf("What should the node be called on the stats page? (default = %s)\n", infos.node.ethstats) |
||||
infos.node.ethstats = w.readDefaultString(infos.node.ethstats) + ":" + w.conf.ethstats |
||||
} |
||||
// Load up the credential needed to release funds
|
||||
if infos.node.keyJSON != "" { |
||||
var key keystore.Key |
||||
if err := json.Unmarshal([]byte(infos.node.keyJSON), &key); err != nil { |
||||
infos.node.keyJSON, infos.node.keyPass = "", "" |
||||
} else { |
||||
fmt.Println() |
||||
fmt.Printf("Reuse previous (%s) funding account (y/n)? (default = yes)\n", key.Address.Hex()) |
||||
if w.readDefaultString("y") != "y" { |
||||
infos.node.keyJSON, infos.node.keyPass = "", "" |
||||
} |
||||
} |
||||
} |
||||
if infos.node.keyJSON == "" { |
||||
fmt.Println() |
||||
fmt.Println("Please paste the faucet's funding account key JSON:") |
||||
infos.node.keyJSON = w.readJSON() |
||||
|
||||
fmt.Println() |
||||
fmt.Println("What's the unlock password for the account? (won't be echoed)") |
||||
infos.node.keyPass = w.readPassword() |
||||
|
||||
if _, err := keystore.DecryptKey([]byte(infos.node.keyJSON), infos.node.keyPass); err != nil { |
||||
log.Error("Failed to decrypt key with given passphrase") |
||||
return |
||||
} |
||||
} |
||||
// Try to deploy the faucet server on the host
|
||||
if out, err := deployFaucet(client, w.network, w.conf.bootLight, infos); err != nil { |
||||
log.Error("Failed to deploy faucet container", "err", err) |
||||
if len(out) > 0 { |
||||
fmt.Printf("%s\n", out) |
||||
} |
||||
return |
||||
} |
||||
// All ok, run a network scan to pick any changes up
|
||||
w.networkStats(false) |
||||
} |
@ -0,0 +1,136 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"math/big" |
||||
"math/rand" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/ethereum/go-ethereum/core" |
||||
"github.com/ethereum/go-ethereum/log" |
||||
"github.com/ethereum/go-ethereum/params" |
||||
) |
||||
|
||||
// makeGenesis creates a new genesis struct based on some user input.
|
||||
func (w *wizard) makeGenesis() { |
||||
// Construct a default genesis block
|
||||
genesis := &core.Genesis{ |
||||
Timestamp: uint64(time.Now().Unix()), |
||||
GasLimit: 4700000, |
||||
Difficulty: big.NewInt(1048576), |
||||
Alloc: make(core.GenesisAlloc), |
||||
Config: ¶ms.ChainConfig{ |
||||
HomesteadBlock: big.NewInt(1), |
||||
EIP150Block: big.NewInt(2), |
||||
EIP155Block: big.NewInt(3), |
||||
EIP158Block: big.NewInt(3), |
||||
}, |
||||
} |
||||
// Figure out which consensus engine to choose
|
||||
fmt.Println() |
||||
fmt.Println("Which consensus engine to use? (default = clique)") |
||||
fmt.Println(" 1. Ethash - proof-of-work") |
||||
fmt.Println(" 2. Clique - proof-of-authority") |
||||
|
||||
choice := w.read() |
||||
switch { |
||||
case choice == "1": |
||||
// In case of ethash, we're pretty much done
|
||||
genesis.Config.Ethash = new(params.EthashConfig) |
||||
genesis.ExtraData = make([]byte, 32) |
||||
|
||||
case choice == "" || choice == "2": |
||||
// In the case of clique, configure the consensus parameters
|
||||
genesis.Difficulty = big.NewInt(1) |
||||
genesis.Config.Clique = ¶ms.CliqueConfig{ |
||||
Period: 15, |
||||
Epoch: 30000, |
||||
} |
||||
fmt.Println() |
||||
fmt.Println("How many seconds should blocks take? (default = 15)") |
||||
genesis.Config.Clique.Period = uint64(w.readDefaultInt(15)) |
||||
|
||||
// We also need the initial list of signers
|
||||
fmt.Println() |
||||
fmt.Println("Which accounts are allowed to seal? (mandatory at least one)") |
||||
|
||||
var signers []common.Address |
||||
for { |
||||
if address := w.readAddress(); address != nil { |
||||
signers = append(signers, *address) |
||||
continue |
||||
} |
||||
if len(signers) > 0 { |
||||
break |
||||
} |
||||
} |
||||
// Sort the signers and embed into the extra-data section
|
||||
for i := 0; i < len(signers); i++ { |
||||
for j := i + 1; j < len(signers); j++ { |
||||
if bytes.Compare(signers[i][:], signers[j][:]) > 0 { |
||||
signers[i], signers[j] = signers[j], signers[i] |
||||
} |
||||
} |
||||
} |
||||
genesis.ExtraData = make([]byte, 32+len(signers)*common.AddressLength+65) |
||||
for i, signer := range signers { |
||||
copy(genesis.ExtraData[32+i*common.AddressLength:], signer[:]) |
||||
} |
||||
|
||||
default: |
||||
log.Crit("Invalid consensus engine choice", "choice", choice) |
||||
} |
||||
// Consensus all set, just ask for initial funds and go
|
||||
fmt.Println() |
||||
fmt.Println("Which accounts should be pre-funded? (advisable at least one)") |
||||
for { |
||||
// Read the address of the account to fund
|
||||
if address := w.readAddress(); address != nil { |
||||
genesis.Alloc[*address] = core.GenesisAccount{ |
||||
Balance: new(big.Int).Lsh(big.NewInt(1), 256-7), // 2^256 / 128 (allow many pre-funds without balance overflows)
|
||||
} |
||||
continue |
||||
} |
||||
break |
||||
} |
||||
// Add a batch of precompile balances to avoid them getting deleted
|
||||
for i := int64(0); i < 256; i++ { |
||||
genesis.Alloc[common.BigToAddress(big.NewInt(i))] = core.GenesisAccount{Balance: big.NewInt(1)} |
||||
} |
||||
fmt.Println() |
||||
|
||||
// Query the user for some custom extras
|
||||
fmt.Println() |
||||
fmt.Println("Specify your chain/network ID if you want an explicit one (default = random)") |
||||
genesis.Config.ChainId = big.NewInt(int64(w.readDefaultInt(rand.Intn(65536)))) |
||||
|
||||
fmt.Println() |
||||
fmt.Println("Anything fun to embed into the genesis block? (max 32 bytes)") |
||||
|
||||
extra := w.read() |
||||
if len(extra) > 32 { |
||||
extra = extra[:32] |
||||
} |
||||
genesis.ExtraData = append([]byte(extra), genesis.ExtraData[len(extra):]...) |
||||
|
||||
// All done, store the genesis and flush to disk
|
||||
w.conf.genesis = genesis |
||||
} |
@ -0,0 +1,153 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"bufio" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// makeWizard creates and returns a new puppeth wizard.
|
||||
func makeWizard(network string) *wizard { |
||||
return &wizard{ |
||||
network: network, |
||||
servers: make(map[string]*sshClient), |
||||
services: make(map[string][]string), |
||||
in: bufio.NewReader(os.Stdin), |
||||
} |
||||
} |
||||
|
||||
// run displays some useful infos to the user, starting on the journey of
|
||||
// setting up a new or managing an existing Ethereum private network.
|
||||
func (w *wizard) run() { |
||||
fmt.Println("+-----------------------------------------------------------+") |
||||
fmt.Println("| Welcome to puppeth, your Ethereum private network manager |") |
||||
fmt.Println("| |") |
||||
fmt.Println("| This tool lets you create a new Ethereum network down to |") |
||||
fmt.Println("| the genesis block, bootnodes, miners and ethstats servers |") |
||||
fmt.Println("| without the hassle that it would normally entail. |") |
||||
fmt.Println("| |") |
||||
fmt.Println("| Puppeth uses SSH to dial in to remote servers, and builds |") |
||||
fmt.Println("| its network components out of Docker containers using the |") |
||||
fmt.Println("| docker-compose toolset. |") |
||||
fmt.Println("+-----------------------------------------------------------+") |
||||
fmt.Println() |
||||
|
||||
// Make sure we have a good network name to work with fmt.Println()
|
||||
if w.network == "" { |
||||
fmt.Println("Please specify a network name to administer (no spaces, please)") |
||||
for { |
||||
w.network = w.readString() |
||||
if !strings.Contains(w.network, " ") { |
||||
fmt.Printf("Sweet, you can set this via --network=%s next time!\n\n", w.network) |
||||
break |
||||
} |
||||
log.Error("I also like to live dangerously, still no spaces") |
||||
} |
||||
} |
||||
log.Info("Administering Ethereum network", "name", w.network) |
||||
|
||||
// Load initial configurations and connect to all live servers
|
||||
w.conf.path = filepath.Join(os.Getenv("HOME"), ".puppeth", w.network) |
||||
|
||||
blob, err := ioutil.ReadFile(w.conf.path) |
||||
if err != nil { |
||||
log.Warn("No previous configurations found", "path", w.conf.path) |
||||
} else if err := json.Unmarshal(blob, &w.conf); err != nil { |
||||
log.Crit("Previous configuration corrupted", "path", w.conf.path, "err", err) |
||||
} else { |
||||
for _, server := range w.conf.Servers { |
||||
log.Info("Dialing previously configured server", "server", server) |
||||
client, err := dial(server) |
||||
if err != nil { |
||||
log.Error("Previous server unreachable", "server", server, "err", err) |
||||
} |
||||
w.servers[server] = client |
||||
} |
||||
w.networkStats(false) |
||||
} |
||||
// Basics done, loop ad infinitum about what to do
|
||||
for { |
||||
fmt.Println() |
||||
fmt.Println("What would you like to do? (default = stats)") |
||||
fmt.Println(" 1. Show network stats") |
||||
if w.conf.genesis == nil { |
||||
fmt.Println(" 2. Configure new genesis") |
||||
} else { |
||||
fmt.Println(" 2. Save existing genesis") |
||||
} |
||||
if len(w.servers) == 0 { |
||||
fmt.Println(" 3. Track new remote server") |
||||
} else { |
||||
fmt.Println(" 3. Manage tracked machines") |
||||
} |
||||
if len(w.services) == 0 { |
||||
fmt.Println(" 4. Deploy network components") |
||||
} else { |
||||
fmt.Println(" 4. Manage network components") |
||||
} |
||||
//fmt.Println(" 5. ProTips for common usecases")
|
||||
|
||||
choice := w.read() |
||||
switch { |
||||
case choice == "" || choice == "1": |
||||
w.networkStats(false) |
||||
|
||||
case choice == "2": |
||||
// If we don't have a genesis, make one
|
||||
if w.conf.genesis == nil { |
||||
w.makeGenesis() |
||||
} else { |
||||
// Otherwise just save whatever we currently have
|
||||
fmt.Println() |
||||
fmt.Printf("Which file to save the genesis into? (default = %s.json)\n", w.network) |
||||
out, _ := json.MarshalIndent(w.conf.genesis, "", " ") |
||||
if err := ioutil.WriteFile(w.readDefaultString(fmt.Sprintf("%s.json", w.network)), out, 0644); err != nil { |
||||
log.Error("Failed to save genesis file", "err", err) |
||||
} |
||||
log.Info("Exported existing genesis block") |
||||
} |
||||
case choice == "3": |
||||
if len(w.servers) == 0 { |
||||
if w.makeServer() != "" { |
||||
w.networkStats(false) |
||||
} |
||||
} else { |
||||
w.manageServers() |
||||
} |
||||
case choice == "4": |
||||
if len(w.services) == 0 { |
||||
w.deployComponent() |
||||
} else { |
||||
w.manageComponents() |
||||
} |
||||
|
||||
case choice == "5": |
||||
w.networkStats(true) |
||||
|
||||
default: |
||||
log.Error("That's not something I can do") |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,235 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"os" |
||||
"strings" |
||||
|
||||
"github.com/ethereum/go-ethereum/core" |
||||
"github.com/ethereum/go-ethereum/log" |
||||
"github.com/olekukonko/tablewriter" |
||||
) |
||||
|
||||
// networkStats verifies the status of network components and generates a protip
|
||||
// configuration set to give users hints on how to do various tasks.
|
||||
func (w *wizard) networkStats(tips bool) { |
||||
if len(w.servers) == 0 { |
||||
log.Error("No remote machines to gather stats from") |
||||
return |
||||
} |
||||
protips := new(protips) |
||||
|
||||
// Iterate over all the specified hosts and check their status
|
||||
stats := tablewriter.NewWriter(os.Stdout) |
||||
stats.SetHeader([]string{"Server", "IP", "Status", "Service", "Details"}) |
||||
stats.SetColWidth(128) |
||||
|
||||
for _, server := range w.conf.Servers { |
||||
client := w.servers[server] |
||||
logger := log.New("server", server) |
||||
logger.Info("Starting remote server health-check") |
||||
|
||||
// If the server is not connected, try to connect again
|
||||
if client == nil { |
||||
conn, err := dial(server) |
||||
if err != nil { |
||||
logger.Error("Failed to establish remote connection", "err", err) |
||||
stats.Append([]string{server, "", err.Error(), "", ""}) |
||||
continue |
||||
} |
||||
client = conn |
||||
} |
||||
// Client connected one way or another, run health-checks
|
||||
services := make(map[string]string) |
||||
logger.Debug("Checking for nginx availability") |
||||
if infos, err := checkNginx(client, w.network); err != nil { |
||||
if err != ErrServiceUnknown { |
||||
services["nginx"] = err.Error() |
||||
} |
||||
} else { |
||||
services["nginx"] = infos.String() |
||||
} |
||||
logger.Debug("Checking for ethstats availability") |
||||
if infos, err := checkEthstats(client, w.network); err != nil { |
||||
if err != ErrServiceUnknown { |
||||
services["ethstats"] = err.Error() |
||||
} |
||||
} else { |
||||
services["ethstats"] = infos.String() |
||||
protips.ethstats = infos.config |
||||
} |
||||
logger.Debug("Checking for bootnode availability") |
||||
if infos, err := checkNode(client, w.network, true); err != nil { |
||||
if err != ErrServiceUnknown { |
||||
services["bootnode"] = err.Error() |
||||
} |
||||
} else { |
||||
services["bootnode"] = infos.String() |
||||
|
||||
protips.genesis = string(infos.genesis) |
||||
protips.bootFull = append(protips.bootFull, infos.enodeFull) |
||||
if infos.enodeLight != "" { |
||||
protips.bootLight = append(protips.bootLight, infos.enodeLight) |
||||
} |
||||
} |
||||
logger.Debug("Checking for sealnode availability") |
||||
if infos, err := checkNode(client, w.network, false); err != nil { |
||||
if err != ErrServiceUnknown { |
||||
services["sealnode"] = err.Error() |
||||
} |
||||
} else { |
||||
services["sealnode"] = infos.String() |
||||
protips.genesis = string(infos.genesis) |
||||
} |
||||
logger.Debug("Checking for faucet availability") |
||||
if infos, err := checkFaucet(client, w.network); err != nil { |
||||
if err != ErrServiceUnknown { |
||||
services["faucet"] = err.Error() |
||||
} |
||||
} else { |
||||
services["faucet"] = infos.String() |
||||
} |
||||
logger.Debug("Checking for dashboard availability") |
||||
if infos, err := checkDashboard(client, w.network); err != nil { |
||||
if err != ErrServiceUnknown { |
||||
services["dashboard"] = err.Error() |
||||
} |
||||
} else { |
||||
services["dashboard"] = infos.String() |
||||
} |
||||
// All status checks complete, report and check next server
|
||||
delete(w.services, server) |
||||
for service := range services { |
||||
w.services[server] = append(w.services[server], service) |
||||
} |
||||
server, address := client.server, client.address |
||||
for service, status := range services { |
||||
stats.Append([]string{server, address, "online", service, status}) |
||||
server, address = "", "" |
||||
} |
||||
if len(services) == 0 { |
||||
stats.Append([]string{server, address, "online", "", ""}) |
||||
} |
||||
} |
||||
// If a genesis block was found, load it into our configs
|
||||
if protips.genesis != "" { |
||||
genesis := new(core.Genesis) |
||||
if err := json.Unmarshal([]byte(protips.genesis), genesis); err != nil { |
||||
log.Error("Failed to parse remote genesis", "err", err) |
||||
} else { |
||||
w.conf.genesis = genesis |
||||
protips.network = genesis.Config.ChainId.Int64() |
||||
} |
||||
} |
||||
if protips.ethstats != "" { |
||||
w.conf.ethstats = protips.ethstats |
||||
} |
||||
w.conf.bootFull = protips.bootFull |
||||
w.conf.bootLight = protips.bootLight |
||||
|
||||
// Print any collected stats and return
|
||||
if !tips { |
||||
stats.Render() |
||||
} else { |
||||
protips.print(w.network) |
||||
} |
||||
} |
||||
|
||||
// protips contains a collection of network infos to report pro-tips
|
||||
// based on.
|
||||
type protips struct { |
||||
genesis string |
||||
network int64 |
||||
bootFull []string |
||||
bootLight []string |
||||
ethstats string |
||||
} |
||||
|
||||
// print analyzes the network information available and prints a collection of
|
||||
// pro tips for the user's consideration.
|
||||
func (p *protips) print(network string) { |
||||
// If a known genesis block is available, display it and prepend an init command
|
||||
fullinit, lightinit := "", "" |
||||
if p.genesis != "" { |
||||
fullinit = fmt.Sprintf("geth --datadir=$HOME/.%s init %s.json && ", network, network) |
||||
lightinit = fmt.Sprintf("geth --datadir=$HOME/.%s --light init %s.json && ", network, network) |
||||
} |
||||
// If an ethstats server is available, add the ethstats flag
|
||||
statsflag := "" |
||||
if p.ethstats != "" { |
||||
if strings.Contains(p.ethstats, " ") { |
||||
statsflag = fmt.Sprintf(` --ethstats="yournode:%s"`, p.ethstats) |
||||
} else { |
||||
statsflag = fmt.Sprintf(` --ethstats=yournode:%s`, p.ethstats) |
||||
} |
||||
} |
||||
// If bootnodes have been specified, add the bootnode flag
|
||||
bootflagFull := "" |
||||
if len(p.bootFull) > 0 { |
||||
bootflagFull = fmt.Sprintf(` --bootnodes %s`, strings.Join(p.bootFull, ",")) |
||||
} |
||||
bootflagLight := "" |
||||
if len(p.bootLight) > 0 { |
||||
bootflagLight = fmt.Sprintf(` --bootnodes %s`, strings.Join(p.bootLight, ",")) |
||||
} |
||||
// Assemble all the known pro-tips
|
||||
var tasks, tips []string |
||||
|
||||
tasks = append(tasks, "Run an archive node with historical data") |
||||
tips = append(tips, fmt.Sprintf("%sgeth --networkid=%d --datadir=$HOME/.%s --cache=1024%s%s", fullinit, p.network, network, statsflag, bootflagFull)) |
||||
|
||||
tasks = append(tasks, "Run a full node with recent data only") |
||||
tips = append(tips, fmt.Sprintf("%sgeth --networkid=%d --datadir=$HOME/.%s --cache=512 --fast%s%s", fullinit, p.network, network, statsflag, bootflagFull)) |
||||
|
||||
tasks = append(tasks, "Run a light node with on demand retrievals") |
||||
tips = append(tips, fmt.Sprintf("%sgeth --networkid=%d --datadir=$HOME/.%s --light%s%s", lightinit, p.network, network, statsflag, bootflagLight)) |
||||
|
||||
tasks = append(tasks, "Run an embedded node with constrained memory") |
||||
tips = append(tips, fmt.Sprintf("%sgeth --networkid=%d --datadir=$HOME/.%s --cache=32 --light%s%s", lightinit, p.network, network, statsflag, bootflagLight)) |
||||
|
||||
// If the tips are short, display in a table
|
||||
short := true |
||||
for _, tip := range tips { |
||||
if len(tip) > 100 { |
||||
short = false |
||||
break |
||||
} |
||||
} |
||||
fmt.Println() |
||||
if short { |
||||
howto := tablewriter.NewWriter(os.Stdout) |
||||
howto.SetHeader([]string{"Fun tasks for you", "Tips on how to"}) |
||||
howto.SetColWidth(100) |
||||
|
||||
for i := 0; i < len(tasks); i++ { |
||||
howto.Append([]string{tasks[i], tips[i]}) |
||||
} |
||||
howto.Render() |
||||
return |
||||
} |
||||
// Meh, tips got ugly, split into many lines
|
||||
for i := 0; i < len(tasks); i++ { |
||||
fmt.Println(tasks[i]) |
||||
fmt.Println(strings.Repeat("-", len(tasks[i]))) |
||||
fmt.Println(tips[i]) |
||||
fmt.Println() |
||||
fmt.Println() |
||||
} |
||||
} |
@ -0,0 +1,194 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strings" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// manageServers displays a list of servers the user can disconnect from, and an
|
||||
// option to connect to new servers.
|
||||
func (w *wizard) manageServers() { |
||||
// List all the servers we can disconnect, along with an entry to connect a new one
|
||||
fmt.Println() |
||||
for i, server := range w.conf.Servers { |
||||
fmt.Printf(" %d. Disconnect %s\n", i+1, server) |
||||
} |
||||
fmt.Printf(" %d. Connect another server\n", len(w.conf.Servers)+1) |
||||
|
||||
choice := w.readInt() |
||||
if choice < 0 || choice > len(w.conf.Servers)+1 { |
||||
log.Error("Invalid server choice, aborting") |
||||
return |
||||
} |
||||
// If the user selected an existing server, drop it
|
||||
if choice <= len(w.conf.Servers) { |
||||
server := w.conf.Servers[choice-1] |
||||
client := w.servers[server] |
||||
|
||||
delete(w.servers, server) |
||||
if client != nil { |
||||
client.Close() |
||||
} |
||||
w.conf.Servers = append(w.conf.Servers[:choice-1], w.conf.Servers[choice:]...) |
||||
w.conf.flush() |
||||
|
||||
log.Info("Disconnected existing server", "server", server) |
||||
w.networkStats(false) |
||||
return |
||||
} |
||||
// If the user requested connecting a new server, do it
|
||||
if w.makeServer() != "" { |
||||
w.networkStats(false) |
||||
} |
||||
} |
||||
|
||||
// makeServer reads a single line from stdin and interprets it as a hostname to
|
||||
// connect to. It tries to establish a new SSH session and also executing some
|
||||
// baseline validations.
|
||||
//
|
||||
// If connection succeeds, the server is added to the wizards configs!
|
||||
func (w *wizard) makeServer() string { |
||||
fmt.Println() |
||||
fmt.Println("Please enter remote server's address:") |
||||
|
||||
for { |
||||
// Read and fial the server to ensure docker is present
|
||||
input := w.readString() |
||||
|
||||
client, err := dial(input) |
||||
if err != nil { |
||||
log.Error("Server not ready for puppeth", "err", err) |
||||
return "" |
||||
} |
||||
// All checks passed, start tracking the server
|
||||
w.servers[input] = client |
||||
w.conf.Servers = append(w.conf.Servers, input) |
||||
w.conf.flush() |
||||
|
||||
return input |
||||
} |
||||
} |
||||
|
||||
// selectServer lists the user all the currnetly known servers to choose from,
|
||||
// also granting the option to add a new one.
|
||||
func (w *wizard) selectServer() string { |
||||
// List the available server to the user and wait for a choice
|
||||
fmt.Println() |
||||
fmt.Println("Which server do you want to interact with?") |
||||
for i, server := range w.conf.Servers { |
||||
fmt.Printf(" %d. %s\n", i+1, server) |
||||
} |
||||
fmt.Printf(" %d. Connect another server\n", len(w.conf.Servers)+1) |
||||
|
||||
choice := w.readInt() |
||||
if choice < 0 || choice > len(w.conf.Servers)+1 { |
||||
log.Error("Invalid server choice, aborting") |
||||
return "" |
||||
} |
||||
// If the user requested connecting to a new server, go for it
|
||||
if choice <= len(w.conf.Servers) { |
||||
return w.conf.Servers[choice-1] |
||||
} |
||||
return w.makeServer() |
||||
} |
||||
|
||||
// manageComponents displays a list of network components the user can tear down
|
||||
// and an option
|
||||
func (w *wizard) manageComponents() { |
||||
// List all the componens we can tear down, along with an entry to deploy a new one
|
||||
fmt.Println() |
||||
|
||||
var serviceHosts, serviceNames []string |
||||
for server, services := range w.services { |
||||
for _, service := range services { |
||||
serviceHosts = append(serviceHosts, server) |
||||
serviceNames = append(serviceNames, service) |
||||
|
||||
fmt.Printf(" %d. Tear down %s on %s\n", len(serviceHosts), strings.Title(service), server) |
||||
} |
||||
} |
||||
fmt.Printf(" %d. Deploy new network component\n", len(serviceHosts)+1) |
||||
|
||||
choice := w.readInt() |
||||
if choice < 0 || choice > len(serviceHosts)+1 { |
||||
log.Error("Invalid component choice, aborting") |
||||
return |
||||
} |
||||
// If the user selected an existing service, destroy it
|
||||
if choice <= len(serviceHosts) { |
||||
// Figure out the service to destroy and execute it
|
||||
service := serviceNames[choice-1] |
||||
server := serviceHosts[choice-1] |
||||
client := w.servers[server] |
||||
|
||||
if out, err := tearDown(client, w.network, service, true); err != nil { |
||||
log.Error("Failed to tear down component", "err", err) |
||||
if len(out) > 0 { |
||||
fmt.Printf("%s\n", out) |
||||
} |
||||
return |
||||
} |
||||
// Clean up any references to it from out state
|
||||
services := w.services[server] |
||||
for i, name := range services { |
||||
if name == service { |
||||
w.services[server] = append(services[:i], services[i+1:]...) |
||||
if len(w.services[server]) == 0 { |
||||
delete(w.services, server) |
||||
} |
||||
} |
||||
} |
||||
log.Info("Torn down existing component", "server", server, "service", service) |
||||
return |
||||
} |
||||
// If the user requested deploying a new component, do it
|
||||
w.deployComponent() |
||||
} |
||||
|
||||
// deployComponent displays a list of network components the user can deploy and
|
||||
// guides through the process.
|
||||
func (w *wizard) deployComponent() { |
||||
// Print all the things we can deploy and wait or user choice
|
||||
fmt.Println() |
||||
fmt.Println("What would you like to deploy? (recommended order)") |
||||
fmt.Println(" 1. Ethstats - Network monitoring tool") |
||||
fmt.Println(" 2. Bootnode - Entry point of the network") |
||||
fmt.Println(" 3. Sealer - Full node minting new blocks") |
||||
fmt.Println(" 4. Wallet - Browser wallet for quick sends (todo)") |
||||
fmt.Println(" 5. Faucet - Crypto faucet to give away funds") |
||||
fmt.Println(" 6. Dashboard - Website listing above web-services") |
||||
|
||||
switch w.read() { |
||||
case "1": |
||||
w.deployEthstats() |
||||
case "2": |
||||
w.deployNode(true) |
||||
case "3": |
||||
w.deployNode(false) |
||||
case "4": |
||||
case "5": |
||||
w.deployFaucet() |
||||
case "6": |
||||
w.deployDashboard() |
||||
default: |
||||
log.Error("That's not something I can do") |
||||
} |
||||
} |
@ -0,0 +1,58 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// ensureVirtualHost checks whether a reverse-proxy is running on the specified
|
||||
// host machine, and if yes requests a virtual host from the user to host a
|
||||
// specific web service on. If no proxy exists, the method will offer to deploy
|
||||
// one.
|
||||
//
|
||||
// If the user elects not to use a reverse proxy, an empty hostname is returned!
|
||||
func (w *wizard) ensureVirtualHost(client *sshClient, port int, def string) (string, error) { |
||||
if proxy, _ := checkNginx(client, w.network); proxy != nil { |
||||
// Reverse proxy is running, if ports match, we need a virtual host
|
||||
if proxy.port == port { |
||||
fmt.Println() |
||||
fmt.Printf("Shared port, which domain to assign? (default = %s)\n", def) |
||||
return w.readDefaultString(def), nil |
||||
} |
||||
} |
||||
// Reverse proxy is not running, offer to deploy a new one
|
||||
fmt.Println() |
||||
fmt.Println("Allow sharing the port with other services (y/n)? (default = yes)") |
||||
if w.readDefaultString("y") == "y" { |
||||
if out, err := deployNginx(client, w.network, port); err != nil { |
||||
log.Error("Failed to deploy reverse-proxy", "err", err) |
||||
if len(out) > 0 { |
||||
fmt.Printf("%s\n", out) |
||||
} |
||||
return "", err |
||||
} |
||||
// Reverse proxy deployed, ask again for the virtual-host
|
||||
fmt.Println() |
||||
fmt.Printf("Proxy deployed, which domain to assign? (default = %s)\n", def) |
||||
return w.readDefaultString(def), nil |
||||
} |
||||
// Reverse proxy not requested, deploy as a standalone service
|
||||
return "", nil |
||||
} |
@ -0,0 +1,153 @@ |
||||
// Copyright 2017 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/accounts/keystore" |
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/ethereum/go-ethereum/log" |
||||
) |
||||
|
||||
// deployNode creates a new node configuration based on some user input.
|
||||
func (w *wizard) deployNode(boot bool) { |
||||
// Do some sanity check before the user wastes time on input
|
||||
if w.conf.genesis == nil { |
||||
log.Error("No genesis block configured") |
||||
return |
||||
} |
||||
if w.conf.ethstats == "" { |
||||
log.Error("No ethstats server configured") |
||||
return |
||||
} |
||||
// Select the server to interact with
|
||||
server := w.selectServer() |
||||
if server == "" { |
||||
return |
||||
} |
||||
client := w.servers[server] |
||||
|
||||
// Retrieve any active ethstats configurations from the server
|
||||
infos, err := checkNode(client, w.network, boot) |
||||
if err != nil { |
||||
if boot { |
||||
infos = &nodeInfos{portFull: 30303, peersTotal: 512, peersLight: 256} |
||||
} else { |
||||
infos = &nodeInfos{portFull: 30303, peersTotal: 50, peersLight: 0} |
||||
} |
||||
} |
||||
infos.genesis, _ = json.MarshalIndent(w.conf.genesis, "", " ") |
||||
infos.network = w.conf.genesis.Config.ChainId.Int64() |
||||
|
||||
// Figure out where the user wants to store the persistent data
|
||||
fmt.Println() |
||||
if infos.datadir == "" { |
||||
fmt.Printf("Where should data be stored on the remote machine?\n") |
||||
infos.datadir = w.readString() |
||||
} else { |
||||
fmt.Printf("Where should data be stored on the remote machine? (default = %s)\n", infos.datadir) |
||||
infos.datadir = w.readDefaultString(infos.datadir) |
||||
} |
||||
// Figure out which port to listen on
|
||||
fmt.Println() |
||||
fmt.Printf("Which TCP/UDP port to listen on? (default = %d)\n", infos.portFull) |
||||
infos.portFull = w.readDefaultInt(infos.portFull) |
||||
|
||||
// Figure out how many peers to allow (different based on node type)
|
||||
fmt.Println() |
||||
fmt.Printf("How many peers to allow connecting? (default = %d)\n", infos.peersTotal) |
||||
infos.peersTotal = w.readDefaultInt(infos.peersTotal) |
||||
|
||||
// Figure out how many light peers to allow (different based on node type)
|
||||
fmt.Println() |
||||
fmt.Printf("How many light peers to allow connecting? (default = %d)\n", infos.peersLight) |
||||
infos.peersLight = w.readDefaultInt(infos.peersLight) |
||||
|
||||
// Set a proper name to report on the stats page
|
||||
fmt.Println() |
||||
if infos.ethstats == "" { |
||||
fmt.Printf("What should the node be called on the stats page?\n") |
||||
infos.ethstats = w.readString() + ":" + w.conf.ethstats |
||||
} else { |
||||
fmt.Printf("What should the node be called on the stats page? (default = %s)\n", infos.ethstats) |
||||
infos.ethstats = w.readDefaultString(infos.ethstats) + ":" + w.conf.ethstats |
||||
} |
||||
// If the node is a miner/signer, load up needed credentials
|
||||
if !boot { |
||||
if w.conf.genesis.Config.Ethash != nil { |
||||
// Ethash based miners only need an etherbase to mine against
|
||||
fmt.Println() |
||||
if infos.etherbase == "" { |
||||
fmt.Printf("What address should the miner user?\n") |
||||
for { |
||||
if address := w.readAddress(); address != nil { |
||||
infos.etherbase = address.Hex() |
||||
break |
||||
} |
||||
} |
||||
} else { |
||||
fmt.Printf("What address should the miner user? (default = %s)\n", infos.etherbase) |
||||
infos.etherbase = w.readDefaultAddress(common.HexToAddress(infos.etherbase)).Hex() |
||||
} |
||||
} else if w.conf.genesis.Config.Clique != nil { |
||||
// If a previous signer was already set, offer to reuse it
|
||||
if infos.keyJSON != "" { |
||||
var key keystore.Key |
||||
if err := json.Unmarshal([]byte(infos.keyJSON), &key); err != nil { |
||||
infos.keyJSON, infos.keyPass = "", "" |
||||
} else { |
||||
fmt.Println() |
||||
fmt.Printf("Reuse previous (%s) signing account (y/n)? (default = yes)\n", key.Address.Hex()) |
||||
if w.readDefaultString("y") != "y" { |
||||
infos.keyJSON, infos.keyPass = "", "" |
||||
} |
||||
} |
||||
} |
||||
// Clique based signers need a keyfile and unlock password, ask if unavailable
|
||||
if infos.keyJSON == "" { |
||||
fmt.Println() |
||||
fmt.Println("Please paste the signer's key JSON:") |
||||
infos.keyJSON = w.readJSON() |
||||
|
||||
fmt.Println() |
||||
fmt.Println("What's the unlock password for the account? (won't be echoed)") |
||||
infos.keyPass = w.readPassword() |
||||
|
||||
if _, err := keystore.DecryptKey([]byte(infos.keyJSON), infos.keyPass); err != nil { |
||||
log.Error("Failed to decrypt key with given passphrase") |
||||
return |
||||
} |
||||
} |
||||
} |
||||
} |
||||
// Try to deploy the full node on the host
|
||||
if out, err := deployNode(client, w.network, w.conf.bootFull, infos); err != nil { |
||||
log.Error("Failed to deploy Ethereum node container", "err", err) |
||||
if len(out) > 0 { |
||||
fmt.Printf("%s\n", out) |
||||
} |
||||
return |
||||
} |
||||
// All ok, run a network scan to pick any changes up
|
||||
log.Info("Waiting for node to finish booting") |
||||
time.Sleep(3 * time.Second) |
||||
|
||||
w.networkStats(false) |
||||
} |
@ -0,0 +1,19 @@ |
||||
Copyright (C) 2014 by Oleku Konko |
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy |
||||
of this software and associated documentation files (the "Software"), to deal |
||||
in the Software without restriction, including without limitation the rights |
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
||||
copies of the Software, and to permit persons to whom the Software is |
||||
furnished to do so, subject to the following conditions: |
||||
|
||||
The above copyright notice and this permission notice shall be included in |
||||
all copies or substantial portions of the Software. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
||||
THE SOFTWARE. |
@ -0,0 +1,204 @@ |
||||
ASCII Table Writer |
||||
========= |
||||
|
||||
[![Build Status](https://travis-ci.org/olekukonko/tablewriter.png?branch=master)](https://travis-ci.org/olekukonko/tablewriter) [![Total views](https://sourcegraph.com/api/repos/github.com/olekukonko/tablewriter/counters/views.png)](https://sourcegraph.com/github.com/olekukonko/tablewriter) |
||||
|
||||
Generate ASCII table on the fly ... Installation is simple as |
||||
|
||||
go get github.com/olekukonko/tablewriter |
||||
|
||||
|
||||
#### Features |
||||
- Automatic Padding |
||||
- Support Multiple Lines |
||||
- Supports Alignment |
||||
- Support Custom Separators |
||||
- Automatic Alignment of numbers & percentage |
||||
- Write directly to http , file etc via `io.Writer` |
||||
- Read directly from CSV file |
||||
- Optional row line via `SetRowLine` |
||||
- Normalise table header |
||||
- Make CSV Headers optional |
||||
- Enable or disable table border |
||||
- Set custom footer support |
||||
- Optional identical cells merging |
||||
|
||||
|
||||
#### Example 1 - Basic |
||||
```go |
||||
data := [][]string{ |
||||
[]string{"A", "The Good", "500"}, |
||||
[]string{"B", "The Very very Bad Man", "288"}, |
||||
[]string{"C", "The Ugly", "120"}, |
||||
[]string{"D", "The Gopher", "800"}, |
||||
} |
||||
|
||||
table := tablewriter.NewWriter(os.Stdout) |
||||
table.SetHeader([]string{"Name", "Sign", "Rating"}) |
||||
|
||||
for _, v := range data { |
||||
table.Append(v) |
||||
} |
||||
table.Render() // Send output |
||||
``` |
||||
|
||||
##### Output 1 |
||||
``` |
||||
+------+-----------------------+--------+ |
||||
| NAME | SIGN | RATING | |
||||
+------+-----------------------+--------+ |
||||
| A | The Good | 500 | |
||||
| B | The Very very Bad Man | 288 | |
||||
| C | The Ugly | 120 | |
||||
| D | The Gopher | 800 | |
||||
+------+-----------------------+--------+ |
||||
``` |
||||
|
||||
#### Example 2 - Without Border / Footer / Bulk Append |
||||
```go |
||||
data := [][]string{ |
||||
[]string{"1/1/2014", "Domain name", "2233", "$10.98"}, |
||||
[]string{"1/1/2014", "January Hosting", "2233", "$54.95"}, |
||||
[]string{"1/4/2014", "February Hosting", "2233", "$51.00"}, |
||||
[]string{"1/4/2014", "February Extra Bandwidth", "2233", "$30.00"}, |
||||
} |
||||
|
||||
table := tablewriter.NewWriter(os.Stdout) |
||||
table.SetHeader([]string{"Date", "Description", "CV2", "Amount"}) |
||||
table.SetFooter([]string{"", "", "Total", "$146.93"}) // Add Footer |
||||
table.SetBorder(false) // Set Border to false |
||||
table.AppendBulk(data) // Add Bulk Data |
||||
table.Render() |
||||
``` |
||||
|
||||
##### Output 2 |
||||
``` |
||||
|
||||
DATE | DESCRIPTION | CV2 | AMOUNT |
||||
+----------+--------------------------+-------+---------+ |
||||
1/1/2014 | Domain name | 2233 | $10.98 |
||||
1/1/2014 | January Hosting | 2233 | $54.95 |
||||
1/4/2014 | February Hosting | 2233 | $51.00 |
||||
1/4/2014 | February Extra Bandwidth | 2233 | $30.00 |
||||
+----------+--------------------------+-------+---------+ |
||||
TOTAL | $146 93 |
||||
+-------+---------+ |
||||
|
||||
``` |
||||
|
||||
|
||||
#### Example 3 - CSV |
||||
```go |
||||
table, _ := tablewriter.NewCSV(os.Stdout, "test_info.csv", true) |
||||
table.SetAlignment(tablewriter.ALIGN_LEFT) // Set Alignment |
||||
table.Render() |
||||
``` |
||||
|
||||
##### Output 3 |
||||
``` |
||||
+----------+--------------+------+-----+---------+----------------+ |
||||
| FIELD | TYPE | NULL | KEY | DEFAULT | EXTRA | |
||||
+----------+--------------+------+-----+---------+----------------+ |
||||
| user_id | smallint(5) | NO | PRI | NULL | auto_increment | |
||||
| username | varchar(10) | NO | | NULL | | |
||||
| password | varchar(100) | NO | | NULL | | |
||||
+----------+--------------+------+-----+---------+----------------+ |
||||
``` |
||||
|
||||
#### Example 4 - Custom Separator |
||||
```go |
||||
table, _ := tablewriter.NewCSV(os.Stdout, "test.csv", true) |
||||
table.SetRowLine(true) // Enable row line |
||||
|
||||
// Change table lines |
||||
table.SetCenterSeparator("*") |
||||
table.SetColumnSeparator("‡") |
||||
table.SetRowSeparator("-") |
||||
|
||||
table.SetAlignment(tablewriter.ALIGN_LEFT) |
||||
table.Render() |
||||
``` |
||||
|
||||
##### Output 4 |
||||
``` |
||||
*------------*-----------*---------* |
||||
╪ FIRST NAME ╪ LAST NAME ╪ SSN ╪ |
||||
*------------*-----------*---------* |
||||
╪ John ╪ Barry ╪ 123456 ╪ |
||||
*------------*-----------*---------* |
||||
╪ Kathy ╪ Smith ╪ 687987 ╪ |
||||
*------------*-----------*---------* |
||||
╪ Bob ╪ McCornick ╪ 3979870 ╪ |
||||
*------------*-----------*---------* |
||||
``` |
||||
|
||||
##### Example 5 - Markdown Format |
||||
```go |
||||
data := [][]string{ |
||||
[]string{"1/1/2014", "Domain name", "2233", "$10.98"}, |
||||
[]string{"1/1/2014", "January Hosting", "2233", "$54.95"}, |
||||
[]string{"1/4/2014", "February Hosting", "2233", "$51.00"}, |
||||
[]string{"1/4/2014", "February Extra Bandwidth", "2233", "$30.00"}, |
||||
} |
||||
|
||||
table := tablewriter.NewWriter(os.Stdout) |
||||
table.SetHeader([]string{"Date", "Description", "CV2", "Amount"}) |
||||
table.SetBorders(tablewriter.Border{Left: true, Top: false, Right: true, Bottom: false}) |
||||
table.SetCenterSeparator("|") |
||||
table.AppendBulk(data) // Add Bulk Data |
||||
table.Render() |
||||
``` |
||||
|
||||
##### Output 5 |
||||
``` |
||||
| DATE | DESCRIPTION | CV2 | AMOUNT | |
||||
|----------|--------------------------|------|--------| |
||||
| 1/1/2014 | Domain name | 2233 | $10.98 | |
||||
| 1/1/2014 | January Hosting | 2233 | $54.95 | |
||||
| 1/4/2014 | February Hosting | 2233 | $51.00 | |
||||
| 1/4/2014 | February Extra Bandwidth | 2233 | $30.00 | |
||||
``` |
||||
|
||||
#### Example 6 - Identical cells merging |
||||
```go |
||||
data := [][]string{ |
||||
[]string{"1/1/2014", "Domain name", "1234", "$10.98"}, |
||||
[]string{"1/1/2014", "January Hosting", "2345", "$54.95"}, |
||||
[]string{"1/4/2014", "February Hosting", "3456", "$51.00"}, |
||||
[]string{"1/4/2014", "February Extra Bandwidth", "4567", "$30.00"}, |
||||
} |
||||
|
||||
table := tablewriter.NewWriter(os.Stdout) |
||||
table.SetHeader([]string{"Date", "Description", "CV2", "Amount"}) |
||||
table.SetFooter([]string{"", "", "Total", "$146.93"}) |
||||
table.SetAutoMergeCells(true) |
||||
table.SetRowLine(true) |
||||
table.AppendBulk(data) |
||||
table.Render() |
||||
``` |
||||
|
||||
##### Output 6 |
||||
``` |
||||
+----------+--------------------------+-------+---------+ |
||||
| DATE | DESCRIPTION | CV2 | AMOUNT | |
||||
+----------+--------------------------+-------+---------+ |
||||
| 1/1/2014 | Domain name | 1234 | $10.98 | |
||||
+ +--------------------------+-------+---------+ |
||||
| | January Hosting | 2345 | $54.95 | |
||||
+----------+--------------------------+-------+---------+ |
||||
| 1/4/2014 | February Hosting | 3456 | $51.00 | |
||||
+ +--------------------------+-------+---------+ |
||||
| | February Extra Bandwidth | 4567 | $30.00 | |
||||
+----------+--------------------------+-------+---------+ |
||||
| TOTAL | $146 93 | |
||||
+----------+--------------------------+-------+---------+ |
||||
``` |
||||
|
||||
#### TODO |
||||
- ~~Import Directly from CSV~~ - `done` |
||||
- ~~Support for `SetFooter`~~ - `done` |
||||
- ~~Support for `SetBorder`~~ - `done` |
||||
- ~~Support table with uneven rows~~ - `done` |
||||
- Support custom alignment |
||||
- General Improvement & Optimisation |
||||
- `NewHTML` Parse table from HTML |
@ -0,0 +1,52 @@ |
||||
// Copyright 2014 Oleku Konko All rights reserved.
|
||||
// Use of this source code is governed by a MIT
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This module is a Table Writer API for the Go Programming Language.
|
||||
// The protocols were written in pure Go and works on windows and unix systems
|
||||
|
||||
package tablewriter |
||||
|
||||
import ( |
||||
"encoding/csv" |
||||
"io" |
||||
"os" |
||||
) |
||||
|
||||
// Start A new table by importing from a CSV file
|
||||
// Takes io.Writer and csv File name
|
||||
func NewCSV(writer io.Writer, fileName string, hasHeader bool) (*Table, error) { |
||||
file, err := os.Open(fileName) |
||||
if err != nil { |
||||
return &Table{}, err |
||||
} |
||||
defer file.Close() |
||||
csvReader := csv.NewReader(file) |
||||
t, err := NewCSVReader(writer, csvReader, hasHeader) |
||||
return t, err |
||||
} |
||||
|
||||
// Start a New Table Writer with csv.Reader
|
||||
// This enables customisation such as reader.Comma = ';'
|
||||
// See http://golang.org/src/pkg/encoding/csv/reader.go?s=3213:3671#L94
|
||||
func NewCSVReader(writer io.Writer, csvReader *csv.Reader, hasHeader bool) (*Table, error) { |
||||
t := NewWriter(writer) |
||||
if hasHeader { |
||||
// Read the first row
|
||||
headers, err := csvReader.Read() |
||||
if err != nil { |
||||
return &Table{}, err |
||||
} |
||||
t.SetHeader(headers) |
||||
} |
||||
for { |
||||
record, err := csvReader.Read() |
||||
if err == io.EOF { |
||||
break |
||||
} else if err != nil { |
||||
return &Table{}, err |
||||
} |
||||
t.Append(record) |
||||
} |
||||
return t, nil |
||||
} |
@ -0,0 +1,662 @@ |
||||
// Copyright 2014 Oleku Konko All rights reserved.
|
||||
// Use of this source code is governed by a MIT
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This module is a Table Writer API for the Go Programming Language.
|
||||
// The protocols were written in pure Go and works on windows and unix systems
|
||||
|
||||
// Create & Generate text based table
|
||||
package tablewriter |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"regexp" |
||||
"strings" |
||||
) |
||||
|
||||
const ( |
||||
MAX_ROW_WIDTH = 30 |
||||
) |
||||
|
||||
const ( |
||||
CENTER = "+" |
||||
ROW = "-" |
||||
COLUMN = "|" |
||||
SPACE = " " |
||||
NEWLINE = "\n" |
||||
) |
||||
|
||||
const ( |
||||
ALIGN_DEFAULT = iota |
||||
ALIGN_CENTER |
||||
ALIGN_RIGHT |
||||
ALIGN_LEFT |
||||
) |
||||
|
||||
var ( |
||||
decimal = regexp.MustCompile(`^-*\d*\.?\d*$`) |
||||
percent = regexp.MustCompile(`^-*\d*\.?\d*$%$`) |
||||
) |
||||
|
||||
type Border struct { |
||||
Left bool |
||||
Right bool |
||||
Top bool |
||||
Bottom bool |
||||
} |
||||
|
||||
type Table struct { |
||||
out io.Writer |
||||
rows [][]string |
||||
lines [][][]string |
||||
cs map[int]int |
||||
rs map[int]int |
||||
headers []string |
||||
footers []string |
||||
autoFmt bool |
||||
autoWrap bool |
||||
mW int |
||||
pCenter string |
||||
pRow string |
||||
pColumn string |
||||
tColumn int |
||||
tRow int |
||||
hAlign int |
||||
fAlign int |
||||
align int |
||||
newLine string |
||||
rowLine bool |
||||
autoMergeCells bool |
||||
hdrLine bool |
||||
borders Border |
||||
colSize int |
||||
} |
||||
|
||||
// Start New Table
|
||||
// Take io.Writer Directly
|
||||
func NewWriter(writer io.Writer) *Table { |
||||
t := &Table{ |
||||
out: writer, |
||||
rows: [][]string{}, |
||||
lines: [][][]string{}, |
||||
cs: make(map[int]int), |
||||
rs: make(map[int]int), |
||||
headers: []string{}, |
||||
footers: []string{}, |
||||
autoFmt: true, |
||||
autoWrap: true, |
||||
mW: MAX_ROW_WIDTH, |
||||
pCenter: CENTER, |
||||
pRow: ROW, |
||||
pColumn: COLUMN, |
||||
tColumn: -1, |
||||
tRow: -1, |
||||
hAlign: ALIGN_DEFAULT, |
||||
fAlign: ALIGN_DEFAULT, |
||||
align: ALIGN_DEFAULT, |
||||
newLine: NEWLINE, |
||||
rowLine: false, |
||||
hdrLine: true, |
||||
borders: Border{Left: true, Right: true, Bottom: true, Top: true}, |
||||
colSize: -1} |
||||
return t |
||||
} |
||||
|
||||
// Render table output
|
||||
func (t Table) Render() { |
||||
if t.borders.Top { |
||||
t.printLine(true) |
||||
} |
||||
t.printHeading() |
||||
if t.autoMergeCells { |
||||
t.printRowsMergeCells() |
||||
} else { |
||||
t.printRows() |
||||
} |
||||
|
||||
if !t.rowLine && t.borders.Bottom { |
||||
t.printLine(true) |
||||
} |
||||
t.printFooter() |
||||
|
||||
} |
||||
|
||||
// Set table header
|
||||
func (t *Table) SetHeader(keys []string) { |
||||
t.colSize = len(keys) |
||||
for i, v := range keys { |
||||
t.parseDimension(v, i, -1) |
||||
t.headers = append(t.headers, v) |
||||
} |
||||
} |
||||
|
||||
// Set table Footer
|
||||
func (t *Table) SetFooter(keys []string) { |
||||
//t.colSize = len(keys)
|
||||
for i, v := range keys { |
||||
t.parseDimension(v, i, -1) |
||||
t.footers = append(t.footers, v) |
||||
} |
||||
} |
||||
|
||||
// Turn header autoformatting on/off. Default is on (true).
|
||||
func (t *Table) SetAutoFormatHeaders(auto bool) { |
||||
t.autoFmt = auto |
||||
} |
||||
|
||||
// Turn automatic multiline text adjustment on/off. Default is on (true).
|
||||
func (t *Table) SetAutoWrapText(auto bool) { |
||||
t.autoWrap = auto |
||||
} |
||||
|
||||
// Set the Default column width
|
||||
func (t *Table) SetColWidth(width int) { |
||||
t.mW = width |
||||
} |
||||
|
||||
// Set the Column Separator
|
||||
func (t *Table) SetColumnSeparator(sep string) { |
||||
t.pColumn = sep |
||||
} |
||||
|
||||
// Set the Row Separator
|
||||
func (t *Table) SetRowSeparator(sep string) { |
||||
t.pRow = sep |
||||
} |
||||
|
||||
// Set the center Separator
|
||||
func (t *Table) SetCenterSeparator(sep string) { |
||||
t.pCenter = sep |
||||
} |
||||
|
||||
// Set Header Alignment
|
||||
func (t *Table) SetHeaderAlignment(hAlign int) { |
||||
t.hAlign = hAlign |
||||
} |
||||
|
||||
// Set Footer Alignment
|
||||
func (t *Table) SetFooterAlignment(fAlign int) { |
||||
t.fAlign = fAlign |
||||
} |
||||
|
||||
// Set Table Alignment
|
||||
func (t *Table) SetAlignment(align int) { |
||||
t.align = align |
||||
} |
||||
|
||||
// Set New Line
|
||||
func (t *Table) SetNewLine(nl string) { |
||||
t.newLine = nl |
||||
} |
||||
|
||||
// Set Header Line
|
||||
// This would enable / disable a line after the header
|
||||
func (t *Table) SetHeaderLine(line bool) { |
||||
t.hdrLine = line |
||||
} |
||||
|
||||
// Set Row Line
|
||||
// This would enable / disable a line on each row of the table
|
||||
func (t *Table) SetRowLine(line bool) { |
||||
t.rowLine = line |
||||
} |
||||
|
||||
// Set Auto Merge Cells
|
||||
// This would enable / disable the merge of cells with identical values
|
||||
func (t *Table) SetAutoMergeCells(auto bool) { |
||||
t.autoMergeCells = auto |
||||
} |
||||
|
||||
// Set Table Border
|
||||
// This would enable / disable line around the table
|
||||
func (t *Table) SetBorder(border bool) { |
||||
t.SetBorders(Border{border, border, border, border}) |
||||
} |
||||
|
||||
func (t *Table) SetBorders(border Border) { |
||||
t.borders = border |
||||
} |
||||
|
||||
// Append row to table
|
||||
func (t *Table) Append(row []string) { |
||||
rowSize := len(t.headers) |
||||
if rowSize > t.colSize { |
||||
t.colSize = rowSize |
||||
} |
||||
|
||||
n := len(t.lines) |
||||
line := [][]string{} |
||||
for i, v := range row { |
||||
|
||||
// Detect string width
|
||||
// Detect String height
|
||||
// Break strings into words
|
||||
out := t.parseDimension(v, i, n) |
||||
|
||||
// Append broken words
|
||||
line = append(line, out) |
||||
} |
||||
t.lines = append(t.lines, line) |
||||
} |
||||
|
||||
// Allow Support for Bulk Append
|
||||
// Eliminates repeated for loops
|
||||
func (t *Table) AppendBulk(rows [][]string) { |
||||
for _, row := range rows { |
||||
t.Append(row) |
||||
} |
||||
} |
||||
|
||||
// Print line based on row width
|
||||
func (t Table) printLine(nl bool) { |
||||
fmt.Fprint(t.out, t.pCenter) |
||||
for i := 0; i < len(t.cs); i++ { |
||||
v := t.cs[i] |
||||
fmt.Fprintf(t.out, "%s%s%s%s", |
||||
t.pRow, |
||||
strings.Repeat(string(t.pRow), v), |
||||
t.pRow, |
||||
t.pCenter) |
||||
} |
||||
if nl { |
||||
fmt.Fprint(t.out, t.newLine) |
||||
} |
||||
} |
||||
|
||||
// Print line based on row width with our without cell separator
|
||||
func (t Table) printLineOptionalCellSeparators(nl bool, displayCellSeparator []bool) { |
||||
fmt.Fprint(t.out, t.pCenter) |
||||
for i := 0; i < len(t.cs); i++ { |
||||
v := t.cs[i] |
||||
if i > len(displayCellSeparator) || displayCellSeparator[i] { |
||||
// Display the cell separator
|
||||
fmt.Fprintf(t.out, "%s%s%s%s", |
||||
t.pRow, |
||||
strings.Repeat(string(t.pRow), v), |
||||
t.pRow, |
||||
t.pCenter) |
||||
} else { |
||||
// Don't display the cell separator for this cell
|
||||
fmt.Fprintf(t.out, "%s%s", |
||||
strings.Repeat(" ", v+2), |
||||
t.pCenter) |
||||
} |
||||
} |
||||
if nl { |
||||
fmt.Fprint(t.out, t.newLine) |
||||
} |
||||
} |
||||
|
||||
// Return the PadRight function if align is left, PadLeft if align is right,
|
||||
// and Pad by default
|
||||
func pad(align int) func(string, string, int) string { |
||||
padFunc := Pad |
||||
switch align { |
||||
case ALIGN_LEFT: |
||||
padFunc = PadRight |
||||
case ALIGN_RIGHT: |
||||
padFunc = PadLeft |
||||
} |
||||
return padFunc |
||||
} |
||||
|
||||
// Print heading information
|
||||
func (t Table) printHeading() { |
||||
// Check if headers is available
|
||||
if len(t.headers) < 1 { |
||||
return |
||||
} |
||||
|
||||
// Check if border is set
|
||||
// Replace with space if not set
|
||||
fmt.Fprint(t.out, ConditionString(t.borders.Left, t.pColumn, SPACE)) |
||||
|
||||
// Identify last column
|
||||
end := len(t.cs) - 1 |
||||
|
||||
// Get pad function
|
||||
padFunc := pad(t.hAlign) |
||||
|
||||
// Print Heading column
|
||||
for i := 0; i <= end; i++ { |
||||
v := t.cs[i] |
||||
h := t.headers[i] |
||||
if t.autoFmt { |
||||
h = Title(h) |
||||
} |
||||
pad := ConditionString((i == end && !t.borders.Left), SPACE, t.pColumn) |
||||
fmt.Fprintf(t.out, " %s %s", |
||||
padFunc(h, SPACE, v), |
||||
pad) |
||||
} |
||||
// Next line
|
||||
fmt.Fprint(t.out, t.newLine) |
||||
if t.hdrLine { |
||||
t.printLine(true) |
||||
} |
||||
} |
||||
|
||||
// Print heading information
|
||||
func (t Table) printFooter() { |
||||
// Check if headers is available
|
||||
if len(t.footers) < 1 { |
||||
return |
||||
} |
||||
|
||||
// Only print line if border is not set
|
||||
if !t.borders.Bottom { |
||||
t.printLine(true) |
||||
} |
||||
// Check if border is set
|
||||
// Replace with space if not set
|
||||
fmt.Fprint(t.out, ConditionString(t.borders.Bottom, t.pColumn, SPACE)) |
||||
|
||||
// Identify last column
|
||||
end := len(t.cs) - 1 |
||||
|
||||
// Get pad function
|
||||
padFunc := pad(t.fAlign) |
||||
|
||||
// Print Heading column
|
||||
for i := 0; i <= end; i++ { |
||||
v := t.cs[i] |
||||
f := t.footers[i] |
||||
if t.autoFmt { |
||||
f = Title(f) |
||||
} |
||||
pad := ConditionString((i == end && !t.borders.Top), SPACE, t.pColumn) |
||||
|
||||
if len(t.footers[i]) == 0 { |
||||
pad = SPACE |
||||
} |
||||
fmt.Fprintf(t.out, " %s %s", |
||||
padFunc(f, SPACE, v), |
||||
pad) |
||||
} |
||||
// Next line
|
||||
fmt.Fprint(t.out, t.newLine) |
||||
//t.printLine(true)
|
||||
|
||||
hasPrinted := false |
||||
|
||||
for i := 0; i <= end; i++ { |
||||
v := t.cs[i] |
||||
pad := t.pRow |
||||
center := t.pCenter |
||||
length := len(t.footers[i]) |
||||
|
||||
if length > 0 { |
||||
hasPrinted = true |
||||
} |
||||
|
||||
// Set center to be space if length is 0
|
||||
if length == 0 && !t.borders.Right { |
||||
center = SPACE |
||||
} |
||||
|
||||
// Print first junction
|
||||
if i == 0 { |
||||
fmt.Fprint(t.out, center) |
||||
} |
||||
|
||||
// Pad With space of length is 0
|
||||
if length == 0 { |
||||
pad = SPACE |
||||
} |
||||
// Ignore left space of it has printed before
|
||||
if hasPrinted || t.borders.Left { |
||||
pad = t.pRow |
||||
center = t.pCenter |
||||
} |
||||
|
||||
// Change Center start position
|
||||
if center == SPACE { |
||||
if i < end && len(t.footers[i+1]) != 0 { |
||||
center = t.pCenter |
||||
} |
||||
} |
||||
|
||||
// Print the footer
|
||||
fmt.Fprintf(t.out, "%s%s%s%s", |
||||
pad, |
||||
strings.Repeat(string(pad), v), |
||||
pad, |
||||
center) |
||||
|
||||
} |
||||
|
||||
fmt.Fprint(t.out, t.newLine) |
||||
|
||||
} |
||||
|
||||
func (t Table) printRows() { |
||||
for i, lines := range t.lines { |
||||
t.printRow(lines, i) |
||||
} |
||||
|
||||
} |
||||
|
||||
// Print Row Information
|
||||
// Adjust column alignment based on type
|
||||
|
||||
func (t Table) printRow(columns [][]string, colKey int) { |
||||
// Get Maximum Height
|
||||
max := t.rs[colKey] |
||||
total := len(columns) |
||||
|
||||
// TODO Fix uneven col size
|
||||
// if total < t.colSize {
|
||||
// for n := t.colSize - total; n < t.colSize ; n++ {
|
||||
// columns = append(columns, []string{SPACE})
|
||||
// t.cs[n] = t.mW
|
||||
// }
|
||||
//}
|
||||
|
||||
// Pad Each Height
|
||||
// pads := []int{}
|
||||
pads := []int{} |
||||
|
||||
for i, line := range columns { |
||||
length := len(line) |
||||
pad := max - length |
||||
pads = append(pads, pad) |
||||
for n := 0; n < pad; n++ { |
||||
columns[i] = append(columns[i], " ") |
||||
} |
||||
} |
||||
//fmt.Println(max, "\n")
|
||||
for x := 0; x < max; x++ { |
||||
for y := 0; y < total; y++ { |
||||
|
||||
// Check if border is set
|
||||
fmt.Fprint(t.out, ConditionString((!t.borders.Left && y == 0), SPACE, t.pColumn)) |
||||
|
||||
fmt.Fprintf(t.out, SPACE) |
||||
str := columns[y][x] |
||||
|
||||
// This would print alignment
|
||||
// Default alignment would use multiple configuration
|
||||
switch t.align { |
||||
case ALIGN_CENTER: //
|
||||
fmt.Fprintf(t.out, "%s", Pad(str, SPACE, t.cs[y])) |
||||
case ALIGN_RIGHT: |
||||
fmt.Fprintf(t.out, "%s", PadLeft(str, SPACE, t.cs[y])) |
||||
case ALIGN_LEFT: |
||||
fmt.Fprintf(t.out, "%s", PadRight(str, SPACE, t.cs[y])) |
||||
default: |
||||
if decimal.MatchString(strings.TrimSpace(str)) || percent.MatchString(strings.TrimSpace(str)) { |
||||
fmt.Fprintf(t.out, "%s", PadLeft(str, SPACE, t.cs[y])) |
||||
} else { |
||||
fmt.Fprintf(t.out, "%s", PadRight(str, SPACE, t.cs[y])) |
||||
|
||||
// TODO Custom alignment per column
|
||||
//if max == 1 || pads[y] > 0 {
|
||||
// fmt.Fprintf(t.out, "%s", Pad(str, SPACE, t.cs[y]))
|
||||
//} else {
|
||||
// fmt.Fprintf(t.out, "%s", PadRight(str, SPACE, t.cs[y]))
|
||||
//}
|
||||
|
||||
} |
||||
} |
||||
fmt.Fprintf(t.out, SPACE) |
||||
} |
||||
// Check if border is set
|
||||
// Replace with space if not set
|
||||
fmt.Fprint(t.out, ConditionString(t.borders.Left, t.pColumn, SPACE)) |
||||
fmt.Fprint(t.out, t.newLine) |
||||
} |
||||
|
||||
if t.rowLine { |
||||
t.printLine(true) |
||||
} |
||||
} |
||||
|
||||
// Print the rows of the table and merge the cells that are identical
|
||||
func (t Table) printRowsMergeCells() { |
||||
var previousLine []string |
||||
var displayCellBorder []bool |
||||
var tmpWriter bytes.Buffer |
||||
for i, lines := range t.lines { |
||||
// We store the display of the current line in a tmp writer, as we need to know which border needs to be print above
|
||||
previousLine, displayCellBorder = t.printRowMergeCells(&tmpWriter, lines, i, previousLine) |
||||
if i > 0 { //We don't need to print borders above first line
|
||||
if t.rowLine { |
||||
t.printLineOptionalCellSeparators(true, displayCellBorder) |
||||
} |
||||
} |
||||
tmpWriter.WriteTo(t.out) |
||||
} |
||||
//Print the end of the table
|
||||
if t.rowLine { |
||||
t.printLine(true) |
||||
} |
||||
} |
||||
|
||||
// Print Row Information to a writer and merge identical cells.
|
||||
// Adjust column alignment based on type
|
||||
|
||||
func (t Table) printRowMergeCells(writer io.Writer, columns [][]string, colKey int, previousLine []string) ([]string, []bool) { |
||||
// Get Maximum Height
|
||||
max := t.rs[colKey] |
||||
total := len(columns) |
||||
|
||||
// Pad Each Height
|
||||
pads := []int{} |
||||
|
||||
for i, line := range columns { |
||||
length := len(line) |
||||
pad := max - length |
||||
pads = append(pads, pad) |
||||
for n := 0; n < pad; n++ { |
||||
columns[i] = append(columns[i], " ") |
||||
} |
||||
} |
||||
|
||||
var displayCellBorder []bool |
||||
for x := 0; x < max; x++ { |
||||
for y := 0; y < total; y++ { |
||||
|
||||
// Check if border is set
|
||||
fmt.Fprint(writer, ConditionString((!t.borders.Left && y == 0), SPACE, t.pColumn)) |
||||
|
||||
fmt.Fprintf(writer, SPACE) |
||||
|
||||
str := columns[y][x] |
||||
|
||||
if t.autoMergeCells { |
||||
//Store the full line to merge mutli-lines cells
|
||||
fullLine := strings.Join(columns[y], " ") |
||||
if len(previousLine) > y && fullLine == previousLine[y] && fullLine != "" { |
||||
// If this cell is identical to the one above but not empty, we don't display the border and keep the cell empty.
|
||||
displayCellBorder = append(displayCellBorder, false) |
||||
str = "" |
||||
} else { |
||||
// First line or different content, keep the content and print the cell border
|
||||
displayCellBorder = append(displayCellBorder, true) |
||||
} |
||||
} |
||||
|
||||
// This would print alignment
|
||||
// Default alignment would use multiple configuration
|
||||
switch t.align { |
||||
case ALIGN_CENTER: //
|
||||
fmt.Fprintf(writer, "%s", Pad(str, SPACE, t.cs[y])) |
||||
case ALIGN_RIGHT: |
||||
fmt.Fprintf(writer, "%s", PadLeft(str, SPACE, t.cs[y])) |
||||
case ALIGN_LEFT: |
||||
fmt.Fprintf(writer, "%s", PadRight(str, SPACE, t.cs[y])) |
||||
default: |
||||
if decimal.MatchString(strings.TrimSpace(str)) || percent.MatchString(strings.TrimSpace(str)) { |
||||
fmt.Fprintf(writer, "%s", PadLeft(str, SPACE, t.cs[y])) |
||||
} else { |
||||
fmt.Fprintf(writer, "%s", PadRight(str, SPACE, t.cs[y])) |
||||
} |
||||
} |
||||
fmt.Fprintf(writer, SPACE) |
||||
} |
||||
// Check if border is set
|
||||
// Replace with space if not set
|
||||
fmt.Fprint(writer, ConditionString(t.borders.Left, t.pColumn, SPACE)) |
||||
fmt.Fprint(writer, t.newLine) |
||||
} |
||||
|
||||
//The new previous line is the current one
|
||||
previousLine = make([]string, total) |
||||
for y := 0; y < total; y++ { |
||||
previousLine[y] = strings.Join(columns[y], " ") //Store the full line for multi-lines cells
|
||||
} |
||||
//Returns the newly added line and wether or not a border should be displayed above.
|
||||
return previousLine, displayCellBorder |
||||
} |
||||
|
||||
func (t *Table) parseDimension(str string, colKey, rowKey int) []string { |
||||
var ( |
||||
raw []string |
||||
max int |
||||
) |
||||
w := DisplayWidth(str) |
||||
// Calculate Width
|
||||
// Check if with is grater than maximum width
|
||||
if w > t.mW { |
||||
w = t.mW |
||||
} |
||||
|
||||
// Check if width exists
|
||||
v, ok := t.cs[colKey] |
||||
if !ok || v < w || v == 0 { |
||||
t.cs[colKey] = w |
||||
} |
||||
|
||||
if rowKey == -1 { |
||||
return raw |
||||
} |
||||
// Calculate Height
|
||||
if t.autoWrap { |
||||
raw, _ = WrapString(str, t.cs[colKey]) |
||||
} else { |
||||
raw = getLines(str) |
||||
} |
||||
|
||||
for _, line := range raw { |
||||
if w := DisplayWidth(line); w > max { |
||||
max = w |
||||
} |
||||
} |
||||
|
||||
// Make sure the with is the same length as maximum word
|
||||
// Important for cases where the width is smaller than maxu word
|
||||
if max > t.cs[colKey] { |
||||
t.cs[colKey] = max |
||||
} |
||||
|
||||
h := len(raw) |
||||
v, ok = t.rs[rowKey] |
||||
|
||||
if !ok || v < h || v == 0 { |
||||
t.rs[rowKey] = h |
||||
} |
||||
//fmt.Printf("Raw %+v %d\n", raw, len(raw))
|
||||
return raw |
||||
} |
|
|
@ -0,0 +1,72 @@ |
||||
// Copyright 2014 Oleku Konko All rights reserved.
|
||||
// Use of this source code is governed by a MIT
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This module is a Table Writer API for the Go Programming Language.
|
||||
// The protocols were written in pure Go and works on windows and unix systems
|
||||
|
||||
package tablewriter |
||||
|
||||
import ( |
||||
"math" |
||||
"regexp" |
||||
"strings" |
||||
|
||||
"github.com/mattn/go-runewidth" |
||||
) |
||||
|
||||
var ansi = regexp.MustCompile("\033\\[(?:[0-9]{1,3}(?:;[0-9]{1,3})*)?[m|K]") |
||||
|
||||
func DisplayWidth(str string) int { |
||||
return runewidth.StringWidth(ansi.ReplaceAllLiteralString(str, "")) |
||||
} |
||||
|
||||
// Simple Condition for string
|
||||
// Returns value based on condition
|
||||
func ConditionString(cond bool, valid, inValid string) string { |
||||
if cond { |
||||
return valid |
||||
} |
||||
return inValid |
||||
} |
||||
|
||||
// Format Table Header
|
||||
// Replace _ , . and spaces
|
||||
func Title(name string) string { |
||||
name = strings.Replace(name, "_", " ", -1) |
||||
name = strings.Replace(name, ".", " ", -1) |
||||
name = strings.TrimSpace(name) |
||||
return strings.ToUpper(name) |
||||
} |
||||
|
||||
// Pad String
|
||||
// Attempts to play string in the center
|
||||
func Pad(s, pad string, width int) string { |
||||
gap := width - DisplayWidth(s) |
||||
if gap > 0 { |
||||
gapLeft := int(math.Ceil(float64(gap / 2))) |
||||
gapRight := gap - gapLeft |
||||
return strings.Repeat(string(pad), gapLeft) + s + strings.Repeat(string(pad), gapRight) |
||||
} |
||||
return s |
||||
} |
||||
|
||||
// Pad String Right position
|
||||
// This would pace string at the left side fo the screen
|
||||
func PadRight(s, pad string, width int) string { |
||||
gap := width - DisplayWidth(s) |
||||
if gap > 0 { |
||||
return s + strings.Repeat(string(pad), gap) |
||||
} |
||||
return s |
||||
} |
||||
|
||||
// Pad String Left position
|
||||
// This would pace string at the right side fo the screen
|
||||
func PadLeft(s, pad string, width int) string { |
||||
gap := width - DisplayWidth(s) |
||||
if gap > 0 { |
||||
return strings.Repeat(string(pad), gap) + s |
||||
} |
||||
return s |
||||
} |
@ -0,0 +1,103 @@ |
||||
// Copyright 2014 Oleku Konko All rights reserved.
|
||||
// Use of this source code is governed by a MIT
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This module is a Table Writer API for the Go Programming Language.
|
||||
// The protocols were written in pure Go and works on windows and unix systems
|
||||
|
||||
package tablewriter |
||||
|
||||
import ( |
||||
"math" |
||||
"strings" |
||||
"unicode/utf8" |
||||
) |
||||
|
||||
var ( |
||||
nl = "\n" |
||||
sp = " " |
||||
) |
||||
|
||||
const defaultPenalty = 1e5 |
||||
|
||||
// Wrap wraps s into a paragraph of lines of length lim, with minimal
|
||||
// raggedness.
|
||||
func WrapString(s string, lim int) ([]string, int) { |
||||
words := strings.Split(strings.Replace(s, nl, sp, -1), sp) |
||||
var lines []string |
||||
max := 0 |
||||
for _, v := range words { |
||||
max = len(v) |
||||
if max > lim { |
||||
lim = max |
||||
} |
||||
} |
||||
for _, line := range WrapWords(words, 1, lim, defaultPenalty) { |
||||
lines = append(lines, strings.Join(line, sp)) |
||||
} |
||||
return lines, lim |
||||
} |
||||
|
||||
// WrapWords is the low-level line-breaking algorithm, useful if you need more
|
||||
// control over the details of the text wrapping process. For most uses,
|
||||
// WrapString will be sufficient and more convenient.
|
||||
//
|
||||
// WrapWords splits a list of words into lines with minimal "raggedness",
|
||||
// treating each rune as one unit, accounting for spc units between adjacent
|
||||
// words on each line, and attempting to limit lines to lim units. Raggedness
|
||||
// is the total error over all lines, where error is the square of the
|
||||
// difference of the length of the line and lim. Too-long lines (which only
|
||||
// happen when a single word is longer than lim units) have pen penalty units
|
||||
// added to the error.
|
||||
func WrapWords(words []string, spc, lim, pen int) [][]string { |
||||
n := len(words) |
||||
|
||||
length := make([][]int, n) |
||||
for i := 0; i < n; i++ { |
||||
length[i] = make([]int, n) |
||||
length[i][i] = utf8.RuneCountInString(words[i]) |
||||
for j := i + 1; j < n; j++ { |
||||
length[i][j] = length[i][j-1] + spc + utf8.RuneCountInString(words[j]) |
||||
} |
||||
} |
||||
nbrk := make([]int, n) |
||||
cost := make([]int, n) |
||||
for i := range cost { |
||||
cost[i] = math.MaxInt32 |
||||
} |
||||
for i := n - 1; i >= 0; i-- { |
||||
if length[i][n-1] <= lim { |
||||
cost[i] = 0 |
||||
nbrk[i] = n |
||||
} else { |
||||
for j := i + 1; j < n; j++ { |
||||
d := lim - length[i][j-1] |
||||
c := d*d + cost[j] |
||||
if length[i][j-1] > lim { |
||||
c += pen // too-long lines get a worse penalty
|
||||
} |
||||
if c < cost[i] { |
||||
cost[i] = c |
||||
nbrk[i] = j |
||||
} |
||||
} |
||||
} |
||||
} |
||||
var lines [][]string |
||||
i := 0 |
||||
for i < n { |
||||
lines = append(lines, words[i:nbrk[i]]) |
||||
i = nbrk[i] |
||||
} |
||||
return lines |
||||
} |
||||
|
||||
// getLines decomposes a multiline string into a slice of strings.
|
||||
func getLines(s string) []string { |
||||
var lines []string |
||||
|
||||
for _, line := range strings.Split(s, nl) { |
||||
lines = append(lines, line) |
||||
} |
||||
return lines |
||||
} |
@ -0,0 +1,8 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This code was translated into a form compatible with 6a from the public
|
||||
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
|
||||
|
||||
#define REDMASK51 0x0007FFFFFFFFFFFF |
@ -0,0 +1,20 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved. |
||||
// Use of this source code is governed by a BSD-style |
||||
// license that can be found in the LICENSE file. |
||||
|
||||
// This code was translated into a form compatible with 6a from the public |
||||
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html |
||||
|
||||
// +build amd64,!gccgo,!appengine |
||||
|
||||
// These constants cannot be encoded in non-MOVQ immediates. |
||||
// We access them directly from memory instead. |
||||
|
||||
DATA ·_121666_213(SB)/8, $996687872 |
||||
GLOBL ·_121666_213(SB), 8, $8 |
||||
|
||||
DATA ·_2P0(SB)/8, $0xFFFFFFFFFFFDA |
||||
GLOBL ·_2P0(SB), 8, $8 |
||||
|
||||
DATA ·_2P1234(SB)/8, $0xFFFFFFFFFFFFE |
||||
GLOBL ·_2P1234(SB), 8, $8 |
@ -0,0 +1,88 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved. |
||||
// Use of this source code is governed by a BSD-style |
||||
// license that can be found in the LICENSE file. |
||||
|
||||
// This code was translated into a form compatible with 6a from the public |
||||
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html |
||||
|
||||
// +build amd64,!gccgo,!appengine |
||||
|
||||
// func cswap(inout *[5]uint64, v uint64) |
||||
TEXT ·cswap(SB),7,$0 |
||||
MOVQ inout+0(FP),DI |
||||
MOVQ v+8(FP),SI |
||||
|
||||
CMPQ SI,$1 |
||||
MOVQ 0(DI),SI |
||||
MOVQ 80(DI),DX |
||||
MOVQ 8(DI),CX |
||||
MOVQ 88(DI),R8 |
||||
MOVQ SI,R9 |
||||
CMOVQEQ DX,SI |
||||
CMOVQEQ R9,DX |
||||
MOVQ CX,R9 |
||||
CMOVQEQ R8,CX |
||||
CMOVQEQ R9,R8 |
||||
MOVQ SI,0(DI) |
||||
MOVQ DX,80(DI) |
||||
MOVQ CX,8(DI) |
||||
MOVQ R8,88(DI) |
||||
MOVQ 16(DI),SI |
||||
MOVQ 96(DI),DX |
||||
MOVQ 24(DI),CX |
||||
MOVQ 104(DI),R8 |
||||
MOVQ SI,R9 |
||||
CMOVQEQ DX,SI |
||||
CMOVQEQ R9,DX |
||||
MOVQ CX,R9 |
||||
CMOVQEQ R8,CX |
||||
CMOVQEQ R9,R8 |
||||
MOVQ SI,16(DI) |
||||
MOVQ DX,96(DI) |
||||
MOVQ CX,24(DI) |
||||
MOVQ R8,104(DI) |
||||
MOVQ 32(DI),SI |
||||
MOVQ 112(DI),DX |
||||
MOVQ 40(DI),CX |
||||
MOVQ 120(DI),R8 |
||||
MOVQ SI,R9 |
||||
CMOVQEQ DX,SI |
||||
CMOVQEQ R9,DX |
||||
MOVQ CX,R9 |
||||
CMOVQEQ R8,CX |
||||
CMOVQEQ R9,R8 |
||||
MOVQ SI,32(DI) |
||||
MOVQ DX,112(DI) |
||||
MOVQ CX,40(DI) |
||||
MOVQ R8,120(DI) |
||||
MOVQ 48(DI),SI |
||||
MOVQ 128(DI),DX |
||||
MOVQ 56(DI),CX |
||||
MOVQ 136(DI),R8 |
||||
MOVQ SI,R9 |
||||
CMOVQEQ DX,SI |
||||
CMOVQEQ R9,DX |
||||
MOVQ CX,R9 |
||||
CMOVQEQ R8,CX |
||||
CMOVQEQ R9,R8 |
||||
MOVQ SI,48(DI) |
||||
MOVQ DX,128(DI) |
||||
MOVQ CX,56(DI) |
||||
MOVQ R8,136(DI) |
||||
MOVQ 64(DI),SI |
||||
MOVQ 144(DI),DX |
||||
MOVQ 72(DI),CX |
||||
MOVQ 152(DI),R8 |
||||
MOVQ SI,R9 |
||||
CMOVQEQ DX,SI |
||||
CMOVQEQ R9,DX |
||||
MOVQ CX,R9 |
||||
CMOVQEQ R8,CX |
||||
CMOVQEQ R9,R8 |
||||
MOVQ SI,64(DI) |
||||
MOVQ DX,144(DI) |
||||
MOVQ CX,72(DI) |
||||
MOVQ R8,152(DI) |
||||
MOVQ DI,AX |
||||
MOVQ SI,DX |
||||
RET |
@ -0,0 +1,841 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// We have a implementation in amd64 assembly so this code is only run on
|
||||
// non-amd64 platforms. The amd64 assembly does not support gccgo.
|
||||
// +build !amd64 gccgo appengine
|
||||
|
||||
package curve25519 |
||||
|
||||
// This code is a port of the public domain, "ref10" implementation of
|
||||
// curve25519 from SUPERCOP 20130419 by D. J. Bernstein.
|
||||
|
||||
// fieldElement represents an element of the field GF(2^255 - 19). An element
|
||||
// t, entries t[0]...t[9], represents the integer t[0]+2^26 t[1]+2^51 t[2]+2^77
|
||||
// t[3]+2^102 t[4]+...+2^230 t[9]. Bounds on each t[i] vary depending on
|
||||
// context.
|
||||
type fieldElement [10]int32 |
||||
|
||||
func feZero(fe *fieldElement) { |
||||
for i := range fe { |
||||
fe[i] = 0 |
||||
} |
||||
} |
||||
|
||||
func feOne(fe *fieldElement) { |
||||
feZero(fe) |
||||
fe[0] = 1 |
||||
} |
||||
|
||||
func feAdd(dst, a, b *fieldElement) { |
||||
for i := range dst { |
||||
dst[i] = a[i] + b[i] |
||||
} |
||||
} |
||||
|
||||
func feSub(dst, a, b *fieldElement) { |
||||
for i := range dst { |
||||
dst[i] = a[i] - b[i] |
||||
} |
||||
} |
||||
|
||||
func feCopy(dst, src *fieldElement) { |
||||
for i := range dst { |
||||
dst[i] = src[i] |
||||
} |
||||
} |
||||
|
||||
// feCSwap replaces (f,g) with (g,f) if b == 1; replaces (f,g) with (f,g) if b == 0.
|
||||
//
|
||||
// Preconditions: b in {0,1}.
|
||||
func feCSwap(f, g *fieldElement, b int32) { |
||||
var x fieldElement |
||||
b = -b |
||||
for i := range x { |
||||
x[i] = b & (f[i] ^ g[i]) |
||||
} |
||||
|
||||
for i := range f { |
||||
f[i] ^= x[i] |
||||
} |
||||
for i := range g { |
||||
g[i] ^= x[i] |
||||
} |
||||
} |
||||
|
||||
// load3 reads a 24-bit, little-endian value from in.
|
||||
func load3(in []byte) int64 { |
||||
var r int64 |
||||
r = int64(in[0]) |
||||
r |= int64(in[1]) << 8 |
||||
r |= int64(in[2]) << 16 |
||||
return r |
||||
} |
||||
|
||||
// load4 reads a 32-bit, little-endian value from in.
|
||||
func load4(in []byte) int64 { |
||||
var r int64 |
||||
r = int64(in[0]) |
||||
r |= int64(in[1]) << 8 |
||||
r |= int64(in[2]) << 16 |
||||
r |= int64(in[3]) << 24 |
||||
return r |
||||
} |
||||
|
||||
func feFromBytes(dst *fieldElement, src *[32]byte) { |
||||
h0 := load4(src[:]) |
||||
h1 := load3(src[4:]) << 6 |
||||
h2 := load3(src[7:]) << 5 |
||||
h3 := load3(src[10:]) << 3 |
||||
h4 := load3(src[13:]) << 2 |
||||
h5 := load4(src[16:]) |
||||
h6 := load3(src[20:]) << 7 |
||||
h7 := load3(src[23:]) << 5 |
||||
h8 := load3(src[26:]) << 4 |
||||
h9 := load3(src[29:]) << 2 |
||||
|
||||
var carry [10]int64 |
||||
carry[9] = (h9 + 1<<24) >> 25 |
||||
h0 += carry[9] * 19 |
||||
h9 -= carry[9] << 25 |
||||
carry[1] = (h1 + 1<<24) >> 25 |
||||
h2 += carry[1] |
||||
h1 -= carry[1] << 25 |
||||
carry[3] = (h3 + 1<<24) >> 25 |
||||
h4 += carry[3] |
||||
h3 -= carry[3] << 25 |
||||
carry[5] = (h5 + 1<<24) >> 25 |
||||
h6 += carry[5] |
||||
h5 -= carry[5] << 25 |
||||
carry[7] = (h7 + 1<<24) >> 25 |
||||
h8 += carry[7] |
||||
h7 -= carry[7] << 25 |
||||
|
||||
carry[0] = (h0 + 1<<25) >> 26 |
||||
h1 += carry[0] |
||||
h0 -= carry[0] << 26 |
||||
carry[2] = (h2 + 1<<25) >> 26 |
||||
h3 += carry[2] |
||||
h2 -= carry[2] << 26 |
||||
carry[4] = (h4 + 1<<25) >> 26 |
||||
h5 += carry[4] |
||||
h4 -= carry[4] << 26 |
||||
carry[6] = (h6 + 1<<25) >> 26 |
||||
h7 += carry[6] |
||||
h6 -= carry[6] << 26 |
||||
carry[8] = (h8 + 1<<25) >> 26 |
||||
h9 += carry[8] |
||||
h8 -= carry[8] << 26 |
||||
|
||||
dst[0] = int32(h0) |
||||
dst[1] = int32(h1) |
||||
dst[2] = int32(h2) |
||||
dst[3] = int32(h3) |
||||
dst[4] = int32(h4) |
||||
dst[5] = int32(h5) |
||||
dst[6] = int32(h6) |
||||
dst[7] = int32(h7) |
||||
dst[8] = int32(h8) |
||||
dst[9] = int32(h9) |
||||
} |
||||
|
||||
// feToBytes marshals h to s.
|
||||
// Preconditions:
|
||||
// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc.
|
||||
//
|
||||
// Write p=2^255-19; q=floor(h/p).
|
||||
// Basic claim: q = floor(2^(-255)(h + 19 2^(-25)h9 + 2^(-1))).
|
||||
//
|
||||
// Proof:
|
||||
// Have |h|<=p so |q|<=1 so |19^2 2^(-255) q|<1/4.
|
||||
// Also have |h-2^230 h9|<2^230 so |19 2^(-255)(h-2^230 h9)|<1/4.
|
||||
//
|
||||
// Write y=2^(-1)-19^2 2^(-255)q-19 2^(-255)(h-2^230 h9).
|
||||
// Then 0<y<1.
|
||||
//
|
||||
// Write r=h-pq.
|
||||
// Have 0<=r<=p-1=2^255-20.
|
||||
// Thus 0<=r+19(2^-255)r<r+19(2^-255)2^255<=2^255-1.
|
||||
//
|
||||
// Write x=r+19(2^-255)r+y.
|
||||
// Then 0<x<2^255 so floor(2^(-255)x) = 0 so floor(q+2^(-255)x) = q.
|
||||
//
|
||||
// Have q+2^(-255)x = 2^(-255)(h + 19 2^(-25) h9 + 2^(-1))
|
||||
// so floor(2^(-255)(h + 19 2^(-25) h9 + 2^(-1))) = q.
|
||||
func feToBytes(s *[32]byte, h *fieldElement) { |
||||
var carry [10]int32 |
||||
|
||||
q := (19*h[9] + (1 << 24)) >> 25 |
||||
q = (h[0] + q) >> 26 |
||||
q = (h[1] + q) >> 25 |
||||
q = (h[2] + q) >> 26 |
||||
q = (h[3] + q) >> 25 |
||||
q = (h[4] + q) >> 26 |
||||
q = (h[5] + q) >> 25 |
||||
q = (h[6] + q) >> 26 |
||||
q = (h[7] + q) >> 25 |
||||
q = (h[8] + q) >> 26 |
||||
q = (h[9] + q) >> 25 |
||||
|
||||
// Goal: Output h-(2^255-19)q, which is between 0 and 2^255-20.
|
||||
h[0] += 19 * q |
||||
// Goal: Output h-2^255 q, which is between 0 and 2^255-20.
|
||||
|
||||
carry[0] = h[0] >> 26 |
||||
h[1] += carry[0] |
||||
h[0] -= carry[0] << 26 |
||||
carry[1] = h[1] >> 25 |
||||
h[2] += carry[1] |
||||
h[1] -= carry[1] << 25 |
||||
carry[2] = h[2] >> 26 |
||||
h[3] += carry[2] |
||||
h[2] -= carry[2] << 26 |
||||
carry[3] = h[3] >> 25 |
||||
h[4] += carry[3] |
||||
h[3] -= carry[3] << 25 |
||||
carry[4] = h[4] >> 26 |
||||
h[5] += carry[4] |
||||
h[4] -= carry[4] << 26 |
||||
carry[5] = h[5] >> 25 |
||||
h[6] += carry[5] |
||||
h[5] -= carry[5] << 25 |
||||
carry[6] = h[6] >> 26 |
||||
h[7] += carry[6] |
||||
h[6] -= carry[6] << 26 |
||||
carry[7] = h[7] >> 25 |
||||
h[8] += carry[7] |
||||
h[7] -= carry[7] << 25 |
||||
carry[8] = h[8] >> 26 |
||||
h[9] += carry[8] |
||||
h[8] -= carry[8] << 26 |
||||
carry[9] = h[9] >> 25 |
||||
h[9] -= carry[9] << 25 |
||||
// h10 = carry9
|
||||
|
||||
// Goal: Output h[0]+...+2^255 h10-2^255 q, which is between 0 and 2^255-20.
|
||||
// Have h[0]+...+2^230 h[9] between 0 and 2^255-1;
|
||||
// evidently 2^255 h10-2^255 q = 0.
|
||||
// Goal: Output h[0]+...+2^230 h[9].
|
||||
|
||||
s[0] = byte(h[0] >> 0) |
||||
s[1] = byte(h[0] >> 8) |
||||
s[2] = byte(h[0] >> 16) |
||||
s[3] = byte((h[0] >> 24) | (h[1] << 2)) |
||||
s[4] = byte(h[1] >> 6) |
||||
s[5] = byte(h[1] >> 14) |
||||
s[6] = byte((h[1] >> 22) | (h[2] << 3)) |
||||
s[7] = byte(h[2] >> 5) |
||||
s[8] = byte(h[2] >> 13) |
||||
s[9] = byte((h[2] >> 21) | (h[3] << 5)) |
||||
s[10] = byte(h[3] >> 3) |
||||
s[11] = byte(h[3] >> 11) |
||||
s[12] = byte((h[3] >> 19) | (h[4] << 6)) |
||||
s[13] = byte(h[4] >> 2) |
||||
s[14] = byte(h[4] >> 10) |
||||
s[15] = byte(h[4] >> 18) |
||||
s[16] = byte(h[5] >> 0) |
||||
s[17] = byte(h[5] >> 8) |
||||
s[18] = byte(h[5] >> 16) |
||||
s[19] = byte((h[5] >> 24) | (h[6] << 1)) |
||||
s[20] = byte(h[6] >> 7) |
||||
s[21] = byte(h[6] >> 15) |
||||
s[22] = byte((h[6] >> 23) | (h[7] << 3)) |
||||
s[23] = byte(h[7] >> 5) |
||||
s[24] = byte(h[7] >> 13) |
||||
s[25] = byte((h[7] >> 21) | (h[8] << 4)) |
||||
s[26] = byte(h[8] >> 4) |
||||
s[27] = byte(h[8] >> 12) |
||||
s[28] = byte((h[8] >> 20) | (h[9] << 6)) |
||||
s[29] = byte(h[9] >> 2) |
||||
s[30] = byte(h[9] >> 10) |
||||
s[31] = byte(h[9] >> 18) |
||||
} |
||||
|
||||
// feMul calculates h = f * g
|
||||
// Can overlap h with f or g.
|
||||
//
|
||||
// Preconditions:
|
||||
// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc.
|
||||
// |g| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc.
|
||||
//
|
||||
// Postconditions:
|
||||
// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc.
|
||||
//
|
||||
// Notes on implementation strategy:
|
||||
//
|
||||
// Using schoolbook multiplication.
|
||||
// Karatsuba would save a little in some cost models.
|
||||
//
|
||||
// Most multiplications by 2 and 19 are 32-bit precomputations;
|
||||
// cheaper than 64-bit postcomputations.
|
||||
//
|
||||
// There is one remaining multiplication by 19 in the carry chain;
|
||||
// one *19 precomputation can be merged into this,
|
||||
// but the resulting data flow is considerably less clean.
|
||||
//
|
||||
// There are 12 carries below.
|
||||
// 10 of them are 2-way parallelizable and vectorizable.
|
||||
// Can get away with 11 carries, but then data flow is much deeper.
|
||||
//
|
||||
// With tighter constraints on inputs can squeeze carries into int32.
|
||||
func feMul(h, f, g *fieldElement) { |
||||
f0 := f[0] |
||||
f1 := f[1] |
||||
f2 := f[2] |
||||
f3 := f[3] |
||||
f4 := f[4] |
||||
f5 := f[5] |
||||
f6 := f[6] |
||||
f7 := f[7] |
||||
f8 := f[8] |
||||
f9 := f[9] |
||||
g0 := g[0] |
||||
g1 := g[1] |
||||
g2 := g[2] |
||||
g3 := g[3] |
||||
g4 := g[4] |
||||
g5 := g[5] |
||||
g6 := g[6] |
||||
g7 := g[7] |
||||
g8 := g[8] |
||||
g9 := g[9] |
||||
g1_19 := 19 * g1 // 1.4*2^29
|
||||
g2_19 := 19 * g2 // 1.4*2^30; still ok
|
||||
g3_19 := 19 * g3 |
||||
g4_19 := 19 * g4 |
||||
g5_19 := 19 * g5 |
||||
g6_19 := 19 * g6 |
||||
g7_19 := 19 * g7 |
||||
g8_19 := 19 * g8 |
||||
g9_19 := 19 * g9 |
||||
f1_2 := 2 * f1 |
||||
f3_2 := 2 * f3 |
||||
f5_2 := 2 * f5 |
||||
f7_2 := 2 * f7 |
||||
f9_2 := 2 * f9 |
||||
f0g0 := int64(f0) * int64(g0) |
||||
f0g1 := int64(f0) * int64(g1) |
||||
f0g2 := int64(f0) * int64(g2) |
||||
f0g3 := int64(f0) * int64(g3) |
||||
f0g4 := int64(f0) * int64(g4) |
||||
f0g5 := int64(f0) * int64(g5) |
||||
f0g6 := int64(f0) * int64(g6) |
||||
f0g7 := int64(f0) * int64(g7) |
||||
f0g8 := int64(f0) * int64(g8) |
||||
f0g9 := int64(f0) * int64(g9) |
||||
f1g0 := int64(f1) * int64(g0) |
||||
f1g1_2 := int64(f1_2) * int64(g1) |
||||
f1g2 := int64(f1) * int64(g2) |
||||
f1g3_2 := int64(f1_2) * int64(g3) |
||||
f1g4 := int64(f1) * int64(g4) |
||||
f1g5_2 := int64(f1_2) * int64(g5) |
||||
f1g6 := int64(f1) * int64(g6) |
||||
f1g7_2 := int64(f1_2) * int64(g7) |
||||
f1g8 := int64(f1) * int64(g8) |
||||
f1g9_38 := int64(f1_2) * int64(g9_19) |
||||
f2g0 := int64(f2) * int64(g0) |
||||
f2g1 := int64(f2) * int64(g1) |
||||
f2g2 := int64(f2) * int64(g2) |
||||
f2g3 := int64(f2) * int64(g3) |
||||
f2g4 := int64(f2) * int64(g4) |
||||
f2g5 := int64(f2) * int64(g5) |
||||
f2g6 := int64(f2) * int64(g6) |
||||
f2g7 := int64(f2) * int64(g7) |
||||
f2g8_19 := int64(f2) * int64(g8_19) |
||||
f2g9_19 := int64(f2) * int64(g9_19) |
||||
f3g0 := int64(f3) * int64(g0) |
||||
f3g1_2 := int64(f3_2) * int64(g1) |
||||
f3g2 := int64(f3) * int64(g2) |
||||
f3g3_2 := int64(f3_2) * int64(g3) |
||||
f3g4 := int64(f3) * int64(g4) |
||||
f3g5_2 := int64(f3_2) * int64(g5) |
||||
f3g6 := int64(f3) * int64(g6) |
||||
f3g7_38 := int64(f3_2) * int64(g7_19) |
||||
f3g8_19 := int64(f3) * int64(g8_19) |
||||
f3g9_38 := int64(f3_2) * int64(g9_19) |
||||
f4g0 := int64(f4) * int64(g0) |
||||
f4g1 := int64(f4) * int64(g1) |
||||
f4g2 := int64(f4) * int64(g2) |
||||
f4g3 := int64(f4) * int64(g3) |
||||
f4g4 := int64(f4) * int64(g4) |
||||
f4g5 := int64(f4) * int64(g5) |
||||
f4g6_19 := int64(f4) * int64(g6_19) |
||||
f4g7_19 := int64(f4) * int64(g7_19) |
||||
f4g8_19 := int64(f4) * int64(g8_19) |
||||
f4g9_19 := int64(f4) * int64(g9_19) |
||||
f5g0 := int64(f5) * int64(g0) |
||||
f5g1_2 := int64(f5_2) * int64(g1) |
||||
f5g2 := int64(f5) * int64(g2) |
||||
f5g3_2 := int64(f5_2) * int64(g3) |
||||
f5g4 := int64(f5) * int64(g4) |
||||
f5g5_38 := int64(f5_2) * int64(g5_19) |
||||
f5g6_19 := int64(f5) * int64(g6_19) |
||||
f5g7_38 := int64(f5_2) * int64(g7_19) |
||||
f5g8_19 := int64(f5) * int64(g8_19) |
||||
f5g9_38 := int64(f5_2) * int64(g9_19) |
||||
f6g0 := int64(f6) * int64(g0) |
||||
f6g1 := int64(f6) * int64(g1) |
||||
f6g2 := int64(f6) * int64(g2) |
||||
f6g3 := int64(f6) * int64(g3) |
||||
f6g4_19 := int64(f6) * int64(g4_19) |
||||
f6g5_19 := int64(f6) * int64(g5_19) |
||||
f6g6_19 := int64(f6) * int64(g6_19) |
||||
f6g7_19 := int64(f6) * int64(g7_19) |
||||
f6g8_19 := int64(f6) * int64(g8_19) |
||||
f6g9_19 := int64(f6) * int64(g9_19) |
||||
f7g0 := int64(f7) * int64(g0) |
||||
f7g1_2 := int64(f7_2) * int64(g1) |
||||
f7g2 := int64(f7) * int64(g2) |
||||
f7g3_38 := int64(f7_2) * int64(g3_19) |
||||
f7g4_19 := int64(f7) * int64(g4_19) |
||||
f7g5_38 := int64(f7_2) * int64(g5_19) |
||||
f7g6_19 := int64(f7) * int64(g6_19) |
||||
f7g7_38 := int64(f7_2) * int64(g7_19) |
||||
f7g8_19 := int64(f7) * int64(g8_19) |
||||
f7g9_38 := int64(f7_2) * int64(g9_19) |
||||
f8g0 := int64(f8) * int64(g0) |
||||
f8g1 := int64(f8) * int64(g1) |
||||
f8g2_19 := int64(f8) * int64(g2_19) |
||||
f8g3_19 := int64(f8) * int64(g3_19) |
||||
f8g4_19 := int64(f8) * int64(g4_19) |
||||
f8g5_19 := int64(f8) * int64(g5_19) |
||||
f8g6_19 := int64(f8) * int64(g6_19) |
||||
f8g7_19 := int64(f8) * int64(g7_19) |
||||
f8g8_19 := int64(f8) * int64(g8_19) |
||||
f8g9_19 := int64(f8) * int64(g9_19) |
||||
f9g0 := int64(f9) * int64(g0) |
||||
f9g1_38 := int64(f9_2) * int64(g1_19) |
||||
f9g2_19 := int64(f9) * int64(g2_19) |
||||
f9g3_38 := int64(f9_2) * int64(g3_19) |
||||
f9g4_19 := int64(f9) * int64(g4_19) |
||||
f9g5_38 := int64(f9_2) * int64(g5_19) |
||||
f9g6_19 := int64(f9) * int64(g6_19) |
||||
f9g7_38 := int64(f9_2) * int64(g7_19) |
||||
f9g8_19 := int64(f9) * int64(g8_19) |
||||
f9g9_38 := int64(f9_2) * int64(g9_19) |
||||
h0 := f0g0 + f1g9_38 + f2g8_19 + f3g7_38 + f4g6_19 + f5g5_38 + f6g4_19 + f7g3_38 + f8g2_19 + f9g1_38 |
||||
h1 := f0g1 + f1g0 + f2g9_19 + f3g8_19 + f4g7_19 + f5g6_19 + f6g5_19 + f7g4_19 + f8g3_19 + f9g2_19 |
||||
h2 := f0g2 + f1g1_2 + f2g0 + f3g9_38 + f4g8_19 + f5g7_38 + f6g6_19 + f7g5_38 + f8g4_19 + f9g3_38 |
||||
h3 := f0g3 + f1g2 + f2g1 + f3g0 + f4g9_19 + f5g8_19 + f6g7_19 + f7g6_19 + f8g5_19 + f9g4_19 |
||||
h4 := f0g4 + f1g3_2 + f2g2 + f3g1_2 + f4g0 + f5g9_38 + f6g8_19 + f7g7_38 + f8g6_19 + f9g5_38 |
||||
h5 := f0g5 + f1g4 + f2g3 + f3g2 + f4g1 + f5g0 + f6g9_19 + f7g8_19 + f8g7_19 + f9g6_19 |
||||
h6 := f0g6 + f1g5_2 + f2g4 + f3g3_2 + f4g2 + f5g1_2 + f6g0 + f7g9_38 + f8g8_19 + f9g7_38 |
||||
h7 := f0g7 + f1g6 + f2g5 + f3g4 + f4g3 + f5g2 + f6g1 + f7g0 + f8g9_19 + f9g8_19 |
||||
h8 := f0g8 + f1g7_2 + f2g6 + f3g5_2 + f4g4 + f5g3_2 + f6g2 + f7g1_2 + f8g0 + f9g9_38 |
||||
h9 := f0g9 + f1g8 + f2g7 + f3g6 + f4g5 + f5g4 + f6g3 + f7g2 + f8g1 + f9g0 |
||||
var carry [10]int64 |
||||
|
||||
// |h0| <= (1.1*1.1*2^52*(1+19+19+19+19)+1.1*1.1*2^50*(38+38+38+38+38))
|
||||
// i.e. |h0| <= 1.2*2^59; narrower ranges for h2, h4, h6, h8
|
||||
// |h1| <= (1.1*1.1*2^51*(1+1+19+19+19+19+19+19+19+19))
|
||||
// i.e. |h1| <= 1.5*2^58; narrower ranges for h3, h5, h7, h9
|
||||
|
||||
carry[0] = (h0 + (1 << 25)) >> 26 |
||||
h1 += carry[0] |
||||
h0 -= carry[0] << 26 |
||||
carry[4] = (h4 + (1 << 25)) >> 26 |
||||
h5 += carry[4] |
||||
h4 -= carry[4] << 26 |
||||
// |h0| <= 2^25
|
||||
// |h4| <= 2^25
|
||||
// |h1| <= 1.51*2^58
|
||||
// |h5| <= 1.51*2^58
|
||||
|
||||
carry[1] = (h1 + (1 << 24)) >> 25 |
||||
h2 += carry[1] |
||||
h1 -= carry[1] << 25 |
||||
carry[5] = (h5 + (1 << 24)) >> 25 |
||||
h6 += carry[5] |
||||
h5 -= carry[5] << 25 |
||||
// |h1| <= 2^24; from now on fits into int32
|
||||
// |h5| <= 2^24; from now on fits into int32
|
||||
// |h2| <= 1.21*2^59
|
||||
// |h6| <= 1.21*2^59
|
||||
|
||||
carry[2] = (h2 + (1 << 25)) >> 26 |
||||
h3 += carry[2] |
||||
h2 -= carry[2] << 26 |
||||
carry[6] = (h6 + (1 << 25)) >> 26 |
||||
h7 += carry[6] |
||||
h6 -= carry[6] << 26 |
||||
// |h2| <= 2^25; from now on fits into int32 unchanged
|
||||
// |h6| <= 2^25; from now on fits into int32 unchanged
|
||||
// |h3| <= 1.51*2^58
|
||||
// |h7| <= 1.51*2^58
|
||||
|
||||
carry[3] = (h3 + (1 << 24)) >> 25 |
||||
h4 += carry[3] |
||||
h3 -= carry[3] << 25 |
||||
carry[7] = (h7 + (1 << 24)) >> 25 |
||||
h8 += carry[7] |
||||
h7 -= carry[7] << 25 |
||||
// |h3| <= 2^24; from now on fits into int32 unchanged
|
||||
// |h7| <= 2^24; from now on fits into int32 unchanged
|
||||
// |h4| <= 1.52*2^33
|
||||
// |h8| <= 1.52*2^33
|
||||
|
||||
carry[4] = (h4 + (1 << 25)) >> 26 |
||||
h5 += carry[4] |
||||
h4 -= carry[4] << 26 |
||||
carry[8] = (h8 + (1 << 25)) >> 26 |
||||
h9 += carry[8] |
||||
h8 -= carry[8] << 26 |
||||
// |h4| <= 2^25; from now on fits into int32 unchanged
|
||||
// |h8| <= 2^25; from now on fits into int32 unchanged
|
||||
// |h5| <= 1.01*2^24
|
||||
// |h9| <= 1.51*2^58
|
||||
|
||||
carry[9] = (h9 + (1 << 24)) >> 25 |
||||
h0 += carry[9] * 19 |
||||
h9 -= carry[9] << 25 |
||||
// |h9| <= 2^24; from now on fits into int32 unchanged
|
||||
// |h0| <= 1.8*2^37
|
||||
|
||||
carry[0] = (h0 + (1 << 25)) >> 26 |
||||
h1 += carry[0] |
||||
h0 -= carry[0] << 26 |
||||
// |h0| <= 2^25; from now on fits into int32 unchanged
|
||||
// |h1| <= 1.01*2^24
|
||||
|
||||
h[0] = int32(h0) |
||||
h[1] = int32(h1) |
||||
h[2] = int32(h2) |
||||
h[3] = int32(h3) |
||||
h[4] = int32(h4) |
||||
h[5] = int32(h5) |
||||
h[6] = int32(h6) |
||||
h[7] = int32(h7) |
||||
h[8] = int32(h8) |
||||
h[9] = int32(h9) |
||||
} |
||||
|
||||
// feSquare calculates h = f*f. Can overlap h with f.
|
||||
//
|
||||
// Preconditions:
|
||||
// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc.
|
||||
//
|
||||
// Postconditions:
|
||||
// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc.
|
||||
func feSquare(h, f *fieldElement) { |
||||
f0 := f[0] |
||||
f1 := f[1] |
||||
f2 := f[2] |
||||
f3 := f[3] |
||||
f4 := f[4] |
||||
f5 := f[5] |
||||
f6 := f[6] |
||||
f7 := f[7] |
||||
f8 := f[8] |
||||
f9 := f[9] |
||||
f0_2 := 2 * f0 |
||||
f1_2 := 2 * f1 |
||||
f2_2 := 2 * f2 |
||||
f3_2 := 2 * f3 |
||||
f4_2 := 2 * f4 |
||||
f5_2 := 2 * f5 |
||||
f6_2 := 2 * f6 |
||||
f7_2 := 2 * f7 |
||||
f5_38 := 38 * f5 // 1.31*2^30
|
||||
f6_19 := 19 * f6 // 1.31*2^30
|
||||
f7_38 := 38 * f7 // 1.31*2^30
|
||||
f8_19 := 19 * f8 // 1.31*2^30
|
||||
f9_38 := 38 * f9 // 1.31*2^30
|
||||
f0f0 := int64(f0) * int64(f0) |
||||
f0f1_2 := int64(f0_2) * int64(f1) |
||||
f0f2_2 := int64(f0_2) * int64(f2) |
||||
f0f3_2 := int64(f0_2) * int64(f3) |
||||
f0f4_2 := int64(f0_2) * int64(f4) |
||||
f0f5_2 := int64(f0_2) * int64(f5) |
||||
f0f6_2 := int64(f0_2) * int64(f6) |
||||
f0f7_2 := int64(f0_2) * int64(f7) |
||||
f0f8_2 := int64(f0_2) * int64(f8) |
||||
f0f9_2 := int64(f0_2) * int64(f9) |
||||
f1f1_2 := int64(f1_2) * int64(f1) |
||||
f1f2_2 := int64(f1_2) * int64(f2) |
||||
f1f3_4 := int64(f1_2) * int64(f3_2) |
||||
f1f4_2 := int64(f1_2) * int64(f4) |
||||
f1f5_4 := int64(f1_2) * int64(f5_2) |
||||
f1f6_2 := int64(f1_2) * int64(f6) |
||||
f1f7_4 := int64(f1_2) * int64(f7_2) |
||||
f1f8_2 := int64(f1_2) * int64(f8) |
||||
f1f9_76 := int64(f1_2) * int64(f9_38) |
||||
f2f2 := int64(f2) * int64(f2) |
||||
f2f3_2 := int64(f2_2) * int64(f3) |
||||
f2f4_2 := int64(f2_2) * int64(f4) |
||||
f2f5_2 := int64(f2_2) * int64(f5) |
||||
f2f6_2 := int64(f2_2) * int64(f6) |
||||
f2f7_2 := int64(f2_2) * int64(f7) |
||||
f2f8_38 := int64(f2_2) * int64(f8_19) |
||||
f2f9_38 := int64(f2) * int64(f9_38) |
||||
f3f3_2 := int64(f3_2) * int64(f3) |
||||
f3f4_2 := int64(f3_2) * int64(f4) |
||||
f3f5_4 := int64(f3_2) * int64(f5_2) |
||||
f3f6_2 := int64(f3_2) * int64(f6) |
||||
f3f7_76 := int64(f3_2) * int64(f7_38) |
||||
f3f8_38 := int64(f3_2) * int64(f8_19) |
||||
f3f9_76 := int64(f3_2) * int64(f9_38) |
||||
f4f4 := int64(f4) * int64(f4) |
||||
f4f5_2 := int64(f4_2) * int64(f5) |
||||
f4f6_38 := int64(f4_2) * int64(f6_19) |
||||
f4f7_38 := int64(f4) * int64(f7_38) |
||||
f4f8_38 := int64(f4_2) * int64(f8_19) |
||||
f4f9_38 := int64(f4) * int64(f9_38) |
||||
f5f5_38 := int64(f5) * int64(f5_38) |
||||
f5f6_38 := int64(f5_2) * int64(f6_19) |
||||
f5f7_76 := int64(f5_2) * int64(f7_38) |
||||
f5f8_38 := int64(f5_2) * int64(f8_19) |
||||
f5f9_76 := int64(f5_2) * int64(f9_38) |
||||
f6f6_19 := int64(f6) * int64(f6_19) |
||||
f6f7_38 := int64(f6) * int64(f7_38) |
||||
f6f8_38 := int64(f6_2) * int64(f8_19) |
||||
f6f9_38 := int64(f6) * int64(f9_38) |
||||
f7f7_38 := int64(f7) * int64(f7_38) |
||||
f7f8_38 := int64(f7_2) * int64(f8_19) |
||||
f7f9_76 := int64(f7_2) * int64(f9_38) |
||||
f8f8_19 := int64(f8) * int64(f8_19) |
||||
f8f9_38 := int64(f8) * int64(f9_38) |
||||
f9f9_38 := int64(f9) * int64(f9_38) |
||||
h0 := f0f0 + f1f9_76 + f2f8_38 + f3f7_76 + f4f6_38 + f5f5_38 |
||||
h1 := f0f1_2 + f2f9_38 + f3f8_38 + f4f7_38 + f5f6_38 |
||||
h2 := f0f2_2 + f1f1_2 + f3f9_76 + f4f8_38 + f5f7_76 + f6f6_19 |
||||
h3 := f0f3_2 + f1f2_2 + f4f9_38 + f5f8_38 + f6f7_38 |
||||
h4 := f0f4_2 + f1f3_4 + f2f2 + f5f9_76 + f6f8_38 + f7f7_38 |
||||
h5 := f0f5_2 + f1f4_2 + f2f3_2 + f6f9_38 + f7f8_38 |
||||
h6 := f0f6_2 + f1f5_4 + f2f4_2 + f3f3_2 + f7f9_76 + f8f8_19 |
||||
h7 := f0f7_2 + f1f6_2 + f2f5_2 + f3f4_2 + f8f9_38 |
||||
h8 := f0f8_2 + f1f7_4 + f2f6_2 + f3f5_4 + f4f4 + f9f9_38 |
||||
h9 := f0f9_2 + f1f8_2 + f2f7_2 + f3f6_2 + f4f5_2 |
||||
var carry [10]int64 |
||||
|
||||
carry[0] = (h0 + (1 << 25)) >> 26 |
||||
h1 += carry[0] |
||||
h0 -= carry[0] << 26 |
||||
carry[4] = (h4 + (1 << 25)) >> 26 |
||||
h5 += carry[4] |
||||
h4 -= carry[4] << 26 |
||||
|
||||
carry[1] = (h1 + (1 << 24)) >> 25 |
||||
h2 += carry[1] |
||||
h1 -= carry[1] << 25 |
||||
carry[5] = (h5 + (1 << 24)) >> 25 |
||||
h6 += carry[5] |
||||
h5 -= carry[5] << 25 |
||||
|
||||
carry[2] = (h2 + (1 << 25)) >> 26 |
||||
h3 += carry[2] |
||||
h2 -= carry[2] << 26 |
||||
carry[6] = (h6 + (1 << 25)) >> 26 |
||||
h7 += carry[6] |
||||
h6 -= carry[6] << 26 |
||||
|
||||
carry[3] = (h3 + (1 << 24)) >> 25 |
||||
h4 += carry[3] |
||||
h3 -= carry[3] << 25 |
||||
carry[7] = (h7 + (1 << 24)) >> 25 |
||||
h8 += carry[7] |
||||
h7 -= carry[7] << 25 |
||||
|
||||
carry[4] = (h4 + (1 << 25)) >> 26 |
||||
h5 += carry[4] |
||||
h4 -= carry[4] << 26 |
||||
carry[8] = (h8 + (1 << 25)) >> 26 |
||||
h9 += carry[8] |
||||
h8 -= carry[8] << 26 |
||||
|
||||
carry[9] = (h9 + (1 << 24)) >> 25 |
||||
h0 += carry[9] * 19 |
||||
h9 -= carry[9] << 25 |
||||
|
||||
carry[0] = (h0 + (1 << 25)) >> 26 |
||||
h1 += carry[0] |
||||
h0 -= carry[0] << 26 |
||||
|
||||
h[0] = int32(h0) |
||||
h[1] = int32(h1) |
||||
h[2] = int32(h2) |
||||
h[3] = int32(h3) |
||||
h[4] = int32(h4) |
||||
h[5] = int32(h5) |
||||
h[6] = int32(h6) |
||||
h[7] = int32(h7) |
||||
h[8] = int32(h8) |
||||
h[9] = int32(h9) |
||||
} |
||||
|
||||
// feMul121666 calculates h = f * 121666. Can overlap h with f.
|
||||
//
|
||||
// Preconditions:
|
||||
// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc.
|
||||
//
|
||||
// Postconditions:
|
||||
// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc.
|
||||
func feMul121666(h, f *fieldElement) { |
||||
h0 := int64(f[0]) * 121666 |
||||
h1 := int64(f[1]) * 121666 |
||||
h2 := int64(f[2]) * 121666 |
||||
h3 := int64(f[3]) * 121666 |
||||
h4 := int64(f[4]) * 121666 |
||||
h5 := int64(f[5]) * 121666 |
||||
h6 := int64(f[6]) * 121666 |
||||
h7 := int64(f[7]) * 121666 |
||||
h8 := int64(f[8]) * 121666 |
||||
h9 := int64(f[9]) * 121666 |
||||
var carry [10]int64 |
||||
|
||||
carry[9] = (h9 + (1 << 24)) >> 25 |
||||
h0 += carry[9] * 19 |
||||
h9 -= carry[9] << 25 |
||||
carry[1] = (h1 + (1 << 24)) >> 25 |
||||
h2 += carry[1] |
||||
h1 -= carry[1] << 25 |
||||
carry[3] = (h3 + (1 << 24)) >> 25 |
||||
h4 += carry[3] |
||||
h3 -= carry[3] << 25 |
||||
carry[5] = (h5 + (1 << 24)) >> 25 |
||||
h6 += carry[5] |
||||
h5 -= carry[5] << 25 |
||||
carry[7] = (h7 + (1 << 24)) >> 25 |
||||
h8 += carry[7] |
||||
h7 -= carry[7] << 25 |
||||
|
||||
carry[0] = (h0 + (1 << 25)) >> 26 |
||||
h1 += carry[0] |
||||
h0 -= carry[0] << 26 |
||||
carry[2] = (h2 + (1 << 25)) >> 26 |
||||
h3 += carry[2] |
||||
h2 -= carry[2] << 26 |
||||
carry[4] = (h4 + (1 << 25)) >> 26 |
||||
h5 += carry[4] |
||||
h4 -= carry[4] << 26 |
||||
carry[6] = (h6 + (1 << 25)) >> 26 |
||||
h7 += carry[6] |
||||
h6 -= carry[6] << 26 |
||||
carry[8] = (h8 + (1 << 25)) >> 26 |
||||
h9 += carry[8] |
||||
h8 -= carry[8] << 26 |
||||
|
||||
h[0] = int32(h0) |
||||
h[1] = int32(h1) |
||||
h[2] = int32(h2) |
||||
h[3] = int32(h3) |
||||
h[4] = int32(h4) |
||||
h[5] = int32(h5) |
||||
h[6] = int32(h6) |
||||
h[7] = int32(h7) |
||||
h[8] = int32(h8) |
||||
h[9] = int32(h9) |
||||
} |
||||
|
||||
// feInvert sets out = z^-1.
|
||||
func feInvert(out, z *fieldElement) { |
||||
var t0, t1, t2, t3 fieldElement |
||||
var i int |
||||
|
||||
feSquare(&t0, z) |
||||
for i = 1; i < 1; i++ { |
||||
feSquare(&t0, &t0) |
||||
} |
||||
feSquare(&t1, &t0) |
||||
for i = 1; i < 2; i++ { |
||||
feSquare(&t1, &t1) |
||||
} |
||||
feMul(&t1, z, &t1) |
||||
feMul(&t0, &t0, &t1) |
||||
feSquare(&t2, &t0) |
||||
for i = 1; i < 1; i++ { |
||||
feSquare(&t2, &t2) |
||||
} |
||||
feMul(&t1, &t1, &t2) |
||||
feSquare(&t2, &t1) |
||||
for i = 1; i < 5; i++ { |
||||
feSquare(&t2, &t2) |
||||
} |
||||
feMul(&t1, &t2, &t1) |
||||
feSquare(&t2, &t1) |
||||
for i = 1; i < 10; i++ { |
||||
feSquare(&t2, &t2) |
||||
} |
||||
feMul(&t2, &t2, &t1) |
||||
feSquare(&t3, &t2) |
||||
for i = 1; i < 20; i++ { |
||||
feSquare(&t3, &t3) |
||||
} |
||||
feMul(&t2, &t3, &t2) |
||||
feSquare(&t2, &t2) |
||||
for i = 1; i < 10; i++ { |
||||
feSquare(&t2, &t2) |
||||
} |
||||
feMul(&t1, &t2, &t1) |
||||
feSquare(&t2, &t1) |
||||
for i = 1; i < 50; i++ { |
||||
feSquare(&t2, &t2) |
||||
} |
||||
feMul(&t2, &t2, &t1) |
||||
feSquare(&t3, &t2) |
||||
for i = 1; i < 100; i++ { |
||||
feSquare(&t3, &t3) |
||||
} |
||||
feMul(&t2, &t3, &t2) |
||||
feSquare(&t2, &t2) |
||||
for i = 1; i < 50; i++ { |
||||
feSquare(&t2, &t2) |
||||
} |
||||
feMul(&t1, &t2, &t1) |
||||
feSquare(&t1, &t1) |
||||
for i = 1; i < 5; i++ { |
||||
feSquare(&t1, &t1) |
||||
} |
||||
feMul(out, &t1, &t0) |
||||
} |
||||
|
||||
func scalarMult(out, in, base *[32]byte) { |
||||
var e [32]byte |
||||
|
||||
copy(e[:], in[:]) |
||||
e[0] &= 248 |
||||
e[31] &= 127 |
||||
e[31] |= 64 |
||||
|
||||
var x1, x2, z2, x3, z3, tmp0, tmp1 fieldElement |
||||
feFromBytes(&x1, base) |
||||
feOne(&x2) |
||||
feCopy(&x3, &x1) |
||||
feOne(&z3) |
||||
|
||||
swap := int32(0) |
||||
for pos := 254; pos >= 0; pos-- { |
||||
b := e[pos/8] >> uint(pos&7) |
||||
b &= 1 |
||||
swap ^= int32(b) |
||||
feCSwap(&x2, &x3, swap) |
||||
feCSwap(&z2, &z3, swap) |
||||
swap = int32(b) |
||||
|
||||
feSub(&tmp0, &x3, &z3) |
||||
feSub(&tmp1, &x2, &z2) |
||||
feAdd(&x2, &x2, &z2) |
||||
feAdd(&z2, &x3, &z3) |
||||
feMul(&z3, &tmp0, &x2) |
||||
feMul(&z2, &z2, &tmp1) |
||||
feSquare(&tmp0, &tmp1) |
||||
feSquare(&tmp1, &x2) |
||||
feAdd(&x3, &z3, &z2) |
||||
feSub(&z2, &z3, &z2) |
||||
feMul(&x2, &tmp1, &tmp0) |
||||
feSub(&tmp1, &tmp1, &tmp0) |
||||
feSquare(&z2, &z2) |
||||
feMul121666(&z3, &tmp1) |
||||
feSquare(&x3, &x3) |
||||
feAdd(&tmp0, &tmp0, &z3) |
||||
feMul(&z3, &x1, &z2) |
||||
feMul(&z2, &tmp1, &tmp0) |
||||
} |
||||
|
||||
feCSwap(&x2, &x3, swap) |
||||
feCSwap(&z2, &z3, swap) |
||||
|
||||
feInvert(&z2, &z2) |
||||
feMul(&x2, &x2, &z2) |
||||
feToBytes(out, &x2) |
||||
} |
@ -0,0 +1,23 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package curve25519 provides an implementation of scalar multiplication on
|
||||
// the elliptic curve known as curve25519. See http://cr.yp.to/ecdh.html
|
||||
package curve25519 // import "golang.org/x/crypto/curve25519"
|
||||
|
||||
// basePoint is the x coordinate of the generator of the curve.
|
||||
var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} |
||||
|
||||
// ScalarMult sets dst to the product in*base where dst and base are the x
|
||||
// coordinates of group points and all values are in little-endian form.
|
||||
func ScalarMult(dst, in, base *[32]byte) { |
||||
scalarMult(dst, in, base) |
||||
} |
||||
|
||||
// ScalarBaseMult sets dst to the product in*base where dst and base are the x
|
||||
// coordinates of group points, base is the standard generator and all values
|
||||
// are in little-endian form.
|
||||
func ScalarBaseMult(dst, in *[32]byte) { |
||||
ScalarMult(dst, in, &basePoint) |
||||
} |
@ -0,0 +1,73 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved. |
||||
// Use of this source code is governed by a BSD-style |
||||
// license that can be found in the LICENSE file. |
||||
|
||||
// This code was translated into a form compatible with 6a from the public |
||||
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html |
||||
|
||||
// +build amd64,!gccgo,!appengine |
||||
|
||||
#include "const_amd64.h" |
||||
|
||||
// func freeze(inout *[5]uint64) |
||||
TEXT ·freeze(SB),7,$0-8 |
||||
MOVQ inout+0(FP), DI |
||||
|
||||
MOVQ 0(DI),SI |
||||
MOVQ 8(DI),DX |
||||
MOVQ 16(DI),CX |
||||
MOVQ 24(DI),R8 |
||||
MOVQ 32(DI),R9 |
||||
MOVQ $REDMASK51,AX |
||||
MOVQ AX,R10 |
||||
SUBQ $18,R10 |
||||
MOVQ $3,R11 |
||||
REDUCELOOP: |
||||
MOVQ SI,R12 |
||||
SHRQ $51,R12 |
||||
ANDQ AX,SI |
||||
ADDQ R12,DX |
||||
MOVQ DX,R12 |
||||
SHRQ $51,R12 |
||||
ANDQ AX,DX |
||||
ADDQ R12,CX |
||||
MOVQ CX,R12 |
||||
SHRQ $51,R12 |
||||
ANDQ AX,CX |
||||
ADDQ R12,R8 |
||||
MOVQ R8,R12 |
||||
SHRQ $51,R12 |
||||
ANDQ AX,R8 |
||||
ADDQ R12,R9 |
||||
MOVQ R9,R12 |
||||
SHRQ $51,R12 |
||||
ANDQ AX,R9 |
||||
IMUL3Q $19,R12,R12 |
||||
ADDQ R12,SI |
||||
SUBQ $1,R11 |
||||
JA REDUCELOOP |
||||
MOVQ $1,R12 |
||||
CMPQ R10,SI |
||||
CMOVQLT R11,R12 |
||||
CMPQ AX,DX |
||||
CMOVQNE R11,R12 |
||||
CMPQ AX,CX |
||||
CMOVQNE R11,R12 |
||||
CMPQ AX,R8 |
||||
CMOVQNE R11,R12 |
||||
CMPQ AX,R9 |
||||
CMOVQNE R11,R12 |
||||
NEGQ R12 |
||||
ANDQ R12,AX |
||||
ANDQ R12,R10 |
||||
SUBQ R10,SI |
||||
SUBQ AX,DX |
||||
SUBQ AX,CX |
||||
SUBQ AX,R8 |
||||
SUBQ AX,R9 |
||||
MOVQ SI,0(DI) |
||||
MOVQ DX,8(DI) |
||||
MOVQ CX,16(DI) |
||||
MOVQ R8,24(DI) |
||||
MOVQ R9,32(DI) |
||||
RET |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,240 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build amd64,!gccgo,!appengine
|
||||
|
||||
package curve25519 |
||||
|
||||
// These functions are implemented in the .s files. The names of the functions
|
||||
// in the rest of the file are also taken from the SUPERCOP sources to help
|
||||
// people following along.
|
||||
|
||||
//go:noescape
|
||||
|
||||
func cswap(inout *[5]uint64, v uint64) |
||||
|
||||
//go:noescape
|
||||
|
||||
func ladderstep(inout *[5][5]uint64) |
||||
|
||||
//go:noescape
|
||||
|
||||
func freeze(inout *[5]uint64) |
||||
|
||||
//go:noescape
|
||||
|
||||
func mul(dest, a, b *[5]uint64) |
||||
|
||||
//go:noescape
|
||||
|
||||
func square(out, in *[5]uint64) |
||||
|
||||
// mladder uses a Montgomery ladder to calculate (xr/zr) *= s.
|
||||
func mladder(xr, zr *[5]uint64, s *[32]byte) { |
||||
var work [5][5]uint64 |
||||
|
||||
work[0] = *xr |
||||
setint(&work[1], 1) |
||||
setint(&work[2], 0) |
||||
work[3] = *xr |
||||
setint(&work[4], 1) |
||||
|
||||
j := uint(6) |
||||
var prevbit byte |
||||
|
||||
for i := 31; i >= 0; i-- { |
||||
for j < 8 { |
||||
bit := ((*s)[i] >> j) & 1 |
||||
swap := bit ^ prevbit |
||||
prevbit = bit |
||||
cswap(&work[1], uint64(swap)) |
||||
ladderstep(&work) |
||||
j-- |
||||
} |
||||
j = 7 |
||||
} |
||||
|
||||
*xr = work[1] |
||||
*zr = work[2] |
||||
} |
||||
|
||||
func scalarMult(out, in, base *[32]byte) { |
||||
var e [32]byte |
||||
copy(e[:], (*in)[:]) |
||||
e[0] &= 248 |
||||
e[31] &= 127 |
||||
e[31] |= 64 |
||||
|
||||
var t, z [5]uint64 |
||||
unpack(&t, base) |
||||
mladder(&t, &z, &e) |
||||
invert(&z, &z) |
||||
mul(&t, &t, &z) |
||||
pack(out, &t) |
||||
} |
||||
|
||||
func setint(r *[5]uint64, v uint64) { |
||||
r[0] = v |
||||
r[1] = 0 |
||||
r[2] = 0 |
||||
r[3] = 0 |
||||
r[4] = 0 |
||||
} |
||||
|
||||
// unpack sets r = x where r consists of 5, 51-bit limbs in little-endian
|
||||
// order.
|
||||
func unpack(r *[5]uint64, x *[32]byte) { |
||||
r[0] = uint64(x[0]) | |
||||
uint64(x[1])<<8 | |
||||
uint64(x[2])<<16 | |
||||
uint64(x[3])<<24 | |
||||
uint64(x[4])<<32 | |
||||
uint64(x[5])<<40 | |
||||
uint64(x[6]&7)<<48 |
||||
|
||||
r[1] = uint64(x[6])>>3 | |
||||
uint64(x[7])<<5 | |
||||
uint64(x[8])<<13 | |
||||
uint64(x[9])<<21 | |
||||
uint64(x[10])<<29 | |
||||
uint64(x[11])<<37 | |
||||
uint64(x[12]&63)<<45 |
||||
|
||||
r[2] = uint64(x[12])>>6 | |
||||
uint64(x[13])<<2 | |
||||
uint64(x[14])<<10 | |
||||
uint64(x[15])<<18 | |
||||
uint64(x[16])<<26 | |
||||
uint64(x[17])<<34 | |
||||
uint64(x[18])<<42 | |
||||
uint64(x[19]&1)<<50 |
||||
|
||||
r[3] = uint64(x[19])>>1 | |
||||
uint64(x[20])<<7 | |
||||
uint64(x[21])<<15 | |
||||
uint64(x[22])<<23 | |
||||
uint64(x[23])<<31 | |
||||
uint64(x[24])<<39 | |
||||
uint64(x[25]&15)<<47 |
||||
|
||||
r[4] = uint64(x[25])>>4 | |
||||
uint64(x[26])<<4 | |
||||
uint64(x[27])<<12 | |
||||
uint64(x[28])<<20 | |
||||
uint64(x[29])<<28 | |
||||
uint64(x[30])<<36 | |
||||
uint64(x[31]&127)<<44 |
||||
} |
||||
|
||||
// pack sets out = x where out is the usual, little-endian form of the 5,
|
||||
// 51-bit limbs in x.
|
||||
func pack(out *[32]byte, x *[5]uint64) { |
||||
t := *x |
||||
freeze(&t) |
||||
|
||||
out[0] = byte(t[0]) |
||||
out[1] = byte(t[0] >> 8) |
||||
out[2] = byte(t[0] >> 16) |
||||
out[3] = byte(t[0] >> 24) |
||||
out[4] = byte(t[0] >> 32) |
||||
out[5] = byte(t[0] >> 40) |
||||
out[6] = byte(t[0] >> 48) |
||||
|
||||
out[6] ^= byte(t[1]<<3) & 0xf8 |
||||
out[7] = byte(t[1] >> 5) |
||||
out[8] = byte(t[1] >> 13) |
||||
out[9] = byte(t[1] >> 21) |
||||
out[10] = byte(t[1] >> 29) |
||||
out[11] = byte(t[1] >> 37) |
||||
out[12] = byte(t[1] >> 45) |
||||
|
||||
out[12] ^= byte(t[2]<<6) & 0xc0 |
||||
out[13] = byte(t[2] >> 2) |
||||
out[14] = byte(t[2] >> 10) |
||||
out[15] = byte(t[2] >> 18) |
||||
out[16] = byte(t[2] >> 26) |
||||
out[17] = byte(t[2] >> 34) |
||||
out[18] = byte(t[2] >> 42) |
||||
out[19] = byte(t[2] >> 50) |
||||
|
||||
out[19] ^= byte(t[3]<<1) & 0xfe |
||||
out[20] = byte(t[3] >> 7) |
||||
out[21] = byte(t[3] >> 15) |
||||
out[22] = byte(t[3] >> 23) |
||||
out[23] = byte(t[3] >> 31) |
||||
out[24] = byte(t[3] >> 39) |
||||
out[25] = byte(t[3] >> 47) |
||||
|
||||
out[25] ^= byte(t[4]<<4) & 0xf0 |
||||
out[26] = byte(t[4] >> 4) |
||||
out[27] = byte(t[4] >> 12) |
||||
out[28] = byte(t[4] >> 20) |
||||
out[29] = byte(t[4] >> 28) |
||||
out[30] = byte(t[4] >> 36) |
||||
out[31] = byte(t[4] >> 44) |
||||
} |
||||
|
||||
// invert calculates r = x^-1 mod p using Fermat's little theorem.
|
||||
func invert(r *[5]uint64, x *[5]uint64) { |
||||
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t [5]uint64 |
||||
|
||||
square(&z2, x) /* 2 */ |
||||
square(&t, &z2) /* 4 */ |
||||
square(&t, &t) /* 8 */ |
||||
mul(&z9, &t, x) /* 9 */ |
||||
mul(&z11, &z9, &z2) /* 11 */ |
||||
square(&t, &z11) /* 22 */ |
||||
mul(&z2_5_0, &t, &z9) /* 2^5 - 2^0 = 31 */ |
||||
|
||||
square(&t, &z2_5_0) /* 2^6 - 2^1 */ |
||||
for i := 1; i < 5; i++ { /* 2^20 - 2^10 */ |
||||
square(&t, &t) |
||||
} |
||||
mul(&z2_10_0, &t, &z2_5_0) /* 2^10 - 2^0 */ |
||||
|
||||
square(&t, &z2_10_0) /* 2^11 - 2^1 */ |
||||
for i := 1; i < 10; i++ { /* 2^20 - 2^10 */ |
||||
square(&t, &t) |
||||
} |
||||
mul(&z2_20_0, &t, &z2_10_0) /* 2^20 - 2^0 */ |
||||
|
||||
square(&t, &z2_20_0) /* 2^21 - 2^1 */ |
||||
for i := 1; i < 20; i++ { /* 2^40 - 2^20 */ |
||||
square(&t, &t) |
||||
} |
||||
mul(&t, &t, &z2_20_0) /* 2^40 - 2^0 */ |
||||
|
||||
square(&t, &t) /* 2^41 - 2^1 */ |
||||
for i := 1; i < 10; i++ { /* 2^50 - 2^10 */ |
||||
square(&t, &t) |
||||
} |
||||
mul(&z2_50_0, &t, &z2_10_0) /* 2^50 - 2^0 */ |
||||
|
||||
square(&t, &z2_50_0) /* 2^51 - 2^1 */ |
||||
for i := 1; i < 50; i++ { /* 2^100 - 2^50 */ |
||||
square(&t, &t) |
||||
} |
||||
mul(&z2_100_0, &t, &z2_50_0) /* 2^100 - 2^0 */ |
||||
|
||||
square(&t, &z2_100_0) /* 2^101 - 2^1 */ |
||||
for i := 1; i < 100; i++ { /* 2^200 - 2^100 */ |
||||
square(&t, &t) |
||||
} |
||||
mul(&t, &t, &z2_100_0) /* 2^200 - 2^0 */ |
||||
|
||||
square(&t, &t) /* 2^201 - 2^1 */ |
||||
for i := 1; i < 50; i++ { /* 2^250 - 2^50 */ |
||||
square(&t, &t) |
||||
} |
||||
mul(&t, &t, &z2_50_0) /* 2^250 - 2^0 */ |
||||
|
||||
square(&t, &t) /* 2^251 - 2^1 */ |
||||
square(&t, &t) /* 2^252 - 2^2 */ |
||||
square(&t, &t) /* 2^253 - 2^3 */ |
||||
|
||||
square(&t, &t) /* 2^254 - 2^4 */ |
||||
|
||||
square(&t, &t) /* 2^255 - 2^5 */ |
||||
mul(r, &t, &z11) /* 2^255 - 21 */ |
||||
} |
@ -0,0 +1,169 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved. |
||||
// Use of this source code is governed by a BSD-style |
||||
// license that can be found in the LICENSE file. |
||||
|
||||
// This code was translated into a form compatible with 6a from the public |
||||
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html |
||||
|
||||
// +build amd64,!gccgo,!appengine |
||||
|
||||
#include "const_amd64.h" |
||||
|
||||
// func mul(dest, a, b *[5]uint64) |
||||
TEXT ·mul(SB),0,$16-24 |
||||
MOVQ dest+0(FP), DI |
||||
MOVQ a+8(FP), SI |
||||
MOVQ b+16(FP), DX |
||||
|
||||
MOVQ DX,CX |
||||
MOVQ 24(SI),DX |
||||
IMUL3Q $19,DX,AX |
||||
MOVQ AX,0(SP) |
||||
MULQ 16(CX) |
||||
MOVQ AX,R8 |
||||
MOVQ DX,R9 |
||||
MOVQ 32(SI),DX |
||||
IMUL3Q $19,DX,AX |
||||
MOVQ AX,8(SP) |
||||
MULQ 8(CX) |
||||
ADDQ AX,R8 |
||||
ADCQ DX,R9 |
||||
MOVQ 0(SI),AX |
||||
MULQ 0(CX) |
||||
ADDQ AX,R8 |
||||
ADCQ DX,R9 |
||||
MOVQ 0(SI),AX |
||||
MULQ 8(CX) |
||||
MOVQ AX,R10 |
||||
MOVQ DX,R11 |
||||
MOVQ 0(SI),AX |
||||
MULQ 16(CX) |
||||
MOVQ AX,R12 |
||||
MOVQ DX,R13 |
||||
MOVQ 0(SI),AX |
||||
MULQ 24(CX) |
||||
MOVQ AX,R14 |
||||
MOVQ DX,R15 |
||||
MOVQ 0(SI),AX |
||||
MULQ 32(CX) |
||||
MOVQ AX,BX |
||||
MOVQ DX,BP |
||||
MOVQ 8(SI),AX |
||||
MULQ 0(CX) |
||||
ADDQ AX,R10 |
||||
ADCQ DX,R11 |
||||
MOVQ 8(SI),AX |
||||
MULQ 8(CX) |
||||
ADDQ AX,R12 |
||||
ADCQ DX,R13 |
||||
MOVQ 8(SI),AX |
||||
MULQ 16(CX) |
||||
ADDQ AX,R14 |
||||
ADCQ DX,R15 |
||||
MOVQ 8(SI),AX |
||||
MULQ 24(CX) |
||||
ADDQ AX,BX |
||||
ADCQ DX,BP |
||||
MOVQ 8(SI),DX |
||||
IMUL3Q $19,DX,AX |
||||
MULQ 32(CX) |
||||
ADDQ AX,R8 |
||||
ADCQ DX,R9 |
||||
MOVQ 16(SI),AX |
||||
MULQ 0(CX) |
||||
ADDQ AX,R12 |
||||
ADCQ DX,R13 |
||||
MOVQ 16(SI),AX |
||||
MULQ 8(CX) |
||||
ADDQ AX,R14 |
||||
ADCQ DX,R15 |
||||
MOVQ 16(SI),AX |
||||
MULQ 16(CX) |
||||
ADDQ AX,BX |
||||
ADCQ DX,BP |
||||
MOVQ 16(SI),DX |
||||
IMUL3Q $19,DX,AX |
||||
MULQ 24(CX) |
||||
ADDQ AX,R8 |
||||
ADCQ DX,R9 |
||||
MOVQ 16(SI),DX |
||||
IMUL3Q $19,DX,AX |
||||
MULQ 32(CX) |
||||
ADDQ AX,R10 |
||||
ADCQ DX,R11 |
||||
MOVQ 24(SI),AX |
||||
MULQ 0(CX) |
||||
ADDQ AX,R14 |
||||
ADCQ DX,R15 |
||||
MOVQ 24(SI),AX |
||||
MULQ 8(CX) |
||||
ADDQ AX,BX |
||||
ADCQ DX,BP |
||||
MOVQ 0(SP),AX |
||||
MULQ 24(CX) |
||||
ADDQ AX,R10 |
||||
ADCQ DX,R11 |
||||
MOVQ 0(SP),AX |
||||
MULQ 32(CX) |
||||
ADDQ AX,R12 |
||||
ADCQ DX,R13 |
||||
MOVQ 32(SI),AX |
||||
MULQ 0(CX) |
||||
ADDQ AX,BX |
||||
ADCQ DX,BP |
||||
MOVQ 8(SP),AX |
||||
MULQ 16(CX) |
||||
ADDQ AX,R10 |
||||
ADCQ DX,R11 |
||||
MOVQ 8(SP),AX |
||||
MULQ 24(CX) |
||||
ADDQ AX,R12 |
||||
ADCQ DX,R13 |
||||
MOVQ 8(SP),AX |
||||
MULQ 32(CX) |
||||
ADDQ AX,R14 |
||||
ADCQ DX,R15 |
||||
MOVQ $REDMASK51,SI |
||||
SHLQ $13,R9:R8 |
||||
ANDQ SI,R8 |
||||
SHLQ $13,R11:R10 |
||||
ANDQ SI,R10 |
||||
ADDQ R9,R10 |
||||
SHLQ $13,R13:R12 |
||||
ANDQ SI,R12 |
||||
ADDQ R11,R12 |
||||
SHLQ $13,R15:R14 |
||||
ANDQ SI,R14 |
||||
ADDQ R13,R14 |
||||
SHLQ $13,BP:BX |
||||
ANDQ SI,BX |
||||
ADDQ R15,BX |
||||
IMUL3Q $19,BP,DX |
||||
ADDQ DX,R8 |
||||
MOVQ R8,DX |
||||
SHRQ $51,DX |
||||
ADDQ R10,DX |
||||
MOVQ DX,CX |
||||
SHRQ $51,DX |
||||
ANDQ SI,R8 |
||||
ADDQ R12,DX |
||||
MOVQ DX,R9 |
||||
SHRQ $51,DX |
||||
ANDQ SI,CX |
||||
ADDQ R14,DX |
||||
MOVQ DX,AX |
||||
SHRQ $51,DX |
||||
ANDQ SI,R9 |
||||
ADDQ BX,DX |
||||
MOVQ DX,R10 |
||||
SHRQ $51,DX |
||||
ANDQ SI,AX |
||||
IMUL3Q $19,DX,DX |
||||
ADDQ DX,R8 |
||||
ANDQ SI,R10 |
||||
MOVQ R8,0(DI) |
||||
MOVQ CX,8(DI) |
||||
MOVQ R9,16(DI) |
||||
MOVQ AX,24(DI) |
||||
MOVQ R10,32(DI) |
||||
RET |
@ -0,0 +1,132 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved. |
||||
// Use of this source code is governed by a BSD-style |
||||
// license that can be found in the LICENSE file. |
||||
|
||||
// This code was translated into a form compatible with 6a from the public |
||||
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html |
||||
|
||||
// +build amd64,!gccgo,!appengine |
||||
|
||||
#include "const_amd64.h" |
||||
|
||||
// func square(out, in *[5]uint64) |
||||
TEXT ·square(SB),7,$0-16 |
||||
MOVQ out+0(FP), DI |
||||
MOVQ in+8(FP), SI |
||||
|
||||
MOVQ 0(SI),AX |
||||
MULQ 0(SI) |
||||
MOVQ AX,CX |
||||
MOVQ DX,R8 |
||||
MOVQ 0(SI),AX |
||||
SHLQ $1,AX |
||||
MULQ 8(SI) |
||||
MOVQ AX,R9 |
||||
MOVQ DX,R10 |
||||
MOVQ 0(SI),AX |
||||
SHLQ $1,AX |
||||
MULQ 16(SI) |
||||
MOVQ AX,R11 |
||||
MOVQ DX,R12 |
||||
MOVQ 0(SI),AX |
||||
SHLQ $1,AX |
||||
MULQ 24(SI) |
||||
MOVQ AX,R13 |
||||
MOVQ DX,R14 |
||||
MOVQ 0(SI),AX |
||||
SHLQ $1,AX |
||||
MULQ 32(SI) |
||||
MOVQ AX,R15 |
||||
MOVQ DX,BX |
||||
MOVQ 8(SI),AX |
||||
MULQ 8(SI) |
||||
ADDQ AX,R11 |
||||
ADCQ DX,R12 |
||||
MOVQ 8(SI),AX |
||||
SHLQ $1,AX |
||||
MULQ 16(SI) |
||||
ADDQ AX,R13 |
||||
ADCQ DX,R14 |
||||
MOVQ 8(SI),AX |
||||
SHLQ $1,AX |
||||
MULQ 24(SI) |
||||
ADDQ AX,R15 |
||||
ADCQ DX,BX |
||||
MOVQ 8(SI),DX |
||||
IMUL3Q $38,DX,AX |
||||
MULQ 32(SI) |
||||
ADDQ AX,CX |
||||
ADCQ DX,R8 |
||||
MOVQ 16(SI),AX |
||||
MULQ 16(SI) |
||||
ADDQ AX,R15 |
||||
ADCQ DX,BX |
||||
MOVQ 16(SI),DX |
||||
IMUL3Q $38,DX,AX |
||||
MULQ 24(SI) |
||||
ADDQ AX,CX |
||||
ADCQ DX,R8 |
||||
MOVQ 16(SI),DX |
||||
IMUL3Q $38,DX,AX |
||||
MULQ 32(SI) |
||||
ADDQ AX,R9 |
||||
ADCQ DX,R10 |
||||
MOVQ 24(SI),DX |
||||
IMUL3Q $19,DX,AX |
||||
MULQ 24(SI) |
||||
ADDQ AX,R9 |
||||
ADCQ DX,R10 |
||||
MOVQ 24(SI),DX |
||||
IMUL3Q $38,DX,AX |
||||
MULQ 32(SI) |
||||
ADDQ AX,R11 |
||||
ADCQ DX,R12 |
||||
MOVQ 32(SI),DX |
||||
IMUL3Q $19,DX,AX |
||||
MULQ 32(SI) |
||||
ADDQ AX,R13 |
||||
ADCQ DX,R14 |
||||
MOVQ $REDMASK51,SI |
||||
SHLQ $13,R8:CX |
||||
ANDQ SI,CX |
||||
SHLQ $13,R10:R9 |
||||
ANDQ SI,R9 |
||||
ADDQ R8,R9 |
||||
SHLQ $13,R12:R11 |
||||
ANDQ SI,R11 |
||||
ADDQ R10,R11 |
||||
SHLQ $13,R14:R13 |
||||
ANDQ SI,R13 |
||||
ADDQ R12,R13 |
||||
SHLQ $13,BX:R15 |
||||
ANDQ SI,R15 |
||||
ADDQ R14,R15 |
||||
IMUL3Q $19,BX,DX |
||||
ADDQ DX,CX |
||||
MOVQ CX,DX |
||||
SHRQ $51,DX |
||||
ADDQ R9,DX |
||||
ANDQ SI,CX |
||||
MOVQ DX,R8 |
||||
SHRQ $51,DX |
||||
ADDQ R11,DX |
||||
ANDQ SI,R8 |
||||
MOVQ DX,R9 |
||||
SHRQ $51,DX |
||||
ADDQ R13,DX |
||||
ANDQ SI,R9 |
||||
MOVQ DX,AX |
||||
SHRQ $51,DX |
||||
ADDQ R15,DX |
||||
ANDQ SI,AX |
||||
MOVQ DX,R10 |
||||
SHRQ $51,DX |
||||
IMUL3Q $19,DX,DX |
||||
ADDQ DX,CX |
||||
ANDQ SI,R10 |
||||
MOVQ CX,0(DI) |
||||
MOVQ R8,8(DI) |
||||
MOVQ R9,16(DI) |
||||
MOVQ AX,24(DI) |
||||
MOVQ R10,32(DI) |
||||
RET |
@ -0,0 +1,181 @@ |
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package ed25519 implements the Ed25519 signature algorithm. See
|
||||
// http://ed25519.cr.yp.to/.
|
||||
//
|
||||
// These functions are also compatible with the “Ed25519” function defined in
|
||||
// https://tools.ietf.org/html/draft-irtf-cfrg-eddsa-05.
|
||||
package ed25519 |
||||
|
||||
// This code is a port of the public domain, “ref10” implementation of ed25519
|
||||
// from SUPERCOP.
|
||||
|
||||
import ( |
||||
"crypto" |
||||
cryptorand "crypto/rand" |
||||
"crypto/sha512" |
||||
"crypto/subtle" |
||||
"errors" |
||||
"io" |
||||
"strconv" |
||||
|
||||
"golang.org/x/crypto/ed25519/internal/edwards25519" |
||||
) |
||||
|
||||
const ( |
||||
// PublicKeySize is the size, in bytes, of public keys as used in this package.
|
||||
PublicKeySize = 32 |
||||
// PrivateKeySize is the size, in bytes, of private keys as used in this package.
|
||||
PrivateKeySize = 64 |
||||
// SignatureSize is the size, in bytes, of signatures generated and verified by this package.
|
||||
SignatureSize = 64 |
||||
) |
||||
|
||||
// PublicKey is the type of Ed25519 public keys.
|
||||
type PublicKey []byte |
||||
|
||||
// PrivateKey is the type of Ed25519 private keys. It implements crypto.Signer.
|
||||
type PrivateKey []byte |
||||
|
||||
// Public returns the PublicKey corresponding to priv.
|
||||
func (priv PrivateKey) Public() crypto.PublicKey { |
||||
publicKey := make([]byte, PublicKeySize) |
||||
copy(publicKey, priv[32:]) |
||||
return PublicKey(publicKey) |
||||
} |
||||
|
||||
// Sign signs the given message with priv.
|
||||
// Ed25519 performs two passes over messages to be signed and therefore cannot
|
||||
// handle pre-hashed messages. Thus opts.HashFunc() must return zero to
|
||||
// indicate the message hasn't been hashed. This can be achieved by passing
|
||||
// crypto.Hash(0) as the value for opts.
|
||||
func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) { |
||||
if opts.HashFunc() != crypto.Hash(0) { |
||||
return nil, errors.New("ed25519: cannot sign hashed message") |
||||
} |
||||
|
||||
return Sign(priv, message), nil |
||||
} |
||||
|
||||
// GenerateKey generates a public/private key pair using entropy from rand.
|
||||
// If rand is nil, crypto/rand.Reader will be used.
|
||||
func GenerateKey(rand io.Reader) (publicKey PublicKey, privateKey PrivateKey, err error) { |
||||
if rand == nil { |
||||
rand = cryptorand.Reader |
||||
} |
||||
|
||||
privateKey = make([]byte, PrivateKeySize) |
||||
publicKey = make([]byte, PublicKeySize) |
||||
_, err = io.ReadFull(rand, privateKey[:32]) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
digest := sha512.Sum512(privateKey[:32]) |
||||
digest[0] &= 248 |
||||
digest[31] &= 127 |
||||
digest[31] |= 64 |
||||
|
||||
var A edwards25519.ExtendedGroupElement |
||||
var hBytes [32]byte |
||||
copy(hBytes[:], digest[:]) |
||||
edwards25519.GeScalarMultBase(&A, &hBytes) |
||||
var publicKeyBytes [32]byte |
||||
A.ToBytes(&publicKeyBytes) |
||||
|
||||
copy(privateKey[32:], publicKeyBytes[:]) |
||||
copy(publicKey, publicKeyBytes[:]) |
||||
|
||||
return publicKey, privateKey, nil |
||||
} |
||||
|
||||
// Sign signs the message with privateKey and returns a signature. It will
|
||||
// panic if len(privateKey) is not PrivateKeySize.
|
||||
func Sign(privateKey PrivateKey, message []byte) []byte { |
||||
if l := len(privateKey); l != PrivateKeySize { |
||||
panic("ed25519: bad private key length: " + strconv.Itoa(l)) |
||||
} |
||||
|
||||
h := sha512.New() |
||||
h.Write(privateKey[:32]) |
||||
|
||||
var digest1, messageDigest, hramDigest [64]byte |
||||
var expandedSecretKey [32]byte |
||||
h.Sum(digest1[:0]) |
||||
copy(expandedSecretKey[:], digest1[:]) |
||||
expandedSecretKey[0] &= 248 |
||||
expandedSecretKey[31] &= 63 |
||||
expandedSecretKey[31] |= 64 |
||||
|
||||
h.Reset() |
||||
h.Write(digest1[32:]) |
||||
h.Write(message) |
||||
h.Sum(messageDigest[:0]) |
||||
|
||||
var messageDigestReduced [32]byte |
||||
edwards25519.ScReduce(&messageDigestReduced, &messageDigest) |
||||
var R edwards25519.ExtendedGroupElement |
||||
edwards25519.GeScalarMultBase(&R, &messageDigestReduced) |
||||
|
||||
var encodedR [32]byte |
||||
R.ToBytes(&encodedR) |
||||
|
||||
h.Reset() |
||||
h.Write(encodedR[:]) |
||||
h.Write(privateKey[32:]) |
||||
h.Write(message) |
||||
h.Sum(hramDigest[:0]) |
||||
var hramDigestReduced [32]byte |
||||
edwards25519.ScReduce(&hramDigestReduced, &hramDigest) |
||||
|
||||
var s [32]byte |
||||
edwards25519.ScMulAdd(&s, &hramDigestReduced, &expandedSecretKey, &messageDigestReduced) |
||||
|
||||
signature := make([]byte, SignatureSize) |
||||
copy(signature[:], encodedR[:]) |
||||
copy(signature[32:], s[:]) |
||||
|
||||
return signature |
||||
} |
||||
|
||||
// Verify reports whether sig is a valid signature of message by publicKey. It
|
||||
// will panic if len(publicKey) is not PublicKeySize.
|
||||
func Verify(publicKey PublicKey, message, sig []byte) bool { |
||||
if l := len(publicKey); l != PublicKeySize { |
||||
panic("ed25519: bad public key length: " + strconv.Itoa(l)) |
||||
} |
||||
|
||||
if len(sig) != SignatureSize || sig[63]&224 != 0 { |
||||
return false |
||||
} |
||||
|
||||
var A edwards25519.ExtendedGroupElement |
||||
var publicKeyBytes [32]byte |
||||
copy(publicKeyBytes[:], publicKey) |
||||
if !A.FromBytes(&publicKeyBytes) { |
||||
return false |
||||
} |
||||
edwards25519.FeNeg(&A.X, &A.X) |
||||
edwards25519.FeNeg(&A.T, &A.T) |
||||
|
||||
h := sha512.New() |
||||
h.Write(sig[:32]) |
||||
h.Write(publicKey[:]) |
||||
h.Write(message) |
||||
var digest [64]byte |
||||
h.Sum(digest[:0]) |
||||
|
||||
var hReduced [32]byte |
||||
edwards25519.ScReduce(&hReduced, &digest) |
||||
|
||||
var R edwards25519.ProjectiveGroupElement |
||||
var b [32]byte |
||||
copy(b[:], sig[32:]) |
||||
edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &b) |
||||
|
||||
var checkR [32]byte |
||||
R.ToBytes(&checkR) |
||||
return subtle.ConstantTimeCompare(sig[:32], checkR[:]) == 1 |
||||
} |
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,98 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"io" |
||||
"sync" |
||||
) |
||||
|
||||
// buffer provides a linked list buffer for data exchange
|
||||
// between producer and consumer. Theoretically the buffer is
|
||||
// of unlimited capacity as it does no allocation of its own.
|
||||
type buffer struct { |
||||
// protects concurrent access to head, tail and closed
|
||||
*sync.Cond |
||||
|
||||
head *element // the buffer that will be read first
|
||||
tail *element // the buffer that will be read last
|
||||
|
||||
closed bool |
||||
} |
||||
|
||||
// An element represents a single link in a linked list.
|
||||
type element struct { |
||||
buf []byte |
||||
next *element |
||||
} |
||||
|
||||
// newBuffer returns an empty buffer that is not closed.
|
||||
func newBuffer() *buffer { |
||||
e := new(element) |
||||
b := &buffer{ |
||||
Cond: newCond(), |
||||
head: e, |
||||
tail: e, |
||||
} |
||||
return b |
||||
} |
||||
|
||||
// write makes buf available for Read to receive.
|
||||
// buf must not be modified after the call to write.
|
||||
func (b *buffer) write(buf []byte) { |
||||
b.Cond.L.Lock() |
||||
e := &element{buf: buf} |
||||
b.tail.next = e |
||||
b.tail = e |
||||
b.Cond.Signal() |
||||
b.Cond.L.Unlock() |
||||
} |
||||
|
||||
// eof closes the buffer. Reads from the buffer once all
|
||||
// the data has been consumed will receive os.EOF.
|
||||
func (b *buffer) eof() error { |
||||
b.Cond.L.Lock() |
||||
b.closed = true |
||||
b.Cond.Signal() |
||||
b.Cond.L.Unlock() |
||||
return nil |
||||
} |
||||
|
||||
// Read reads data from the internal buffer in buf. Reads will block
|
||||
// if no data is available, or until the buffer is closed.
|
||||
func (b *buffer) Read(buf []byte) (n int, err error) { |
||||
b.Cond.L.Lock() |
||||
defer b.Cond.L.Unlock() |
||||
|
||||
for len(buf) > 0 { |
||||
// if there is data in b.head, copy it
|
||||
if len(b.head.buf) > 0 { |
||||
r := copy(buf, b.head.buf) |
||||
buf, b.head.buf = buf[r:], b.head.buf[r:] |
||||
n += r |
||||
continue |
||||
} |
||||
// if there is a next buffer, make it the head
|
||||
if len(b.head.buf) == 0 && b.head != b.tail { |
||||
b.head = b.head.next |
||||
continue |
||||
} |
||||
|
||||
// if at least one byte has been copied, return
|
||||
if n > 0 { |
||||
break |
||||
} |
||||
|
||||
// if nothing was read, and there is nothing outstanding
|
||||
// check to see if the buffer is closed.
|
||||
if b.closed { |
||||
err = io.EOF |
||||
break |
||||
} |
||||
// out of buffers, wait for producer
|
||||
b.Cond.Wait() |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,503 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"sort" |
||||
"time" |
||||
) |
||||
|
||||
// These constants from [PROTOCOL.certkeys] represent the algorithm names
|
||||
// for certificate types supported by this package.
|
||||
const ( |
||||
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" |
||||
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" |
||||
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" |
||||
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" |
||||
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" |
||||
CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com" |
||||
) |
||||
|
||||
// Certificate types distinguish between host and user
|
||||
// certificates. The values can be set in the CertType field of
|
||||
// Certificate.
|
||||
const ( |
||||
UserCert = 1 |
||||
HostCert = 2 |
||||
) |
||||
|
||||
// Signature represents a cryptographic signature.
|
||||
type Signature struct { |
||||
Format string |
||||
Blob []byte |
||||
} |
||||
|
||||
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
|
||||
// a certificate does not expire.
|
||||
const CertTimeInfinity = 1<<64 - 1 |
||||
|
||||
// An Certificate represents an OpenSSH certificate as defined in
|
||||
// [PROTOCOL.certkeys]?rev=1.8.
|
||||
type Certificate struct { |
||||
Nonce []byte |
||||
Key PublicKey |
||||
Serial uint64 |
||||
CertType uint32 |
||||
KeyId string |
||||
ValidPrincipals []string |
||||
ValidAfter uint64 |
||||
ValidBefore uint64 |
||||
Permissions |
||||
Reserved []byte |
||||
SignatureKey PublicKey |
||||
Signature *Signature |
||||
} |
||||
|
||||
// genericCertData holds the key-independent part of the certificate data.
|
||||
// Overall, certificates contain an nonce, public key fields and
|
||||
// key-independent fields.
|
||||
type genericCertData struct { |
||||
Serial uint64 |
||||
CertType uint32 |
||||
KeyId string |
||||
ValidPrincipals []byte |
||||
ValidAfter uint64 |
||||
ValidBefore uint64 |
||||
CriticalOptions []byte |
||||
Extensions []byte |
||||
Reserved []byte |
||||
SignatureKey []byte |
||||
Signature []byte |
||||
} |
||||
|
||||
func marshalStringList(namelist []string) []byte { |
||||
var to []byte |
||||
for _, name := range namelist { |
||||
s := struct{ N string }{name} |
||||
to = append(to, Marshal(&s)...) |
||||
} |
||||
return to |
||||
} |
||||
|
||||
type optionsTuple struct { |
||||
Key string |
||||
Value []byte |
||||
} |
||||
|
||||
type optionsTupleValue struct { |
||||
Value string |
||||
} |
||||
|
||||
// serialize a map of critical options or extensions
|
||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||
// we need two length prefixes for a non-empty string value
|
||||
func marshalTuples(tups map[string]string) []byte { |
||||
keys := make([]string, 0, len(tups)) |
||||
for key := range tups { |
||||
keys = append(keys, key) |
||||
} |
||||
sort.Strings(keys) |
||||
|
||||
var ret []byte |
||||
for _, key := range keys { |
||||
s := optionsTuple{Key: key} |
||||
if value := tups[key]; len(value) > 0 { |
||||
s.Value = Marshal(&optionsTupleValue{value}) |
||||
} |
||||
ret = append(ret, Marshal(&s)...) |
||||
} |
||||
return ret |
||||
} |
||||
|
||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||
// we need two length prefixes for a non-empty option value
|
||||
func parseTuples(in []byte) (map[string]string, error) { |
||||
tups := map[string]string{} |
||||
var lastKey string |
||||
var haveLastKey bool |
||||
|
||||
for len(in) > 0 { |
||||
var key, val, extra []byte |
||||
var ok bool |
||||
|
||||
if key, in, ok = parseString(in); !ok { |
||||
return nil, errShortRead |
||||
} |
||||
keyStr := string(key) |
||||
// according to [PROTOCOL.certkeys], the names must be in
|
||||
// lexical order.
|
||||
if haveLastKey && keyStr <= lastKey { |
||||
return nil, fmt.Errorf("ssh: certificate options are not in lexical order") |
||||
} |
||||
lastKey, haveLastKey = keyStr, true |
||||
// the next field is a data field, which if non-empty has a string embedded
|
||||
if val, in, ok = parseString(in); !ok { |
||||
return nil, errShortRead |
||||
} |
||||
if len(val) > 0 { |
||||
val, extra, ok = parseString(val) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
if len(extra) > 0 { |
||||
return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") |
||||
} |
||||
tups[keyStr] = string(val) |
||||
} else { |
||||
tups[keyStr] = "" |
||||
} |
||||
} |
||||
return tups, nil |
||||
} |
||||
|
||||
func parseCert(in []byte, privAlgo string) (*Certificate, error) { |
||||
nonce, rest, ok := parseString(in) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
|
||||
key, rest, err := parsePubKey(rest, privAlgo) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var g genericCertData |
||||
if err := Unmarshal(rest, &g); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c := &Certificate{ |
||||
Nonce: nonce, |
||||
Key: key, |
||||
Serial: g.Serial, |
||||
CertType: g.CertType, |
||||
KeyId: g.KeyId, |
||||
ValidAfter: g.ValidAfter, |
||||
ValidBefore: g.ValidBefore, |
||||
} |
||||
|
||||
for principals := g.ValidPrincipals; len(principals) > 0; { |
||||
principal, rest, ok := parseString(principals) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) |
||||
principals = rest |
||||
} |
||||
|
||||
c.CriticalOptions, err = parseTuples(g.CriticalOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.Extensions, err = parseTuples(g.Extensions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.Reserved = g.Reserved |
||||
k, err := ParsePublicKey(g.SignatureKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c.SignatureKey = k |
||||
c.Signature, rest, ok = parseSignatureBody(g.Signature) |
||||
if !ok || len(rest) > 0 { |
||||
return nil, errors.New("ssh: signature parse error") |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
type openSSHCertSigner struct { |
||||
pub *Certificate |
||||
signer Signer |
||||
} |
||||
|
||||
// NewCertSigner returns a Signer that signs with the given Certificate, whose
|
||||
// private key is held by signer. It returns an error if the public key in cert
|
||||
// doesn't match the key used by signer.
|
||||
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { |
||||
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { |
||||
return nil, errors.New("ssh: signer and cert have different public key") |
||||
} |
||||
|
||||
return &openSSHCertSigner{cert, signer}, nil |
||||
} |
||||
|
||||
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
return s.signer.Sign(rand, data) |
||||
} |
||||
|
||||
func (s *openSSHCertSigner) PublicKey() PublicKey { |
||||
return s.pub |
||||
} |
||||
|
||||
const sourceAddressCriticalOption = "source-address" |
||||
|
||||
// CertChecker does the work of verifying a certificate. Its methods
|
||||
// can be plugged into ClientConfig.HostKeyCallback and
|
||||
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
|
||||
// minimally, the IsAuthority callback should be set.
|
||||
type CertChecker struct { |
||||
// SupportedCriticalOptions lists the CriticalOptions that the
|
||||
// server application layer understands. These are only used
|
||||
// for user certificates.
|
||||
SupportedCriticalOptions []string |
||||
|
||||
// IsAuthority should return true if the key is recognized as
|
||||
// an authority. This allows for certificates to be signed by other
|
||||
// certificates.
|
||||
IsAuthority func(auth PublicKey) bool |
||||
|
||||
// Clock is used for verifying time stamps. If nil, time.Now
|
||||
// is used.
|
||||
Clock func() time.Time |
||||
|
||||
// UserKeyFallback is called when CertChecker.Authenticate encounters a
|
||||
// public key that is not a certificate. It must implement validation
|
||||
// of user keys or else, if nil, all such keys are rejected.
|
||||
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) |
||||
|
||||
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
|
||||
// public key that is not a certificate. It must implement host key
|
||||
// validation or else, if nil, all such keys are rejected.
|
||||
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error |
||||
|
||||
// IsRevoked is called for each certificate so that revocation checking
|
||||
// can be implemented. It should return true if the given certificate
|
||||
// is revoked and false otherwise. If nil, no certificates are
|
||||
// considered to have been revoked.
|
||||
IsRevoked func(cert *Certificate) bool |
||||
} |
||||
|
||||
// CheckHostKey checks a host key certificate. This method can be
|
||||
// plugged into ClientConfig.HostKeyCallback.
|
||||
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { |
||||
cert, ok := key.(*Certificate) |
||||
if !ok { |
||||
if c.HostKeyFallback != nil { |
||||
return c.HostKeyFallback(addr, remote, key) |
||||
} |
||||
return errors.New("ssh: non-certificate host key") |
||||
} |
||||
if cert.CertType != HostCert { |
||||
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) |
||||
} |
||||
|
||||
return c.CheckCert(addr, cert) |
||||
} |
||||
|
||||
// Authenticate checks a user certificate. Authenticate can be used as
|
||||
// a value for ServerConfig.PublicKeyCallback.
|
||||
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { |
||||
cert, ok := pubKey.(*Certificate) |
||||
if !ok { |
||||
if c.UserKeyFallback != nil { |
||||
return c.UserKeyFallback(conn, pubKey) |
||||
} |
||||
return nil, errors.New("ssh: normal key pairs not accepted") |
||||
} |
||||
|
||||
if cert.CertType != UserCert { |
||||
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) |
||||
} |
||||
|
||||
if err := c.CheckCert(conn.User(), cert); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &cert.Permissions, nil |
||||
} |
||||
|
||||
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
|
||||
// the signature of the certificate.
|
||||
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { |
||||
if c.IsRevoked != nil && c.IsRevoked(cert) { |
||||
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) |
||||
} |
||||
|
||||
for opt, _ := range cert.CriticalOptions { |
||||
// sourceAddressCriticalOption will be enforced by
|
||||
// serverAuthenticate
|
||||
if opt == sourceAddressCriticalOption { |
||||
continue |
||||
} |
||||
|
||||
found := false |
||||
for _, supp := range c.SupportedCriticalOptions { |
||||
if supp == opt { |
||||
found = true |
||||
break |
||||
} |
||||
} |
||||
if !found { |
||||
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) |
||||
} |
||||
} |
||||
|
||||
if len(cert.ValidPrincipals) > 0 { |
||||
// By default, certs are valid for all users/hosts.
|
||||
found := false |
||||
for _, p := range cert.ValidPrincipals { |
||||
if p == principal { |
||||
found = true |
||||
break |
||||
} |
||||
} |
||||
if !found { |
||||
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) |
||||
} |
||||
} |
||||
|
||||
if !c.IsAuthority(cert.SignatureKey) { |
||||
return fmt.Errorf("ssh: certificate signed by unrecognized authority") |
||||
} |
||||
|
||||
clock := c.Clock |
||||
if clock == nil { |
||||
clock = time.Now |
||||
} |
||||
|
||||
unixNow := clock().Unix() |
||||
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { |
||||
return fmt.Errorf("ssh: cert is not yet valid") |
||||
} |
||||
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { |
||||
return fmt.Errorf("ssh: cert has expired") |
||||
} |
||||
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { |
||||
return fmt.Errorf("ssh: certificate signature does not verify") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// SignCert sets c.SignatureKey to the authority's public key and stores a
|
||||
// Signature, by authority, in the certificate.
|
||||
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { |
||||
c.Nonce = make([]byte, 32) |
||||
if _, err := io.ReadFull(rand, c.Nonce); err != nil { |
||||
return err |
||||
} |
||||
c.SignatureKey = authority.PublicKey() |
||||
|
||||
sig, err := authority.Sign(rand, c.bytesForSigning()) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
c.Signature = sig |
||||
return nil |
||||
} |
||||
|
||||
var certAlgoNames = map[string]string{ |
||||
KeyAlgoRSA: CertAlgoRSAv01, |
||||
KeyAlgoDSA: CertAlgoDSAv01, |
||||
KeyAlgoECDSA256: CertAlgoECDSA256v01, |
||||
KeyAlgoECDSA384: CertAlgoECDSA384v01, |
||||
KeyAlgoECDSA521: CertAlgoECDSA521v01, |
||||
KeyAlgoED25519: CertAlgoED25519v01, |
||||
} |
||||
|
||||
// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
|
||||
// Panics if a non-certificate algorithm is passed.
|
||||
func certToPrivAlgo(algo string) string { |
||||
for privAlgo, pubAlgo := range certAlgoNames { |
||||
if pubAlgo == algo { |
||||
return privAlgo |
||||
} |
||||
} |
||||
panic("unknown cert algorithm") |
||||
} |
||||
|
||||
func (cert *Certificate) bytesForSigning() []byte { |
||||
c2 := *cert |
||||
c2.Signature = nil |
||||
out := c2.Marshal() |
||||
// Drop trailing signature length.
|
||||
return out[:len(out)-4] |
||||
} |
||||
|
||||
// Marshal serializes c into OpenSSH's wire format. It is part of the
|
||||
// PublicKey interface.
|
||||
func (c *Certificate) Marshal() []byte { |
||||
generic := genericCertData{ |
||||
Serial: c.Serial, |
||||
CertType: c.CertType, |
||||
KeyId: c.KeyId, |
||||
ValidPrincipals: marshalStringList(c.ValidPrincipals), |
||||
ValidAfter: uint64(c.ValidAfter), |
||||
ValidBefore: uint64(c.ValidBefore), |
||||
CriticalOptions: marshalTuples(c.CriticalOptions), |
||||
Extensions: marshalTuples(c.Extensions), |
||||
Reserved: c.Reserved, |
||||
SignatureKey: c.SignatureKey.Marshal(), |
||||
} |
||||
if c.Signature != nil { |
||||
generic.Signature = Marshal(c.Signature) |
||||
} |
||||
genericBytes := Marshal(&generic) |
||||
keyBytes := c.Key.Marshal() |
||||
_, keyBytes, _ = parseString(keyBytes) |
||||
prefix := Marshal(&struct { |
||||
Name string |
||||
Nonce []byte |
||||
Key []byte `ssh:"rest"` |
||||
}{c.Type(), c.Nonce, keyBytes}) |
||||
|
||||
result := make([]byte, 0, len(prefix)+len(genericBytes)) |
||||
result = append(result, prefix...) |
||||
result = append(result, genericBytes...) |
||||
return result |
||||
} |
||||
|
||||
// Type returns the key name. It is part of the PublicKey interface.
|
||||
func (c *Certificate) Type() string { |
||||
algo, ok := certAlgoNames[c.Key.Type()] |
||||
if !ok { |
||||
panic("unknown cert key type " + c.Key.Type()) |
||||
} |
||||
return algo |
||||
} |
||||
|
||||
// Verify verifies a signature against the certificate's public
|
||||
// key. It is part of the PublicKey interface.
|
||||
func (c *Certificate) Verify(data []byte, sig *Signature) error { |
||||
return c.Key.Verify(data, sig) |
||||
} |
||||
|
||||
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { |
||||
format, in, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
out = &Signature{ |
||||
Format: string(format), |
||||
} |
||||
|
||||
if out.Blob, in, ok = parseString(in); !ok { |
||||
return |
||||
} |
||||
|
||||
return out, in, ok |
||||
} |
||||
|
||||
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { |
||||
sigBytes, rest, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
out, trailing, ok := parseSignatureBody(sigBytes) |
||||
if !ok || len(trailing) > 0 { |
||||
return nil, nil, false |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,633 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"sync" |
||||
) |
||||
|
||||
const ( |
||||
minPacketLength = 9 |
||||
// channelMaxPacket contains the maximum number of bytes that will be
|
||||
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
||||
// the minimum.
|
||||
channelMaxPacket = 1 << 15 |
||||
// We follow OpenSSH here.
|
||||
channelWindowSize = 64 * channelMaxPacket |
||||
) |
||||
|
||||
// NewChannel represents an incoming request to a channel. It must either be
|
||||
// accepted for use by calling Accept, or rejected by calling Reject.
|
||||
type NewChannel interface { |
||||
// Accept accepts the channel creation request. It returns the Channel
|
||||
// and a Go channel containing SSH requests. The Go channel must be
|
||||
// serviced otherwise the Channel will hang.
|
||||
Accept() (Channel, <-chan *Request, error) |
||||
|
||||
// Reject rejects the channel creation request. After calling
|
||||
// this, no other methods on the Channel may be called.
|
||||
Reject(reason RejectionReason, message string) error |
||||
|
||||
// ChannelType returns the type of the channel, as supplied by the
|
||||
// client.
|
||||
ChannelType() string |
||||
|
||||
// ExtraData returns the arbitrary payload for this channel, as supplied
|
||||
// by the client. This data is specific to the channel type.
|
||||
ExtraData() []byte |
||||
} |
||||
|
||||
// A Channel is an ordered, reliable, flow-controlled, duplex stream
|
||||
// that is multiplexed over an SSH connection.
|
||||
type Channel interface { |
||||
// Read reads up to len(data) bytes from the channel.
|
||||
Read(data []byte) (int, error) |
||||
|
||||
// Write writes len(data) bytes to the channel.
|
||||
Write(data []byte) (int, error) |
||||
|
||||
// Close signals end of channel use. No data may be sent after this
|
||||
// call.
|
||||
Close() error |
||||
|
||||
// CloseWrite signals the end of sending in-band
|
||||
// data. Requests may still be sent, and the other side may
|
||||
// still send data
|
||||
CloseWrite() error |
||||
|
||||
// SendRequest sends a channel request. If wantReply is true,
|
||||
// it will wait for a reply and return the result as a
|
||||
// boolean, otherwise the return value will be false. Channel
|
||||
// requests are out-of-band messages so they may be sent even
|
||||
// if the data stream is closed or blocked by flow control.
|
||||
// If the channel is closed before a reply is returned, io.EOF
|
||||
// is returned.
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, error) |
||||
|
||||
// Stderr returns an io.ReadWriter that writes to this channel
|
||||
// with the extended data type set to stderr. Stderr may
|
||||
// safely be read and written from a different goroutine than
|
||||
// Read and Write respectively.
|
||||
Stderr() io.ReadWriter |
||||
} |
||||
|
||||
// Request is a request sent outside of the normal stream of
|
||||
// data. Requests can either be specific to an SSH channel, or they
|
||||
// can be global.
|
||||
type Request struct { |
||||
Type string |
||||
WantReply bool |
||||
Payload []byte |
||||
|
||||
ch *channel |
||||
mux *mux |
||||
} |
||||
|
||||
// Reply sends a response to a request. It must be called for all requests
|
||||
// where WantReply is true and is a no-op otherwise. The payload argument is
|
||||
// ignored for replies to channel-specific requests.
|
||||
func (r *Request) Reply(ok bool, payload []byte) error { |
||||
if !r.WantReply { |
||||
return nil |
||||
} |
||||
|
||||
if r.ch == nil { |
||||
return r.mux.ackRequest(ok, payload) |
||||
} |
||||
|
||||
return r.ch.ackRequest(ok) |
||||
} |
||||
|
||||
// RejectionReason is an enumeration used when rejecting channel creation
|
||||
// requests. See RFC 4254, section 5.1.
|
||||
type RejectionReason uint32 |
||||
|
||||
const ( |
||||
Prohibited RejectionReason = iota + 1 |
||||
ConnectionFailed |
||||
UnknownChannelType |
||||
ResourceShortage |
||||
) |
||||
|
||||
// String converts the rejection reason to human readable form.
|
||||
func (r RejectionReason) String() string { |
||||
switch r { |
||||
case Prohibited: |
||||
return "administratively prohibited" |
||||
case ConnectionFailed: |
||||
return "connect failed" |
||||
case UnknownChannelType: |
||||
return "unknown channel type" |
||||
case ResourceShortage: |
||||
return "resource shortage" |
||||
} |
||||
return fmt.Sprintf("unknown reason %d", int(r)) |
||||
} |
||||
|
||||
func min(a uint32, b int) uint32 { |
||||
if a < uint32(b) { |
||||
return a |
||||
} |
||||
return uint32(b) |
||||
} |
||||
|
||||
type channelDirection uint8 |
||||
|
||||
const ( |
||||
channelInbound channelDirection = iota |
||||
channelOutbound |
||||
) |
||||
|
||||
// channel is an implementation of the Channel interface that works
|
||||
// with the mux class.
|
||||
type channel struct { |
||||
// R/O after creation
|
||||
chanType string |
||||
extraData []byte |
||||
localId, remoteId uint32 |
||||
|
||||
// maxIncomingPayload and maxRemotePayload are the maximum
|
||||
// payload sizes of normal and extended data packets for
|
||||
// receiving and sending, respectively. The wire packet will
|
||||
// be 9 or 13 bytes larger (excluding encryption overhead).
|
||||
maxIncomingPayload uint32 |
||||
maxRemotePayload uint32 |
||||
|
||||
mux *mux |
||||
|
||||
// decided is set to true if an accept or reject message has been sent
|
||||
// (for outbound channels) or received (for inbound channels).
|
||||
decided bool |
||||
|
||||
// direction contains either channelOutbound, for channels created
|
||||
// locally, or channelInbound, for channels created by the peer.
|
||||
direction channelDirection |
||||
|
||||
// Pending internal channel messages.
|
||||
msg chan interface{} |
||||
|
||||
// Since requests have no ID, there can be only one request
|
||||
// with WantReply=true outstanding. This lock is held by a
|
||||
// goroutine that has such an outgoing request pending.
|
||||
sentRequestMu sync.Mutex |
||||
|
||||
incomingRequests chan *Request |
||||
|
||||
sentEOF bool |
||||
|
||||
// thread-safe data
|
||||
remoteWin window |
||||
pending *buffer |
||||
extPending *buffer |
||||
|
||||
// windowMu protects myWindow, the flow-control window.
|
||||
windowMu sync.Mutex |
||||
myWindow uint32 |
||||
|
||||
// writeMu serializes calls to mux.conn.writePacket() and
|
||||
// protects sentClose and packetPool. This mutex must be
|
||||
// different from windowMu, as writePacket can block if there
|
||||
// is a key exchange pending.
|
||||
writeMu sync.Mutex |
||||
sentClose bool |
||||
|
||||
// packetPool has a buffer for each extended channel ID to
|
||||
// save allocations during writes.
|
||||
packetPool map[uint32][]byte |
||||
} |
||||
|
||||
// writePacket sends a packet. If the packet is a channel close, it updates
|
||||
// sentClose. This method takes the lock c.writeMu.
|
||||
func (c *channel) writePacket(packet []byte) error { |
||||
c.writeMu.Lock() |
||||
if c.sentClose { |
||||
c.writeMu.Unlock() |
||||
return io.EOF |
||||
} |
||||
c.sentClose = (packet[0] == msgChannelClose) |
||||
err := c.mux.conn.writePacket(packet) |
||||
c.writeMu.Unlock() |
||||
return err |
||||
} |
||||
|
||||
func (c *channel) sendMessage(msg interface{}) error { |
||||
if debugMux { |
||||
log.Printf("send(%d): %#v", c.mux.chanList.offset, msg) |
||||
} |
||||
|
||||
p := Marshal(msg) |
||||
binary.BigEndian.PutUint32(p[1:], c.remoteId) |
||||
return c.writePacket(p) |
||||
} |
||||
|
||||
// WriteExtended writes data to a specific extended stream. These streams are
|
||||
// used, for example, for stderr.
|
||||
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { |
||||
if c.sentEOF { |
||||
return 0, io.EOF |
||||
} |
||||
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
|
||||
opCode := byte(msgChannelData) |
||||
headerLength := uint32(9) |
||||
if extendedCode > 0 { |
||||
headerLength += 4 |
||||
opCode = msgChannelExtendedData |
||||
} |
||||
|
||||
c.writeMu.Lock() |
||||
packet := c.packetPool[extendedCode] |
||||
// We don't remove the buffer from packetPool, so
|
||||
// WriteExtended calls from different goroutines will be
|
||||
// flagged as errors by the race detector.
|
||||
c.writeMu.Unlock() |
||||
|
||||
for len(data) > 0 { |
||||
space := min(c.maxRemotePayload, len(data)) |
||||
if space, err = c.remoteWin.reserve(space); err != nil { |
||||
return n, err |
||||
} |
||||
if want := headerLength + space; uint32(cap(packet)) < want { |
||||
packet = make([]byte, want) |
||||
} else { |
||||
packet = packet[:want] |
||||
} |
||||
|
||||
todo := data[:space] |
||||
|
||||
packet[0] = opCode |
||||
binary.BigEndian.PutUint32(packet[1:], c.remoteId) |
||||
if extendedCode > 0 { |
||||
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) |
||||
} |
||||
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) |
||||
copy(packet[headerLength:], todo) |
||||
if err = c.writePacket(packet); err != nil { |
||||
return n, err |
||||
} |
||||
|
||||
n += len(todo) |
||||
data = data[len(todo):] |
||||
} |
||||
|
||||
c.writeMu.Lock() |
||||
c.packetPool[extendedCode] = packet |
||||
c.writeMu.Unlock() |
||||
|
||||
return n, err |
||||
} |
||||
|
||||
func (c *channel) handleData(packet []byte) error { |
||||
headerLen := 9 |
||||
isExtendedData := packet[0] == msgChannelExtendedData |
||||
if isExtendedData { |
||||
headerLen = 13 |
||||
} |
||||
if len(packet) < headerLen { |
||||
// malformed data packet
|
||||
return parseError(packet[0]) |
||||
} |
||||
|
||||
var extended uint32 |
||||
if isExtendedData { |
||||
extended = binary.BigEndian.Uint32(packet[5:]) |
||||
} |
||||
|
||||
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) |
||||
if length == 0 { |
||||
return nil |
||||
} |
||||
if length > c.maxIncomingPayload { |
||||
// TODO(hanwen): should send Disconnect?
|
||||
return errors.New("ssh: incoming packet exceeds maximum payload size") |
||||
} |
||||
|
||||
data := packet[headerLen:] |
||||
if length != uint32(len(data)) { |
||||
return errors.New("ssh: wrong packet length") |
||||
} |
||||
|
||||
c.windowMu.Lock() |
||||
if c.myWindow < length { |
||||
c.windowMu.Unlock() |
||||
// TODO(hanwen): should send Disconnect with reason?
|
||||
return errors.New("ssh: remote side wrote too much") |
||||
} |
||||
c.myWindow -= length |
||||
c.windowMu.Unlock() |
||||
|
||||
if extended == 1 { |
||||
c.extPending.write(data) |
||||
} else if extended > 0 { |
||||
// discard other extended data.
|
||||
} else { |
||||
c.pending.write(data) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *channel) adjustWindow(n uint32) error { |
||||
c.windowMu.Lock() |
||||
// Since myWindow is managed on our side, and can never exceed
|
||||
// the initial window setting, we don't worry about overflow.
|
||||
c.myWindow += uint32(n) |
||||
c.windowMu.Unlock() |
||||
return c.sendMessage(windowAdjustMsg{ |
||||
AdditionalBytes: uint32(n), |
||||
}) |
||||
} |
||||
|
||||
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { |
||||
switch extended { |
||||
case 1: |
||||
n, err = c.extPending.Read(data) |
||||
case 0: |
||||
n, err = c.pending.Read(data) |
||||
default: |
||||
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) |
||||
} |
||||
|
||||
if n > 0 { |
||||
err = c.adjustWindow(uint32(n)) |
||||
// sendWindowAdjust can return io.EOF if the remote
|
||||
// peer has closed the connection, however we want to
|
||||
// defer forwarding io.EOF to the caller of Read until
|
||||
// the buffer has been drained.
|
||||
if n > 0 && err == io.EOF { |
||||
err = nil |
||||
} |
||||
} |
||||
|
||||
return n, err |
||||
} |
||||
|
||||
func (c *channel) close() { |
||||
c.pending.eof() |
||||
c.extPending.eof() |
||||
close(c.msg) |
||||
close(c.incomingRequests) |
||||
c.writeMu.Lock() |
||||
// This is not necessary for a normal channel teardown, but if
|
||||
// there was another error, it is.
|
||||
c.sentClose = true |
||||
c.writeMu.Unlock() |
||||
// Unblock writers.
|
||||
c.remoteWin.close() |
||||
} |
||||
|
||||
// responseMessageReceived is called when a success or failure message is
|
||||
// received on a channel to check that such a message is reasonable for the
|
||||
// given channel.
|
||||
func (c *channel) responseMessageReceived() error { |
||||
if c.direction == channelInbound { |
||||
return errors.New("ssh: channel response message received on inbound channel") |
||||
} |
||||
if c.decided { |
||||
return errors.New("ssh: duplicate response received for channel") |
||||
} |
||||
c.decided = true |
||||
return nil |
||||
} |
||||
|
||||
func (c *channel) handlePacket(packet []byte) error { |
||||
switch packet[0] { |
||||
case msgChannelData, msgChannelExtendedData: |
||||
return c.handleData(packet) |
||||
case msgChannelClose: |
||||
c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) |
||||
c.mux.chanList.remove(c.localId) |
||||
c.close() |
||||
return nil |
||||
case msgChannelEOF: |
||||
// RFC 4254 is mute on how EOF affects dataExt messages but
|
||||
// it is logical to signal EOF at the same time.
|
||||
c.extPending.eof() |
||||
c.pending.eof() |
||||
return nil |
||||
} |
||||
|
||||
decoded, err := decode(packet) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
switch msg := decoded.(type) { |
||||
case *channelOpenFailureMsg: |
||||
if err := c.responseMessageReceived(); err != nil { |
||||
return err |
||||
} |
||||
c.mux.chanList.remove(msg.PeersId) |
||||
c.msg <- msg |
||||
case *channelOpenConfirmMsg: |
||||
if err := c.responseMessageReceived(); err != nil { |
||||
return err |
||||
} |
||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
||||
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) |
||||
} |
||||
c.remoteId = msg.MyId |
||||
c.maxRemotePayload = msg.MaxPacketSize |
||||
c.remoteWin.add(msg.MyWindow) |
||||
c.msg <- msg |
||||
case *windowAdjustMsg: |
||||
if !c.remoteWin.add(msg.AdditionalBytes) { |
||||
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) |
||||
} |
||||
case *channelRequestMsg: |
||||
req := Request{ |
||||
Type: msg.Request, |
||||
WantReply: msg.WantReply, |
||||
Payload: msg.RequestSpecificData, |
||||
ch: c, |
||||
} |
||||
|
||||
c.incomingRequests <- &req |
||||
default: |
||||
c.msg <- msg |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { |
||||
ch := &channel{ |
||||
remoteWin: window{Cond: newCond()}, |
||||
myWindow: channelWindowSize, |
||||
pending: newBuffer(), |
||||
extPending: newBuffer(), |
||||
direction: direction, |
||||
incomingRequests: make(chan *Request, chanSize), |
||||
msg: make(chan interface{}, chanSize), |
||||
chanType: chanType, |
||||
extraData: extraData, |
||||
mux: m, |
||||
packetPool: make(map[uint32][]byte), |
||||
} |
||||
ch.localId = m.chanList.add(ch) |
||||
return ch |
||||
} |
||||
|
||||
var errUndecided = errors.New("ssh: must Accept or Reject channel") |
||||
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") |
||||
|
||||
type extChannel struct { |
||||
code uint32 |
||||
ch *channel |
||||
} |
||||
|
||||
func (e *extChannel) Write(data []byte) (n int, err error) { |
||||
return e.ch.WriteExtended(data, e.code) |
||||
} |
||||
|
||||
func (e *extChannel) Read(data []byte) (n int, err error) { |
||||
return e.ch.ReadExtended(data, e.code) |
||||
} |
||||
|
||||
func (c *channel) Accept() (Channel, <-chan *Request, error) { |
||||
if c.decided { |
||||
return nil, nil, errDecidedAlready |
||||
} |
||||
c.maxIncomingPayload = channelMaxPacket |
||||
confirm := channelOpenConfirmMsg{ |
||||
PeersId: c.remoteId, |
||||
MyId: c.localId, |
||||
MyWindow: c.myWindow, |
||||
MaxPacketSize: c.maxIncomingPayload, |
||||
} |
||||
c.decided = true |
||||
if err := c.sendMessage(confirm); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return c, c.incomingRequests, nil |
||||
} |
||||
|
||||
func (ch *channel) Reject(reason RejectionReason, message string) error { |
||||
if ch.decided { |
||||
return errDecidedAlready |
||||
} |
||||
reject := channelOpenFailureMsg{ |
||||
PeersId: ch.remoteId, |
||||
Reason: reason, |
||||
Message: message, |
||||
Language: "en", |
||||
} |
||||
ch.decided = true |
||||
return ch.sendMessage(reject) |
||||
} |
||||
|
||||
func (ch *channel) Read(data []byte) (int, error) { |
||||
if !ch.decided { |
||||
return 0, errUndecided |
||||
} |
||||
return ch.ReadExtended(data, 0) |
||||
} |
||||
|
||||
func (ch *channel) Write(data []byte) (int, error) { |
||||
if !ch.decided { |
||||
return 0, errUndecided |
||||
} |
||||
return ch.WriteExtended(data, 0) |
||||
} |
||||
|
||||
func (ch *channel) CloseWrite() error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
ch.sentEOF = true |
||||
return ch.sendMessage(channelEOFMsg{ |
||||
PeersId: ch.remoteId}) |
||||
} |
||||
|
||||
func (ch *channel) Close() error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
|
||||
return ch.sendMessage(channelCloseMsg{ |
||||
PeersId: ch.remoteId}) |
||||
} |
||||
|
||||
// Extended returns an io.ReadWriter that sends and receives data on the given,
|
||||
// SSH extended stream. Such streams are used, for example, for stderr.
|
||||
func (ch *channel) Extended(code uint32) io.ReadWriter { |
||||
if !ch.decided { |
||||
return nil |
||||
} |
||||
return &extChannel{code, ch} |
||||
} |
||||
|
||||
func (ch *channel) Stderr() io.ReadWriter { |
||||
return ch.Extended(1) |
||||
} |
||||
|
||||
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { |
||||
if !ch.decided { |
||||
return false, errUndecided |
||||
} |
||||
|
||||
if wantReply { |
||||
ch.sentRequestMu.Lock() |
||||
defer ch.sentRequestMu.Unlock() |
||||
} |
||||
|
||||
msg := channelRequestMsg{ |
||||
PeersId: ch.remoteId, |
||||
Request: name, |
||||
WantReply: wantReply, |
||||
RequestSpecificData: payload, |
||||
} |
||||
|
||||
if err := ch.sendMessage(msg); err != nil { |
||||
return false, err |
||||
} |
||||
|
||||
if wantReply { |
||||
m, ok := (<-ch.msg) |
||||
if !ok { |
||||
return false, io.EOF |
||||
} |
||||
switch m.(type) { |
||||
case *channelRequestFailureMsg: |
||||
return false, nil |
||||
case *channelRequestSuccessMsg: |
||||
return true, nil |
||||
default: |
||||
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) |
||||
} |
||||
} |
||||
|
||||
return false, nil |
||||
} |
||||
|
||||
// ackRequest either sends an ack or nack to the channel request.
|
||||
func (ch *channel) ackRequest(ok bool) error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
|
||||
var msg interface{} |
||||
if !ok { |
||||
msg = channelRequestFailureMsg{ |
||||
PeersId: ch.remoteId, |
||||
} |
||||
} else { |
||||
msg = channelRequestSuccessMsg{ |
||||
PeersId: ch.remoteId, |
||||
} |
||||
} |
||||
return ch.sendMessage(msg) |
||||
} |
||||
|
||||
func (ch *channel) ChannelType() string { |
||||
return ch.chanType |
||||
} |
||||
|
||||
func (ch *channel) ExtraData() []byte { |
||||
return ch.extraData |
||||
} |
@ -0,0 +1,627 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto/aes" |
||||
"crypto/cipher" |
||||
"crypto/des" |
||||
"crypto/rc4" |
||||
"crypto/subtle" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"hash" |
||||
"io" |
||||
"io/ioutil" |
||||
) |
||||
|
||||
const ( |
||||
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
|
||||
|
||||
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
|
||||
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
|
||||
// indicates implementations SHOULD be able to handle larger packet sizes, but then
|
||||
// waffles on about reasonable limits.
|
||||
//
|
||||
// OpenSSH caps their maxPacket at 256kB so we choose to do
|
||||
// the same. maxPacket is also used to ensure that uint32
|
||||
// length fields do not overflow, so it should remain well
|
||||
// below 4G.
|
||||
maxPacket = 256 * 1024 |
||||
) |
||||
|
||||
// noneCipher implements cipher.Stream and provides no encryption. It is used
|
||||
// by the transport before the first key-exchange.
|
||||
type noneCipher struct{} |
||||
|
||||
func (c noneCipher) XORKeyStream(dst, src []byte) { |
||||
copy(dst, src) |
||||
} |
||||
|
||||
func newAESCTR(key, iv []byte) (cipher.Stream, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return cipher.NewCTR(c, iv), nil |
||||
} |
||||
|
||||
func newRC4(key, iv []byte) (cipher.Stream, error) { |
||||
return rc4.NewCipher(key) |
||||
} |
||||
|
||||
type streamCipherMode struct { |
||||
keySize int |
||||
ivSize int |
||||
skip int |
||||
createFunc func(key, iv []byte) (cipher.Stream, error) |
||||
} |
||||
|
||||
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { |
||||
if len(key) < c.keySize { |
||||
panic("ssh: key length too small for cipher") |
||||
} |
||||
if len(iv) < c.ivSize { |
||||
panic("ssh: iv too small for cipher") |
||||
} |
||||
|
||||
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var streamDump []byte |
||||
if c.skip > 0 { |
||||
streamDump = make([]byte, 512) |
||||
} |
||||
|
||||
for remainingToDump := c.skip; remainingToDump > 0; { |
||||
dumpThisTime := remainingToDump |
||||
if dumpThisTime > len(streamDump) { |
||||
dumpThisTime = len(streamDump) |
||||
} |
||||
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) |
||||
remainingToDump -= dumpThisTime |
||||
} |
||||
|
||||
return stream, nil |
||||
} |
||||
|
||||
// cipherModes documents properties of supported ciphers. Ciphers not included
|
||||
// are not supported and will not be negotiated, even if explicitly requested in
|
||||
// ClientConfig.Crypto.Ciphers.
|
||||
var cipherModes = map[string]*streamCipherMode{ |
||||
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
|
||||
// are defined in the order specified in the RFC.
|
||||
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, |
||||
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR}, |
||||
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR}, |
||||
|
||||
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
|
||||
// They are defined in the order specified in the RFC.
|
||||
"arcfour128": {16, 0, 1536, newRC4}, |
||||
"arcfour256": {32, 0, 1536, newRC4}, |
||||
|
||||
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
|
||||
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
|
||||
// RC4) has problems with weak keys, and should be used with caution."
|
||||
// RFC4345 introduces improved versions of Arcfour.
|
||||
"arcfour": {16, 0, 0, newRC4}, |
||||
|
||||
// AES-GCM is not a stream cipher, so it is constructed with a
|
||||
// special case. If we add any more non-stream ciphers, we
|
||||
// should invest a cleaner way to do this.
|
||||
gcmCipherID: {16, 12, 0, nil}, |
||||
|
||||
// CBC mode is insecure and so is not included in the default config.
|
||||
// (See http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf). If absolutely
|
||||
// needed, it's possible to specify a custom Config to enable it.
|
||||
// You should expect that an active attacker can recover plaintext if
|
||||
// you do.
|
||||
aes128cbcID: {16, aes.BlockSize, 0, nil}, |
||||
|
||||
// 3des-cbc is insecure and is disabled by default.
|
||||
tripledescbcID: {24, des.BlockSize, 0, nil}, |
||||
} |
||||
|
||||
// prefixLen is the length of the packet prefix that contains the packet length
|
||||
// and number of padding bytes.
|
||||
const prefixLen = 5 |
||||
|
||||
// streamPacketCipher is a packetCipher using a stream cipher.
|
||||
type streamPacketCipher struct { |
||||
mac hash.Hash |
||||
cipher cipher.Stream |
||||
etm bool |
||||
|
||||
// The following members are to avoid per-packet allocations.
|
||||
prefix [prefixLen]byte |
||||
seqNumBytes [4]byte |
||||
padding [2 * packetSizeMultiple]byte |
||||
packetData []byte |
||||
macResult []byte |
||||
} |
||||
|
||||
// readPacket reads and decrypt a single packet from the reader argument.
|
||||
func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
if _, err := io.ReadFull(r, s.prefix[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var encryptedPaddingLength [1]byte |
||||
if s.mac != nil && s.etm { |
||||
copy(encryptedPaddingLength[:], s.prefix[4:5]) |
||||
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) |
||||
} else { |
||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) |
||||
} |
||||
|
||||
length := binary.BigEndian.Uint32(s.prefix[0:4]) |
||||
paddingLength := uint32(s.prefix[4]) |
||||
|
||||
var macSize uint32 |
||||
if s.mac != nil { |
||||
s.mac.Reset() |
||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) |
||||
s.mac.Write(s.seqNumBytes[:]) |
||||
if s.etm { |
||||
s.mac.Write(s.prefix[:4]) |
||||
s.mac.Write(encryptedPaddingLength[:]) |
||||
} else { |
||||
s.mac.Write(s.prefix[:]) |
||||
} |
||||
macSize = uint32(s.mac.Size()) |
||||
} |
||||
|
||||
if length <= paddingLength+1 { |
||||
return nil, errors.New("ssh: invalid packet length, packet too small") |
||||
} |
||||
|
||||
if length > maxPacket { |
||||
return nil, errors.New("ssh: invalid packet length, packet too large") |
||||
} |
||||
|
||||
// the maxPacket check above ensures that length-1+macSize
|
||||
// does not overflow.
|
||||
if uint32(cap(s.packetData)) < length-1+macSize { |
||||
s.packetData = make([]byte, length-1+macSize) |
||||
} else { |
||||
s.packetData = s.packetData[:length-1+macSize] |
||||
} |
||||
|
||||
if _, err := io.ReadFull(r, s.packetData); err != nil { |
||||
return nil, err |
||||
} |
||||
mac := s.packetData[length-1:] |
||||
data := s.packetData[:length-1] |
||||
|
||||
if s.mac != nil && s.etm { |
||||
s.mac.Write(data) |
||||
} |
||||
|
||||
s.cipher.XORKeyStream(data, data) |
||||
|
||||
if s.mac != nil { |
||||
if !s.etm { |
||||
s.mac.Write(data) |
||||
} |
||||
s.macResult = s.mac.Sum(s.macResult[:0]) |
||||
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { |
||||
return nil, errors.New("ssh: MAC failure") |
||||
} |
||||
} |
||||
|
||||
return s.packetData[:length-paddingLength-1], nil |
||||
} |
||||
|
||||
// writePacket encrypts and sends a packet of data to the writer argument
|
||||
func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
if len(packet) > maxPacket { |
||||
return errors.New("ssh: packet too large") |
||||
} |
||||
|
||||
aadlen := 0 |
||||
if s.mac != nil && s.etm { |
||||
// packet length is not encrypted for EtM modes
|
||||
aadlen = 4 |
||||
} |
||||
|
||||
paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple |
||||
if paddingLength < 4 { |
||||
paddingLength += packetSizeMultiple |
||||
} |
||||
|
||||
length := len(packet) + 1 + paddingLength |
||||
binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) |
||||
s.prefix[4] = byte(paddingLength) |
||||
padding := s.padding[:paddingLength] |
||||
if _, err := io.ReadFull(rand, padding); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if s.mac != nil { |
||||
s.mac.Reset() |
||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) |
||||
s.mac.Write(s.seqNumBytes[:]) |
||||
|
||||
if s.etm { |
||||
// For EtM algorithms, the packet length must stay unencrypted,
|
||||
// but the following data (padding length) must be encrypted
|
||||
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) |
||||
} |
||||
|
||||
s.mac.Write(s.prefix[:]) |
||||
|
||||
if !s.etm { |
||||
// For non-EtM algorithms, the algorithm is applied on unencrypted data
|
||||
s.mac.Write(packet) |
||||
s.mac.Write(padding) |
||||
} |
||||
} |
||||
|
||||
if !(s.mac != nil && s.etm) { |
||||
// For EtM algorithms, the padding length has already been encrypted
|
||||
// and the packet length must remain unencrypted
|
||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) |
||||
} |
||||
|
||||
s.cipher.XORKeyStream(packet, packet) |
||||
s.cipher.XORKeyStream(padding, padding) |
||||
|
||||
if s.mac != nil && s.etm { |
||||
// For EtM algorithms, packet and padding must be encrypted
|
||||
s.mac.Write(packet) |
||||
s.mac.Write(padding) |
||||
} |
||||
|
||||
if _, err := w.Write(s.prefix[:]); err != nil { |
||||
return err |
||||
} |
||||
if _, err := w.Write(packet); err != nil { |
||||
return err |
||||
} |
||||
if _, err := w.Write(padding); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if s.mac != nil { |
||||
s.macResult = s.mac.Sum(s.macResult[:0]) |
||||
if _, err := w.Write(s.macResult); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
type gcmCipher struct { |
||||
aead cipher.AEAD |
||||
prefix [4]byte |
||||
iv []byte |
||||
buf []byte |
||||
} |
||||
|
||||
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
aead, err := cipher.NewGCM(c) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &gcmCipher{ |
||||
aead: aead, |
||||
iv: iv, |
||||
}, nil |
||||
} |
||||
|
||||
const gcmTagSize = 16 |
||||
|
||||
func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
// Pad out to multiple of 16 bytes. This is different from the
|
||||
// stream cipher because that encrypts the length too.
|
||||
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) |
||||
if padding < 4 { |
||||
padding += packetSizeMultiple |
||||
} |
||||
|
||||
length := uint32(len(packet) + int(padding) + 1) |
||||
binary.BigEndian.PutUint32(c.prefix[:], length) |
||||
if _, err := w.Write(c.prefix[:]); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if cap(c.buf) < int(length) { |
||||
c.buf = make([]byte, length) |
||||
} else { |
||||
c.buf = c.buf[:length] |
||||
} |
||||
|
||||
c.buf[0] = padding |
||||
copy(c.buf[1:], packet) |
||||
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { |
||||
return err |
||||
} |
||||
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) |
||||
if _, err := w.Write(c.buf); err != nil { |
||||
return err |
||||
} |
||||
c.incIV() |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (c *gcmCipher) incIV() { |
||||
for i := 4 + 7; i >= 4; i-- { |
||||
c.iv[i]++ |
||||
if c.iv[i] != 0 { |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
if _, err := io.ReadFull(r, c.prefix[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
length := binary.BigEndian.Uint32(c.prefix[:]) |
||||
if length > maxPacket { |
||||
return nil, errors.New("ssh: max packet length exceeded.") |
||||
} |
||||
|
||||
if cap(c.buf) < int(length+gcmTagSize) { |
||||
c.buf = make([]byte, length+gcmTagSize) |
||||
} else { |
||||
c.buf = c.buf[:length+gcmTagSize] |
||||
} |
||||
|
||||
if _, err := io.ReadFull(r, c.buf); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.incIV() |
||||
|
||||
padding := plain[0] |
||||
if padding < 4 || padding >= 20 { |
||||
return nil, fmt.Errorf("ssh: illegal padding %d", padding) |
||||
} |
||||
|
||||
if int(padding+1) >= len(plain) { |
||||
return nil, fmt.Errorf("ssh: padding %d too large", padding) |
||||
} |
||||
plain = plain[1 : length-uint32(padding)] |
||||
return plain, nil |
||||
} |
||||
|
||||
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
|
||||
type cbcCipher struct { |
||||
mac hash.Hash |
||||
macSize uint32 |
||||
decrypter cipher.BlockMode |
||||
encrypter cipher.BlockMode |
||||
|
||||
// The following members are to avoid per-packet allocations.
|
||||
seqNumBytes [4]byte |
||||
packetData []byte |
||||
macResult []byte |
||||
|
||||
// Amount of data we should still read to hide which
|
||||
// verification error triggered.
|
||||
oracleCamouflage uint32 |
||||
} |
||||
|
||||
func newCBCCipher(c cipher.Block, iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { |
||||
cbc := &cbcCipher{ |
||||
mac: macModes[algs.MAC].new(macKey), |
||||
decrypter: cipher.NewCBCDecrypter(c, iv), |
||||
encrypter: cipher.NewCBCEncrypter(c, iv), |
||||
packetData: make([]byte, 1024), |
||||
} |
||||
if cbc.mac != nil { |
||||
cbc.macSize = uint32(cbc.mac.Size()) |
||||
} |
||||
|
||||
return cbc, nil |
||||
} |
||||
|
||||
func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
cbc, err := newCBCCipher(c, iv, key, macKey, algs) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return cbc, nil |
||||
} |
||||
|
||||
func newTripleDESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { |
||||
c, err := des.NewTripleDESCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
cbc, err := newCBCCipher(c, iv, key, macKey, algs) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return cbc, nil |
||||
} |
||||
|
||||
func maxUInt32(a, b int) uint32 { |
||||
if a > b { |
||||
return uint32(a) |
||||
} |
||||
return uint32(b) |
||||
} |
||||
|
||||
const ( |
||||
cbcMinPacketSizeMultiple = 8 |
||||
cbcMinPacketSize = 16 |
||||
cbcMinPaddingSize = 4 |
||||
) |
||||
|
||||
// cbcError represents a verification error that may leak information.
|
||||
type cbcError string |
||||
|
||||
func (e cbcError) Error() string { return string(e) } |
||||
|
||||
func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
p, err := c.readPacketLeaky(seqNum, r) |
||||
if err != nil { |
||||
if _, ok := err.(cbcError); ok { |
||||
// Verification error: read a fixed amount of
|
||||
// data, to make distinguishing between
|
||||
// failing MAC and failing length check more
|
||||
// difficult.
|
||||
io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) |
||||
} |
||||
} |
||||
return p, err |
||||
} |
||||
|
||||
func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
blockSize := c.decrypter.BlockSize() |
||||
|
||||
// Read the header, which will include some of the subsequent data in the
|
||||
// case of block ciphers - this is copied back to the payload later.
|
||||
// How many bytes of payload/padding will be read with this first read.
|
||||
firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) |
||||
firstBlock := c.packetData[:firstBlockLength] |
||||
if _, err := io.ReadFull(r, firstBlock); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength |
||||
|
||||
c.decrypter.CryptBlocks(firstBlock, firstBlock) |
||||
length := binary.BigEndian.Uint32(firstBlock[:4]) |
||||
if length > maxPacket { |
||||
return nil, cbcError("ssh: packet too large") |
||||
} |
||||
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { |
||||
// The minimum size of a packet is 16 (or the cipher block size, whichever
|
||||
// is larger) bytes.
|
||||
return nil, cbcError("ssh: packet too small") |
||||
} |
||||
// The length of the packet (including the length field but not the MAC) must
|
||||
// be a multiple of the block size or 8, whichever is larger.
|
||||
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { |
||||
return nil, cbcError("ssh: invalid packet length multiple") |
||||
} |
||||
|
||||
paddingLength := uint32(firstBlock[4]) |
||||
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { |
||||
return nil, cbcError("ssh: invalid packet length") |
||||
} |
||||
|
||||
// Positions within the c.packetData buffer:
|
||||
macStart := 4 + length |
||||
paddingStart := macStart - paddingLength |
||||
|
||||
// Entire packet size, starting before length, ending at end of mac.
|
||||
entirePacketSize := macStart + c.macSize |
||||
|
||||
// Ensure c.packetData is large enough for the entire packet data.
|
||||
if uint32(cap(c.packetData)) < entirePacketSize { |
||||
// Still need to upsize and copy, but this should be rare at runtime, only
|
||||
// on upsizing the packetData buffer.
|
||||
c.packetData = make([]byte, entirePacketSize) |
||||
copy(c.packetData, firstBlock) |
||||
} else { |
||||
c.packetData = c.packetData[:entirePacketSize] |
||||
} |
||||
|
||||
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { |
||||
return nil, err |
||||
} else { |
||||
c.oracleCamouflage -= uint32(n) |
||||
} |
||||
|
||||
remainingCrypted := c.packetData[firstBlockLength:macStart] |
||||
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) |
||||
|
||||
mac := c.packetData[macStart:] |
||||
if c.mac != nil { |
||||
c.mac.Reset() |
||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) |
||||
c.mac.Write(c.seqNumBytes[:]) |
||||
c.mac.Write(c.packetData[:macStart]) |
||||
c.macResult = c.mac.Sum(c.macResult[:0]) |
||||
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { |
||||
return nil, cbcError("ssh: MAC failure") |
||||
} |
||||
} |
||||
|
||||
return c.packetData[prefixLen:paddingStart], nil |
||||
} |
||||
|
||||
func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) |
||||
|
||||
// Length of encrypted portion of the packet (header, payload, padding).
|
||||
// Enforce minimum padding and packet size.
|
||||
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) |
||||
// Enforce block size.
|
||||
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize |
||||
|
||||
length := encLength - 4 |
||||
paddingLength := int(length) - (1 + len(packet)) |
||||
|
||||
// Overall buffer contains: header, payload, padding, mac.
|
||||
// Space for the MAC is reserved in the capacity but not the slice length.
|
||||
bufferSize := encLength + c.macSize |
||||
if uint32(cap(c.packetData)) < bufferSize { |
||||
c.packetData = make([]byte, encLength, bufferSize) |
||||
} else { |
||||
c.packetData = c.packetData[:encLength] |
||||
} |
||||
|
||||
p := c.packetData |
||||
|
||||
// Packet header.
|
||||
binary.BigEndian.PutUint32(p, length) |
||||
p = p[4:] |
||||
p[0] = byte(paddingLength) |
||||
|
||||
// Payload.
|
||||
p = p[1:] |
||||
copy(p, packet) |
||||
|
||||
// Padding.
|
||||
p = p[len(packet):] |
||||
if _, err := io.ReadFull(rand, p); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if c.mac != nil { |
||||
c.mac.Reset() |
||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) |
||||
c.mac.Write(c.seqNumBytes[:]) |
||||
c.mac.Write(c.packetData) |
||||
// The MAC is now appended into the capacity reserved for it earlier.
|
||||
c.packetData = c.mac.Sum(c.packetData) |
||||
} |
||||
|
||||
c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) |
||||
|
||||
if _, err := w.Write(c.packetData); err != nil { |
||||
return err |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,211 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
// Client implements a traditional SSH client that supports shells,
|
||||
// subprocesses, port forwarding and tunneled dialing.
|
||||
type Client struct { |
||||
Conn |
||||
|
||||
forwards forwardList // forwarded tcpip connections from the remote side
|
||||
mu sync.Mutex |
||||
channelHandlers map[string]chan NewChannel |
||||
} |
||||
|
||||
// HandleChannelOpen returns a channel on which NewChannel requests
|
||||
// for the given type are sent. If the type already is being handled,
|
||||
// nil is returned. The channel is closed when the connection is closed.
|
||||
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if c.channelHandlers == nil { |
||||
// The SSH channel has been closed.
|
||||
c := make(chan NewChannel) |
||||
close(c) |
||||
return c |
||||
} |
||||
|
||||
ch := c.channelHandlers[channelType] |
||||
if ch != nil { |
||||
return nil |
||||
} |
||||
|
||||
ch = make(chan NewChannel, chanSize) |
||||
c.channelHandlers[channelType] = ch |
||||
return ch |
||||
} |
||||
|
||||
// NewClient creates a Client on top of the given connection.
|
||||
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { |
||||
conn := &Client{ |
||||
Conn: c, |
||||
channelHandlers: make(map[string]chan NewChannel, 1), |
||||
} |
||||
|
||||
go conn.handleGlobalRequests(reqs) |
||||
go conn.handleChannelOpens(chans) |
||||
go func() { |
||||
conn.Wait() |
||||
conn.forwards.closeAll() |
||||
}() |
||||
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) |
||||
return conn |
||||
} |
||||
|
||||
// NewClientConn establishes an authenticated SSH connection using c
|
||||
// as the underlying transport. The Request and NewChannel channels
|
||||
// must be serviced or the connection will hang.
|
||||
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { |
||||
fullConf := *config |
||||
fullConf.SetDefaults() |
||||
conn := &connection{ |
||||
sshConn: sshConn{conn: c}, |
||||
} |
||||
|
||||
if err := conn.clientHandshake(addr, &fullConf); err != nil { |
||||
c.Close() |
||||
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) |
||||
} |
||||
conn.mux = newMux(conn.transport) |
||||
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil |
||||
} |
||||
|
||||
// clientHandshake performs the client side key exchange. See RFC 4253 Section
|
||||
// 7.
|
||||
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { |
||||
if config.ClientVersion != "" { |
||||
c.clientVersion = []byte(config.ClientVersion) |
||||
} else { |
||||
c.clientVersion = []byte(packageVersion) |
||||
} |
||||
var err error |
||||
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
c.transport = newClientTransport( |
||||
newTransport(c.sshConn.conn, config.Rand, true /* is client */), |
||||
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) |
||||
if err := c.transport.waitSession(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
c.sessionID = c.transport.getSessionID() |
||||
return c.clientAuthenticate(config) |
||||
} |
||||
|
||||
// verifyHostKeySignature verifies the host key obtained in the key
|
||||
// exchange.
|
||||
func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { |
||||
sig, rest, ok := parseSignatureBody(result.Signature) |
||||
if len(rest) > 0 || !ok { |
||||
return errors.New("ssh: signature parse error") |
||||
} |
||||
|
||||
return hostKey.Verify(result.H, sig) |
||||
} |
||||
|
||||
// NewSession opens a new Session for this client. (A session is a remote
|
||||
// execution of a program.)
|
||||
func (c *Client) NewSession() (*Session, error) { |
||||
ch, in, err := c.OpenChannel("session", nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return newSession(ch, in) |
||||
} |
||||
|
||||
func (c *Client) handleGlobalRequests(incoming <-chan *Request) { |
||||
for r := range incoming { |
||||
// This handles keepalive messages and matches
|
||||
// the behaviour of OpenSSH.
|
||||
r.Reply(false, nil) |
||||
} |
||||
} |
||||
|
||||
// handleChannelOpens channel open messages from the remote side.
|
||||
func (c *Client) handleChannelOpens(in <-chan NewChannel) { |
||||
for ch := range in { |
||||
c.mu.Lock() |
||||
handler := c.channelHandlers[ch.ChannelType()] |
||||
c.mu.Unlock() |
||||
|
||||
if handler != nil { |
||||
handler <- ch |
||||
} else { |
||||
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) |
||||
} |
||||
} |
||||
|
||||
c.mu.Lock() |
||||
for _, ch := range c.channelHandlers { |
||||
close(ch) |
||||
} |
||||
c.channelHandlers = nil |
||||
c.mu.Unlock() |
||||
} |
||||
|
||||
// Dial starts a client connection to the given SSH server. It is a
|
||||
// convenience function that connects to the given network address,
|
||||
// initiates the SSH handshake, and then sets up a Client. For access
|
||||
// to incoming channels and requests, use net.Dial with NewClientConn
|
||||
// instead.
|
||||
func Dial(network, addr string, config *ClientConfig) (*Client, error) { |
||||
conn, err := net.DialTimeout(network, addr, config.Timeout) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c, chans, reqs, err := NewClientConn(conn, addr, config) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return NewClient(c, chans, reqs), nil |
||||
} |
||||
|
||||
// A ClientConfig structure is used to configure a Client. It must not be
|
||||
// modified after having been passed to an SSH function.
|
||||
type ClientConfig struct { |
||||
// Config contains configuration that is shared between clients and
|
||||
// servers.
|
||||
Config |
||||
|
||||
// User contains the username to authenticate as.
|
||||
User string |
||||
|
||||
// Auth contains possible authentication methods to use with the
|
||||
// server. Only the first instance of a particular RFC 4252 method will
|
||||
// be used during authentication.
|
||||
Auth []AuthMethod |
||||
|
||||
// HostKeyCallback, if not nil, is called during the cryptographic
|
||||
// handshake to validate the server's host key. A nil HostKeyCallback
|
||||
// implies that all host keys are accepted.
|
||||
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error |
||||
|
||||
// ClientVersion contains the version identification string that will
|
||||
// be used for the connection. If empty, a reasonable default is used.
|
||||
ClientVersion string |
||||
|
||||
// HostKeyAlgorithms lists the key types that the client will
|
||||
// accept from the server as host key, in order of
|
||||
// preference. If empty, a reasonable default is used. Any
|
||||
// string returned from PublicKey.Type method may be used, or
|
||||
// any of the CertAlgoXxxx and KeyAlgoXxxx constants.
|
||||
HostKeyAlgorithms []string |
||||
|
||||
// Timeout is the maximum amount of time for the TCP connection to establish.
|
||||
//
|
||||
// A Timeout of zero means no timeout.
|
||||
Timeout time.Duration |
||||
} |
@ -0,0 +1,475 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
) |
||||
|
||||
// clientAuthenticate authenticates with the remote server. See RFC 4252.
|
||||
func (c *connection) clientAuthenticate(config *ClientConfig) error { |
||||
// initiate user auth session
|
||||
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { |
||||
return err |
||||
} |
||||
packet, err := c.transport.readPacket() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var serviceAccept serviceAcceptMsg |
||||
if err := Unmarshal(packet, &serviceAccept); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// during the authentication phase the client first attempts the "none" method
|
||||
// then any untried methods suggested by the server.
|
||||
tried := make(map[string]bool) |
||||
var lastMethods []string |
||||
|
||||
sessionID := c.transport.getSessionID() |
||||
for auth := AuthMethod(new(noneAuth)); auth != nil; { |
||||
ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if ok { |
||||
// success
|
||||
return nil |
||||
} |
||||
tried[auth.method()] = true |
||||
if methods == nil { |
||||
methods = lastMethods |
||||
} |
||||
lastMethods = methods |
||||
|
||||
auth = nil |
||||
|
||||
findNext: |
||||
for _, a := range config.Auth { |
||||
candidateMethod := a.method() |
||||
if tried[candidateMethod] { |
||||
continue |
||||
} |
||||
for _, meth := range methods { |
||||
if meth == candidateMethod { |
||||
auth = a |
||||
break findNext |
||||
} |
||||
} |
||||
} |
||||
} |
||||
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) |
||||
} |
||||
|
||||
func keys(m map[string]bool) []string { |
||||
s := make([]string, 0, len(m)) |
||||
|
||||
for key := range m { |
||||
s = append(s, key) |
||||
} |
||||
return s |
||||
} |
||||
|
||||
// An AuthMethod represents an instance of an RFC 4252 authentication method.
|
||||
type AuthMethod interface { |
||||
// auth authenticates user over transport t.
|
||||
// Returns true if authentication is successful.
|
||||
// If authentication is not successful, a []string of alternative
|
||||
// method names is returned. If the slice is nil, it will be ignored
|
||||
// and the previous set of possible methods will be reused.
|
||||
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error) |
||||
|
||||
// method returns the RFC 4252 method name.
|
||||
method() string |
||||
} |
||||
|
||||
// "none" authentication, RFC 4252 section 5.2.
|
||||
type noneAuth int |
||||
|
||||
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
if err := c.writePacket(Marshal(&userAuthRequestMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "none", |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
return handleAuthResponse(c) |
||||
} |
||||
|
||||
func (n *noneAuth) method() string { |
||||
return "none" |
||||
} |
||||
|
||||
// passwordCallback is an AuthMethod that fetches the password through
|
||||
// a function call, e.g. by prompting the user.
|
||||
type passwordCallback func() (password string, err error) |
||||
|
||||
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
type passwordAuthMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
Reply bool |
||||
Password string |
||||
} |
||||
|
||||
pw, err := cb() |
||||
// REVIEW NOTE: is there a need to support skipping a password attempt?
|
||||
// The program may only find out that the user doesn't have a password
|
||||
// when prompting.
|
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
if err := c.writePacket(Marshal(&passwordAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
Reply: false, |
||||
Password: pw, |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
return handleAuthResponse(c) |
||||
} |
||||
|
||||
func (cb passwordCallback) method() string { |
||||
return "password" |
||||
} |
||||
|
||||
// Password returns an AuthMethod using the given password.
|
||||
func Password(secret string) AuthMethod { |
||||
return passwordCallback(func() (string, error) { return secret, nil }) |
||||
} |
||||
|
||||
// PasswordCallback returns an AuthMethod that uses a callback for
|
||||
// fetching a password.
|
||||
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { |
||||
return passwordCallback(prompt) |
||||
} |
||||
|
||||
type publickeyAuthMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
// HasSig indicates to the receiver packet that the auth request is signed and
|
||||
// should be used for authentication of the request.
|
||||
HasSig bool |
||||
Algoname string |
||||
PubKey []byte |
||||
// Sig is tagged with "rest" so Marshal will exclude it during
|
||||
// validateKey
|
||||
Sig []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// publicKeyCallback is an AuthMethod that uses a set of key
|
||||
// pairs for authentication.
|
||||
type publicKeyCallback func() ([]Signer, error) |
||||
|
||||
func (cb publicKeyCallback) method() string { |
||||
return "publickey" |
||||
} |
||||
|
||||
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
// Authentication is performed in two stages. The first stage sends an
|
||||
// enquiry to test if each key is acceptable to the remote. The second
|
||||
// stage attempts to authenticate with the valid keys obtained in the
|
||||
// first stage.
|
||||
|
||||
signers, err := cb() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
var validKeys []Signer |
||||
for _, signer := range signers { |
||||
if ok, err := validateKey(signer.PublicKey(), user, c); ok { |
||||
validKeys = append(validKeys, signer) |
||||
} else { |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
} |
||||
} |
||||
|
||||
// methods that may continue if this auth is not successful.
|
||||
var methods []string |
||||
for _, signer := range validKeys { |
||||
pub := signer.PublicKey() |
||||
|
||||
pubKey := pub.Marshal() |
||||
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
}, []byte(pub.Type()), pubKey)) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// manually wrap the serialized signature in a string
|
||||
s := Marshal(sign) |
||||
sig := make([]byte, stringLength(len(s))) |
||||
marshalString(sig, s) |
||||
msg := publickeyAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
HasSig: true, |
||||
Algoname: pub.Type(), |
||||
PubKey: pubKey, |
||||
Sig: sig, |
||||
} |
||||
p := Marshal(&msg) |
||||
if err := c.writePacket(p); err != nil { |
||||
return false, nil, err |
||||
} |
||||
var success bool |
||||
success, methods, err = handleAuthResponse(c) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
if success { |
||||
return success, methods, err |
||||
} |
||||
} |
||||
return false, methods, nil |
||||
} |
||||
|
||||
// validateKey validates the key provided is acceptable to the server.
|
||||
func validateKey(key PublicKey, user string, c packetConn) (bool, error) { |
||||
pubKey := key.Marshal() |
||||
msg := publickeyAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "publickey", |
||||
HasSig: false, |
||||
Algoname: key.Type(), |
||||
PubKey: pubKey, |
||||
} |
||||
if err := c.writePacket(Marshal(&msg)); err != nil { |
||||
return false, err |
||||
} |
||||
|
||||
return confirmKeyAck(key, c) |
||||
} |
||||
|
||||
func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { |
||||
pubKey := key.Marshal() |
||||
algoname := key.Type() |
||||
|
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO(gpaul): add callback to present the banner to the user
|
||||
case msgUserAuthPubKeyOk: |
||||
var msg userAuthPubKeyOkMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, err |
||||
} |
||||
if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { |
||||
return false, nil |
||||
} |
||||
return true, nil |
||||
case msgUserAuthFailure: |
||||
return false, nil |
||||
default: |
||||
return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// PublicKeys returns an AuthMethod that uses the given key
|
||||
// pairs.
|
||||
func PublicKeys(signers ...Signer) AuthMethod { |
||||
return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) |
||||
} |
||||
|
||||
// PublicKeysCallback returns an AuthMethod that runs the given
|
||||
// function to obtain a list of key pairs.
|
||||
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { |
||||
return publicKeyCallback(getSigners) |
||||
} |
||||
|
||||
// handleAuthResponse returns whether the preceding authentication request succeeded
|
||||
// along with a list of remaining authentication methods to try next and
|
||||
// an error if an unexpected response was received.
|
||||
func handleAuthResponse(c packetConn) (bool, []string, error) { |
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO: add callback to present the banner to the user
|
||||
case msgUserAuthFailure: |
||||
var msg userAuthFailureMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
return false, msg.Methods, nil |
||||
case msgUserAuthSuccess: |
||||
return true, nil, nil |
||||
default: |
||||
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// KeyboardInteractiveChallenge should print questions, optionally
|
||||
// disabling echoing (e.g. for passwords), and return all the answers.
|
||||
// Challenge may be called multiple times in a single session. After
|
||||
// successful authentication, the server may send a challenge with no
|
||||
// questions, for which the user and instruction messages should be
|
||||
// printed. RFC 4256 section 3.3 details how the UI should behave for
|
||||
// both CLI and GUI environments.
|
||||
type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) |
||||
|
||||
// KeyboardInteractive returns a AuthMethod using a prompt/response
|
||||
// sequence controlled by the server.
|
||||
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { |
||||
return challenge |
||||
} |
||||
|
||||
func (cb KeyboardInteractiveChallenge) method() string { |
||||
return "keyboard-interactive" |
||||
} |
||||
|
||||
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
type initiateMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
Language string |
||||
Submethods string |
||||
} |
||||
|
||||
if err := c.writePacket(Marshal(&initiateMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "keyboard-interactive", |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// like handleAuthResponse, but with less options.
|
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO: Print banners during userauth.
|
||||
continue |
||||
case msgUserAuthInfoRequest: |
||||
// OK
|
||||
case msgUserAuthFailure: |
||||
var msg userAuthFailureMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
return false, msg.Methods, nil |
||||
case msgUserAuthSuccess: |
||||
return true, nil, nil |
||||
default: |
||||
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) |
||||
} |
||||
|
||||
var msg userAuthInfoRequestMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// Manually unpack the prompt/echo pairs.
|
||||
rest := msg.Prompts |
||||
var prompts []string |
||||
var echos []bool |
||||
for i := 0; i < int(msg.NumPrompts); i++ { |
||||
prompt, r, ok := parseString(rest) |
||||
if !ok || len(r) == 0 { |
||||
return false, nil, errors.New("ssh: prompt format error") |
||||
} |
||||
prompts = append(prompts, string(prompt)) |
||||
echos = append(echos, r[0] != 0) |
||||
rest = r[1:] |
||||
} |
||||
|
||||
if len(rest) != 0 { |
||||
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") |
||||
} |
||||
|
||||
answers, err := cb(msg.User, msg.Instruction, prompts, echos) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
if len(answers) != len(prompts) { |
||||
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") |
||||
} |
||||
responseLength := 1 + 4 |
||||
for _, a := range answers { |
||||
responseLength += stringLength(len(a)) |
||||
} |
||||
serialized := make([]byte, responseLength) |
||||
p := serialized |
||||
p[0] = msgUserAuthInfoResponse |
||||
p = p[1:] |
||||
p = marshalUint32(p, uint32(len(answers))) |
||||
for _, a := range answers { |
||||
p = marshalString(p, []byte(a)) |
||||
} |
||||
|
||||
if err := c.writePacket(serialized); err != nil { |
||||
return false, nil, err |
||||
} |
||||
} |
||||
} |
||||
|
||||
type retryableAuthMethod struct { |
||||
authMethod AuthMethod |
||||
maxTries int |
||||
} |
||||
|
||||
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok bool, methods []string, err error) { |
||||
for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ { |
||||
ok, methods, err = r.authMethod.auth(session, user, c, rand) |
||||
if ok || err != nil { // either success or error terminate
|
||||
return ok, methods, err |
||||
} |
||||
} |
||||
return ok, methods, err |
||||
} |
||||
|
||||
func (r *retryableAuthMethod) method() string { |
||||
return r.authMethod.method() |
||||
} |
||||
|
||||
// RetryableAuthMethod is a decorator for other auth methods enabling them to
|
||||
// be retried up to maxTries before considering that AuthMethod itself failed.
|
||||
// If maxTries is <= 0, will retry indefinitely
|
||||
//
|
||||
// This is useful for interactive clients using challenge/response type
|
||||
// authentication (e.g. Keyboard-Interactive, Password, etc) where the user
|
||||
// could mistype their response resulting in the server issuing a
|
||||
// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4
|
||||
// [keyboard-interactive]); Without this decorator, the non-retryable
|
||||
// AuthMethod would be removed from future consideration, and never tried again
|
||||
// (and so the user would never be able to retry their entry).
|
||||
func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod { |
||||
return &retryableAuthMethod{authMethod: auth, maxTries: maxTries} |
||||
} |
@ -0,0 +1,371 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto" |
||||
"crypto/rand" |
||||
"fmt" |
||||
"io" |
||||
"sync" |
||||
|
||||
_ "crypto/sha1" |
||||
_ "crypto/sha256" |
||||
_ "crypto/sha512" |
||||
) |
||||
|
||||
// These are string constants in the SSH protocol.
|
||||
const ( |
||||
compressionNone = "none" |
||||
serviceUserAuth = "ssh-userauth" |
||||
serviceSSH = "ssh-connection" |
||||
) |
||||
|
||||
// supportedCiphers specifies the supported ciphers in preference order.
|
||||
var supportedCiphers = []string{ |
||||
"aes128-ctr", "aes192-ctr", "aes256-ctr", |
||||
"aes128-gcm@openssh.com", |
||||
"arcfour256", "arcfour128", |
||||
} |
||||
|
||||
// supportedKexAlgos specifies the supported key-exchange algorithms in
|
||||
// preference order.
|
||||
var supportedKexAlgos = []string{ |
||||
kexAlgoCurve25519SHA256, |
||||
// P384 and P521 are not constant-time yet, but since we don't
|
||||
// reuse ephemeral keys, using them for ECDH should be OK.
|
||||
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, |
||||
kexAlgoDH14SHA1, kexAlgoDH1SHA1, |
||||
} |
||||
|
||||
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
|
||||
// of authenticating servers) in preference order.
|
||||
var supportedHostKeyAlgos = []string{ |
||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, |
||||
CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01, |
||||
|
||||
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, |
||||
KeyAlgoRSA, KeyAlgoDSA, |
||||
|
||||
KeyAlgoED25519, |
||||
} |
||||
|
||||
// supportedMACs specifies a default set of MAC algorithms in preference order.
|
||||
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
|
||||
// because they have reached the end of their useful life.
|
||||
var supportedMACs = []string{ |
||||
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", |
||||
} |
||||
|
||||
var supportedCompressions = []string{compressionNone} |
||||
|
||||
// hashFuncs keeps the mapping of supported algorithms to their respective
|
||||
// hashes needed for signature verification.
|
||||
var hashFuncs = map[string]crypto.Hash{ |
||||
KeyAlgoRSA: crypto.SHA1, |
||||
KeyAlgoDSA: crypto.SHA1, |
||||
KeyAlgoECDSA256: crypto.SHA256, |
||||
KeyAlgoECDSA384: crypto.SHA384, |
||||
KeyAlgoECDSA521: crypto.SHA512, |
||||
CertAlgoRSAv01: crypto.SHA1, |
||||
CertAlgoDSAv01: crypto.SHA1, |
||||
CertAlgoECDSA256v01: crypto.SHA256, |
||||
CertAlgoECDSA384v01: crypto.SHA384, |
||||
CertAlgoECDSA521v01: crypto.SHA512, |
||||
} |
||||
|
||||
// unexpectedMessageError results when the SSH message that we received didn't
|
||||
// match what we wanted.
|
||||
func unexpectedMessageError(expected, got uint8) error { |
||||
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) |
||||
} |
||||
|
||||
// parseError results from a malformed SSH message.
|
||||
func parseError(tag uint8) error { |
||||
return fmt.Errorf("ssh: parse error in message type %d", tag) |
||||
} |
||||
|
||||
func findCommon(what string, client []string, server []string) (common string, err error) { |
||||
for _, c := range client { |
||||
for _, s := range server { |
||||
if c == s { |
||||
return c, nil |
||||
} |
||||
} |
||||
} |
||||
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) |
||||
} |
||||
|
||||
type directionAlgorithms struct { |
||||
Cipher string |
||||
MAC string |
||||
Compression string |
||||
} |
||||
|
||||
// rekeyBytes returns a rekeying intervals in bytes.
|
||||
func (a *directionAlgorithms) rekeyBytes() int64 { |
||||
// According to RFC4344 block ciphers should rekey after
|
||||
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
|
||||
// 128.
|
||||
switch a.Cipher { |
||||
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID: |
||||
return 16 * (1 << 32) |
||||
|
||||
} |
||||
|
||||
// For others, stick with RFC4253 recommendation to rekey after 1 Gb of data.
|
||||
return 1 << 30 |
||||
} |
||||
|
||||
type algorithms struct { |
||||
kex string |
||||
hostKey string |
||||
w directionAlgorithms |
||||
r directionAlgorithms |
||||
} |
||||
|
||||
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { |
||||
result := &algorithms{} |
||||
|
||||
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
return result, nil |
||||
} |
||||
|
||||
// If rekeythreshold is too small, we can't make any progress sending
|
||||
// stuff.
|
||||
const minRekeyThreshold uint64 = 256 |
||||
|
||||
// Config contains configuration data common to both ServerConfig and
|
||||
// ClientConfig.
|
||||
type Config struct { |
||||
// Rand provides the source of entropy for cryptographic
|
||||
// primitives. If Rand is nil, the cryptographic random reader
|
||||
// in package crypto/rand will be used.
|
||||
Rand io.Reader |
||||
|
||||
// The maximum number of bytes sent or received after which a
|
||||
// new key is negotiated. It must be at least 256. If
|
||||
// unspecified, 1 gigabyte is used.
|
||||
RekeyThreshold uint64 |
||||
|
||||
// The allowed key exchanges algorithms. If unspecified then a
|
||||
// default set of algorithms is used.
|
||||
KeyExchanges []string |
||||
|
||||
// The allowed cipher algorithms. If unspecified then a sensible
|
||||
// default is used.
|
||||
Ciphers []string |
||||
|
||||
// The allowed MAC algorithms. If unspecified then a sensible default
|
||||
// is used.
|
||||
MACs []string |
||||
} |
||||
|
||||
// SetDefaults sets sensible values for unset fields in config. This is
|
||||
// exported for testing: Configs passed to SSH functions are copied and have
|
||||
// default values set automatically.
|
||||
func (c *Config) SetDefaults() { |
||||
if c.Rand == nil { |
||||
c.Rand = rand.Reader |
||||
} |
||||
if c.Ciphers == nil { |
||||
c.Ciphers = supportedCiphers |
||||
} |
||||
var ciphers []string |
||||
for _, c := range c.Ciphers { |
||||
if cipherModes[c] != nil { |
||||
// reject the cipher if we have no cipherModes definition
|
||||
ciphers = append(ciphers, c) |
||||
} |
||||
} |
||||
c.Ciphers = ciphers |
||||
|
||||
if c.KeyExchanges == nil { |
||||
c.KeyExchanges = supportedKexAlgos |
||||
} |
||||
|
||||
if c.MACs == nil { |
||||
c.MACs = supportedMACs |
||||
} |
||||
|
||||
if c.RekeyThreshold == 0 { |
||||
// RFC 4253, section 9 suggests rekeying after 1G.
|
||||
c.RekeyThreshold = 1 << 30 |
||||
} |
||||
if c.RekeyThreshold < minRekeyThreshold { |
||||
c.RekeyThreshold = minRekeyThreshold |
||||
} |
||||
} |
||||
|
||||
// buildDataSignedForAuth returns the data that is signed in order to prove
|
||||
// possession of a private key. See RFC 4252, section 7.
|
||||
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { |
||||
data := struct { |
||||
Session []byte |
||||
Type byte |
||||
User string |
||||
Service string |
||||
Method string |
||||
Sign bool |
||||
Algo []byte |
||||
PubKey []byte |
||||
}{ |
||||
sessionId, |
||||
msgUserAuthRequest, |
||||
req.User, |
||||
req.Service, |
||||
req.Method, |
||||
true, |
||||
algo, |
||||
pubKey, |
||||
} |
||||
return Marshal(data) |
||||
} |
||||
|
||||
func appendU16(buf []byte, n uint16) []byte { |
||||
return append(buf, byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendU32(buf []byte, n uint32) []byte { |
||||
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendU64(buf []byte, n uint64) []byte { |
||||
return append(buf, |
||||
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), |
||||
byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendInt(buf []byte, n int) []byte { |
||||
return appendU32(buf, uint32(n)) |
||||
} |
||||
|
||||
func appendString(buf []byte, s string) []byte { |
||||
buf = appendU32(buf, uint32(len(s))) |
||||
buf = append(buf, s...) |
||||
return buf |
||||
} |
||||
|
||||
func appendBool(buf []byte, b bool) []byte { |
||||
if b { |
||||
return append(buf, 1) |
||||
} |
||||
return append(buf, 0) |
||||
} |
||||
|
||||
// newCond is a helper to hide the fact that there is no usable zero
|
||||
// value for sync.Cond.
|
||||
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } |
||||
|
||||
// window represents the buffer available to clients
|
||||
// wishing to write to a channel.
|
||||
type window struct { |
||||
*sync.Cond |
||||
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
||||
writeWaiters int |
||||
closed bool |
||||
} |
||||
|
||||
// add adds win to the amount of window available
|
||||
// for consumers.
|
||||
func (w *window) add(win uint32) bool { |
||||
// a zero sized window adjust is a noop.
|
||||
if win == 0 { |
||||
return true |
||||
} |
||||
w.L.Lock() |
||||
if w.win+win < win { |
||||
w.L.Unlock() |
||||
return false |
||||
} |
||||
w.win += win |
||||
// It is unusual that multiple goroutines would be attempting to reserve
|
||||
// window space, but not guaranteed. Use broadcast to notify all waiters
|
||||
// that additional window is available.
|
||||
w.Broadcast() |
||||
w.L.Unlock() |
||||
return true |
||||
} |
||||
|
||||
// close sets the window to closed, so all reservations fail
|
||||
// immediately.
|
||||
func (w *window) close() { |
||||
w.L.Lock() |
||||
w.closed = true |
||||
w.Broadcast() |
||||
w.L.Unlock() |
||||
} |
||||
|
||||
// reserve reserves win from the available window capacity.
|
||||
// If no capacity remains, reserve will block. reserve may
|
||||
// return less than requested.
|
||||
func (w *window) reserve(win uint32) (uint32, error) { |
||||
var err error |
||||
w.L.Lock() |
||||
w.writeWaiters++ |
||||
w.Broadcast() |
||||
for w.win == 0 && !w.closed { |
||||
w.Wait() |
||||
} |
||||
w.writeWaiters-- |
||||
if w.win < win { |
||||
win = w.win |
||||
} |
||||
w.win -= win |
||||
if w.closed { |
||||
err = io.EOF |
||||
} |
||||
w.L.Unlock() |
||||
return win, err |
||||
} |
||||
|
||||
// waitWriterBlocked waits until some goroutine is blocked for further
|
||||
// writes. It is used in tests only.
|
||||
func (w *window) waitWriterBlocked() { |
||||
w.Cond.L.Lock() |
||||
for w.writeWaiters == 0 { |
||||
w.Cond.Wait() |
||||
} |
||||
w.Cond.L.Unlock() |
||||
} |
@ -0,0 +1,143 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
) |
||||
|
||||
// OpenChannelError is returned if the other side rejects an
|
||||
// OpenChannel request.
|
||||
type OpenChannelError struct { |
||||
Reason RejectionReason |
||||
Message string |
||||
} |
||||
|
||||
func (e *OpenChannelError) Error() string { |
||||
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) |
||||
} |
||||
|
||||
// ConnMetadata holds metadata for the connection.
|
||||
type ConnMetadata interface { |
||||
// User returns the user ID for this connection.
|
||||
User() string |
||||
|
||||
// SessionID returns the sesson hash, also denoted by H.
|
||||
SessionID() []byte |
||||
|
||||
// ClientVersion returns the client's version string as hashed
|
||||
// into the session ID.
|
||||
ClientVersion() []byte |
||||
|
||||
// ServerVersion returns the server's version string as hashed
|
||||
// into the session ID.
|
||||
ServerVersion() []byte |
||||
|
||||
// RemoteAddr returns the remote address for this connection.
|
||||
RemoteAddr() net.Addr |
||||
|
||||
// LocalAddr returns the local address for this connection.
|
||||
LocalAddr() net.Addr |
||||
} |
||||
|
||||
// Conn represents an SSH connection for both server and client roles.
|
||||
// Conn is the basis for implementing an application layer, such
|
||||
// as ClientConn, which implements the traditional shell access for
|
||||
// clients.
|
||||
type Conn interface { |
||||
ConnMetadata |
||||
|
||||
// SendRequest sends a global request, and returns the
|
||||
// reply. If wantReply is true, it returns the response status
|
||||
// and payload. See also RFC4254, section 4.
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) |
||||
|
||||
// OpenChannel tries to open an channel. If the request is
|
||||
// rejected, it returns *OpenChannelError. On success it returns
|
||||
// the SSH Channel and a Go channel for incoming, out-of-band
|
||||
// requests. The Go channel must be serviced, or the
|
||||
// connection will hang.
|
||||
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) |
||||
|
||||
// Close closes the underlying network connection
|
||||
Close() error |
||||
|
||||
// Wait blocks until the connection has shut down, and returns the
|
||||
// error causing the shutdown.
|
||||
Wait() error |
||||
|
||||
// TODO(hanwen): consider exposing:
|
||||
// RequestKeyChange
|
||||
// Disconnect
|
||||
} |
||||
|
||||
// DiscardRequests consumes and rejects all requests from the
|
||||
// passed-in channel.
|
||||
func DiscardRequests(in <-chan *Request) { |
||||
for req := range in { |
||||
if req.WantReply { |
||||
req.Reply(false, nil) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// A connection represents an incoming connection.
|
||||
type connection struct { |
||||
transport *handshakeTransport |
||||
sshConn |
||||
|
||||
// The connection protocol.
|
||||
*mux |
||||
} |
||||
|
||||
func (c *connection) Close() error { |
||||
return c.sshConn.conn.Close() |
||||
} |
||||
|
||||
// sshconn provides net.Conn metadata, but disallows direct reads and
|
||||
// writes.
|
||||
type sshConn struct { |
||||
conn net.Conn |
||||
|
||||
user string |
||||
sessionID []byte |
||||
clientVersion []byte |
||||
serverVersion []byte |
||||
} |
||||
|
||||
func dup(src []byte) []byte { |
||||
dst := make([]byte, len(src)) |
||||
copy(dst, src) |
||||
return dst |
||||
} |
||||
|
||||
func (c *sshConn) User() string { |
||||
return c.user |
||||
} |
||||
|
||||
func (c *sshConn) RemoteAddr() net.Addr { |
||||
return c.conn.RemoteAddr() |
||||
} |
||||
|
||||
func (c *sshConn) Close() error { |
||||
return c.conn.Close() |
||||
} |
||||
|
||||
func (c *sshConn) LocalAddr() net.Addr { |
||||
return c.conn.LocalAddr() |
||||
} |
||||
|
||||
func (c *sshConn) SessionID() []byte { |
||||
return dup(c.sessionID) |
||||
} |
||||
|
||||
func (c *sshConn) ClientVersion() []byte { |
||||
return dup(c.clientVersion) |
||||
} |
||||
|
||||
func (c *sshConn) ServerVersion() []byte { |
||||
return dup(c.serverVersion) |
||||
} |
@ -0,0 +1,18 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/* |
||||
Package ssh implements an SSH client and server. |
||||
|
||||
SSH is a transport security protocol, an authentication protocol and a |
||||
family of application protocols. The most typical application level |
||||
protocol is a remote shell and this is specifically implemented. However, |
||||
the multiplexed nature of SSH is exposed to users that wish to support |
||||
others. |
||||
|
||||
References: |
||||
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
|
||||
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
|
||||
*/ |
||||
package ssh // import "golang.org/x/crypto/ssh"
|
@ -0,0 +1,625 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"net" |
||||
"sync" |
||||
) |
||||
|
||||
// debugHandshake, if set, prints messages sent and received. Key
|
||||
// exchange messages are printed as if DH were used, so the debug
|
||||
// messages are wrong when using ECDH.
|
||||
const debugHandshake = false |
||||
|
||||
// chanSize sets the amount of buffering SSH connections. This is
|
||||
// primarily for testing: setting chanSize=0 uncovers deadlocks more
|
||||
// quickly.
|
||||
const chanSize = 16 |
||||
|
||||
// keyingTransport is a packet based transport that supports key
|
||||
// changes. It need not be thread-safe. It should pass through
|
||||
// msgNewKeys in both directions.
|
||||
type keyingTransport interface { |
||||
packetConn |
||||
|
||||
// prepareKeyChange sets up a key change. The key change for a
|
||||
// direction will be effected if a msgNewKeys message is sent
|
||||
// or received.
|
||||
prepareKeyChange(*algorithms, *kexResult) error |
||||
} |
||||
|
||||
// handshakeTransport implements rekeying on top of a keyingTransport
|
||||
// and offers a thread-safe writePacket() interface.
|
||||
type handshakeTransport struct { |
||||
conn keyingTransport |
||||
config *Config |
||||
|
||||
serverVersion []byte |
||||
clientVersion []byte |
||||
|
||||
// hostKeys is non-empty if we are the server. In that case,
|
||||
// it contains all host keys that can be used to sign the
|
||||
// connection.
|
||||
hostKeys []Signer |
||||
|
||||
// hostKeyAlgorithms is non-empty if we are the client. In that case,
|
||||
// we accept these key types from the server as host key.
|
||||
hostKeyAlgorithms []string |
||||
|
||||
// On read error, incoming is closed, and readError is set.
|
||||
incoming chan []byte |
||||
readError error |
||||
|
||||
mu sync.Mutex |
||||
writeError error |
||||
sentInitPacket []byte |
||||
sentInitMsg *kexInitMsg |
||||
pendingPackets [][]byte // Used when a key exchange is in progress.
|
||||
|
||||
// If the read loop wants to schedule a kex, it pings this
|
||||
// channel, and the write loop will send out a kex
|
||||
// message.
|
||||
requestKex chan struct{} |
||||
|
||||
// If the other side requests or confirms a kex, its kexInit
|
||||
// packet is sent here for the write loop to find it.
|
||||
startKex chan *pendingKex |
||||
|
||||
// data for host key checking
|
||||
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error |
||||
dialAddress string |
||||
remoteAddr net.Addr |
||||
|
||||
// Algorithms agreed in the last key exchange.
|
||||
algorithms *algorithms |
||||
|
||||
readPacketsLeft uint32 |
||||
readBytesLeft int64 |
||||
|
||||
writePacketsLeft uint32 |
||||
writeBytesLeft int64 |
||||
|
||||
// The session ID or nil if first kex did not complete yet.
|
||||
sessionID []byte |
||||
} |
||||
|
||||
type pendingKex struct { |
||||
otherInit []byte |
||||
done chan error |
||||
} |
||||
|
||||
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { |
||||
t := &handshakeTransport{ |
||||
conn: conn, |
||||
serverVersion: serverVersion, |
||||
clientVersion: clientVersion, |
||||
incoming: make(chan []byte, chanSize), |
||||
requestKex: make(chan struct{}, 1), |
||||
startKex: make(chan *pendingKex, 1), |
||||
|
||||
config: config, |
||||
} |
||||
|
||||
// We always start with a mandatory key exchange.
|
||||
t.requestKex <- struct{}{} |
||||
return t |
||||
} |
||||
|
||||
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { |
||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) |
||||
t.dialAddress = dialAddr |
||||
t.remoteAddr = addr |
||||
t.hostKeyCallback = config.HostKeyCallback |
||||
if config.HostKeyAlgorithms != nil { |
||||
t.hostKeyAlgorithms = config.HostKeyAlgorithms |
||||
} else { |
||||
t.hostKeyAlgorithms = supportedHostKeyAlgos |
||||
} |
||||
go t.readLoop() |
||||
go t.kexLoop() |
||||
return t |
||||
} |
||||
|
||||
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { |
||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) |
||||
t.hostKeys = config.hostKeys |
||||
go t.readLoop() |
||||
go t.kexLoop() |
||||
return t |
||||
} |
||||
|
||||
func (t *handshakeTransport) getSessionID() []byte { |
||||
return t.sessionID |
||||
} |
||||
|
||||
// waitSession waits for the session to be established. This should be
|
||||
// the first thing to call after instantiating handshakeTransport.
|
||||
func (t *handshakeTransport) waitSession() error { |
||||
p, err := t.readPacket() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if p[0] != msgNewKeys { |
||||
return fmt.Errorf("ssh: first packet should be msgNewKeys") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) id() string { |
||||
if len(t.hostKeys) > 0 { |
||||
return "server" |
||||
} |
||||
return "client" |
||||
} |
||||
|
||||
func (t *handshakeTransport) printPacket(p []byte, write bool) { |
||||
action := "got" |
||||
if write { |
||||
action = "sent" |
||||
} |
||||
|
||||
if p[0] == msgChannelData || p[0] == msgChannelExtendedData { |
||||
log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p)) |
||||
} else { |
||||
msg, err := decode(p) |
||||
log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err) |
||||
} |
||||
} |
||||
|
||||
func (t *handshakeTransport) readPacket() ([]byte, error) { |
||||
p, ok := <-t.incoming |
||||
if !ok { |
||||
return nil, t.readError |
||||
} |
||||
return p, nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) readLoop() { |
||||
first := true |
||||
for { |
||||
p, err := t.readOnePacket(first) |
||||
first = false |
||||
if err != nil { |
||||
t.readError = err |
||||
close(t.incoming) |
||||
break |
||||
} |
||||
if p[0] == msgIgnore || p[0] == msgDebug { |
||||
continue |
||||
} |
||||
t.incoming <- p |
||||
} |
||||
|
||||
// Stop writers too.
|
||||
t.recordWriteError(t.readError) |
||||
|
||||
// Unblock the writer should it wait for this.
|
||||
close(t.startKex) |
||||
|
||||
// Don't close t.requestKex; it's also written to from writePacket.
|
||||
} |
||||
|
||||
func (t *handshakeTransport) pushPacket(p []byte) error { |
||||
if debugHandshake { |
||||
t.printPacket(p, true) |
||||
} |
||||
return t.conn.writePacket(p) |
||||
} |
||||
|
||||
func (t *handshakeTransport) getWriteError() error { |
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
return t.writeError |
||||
} |
||||
|
||||
func (t *handshakeTransport) recordWriteError(err error) { |
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
if t.writeError == nil && err != nil { |
||||
t.writeError = err |
||||
} |
||||
} |
||||
|
||||
func (t *handshakeTransport) requestKeyExchange() { |
||||
select { |
||||
case t.requestKex <- struct{}{}: |
||||
default: |
||||
// something already requested a kex, so do nothing.
|
||||
} |
||||
} |
||||
|
||||
func (t *handshakeTransport) kexLoop() { |
||||
|
||||
write: |
||||
for t.getWriteError() == nil { |
||||
var request *pendingKex |
||||
var sent bool |
||||
|
||||
for request == nil || !sent { |
||||
var ok bool |
||||
select { |
||||
case request, ok = <-t.startKex: |
||||
if !ok { |
||||
break write |
||||
} |
||||
case <-t.requestKex: |
||||
break |
||||
} |
||||
|
||||
if !sent { |
||||
if err := t.sendKexInit(); err != nil { |
||||
t.recordWriteError(err) |
||||
break |
||||
} |
||||
sent = true |
||||
} |
||||
} |
||||
|
||||
if err := t.getWriteError(); err != nil { |
||||
if request != nil { |
||||
request.done <- err |
||||
} |
||||
break |
||||
} |
||||
|
||||
// We're not servicing t.requestKex, but that is OK:
|
||||
// we never block on sending to t.requestKex.
|
||||
|
||||
// We're not servicing t.startKex, but the remote end
|
||||
// has just sent us a kexInitMsg, so it can't send
|
||||
// another key change request, until we close the done
|
||||
// channel on the pendingKex request.
|
||||
|
||||
err := t.enterKeyExchange(request.otherInit) |
||||
|
||||
t.mu.Lock() |
||||
t.writeError = err |
||||
t.sentInitPacket = nil |
||||
t.sentInitMsg = nil |
||||
t.writePacketsLeft = packetRekeyThreshold |
||||
if t.config.RekeyThreshold > 0 { |
||||
t.writeBytesLeft = int64(t.config.RekeyThreshold) |
||||
} else if t.algorithms != nil { |
||||
t.writeBytesLeft = t.algorithms.w.rekeyBytes() |
||||
} |
||||
|
||||
// we have completed the key exchange. Since the
|
||||
// reader is still blocked, it is safe to clear out
|
||||
// the requestKex channel. This avoids the situation
|
||||
// where: 1) we consumed our own request for the
|
||||
// initial kex, and 2) the kex from the remote side
|
||||
// caused another send on the requestKex channel,
|
||||
clear: |
||||
for { |
||||
select { |
||||
case <-t.requestKex: |
||||
//
|
||||
default: |
||||
break clear |
||||
} |
||||
} |
||||
|
||||
request.done <- t.writeError |
||||
|
||||
// kex finished. Push packets that we received while
|
||||
// the kex was in progress. Don't look at t.startKex
|
||||
// and don't increment writtenSinceKex: if we trigger
|
||||
// another kex while we are still busy with the last
|
||||
// one, things will become very confusing.
|
||||
for _, p := range t.pendingPackets { |
||||
t.writeError = t.pushPacket(p) |
||||
if t.writeError != nil { |
||||
break |
||||
} |
||||
} |
||||
t.pendingPackets = t.pendingPackets[:0] |
||||
t.mu.Unlock() |
||||
} |
||||
|
||||
// drain startKex channel. We don't service t.requestKex
|
||||
// because nobody does blocking sends there.
|
||||
go func() { |
||||
for init := range t.startKex { |
||||
init.done <- t.writeError |
||||
} |
||||
}() |
||||
|
||||
// Unblock reader.
|
||||
t.conn.Close() |
||||
} |
||||
|
||||
// The protocol uses uint32 for packet counters, so we can't let them
|
||||
// reach 1<<32. We will actually read and write more packets than
|
||||
// this, though: the other side may send more packets, and after we
|
||||
// hit this limit on writing we will send a few more packets for the
|
||||
// key exchange itself.
|
||||
const packetRekeyThreshold = (1 << 31) |
||||
|
||||
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { |
||||
p, err := t.conn.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if t.readPacketsLeft > 0 { |
||||
t.readPacketsLeft-- |
||||
} else { |
||||
t.requestKeyExchange() |
||||
} |
||||
|
||||
if t.readBytesLeft > 0 { |
||||
t.readBytesLeft -= int64(len(p)) |
||||
} else { |
||||
t.requestKeyExchange() |
||||
} |
||||
|
||||
if debugHandshake { |
||||
t.printPacket(p, false) |
||||
} |
||||
|
||||
if first && p[0] != msgKexInit { |
||||
return nil, fmt.Errorf("ssh: first packet should be msgKexInit") |
||||
} |
||||
|
||||
if p[0] != msgKexInit { |
||||
return p, nil |
||||
} |
||||
|
||||
firstKex := t.sessionID == nil |
||||
|
||||
kex := pendingKex{ |
||||
done: make(chan error, 1), |
||||
otherInit: p, |
||||
} |
||||
t.startKex <- &kex |
||||
err = <-kex.done |
||||
|
||||
if debugHandshake { |
||||
log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) |
||||
} |
||||
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
t.readPacketsLeft = packetRekeyThreshold |
||||
if t.config.RekeyThreshold > 0 { |
||||
t.readBytesLeft = int64(t.config.RekeyThreshold) |
||||
} else { |
||||
t.readBytesLeft = t.algorithms.r.rekeyBytes() |
||||
} |
||||
|
||||
// By default, a key exchange is hidden from higher layers by
|
||||
// translating it into msgIgnore.
|
||||
successPacket := []byte{msgIgnore} |
||||
if firstKex { |
||||
// sendKexInit() for the first kex waits for
|
||||
// msgNewKeys so the authentication process is
|
||||
// guaranteed to happen over an encrypted transport.
|
||||
successPacket = []byte{msgNewKeys} |
||||
} |
||||
|
||||
return successPacket, nil |
||||
} |
||||
|
||||
// sendKexInit sends a key change message.
|
||||
func (t *handshakeTransport) sendKexInit() error { |
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
if t.sentInitMsg != nil { |
||||
// kexInits may be sent either in response to the other side,
|
||||
// or because our side wants to initiate a key change, so we
|
||||
// may have already sent a kexInit. In that case, don't send a
|
||||
// second kexInit.
|
||||
return nil |
||||
} |
||||
|
||||
msg := &kexInitMsg{ |
||||
KexAlgos: t.config.KeyExchanges, |
||||
CiphersClientServer: t.config.Ciphers, |
||||
CiphersServerClient: t.config.Ciphers, |
||||
MACsClientServer: t.config.MACs, |
||||
MACsServerClient: t.config.MACs, |
||||
CompressionClientServer: supportedCompressions, |
||||
CompressionServerClient: supportedCompressions, |
||||
} |
||||
io.ReadFull(rand.Reader, msg.Cookie[:]) |
||||
|
||||
if len(t.hostKeys) > 0 { |
||||
for _, k := range t.hostKeys { |
||||
msg.ServerHostKeyAlgos = append( |
||||
msg.ServerHostKeyAlgos, k.PublicKey().Type()) |
||||
} |
||||
} else { |
||||
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms |
||||
} |
||||
packet := Marshal(msg) |
||||
|
||||
// writePacket destroys the contents, so save a copy.
|
||||
packetCopy := make([]byte, len(packet)) |
||||
copy(packetCopy, packet) |
||||
|
||||
if err := t.pushPacket(packetCopy); err != nil { |
||||
return err |
||||
} |
||||
|
||||
t.sentInitMsg = msg |
||||
t.sentInitPacket = packet |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) writePacket(p []byte) error { |
||||
switch p[0] { |
||||
case msgKexInit: |
||||
return errors.New("ssh: only handshakeTransport can send kexInit") |
||||
case msgNewKeys: |
||||
return errors.New("ssh: only handshakeTransport can send newKeys") |
||||
} |
||||
|
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
if t.writeError != nil { |
||||
return t.writeError |
||||
} |
||||
|
||||
if t.sentInitMsg != nil { |
||||
// Copy the packet so the writer can reuse the buffer.
|
||||
cp := make([]byte, len(p)) |
||||
copy(cp, p) |
||||
t.pendingPackets = append(t.pendingPackets, cp) |
||||
return nil |
||||
} |
||||
|
||||
if t.writeBytesLeft > 0 { |
||||
t.writeBytesLeft -= int64(len(p)) |
||||
} else { |
||||
t.requestKeyExchange() |
||||
} |
||||
|
||||
if t.writePacketsLeft > 0 { |
||||
t.writePacketsLeft-- |
||||
} else { |
||||
t.requestKeyExchange() |
||||
} |
||||
|
||||
if err := t.pushPacket(p); err != nil { |
||||
t.writeError = err |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) Close() error { |
||||
return t.conn.Close() |
||||
} |
||||
|
||||
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { |
||||
if debugHandshake { |
||||
log.Printf("%s entered key exchange", t.id()) |
||||
} |
||||
|
||||
otherInit := &kexInitMsg{} |
||||
if err := Unmarshal(otherInitPacket, otherInit); err != nil { |
||||
return err |
||||
} |
||||
|
||||
magics := handshakeMagics{ |
||||
clientVersion: t.clientVersion, |
||||
serverVersion: t.serverVersion, |
||||
clientKexInit: otherInitPacket, |
||||
serverKexInit: t.sentInitPacket, |
||||
} |
||||
|
||||
clientInit := otherInit |
||||
serverInit := t.sentInitMsg |
||||
if len(t.hostKeys) == 0 { |
||||
clientInit, serverInit = serverInit, clientInit |
||||
|
||||
magics.clientKexInit = t.sentInitPacket |
||||
magics.serverKexInit = otherInitPacket |
||||
} |
||||
|
||||
var err error |
||||
t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// We don't send FirstKexFollows, but we handle receiving it.
|
||||
//
|
||||
// RFC 4253 section 7 defines the kex and the agreement method for
|
||||
// first_kex_packet_follows. It states that the guessed packet
|
||||
// should be ignored if the "kex algorithm and/or the host
|
||||
// key algorithm is guessed wrong (server and client have
|
||||
// different preferred algorithm), or if any of the other
|
||||
// algorithms cannot be agreed upon". The other algorithms have
|
||||
// already been checked above so the kex algorithm and host key
|
||||
// algorithm are checked here.
|
||||
if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) { |
||||
// other side sent a kex message for the wrong algorithm,
|
||||
// which we have to ignore.
|
||||
if _, err := t.conn.readPacket(); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
kex, ok := kexAlgoMap[t.algorithms.kex] |
||||
if !ok { |
||||
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex) |
||||
} |
||||
|
||||
var result *kexResult |
||||
if len(t.hostKeys) > 0 { |
||||
result, err = t.server(kex, t.algorithms, &magics) |
||||
} else { |
||||
result, err = t.client(kex, t.algorithms, &magics) |
||||
} |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if t.sessionID == nil { |
||||
t.sessionID = result.H |
||||
} |
||||
result.SessionID = t.sessionID |
||||
|
||||
t.conn.prepareKeyChange(t.algorithms, result) |
||||
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { |
||||
return err |
||||
} |
||||
if packet, err := t.conn.readPacket(); err != nil { |
||||
return err |
||||
} else if packet[0] != msgNewKeys { |
||||
return unexpectedMessageError(msgNewKeys, packet[0]) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { |
||||
var hostKey Signer |
||||
for _, k := range t.hostKeys { |
||||
if algs.hostKey == k.PublicKey().Type() { |
||||
hostKey = k |
||||
} |
||||
} |
||||
|
||||
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) |
||||
return r, err |
||||
} |
||||
|
||||
func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { |
||||
result, err := kex.Client(t.conn, t.config.Rand, magics) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKey, err := ParsePublicKey(result.HostKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if err := verifyHostKeySignature(hostKey, result); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if t.hostKeyCallback != nil { |
||||
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
return result, nil |
||||
} |
@ -0,0 +1,540 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/rand" |
||||
"crypto/subtle" |
||||
"errors" |
||||
"io" |
||||
"math/big" |
||||
|
||||
"golang.org/x/crypto/curve25519" |
||||
) |
||||
|
||||
const ( |
||||
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" |
||||
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" |
||||
kexAlgoECDH256 = "ecdh-sha2-nistp256" |
||||
kexAlgoECDH384 = "ecdh-sha2-nistp384" |
||||
kexAlgoECDH521 = "ecdh-sha2-nistp521" |
||||
kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org" |
||||
) |
||||
|
||||
// kexResult captures the outcome of a key exchange.
|
||||
type kexResult struct { |
||||
// Session hash. See also RFC 4253, section 8.
|
||||
H []byte |
||||
|
||||
// Shared secret. See also RFC 4253, section 8.
|
||||
K []byte |
||||
|
||||
// Host key as hashed into H.
|
||||
HostKey []byte |
||||
|
||||
// Signature of H.
|
||||
Signature []byte |
||||
|
||||
// A cryptographic hash function that matches the security
|
||||
// level of the key exchange algorithm. It is used for
|
||||
// calculating H, and for deriving keys from H and K.
|
||||
Hash crypto.Hash |
||||
|
||||
// The session ID, which is the first H computed. This is used
|
||||
// to derive key material inside the transport.
|
||||
SessionID []byte |
||||
} |
||||
|
||||
// handshakeMagics contains data that is always included in the
|
||||
// session hash.
|
||||
type handshakeMagics struct { |
||||
clientVersion, serverVersion []byte |
||||
clientKexInit, serverKexInit []byte |
||||
} |
||||
|
||||
func (m *handshakeMagics) write(w io.Writer) { |
||||
writeString(w, m.clientVersion) |
||||
writeString(w, m.serverVersion) |
||||
writeString(w, m.clientKexInit) |
||||
writeString(w, m.serverKexInit) |
||||
} |
||||
|
||||
// kexAlgorithm abstracts different key exchange algorithms.
|
||||
type kexAlgorithm interface { |
||||
// Server runs server-side key agreement, signing the result
|
||||
// with a hostkey.
|
||||
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) |
||||
|
||||
// Client runs the client-side key agreement. Caller is
|
||||
// responsible for verifying the host key signature.
|
||||
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) |
||||
} |
||||
|
||||
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
|
||||
type dhGroup struct { |
||||
g, p, pMinus1 *big.Int |
||||
} |
||||
|
||||
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { |
||||
if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 { |
||||
return nil, errors.New("ssh: DH parameter out of bounds") |
||||
} |
||||
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil |
||||
} |
||||
|
||||
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
hashFunc := crypto.SHA1 |
||||
|
||||
var x *big.Int |
||||
for { |
||||
var err error |
||||
if x, err = rand.Int(randSource, group.pMinus1); err != nil { |
||||
return nil, err |
||||
} |
||||
if x.Sign() > 0 { |
||||
break |
||||
} |
||||
} |
||||
|
||||
X := new(big.Int).Exp(group.g, x, group.p) |
||||
kexDHInit := kexDHInitMsg{ |
||||
X: X, |
||||
} |
||||
if err := c.writePacket(Marshal(&kexDHInit)); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var kexDHReply kexDHReplyMsg |
||||
if err = Unmarshal(packet, &kexDHReply); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kInt, err := group.diffieHellman(kexDHReply.Y, x) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
h := hashFunc.New() |
||||
magics.write(h) |
||||
writeString(h, kexDHReply.HostKey) |
||||
writeInt(h, X) |
||||
writeInt(h, kexDHReply.Y) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: kexDHReply.HostKey, |
||||
Signature: kexDHReply.Signature, |
||||
Hash: crypto.SHA1, |
||||
}, nil |
||||
} |
||||
|
||||
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
hashFunc := crypto.SHA1 |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return |
||||
} |
||||
var kexDHInit kexDHInitMsg |
||||
if err = Unmarshal(packet, &kexDHInit); err != nil { |
||||
return |
||||
} |
||||
|
||||
var y *big.Int |
||||
for { |
||||
if y, err = rand.Int(randSource, group.pMinus1); err != nil { |
||||
return |
||||
} |
||||
if y.Sign() > 0 { |
||||
break |
||||
} |
||||
} |
||||
|
||||
Y := new(big.Int).Exp(group.g, y, group.p) |
||||
kInt, err := group.diffieHellman(kexDHInit.X, y) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
h := hashFunc.New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeInt(h, kexDHInit.X) |
||||
writeInt(h, Y) |
||||
|
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, randSource, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kexDHReply := kexDHReplyMsg{ |
||||
HostKey: hostKeyBytes, |
||||
Y: Y, |
||||
Signature: sig, |
||||
} |
||||
packet = Marshal(&kexDHReply) |
||||
|
||||
err = c.writePacket(packet) |
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
Hash: crypto.SHA1, |
||||
}, nil |
||||
} |
||||
|
||||
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
|
||||
// described in RFC 5656, section 4.
|
||||
type ecdh struct { |
||||
curve elliptic.Curve |
||||
} |
||||
|
||||
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kexInit := kexECDHInitMsg{ |
||||
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), |
||||
} |
||||
|
||||
serialized := Marshal(&kexInit) |
||||
if err := c.writePacket(serialized); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var reply kexECDHReplyMsg |
||||
if err = Unmarshal(packet, &reply); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// generate shared secret
|
||||
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) |
||||
|
||||
h := ecHash(kex.curve).New() |
||||
magics.write(h) |
||||
writeString(h, reply.HostKey) |
||||
writeString(h, kexInit.ClientPubKey) |
||||
writeString(h, reply.EphemeralPubKey) |
||||
K := make([]byte, intLength(secret)) |
||||
marshalInt(K, secret) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: reply.Signature, |
||||
Hash: ecHash(kex.curve), |
||||
}, nil |
||||
} |
||||
|
||||
// unmarshalECKey parses and checks an EC key.
|
||||
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { |
||||
x, y = elliptic.Unmarshal(curve, pubkey) |
||||
if x == nil { |
||||
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") |
||||
} |
||||
if !validateECPublicKey(curve, x, y) { |
||||
return nil, nil, errors.New("ssh: public key not on curve") |
||||
} |
||||
return x, y, nil |
||||
} |
||||
|
||||
// validateECPublicKey checks that the point is a valid public key for
|
||||
// the given curve. See [SEC1], 3.2.2
|
||||
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { |
||||
if x.Sign() == 0 && y.Sign() == 0 { |
||||
return false |
||||
} |
||||
|
||||
if x.Cmp(curve.Params().P) >= 0 { |
||||
return false |
||||
} |
||||
|
||||
if y.Cmp(curve.Params().P) >= 0 { |
||||
return false |
||||
} |
||||
|
||||
if !curve.IsOnCurve(x, y) { |
||||
return false |
||||
} |
||||
|
||||
// We don't check if N * PubKey == 0, since
|
||||
//
|
||||
// - the NIST curves have cofactor = 1, so this is implicit.
|
||||
// (We don't foresee an implementation that supports non NIST
|
||||
// curves)
|
||||
//
|
||||
// - for ephemeral keys, we don't need to worry about small
|
||||
// subgroup attacks.
|
||||
return true |
||||
} |
||||
|
||||
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var kexECDHInit kexECDHInitMsg |
||||
if err = Unmarshal(packet, &kexECDHInit); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// We could cache this key across multiple users/multiple
|
||||
// connection attempts, but the benefit is small. OpenSSH
|
||||
// generates a new key for each incoming connection.
|
||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) |
||||
|
||||
// generate shared secret
|
||||
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) |
||||
|
||||
h := ecHash(kex.curve).New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeString(h, kexECDHInit.ClientPubKey) |
||||
writeString(h, serializedEphKey) |
||||
|
||||
K := make([]byte, intLength(secret)) |
||||
marshalInt(K, secret) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, rand, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
reply := kexECDHReplyMsg{ |
||||
EphemeralPubKey: serializedEphKey, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
} |
||||
|
||||
serialized := Marshal(&reply) |
||||
if err := c.writePacket(serialized); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: sig, |
||||
Hash: ecHash(kex.curve), |
||||
}, nil |
||||
} |
||||
|
||||
var kexAlgoMap = map[string]kexAlgorithm{} |
||||
|
||||
func init() { |
||||
// This is the group called diffie-hellman-group1-sha1 in RFC
|
||||
// 4253 and Oakley Group 2 in RFC 2409.
|
||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) |
||||
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ |
||||
g: new(big.Int).SetInt64(2), |
||||
p: p, |
||||
pMinus1: new(big.Int).Sub(p, bigOne), |
||||
} |
||||
|
||||
// This is the group called diffie-hellman-group14-sha1 in RFC
|
||||
// 4253 and Oakley Group 14 in RFC 3526.
|
||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) |
||||
|
||||
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ |
||||
g: new(big.Int).SetInt64(2), |
||||
p: p, |
||||
pMinus1: new(big.Int).Sub(p, bigOne), |
||||
} |
||||
|
||||
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} |
||||
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} |
||||
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} |
||||
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} |
||||
} |
||||
|
||||
// curve25519sha256 implements the curve25519-sha256@libssh.org key
|
||||
// agreement protocol, as described in
|
||||
// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
|
||||
type curve25519sha256 struct{} |
||||
|
||||
type curve25519KeyPair struct { |
||||
priv [32]byte |
||||
pub [32]byte |
||||
} |
||||
|
||||
func (kp *curve25519KeyPair) generate(rand io.Reader) error { |
||||
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { |
||||
return err |
||||
} |
||||
curve25519.ScalarBaseMult(&kp.pub, &kp.priv) |
||||
return nil |
||||
} |
||||
|
||||
// curve25519Zeros is just an array of 32 zero bytes so that we have something
|
||||
// convenient to compare against in order to reject curve25519 points with the
|
||||
// wrong order.
|
||||
var curve25519Zeros [32]byte |
||||
|
||||
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
var kp curve25519KeyPair |
||||
if err := kp.generate(rand); err != nil { |
||||
return nil, err |
||||
} |
||||
if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var reply kexECDHReplyMsg |
||||
if err = Unmarshal(packet, &reply); err != nil { |
||||
return nil, err |
||||
} |
||||
if len(reply.EphemeralPubKey) != 32 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length") |
||||
} |
||||
|
||||
var servPub, secret [32]byte |
||||
copy(servPub[:], reply.EphemeralPubKey) |
||||
curve25519.ScalarMult(&secret, &kp.priv, &servPub) |
||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order") |
||||
} |
||||
|
||||
h := crypto.SHA256.New() |
||||
magics.write(h) |
||||
writeString(h, reply.HostKey) |
||||
writeString(h, kp.pub[:]) |
||||
writeString(h, reply.EphemeralPubKey) |
||||
|
||||
kInt := new(big.Int).SetBytes(secret[:]) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: reply.Signature, |
||||
Hash: crypto.SHA256, |
||||
}, nil |
||||
} |
||||
|
||||
func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return |
||||
} |
||||
var kexInit kexECDHInitMsg |
||||
if err = Unmarshal(packet, &kexInit); err != nil { |
||||
return |
||||
} |
||||
|
||||
if len(kexInit.ClientPubKey) != 32 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length") |
||||
} |
||||
|
||||
var kp curve25519KeyPair |
||||
if err := kp.generate(rand); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var clientPub, secret [32]byte |
||||
copy(clientPub[:], kexInit.ClientPubKey) |
||||
curve25519.ScalarMult(&secret, &kp.priv, &clientPub) |
||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order") |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
h := crypto.SHA256.New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeString(h, kexInit.ClientPubKey) |
||||
writeString(h, kp.pub[:]) |
||||
|
||||
kInt := new(big.Int).SetBytes(secret[:]) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
sig, err := signAndMarshal(priv, rand, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
reply := kexECDHReplyMsg{ |
||||
EphemeralPubKey: kp.pub[:], |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
} |
||||
if err := c.writePacket(Marshal(&reply)); err != nil { |
||||
return nil, err |
||||
} |
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
Hash: crypto.SHA256, |
||||
}, nil |
||||
} |
@ -0,0 +1,905 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto" |
||||
"crypto/dsa" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/md5" |
||||
"crypto/rsa" |
||||
"crypto/sha256" |
||||
"crypto/x509" |
||||
"encoding/asn1" |
||||
"encoding/base64" |
||||
"encoding/hex" |
||||
"encoding/pem" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"math/big" |
||||
"strings" |
||||
|
||||
"golang.org/x/crypto/ed25519" |
||||
) |
||||
|
||||
// These constants represent the algorithm names for key types supported by this
|
||||
// package.
|
||||
const ( |
||||
KeyAlgoRSA = "ssh-rsa" |
||||
KeyAlgoDSA = "ssh-dss" |
||||
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" |
||||
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" |
||||
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" |
||||
KeyAlgoED25519 = "ssh-ed25519" |
||||
) |
||||
|
||||
// parsePubKey parses a public key of the given algorithm.
|
||||
// Use ParsePublicKey for keys with prepended algorithm.
|
||||
func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { |
||||
switch algo { |
||||
case KeyAlgoRSA: |
||||
return parseRSA(in) |
||||
case KeyAlgoDSA: |
||||
return parseDSA(in) |
||||
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: |
||||
return parseECDSA(in) |
||||
case KeyAlgoED25519: |
||||
return parseED25519(in) |
||||
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01: |
||||
cert, err := parseCert(in, certToPrivAlgo(algo)) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
return cert, nil, nil |
||||
} |
||||
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo) |
||||
} |
||||
|
||||
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
|
||||
// (see sshd(8) manual page) once the options and key type fields have been
|
||||
// removed.
|
||||
func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { |
||||
in = bytes.TrimSpace(in) |
||||
|
||||
i := bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
i = len(in) |
||||
} |
||||
base64Key := in[:i] |
||||
|
||||
key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) |
||||
n, err := base64.StdEncoding.Decode(key, base64Key) |
||||
if err != nil { |
||||
return nil, "", err |
||||
} |
||||
key = key[:n] |
||||
out, err = ParsePublicKey(key) |
||||
if err != nil { |
||||
return nil, "", err |
||||
} |
||||
comment = string(bytes.TrimSpace(in[i:])) |
||||
return out, comment, nil |
||||
} |
||||
|
||||
// ParseKnownHosts parses an entry in the format of the known_hosts file.
|
||||
//
|
||||
// The known_hosts format is documented in the sshd(8) manual page. This
|
||||
// function will parse a single entry from in. On successful return, marker
|
||||
// will contain the optional marker value (i.e. "cert-authority" or "revoked")
|
||||
// or else be empty, hosts will contain the hosts that this entry matches,
|
||||
// pubKey will contain the public key and comment will contain any trailing
|
||||
// comment at the end of the line. See the sshd(8) manual page for the various
|
||||
// forms that a host string can take.
|
||||
//
|
||||
// The unparsed remainder of the input will be returned in rest. This function
|
||||
// can be called repeatedly to parse multiple entries.
|
||||
//
|
||||
// If no entries were found in the input then err will be io.EOF. Otherwise a
|
||||
// non-nil err value indicates a parse error.
|
||||
func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) { |
||||
for len(in) > 0 { |
||||
end := bytes.IndexByte(in, '\n') |
||||
if end != -1 { |
||||
rest = in[end+1:] |
||||
in = in[:end] |
||||
} else { |
||||
rest = nil |
||||
} |
||||
|
||||
end = bytes.IndexByte(in, '\r') |
||||
if end != -1 { |
||||
in = in[:end] |
||||
} |
||||
|
||||
in = bytes.TrimSpace(in) |
||||
if len(in) == 0 || in[0] == '#' { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
i := bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
// Strip out the beginning of the known_host key.
|
||||
// This is either an optional marker or a (set of) hostname(s).
|
||||
keyFields := bytes.Fields(in) |
||||
if len(keyFields) < 3 || len(keyFields) > 5 { |
||||
return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data") |
||||
} |
||||
|
||||
// keyFields[0] is either "@cert-authority", "@revoked" or a comma separated
|
||||
// list of hosts
|
||||
marker := "" |
||||
if keyFields[0][0] == '@' { |
||||
marker = string(keyFields[0][1:]) |
||||
keyFields = keyFields[1:] |
||||
} |
||||
|
||||
hosts := string(keyFields[0]) |
||||
// keyFields[1] contains the key type (e.g. “ssh-rsa”).
|
||||
// However, that information is duplicated inside the
|
||||
// base64-encoded key and so is ignored here.
|
||||
|
||||
key := bytes.Join(keyFields[2:], []byte(" ")) |
||||
if pubKey, comment, err = parseAuthorizedKey(key); err != nil { |
||||
return "", nil, nil, "", nil, err |
||||
} |
||||
|
||||
return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil |
||||
} |
||||
|
||||
return "", nil, nil, "", nil, io.EOF |
||||
} |
||||
|
||||
// ParseAuthorizedKeys parses a public key from an authorized_keys
|
||||
// file used in OpenSSH according to the sshd(8) manual page.
|
||||
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { |
||||
for len(in) > 0 { |
||||
end := bytes.IndexByte(in, '\n') |
||||
if end != -1 { |
||||
rest = in[end+1:] |
||||
in = in[:end] |
||||
} else { |
||||
rest = nil |
||||
} |
||||
|
||||
end = bytes.IndexByte(in, '\r') |
||||
if end != -1 { |
||||
in = in[:end] |
||||
} |
||||
|
||||
in = bytes.TrimSpace(in) |
||||
if len(in) == 0 || in[0] == '#' { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
i := bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { |
||||
return out, comment, options, rest, nil |
||||
} |
||||
|
||||
// No key type recognised. Maybe there's an options field at
|
||||
// the beginning.
|
||||
var b byte |
||||
inQuote := false |
||||
var candidateOptions []string |
||||
optionStart := 0 |
||||
for i, b = range in { |
||||
isEnd := !inQuote && (b == ' ' || b == '\t') |
||||
if (b == ',' && !inQuote) || isEnd { |
||||
if i-optionStart > 0 { |
||||
candidateOptions = append(candidateOptions, string(in[optionStart:i])) |
||||
} |
||||
optionStart = i + 1 |
||||
} |
||||
if isEnd { |
||||
break |
||||
} |
||||
if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { |
||||
inQuote = !inQuote |
||||
} |
||||
} |
||||
for i < len(in) && (in[i] == ' ' || in[i] == '\t') { |
||||
i++ |
||||
} |
||||
if i == len(in) { |
||||
// Invalid line: unmatched quote
|
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
in = in[i:] |
||||
i = bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { |
||||
options = candidateOptions |
||||
return out, comment, options, rest, nil |
||||
} |
||||
|
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
return nil, "", nil, nil, errors.New("ssh: no key found") |
||||
} |
||||
|
||||
// ParsePublicKey parses an SSH public key formatted for use in
|
||||
// the SSH wire protocol according to RFC 4253, section 6.6.
|
||||
func ParsePublicKey(in []byte) (out PublicKey, err error) { |
||||
algo, in, ok := parseString(in) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
var rest []byte |
||||
out, rest, err = parsePubKey(in, string(algo)) |
||||
if len(rest) > 0 { |
||||
return nil, errors.New("ssh: trailing junk in public key") |
||||
} |
||||
|
||||
return out, err |
||||
} |
||||
|
||||
// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
|
||||
// authorized_keys file. The return value ends with newline.
|
||||
func MarshalAuthorizedKey(key PublicKey) []byte { |
||||
b := &bytes.Buffer{} |
||||
b.WriteString(key.Type()) |
||||
b.WriteByte(' ') |
||||
e := base64.NewEncoder(base64.StdEncoding, b) |
||||
e.Write(key.Marshal()) |
||||
e.Close() |
||||
b.WriteByte('\n') |
||||
return b.Bytes() |
||||
} |
||||
|
||||
// PublicKey is an abstraction of different types of public keys.
|
||||
type PublicKey interface { |
||||
// Type returns the key's type, e.g. "ssh-rsa".
|
||||
Type() string |
||||
|
||||
// Marshal returns the serialized key data in SSH wire format,
|
||||
// with the name prefix.
|
||||
Marshal() []byte |
||||
|
||||
// Verify that sig is a signature on the given data using this
|
||||
// key. This function will hash the data appropriately first.
|
||||
Verify(data []byte, sig *Signature) error |
||||
} |
||||
|
||||
// CryptoPublicKey, if implemented by a PublicKey,
|
||||
// returns the underlying crypto.PublicKey form of the key.
|
||||
type CryptoPublicKey interface { |
||||
CryptoPublicKey() crypto.PublicKey |
||||
} |
||||
|
||||
// A Signer can create signatures that verify against a public key.
|
||||
type Signer interface { |
||||
// PublicKey returns an associated PublicKey instance.
|
||||
PublicKey() PublicKey |
||||
|
||||
// Sign returns raw signature for the given data. This method
|
||||
// will apply the hash specified for the keytype to the data.
|
||||
Sign(rand io.Reader, data []byte) (*Signature, error) |
||||
} |
||||
|
||||
type rsaPublicKey rsa.PublicKey |
||||
|
||||
func (r *rsaPublicKey) Type() string { |
||||
return "ssh-rsa" |
||||
} |
||||
|
||||
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
|
||||
func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
E *big.Int |
||||
N *big.Int |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
if w.E.BitLen() > 24 { |
||||
return nil, nil, errors.New("ssh: exponent too large") |
||||
} |
||||
e := w.E.Int64() |
||||
if e < 3 || e&1 == 0 { |
||||
return nil, nil, errors.New("ssh: incorrect exponent") |
||||
} |
||||
|
||||
var key rsa.PublicKey |
||||
key.E = int(e) |
||||
key.N = w.N |
||||
return (*rsaPublicKey)(&key), w.Rest, nil |
||||
} |
||||
|
||||
func (r *rsaPublicKey) Marshal() []byte { |
||||
e := new(big.Int).SetInt64(int64(r.E)) |
||||
// RSA publickey struct layout should match the struct used by
|
||||
// parseRSACert in the x/crypto/ssh/agent package.
|
||||
wirekey := struct { |
||||
Name string |
||||
E *big.Int |
||||
N *big.Int |
||||
}{ |
||||
KeyAlgoRSA, |
||||
e, |
||||
r.N, |
||||
} |
||||
return Marshal(&wirekey) |
||||
} |
||||
|
||||
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != r.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) |
||||
} |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) |
||||
} |
||||
|
||||
func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { |
||||
return (*rsa.PublicKey)(r) |
||||
} |
||||
|
||||
type dsaPublicKey dsa.PublicKey |
||||
|
||||
func (r *dsaPublicKey) Type() string { |
||||
return "ssh-dss" |
||||
} |
||||
|
||||
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
|
||||
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
P, Q, G, Y *big.Int |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
key := &dsaPublicKey{ |
||||
Parameters: dsa.Parameters{ |
||||
P: w.P, |
||||
Q: w.Q, |
||||
G: w.G, |
||||
}, |
||||
Y: w.Y, |
||||
} |
||||
return key, w.Rest, nil |
||||
} |
||||
|
||||
func (k *dsaPublicKey) Marshal() []byte { |
||||
// DSA publickey struct layout should match the struct used by
|
||||
// parseDSACert in the x/crypto/ssh/agent package.
|
||||
w := struct { |
||||
Name string |
||||
P, Q, G, Y *big.Int |
||||
}{ |
||||
k.Type(), |
||||
k.P, |
||||
k.Q, |
||||
k.G, |
||||
k.Y, |
||||
} |
||||
|
||||
return Marshal(&w) |
||||
} |
||||
|
||||
func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != k.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) |
||||
} |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
|
||||
// Per RFC 4253, section 6.6,
|
||||
// The value for 'dss_signature_blob' is encoded as a string containing
|
||||
// r, followed by s (which are 160-bit integers, without lengths or
|
||||
// padding, unsigned, and in network byte order).
|
||||
// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
|
||||
if len(sig.Blob) != 40 { |
||||
return errors.New("ssh: DSA signature parse error") |
||||
} |
||||
r := new(big.Int).SetBytes(sig.Blob[:20]) |
||||
s := new(big.Int).SetBytes(sig.Blob[20:]) |
||||
if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { |
||||
return nil |
||||
} |
||||
return errors.New("ssh: signature did not verify") |
||||
} |
||||
|
||||
func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey { |
||||
return (*dsa.PublicKey)(k) |
||||
} |
||||
|
||||
type dsaPrivateKey struct { |
||||
*dsa.PrivateKey |
||||
} |
||||
|
||||
func (k *dsaPrivateKey) PublicKey() PublicKey { |
||||
return (*dsaPublicKey)(&k.PrivateKey.PublicKey) |
||||
} |
||||
|
||||
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
r, s, err := dsa.Sign(rand, k.PrivateKey, digest) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
sig := make([]byte, 40) |
||||
rb := r.Bytes() |
||||
sb := s.Bytes() |
||||
|
||||
copy(sig[20-len(rb):20], rb) |
||||
copy(sig[40-len(sb):], sb) |
||||
|
||||
return &Signature{ |
||||
Format: k.PublicKey().Type(), |
||||
Blob: sig, |
||||
}, nil |
||||
} |
||||
|
||||
type ecdsaPublicKey ecdsa.PublicKey |
||||
|
||||
func (key *ecdsaPublicKey) Type() string { |
||||
return "ecdsa-sha2-" + key.nistID() |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) nistID() string { |
||||
switch key.Params().BitSize { |
||||
case 256: |
||||
return "nistp256" |
||||
case 384: |
||||
return "nistp384" |
||||
case 521: |
||||
return "nistp521" |
||||
} |
||||
panic("ssh: unsupported ecdsa key size") |
||||
} |
||||
|
||||
type ed25519PublicKey ed25519.PublicKey |
||||
|
||||
func (key ed25519PublicKey) Type() string { |
||||
return KeyAlgoED25519 |
||||
} |
||||
|
||||
func parseED25519(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
KeyBytes []byte |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
key := ed25519.PublicKey(w.KeyBytes) |
||||
|
||||
return (ed25519PublicKey)(key), w.Rest, nil |
||||
} |
||||
|
||||
func (key ed25519PublicKey) Marshal() []byte { |
||||
w := struct { |
||||
Name string |
||||
KeyBytes []byte |
||||
}{ |
||||
KeyAlgoED25519, |
||||
[]byte(key), |
||||
} |
||||
return Marshal(&w) |
||||
} |
||||
|
||||
func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error { |
||||
if sig.Format != key.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) |
||||
} |
||||
|
||||
edKey := (ed25519.PublicKey)(key) |
||||
if ok := ed25519.Verify(edKey, b, sig.Blob); !ok { |
||||
return errors.New("ssh: signature did not verify") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey { |
||||
return ed25519.PublicKey(k) |
||||
} |
||||
|
||||
func supportedEllipticCurve(curve elliptic.Curve) bool { |
||||
return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() |
||||
} |
||||
|
||||
// ecHash returns the hash to match the given elliptic curve, see RFC
|
||||
// 5656, section 6.2.1
|
||||
func ecHash(curve elliptic.Curve) crypto.Hash { |
||||
bitSize := curve.Params().BitSize |
||||
switch { |
||||
case bitSize <= 256: |
||||
return crypto.SHA256 |
||||
case bitSize <= 384: |
||||
return crypto.SHA384 |
||||
} |
||||
return crypto.SHA512 |
||||
} |
||||
|
||||
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
|
||||
func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
Curve string |
||||
KeyBytes []byte |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
key := new(ecdsa.PublicKey) |
||||
|
||||
switch w.Curve { |
||||
case "nistp256": |
||||
key.Curve = elliptic.P256() |
||||
case "nistp384": |
||||
key.Curve = elliptic.P384() |
||||
case "nistp521": |
||||
key.Curve = elliptic.P521() |
||||
default: |
||||
return nil, nil, errors.New("ssh: unsupported curve") |
||||
} |
||||
|
||||
key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) |
||||
if key.X == nil || key.Y == nil { |
||||
return nil, nil, errors.New("ssh: invalid curve point") |
||||
} |
||||
return (*ecdsaPublicKey)(key), w.Rest, nil |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) Marshal() []byte { |
||||
// See RFC 5656, section 3.1.
|
||||
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) |
||||
// ECDSA publickey struct layout should match the struct used by
|
||||
// parseECDSACert in the x/crypto/ssh/agent package.
|
||||
w := struct { |
||||
Name string |
||||
ID string |
||||
Key []byte |
||||
}{ |
||||
key.Type(), |
||||
key.nistID(), |
||||
keyBytes, |
||||
} |
||||
|
||||
return Marshal(&w) |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != key.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) |
||||
} |
||||
|
||||
h := ecHash(key.Curve).New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
|
||||
// Per RFC 5656, section 3.1.2,
|
||||
// The ecdsa_signature_blob value has the following specific encoding:
|
||||
// mpint r
|
||||
// mpint s
|
||||
var ecSig struct { |
||||
R *big.Int |
||||
S *big.Int |
||||
} |
||||
|
||||
if err := Unmarshal(sig.Blob, &ecSig); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { |
||||
return nil |
||||
} |
||||
return errors.New("ssh: signature did not verify") |
||||
} |
||||
|
||||
func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey { |
||||
return (*ecdsa.PublicKey)(k) |
||||
} |
||||
|
||||
// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
|
||||
// *ecdsa.PrivateKey or any other crypto.Signer and returns a corresponding
|
||||
// Signer instance. ECDSA keys must use P-256, P-384 or P-521.
|
||||
func NewSignerFromKey(key interface{}) (Signer, error) { |
||||
switch key := key.(type) { |
||||
case crypto.Signer: |
||||
return NewSignerFromSigner(key) |
||||
case *dsa.PrivateKey: |
||||
return &dsaPrivateKey{key}, nil |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %T", key) |
||||
} |
||||
} |
||||
|
||||
type wrappedSigner struct { |
||||
signer crypto.Signer |
||||
pubKey PublicKey |
||||
} |
||||
|
||||
// NewSignerFromSigner takes any crypto.Signer implementation and
|
||||
// returns a corresponding Signer interface. This can be used, for
|
||||
// example, with keys kept in hardware modules.
|
||||
func NewSignerFromSigner(signer crypto.Signer) (Signer, error) { |
||||
pubKey, err := NewPublicKey(signer.Public()) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &wrappedSigner{signer, pubKey}, nil |
||||
} |
||||
|
||||
func (s *wrappedSigner) PublicKey() PublicKey { |
||||
return s.pubKey |
||||
} |
||||
|
||||
func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
var hashFunc crypto.Hash |
||||
|
||||
switch key := s.pubKey.(type) { |
||||
case *rsaPublicKey, *dsaPublicKey: |
||||
hashFunc = crypto.SHA1 |
||||
case *ecdsaPublicKey: |
||||
hashFunc = ecHash(key.Curve) |
||||
case ed25519PublicKey: |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %T", key) |
||||
} |
||||
|
||||
var digest []byte |
||||
if hashFunc != 0 { |
||||
h := hashFunc.New() |
||||
h.Write(data) |
||||
digest = h.Sum(nil) |
||||
} else { |
||||
digest = data |
||||
} |
||||
|
||||
signature, err := s.signer.Sign(rand, digest, hashFunc) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// crypto.Signer.Sign is expected to return an ASN.1-encoded signature
|
||||
// for ECDSA and DSA, but that's not the encoding expected by SSH, so
|
||||
// re-encode.
|
||||
switch s.pubKey.(type) { |
||||
case *ecdsaPublicKey, *dsaPublicKey: |
||||
type asn1Signature struct { |
||||
R, S *big.Int |
||||
} |
||||
asn1Sig := new(asn1Signature) |
||||
_, err := asn1.Unmarshal(signature, asn1Sig) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
switch s.pubKey.(type) { |
||||
case *ecdsaPublicKey: |
||||
signature = Marshal(asn1Sig) |
||||
|
||||
case *dsaPublicKey: |
||||
signature = make([]byte, 40) |
||||
r := asn1Sig.R.Bytes() |
||||
s := asn1Sig.S.Bytes() |
||||
copy(signature[20-len(r):20], r) |
||||
copy(signature[40-len(s):40], s) |
||||
} |
||||
} |
||||
|
||||
return &Signature{ |
||||
Format: s.pubKey.Type(), |
||||
Blob: signature, |
||||
}, nil |
||||
} |
||||
|
||||
// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey,
|
||||
// or ed25519.PublicKey returns a corresponding PublicKey instance.
|
||||
// ECDSA keys must use P-256, P-384 or P-521.
|
||||
func NewPublicKey(key interface{}) (PublicKey, error) { |
||||
switch key := key.(type) { |
||||
case *rsa.PublicKey: |
||||
return (*rsaPublicKey)(key), nil |
||||
case *ecdsa.PublicKey: |
||||
if !supportedEllipticCurve(key.Curve) { |
||||
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.") |
||||
} |
||||
return (*ecdsaPublicKey)(key), nil |
||||
case *dsa.PublicKey: |
||||
return (*dsaPublicKey)(key), nil |
||||
case ed25519.PublicKey: |
||||
return (ed25519PublicKey)(key), nil |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %T", key) |
||||
} |
||||
} |
||||
|
||||
// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
|
||||
// the same keys as ParseRawPrivateKey.
|
||||
func ParsePrivateKey(pemBytes []byte) (Signer, error) { |
||||
key, err := ParseRawPrivateKey(pemBytes) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return NewSignerFromKey(key) |
||||
} |
||||
|
||||
// encryptedBlock tells whether a private key is
|
||||
// encrypted by examining its Proc-Type header
|
||||
// for a mention of ENCRYPTED
|
||||
// according to RFC 1421 Section 4.6.1.1.
|
||||
func encryptedBlock(block *pem.Block) bool { |
||||
return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED") |
||||
} |
||||
|
||||
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
|
||||
// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
|
||||
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { |
||||
block, _ := pem.Decode(pemBytes) |
||||
if block == nil { |
||||
return nil, errors.New("ssh: no key found") |
||||
} |
||||
|
||||
if encryptedBlock(block) { |
||||
return nil, errors.New("ssh: cannot decode encrypted private keys") |
||||
} |
||||
|
||||
switch block.Type { |
||||
case "RSA PRIVATE KEY": |
||||
return x509.ParsePKCS1PrivateKey(block.Bytes) |
||||
case "EC PRIVATE KEY": |
||||
return x509.ParseECPrivateKey(block.Bytes) |
||||
case "DSA PRIVATE KEY": |
||||
return ParseDSAPrivateKey(block.Bytes) |
||||
case "OPENSSH PRIVATE KEY": |
||||
return parseOpenSSHPrivateKey(block.Bytes) |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) |
||||
} |
||||
} |
||||
|
||||
// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
|
||||
// specified by the OpenSSL DSA man page.
|
||||
func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { |
||||
var k struct { |
||||
Version int |
||||
P *big.Int |
||||
Q *big.Int |
||||
G *big.Int |
||||
Pub *big.Int |
||||
Priv *big.Int |
||||
} |
||||
rest, err := asn1.Unmarshal(der, &k) |
||||
if err != nil { |
||||
return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) |
||||
} |
||||
if len(rest) > 0 { |
||||
return nil, errors.New("ssh: garbage after DSA key") |
||||
} |
||||
|
||||
return &dsa.PrivateKey{ |
||||
PublicKey: dsa.PublicKey{ |
||||
Parameters: dsa.Parameters{ |
||||
P: k.P, |
||||
Q: k.Q, |
||||
G: k.G, |
||||
}, |
||||
Y: k.Pub, |
||||
}, |
||||
X: k.Priv, |
||||
}, nil |
||||
} |
||||
|
||||
// Implemented based on the documentation at
|
||||
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
|
||||
func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { |
||||
magic := append([]byte("openssh-key-v1"), 0) |
||||
if !bytes.Equal(magic, key[0:len(magic)]) { |
||||
return nil, errors.New("ssh: invalid openssh private key format") |
||||
} |
||||
remaining := key[len(magic):] |
||||
|
||||
var w struct { |
||||
CipherName string |
||||
KdfName string |
||||
KdfOpts string |
||||
NumKeys uint32 |
||||
PubKey []byte |
||||
PrivKeyBlock []byte |
||||
} |
||||
|
||||
if err := Unmarshal(remaining, &w); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
pk1 := struct { |
||||
Check1 uint32 |
||||
Check2 uint32 |
||||
Keytype string |
||||
Pub []byte |
||||
Priv []byte |
||||
Comment string |
||||
Pad []byte `ssh:"rest"` |
||||
}{} |
||||
|
||||
if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if pk1.Check1 != pk1.Check2 { |
||||
return nil, errors.New("ssh: checkint mismatch") |
||||
} |
||||
|
||||
// we only handle ed25519 keys currently
|
||||
if pk1.Keytype != KeyAlgoED25519 { |
||||
return nil, errors.New("ssh: unhandled key type") |
||||
} |
||||
|
||||
for i, b := range pk1.Pad { |
||||
if int(b) != i+1 { |
||||
return nil, errors.New("ssh: padding not as expected") |
||||
} |
||||
} |
||||
|
||||
if len(pk1.Priv) != ed25519.PrivateKeySize { |
||||
return nil, errors.New("ssh: private key unexpected length") |
||||
} |
||||
|
||||
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) |
||||
copy(pk, pk1.Priv) |
||||
return &pk, nil |
||||
} |
||||
|
||||
// FingerprintLegacyMD5 returns the user presentation of the key's
|
||||
// fingerprint as described by RFC 4716 section 4.
|
||||
func FingerprintLegacyMD5(pubKey PublicKey) string { |
||||
md5sum := md5.Sum(pubKey.Marshal()) |
||||
hexarray := make([]string, len(md5sum)) |
||||
for i, c := range md5sum { |
||||
hexarray[i] = hex.EncodeToString([]byte{c}) |
||||
} |
||||
return strings.Join(hexarray, ":") |
||||
} |
||||
|
||||
// FingerprintSHA256 returns the user presentation of the key's
|
||||
// fingerprint as unpadded base64 encoded sha256 hash.
|
||||
// This format was introduced from OpenSSH 6.8.
|
||||
// https://www.openssh.com/txt/release-6.8
|
||||
// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding)
|
||||
func FingerprintSHA256(pubKey PublicKey) string { |
||||
sha256sum := sha256.Sum256(pubKey.Marshal()) |
||||
hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) |
||||
return "SHA256:" + hash |
||||
} |
@ -0,0 +1,61 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
// Message authentication support
|
||||
|
||||
import ( |
||||
"crypto/hmac" |
||||
"crypto/sha1" |
||||
"crypto/sha256" |
||||
"hash" |
||||
) |
||||
|
||||
type macMode struct { |
||||
keySize int |
||||
etm bool |
||||
new func(key []byte) hash.Hash |
||||
} |
||||
|
||||
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
|
||||
// a given size.
|
||||
type truncatingMAC struct { |
||||
length int |
||||
hmac hash.Hash |
||||
} |
||||
|
||||
func (t truncatingMAC) Write(data []byte) (int, error) { |
||||
return t.hmac.Write(data) |
||||
} |
||||
|
||||
func (t truncatingMAC) Sum(in []byte) []byte { |
||||
out := t.hmac.Sum(in) |
||||
return out[:len(in)+t.length] |
||||
} |
||||
|
||||
func (t truncatingMAC) Reset() { |
||||
t.hmac.Reset() |
||||
} |
||||
|
||||
func (t truncatingMAC) Size() int { |
||||
return t.length |
||||
} |
||||
|
||||
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } |
||||
|
||||
var macModes = map[string]*macMode{ |
||||
"hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash { |
||||
return hmac.New(sha256.New, key) |
||||
}}, |
||||
"hmac-sha2-256": {32, false, func(key []byte) hash.Hash { |
||||
return hmac.New(sha256.New, key) |
||||
}}, |
||||
"hmac-sha1": {20, false, func(key []byte) hash.Hash { |
||||
return hmac.New(sha1.New, key) |
||||
}}, |
||||
"hmac-sha1-96": {20, false, func(key []byte) hash.Hash { |
||||
return truncatingMAC{12, hmac.New(sha1.New, key)} |
||||
}}, |
||||
} |
@ -0,0 +1,758 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"math/big" |
||||
"reflect" |
||||
"strconv" |
||||
"strings" |
||||
) |
||||
|
||||
// These are SSH message type numbers. They are scattered around several
|
||||
// documents but many were taken from [SSH-PARAMETERS].
|
||||
const ( |
||||
msgIgnore = 2 |
||||
msgUnimplemented = 3 |
||||
msgDebug = 4 |
||||
msgNewKeys = 21 |
||||
|
||||
// Standard authentication messages
|
||||
msgUserAuthSuccess = 52 |
||||
msgUserAuthBanner = 53 |
||||
) |
||||
|
||||
// SSH messages:
|
||||
//
|
||||
// These structures mirror the wire format of the corresponding SSH messages.
|
||||
// They are marshaled using reflection with the marshal and unmarshal functions
|
||||
// in this file. The only wrinkle is that a final member of type []byte with a
|
||||
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
|
||||
|
||||
// See RFC 4253, section 11.1.
|
||||
const msgDisconnect = 1 |
||||
|
||||
// disconnectMsg is the message that signals a disconnect. It is also
|
||||
// the error type returned from mux.Wait()
|
||||
type disconnectMsg struct { |
||||
Reason uint32 `sshtype:"1"` |
||||
Message string |
||||
Language string |
||||
} |
||||
|
||||
func (d *disconnectMsg) Error() string { |
||||
return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) |
||||
} |
||||
|
||||
// See RFC 4253, section 7.1.
|
||||
const msgKexInit = 20 |
||||
|
||||
type kexInitMsg struct { |
||||
Cookie [16]byte `sshtype:"20"` |
||||
KexAlgos []string |
||||
ServerHostKeyAlgos []string |
||||
CiphersClientServer []string |
||||
CiphersServerClient []string |
||||
MACsClientServer []string |
||||
MACsServerClient []string |
||||
CompressionClientServer []string |
||||
CompressionServerClient []string |
||||
LanguagesClientServer []string |
||||
LanguagesServerClient []string |
||||
FirstKexFollows bool |
||||
Reserved uint32 |
||||
} |
||||
|
||||
// See RFC 4253, section 8.
|
||||
|
||||
// Diffie-Helman
|
||||
const msgKexDHInit = 30 |
||||
|
||||
type kexDHInitMsg struct { |
||||
X *big.Int `sshtype:"30"` |
||||
} |
||||
|
||||
const msgKexECDHInit = 30 |
||||
|
||||
type kexECDHInitMsg struct { |
||||
ClientPubKey []byte `sshtype:"30"` |
||||
} |
||||
|
||||
const msgKexECDHReply = 31 |
||||
|
||||
type kexECDHReplyMsg struct { |
||||
HostKey []byte `sshtype:"31"` |
||||
EphemeralPubKey []byte |
||||
Signature []byte |
||||
} |
||||
|
||||
const msgKexDHReply = 31 |
||||
|
||||
type kexDHReplyMsg struct { |
||||
HostKey []byte `sshtype:"31"` |
||||
Y *big.Int |
||||
Signature []byte |
||||
} |
||||
|
||||
// See RFC 4253, section 10.
|
||||
const msgServiceRequest = 5 |
||||
|
||||
type serviceRequestMsg struct { |
||||
Service string `sshtype:"5"` |
||||
} |
||||
|
||||
// See RFC 4253, section 10.
|
||||
const msgServiceAccept = 6 |
||||
|
||||
type serviceAcceptMsg struct { |
||||
Service string `sshtype:"6"` |
||||
} |
||||
|
||||
// See RFC 4252, section 5.
|
||||
const msgUserAuthRequest = 50 |
||||
|
||||
type userAuthRequestMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
Payload []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// Used for debug printouts of packets.
|
||||
type userAuthSuccessMsg struct { |
||||
} |
||||
|
||||
// See RFC 4252, section 5.1
|
||||
const msgUserAuthFailure = 51 |
||||
|
||||
type userAuthFailureMsg struct { |
||||
Methods []string `sshtype:"51"` |
||||
PartialSuccess bool |
||||
} |
||||
|
||||
// See RFC 4256, section 3.2
|
||||
const msgUserAuthInfoRequest = 60 |
||||
const msgUserAuthInfoResponse = 61 |
||||
|
||||
type userAuthInfoRequestMsg struct { |
||||
User string `sshtype:"60"` |
||||
Instruction string |
||||
DeprecatedLanguage string |
||||
NumPrompts uint32 |
||||
Prompts []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpen = 90 |
||||
|
||||
type channelOpenMsg struct { |
||||
ChanType string `sshtype:"90"` |
||||
PeersId uint32 |
||||
PeersWindow uint32 |
||||
MaxPacketSize uint32 |
||||
TypeSpecificData []byte `ssh:"rest"` |
||||
} |
||||
|
||||
const msgChannelExtendedData = 95 |
||||
const msgChannelData = 94 |
||||
|
||||
// Used for debug print outs of packets.
|
||||
type channelDataMsg struct { |
||||
PeersId uint32 `sshtype:"94"` |
||||
Length uint32 |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpenConfirm = 91 |
||||
|
||||
type channelOpenConfirmMsg struct { |
||||
PeersId uint32 `sshtype:"91"` |
||||
MyId uint32 |
||||
MyWindow uint32 |
||||
MaxPacketSize uint32 |
||||
TypeSpecificData []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpenFailure = 92 |
||||
|
||||
type channelOpenFailureMsg struct { |
||||
PeersId uint32 `sshtype:"92"` |
||||
Reason RejectionReason |
||||
Message string |
||||
Language string |
||||
} |
||||
|
||||
const msgChannelRequest = 98 |
||||
|
||||
type channelRequestMsg struct { |
||||
PeersId uint32 `sshtype:"98"` |
||||
Request string |
||||
WantReply bool |
||||
RequestSpecificData []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelSuccess = 99 |
||||
|
||||
type channelRequestSuccessMsg struct { |
||||
PeersId uint32 `sshtype:"99"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelFailure = 100 |
||||
|
||||
type channelRequestFailureMsg struct { |
||||
PeersId uint32 `sshtype:"100"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelClose = 97 |
||||
|
||||
type channelCloseMsg struct { |
||||
PeersId uint32 `sshtype:"97"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelEOF = 96 |
||||
|
||||
type channelEOFMsg struct { |
||||
PeersId uint32 `sshtype:"96"` |
||||
} |
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgGlobalRequest = 80 |
||||
|
||||
type globalRequestMsg struct { |
||||
Type string `sshtype:"80"` |
||||
WantReply bool |
||||
Data []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgRequestSuccess = 81 |
||||
|
||||
type globalRequestSuccessMsg struct { |
||||
Data []byte `ssh:"rest" sshtype:"81"` |
||||
} |
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgRequestFailure = 82 |
||||
|
||||
type globalRequestFailureMsg struct { |
||||
Data []byte `ssh:"rest" sshtype:"82"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.2
|
||||
const msgChannelWindowAdjust = 93 |
||||
|
||||
type windowAdjustMsg struct { |
||||
PeersId uint32 `sshtype:"93"` |
||||
AdditionalBytes uint32 |
||||
} |
||||
|
||||
// See RFC 4252, section 7
|
||||
const msgUserAuthPubKeyOk = 60 |
||||
|
||||
type userAuthPubKeyOkMsg struct { |
||||
Algo string `sshtype:"60"` |
||||
PubKey []byte |
||||
} |
||||
|
||||
// typeTags returns the possible type bytes for the given reflect.Type, which
|
||||
// should be a struct. The possible values are separated by a '|' character.
|
||||
func typeTags(structType reflect.Type) (tags []byte) { |
||||
tagStr := structType.Field(0).Tag.Get("sshtype") |
||||
|
||||
for _, tag := range strings.Split(tagStr, "|") { |
||||
i, err := strconv.Atoi(tag) |
||||
if err == nil { |
||||
tags = append(tags, byte(i)) |
||||
} |
||||
} |
||||
|
||||
return tags |
||||
} |
||||
|
||||
func fieldError(t reflect.Type, field int, problem string) error { |
||||
if problem != "" { |
||||
problem = ": " + problem |
||||
} |
||||
return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) |
||||
} |
||||
|
||||
var errShortRead = errors.New("ssh: short read") |
||||
|
||||
// Unmarshal parses data in SSH wire format into a structure. The out
|
||||
// argument should be a pointer to struct. If the first member of the
|
||||
// struct has the "sshtype" tag set to a '|'-separated set of numbers
|
||||
// in decimal, the packet must start with one of those numbers. In
|
||||
// case of error, Unmarshal returns a ParseError or
|
||||
// UnexpectedMessageError.
|
||||
func Unmarshal(data []byte, out interface{}) error { |
||||
v := reflect.ValueOf(out).Elem() |
||||
structType := v.Type() |
||||
expectedTypes := typeTags(structType) |
||||
|
||||
var expectedType byte |
||||
if len(expectedTypes) > 0 { |
||||
expectedType = expectedTypes[0] |
||||
} |
||||
|
||||
if len(data) == 0 { |
||||
return parseError(expectedType) |
||||
} |
||||
|
||||
if len(expectedTypes) > 0 { |
||||
goodType := false |
||||
for _, e := range expectedTypes { |
||||
if e > 0 && data[0] == e { |
||||
goodType = true |
||||
break |
||||
} |
||||
} |
||||
if !goodType { |
||||
return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes) |
||||
} |
||||
data = data[1:] |
||||
} |
||||
|
||||
var ok bool |
||||
for i := 0; i < v.NumField(); i++ { |
||||
field := v.Field(i) |
||||
t := field.Type() |
||||
switch t.Kind() { |
||||
case reflect.Bool: |
||||
if len(data) < 1 { |
||||
return errShortRead |
||||
} |
||||
field.SetBool(data[0] != 0) |
||||
data = data[1:] |
||||
case reflect.Array: |
||||
if t.Elem().Kind() != reflect.Uint8 { |
||||
return fieldError(structType, i, "array of unsupported type") |
||||
} |
||||
if len(data) < t.Len() { |
||||
return errShortRead |
||||
} |
||||
for j, n := 0, t.Len(); j < n; j++ { |
||||
field.Index(j).Set(reflect.ValueOf(data[j])) |
||||
} |
||||
data = data[t.Len():] |
||||
case reflect.Uint64: |
||||
var u64 uint64 |
||||
if u64, data, ok = parseUint64(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.SetUint(u64) |
||||
case reflect.Uint32: |
||||
var u32 uint32 |
||||
if u32, data, ok = parseUint32(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.SetUint(uint64(u32)) |
||||
case reflect.Uint8: |
||||
if len(data) < 1 { |
||||
return errShortRead |
||||
} |
||||
field.SetUint(uint64(data[0])) |
||||
data = data[1:] |
||||
case reflect.String: |
||||
var s []byte |
||||
if s, data, ok = parseString(data); !ok { |
||||
return fieldError(structType, i, "") |
||||
} |
||||
field.SetString(string(s)) |
||||
case reflect.Slice: |
||||
switch t.Elem().Kind() { |
||||
case reflect.Uint8: |
||||
if structType.Field(i).Tag.Get("ssh") == "rest" { |
||||
field.Set(reflect.ValueOf(data)) |
||||
data = nil |
||||
} else { |
||||
var s []byte |
||||
if s, data, ok = parseString(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.Set(reflect.ValueOf(s)) |
||||
} |
||||
case reflect.String: |
||||
var nl []string |
||||
if nl, data, ok = parseNameList(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.Set(reflect.ValueOf(nl)) |
||||
default: |
||||
return fieldError(structType, i, "slice of unsupported type") |
||||
} |
||||
case reflect.Ptr: |
||||
if t == bigIntType { |
||||
var n *big.Int |
||||
if n, data, ok = parseInt(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.Set(reflect.ValueOf(n)) |
||||
} else { |
||||
return fieldError(structType, i, "pointer to unsupported type") |
||||
} |
||||
default: |
||||
return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t)) |
||||
} |
||||
} |
||||
|
||||
if len(data) != 0 { |
||||
return parseError(expectedType) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// Marshal serializes the message in msg to SSH wire format. The msg
|
||||
// argument should be a struct or pointer to struct. If the first
|
||||
// member has the "sshtype" tag set to a number in decimal, that
|
||||
// number is prepended to the result. If the last of member has the
|
||||
// "ssh" tag set to "rest", its contents are appended to the output.
|
||||
func Marshal(msg interface{}) []byte { |
||||
out := make([]byte, 0, 64) |
||||
return marshalStruct(out, msg) |
||||
} |
||||
|
||||
func marshalStruct(out []byte, msg interface{}) []byte { |
||||
v := reflect.Indirect(reflect.ValueOf(msg)) |
||||
msgTypes := typeTags(v.Type()) |
||||
if len(msgTypes) > 0 { |
||||
out = append(out, msgTypes[0]) |
||||
} |
||||
|
||||
for i, n := 0, v.NumField(); i < n; i++ { |
||||
field := v.Field(i) |
||||
switch t := field.Type(); t.Kind() { |
||||
case reflect.Bool: |
||||
var v uint8 |
||||
if field.Bool() { |
||||
v = 1 |
||||
} |
||||
out = append(out, v) |
||||
case reflect.Array: |
||||
if t.Elem().Kind() != reflect.Uint8 { |
||||
panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) |
||||
} |
||||
for j, l := 0, t.Len(); j < l; j++ { |
||||
out = append(out, uint8(field.Index(j).Uint())) |
||||
} |
||||
case reflect.Uint32: |
||||
out = appendU32(out, uint32(field.Uint())) |
||||
case reflect.Uint64: |
||||
out = appendU64(out, uint64(field.Uint())) |
||||
case reflect.Uint8: |
||||
out = append(out, uint8(field.Uint())) |
||||
case reflect.String: |
||||
s := field.String() |
||||
out = appendInt(out, len(s)) |
||||
out = append(out, s...) |
||||
case reflect.Slice: |
||||
switch t.Elem().Kind() { |
||||
case reflect.Uint8: |
||||
if v.Type().Field(i).Tag.Get("ssh") != "rest" { |
||||
out = appendInt(out, field.Len()) |
||||
} |
||||
out = append(out, field.Bytes()...) |
||||
case reflect.String: |
||||
offset := len(out) |
||||
out = appendU32(out, 0) |
||||
if n := field.Len(); n > 0 { |
||||
for j := 0; j < n; j++ { |
||||
f := field.Index(j) |
||||
if j != 0 { |
||||
out = append(out, ',') |
||||
} |
||||
out = append(out, f.String()...) |
||||
} |
||||
// overwrite length value
|
||||
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) |
||||
} |
||||
default: |
||||
panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) |
||||
} |
||||
case reflect.Ptr: |
||||
if t == bigIntType { |
||||
var n *big.Int |
||||
nValue := reflect.ValueOf(&n) |
||||
nValue.Elem().Set(field) |
||||
needed := intLength(n) |
||||
oldLength := len(out) |
||||
|
||||
if cap(out)-len(out) < needed { |
||||
newOut := make([]byte, len(out), 2*(len(out)+needed)) |
||||
copy(newOut, out) |
||||
out = newOut |
||||
} |
||||
out = out[:oldLength+needed] |
||||
marshalInt(out[oldLength:], n) |
||||
} else { |
||||
panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) |
||||
} |
||||
} |
||||
} |
||||
|
||||
return out |
||||
} |
||||
|
||||
var bigOne = big.NewInt(1) |
||||
|
||||
func parseString(in []byte) (out, rest []byte, ok bool) { |
||||
if len(in) < 4 { |
||||
return |
||||
} |
||||
length := binary.BigEndian.Uint32(in) |
||||
in = in[4:] |
||||
if uint32(len(in)) < length { |
||||
return |
||||
} |
||||
out = in[:length] |
||||
rest = in[length:] |
||||
ok = true |
||||
return |
||||
} |
||||
|
||||
var ( |
||||
comma = []byte{','} |
||||
emptyNameList = []string{} |
||||
) |
||||
|
||||
func parseNameList(in []byte) (out []string, rest []byte, ok bool) { |
||||
contents, rest, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
if len(contents) == 0 { |
||||
out = emptyNameList |
||||
return |
||||
} |
||||
parts := bytes.Split(contents, comma) |
||||
out = make([]string, len(parts)) |
||||
for i, part := range parts { |
||||
out[i] = string(part) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { |
||||
contents, rest, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
out = new(big.Int) |
||||
|
||||
if len(contents) > 0 && contents[0]&0x80 == 0x80 { |
||||
// This is a negative number
|
||||
notBytes := make([]byte, len(contents)) |
||||
for i := range notBytes { |
||||
notBytes[i] = ^contents[i] |
||||
} |
||||
out.SetBytes(notBytes) |
||||
out.Add(out, bigOne) |
||||
out.Neg(out) |
||||
} else { |
||||
// Positive number
|
||||
out.SetBytes(contents) |
||||
} |
||||
ok = true |
||||
return |
||||
} |
||||
|
||||
func parseUint32(in []byte) (uint32, []byte, bool) { |
||||
if len(in) < 4 { |
||||
return 0, nil, false |
||||
} |
||||
return binary.BigEndian.Uint32(in), in[4:], true |
||||
} |
||||
|
||||
func parseUint64(in []byte) (uint64, []byte, bool) { |
||||
if len(in) < 8 { |
||||
return 0, nil, false |
||||
} |
||||
return binary.BigEndian.Uint64(in), in[8:], true |
||||
} |
||||
|
||||
func intLength(n *big.Int) int { |
||||
length := 4 /* length bytes */ |
||||
if n.Sign() < 0 { |
||||
nMinus1 := new(big.Int).Neg(n) |
||||
nMinus1.Sub(nMinus1, bigOne) |
||||
bitLen := nMinus1.BitLen() |
||||
if bitLen%8 == 0 { |
||||
// The number will need 0xff padding
|
||||
length++ |
||||
} |
||||
length += (bitLen + 7) / 8 |
||||
} else if n.Sign() == 0 { |
||||
// A zero is the zero length string
|
||||
} else { |
||||
bitLen := n.BitLen() |
||||
if bitLen%8 == 0 { |
||||
// The number will need 0x00 padding
|
||||
length++ |
||||
} |
||||
length += (bitLen + 7) / 8 |
||||
} |
||||
|
||||
return length |
||||
} |
||||
|
||||
func marshalUint32(to []byte, n uint32) []byte { |
||||
binary.BigEndian.PutUint32(to, n) |
||||
return to[4:] |
||||
} |
||||
|
||||
func marshalUint64(to []byte, n uint64) []byte { |
||||
binary.BigEndian.PutUint64(to, n) |
||||
return to[8:] |
||||
} |
||||
|
||||
func marshalInt(to []byte, n *big.Int) []byte { |
||||
lengthBytes := to |
||||
to = to[4:] |
||||
length := 0 |
||||
|
||||
if n.Sign() < 0 { |
||||
// A negative number has to be converted to two's-complement
|
||||
// form. So we'll subtract 1 and invert. If the
|
||||
// most-significant-bit isn't set then we'll need to pad the
|
||||
// beginning with 0xff in order to keep the number negative.
|
||||
nMinus1 := new(big.Int).Neg(n) |
||||
nMinus1.Sub(nMinus1, bigOne) |
||||
bytes := nMinus1.Bytes() |
||||
for i := range bytes { |
||||
bytes[i] ^= 0xff |
||||
} |
||||
if len(bytes) == 0 || bytes[0]&0x80 == 0 { |
||||
to[0] = 0xff |
||||
to = to[1:] |
||||
length++ |
||||
} |
||||
nBytes := copy(to, bytes) |
||||
to = to[nBytes:] |
||||
length += nBytes |
||||
} else if n.Sign() == 0 { |
||||
// A zero is the zero length string
|
||||
} else { |
||||
bytes := n.Bytes() |
||||
if len(bytes) > 0 && bytes[0]&0x80 != 0 { |
||||
// We'll have to pad this with a 0x00 in order to
|
||||
// stop it looking like a negative number.
|
||||
to[0] = 0 |
||||
to = to[1:] |
||||
length++ |
||||
} |
||||
nBytes := copy(to, bytes) |
||||
to = to[nBytes:] |
||||
length += nBytes |
||||
} |
||||
|
||||
lengthBytes[0] = byte(length >> 24) |
||||
lengthBytes[1] = byte(length >> 16) |
||||
lengthBytes[2] = byte(length >> 8) |
||||
lengthBytes[3] = byte(length) |
||||
return to |
||||
} |
||||
|
||||
func writeInt(w io.Writer, n *big.Int) { |
||||
length := intLength(n) |
||||
buf := make([]byte, length) |
||||
marshalInt(buf, n) |
||||
w.Write(buf) |
||||
} |
||||
|
||||
func writeString(w io.Writer, s []byte) { |
||||
var lengthBytes [4]byte |
||||
lengthBytes[0] = byte(len(s) >> 24) |
||||
lengthBytes[1] = byte(len(s) >> 16) |
||||
lengthBytes[2] = byte(len(s) >> 8) |
||||
lengthBytes[3] = byte(len(s)) |
||||
w.Write(lengthBytes[:]) |
||||
w.Write(s) |
||||
} |
||||
|
||||
func stringLength(n int) int { |
||||
return 4 + n |
||||
} |
||||
|
||||
func marshalString(to []byte, s []byte) []byte { |
||||
to[0] = byte(len(s) >> 24) |
||||
to[1] = byte(len(s) >> 16) |
||||
to[2] = byte(len(s) >> 8) |
||||
to[3] = byte(len(s)) |
||||
to = to[4:] |
||||
copy(to, s) |
||||
return to[len(s):] |
||||
} |
||||
|
||||
var bigIntType = reflect.TypeOf((*big.Int)(nil)) |
||||
|
||||
// Decode a packet into its corresponding message.
|
||||
func decode(packet []byte) (interface{}, error) { |
||||
var msg interface{} |
||||
switch packet[0] { |
||||
case msgDisconnect: |
||||
msg = new(disconnectMsg) |
||||
case msgServiceRequest: |
||||
msg = new(serviceRequestMsg) |
||||
case msgServiceAccept: |
||||
msg = new(serviceAcceptMsg) |
||||
case msgKexInit: |
||||
msg = new(kexInitMsg) |
||||
case msgKexDHInit: |
||||
msg = new(kexDHInitMsg) |
||||
case msgKexDHReply: |
||||
msg = new(kexDHReplyMsg) |
||||
case msgUserAuthRequest: |
||||
msg = new(userAuthRequestMsg) |
||||
case msgUserAuthSuccess: |
||||
return new(userAuthSuccessMsg), nil |
||||
case msgUserAuthFailure: |
||||
msg = new(userAuthFailureMsg) |
||||
case msgUserAuthPubKeyOk: |
||||
msg = new(userAuthPubKeyOkMsg) |
||||
case msgGlobalRequest: |
||||
msg = new(globalRequestMsg) |
||||
case msgRequestSuccess: |
||||
msg = new(globalRequestSuccessMsg) |
||||
case msgRequestFailure: |
||||
msg = new(globalRequestFailureMsg) |
||||
case msgChannelOpen: |
||||
msg = new(channelOpenMsg) |
||||
case msgChannelData: |
||||
msg = new(channelDataMsg) |
||||
case msgChannelOpenConfirm: |
||||
msg = new(channelOpenConfirmMsg) |
||||
case msgChannelOpenFailure: |
||||
msg = new(channelOpenFailureMsg) |
||||
case msgChannelWindowAdjust: |
||||
msg = new(windowAdjustMsg) |
||||
case msgChannelEOF: |
||||
msg = new(channelEOFMsg) |
||||
case msgChannelClose: |
||||
msg = new(channelCloseMsg) |
||||
case msgChannelRequest: |
||||
msg = new(channelRequestMsg) |
||||
case msgChannelSuccess: |
||||
msg = new(channelRequestSuccessMsg) |
||||
case msgChannelFailure: |
||||
msg = new(channelRequestFailureMsg) |
||||
default: |
||||
return nil, unexpectedMessageError(0, packet[0]) |
||||
} |
||||
if err := Unmarshal(packet, msg); err != nil { |
||||
return nil, err |
||||
} |
||||
return msg, nil |
||||
} |
@ -0,0 +1,330 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"sync" |
||||
"sync/atomic" |
||||
) |
||||
|
||||
// debugMux, if set, causes messages in the connection protocol to be
|
||||
// logged.
|
||||
const debugMux = false |
||||
|
||||
// chanList is a thread safe channel list.
|
||||
type chanList struct { |
||||
// protects concurrent access to chans
|
||||
sync.Mutex |
||||
|
||||
// chans are indexed by the local id of the channel, which the
|
||||
// other side should send in the PeersId field.
|
||||
chans []*channel |
||||
|
||||
// This is a debugging aid: it offsets all IDs by this
|
||||
// amount. This helps distinguish otherwise identical
|
||||
// server/client muxes
|
||||
offset uint32 |
||||
} |
||||
|
||||
// Assigns a channel ID to the given channel.
|
||||
func (c *chanList) add(ch *channel) uint32 { |
||||
c.Lock() |
||||
defer c.Unlock() |
||||
for i := range c.chans { |
||||
if c.chans[i] == nil { |
||||
c.chans[i] = ch |
||||
return uint32(i) + c.offset |
||||
} |
||||
} |
||||
c.chans = append(c.chans, ch) |
||||
return uint32(len(c.chans)-1) + c.offset |
||||
} |
||||
|
||||
// getChan returns the channel for the given ID.
|
||||
func (c *chanList) getChan(id uint32) *channel { |
||||
id -= c.offset |
||||
|
||||
c.Lock() |
||||
defer c.Unlock() |
||||
if id < uint32(len(c.chans)) { |
||||
return c.chans[id] |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *chanList) remove(id uint32) { |
||||
id -= c.offset |
||||
c.Lock() |
||||
if id < uint32(len(c.chans)) { |
||||
c.chans[id] = nil |
||||
} |
||||
c.Unlock() |
||||
} |
||||
|
||||
// dropAll forgets all channels it knows, returning them in a slice.
|
||||
func (c *chanList) dropAll() []*channel { |
||||
c.Lock() |
||||
defer c.Unlock() |
||||
var r []*channel |
||||
|
||||
for _, ch := range c.chans { |
||||
if ch == nil { |
||||
continue |
||||
} |
||||
r = append(r, ch) |
||||
} |
||||
c.chans = nil |
||||
return r |
||||
} |
||||
|
||||
// mux represents the state for the SSH connection protocol, which
|
||||
// multiplexes many channels onto a single packet transport.
|
||||
type mux struct { |
||||
conn packetConn |
||||
chanList chanList |
||||
|
||||
incomingChannels chan NewChannel |
||||
|
||||
globalSentMu sync.Mutex |
||||
globalResponses chan interface{} |
||||
incomingRequests chan *Request |
||||
|
||||
errCond *sync.Cond |
||||
err error |
||||
} |
||||
|
||||
// When debugging, each new chanList instantiation has a different
|
||||
// offset.
|
||||
var globalOff uint32 |
||||
|
||||
func (m *mux) Wait() error { |
||||
m.errCond.L.Lock() |
||||
defer m.errCond.L.Unlock() |
||||
for m.err == nil { |
||||
m.errCond.Wait() |
||||
} |
||||
return m.err |
||||
} |
||||
|
||||
// newMux returns a mux that runs over the given connection.
|
||||
func newMux(p packetConn) *mux { |
||||
m := &mux{ |
||||
conn: p, |
||||
incomingChannels: make(chan NewChannel, chanSize), |
||||
globalResponses: make(chan interface{}, 1), |
||||
incomingRequests: make(chan *Request, chanSize), |
||||
errCond: newCond(), |
||||
} |
||||
if debugMux { |
||||
m.chanList.offset = atomic.AddUint32(&globalOff, 1) |
||||
} |
||||
|
||||
go m.loop() |
||||
return m |
||||
} |
||||
|
||||
func (m *mux) sendMessage(msg interface{}) error { |
||||
p := Marshal(msg) |
||||
if debugMux { |
||||
log.Printf("send global(%d): %#v", m.chanList.offset, msg) |
||||
} |
||||
return m.conn.writePacket(p) |
||||
} |
||||
|
||||
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { |
||||
if wantReply { |
||||
m.globalSentMu.Lock() |
||||
defer m.globalSentMu.Unlock() |
||||
} |
||||
|
||||
if err := m.sendMessage(globalRequestMsg{ |
||||
Type: name, |
||||
WantReply: wantReply, |
||||
Data: payload, |
||||
}); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
if !wantReply { |
||||
return false, nil, nil |
||||
} |
||||
|
||||
msg, ok := <-m.globalResponses |
||||
if !ok { |
||||
return false, nil, io.EOF |
||||
} |
||||
switch msg := msg.(type) { |
||||
case *globalRequestFailureMsg: |
||||
return false, msg.Data, nil |
||||
case *globalRequestSuccessMsg: |
||||
return true, msg.Data, nil |
||||
default: |
||||
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) |
||||
} |
||||
} |
||||
|
||||
// ackRequest must be called after processing a global request that
|
||||
// has WantReply set.
|
||||
func (m *mux) ackRequest(ok bool, data []byte) error { |
||||
if ok { |
||||
return m.sendMessage(globalRequestSuccessMsg{Data: data}) |
||||
} |
||||
return m.sendMessage(globalRequestFailureMsg{Data: data}) |
||||
} |
||||
|
||||
func (m *mux) Close() error { |
||||
return m.conn.Close() |
||||
} |
||||
|
||||
// loop runs the connection machine. It will process packets until an
|
||||
// error is encountered. To synchronize on loop exit, use mux.Wait.
|
||||
func (m *mux) loop() { |
||||
var err error |
||||
for err == nil { |
||||
err = m.onePacket() |
||||
} |
||||
|
||||
for _, ch := range m.chanList.dropAll() { |
||||
ch.close() |
||||
} |
||||
|
||||
close(m.incomingChannels) |
||||
close(m.incomingRequests) |
||||
close(m.globalResponses) |
||||
|
||||
m.conn.Close() |
||||
|
||||
m.errCond.L.Lock() |
||||
m.err = err |
||||
m.errCond.Broadcast() |
||||
m.errCond.L.Unlock() |
||||
|
||||
if debugMux { |
||||
log.Println("loop exit", err) |
||||
} |
||||
} |
||||
|
||||
// onePacket reads and processes one packet.
|
||||
func (m *mux) onePacket() error { |
||||
packet, err := m.conn.readPacket() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if debugMux { |
||||
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { |
||||
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) |
||||
} else { |
||||
p, _ := decode(packet) |
||||
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) |
||||
} |
||||
} |
||||
|
||||
switch packet[0] { |
||||
case msgChannelOpen: |
||||
return m.handleChannelOpen(packet) |
||||
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: |
||||
return m.handleGlobalPacket(packet) |
||||
} |
||||
|
||||
// assume a channel packet.
|
||||
if len(packet) < 5 { |
||||
return parseError(packet[0]) |
||||
} |
||||
id := binary.BigEndian.Uint32(packet[1:]) |
||||
ch := m.chanList.getChan(id) |
||||
if ch == nil { |
||||
return fmt.Errorf("ssh: invalid channel %d", id) |
||||
} |
||||
|
||||
return ch.handlePacket(packet) |
||||
} |
||||
|
||||
func (m *mux) handleGlobalPacket(packet []byte) error { |
||||
msg, err := decode(packet) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
switch msg := msg.(type) { |
||||
case *globalRequestMsg: |
||||
m.incomingRequests <- &Request{ |
||||
Type: msg.Type, |
||||
WantReply: msg.WantReply, |
||||
Payload: msg.Data, |
||||
mux: m, |
||||
} |
||||
case *globalRequestSuccessMsg, *globalRequestFailureMsg: |
||||
m.globalResponses <- msg |
||||
default: |
||||
panic(fmt.Sprintf("not a global message %#v", msg)) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// handleChannelOpen schedules a channel to be Accept()ed.
|
||||
func (m *mux) handleChannelOpen(packet []byte) error { |
||||
var msg channelOpenMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
||||
failMsg := channelOpenFailureMsg{ |
||||
PeersId: msg.PeersId, |
||||
Reason: ConnectionFailed, |
||||
Message: "invalid request", |
||||
Language: "en_US.UTF-8", |
||||
} |
||||
return m.sendMessage(failMsg) |
||||
} |
||||
|
||||
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) |
||||
c.remoteId = msg.PeersId |
||||
c.maxRemotePayload = msg.MaxPacketSize |
||||
c.remoteWin.add(msg.PeersWindow) |
||||
m.incomingChannels <- c |
||||
return nil |
||||
} |
||||
|
||||
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { |
||||
ch, err := m.openChannel(chanType, extra) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return ch, ch.incomingRequests, nil |
||||
} |
||||
|
||||
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { |
||||
ch := m.newChannel(chanType, channelOutbound, extra) |
||||
|
||||
ch.maxIncomingPayload = channelMaxPacket |
||||
|
||||
open := channelOpenMsg{ |
||||
ChanType: chanType, |
||||
PeersWindow: ch.myWindow, |
||||
MaxPacketSize: ch.maxIncomingPayload, |
||||
TypeSpecificData: extra, |
||||
PeersId: ch.localId, |
||||
} |
||||
if err := m.sendMessage(open); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
switch msg := (<-ch.msg).(type) { |
||||
case *channelOpenConfirmMsg: |
||||
return ch, nil |
||||
case *channelOpenFailureMsg: |
||||
return nil, &OpenChannelError{msg.Reason, msg.Message} |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) |
||||
} |
||||
} |
@ -0,0 +1,491 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"strings" |
||||
) |
||||
|
||||
// The Permissions type holds fine-grained permissions that are
|
||||
// specific to a user or a specific authentication method for a
|
||||
// user. Permissions, except for "source-address", must be enforced in
|
||||
// the server application layer, after successful authentication. The
|
||||
// Permissions are passed on in ServerConn so a server implementation
|
||||
// can honor them.
|
||||
type Permissions struct { |
||||
// Critical options restrict default permissions. Common
|
||||
// restrictions are "source-address" and "force-command". If
|
||||
// the server cannot enforce the restriction, or does not
|
||||
// recognize it, the user should not authenticate.
|
||||
CriticalOptions map[string]string |
||||
|
||||
// Extensions are extra functionality that the server may
|
||||
// offer on authenticated connections. Common extensions are
|
||||
// "permit-agent-forwarding", "permit-X11-forwarding". Lack of
|
||||
// support for an extension does not preclude authenticating a
|
||||
// user.
|
||||
Extensions map[string]string |
||||
} |
||||
|
||||
// ServerConfig holds server specific configuration data.
|
||||
type ServerConfig struct { |
||||
// Config contains configuration shared between client and server.
|
||||
Config |
||||
|
||||
hostKeys []Signer |
||||
|
||||
// NoClientAuth is true if clients are allowed to connect without
|
||||
// authenticating.
|
||||
NoClientAuth bool |
||||
|
||||
// PasswordCallback, if non-nil, is called when a user
|
||||
// attempts to authenticate using a password.
|
||||
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) |
||||
|
||||
// PublicKeyCallback, if non-nil, is called when a client attempts public
|
||||
// key authentication. It must return true if the given public key is
|
||||
// valid for the given user. For example, see CertChecker.Authenticate.
|
||||
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) |
||||
|
||||
// KeyboardInteractiveCallback, if non-nil, is called when
|
||||
// keyboard-interactive authentication is selected (RFC
|
||||
// 4256). The client object's Challenge function should be
|
||||
// used to query the user. The callback may offer multiple
|
||||
// Challenge rounds. To avoid information leaks, the client
|
||||
// should be presented a challenge even if the user is
|
||||
// unknown.
|
||||
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) |
||||
|
||||
// AuthLogCallback, if non-nil, is called to log all authentication
|
||||
// attempts.
|
||||
AuthLogCallback func(conn ConnMetadata, method string, err error) |
||||
|
||||
// ServerVersion is the version identification string to announce in
|
||||
// the public handshake.
|
||||
// If empty, a reasonable default is used.
|
||||
// Note that RFC 4253 section 4.2 requires that this string start with
|
||||
// "SSH-2.0-".
|
||||
ServerVersion string |
||||
} |
||||
|
||||
// AddHostKey adds a private key as a host key. If an existing host
|
||||
// key exists with the same algorithm, it is overwritten. Each server
|
||||
// config must have at least one host key.
|
||||
func (s *ServerConfig) AddHostKey(key Signer) { |
||||
for i, k := range s.hostKeys { |
||||
if k.PublicKey().Type() == key.PublicKey().Type() { |
||||
s.hostKeys[i] = key |
||||
return |
||||
} |
||||
} |
||||
|
||||
s.hostKeys = append(s.hostKeys, key) |
||||
} |
||||
|
||||
// cachedPubKey contains the results of querying whether a public key is
|
||||
// acceptable for a user.
|
||||
type cachedPubKey struct { |
||||
user string |
||||
pubKeyData []byte |
||||
result error |
||||
perms *Permissions |
||||
} |
||||
|
||||
const maxCachedPubKeys = 16 |
||||
|
||||
// pubKeyCache caches tests for public keys. Since SSH clients
|
||||
// will query whether a public key is acceptable before attempting to
|
||||
// authenticate with it, we end up with duplicate queries for public
|
||||
// key validity. The cache only applies to a single ServerConn.
|
||||
type pubKeyCache struct { |
||||
keys []cachedPubKey |
||||
} |
||||
|
||||
// get returns the result for a given user/algo/key tuple.
|
||||
func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { |
||||
for _, k := range c.keys { |
||||
if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { |
||||
return k, true |
||||
} |
||||
} |
||||
return cachedPubKey{}, false |
||||
} |
||||
|
||||
// add adds the given tuple to the cache.
|
||||
func (c *pubKeyCache) add(candidate cachedPubKey) { |
||||
if len(c.keys) < maxCachedPubKeys { |
||||
c.keys = append(c.keys, candidate) |
||||
} |
||||
} |
||||
|
||||
// ServerConn is an authenticated SSH connection, as seen from the
|
||||
// server
|
||||
type ServerConn struct { |
||||
Conn |
||||
|
||||
// If the succeeding authentication callback returned a
|
||||
// non-nil Permissions pointer, it is stored here.
|
||||
Permissions *Permissions |
||||
} |
||||
|
||||
// NewServerConn starts a new SSH server with c as the underlying
|
||||
// transport. It starts with a handshake and, if the handshake is
|
||||
// unsuccessful, it closes the connection and returns an error. The
|
||||
// Request and NewChannel channels must be serviced, or the connection
|
||||
// will hang.
|
||||
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { |
||||
fullConf := *config |
||||
fullConf.SetDefaults() |
||||
s := &connection{ |
||||
sshConn: sshConn{conn: c}, |
||||
} |
||||
perms, err := s.serverHandshake(&fullConf) |
||||
if err != nil { |
||||
c.Close() |
||||
return nil, nil, nil, err |
||||
} |
||||
return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil |
||||
} |
||||
|
||||
// signAndMarshal signs the data with the appropriate algorithm,
|
||||
// and serializes the result in SSH wire format.
|
||||
func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { |
||||
sig, err := k.Sign(rand, data) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return Marshal(sig), nil |
||||
} |
||||
|
||||
// handshake performs key exchange and user authentication.
|
||||
func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { |
||||
if len(config.hostKeys) == 0 { |
||||
return nil, errors.New("ssh: server has no host keys") |
||||
} |
||||
|
||||
if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && config.KeyboardInteractiveCallback == nil { |
||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") |
||||
} |
||||
|
||||
if config.ServerVersion != "" { |
||||
s.serverVersion = []byte(config.ServerVersion) |
||||
} else { |
||||
s.serverVersion = []byte(packageVersion) |
||||
} |
||||
var err error |
||||
s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) |
||||
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) |
||||
|
||||
if err := s.transport.waitSession(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// We just did the key change, so the session ID is established.
|
||||
s.sessionID = s.transport.getSessionID() |
||||
|
||||
var packet []byte |
||||
if packet, err = s.transport.readPacket(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var serviceRequest serviceRequestMsg |
||||
if err = Unmarshal(packet, &serviceRequest); err != nil { |
||||
return nil, err |
||||
} |
||||
if serviceRequest.Service != serviceUserAuth { |
||||
return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") |
||||
} |
||||
serviceAccept := serviceAcceptMsg{ |
||||
Service: serviceUserAuth, |
||||
} |
||||
if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
perms, err := s.serverAuthenticate(config) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
s.mux = newMux(s.transport) |
||||
return perms, err |
||||
} |
||||
|
||||
func isAcceptableAlgo(algo string) bool { |
||||
switch algo { |
||||
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519, |
||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: |
||||
return true |
||||
} |
||||
return false |
||||
} |
||||
|
||||
func checkSourceAddress(addr net.Addr, sourceAddrs string) error { |
||||
if addr == nil { |
||||
return errors.New("ssh: no address known for client, but source-address match required") |
||||
} |
||||
|
||||
tcpAddr, ok := addr.(*net.TCPAddr) |
||||
if !ok { |
||||
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) |
||||
} |
||||
|
||||
for _, sourceAddr := range strings.Split(sourceAddrs, ",") { |
||||
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { |
||||
if allowedIP.Equal(tcpAddr.IP) { |
||||
return nil |
||||
} |
||||
} else { |
||||
_, ipNet, err := net.ParseCIDR(sourceAddr) |
||||
if err != nil { |
||||
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) |
||||
} |
||||
|
||||
if ipNet.Contains(tcpAddr.IP) { |
||||
return nil |
||||
} |
||||
} |
||||
} |
||||
|
||||
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) |
||||
} |
||||
|
||||
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { |
||||
sessionID := s.transport.getSessionID() |
||||
var cache pubKeyCache |
||||
var perms *Permissions |
||||
|
||||
userAuthLoop: |
||||
for { |
||||
var userAuthReq userAuthRequestMsg |
||||
if packet, err := s.transport.readPacket(); err != nil { |
||||
return nil, err |
||||
} else if err = Unmarshal(packet, &userAuthReq); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if userAuthReq.Service != serviceSSH { |
||||
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) |
||||
} |
||||
|
||||
s.user = userAuthReq.User |
||||
perms = nil |
||||
authErr := errors.New("no auth passed yet") |
||||
|
||||
switch userAuthReq.Method { |
||||
case "none": |
||||
if config.NoClientAuth { |
||||
authErr = nil |
||||
} |
||||
case "password": |
||||
if config.PasswordCallback == nil { |
||||
authErr = errors.New("ssh: password auth not configured") |
||||
break |
||||
} |
||||
payload := userAuthReq.Payload |
||||
if len(payload) < 1 || payload[0] != 0 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
payload = payload[1:] |
||||
password, payload, ok := parseString(payload) |
||||
if !ok || len(payload) > 0 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
|
||||
perms, authErr = config.PasswordCallback(s, password) |
||||
case "keyboard-interactive": |
||||
if config.KeyboardInteractiveCallback == nil { |
||||
authErr = errors.New("ssh: keyboard-interactive auth not configubred") |
||||
break |
||||
} |
||||
|
||||
prompter := &sshClientKeyboardInteractive{s} |
||||
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) |
||||
case "publickey": |
||||
if config.PublicKeyCallback == nil { |
||||
authErr = errors.New("ssh: publickey auth not configured") |
||||
break |
||||
} |
||||
payload := userAuthReq.Payload |
||||
if len(payload) < 1 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
isQuery := payload[0] == 0 |
||||
payload = payload[1:] |
||||
algoBytes, payload, ok := parseString(payload) |
||||
if !ok { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
algo := string(algoBytes) |
||||
if !isAcceptableAlgo(algo) { |
||||
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) |
||||
break |
||||
} |
||||
|
||||
pubKeyData, payload, ok := parseString(payload) |
||||
if !ok { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
|
||||
pubKey, err := ParsePublicKey(pubKeyData) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
candidate, ok := cache.get(s.user, pubKeyData) |
||||
if !ok { |
||||
candidate.user = s.user |
||||
candidate.pubKeyData = pubKeyData |
||||
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) |
||||
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { |
||||
candidate.result = checkSourceAddress( |
||||
s.RemoteAddr(), |
||||
candidate.perms.CriticalOptions[sourceAddressCriticalOption]) |
||||
} |
||||
cache.add(candidate) |
||||
} |
||||
|
||||
if isQuery { |
||||
// The client can query if the given public key
|
||||
// would be okay.
|
||||
if len(payload) > 0 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
|
||||
if candidate.result == nil { |
||||
okMsg := userAuthPubKeyOkMsg{ |
||||
Algo: algo, |
||||
PubKey: pubKeyData, |
||||
} |
||||
if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { |
||||
return nil, err |
||||
} |
||||
continue userAuthLoop |
||||
} |
||||
authErr = candidate.result |
||||
} else { |
||||
sig, payload, ok := parseSignature(payload) |
||||
if !ok || len(payload) > 0 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
// Ensure the public key algo and signature algo
|
||||
// are supported. Compare the private key
|
||||
// algorithm name that corresponds to algo with
|
||||
// sig.Format. This is usually the same, but
|
||||
// for certs, the names differ.
|
||||
if !isAcceptableAlgo(sig.Format) { |
||||
break |
||||
} |
||||
signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData) |
||||
|
||||
if err := pubKey.Verify(signedData, sig); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
authErr = candidate.result |
||||
perms = candidate.perms |
||||
} |
||||
default: |
||||
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) |
||||
} |
||||
|
||||
if config.AuthLogCallback != nil { |
||||
config.AuthLogCallback(s, userAuthReq.Method, authErr) |
||||
} |
||||
|
||||
if authErr == nil { |
||||
break userAuthLoop |
||||
} |
||||
|
||||
var failureMsg userAuthFailureMsg |
||||
if config.PasswordCallback != nil { |
||||
failureMsg.Methods = append(failureMsg.Methods, "password") |
||||
} |
||||
if config.PublicKeyCallback != nil { |
||||
failureMsg.Methods = append(failureMsg.Methods, "publickey") |
||||
} |
||||
if config.KeyboardInteractiveCallback != nil { |
||||
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") |
||||
} |
||||
|
||||
if len(failureMsg.Methods) == 0 { |
||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") |
||||
} |
||||
|
||||
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { |
||||
return nil, err |
||||
} |
||||
return perms, nil |
||||
} |
||||
|
||||
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
|
||||
// asking the client on the other side of a ServerConn.
|
||||
type sshClientKeyboardInteractive struct { |
||||
*connection |
||||
} |
||||
|
||||
func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { |
||||
if len(questions) != len(echos) { |
||||
return nil, errors.New("ssh: echos and questions must have equal length") |
||||
} |
||||
|
||||
var prompts []byte |
||||
for i := range questions { |
||||
prompts = appendString(prompts, questions[i]) |
||||
prompts = appendBool(prompts, echos[i]) |
||||
} |
||||
|
||||
if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ |
||||
Instruction: instruction, |
||||
NumPrompts: uint32(len(questions)), |
||||
Prompts: prompts, |
||||
})); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.transport.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if packet[0] != msgUserAuthInfoResponse { |
||||
return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) |
||||
} |
||||
packet = packet[1:] |
||||
|
||||
n, packet, ok := parseUint32(packet) |
||||
if !ok || int(n) != len(questions) { |
||||
return nil, parseError(msgUserAuthInfoResponse) |
||||
} |
||||
|
||||
for i := uint32(0); i < n; i++ { |
||||
ans, rest, ok := parseString(packet) |
||||
if !ok { |
||||
return nil, parseError(msgUserAuthInfoResponse) |
||||
} |
||||
|
||||
answers = append(answers, string(ans)) |
||||
packet = rest |
||||
} |
||||
if len(packet) != 0 { |
||||
return nil, errors.New("ssh: junk at end of message") |
||||
} |
||||
|
||||
return answers, nil |
||||
} |
@ -0,0 +1,627 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
// Session implements an interactive session described in
|
||||
// "RFC 4254, section 6".
|
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"sync" |
||||
) |
||||
|
||||
type Signal string |
||||
|
||||
// POSIX signals as listed in RFC 4254 Section 6.10.
|
||||
const ( |
||||
SIGABRT Signal = "ABRT" |
||||
SIGALRM Signal = "ALRM" |
||||
SIGFPE Signal = "FPE" |
||||
SIGHUP Signal = "HUP" |
||||
SIGILL Signal = "ILL" |
||||
SIGINT Signal = "INT" |
||||
SIGKILL Signal = "KILL" |
||||
SIGPIPE Signal = "PIPE" |
||||
SIGQUIT Signal = "QUIT" |
||||
SIGSEGV Signal = "SEGV" |
||||
SIGTERM Signal = "TERM" |
||||
SIGUSR1 Signal = "USR1" |
||||
SIGUSR2 Signal = "USR2" |
||||
) |
||||
|
||||
var signals = map[Signal]int{ |
||||
SIGABRT: 6, |
||||
SIGALRM: 14, |
||||
SIGFPE: 8, |
||||
SIGHUP: 1, |
||||
SIGILL: 4, |
||||
SIGINT: 2, |
||||
SIGKILL: 9, |
||||
SIGPIPE: 13, |
||||
SIGQUIT: 3, |
||||
SIGSEGV: 11, |
||||
SIGTERM: 15, |
||||
} |
||||
|
||||
type TerminalModes map[uint8]uint32 |
||||
|
||||
// POSIX terminal mode flags as listed in RFC 4254 Section 8.
|
||||
const ( |
||||
tty_OP_END = 0 |
||||
VINTR = 1 |
||||
VQUIT = 2 |
||||
VERASE = 3 |
||||
VKILL = 4 |
||||
VEOF = 5 |
||||
VEOL = 6 |
||||
VEOL2 = 7 |
||||
VSTART = 8 |
||||
VSTOP = 9 |
||||
VSUSP = 10 |
||||
VDSUSP = 11 |
||||
VREPRINT = 12 |
||||
VWERASE = 13 |
||||
VLNEXT = 14 |
||||
VFLUSH = 15 |
||||
VSWTCH = 16 |
||||
VSTATUS = 17 |
||||
VDISCARD = 18 |
||||
IGNPAR = 30 |
||||
PARMRK = 31 |
||||
INPCK = 32 |
||||
ISTRIP = 33 |
||||
INLCR = 34 |
||||
IGNCR = 35 |
||||
ICRNL = 36 |
||||
IUCLC = 37 |
||||
IXON = 38 |
||||
IXANY = 39 |
||||
IXOFF = 40 |
||||
IMAXBEL = 41 |
||||
ISIG = 50 |
||||
ICANON = 51 |
||||
XCASE = 52 |
||||
ECHO = 53 |
||||
ECHOE = 54 |
||||
ECHOK = 55 |
||||
ECHONL = 56 |
||||
NOFLSH = 57 |
||||
TOSTOP = 58 |
||||
IEXTEN = 59 |
||||
ECHOCTL = 60 |
||||
ECHOKE = 61 |
||||
PENDIN = 62 |
||||
OPOST = 70 |
||||
OLCUC = 71 |
||||
ONLCR = 72 |
||||
OCRNL = 73 |
||||
ONOCR = 74 |
||||
ONLRET = 75 |
||||
CS7 = 90 |
||||
CS8 = 91 |
||||
PARENB = 92 |
||||
PARODD = 93 |
||||
TTY_OP_ISPEED = 128 |
||||
TTY_OP_OSPEED = 129 |
||||
) |
||||
|
||||
// A Session represents a connection to a remote command or shell.
|
||||
type Session struct { |
||||
// Stdin specifies the remote process's standard input.
|
||||
// If Stdin is nil, the remote process reads from an empty
|
||||
// bytes.Buffer.
|
||||
Stdin io.Reader |
||||
|
||||
// Stdout and Stderr specify the remote process's standard
|
||||
// output and error.
|
||||
//
|
||||
// If either is nil, Run connects the corresponding file
|
||||
// descriptor to an instance of ioutil.Discard. There is a
|
||||
// fixed amount of buffering that is shared for the two streams.
|
||||
// If either blocks it may eventually cause the remote
|
||||
// command to block.
|
||||
Stdout io.Writer |
||||
Stderr io.Writer |
||||
|
||||
ch Channel // the channel backing this session
|
||||
started bool // true once Start, Run or Shell is invoked.
|
||||
copyFuncs []func() error |
||||
errors chan error // one send per copyFunc
|
||||
|
||||
// true if pipe method is active
|
||||
stdinpipe, stdoutpipe, stderrpipe bool |
||||
|
||||
// stdinPipeWriter is non-nil if StdinPipe has not been called
|
||||
// and Stdin was specified by the user; it is the write end of
|
||||
// a pipe connecting Session.Stdin to the stdin channel.
|
||||
stdinPipeWriter io.WriteCloser |
||||
|
||||
exitStatus chan error |
||||
} |
||||
|
||||
// SendRequest sends an out-of-band channel request on the SSH channel
|
||||
// underlying the session.
|
||||
func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { |
||||
return s.ch.SendRequest(name, wantReply, payload) |
||||
} |
||||
|
||||
func (s *Session) Close() error { |
||||
return s.ch.Close() |
||||
} |
||||
|
||||
// RFC 4254 Section 6.4.
|
||||
type setenvRequest struct { |
||||
Name string |
||||
Value string |
||||
} |
||||
|
||||
// Setenv sets an environment variable that will be applied to any
|
||||
// command executed by Shell or Run.
|
||||
func (s *Session) Setenv(name, value string) error { |
||||
msg := setenvRequest{ |
||||
Name: name, |
||||
Value: value, |
||||
} |
||||
ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) |
||||
if err == nil && !ok { |
||||
err = errors.New("ssh: setenv failed") |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// RFC 4254 Section 6.2.
|
||||
type ptyRequestMsg struct { |
||||
Term string |
||||
Columns uint32 |
||||
Rows uint32 |
||||
Width uint32 |
||||
Height uint32 |
||||
Modelist string |
||||
} |
||||
|
||||
// RequestPty requests the association of a pty with the session on the remote host.
|
||||
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { |
||||
var tm []byte |
||||
for k, v := range termmodes { |
||||
kv := struct { |
||||
Key byte |
||||
Val uint32 |
||||
}{k, v} |
||||
|
||||
tm = append(tm, Marshal(&kv)...) |
||||
} |
||||
tm = append(tm, tty_OP_END) |
||||
req := ptyRequestMsg{ |
||||
Term: term, |
||||
Columns: uint32(w), |
||||
Rows: uint32(h), |
||||
Width: uint32(w * 8), |
||||
Height: uint32(h * 8), |
||||
Modelist: string(tm), |
||||
} |
||||
ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) |
||||
if err == nil && !ok { |
||||
err = errors.New("ssh: pty-req failed") |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// RFC 4254 Section 6.5.
|
||||
type subsystemRequestMsg struct { |
||||
Subsystem string |
||||
} |
||||
|
||||
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
|
||||
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
|
||||
func (s *Session) RequestSubsystem(subsystem string) error { |
||||
msg := subsystemRequestMsg{ |
||||
Subsystem: subsystem, |
||||
} |
||||
ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) |
||||
if err == nil && !ok { |
||||
err = errors.New("ssh: subsystem request failed") |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// RFC 4254 Section 6.9.
|
||||
type signalMsg struct { |
||||
Signal string |
||||
} |
||||
|
||||
// Signal sends the given signal to the remote process.
|
||||
// sig is one of the SIG* constants.
|
||||
func (s *Session) Signal(sig Signal) error { |
||||
msg := signalMsg{ |
||||
Signal: string(sig), |
||||
} |
||||
|
||||
_, err := s.ch.SendRequest("signal", false, Marshal(&msg)) |
||||
return err |
||||
} |
||||
|
||||
// RFC 4254 Section 6.5.
|
||||
type execMsg struct { |
||||
Command string |
||||
} |
||||
|
||||
// Start runs cmd on the remote host. Typically, the remote
|
||||
// server passes cmd to the shell for interpretation.
|
||||
// A Session only accepts one call to Run, Start or Shell.
|
||||
func (s *Session) Start(cmd string) error { |
||||
if s.started { |
||||
return errors.New("ssh: session already started") |
||||
} |
||||
req := execMsg{ |
||||
Command: cmd, |
||||
} |
||||
|
||||
ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) |
||||
if err == nil && !ok { |
||||
err = fmt.Errorf("ssh: command %v failed", cmd) |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return s.start() |
||||
} |
||||
|
||||
// Run runs cmd on the remote host. Typically, the remote
|
||||
// server passes cmd to the shell for interpretation.
|
||||
// A Session only accepts one call to Run, Start, Shell, Output,
|
||||
// or CombinedOutput.
|
||||
//
|
||||
// The returned error is nil if the command runs, has no problems
|
||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||
// status.
|
||||
//
|
||||
// If the remote server does not send an exit status, an error of type
|
||||
// *ExitMissingError is returned. If the command completes
|
||||
// unsuccessfully or is interrupted by a signal, the error is of type
|
||||
// *ExitError. Other error types may be returned for I/O problems.
|
||||
func (s *Session) Run(cmd string) error { |
||||
err := s.Start(cmd) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return s.Wait() |
||||
} |
||||
|
||||
// Output runs cmd on the remote host and returns its standard output.
|
||||
func (s *Session) Output(cmd string) ([]byte, error) { |
||||
if s.Stdout != nil { |
||||
return nil, errors.New("ssh: Stdout already set") |
||||
} |
||||
var b bytes.Buffer |
||||
s.Stdout = &b |
||||
err := s.Run(cmd) |
||||
return b.Bytes(), err |
||||
} |
||||
|
||||
type singleWriter struct { |
||||
b bytes.Buffer |
||||
mu sync.Mutex |
||||
} |
||||
|
||||
func (w *singleWriter) Write(p []byte) (int, error) { |
||||
w.mu.Lock() |
||||
defer w.mu.Unlock() |
||||
return w.b.Write(p) |
||||
} |
||||
|
||||
// CombinedOutput runs cmd on the remote host and returns its combined
|
||||
// standard output and standard error.
|
||||
func (s *Session) CombinedOutput(cmd string) ([]byte, error) { |
||||
if s.Stdout != nil { |
||||
return nil, errors.New("ssh: Stdout already set") |
||||
} |
||||
if s.Stderr != nil { |
||||
return nil, errors.New("ssh: Stderr already set") |
||||
} |
||||
var b singleWriter |
||||
s.Stdout = &b |
||||
s.Stderr = &b |
||||
err := s.Run(cmd) |
||||
return b.b.Bytes(), err |
||||
} |
||||
|
||||
// Shell starts a login shell on the remote host. A Session only
|
||||
// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
|
||||
func (s *Session) Shell() error { |
||||
if s.started { |
||||
return errors.New("ssh: session already started") |
||||
} |
||||
|
||||
ok, err := s.ch.SendRequest("shell", true, nil) |
||||
if err == nil && !ok { |
||||
return errors.New("ssh: could not start shell") |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return s.start() |
||||
} |
||||
|
||||
func (s *Session) start() error { |
||||
s.started = true |
||||
|
||||
type F func(*Session) |
||||
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { |
||||
setupFd(s) |
||||
} |
||||
|
||||
s.errors = make(chan error, len(s.copyFuncs)) |
||||
for _, fn := range s.copyFuncs { |
||||
go func(fn func() error) { |
||||
s.errors <- fn() |
||||
}(fn) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Wait waits for the remote command to exit.
|
||||
//
|
||||
// The returned error is nil if the command runs, has no problems
|
||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||
// status.
|
||||
//
|
||||
// If the remote server does not send an exit status, an error of type
|
||||
// *ExitMissingError is returned. If the command completes
|
||||
// unsuccessfully or is interrupted by a signal, the error is of type
|
||||
// *ExitError. Other error types may be returned for I/O problems.
|
||||
func (s *Session) Wait() error { |
||||
if !s.started { |
||||
return errors.New("ssh: session not started") |
||||
} |
||||
waitErr := <-s.exitStatus |
||||
|
||||
if s.stdinPipeWriter != nil { |
||||
s.stdinPipeWriter.Close() |
||||
} |
||||
var copyError error |
||||
for _ = range s.copyFuncs { |
||||
if err := <-s.errors; err != nil && copyError == nil { |
||||
copyError = err |
||||
} |
||||
} |
||||
if waitErr != nil { |
||||
return waitErr |
||||
} |
||||
return copyError |
||||
} |
||||
|
||||
func (s *Session) wait(reqs <-chan *Request) error { |
||||
wm := Waitmsg{status: -1} |
||||
// Wait for msg channel to be closed before returning.
|
||||
for msg := range reqs { |
||||
switch msg.Type { |
||||
case "exit-status": |
||||
wm.status = int(binary.BigEndian.Uint32(msg.Payload)) |
||||
case "exit-signal": |
||||
var sigval struct { |
||||
Signal string |
||||
CoreDumped bool |
||||
Error string |
||||
Lang string |
||||
} |
||||
if err := Unmarshal(msg.Payload, &sigval); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Must sanitize strings?
|
||||
wm.signal = sigval.Signal |
||||
wm.msg = sigval.Error |
||||
wm.lang = sigval.Lang |
||||
default: |
||||
// This handles keepalives and matches
|
||||
// OpenSSH's behaviour.
|
||||
if msg.WantReply { |
||||
msg.Reply(false, nil) |
||||
} |
||||
} |
||||
} |
||||
if wm.status == 0 { |
||||
return nil |
||||
} |
||||
if wm.status == -1 { |
||||
// exit-status was never sent from server
|
||||
if wm.signal == "" { |
||||
// signal was not sent either. RFC 4254
|
||||
// section 6.10 recommends against this
|
||||
// behavior, but it is allowed, so we let
|
||||
// clients handle it.
|
||||
return &ExitMissingError{} |
||||
} |
||||
wm.status = 128 |
||||
if _, ok := signals[Signal(wm.signal)]; ok { |
||||
wm.status += signals[Signal(wm.signal)] |
||||
} |
||||
} |
||||
|
||||
return &ExitError{wm} |
||||
} |
||||
|
||||
// ExitMissingError is returned if a session is torn down cleanly, but
|
||||
// the server sends no confirmation of the exit status.
|
||||
type ExitMissingError struct{} |
||||
|
||||
func (e *ExitMissingError) Error() string { |
||||
return "wait: remote command exited without exit status or exit signal" |
||||
} |
||||
|
||||
func (s *Session) stdin() { |
||||
if s.stdinpipe { |
||||
return |
||||
} |
||||
var stdin io.Reader |
||||
if s.Stdin == nil { |
||||
stdin = new(bytes.Buffer) |
||||
} else { |
||||
r, w := io.Pipe() |
||||
go func() { |
||||
_, err := io.Copy(w, s.Stdin) |
||||
w.CloseWithError(err) |
||||
}() |
||||
stdin, s.stdinPipeWriter = r, w |
||||
} |
||||
s.copyFuncs = append(s.copyFuncs, func() error { |
||||
_, err := io.Copy(s.ch, stdin) |
||||
if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { |
||||
err = err1 |
||||
} |
||||
return err |
||||
}) |
||||
} |
||||
|
||||
func (s *Session) stdout() { |
||||
if s.stdoutpipe { |
||||
return |
||||
} |
||||
if s.Stdout == nil { |
||||
s.Stdout = ioutil.Discard |
||||
} |
||||
s.copyFuncs = append(s.copyFuncs, func() error { |
||||
_, err := io.Copy(s.Stdout, s.ch) |
||||
return err |
||||
}) |
||||
} |
||||
|
||||
func (s *Session) stderr() { |
||||
if s.stderrpipe { |
||||
return |
||||
} |
||||
if s.Stderr == nil { |
||||
s.Stderr = ioutil.Discard |
||||
} |
||||
s.copyFuncs = append(s.copyFuncs, func() error { |
||||
_, err := io.Copy(s.Stderr, s.ch.Stderr()) |
||||
return err |
||||
}) |
||||
} |
||||
|
||||
// sessionStdin reroutes Close to CloseWrite.
|
||||
type sessionStdin struct { |
||||
io.Writer |
||||
ch Channel |
||||
} |
||||
|
||||
func (s *sessionStdin) Close() error { |
||||
return s.ch.CloseWrite() |
||||
} |
||||
|
||||
// StdinPipe returns a pipe that will be connected to the
|
||||
// remote command's standard input when the command starts.
|
||||
func (s *Session) StdinPipe() (io.WriteCloser, error) { |
||||
if s.Stdin != nil { |
||||
return nil, errors.New("ssh: Stdin already set") |
||||
} |
||||
if s.started { |
||||
return nil, errors.New("ssh: StdinPipe after process started") |
||||
} |
||||
s.stdinpipe = true |
||||
return &sessionStdin{s.ch, s.ch}, nil |
||||
} |
||||
|
||||
// StdoutPipe returns a pipe that will be connected to the
|
||||
// remote command's standard output when the command starts.
|
||||
// There is a fixed amount of buffering that is shared between
|
||||
// stdout and stderr streams. If the StdoutPipe reader is
|
||||
// not serviced fast enough it may eventually cause the
|
||||
// remote command to block.
|
||||
func (s *Session) StdoutPipe() (io.Reader, error) { |
||||
if s.Stdout != nil { |
||||
return nil, errors.New("ssh: Stdout already set") |
||||
} |
||||
if s.started { |
||||
return nil, errors.New("ssh: StdoutPipe after process started") |
||||
} |
||||
s.stdoutpipe = true |
||||
return s.ch, nil |
||||
} |
||||
|
||||
// StderrPipe returns a pipe that will be connected to the
|
||||
// remote command's standard error when the command starts.
|
||||
// There is a fixed amount of buffering that is shared between
|
||||
// stdout and stderr streams. If the StderrPipe reader is
|
||||
// not serviced fast enough it may eventually cause the
|
||||
// remote command to block.
|
||||
func (s *Session) StderrPipe() (io.Reader, error) { |
||||
if s.Stderr != nil { |
||||
return nil, errors.New("ssh: Stderr already set") |
||||
} |
||||
if s.started { |
||||
return nil, errors.New("ssh: StderrPipe after process started") |
||||
} |
||||
s.stderrpipe = true |
||||
return s.ch.Stderr(), nil |
||||
} |
||||
|
||||
// newSession returns a new interactive session on the remote host.
|
||||
func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { |
||||
s := &Session{ |
||||
ch: ch, |
||||
} |
||||
s.exitStatus = make(chan error, 1) |
||||
go func() { |
||||
s.exitStatus <- s.wait(reqs) |
||||
}() |
||||
|
||||
return s, nil |
||||
} |
||||
|
||||
// An ExitError reports unsuccessful completion of a remote command.
|
||||
type ExitError struct { |
||||
Waitmsg |
||||
} |
||||
|
||||
func (e *ExitError) Error() string { |
||||
return e.Waitmsg.String() |
||||
} |
||||
|
||||
// Waitmsg stores the information about an exited remote command
|
||||
// as reported by Wait.
|
||||
type Waitmsg struct { |
||||
status int |
||||
signal string |
||||
msg string |
||||
lang string |
||||
} |
||||
|
||||
// ExitStatus returns the exit status of the remote command.
|
||||
func (w Waitmsg) ExitStatus() int { |
||||
return w.status |
||||
} |
||||
|
||||
// Signal returns the exit signal of the remote command if
|
||||
// it was terminated violently.
|
||||
func (w Waitmsg) Signal() string { |
||||
return w.signal |
||||
} |
||||
|
||||
// Msg returns the exit message given by the remote command
|
||||
func (w Waitmsg) Msg() string { |
||||
return w.msg |
||||
} |
||||
|
||||
// Lang returns the language tag. See RFC 3066
|
||||
func (w Waitmsg) Lang() string { |
||||
return w.lang |
||||
} |
||||
|
||||
func (w Waitmsg) String() string { |
||||
str := fmt.Sprintf("Process exited with status %v", w.status) |
||||
if w.signal != "" { |
||||
str += fmt.Sprintf(" from signal %v", w.signal) |
||||
} |
||||
if w.msg != "" { |
||||
str += fmt.Sprintf(". Reason was: %v", w.msg) |
||||
} |
||||
return str |
||||
} |
@ -0,0 +1,407 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"math/rand" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
// Listen requests the remote peer open a listening socket on
|
||||
// addr. Incoming connections will be available by calling Accept on
|
||||
// the returned net.Listener. The listener must be serviced, or the
|
||||
// SSH connection may hang.
|
||||
func (c *Client) Listen(n, addr string) (net.Listener, error) { |
||||
laddr, err := net.ResolveTCPAddr(n, addr) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return c.ListenTCP(laddr) |
||||
} |
||||
|
||||
// Automatic port allocation is broken with OpenSSH before 6.0. See
|
||||
// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In
|
||||
// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
|
||||
// rather than the actual port number. This means you can never open
|
||||
// two different listeners with auto allocated ports. We work around
|
||||
// this by trying explicit ports until we succeed.
|
||||
|
||||
const openSSHPrefix = "OpenSSH_" |
||||
|
||||
var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) |
||||
|
||||
// isBrokenOpenSSHVersion returns true if the given version string
|
||||
// specifies a version of OpenSSH that is known to have a bug in port
|
||||
// forwarding.
|
||||
func isBrokenOpenSSHVersion(versionStr string) bool { |
||||
i := strings.Index(versionStr, openSSHPrefix) |
||||
if i < 0 { |
||||
return false |
||||
} |
||||
i += len(openSSHPrefix) |
||||
j := i |
||||
for ; j < len(versionStr); j++ { |
||||
if versionStr[j] < '0' || versionStr[j] > '9' { |
||||
break |
||||
} |
||||
} |
||||
version, _ := strconv.Atoi(versionStr[i:j]) |
||||
return version < 6 |
||||
} |
||||
|
||||
// autoPortListenWorkaround simulates automatic port allocation by
|
||||
// trying random ports repeatedly.
|
||||
func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { |
||||
var sshListener net.Listener |
||||
var err error |
||||
const tries = 10 |
||||
for i := 0; i < tries; i++ { |
||||
addr := *laddr |
||||
addr.Port = 1024 + portRandomizer.Intn(60000) |
||||
sshListener, err = c.ListenTCP(&addr) |
||||
if err == nil { |
||||
laddr.Port = addr.Port |
||||
return sshListener, err |
||||
} |
||||
} |
||||
return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) |
||||
} |
||||
|
||||
// RFC 4254 7.1
|
||||
type channelForwardMsg struct { |
||||
addr string |
||||
rport uint32 |
||||
} |
||||
|
||||
// ListenTCP requests the remote peer open a listening socket
|
||||
// on laddr. Incoming connections will be available by calling
|
||||
// Accept on the returned net.Listener.
|
||||
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { |
||||
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { |
||||
return c.autoPortListenWorkaround(laddr) |
||||
} |
||||
|
||||
m := channelForwardMsg{ |
||||
laddr.IP.String(), |
||||
uint32(laddr.Port), |
||||
} |
||||
// send message
|
||||
ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if !ok { |
||||
return nil, errors.New("ssh: tcpip-forward request denied by peer") |
||||
} |
||||
|
||||
// If the original port was 0, then the remote side will
|
||||
// supply a real port number in the response.
|
||||
if laddr.Port == 0 { |
||||
var p struct { |
||||
Port uint32 |
||||
} |
||||
if err := Unmarshal(resp, &p); err != nil { |
||||
return nil, err |
||||
} |
||||
laddr.Port = int(p.Port) |
||||
} |
||||
|
||||
// Register this forward, using the port number we obtained.
|
||||
ch := c.forwards.add(*laddr) |
||||
|
||||
return &tcpListener{laddr, c, ch}, nil |
||||
} |
||||
|
||||
// forwardList stores a mapping between remote
|
||||
// forward requests and the tcpListeners.
|
||||
type forwardList struct { |
||||
sync.Mutex |
||||
entries []forwardEntry |
||||
} |
||||
|
||||
// forwardEntry represents an established mapping of a laddr on a
|
||||
// remote ssh server to a channel connected to a tcpListener.
|
||||
type forwardEntry struct { |
||||
laddr net.TCPAddr |
||||
c chan forward |
||||
} |
||||
|
||||
// forward represents an incoming forwarded tcpip connection. The
|
||||
// arguments to add/remove/lookup should be address as specified in
|
||||
// the original forward-request.
|
||||
type forward struct { |
||||
newCh NewChannel // the ssh client channel underlying this forward
|
||||
raddr *net.TCPAddr // the raddr of the incoming connection
|
||||
} |
||||
|
||||
func (l *forwardList) add(addr net.TCPAddr) chan forward { |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
f := forwardEntry{ |
||||
addr, |
||||
make(chan forward, 1), |
||||
} |
||||
l.entries = append(l.entries, f) |
||||
return f.c |
||||
} |
||||
|
||||
// See RFC 4254, section 7.2
|
||||
type forwardedTCPPayload struct { |
||||
Addr string |
||||
Port uint32 |
||||
OriginAddr string |
||||
OriginPort uint32 |
||||
} |
||||
|
||||
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
|
||||
func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { |
||||
if port == 0 || port > 65535 { |
||||
return nil, fmt.Errorf("ssh: port number out of range: %d", port) |
||||
} |
||||
ip := net.ParseIP(string(addr)) |
||||
if ip == nil { |
||||
return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) |
||||
} |
||||
return &net.TCPAddr{IP: ip, Port: int(port)}, nil |
||||
} |
||||
|
||||
func (l *forwardList) handleChannels(in <-chan NewChannel) { |
||||
for ch := range in { |
||||
var payload forwardedTCPPayload |
||||
if err := Unmarshal(ch.ExtraData(), &payload); err != nil { |
||||
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) |
||||
continue |
||||
} |
||||
|
||||
// RFC 4254 section 7.2 specifies that incoming
|
||||
// addresses should list the address, in string
|
||||
// format. It is implied that this should be an IP
|
||||
// address, as it would be impossible to connect to it
|
||||
// otherwise.
|
||||
laddr, err := parseTCPAddr(payload.Addr, payload.Port) |
||||
if err != nil { |
||||
ch.Reject(ConnectionFailed, err.Error()) |
||||
continue |
||||
} |
||||
raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) |
||||
if err != nil { |
||||
ch.Reject(ConnectionFailed, err.Error()) |
||||
continue |
||||
} |
||||
|
||||
if ok := l.forward(*laddr, *raddr, ch); !ok { |
||||
// Section 7.2, implementations MUST reject spurious incoming
|
||||
// connections.
|
||||
ch.Reject(Prohibited, "no forward for address") |
||||
continue |
||||
} |
||||
} |
||||
} |
||||
|
||||
// remove removes the forward entry, and the channel feeding its
|
||||
// listener.
|
||||
func (l *forwardList) remove(addr net.TCPAddr) { |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
for i, f := range l.entries { |
||||
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { |
||||
l.entries = append(l.entries[:i], l.entries[i+1:]...) |
||||
close(f.c) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// closeAll closes and clears all forwards.
|
||||
func (l *forwardList) closeAll() { |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
for _, f := range l.entries { |
||||
close(f.c) |
||||
} |
||||
l.entries = nil |
||||
} |
||||
|
||||
func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
for _, f := range l.entries { |
||||
if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { |
||||
f.c <- forward{ch, &raddr} |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
type tcpListener struct { |
||||
laddr *net.TCPAddr |
||||
|
||||
conn *Client |
||||
in <-chan forward |
||||
} |
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l *tcpListener) Accept() (net.Conn, error) { |
||||
s, ok := <-l.in |
||||
if !ok { |
||||
return nil, io.EOF |
||||
} |
||||
ch, incoming, err := s.newCh.Accept() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
go DiscardRequests(incoming) |
||||
|
||||
return &tcpChanConn{ |
||||
Channel: ch, |
||||
laddr: l.laddr, |
||||
raddr: s.raddr, |
||||
}, nil |
||||
} |
||||
|
||||
// Close closes the listener.
|
||||
func (l *tcpListener) Close() error { |
||||
m := channelForwardMsg{ |
||||
l.laddr.IP.String(), |
||||
uint32(l.laddr.Port), |
||||
} |
||||
|
||||
// this also closes the listener.
|
||||
l.conn.forwards.remove(*l.laddr) |
||||
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) |
||||
if err == nil && !ok { |
||||
err = errors.New("ssh: cancel-tcpip-forward failed") |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *tcpListener) Addr() net.Addr { |
||||
return l.laddr |
||||
} |
||||
|
||||
// Dial initiates a connection to the addr from the remote host.
|
||||
// The resulting connection has a zero LocalAddr() and RemoteAddr().
|
||||
func (c *Client) Dial(n, addr string) (net.Conn, error) { |
||||
// Parse the address into host and numeric port.
|
||||
host, portString, err := net.SplitHostPort(addr) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
port, err := strconv.ParseUint(portString, 10, 16) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
// Use a zero address for local and remote address.
|
||||
zeroAddr := &net.TCPAddr{ |
||||
IP: net.IPv4zero, |
||||
Port: 0, |
||||
} |
||||
ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &tcpChanConn{ |
||||
Channel: ch, |
||||
laddr: zeroAddr, |
||||
raddr: zeroAddr, |
||||
}, nil |
||||
} |
||||
|
||||
// DialTCP connects to the remote address raddr on the network net,
|
||||
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
|
||||
// as the local address for the connection.
|
||||
func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { |
||||
if laddr == nil { |
||||
laddr = &net.TCPAddr{ |
||||
IP: net.IPv4zero, |
||||
Port: 0, |
||||
} |
||||
} |
||||
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &tcpChanConn{ |
||||
Channel: ch, |
||||
laddr: laddr, |
||||
raddr: raddr, |
||||
}, nil |
||||
} |
||||
|
||||
// RFC 4254 7.2
|
||||
type channelOpenDirectMsg struct { |
||||
raddr string |
||||
rport uint32 |
||||
laddr string |
||||
lport uint32 |
||||
} |
||||
|
||||
func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { |
||||
msg := channelOpenDirectMsg{ |
||||
raddr: raddr, |
||||
rport: uint32(rport), |
||||
laddr: laddr, |
||||
lport: uint32(lport), |
||||
} |
||||
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
go DiscardRequests(in) |
||||
return ch, err |
||||
} |
||||
|
||||
type tcpChan struct { |
||||
Channel // the backing channel
|
||||
} |
||||
|
||||
// tcpChanConn fulfills the net.Conn interface without
|
||||
// the tcpChan having to hold laddr or raddr directly.
|
||||
type tcpChanConn struct { |
||||
Channel |
||||
laddr, raddr net.Addr |
||||
} |
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (t *tcpChanConn) LocalAddr() net.Addr { |
||||
return t.laddr |
||||
} |
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (t *tcpChanConn) RemoteAddr() net.Addr { |
||||
return t.raddr |
||||
} |
||||
|
||||
// SetDeadline sets the read and write deadlines associated
|
||||
// with the connection.
|
||||
func (t *tcpChanConn) SetDeadline(deadline time.Time) error { |
||||
if err := t.SetReadDeadline(deadline); err != nil { |
||||
return err |
||||
} |
||||
return t.SetWriteDeadline(deadline) |
||||
} |
||||
|
||||
// SetReadDeadline sets the read deadline.
|
||||
// A zero value for t means Read will not time out.
|
||||
// After the deadline, the error from Read will implement net.Error
|
||||
// with Timeout() == true.
|
||||
func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { |
||||
return errors.New("ssh: tcpChan: deadline not supported") |
||||
} |
||||
|
||||
// SetWriteDeadline exists to satisfy the net.Conn interface
|
||||
// but is not implemented by this type. It always returns an error.
|
||||
func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { |
||||
return errors.New("ssh: tcpChan: deadline not supported") |
||||
} |
@ -0,0 +1,951 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package terminal |
||||
|
||||
import ( |
||||
"bytes" |
||||
"io" |
||||
"sync" |
||||
"unicode/utf8" |
||||
) |
||||
|
||||
// EscapeCodes contains escape sequences that can be written to the terminal in
|
||||
// order to achieve different styles of text.
|
||||
type EscapeCodes struct { |
||||
// Foreground colors
|
||||
Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte |
||||
|
||||
// Reset all attributes
|
||||
Reset []byte |
||||
} |
||||
|
||||
var vt100EscapeCodes = EscapeCodes{ |
||||
Black: []byte{keyEscape, '[', '3', '0', 'm'}, |
||||
Red: []byte{keyEscape, '[', '3', '1', 'm'}, |
||||
Green: []byte{keyEscape, '[', '3', '2', 'm'}, |
||||
Yellow: []byte{keyEscape, '[', '3', '3', 'm'}, |
||||
Blue: []byte{keyEscape, '[', '3', '4', 'm'}, |
||||
Magenta: []byte{keyEscape, '[', '3', '5', 'm'}, |
||||
Cyan: []byte{keyEscape, '[', '3', '6', 'm'}, |
||||
White: []byte{keyEscape, '[', '3', '7', 'm'}, |
||||
|
||||
Reset: []byte{keyEscape, '[', '0', 'm'}, |
||||
} |
||||
|
||||
// Terminal contains the state for running a VT100 terminal that is capable of
|
||||
// reading lines of input.
|
||||
type Terminal struct { |
||||
// AutoCompleteCallback, if non-null, is called for each keypress with
|
||||
// the full input line and the current position of the cursor (in
|
||||
// bytes, as an index into |line|). If it returns ok=false, the key
|
||||
// press is processed normally. Otherwise it returns a replacement line
|
||||
// and the new cursor position.
|
||||
AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool) |
||||
|
||||
// Escape contains a pointer to the escape codes for this terminal.
|
||||
// It's always a valid pointer, although the escape codes themselves
|
||||
// may be empty if the terminal doesn't support them.
|
||||
Escape *EscapeCodes |
||||
|
||||
// lock protects the terminal and the state in this object from
|
||||
// concurrent processing of a key press and a Write() call.
|
||||
lock sync.Mutex |
||||
|
||||
c io.ReadWriter |
||||
prompt []rune |
||||
|
||||
// line is the current line being entered.
|
||||
line []rune |
||||
// pos is the logical position of the cursor in line
|
||||
pos int |
||||
// echo is true if local echo is enabled
|
||||
echo bool |
||||
// pasteActive is true iff there is a bracketed paste operation in
|
||||
// progress.
|
||||
pasteActive bool |
||||
|
||||
// cursorX contains the current X value of the cursor where the left
|
||||
// edge is 0. cursorY contains the row number where the first row of
|
||||
// the current line is 0.
|
||||
cursorX, cursorY int |
||||
// maxLine is the greatest value of cursorY so far.
|
||||
maxLine int |
||||
|
||||
termWidth, termHeight int |
||||
|
||||
// outBuf contains the terminal data to be sent.
|
||||
outBuf []byte |
||||
// remainder contains the remainder of any partial key sequences after
|
||||
// a read. It aliases into inBuf.
|
||||
remainder []byte |
||||
inBuf [256]byte |
||||
|
||||
// history contains previously entered commands so that they can be
|
||||
// accessed with the up and down keys.
|
||||
history stRingBuffer |
||||
// historyIndex stores the currently accessed history entry, where zero
|
||||
// means the immediately previous entry.
|
||||
historyIndex int |
||||
// When navigating up and down the history it's possible to return to
|
||||
// the incomplete, initial line. That value is stored in
|
||||
// historyPending.
|
||||
historyPending string |
||||
} |
||||
|
||||
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
|
||||
// a local terminal, that terminal must first have been put into raw mode.
|
||||
// prompt is a string that is written at the start of each input line (i.e.
|
||||
// "> ").
|
||||
func NewTerminal(c io.ReadWriter, prompt string) *Terminal { |
||||
return &Terminal{ |
||||
Escape: &vt100EscapeCodes, |
||||
c: c, |
||||
prompt: []rune(prompt), |
||||
termWidth: 80, |
||||
termHeight: 24, |
||||
echo: true, |
||||
historyIndex: -1, |
||||
} |
||||
} |
||||
|
||||
const ( |
||||
keyCtrlD = 4 |
||||
keyCtrlU = 21 |
||||
keyEnter = '\r' |
||||
keyEscape = 27 |
||||
keyBackspace = 127 |
||||
keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota |
||||
keyUp |
||||
keyDown |
||||
keyLeft |
||||
keyRight |
||||
keyAltLeft |
||||
keyAltRight |
||||
keyHome |
||||
keyEnd |
||||
keyDeleteWord |
||||
keyDeleteLine |
||||
keyClearScreen |
||||
keyPasteStart |
||||
keyPasteEnd |
||||
) |
||||
|
||||
var ( |
||||
crlf = []byte{'\r', '\n'} |
||||
pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'} |
||||
pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'} |
||||
) |
||||
|
||||
// bytesToKey tries to parse a key sequence from b. If successful, it returns
|
||||
// the key and the remainder of the input. Otherwise it returns utf8.RuneError.
|
||||
func bytesToKey(b []byte, pasteActive bool) (rune, []byte) { |
||||
if len(b) == 0 { |
||||
return utf8.RuneError, nil |
||||
} |
||||
|
||||
if !pasteActive { |
||||
switch b[0] { |
||||
case 1: // ^A
|
||||
return keyHome, b[1:] |
||||
case 5: // ^E
|
||||
return keyEnd, b[1:] |
||||
case 8: // ^H
|
||||
return keyBackspace, b[1:] |
||||
case 11: // ^K
|
||||
return keyDeleteLine, b[1:] |
||||
case 12: // ^L
|
||||
return keyClearScreen, b[1:] |
||||
case 23: // ^W
|
||||
return keyDeleteWord, b[1:] |
||||
} |
||||
} |
||||
|
||||
if b[0] != keyEscape { |
||||
if !utf8.FullRune(b) { |
||||
return utf8.RuneError, b |
||||
} |
||||
r, l := utf8.DecodeRune(b) |
||||
return r, b[l:] |
||||
} |
||||
|
||||
if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' { |
||||
switch b[2] { |
||||
case 'A': |
||||
return keyUp, b[3:] |
||||
case 'B': |
||||
return keyDown, b[3:] |
||||
case 'C': |
||||
return keyRight, b[3:] |
||||
case 'D': |
||||
return keyLeft, b[3:] |
||||
case 'H': |
||||
return keyHome, b[3:] |
||||
case 'F': |
||||
return keyEnd, b[3:] |
||||
} |
||||
} |
||||
|
||||
if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' { |
||||
switch b[5] { |
||||
case 'C': |
||||
return keyAltRight, b[6:] |
||||
case 'D': |
||||
return keyAltLeft, b[6:] |
||||
} |
||||
} |
||||
|
||||
if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) { |
||||
return keyPasteStart, b[6:] |
||||
} |
||||
|
||||
if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) { |
||||
return keyPasteEnd, b[6:] |
||||
} |
||||
|
||||
// If we get here then we have a key that we don't recognise, or a
|
||||
// partial sequence. It's not clear how one should find the end of a
|
||||
// sequence without knowing them all, but it seems that [a-zA-Z~] only
|
||||
// appears at the end of a sequence.
|
||||
for i, c := range b[0:] { |
||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' { |
||||
return keyUnknown, b[i+1:] |
||||
} |
||||
} |
||||
|
||||
return utf8.RuneError, b |
||||
} |
||||
|
||||
// queue appends data to the end of t.outBuf
|
||||
func (t *Terminal) queue(data []rune) { |
||||
t.outBuf = append(t.outBuf, []byte(string(data))...) |
||||
} |
||||
|
||||
var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'} |
||||
var space = []rune{' '} |
||||
|
||||
func isPrintable(key rune) bool { |
||||
isInSurrogateArea := key >= 0xd800 && key <= 0xdbff |
||||
return key >= 32 && !isInSurrogateArea |
||||
} |
||||
|
||||
// moveCursorToPos appends data to t.outBuf which will move the cursor to the
|
||||
// given, logical position in the text.
|
||||
func (t *Terminal) moveCursorToPos(pos int) { |
||||
if !t.echo { |
||||
return |
||||
} |
||||
|
||||
x := visualLength(t.prompt) + pos |
||||
y := x / t.termWidth |
||||
x = x % t.termWidth |
||||
|
||||
up := 0 |
||||
if y < t.cursorY { |
||||
up = t.cursorY - y |
||||
} |
||||
|
||||
down := 0 |
||||
if y > t.cursorY { |
||||
down = y - t.cursorY |
||||
} |
||||
|
||||
left := 0 |
||||
if x < t.cursorX { |
||||
left = t.cursorX - x |
||||
} |
||||
|
||||
right := 0 |
||||
if x > t.cursorX { |
||||
right = x - t.cursorX |
||||
} |
||||
|
||||
t.cursorX = x |
||||
t.cursorY = y |
||||
t.move(up, down, left, right) |
||||
} |
||||
|
||||
func (t *Terminal) move(up, down, left, right int) { |
||||
movement := make([]rune, 3*(up+down+left+right)) |
||||
m := movement |
||||
for i := 0; i < up; i++ { |
||||
m[0] = keyEscape |
||||
m[1] = '[' |
||||
m[2] = 'A' |
||||
m = m[3:] |
||||
} |
||||
for i := 0; i < down; i++ { |
||||
m[0] = keyEscape |
||||
m[1] = '[' |
||||
m[2] = 'B' |
||||
m = m[3:] |
||||
} |
||||
for i := 0; i < left; i++ { |
||||
m[0] = keyEscape |
||||
m[1] = '[' |
||||
m[2] = 'D' |
||||
m = m[3:] |
||||
} |
||||
for i := 0; i < right; i++ { |
||||
m[0] = keyEscape |
||||
m[1] = '[' |
||||
m[2] = 'C' |
||||
m = m[3:] |
||||
} |
||||
|
||||
t.queue(movement) |
||||
} |
||||
|
||||
func (t *Terminal) clearLineToRight() { |
||||
op := []rune{keyEscape, '[', 'K'} |
||||
t.queue(op) |
||||
} |
||||
|
||||
const maxLineLength = 4096 |
||||
|
||||
func (t *Terminal) setLine(newLine []rune, newPos int) { |
||||
if t.echo { |
||||
t.moveCursorToPos(0) |
||||
t.writeLine(newLine) |
||||
for i := len(newLine); i < len(t.line); i++ { |
||||
t.writeLine(space) |
||||
} |
||||
t.moveCursorToPos(newPos) |
||||
} |
||||
t.line = newLine |
||||
t.pos = newPos |
||||
} |
||||
|
||||
func (t *Terminal) advanceCursor(places int) { |
||||
t.cursorX += places |
||||
t.cursorY += t.cursorX / t.termWidth |
||||
if t.cursorY > t.maxLine { |
||||
t.maxLine = t.cursorY |
||||
} |
||||
t.cursorX = t.cursorX % t.termWidth |
||||
|
||||
if places > 0 && t.cursorX == 0 { |
||||
// Normally terminals will advance the current position
|
||||
// when writing a character. But that doesn't happen
|
||||
// for the last character in a line. However, when
|
||||
// writing a character (except a new line) that causes
|
||||
// a line wrap, the position will be advanced two
|
||||
// places.
|
||||
//
|
||||
// So, if we are stopping at the end of a line, we
|
||||
// need to write a newline so that our cursor can be
|
||||
// advanced to the next line.
|
||||
t.outBuf = append(t.outBuf, '\r', '\n') |
||||
} |
||||
} |
||||
|
||||
func (t *Terminal) eraseNPreviousChars(n int) { |
||||
if n == 0 { |
||||
return |
||||
} |
||||
|
||||
if t.pos < n { |
||||
n = t.pos |
||||
} |
||||
t.pos -= n |
||||
t.moveCursorToPos(t.pos) |
||||
|
||||
copy(t.line[t.pos:], t.line[n+t.pos:]) |
||||
t.line = t.line[:len(t.line)-n] |
||||
if t.echo { |
||||
t.writeLine(t.line[t.pos:]) |
||||
for i := 0; i < n; i++ { |
||||
t.queue(space) |
||||
} |
||||
t.advanceCursor(n) |
||||
t.moveCursorToPos(t.pos) |
||||
} |
||||
} |
||||
|
||||
// countToLeftWord returns then number of characters from the cursor to the
|
||||
// start of the previous word.
|
||||
func (t *Terminal) countToLeftWord() int { |
||||
if t.pos == 0 { |
||||
return 0 |
||||
} |
||||
|
||||
pos := t.pos - 1 |
||||
for pos > 0 { |
||||
if t.line[pos] != ' ' { |
||||
break |
||||
} |
||||
pos-- |
||||
} |
||||
for pos > 0 { |
||||
if t.line[pos] == ' ' { |
||||
pos++ |
||||
break |
||||
} |
||||
pos-- |
||||
} |
||||
|
||||
return t.pos - pos |
||||
} |
||||
|
||||
// countToRightWord returns then number of characters from the cursor to the
|
||||
// start of the next word.
|
||||
func (t *Terminal) countToRightWord() int { |
||||
pos := t.pos |
||||
for pos < len(t.line) { |
||||
if t.line[pos] == ' ' { |
||||
break |
||||
} |
||||
pos++ |
||||
} |
||||
for pos < len(t.line) { |
||||
if t.line[pos] != ' ' { |
||||
break |
||||
} |
||||
pos++ |
||||
} |
||||
return pos - t.pos |
||||
} |
||||
|
||||
// visualLength returns the number of visible glyphs in s.
|
||||
func visualLength(runes []rune) int { |
||||
inEscapeSeq := false |
||||
length := 0 |
||||
|
||||
for _, r := range runes { |
||||
switch { |
||||
case inEscapeSeq: |
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { |
||||
inEscapeSeq = false |
||||
} |
||||
case r == '\x1b': |
||||
inEscapeSeq = true |
||||
default: |
||||
length++ |
||||
} |
||||
} |
||||
|
||||
return length |
||||
} |
||||
|
||||
// handleKey processes the given key and, optionally, returns a line of text
|
||||
// that the user has entered.
|
||||
func (t *Terminal) handleKey(key rune) (line string, ok bool) { |
||||
if t.pasteActive && key != keyEnter { |
||||
t.addKeyToLine(key) |
||||
return |
||||
} |
||||
|
||||
switch key { |
||||
case keyBackspace: |
||||
if t.pos == 0 { |
||||
return |
||||
} |
||||
t.eraseNPreviousChars(1) |
||||
case keyAltLeft: |
||||
// move left by a word.
|
||||
t.pos -= t.countToLeftWord() |
||||
t.moveCursorToPos(t.pos) |
||||
case keyAltRight: |
||||
// move right by a word.
|
||||
t.pos += t.countToRightWord() |
||||
t.moveCursorToPos(t.pos) |
||||
case keyLeft: |
||||
if t.pos == 0 { |
||||
return |
||||
} |
||||
t.pos-- |
||||
t.moveCursorToPos(t.pos) |
||||
case keyRight: |
||||
if t.pos == len(t.line) { |
||||
return |
||||
} |
||||
t.pos++ |
||||
t.moveCursorToPos(t.pos) |
||||
case keyHome: |
||||
if t.pos == 0 { |
||||
return |
||||
} |
||||
t.pos = 0 |
||||
t.moveCursorToPos(t.pos) |
||||
case keyEnd: |
||||
if t.pos == len(t.line) { |
||||
return |
||||
} |
||||
t.pos = len(t.line) |
||||
t.moveCursorToPos(t.pos) |
||||
case keyUp: |
||||
entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1) |
||||
if !ok { |
||||
return "", false |
||||
} |
||||
if t.historyIndex == -1 { |
||||
t.historyPending = string(t.line) |
||||
} |
||||
t.historyIndex++ |
||||
runes := []rune(entry) |
||||
t.setLine(runes, len(runes)) |
||||
case keyDown: |
||||
switch t.historyIndex { |
||||
case -1: |
||||
return |
||||
case 0: |
||||
runes := []rune(t.historyPending) |
||||
t.setLine(runes, len(runes)) |
||||
t.historyIndex-- |
||||
default: |
||||
entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1) |
||||
if ok { |
||||
t.historyIndex-- |
||||
runes := []rune(entry) |
||||
t.setLine(runes, len(runes)) |
||||
} |
||||
} |
||||
case keyEnter: |
||||
t.moveCursorToPos(len(t.line)) |
||||
t.queue([]rune("\r\n")) |
||||
line = string(t.line) |
||||
ok = true |
||||
t.line = t.line[:0] |
||||
t.pos = 0 |
||||
t.cursorX = 0 |
||||
t.cursorY = 0 |
||||
t.maxLine = 0 |
||||
case keyDeleteWord: |
||||
// Delete zero or more spaces and then one or more characters.
|
||||
t.eraseNPreviousChars(t.countToLeftWord()) |
||||
case keyDeleteLine: |
||||
// Delete everything from the current cursor position to the
|
||||
// end of line.
|
||||
for i := t.pos; i < len(t.line); i++ { |
||||
t.queue(space) |
||||
t.advanceCursor(1) |
||||
} |
||||
t.line = t.line[:t.pos] |
||||
t.moveCursorToPos(t.pos) |
||||
case keyCtrlD: |
||||
// Erase the character under the current position.
|
||||
// The EOF case when the line is empty is handled in
|
||||
// readLine().
|
||||
if t.pos < len(t.line) { |
||||
t.pos++ |
||||
t.eraseNPreviousChars(1) |
||||
} |
||||
case keyCtrlU: |
||||
t.eraseNPreviousChars(t.pos) |
||||
case keyClearScreen: |
||||
// Erases the screen and moves the cursor to the home position.
|
||||
t.queue([]rune("\x1b[2J\x1b[H")) |
||||
t.queue(t.prompt) |
||||
t.cursorX, t.cursorY = 0, 0 |
||||
t.advanceCursor(visualLength(t.prompt)) |
||||
t.setLine(t.line, t.pos) |
||||
default: |
||||
if t.AutoCompleteCallback != nil { |
||||
prefix := string(t.line[:t.pos]) |
||||
suffix := string(t.line[t.pos:]) |
||||
|
||||
t.lock.Unlock() |
||||
newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key) |
||||
t.lock.Lock() |
||||
|
||||
if completeOk { |
||||
t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos])) |
||||
return |
||||
} |
||||
} |
||||
if !isPrintable(key) { |
||||
return |
||||
} |
||||
if len(t.line) == maxLineLength { |
||||
return |
||||
} |
||||
t.addKeyToLine(key) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// addKeyToLine inserts the given key at the current position in the current
|
||||
// line.
|
||||
func (t *Terminal) addKeyToLine(key rune) { |
||||
if len(t.line) == cap(t.line) { |
||||
newLine := make([]rune, len(t.line), 2*(1+len(t.line))) |
||||
copy(newLine, t.line) |
||||
t.line = newLine |
||||
} |
||||
t.line = t.line[:len(t.line)+1] |
||||
copy(t.line[t.pos+1:], t.line[t.pos:]) |
||||
t.line[t.pos] = key |
||||
if t.echo { |
||||
t.writeLine(t.line[t.pos:]) |
||||
} |
||||
t.pos++ |
||||
t.moveCursorToPos(t.pos) |
||||
} |
||||
|
||||
func (t *Terminal) writeLine(line []rune) { |
||||
for len(line) != 0 { |
||||
remainingOnLine := t.termWidth - t.cursorX |
||||
todo := len(line) |
||||
if todo > remainingOnLine { |
||||
todo = remainingOnLine |
||||
} |
||||
t.queue(line[:todo]) |
||||
t.advanceCursor(visualLength(line[:todo])) |
||||
line = line[todo:] |
||||
} |
||||
} |
||||
|
||||
// writeWithCRLF writes buf to w but replaces all occurrences of \n with \r\n.
|
||||
func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) { |
||||
for len(buf) > 0 { |
||||
i := bytes.IndexByte(buf, '\n') |
||||
todo := len(buf) |
||||
if i >= 0 { |
||||
todo = i |
||||
} |
||||
|
||||
var nn int |
||||
nn, err = w.Write(buf[:todo]) |
||||
n += nn |
||||
if err != nil { |
||||
return n, err |
||||
} |
||||
buf = buf[todo:] |
||||
|
||||
if i >= 0 { |
||||
if _, err = w.Write(crlf); err != nil { |
||||
return n, err |
||||
} |
||||
n += 1 |
||||
buf = buf[1:] |
||||
} |
||||
} |
||||
|
||||
return n, nil |
||||
} |
||||
|
||||
func (t *Terminal) Write(buf []byte) (n int, err error) { |
||||
t.lock.Lock() |
||||
defer t.lock.Unlock() |
||||
|
||||
if t.cursorX == 0 && t.cursorY == 0 { |
||||
// This is the easy case: there's nothing on the screen that we
|
||||
// have to move out of the way.
|
||||
return writeWithCRLF(t.c, buf) |
||||
} |
||||
|
||||
// We have a prompt and possibly user input on the screen. We
|
||||
// have to clear it first.
|
||||
t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */) |
||||
t.cursorX = 0 |
||||
t.clearLineToRight() |
||||
|
||||
for t.cursorY > 0 { |
||||
t.move(1 /* up */, 0, 0, 0) |
||||
t.cursorY-- |
||||
t.clearLineToRight() |
||||
} |
||||
|
||||
if _, err = t.c.Write(t.outBuf); err != nil { |
||||
return |
||||
} |
||||
t.outBuf = t.outBuf[:0] |
||||
|
||||
if n, err = writeWithCRLF(t.c, buf); err != nil { |
||||
return |
||||
} |
||||
|
||||
t.writeLine(t.prompt) |
||||
if t.echo { |
||||
t.writeLine(t.line) |
||||
} |
||||
|
||||
t.moveCursorToPos(t.pos) |
||||
|
||||
if _, err = t.c.Write(t.outBuf); err != nil { |
||||
return |
||||
} |
||||
t.outBuf = t.outBuf[:0] |
||||
return |
||||
} |
||||
|
||||
// ReadPassword temporarily changes the prompt and reads a password, without
|
||||
// echo, from the terminal.
|
||||
func (t *Terminal) ReadPassword(prompt string) (line string, err error) { |
||||
t.lock.Lock() |
||||
defer t.lock.Unlock() |
||||
|
||||
oldPrompt := t.prompt |
||||
t.prompt = []rune(prompt) |
||||
t.echo = false |
||||
|
||||
line, err = t.readLine() |
||||
|
||||
t.prompt = oldPrompt |
||||
t.echo = true |
||||
|
||||
return |
||||
} |
||||
|
||||
// ReadLine returns a line of input from the terminal.
|
||||
func (t *Terminal) ReadLine() (line string, err error) { |
||||
t.lock.Lock() |
||||
defer t.lock.Unlock() |
||||
|
||||
return t.readLine() |
||||
} |
||||
|
||||
func (t *Terminal) readLine() (line string, err error) { |
||||
// t.lock must be held at this point
|
||||
|
||||
if t.cursorX == 0 && t.cursorY == 0 { |
||||
t.writeLine(t.prompt) |
||||
t.c.Write(t.outBuf) |
||||
t.outBuf = t.outBuf[:0] |
||||
} |
||||
|
||||
lineIsPasted := t.pasteActive |
||||
|
||||
for { |
||||
rest := t.remainder |
||||
lineOk := false |
||||
for !lineOk { |
||||
var key rune |
||||
key, rest = bytesToKey(rest, t.pasteActive) |
||||
if key == utf8.RuneError { |
||||
break |
||||
} |
||||
if !t.pasteActive { |
||||
if key == keyCtrlD { |
||||
if len(t.line) == 0 { |
||||
return "", io.EOF |
||||
} |
||||
} |
||||
if key == keyPasteStart { |
||||
t.pasteActive = true |
||||
if len(t.line) == 0 { |
||||
lineIsPasted = true |
||||
} |
||||
continue |
||||
} |
||||
} else if key == keyPasteEnd { |
||||
t.pasteActive = false |
||||
continue |
||||
} |
||||
if !t.pasteActive { |
||||
lineIsPasted = false |
||||
} |
||||
line, lineOk = t.handleKey(key) |
||||
} |
||||
if len(rest) > 0 { |
||||
n := copy(t.inBuf[:], rest) |
||||
t.remainder = t.inBuf[:n] |
||||
} else { |
||||
t.remainder = nil |
||||
} |
||||
t.c.Write(t.outBuf) |
||||
t.outBuf = t.outBuf[:0] |
||||
if lineOk { |
||||
if t.echo { |
||||
t.historyIndex = -1 |
||||
t.history.Add(line) |
||||
} |
||||
if lineIsPasted { |
||||
err = ErrPasteIndicator |
||||
} |
||||
return |
||||
} |
||||
|
||||
// t.remainder is a slice at the beginning of t.inBuf
|
||||
// containing a partial key sequence
|
||||
readBuf := t.inBuf[len(t.remainder):] |
||||
var n int |
||||
|
||||
t.lock.Unlock() |
||||
n, err = t.c.Read(readBuf) |
||||
t.lock.Lock() |
||||
|
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
t.remainder = t.inBuf[:n+len(t.remainder)] |
||||
} |
||||
} |
||||
|
||||
// SetPrompt sets the prompt to be used when reading subsequent lines.
|
||||
func (t *Terminal) SetPrompt(prompt string) { |
||||
t.lock.Lock() |
||||
defer t.lock.Unlock() |
||||
|
||||
t.prompt = []rune(prompt) |
||||
} |
||||
|
||||
func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) { |
||||
// Move cursor to column zero at the start of the line.
|
||||
t.move(t.cursorY, 0, t.cursorX, 0) |
||||
t.cursorX, t.cursorY = 0, 0 |
||||
t.clearLineToRight() |
||||
for t.cursorY < numPrevLines { |
||||
// Move down a line
|
||||
t.move(0, 1, 0, 0) |
||||
t.cursorY++ |
||||
t.clearLineToRight() |
||||
} |
||||
// Move back to beginning.
|
||||
t.move(t.cursorY, 0, 0, 0) |
||||
t.cursorX, t.cursorY = 0, 0 |
||||
|
||||
t.queue(t.prompt) |
||||
t.advanceCursor(visualLength(t.prompt)) |
||||
t.writeLine(t.line) |
||||
t.moveCursorToPos(t.pos) |
||||
} |
||||
|
||||
func (t *Terminal) SetSize(width, height int) error { |
||||
t.lock.Lock() |
||||
defer t.lock.Unlock() |
||||
|
||||
if width == 0 { |
||||
width = 1 |
||||
} |
||||
|
||||
oldWidth := t.termWidth |
||||
t.termWidth, t.termHeight = width, height |
||||
|
||||
switch { |
||||
case width == oldWidth: |
||||
// If the width didn't change then nothing else needs to be
|
||||
// done.
|
||||
return nil |
||||
case len(t.line) == 0 && t.cursorX == 0 && t.cursorY == 0: |
||||
// If there is nothing on current line and no prompt printed,
|
||||
// just do nothing
|
||||
return nil |
||||
case width < oldWidth: |
||||
// Some terminals (e.g. xterm) will truncate lines that were
|
||||
// too long when shinking. Others, (e.g. gnome-terminal) will
|
||||
// attempt to wrap them. For the former, repainting t.maxLine
|
||||
// works great, but that behaviour goes badly wrong in the case
|
||||
// of the latter because they have doubled every full line.
|
||||
|
||||
// We assume that we are working on a terminal that wraps lines
|
||||
// and adjust the cursor position based on every previous line
|
||||
// wrapping and turning into two. This causes the prompt on
|
||||
// xterms to move upwards, which isn't great, but it avoids a
|
||||
// huge mess with gnome-terminal.
|
||||
if t.cursorX >= t.termWidth { |
||||
t.cursorX = t.termWidth - 1 |
||||
} |
||||
t.cursorY *= 2 |
||||
t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2) |
||||
case width > oldWidth: |
||||
// If the terminal expands then our position calculations will
|
||||
// be wrong in the future because we think the cursor is
|
||||
// |t.pos| chars into the string, but there will be a gap at
|
||||
// the end of any wrapped line.
|
||||
//
|
||||
// But the position will actually be correct until we move, so
|
||||
// we can move back to the beginning and repaint everything.
|
||||
t.clearAndRepaintLinePlusNPrevious(t.maxLine) |
||||
} |
||||
|
||||
_, err := t.c.Write(t.outBuf) |
||||
t.outBuf = t.outBuf[:0] |
||||
return err |
||||
} |
||||
|
||||
type pasteIndicatorError struct{} |
||||
|
||||
func (pasteIndicatorError) Error() string { |
||||
return "terminal: ErrPasteIndicator not correctly handled" |
||||
} |
||||
|
||||
// ErrPasteIndicator may be returned from ReadLine as the error, in addition
|
||||
// to valid line data. It indicates that bracketed paste mode is enabled and
|
||||
// that the returned line consists only of pasted data. Programs may wish to
|
||||
// interpret pasted data more literally than typed data.
|
||||
var ErrPasteIndicator = pasteIndicatorError{} |
||||
|
||||
// SetBracketedPasteMode requests that the terminal bracket paste operations
|
||||
// with markers. Not all terminals support this but, if it is supported, then
|
||||
// enabling this mode will stop any autocomplete callback from running due to
|
||||
// pastes. Additionally, any lines that are completely pasted will be returned
|
||||
// from ReadLine with the error set to ErrPasteIndicator.
|
||||
func (t *Terminal) SetBracketedPasteMode(on bool) { |
||||
if on { |
||||
io.WriteString(t.c, "\x1b[?2004h") |
||||
} else { |
||||
io.WriteString(t.c, "\x1b[?2004l") |
||||
} |
||||
} |
||||
|
||||
// stRingBuffer is a ring buffer of strings.
|
||||
type stRingBuffer struct { |
||||
// entries contains max elements.
|
||||
entries []string |
||||
max int |
||||
// head contains the index of the element most recently added to the ring.
|
||||
head int |
||||
// size contains the number of elements in the ring.
|
||||
size int |
||||
} |
||||
|
||||
func (s *stRingBuffer) Add(a string) { |
||||
if s.entries == nil { |
||||
const defaultNumEntries = 100 |
||||
s.entries = make([]string, defaultNumEntries) |
||||
s.max = defaultNumEntries |
||||
} |
||||
|
||||
s.head = (s.head + 1) % s.max |
||||
s.entries[s.head] = a |
||||
if s.size < s.max { |
||||
s.size++ |
||||
} |
||||
} |
||||
|
||||
// NthPreviousEntry returns the value passed to the nth previous call to Add.
|
||||
// If n is zero then the immediately prior value is returned, if one, then the
|
||||
// next most recent, and so on. If such an element doesn't exist then ok is
|
||||
// false.
|
||||
func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) { |
||||
if n >= s.size { |
||||
return "", false |
||||
} |
||||
index := s.head - n |
||||
if index < 0 { |
||||
index += s.max |
||||
} |
||||
return s.entries[index], true |
||||
} |
||||
|
||||
// readPasswordLine reads from reader until it finds \n or io.EOF.
|
||||
// The slice returned does not include the \n.
|
||||
// readPasswordLine also ignores any \r it finds.
|
||||
func readPasswordLine(reader io.Reader) ([]byte, error) { |
||||
var buf [1]byte |
||||
var ret []byte |
||||
|
||||
for { |
||||
n, err := reader.Read(buf[:]) |
||||
if n > 0 { |
||||
switch buf[0] { |
||||
case '\n': |
||||
return ret, nil |
||||
case '\r': |
||||
// remove \r from passwords on Windows
|
||||
default: |
||||
ret = append(ret, buf[0]) |
||||
} |
||||
continue |
||||
} |
||||
if err != nil { |
||||
if err == io.EOF && len(ret) > 0 { |
||||
return ret, nil |
||||
} |
||||
return ret, err |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,119 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build darwin dragonfly freebsd linux,!appengine netbsd openbsd
|
||||
|
||||
// Package terminal provides support functions for dealing with terminals, as
|
||||
// commonly found on UNIX systems.
|
||||
//
|
||||
// Putting a terminal into raw mode is the most common requirement:
|
||||
//
|
||||
// oldState, err := terminal.MakeRaw(0)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// defer terminal.Restore(0, oldState)
|
||||
package terminal // import "golang.org/x/crypto/ssh/terminal"
|
||||
|
||||
import ( |
||||
"syscall" |
||||
"unsafe" |
||||
) |
||||
|
||||
// State contains the state of a terminal.
|
||||
type State struct { |
||||
termios syscall.Termios |
||||
} |
||||
|
||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||
func IsTerminal(fd int) bool { |
||||
var termios syscall.Termios |
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) |
||||
return err == 0 |
||||
} |
||||
|
||||
// MakeRaw put the terminal connected to the given file descriptor into raw
|
||||
// mode and returns the previous state of the terminal so that it can be
|
||||
// restored.
|
||||
func MakeRaw(fd int) (*State, error) { |
||||
var oldState State |
||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { |
||||
return nil, err |
||||
} |
||||
|
||||
newState := oldState.termios |
||||
// This attempts to replicate the behaviour documented for cfmakeraw in
|
||||
// the termios(3) manpage.
|
||||
newState.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON |
||||
newState.Oflag &^= syscall.OPOST |
||||
newState.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN |
||||
newState.Cflag &^= syscall.CSIZE | syscall.PARENB |
||||
newState.Cflag |= syscall.CS8 |
||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { |
||||
return nil, err |
||||
} |
||||
|
||||
return &oldState, nil |
||||
} |
||||
|
||||
// GetState returns the current state of a terminal which may be useful to
|
||||
// restore the terminal after a signal.
|
||||
func GetState(fd int) (*State, error) { |
||||
var oldState State |
||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { |
||||
return nil, err |
||||
} |
||||
|
||||
return &oldState, nil |
||||
} |
||||
|
||||
// Restore restores the terminal connected to the given file descriptor to a
|
||||
// previous state.
|
||||
func Restore(fd int, state *State) error { |
||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0); err != 0 { |
||||
return err |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetSize returns the dimensions of the given terminal.
|
||||
func GetSize(fd int) (width, height int, err error) { |
||||
var dimensions [4]uint16 |
||||
|
||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 { |
||||
return -1, -1, err |
||||
} |
||||
return int(dimensions[1]), int(dimensions[0]), nil |
||||
} |
||||
|
||||
// passwordReader is an io.Reader that reads from a specific file descriptor.
|
||||
type passwordReader int |
||||
|
||||
func (r passwordReader) Read(buf []byte) (int, error) { |
||||
return syscall.Read(int(r), buf) |
||||
} |
||||
|
||||
// ReadPassword reads a line of input from a terminal without local echo. This
|
||||
// is commonly used for inputting passwords and other sensitive data. The slice
|
||||
// returned does not include the \n.
|
||||
func ReadPassword(fd int) ([]byte, error) { |
||||
var oldState syscall.Termios |
||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 { |
||||
return nil, err |
||||
} |
||||
|
||||
newState := oldState |
||||
newState.Lflag &^= syscall.ECHO |
||||
newState.Lflag |= syscall.ICANON | syscall.ISIG |
||||
newState.Iflag |= syscall.ICRNL |
||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { |
||||
return nil, err |
||||
} |
||||
|
||||
defer func() { |
||||
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) |
||||
}() |
||||
|
||||
return readPasswordLine(passwordReader(fd)) |
||||
} |
@ -0,0 +1,12 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build darwin dragonfly freebsd netbsd openbsd
|
||||
|
||||
package terminal |
||||
|
||||
import "syscall" |
||||
|
||||
const ioctlReadTermios = syscall.TIOCGETA |
||||
const ioctlWriteTermios = syscall.TIOCSETA |
@ -0,0 +1,11 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package terminal |
||||
|
||||
// These constants are declared here, rather than importing
|
||||
// them from the syscall package as some syscall packages, even
|
||||
// on linux, for example gccgo, do not declare them.
|
||||
const ioctlReadTermios = 0x5401 // syscall.TCGETS
|
||||
const ioctlWriteTermios = 0x5402 // syscall.TCSETS
|
@ -0,0 +1,58 @@ |
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package terminal provides support functions for dealing with terminals, as
|
||||
// commonly found on UNIX systems.
|
||||
//
|
||||
// Putting a terminal into raw mode is the most common requirement:
|
||||
//
|
||||
// oldState, err := terminal.MakeRaw(0)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// defer terminal.Restore(0, oldState)
|
||||
package terminal |
||||
|
||||
import ( |
||||
"fmt" |
||||
"runtime" |
||||
) |
||||
|
||||
type State struct{} |
||||
|
||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||
func IsTerminal(fd int) bool { |
||||
return false |
||||
} |
||||
|
||||
// MakeRaw put the terminal connected to the given file descriptor into raw
|
||||
// mode and returns the previous state of the terminal so that it can be
|
||||
// restored.
|
||||
func MakeRaw(fd int) (*State, error) { |
||||
return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) |
||||
} |
||||
|
||||
// GetState returns the current state of a terminal which may be useful to
|
||||
// restore the terminal after a signal.
|
||||
func GetState(fd int) (*State, error) { |
||||
return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) |
||||
} |
||||
|
||||
// Restore restores the terminal connected to the given file descriptor to a
|
||||
// previous state.
|
||||
func Restore(fd int, state *State) error { |
||||
return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) |
||||
} |
||||
|
||||
// GetSize returns the dimensions of the given terminal.
|
||||
func GetSize(fd int) (width, height int, err error) { |
||||
return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) |
||||
} |
||||
|
||||
// ReadPassword reads a line of input from a terminal without local echo. This
|
||||
// is commonly used for inputting passwords and other sensitive data. The slice
|
||||
// returned does not include the \n.
|
||||
func ReadPassword(fd int) ([]byte, error) { |
||||
return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) |
||||
} |
@ -0,0 +1,73 @@ |
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build solaris
|
||||
|
||||
package terminal // import "golang.org/x/crypto/ssh/terminal"
|
||||
|
||||
import ( |
||||
"golang.org/x/sys/unix" |
||||
"io" |
||||
"syscall" |
||||
) |
||||
|
||||
// State contains the state of a terminal.
|
||||
type State struct { |
||||
termios syscall.Termios |
||||
} |
||||
|
||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||
func IsTerminal(fd int) bool { |
||||
// see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c
|
||||
var termio unix.Termio |
||||
err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio) |
||||
return err == nil |
||||
} |
||||
|
||||
// ReadPassword reads a line of input from a terminal without local echo. This
|
||||
// is commonly used for inputting passwords and other sensitive data. The slice
|
||||
// returned does not include the \n.
|
||||
func ReadPassword(fd int) ([]byte, error) { |
||||
// see also: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libast/common/uwin/getpass.c
|
||||
val, err := unix.IoctlGetTermios(fd, unix.TCGETS) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
oldState := *val |
||||
|
||||
newState := oldState |
||||
newState.Lflag &^= syscall.ECHO |
||||
newState.Lflag |= syscall.ICANON | syscall.ISIG |
||||
newState.Iflag |= syscall.ICRNL |
||||
err = unix.IoctlSetTermios(fd, unix.TCSETS, &newState) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
defer unix.IoctlSetTermios(fd, unix.TCSETS, &oldState) |
||||
|
||||
var buf [16]byte |
||||
var ret []byte |
||||
for { |
||||
n, err := syscall.Read(fd, buf[:]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if n == 0 { |
||||
if len(ret) == 0 { |
||||
return nil, io.EOF |
||||
} |
||||
break |
||||
} |
||||
if buf[n-1] == '\n' { |
||||
n-- |
||||
} |
||||
ret = append(ret, buf[:n]...) |
||||
if n < len(buf) { |
||||
break |
||||
} |
||||
} |
||||
|
||||
return ret, nil |
||||
} |
@ -0,0 +1,155 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
// Package terminal provides support functions for dealing with terminals, as
|
||||
// commonly found on UNIX systems.
|
||||
//
|
||||
// Putting a terminal into raw mode is the most common requirement:
|
||||
//
|
||||
// oldState, err := terminal.MakeRaw(0)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// defer terminal.Restore(0, oldState)
|
||||
package terminal |
||||
|
||||
import ( |
||||
"syscall" |
||||
"unsafe" |
||||
) |
||||
|
||||
const ( |
||||
enableLineInput = 2 |
||||
enableEchoInput = 4 |
||||
enableProcessedInput = 1 |
||||
enableWindowInput = 8 |
||||
enableMouseInput = 16 |
||||
enableInsertMode = 32 |
||||
enableQuickEditMode = 64 |
||||
enableExtendedFlags = 128 |
||||
enableAutoPosition = 256 |
||||
enableProcessedOutput = 1 |
||||
enableWrapAtEolOutput = 2 |
||||
) |
||||
|
||||
var kernel32 = syscall.NewLazyDLL("kernel32.dll") |
||||
|
||||
var ( |
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode") |
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode") |
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") |
||||
) |
||||
|
||||
type ( |
||||
short int16 |
||||
word uint16 |
||||
|
||||
coord struct { |
||||
x short |
||||
y short |
||||
} |
||||
smallRect struct { |
||||
left short |
||||
top short |
||||
right short |
||||
bottom short |
||||
} |
||||
consoleScreenBufferInfo struct { |
||||
size coord |
||||
cursorPosition coord |
||||
attributes word |
||||
window smallRect |
||||
maximumWindowSize coord |
||||
} |
||||
) |
||||
|
||||
type State struct { |
||||
mode uint32 |
||||
} |
||||
|
||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||
func IsTerminal(fd int) bool { |
||||
var st uint32 |
||||
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
||||
return r != 0 && e == 0 |
||||
} |
||||
|
||||
// MakeRaw put the terminal connected to the given file descriptor into raw
|
||||
// mode and returns the previous state of the terminal so that it can be
|
||||
// restored.
|
||||
func MakeRaw(fd int) (*State, error) { |
||||
var st uint32 |
||||
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
||||
if e != 0 { |
||||
return nil, error(e) |
||||
} |
||||
raw := st &^ (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput) |
||||
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(raw), 0) |
||||
if e != 0 { |
||||
return nil, error(e) |
||||
} |
||||
return &State{st}, nil |
||||
} |
||||
|
||||
// GetState returns the current state of a terminal which may be useful to
|
||||
// restore the terminal after a signal.
|
||||
func GetState(fd int) (*State, error) { |
||||
var st uint32 |
||||
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
||||
if e != 0 { |
||||
return nil, error(e) |
||||
} |
||||
return &State{st}, nil |
||||
} |
||||
|
||||
// Restore restores the terminal connected to the given file descriptor to a
|
||||
// previous state.
|
||||
func Restore(fd int, state *State) error { |
||||
_, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0) |
||||
return err |
||||
} |
||||
|
||||
// GetSize returns the dimensions of the given terminal.
|
||||
func GetSize(fd int) (width, height int, err error) { |
||||
var info consoleScreenBufferInfo |
||||
_, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0) |
||||
if e != 0 { |
||||
return 0, 0, error(e) |
||||
} |
||||
return int(info.size.x), int(info.size.y), nil |
||||
} |
||||
|
||||
// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
|
||||
type passwordReader int |
||||
|
||||
func (r passwordReader) Read(buf []byte) (int, error) { |
||||
return syscall.Read(syscall.Handle(r), buf) |
||||
} |
||||
|
||||
// ReadPassword reads a line of input from a terminal without local echo. This
|
||||
// is commonly used for inputting passwords and other sensitive data. The slice
|
||||
// returned does not include the \n.
|
||||
func ReadPassword(fd int) ([]byte, error) { |
||||
var st uint32 |
||||
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
||||
if e != 0 { |
||||
return nil, error(e) |
||||
} |
||||
old := st |
||||
|
||||
st &^= (enableEchoInput) |
||||
st |= (enableProcessedInput | enableLineInput | enableProcessedOutput) |
||||
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) |
||||
if e != 0 { |
||||
return nil, error(e) |
||||
} |
||||
|
||||
defer func() { |
||||
syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) |
||||
}() |
||||
|
||||
return readPasswordLine(passwordReader(fd)) |
||||
} |
@ -0,0 +1,375 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bufio" |
||||
"errors" |
||||
"io" |
||||
"log" |
||||
) |
||||
|
||||
// debugTransport if set, will print packet types as they go over the
|
||||
// wire. No message decoding is done, to minimize the impact on timing.
|
||||
const debugTransport = false |
||||
|
||||
const ( |
||||
gcmCipherID = "aes128-gcm@openssh.com" |
||||
aes128cbcID = "aes128-cbc" |
||||
tripledescbcID = "3des-cbc" |
||||
) |
||||
|
||||
// packetConn represents a transport that implements packet based
|
||||
// operations.
|
||||
type packetConn interface { |
||||
// Encrypt and send a packet of data to the remote peer.
|
||||
writePacket(packet []byte) error |
||||
|
||||
// Read a packet from the connection. The read is blocking,
|
||||
// i.e. if error is nil, then the returned byte slice is
|
||||
// always non-empty.
|
||||
readPacket() ([]byte, error) |
||||
|
||||
// Close closes the write-side of the connection.
|
||||
Close() error |
||||
} |
||||
|
||||
// transport is the keyingTransport that implements the SSH packet
|
||||
// protocol.
|
||||
type transport struct { |
||||
reader connectionState |
||||
writer connectionState |
||||
|
||||
bufReader *bufio.Reader |
||||
bufWriter *bufio.Writer |
||||
rand io.Reader |
||||
isClient bool |
||||
io.Closer |
||||
} |
||||
|
||||
// packetCipher represents a combination of SSH encryption/MAC
|
||||
// protocol. A single instance should be used for one direction only.
|
||||
type packetCipher interface { |
||||
// writePacket encrypts the packet and writes it to w. The
|
||||
// contents of the packet are generally scrambled.
|
||||
writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error |
||||
|
||||
// readPacket reads and decrypts a packet of data. The
|
||||
// returned packet may be overwritten by future calls of
|
||||
// readPacket.
|
||||
readPacket(seqnum uint32, r io.Reader) ([]byte, error) |
||||
} |
||||
|
||||
// connectionState represents one side (read or write) of the
|
||||
// connection. This is necessary because each direction has its own
|
||||
// keys, and can even have its own algorithms
|
||||
type connectionState struct { |
||||
packetCipher |
||||
seqNum uint32 |
||||
dir direction |
||||
pendingKeyChange chan packetCipher |
||||
} |
||||
|
||||
// prepareKeyChange sets up key material for a keychange. The key changes in
|
||||
// both directions are triggered by reading and writing a msgNewKey packet
|
||||
// respectively.
|
||||
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { |
||||
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { |
||||
return err |
||||
} else { |
||||
t.reader.pendingKeyChange <- ciph |
||||
} |
||||
|
||||
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { |
||||
return err |
||||
} else { |
||||
t.writer.pendingKeyChange <- ciph |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (t *transport) printPacket(p []byte, write bool) { |
||||
if len(p) == 0 { |
||||
return |
||||
} |
||||
who := "server" |
||||
if t.isClient { |
||||
who = "client" |
||||
} |
||||
what := "read" |
||||
if write { |
||||
what = "write" |
||||
} |
||||
|
||||
log.Println(what, who, p[0]) |
||||
} |
||||
|
||||
// Read and decrypt next packet.
|
||||
func (t *transport) readPacket() (p []byte, err error) { |
||||
for { |
||||
p, err = t.reader.readPacket(t.bufReader) |
||||
if err != nil { |
||||
break |
||||
} |
||||
if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) { |
||||
break |
||||
} |
||||
} |
||||
if debugTransport { |
||||
t.printPacket(p, false) |
||||
} |
||||
|
||||
return p, err |
||||
} |
||||
|
||||
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { |
||||
packet, err := s.packetCipher.readPacket(s.seqNum, r) |
||||
s.seqNum++ |
||||
if err == nil && len(packet) == 0 { |
||||
err = errors.New("ssh: zero length packet") |
||||
} |
||||
|
||||
if len(packet) > 0 { |
||||
switch packet[0] { |
||||
case msgNewKeys: |
||||
select { |
||||
case cipher := <-s.pendingKeyChange: |
||||
s.packetCipher = cipher |
||||
default: |
||||
return nil, errors.New("ssh: got bogus newkeys message.") |
||||
} |
||||
|
||||
case msgDisconnect: |
||||
// Transform a disconnect message into an
|
||||
// error. Since this is lowest level at which
|
||||
// we interpret message types, doing it here
|
||||
// ensures that we don't have to handle it
|
||||
// elsewhere.
|
||||
var msg disconnectMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return nil, err |
||||
} |
||||
return nil, &msg |
||||
} |
||||
} |
||||
|
||||
// The packet may point to an internal buffer, so copy the
|
||||
// packet out here.
|
||||
fresh := make([]byte, len(packet)) |
||||
copy(fresh, packet) |
||||
|
||||
return fresh, err |
||||
} |
||||
|
||||
func (t *transport) writePacket(packet []byte) error { |
||||
if debugTransport { |
||||
t.printPacket(packet, true) |
||||
} |
||||
return t.writer.writePacket(t.bufWriter, t.rand, packet) |
||||
} |
||||
|
||||
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { |
||||
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys |
||||
|
||||
err := s.packetCipher.writePacket(s.seqNum, w, rand, packet) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err = w.Flush(); err != nil { |
||||
return err |
||||
} |
||||
s.seqNum++ |
||||
if changeKeys { |
||||
select { |
||||
case cipher := <-s.pendingKeyChange: |
||||
s.packetCipher = cipher |
||||
default: |
||||
panic("ssh: no key material for msgNewKeys") |
||||
} |
||||
} |
||||
return err |
||||
} |
||||
|
||||
func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { |
||||
t := &transport{ |
||||
bufReader: bufio.NewReader(rwc), |
||||
bufWriter: bufio.NewWriter(rwc), |
||||
rand: rand, |
||||
reader: connectionState{ |
||||
packetCipher: &streamPacketCipher{cipher: noneCipher{}}, |
||||
pendingKeyChange: make(chan packetCipher, 1), |
||||
}, |
||||
writer: connectionState{ |
||||
packetCipher: &streamPacketCipher{cipher: noneCipher{}}, |
||||
pendingKeyChange: make(chan packetCipher, 1), |
||||
}, |
||||
Closer: rwc, |
||||
} |
||||
t.isClient = isClient |
||||
|
||||
if isClient { |
||||
t.reader.dir = serverKeys |
||||
t.writer.dir = clientKeys |
||||
} else { |
||||
t.reader.dir = clientKeys |
||||
t.writer.dir = serverKeys |
||||
} |
||||
|
||||
return t |
||||
} |
||||
|
||||
type direction struct { |
||||
ivTag []byte |
||||
keyTag []byte |
||||
macKeyTag []byte |
||||
} |
||||
|
||||
var ( |
||||
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} |
||||
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} |
||||
) |
||||
|
||||
// generateKeys generates key material for IV, MAC and encryption.
|
||||
func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) { |
||||
cipherMode := cipherModes[algs.Cipher] |
||||
macMode := macModes[algs.MAC] |
||||
|
||||
iv = make([]byte, cipherMode.ivSize) |
||||
key = make([]byte, cipherMode.keySize) |
||||
macKey = make([]byte, macMode.keySize) |
||||
|
||||
generateKeyMaterial(iv, d.ivTag, kex) |
||||
generateKeyMaterial(key, d.keyTag, kex) |
||||
generateKeyMaterial(macKey, d.macKeyTag, kex) |
||||
return |
||||
} |
||||
|
||||
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
|
||||
// described in RFC 4253, section 6.4. direction should either be serverKeys
|
||||
// (to setup server->client keys) or clientKeys (for client->server keys).
|
||||
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { |
||||
iv, key, macKey := generateKeys(d, algs, kex) |
||||
|
||||
if algs.Cipher == gcmCipherID { |
||||
return newGCMCipher(iv, key, macKey) |
||||
} |
||||
|
||||
if algs.Cipher == aes128cbcID { |
||||
return newAESCBCCipher(iv, key, macKey, algs) |
||||
} |
||||
|
||||
if algs.Cipher == tripledescbcID { |
||||
return newTripleDESCBCCipher(iv, key, macKey, algs) |
||||
} |
||||
|
||||
c := &streamPacketCipher{ |
||||
mac: macModes[algs.MAC].new(macKey), |
||||
etm: macModes[algs.MAC].etm, |
||||
} |
||||
c.macResult = make([]byte, c.mac.Size()) |
||||
|
||||
var err error |
||||
c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
// generateKeyMaterial fills out with key material generated from tag, K, H
|
||||
// and sessionId, as specified in RFC 4253, section 7.2.
|
||||
func generateKeyMaterial(out, tag []byte, r *kexResult) { |
||||
var digestsSoFar []byte |
||||
|
||||
h := r.Hash.New() |
||||
for len(out) > 0 { |
||||
h.Reset() |
||||
h.Write(r.K) |
||||
h.Write(r.H) |
||||
|
||||
if len(digestsSoFar) == 0 { |
||||
h.Write(tag) |
||||
h.Write(r.SessionID) |
||||
} else { |
||||
h.Write(digestsSoFar) |
||||
} |
||||
|
||||
digest := h.Sum(nil) |
||||
n := copy(out, digest) |
||||
out = out[n:] |
||||
if len(out) > 0 { |
||||
digestsSoFar = append(digestsSoFar, digest...) |
||||
} |
||||
} |
||||
} |
||||
|
||||
const packageVersion = "SSH-2.0-Go" |
||||
|
||||
// Sends and receives a version line. The versionLine string should
|
||||
// be US ASCII, start with "SSH-2.0-", and should not include a
|
||||
// newline. exchangeVersions returns the other side's version line.
|
||||
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { |
||||
// Contrary to the RFC, we do not ignore lines that don't
|
||||
// start with "SSH-2.0-" to make the library usable with
|
||||
// nonconforming servers.
|
||||
for _, c := range versionLine { |
||||
// The spec disallows non US-ASCII chars, and
|
||||
// specifically forbids null chars.
|
||||
if c < 32 { |
||||
return nil, errors.New("ssh: junk character in version line") |
||||
} |
||||
} |
||||
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { |
||||
return |
||||
} |
||||
|
||||
them, err = readVersion(rw) |
||||
return them, err |
||||
} |
||||
|
||||
// maxVersionStringBytes is the maximum number of bytes that we'll
|
||||
// accept as a version string. RFC 4253 section 4.2 limits this at 255
|
||||
// chars
|
||||
const maxVersionStringBytes = 255 |
||||
|
||||
// Read version string as specified by RFC 4253, section 4.2.
|
||||
func readVersion(r io.Reader) ([]byte, error) { |
||||
versionString := make([]byte, 0, 64) |
||||
var ok bool |
||||
var buf [1]byte |
||||
|
||||
for len(versionString) < maxVersionStringBytes { |
||||
_, err := io.ReadFull(r, buf[:]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
// The RFC says that the version should be terminated with \r\n
|
||||
// but several SSH servers actually only send a \n.
|
||||
if buf[0] == '\n' { |
||||
ok = true |
||||
break |
||||
} |
||||
|
||||
// non ASCII chars are disallowed, but we are lenient,
|
||||
// since Go doesn't use null-terminated strings.
|
||||
|
||||
// The RFC allows a comment after a space, however,
|
||||
// all of it (version and comments) goes into the
|
||||
// session hash.
|
||||
versionString = append(versionString, buf[0]) |
||||
} |
||||
|
||||
if !ok { |
||||
return nil, errors.New("ssh: overflow reading version string") |
||||
} |
||||
|
||||
// There might be a '\r' on the end which we should remove.
|
||||
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { |
||||
versionString = versionString[:len(versionString)-1] |
||||
} |
||||
return versionString, nil |
||||
} |
Loading…
Reference in new issue