diff --git a/accounts/abi/bind/bind.go b/accounts/abi/bind/bind.go index 8b587f1aa..a9f21b21a 100644 --- a/accounts/abi/bind/bind.go +++ b/accounts/abi/bind/bind.go @@ -125,8 +125,8 @@ func bindType(kind abi.Type) string { case stringKind == "address": return "common.Address" - case stringKind == "hash": - return "common.Hash" + case stringKind == "address[]": + return "[]common.Address" case strings.HasPrefix(stringKind, "bytes"): if stringKind == "bytes" { diff --git a/accounts/abi/bind/bind_test.go b/accounts/abi/bind/bind_test.go index abe60b22c..37b8ef5a7 100644 --- a/accounts/abi/bind/bind_test.go +++ b/accounts/abi/bind/bind_test.go @@ -247,41 +247,41 @@ func TestBindings(t *testing.T) { if !strings.Contains(string(linkTestDeps), "go-ethereum") { t.Skip("symlinked environment doesn't support bind (https://github.com/golang/go/issues/14845)") } - // All is well, run the tests - for i, tt := range bindTests { - // Create a temporary workspace for this test - ws, err := ioutil.TempDir("", "") - if err != nil { - t.Fatalf("test %d: failed to create temporary workspace: %v", i, err) - } - defer os.RemoveAll(ws) + // Create a temporary workspace for the test suite + ws, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("failed to create temporary workspace: %v", err) + } + defer os.RemoveAll(ws) - // Generate the binding and create a Go package in the workspace + pkg := filepath.Join(ws, "bindtest") + if err = os.MkdirAll(pkg, 0700); err != nil { + t.Fatalf("failed to create package: %v", err) + } + // Generate the test suite for all the contracts + for i, tt := range bindTests { + // Generate the binding and create a Go source file in the workspace bind, err := Bind([]string{tt.name}, []string{tt.abi}, []string{tt.bytecode}, "bindtest") if err != nil { t.Fatalf("test %d: failed to generate binding: %v", i, err) } - pkg := filepath.Join(ws, "bindtest") - if err = os.MkdirAll(pkg, 0700); err != nil { - t.Fatalf("test %d: failed to create package: %v", i, err) - } - if err = ioutil.WriteFile(filepath.Join(pkg, "main.go"), []byte(bind), 0600); err != nil { + if err = ioutil.WriteFile(filepath.Join(pkg, strings.ToLower(tt.name)+".go"), []byte(bind), 0600); err != nil { t.Fatalf("test %d: failed to write binding: %v", i, err) } // Generate the test file with the injected test code - code := fmt.Sprintf("package bindtest\nimport \"testing\"\nfunc TestBinding%d(t *testing.T){\n%s\n}", i, tt.tester) + code := fmt.Sprintf("package bindtest\nimport \"testing\"\nfunc Test%s(t *testing.T){\n%s\n}", tt.name, tt.tester) blob, err := imports.Process("", []byte(code), nil) if err != nil { t.Fatalf("test %d: failed to generate tests: %v", i, err) } - if err := ioutil.WriteFile(filepath.Join(pkg, "main_test.go"), blob, 0600); err != nil { + if err := ioutil.WriteFile(filepath.Join(pkg, strings.ToLower(tt.name)+"_test.go"), blob, 0600); err != nil { t.Fatalf("test %d: failed to write tests: %v", i, err) } - // Test the entire package and report any failures - cmd := exec.Command(gocmd, "test", "-v") - cmd.Dir = pkg - if out, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("test %d: failed to run binding test: %v\n%s\n%s", i, err, out, bind) - } + } + // Test the entire package and report any failures + cmd := exec.Command(gocmd, "test", "-v") + cmd.Dir = pkg + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("failed to run binding test: %v\n%s", err, out) } } diff --git a/accounts/abi/bind/template.go b/accounts/abi/bind/template.go index 30df66365..f1a10137c 100644 --- a/accounts/abi/bind/template.go +++ b/accounts/abi/bind/template.go @@ -147,22 +147,17 @@ package {{.Package}} } {{range .Calls}} - {{if .Structured}} - // {{.Normalized.Name}}Result is the result of the {{.Normalized.Name}} invocation." - type {{.Normalized.Name}}Result struct { - {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}} - {{end}} - } - {{end}} - // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}Caller) {{.Normalized.Name}}(opts *bind.CallOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}{{.Normalized.Name}}Result,{{else}}{{range .Normalized.Outputs}}{{bindtype .Type}},{{end}}{{end}} error) { - var ( - {{if .Structured}}ret = new({{.Normalized.Name}}Result){{else}}{{range $i, $_ := .Normalized.Outputs}}ret{{$i}} = new({{bindtype .Type}}) - {{end}}{{end}} - ) + func (_{{$contract.Type}} *{{$contract.Type}}Caller) {{.Normalized.Name}}(opts *bind.CallOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} },{{else}}{{range .Normalized.Outputs}}{{bindtype .Type}},{{end}}{{end}} error) { + {{if .Structured}}ret := new(struct{ + {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}} + {{end}} + }){{else}}var ( + {{range $i, $_ := .Normalized.Outputs}}ret{{$i}} = new({{bindtype .Type}}) + {{end}} + ){{end}} out := {{if .Structured}}ret{{else}}{{if eq (len .Normalized.Outputs) 1}}ret0{{else}}[]interface{}{ {{range $i, $_ := .Normalized.Outputs}}ret{{$i}}, {{end}} @@ -174,14 +169,14 @@ package {{.Package}} // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}{{.Normalized.Name}}Result, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type}},{{end}} {{end}} error) { + func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type}},{{end}} {{end}} error) { return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.CallOpts {{range .Normalized.Inputs}}, {{.Name}}{{end}}) } // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}CallerSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}{{.Normalized.Name}}Result, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type}},{{end}} {{end}} error) { + func (_{{$contract.Type}} *{{$contract.Type}}CallerSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type}},{{end}} {{end}} error) { return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.CallOpts {{range .Normalized.Inputs}}, {{.Name}}{{end}}) } {{end}} diff --git a/cmd/abigen/main.go b/cmd/abigen/main.go index 329e9b109..88d27e443 100644 --- a/cmd/abigen/main.go +++ b/cmd/abigen/main.go @@ -22,6 +22,7 @@ import ( "fmt" "io/ioutil" "os" + "strings" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common/compiler" @@ -34,6 +35,7 @@ var ( solFlag = flag.String("sol", "", "Path to the Ethereum contract Solidity source to build and bind") solcFlag = flag.String("solc", "solc", "Solidity compiler to use if source builds are requested") + excFlag = flag.String("exc", "", "Comma separated types to exclude from binding") pkgFlag = flag.String("pkg", "", "Go package name to generate the binding into") outFlag = flag.String("out", "", "Output file for the generated binding (default = stdout)") @@ -61,6 +63,12 @@ func main() { types []string ) if *solFlag != "" { + // Generate the list of types to exclude from binding + exclude := make(map[string]bool) + for _, kind := range strings.Split(*excFlag, ",") { + exclude[strings.ToLower(kind)] = true + } + // Build the Solidity source into bindable components solc, err := compiler.New(*solcFlag) if err != nil { fmt.Printf("Failed to locate Solidity compiler: %v\n", err) @@ -76,7 +84,11 @@ func main() { fmt.Printf("Failed to build Solidity contract: %v\n", err) os.Exit(-1) } + // Gather all non-excluded contract for binding for name, contract := range contracts { + if exclude[strings.ToLower(name)] { + continue + } abi, _ := json.Marshal(contract.Info.AbiDefinition) // Flatten the compiler parse abis = append(abis, string(abi)) bins = append(bins, contract.Code)