mirror of https://github.com/ethereum/go-ethereum
commit
e64f727529
@ -0,0 +1,6 @@ |
||||
language: go |
||||
go: 1.1 |
||||
|
||||
script: |
||||
- go vet ./... |
||||
- go test -v ./... |
@ -0,0 +1,21 @@ |
||||
Copyright (C) 2013 Jeremy Saenz |
||||
All Rights Reserved. |
||||
|
||||
MIT LICENSE |
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of |
||||
this software and associated documentation files (the "Software"), to deal in |
||||
the Software without restriction, including without limitation the rights to |
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
||||
the Software, and to permit persons to whom the Software is furnished to do so, |
||||
subject to the following conditions: |
||||
|
||||
The above copyright notice and this permission notice shall be included in all |
||||
copies or substantial portions of the Software. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@ -0,0 +1,298 @@ |
||||
[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli) |
||||
|
||||
# cli.go |
||||
cli.go is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way. |
||||
|
||||
You can view the API docs here: |
||||
http://godoc.org/github.com/codegangsta/cli |
||||
|
||||
## Overview |
||||
Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app. |
||||
|
||||
**This is where cli.go comes into play.** cli.go makes command line programming fun, organized, and expressive! |
||||
|
||||
## Installation |
||||
Make sure you have a working Go environment (go 1.1 is *required*). [See the install instructions](http://golang.org/doc/install.html). |
||||
|
||||
To install `cli.go`, simply run: |
||||
``` |
||||
$ go get github.com/codegangsta/cli |
||||
``` |
||||
|
||||
Make sure your `PATH` includes to the `$GOPATH/bin` directory so your commands can be easily used: |
||||
``` |
||||
export PATH=$PATH:$GOPATH/bin |
||||
``` |
||||
|
||||
## Getting Started |
||||
One of the philosophies behind cli.go is that an API should be playful and full of discovery. So a cli.go app can be as little as one line of code in `main()`. |
||||
|
||||
``` go |
||||
package main |
||||
|
||||
import ( |
||||
"os" |
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
func main() { |
||||
cli.NewApp().Run(os.Args) |
||||
} |
||||
``` |
||||
|
||||
This app will run and show help text, but is not very useful. Let's give an action to execute and some help documentation: |
||||
|
||||
``` go |
||||
package main |
||||
|
||||
import ( |
||||
"os" |
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
func main() { |
||||
app := cli.NewApp() |
||||
app.Name = "boom" |
||||
app.Usage = "make an explosive entrance" |
||||
app.Action = func(c *cli.Context) { |
||||
println("boom! I say!") |
||||
} |
||||
|
||||
app.Run(os.Args) |
||||
} |
||||
``` |
||||
|
||||
Running this already gives you a ton of functionality, plus support for things like subcommands and flags, which are covered below. |
||||
|
||||
## Example |
||||
|
||||
Being a programmer can be a lonely job. Thankfully by the power of automation that is not the case! Let's create a greeter app to fend off our demons of loneliness! |
||||
|
||||
Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it: |
||||
|
||||
``` go |
||||
package main |
||||
|
||||
import ( |
||||
"os" |
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
func main() { |
||||
app := cli.NewApp() |
||||
app.Name = "greet" |
||||
app.Usage = "fight the loneliness!" |
||||
app.Action = func(c *cli.Context) { |
||||
println("Hello friend!") |
||||
} |
||||
|
||||
app.Run(os.Args) |
||||
} |
||||
``` |
||||
|
||||
Install our command to the `$GOPATH/bin` directory: |
||||
|
||||
``` |
||||
$ go install |
||||
``` |
||||
|
||||
Finally run our new command: |
||||
|
||||
``` |
||||
$ greet |
||||
Hello friend! |
||||
``` |
||||
|
||||
cli.go also generates some bitchass help text: |
||||
``` |
||||
$ greet help |
||||
NAME: |
||||
greet - fight the loneliness! |
||||
|
||||
USAGE: |
||||
greet [global options] command [command options] [arguments...] |
||||
|
||||
VERSION: |
||||
0.0.0 |
||||
|
||||
COMMANDS: |
||||
help, h Shows a list of commands or help for one command |
||||
|
||||
GLOBAL OPTIONS |
||||
--version Shows version information |
||||
``` |
||||
|
||||
### Arguments |
||||
You can lookup arguments by calling the `Args` function on `cli.Context`. |
||||
|
||||
``` go |
||||
... |
||||
app.Action = func(c *cli.Context) { |
||||
println("Hello", c.Args()[0]) |
||||
} |
||||
... |
||||
``` |
||||
|
||||
### Flags |
||||
Setting and querying flags is simple. |
||||
``` go |
||||
... |
||||
app.Flags = []cli.Flag { |
||||
cli.StringFlag{ |
||||
Name: "lang", |
||||
Value: "english", |
||||
Usage: "language for the greeting", |
||||
}, |
||||
} |
||||
app.Action = func(c *cli.Context) { |
||||
name := "someone" |
||||
if len(c.Args()) > 0 { |
||||
name = c.Args()[0] |
||||
} |
||||
if c.String("lang") == "spanish" { |
||||
println("Hola", name) |
||||
} else { |
||||
println("Hello", name) |
||||
} |
||||
} |
||||
... |
||||
``` |
||||
|
||||
#### Alternate Names |
||||
|
||||
You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g. |
||||
|
||||
``` go |
||||
app.Flags = []cli.Flag { |
||||
cli.StringFlag{ |
||||
Name: "lang, l", |
||||
Value: "english", |
||||
Usage: "language for the greeting", |
||||
}, |
||||
} |
||||
``` |
||||
|
||||
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error. |
||||
|
||||
#### Values from the Environment |
||||
|
||||
You can also have the default value set from the environment via `EnvVar`. e.g. |
||||
|
||||
``` go |
||||
app.Flags = []cli.Flag { |
||||
cli.StringFlag{ |
||||
Name: "lang, l", |
||||
Value: "english", |
||||
Usage: "language for the greeting", |
||||
EnvVar: "APP_LANG", |
||||
}, |
||||
} |
||||
``` |
||||
|
||||
The `EnvVar` may also be given as a comma-delimited "cascade", where the first environment variable that resolves is used as the default. |
||||
|
||||
``` go |
||||
app.Flags = []cli.Flag { |
||||
cli.StringFlag{ |
||||
Name: "lang, l", |
||||
Value: "english", |
||||
Usage: "language for the greeting", |
||||
EnvVar: "LEGACY_COMPAT_LANG,APP_LANG,LANG", |
||||
}, |
||||
} |
||||
``` |
||||
|
||||
### Subcommands |
||||
|
||||
Subcommands can be defined for a more git-like command line app. |
||||
```go |
||||
... |
||||
app.Commands = []cli.Command{ |
||||
{ |
||||
Name: "add", |
||||
ShortName: "a", |
||||
Usage: "add a task to the list", |
||||
Action: func(c *cli.Context) { |
||||
println("added task: ", c.Args().First()) |
||||
}, |
||||
}, |
||||
{ |
||||
Name: "complete", |
||||
ShortName: "c", |
||||
Usage: "complete a task on the list", |
||||
Action: func(c *cli.Context) { |
||||
println("completed task: ", c.Args().First()) |
||||
}, |
||||
}, |
||||
{ |
||||
Name: "template", |
||||
ShortName: "r", |
||||
Usage: "options for task templates", |
||||
Subcommands: []cli.Command{ |
||||
{ |
||||
Name: "add", |
||||
Usage: "add a new template", |
||||
Action: func(c *cli.Context) { |
||||
println("new task template: ", c.Args().First()) |
||||
}, |
||||
}, |
||||
{ |
||||
Name: "remove", |
||||
Usage: "remove an existing template", |
||||
Action: func(c *cli.Context) { |
||||
println("removed task template: ", c.Args().First()) |
||||
}, |
||||
}, |
||||
}, |
||||
}, |
||||
} |
||||
... |
||||
``` |
||||
|
||||
### Bash Completion |
||||
|
||||
You can enable completion commands by setting the `EnableBashCompletion` |
||||
flag on the `App` object. By default, this setting will only auto-complete to |
||||
show an app's subcommands, but you can write your own completion methods for |
||||
the App or its subcommands. |
||||
```go |
||||
... |
||||
var tasks = []string{"cook", "clean", "laundry", "eat", "sleep", "code"} |
||||
app := cli.NewApp() |
||||
app.EnableBashCompletion = true |
||||
app.Commands = []cli.Command{ |
||||
{ |
||||
Name: "complete", |
||||
ShortName: "c", |
||||
Usage: "complete a task on the list", |
||||
Action: func(c *cli.Context) { |
||||
println("completed task: ", c.Args().First()) |
||||
}, |
||||
BashComplete: func(c *cli.Context) { |
||||
// This will complete if no args are passed |
||||
if len(c.Args()) > 0 { |
||||
return |
||||
} |
||||
for _, t := range tasks { |
||||
fmt.Println(t) |
||||
} |
||||
}, |
||||
} |
||||
} |
||||
... |
||||
``` |
||||
|
||||
#### To Enable |
||||
|
||||
Source the `autocomplete/bash_autocomplete` file in your `.bashrc` file while |
||||
setting the `PROG` variable to the name of your program: |
||||
|
||||
`PROG=myprogram source /.../cli/autocomplete/bash_autocomplete` |
||||
|
||||
|
||||
## Contribution Guidelines |
||||
Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch. |
||||
|
||||
If you have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together. |
||||
|
||||
If you feel like you have contributed to the project but have not yet been added as a collaborator, I probably forgot to add you. Hit @codegangsta up over email and we will get it figured out. |
@ -0,0 +1,296 @@ |
||||
package cli |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"os" |
||||
"text/tabwriter" |
||||
"text/template" |
||||
"time" |
||||
) |
||||
|
||||
// App is the main structure of a cli application. It is recomended that
|
||||
// and app be created with the cli.NewApp() function
|
||||
type App struct { |
||||
// The name of the program. Defaults to os.Args[0]
|
||||
Name string |
||||
// Description of the program.
|
||||
Usage string |
||||
// Version of the program
|
||||
Version string |
||||
// List of commands to execute
|
||||
Commands []Command |
||||
// List of flags to parse
|
||||
Flags []Flag |
||||
// Boolean to enable bash completion commands
|
||||
EnableBashCompletion bool |
||||
// Boolean to hide built-in help command
|
||||
HideHelp bool |
||||
// Boolean to hide built-in version flag
|
||||
HideVersion bool |
||||
// An action to execute when the bash-completion flag is set
|
||||
BashComplete func(context *Context) |
||||
// An action to execute before any subcommands are run, but after the context is ready
|
||||
// If a non-nil error is returned, no subcommands are run
|
||||
Before func(context *Context) error |
||||
// An action to execute after any subcommands are run, but after the subcommand has finished
|
||||
// It is run even if Action() panics
|
||||
After func(context *Context) error |
||||
// The action to execute when no subcommands are specified
|
||||
Action func(context *Context) |
||||
// Execute this function if the proper command cannot be found
|
||||
CommandNotFound func(context *Context, command string) |
||||
// Compilation date
|
||||
Compiled time.Time |
||||
// Author
|
||||
Author string |
||||
// Author e-mail
|
||||
Email string |
||||
// Writer writer to write output to
|
||||
Writer io.Writer |
||||
} |
||||
|
||||
// Tries to find out when this binary was compiled.
|
||||
// Returns the current time if it fails to find it.
|
||||
func compileTime() time.Time { |
||||
info, err := os.Stat(os.Args[0]) |
||||
if err != nil { |
||||
return time.Now() |
||||
} |
||||
return info.ModTime() |
||||
} |
||||
|
||||
// Creates a new cli Application with some reasonable defaults for Name, Usage, Version and Action.
|
||||
func NewApp() *App { |
||||
return &App{ |
||||
Name: os.Args[0], |
||||
Usage: "A new cli application", |
||||
Version: "0.0.0", |
||||
BashComplete: DefaultAppComplete, |
||||
Action: helpCommand.Action, |
||||
Compiled: compileTime(), |
||||
Author: "Author", |
||||
Email: "unknown@email", |
||||
Writer: os.Stdout, |
||||
} |
||||
} |
||||
|
||||
// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination
|
||||
func (a *App) Run(arguments []string) (err error) { |
||||
if HelpPrinter == nil { |
||||
defer func() { |
||||
HelpPrinter = nil |
||||
}() |
||||
|
||||
HelpPrinter = func(templ string, data interface{}) { |
||||
w := tabwriter.NewWriter(a.Writer, 0, 8, 1, '\t', 0) |
||||
t := template.Must(template.New("help").Parse(templ)) |
||||
err := t.Execute(w, data) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
w.Flush() |
||||
} |
||||
} |
||||
|
||||
// append help to commands
|
||||
if a.Command(helpCommand.Name) == nil && !a.HideHelp { |
||||
a.Commands = append(a.Commands, helpCommand) |
||||
if (HelpFlag != BoolFlag{}) { |
||||
a.appendFlag(HelpFlag) |
||||
} |
||||
} |
||||
|
||||
//append version/help flags
|
||||
if a.EnableBashCompletion { |
||||
a.appendFlag(BashCompletionFlag) |
||||
} |
||||
|
||||
if !a.HideVersion { |
||||
a.appendFlag(VersionFlag) |
||||
} |
||||
|
||||
// parse flags
|
||||
set := flagSet(a.Name, a.Flags) |
||||
set.SetOutput(ioutil.Discard) |
||||
err = set.Parse(arguments[1:]) |
||||
nerr := normalizeFlags(a.Flags, set) |
||||
if nerr != nil { |
||||
fmt.Fprintln(a.Writer, nerr) |
||||
context := NewContext(a, set, set) |
||||
ShowAppHelp(context) |
||||
fmt.Fprintln(a.Writer) |
||||
return nerr |
||||
} |
||||
context := NewContext(a, set, set) |
||||
|
||||
if err != nil { |
||||
fmt.Fprintf(a.Writer, "Incorrect Usage.\n\n") |
||||
ShowAppHelp(context) |
||||
fmt.Fprintln(a.Writer) |
||||
return err |
||||
} |
||||
|
||||
if checkCompletions(context) { |
||||
return nil |
||||
} |
||||
|
||||
if checkHelp(context) { |
||||
return nil |
||||
} |
||||
|
||||
if checkVersion(context) { |
||||
return nil |
||||
} |
||||
|
||||
if a.After != nil { |
||||
defer func() { |
||||
// err is always nil here.
|
||||
// There is a check to see if it is non-nil
|
||||
// just few lines before.
|
||||
err = a.After(context) |
||||
}() |
||||
} |
||||
|
||||
if a.Before != nil { |
||||
err := a.Before(context) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
args := context.Args() |
||||
if args.Present() { |
||||
name := args.First() |
||||
c := a.Command(name) |
||||
if c != nil { |
||||
return c.Run(context) |
||||
} |
||||
} |
||||
|
||||
// Run default Action
|
||||
a.Action(context) |
||||
return nil |
||||
} |
||||
|
||||
// Another entry point to the cli app, takes care of passing arguments and error handling
|
||||
func (a *App) RunAndExitOnError() { |
||||
if err := a.Run(os.Args); err != nil { |
||||
fmt.Fprintln(os.Stderr, err) |
||||
os.Exit(1) |
||||
} |
||||
} |
||||
|
||||
// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags
|
||||
func (a *App) RunAsSubcommand(ctx *Context) (err error) { |
||||
// append help to commands
|
||||
if len(a.Commands) > 0 { |
||||
if a.Command(helpCommand.Name) == nil && !a.HideHelp { |
||||
a.Commands = append(a.Commands, helpCommand) |
||||
if (HelpFlag != BoolFlag{}) { |
||||
a.appendFlag(HelpFlag) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// append flags
|
||||
if a.EnableBashCompletion { |
||||
a.appendFlag(BashCompletionFlag) |
||||
} |
||||
|
||||
// parse flags
|
||||
set := flagSet(a.Name, a.Flags) |
||||
set.SetOutput(ioutil.Discard) |
||||
err = set.Parse(ctx.Args().Tail()) |
||||
nerr := normalizeFlags(a.Flags, set) |
||||
context := NewContext(a, set, ctx.globalSet) |
||||
|
||||
if nerr != nil { |
||||
fmt.Fprintln(a.Writer, nerr) |
||||
if len(a.Commands) > 0 { |
||||
ShowSubcommandHelp(context) |
||||
} else { |
||||
ShowCommandHelp(ctx, context.Args().First()) |
||||
} |
||||
fmt.Fprintln(a.Writer) |
||||
return nerr |
||||
} |
||||
|
||||
if err != nil { |
||||
fmt.Fprintf(a.Writer, "Incorrect Usage.\n\n") |
||||
ShowSubcommandHelp(context) |
||||
return err |
||||
} |
||||
|
||||
if checkCompletions(context) { |
||||
return nil |
||||
} |
||||
|
||||
if len(a.Commands) > 0 { |
||||
if checkSubcommandHelp(context) { |
||||
return nil |
||||
} |
||||
} else { |
||||
if checkCommandHelp(ctx, context.Args().First()) { |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
if a.After != nil { |
||||
defer func() { |
||||
// err is always nil here.
|
||||
// There is a check to see if it is non-nil
|
||||
// just few lines before.
|
||||
err = a.After(context) |
||||
}() |
||||
} |
||||
|
||||
if a.Before != nil { |
||||
err := a.Before(context) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
args := context.Args() |
||||
if args.Present() { |
||||
name := args.First() |
||||
c := a.Command(name) |
||||
if c != nil { |
||||
return c.Run(context) |
||||
} |
||||
} |
||||
|
||||
// Run default Action
|
||||
a.Action(context) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// Returns the named command on App. Returns nil if the command does not exist
|
||||
func (a *App) Command(name string) *Command { |
||||
for _, c := range a.Commands { |
||||
if c.HasName(name) { |
||||
return &c |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (a *App) hasFlag(flag Flag) bool { |
||||
for _, f := range a.Flags { |
||||
if flag == f { |
||||
return true |
||||
} |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
func (a *App) appendFlag(flag Flag) { |
||||
if !a.hasFlag(flag) { |
||||
a.Flags = append(a.Flags, flag) |
||||
} |
||||
} |
@ -0,0 +1,619 @@ |
||||
package cli_test |
||||
|
||||
import ( |
||||
"flag" |
||||
"fmt" |
||||
"os" |
||||
"testing" |
||||
|
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
func ExampleApp() { |
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "--name", "Jeremy"} |
||||
|
||||
app := cli.NewApp() |
||||
app.Name = "greet" |
||||
app.Flags = []cli.Flag{ |
||||
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"}, |
||||
} |
||||
app.Action = func(c *cli.Context) { |
||||
fmt.Printf("Hello %v\n", c.String("name")) |
||||
} |
||||
app.Run(os.Args) |
||||
// Output:
|
||||
// Hello Jeremy
|
||||
} |
||||
|
||||
func ExampleAppSubcommand() { |
||||
// set args for examples sake
|
||||
os.Args = []string{"say", "hi", "english", "--name", "Jeremy"} |
||||
app := cli.NewApp() |
||||
app.Name = "say" |
||||
app.Commands = []cli.Command{ |
||||
{ |
||||
Name: "hello", |
||||
ShortName: "hi", |
||||
Usage: "use it to see a description", |
||||
Description: "This is how we describe hello the function", |
||||
Subcommands: []cli.Command{ |
||||
{ |
||||
Name: "english", |
||||
ShortName: "en", |
||||
Usage: "sends a greeting in english", |
||||
Description: "greets someone in english", |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{ |
||||
Name: "name", |
||||
Value: "Bob", |
||||
Usage: "Name of the person to greet", |
||||
}, |
||||
}, |
||||
Action: func(c *cli.Context) { |
||||
fmt.Println("Hello,", c.String("name")) |
||||
}, |
||||
}, |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
app.Run(os.Args) |
||||
// Output:
|
||||
// Hello, Jeremy
|
||||
} |
||||
|
||||
func ExampleAppHelp() { |
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "h", "describeit"} |
||||
|
||||
app := cli.NewApp() |
||||
app.Name = "greet" |
||||
app.Flags = []cli.Flag{ |
||||
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"}, |
||||
} |
||||
app.Commands = []cli.Command{ |
||||
{ |
||||
Name: "describeit", |
||||
ShortName: "d", |
||||
Usage: "use it to see a description", |
||||
Description: "This is how we describe describeit the function", |
||||
Action: func(c *cli.Context) { |
||||
fmt.Printf("i like to describe things") |
||||
}, |
||||
}, |
||||
} |
||||
app.Run(os.Args) |
||||
// Output:
|
||||
// NAME:
|
||||
// describeit - use it to see a description
|
||||
//
|
||||
// USAGE:
|
||||
// command describeit [arguments...]
|
||||
//
|
||||
// DESCRIPTION:
|
||||
// This is how we describe describeit the function
|
||||
} |
||||
|
||||
func ExampleAppBashComplete() { |
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "--generate-bash-completion"} |
||||
|
||||
app := cli.NewApp() |
||||
app.Name = "greet" |
||||
app.EnableBashCompletion = true |
||||
app.Commands = []cli.Command{ |
||||
{ |
||||
Name: "describeit", |
||||
ShortName: "d", |
||||
Usage: "use it to see a description", |
||||
Description: "This is how we describe describeit the function", |
||||
Action: func(c *cli.Context) { |
||||
fmt.Printf("i like to describe things") |
||||
}, |
||||
}, { |
||||
Name: "next", |
||||
Usage: "next example", |
||||
Description: "more stuff to see when generating bash completion", |
||||
Action: func(c *cli.Context) { |
||||
fmt.Printf("the next example") |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
app.Run(os.Args) |
||||
// Output:
|
||||
// describeit
|
||||
// d
|
||||
// next
|
||||
// help
|
||||
// h
|
||||
} |
||||
|
||||
func TestApp_Run(t *testing.T) { |
||||
s := "" |
||||
|
||||
app := cli.NewApp() |
||||
app.Action = func(c *cli.Context) { |
||||
s = s + c.Args().First() |
||||
} |
||||
|
||||
err := app.Run([]string{"command", "foo"}) |
||||
expect(t, err, nil) |
||||
err = app.Run([]string{"command", "bar"}) |
||||
expect(t, err, nil) |
||||
expect(t, s, "foobar") |
||||
} |
||||
|
||||
var commandAppTests = []struct { |
||||
name string |
||||
expected bool |
||||
}{ |
||||
{"foobar", true}, |
||||
{"batbaz", true}, |
||||
{"b", true}, |
||||
{"f", true}, |
||||
{"bat", false}, |
||||
{"nothing", false}, |
||||
} |
||||
|
||||
func TestApp_Command(t *testing.T) { |
||||
app := cli.NewApp() |
||||
fooCommand := cli.Command{Name: "foobar", ShortName: "f"} |
||||
batCommand := cli.Command{Name: "batbaz", ShortName: "b"} |
||||
app.Commands = []cli.Command{ |
||||
fooCommand, |
||||
batCommand, |
||||
} |
||||
|
||||
for _, test := range commandAppTests { |
||||
expect(t, app.Command(test.name) != nil, test.expected) |
||||
} |
||||
} |
||||
|
||||
func TestApp_CommandWithArgBeforeFlags(t *testing.T) { |
||||
var parsedOption, firstArg string |
||||
|
||||
app := cli.NewApp() |
||||
command := cli.Command{ |
||||
Name: "cmd", |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{Name: "option", Value: "", Usage: "some option"}, |
||||
}, |
||||
Action: func(c *cli.Context) { |
||||
parsedOption = c.String("option") |
||||
firstArg = c.Args().First() |
||||
}, |
||||
} |
||||
app.Commands = []cli.Command{command} |
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option"}) |
||||
|
||||
expect(t, parsedOption, "my-option") |
||||
expect(t, firstArg, "my-arg") |
||||
} |
||||
|
||||
func TestApp_RunAsSubcommandParseFlags(t *testing.T) { |
||||
var context *cli.Context |
||||
|
||||
a := cli.NewApp() |
||||
a.Commands = []cli.Command{ |
||||
{ |
||||
Name: "foo", |
||||
Action: func(c *cli.Context) { |
||||
context = c |
||||
}, |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{ |
||||
Name: "lang", |
||||
Value: "english", |
||||
Usage: "language for the greeting", |
||||
}, |
||||
}, |
||||
Before: func(_ *cli.Context) error { return nil }, |
||||
}, |
||||
} |
||||
a.Run([]string{"", "foo", "--lang", "spanish", "abcd"}) |
||||
|
||||
expect(t, context.Args().Get(0), "abcd") |
||||
expect(t, context.String("lang"), "spanish") |
||||
} |
||||
|
||||
func TestApp_CommandWithFlagBeforeTerminator(t *testing.T) { |
||||
var parsedOption string |
||||
var args []string |
||||
|
||||
app := cli.NewApp() |
||||
command := cli.Command{ |
||||
Name: "cmd", |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{Name: "option", Value: "", Usage: "some option"}, |
||||
}, |
||||
Action: func(c *cli.Context) { |
||||
parsedOption = c.String("option") |
||||
args = c.Args() |
||||
}, |
||||
} |
||||
app.Commands = []cli.Command{command} |
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option", "--", "--notARealFlag"}) |
||||
|
||||
expect(t, parsedOption, "my-option") |
||||
expect(t, args[0], "my-arg") |
||||
expect(t, args[1], "--") |
||||
expect(t, args[2], "--notARealFlag") |
||||
} |
||||
|
||||
func TestApp_CommandWithNoFlagBeforeTerminator(t *testing.T) { |
||||
var args []string |
||||
|
||||
app := cli.NewApp() |
||||
command := cli.Command{ |
||||
Name: "cmd", |
||||
Action: func(c *cli.Context) { |
||||
args = c.Args() |
||||
}, |
||||
} |
||||
app.Commands = []cli.Command{command} |
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "--", "notAFlagAtAll"}) |
||||
|
||||
expect(t, args[0], "my-arg") |
||||
expect(t, args[1], "--") |
||||
expect(t, args[2], "notAFlagAtAll") |
||||
} |
||||
|
||||
func TestApp_Float64Flag(t *testing.T) { |
||||
var meters float64 |
||||
|
||||
app := cli.NewApp() |
||||
app.Flags = []cli.Flag{ |
||||
cli.Float64Flag{Name: "height", Value: 1.5, Usage: "Set the height, in meters"}, |
||||
} |
||||
app.Action = func(c *cli.Context) { |
||||
meters = c.Float64("height") |
||||
} |
||||
|
||||
app.Run([]string{"", "--height", "1.93"}) |
||||
expect(t, meters, 1.93) |
||||
} |
||||
|
||||
func TestApp_ParseSliceFlags(t *testing.T) { |
||||
var parsedOption, firstArg string |
||||
var parsedIntSlice []int |
||||
var parsedStringSlice []string |
||||
|
||||
app := cli.NewApp() |
||||
command := cli.Command{ |
||||
Name: "cmd", |
||||
Flags: []cli.Flag{ |
||||
cli.IntSliceFlag{Name: "p", Value: &cli.IntSlice{}, Usage: "set one or more ip addr"}, |
||||
cli.StringSliceFlag{Name: "ip", Value: &cli.StringSlice{}, Usage: "set one or more ports to open"}, |
||||
}, |
||||
Action: func(c *cli.Context) { |
||||
parsedIntSlice = c.IntSlice("p") |
||||
parsedStringSlice = c.StringSlice("ip") |
||||
parsedOption = c.String("option") |
||||
firstArg = c.Args().First() |
||||
}, |
||||
} |
||||
app.Commands = []cli.Command{command} |
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4"}) |
||||
|
||||
IntsEquals := func(a, b []int) bool { |
||||
if len(a) != len(b) { |
||||
return false |
||||
} |
||||
for i, v := range a { |
||||
if v != b[i] { |
||||
return false |
||||
} |
||||
} |
||||
return true |
||||
} |
||||
|
||||
StrsEquals := func(a, b []string) bool { |
||||
if len(a) != len(b) { |
||||
return false |
||||
} |
||||
for i, v := range a { |
||||
if v != b[i] { |
||||
return false |
||||
} |
||||
} |
||||
return true |
||||
} |
||||
var expectedIntSlice = []int{22, 80} |
||||
var expectedStringSlice = []string{"8.8.8.8", "8.8.4.4"} |
||||
|
||||
if !IntsEquals(parsedIntSlice, expectedIntSlice) { |
||||
t.Errorf("%v does not match %v", parsedIntSlice, expectedIntSlice) |
||||
} |
||||
|
||||
if !StrsEquals(parsedStringSlice, expectedStringSlice) { |
||||
t.Errorf("%v does not match %v", parsedStringSlice, expectedStringSlice) |
||||
} |
||||
} |
||||
|
||||
func TestApp_DefaultStdout(t *testing.T) { |
||||
app := cli.NewApp() |
||||
|
||||
if app.Writer != os.Stdout { |
||||
t.Error("Default output writer not set.") |
||||
} |
||||
} |
||||
|
||||
type mockWriter struct { |
||||
written []byte |
||||
} |
||||
|
||||
func (fw *mockWriter) Write(p []byte) (n int, err error) { |
||||
if fw.written == nil { |
||||
fw.written = p |
||||
} else { |
||||
fw.written = append(fw.written, p...) |
||||
} |
||||
|
||||
return len(p), nil |
||||
} |
||||
|
||||
func (fw *mockWriter) GetWritten() (b []byte) { |
||||
return fw.written |
||||
} |
||||
|
||||
func TestApp_SetStdout(t *testing.T) { |
||||
w := &mockWriter{} |
||||
|
||||
app := cli.NewApp() |
||||
app.Name = "test" |
||||
app.Writer = w |
||||
|
||||
err := app.Run([]string{"help"}) |
||||
|
||||
if err != nil { |
||||
t.Fatalf("Run error: %s", err) |
||||
} |
||||
|
||||
if len(w.written) == 0 { |
||||
t.Error("App did not write output to desired writer.") |
||||
} |
||||
} |
||||
|
||||
func TestApp_BeforeFunc(t *testing.T) { |
||||
beforeRun, subcommandRun := false, false |
||||
beforeError := fmt.Errorf("fail") |
||||
var err error |
||||
|
||||
app := cli.NewApp() |
||||
|
||||
app.Before = func(c *cli.Context) error { |
||||
beforeRun = true |
||||
s := c.String("opt") |
||||
if s == "fail" { |
||||
return beforeError |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
app.Commands = []cli.Command{ |
||||
cli.Command{ |
||||
Name: "sub", |
||||
Action: func(c *cli.Context) { |
||||
subcommandRun = true |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
app.Flags = []cli.Flag{ |
||||
cli.StringFlag{Name: "opt"}, |
||||
} |
||||
|
||||
// run with the Before() func succeeding
|
||||
err = app.Run([]string{"command", "--opt", "succeed", "sub"}) |
||||
|
||||
if err != nil { |
||||
t.Fatalf("Run error: %s", err) |
||||
} |
||||
|
||||
if beforeRun == false { |
||||
t.Errorf("Before() not executed when expected") |
||||
} |
||||
|
||||
if subcommandRun == false { |
||||
t.Errorf("Subcommand not executed when expected") |
||||
} |
||||
|
||||
// reset
|
||||
beforeRun, subcommandRun = false, false |
||||
|
||||
// run with the Before() func failing
|
||||
err = app.Run([]string{"command", "--opt", "fail", "sub"}) |
||||
|
||||
// should be the same error produced by the Before func
|
||||
if err != beforeError { |
||||
t.Errorf("Run error expected, but not received") |
||||
} |
||||
|
||||
if beforeRun == false { |
||||
t.Errorf("Before() not executed when expected") |
||||
} |
||||
|
||||
if subcommandRun == true { |
||||
t.Errorf("Subcommand executed when NOT expected") |
||||
} |
||||
|
||||
} |
||||
|
||||
func TestApp_AfterFunc(t *testing.T) { |
||||
afterRun, subcommandRun := false, false |
||||
afterError := fmt.Errorf("fail") |
||||
var err error |
||||
|
||||
app := cli.NewApp() |
||||
|
||||
app.After = func(c *cli.Context) error { |
||||
afterRun = true |
||||
s := c.String("opt") |
||||
if s == "fail" { |
||||
return afterError |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
app.Commands = []cli.Command{ |
||||
cli.Command{ |
||||
Name: "sub", |
||||
Action: func(c *cli.Context) { |
||||
subcommandRun = true |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
app.Flags = []cli.Flag{ |
||||
cli.StringFlag{Name: "opt"}, |
||||
} |
||||
|
||||
// run with the After() func succeeding
|
||||
err = app.Run([]string{"command", "--opt", "succeed", "sub"}) |
||||
|
||||
if err != nil { |
||||
t.Fatalf("Run error: %s", err) |
||||
} |
||||
|
||||
if afterRun == false { |
||||
t.Errorf("After() not executed when expected") |
||||
} |
||||
|
||||
if subcommandRun == false { |
||||
t.Errorf("Subcommand not executed when expected") |
||||
} |
||||
|
||||
// reset
|
||||
afterRun, subcommandRun = false, false |
||||
|
||||
// run with the Before() func failing
|
||||
err = app.Run([]string{"command", "--opt", "fail", "sub"}) |
||||
|
||||
// should be the same error produced by the Before func
|
||||
if err != afterError { |
||||
t.Errorf("Run error expected, but not received") |
||||
} |
||||
|
||||
if afterRun == false { |
||||
t.Errorf("After() not executed when expected") |
||||
} |
||||
|
||||
if subcommandRun == false { |
||||
t.Errorf("Subcommand not executed when expected") |
||||
} |
||||
} |
||||
|
||||
func TestAppNoHelpFlag(t *testing.T) { |
||||
oldFlag := cli.HelpFlag |
||||
defer func() { |
||||
cli.HelpFlag = oldFlag |
||||
}() |
||||
|
||||
cli.HelpFlag = cli.BoolFlag{} |
||||
|
||||
app := cli.NewApp() |
||||
err := app.Run([]string{"test", "-h"}) |
||||
|
||||
if err != flag.ErrHelp { |
||||
t.Errorf("expected error about missing help flag, but got: %s (%T)", err, err) |
||||
} |
||||
} |
||||
|
||||
func TestAppHelpPrinter(t *testing.T) { |
||||
oldPrinter := cli.HelpPrinter |
||||
defer func() { |
||||
cli.HelpPrinter = oldPrinter |
||||
}() |
||||
|
||||
var wasCalled = false |
||||
cli.HelpPrinter = func(template string, data interface{}) { |
||||
wasCalled = true |
||||
} |
||||
|
||||
app := cli.NewApp() |
||||
app.Run([]string{"-h"}) |
||||
|
||||
if wasCalled == false { |
||||
t.Errorf("Help printer expected to be called, but was not") |
||||
} |
||||
} |
||||
|
||||
func TestAppVersionPrinter(t *testing.T) { |
||||
oldPrinter := cli.VersionPrinter |
||||
defer func() { |
||||
cli.VersionPrinter = oldPrinter |
||||
}() |
||||
|
||||
var wasCalled = false |
||||
cli.VersionPrinter = func(c *cli.Context) { |
||||
wasCalled = true |
||||
} |
||||
|
||||
app := cli.NewApp() |
||||
ctx := cli.NewContext(app, nil, nil) |
||||
cli.ShowVersion(ctx) |
||||
|
||||
if wasCalled == false { |
||||
t.Errorf("Version printer expected to be called, but was not") |
||||
} |
||||
} |
||||
|
||||
func TestAppCommandNotFound(t *testing.T) { |
||||
beforeRun, subcommandRun := false, false |
||||
app := cli.NewApp() |
||||
|
||||
app.CommandNotFound = func(c *cli.Context, command string) { |
||||
beforeRun = true |
||||
} |
||||
|
||||
app.Commands = []cli.Command{ |
||||
cli.Command{ |
||||
Name: "bar", |
||||
Action: func(c *cli.Context) { |
||||
subcommandRun = true |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
app.Run([]string{"command", "foo"}) |
||||
|
||||
expect(t, beforeRun, true) |
||||
expect(t, subcommandRun, false) |
||||
} |
||||
|
||||
func TestGlobalFlagsInSubcommands(t *testing.T) { |
||||
subcommandRun := false |
||||
app := cli.NewApp() |
||||
|
||||
app.Flags = []cli.Flag{ |
||||
cli.BoolFlag{Name: "debug, d", Usage: "Enable debugging"}, |
||||
} |
||||
|
||||
app.Commands = []cli.Command{ |
||||
cli.Command{ |
||||
Name: "foo", |
||||
Subcommands: []cli.Command{ |
||||
{ |
||||
Name: "bar", |
||||
Action: func(c *cli.Context) { |
||||
if c.GlobalBool("debug") { |
||||
subcommandRun = true |
||||
} |
||||
}, |
||||
}, |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
app.Run([]string{"command", "-d", "foo", "bar"}) |
||||
|
||||
expect(t, subcommandRun, true) |
||||
} |
13
Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete
generated
vendored
13
Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete
generated
vendored
@ -0,0 +1,13 @@ |
||||
#! /bin/bash |
||||
|
||||
_cli_bash_autocomplete() { |
||||
local cur prev opts base |
||||
COMPREPLY=() |
||||
cur="${COMP_WORDS[COMP_CWORD]}" |
||||
prev="${COMP_WORDS[COMP_CWORD-1]}" |
||||
opts=$( ${COMP_WORDS[@]:0:$COMP_CWORD} --generate-bash-completion ) |
||||
COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) ) |
||||
return 0 |
||||
} |
||||
|
||||
complete -F _cli_bash_autocomplete $PROG |
@ -0,0 +1,5 @@ |
||||
autoload -U compinit && compinit |
||||
autoload -U bashcompinit && bashcompinit |
||||
|
||||
script_dir=$(dirname $0) |
||||
source ${script_dir}/bash_autocomplete |
@ -0,0 +1,19 @@ |
||||
// Package cli provides a minimal framework for creating and organizing command line
|
||||
// Go applications. cli is designed to be easy to understand and write, the most simple
|
||||
// cli application can be written as follows:
|
||||
// func main() {
|
||||
// cli.NewApp().Run(os.Args)
|
||||
// }
|
||||
//
|
||||
// Of course this application does not do much, so let's make this an actual application:
|
||||
// func main() {
|
||||
// app := cli.NewApp()
|
||||
// app.Name = "greet"
|
||||
// app.Usage = "say a greeting"
|
||||
// app.Action = func(c *cli.Context) {
|
||||
// println("Greetings")
|
||||
// }
|
||||
//
|
||||
// app.Run(os.Args)
|
||||
// }
|
||||
package cli |
@ -0,0 +1,100 @@ |
||||
package cli_test |
||||
|
||||
import ( |
||||
"os" |
||||
|
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
func Example() { |
||||
app := cli.NewApp() |
||||
app.Name = "todo" |
||||
app.Usage = "task list on the command line" |
||||
app.Commands = []cli.Command{ |
||||
{ |
||||
Name: "add", |
||||
ShortName: "a", |
||||
Usage: "add a task to the list", |
||||
Action: func(c *cli.Context) { |
||||
println("added task: ", c.Args().First()) |
||||
}, |
||||
}, |
||||
{ |
||||
Name: "complete", |
||||
ShortName: "c", |
||||
Usage: "complete a task on the list", |
||||
Action: func(c *cli.Context) { |
||||
println("completed task: ", c.Args().First()) |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
app.Run(os.Args) |
||||
} |
||||
|
||||
func ExampleSubcommand() { |
||||
app := cli.NewApp() |
||||
app.Name = "say" |
||||
app.Commands = []cli.Command{ |
||||
{ |
||||
Name: "hello", |
||||
ShortName: "hi", |
||||
Usage: "use it to see a description", |
||||
Description: "This is how we describe hello the function", |
||||
Subcommands: []cli.Command{ |
||||
{ |
||||
Name: "english", |
||||
ShortName: "en", |
||||
Usage: "sends a greeting in english", |
||||
Description: "greets someone in english", |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{ |
||||
Name: "name", |
||||
Value: "Bob", |
||||
Usage: "Name of the person to greet", |
||||
}, |
||||
}, |
||||
Action: func(c *cli.Context) { |
||||
println("Hello, ", c.String("name")) |
||||
}, |
||||
}, { |
||||
Name: "spanish", |
||||
ShortName: "sp", |
||||
Usage: "sends a greeting in spanish", |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{ |
||||
Name: "surname", |
||||
Value: "Jones", |
||||
Usage: "Surname of the person to greet", |
||||
}, |
||||
}, |
||||
Action: func(c *cli.Context) { |
||||
println("Hola, ", c.String("surname")) |
||||
}, |
||||
}, { |
||||
Name: "french", |
||||
ShortName: "fr", |
||||
Usage: "sends a greeting in french", |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{ |
||||
Name: "nickname", |
||||
Value: "Stevie", |
||||
Usage: "Nickname of the person to greet", |
||||
}, |
||||
}, |
||||
Action: func(c *cli.Context) { |
||||
println("Bonjour, ", c.String("nickname")) |
||||
}, |
||||
}, |
||||
}, |
||||
}, { |
||||
Name: "bye", |
||||
Usage: "says goodbye", |
||||
Action: func(c *cli.Context) { |
||||
println("bye") |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
app.Run(os.Args) |
||||
} |
@ -0,0 +1,160 @@ |
||||
package cli |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io/ioutil" |
||||
"strings" |
||||
) |
||||
|
||||
// Command is a subcommand for a cli.App.
|
||||
type Command struct { |
||||
// The name of the command
|
||||
Name string |
||||
// short name of the command. Typically one character
|
||||
ShortName string |
||||
// A short description of the usage of this command
|
||||
Usage string |
||||
// A longer explanation of how the command works
|
||||
Description string |
||||
// The function to call when checking for bash command completions
|
||||
BashComplete func(context *Context) |
||||
// An action to execute before any sub-subcommands are run, but after the context is ready
|
||||
// If a non-nil error is returned, no sub-subcommands are run
|
||||
Before func(context *Context) error |
||||
// An action to execute after any subcommands are run, but after the subcommand has finished
|
||||
// It is run even if Action() panics
|
||||
After func(context *Context) error |
||||
// The function to call when this command is invoked
|
||||
Action func(context *Context) |
||||
// List of child commands
|
||||
Subcommands []Command |
||||
// List of flags to parse
|
||||
Flags []Flag |
||||
// Treat all flags as normal arguments if true
|
||||
SkipFlagParsing bool |
||||
// Boolean to hide built-in help command
|
||||
HideHelp bool |
||||
} |
||||
|
||||
// Invokes the command given the context, parses ctx.Args() to generate command-specific flags
|
||||
func (c Command) Run(ctx *Context) error { |
||||
|
||||
if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil { |
||||
return c.startApp(ctx) |
||||
} |
||||
|
||||
if !c.HideHelp && (HelpFlag != BoolFlag{}) { |
||||
// append help to flags
|
||||
c.Flags = append( |
||||
c.Flags, |
||||
HelpFlag, |
||||
) |
||||
} |
||||
|
||||
if ctx.App.EnableBashCompletion { |
||||
c.Flags = append(c.Flags, BashCompletionFlag) |
||||
} |
||||
|
||||
set := flagSet(c.Name, c.Flags) |
||||
set.SetOutput(ioutil.Discard) |
||||
|
||||
firstFlagIndex := -1 |
||||
terminatorIndex := -1 |
||||
for index, arg := range ctx.Args() { |
||||
if arg == "--" { |
||||
terminatorIndex = index |
||||
break |
||||
} else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 { |
||||
firstFlagIndex = index |
||||
} |
||||
} |
||||
|
||||
var err error |
||||
if firstFlagIndex > -1 && !c.SkipFlagParsing { |
||||
args := ctx.Args() |
||||
regularArgs := make([]string, len(args[1:firstFlagIndex])) |
||||
copy(regularArgs, args[1:firstFlagIndex]) |
||||
|
||||
var flagArgs []string |
||||
if terminatorIndex > -1 { |
||||
flagArgs = args[firstFlagIndex:terminatorIndex] |
||||
regularArgs = append(regularArgs, args[terminatorIndex:]...) |
||||
} else { |
||||
flagArgs = args[firstFlagIndex:] |
||||
} |
||||
|
||||
err = set.Parse(append(flagArgs, regularArgs...)) |
||||
} else { |
||||
err = set.Parse(ctx.Args().Tail()) |
||||
} |
||||
|
||||
if err != nil { |
||||
fmt.Fprint(ctx.App.Writer, "Incorrect Usage.\n\n") |
||||
ShowCommandHelp(ctx, c.Name) |
||||
fmt.Fprintln(ctx.App.Writer) |
||||
return err |
||||
} |
||||
|
||||
nerr := normalizeFlags(c.Flags, set) |
||||
if nerr != nil { |
||||
fmt.Fprintln(ctx.App.Writer, nerr) |
||||
fmt.Fprintln(ctx.App.Writer) |
||||
ShowCommandHelp(ctx, c.Name) |
||||
fmt.Fprintln(ctx.App.Writer) |
||||
return nerr |
||||
} |
||||
context := NewContext(ctx.App, set, ctx.globalSet) |
||||
|
||||
if checkCommandCompletions(context, c.Name) { |
||||
return nil |
||||
} |
||||
|
||||
if checkCommandHelp(context, c.Name) { |
||||
return nil |
||||
} |
||||
context.Command = c |
||||
c.Action(context) |
||||
return nil |
||||
} |
||||
|
||||
// Returns true if Command.Name or Command.ShortName matches given name
|
||||
func (c Command) HasName(name string) bool { |
||||
return c.Name == name || c.ShortName == name |
||||
} |
||||
|
||||
func (c Command) startApp(ctx *Context) error { |
||||
app := NewApp() |
||||
|
||||
// set the name and usage
|
||||
app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name) |
||||
if c.Description != "" { |
||||
app.Usage = c.Description |
||||
} else { |
||||
app.Usage = c.Usage |
||||
} |
||||
|
||||
// set CommandNotFound
|
||||
app.CommandNotFound = ctx.App.CommandNotFound |
||||
|
||||
// set the flags and commands
|
||||
app.Commands = c.Subcommands |
||||
app.Flags = c.Flags |
||||
app.HideHelp = c.HideHelp |
||||
|
||||
// bash completion
|
||||
app.EnableBashCompletion = ctx.App.EnableBashCompletion |
||||
if c.BashComplete != nil { |
||||
app.BashComplete = c.BashComplete |
||||
} |
||||
|
||||
// set the actions
|
||||
app.Before = c.Before |
||||
app.After = c.After |
||||
if c.Action != nil { |
||||
app.Action = c.Action |
||||
} else { |
||||
app.Action = helpSubcommand.Action |
||||
} |
||||
|
||||
return app.RunAsSubcommand(ctx) |
||||
} |
@ -0,0 +1,49 @@ |
||||
package cli_test |
||||
|
||||
import ( |
||||
"flag" |
||||
"testing" |
||||
|
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
func TestCommandDoNotIgnoreFlags(t *testing.T) { |
||||
app := cli.NewApp() |
||||
set := flag.NewFlagSet("test", 0) |
||||
test := []string{"blah", "blah", "-break"} |
||||
set.Parse(test) |
||||
|
||||
c := cli.NewContext(app, set, set) |
||||
|
||||
command := cli.Command{ |
||||
Name: "test-cmd", |
||||
ShortName: "tc", |
||||
Usage: "this is for testing", |
||||
Description: "testing", |
||||
Action: func(_ *cli.Context) {}, |
||||
} |
||||
err := command.Run(c) |
||||
|
||||
expect(t, err.Error(), "flag provided but not defined: -break") |
||||
} |
||||
|
||||
func TestCommandIgnoreFlags(t *testing.T) { |
||||
app := cli.NewApp() |
||||
set := flag.NewFlagSet("test", 0) |
||||
test := []string{"blah", "blah"} |
||||
set.Parse(test) |
||||
|
||||
c := cli.NewContext(app, set, set) |
||||
|
||||
command := cli.Command{ |
||||
Name: "test-cmd", |
||||
ShortName: "tc", |
||||
Usage: "this is for testing", |
||||
Description: "testing", |
||||
Action: func(_ *cli.Context) {}, |
||||
SkipFlagParsing: true, |
||||
} |
||||
err := command.Run(c) |
||||
|
||||
expect(t, err, nil) |
||||
} |
@ -0,0 +1,339 @@ |
||||
package cli |
||||
|
||||
import ( |
||||
"errors" |
||||
"flag" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
) |
||||
|
||||
// Context is a type that is passed through to
|
||||
// each Handler action in a cli application. Context
|
||||
// can be used to retrieve context-specific Args and
|
||||
// parsed command-line options.
|
||||
type Context struct { |
||||
App *App |
||||
Command Command |
||||
flagSet *flag.FlagSet |
||||
globalSet *flag.FlagSet |
||||
setFlags map[string]bool |
||||
globalSetFlags map[string]bool |
||||
} |
||||
|
||||
// Creates a new context. For use in when invoking an App or Command action.
|
||||
func NewContext(app *App, set *flag.FlagSet, globalSet *flag.FlagSet) *Context { |
||||
return &Context{App: app, flagSet: set, globalSet: globalSet} |
||||
} |
||||
|
||||
// Looks up the value of a local int flag, returns 0 if no int flag exists
|
||||
func (c *Context) Int(name string) int { |
||||
return lookupInt(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a local time.Duration flag, returns 0 if no time.Duration flag exists
|
||||
func (c *Context) Duration(name string) time.Duration { |
||||
return lookupDuration(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a local float64 flag, returns 0 if no float64 flag exists
|
||||
func (c *Context) Float64(name string) float64 { |
||||
return lookupFloat64(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a local bool flag, returns false if no bool flag exists
|
||||
func (c *Context) Bool(name string) bool { |
||||
return lookupBool(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a local boolT flag, returns false if no bool flag exists
|
||||
func (c *Context) BoolT(name string) bool { |
||||
return lookupBoolT(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a local string flag, returns "" if no string flag exists
|
||||
func (c *Context) String(name string) string { |
||||
return lookupString(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a local string slice flag, returns nil if no string slice flag exists
|
||||
func (c *Context) StringSlice(name string) []string { |
||||
return lookupStringSlice(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a local int slice flag, returns nil if no int slice flag exists
|
||||
func (c *Context) IntSlice(name string) []int { |
||||
return lookupIntSlice(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a local generic flag, returns nil if no generic flag exists
|
||||
func (c *Context) Generic(name string) interface{} { |
||||
return lookupGeneric(name, c.flagSet) |
||||
} |
||||
|
||||
// Looks up the value of a global int flag, returns 0 if no int flag exists
|
||||
func (c *Context) GlobalInt(name string) int { |
||||
return lookupInt(name, c.globalSet) |
||||
} |
||||
|
||||
// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists
|
||||
func (c *Context) GlobalDuration(name string) time.Duration { |
||||
return lookupDuration(name, c.globalSet) |
||||
} |
||||
|
||||
// Looks up the value of a global bool flag, returns false if no bool flag exists
|
||||
func (c *Context) GlobalBool(name string) bool { |
||||
return lookupBool(name, c.globalSet) |
||||
} |
||||
|
||||
// Looks up the value of a global string flag, returns "" if no string flag exists
|
||||
func (c *Context) GlobalString(name string) string { |
||||
return lookupString(name, c.globalSet) |
||||
} |
||||
|
||||
// Looks up the value of a global string slice flag, returns nil if no string slice flag exists
|
||||
func (c *Context) GlobalStringSlice(name string) []string { |
||||
return lookupStringSlice(name, c.globalSet) |
||||
} |
||||
|
||||
// Looks up the value of a global int slice flag, returns nil if no int slice flag exists
|
||||
func (c *Context) GlobalIntSlice(name string) []int { |
||||
return lookupIntSlice(name, c.globalSet) |
||||
} |
||||
|
||||
// Looks up the value of a global generic flag, returns nil if no generic flag exists
|
||||
func (c *Context) GlobalGeneric(name string) interface{} { |
||||
return lookupGeneric(name, c.globalSet) |
||||
} |
||||
|
||||
// Determines if the flag was actually set
|
||||
func (c *Context) IsSet(name string) bool { |
||||
if c.setFlags == nil { |
||||
c.setFlags = make(map[string]bool) |
||||
c.flagSet.Visit(func(f *flag.Flag) { |
||||
c.setFlags[f.Name] = true |
||||
}) |
||||
} |
||||
return c.setFlags[name] == true |
||||
} |
||||
|
||||
// Determines if the global flag was actually set
|
||||
func (c *Context) GlobalIsSet(name string) bool { |
||||
if c.globalSetFlags == nil { |
||||
c.globalSetFlags = make(map[string]bool) |
||||
c.globalSet.Visit(func(f *flag.Flag) { |
||||
c.globalSetFlags[f.Name] = true |
||||
}) |
||||
} |
||||
return c.globalSetFlags[name] == true |
||||
} |
||||
|
||||
// Returns a slice of flag names used in this context.
|
||||
func (c *Context) FlagNames() (names []string) { |
||||
for _, flag := range c.Command.Flags { |
||||
name := strings.Split(flag.getName(), ",")[0] |
||||
if name == "help" { |
||||
continue |
||||
} |
||||
names = append(names, name) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Returns a slice of global flag names used by the app.
|
||||
func (c *Context) GlobalFlagNames() (names []string) { |
||||
for _, flag := range c.App.Flags { |
||||
name := strings.Split(flag.getName(), ",")[0] |
||||
if name == "help" || name == "version" { |
||||
continue |
||||
} |
||||
names = append(names, name) |
||||
} |
||||
return |
||||
} |
||||
|
||||
type Args []string |
||||
|
||||
// Returns the command line arguments associated with the context.
|
||||
func (c *Context) Args() Args { |
||||
args := Args(c.flagSet.Args()) |
||||
return args |
||||
} |
||||
|
||||
// Returns the nth argument, or else a blank string
|
||||
func (a Args) Get(n int) string { |
||||
if len(a) > n { |
||||
return a[n] |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
// Returns the first argument, or else a blank string
|
||||
func (a Args) First() string { |
||||
return a.Get(0) |
||||
} |
||||
|
||||
// Return the rest of the arguments (not the first one)
|
||||
// or else an empty string slice
|
||||
func (a Args) Tail() []string { |
||||
if len(a) >= 2 { |
||||
return []string(a)[1:] |
||||
} |
||||
return []string{} |
||||
} |
||||
|
||||
// Checks if there are any arguments present
|
||||
func (a Args) Present() bool { |
||||
return len(a) != 0 |
||||
} |
||||
|
||||
// Swaps arguments at the given indexes
|
||||
func (a Args) Swap(from, to int) error { |
||||
if from >= len(a) || to >= len(a) { |
||||
return errors.New("index out of range") |
||||
} |
||||
a[from], a[to] = a[to], a[from] |
||||
return nil |
||||
} |
||||
|
||||
func lookupInt(name string, set *flag.FlagSet) int { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
val, err := strconv.Atoi(f.Value.String()) |
||||
if err != nil { |
||||
return 0 |
||||
} |
||||
return val |
||||
} |
||||
|
||||
return 0 |
||||
} |
||||
|
||||
func lookupDuration(name string, set *flag.FlagSet) time.Duration { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
val, err := time.ParseDuration(f.Value.String()) |
||||
if err == nil { |
||||
return val |
||||
} |
||||
} |
||||
|
||||
return 0 |
||||
} |
||||
|
||||
func lookupFloat64(name string, set *flag.FlagSet) float64 { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
val, err := strconv.ParseFloat(f.Value.String(), 64) |
||||
if err != nil { |
||||
return 0 |
||||
} |
||||
return val |
||||
} |
||||
|
||||
return 0 |
||||
} |
||||
|
||||
func lookupString(name string, set *flag.FlagSet) string { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
return f.Value.String() |
||||
} |
||||
|
||||
return "" |
||||
} |
||||
|
||||
func lookupStringSlice(name string, set *flag.FlagSet) []string { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
return (f.Value.(*StringSlice)).Value() |
||||
|
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func lookupIntSlice(name string, set *flag.FlagSet) []int { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
return (f.Value.(*IntSlice)).Value() |
||||
|
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func lookupGeneric(name string, set *flag.FlagSet) interface{} { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
return f.Value |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func lookupBool(name string, set *flag.FlagSet) bool { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
val, err := strconv.ParseBool(f.Value.String()) |
||||
if err != nil { |
||||
return false |
||||
} |
||||
return val |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
func lookupBoolT(name string, set *flag.FlagSet) bool { |
||||
f := set.Lookup(name) |
||||
if f != nil { |
||||
val, err := strconv.ParseBool(f.Value.String()) |
||||
if err != nil { |
||||
return true |
||||
} |
||||
return val |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) { |
||||
switch ff.Value.(type) { |
||||
case *StringSlice: |
||||
default: |
||||
set.Set(name, ff.Value.String()) |
||||
} |
||||
} |
||||
|
||||
func normalizeFlags(flags []Flag, set *flag.FlagSet) error { |
||||
visited := make(map[string]bool) |
||||
set.Visit(func(f *flag.Flag) { |
||||
visited[f.Name] = true |
||||
}) |
||||
for _, f := range flags { |
||||
parts := strings.Split(f.getName(), ",") |
||||
if len(parts) == 1 { |
||||
continue |
||||
} |
||||
var ff *flag.Flag |
||||
for _, name := range parts { |
||||
name = strings.Trim(name, " ") |
||||
if visited[name] { |
||||
if ff != nil { |
||||
return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name) |
||||
} |
||||
ff = set.Lookup(name) |
||||
} |
||||
} |
||||
if ff == nil { |
||||
continue |
||||
} |
||||
for _, name := range parts { |
||||
name = strings.Trim(name, " ") |
||||
if !visited[name] { |
||||
copyFlag(name, ff, set) |
||||
} |
||||
} |
||||
} |
||||
return nil |
||||
} |
@ -0,0 +1,99 @@ |
||||
package cli_test |
||||
|
||||
import ( |
||||
"flag" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
func TestNewContext(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.Int("myflag", 12, "doc") |
||||
globalSet := flag.NewFlagSet("test", 0) |
||||
globalSet.Int("myflag", 42, "doc") |
||||
command := cli.Command{Name: "mycommand"} |
||||
c := cli.NewContext(nil, set, globalSet) |
||||
c.Command = command |
||||
expect(t, c.Int("myflag"), 12) |
||||
expect(t, c.GlobalInt("myflag"), 42) |
||||
expect(t, c.Command.Name, "mycommand") |
||||
} |
||||
|
||||
func TestContext_Int(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.Int("myflag", 12, "doc") |
||||
c := cli.NewContext(nil, set, set) |
||||
expect(t, c.Int("myflag"), 12) |
||||
} |
||||
|
||||
func TestContext_Duration(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.Duration("myflag", time.Duration(12*time.Second), "doc") |
||||
c := cli.NewContext(nil, set, set) |
||||
expect(t, c.Duration("myflag"), time.Duration(12*time.Second)) |
||||
} |
||||
|
||||
func TestContext_String(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.String("myflag", "hello world", "doc") |
||||
c := cli.NewContext(nil, set, set) |
||||
expect(t, c.String("myflag"), "hello world") |
||||
} |
||||
|
||||
func TestContext_Bool(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.Bool("myflag", false, "doc") |
||||
c := cli.NewContext(nil, set, set) |
||||
expect(t, c.Bool("myflag"), false) |
||||
} |
||||
|
||||
func TestContext_BoolT(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.Bool("myflag", true, "doc") |
||||
c := cli.NewContext(nil, set, set) |
||||
expect(t, c.BoolT("myflag"), true) |
||||
} |
||||
|
||||
func TestContext_Args(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.Bool("myflag", false, "doc") |
||||
c := cli.NewContext(nil, set, set) |
||||
set.Parse([]string{"--myflag", "bat", "baz"}) |
||||
expect(t, len(c.Args()), 2) |
||||
expect(t, c.Bool("myflag"), true) |
||||
} |
||||
|
||||
func TestContext_IsSet(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.Bool("myflag", false, "doc") |
||||
set.String("otherflag", "hello world", "doc") |
||||
globalSet := flag.NewFlagSet("test", 0) |
||||
globalSet.Bool("myflagGlobal", true, "doc") |
||||
c := cli.NewContext(nil, set, globalSet) |
||||
set.Parse([]string{"--myflag", "bat", "baz"}) |
||||
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) |
||||
expect(t, c.IsSet("myflag"), true) |
||||
expect(t, c.IsSet("otherflag"), false) |
||||
expect(t, c.IsSet("bogusflag"), false) |
||||
expect(t, c.IsSet("myflagGlobal"), false) |
||||
} |
||||
|
||||
func TestContext_GlobalIsSet(t *testing.T) { |
||||
set := flag.NewFlagSet("test", 0) |
||||
set.Bool("myflag", false, "doc") |
||||
set.String("otherflag", "hello world", "doc") |
||||
globalSet := flag.NewFlagSet("test", 0) |
||||
globalSet.Bool("myflagGlobal", true, "doc") |
||||
globalSet.Bool("myflagGlobalUnset", true, "doc") |
||||
c := cli.NewContext(nil, set, globalSet) |
||||
set.Parse([]string{"--myflag", "bat", "baz"}) |
||||
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) |
||||
expect(t, c.GlobalIsSet("myflag"), false) |
||||
expect(t, c.GlobalIsSet("otherflag"), false) |
||||
expect(t, c.GlobalIsSet("bogusflag"), false) |
||||
expect(t, c.GlobalIsSet("myflagGlobal"), true) |
||||
expect(t, c.GlobalIsSet("myflagGlobalUnset"), false) |
||||
expect(t, c.GlobalIsSet("bogusGlobal"), false) |
||||
} |
@ -0,0 +1,454 @@ |
||||
package cli |
||||
|
||||
import ( |
||||
"flag" |
||||
"fmt" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
) |
||||
|
||||
// This flag enables bash-completion for all commands and subcommands
|
||||
var BashCompletionFlag = BoolFlag{ |
||||
Name: "generate-bash-completion", |
||||
} |
||||
|
||||
// This flag prints the version for the application
|
||||
var VersionFlag = BoolFlag{ |
||||
Name: "version, v", |
||||
Usage: "print the version", |
||||
} |
||||
|
||||
// This flag prints the help for all commands and subcommands
|
||||
// Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand
|
||||
// unless HideHelp is set to true)
|
||||
var HelpFlag = BoolFlag{ |
||||
Name: "help, h", |
||||
Usage: "show help", |
||||
} |
||||
|
||||
// Flag is a common interface related to parsing flags in cli.
|
||||
// For more advanced flag parsing techniques, it is recomended that
|
||||
// this interface be implemented.
|
||||
type Flag interface { |
||||
fmt.Stringer |
||||
// Apply Flag settings to the given flag set
|
||||
Apply(*flag.FlagSet) |
||||
getName() string |
||||
} |
||||
|
||||
func flagSet(name string, flags []Flag) *flag.FlagSet { |
||||
set := flag.NewFlagSet(name, flag.ContinueOnError) |
||||
|
||||
for _, f := range flags { |
||||
f.Apply(set) |
||||
} |
||||
return set |
||||
} |
||||
|
||||
func eachName(longName string, fn func(string)) { |
||||
parts := strings.Split(longName, ",") |
||||
for _, name := range parts { |
||||
name = strings.Trim(name, " ") |
||||
fn(name) |
||||
} |
||||
} |
||||
|
||||
// Generic is a generic parseable type identified by a specific flag
|
||||
type Generic interface { |
||||
Set(value string) error |
||||
String() string |
||||
} |
||||
|
||||
// GenericFlag is the flag type for types implementing Generic
|
||||
type GenericFlag struct { |
||||
Name string |
||||
Value Generic |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
// String returns the string representation of the generic flag to display the
|
||||
// help text to the user (uses the String() method of the generic flag to show
|
||||
// the value)
|
||||
func (f GenericFlag) String() string { |
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s \"%v\"\t%v", prefixFor(f.Name), f.Name, f.Value, f.Usage)) |
||||
} |
||||
|
||||
// Apply takes the flagset and calls Set on the generic flag with the value
|
||||
// provided by the user for parsing by the flag
|
||||
func (f GenericFlag) Apply(set *flag.FlagSet) { |
||||
val := f.Value |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
val.Set(envVal) |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.Var(f.Value, name, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f GenericFlag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
type StringSlice []string |
||||
|
||||
func (f *StringSlice) Set(value string) error { |
||||
*f = append(*f, value) |
||||
return nil |
||||
} |
||||
|
||||
func (f *StringSlice) String() string { |
||||
return fmt.Sprintf("%s", *f) |
||||
} |
||||
|
||||
func (f *StringSlice) Value() []string { |
||||
return *f |
||||
} |
||||
|
||||
type StringSliceFlag struct { |
||||
Name string |
||||
Value *StringSlice |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
func (f StringSliceFlag) String() string { |
||||
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") |
||||
pref := prefixFor(firstName) |
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) |
||||
} |
||||
|
||||
func (f StringSliceFlag) Apply(set *flag.FlagSet) { |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
newVal := &StringSlice{} |
||||
for _, s := range strings.Split(envVal, ",") { |
||||
s = strings.TrimSpace(s) |
||||
newVal.Set(s) |
||||
} |
||||
f.Value = newVal |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.Var(f.Value, name, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f StringSliceFlag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
type IntSlice []int |
||||
|
||||
func (f *IntSlice) Set(value string) error { |
||||
|
||||
tmp, err := strconv.Atoi(value) |
||||
if err != nil { |
||||
return err |
||||
} else { |
||||
*f = append(*f, tmp) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (f *IntSlice) String() string { |
||||
return fmt.Sprintf("%d", *f) |
||||
} |
||||
|
||||
func (f *IntSlice) Value() []int { |
||||
return *f |
||||
} |
||||
|
||||
type IntSliceFlag struct { |
||||
Name string |
||||
Value *IntSlice |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
func (f IntSliceFlag) String() string { |
||||
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") |
||||
pref := prefixFor(firstName) |
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) |
||||
} |
||||
|
||||
func (f IntSliceFlag) Apply(set *flag.FlagSet) { |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
newVal := &IntSlice{} |
||||
for _, s := range strings.Split(envVal, ",") { |
||||
s = strings.TrimSpace(s) |
||||
err := newVal.Set(s) |
||||
if err != nil { |
||||
fmt.Fprintf(os.Stderr, err.Error()) |
||||
} |
||||
} |
||||
f.Value = newVal |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.Var(f.Value, name, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f IntSliceFlag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
type BoolFlag struct { |
||||
Name string |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
func (f BoolFlag) String() string { |
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) |
||||
} |
||||
|
||||
func (f BoolFlag) Apply(set *flag.FlagSet) { |
||||
val := false |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
envValBool, err := strconv.ParseBool(envVal) |
||||
if err == nil { |
||||
val = envValBool |
||||
} |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.Bool(name, val, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f BoolFlag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
type BoolTFlag struct { |
||||
Name string |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
func (f BoolTFlag) String() string { |
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) |
||||
} |
||||
|
||||
func (f BoolTFlag) Apply(set *flag.FlagSet) { |
||||
val := true |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
envValBool, err := strconv.ParseBool(envVal) |
||||
if err == nil { |
||||
val = envValBool |
||||
break |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.Bool(name, val, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f BoolTFlag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
type StringFlag struct { |
||||
Name string |
||||
Value string |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
func (f StringFlag) String() string { |
||||
var fmtString string |
||||
fmtString = "%s %v\t%v" |
||||
|
||||
if len(f.Value) > 0 { |
||||
fmtString = "%s \"%v\"\t%v" |
||||
} else { |
||||
fmtString = "%s %v\t%v" |
||||
} |
||||
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage)) |
||||
} |
||||
|
||||
func (f StringFlag) Apply(set *flag.FlagSet) { |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
f.Value = envVal |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.String(name, f.Value, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f StringFlag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
type IntFlag struct { |
||||
Name string |
||||
Value int |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
func (f IntFlag) String() string { |
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) |
||||
} |
||||
|
||||
func (f IntFlag) Apply(set *flag.FlagSet) { |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
envValInt, err := strconv.ParseInt(envVal, 0, 64) |
||||
if err == nil { |
||||
f.Value = int(envValInt) |
||||
break |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.Int(name, f.Value, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f IntFlag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
type DurationFlag struct { |
||||
Name string |
||||
Value time.Duration |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
func (f DurationFlag) String() string { |
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) |
||||
} |
||||
|
||||
func (f DurationFlag) Apply(set *flag.FlagSet) { |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
envValDuration, err := time.ParseDuration(envVal) |
||||
if err == nil { |
||||
f.Value = envValDuration |
||||
break |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.Duration(name, f.Value, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f DurationFlag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
type Float64Flag struct { |
||||
Name string |
||||
Value float64 |
||||
Usage string |
||||
EnvVar string |
||||
} |
||||
|
||||
func (f Float64Flag) String() string { |
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) |
||||
} |
||||
|
||||
func (f Float64Flag) Apply(set *flag.FlagSet) { |
||||
if f.EnvVar != "" { |
||||
for _, envVar := range strings.Split(f.EnvVar, ",") { |
||||
envVar = strings.TrimSpace(envVar) |
||||
if envVal := os.Getenv(envVar); envVal != "" { |
||||
envValFloat, err := strconv.ParseFloat(envVal, 10) |
||||
if err == nil { |
||||
f.Value = float64(envValFloat) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
eachName(f.Name, func(name string) { |
||||
set.Float64(name, f.Value, f.Usage) |
||||
}) |
||||
} |
||||
|
||||
func (f Float64Flag) getName() string { |
||||
return f.Name |
||||
} |
||||
|
||||
func prefixFor(name string) (prefix string) { |
||||
if len(name) == 1 { |
||||
prefix = "-" |
||||
} else { |
||||
prefix = "--" |
||||
} |
||||
|
||||
return |
||||
} |
||||
|
||||
func prefixedNames(fullName string) (prefixed string) { |
||||
parts := strings.Split(fullName, ",") |
||||
for i, name := range parts { |
||||
name = strings.Trim(name, " ") |
||||
prefixed += prefixFor(name) + name |
||||
if i < len(parts)-1 { |
||||
prefixed += ", " |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func withEnvHint(envVar, str string) string { |
||||
envText := "" |
||||
if envVar != "" { |
||||
envText = fmt.Sprintf(" [$%s]", strings.Join(strings.Split(envVar, ","), ", $")) |
||||
} |
||||
return str + envText |
||||
} |
@ -0,0 +1,742 @@ |
||||
package cli_test |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"reflect" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
var boolFlagTests = []struct { |
||||
name string |
||||
expected string |
||||
}{ |
||||
{"help", "--help\t"}, |
||||
{"h", "-h\t"}, |
||||
} |
||||
|
||||
func TestBoolFlagHelpOutput(t *testing.T) { |
||||
|
||||
for _, test := range boolFlagTests { |
||||
flag := cli.BoolFlag{Name: test.name} |
||||
output := flag.String() |
||||
|
||||
if output != test.expected { |
||||
t.Errorf("%s does not match %s", output, test.expected) |
||||
} |
||||
} |
||||
} |
||||
|
||||
var stringFlagTests = []struct { |
||||
name string |
||||
value string |
||||
expected string |
||||
}{ |
||||
{"help", "", "--help \t"}, |
||||
{"h", "", "-h \t"}, |
||||
{"h", "", "-h \t"}, |
||||
{"test", "Something", "--test \"Something\"\t"}, |
||||
} |
||||
|
||||
func TestStringFlagHelpOutput(t *testing.T) { |
||||
|
||||
for _, test := range stringFlagTests { |
||||
flag := cli.StringFlag{Name: test.name, Value: test.value} |
||||
output := flag.String() |
||||
|
||||
if output != test.expected { |
||||
t.Errorf("%s does not match %s", output, test.expected) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestStringFlagWithEnvVarHelpOutput(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_FOO", "derp") |
||||
for _, test := range stringFlagTests { |
||||
flag := cli.StringFlag{Name: test.name, Value: test.value, EnvVar: "APP_FOO"} |
||||
output := flag.String() |
||||
|
||||
if !strings.HasSuffix(output, " [$APP_FOO]") { |
||||
t.Errorf("%s does not end with [$APP_FOO]", output) |
||||
} |
||||
} |
||||
} |
||||
|
||||
var stringSliceFlagTests = []struct { |
||||
name string |
||||
value *cli.StringSlice |
||||
expected string |
||||
}{ |
||||
{"help", func() *cli.StringSlice { |
||||
s := &cli.StringSlice{} |
||||
s.Set("") |
||||
return s |
||||
}(), "--help [--help option --help option]\t"}, |
||||
{"h", func() *cli.StringSlice { |
||||
s := &cli.StringSlice{} |
||||
s.Set("") |
||||
return s |
||||
}(), "-h [-h option -h option]\t"}, |
||||
{"h", func() *cli.StringSlice { |
||||
s := &cli.StringSlice{} |
||||
s.Set("") |
||||
return s |
||||
}(), "-h [-h option -h option]\t"}, |
||||
{"test", func() *cli.StringSlice { |
||||
s := &cli.StringSlice{} |
||||
s.Set("Something") |
||||
return s |
||||
}(), "--test [--test option --test option]\t"}, |
||||
} |
||||
|
||||
func TestStringSliceFlagHelpOutput(t *testing.T) { |
||||
|
||||
for _, test := range stringSliceFlagTests { |
||||
flag := cli.StringSliceFlag{Name: test.name, Value: test.value} |
||||
output := flag.String() |
||||
|
||||
if output != test.expected { |
||||
t.Errorf("%q does not match %q", output, test.expected) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestStringSliceFlagWithEnvVarHelpOutput(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_QWWX", "11,4") |
||||
for _, test := range stringSliceFlagTests { |
||||
flag := cli.StringSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_QWWX"} |
||||
output := flag.String() |
||||
|
||||
if !strings.HasSuffix(output, " [$APP_QWWX]") { |
||||
t.Errorf("%q does not end with [$APP_QWWX]", output) |
||||
} |
||||
} |
||||
} |
||||
|
||||
var intFlagTests = []struct { |
||||
name string |
||||
expected string |
||||
}{ |
||||
{"help", "--help \"0\"\t"}, |
||||
{"h", "-h \"0\"\t"}, |
||||
} |
||||
|
||||
func TestIntFlagHelpOutput(t *testing.T) { |
||||
|
||||
for _, test := range intFlagTests { |
||||
flag := cli.IntFlag{Name: test.name} |
||||
output := flag.String() |
||||
|
||||
if output != test.expected { |
||||
t.Errorf("%s does not match %s", output, test.expected) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestIntFlagWithEnvVarHelpOutput(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_BAR", "2") |
||||
for _, test := range intFlagTests { |
||||
flag := cli.IntFlag{Name: test.name, EnvVar: "APP_BAR"} |
||||
output := flag.String() |
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAR]") { |
||||
t.Errorf("%s does not end with [$APP_BAR]", output) |
||||
} |
||||
} |
||||
} |
||||
|
||||
var durationFlagTests = []struct { |
||||
name string |
||||
expected string |
||||
}{ |
||||
{"help", "--help \"0\"\t"}, |
||||
{"h", "-h \"0\"\t"}, |
||||
} |
||||
|
||||
func TestDurationFlagHelpOutput(t *testing.T) { |
||||
|
||||
for _, test := range durationFlagTests { |
||||
flag := cli.DurationFlag{Name: test.name} |
||||
output := flag.String() |
||||
|
||||
if output != test.expected { |
||||
t.Errorf("%s does not match %s", output, test.expected) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_BAR", "2h3m6s") |
||||
for _, test := range durationFlagTests { |
||||
flag := cli.DurationFlag{Name: test.name, EnvVar: "APP_BAR"} |
||||
output := flag.String() |
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAR]") { |
||||
t.Errorf("%s does not end with [$APP_BAR]", output) |
||||
} |
||||
} |
||||
} |
||||
|
||||
var intSliceFlagTests = []struct { |
||||
name string |
||||
value *cli.IntSlice |
||||
expected string |
||||
}{ |
||||
{"help", &cli.IntSlice{}, "--help [--help option --help option]\t"}, |
||||
{"h", &cli.IntSlice{}, "-h [-h option -h option]\t"}, |
||||
{"h", &cli.IntSlice{}, "-h [-h option -h option]\t"}, |
||||
{"test", func() *cli.IntSlice { |
||||
i := &cli.IntSlice{} |
||||
i.Set("9") |
||||
return i |
||||
}(), "--test [--test option --test option]\t"}, |
||||
} |
||||
|
||||
func TestIntSliceFlagHelpOutput(t *testing.T) { |
||||
|
||||
for _, test := range intSliceFlagTests { |
||||
flag := cli.IntSliceFlag{Name: test.name, Value: test.value} |
||||
output := flag.String() |
||||
|
||||
if output != test.expected { |
||||
t.Errorf("%q does not match %q", output, test.expected) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestIntSliceFlagWithEnvVarHelpOutput(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_SMURF", "42,3") |
||||
for _, test := range intSliceFlagTests { |
||||
flag := cli.IntSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_SMURF"} |
||||
output := flag.String() |
||||
|
||||
if !strings.HasSuffix(output, " [$APP_SMURF]") { |
||||
t.Errorf("%q does not end with [$APP_SMURF]", output) |
||||
} |
||||
} |
||||
} |
||||
|
||||
var float64FlagTests = []struct { |
||||
name string |
||||
expected string |
||||
}{ |
||||
{"help", "--help \"0\"\t"}, |
||||
{"h", "-h \"0\"\t"}, |
||||
} |
||||
|
||||
func TestFloat64FlagHelpOutput(t *testing.T) { |
||||
|
||||
for _, test := range float64FlagTests { |
||||
flag := cli.Float64Flag{Name: test.name} |
||||
output := flag.String() |
||||
|
||||
if output != test.expected { |
||||
t.Errorf("%s does not match %s", output, test.expected) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestFloat64FlagWithEnvVarHelpOutput(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_BAZ", "99.4") |
||||
for _, test := range float64FlagTests { |
||||
flag := cli.Float64Flag{Name: test.name, EnvVar: "APP_BAZ"} |
||||
output := flag.String() |
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAZ]") { |
||||
t.Errorf("%s does not end with [$APP_BAZ]", output) |
||||
} |
||||
} |
||||
} |
||||
|
||||
var genericFlagTests = []struct { |
||||
name string |
||||
value cli.Generic |
||||
expected string |
||||
}{ |
||||
{"test", &Parser{"abc", "def"}, "--test \"abc,def\"\ttest flag"}, |
||||
{"t", &Parser{"abc", "def"}, "-t \"abc,def\"\ttest flag"}, |
||||
} |
||||
|
||||
func TestGenericFlagHelpOutput(t *testing.T) { |
||||
|
||||
for _, test := range genericFlagTests { |
||||
flag := cli.GenericFlag{Name: test.name, Value: test.value, Usage: "test flag"} |
||||
output := flag.String() |
||||
|
||||
if output != test.expected { |
||||
t.Errorf("%q does not match %q", output, test.expected) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestGenericFlagWithEnvVarHelpOutput(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_ZAP", "3") |
||||
for _, test := range genericFlagTests { |
||||
flag := cli.GenericFlag{Name: test.name, EnvVar: "APP_ZAP"} |
||||
output := flag.String() |
||||
|
||||
if !strings.HasSuffix(output, " [$APP_ZAP]") { |
||||
t.Errorf("%s does not end with [$APP_ZAP]", output) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestParseMultiString(t *testing.T) { |
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{Name: "serve, s"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.String("serve") != "10" { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.String("s") != "10" { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
}).Run([]string{"run", "-s", "10"}) |
||||
} |
||||
|
||||
func TestParseMultiStringFromEnv(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_COUNT", "20") |
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{Name: "count, c", EnvVar: "APP_COUNT"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.String("count") != "20" { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.String("c") != "20" { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
}).Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiStringFromEnvCascade(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_COUNT", "20") |
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.StringFlag{Name: "count, c", EnvVar: "COMPAT_COUNT,APP_COUNT"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.String("count") != "20" { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.String("c") != "20" { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
}).Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiStringSlice(t *testing.T) { |
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.StringSliceFlag{Name: "serve, s", Value: &cli.StringSlice{}}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.StringSlice("serve"), []string{"10", "20"}) { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if !reflect.DeepEqual(ctx.StringSlice("s"), []string{"10", "20"}) { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
}).Run([]string{"run", "-s", "10", "-s", "20"}) |
||||
} |
||||
|
||||
func TestParseMultiStringSliceFromEnv(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_INTERVALS", "20,30,40") |
||||
|
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.StringSliceFlag{Name: "intervals, i", Value: &cli.StringSlice{}, EnvVar: "APP_INTERVALS"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
}).Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiStringSliceFromEnvCascade(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_INTERVALS", "20,30,40") |
||||
|
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.StringSliceFlag{Name: "intervals, i", Value: &cli.StringSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
}).Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiInt(t *testing.T) { |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.IntFlag{Name: "serve, s"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Int("serve") != 10 { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.Int("s") != 10 { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run", "-s", "10"}) |
||||
} |
||||
|
||||
func TestParseMultiIntFromEnv(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_TIMEOUT_SECONDS", "10") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.IntFlag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Int("timeout") != 10 { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.Int("t") != 10 { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiIntFromEnvCascade(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_TIMEOUT_SECONDS", "10") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.IntFlag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Int("timeout") != 10 { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.Int("t") != 10 { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiIntSlice(t *testing.T) { |
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.IntSliceFlag{Name: "serve, s", Value: &cli.IntSlice{}}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.IntSlice("serve"), []int{10, 20}) { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if !reflect.DeepEqual(ctx.IntSlice("s"), []int{10, 20}) { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
}).Run([]string{"run", "-s", "10", "-s", "20"}) |
||||
} |
||||
|
||||
func TestParseMultiIntSliceFromEnv(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_INTERVALS", "20,30,40") |
||||
|
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.IntSliceFlag{Name: "intervals, i", Value: &cli.IntSlice{}, EnvVar: "APP_INTERVALS"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
}).Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiIntSliceFromEnvCascade(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_INTERVALS", "20,30,40") |
||||
|
||||
(&cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.IntSliceFlag{Name: "intervals, i", Value: &cli.IntSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
}).Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiFloat64(t *testing.T) { |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.Float64Flag{Name: "serve, s"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Float64("serve") != 10.2 { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.Float64("s") != 10.2 { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run", "-s", "10.2"}) |
||||
} |
||||
|
||||
func TestParseMultiFloat64FromEnv(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_TIMEOUT_SECONDS", "15.5") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Float64("timeout") != 15.5 { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.Float64("t") != 15.5 { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiFloat64FromEnvCascade(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_TIMEOUT_SECONDS", "15.5") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.Float64Flag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Float64("timeout") != 15.5 { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.Float64("t") != 15.5 { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiBool(t *testing.T) { |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.BoolFlag{Name: "serve, s"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Bool("serve") != true { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.Bool("s") != true { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run", "--serve"}) |
||||
} |
||||
|
||||
func TestParseMultiBoolFromEnv(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_DEBUG", "1") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.BoolFlag{Name: "debug, d", EnvVar: "APP_DEBUG"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Bool("debug") != true { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if ctx.Bool("d") != true { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiBoolFromEnvCascade(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_DEBUG", "1") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.BoolFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.Bool("debug") != true { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if ctx.Bool("d") != true { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiBoolT(t *testing.T) { |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.BoolTFlag{Name: "serve, s"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.BoolT("serve") != true { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if ctx.BoolT("s") != true { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run", "--serve"}) |
||||
} |
||||
|
||||
func TestParseMultiBoolTFromEnv(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_DEBUG", "0") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.BoolTFlag{Name: "debug, d", EnvVar: "APP_DEBUG"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.BoolT("debug") != false { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if ctx.BoolT("d") != false { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseMultiBoolTFromEnvCascade(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_DEBUG", "0") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.BoolTFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if ctx.BoolT("debug") != false { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if ctx.BoolT("d") != false { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
type Parser [2]string |
||||
|
||||
func (p *Parser) Set(value string) error { |
||||
parts := strings.Split(value, ",") |
||||
if len(parts) != 2 { |
||||
return fmt.Errorf("invalid format") |
||||
} |
||||
|
||||
(*p)[0] = parts[0] |
||||
(*p)[1] = parts[1] |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (p *Parser) String() string { |
||||
return fmt.Sprintf("%s,%s", p[0], p[1]) |
||||
} |
||||
|
||||
func TestParseGeneric(t *testing.T) { |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.GenericFlag{Name: "serve, s", Value: &Parser{}}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"10", "20"}) { |
||||
t.Errorf("main name not set") |
||||
} |
||||
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"10", "20"}) { |
||||
t.Errorf("short name not set") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run", "-s", "10,20"}) |
||||
} |
||||
|
||||
func TestParseGenericFromEnv(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_SERVE", "20,30") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.GenericFlag{Name: "serve, s", Value: &Parser{}, EnvVar: "APP_SERVE"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"20", "30"}) { |
||||
t.Errorf("main name not set from env") |
||||
} |
||||
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"20", "30"}) { |
||||
t.Errorf("short name not set from env") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
||||
|
||||
func TestParseGenericFromEnvCascade(t *testing.T) { |
||||
os.Clearenv() |
||||
os.Setenv("APP_FOO", "99,2000") |
||||
a := cli.App{ |
||||
Flags: []cli.Flag{ |
||||
cli.GenericFlag{Name: "foos", Value: &Parser{}, EnvVar: "COMPAT_FOO,APP_FOO"}, |
||||
}, |
||||
Action: func(ctx *cli.Context) { |
||||
if !reflect.DeepEqual(ctx.Generic("foos"), &Parser{"99", "2000"}) { |
||||
t.Errorf("value not set from env") |
||||
} |
||||
}, |
||||
} |
||||
a.Run([]string{"run"}) |
||||
} |
@ -0,0 +1,211 @@ |
||||
package cli |
||||
|
||||
import "fmt" |
||||
|
||||
// The text template for the Default help topic.
|
||||
// cli.go uses text/template to render templates. You can
|
||||
// render custom help text by setting this variable.
|
||||
var AppHelpTemplate = `NAME: |
||||
{{.Name}} - {{.Usage}} |
||||
|
||||
USAGE: |
||||
{{.Name}} {{if .Flags}}[global options] {{end}}command{{if .Flags}} [command options]{{end}} [arguments...] |
||||
|
||||
VERSION: |
||||
{{.Version}}{{if or .Author .Email}} |
||||
|
||||
AUTHOR:{{if .Author}} |
||||
{{.Author}}{{if .Email}} - <{{.Email}}>{{end}}{{else}} |
||||
{{.Email}}{{end}}{{end}} |
||||
|
||||
COMMANDS: |
||||
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}} |
||||
{{end}}{{if .Flags}} |
||||
GLOBAL OPTIONS: |
||||
{{range .Flags}}{{.}} |
||||
{{end}}{{end}} |
||||
` |
||||
|
||||
// The text template for the command help topic.
|
||||
// cli.go uses text/template to render templates. You can
|
||||
// render custom help text by setting this variable.
|
||||
var CommandHelpTemplate = `NAME: |
||||
{{.Name}} - {{.Usage}} |
||||
|
||||
USAGE: |
||||
command {{.Name}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}} |
||||
|
||||
DESCRIPTION: |
||||
{{.Description}}{{end}}{{if .Flags}} |
||||
|
||||
OPTIONS: |
||||
{{range .Flags}}{{.}} |
||||
{{end}}{{ end }} |
||||
` |
||||
|
||||
// The text template for the subcommand help topic.
|
||||
// cli.go uses text/template to render templates. You can
|
||||
// render custom help text by setting this variable.
|
||||
var SubcommandHelpTemplate = `NAME: |
||||
{{.Name}} - {{.Usage}} |
||||
|
||||
USAGE: |
||||
{{.Name}} command{{if .Flags}} [command options]{{end}} [arguments...] |
||||
|
||||
COMMANDS: |
||||
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}} |
||||
{{end}}{{if .Flags}} |
||||
OPTIONS: |
||||
{{range .Flags}}{{.}} |
||||
{{end}}{{end}} |
||||
` |
||||
|
||||
var helpCommand = Command{ |
||||
Name: "help", |
||||
ShortName: "h", |
||||
Usage: "Shows a list of commands or help for one command", |
||||
Action: func(c *Context) { |
||||
args := c.Args() |
||||
if args.Present() { |
||||
ShowCommandHelp(c, args.First()) |
||||
} else { |
||||
ShowAppHelp(c) |
||||
} |
||||
}, |
||||
} |
||||
|
||||
var helpSubcommand = Command{ |
||||
Name: "help", |
||||
ShortName: "h", |
||||
Usage: "Shows a list of commands or help for one command", |
||||
Action: func(c *Context) { |
||||
args := c.Args() |
||||
if args.Present() { |
||||
ShowCommandHelp(c, args.First()) |
||||
} else { |
||||
ShowSubcommandHelp(c) |
||||
} |
||||
}, |
||||
} |
||||
|
||||
// Prints help for the App
|
||||
type helpPrinter func(templ string, data interface{}) |
||||
|
||||
var HelpPrinter helpPrinter = nil |
||||
|
||||
// Prints version for the App
|
||||
var VersionPrinter = printVersion |
||||
|
||||
func ShowAppHelp(c *Context) { |
||||
HelpPrinter(AppHelpTemplate, c.App) |
||||
} |
||||
|
||||
// Prints the list of subcommands as the default app completion method
|
||||
func DefaultAppComplete(c *Context) { |
||||
for _, command := range c.App.Commands { |
||||
fmt.Fprintln(c.App.Writer, command.Name) |
||||
if command.ShortName != "" { |
||||
fmt.Fprintln(c.App.Writer, command.ShortName) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Prints help for the given command
|
||||
func ShowCommandHelp(c *Context, command string) { |
||||
for _, c := range c.App.Commands { |
||||
if c.HasName(command) { |
||||
HelpPrinter(CommandHelpTemplate, c) |
||||
return |
||||
} |
||||
} |
||||
|
||||
if c.App.CommandNotFound != nil { |
||||
c.App.CommandNotFound(c, command) |
||||
} else { |
||||
fmt.Fprintf(c.App.Writer, "No help topic for '%v'\n", command) |
||||
} |
||||
} |
||||
|
||||
// Prints help for the given subcommand
|
||||
func ShowSubcommandHelp(c *Context) { |
||||
ShowCommandHelp(c, c.Command.Name) |
||||
} |
||||
|
||||
// Prints the version number of the App
|
||||
func ShowVersion(c *Context) { |
||||
VersionPrinter(c) |
||||
} |
||||
|
||||
func printVersion(c *Context) { |
||||
fmt.Fprintf(c.App.Writer, "%v version %v\n", c.App.Name, c.App.Version) |
||||
} |
||||
|
||||
// Prints the lists of commands within a given context
|
||||
func ShowCompletions(c *Context) { |
||||
a := c.App |
||||
if a != nil && a.BashComplete != nil { |
||||
a.BashComplete(c) |
||||
} |
||||
} |
||||
|
||||
// Prints the custom completions for a given command
|
||||
func ShowCommandCompletions(ctx *Context, command string) { |
||||
c := ctx.App.Command(command) |
||||
if c != nil && c.BashComplete != nil { |
||||
c.BashComplete(ctx) |
||||
} |
||||
} |
||||
|
||||
func checkVersion(c *Context) bool { |
||||
if c.GlobalBool("version") { |
||||
ShowVersion(c) |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
func checkHelp(c *Context) bool { |
||||
if c.GlobalBool("h") || c.GlobalBool("help") { |
||||
ShowAppHelp(c) |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
func checkCommandHelp(c *Context, name string) bool { |
||||
if c.Bool("h") || c.Bool("help") { |
||||
ShowCommandHelp(c, name) |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
func checkSubcommandHelp(c *Context) bool { |
||||
if c.GlobalBool("h") || c.GlobalBool("help") { |
||||
ShowSubcommandHelp(c) |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
func checkCompletions(c *Context) bool { |
||||
if (c.GlobalBool(BashCompletionFlag.Name) || c.Bool(BashCompletionFlag.Name)) && c.App.EnableBashCompletion { |
||||
ShowCompletions(c) |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
func checkCommandCompletions(c *Context, name string) bool { |
||||
if c.Bool(BashCompletionFlag.Name) && c.App.EnableBashCompletion { |
||||
ShowCommandCompletions(c, name) |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
@ -0,0 +1,19 @@ |
||||
package cli_test |
||||
|
||||
import ( |
||||
"reflect" |
||||
"testing" |
||||
) |
||||
|
||||
/* Test Helpers */ |
||||
func expect(t *testing.T, a interface{}, b interface{}) { |
||||
if a != b { |
||||
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) |
||||
} |
||||
} |
||||
|
||||
func refute(t *testing.T, a interface{}, b interface{}) { |
||||
if a == b { |
||||
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) |
||||
} |
||||
} |
@ -0,0 +1,21 @@ |
||||
Copyright © 2012 Peter Harris |
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a |
||||
copy of this software and associated documentation files (the "Software"), |
||||
to deal in the Software without restriction, including without limitation |
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense, |
||||
and/or sell copies of the Software, and to permit persons to whom the |
||||
Software is furnished to do so, subject to the following conditions: |
||||
|
||||
The above copyright notice and this permission notice (including the next |
||||
paragraph) shall be included in all copies or substantial portions of the |
||||
Software. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL |
||||
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
||||
DEALINGS IN THE SOFTWARE. |
||||
|
@ -0,0 +1,95 @@ |
||||
Liner |
||||
===== |
||||
|
||||
Liner is a command line editor with history. It was inspired by linenoise; |
||||
everything Unix-like is a VT100 (or is trying very hard to be). If your |
||||
terminal is not pretending to be a VT100, change it. Liner also support |
||||
Windows. |
||||
|
||||
Liner is released under the X11 license (which is similar to the new BSD |
||||
license). |
||||
|
||||
Line Editing |
||||
------------ |
||||
|
||||
The following line editing commands are supported on platforms and terminals |
||||
that Liner supports: |
||||
|
||||
Keystroke | Action |
||||
--------- | ------ |
||||
Ctrl-A, Home | Move cursor to beginning of line |
||||
Ctrl-E, End | Move cursor to end of line |
||||
Ctrl-B, Left | Move cursor one character left |
||||
Ctrl-F, Right| Move cursor one character right |
||||
Ctrl-Left | Move cursor to previous word |
||||
Ctrl-Right | Move cursor to next word |
||||
Ctrl-D, Del | (if line is *not* empty) Delete character under cursor |
||||
Ctrl-D | (if line *is* empty) End of File - usually quits application |
||||
Ctrl-C | Reset input (create new empty prompt) |
||||
Ctrl-L | Clear screen (line is unmodified) |
||||
Ctrl-T | Transpose previous character with current character |
||||
Ctrl-H, BackSpace | Delete character before cursor |
||||
Ctrl-W | Delete word leading up to cursor |
||||
Ctrl-K | Delete from cursor to end of line |
||||
Ctrl-U | Delete from start of line to cursor |
||||
Ctrl-P, Up | Previous match from history |
||||
Ctrl-N, Down | Next match from history |
||||
Ctrl-R | Reverse Search history (Ctrl-S forward, Ctrl-G cancel) |
||||
Ctrl-Y | Paste from Yank buffer (Alt-Y to paste next yank instead) |
||||
Tab | Next completion |
||||
Shift-Tab | (after Tab) Previous completion |
||||
|
||||
Getting started |
||||
----------------- |
||||
|
||||
```go |
||||
package main |
||||
|
||||
import ( |
||||
"log" |
||||
"os" |
||||
"strings" |
||||
|
||||
"github.com/peterh/liner" |
||||
) |
||||
|
||||
var ( |
||||
history_fn = "/tmp/.liner_history" |
||||
names = []string{"john", "james", "mary", "nancy"} |
||||
) |
||||
|
||||
func main() { |
||||
line := liner.NewLiner() |
||||
defer line.Close() |
||||
|
||||
line.SetCompleter(func(line string) (c []string) { |
||||
for _, n := range names { |
||||
if strings.HasPrefix(n, strings.ToLower(line)) { |
||||
c = append(c, n) |
||||
} |
||||
} |
||||
return |
||||
}) |
||||
|
||||
if f, err := os.Open(history_fn); err == nil { |
||||
line.ReadHistory(f) |
||||
f.Close() |
||||
} |
||||
|
||||
if name, err := line.Prompt("What is your name? "); err != nil { |
||||
log.Print("Error reading line: ", err) |
||||
} else { |
||||
log.Print("Got: ", name) |
||||
line.AppendHistory(name) |
||||
} |
||||
|
||||
if f, err := os.Create(history_fn); err != nil { |
||||
log.Print("Error writing history file: ", err) |
||||
} else { |
||||
line.WriteHistory(f) |
||||
f.Close() |
||||
} |
||||
} |
||||
``` |
||||
|
||||
For documentation, see http://godoc.org/github.com/peterh/liner |
@ -0,0 +1,39 @@ |
||||
// +build openbsd freebsd netbsd
|
||||
|
||||
package liner |
||||
|
||||
import "syscall" |
||||
|
||||
const ( |
||||
getTermios = syscall.TIOCGETA |
||||
setTermios = syscall.TIOCSETA |
||||
) |
||||
|
||||
const ( |
||||
// Input flags
|
||||
inpck = 0x010 |
||||
istrip = 0x020 |
||||
icrnl = 0x100 |
||||
ixon = 0x200 |
||||
|
||||
// Output flags
|
||||
opost = 0x1 |
||||
|
||||
// Control flags
|
||||
cs8 = 0x300 |
||||
|
||||
// Local flags
|
||||
isig = 0x080 |
||||
icanon = 0x100 |
||||
iexten = 0x400 |
||||
) |
||||
|
||||
type termios struct { |
||||
Iflag uint32 |
||||
Oflag uint32 |
||||
Cflag uint32 |
||||
Lflag uint32 |
||||
Cc [20]byte |
||||
Ispeed int32 |
||||
Ospeed int32 |
||||
} |
@ -0,0 +1,219 @@ |
||||
/* |
||||
Package liner implements a simple command line editor, inspired by linenoise |
||||
(https://github.com/antirez/linenoise/). This package supports WIN32 in
|
||||
addition to the xterm codes supported by everything else. |
||||
*/ |
||||
package liner |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"container/ring" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"strings" |
||||
"sync" |
||||
"unicode/utf8" |
||||
) |
||||
|
||||
type commonState struct { |
||||
terminalSupported bool |
||||
outputRedirected bool |
||||
inputRedirected bool |
||||
history []string |
||||
historyMutex sync.RWMutex |
||||
completer WordCompleter |
||||
columns int |
||||
killRing *ring.Ring |
||||
ctrlCAborts bool |
||||
r *bufio.Reader |
||||
tabStyle TabStyle |
||||
} |
||||
|
||||
// TabStyle is used to select how tab completions are displayed.
|
||||
type TabStyle int |
||||
|
||||
// Two tab styles are currently available:
|
||||
//
|
||||
// TabCircular cycles through each completion item and displays it directly on
|
||||
// the prompt
|
||||
//
|
||||
// TabPrints prints the list of completion items to the screen after a second
|
||||
// tab key is pressed. This behaves similar to GNU readline and BASH (which
|
||||
// uses readline)
|
||||
const ( |
||||
TabCircular TabStyle = iota |
||||
TabPrints |
||||
) |
||||
|
||||
// ErrPromptAborted is returned from Prompt or PasswordPrompt when the user presses Ctrl-C
|
||||
// if SetCtrlCAborts(true) has been called on the State
|
||||
var ErrPromptAborted = errors.New("prompt aborted") |
||||
|
||||
// ErrNotTerminalOutput is returned from Prompt or PasswordPrompt if the
|
||||
// platform is normally supported, but stdout has been redirected
|
||||
var ErrNotTerminalOutput = errors.New("standard output is not a terminal") |
||||
|
||||
// Max elements to save on the killring
|
||||
const KillRingMax = 60 |
||||
|
||||
// HistoryLimit is the maximum number of entries saved in the scrollback history.
|
||||
const HistoryLimit = 1000 |
||||
|
||||
// ReadHistory reads scrollback history from r. Returns the number of lines
|
||||
// read, and any read error (except io.EOF).
|
||||
func (s *State) ReadHistory(r io.Reader) (num int, err error) { |
||||
s.historyMutex.Lock() |
||||
defer s.historyMutex.Unlock() |
||||
|
||||
in := bufio.NewReader(r) |
||||
num = 0 |
||||
for { |
||||
line, part, err := in.ReadLine() |
||||
if err == io.EOF { |
||||
break |
||||
} |
||||
if err != nil { |
||||
return num, err |
||||
} |
||||
if part { |
||||
return num, fmt.Errorf("line %d is too long", num+1) |
||||
} |
||||
if !utf8.Valid(line) { |
||||
return num, fmt.Errorf("invalid string at line %d", num+1) |
||||
} |
||||
num++ |
||||
s.history = append(s.history, string(line)) |
||||
if len(s.history) > HistoryLimit { |
||||
s.history = s.history[1:] |
||||
} |
||||
} |
||||
return num, nil |
||||
} |
||||
|
||||
// WriteHistory writes scrollback history to w. Returns the number of lines
|
||||
// successfully written, and any write error.
|
||||
//
|
||||
// Unlike the rest of liner's API, WriteHistory is safe to call
|
||||
// from another goroutine while Prompt is in progress.
|
||||
// This exception is to facilitate the saving of the history buffer
|
||||
// during an unexpected exit (for example, due to Ctrl-C being invoked)
|
||||
func (s *State) WriteHistory(w io.Writer) (num int, err error) { |
||||
s.historyMutex.RLock() |
||||
defer s.historyMutex.RUnlock() |
||||
|
||||
for _, item := range s.history { |
||||
_, err := fmt.Fprintln(w, item) |
||||
if err != nil { |
||||
return num, err |
||||
} |
||||
num++ |
||||
} |
||||
return num, nil |
||||
} |
||||
|
||||
// AppendHistory appends an entry to the scrollback history. AppendHistory
|
||||
// should be called iff Prompt returns a valid command.
|
||||
func (s *State) AppendHistory(item string) { |
||||
s.historyMutex.Lock() |
||||
defer s.historyMutex.Unlock() |
||||
|
||||
if len(s.history) > 0 { |
||||
if item == s.history[len(s.history)-1] { |
||||
return |
||||
} |
||||
} |
||||
s.history = append(s.history, item) |
||||
if len(s.history) > HistoryLimit { |
||||
s.history = s.history[1:] |
||||
} |
||||
} |
||||
|
||||
// Returns the history lines starting with prefix
|
||||
func (s *State) getHistoryByPrefix(prefix string) (ph []string) { |
||||
for _, h := range s.history { |
||||
if strings.HasPrefix(h, prefix) { |
||||
ph = append(ph, h) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Returns the history lines matching the inteligent search
|
||||
func (s *State) getHistoryByPattern(pattern string) (ph []string, pos []int) { |
||||
if pattern == "" { |
||||
return |
||||
} |
||||
for _, h := range s.history { |
||||
if i := strings.Index(h, pattern); i >= 0 { |
||||
ph = append(ph, h) |
||||
pos = append(pos, i) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Completer takes the currently edited line content at the left of the cursor
|
||||
// and returns a list of completion candidates.
|
||||
// If the line is "Hello, wo!!!" and the cursor is before the first '!', "Hello, wo" is passed
|
||||
// to the completer which may return {"Hello, world", "Hello, Word"} to have "Hello, world!!!".
|
||||
type Completer func(line string) []string |
||||
|
||||
// WordCompleter takes the currently edited line with the cursor position and
|
||||
// returns the completion candidates for the partial word to be completed.
|
||||
// If the line is "Hello, wo!!!" and the cursor is before the first '!', ("Hello, wo!!!", 9) is passed
|
||||
// to the completer which may returns ("Hello, ", {"world", "Word"}, "!!!") to have "Hello, world!!!".
|
||||
type WordCompleter func(line string, pos int) (head string, completions []string, tail string) |
||||
|
||||
// SetCompleter sets the completion function that Liner will call to
|
||||
// fetch completion candidates when the user presses tab.
|
||||
func (s *State) SetCompleter(f Completer) { |
||||
if f == nil { |
||||
s.completer = nil |
||||
return |
||||
} |
||||
s.completer = func(line string, pos int) (string, []string, string) { |
||||
return "", f(line[:pos]), line[pos:] |
||||
} |
||||
} |
||||
|
||||
// SetWordCompleter sets the completion function that Liner will call to
|
||||
// fetch completion candidates when the user presses tab.
|
||||
func (s *State) SetWordCompleter(f WordCompleter) { |
||||
s.completer = f |
||||
} |
||||
|
||||
// SetTabCompletionStyle sets the behvavior when the Tab key is pressed
|
||||
// for auto-completion. TabCircular is the default behavior and cycles
|
||||
// through the list of candidates at the prompt. TabPrints will print
|
||||
// the available completion candidates to the screen similar to BASH
|
||||
// and GNU Readline
|
||||
func (s *State) SetTabCompletionStyle(tabStyle TabStyle) { |
||||
s.tabStyle = tabStyle |
||||
} |
||||
|
||||
// ModeApplier is the interface that wraps a representation of the terminal
|
||||
// mode. ApplyMode sets the terminal to this mode.
|
||||
type ModeApplier interface { |
||||
ApplyMode() error |
||||
} |
||||
|
||||
// SetCtrlCAborts sets whether Prompt on a supported terminal will return an
|
||||
// ErrPromptAborted when Ctrl-C is pressed. The default is false (will not
|
||||
// return when Ctrl-C is pressed). Unsupported terminals typically raise SIGINT
|
||||
// (and Prompt does not return) regardless of the value passed to SetCtrlCAborts.
|
||||
func (s *State) SetCtrlCAborts(aborts bool) { |
||||
s.ctrlCAborts = aborts |
||||
} |
||||
|
||||
func (s *State) promptUnsupported(p string) (string, error) { |
||||
if !s.inputRedirected { |
||||
fmt.Print(p) |
||||
} |
||||
linebuf, _, err := s.r.ReadLine() |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return string(bytes.TrimSpace(linebuf)), nil |
||||
} |
@ -0,0 +1,57 @@ |
||||
// +build !windows,!linux,!darwin,!openbsd,!freebsd,!netbsd
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"bufio" |
||||
"errors" |
||||
"os" |
||||
) |
||||
|
||||
// State represents an open terminal
|
||||
type State struct { |
||||
commonState |
||||
} |
||||
|
||||
// Prompt displays p, and then waits for user input. Prompt does not support
|
||||
// line editing on this operating system.
|
||||
func (s *State) Prompt(p string) (string, error) { |
||||
return s.promptUnsupported(p) |
||||
} |
||||
|
||||
// PasswordPrompt is not supported in this OS.
|
||||
func (s *State) PasswordPrompt(p string) (string, error) { |
||||
return "", errors.New("liner: function not supported in this terminal") |
||||
} |
||||
|
||||
// NewLiner initializes a new *State
|
||||
//
|
||||
// Note that this operating system uses a fallback mode without line
|
||||
// editing. Patches welcome.
|
||||
func NewLiner() *State { |
||||
var s State |
||||
s.r = bufio.NewReader(os.Stdin) |
||||
return &s |
||||
} |
||||
|
||||
// Close returns the terminal to its previous mode
|
||||
func (s *State) Close() error { |
||||
return nil |
||||
} |
||||
|
||||
// TerminalSupported returns false because line editing is not
|
||||
// supported on this platform.
|
||||
func TerminalSupported() bool { |
||||
return false |
||||
} |
||||
|
||||
type noopMode struct{} |
||||
|
||||
func (n noopMode) ApplyMode() error { |
||||
return nil |
||||
} |
||||
|
||||
// TerminalMode returns a noop InputModeSetter on this platform.
|
||||
func TerminalMode() (ModeApplier, error) { |
||||
return noopMode{}, nil |
||||
} |
@ -0,0 +1,359 @@ |
||||
// +build linux darwin openbsd freebsd netbsd
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"bufio" |
||||
"errors" |
||||
"os" |
||||
"os/signal" |
||||
"strconv" |
||||
"strings" |
||||
"syscall" |
||||
"time" |
||||
) |
||||
|
||||
type nexter struct { |
||||
r rune |
||||
err error |
||||
} |
||||
|
||||
// State represents an open terminal
|
||||
type State struct { |
||||
commonState |
||||
origMode termios |
||||
defaultMode termios |
||||
next <-chan nexter |
||||
winch chan os.Signal |
||||
pending []rune |
||||
useCHA bool |
||||
} |
||||
|
||||
// NewLiner initializes a new *State, and sets the terminal into raw mode. To
|
||||
// restore the terminal to its previous state, call State.Close().
|
||||
//
|
||||
// Note if you are still using Go 1.0: NewLiner handles SIGWINCH, so it will
|
||||
// leak a channel every time you call it. Therefore, it is recommened that you
|
||||
// upgrade to a newer release of Go, or ensure that NewLiner is only called
|
||||
// once.
|
||||
func NewLiner() *State { |
||||
var s State |
||||
s.r = bufio.NewReader(os.Stdin) |
||||
|
||||
s.terminalSupported = TerminalSupported() |
||||
if m, err := TerminalMode(); err == nil { |
||||
s.origMode = *m.(*termios) |
||||
} else { |
||||
s.terminalSupported = false |
||||
s.inputRedirected = true |
||||
} |
||||
if _, err := getMode(syscall.Stdout); err != 0 { |
||||
s.terminalSupported = false |
||||
s.outputRedirected = true |
||||
} |
||||
if s.terminalSupported { |
||||
mode := s.origMode |
||||
mode.Iflag &^= icrnl | inpck | istrip | ixon |
||||
mode.Cflag |= cs8 |
||||
mode.Lflag &^= syscall.ECHO | icanon | iexten |
||||
mode.ApplyMode() |
||||
|
||||
winch := make(chan os.Signal, 1) |
||||
signal.Notify(winch, syscall.SIGWINCH) |
||||
s.winch = winch |
||||
|
||||
s.checkOutput() |
||||
} |
||||
|
||||
if !s.outputRedirected { |
||||
s.getColumns() |
||||
s.outputRedirected = s.columns <= 0 |
||||
} |
||||
|
||||
return &s |
||||
} |
||||
|
||||
var errTimedOut = errors.New("timeout") |
||||
|
||||
func (s *State) startPrompt() { |
||||
if s.terminalSupported { |
||||
if m, err := TerminalMode(); err == nil { |
||||
s.defaultMode = *m.(*termios) |
||||
mode := s.defaultMode |
||||
mode.Lflag &^= isig |
||||
mode.ApplyMode() |
||||
} |
||||
} |
||||
s.restartPrompt() |
||||
} |
||||
|
||||
func (s *State) restartPrompt() { |
||||
next := make(chan nexter) |
||||
go func() { |
||||
for { |
||||
var n nexter |
||||
n.r, _, n.err = s.r.ReadRune() |
||||
next <- n |
||||
// Shut down nexter loop when an end condition has been reached
|
||||
if n.err != nil || n.r == '\n' || n.r == '\r' || n.r == ctrlC || n.r == ctrlD { |
||||
close(next) |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
s.next = next |
||||
} |
||||
|
||||
func (s *State) stopPrompt() { |
||||
if s.terminalSupported { |
||||
s.defaultMode.ApplyMode() |
||||
} |
||||
} |
||||
|
||||
func (s *State) nextPending(timeout <-chan time.Time) (rune, error) { |
||||
select { |
||||
case thing, ok := <-s.next: |
||||
if !ok { |
||||
return 0, errors.New("liner: internal error") |
||||
} |
||||
if thing.err != nil { |
||||
return 0, thing.err |
||||
} |
||||
s.pending = append(s.pending, thing.r) |
||||
return thing.r, nil |
||||
case <-timeout: |
||||
rv := s.pending[0] |
||||
s.pending = s.pending[1:] |
||||
return rv, errTimedOut |
||||
} |
||||
// not reached
|
||||
return 0, nil |
||||
} |
||||
|
||||
func (s *State) readNext() (interface{}, error) { |
||||
if len(s.pending) > 0 { |
||||
rv := s.pending[0] |
||||
s.pending = s.pending[1:] |
||||
return rv, nil |
||||
} |
||||
var r rune |
||||
select { |
||||
case thing, ok := <-s.next: |
||||
if !ok { |
||||
return 0, errors.New("liner: internal error") |
||||
} |
||||
if thing.err != nil { |
||||
return nil, thing.err |
||||
} |
||||
r = thing.r |
||||
case <-s.winch: |
||||
s.getColumns() |
||||
return winch, nil |
||||
} |
||||
if r != esc { |
||||
return r, nil |
||||
} |
||||
s.pending = append(s.pending, r) |
||||
|
||||
// Wait at most 50 ms for the rest of the escape sequence
|
||||
// If nothing else arrives, it was an actual press of the esc key
|
||||
timeout := time.After(50 * time.Millisecond) |
||||
flag, err := s.nextPending(timeout) |
||||
if err != nil { |
||||
if err == errTimedOut { |
||||
return flag, nil |
||||
} |
||||
return unknown, err |
||||
} |
||||
|
||||
switch flag { |
||||
case '[': |
||||
code, err := s.nextPending(timeout) |
||||
if err != nil { |
||||
if err == errTimedOut { |
||||
return code, nil |
||||
} |
||||
return unknown, err |
||||
} |
||||
switch code { |
||||
case 'A': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
return up, nil |
||||
case 'B': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
return down, nil |
||||
case 'C': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
return right, nil |
||||
case 'D': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
return left, nil |
||||
case 'F': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
return end, nil |
||||
case 'H': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
return home, nil |
||||
case 'Z': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
return shiftTab, nil |
||||
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': |
||||
num := []rune{code} |
||||
for { |
||||
code, err := s.nextPending(timeout) |
||||
if err != nil { |
||||
if err == errTimedOut { |
||||
return code, nil |
||||
} |
||||
return nil, err |
||||
} |
||||
switch code { |
||||
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': |
||||
num = append(num, code) |
||||
case ';': |
||||
// Modifier code to follow
|
||||
// This only supports Ctrl-left and Ctrl-right for now
|
||||
x, _ := strconv.ParseInt(string(num), 10, 32) |
||||
if x != 1 { |
||||
// Can't be left or right
|
||||
rv := s.pending[0] |
||||
s.pending = s.pending[1:] |
||||
return rv, nil |
||||
} |
||||
num = num[:0] |
||||
for { |
||||
code, err = s.nextPending(timeout) |
||||
if err != nil { |
||||
if err == errTimedOut { |
||||
rv := s.pending[0] |
||||
s.pending = s.pending[1:] |
||||
return rv, nil |
||||
} |
||||
return nil, err |
||||
} |
||||
switch code { |
||||
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': |
||||
num = append(num, code) |
||||
case 'C', 'D': |
||||
// right, left
|
||||
mod, _ := strconv.ParseInt(string(num), 10, 32) |
||||
if mod != 5 { |
||||
// Not bare Ctrl
|
||||
rv := s.pending[0] |
||||
s.pending = s.pending[1:] |
||||
return rv, nil |
||||
} |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
if code == 'C' { |
||||
return wordRight, nil |
||||
} |
||||
return wordLeft, nil |
||||
default: |
||||
// Not left or right
|
||||
rv := s.pending[0] |
||||
s.pending = s.pending[1:] |
||||
return rv, nil |
||||
} |
||||
} |
||||
case '~': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
x, _ := strconv.ParseInt(string(num), 10, 32) |
||||
switch x { |
||||
case 2: |
||||
return insert, nil |
||||
case 3: |
||||
return del, nil |
||||
case 5: |
||||
return pageUp, nil |
||||
case 6: |
||||
return pageDown, nil |
||||
case 7: |
||||
return home, nil |
||||
case 8: |
||||
return end, nil |
||||
case 15: |
||||
return f5, nil |
||||
case 17: |
||||
return f6, nil |
||||
case 18: |
||||
return f7, nil |
||||
case 19: |
||||
return f8, nil |
||||
case 20: |
||||
return f9, nil |
||||
case 21: |
||||
return f10, nil |
||||
case 23: |
||||
return f11, nil |
||||
case 24: |
||||
return f12, nil |
||||
default: |
||||
return unknown, nil |
||||
} |
||||
default: |
||||
// unrecognized escape code
|
||||
rv := s.pending[0] |
||||
s.pending = s.pending[1:] |
||||
return rv, nil |
||||
} |
||||
} |
||||
} |
||||
|
||||
case 'O': |
||||
code, err := s.nextPending(timeout) |
||||
if err != nil { |
||||
if err == errTimedOut { |
||||
return code, nil |
||||
} |
||||
return nil, err |
||||
} |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
switch code { |
||||
case 'c': |
||||
return wordRight, nil |
||||
case 'd': |
||||
return wordLeft, nil |
||||
case 'H': |
||||
return home, nil |
||||
case 'F': |
||||
return end, nil |
||||
case 'P': |
||||
return f1, nil |
||||
case 'Q': |
||||
return f2, nil |
||||
case 'R': |
||||
return f3, nil |
||||
case 'S': |
||||
return f4, nil |
||||
default: |
||||
return unknown, nil |
||||
} |
||||
case 'y': |
||||
s.pending = s.pending[:0] // escape code complete
|
||||
return altY, nil |
||||
default: |
||||
rv := s.pending[0] |
||||
s.pending = s.pending[1:] |
||||
return rv, nil |
||||
} |
||||
|
||||
// not reached
|
||||
return r, nil |
||||
} |
||||
|
||||
// Close returns the terminal to its previous mode
|
||||
func (s *State) Close() error { |
||||
stopSignal(s.winch) |
||||
if s.terminalSupported { |
||||
s.origMode.ApplyMode() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// TerminalSupported returns true if the current terminal supports
|
||||
// line editing features, and false if liner will use the 'dumb'
|
||||
// fallback for input.
|
||||
func TerminalSupported() bool { |
||||
bad := map[string]bool{"": true, "dumb": true, "cons25": true} |
||||
return !bad[strings.ToLower(os.Getenv("TERM"))] |
||||
} |
@ -0,0 +1,39 @@ |
||||
// +build darwin
|
||||
|
||||
package liner |
||||
|
||||
import "syscall" |
||||
|
||||
const ( |
||||
getTermios = syscall.TIOCGETA |
||||
setTermios = syscall.TIOCSETA |
||||
) |
||||
|
||||
const ( |
||||
// Input flags
|
||||
inpck = 0x010 |
||||
istrip = 0x020 |
||||
icrnl = 0x100 |
||||
ixon = 0x200 |
||||
|
||||
// Output flags
|
||||
opost = 0x1 |
||||
|
||||
// Control flags
|
||||
cs8 = 0x300 |
||||
|
||||
// Local flags
|
||||
isig = 0x080 |
||||
icanon = 0x100 |
||||
iexten = 0x400 |
||||
) |
||||
|
||||
type termios struct { |
||||
Iflag uintptr |
||||
Oflag uintptr |
||||
Cflag uintptr |
||||
Lflag uintptr |
||||
Cc [20]byte |
||||
Ispeed uintptr |
||||
Ospeed uintptr |
||||
} |
@ -0,0 +1,26 @@ |
||||
// +build linux
|
||||
|
||||
package liner |
||||
|
||||
import "syscall" |
||||
|
||||
const ( |
||||
getTermios = syscall.TCGETS |
||||
setTermios = syscall.TCSETS |
||||
) |
||||
|
||||
const ( |
||||
icrnl = syscall.ICRNL |
||||
inpck = syscall.INPCK |
||||
istrip = syscall.ISTRIP |
||||
ixon = syscall.IXON |
||||
opost = syscall.OPOST |
||||
cs8 = syscall.CS8 |
||||
isig = syscall.ISIG |
||||
icanon = syscall.ICANON |
||||
iexten = syscall.IEXTEN |
||||
) |
||||
|
||||
type termios struct { |
||||
syscall.Termios |
||||
} |
@ -0,0 +1,61 @@ |
||||
// +build !windows
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"testing" |
||||
) |
||||
|
||||
func (s *State) expectRune(t *testing.T, r rune) { |
||||
item, err := s.readNext() |
||||
if err != nil { |
||||
t.Fatalf("Expected rune '%c', got error %s\n", r, err) |
||||
} |
||||
if v, ok := item.(rune); !ok { |
||||
t.Fatalf("Expected rune '%c', got non-rune %v\n", r, v) |
||||
} else { |
||||
if v != r { |
||||
t.Fatalf("Expected rune '%c', got rune '%c'\n", r, v) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (s *State) expectAction(t *testing.T, a action) { |
||||
item, err := s.readNext() |
||||
if err != nil { |
||||
t.Fatalf("Expected Action %d, got error %s\n", a, err) |
||||
} |
||||
if v, ok := item.(action); !ok { |
||||
t.Fatalf("Expected Action %d, got non-Action %v\n", a, v) |
||||
} else { |
||||
if v != a { |
||||
t.Fatalf("Expected Action %d, got Action %d\n", a, v) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestTypes(t *testing.T) { |
||||
input := []byte{'A', 27, 'B', 27, 91, 68, 27, '[', '1', ';', '5', 'D', 'e'} |
||||
var s State |
||||
s.r = bufio.NewReader(bytes.NewBuffer(input)) |
||||
|
||||
next := make(chan nexter) |
||||
go func() { |
||||
for { |
||||
var n nexter |
||||
n.r, _, n.err = s.r.ReadRune() |
||||
next <- n |
||||
} |
||||
}() |
||||
s.next = next |
||||
|
||||
s.expectRune(t, 'A') |
||||
s.expectRune(t, 27) |
||||
s.expectRune(t, 'B') |
||||
s.expectAction(t, left) |
||||
s.expectAction(t, wordLeft) |
||||
|
||||
s.expectRune(t, 'e') |
||||
} |
@ -0,0 +1,313 @@ |
||||
package liner |
||||
|
||||
import ( |
||||
"bufio" |
||||
"os" |
||||
"syscall" |
||||
"unsafe" |
||||
) |
||||
|
||||
var ( |
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll") |
||||
|
||||
procGetStdHandle = kernel32.NewProc("GetStdHandle") |
||||
procReadConsoleInput = kernel32.NewProc("ReadConsoleInputW") |
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode") |
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode") |
||||
procSetConsoleCursorPosition = kernel32.NewProc("SetConsoleCursorPosition") |
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") |
||||
procFillConsoleOutputCharacter = kernel32.NewProc("FillConsoleOutputCharacterW") |
||||
) |
||||
|
||||
// These names are from the Win32 api, so they use underscores (contrary to
|
||||
// what golint suggests)
|
||||
const ( |
||||
std_input_handle = uint32(-10 & 0xFFFFFFFF) |
||||
std_output_handle = uint32(-11 & 0xFFFFFFFF) |
||||
std_error_handle = uint32(-12 & 0xFFFFFFFF) |
||||
invalid_handle_value = ^uintptr(0) |
||||
) |
||||
|
||||
type inputMode uint32 |
||||
|
||||
// State represents an open terminal
|
||||
type State struct { |
||||
commonState |
||||
handle syscall.Handle |
||||
hOut syscall.Handle |
||||
origMode inputMode |
||||
defaultMode inputMode |
||||
key interface{} |
||||
repeat uint16 |
||||
} |
||||
|
||||
const ( |
||||
enableEchoInput = 0x4 |
||||
enableInsertMode = 0x20 |
||||
enableLineInput = 0x2 |
||||
enableMouseInput = 0x10 |
||||
enableProcessedInput = 0x1 |
||||
enableQuickEditMode = 0x40 |
||||
enableWindowInput = 0x8 |
||||
) |
||||
|
||||
// NewLiner initializes a new *State, and sets the terminal into raw mode. To
|
||||
// restore the terminal to its previous state, call State.Close().
|
||||
func NewLiner() *State { |
||||
var s State |
||||
hIn, _, _ := procGetStdHandle.Call(uintptr(std_input_handle)) |
||||
s.handle = syscall.Handle(hIn) |
||||
hOut, _, _ := procGetStdHandle.Call(uintptr(std_output_handle)) |
||||
s.hOut = syscall.Handle(hOut) |
||||
|
||||
s.terminalSupported = true |
||||
if m, err := TerminalMode(); err == nil { |
||||
s.origMode = m.(inputMode) |
||||
mode := s.origMode |
||||
mode &^= enableEchoInput |
||||
mode &^= enableInsertMode |
||||
mode &^= enableLineInput |
||||
mode &^= enableMouseInput |
||||
mode |= enableWindowInput |
||||
mode.ApplyMode() |
||||
} else { |
||||
s.inputRedirected = true |
||||
s.r = bufio.NewReader(os.Stdin) |
||||
} |
||||
|
||||
s.getColumns() |
||||
s.outputRedirected = s.columns <= 0 |
||||
|
||||
return &s |
||||
} |
||||
|
||||
// These names are from the Win32 api, so they use underscores (contrary to
|
||||
// what golint suggests)
|
||||
const ( |
||||
focus_event = 0x0010 |
||||
key_event = 0x0001 |
||||
menu_event = 0x0008 |
||||
mouse_event = 0x0002 |
||||
window_buffer_size_event = 0x0004 |
||||
) |
||||
|
||||
type input_record struct { |
||||
eventType uint16 |
||||
pad uint16 |
||||
blob [16]byte |
||||
} |
||||
|
||||
type key_event_record struct { |
||||
KeyDown int32 |
||||
RepeatCount uint16 |
||||
VirtualKeyCode uint16 |
||||
VirtualScanCode uint16 |
||||
Char int16 |
||||
ControlKeyState uint32 |
||||
} |
||||
|
||||
// These names are from the Win32 api, so they use underscores (contrary to
|
||||
// what golint suggests)
|
||||
const ( |
||||
vk_tab = 0x09 |
||||
vk_prior = 0x21 |
||||
vk_next = 0x22 |
||||
vk_end = 0x23 |
||||
vk_home = 0x24 |
||||
vk_left = 0x25 |
||||
vk_up = 0x26 |
||||
vk_right = 0x27 |
||||
vk_down = 0x28 |
||||
vk_insert = 0x2d |
||||
vk_delete = 0x2e |
||||
vk_f1 = 0x70 |
||||
vk_f2 = 0x71 |
||||
vk_f3 = 0x72 |
||||
vk_f4 = 0x73 |
||||
vk_f5 = 0x74 |
||||
vk_f6 = 0x75 |
||||
vk_f7 = 0x76 |
||||
vk_f8 = 0x77 |
||||
vk_f9 = 0x78 |
||||
vk_f10 = 0x79 |
||||
vk_f11 = 0x7a |
||||
vk_f12 = 0x7b |
||||
yKey = 0x59 |
||||
) |
||||
|
||||
const ( |
||||
shiftPressed = 0x0010 |
||||
leftAltPressed = 0x0002 |
||||
leftCtrlPressed = 0x0008 |
||||
rightAltPressed = 0x0001 |
||||
rightCtrlPressed = 0x0004 |
||||
|
||||
modKeys = shiftPressed | leftAltPressed | rightAltPressed | leftCtrlPressed | rightCtrlPressed |
||||
) |
||||
|
||||
func (s *State) readNext() (interface{}, error) { |
||||
if s.repeat > 0 { |
||||
s.repeat-- |
||||
return s.key, nil |
||||
} |
||||
|
||||
var input input_record |
||||
pbuf := uintptr(unsafe.Pointer(&input)) |
||||
var rv uint32 |
||||
prv := uintptr(unsafe.Pointer(&rv)) |
||||
|
||||
for { |
||||
ok, _, err := procReadConsoleInput.Call(uintptr(s.handle), pbuf, 1, prv) |
||||
|
||||
if ok == 0 { |
||||
return nil, err |
||||
} |
||||
|
||||
if input.eventType == window_buffer_size_event { |
||||
xy := (*coord)(unsafe.Pointer(&input.blob[0])) |
||||
s.columns = int(xy.x) |
||||
return winch, nil |
||||
} |
||||
if input.eventType != key_event { |
||||
continue |
||||
} |
||||
ke := (*key_event_record)(unsafe.Pointer(&input.blob[0])) |
||||
if ke.KeyDown == 0 { |
||||
continue |
||||
} |
||||
|
||||
if ke.VirtualKeyCode == vk_tab && ke.ControlKeyState&modKeys == shiftPressed { |
||||
s.key = shiftTab |
||||
} else if ke.VirtualKeyCode == yKey && (ke.ControlKeyState&modKeys == leftAltPressed || |
||||
ke.ControlKeyState&modKeys == rightAltPressed) { |
||||
s.key = altY |
||||
} else if ke.Char > 0 { |
||||
s.key = rune(ke.Char) |
||||
} else { |
||||
switch ke.VirtualKeyCode { |
||||
case vk_prior: |
||||
s.key = pageUp |
||||
case vk_next: |
||||
s.key = pageDown |
||||
case vk_end: |
||||
s.key = end |
||||
case vk_home: |
||||
s.key = home |
||||
case vk_left: |
||||
s.key = left |
||||
if ke.ControlKeyState&(leftCtrlPressed|rightCtrlPressed) != 0 { |
||||
if ke.ControlKeyState&modKeys == ke.ControlKeyState&(leftCtrlPressed|rightCtrlPressed) { |
||||
s.key = wordLeft |
||||
} |
||||
} |
||||
case vk_right: |
||||
s.key = right |
||||
if ke.ControlKeyState&(leftCtrlPressed|rightCtrlPressed) != 0 { |
||||
if ke.ControlKeyState&modKeys == ke.ControlKeyState&(leftCtrlPressed|rightCtrlPressed) { |
||||
s.key = wordRight |
||||
} |
||||
} |
||||
case vk_up: |
||||
s.key = up |
||||
case vk_down: |
||||
s.key = down |
||||
case vk_insert: |
||||
s.key = insert |
||||
case vk_delete: |
||||
s.key = del |
||||
case vk_f1: |
||||
s.key = f1 |
||||
case vk_f2: |
||||
s.key = f2 |
||||
case vk_f3: |
||||
s.key = f3 |
||||
case vk_f4: |
||||
s.key = f4 |
||||
case vk_f5: |
||||
s.key = f5 |
||||
case vk_f6: |
||||
s.key = f6 |
||||
case vk_f7: |
||||
s.key = f7 |
||||
case vk_f8: |
||||
s.key = f8 |
||||
case vk_f9: |
||||
s.key = f9 |
||||
case vk_f10: |
||||
s.key = f10 |
||||
case vk_f11: |
||||
s.key = f11 |
||||
case vk_f12: |
||||
s.key = f12 |
||||
default: |
||||
// Eat modifier keys
|
||||
// TODO: return Action(Unknown) if the key isn't a
|
||||
// modifier.
|
||||
continue |
||||
} |
||||
} |
||||
|
||||
if ke.RepeatCount > 1 { |
||||
s.repeat = ke.RepeatCount - 1 |
||||
} |
||||
return s.key, nil |
||||
} |
||||
return unknown, nil |
||||
} |
||||
|
||||
// Close returns the terminal to its previous mode
|
||||
func (s *State) Close() error { |
||||
s.origMode.ApplyMode() |
||||
return nil |
||||
} |
||||
|
||||
func (s *State) startPrompt() { |
||||
if m, err := TerminalMode(); err == nil { |
||||
s.defaultMode = m.(inputMode) |
||||
mode := s.defaultMode |
||||
mode &^= enableProcessedInput |
||||
mode.ApplyMode() |
||||
} |
||||
} |
||||
|
||||
func (s *State) restartPrompt() { |
||||
} |
||||
|
||||
func (s *State) stopPrompt() { |
||||
s.defaultMode.ApplyMode() |
||||
} |
||||
|
||||
// TerminalSupported returns true because line editing is always
|
||||
// supported on Windows.
|
||||
func TerminalSupported() bool { |
||||
return true |
||||
} |
||||
|
||||
func (mode inputMode) ApplyMode() error { |
||||
hIn, _, err := procGetStdHandle.Call(uintptr(std_input_handle)) |
||||
if hIn == invalid_handle_value || hIn == 0 { |
||||
return err |
||||
} |
||||
ok, _, err := procSetConsoleMode.Call(hIn, uintptr(mode)) |
||||
if ok != 0 { |
||||
err = nil |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// TerminalMode returns the current terminal input mode as an InputModeSetter.
|
||||
//
|
||||
// This function is provided for convenience, and should
|
||||
// not be necessary for most users of liner.
|
||||
func TerminalMode() (ModeApplier, error) { |
||||
var mode inputMode |
||||
hIn, _, err := procGetStdHandle.Call(uintptr(std_input_handle)) |
||||
if hIn == invalid_handle_value || hIn == 0 { |
||||
return nil, err |
||||
} |
||||
ok, _, err := procGetConsoleMode.Call(hIn, uintptr(unsafe.Pointer(&mode))) |
||||
if ok != 0 { |
||||
err = nil |
||||
} |
||||
return mode, err |
||||
} |
@ -0,0 +1,864 @@ |
||||
// +build windows linux darwin openbsd freebsd netbsd
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"container/ring" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"strings" |
||||
"unicode" |
||||
"unicode/utf8" |
||||
) |
||||
|
||||
type action int |
||||
|
||||
const ( |
||||
left action = iota |
||||
right |
||||
up |
||||
down |
||||
home |
||||
end |
||||
insert |
||||
del |
||||
pageUp |
||||
pageDown |
||||
f1 |
||||
f2 |
||||
f3 |
||||
f4 |
||||
f5 |
||||
f6 |
||||
f7 |
||||
f8 |
||||
f9 |
||||
f10 |
||||
f11 |
||||
f12 |
||||
altY |
||||
shiftTab |
||||
wordLeft |
||||
wordRight |
||||
winch |
||||
unknown |
||||
) |
||||
|
||||
const ( |
||||
ctrlA = 1 |
||||
ctrlB = 2 |
||||
ctrlC = 3 |
||||
ctrlD = 4 |
||||
ctrlE = 5 |
||||
ctrlF = 6 |
||||
ctrlG = 7 |
||||
ctrlH = 8 |
||||
tab = 9 |
||||
lf = 10 |
||||
ctrlK = 11 |
||||
ctrlL = 12 |
||||
cr = 13 |
||||
ctrlN = 14 |
||||
ctrlO = 15 |
||||
ctrlP = 16 |
||||
ctrlQ = 17 |
||||
ctrlR = 18 |
||||
ctrlS = 19 |
||||
ctrlT = 20 |
||||
ctrlU = 21 |
||||
ctrlV = 22 |
||||
ctrlW = 23 |
||||
ctrlX = 24 |
||||
ctrlY = 25 |
||||
ctrlZ = 26 |
||||
esc = 27 |
||||
bs = 127 |
||||
) |
||||
|
||||
const ( |
||||
beep = "\a" |
||||
) |
||||
|
||||
type tabDirection int |
||||
|
||||
const ( |
||||
tabForward tabDirection = iota |
||||
tabReverse |
||||
) |
||||
|
||||
func (s *State) refresh(prompt []rune, buf []rune, pos int) error { |
||||
s.cursorPos(0) |
||||
_, err := fmt.Print(string(prompt)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
pLen := countGlyphs(prompt) |
||||
bLen := countGlyphs(buf) |
||||
pos = countGlyphs(buf[:pos]) |
||||
if pLen+bLen < s.columns { |
||||
_, err = fmt.Print(string(buf)) |
||||
s.eraseLine() |
||||
s.cursorPos(pLen + pos) |
||||
} else { |
||||
// Find space available
|
||||
space := s.columns - pLen |
||||
space-- // space for cursor
|
||||
start := pos - space/2 |
||||
end := start + space |
||||
if end > bLen { |
||||
end = bLen |
||||
start = end - space |
||||
} |
||||
if start < 0 { |
||||
start = 0 |
||||
end = space |
||||
} |
||||
pos -= start |
||||
|
||||
// Leave space for markers
|
||||
if start > 0 { |
||||
start++ |
||||
} |
||||
if end < bLen { |
||||
end-- |
||||
} |
||||
startRune := len(getPrefixGlyphs(buf, start)) |
||||
line := getPrefixGlyphs(buf[startRune:], end-start) |
||||
|
||||
// Output
|
||||
if start > 0 { |
||||
fmt.Print("{") |
||||
} |
||||
fmt.Print(string(line)) |
||||
if end < bLen { |
||||
fmt.Print("}") |
||||
} |
||||
|
||||
// Set cursor position
|
||||
s.eraseLine() |
||||
s.cursorPos(pLen + pos) |
||||
} |
||||
return err |
||||
} |
||||
|
||||
func longestCommonPrefix(strs []string) string { |
||||
if len(strs) == 0 { |
||||
return "" |
||||
} |
||||
longest := strs[0] |
||||
|
||||
for _, str := range strs[1:] { |
||||
for !strings.HasPrefix(str, longest) { |
||||
longest = longest[:len(longest)-1] |
||||
} |
||||
} |
||||
// Remove trailing partial runes
|
||||
longest = strings.TrimRight(longest, "\uFFFD") |
||||
return longest |
||||
} |
||||
|
||||
func (s *State) circularTabs(items []string) func(tabDirection) (string, error) { |
||||
item := -1 |
||||
return func(direction tabDirection) (string, error) { |
||||
if direction == tabForward { |
||||
if item < len(items)-1 { |
||||
item++ |
||||
} else { |
||||
item = 0 |
||||
} |
||||
} else if direction == tabReverse { |
||||
if item > 0 { |
||||
item-- |
||||
} else { |
||||
item = len(items) - 1 |
||||
} |
||||
} |
||||
return items[item], nil |
||||
} |
||||
} |
||||
|
||||
func (s *State) printedTabs(items []string) func(tabDirection) (string, error) { |
||||
numTabs := 1 |
||||
prefix := longestCommonPrefix(items) |
||||
return func(direction tabDirection) (string, error) { |
||||
if len(items) == 1 { |
||||
return items[0], nil |
||||
} |
||||
|
||||
if numTabs == 2 { |
||||
if len(items) > 100 { |
||||
fmt.Printf("\nDisplay all %d possibilities? (y or n) ", len(items)) |
||||
for { |
||||
next, err := s.readNext() |
||||
if err != nil { |
||||
return prefix, err |
||||
} |
||||
|
||||
if key, ok := next.(rune); ok { |
||||
if unicode.ToLower(key) == 'n' { |
||||
return prefix, nil |
||||
} else if unicode.ToLower(key) == 'y' { |
||||
break |
||||
} |
||||
} |
||||
} |
||||
} |
||||
fmt.Println("") |
||||
maxWidth := 0 |
||||
for _, item := range items { |
||||
if len(item) >= maxWidth { |
||||
maxWidth = len(item) + 1 |
||||
} |
||||
} |
||||
|
||||
numColumns := s.columns / maxWidth |
||||
numRows := len(items) / numColumns |
||||
if len(items)%numColumns > 0 { |
||||
numRows++ |
||||
} |
||||
|
||||
if len(items) <= numColumns { |
||||
maxWidth = 0 |
||||
} |
||||
for i := 0; i < numRows; i++ { |
||||
for j := 0; j < numColumns*numRows; j += numRows { |
||||
if i+j < len(items) { |
||||
if maxWidth > 0 { |
||||
fmt.Printf("%-*s", maxWidth, items[i+j]) |
||||
} else { |
||||
fmt.Printf("%v ", items[i+j]) |
||||
} |
||||
} |
||||
} |
||||
fmt.Println("") |
||||
} |
||||
} else { |
||||
numTabs++ |
||||
} |
||||
return prefix, nil |
||||
} |
||||
} |
||||
|
||||
func (s *State) tabComplete(p []rune, line []rune, pos int) ([]rune, int, interface{}, error) { |
||||
if s.completer == nil { |
||||
return line, pos, rune(esc), nil |
||||
} |
||||
head, list, tail := s.completer(string(line), pos) |
||||
if len(list) <= 0 { |
||||
return line, pos, rune(esc), nil |
||||
} |
||||
hl := utf8.RuneCountInString(head) |
||||
if len(list) == 1 { |
||||
s.refresh(p, []rune(head+list[0]+tail), hl+utf8.RuneCountInString(list[0])) |
||||
return []rune(head + list[0] + tail), hl + utf8.RuneCountInString(list[0]), rune(esc), nil |
||||
} |
||||
|
||||
direction := tabForward |
||||
tabPrinter := s.circularTabs(list) |
||||
if s.tabStyle == TabPrints { |
||||
tabPrinter = s.printedTabs(list) |
||||
} |
||||
|
||||
for { |
||||
pick, err := tabPrinter(direction) |
||||
if err != nil { |
||||
return line, pos, rune(esc), err |
||||
} |
||||
s.refresh(p, []rune(head+pick+tail), hl+utf8.RuneCountInString(pick)) |
||||
|
||||
next, err := s.readNext() |
||||
if err != nil { |
||||
return line, pos, rune(esc), err |
||||
} |
||||
if key, ok := next.(rune); ok { |
||||
if key == tab { |
||||
direction = tabForward |
||||
continue |
||||
} |
||||
if key == esc { |
||||
return line, pos, rune(esc), nil |
||||
} |
||||
} |
||||
if a, ok := next.(action); ok && a == shiftTab { |
||||
direction = tabReverse |
||||
continue |
||||
} |
||||
return []rune(head + pick + tail), hl + utf8.RuneCountInString(pick), next, nil |
||||
} |
||||
// Not reached
|
||||
return line, pos, rune(esc), nil |
||||
} |
||||
|
||||
// reverse intelligent search, implements a bash-like history search.
|
||||
func (s *State) reverseISearch(origLine []rune, origPos int) ([]rune, int, interface{}, error) { |
||||
p := "(reverse-i-search)`': " |
||||
s.refresh([]rune(p), origLine, origPos) |
||||
|
||||
line := []rune{} |
||||
pos := 0 |
||||
foundLine := string(origLine) |
||||
foundPos := origPos |
||||
|
||||
getLine := func() ([]rune, []rune, int) { |
||||
search := string(line) |
||||
prompt := "(reverse-i-search)`%s': " |
||||
return []rune(fmt.Sprintf(prompt, search)), []rune(foundLine), foundPos |
||||
} |
||||
|
||||
history, positions := s.getHistoryByPattern(string(line)) |
||||
historyPos := len(history) - 1 |
||||
|
||||
for { |
||||
next, err := s.readNext() |
||||
if err != nil { |
||||
return []rune(foundLine), foundPos, rune(esc), err |
||||
} |
||||
|
||||
switch v := next.(type) { |
||||
case rune: |
||||
switch v { |
||||
case ctrlR: // Search backwards
|
||||
if historyPos > 0 && historyPos < len(history) { |
||||
historyPos-- |
||||
foundLine = history[historyPos] |
||||
foundPos = positions[historyPos] |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case ctrlS: // Search forward
|
||||
if historyPos < len(history)-1 && historyPos >= 0 { |
||||
historyPos++ |
||||
foundLine = history[historyPos] |
||||
foundPos = positions[historyPos] |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case ctrlH, bs: // Backspace
|
||||
if pos <= 0 { |
||||
fmt.Print(beep) |
||||
} else { |
||||
n := len(getSuffixGlyphs(line[:pos], 1)) |
||||
line = append(line[:pos-n], line[pos:]...) |
||||
pos -= n |
||||
|
||||
// For each char deleted, display the last matching line of history
|
||||
history, positions := s.getHistoryByPattern(string(line)) |
||||
historyPos = len(history) - 1 |
||||
if len(history) > 0 { |
||||
foundLine = history[historyPos] |
||||
foundPos = positions[historyPos] |
||||
} else { |
||||
foundLine = "" |
||||
foundPos = 0 |
||||
} |
||||
} |
||||
case ctrlG: // Cancel
|
||||
return origLine, origPos, rune(esc), err |
||||
|
||||
case tab, cr, lf, ctrlA, ctrlB, ctrlD, ctrlE, ctrlF, ctrlK, |
||||
ctrlL, ctrlN, ctrlO, ctrlP, ctrlQ, ctrlT, ctrlU, ctrlV, ctrlW, ctrlX, ctrlY, ctrlZ: |
||||
fallthrough |
||||
case 0, ctrlC, esc, 28, 29, 30, 31: |
||||
return []rune(foundLine), foundPos, next, err |
||||
default: |
||||
line = append(line[:pos], append([]rune{v}, line[pos:]...)...) |
||||
pos++ |
||||
|
||||
// For each keystroke typed, display the last matching line of history
|
||||
history, positions = s.getHistoryByPattern(string(line)) |
||||
historyPos = len(history) - 1 |
||||
if len(history) > 0 { |
||||
foundLine = history[historyPos] |
||||
foundPos = positions[historyPos] |
||||
} else { |
||||
foundLine = "" |
||||
foundPos = 0 |
||||
} |
||||
} |
||||
case action: |
||||
return []rune(foundLine), foundPos, next, err |
||||
} |
||||
s.refresh(getLine()) |
||||
} |
||||
} |
||||
|
||||
// addToKillRing adds some text to the kill ring. If mode is 0 it adds it to a
|
||||
// new node in the end of the kill ring, and move the current pointer to the new
|
||||
// node. If mode is 1 or 2 it appends or prepends the text to the current entry
|
||||
// of the killRing.
|
||||
func (s *State) addToKillRing(text []rune, mode int) { |
||||
// Don't use the same underlying array as text
|
||||
killLine := make([]rune, len(text)) |
||||
copy(killLine, text) |
||||
|
||||
// Point killRing to a newNode, procedure depends on the killring state and
|
||||
// append mode.
|
||||
if mode == 0 { // Add new node to killRing
|
||||
if s.killRing == nil { // if killring is empty, create a new one
|
||||
s.killRing = ring.New(1) |
||||
} else if s.killRing.Len() >= KillRingMax { // if killring is "full"
|
||||
s.killRing = s.killRing.Next() |
||||
} else { // Normal case
|
||||
s.killRing.Link(ring.New(1)) |
||||
s.killRing = s.killRing.Next() |
||||
} |
||||
} else { |
||||
if s.killRing == nil { // if killring is empty, create a new one
|
||||
s.killRing = ring.New(1) |
||||
s.killRing.Value = []rune{} |
||||
} |
||||
if mode == 1 { // Append to last entry
|
||||
killLine = append(s.killRing.Value.([]rune), killLine...) |
||||
} else if mode == 2 { // Prepend to last entry
|
||||
killLine = append(killLine, s.killRing.Value.([]rune)...) |
||||
} |
||||
} |
||||
|
||||
// Save text in the current killring node
|
||||
s.killRing.Value = killLine |
||||
} |
||||
|
||||
func (s *State) yank(p []rune, text []rune, pos int) ([]rune, int, interface{}, error) { |
||||
if s.killRing == nil { |
||||
return text, pos, rune(esc), nil |
||||
} |
||||
|
||||
lineStart := text[:pos] |
||||
lineEnd := text[pos:] |
||||
var line []rune |
||||
|
||||
for { |
||||
value := s.killRing.Value.([]rune) |
||||
line = make([]rune, 0) |
||||
line = append(line, lineStart...) |
||||
line = append(line, value...) |
||||
line = append(line, lineEnd...) |
||||
|
||||
pos = len(lineStart) + len(value) |
||||
s.refresh(p, line, pos) |
||||
|
||||
next, err := s.readNext() |
||||
if err != nil { |
||||
return line, pos, next, err |
||||
} |
||||
|
||||
switch v := next.(type) { |
||||
case rune: |
||||
return line, pos, next, nil |
||||
case action: |
||||
switch v { |
||||
case altY: |
||||
s.killRing = s.killRing.Prev() |
||||
default: |
||||
return line, pos, next, nil |
||||
} |
||||
} |
||||
} |
||||
|
||||
return line, pos, esc, nil |
||||
} |
||||
|
||||
// Prompt displays p, and then waits for user input. Prompt allows line editing
|
||||
// if the terminal supports it.
|
||||
func (s *State) Prompt(prompt string) (string, error) { |
||||
if s.inputRedirected { |
||||
return s.promptUnsupported(prompt) |
||||
} |
||||
if s.outputRedirected { |
||||
return "", ErrNotTerminalOutput |
||||
} |
||||
if !s.terminalSupported { |
||||
return s.promptUnsupported(prompt) |
||||
} |
||||
|
||||
s.historyMutex.RLock() |
||||
defer s.historyMutex.RUnlock() |
||||
|
||||
s.startPrompt() |
||||
defer s.stopPrompt() |
||||
s.getColumns() |
||||
|
||||
fmt.Print(prompt) |
||||
p := []rune(prompt) |
||||
var line []rune |
||||
pos := 0 |
||||
historyEnd := "" |
||||
prefixHistory := s.getHistoryByPrefix(string(line)) |
||||
historyPos := len(prefixHistory) |
||||
historyAction := false // used to mark history related actions
|
||||
killAction := 0 // used to mark kill related actions
|
||||
mainLoop: |
||||
for { |
||||
next, err := s.readNext() |
||||
haveNext: |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
historyAction = false |
||||
switch v := next.(type) { |
||||
case rune: |
||||
switch v { |
||||
case cr, lf: |
||||
fmt.Println() |
||||
break mainLoop |
||||
case ctrlA: // Start of line
|
||||
pos = 0 |
||||
s.refresh(p, line, pos) |
||||
case ctrlE: // End of line
|
||||
pos = len(line) |
||||
s.refresh(p, line, pos) |
||||
case ctrlB: // left
|
||||
if pos > 0 { |
||||
pos -= len(getSuffixGlyphs(line[:pos], 1)) |
||||
s.refresh(p, line, pos) |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case ctrlF: // right
|
||||
if pos < len(line) { |
||||
pos += len(getPrefixGlyphs(line[pos:], 1)) |
||||
s.refresh(p, line, pos) |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case ctrlD: // del
|
||||
if pos == 0 && len(line) == 0 { |
||||
// exit
|
||||
return "", io.EOF |
||||
} |
||||
|
||||
// ctrlD is a potential EOF, so the rune reader shuts down.
|
||||
// Therefore, if it isn't actually an EOF, we must re-startPrompt.
|
||||
s.restartPrompt() |
||||
|
||||
if pos >= len(line) { |
||||
fmt.Print(beep) |
||||
} else { |
||||
n := len(getPrefixGlyphs(line[pos:], 1)) |
||||
line = append(line[:pos], line[pos+n:]...) |
||||
s.refresh(p, line, pos) |
||||
} |
||||
case ctrlK: // delete remainder of line
|
||||
if pos >= len(line) { |
||||
fmt.Print(beep) |
||||
} else { |
||||
if killAction > 0 { |
||||
s.addToKillRing(line[pos:], 1) // Add in apend mode
|
||||
} else { |
||||
s.addToKillRing(line[pos:], 0) // Add in normal mode
|
||||
} |
||||
|
||||
killAction = 2 // Mark that there was a kill action
|
||||
line = line[:pos] |
||||
s.refresh(p, line, pos) |
||||
} |
||||
case ctrlP: // up
|
||||
historyAction = true |
||||
if historyPos > 0 { |
||||
if historyPos == len(prefixHistory) { |
||||
historyEnd = string(line) |
||||
} |
||||
historyPos-- |
||||
line = []rune(prefixHistory[historyPos]) |
||||
pos = len(line) |
||||
s.refresh(p, line, pos) |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case ctrlN: // down
|
||||
historyAction = true |
||||
if historyPos < len(prefixHistory) { |
||||
historyPos++ |
||||
if historyPos == len(prefixHistory) { |
||||
line = []rune(historyEnd) |
||||
} else { |
||||
line = []rune(prefixHistory[historyPos]) |
||||
} |
||||
pos = len(line) |
||||
s.refresh(p, line, pos) |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case ctrlT: // transpose prev glyph with glyph under cursor
|
||||
if len(line) < 2 || pos < 1 { |
||||
fmt.Print(beep) |
||||
} else { |
||||
if pos == len(line) { |
||||
pos -= len(getSuffixGlyphs(line, 1)) |
||||
} |
||||
prev := getSuffixGlyphs(line[:pos], 1) |
||||
next := getPrefixGlyphs(line[pos:], 1) |
||||
scratch := make([]rune, len(prev)) |
||||
copy(scratch, prev) |
||||
copy(line[pos-len(prev):], next) |
||||
copy(line[pos-len(prev)+len(next):], scratch) |
||||
pos += len(next) |
||||
s.refresh(p, line, pos) |
||||
} |
||||
case ctrlL: // clear screen
|
||||
s.eraseScreen() |
||||
s.refresh(p, line, pos) |
||||
case ctrlC: // reset
|
||||
fmt.Println("^C") |
||||
if s.ctrlCAborts { |
||||
return "", ErrPromptAborted |
||||
} |
||||
line = line[:0] |
||||
pos = 0 |
||||
fmt.Print(prompt) |
||||
s.restartPrompt() |
||||
case ctrlH, bs: // Backspace
|
||||
if pos <= 0 { |
||||
fmt.Print(beep) |
||||
} else { |
||||
n := len(getSuffixGlyphs(line[:pos], 1)) |
||||
line = append(line[:pos-n], line[pos:]...) |
||||
pos -= n |
||||
s.refresh(p, line, pos) |
||||
} |
||||
case ctrlU: // Erase line before cursor
|
||||
if killAction > 0 { |
||||
s.addToKillRing(line[:pos], 2) // Add in prepend mode
|
||||
} else { |
||||
s.addToKillRing(line[:pos], 0) // Add in normal mode
|
||||
} |
||||
|
||||
killAction = 2 // Mark that there was some killing
|
||||
line = line[pos:] |
||||
pos = 0 |
||||
s.refresh(p, line, pos) |
||||
case ctrlW: // Erase word
|
||||
if pos == 0 { |
||||
fmt.Print(beep) |
||||
break |
||||
} |
||||
// Remove whitespace to the left
|
||||
var buf []rune // Store the deleted chars in a buffer
|
||||
for { |
||||
if pos == 0 || !unicode.IsSpace(line[pos-1]) { |
||||
break |
||||
} |
||||
buf = append(buf, line[pos-1]) |
||||
line = append(line[:pos-1], line[pos:]...) |
||||
pos-- |
||||
} |
||||
// Remove non-whitespace to the left
|
||||
for { |
||||
if pos == 0 || unicode.IsSpace(line[pos-1]) { |
||||
break |
||||
} |
||||
buf = append(buf, line[pos-1]) |
||||
line = append(line[:pos-1], line[pos:]...) |
||||
pos-- |
||||
} |
||||
// Invert the buffer and save the result on the killRing
|
||||
var newBuf []rune |
||||
for i := len(buf) - 1; i >= 0; i-- { |
||||
newBuf = append(newBuf, buf[i]) |
||||
} |
||||
if killAction > 0 { |
||||
s.addToKillRing(newBuf, 2) // Add in prepend mode
|
||||
} else { |
||||
s.addToKillRing(newBuf, 0) // Add in normal mode
|
||||
} |
||||
killAction = 2 // Mark that there was some killing
|
||||
|
||||
s.refresh(p, line, pos) |
||||
case ctrlY: // Paste from Yank buffer
|
||||
line, pos, next, err = s.yank(p, line, pos) |
||||
goto haveNext |
||||
case ctrlR: // Reverse Search
|
||||
line, pos, next, err = s.reverseISearch(line, pos) |
||||
s.refresh(p, line, pos) |
||||
goto haveNext |
||||
case tab: // Tab completion
|
||||
line, pos, next, err = s.tabComplete(p, line, pos) |
||||
goto haveNext |
||||
// Catch keys that do nothing, but you don't want them to beep
|
||||
case esc: |
||||
// DO NOTHING
|
||||
// Unused keys
|
||||
case ctrlG, ctrlO, ctrlQ, ctrlS, ctrlV, ctrlX, ctrlZ: |
||||
fallthrough |
||||
// Catch unhandled control codes (anything <= 31)
|
||||
case 0, 28, 29, 30, 31: |
||||
fmt.Print(beep) |
||||
default: |
||||
if pos == len(line) && len(p)+len(line) < s.columns-1 { |
||||
line = append(line, v) |
||||
fmt.Printf("%c", v) |
||||
pos++ |
||||
} else { |
||||
line = append(line[:pos], append([]rune{v}, line[pos:]...)...) |
||||
pos++ |
||||
s.refresh(p, line, pos) |
||||
} |
||||
} |
||||
case action: |
||||
switch v { |
||||
case del: |
||||
if pos >= len(line) { |
||||
fmt.Print(beep) |
||||
} else { |
||||
n := len(getPrefixGlyphs(line[pos:], 1)) |
||||
line = append(line[:pos], line[pos+n:]...) |
||||
} |
||||
case left: |
||||
if pos > 0 { |
||||
pos -= len(getSuffixGlyphs(line[:pos], 1)) |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case wordLeft: |
||||
if pos > 0 { |
||||
for { |
||||
pos-- |
||||
if pos == 0 || unicode.IsSpace(line[pos-1]) { |
||||
break |
||||
} |
||||
} |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case right: |
||||
if pos < len(line) { |
||||
pos += len(getPrefixGlyphs(line[pos:], 1)) |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case wordRight: |
||||
if pos < len(line) { |
||||
for { |
||||
pos++ |
||||
if pos == len(line) || unicode.IsSpace(line[pos]) { |
||||
break |
||||
} |
||||
} |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case up: |
||||
historyAction = true |
||||
if historyPos > 0 { |
||||
if historyPos == len(prefixHistory) { |
||||
historyEnd = string(line) |
||||
} |
||||
historyPos-- |
||||
line = []rune(prefixHistory[historyPos]) |
||||
pos = len(line) |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case down: |
||||
historyAction = true |
||||
if historyPos < len(prefixHistory) { |
||||
historyPos++ |
||||
if historyPos == len(prefixHistory) { |
||||
line = []rune(historyEnd) |
||||
} else { |
||||
line = []rune(prefixHistory[historyPos]) |
||||
} |
||||
pos = len(line) |
||||
} else { |
||||
fmt.Print(beep) |
||||
} |
||||
case home: // Start of line
|
||||
pos = 0 |
||||
case end: // End of line
|
||||
pos = len(line) |
||||
} |
||||
s.refresh(p, line, pos) |
||||
} |
||||
if !historyAction { |
||||
prefixHistory = s.getHistoryByPrefix(string(line)) |
||||
historyPos = len(prefixHistory) |
||||
} |
||||
if killAction > 0 { |
||||
killAction-- |
||||
} |
||||
} |
||||
return string(line), nil |
||||
} |
||||
|
||||
// PasswordPrompt displays p, and then waits for user input. The input typed by
|
||||
// the user is not displayed in the terminal.
|
||||
func (s *State) PasswordPrompt(prompt string) (string, error) { |
||||
if s.inputRedirected { |
||||
return s.promptUnsupported(prompt) |
||||
} |
||||
if s.outputRedirected { |
||||
return "", ErrNotTerminalOutput |
||||
} |
||||
if !s.terminalSupported { |
||||
return "", errors.New("liner: function not supported in this terminal") |
||||
} |
||||
|
||||
s.startPrompt() |
||||
defer s.stopPrompt() |
||||
s.getColumns() |
||||
|
||||
fmt.Print(prompt) |
||||
p := []rune(prompt) |
||||
var line []rune |
||||
pos := 0 |
||||
|
||||
mainLoop: |
||||
for { |
||||
next, err := s.readNext() |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
switch v := next.(type) { |
||||
case rune: |
||||
switch v { |
||||
case cr, lf: |
||||
fmt.Println() |
||||
break mainLoop |
||||
case ctrlD: // del
|
||||
if pos == 0 && len(line) == 0 { |
||||
// exit
|
||||
return "", io.EOF |
||||
} |
||||
|
||||
// ctrlD is a potential EOF, so the rune reader shuts down.
|
||||
// Therefore, if it isn't actually an EOF, we must re-startPrompt.
|
||||
s.restartPrompt() |
||||
case ctrlL: // clear screen
|
||||
s.eraseScreen() |
||||
s.refresh(p, []rune{}, 0) |
||||
case ctrlH, bs: // Backspace
|
||||
if pos <= 0 { |
||||
fmt.Print(beep) |
||||
} else { |
||||
n := len(getSuffixGlyphs(line[:pos], 1)) |
||||
line = append(line[:pos-n], line[pos:]...) |
||||
pos -= n |
||||
} |
||||
case ctrlC: |
||||
fmt.Println("^C") |
||||
if s.ctrlCAborts { |
||||
return "", ErrPromptAborted |
||||
} |
||||
line = line[:0] |
||||
pos = 0 |
||||
fmt.Print(prompt) |
||||
s.restartPrompt() |
||||
// Unused keys
|
||||
case esc, tab, ctrlA, ctrlB, ctrlE, ctrlF, ctrlG, ctrlK, ctrlN, ctrlO, ctrlP, ctrlQ, ctrlR, ctrlS, |
||||
ctrlT, ctrlU, ctrlV, ctrlW, ctrlX, ctrlY, ctrlZ: |
||||
fallthrough |
||||
// Catch unhandled control codes (anything <= 31)
|
||||
case 0, 28, 29, 30, 31: |
||||
fmt.Print(beep) |
||||
default: |
||||
line = append(line[:pos], append([]rune{v}, line[pos:]...)...) |
||||
pos++ |
||||
} |
||||
} |
||||
} |
||||
return string(line), nil |
||||
} |
@ -0,0 +1,90 @@ |
||||
package liner |
||||
|
||||
import ( |
||||
"bytes" |
||||
"strings" |
||||
"testing" |
||||
) |
||||
|
||||
func TestAppend(t *testing.T) { |
||||
var s State |
||||
s.AppendHistory("foo") |
||||
s.AppendHistory("bar") |
||||
|
||||
var out bytes.Buffer |
||||
num, err := s.WriteHistory(&out) |
||||
if err != nil { |
||||
t.Fatal("Unexpected error writing history", err) |
||||
} |
||||
if num != 2 { |
||||
t.Fatalf("Expected 2 history entries, got %d", num) |
||||
} |
||||
|
||||
s.AppendHistory("baz") |
||||
num, err = s.WriteHistory(&out) |
||||
if err != nil { |
||||
t.Fatal("Unexpected error writing history", err) |
||||
} |
||||
if num != 3 { |
||||
t.Fatalf("Expected 3 history entries, got %d", num) |
||||
} |
||||
|
||||
s.AppendHistory("baz") |
||||
num, err = s.WriteHistory(&out) |
||||
if err != nil { |
||||
t.Fatal("Unexpected error writing history", err) |
||||
} |
||||
if num != 3 { |
||||
t.Fatalf("Expected 3 history entries after duplicate append, got %d", num) |
||||
} |
||||
|
||||
s.AppendHistory("baz") |
||||
|
||||
} |
||||
|
||||
func TestHistory(t *testing.T) { |
||||
input := `foo |
||||
bar |
||||
baz |
||||
quux |
||||
dingle` |
||||
|
||||
var s State |
||||
num, err := s.ReadHistory(strings.NewReader(input)) |
||||
if err != nil { |
||||
t.Fatal("Unexpected error reading history", err) |
||||
} |
||||
if num != 5 { |
||||
t.Fatal("Wrong number of history entries read") |
||||
} |
||||
|
||||
var out bytes.Buffer |
||||
num, err = s.WriteHistory(&out) |
||||
if err != nil { |
||||
t.Fatal("Unexpected error writing history", err) |
||||
} |
||||
if num != 5 { |
||||
t.Fatal("Wrong number of history entries written") |
||||
} |
||||
if strings.TrimSpace(out.String()) != input { |
||||
t.Fatal("Round-trip failure") |
||||
} |
||||
|
||||
// Test reading with a trailing newline present
|
||||
var s2 State |
||||
num, err = s2.ReadHistory(&out) |
||||
if err != nil { |
||||
t.Fatal("Unexpected error reading history the 2nd time", err) |
||||
} |
||||
if num != 5 { |
||||
t.Fatal("Wrong number of history entries read the 2nd time") |
||||
} |
||||
|
||||
num, err = s.ReadHistory(strings.NewReader(input + "\n\xff")) |
||||
if err == nil { |
||||
t.Fatal("Unexpected success reading corrupted history", err) |
||||
} |
||||
if num != 5 { |
||||
t.Fatal("Wrong number of history entries read the 3rd time") |
||||
} |
||||
} |
@ -0,0 +1,63 @@ |
||||
// +build linux darwin openbsd freebsd netbsd
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"strings" |
||||
"syscall" |
||||
"unsafe" |
||||
) |
||||
|
||||
func (s *State) cursorPos(x int) { |
||||
if s.useCHA { |
||||
// 'G' is "Cursor Character Absolute (CHA)"
|
||||
fmt.Printf("\x1b[%dG", x+1) |
||||
} else { |
||||
// 'C' is "Cursor Forward (CUF)"
|
||||
fmt.Print("\r") |
||||
if x > 0 { |
||||
fmt.Printf("\x1b[%dC", x) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (s *State) eraseLine() { |
||||
fmt.Print("\x1b[0K") |
||||
} |
||||
|
||||
func (s *State) eraseScreen() { |
||||
fmt.Print("\x1b[H\x1b[2J") |
||||
} |
||||
|
||||
type winSize struct { |
||||
row, col uint16 |
||||
xpixel, ypixel uint16 |
||||
} |
||||
|
||||
func (s *State) getColumns() { |
||||
var ws winSize |
||||
ok, _, _ := syscall.Syscall(syscall.SYS_IOCTL, uintptr(syscall.Stdout), |
||||
syscall.TIOCGWINSZ, uintptr(unsafe.Pointer(&ws))) |
||||
if ok < 0 { |
||||
s.columns = 80 |
||||
} |
||||
s.columns = int(ws.col) |
||||
} |
||||
|
||||
func (s *State) checkOutput() { |
||||
// xterm is known to support CHA
|
||||
if strings.Contains(strings.ToLower(os.Getenv("TERM")), "xterm") { |
||||
s.useCHA = true |
||||
return |
||||
} |
||||
|
||||
// The test for functional ANSI CHA is unreliable (eg the Windows
|
||||
// telnet command does not support reading the cursor position with
|
||||
// an ANSI DSR request, despite setting TERM=ansi)
|
||||
|
||||
// Assume CHA isn't supported (which should be safe, although it
|
||||
// does result in occasional visible cursor jitter)
|
||||
s.useCHA = false |
||||
} |
@ -0,0 +1,54 @@ |
||||
package liner |
||||
|
||||
import ( |
||||
"unsafe" |
||||
) |
||||
|
||||
type coord struct { |
||||
x, y int16 |
||||
} |
||||
type smallRect struct { |
||||
left, top, right, bottom int16 |
||||
} |
||||
|
||||
type consoleScreenBufferInfo struct { |
||||
dwSize coord |
||||
dwCursorPosition coord |
||||
wAttributes int16 |
||||
srWindow smallRect |
||||
dwMaximumWindowSize coord |
||||
} |
||||
|
||||
func (s *State) cursorPos(x int) { |
||||
var sbi consoleScreenBufferInfo |
||||
procGetConsoleScreenBufferInfo.Call(uintptr(s.hOut), uintptr(unsafe.Pointer(&sbi))) |
||||
procSetConsoleCursorPosition.Call(uintptr(s.hOut), |
||||
uintptr(int(x)&0xFFFF|int(sbi.dwCursorPosition.y)<<16)) |
||||
} |
||||
|
||||
func (s *State) eraseLine() { |
||||
var sbi consoleScreenBufferInfo |
||||
procGetConsoleScreenBufferInfo.Call(uintptr(s.hOut), uintptr(unsafe.Pointer(&sbi))) |
||||
var numWritten uint32 |
||||
procFillConsoleOutputCharacter.Call(uintptr(s.hOut), uintptr(' '), |
||||
uintptr(sbi.dwSize.x-sbi.dwCursorPosition.x), |
||||
uintptr(int(sbi.dwCursorPosition.x)&0xFFFF|int(sbi.dwCursorPosition.y)<<16), |
||||
uintptr(unsafe.Pointer(&numWritten))) |
||||
} |
||||
|
||||
func (s *State) eraseScreen() { |
||||
var sbi consoleScreenBufferInfo |
||||
procGetConsoleScreenBufferInfo.Call(uintptr(s.hOut), uintptr(unsafe.Pointer(&sbi))) |
||||
var numWritten uint32 |
||||
procFillConsoleOutputCharacter.Call(uintptr(s.hOut), uintptr(' '), |
||||
uintptr(sbi.dwSize.x)*uintptr(sbi.dwSize.y), |
||||
0, |
||||
uintptr(unsafe.Pointer(&numWritten))) |
||||
procSetConsoleCursorPosition.Call(uintptr(s.hOut), 0) |
||||
} |
||||
|
||||
func (s *State) getColumns() { |
||||
var sbi consoleScreenBufferInfo |
||||
procGetConsoleScreenBufferInfo.Call(uintptr(s.hOut), uintptr(unsafe.Pointer(&sbi))) |
||||
s.columns = int(sbi.dwSize.x) |
||||
} |
@ -0,0 +1,37 @@ |
||||
// +build windows linux darwin openbsd freebsd netbsd
|
||||
|
||||
package liner |
||||
|
||||
import "testing" |
||||
|
||||
type testItem struct { |
||||
list []string |
||||
prefix string |
||||
} |
||||
|
||||
func TestPrefix(t *testing.T) { |
||||
list := []testItem{ |
||||
{[]string{"food", "foot"}, "foo"}, |
||||
{[]string{"foo", "foot"}, "foo"}, |
||||
{[]string{"food", "foo"}, "foo"}, |
||||
{[]string{"food", "foe", "foot"}, "fo"}, |
||||
{[]string{"food", "foot", "barbeque"}, ""}, |
||||
{[]string{"cafeteria", "café"}, "caf"}, |
||||
{[]string{"cafe", "café"}, "caf"}, |
||||
{[]string{"cafè", "café"}, "caf"}, |
||||
{[]string{"cafés", "café"}, "café"}, |
||||
{[]string{"áéíóú", "áéíóú"}, "áéíóú"}, |
||||
{[]string{"éclairs", "éclairs"}, "éclairs"}, |
||||
{[]string{"éclairs are the best", "éclairs are great", "éclairs"}, "éclairs"}, |
||||
{[]string{"éclair", "éclairs"}, "éclair"}, |
||||
{[]string{"éclairs", "éclair"}, "éclair"}, |
||||
{[]string{"éclair", "élan"}, "é"}, |
||||
} |
||||
|
||||
for _, test := range list { |
||||
lcp := longestCommonPrefix(test.list) |
||||
if lcp != test.prefix { |
||||
t.Errorf("%s != %s for %+v", lcp, test.prefix, test.list) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,44 @@ |
||||
// +build race
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"io/ioutil" |
||||
"os" |
||||
"sync" |
||||
"testing" |
||||
) |
||||
|
||||
func TestWriteHistory(t *testing.T) { |
||||
oldout := os.Stdout |
||||
defer func() { os.Stdout = oldout }() |
||||
oldin := os.Stdout |
||||
defer func() { os.Stdin = oldin }() |
||||
|
||||
newinr, newinw, err := os.Pipe() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
os.Stdin = newinr |
||||
newoutr, newoutw, err := os.Pipe() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
defer newoutr.Close() |
||||
os.Stdout = newoutw |
||||
|
||||
var wait sync.WaitGroup |
||||
wait.Add(1) |
||||
s := NewLiner() |
||||
go func() { |
||||
s.AppendHistory("foo") |
||||
s.AppendHistory("bar") |
||||
s.Prompt("") |
||||
wait.Done() |
||||
}() |
||||
|
||||
s.WriteHistory(ioutil.Discard) |
||||
|
||||
newinw.Close() |
||||
wait.Wait() |
||||
} |
@ -0,0 +1,12 @@ |
||||
// +build go1.1,!windows
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"os" |
||||
"os/signal" |
||||
) |
||||
|
||||
func stopSignal(c chan<- os.Signal) { |
||||
signal.Stop(c) |
||||
} |
@ -0,0 +1,11 @@ |
||||
// +build !go1.1,!windows
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"os" |
||||
) |
||||
|
||||
func stopSignal(c chan<- os.Signal) { |
||||
// signal.Stop does not exist before Go 1.1
|
||||
} |
@ -0,0 +1,37 @@ |
||||
// +build linux darwin freebsd openbsd netbsd
|
||||
|
||||
package liner |
||||
|
||||
import ( |
||||
"syscall" |
||||
"unsafe" |
||||
) |
||||
|
||||
func (mode *termios) ApplyMode() error { |
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(syscall.Stdin), setTermios, uintptr(unsafe.Pointer(mode))) |
||||
|
||||
if errno != 0 { |
||||
return errno |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// TerminalMode returns the current terminal input mode as an InputModeSetter.
|
||||
//
|
||||
// This function is provided for convenience, and should
|
||||
// not be necessary for most users of liner.
|
||||
func TerminalMode() (ModeApplier, error) { |
||||
mode, errno := getMode(syscall.Stdin) |
||||
|
||||
if errno != 0 { |
||||
return nil, errno |
||||
} |
||||
return mode, nil |
||||
} |
||||
|
||||
func getMode(handle int) (*termios, syscall.Errno) { |
||||
var mode termios |
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(handle), getTermios, uintptr(unsafe.Pointer(&mode))) |
||||
|
||||
return &mode, errno |
||||
} |
@ -0,0 +1,47 @@ |
||||
package liner |
||||
|
||||
import "unicode" |
||||
|
||||
// These character classes are mostly zero width (when combined).
|
||||
// A few might not be, depending on the user's font. Fixing this
|
||||
// is non-trivial, given that some terminals don't support
|
||||
// ANSI DSR/CPR
|
||||
var zeroWidth = []*unicode.RangeTable{ |
||||
unicode.Mn, |
||||
unicode.Me, |
||||
unicode.Cc, |
||||
unicode.Cf, |
||||
} |
||||
|
||||
func countGlyphs(s []rune) int { |
||||
n := 0 |
||||
for _, r := range s { |
||||
if !unicode.IsOneOf(zeroWidth, r) { |
||||
n++ |
||||
} |
||||
} |
||||
return n |
||||
} |
||||
|
||||
func getPrefixGlyphs(s []rune, num int) []rune { |
||||
p := 0 |
||||
for n := 0; n < num && p < len(s); p++ { |
||||
if !unicode.IsOneOf(zeroWidth, s[p]) { |
||||
n++ |
||||
} |
||||
} |
||||
for p < len(s) && unicode.IsOneOf(zeroWidth, s[p]) { |
||||
p++ |
||||
} |
||||
return s[:p] |
||||
} |
||||
|
||||
func getSuffixGlyphs(s []rune, num int) []rune { |
||||
p := len(s) |
||||
for n := 0; n < num && p > 0; p-- { |
||||
if !unicode.IsOneOf(zeroWidth, s[p-1]) { |
||||
n++ |
||||
} |
||||
} |
||||
return s[p:] |
||||
} |
@ -0,0 +1,87 @@ |
||||
package liner |
||||
|
||||
import ( |
||||
"strconv" |
||||
"testing" |
||||
) |
||||
|
||||
func accent(in []rune) []rune { |
||||
var out []rune |
||||
for _, r := range in { |
||||
out = append(out, r) |
||||
out = append(out, '\u0301') |
||||
} |
||||
return out |
||||
} |
||||
|
||||
var testString = []rune("query") |
||||
|
||||
func TestCountGlyphs(t *testing.T) { |
||||
count := countGlyphs(testString) |
||||
if count != len(testString) { |
||||
t.Errorf("ASCII count incorrect. %d != %d", count, len(testString)) |
||||
} |
||||
count = countGlyphs(accent(testString)) |
||||
if count != len(testString) { |
||||
t.Errorf("Accent count incorrect. %d != %d", count, len(testString)) |
||||
} |
||||
} |
||||
|
||||
func compare(a, b []rune, name string, t *testing.T) { |
||||
if len(a) != len(b) { |
||||
t.Errorf(`"%s" != "%s" in %s"`, string(a), string(b), name) |
||||
return |
||||
} |
||||
for i := range a { |
||||
if a[i] != b[i] { |
||||
t.Errorf(`"%s" != "%s" in %s"`, string(a), string(b), name) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestPrefixGlyphs(t *testing.T) { |
||||
for i := 0; i <= len(testString); i++ { |
||||
iter := strconv.Itoa(i) |
||||
out := getPrefixGlyphs(testString, i) |
||||
compare(out, testString[:i], "ascii prefix "+iter, t) |
||||
out = getPrefixGlyphs(accent(testString), i) |
||||
compare(out, accent(testString[:i]), "accent prefix "+iter, t) |
||||
} |
||||
out := getPrefixGlyphs(testString, 999) |
||||
compare(out, testString, "ascii prefix overflow", t) |
||||
out = getPrefixGlyphs(accent(testString), 999) |
||||
compare(out, accent(testString), "accent prefix overflow", t) |
||||
|
||||
out = getPrefixGlyphs(testString, -3) |
||||
if len(out) != 0 { |
||||
t.Error("ascii prefix negative") |
||||
} |
||||
out = getPrefixGlyphs(accent(testString), -3) |
||||
if len(out) != 0 { |
||||
t.Error("accent prefix negative") |
||||
} |
||||
} |
||||
|
||||
func TestSuffixGlyphs(t *testing.T) { |
||||
for i := 0; i <= len(testString); i++ { |
||||
iter := strconv.Itoa(i) |
||||
out := getSuffixGlyphs(testString, i) |
||||
compare(out, testString[len(testString)-i:], "ascii suffix "+iter, t) |
||||
out = getSuffixGlyphs(accent(testString), i) |
||||
compare(out, accent(testString[len(testString)-i:]), "accent suffix "+iter, t) |
||||
} |
||||
out := getSuffixGlyphs(testString, 999) |
||||
compare(out, testString, "ascii suffix overflow", t) |
||||
out = getSuffixGlyphs(accent(testString), 999) |
||||
compare(out, accent(testString), "accent suffix overflow", t) |
||||
|
||||
out = getSuffixGlyphs(testString, -3) |
||||
if len(out) != 0 { |
||||
t.Error("ascii suffix negative") |
||||
} |
||||
out = getSuffixGlyphs(accent(testString), -3) |
||||
if len(out) != 0 { |
||||
t.Error("accent suffix negative") |
||||
} |
||||
} |
@ -1,112 +0,0 @@ |
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket |
||||
|
||||
import ( |
||||
"bufio" |
||||
"crypto/tls" |
||||
"io" |
||||
"net" |
||||
"net/http" |
||||
"net/url" |
||||
) |
||||
|
||||
// DialError is an error that occurs while dialling a websocket server.
|
||||
type DialError struct { |
||||
*Config |
||||
Err error |
||||
} |
||||
|
||||
func (e *DialError) Error() string { |
||||
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error() |
||||
} |
||||
|
||||
// NewConfig creates a new WebSocket config for client connection.
|
||||
func NewConfig(server, origin string) (config *Config, err error) { |
||||
config = new(Config) |
||||
config.Version = ProtocolVersionHybi13 |
||||
config.Location, err = url.ParseRequestURI(server) |
||||
if err != nil { |
||||
return |
||||
} |
||||
config.Origin, err = url.ParseRequestURI(origin) |
||||
if err != nil { |
||||
return |
||||
} |
||||
config.Header = http.Header(make(map[string][]string)) |
||||
return |
||||
} |
||||
|
||||
// NewClient creates a new WebSocket client connection over rwc.
|
||||
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) { |
||||
br := bufio.NewReader(rwc) |
||||
bw := bufio.NewWriter(rwc) |
||||
err = hybiClientHandshake(config, br, bw) |
||||
if err != nil { |
||||
return |
||||
} |
||||
buf := bufio.NewReadWriter(br, bw) |
||||
ws = newHybiClientConn(config, buf, rwc) |
||||
return |
||||
} |
||||
|
||||
// Dial opens a new client connection to a WebSocket.
|
||||
func Dial(url_, protocol, origin string) (ws *Conn, err error) { |
||||
config, err := NewConfig(url_, origin) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if protocol != "" { |
||||
config.Protocol = []string{protocol} |
||||
} |
||||
return DialConfig(config) |
||||
} |
||||
|
||||
var portMap = map[string]string{ |
||||
"ws": "80", |
||||
"wss": "443", |
||||
} |
||||
|
||||
func parseAuthority(location *url.URL) string { |
||||
if _, ok := portMap[location.Scheme]; ok { |
||||
if _, _, err := net.SplitHostPort(location.Host); err != nil { |
||||
return net.JoinHostPort(location.Host, portMap[location.Scheme]) |
||||
} |
||||
} |
||||
return location.Host |
||||
} |
||||
|
||||
// DialConfig opens a new client connection to a WebSocket with a config.
|
||||
func DialConfig(config *Config) (ws *Conn, err error) { |
||||
var client net.Conn |
||||
if config.Location == nil { |
||||
return nil, &DialError{config, ErrBadWebSocketLocation} |
||||
} |
||||
if config.Origin == nil { |
||||
return nil, &DialError{config, ErrBadWebSocketOrigin} |
||||
} |
||||
switch config.Location.Scheme { |
||||
case "ws": |
||||
client, err = net.Dial("tcp", parseAuthority(config.Location)) |
||||
|
||||
case "wss": |
||||
client, err = tls.Dial("tcp", parseAuthority(config.Location), config.TlsConfig) |
||||
|
||||
default: |
||||
err = ErrBadScheme |
||||
} |
||||
if err != nil { |
||||
goto Error |
||||
} |
||||
|
||||
ws, err = NewClient(config, client) |
||||
if err != nil { |
||||
goto Error |
||||
} |
||||
return |
||||
|
||||
Error: |
||||
return nil, &DialError{config, err} |
||||
} |
@ -1,31 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket_test |
||||
|
||||
import ( |
||||
"fmt" |
||||
"log" |
||||
|
||||
"golang.org/x/net/websocket" |
||||
) |
||||
|
||||
// This example demonstrates a trivial client.
|
||||
func ExampleDial() { |
||||
origin := "http://localhost/" |
||||
url := "ws://localhost:12345/ws" |
||||
ws, err := websocket.Dial(url, "", origin) |
||||
if err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
if _, err := ws.Write([]byte("hello, world!\n")); err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
var msg = make([]byte, 512) |
||||
var n int |
||||
if n, err = ws.Read(msg); err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
fmt.Printf("Received: %s.\n", msg[:n]) |
||||
} |
@ -1,26 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket_test |
||||
|
||||
import ( |
||||
"io" |
||||
"net/http" |
||||
|
||||
"golang.org/x/net/websocket" |
||||
) |
||||
|
||||
// Echo the data received on the WebSocket.
|
||||
func EchoServer(ws *websocket.Conn) { |
||||
io.Copy(ws, ws) |
||||
} |
||||
|
||||
// This example demonstrates a trivial echo server.
|
||||
func ExampleHandler() { |
||||
http.Handle("/echo", websocket.Handler(EchoServer)) |
||||
err := http.ListenAndServe(":12345", nil) |
||||
if err != nil { |
||||
panic("ListenAndServe: " + err.Error()) |
||||
} |
||||
} |
@ -1,564 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket |
||||
|
||||
// This file implements a protocol of hybi draft.
|
||||
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
|
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"crypto/rand" |
||||
"crypto/sha1" |
||||
"encoding/base64" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"net/http" |
||||
"net/url" |
||||
"strings" |
||||
) |
||||
|
||||
const ( |
||||
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" |
||||
|
||||
closeStatusNormal = 1000 |
||||
closeStatusGoingAway = 1001 |
||||
closeStatusProtocolError = 1002 |
||||
closeStatusUnsupportedData = 1003 |
||||
closeStatusFrameTooLarge = 1004 |
||||
closeStatusNoStatusRcvd = 1005 |
||||
closeStatusAbnormalClosure = 1006 |
||||
closeStatusBadMessageData = 1007 |
||||
closeStatusPolicyViolation = 1008 |
||||
closeStatusTooBigData = 1009 |
||||
closeStatusExtensionMismatch = 1010 |
||||
|
||||
maxControlFramePayloadLength = 125 |
||||
) |
||||
|
||||
var ( |
||||
ErrBadMaskingKey = &ProtocolError{"bad masking key"} |
||||
ErrBadPongMessage = &ProtocolError{"bad pong message"} |
||||
ErrBadClosingStatus = &ProtocolError{"bad closing status"} |
||||
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"} |
||||
ErrNotImplemented = &ProtocolError{"not implemented"} |
||||
|
||||
handshakeHeader = map[string]bool{ |
||||
"Host": true, |
||||
"Upgrade": true, |
||||
"Connection": true, |
||||
"Sec-Websocket-Key": true, |
||||
"Sec-Websocket-Origin": true, |
||||
"Sec-Websocket-Version": true, |
||||
"Sec-Websocket-Protocol": true, |
||||
"Sec-Websocket-Accept": true, |
||||
} |
||||
) |
||||
|
||||
// A hybiFrameHeader is a frame header as defined in hybi draft.
|
||||
type hybiFrameHeader struct { |
||||
Fin bool |
||||
Rsv [3]bool |
||||
OpCode byte |
||||
Length int64 |
||||
MaskingKey []byte |
||||
|
||||
data *bytes.Buffer |
||||
} |
||||
|
||||
// A hybiFrameReader is a reader for hybi frame.
|
||||
type hybiFrameReader struct { |
||||
reader io.Reader |
||||
|
||||
header hybiFrameHeader |
||||
pos int64 |
||||
length int |
||||
} |
||||
|
||||
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) { |
||||
n, err = frame.reader.Read(msg) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
if frame.header.MaskingKey != nil { |
||||
for i := 0; i < n; i++ { |
||||
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4] |
||||
frame.pos++ |
||||
} |
||||
} |
||||
return n, err |
||||
} |
||||
|
||||
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode } |
||||
|
||||
func (frame *hybiFrameReader) HeaderReader() io.Reader { |
||||
if frame.header.data == nil { |
||||
return nil |
||||
} |
||||
if frame.header.data.Len() == 0 { |
||||
return nil |
||||
} |
||||
return frame.header.data |
||||
} |
||||
|
||||
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil } |
||||
|
||||
func (frame *hybiFrameReader) Len() (n int) { return frame.length } |
||||
|
||||
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
|
||||
type hybiFrameReaderFactory struct { |
||||
*bufio.Reader |
||||
} |
||||
|
||||
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
|
||||
// See Section 5.2 Base Framing protocol for detail.
|
||||
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
|
||||
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) { |
||||
hybiFrame := new(hybiFrameReader) |
||||
frame = hybiFrame |
||||
var header []byte |
||||
var b byte |
||||
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
|
||||
b, err = buf.ReadByte() |
||||
if err != nil { |
||||
return |
||||
} |
||||
header = append(header, b) |
||||
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0 |
||||
for i := 0; i < 3; i++ { |
||||
j := uint(6 - i) |
||||
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0 |
||||
} |
||||
hybiFrame.header.OpCode = header[0] & 0x0f |
||||
|
||||
// Second byte. Mask/Payload len(7bits)
|
||||
b, err = buf.ReadByte() |
||||
if err != nil { |
||||
return |
||||
} |
||||
header = append(header, b) |
||||
mask := (b & 0x80) != 0 |
||||
b &= 0x7f |
||||
lengthFields := 0 |
||||
switch { |
||||
case b <= 125: // Payload length 7bits.
|
||||
hybiFrame.header.Length = int64(b) |
||||
case b == 126: // Payload length 7+16bits
|
||||
lengthFields = 2 |
||||
case b == 127: // Payload length 7+64bits
|
||||
lengthFields = 8 |
||||
} |
||||
for i := 0; i < lengthFields; i++ { |
||||
b, err = buf.ReadByte() |
||||
if err != nil { |
||||
return |
||||
} |
||||
header = append(header, b) |
||||
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b) |
||||
} |
||||
if mask { |
||||
// Masking key. 4 bytes.
|
||||
for i := 0; i < 4; i++ { |
||||
b, err = buf.ReadByte() |
||||
if err != nil { |
||||
return |
||||
} |
||||
header = append(header, b) |
||||
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b) |
||||
} |
||||
} |
||||
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length) |
||||
hybiFrame.header.data = bytes.NewBuffer(header) |
||||
hybiFrame.length = len(header) + int(hybiFrame.header.Length) |
||||
return |
||||
} |
||||
|
||||
// A HybiFrameWriter is a writer for hybi frame.
|
||||
type hybiFrameWriter struct { |
||||
writer *bufio.Writer |
||||
|
||||
header *hybiFrameHeader |
||||
} |
||||
|
||||
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) { |
||||
var header []byte |
||||
var b byte |
||||
if frame.header.Fin { |
||||
b |= 0x80 |
||||
} |
||||
for i := 0; i < 3; i++ { |
||||
if frame.header.Rsv[i] { |
||||
j := uint(6 - i) |
||||
b |= 1 << j |
||||
} |
||||
} |
||||
b |= frame.header.OpCode |
||||
header = append(header, b) |
||||
if frame.header.MaskingKey != nil { |
||||
b = 0x80 |
||||
} else { |
||||
b = 0 |
||||
} |
||||
lengthFields := 0 |
||||
length := len(msg) |
||||
switch { |
||||
case length <= 125: |
||||
b |= byte(length) |
||||
case length < 65536: |
||||
b |= 126 |
||||
lengthFields = 2 |
||||
default: |
||||
b |= 127 |
||||
lengthFields = 8 |
||||
} |
||||
header = append(header, b) |
||||
for i := 0; i < lengthFields; i++ { |
||||
j := uint((lengthFields - i - 1) * 8) |
||||
b = byte((length >> j) & 0xff) |
||||
header = append(header, b) |
||||
} |
||||
if frame.header.MaskingKey != nil { |
||||
if len(frame.header.MaskingKey) != 4 { |
||||
return 0, ErrBadMaskingKey |
||||
} |
||||
header = append(header, frame.header.MaskingKey...) |
||||
frame.writer.Write(header) |
||||
data := make([]byte, length) |
||||
for i := range data { |
||||
data[i] = msg[i] ^ frame.header.MaskingKey[i%4] |
||||
} |
||||
frame.writer.Write(data) |
||||
err = frame.writer.Flush() |
||||
return length, err |
||||
} |
||||
frame.writer.Write(header) |
||||
frame.writer.Write(msg) |
||||
err = frame.writer.Flush() |
||||
return length, err |
||||
} |
||||
|
||||
func (frame *hybiFrameWriter) Close() error { return nil } |
||||
|
||||
type hybiFrameWriterFactory struct { |
||||
*bufio.Writer |
||||
needMaskingKey bool |
||||
} |
||||
|
||||
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) { |
||||
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType} |
||||
if buf.needMaskingKey { |
||||
frameHeader.MaskingKey, err = generateMaskingKey() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil |
||||
} |
||||
|
||||
type hybiFrameHandler struct { |
||||
conn *Conn |
||||
payloadType byte |
||||
} |
||||
|
||||
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err error) { |
||||
if handler.conn.IsServerConn() { |
||||
// The client MUST mask all frames sent to the server.
|
||||
if frame.(*hybiFrameReader).header.MaskingKey == nil { |
||||
handler.WriteClose(closeStatusProtocolError) |
||||
return nil, io.EOF |
||||
} |
||||
} else { |
||||
// The server MUST NOT mask all frames.
|
||||
if frame.(*hybiFrameReader).header.MaskingKey != nil { |
||||
handler.WriteClose(closeStatusProtocolError) |
||||
return nil, io.EOF |
||||
} |
||||
} |
||||
if header := frame.HeaderReader(); header != nil { |
||||
io.Copy(ioutil.Discard, header) |
||||
} |
||||
switch frame.PayloadType() { |
||||
case ContinuationFrame: |
||||
frame.(*hybiFrameReader).header.OpCode = handler.payloadType |
||||
case TextFrame, BinaryFrame: |
||||
handler.payloadType = frame.PayloadType() |
||||
case CloseFrame: |
||||
return nil, io.EOF |
||||
case PingFrame: |
||||
pingMsg := make([]byte, maxControlFramePayloadLength) |
||||
n, err := io.ReadFull(frame, pingMsg) |
||||
if err != nil && err != io.ErrUnexpectedEOF { |
||||
return nil, err |
||||
} |
||||
io.Copy(ioutil.Discard, frame) |
||||
n, err = handler.WritePong(pingMsg[:n]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return nil, nil |
||||
case PongFrame: |
||||
return nil, ErrNotImplemented |
||||
} |
||||
return frame, nil |
||||
} |
||||
|
||||
func (handler *hybiFrameHandler) WriteClose(status int) (err error) { |
||||
handler.conn.wio.Lock() |
||||
defer handler.conn.wio.Unlock() |
||||
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
msg := make([]byte, 2) |
||||
binary.BigEndian.PutUint16(msg, uint16(status)) |
||||
_, err = w.Write(msg) |
||||
w.Close() |
||||
return err |
||||
} |
||||
|
||||
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) { |
||||
handler.conn.wio.Lock() |
||||
defer handler.conn.wio.Unlock() |
||||
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
n, err = w.Write(msg) |
||||
w.Close() |
||||
return n, err |
||||
} |
||||
|
||||
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
|
||||
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { |
||||
if buf == nil { |
||||
br := bufio.NewReader(rwc) |
||||
bw := bufio.NewWriter(rwc) |
||||
buf = bufio.NewReadWriter(br, bw) |
||||
} |
||||
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc, |
||||
frameReaderFactory: hybiFrameReaderFactory{buf.Reader}, |
||||
frameWriterFactory: hybiFrameWriterFactory{ |
||||
buf.Writer, request == nil}, |
||||
PayloadType: TextFrame, |
||||
defaultCloseStatus: closeStatusNormal} |
||||
ws.frameHandler = &hybiFrameHandler{conn: ws} |
||||
return ws |
||||
} |
||||
|
||||
// generateMaskingKey generates a masking key for a frame.
|
||||
func generateMaskingKey() (maskingKey []byte, err error) { |
||||
maskingKey = make([]byte, 4) |
||||
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil { |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
// generateNonce generates a nonce consisting of a randomly selected 16-byte
|
||||
// value that has been base64-encoded.
|
||||
func generateNonce() (nonce []byte) { |
||||
key := make([]byte, 16) |
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil { |
||||
panic(err) |
||||
} |
||||
nonce = make([]byte, 24) |
||||
base64.StdEncoding.Encode(nonce, key) |
||||
return |
||||
} |
||||
|
||||
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
|
||||
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
|
||||
func getNonceAccept(nonce []byte) (expected []byte, err error) { |
||||
h := sha1.New() |
||||
if _, err = h.Write(nonce); err != nil { |
||||
return |
||||
} |
||||
if _, err = h.Write([]byte(websocketGUID)); err != nil { |
||||
return |
||||
} |
||||
expected = make([]byte, 28) |
||||
base64.StdEncoding.Encode(expected, h.Sum(nil)) |
||||
return |
||||
} |
||||
|
||||
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
|
||||
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) { |
||||
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n") |
||||
|
||||
bw.WriteString("Host: " + config.Location.Host + "\r\n") |
||||
bw.WriteString("Upgrade: websocket\r\n") |
||||
bw.WriteString("Connection: Upgrade\r\n") |
||||
nonce := generateNonce() |
||||
if config.handshakeData != nil { |
||||
nonce = []byte(config.handshakeData["key"]) |
||||
} |
||||
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n") |
||||
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") |
||||
|
||||
if config.Version != ProtocolVersionHybi13 { |
||||
return ErrBadProtocolVersion |
||||
} |
||||
|
||||
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") |
||||
if len(config.Protocol) > 0 { |
||||
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") |
||||
} |
||||
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||
err = config.Header.WriteSubset(bw, handshakeHeader) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
bw.WriteString("\r\n") |
||||
if err = bw.Flush(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if resp.StatusCode != 101 { |
||||
return ErrBadStatus |
||||
} |
||||
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || |
||||
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { |
||||
return ErrBadUpgrade |
||||
} |
||||
expectedAccept, err := getNonceAccept(nonce) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) { |
||||
return ErrChallengeResponse |
||||
} |
||||
if resp.Header.Get("Sec-WebSocket-Extensions") != "" { |
||||
return ErrUnsupportedExtensions |
||||
} |
||||
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol") |
||||
if offeredProtocol != "" { |
||||
protocolMatched := false |
||||
for i := 0; i < len(config.Protocol); i++ { |
||||
if config.Protocol[i] == offeredProtocol { |
||||
protocolMatched = true |
||||
break |
||||
} |
||||
} |
||||
if !protocolMatched { |
||||
return ErrBadWebSocketProtocol |
||||
} |
||||
config.Protocol = []string{offeredProtocol} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// newHybiClientConn creates a client WebSocket connection after handshake.
|
||||
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { |
||||
return newHybiConn(config, buf, rwc, nil) |
||||
} |
||||
|
||||
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
|
||||
type hybiServerHandshaker struct { |
||||
*Config |
||||
accept []byte |
||||
} |
||||
|
||||
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) { |
||||
c.Version = ProtocolVersionHybi13 |
||||
if req.Method != "GET" { |
||||
return http.StatusMethodNotAllowed, ErrBadRequestMethod |
||||
} |
||||
// HTTP version can be safely ignored.
|
||||
|
||||
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || |
||||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { |
||||
return http.StatusBadRequest, ErrNotWebSocket |
||||
} |
||||
|
||||
key := req.Header.Get("Sec-Websocket-Key") |
||||
if key == "" { |
||||
return http.StatusBadRequest, ErrChallengeResponse |
||||
} |
||||
version := req.Header.Get("Sec-Websocket-Version") |
||||
switch version { |
||||
case "13": |
||||
c.Version = ProtocolVersionHybi13 |
||||
default: |
||||
return http.StatusBadRequest, ErrBadWebSocketVersion |
||||
} |
||||
var scheme string |
||||
if req.TLS != nil { |
||||
scheme = "wss" |
||||
} else { |
||||
scheme = "ws" |
||||
} |
||||
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI()) |
||||
if err != nil { |
||||
return http.StatusBadRequest, err |
||||
} |
||||
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) |
||||
if protocol != "" { |
||||
protocols := strings.Split(protocol, ",") |
||||
for i := 0; i < len(protocols); i++ { |
||||
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) |
||||
} |
||||
} |
||||
c.accept, err = getNonceAccept([]byte(key)) |
||||
if err != nil { |
||||
return http.StatusInternalServerError, err |
||||
} |
||||
return http.StatusSwitchingProtocols, nil |
||||
} |
||||
|
||||
// Origin parses Origin header in "req".
|
||||
// If origin is "null", returns (nil, nil).
|
||||
func Origin(config *Config, req *http.Request) (*url.URL, error) { |
||||
var origin string |
||||
switch config.Version { |
||||
case ProtocolVersionHybi13: |
||||
origin = req.Header.Get("Origin") |
||||
} |
||||
if origin == "null" { |
||||
return nil, nil |
||||
} |
||||
return url.ParseRequestURI(origin) |
||||
} |
||||
|
||||
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) { |
||||
if len(c.Protocol) > 0 { |
||||
if len(c.Protocol) != 1 { |
||||
// You need choose a Protocol in Handshake func in Server.
|
||||
return ErrBadWebSocketProtocol |
||||
} |
||||
} |
||||
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n") |
||||
buf.WriteString("Upgrade: websocket\r\n") |
||||
buf.WriteString("Connection: Upgrade\r\n") |
||||
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n") |
||||
if len(c.Protocol) > 0 { |
||||
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") |
||||
} |
||||
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||
if c.Header != nil { |
||||
err := c.Header.WriteSubset(buf, handshakeHeader) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
buf.WriteString("\r\n") |
||||
return buf.Flush() |
||||
} |
||||
|
||||
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { |
||||
return newHybiServerConn(c.Config, buf, rwc, request) |
||||
} |
||||
|
||||
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
|
||||
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { |
||||
return newHybiConn(config, buf, rwc, request) |
||||
} |
@ -1,590 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"net/http" |
||||
"net/url" |
||||
"strings" |
||||
"testing" |
||||
) |
||||
|
||||
// Test the getNonceAccept function with values in
|
||||
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
|
||||
func TestSecWebSocketAccept(t *testing.T) { |
||||
nonce := []byte("dGhlIHNhbXBsZSBub25jZQ==") |
||||
expected := []byte("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=") |
||||
accept, err := getNonceAccept(nonce) |
||||
if err != nil { |
||||
t.Errorf("getNonceAccept: returned error %v", err) |
||||
return |
||||
} |
||||
if !bytes.Equal(expected, accept) { |
||||
t.Errorf("getNonceAccept: expected %q got %q", expected, accept) |
||||
} |
||||
} |
||||
|
||||
func TestHybiClientHandshake(t *testing.T) { |
||||
b := bytes.NewBuffer([]byte{}) |
||||
bw := bufio.NewWriter(b) |
||||
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols |
||||
Upgrade: websocket |
||||
Connection: Upgrade |
||||
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= |
||||
Sec-WebSocket-Protocol: chat |
||||
|
||||
`)) |
||||
var err error |
||||
config := new(Config) |
||||
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat") |
||||
if err != nil { |
||||
t.Fatal("location url", err) |
||||
} |
||||
config.Origin, err = url.ParseRequestURI("http://example.com") |
||||
if err != nil { |
||||
t.Fatal("origin url", err) |
||||
} |
||||
config.Protocol = append(config.Protocol, "chat") |
||||
config.Protocol = append(config.Protocol, "superchat") |
||||
config.Version = ProtocolVersionHybi13 |
||||
|
||||
config.handshakeData = map[string]string{ |
||||
"key": "dGhlIHNhbXBsZSBub25jZQ==", |
||||
} |
||||
err = hybiClientHandshake(config, br, bw) |
||||
if err != nil { |
||||
t.Errorf("handshake failed: %v", err) |
||||
} |
||||
req, err := http.ReadRequest(bufio.NewReader(b)) |
||||
if err != nil { |
||||
t.Fatalf("read request: %v", err) |
||||
} |
||||
if req.Method != "GET" { |
||||
t.Errorf("request method expected GET, but got %q", req.Method) |
||||
} |
||||
if req.URL.Path != "/chat" { |
||||
t.Errorf("request path expected /chat, but got %q", req.URL.Path) |
||||
} |
||||
if req.Proto != "HTTP/1.1" { |
||||
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto) |
||||
} |
||||
if req.Host != "server.example.com" { |
||||
t.Errorf("request Host expected server.example.com, but got %v", req.Host) |
||||
} |
||||
var expectedHeader = map[string]string{ |
||||
"Connection": "Upgrade", |
||||
"Upgrade": "websocket", |
||||
"Sec-Websocket-Key": config.handshakeData["key"], |
||||
"Origin": config.Origin.String(), |
||||
"Sec-Websocket-Protocol": "chat, superchat", |
||||
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13), |
||||
} |
||||
for k, v := range expectedHeader { |
||||
if req.Header.Get(k) != v { |
||||
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k))) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestHybiClientHandshakeWithHeader(t *testing.T) { |
||||
b := bytes.NewBuffer([]byte{}) |
||||
bw := bufio.NewWriter(b) |
||||
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols |
||||
Upgrade: websocket |
||||
Connection: Upgrade |
||||
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= |
||||
Sec-WebSocket-Protocol: chat |
||||
|
||||
`)) |
||||
var err error |
||||
config := new(Config) |
||||
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat") |
||||
if err != nil { |
||||
t.Fatal("location url", err) |
||||
} |
||||
config.Origin, err = url.ParseRequestURI("http://example.com") |
||||
if err != nil { |
||||
t.Fatal("origin url", err) |
||||
} |
||||
config.Protocol = append(config.Protocol, "chat") |
||||
config.Protocol = append(config.Protocol, "superchat") |
||||
config.Version = ProtocolVersionHybi13 |
||||
config.Header = http.Header(make(map[string][]string)) |
||||
config.Header.Add("User-Agent", "test") |
||||
|
||||
config.handshakeData = map[string]string{ |
||||
"key": "dGhlIHNhbXBsZSBub25jZQ==", |
||||
} |
||||
err = hybiClientHandshake(config, br, bw) |
||||
if err != nil { |
||||
t.Errorf("handshake failed: %v", err) |
||||
} |
||||
req, err := http.ReadRequest(bufio.NewReader(b)) |
||||
if err != nil { |
||||
t.Fatalf("read request: %v", err) |
||||
} |
||||
if req.Method != "GET" { |
||||
t.Errorf("request method expected GET, but got %q", req.Method) |
||||
} |
||||
if req.URL.Path != "/chat" { |
||||
t.Errorf("request path expected /chat, but got %q", req.URL.Path) |
||||
} |
||||
if req.Proto != "HTTP/1.1" { |
||||
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto) |
||||
} |
||||
if req.Host != "server.example.com" { |
||||
t.Errorf("request Host expected server.example.com, but got %v", req.Host) |
||||
} |
||||
var expectedHeader = map[string]string{ |
||||
"Connection": "Upgrade", |
||||
"Upgrade": "websocket", |
||||
"Sec-Websocket-Key": config.handshakeData["key"], |
||||
"Origin": config.Origin.String(), |
||||
"Sec-Websocket-Protocol": "chat, superchat", |
||||
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13), |
||||
"User-Agent": "test", |
||||
} |
||||
for k, v := range expectedHeader { |
||||
if req.Header.Get(k) != v { |
||||
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k))) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestHybiServerHandshake(t *testing.T) { |
||||
config := new(Config) |
||||
handshaker := &hybiServerHandshaker{Config: config} |
||||
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 |
||||
Host: server.example.com |
||||
Upgrade: websocket |
||||
Connection: Upgrade |
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== |
||||
Origin: http://example.com
|
||||
Sec-WebSocket-Protocol: chat, superchat |
||||
Sec-WebSocket-Version: 13 |
||||
|
||||
`)) |
||||
req, err := http.ReadRequest(br) |
||||
if err != nil { |
||||
t.Fatal("request", err) |
||||
} |
||||
code, err := handshaker.ReadHandshake(br, req) |
||||
if err != nil { |
||||
t.Errorf("handshake failed: %v", err) |
||||
} |
||||
if code != http.StatusSwitchingProtocols { |
||||
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) |
||||
} |
||||
expectedProtocols := []string{"chat", "superchat"} |
||||
if fmt.Sprintf("%v", config.Protocol) != fmt.Sprintf("%v", expectedProtocols) { |
||||
t.Errorf("protocol expected %q but got %q", expectedProtocols, config.Protocol) |
||||
} |
||||
b := bytes.NewBuffer([]byte{}) |
||||
bw := bufio.NewWriter(b) |
||||
|
||||
config.Protocol = config.Protocol[:1] |
||||
|
||||
err = handshaker.AcceptHandshake(bw) |
||||
if err != nil { |
||||
t.Errorf("handshake response failed: %v", err) |
||||
} |
||||
expectedResponse := strings.Join([]string{ |
||||
"HTTP/1.1 101 Switching Protocols", |
||||
"Upgrade: websocket", |
||||
"Connection: Upgrade", |
||||
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", |
||||
"Sec-WebSocket-Protocol: chat", |
||||
"", ""}, "\r\n") |
||||
|
||||
if b.String() != expectedResponse { |
||||
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String()) |
||||
} |
||||
} |
||||
|
||||
func TestHybiServerHandshakeNoSubProtocol(t *testing.T) { |
||||
config := new(Config) |
||||
handshaker := &hybiServerHandshaker{Config: config} |
||||
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 |
||||
Host: server.example.com |
||||
Upgrade: websocket |
||||
Connection: Upgrade |
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== |
||||
Origin: http://example.com
|
||||
Sec-WebSocket-Version: 13 |
||||
|
||||
`)) |
||||
req, err := http.ReadRequest(br) |
||||
if err != nil { |
||||
t.Fatal("request", err) |
||||
} |
||||
code, err := handshaker.ReadHandshake(br, req) |
||||
if err != nil { |
||||
t.Errorf("handshake failed: %v", err) |
||||
} |
||||
if code != http.StatusSwitchingProtocols { |
||||
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) |
||||
} |
||||
if len(config.Protocol) != 0 { |
||||
t.Errorf("len(config.Protocol) expected 0, but got %q", len(config.Protocol)) |
||||
} |
||||
b := bytes.NewBuffer([]byte{}) |
||||
bw := bufio.NewWriter(b) |
||||
|
||||
err = handshaker.AcceptHandshake(bw) |
||||
if err != nil { |
||||
t.Errorf("handshake response failed: %v", err) |
||||
} |
||||
expectedResponse := strings.Join([]string{ |
||||
"HTTP/1.1 101 Switching Protocols", |
||||
"Upgrade: websocket", |
||||
"Connection: Upgrade", |
||||
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", |
||||
"", ""}, "\r\n") |
||||
|
||||
if b.String() != expectedResponse { |
||||
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String()) |
||||
} |
||||
} |
||||
|
||||
func TestHybiServerHandshakeHybiBadVersion(t *testing.T) { |
||||
config := new(Config) |
||||
handshaker := &hybiServerHandshaker{Config: config} |
||||
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 |
||||
Host: server.example.com |
||||
Upgrade: websocket |
||||
Connection: Upgrade |
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== |
||||
Sec-WebSocket-Origin: http://example.com
|
||||
Sec-WebSocket-Protocol: chat, superchat |
||||
Sec-WebSocket-Version: 9 |
||||
|
||||
`)) |
||||
req, err := http.ReadRequest(br) |
||||
if err != nil { |
||||
t.Fatal("request", err) |
||||
} |
||||
code, err := handshaker.ReadHandshake(br, req) |
||||
if err != ErrBadWebSocketVersion { |
||||
t.Errorf("handshake expected err %q but got %q", ErrBadWebSocketVersion, err) |
||||
} |
||||
if code != http.StatusBadRequest { |
||||
t.Errorf("status expected %q but got %q", http.StatusBadRequest, code) |
||||
} |
||||
} |
||||
|
||||
func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []byte, frameHeader *hybiFrameHeader) { |
||||
b := bytes.NewBuffer([]byte{}) |
||||
frameWriterFactory := &hybiFrameWriterFactory{bufio.NewWriter(b), false} |
||||
w, _ := frameWriterFactory.NewFrameWriter(TextFrame) |
||||
w.(*hybiFrameWriter).header = frameHeader |
||||
_, err := w.Write(testPayload) |
||||
w.Close() |
||||
if err != nil { |
||||
t.Errorf("Write error %q", err) |
||||
} |
||||
var expectedFrame []byte |
||||
expectedFrame = append(expectedFrame, testHeader...) |
||||
expectedFrame = append(expectedFrame, testMaskedPayload...) |
||||
if !bytes.Equal(expectedFrame, b.Bytes()) { |
||||
t.Errorf("frame expected %q got %q", expectedFrame, b.Bytes()) |
||||
} |
||||
frameReaderFactory := &hybiFrameReaderFactory{bufio.NewReader(b)} |
||||
r, err := frameReaderFactory.NewFrameReader() |
||||
if err != nil { |
||||
t.Errorf("Read error %q", err) |
||||
} |
||||
if header := r.HeaderReader(); header == nil { |
||||
t.Errorf("no header") |
||||
} else { |
||||
actualHeader := make([]byte, r.Len()) |
||||
n, err := header.Read(actualHeader) |
||||
if err != nil { |
||||
t.Errorf("Read header error %q", err) |
||||
} else { |
||||
if n < len(testHeader) { |
||||
t.Errorf("header too short %q got %q", testHeader, actualHeader[:n]) |
||||
} |
||||
if !bytes.Equal(testHeader, actualHeader[:n]) { |
||||
t.Errorf("header expected %q got %q", testHeader, actualHeader[:n]) |
||||
} |
||||
} |
||||
} |
||||
if trailer := r.TrailerReader(); trailer != nil { |
||||
t.Errorf("unexpected trailer %q", trailer) |
||||
} |
||||
frame := r.(*hybiFrameReader) |
||||
if frameHeader.Fin != frame.header.Fin || |
||||
frameHeader.OpCode != frame.header.OpCode || |
||||
len(testPayload) != int(frame.header.Length) { |
||||
t.Errorf("mismatch %v (%d) vs %v", frameHeader, len(testPayload), frame) |
||||
} |
||||
payload := make([]byte, len(testPayload)) |
||||
_, err = r.Read(payload) |
||||
if err != nil { |
||||
t.Errorf("read %v", err) |
||||
} |
||||
if !bytes.Equal(testPayload, payload) { |
||||
t.Errorf("payload %q vs %q", testPayload, payload) |
||||
} |
||||
} |
||||
|
||||
func TestHybiShortTextFrame(t *testing.T) { |
||||
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame} |
||||
payload := []byte("hello") |
||||
testHybiFrame(t, []byte{0x81, 0x05}, payload, payload, frameHeader) |
||||
|
||||
payload = make([]byte, 125) |
||||
testHybiFrame(t, []byte{0x81, 125}, payload, payload, frameHeader) |
||||
} |
||||
|
||||
func TestHybiShortMaskedTextFrame(t *testing.T) { |
||||
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame, |
||||
MaskingKey: []byte{0xcc, 0x55, 0x80, 0x20}} |
||||
payload := []byte("hello") |
||||
maskedPayload := []byte{0xa4, 0x30, 0xec, 0x4c, 0xa3} |
||||
header := []byte{0x81, 0x85} |
||||
header = append(header, frameHeader.MaskingKey...) |
||||
testHybiFrame(t, header, payload, maskedPayload, frameHeader) |
||||
} |
||||
|
||||
func TestHybiShortBinaryFrame(t *testing.T) { |
||||
frameHeader := &hybiFrameHeader{Fin: true, OpCode: BinaryFrame} |
||||
payload := []byte("hello") |
||||
testHybiFrame(t, []byte{0x82, 0x05}, payload, payload, frameHeader) |
||||
|
||||
payload = make([]byte, 125) |
||||
testHybiFrame(t, []byte{0x82, 125}, payload, payload, frameHeader) |
||||
} |
||||
|
||||
func TestHybiControlFrame(t *testing.T) { |
||||
frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame} |
||||
payload := []byte("hello") |
||||
testHybiFrame(t, []byte{0x89, 0x05}, payload, payload, frameHeader) |
||||
|
||||
frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame} |
||||
testHybiFrame(t, []byte{0x8A, 0x05}, payload, payload, frameHeader) |
||||
|
||||
frameHeader = &hybiFrameHeader{Fin: true, OpCode: CloseFrame} |
||||
payload = []byte{0x03, 0xe8} // 1000
|
||||
testHybiFrame(t, []byte{0x88, 0x02}, payload, payload, frameHeader) |
||||
} |
||||
|
||||
func TestHybiLongFrame(t *testing.T) { |
||||
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame} |
||||
payload := make([]byte, 126) |
||||
testHybiFrame(t, []byte{0x81, 126, 0x00, 126}, payload, payload, frameHeader) |
||||
|
||||
payload = make([]byte, 65535) |
||||
testHybiFrame(t, []byte{0x81, 126, 0xff, 0xff}, payload, payload, frameHeader) |
||||
|
||||
payload = make([]byte, 65536) |
||||
testHybiFrame(t, []byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, payload, payload, frameHeader) |
||||
} |
||||
|
||||
func TestHybiClientRead(t *testing.T) { |
||||
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o', |
||||
0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping
|
||||
0x81, 0x05, 'w', 'o', 'r', 'l', 'd'} |
||||
br := bufio.NewReader(bytes.NewBuffer(wireData)) |
||||
bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) |
||||
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil) |
||||
|
||||
msg := make([]byte, 512) |
||||
n, err := conn.Read(msg) |
||||
if err != nil { |
||||
t.Errorf("read 1st frame, error %q", err) |
||||
} |
||||
if n != 5 { |
||||
t.Errorf("read 1st frame, expect 5, got %d", n) |
||||
} |
||||
if !bytes.Equal(wireData[2:7], msg[:n]) { |
||||
t.Errorf("read 1st frame %v, got %v", wireData[2:7], msg[:n]) |
||||
} |
||||
n, err = conn.Read(msg) |
||||
if err != nil { |
||||
t.Errorf("read 2nd frame, error %q", err) |
||||
} |
||||
if n != 5 { |
||||
t.Errorf("read 2nd frame, expect 5, got %d", n) |
||||
} |
||||
if !bytes.Equal(wireData[16:21], msg[:n]) { |
||||
t.Errorf("read 2nd frame %v, got %v", wireData[16:21], msg[:n]) |
||||
} |
||||
n, err = conn.Read(msg) |
||||
if err == nil { |
||||
t.Errorf("read not EOF") |
||||
} |
||||
if n != 0 { |
||||
t.Errorf("expect read 0, got %d", n) |
||||
} |
||||
} |
||||
|
||||
func TestHybiShortRead(t *testing.T) { |
||||
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o', |
||||
0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping
|
||||
0x81, 0x05, 'w', 'o', 'r', 'l', 'd'} |
||||
br := bufio.NewReader(bytes.NewBuffer(wireData)) |
||||
bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) |
||||
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil) |
||||
|
||||
step := 0 |
||||
pos := 0 |
||||
expectedPos := []int{2, 5, 16, 19} |
||||
expectedLen := []int{3, 2, 3, 2} |
||||
for { |
||||
msg := make([]byte, 3) |
||||
n, err := conn.Read(msg) |
||||
if step >= len(expectedPos) { |
||||
if err == nil { |
||||
t.Errorf("read not EOF") |
||||
} |
||||
if n != 0 { |
||||
t.Errorf("expect read 0, got %d", n) |
||||
} |
||||
return |
||||
} |
||||
pos = expectedPos[step] |
||||
endPos := pos + expectedLen[step] |
||||
if err != nil { |
||||
t.Errorf("read from %d, got error %q", pos, err) |
||||
return |
||||
} |
||||
if n != endPos-pos { |
||||
t.Errorf("read from %d, expect %d, got %d", pos, endPos-pos, n) |
||||
} |
||||
if !bytes.Equal(wireData[pos:endPos], msg[:n]) { |
||||
t.Errorf("read from %d, frame %v, got %v", pos, wireData[pos:endPos], msg[:n]) |
||||
} |
||||
step++ |
||||
} |
||||
} |
||||
|
||||
func TestHybiServerRead(t *testing.T) { |
||||
wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20, |
||||
0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello
|
||||
0x89, 0x85, 0xcc, 0x55, 0x80, 0x20, |
||||
0xa4, 0x30, 0xec, 0x4c, 0xa3, // ping: hello
|
||||
0x81, 0x85, 0xed, 0x83, 0xb4, 0x24, |
||||
0x9a, 0xec, 0xc6, 0x48, 0x89, // world
|
||||
} |
||||
br := bufio.NewReader(bytes.NewBuffer(wireData)) |
||||
bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) |
||||
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request)) |
||||
|
||||
expected := [][]byte{[]byte("hello"), []byte("world")} |
||||
|
||||
msg := make([]byte, 512) |
||||
n, err := conn.Read(msg) |
||||
if err != nil { |
||||
t.Errorf("read 1st frame, error %q", err) |
||||
} |
||||
if n != 5 { |
||||
t.Errorf("read 1st frame, expect 5, got %d", n) |
||||
} |
||||
if !bytes.Equal(expected[0], msg[:n]) { |
||||
t.Errorf("read 1st frame %q, got %q", expected[0], msg[:n]) |
||||
} |
||||
|
||||
n, err = conn.Read(msg) |
||||
if err != nil { |
||||
t.Errorf("read 2nd frame, error %q", err) |
||||
} |
||||
if n != 5 { |
||||
t.Errorf("read 2nd frame, expect 5, got %d", n) |
||||
} |
||||
if !bytes.Equal(expected[1], msg[:n]) { |
||||
t.Errorf("read 2nd frame %q, got %q", expected[1], msg[:n]) |
||||
} |
||||
|
||||
n, err = conn.Read(msg) |
||||
if err == nil { |
||||
t.Errorf("read not EOF") |
||||
} |
||||
if n != 0 { |
||||
t.Errorf("expect read 0, got %d", n) |
||||
} |
||||
} |
||||
|
||||
func TestHybiServerReadWithoutMasking(t *testing.T) { |
||||
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o'} |
||||
br := bufio.NewReader(bytes.NewBuffer(wireData)) |
||||
bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) |
||||
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request)) |
||||
// server MUST close the connection upon receiving a non-masked frame.
|
||||
msg := make([]byte, 512) |
||||
_, err := conn.Read(msg) |
||||
if err != io.EOF { |
||||
t.Errorf("read 1st frame, expect %q, but got %q", io.EOF, err) |
||||
} |
||||
} |
||||
|
||||
func TestHybiClientReadWithMasking(t *testing.T) { |
||||
wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20, |
||||
0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello
|
||||
} |
||||
br := bufio.NewReader(bytes.NewBuffer(wireData)) |
||||
bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) |
||||
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil) |
||||
|
||||
// client MUST close the connection upon receiving a masked frame.
|
||||
msg := make([]byte, 512) |
||||
_, err := conn.Read(msg) |
||||
if err != io.EOF { |
||||
t.Errorf("read 1st frame, expect %q, but got %q", io.EOF, err) |
||||
} |
||||
} |
||||
|
||||
// Test the hybiServerHandshaker supports firefox implementation and
|
||||
// checks Connection request header include (but it's not necessary
|
||||
// equal to) "upgrade"
|
||||
func TestHybiServerFirefoxHandshake(t *testing.T) { |
||||
config := new(Config) |
||||
handshaker := &hybiServerHandshaker{Config: config} |
||||
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 |
||||
Host: server.example.com |
||||
Upgrade: websocket |
||||
Connection: keep-alive, upgrade |
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== |
||||
Origin: http://example.com
|
||||
Sec-WebSocket-Protocol: chat, superchat |
||||
Sec-WebSocket-Version: 13 |
||||
|
||||
`)) |
||||
req, err := http.ReadRequest(br) |
||||
if err != nil { |
||||
t.Fatal("request", err) |
||||
} |
||||
code, err := handshaker.ReadHandshake(br, req) |
||||
if err != nil { |
||||
t.Errorf("handshake failed: %v", err) |
||||
} |
||||
if code != http.StatusSwitchingProtocols { |
||||
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) |
||||
} |
||||
b := bytes.NewBuffer([]byte{}) |
||||
bw := bufio.NewWriter(b) |
||||
|
||||
config.Protocol = []string{"chat"} |
||||
|
||||
err = handshaker.AcceptHandshake(bw) |
||||
if err != nil { |
||||
t.Errorf("handshake response failed: %v", err) |
||||
} |
||||
expectedResponse := strings.Join([]string{ |
||||
"HTTP/1.1 101 Switching Protocols", |
||||
"Upgrade: websocket", |
||||
"Connection: Upgrade", |
||||
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", |
||||
"Sec-WebSocket-Protocol: chat", |
||||
"", ""}, "\r\n") |
||||
|
||||
if b.String() != expectedResponse { |
||||
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String()) |
||||
} |
||||
} |
@ -1,114 +0,0 @@ |
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket |
||||
|
||||
import ( |
||||
"bufio" |
||||
"fmt" |
||||
"io" |
||||
"net/http" |
||||
) |
||||
|
||||
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) { |
||||
var hs serverHandshaker = &hybiServerHandshaker{Config: config} |
||||
code, err := hs.ReadHandshake(buf.Reader, req) |
||||
if err == ErrBadWebSocketVersion { |
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
||||
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) |
||||
buf.WriteString("\r\n") |
||||
buf.WriteString(err.Error()) |
||||
buf.Flush() |
||||
return |
||||
} |
||||
if err != nil { |
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
||||
buf.WriteString("\r\n") |
||||
buf.WriteString(err.Error()) |
||||
buf.Flush() |
||||
return |
||||
} |
||||
if handshake != nil { |
||||
err = handshake(config, req) |
||||
if err != nil { |
||||
code = http.StatusForbidden |
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
||||
buf.WriteString("\r\n") |
||||
buf.Flush() |
||||
return |
||||
} |
||||
} |
||||
err = hs.AcceptHandshake(buf.Writer) |
||||
if err != nil { |
||||
code = http.StatusBadRequest |
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
||||
buf.WriteString("\r\n") |
||||
buf.Flush() |
||||
return |
||||
} |
||||
conn = hs.NewServerConn(buf, rwc, req) |
||||
return |
||||
} |
||||
|
||||
// Server represents a server of a WebSocket.
|
||||
type Server struct { |
||||
// Config is a WebSocket configuration for new WebSocket connection.
|
||||
Config |
||||
|
||||
// Handshake is an optional function in WebSocket handshake.
|
||||
// For example, you can check, or don't check Origin header.
|
||||
// Another example, you can select config.Protocol.
|
||||
Handshake func(*Config, *http.Request) error |
||||
|
||||
// Handler handles a WebSocket connection.
|
||||
Handler |
||||
} |
||||
|
||||
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
||||
s.serveWebSocket(w, req) |
||||
} |
||||
|
||||
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) { |
||||
rwc, buf, err := w.(http.Hijacker).Hijack() |
||||
if err != nil { |
||||
panic("Hijack failed: " + err.Error()) |
||||
return |
||||
} |
||||
// The server should abort the WebSocket connection if it finds
|
||||
// the client did not send a handshake that matches with protocol
|
||||
// specification.
|
||||
defer rwc.Close() |
||||
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake) |
||||
if err != nil { |
||||
return |
||||
} |
||||
if conn == nil { |
||||
panic("unexpected nil conn") |
||||
} |
||||
s.Handler(conn) |
||||
} |
||||
|
||||
// Handler is a simple interface to a WebSocket browser client.
|
||||
// It checks if Origin header is valid URL by default.
|
||||
// You might want to verify websocket.Conn.Config().Origin in the func.
|
||||
// If you use Server instead of Handler, you could call websocket.Origin and
|
||||
// check the origin in your Handshake func. So, if you want to accept
|
||||
// non-browser client, which doesn't send Origin header, you could use Server
|
||||
//. that doesn't check origin in its Handshake.
|
||||
type Handler func(*Conn) |
||||
|
||||
func checkOrigin(config *Config, req *http.Request) (err error) { |
||||
config.Origin, err = Origin(config, req) |
||||
if err == nil && config.Origin == nil { |
||||
return fmt.Errorf("null origin") |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
||||
s := Server{Handler: h, Handshake: checkOrigin} |
||||
s.serveWebSocket(w, req) |
||||
} |
@ -1,411 +0,0 @@ |
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package websocket implements a client and server for the WebSocket protocol
|
||||
// as specified in RFC 6455.
|
||||
package websocket |
||||
|
||||
import ( |
||||
"bufio" |
||||
"crypto/tls" |
||||
"encoding/json" |
||||
"errors" |
||||
"io" |
||||
"io/ioutil" |
||||
"net" |
||||
"net/http" |
||||
"net/url" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
const ( |
||||
ProtocolVersionHybi13 = 13 |
||||
ProtocolVersionHybi = ProtocolVersionHybi13 |
||||
SupportedProtocolVersion = "13" |
||||
|
||||
ContinuationFrame = 0 |
||||
TextFrame = 1 |
||||
BinaryFrame = 2 |
||||
CloseFrame = 8 |
||||
PingFrame = 9 |
||||
PongFrame = 10 |
||||
UnknownFrame = 255 |
||||
) |
||||
|
||||
// ProtocolError represents WebSocket protocol errors.
|
||||
type ProtocolError struct { |
||||
ErrorString string |
||||
} |
||||
|
||||
func (err *ProtocolError) Error() string { return err.ErrorString } |
||||
|
||||
var ( |
||||
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"} |
||||
ErrBadScheme = &ProtocolError{"bad scheme"} |
||||
ErrBadStatus = &ProtocolError{"bad status"} |
||||
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} |
||||
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} |
||||
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"} |
||||
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"} |
||||
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"} |
||||
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"} |
||||
ErrBadFrame = &ProtocolError{"bad frame"} |
||||
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"} |
||||
ErrNotWebSocket = &ProtocolError{"not websocket protocol"} |
||||
ErrBadRequestMethod = &ProtocolError{"bad method"} |
||||
ErrNotSupported = &ProtocolError{"not supported"} |
||||
) |
||||
|
||||
// Addr is an implementation of net.Addr for WebSocket.
|
||||
type Addr struct { |
||||
*url.URL |
||||
} |
||||
|
||||
// Network returns the network type for a WebSocket, "websocket".
|
||||
func (addr *Addr) Network() string { return "websocket" } |
||||
|
||||
// Config is a WebSocket configuration
|
||||
type Config struct { |
||||
// A WebSocket server address.
|
||||
Location *url.URL |
||||
|
||||
// A Websocket client origin.
|
||||
Origin *url.URL |
||||
|
||||
// WebSocket subprotocols.
|
||||
Protocol []string |
||||
|
||||
// WebSocket protocol version.
|
||||
Version int |
||||
|
||||
// TLS config for secure WebSocket (wss).
|
||||
TlsConfig *tls.Config |
||||
|
||||
// Additional header fields to be sent in WebSocket opening handshake.
|
||||
Header http.Header |
||||
|
||||
handshakeData map[string]string |
||||
} |
||||
|
||||
// serverHandshaker is an interface to handle WebSocket server side handshake.
|
||||
type serverHandshaker interface { |
||||
// ReadHandshake reads handshake request message from client.
|
||||
// Returns http response code and error if any.
|
||||
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) |
||||
|
||||
// AcceptHandshake accepts the client handshake request and sends
|
||||
// handshake response back to client.
|
||||
AcceptHandshake(buf *bufio.Writer) (err error) |
||||
|
||||
// NewServerConn creates a new WebSocket connection.
|
||||
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn) |
||||
} |
||||
|
||||
// frameReader is an interface to read a WebSocket frame.
|
||||
type frameReader interface { |
||||
// Reader is to read payload of the frame.
|
||||
io.Reader |
||||
|
||||
// PayloadType returns payload type.
|
||||
PayloadType() byte |
||||
|
||||
// HeaderReader returns a reader to read header of the frame.
|
||||
HeaderReader() io.Reader |
||||
|
||||
// TrailerReader returns a reader to read trailer of the frame.
|
||||
// If it returns nil, there is no trailer in the frame.
|
||||
TrailerReader() io.Reader |
||||
|
||||
// Len returns total length of the frame, including header and trailer.
|
||||
Len() int |
||||
} |
||||
|
||||
// frameReaderFactory is an interface to creates new frame reader.
|
||||
type frameReaderFactory interface { |
||||
NewFrameReader() (r frameReader, err error) |
||||
} |
||||
|
||||
// frameWriter is an interface to write a WebSocket frame.
|
||||
type frameWriter interface { |
||||
// Writer is to write payload of the frame.
|
||||
io.WriteCloser |
||||
} |
||||
|
||||
// frameWriterFactory is an interface to create new frame writer.
|
||||
type frameWriterFactory interface { |
||||
NewFrameWriter(payloadType byte) (w frameWriter, err error) |
||||
} |
||||
|
||||
type frameHandler interface { |
||||
HandleFrame(frame frameReader) (r frameReader, err error) |
||||
WriteClose(status int) (err error) |
||||
} |
||||
|
||||
// Conn represents a WebSocket connection.
|
||||
type Conn struct { |
||||
config *Config |
||||
request *http.Request |
||||
|
||||
buf *bufio.ReadWriter |
||||
rwc io.ReadWriteCloser |
||||
|
||||
rio sync.Mutex |
||||
frameReaderFactory |
||||
frameReader |
||||
|
||||
wio sync.Mutex |
||||
frameWriterFactory |
||||
|
||||
frameHandler |
||||
PayloadType byte |
||||
defaultCloseStatus int |
||||
} |
||||
|
||||
// Read implements the io.Reader interface:
|
||||
// it reads data of a frame from the WebSocket connection.
|
||||
// if msg is not large enough for the frame data, it fills the msg and next Read
|
||||
// will read the rest of the frame data.
|
||||
// it reads Text frame or Binary frame.
|
||||
func (ws *Conn) Read(msg []byte) (n int, err error) { |
||||
ws.rio.Lock() |
||||
defer ws.rio.Unlock() |
||||
again: |
||||
if ws.frameReader == nil { |
||||
frame, err := ws.frameReaderFactory.NewFrameReader() |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
ws.frameReader, err = ws.frameHandler.HandleFrame(frame) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
if ws.frameReader == nil { |
||||
goto again |
||||
} |
||||
} |
||||
n, err = ws.frameReader.Read(msg) |
||||
if err == io.EOF { |
||||
if trailer := ws.frameReader.TrailerReader(); trailer != nil { |
||||
io.Copy(ioutil.Discard, trailer) |
||||
} |
||||
ws.frameReader = nil |
||||
goto again |
||||
} |
||||
return n, err |
||||
} |
||||
|
||||
// Write implements the io.Writer interface:
|
||||
// it writes data as a frame to the WebSocket connection.
|
||||
func (ws *Conn) Write(msg []byte) (n int, err error) { |
||||
ws.wio.Lock() |
||||
defer ws.wio.Unlock() |
||||
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
n, err = w.Write(msg) |
||||
w.Close() |
||||
if err != nil { |
||||
return n, err |
||||
} |
||||
return n, err |
||||
} |
||||
|
||||
// Close implements the io.Closer interface.
|
||||
func (ws *Conn) Close() error { |
||||
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return ws.rwc.Close() |
||||
} |
||||
|
||||
func (ws *Conn) IsClientConn() bool { return ws.request == nil } |
||||
func (ws *Conn) IsServerConn() bool { return ws.request != nil } |
||||
|
||||
// LocalAddr returns the WebSocket Origin for the connection for client, or
|
||||
// the WebSocket location for server.
|
||||
func (ws *Conn) LocalAddr() net.Addr { |
||||
if ws.IsClientConn() { |
||||
return &Addr{ws.config.Origin} |
||||
} |
||||
return &Addr{ws.config.Location} |
||||
} |
||||
|
||||
// RemoteAddr returns the WebSocket location for the connection for client, or
|
||||
// the Websocket Origin for server.
|
||||
func (ws *Conn) RemoteAddr() net.Addr { |
||||
if ws.IsClientConn() { |
||||
return &Addr{ws.config.Location} |
||||
} |
||||
return &Addr{ws.config.Origin} |
||||
} |
||||
|
||||
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn") |
||||
|
||||
// SetDeadline sets the connection's network read & write deadlines.
|
||||
func (ws *Conn) SetDeadline(t time.Time) error { |
||||
if conn, ok := ws.rwc.(net.Conn); ok { |
||||
return conn.SetDeadline(t) |
||||
} |
||||
return errSetDeadline |
||||
} |
||||
|
||||
// SetReadDeadline sets the connection's network read deadline.
|
||||
func (ws *Conn) SetReadDeadline(t time.Time) error { |
||||
if conn, ok := ws.rwc.(net.Conn); ok { |
||||
return conn.SetReadDeadline(t) |
||||
} |
||||
return errSetDeadline |
||||
} |
||||
|
||||
// SetWriteDeadline sets the connection's network write deadline.
|
||||
func (ws *Conn) SetWriteDeadline(t time.Time) error { |
||||
if conn, ok := ws.rwc.(net.Conn); ok { |
||||
return conn.SetWriteDeadline(t) |
||||
} |
||||
return errSetDeadline |
||||
} |
||||
|
||||
// Config returns the WebSocket config.
|
||||
func (ws *Conn) Config() *Config { return ws.config } |
||||
|
||||
// Request returns the http request upgraded to the WebSocket.
|
||||
// It is nil for client side.
|
||||
func (ws *Conn) Request() *http.Request { return ws.request } |
||||
|
||||
// Codec represents a symmetric pair of functions that implement a codec.
|
||||
type Codec struct { |
||||
Marshal func(v interface{}) (data []byte, payloadType byte, err error) |
||||
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error) |
||||
} |
||||
|
||||
// Send sends v marshaled by cd.Marshal as single frame to ws.
|
||||
func (cd Codec) Send(ws *Conn, v interface{}) (err error) { |
||||
data, payloadType, err := cd.Marshal(v) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
ws.wio.Lock() |
||||
defer ws.wio.Unlock() |
||||
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
_, err = w.Write(data) |
||||
w.Close() |
||||
return err |
||||
} |
||||
|
||||
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores in v.
|
||||
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) { |
||||
ws.rio.Lock() |
||||
defer ws.rio.Unlock() |
||||
if ws.frameReader != nil { |
||||
_, err = io.Copy(ioutil.Discard, ws.frameReader) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
ws.frameReader = nil |
||||
} |
||||
again: |
||||
frame, err := ws.frameReaderFactory.NewFrameReader() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
frame, err = ws.frameHandler.HandleFrame(frame) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if frame == nil { |
||||
goto again |
||||
} |
||||
payloadType := frame.PayloadType() |
||||
data, err := ioutil.ReadAll(frame) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return cd.Unmarshal(data, payloadType, v) |
||||
} |
||||
|
||||
func marshal(v interface{}) (msg []byte, payloadType byte, err error) { |
||||
switch data := v.(type) { |
||||
case string: |
||||
return []byte(data), TextFrame, nil |
||||
case []byte: |
||||
return data, BinaryFrame, nil |
||||
} |
||||
return nil, UnknownFrame, ErrNotSupported |
||||
} |
||||
|
||||
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) { |
||||
switch data := v.(type) { |
||||
case *string: |
||||
*data = string(msg) |
||||
return nil |
||||
case *[]byte: |
||||
*data = msg |
||||
return nil |
||||
} |
||||
return ErrNotSupported |
||||
} |
||||
|
||||
/* |
||||
Message is a codec to send/receive text/binary data in a frame on WebSocket connection. |
||||
To send/receive text frame, use string type. |
||||
To send/receive binary frame, use []byte type. |
||||
|
||||
Trivial usage: |
||||
|
||||
import "websocket" |
||||
|
||||
// receive text frame
|
||||
var message string |
||||
websocket.Message.Receive(ws, &message) |
||||
|
||||
// send text frame
|
||||
message = "hello" |
||||
websocket.Message.Send(ws, message) |
||||
|
||||
// receive binary frame
|
||||
var data []byte |
||||
websocket.Message.Receive(ws, &data) |
||||
|
||||
// send binary frame
|
||||
data = []byte{0, 1, 2} |
||||
websocket.Message.Send(ws, data) |
||||
|
||||
*/ |
||||
var Message = Codec{marshal, unmarshal} |
||||
|
||||
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) { |
||||
msg, err = json.Marshal(v) |
||||
return msg, TextFrame, err |
||||
} |
||||
|
||||
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) { |
||||
return json.Unmarshal(msg, v) |
||||
} |
||||
|
||||
/* |
||||
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection. |
||||
|
||||
Trivial usage: |
||||
|
||||
import "websocket" |
||||
|
||||
type T struct { |
||||
Msg string |
||||
Count int |
||||
} |
||||
|
||||
// receive JSON type T
|
||||
var data T |
||||
websocket.JSON.Receive(ws, &data) |
||||
|
||||
// send JSON type T
|
||||
websocket.JSON.Send(ws, data) |
||||
*/ |
||||
var JSON = Codec{jsonMarshal, jsonUnmarshal} |
@ -1,414 +0,0 @@ |
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"net" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"net/url" |
||||
"strings" |
||||
"sync" |
||||
"testing" |
||||
) |
||||
|
||||
var serverAddr string |
||||
var once sync.Once |
||||
|
||||
func echoServer(ws *Conn) { io.Copy(ws, ws) } |
||||
|
||||
type Count struct { |
||||
S string |
||||
N int |
||||
} |
||||
|
||||
func countServer(ws *Conn) { |
||||
for { |
||||
var count Count |
||||
err := JSON.Receive(ws, &count) |
||||
if err != nil { |
||||
return |
||||
} |
||||
count.N++ |
||||
count.S = strings.Repeat(count.S, count.N) |
||||
err = JSON.Send(ws, count) |
||||
if err != nil { |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func subProtocolHandshake(config *Config, req *http.Request) error { |
||||
for _, proto := range config.Protocol { |
||||
if proto == "chat" { |
||||
config.Protocol = []string{proto} |
||||
return nil |
||||
} |
||||
} |
||||
return ErrBadWebSocketProtocol |
||||
} |
||||
|
||||
func subProtoServer(ws *Conn) { |
||||
for _, proto := range ws.Config().Protocol { |
||||
io.WriteString(ws, proto) |
||||
} |
||||
} |
||||
|
||||
func startServer() { |
||||
http.Handle("/echo", Handler(echoServer)) |
||||
http.Handle("/count", Handler(countServer)) |
||||
subproto := Server{ |
||||
Handshake: subProtocolHandshake, |
||||
Handler: Handler(subProtoServer), |
||||
} |
||||
http.Handle("/subproto", subproto) |
||||
server := httptest.NewServer(nil) |
||||
serverAddr = server.Listener.Addr().String() |
||||
log.Print("Test WebSocket server listening on ", serverAddr) |
||||
} |
||||
|
||||
func newConfig(t *testing.T, path string) *Config { |
||||
config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost") |
||||
return config |
||||
} |
||||
|
||||
func TestEcho(t *testing.T) { |
||||
once.Do(startServer) |
||||
|
||||
// websocket.Dial()
|
||||
client, err := net.Dial("tcp", serverAddr) |
||||
if err != nil { |
||||
t.Fatal("dialing", err) |
||||
} |
||||
conn, err := NewClient(newConfig(t, "/echo"), client) |
||||
if err != nil { |
||||
t.Errorf("WebSocket handshake error: %v", err) |
||||
return |
||||
} |
||||
|
||||
msg := []byte("hello, world\n") |
||||
if _, err := conn.Write(msg); err != nil { |
||||
t.Errorf("Write: %v", err) |
||||
} |
||||
var actual_msg = make([]byte, 512) |
||||
n, err := conn.Read(actual_msg) |
||||
if err != nil { |
||||
t.Errorf("Read: %v", err) |
||||
} |
||||
actual_msg = actual_msg[0:n] |
||||
if !bytes.Equal(msg, actual_msg) { |
||||
t.Errorf("Echo: expected %q got %q", msg, actual_msg) |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestAddr(t *testing.T) { |
||||
once.Do(startServer) |
||||
|
||||
// websocket.Dial()
|
||||
client, err := net.Dial("tcp", serverAddr) |
||||
if err != nil { |
||||
t.Fatal("dialing", err) |
||||
} |
||||
conn, err := NewClient(newConfig(t, "/echo"), client) |
||||
if err != nil { |
||||
t.Errorf("WebSocket handshake error: %v", err) |
||||
return |
||||
} |
||||
|
||||
ra := conn.RemoteAddr().String() |
||||
if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") { |
||||
t.Errorf("Bad remote addr: %v", ra) |
||||
} |
||||
la := conn.LocalAddr().String() |
||||
if !strings.HasPrefix(la, "http://") { |
||||
t.Errorf("Bad local addr: %v", la) |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestCount(t *testing.T) { |
||||
once.Do(startServer) |
||||
|
||||
// websocket.Dial()
|
||||
client, err := net.Dial("tcp", serverAddr) |
||||
if err != nil { |
||||
t.Fatal("dialing", err) |
||||
} |
||||
conn, err := NewClient(newConfig(t, "/count"), client) |
||||
if err != nil { |
||||
t.Errorf("WebSocket handshake error: %v", err) |
||||
return |
||||
} |
||||
|
||||
var count Count |
||||
count.S = "hello" |
||||
if err := JSON.Send(conn, count); err != nil { |
||||
t.Errorf("Write: %v", err) |
||||
} |
||||
if err := JSON.Receive(conn, &count); err != nil { |
||||
t.Errorf("Read: %v", err) |
||||
} |
||||
if count.N != 1 { |
||||
t.Errorf("count: expected %d got %d", 1, count.N) |
||||
} |
||||
if count.S != "hello" { |
||||
t.Errorf("count: expected %q got %q", "hello", count.S) |
||||
} |
||||
if err := JSON.Send(conn, count); err != nil { |
||||
t.Errorf("Write: %v", err) |
||||
} |
||||
if err := JSON.Receive(conn, &count); err != nil { |
||||
t.Errorf("Read: %v", err) |
||||
} |
||||
if count.N != 2 { |
||||
t.Errorf("count: expected %d got %d", 2, count.N) |
||||
} |
||||
if count.S != "hellohello" { |
||||
t.Errorf("count: expected %q got %q", "hellohello", count.S) |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestWithQuery(t *testing.T) { |
||||
once.Do(startServer) |
||||
|
||||
client, err := net.Dial("tcp", serverAddr) |
||||
if err != nil { |
||||
t.Fatal("dialing", err) |
||||
} |
||||
|
||||
config := newConfig(t, "/echo") |
||||
config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr)) |
||||
if err != nil { |
||||
t.Fatal("location url", err) |
||||
} |
||||
|
||||
ws, err := NewClient(config, client) |
||||
if err != nil { |
||||
t.Errorf("WebSocket handshake: %v", err) |
||||
return |
||||
} |
||||
ws.Close() |
||||
} |
||||
|
||||
func testWithProtocol(t *testing.T, subproto []string) (string, error) { |
||||
once.Do(startServer) |
||||
|
||||
client, err := net.Dial("tcp", serverAddr) |
||||
if err != nil { |
||||
t.Fatal("dialing", err) |
||||
} |
||||
|
||||
config := newConfig(t, "/subproto") |
||||
config.Protocol = subproto |
||||
|
||||
ws, err := NewClient(config, client) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
msg := make([]byte, 16) |
||||
n, err := ws.Read(msg) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
ws.Close() |
||||
return string(msg[:n]), nil |
||||
} |
||||
|
||||
func TestWithProtocol(t *testing.T) { |
||||
proto, err := testWithProtocol(t, []string{"chat"}) |
||||
if err != nil { |
||||
t.Errorf("SubProto: unexpected error: %v", err) |
||||
} |
||||
if proto != "chat" { |
||||
t.Errorf("SubProto: expected %q, got %q", "chat", proto) |
||||
} |
||||
} |
||||
|
||||
func TestWithTwoProtocol(t *testing.T) { |
||||
proto, err := testWithProtocol(t, []string{"test", "chat"}) |
||||
if err != nil { |
||||
t.Errorf("SubProto: unexpected error: %v", err) |
||||
} |
||||
if proto != "chat" { |
||||
t.Errorf("SubProto: expected %q, got %q", "chat", proto) |
||||
} |
||||
} |
||||
|
||||
func TestWithBadProtocol(t *testing.T) { |
||||
_, err := testWithProtocol(t, []string{"test"}) |
||||
if err != ErrBadStatus { |
||||
t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err) |
||||
} |
||||
} |
||||
|
||||
func TestHTTP(t *testing.T) { |
||||
once.Do(startServer) |
||||
|
||||
// If the client did not send a handshake that matches the protocol
|
||||
// specification, the server MUST return an HTTP response with an
|
||||
// appropriate error code (such as 400 Bad Request)
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) |
||||
if err != nil { |
||||
t.Errorf("Get: error %#v", err) |
||||
return |
||||
} |
||||
if resp == nil { |
||||
t.Error("Get: resp is null") |
||||
return |
||||
} |
||||
if resp.StatusCode != http.StatusBadRequest { |
||||
t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode) |
||||
} |
||||
} |
||||
|
||||
func TestTrailingSpaces(t *testing.T) { |
||||
// http://code.google.com/p/go/issues/detail?id=955
|
||||
// The last runs of this create keys with trailing spaces that should not be
|
||||
// generated by the client.
|
||||
once.Do(startServer) |
||||
config := newConfig(t, "/echo") |
||||
for i := 0; i < 30; i++ { |
||||
// body
|
||||
ws, err := DialConfig(config) |
||||
if err != nil { |
||||
t.Errorf("Dial #%d failed: %v", i, err) |
||||
break |
||||
} |
||||
ws.Close() |
||||
} |
||||
} |
||||
|
||||
func TestDialConfigBadVersion(t *testing.T) { |
||||
once.Do(startServer) |
||||
config := newConfig(t, "/echo") |
||||
config.Version = 1234 |
||||
|
||||
_, err := DialConfig(config) |
||||
|
||||
if dialerr, ok := err.(*DialError); ok { |
||||
if dialerr.Err != ErrBadProtocolVersion { |
||||
t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestSmallBuffer(t *testing.T) { |
||||
// http://code.google.com/p/go/issues/detail?id=1145
|
||||
// Read should be able to handle reading a fragment of a frame.
|
||||
once.Do(startServer) |
||||
|
||||
// websocket.Dial()
|
||||
client, err := net.Dial("tcp", serverAddr) |
||||
if err != nil { |
||||
t.Fatal("dialing", err) |
||||
} |
||||
conn, err := NewClient(newConfig(t, "/echo"), client) |
||||
if err != nil { |
||||
t.Errorf("WebSocket handshake error: %v", err) |
||||
return |
||||
} |
||||
|
||||
msg := []byte("hello, world\n") |
||||
if _, err := conn.Write(msg); err != nil { |
||||
t.Errorf("Write: %v", err) |
||||
} |
||||
var small_msg = make([]byte, 8) |
||||
n, err := conn.Read(small_msg) |
||||
if err != nil { |
||||
t.Errorf("Read: %v", err) |
||||
} |
||||
if !bytes.Equal(msg[:len(small_msg)], small_msg) { |
||||
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) |
||||
} |
||||
var second_msg = make([]byte, len(msg)) |
||||
n, err = conn.Read(second_msg) |
||||
if err != nil { |
||||
t.Errorf("Read: %v", err) |
||||
} |
||||
second_msg = second_msg[0:n] |
||||
if !bytes.Equal(msg[len(small_msg):], second_msg) { |
||||
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
var parseAuthorityTests = []struct { |
||||
in *url.URL |
||||
out string |
||||
}{ |
||||
{ |
||||
&url.URL{ |
||||
Scheme: "ws", |
||||
Host: "www.google.com", |
||||
}, |
||||
"www.google.com:80", |
||||
}, |
||||
{ |
||||
&url.URL{ |
||||
Scheme: "wss", |
||||
Host: "www.google.com", |
||||
}, |
||||
"www.google.com:443", |
||||
}, |
||||
{ |
||||
&url.URL{ |
||||
Scheme: "ws", |
||||
Host: "www.google.com:80", |
||||
}, |
||||
"www.google.com:80", |
||||
}, |
||||
{ |
||||
&url.URL{ |
||||
Scheme: "wss", |
||||
Host: "www.google.com:443", |
||||
}, |
||||
"www.google.com:443", |
||||
}, |
||||
// some invalid ones for parseAuthority. parseAuthority doesn't
|
||||
// concern itself with the scheme unless it actually knows about it
|
||||
{ |
||||
&url.URL{ |
||||
Scheme: "http", |
||||
Host: "www.google.com", |
||||
}, |
||||
"www.google.com", |
||||
}, |
||||
{ |
||||
&url.URL{ |
||||
Scheme: "http", |
||||
Host: "www.google.com:80", |
||||
}, |
||||
"www.google.com:80", |
||||
}, |
||||
{ |
||||
&url.URL{ |
||||
Scheme: "asdf", |
||||
Host: "127.0.0.1", |
||||
}, |
||||
"127.0.0.1", |
||||
}, |
||||
{ |
||||
&url.URL{ |
||||
Scheme: "asdf", |
||||
Host: "www.google.com", |
||||
}, |
||||
"www.google.com", |
||||
}, |
||||
} |
||||
|
||||
func TestParseAuthority(t *testing.T) { |
||||
for _, tt := range parseAuthorityTests { |
||||
out := parseAuthority(tt.in) |
||||
if out != tt.out { |
||||
t.Errorf("got %v; want %v", out, tt.out) |
||||
} |
||||
} |
||||
} |
@ -1,40 +0,0 @@ |
||||
// Copyright (c) 2013-2014, Jeffrey Wilcke. All rights reserved.
|
||||
//
|
||||
// This library is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 2.1 of the License, or (at your option) any later version.
|
||||
//
|
||||
// This library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
// General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this library; if not, write to the Free Software
|
||||
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
||||
// MA 02110-1301 USA
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"io/ioutil" |
||||
"os" |
||||
|
||||
"github.com/ethereum/go-ethereum/eth" |
||||
"github.com/ethereum/go-ethereum/javascript" |
||||
"github.com/ethereum/go-ethereum/xeth" |
||||
) |
||||
|
||||
func ExecJsFile(ethereum *eth.Ethereum, InputFile string) { |
||||
file, err := os.Open(InputFile) |
||||
if err != nil { |
||||
clilogger.Fatalln(err) |
||||
} |
||||
content, err := ioutil.ReadAll(file) |
||||
if err != nil { |
||||
clilogger.Fatalln(err) |
||||
} |
||||
re := javascript.NewJSRE(xeth.New(ethereum)) |
||||
re.Run(string(content)) |
||||
} |
@ -1,168 +0,0 @@ |
||||
/* |
||||
This file is part of go-ethereum |
||||
|
||||
go-ethereum is free software: you can redistribute it and/or modify |
||||
it under the terms of the GNU General Public License as published by |
||||
the Free Software Foundation, either version 3 of the License, or |
||||
(at your option) any later version. |
||||
|
||||
go-ethereum is distributed in the hope that it will be useful, |
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of |
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
||||
GNU General Public License for more details. |
||||
|
||||
You should have received a copy of the GNU General Public License |
||||
along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/ |
||||
/** |
||||
* @authors |
||||
* Jeffrey Wilcke <i@jev.io> |
||||
*/ |
||||
package main |
||||
|
||||
import ( |
||||
"crypto/ecdsa" |
||||
"flag" |
||||
"fmt" |
||||
"log" |
||||
"os" |
||||
"path" |
||||
"runtime" |
||||
|
||||
"github.com/ethereum/go-ethereum/crypto" |
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
"github.com/ethereum/go-ethereum/logger" |
||||
"github.com/ethereum/go-ethereum/p2p/nat" |
||||
"github.com/ethereum/go-ethereum/vm" |
||||
) |
||||
|
||||
var ( |
||||
Identifier string |
||||
KeyRing string |
||||
DiffTool bool |
||||
DiffType string |
||||
KeyStore string |
||||
StartRpc bool |
||||
StartWebSockets bool |
||||
RpcListenAddress string |
||||
RpcPort int |
||||
OutboundPort string |
||||
ShowGenesis bool |
||||
AddPeer string |
||||
MaxPeer int |
||||
GenAddr bool |
||||
BootNodes string |
||||
NodeKey *ecdsa.PrivateKey |
||||
NAT nat.Interface |
||||
SecretFile string |
||||
ExportDir string |
||||
NonInteractive bool |
||||
Datadir string |
||||
LogFile string |
||||
ConfigFile string |
||||
DebugFile string |
||||
LogLevel int |
||||
LogFormat string |
||||
Dump bool |
||||
DumpHash string |
||||
DumpNumber int |
||||
VmType int |
||||
ImportChain string |
||||
SHH bool |
||||
Dial bool |
||||
PrintVersion bool |
||||
MinerThreads int |
||||
) |
||||
|
||||
// flags specific to cli client
|
||||
var ( |
||||
StartMining bool |
||||
StartJsConsole bool |
||||
InputFile string |
||||
) |
||||
|
||||
var defaultConfigFile = path.Join(ethutil.DefaultDataDir(), "conf.ini") |
||||
|
||||
func Init() { |
||||
// TODO: move common flag processing to cmd/util
|
||||
flag.Usage = func() { |
||||
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0]) |
||||
flag.PrintDefaults() |
||||
} |
||||
|
||||
flag.IntVar(&VmType, "vm", 0, "Virtual Machine type: 0-1: standard, debug") |
||||
flag.StringVar(&Identifier, "id", "", "Custom client identifier") |
||||
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use") |
||||
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file") |
||||
|
||||
flag.StringVar(&RpcListenAddress, "rpcaddr", "127.0.0.1", "address for json-rpc server to listen on") |
||||
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on") |
||||
flag.BoolVar(&StartRpc, "rpc", false, "start rpc server") |
||||
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)") |
||||
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key") |
||||
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)") |
||||
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given") |
||||
flag.StringVar(&LogFile, "logfile", "", "log file (defaults to standard output)") |
||||
flag.StringVar(&Datadir, "datadir", ethutil.DefaultDataDir(), "specifies the datadir to use") |
||||
flag.StringVar(&ConfigFile, "conf", defaultConfigFile, "config file") |
||||
flag.StringVar(&DebugFile, "debug", "", "debug file (no debugging if not set)") |
||||
flag.IntVar(&LogLevel, "loglevel", int(logger.InfoLevel), "loglevel: 0-5 (= silent,error,warn,info,debug,debug detail)") |
||||
flag.StringVar(&LogFormat, "logformat", "std", "logformat: std,raw") |
||||
flag.BoolVar(&DiffTool, "difftool", false, "creates output for diff'ing. Sets LogLevel=0") |
||||
flag.StringVar(&DiffType, "diff", "all", "sets the level of diff output [vm, all]. Has no effect if difftool=false") |
||||
flag.BoolVar(&ShowGenesis, "genesis", false, "Dump the genesis block") |
||||
flag.StringVar(&ImportChain, "chain", "", "Imports given chain") |
||||
|
||||
flag.BoolVar(&Dump, "dump", false, "output the ethereum state in JSON format. Sub args [number, hash]") |
||||
flag.StringVar(&DumpHash, "hash", "", "specify arg in hex") |
||||
flag.IntVar(&DumpNumber, "number", -1, "specify arg in number") |
||||
|
||||
flag.BoolVar(&StartMining, "mine", false, "start mining") |
||||
flag.BoolVar(&StartJsConsole, "js", false, "launches javascript console") |
||||
flag.BoolVar(&PrintVersion, "version", false, "prints version number") |
||||
flag.IntVar(&MinerThreads, "minerthreads", runtime.NumCPU(), "number of miner threads") |
||||
|
||||
// Network stuff
|
||||
var ( |
||||
nodeKeyFile = flag.String("nodekey", "", "network private key file") |
||||
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)") |
||||
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)") |
||||
) |
||||
flag.BoolVar(&Dial, "dial", true, "dial out connections (default on)") |
||||
//flag.BoolVar(&SHH, "shh", true, "run whisper protocol (default on)")
|
||||
flag.StringVar(&OutboundPort, "port", "30303", "listening port") |
||||
|
||||
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap") |
||||
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers") |
||||
|
||||
flag.Parse() |
||||
|
||||
// When the javascript console is started log to a file instead
|
||||
// of stdout
|
||||
if StartJsConsole { |
||||
LogFile = path.Join(Datadir, "ethereum.log") |
||||
} |
||||
|
||||
var err error |
||||
if NAT, err = nat.Parse(*natstr); err != nil { |
||||
log.Fatalf("-nat: %v", err) |
||||
} |
||||
switch { |
||||
case *nodeKeyFile != "" && *nodeKeyHex != "": |
||||
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive") |
||||
case *nodeKeyFile != "": |
||||
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil { |
||||
log.Fatalf("-nodekey: %v", err) |
||||
} |
||||
case *nodeKeyHex != "": |
||||
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil { |
||||
log.Fatalf("-nodekeyhex: %v", err) |
||||
} |
||||
} |
||||
|
||||
if VmType >= int(vm.MaxVmTy) { |
||||
log.Fatal("Invalid VM type ", VmType) |
||||
} |
||||
|
||||
InputFile = flag.Arg(0) |
||||
} |
@ -0,0 +1,280 @@ |
||||
// Copyright (c) 2013-2014, Jeffrey Wilcke. All rights reserved.
|
||||
//
|
||||
// This library is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 2.1 of the License, or (at your option) any later version.
|
||||
//
|
||||
// This library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
// General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this library; if not, write to the Free Software
|
||||
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
||||
// MA 02110-1301 USA
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"bufio" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"os/signal" |
||||
"path" |
||||
"strings" |
||||
|
||||
"github.com/ethereum/go-ethereum/cmd/utils" |
||||
"github.com/ethereum/go-ethereum/core/types" |
||||
"github.com/ethereum/go-ethereum/eth" |
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
"github.com/ethereum/go-ethereum/javascript" |
||||
"github.com/ethereum/go-ethereum/state" |
||||
"github.com/ethereum/go-ethereum/xeth" |
||||
"github.com/obscuren/otto" |
||||
"github.com/peterh/liner" |
||||
) |
||||
|
||||
func execJsFile(ethereum *eth.Ethereum, filename string) { |
||||
file, err := os.Open(filename) |
||||
if err != nil { |
||||
utils.Fatalf("%v", err) |
||||
} |
||||
content, err := ioutil.ReadAll(file) |
||||
if err != nil { |
||||
utils.Fatalf("%v", err) |
||||
} |
||||
re := javascript.NewJSRE(xeth.New(ethereum)) |
||||
if _, err := re.Run(string(content)); err != nil { |
||||
utils.Fatalf("Javascript Error: %v", err) |
||||
} |
||||
} |
||||
|
||||
type repl struct { |
||||
re *javascript.JSRE |
||||
ethereum *eth.Ethereum |
||||
xeth *xeth.XEth |
||||
prompt string |
||||
lr *liner.State |
||||
} |
||||
|
||||
func runREPL(ethereum *eth.Ethereum) { |
||||
xeth := xeth.New(ethereum) |
||||
repl := &repl{ |
||||
re: javascript.NewJSRE(xeth), |
||||
xeth: xeth, |
||||
ethereum: ethereum, |
||||
prompt: "> ", |
||||
} |
||||
repl.initStdFuncs() |
||||
if !liner.TerminalSupported() { |
||||
repl.dumbRead() |
||||
} else { |
||||
lr := liner.NewLiner() |
||||
defer lr.Close() |
||||
lr.SetCtrlCAborts(true) |
||||
repl.withHistory(func(hist *os.File) { lr.ReadHistory(hist) }) |
||||
repl.read(lr) |
||||
repl.withHistory(func(hist *os.File) { hist.Truncate(0); lr.WriteHistory(hist) }) |
||||
} |
||||
} |
||||
|
||||
func (self *repl) withHistory(op func(*os.File)) { |
||||
hist, err := os.OpenFile(path.Join(self.ethereum.DataDir, "history"), os.O_RDWR|os.O_CREATE, os.ModePerm) |
||||
if err != nil { |
||||
fmt.Printf("unable to open history file: %v\n", err) |
||||
return |
||||
} |
||||
op(hist) |
||||
hist.Close() |
||||
} |
||||
|
||||
func (self *repl) parseInput(code string) { |
||||
defer func() { |
||||
if r := recover(); r != nil { |
||||
fmt.Println("[native] error", r) |
||||
} |
||||
}() |
||||
value, err := self.re.Run(code) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
return |
||||
} |
||||
self.printValue(value) |
||||
} |
||||
|
||||
var indentCount = 0 |
||||
var str = "" |
||||
|
||||
func (self *repl) setIndent() { |
||||
open := strings.Count(str, "{") |
||||
open += strings.Count(str, "(") |
||||
closed := strings.Count(str, "}") |
||||
closed += strings.Count(str, ")") |
||||
indentCount = open - closed |
||||
if indentCount <= 0 { |
||||
self.prompt = "> " |
||||
} else { |
||||
self.prompt = strings.Join(make([]string, indentCount*2), "..") |
||||
self.prompt += " " |
||||
} |
||||
} |
||||
|
||||
func (self *repl) read(lr *liner.State) { |
||||
for { |
||||
input, err := lr.Prompt(self.prompt) |
||||
if err != nil { |
||||
return |
||||
} |
||||
if input == "" { |
||||
continue |
||||
} |
||||
str += input + "\n" |
||||
self.setIndent() |
||||
if indentCount <= 0 { |
||||
if input == "exit" { |
||||
return |
||||
} |
||||
hist := str[:len(str)-1] |
||||
lr.AppendHistory(hist) |
||||
self.parseInput(str) |
||||
str = "" |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *repl) dumbRead() { |
||||
fmt.Println("Unsupported terminal, line editing will not work.") |
||||
|
||||
// process lines
|
||||
readDone := make(chan struct{}) |
||||
go func() { |
||||
r := bufio.NewReader(os.Stdin) |
||||
loop: |
||||
for { |
||||
fmt.Print(self.prompt) |
||||
line, err := r.ReadString('\n') |
||||
switch { |
||||
case err != nil || line == "exit": |
||||
break loop |
||||
case line == "": |
||||
continue |
||||
default: |
||||
self.parseInput(line + "\n") |
||||
} |
||||
} |
||||
close(readDone) |
||||
}() |
||||
|
||||
// wait for Ctrl-C
|
||||
sigc := make(chan os.Signal, 1) |
||||
signal.Notify(sigc, os.Interrupt, os.Kill) |
||||
defer signal.Stop(sigc) |
||||
|
||||
select { |
||||
case <-readDone: |
||||
case <-sigc: |
||||
os.Stdin.Close() // terminate read
|
||||
} |
||||
} |
||||
|
||||
func (self *repl) printValue(v interface{}) { |
||||
method, _ := self.re.Vm.Get("prettyPrint") |
||||
v, err := self.re.Vm.ToValue(v) |
||||
if err == nil { |
||||
val, err := method.Call(method, v) |
||||
if err == nil { |
||||
fmt.Printf("%v", val) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *repl) initStdFuncs() { |
||||
t, _ := self.re.Vm.Get("eth") |
||||
eth := t.Object() |
||||
eth.Set("connect", self.connect) |
||||
eth.Set("stopMining", self.stopMining) |
||||
eth.Set("startMining", self.startMining) |
||||
eth.Set("dump", self.dump) |
||||
eth.Set("export", self.export) |
||||
} |
||||
|
||||
/* |
||||
* The following methods are natively implemented javascript functions. |
||||
*/ |
||||
|
||||
func (self *repl) dump(call otto.FunctionCall) otto.Value { |
||||
var block *types.Block |
||||
|
||||
if len(call.ArgumentList) > 0 { |
||||
if call.Argument(0).IsNumber() { |
||||
num, _ := call.Argument(0).ToInteger() |
||||
block = self.ethereum.ChainManager().GetBlockByNumber(uint64(num)) |
||||
} else if call.Argument(0).IsString() { |
||||
hash, _ := call.Argument(0).ToString() |
||||
block = self.ethereum.ChainManager().GetBlock(ethutil.Hex2Bytes(hash)) |
||||
} else { |
||||
fmt.Println("invalid argument for dump. Either hex string or number") |
||||
} |
||||
|
||||
if block == nil { |
||||
fmt.Println("block not found") |
||||
|
||||
return otto.UndefinedValue() |
||||
} |
||||
|
||||
} else { |
||||
block = self.ethereum.ChainManager().CurrentBlock() |
||||
} |
||||
|
||||
statedb := state.New(block.Root(), self.ethereum.Db()) |
||||
|
||||
v, _ := self.re.Vm.ToValue(statedb.RawDump()) |
||||
|
||||
return v |
||||
} |
||||
|
||||
func (self *repl) stopMining(call otto.FunctionCall) otto.Value { |
||||
self.xeth.Miner().Stop() |
||||
return otto.TrueValue() |
||||
} |
||||
|
||||
func (self *repl) startMining(call otto.FunctionCall) otto.Value { |
||||
self.xeth.Miner().Start() |
||||
return otto.TrueValue() |
||||
} |
||||
|
||||
func (self *repl) connect(call otto.FunctionCall) otto.Value { |
||||
nodeURL, err := call.Argument(0).ToString() |
||||
if err != nil { |
||||
return otto.FalseValue() |
||||
} |
||||
if err := self.ethereum.SuggestPeer(nodeURL); err != nil { |
||||
return otto.FalseValue() |
||||
} |
||||
return otto.TrueValue() |
||||
} |
||||
|
||||
func (self *repl) export(call otto.FunctionCall) otto.Value { |
||||
if len(call.ArgumentList) == 0 { |
||||
fmt.Println("err: require file name") |
||||
return otto.FalseValue() |
||||
} |
||||
|
||||
fn, err := call.Argument(0).ToString() |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
return otto.FalseValue() |
||||
} |
||||
|
||||
data := self.ethereum.ChainManager().Export() |
||||
|
||||
if err := ethutil.WriteFile(fn, data); err != nil { |
||||
fmt.Println(err) |
||||
return otto.FalseValue() |
||||
} |
||||
|
||||
return otto.TrueValue() |
||||
} |
@ -1,97 +0,0 @@ |
||||
// Copyright (c) 2013-2014, Jeffrey Wilcke. All rights reserved.
|
||||
//
|
||||
// This library is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 2.1 of the License, or (at your option) any later version.
|
||||
//
|
||||
// This library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
// General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this library; if not, write to the Free Software
|
||||
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
||||
// MA 02110-1301 USA
|
||||
|
||||
/* Inspired by https://github.com/xuyu/logging/blob/master/colorful_win.go */ |
||||
|
||||
package ethrepl |
||||
|
||||
import ( |
||||
"syscall" |
||||
"unsafe" |
||||
) |
||||
|
||||
type color uint16 |
||||
|
||||
const ( |
||||
green = color(0x0002) |
||||
red = color(0x0004) |
||||
yellow = color(0x000E) |
||||
) |
||||
|
||||
const ( |
||||
mask = uint16(yellow | green | red) |
||||
) |
||||
|
||||
var ( |
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll") |
||||
procGetStdHandle = kernel32.NewProc("GetStdHandle") |
||||
procSetConsoleTextAttribute = kernel32.NewProc("SetConsoleTextAttribute") |
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") |
||||
hStdout uintptr |
||||
initScreenInfo *consoleScreenBufferInfo |
||||
) |
||||
|
||||
func setConsoleTextAttribute(hConsoleOutput uintptr, wAttributes uint16) bool { |
||||
ret, _, _ := procSetConsoleTextAttribute.Call(hConsoleOutput, uintptr(wAttributes)) |
||||
return ret != 0 |
||||
} |
||||
|
||||
type coord struct { |
||||
X, Y int16 |
||||
} |
||||
|
||||
type smallRect struct { |
||||
Left, Top, Right, Bottom int16 |
||||
} |
||||
|
||||
type consoleScreenBufferInfo struct { |
||||
DwSize coord |
||||
DwCursorPosition coord |
||||
WAttributes uint16 |
||||
SrWindow smallRect |
||||
DwMaximumWindowSize coord |
||||
} |
||||
|
||||
func getConsoleScreenBufferInfo(hConsoleOutput uintptr) *consoleScreenBufferInfo { |
||||
var csbi consoleScreenBufferInfo |
||||
ret, _, _ := procGetConsoleScreenBufferInfo.Call(hConsoleOutput, uintptr(unsafe.Pointer(&csbi))) |
||||
if ret == 0 { |
||||
return nil |
||||
} |
||||
return &csbi |
||||
} |
||||
|
||||
const ( |
||||
stdOutputHandle = uint32(-11 & 0xFFFFFFFF) |
||||
) |
||||
|
||||
func init() { |
||||
hStdout, _, _ = procGetStdHandle.Call(uintptr(stdOutputHandle)) |
||||
initScreenInfo = getConsoleScreenBufferInfo(hStdout) |
||||
} |
||||
|
||||
func resetColorful() { |
||||
if initScreenInfo == nil { |
||||
return |
||||
} |
||||
setConsoleTextAttribute(hStdout, initScreenInfo.WAttributes) |
||||
} |
||||
|
||||
func changeColor(c color) { |
||||
attr := uint16(0) & ^mask | uint16(c) |
||||
setConsoleTextAttribute(hStdout, attr) |
||||
} |
@ -1,201 +0,0 @@ |
||||
// Copyright (c) 2013-2014, Jeffrey Wilcke. All rights reserved.
|
||||
//
|
||||
// This library is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 2.1 of the License, or (at your option) any later version.
|
||||
//
|
||||
// This library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
// General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this library; if not, write to the Free Software
|
||||
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
||||
// MA 02110-1301 USA
|
||||
|
||||
package ethrepl |
||||
|
||||
import ( |
||||
"bufio" |
||||
"fmt" |
||||
"io" |
||||
"os" |
||||
"path" |
||||
|
||||
"github.com/ethereum/go-ethereum/core/types" |
||||
"github.com/ethereum/go-ethereum/eth" |
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
"github.com/ethereum/go-ethereum/javascript" |
||||
"github.com/ethereum/go-ethereum/logger" |
||||
"github.com/ethereum/go-ethereum/state" |
||||
"github.com/ethereum/go-ethereum/xeth" |
||||
"github.com/obscuren/otto" |
||||
) |
||||
|
||||
var repllogger = logger.NewLogger("REPL") |
||||
|
||||
type Repl interface { |
||||
Start() |
||||
Stop() |
||||
} |
||||
|
||||
type JSRepl struct { |
||||
re *javascript.JSRE |
||||
ethereum *eth.Ethereum |
||||
xeth *xeth.XEth |
||||
|
||||
prompt string |
||||
|
||||
history *os.File |
||||
|
||||
running bool |
||||
} |
||||
|
||||
func NewJSRepl(ethereum *eth.Ethereum) *JSRepl { |
||||
hist, err := os.OpenFile(path.Join(ethutil.Config.ExecPath, "history"), os.O_RDWR|os.O_CREATE, os.ModePerm) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
xeth := xeth.New(ethereum) |
||||
repl := &JSRepl{re: javascript.NewJSRE(xeth), xeth: xeth, ethereum: ethereum, prompt: "> ", history: hist} |
||||
repl.initStdFuncs() |
||||
|
||||
return repl |
||||
} |
||||
|
||||
func (self *JSRepl) Start() { |
||||
if !self.running { |
||||
self.running = true |
||||
repllogger.Infoln("init JS Console") |
||||
|
||||
reader := bufio.NewReader(self.history) |
||||
for { |
||||
line, err := reader.ReadString('\n') |
||||
if err != nil && err == io.EOF { |
||||
break |
||||
} else if err != nil { |
||||
fmt.Println("error reading history", err) |
||||
break |
||||
} |
||||
|
||||
addHistory(line[:len(line)-1]) |
||||
} |
||||
self.read() |
||||
} |
||||
} |
||||
|
||||
func (self *JSRepl) Stop() { |
||||
if self.running { |
||||
self.running = false |
||||
repllogger.Infoln("exit JS Console") |
||||
self.history.Close() |
||||
} |
||||
} |
||||
|
||||
func (self *JSRepl) parseInput(code string) { |
||||
defer func() { |
||||
if r := recover(); r != nil { |
||||
fmt.Println("[native] error", r) |
||||
} |
||||
}() |
||||
|
||||
value, err := self.re.Run(code) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
return |
||||
} |
||||
|
||||
self.PrintValue(value) |
||||
} |
||||
|
||||
func (self *JSRepl) initStdFuncs() { |
||||
t, _ := self.re.Vm.Get("eth") |
||||
eth := t.Object() |
||||
eth.Set("connect", self.connect) |
||||
eth.Set("stopMining", self.stopMining) |
||||
eth.Set("startMining", self.startMining) |
||||
eth.Set("dump", self.dump) |
||||
eth.Set("export", self.export) |
||||
} |
||||
|
||||
/* |
||||
* The following methods are natively implemented javascript functions |
||||
*/ |
||||
|
||||
func (self *JSRepl) dump(call otto.FunctionCall) otto.Value { |
||||
var block *types.Block |
||||
|
||||
if len(call.ArgumentList) > 0 { |
||||
if call.Argument(0).IsNumber() { |
||||
num, _ := call.Argument(0).ToInteger() |
||||
block = self.ethereum.ChainManager().GetBlockByNumber(uint64(num)) |
||||
} else if call.Argument(0).IsString() { |
||||
hash, _ := call.Argument(0).ToString() |
||||
block = self.ethereum.ChainManager().GetBlock(ethutil.Hex2Bytes(hash)) |
||||
} else { |
||||
fmt.Println("invalid argument for dump. Either hex string or number") |
||||
} |
||||
|
||||
if block == nil { |
||||
fmt.Println("block not found") |
||||
|
||||
return otto.UndefinedValue() |
||||
} |
||||
|
||||
} else { |
||||
block = self.ethereum.ChainManager().CurrentBlock() |
||||
} |
||||
|
||||
statedb := state.New(block.Root(), self.ethereum.Db()) |
||||
|
||||
v, _ := self.re.Vm.ToValue(statedb.RawDump()) |
||||
|
||||
return v |
||||
} |
||||
|
||||
func (self *JSRepl) stopMining(call otto.FunctionCall) otto.Value { |
||||
self.xeth.Miner().Stop() |
||||
|
||||
return otto.TrueValue() |
||||
} |
||||
|
||||
func (self *JSRepl) startMining(call otto.FunctionCall) otto.Value { |
||||
self.xeth.Miner().Start() |
||||
return otto.TrueValue() |
||||
} |
||||
|
||||
func (self *JSRepl) connect(call otto.FunctionCall) otto.Value { |
||||
nodeURL, err := call.Argument(0).ToString() |
||||
if err != nil { |
||||
return otto.FalseValue() |
||||
} |
||||
if err := self.ethereum.SuggestPeer(nodeURL); err != nil { |
||||
return otto.FalseValue() |
||||
} |
||||
return otto.TrueValue() |
||||
} |
||||
|
||||
func (self *JSRepl) export(call otto.FunctionCall) otto.Value { |
||||
if len(call.ArgumentList) == 0 { |
||||
fmt.Println("err: require file name") |
||||
return otto.FalseValue() |
||||
} |
||||
|
||||
fn, err := call.Argument(0).ToString() |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
return otto.FalseValue() |
||||
} |
||||
|
||||
data := self.ethereum.ChainManager().Export() |
||||
|
||||
if err := ethutil.WriteFile(fn, data); err != nil { |
||||
fmt.Println(err) |
||||
return otto.FalseValue() |
||||
} |
||||
|
||||
return otto.TrueValue() |
||||
} |
@ -1,144 +0,0 @@ |
||||
// Copyright (c) 2013-2014, Jeffrey Wilcke. All rights reserved.
|
||||
//
|
||||
// This library is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 2.1 of the License, or (at your option) any later version.
|
||||
//
|
||||
// This library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
// General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this library; if not, write to the Free Software
|
||||
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
||||
// MA 02110-1301 USA
|
||||
|
||||
package ethrepl |
||||
|
||||
// #cgo darwin CFLAGS: -I/usr/local/opt/readline/include
|
||||
// #cgo darwin LDFLAGS: -L/usr/local/opt/readline/lib
|
||||
// #cgo LDFLAGS: -lreadline
|
||||
// #include <stdio.h>
|
||||
// #include <stdlib.h>
|
||||
// #include <readline/readline.h>
|
||||
// #include <readline/history.h>
|
||||
import "C" |
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"os/signal" |
||||
"strings" |
||||
"syscall" |
||||
"unsafe" |
||||
) |
||||
|
||||
func initReadLine() { |
||||
C.rl_catch_sigwinch = 0 |
||||
C.rl_catch_signals = 0 |
||||
c := make(chan os.Signal, 1) |
||||
signal.Notify(c, syscall.SIGWINCH) |
||||
signal.Notify(c, os.Interrupt) |
||||
go func() { |
||||
for sig := range c { |
||||
switch sig { |
||||
case syscall.SIGWINCH: |
||||
C.rl_resize_terminal() |
||||
|
||||
case os.Interrupt: |
||||
C.rl_cleanup_after_signal() |
||||
default: |
||||
|
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
func readLine(prompt *string) *string { |
||||
var p *C.char |
||||
|
||||
//readline allows an empty prompt(NULL)
|
||||
if prompt != nil { |
||||
p = C.CString(*prompt) |
||||
} |
||||
|
||||
ret := C.readline(p) |
||||
|
||||
if p != nil { |
||||
C.free(unsafe.Pointer(p)) |
||||
} |
||||
|
||||
if ret == nil { |
||||
return nil |
||||
} //EOF
|
||||
|
||||
s := C.GoString(ret) |
||||
C.free(unsafe.Pointer(ret)) |
||||
return &s |
||||
} |
||||
|
||||
func addHistory(s string) { |
||||
p := C.CString(s) |
||||
C.add_history(p) |
||||
C.free(unsafe.Pointer(p)) |
||||
} |
||||
|
||||
var indentCount = 0 |
||||
var str = "" |
||||
|
||||
func (self *JSRepl) setIndent() { |
||||
open := strings.Count(str, "{") |
||||
open += strings.Count(str, "(") |
||||
closed := strings.Count(str, "}") |
||||
closed += strings.Count(str, ")") |
||||
indentCount = open - closed |
||||
if indentCount <= 0 { |
||||
self.prompt = "> " |
||||
} else { |
||||
self.prompt = strings.Join(make([]string, indentCount*2), "..") |
||||
self.prompt += " " |
||||
} |
||||
} |
||||
|
||||
func (self *JSRepl) read() { |
||||
initReadLine() |
||||
L: |
||||
for { |
||||
switch result := readLine(&self.prompt); true { |
||||
case result == nil: |
||||
break L |
||||
|
||||
case *result != "": |
||||
str += *result + "\n" |
||||
|
||||
self.setIndent() |
||||
|
||||
if indentCount <= 0 { |
||||
if *result == "exit" { |
||||
self.Stop() |
||||
break L |
||||
} |
||||
|
||||
hist := str[:len(str)-1] |
||||
addHistory(hist) //allow user to recall this line
|
||||
self.history.WriteString(str) |
||||
|
||||
self.parseInput(str) |
||||
|
||||
str = "" |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (self *JSRepl) PrintValue(v interface{}) { |
||||
method, _ := self.re.Vm.Get("prettyPrint") |
||||
v, err := self.re.Vm.ToValue(v) |
||||
if err == nil { |
||||
val, err := method.Call(method, v) |
||||
if err == nil { |
||||
fmt.Printf("%v", val) |
||||
} |
||||
} |
||||
} |
@ -1 +0,0 @@ |
||||
repl_darwin.go |
@ -1,92 +0,0 @@ |
||||
// Copyright (c) 2013-2014, Jeffrey Wilcke. All rights reserved.
|
||||
//
|
||||
// This library is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 2.1 of the License, or (at your option) any later version.
|
||||
//
|
||||
// This library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
// General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this library; if not, write to the Free Software
|
||||
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
||||
// MA 02110-1301 USA
|
||||
|
||||
package ethrepl |
||||
|
||||
import ( |
||||
"bufio" |
||||
"fmt" |
||||
"os" |
||||
"strings" |
||||
) |
||||
|
||||
func (self *JSRepl) read() { |
||||
reader := bufio.NewReader(os.Stdin) |
||||
for { |
||||
fmt.Printf(self.prompt) |
||||
str, _, err := reader.ReadLine() |
||||
if err != nil { |
||||
fmt.Println("Error reading input", err) |
||||
} else { |
||||
if string(str) == "exit" { |
||||
self.Stop() |
||||
break |
||||
} else { |
||||
self.parseInput(string(str)) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func addHistory(s string) { |
||||
} |
||||
|
||||
func printColored(outputVal string) { |
||||
for outputVal != "" { |
||||
codePart := "" |
||||
if strings.HasPrefix(outputVal, "\033[32m") { |
||||
codePart = "\033[32m" |
||||
changeColor(2) |
||||
} |
||||
if strings.HasPrefix(outputVal, "\033[1m\033[30m") { |
||||
codePart = "\033[1m\033[30m" |
||||
changeColor(8) |
||||
} |
||||
if strings.HasPrefix(outputVal, "\033[31m") { |
||||
codePart = "\033[31m" |
||||
changeColor(red) |
||||
} |
||||
if strings.HasPrefix(outputVal, "\033[35m") { |
||||
codePart = "\033[35m" |
||||
changeColor(5) |
||||
} |
||||
if strings.HasPrefix(outputVal, "\033[0m") { |
||||
codePart = "\033[0m" |
||||
resetColorful() |
||||
} |
||||
textPart := outputVal[len(codePart):len(outputVal)] |
||||
index := strings.Index(textPart, "\033") |
||||
if index == -1 { |
||||
outputVal = "" |
||||
} else { |
||||
outputVal = textPart[index:len(textPart)] |
||||
textPart = textPart[0:index] |
||||
} |
||||
fmt.Printf("%v", textPart) |
||||
} |
||||
} |
||||
|
||||
func (self *JSRepl) PrintValue(v interface{}) { |
||||
method, _ := self.re.Vm.Get("prettyPrint") |
||||
v, err := self.re.Vm.ToValue(v) |
||||
if err == nil { |
||||
val, err := method.Call(method, v) |
||||
if err == nil { |
||||
printColored(fmt.Sprintf("%v", val)) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,178 @@ |
||||
package utils |
||||
|
||||
import ( |
||||
"crypto/ecdsa" |
||||
"path" |
||||
"runtime" |
||||
|
||||
"github.com/codegangsta/cli" |
||||
"github.com/ethereum/go-ethereum/core" |
||||
"github.com/ethereum/go-ethereum/crypto" |
||||
"github.com/ethereum/go-ethereum/eth" |
||||
"github.com/ethereum/go-ethereum/ethdb" |
||||
"github.com/ethereum/go-ethereum/ethutil" |
||||
"github.com/ethereum/go-ethereum/event" |
||||
"github.com/ethereum/go-ethereum/logger" |
||||
"github.com/ethereum/go-ethereum/p2p" |
||||
"github.com/ethereum/go-ethereum/p2p/nat" |
||||
) |
||||
|
||||
// These are all the command line flags we support.
|
||||
// If you add to this list, please remember to include the
|
||||
// flag in the appropriate command definition.
|
||||
//
|
||||
// The flags are defined here so their names and help texts
|
||||
// are the same for all commands.
|
||||
|
||||
var ( |
||||
// General settings
|
||||
VMTypeFlag = cli.IntFlag{ |
||||
Name: "vm", |
||||
Usage: "Virtual Machine type: 0 is standard VM, 1 is debug VM", |
||||
} |
||||
KeyRingFlag = cli.StringFlag{ |
||||
Name: "keyring", |
||||
Usage: "Name of keyring to be used", |
||||
Value: "", |
||||
} |
||||
KeyStoreFlag = cli.StringFlag{ |
||||
Name: "keystore", |
||||
Usage: `Where to store keyrings: "db" or "file"`, |
||||
Value: "db", |
||||
} |
||||
DataDirFlag = cli.StringFlag{ |
||||
Name: "datadir", |
||||
Usage: "Data directory to be used", |
||||
Value: ethutil.DefaultDataDir(), |
||||
} |
||||
MinerThreadsFlag = cli.IntFlag{ |
||||
Name: "minerthreads", |
||||
Usage: "Number of miner threads", |
||||
Value: runtime.NumCPU(), |
||||
} |
||||
MiningEnabledFlag = cli.BoolFlag{ |
||||
Name: "mine", |
||||
Usage: "Enable mining", |
||||
} |
||||
|
||||
LogFileFlag = cli.StringFlag{ |
||||
Name: "logfile", |
||||
Usage: "Send log output to a file", |
||||
} |
||||
LogLevelFlag = cli.IntFlag{ |
||||
Name: "loglevel", |
||||
Usage: "0-5 (silent, error, warn, info, debug, debug detail)", |
||||
Value: int(logger.InfoLevel), |
||||
} |
||||
LogFormatFlag = cli.StringFlag{ |
||||
Name: "logformat", |
||||
Usage: `"std" or "raw"`, |
||||
Value: "std", |
||||
} |
||||
|
||||
// RPC settings
|
||||
RPCEnabledFlag = cli.BoolFlag{ |
||||
Name: "rpc", |
||||
Usage: "Whether RPC server is enabled", |
||||
} |
||||
RPCListenAddrFlag = cli.StringFlag{ |
||||
Name: "rpcaddr", |
||||
Usage: "Listening address for the JSON-RPC server", |
||||
Value: "127.0.0.1", |
||||
} |
||||
RPCPortFlag = cli.IntFlag{ |
||||
Name: "rpcport", |
||||
Usage: "Port on which the JSON-RPC server should listen", |
||||
Value: 8545, |
||||
} |
||||
|
||||
// Network Settings
|
||||
MaxPeersFlag = cli.IntFlag{ |
||||
Name: "maxpeers", |
||||
Usage: "Maximum number of network peers", |
||||
Value: 16, |
||||
} |
||||
ListenPortFlag = cli.IntFlag{ |
||||
Name: "port", |
||||
Usage: "Network listening port", |
||||
Value: 30303, |
||||
} |
||||
BootnodesFlag = cli.StringFlag{ |
||||
Name: "bootnodes", |
||||
Usage: "Space-separated enode URLs for discovery bootstrap", |
||||
Value: "", |
||||
} |
||||
NodeKeyFileFlag = cli.StringFlag{ |
||||
Name: "nodekey", |
||||
Usage: "P2P node key file", |
||||
} |
||||
NodeKeyHexFlag = cli.StringFlag{ |
||||
Name: "nodekeyhex", |
||||
Usage: "P2P node key as hex (for testing)", |
||||
} |
||||
NATFlag = cli.StringFlag{ |
||||
Name: "nat", |
||||
Usage: "Port mapping mechanism (any|none|upnp|pmp|extip:<IP>)", |
||||
Value: "any", |
||||
} |
||||
) |
||||
|
||||
func GetNAT(ctx *cli.Context) nat.Interface { |
||||
natif, err := nat.Parse(ctx.GlobalString(NATFlag.Name)) |
||||
if err != nil { |
||||
Fatalf("Option %s: %v", NATFlag.Name, err) |
||||
} |
||||
return natif |
||||
} |
||||
|
||||
func GetNodeKey(ctx *cli.Context) (key *ecdsa.PrivateKey) { |
||||
hex, file := ctx.GlobalString(NodeKeyHexFlag.Name), ctx.GlobalString(NodeKeyFileFlag.Name) |
||||
var err error |
||||
switch { |
||||
case file != "" && hex != "": |
||||
Fatalf("Options %q and %q are mutually exclusive", NodeKeyFileFlag.Name, NodeKeyHexFlag.Name) |
||||
case file != "": |
||||
if key, err = crypto.LoadECDSA(file); err != nil { |
||||
Fatalf("Option %q: %v", NodeKeyFileFlag.Name, err) |
||||
} |
||||
case hex != "": |
||||
if key, err = crypto.HexToECDSA(hex); err != nil { |
||||
Fatalf("Option %q: %v", NodeKeyHexFlag.Name, err) |
||||
} |
||||
} |
||||
return key |
||||
} |
||||
|
||||
func GetEthereum(clientID, version string, ctx *cli.Context) *eth.Ethereum { |
||||
ethereum, err := eth.New(ð.Config{ |
||||
Name: p2p.MakeName(clientID, version), |
||||
KeyStore: ctx.GlobalString(KeyStoreFlag.Name), |
||||
DataDir: ctx.GlobalString(DataDirFlag.Name), |
||||
LogFile: ctx.GlobalString(LogFileFlag.Name), |
||||
LogLevel: ctx.GlobalInt(LogLevelFlag.Name), |
||||
LogFormat: ctx.GlobalString(LogFormatFlag.Name), |
||||
MinerThreads: ctx.GlobalInt(MinerThreadsFlag.Name), |
||||
|
||||
MaxPeers: ctx.GlobalInt(MaxPeersFlag.Name), |
||||
Port: ctx.GlobalString(ListenPortFlag.Name), |
||||
NAT: GetNAT(ctx), |
||||
NodeKey: GetNodeKey(ctx), |
||||
KeyRing: ctx.GlobalString(KeyRingFlag.Name), |
||||
Shh: true, |
||||
Dial: true, |
||||
BootNodes: ctx.GlobalString(BootnodesFlag.Name), |
||||
}) |
||||
if err != nil { |
||||
exit(err) |
||||
} |
||||
return ethereum |
||||
} |
||||
|
||||
func GetChain(ctx *cli.Context) (*core.ChainManager, ethutil.Database) { |
||||
dataDir := ctx.GlobalString(DataDirFlag.Name) |
||||
db, err := ethdb.NewLDBDatabase(path.Join(dataDir, "blockchain")) |
||||
if err != nil { |
||||
Fatalf("Could not open database: %v", err) |
||||
} |
||||
return core.NewChainManager(db, new(event.TypeMux)), db |
||||
} |
Loading…
Reference in new issue