diff --git a/cmd/abigen/main.go b/cmd/abigen/main.go index 075e98930e67..83b6c5e4289f 100644 --- a/cmd/abigen/main.go +++ b/cmd/abigen/main.go @@ -154,9 +154,12 @@ func abigen(c *cli.Context) error { types = append(types, kind) } else { // Generate the list of types to exclude from binding - exclude := make(map[string]bool) - for _, kind := range strings.Split(c.String(excFlag.Name), ",") { - exclude[strings.ToLower(kind)] = true + var exclude *nameFilter + if c.IsSet(excFlag.Name) { + var err error + if exclude, err = newNameFilter(strings.Split(c.String(excFlag.Name), ",")...); err != nil { + utils.Fatalf("Failed to parse excludes: %v", err) + } } var contracts map[string]*compiler.Contract @@ -181,7 +184,11 @@ func abigen(c *cli.Context) error { } // Gather all non-excluded contract for binding for name, contract := range contracts { - if exclude[strings.ToLower(name)] { + // fully qualified name is of the form : + nameParts := strings.Split(name, ":") + typeName := nameParts[len(nameParts)-1] + if exclude != nil && exclude.Matches(name) { + fmt.Fprintf(os.Stderr, "excluding: %v\n", name) continue } abi, err := json.Marshal(contract.Info.AbiDefinition) // Flatten the compiler parse @@ -191,15 +198,14 @@ func abigen(c *cli.Context) error { abis = append(abis, string(abi)) bins = append(bins, contract.Code) sigs = append(sigs, contract.Hashes) - nameParts := strings.Split(name, ":") - types = append(types, nameParts[len(nameParts)-1]) + types = append(types, typeName) // Derive the library placeholder which is a 34 character prefix of the // hex encoding of the keccak256 hash of the fully qualified library name. // Note that the fully qualified library name is the path of its source // file and the library name separated by ":". libPattern := crypto.Keccak256Hash([]byte(name)).String()[2:36] // the first 2 chars are 0x - libs[libPattern] = nameParts[len(nameParts)-1] + libs[libPattern] = typeName } } // Extract all aliases from the flags diff --git a/cmd/abigen/namefilter.go b/cmd/abigen/namefilter.go new file mode 100644 index 000000000000..eea5c643c442 --- /dev/null +++ b/cmd/abigen/namefilter.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "strings" +) + +type nameFilter struct { + fulls map[string]bool // path/to/contract.sol:Type + files map[string]bool // path/to/contract.sol:* + types map[string]bool // *:Type +} + +func newNameFilter(patterns ...string) (*nameFilter, error) { + f := &nameFilter{ + fulls: make(map[string]bool), + files: make(map[string]bool), + types: make(map[string]bool), + } + for _, pattern := range patterns { + if err := f.add(pattern); err != nil { + return nil, err + } + } + return f, nil +} + +func (f *nameFilter) add(pattern string) error { + ft := strings.Split(pattern, ":") + if len(ft) != 2 { + // filenames and types must not include ':' symbol + return fmt.Errorf("invalid pattern: %s", pattern) + } + + file, typ := ft[0], ft[1] + if file == "*" { + f.types[typ] = true + return nil + } else if typ == "*" { + f.files[file] = true + return nil + } + f.fulls[pattern] = true + return nil +} + +func (f *nameFilter) Matches(name string) bool { + ft := strings.Split(name, ":") + if len(ft) != 2 { + // If contract names are always of the fully-qualified form + // :, then this case will never happen. + return false + } + + file, typ := ft[0], ft[1] + // full paths > file paths > types + return f.fulls[name] || f.files[file] || f.types[typ] +} diff --git a/cmd/abigen/namefilter_test.go b/cmd/abigen/namefilter_test.go new file mode 100644 index 000000000000..42ba55be5eb5 --- /dev/null +++ b/cmd/abigen/namefilter_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNameFilter(t *testing.T) { + _, err := newNameFilter("Foo") + require.Error(t, err) + _, err = newNameFilter("too/many:colons:Foo") + require.Error(t, err) + + f, err := newNameFilter("a/path:A", "*:B", "c/path:*") + require.NoError(t, err) + + for _, tt := range []struct { + name string + match bool + }{ + {"a/path:A", true}, + {"unknown/path:A", false}, + {"a/path:X", false}, + {"unknown/path:X", false}, + {"any/path:B", true}, + {"c/path:X", true}, + {"c/path:foo:B", false}, + } { + match := f.Matches(tt.name) + if tt.match { + assert.True(t, match, "expected match") + } else { + assert.False(t, match, "expected no match") + } + } +}