diff --git a/cmd/geth/config.go b/cmd/geth/config.go index 21edd9a32fd4..6c40f5529867 100644 --- a/cmd/geth/config.go +++ b/cmd/geth/config.go @@ -194,7 +194,8 @@ func makeFullNode(ctx *cli.Context) (*node.Node, ethapi.Backend) { switch dbType { case shared.FILE: indexerConfig = file.Config{ - FilePath: ctx.GlobalString(utils.StateDiffFilePath.Name), + FilePath: ctx.GlobalString(utils.StateDiffFilePath.Name), + WatchedAddressesFilePath: ctx.GlobalString(utils.StateDiffWatchedAddressesFilePath.Name), } case shared.POSTGRES: driverTypeStr := ctx.GlobalString(utils.StateDiffDBDriverTypeFlag.Name) diff --git a/cmd/geth/main.go b/cmd/geth/main.go index f931d3ffa002..1bc38c8b0dc9 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -178,6 +178,7 @@ var ( utils.StateDiffFilePath, utils.StateDiffKnownGapsFilePath, utils.StateDiffWaitForSync, + utils.StateDiffWatchedAddressesFilePath, configFileFlag, } diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index c8338ac5fd69..6cd6df9466d2 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -248,6 +248,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.StateDiffFilePath, utils.StateDiffKnownGapsFilePath, utils.StateDiffWaitForSync, + utils.StateDiffWatchedAddressesFilePath, }, }, { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 92f8cfc93ab5..275f70a9834e 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -872,6 +872,10 @@ var ( Usage: "Full path (including filename) to write knownGaps statements when the DB is unavailable.", Value: "./known_gaps.sql", } + StateDiffWatchedAddressesFilePath = cli.StringFlag{ + Name: "statediff.file.wapath", + Usage: "Full path (including filename) to write statediff watched addresses out to when operating in file mode", + } StateDiffDBClientNameFlag = cli.StringFlag{ Name: "statediff.db.clientname", Usage: "Client name to use when writing state diffs to database", diff --git a/go.mod b/go.mod index 669d85d16c28..cd2a78f43d29 100644 --- a/go.mod +++ b/go.mod @@ -63,6 +63,7 @@ require ( github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 github.com/olekukonko/tablewriter v0.0.5 github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 + github.com/pganalyze/pg_query_go/v2 v2.1.0 github.com/prometheus/tsdb v0.7.1 github.com/rjeczalik/notify v0.9.1 github.com/rs/cors v1.7.0 @@ -70,6 +71,7 @@ require ( github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4 github.com/stretchr/testify v1.7.0 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 + github.com/thoas/go-funk v0.9.2 github.com/tklauser/go-sysconf v0.3.5 // indirect github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 diff --git a/go.sum b/go.sum index 136d318067e0..26e893ea9edf 100644 --- a/go.sum +++ b/go.sum @@ -222,6 +222,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.1.1-0.20200604201612-c04b05f3adfa h1:Q75Upo5UN4JbPFURXZ8nLKYUvF85dyFRop/vQ0Rv+64= @@ -516,6 +517,8 @@ github.com/paulbellamy/ratecounter v0.2.0/go.mod h1:Hfx1hDpSGoqxkVVpBi/IlYD7kChl github.com/peterh/liner v1.0.1-0.20180619022028-8c1271fcf47f/go.mod h1:xIteQHvHuaLYG9IFj6mSxM0fCKrs34IrEQUhOYuGPHc= github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 h1:oYW+YCJ1pachXTQmzR3rNLYGGz4g/UgFcjb28p/viDM= github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0= +github.com/pganalyze/pg_query_go/v2 v2.1.0 h1:donwPZ4G/X+kMs7j5eYtKjdziqyOLVp3pkUrzb9lDl8= +github.com/pganalyze/pg_query_go/v2 v2.1.0/go.mod h1:XAxmVqz1tEGqizcQ3YSdN90vCOHBWjJi8URL1er5+cA= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -585,6 +588,8 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= +github.com/thoas/go-funk v0.9.2 h1:oKlNYv0AY5nyf9g+/GhMgS/UO2ces0QRdPKwkhY3VCk= +github.com/thoas/go-funk v0.9.2/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tklauser/go-sysconf v0.3.5 h1:uu3Xl4nkLzQfXNsWn15rPc/HQCJKObbt1dKJeWp3vU4= github.com/tklauser/go-sysconf v0.3.5/go.mod h1:MkWzOF4RMCshBAMXuhXJs64Rte09mITnppBXY/rYEFI= diff --git a/statediff/README.md b/statediff/README.md index ef9509a7db61..f262d7a8ecb4 100644 --- a/statediff/README.md +++ b/statediff/README.md @@ -120,6 +120,8 @@ This service introduces a CLI flag namespace `statediff` `--statediff.file.path` full path (including filename) to write statediff data out to when operating in file mode +`--statediff.file.wapath` full path (including filename) to write statediff watched addresses out to when operating in file mode + The service can only operate in full sync mode (`--syncmode=full`), but only the historical RPC endpoints require an archive node (`--gcmode=archive`) e.g. @@ -148,15 +150,13 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address - WatchedStorageSlots []common.Hash } ``` Using these params we can tell the service whether to include state and/or storage intermediate nodes; whether to include the associated block (header, uncles, and transactions); whether to include the associated receipts; whether to include the total difficulty for this block; whether to include the set of code hashes and code for -contracts deployed in this block; whether to limit the diffing process to a list of specific addresses; and/or -whether to limit the diffing process to a list of specific storage slot keys. +contracts deployed in this block; whether to limit the diffing process to a list of specific addresses. #### Subscription endpoint diff --git a/statediff/api.go b/statediff/api.go index 5c534cddb589..0a7c5bba8cc7 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -149,3 +149,8 @@ func (api *PublicStateDiffAPI) WriteStateDiffAt(ctx context.Context, blockNumber func (api *PublicStateDiffAPI) WriteStateDiffFor(ctx context.Context, blockHash common.Hash, params Params) error { return api.sds.WriteStateDiffFor(blockHash, params) } + +// WatchAddress changes the list of watched addresses to which the direct indexing is restricted according to given operation +func (api *PublicStateDiffAPI) WatchAddress(operation types.OperationType, args []types.WatchAddressArg) error { + return api.sds.WatchAddress(operation, args) +} diff --git a/statediff/builder.go b/statediff/builder.go index 7811c3e829c9..3811c4cce6d5 100644 --- a/statediff/builder.go +++ b/statediff/builder.go @@ -123,7 +123,7 @@ func (sdb *builder) buildStateTrie(it trie.NodeIterator) ([]types2.StateNode, [] node.LeafKey = leafKey if !bytes.Equal(account.CodeHash, nullCodeHash) { var storageNodes []types2.StorageNode - err := sdb.buildStorageNodesEventual(account.Root, nil, true, storageNodeAppender(&storageNodes)) + err := sdb.buildStorageNodesEventual(account.Root, true, storageNodeAppender(&storageNodes)) if err != nil { return nil, nil, fmt.Errorf("failed building eventual storage diffs for account %+v\r\nerror: %v", account, err) } @@ -202,7 +202,7 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args types2.StateRo // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, output) + diffPathsAtB, params.watchedAddressesLeafKeys, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -220,12 +220,12 @@ func (sdb *builder) buildStateDiffWithIntermediateStateNodes(args types2.StateRo // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two err = sdb.buildAccountUpdates( diffAccountsAtB, diffAccountsAtA, updatedKeys, - params.WatchedStorageSlots, params.IntermediateStorageNodes, output) + params.IntermediateStorageNodes, output) if err != nil { return fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) + err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput) if err != nil { return fmt.Errorf("error building diff for created accounts: %v", err) } @@ -247,7 +247,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat // and a slice of all the paths for the nodes in both of the above sets diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - params.WatchedAddresses) + params.watchedAddressesLeafKeys) if err != nil { return fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err) } @@ -256,7 +256,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat // a map of their leafkey to all the accounts that were touched and exist at A diffAccountsAtA, err := sdb.deletedOrUpdatedState( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, output) + diffPathsAtB, params.watchedAddressesLeafKeys, output) if err != nil { return fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } @@ -274,12 +274,12 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat // build the diff nodes for the updated accounts using the mappings at both A and B as directed by the keys found as the intersection of the two err = sdb.buildAccountUpdates( diffAccountsAtB, diffAccountsAtA, updatedKeys, - params.WatchedStorageSlots, params.IntermediateStorageNodes, output) + params.IntermediateStorageNodes, output) if err != nil { return fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - err = sdb.buildAccountCreations(diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes, output, codeOutput) + err = sdb.buildAccountCreations(diffAccountsAtB, params.IntermediateStorageNodes, output, codeOutput) if err != nil { return fmt.Errorf("error building diff for created accounts: %v", err) } @@ -289,7 +289,7 @@ func (sdb *builder) buildStateDiffWithoutIntermediateStateNodes(args types2.Stat // createdAndUpdatedState returns // a mapping of their leafkeys to all the accounts that exist in a different state at B than A // and a slice of the paths for all of the nodes included in both -func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddresses []common.Address) (types2.AccountMap, map[string]bool, error) { +func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddressesLeafKeys map[common.Hash]struct{}) (types2.AccountMap, map[string]bool, error) { diffPathsAtB := make(map[string]bool) diffAcountsAtB := make(types2.AccountMap) it, _ := trie.NewDifferenceIterator(a, b) @@ -313,7 +313,7 @@ func (sdb *builder) createdAndUpdatedState(a, b trie.NodeIterator, watchedAddres valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedAddress(watchedAddresses, leafKey) { + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { diffAcountsAtB[common.Bytes2Hex(leafKey)] = types2.AccountWrapper{ NodeType: node.NodeType, Path: node.Path, @@ -386,7 +386,7 @@ func (sdb *builder) createdAndUpdatedStateWithIntermediateNodes(a, b trie.NodeIt // deletedOrUpdatedState returns a slice of all the pathes that are emptied at B // and a mapping of their leafkeys to all the accounts that exist in a different state at A than B -func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, output types2.StateNodeSink) (types2.AccountMap, error) { +func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedAddressesLeafKeys map[common.Hash]struct{}, output types2.StateNodeSink) (types2.AccountMap, error) { diffAccountAtA := make(types2.AccountMap) it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { @@ -409,24 +409,26 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - diffAccountAtA[common.Bytes2Hex(leafKey)] = types2.AccountWrapper{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - Account: &account, - } - // if this node's path did not show up in diffPathsAtB - // that means the node at this path was deleted (or moved) in B - // emit an empty "removed" diff to signify as such - if _, ok := diffPathsAtB[common.Bytes2Hex(node.Path)]; !ok { - if err := output(types2.StateNode{ + if isWatchedAddress(watchedAddressesLeafKeys, leafKey) { + diffAccountAtA[common.Bytes2Hex(leafKey)] = types2.AccountWrapper{ + NodeType: node.NodeType, Path: node.Path, - NodeValue: []byte{}, - NodeType: types2.Removed, + NodeValue: node.NodeValue, LeafKey: leafKey, - }); err != nil { - return nil, err + Account: &account, + } + // if this node's path did not show up in diffPathsAtB + // that means the node at this path was deleted (or moved) in B + // emit an empty "removed" diff to signify as such + if _, ok := diffPathsAtB[common.Bytes2Hex(node.Path)]; !ok { + if err := output(types2.StateNode{ + Path: node.Path, + NodeValue: []byte{}, + NodeType: types2.Removed, + LeafKey: leafKey, + }); err != nil { + return nil, err + } } } case types2.Extension, types2.Branch: @@ -454,8 +456,7 @@ func (sdb *builder) deletedOrUpdatedState(a, b trie.NodeIterator, diffPathsAtB m // to generate the statediff node objects for all of the accounts that existed at both A and B but in different states // needs to be called before building account creations and deletions as this mutates // those account maps to remove the accounts which were updated -func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, updatedKeys []string, - watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output types2.StateNodeSink) error { +func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, updatedKeys []string, intermediateStorageNodes bool, output types2.StateNodeSink) error { var err error for _, key := range updatedKeys { createdAcc := creations[key] @@ -465,7 +466,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, oldSR := deletedAcc.Account.Root newSR := createdAcc.Account.Root err = sdb.buildStorageNodesIncremental( - oldSR, newSR, watchedStorageKeys, intermediateStorageNodes, + oldSR, newSR, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) if err != nil { return fmt.Errorf("failed building incremental storage diffs for account with leafkey %s\r\nerror: %v", key, err) @@ -489,7 +490,7 @@ func (sdb *builder) buildAccountUpdates(creations, deletions types2.AccountMap, // buildAccountCreations returns the statediff node objects for all the accounts that exist at B but not at A // it also returns the code and codehash for created contract accounts -func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool, output types2.StateNodeSink, codeOutput types2.CodeSink) error { +func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, intermediateStorageNodes bool, output types2.StateNodeSink, codeOutput types2.CodeSink) error { for _, val := range accounts { diff := types2.StateNode{ NodeType: val.NodeType, @@ -500,7 +501,7 @@ func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedSto if !bytes.Equal(val.Account.CodeHash, nullCodeHash) { // For contract creations, any storage node contained is a diff var storageDiffs []types2.StorageNode - err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) + err := sdb.buildStorageNodesEventual(val.Account.Root, intermediateStorageNodes, storageNodeAppender(&storageDiffs)) if err != nil { return fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err) } @@ -528,7 +529,7 @@ func (sdb *builder) buildAccountCreations(accounts types2.AccountMap, watchedSto // buildStorageNodesEventual builds the storage diff node objects for a created account // i.e. it returns all the storage nodes at this state, since there is no previous state -func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { +func (sdb *builder) buildStorageNodesEventual(sr common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { if bytes.Equal(sr.Bytes(), emptyContractRoot.Bytes()) { return nil } @@ -539,7 +540,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys return err } it := sTrie.NodeIterator(make([]byte, 0)) - err = sdb.buildStorageNodesFromTrie(it, watchedStorageKeys, intermediateNodes, output) + err = sdb.buildStorageNodesFromTrie(it, intermediateNodes, output) if err != nil { return err } @@ -549,7 +550,7 @@ func (sdb *builder) buildStorageNodesEventual(sr common.Hash, watchedStorageKeys // buildStorageNodesFromTrie returns all the storage diff node objects in the provided node interator // if any storage keys are provided it will only return those leaf nodes // including intermediate nodes can be turned on or off -func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { +func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, intermediateNodes bool, output types2.StorageNodeSink) error { for it.Next(true) { // skip value nodes if it.Leaf() || bytes.Equal(nullHashBytes, it.Hash().Bytes()) { @@ -565,15 +566,13 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedStorageKeys, leafKey) { - if err := output(types2.StorageNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - }); err != nil { - return err - } + if err := output(types2.StorageNode{ + NodeType: node.NodeType, + Path: node.Path, + NodeValue: node.NodeValue, + LeafKey: leafKey, + }); err != nil { + return err } case types2.Extension, types2.Branch: if intermediateNodes { @@ -593,7 +592,7 @@ func (sdb *builder) buildStorageNodesFromTrie(it trie.NodeIterator, watchedStora } // buildStorageNodesIncremental builds the storage diff node objects for all nodes that exist in a different state at B than A -func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, watchedStorageKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { +func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { if bytes.Equal(newSR.Bytes(), oldSR.Bytes()) { return nil } @@ -609,19 +608,19 @@ func (sdb *builder) buildStorageNodesIncremental(oldSR common.Hash, newSR common diffPathsAtB, err := sdb.createdAndUpdatedStorage( oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - watchedStorageKeys, intermediateNodes, output) + intermediateNodes, output) if err != nil { return err } err = sdb.deletedOrUpdatedStorage(oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}), - diffPathsAtB, watchedStorageKeys, intermediateNodes, output) + diffPathsAtB, intermediateNodes, output) if err != nil { return err } return nil } -func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) (map[string]bool, error) { +func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, intermediateNodes bool, output types2.StorageNodeSink) (map[string]bool, error) { diffPathsAtB := make(map[string]bool) it, _ := trie.NewDifferenceIterator(a, b) for it.Next(true) { @@ -639,15 +638,13 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedKeys, leafKey) { - if err := output(types2.StorageNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, - LeafKey: leafKey, - }); err != nil { - return nil, err - } + if err := output(types2.StorageNode{ + NodeType: node.NodeType, + Path: node.Path, + NodeValue: node.NodeValue, + LeafKey: leafKey, + }); err != nil { + return nil, err } case types2.Extension, types2.Branch: if intermediateNodes { @@ -667,7 +664,7 @@ func (sdb *builder) createdAndUpdatedStorage(a, b trie.NodeIterator, watchedKeys return diffPathsAtB, it.Error() } -func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, watchedKeys []common.Hash, intermediateNodes bool, output types2.StorageNodeSink) error { +func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB map[string]bool, intermediateNodes bool, output types2.StorageNodeSink) error { it, _ := trie.NewDifferenceIterator(b, a) for it.Next(true) { // skip value nodes @@ -690,15 +687,13 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB valueNodePath := append(node.Path, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - if isWatchedStorageKey(watchedKeys, leafKey) { - if err := output(types2.StorageNode{ - NodeType: types2.Removed, - Path: node.Path, - NodeValue: []byte{}, - LeafKey: leafKey, - }); err != nil { - return err - } + if err := output(types2.StorageNode{ + NodeType: types2.Removed, + Path: node.Path, + NodeValue: []byte{}, + LeafKey: leafKey, + }); err != nil { + return err } case types2.Extension, types2.Branch: if intermediateNodes { @@ -718,30 +713,12 @@ func (sdb *builder) deletedOrUpdatedStorage(a, b trie.NodeIterator, diffPathsAtB } // isWatchedAddress is used to check if a state account corresponds to one of the addresses the builder is configured to watch -func isWatchedAddress(watchedAddresses []common.Address, stateLeafKey []byte) bool { +func isWatchedAddress(watchedAddressesLeafKeys map[common.Hash]struct{}, stateLeafKey []byte) bool { // If we aren't watching any specific addresses, we are watching everything - if len(watchedAddresses) == 0 { + if len(watchedAddressesLeafKeys) == 0 { return true } - for _, addr := range watchedAddresses { - addrHashKey := crypto.Keccak256(addr.Bytes()) - if bytes.Equal(addrHashKey, stateLeafKey) { - return true - } - } - return false -} -// isWatchedStorageKey is used to check if a storage leaf corresponds to one of the storage slots the builder is configured to watch -func isWatchedStorageKey(watchedKeys []common.Hash, storageLeafKey []byte) bool { - // If we aren't watching any specific addresses, we are watching everything - if len(watchedKeys) == 0 { - return true - } - for _, hashKey := range watchedKeys { - if bytes.Equal(hashKey.Bytes(), storageLeafKey) { - return true - } - } - return false + _, ok := watchedAddressesLeafKeys[common.BytesToHash(stateLeafKey)] + return ok } diff --git a/statediff/builder_test.go b/statediff/builder_test.go index d4d67940e32f..5edc1152361f 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -988,6 +988,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { params := statediff.Params{ WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.ContractAddr}, } + params.ComputeWatchedAddressesLeafKeys() builder = statediff.NewBuilder(chain.StateCache()) var tests = []struct { @@ -1152,17 +1153,17 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { } } -func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { - blocks, chain := test_helpers.MakeChain(3, test_helpers.Genesis, test_helpers.TestChainGen) +func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { + blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen) contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr) defer chain.Stop() - block0 = test_helpers.Genesis - block1 = blocks[0] - block2 = blocks[1] block3 = blocks[2] + block4 = blocks[3] + block5 = blocks[4] + block6 = blocks[5] params := statediff.Params{ - WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.ContractAddr}, - WatchedStorageSlots: []common.Hash{slot1StorageKey}, + IntermediateStateNodes: true, + IntermediateStorageNodes: true, } builder = statediff.NewBuilder(chain.StateCache()) @@ -1171,83 +1172,121 @@ func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { startingArguments statediff.Args expected *types2.StateObject }{ + // blocks 0-3 are the same as in TestBuilderWithIntermediateNodes { - "testEmptyDiff", - statediff.Args{ - OldStateRoot: block0.Root(), - NewStateRoot: block0.Root(), - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - }, - &types2.StateObject{ - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - Nodes: emptyDiffs, - }, - }, - { - "testBlock0", - //10000 transferred from testBankAddress to account1Addr - statediff.Args{ - OldStateRoot: test_helpers.NullHash, - NewStateRoot: block0.Root(), - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - }, - &types2.StateObject{ - BlockNumber: block0.Number(), - BlockHash: block0.Hash(), - Nodes: emptyDiffs, - }, - }, - { - "testBlock1", - //10000 transferred from testBankAddress to account1Addr + "testBlock4", statediff.Args{ - OldStateRoot: block0.Root(), - NewStateRoot: block1.Root(), - BlockNumber: block1.Number(), - BlockHash: block1.Hash(), + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), }, &types2.StateObject{ - BlockNumber: block1.Number(), - BlockHash: block1.Hash(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), Nodes: []types2.StateNode{ { - Path: []byte{'\x0e'}, + Path: []byte{}, + NodeType: types2.Branch, + NodeValue: block4BranchRootNode, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x00'}, NodeType: types2.Leaf, - LeafKey: test_helpers.Account1LeafKey, - NodeValue: account1AtBlock1LeafNode, + LeafKey: test_helpers.BankLeafKey, + NodeValue: bankAccountAtBlock4LeafNode, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x06'}, + NodeType: types2.Leaf, + LeafKey: contractLeafKey, + NodeValue: contractAccountAtBlock4LeafNode, + StorageNodes: []types2.StorageNode{ + { + Path: []byte{}, + NodeType: types2.Branch, + NodeValue: block4StorageBranchRootNode, + }, + { + Path: []byte{'\x04'}, + NodeType: types2.Leaf, + LeafKey: slot2StorageKey.Bytes(), + NodeValue: slot2StorageLeafNode, + }, + { + Path: []byte{'\x0b'}, + NodeType: types2.Removed, + LeafKey: slot1StorageKey.Bytes(), + NodeValue: []byte{}, + }, + { + Path: []byte{'\x0c'}, + NodeType: types2.Removed, + LeafKey: slot3StorageKey.Bytes(), + NodeValue: []byte{}, + }, + }, + }, + { + Path: []byte{'\x0c'}, + NodeType: types2.Leaf, + LeafKey: test_helpers.Account2LeafKey, + NodeValue: account2AtBlock4LeafNode, StorageNodes: emptyStorage, }, }, }, }, { - "testBlock2", - //1000 transferred from testBankAddress to account1Addr - //1000 transferred from account1Addr to account2Addr + "testBlock5", statediff.Args{ - OldStateRoot: block1.Root(), - NewStateRoot: block2.Root(), - BlockNumber: block2.Number(), - BlockHash: block2.Hash(), + OldStateRoot: block4.Root(), + NewStateRoot: block5.Root(), + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), }, &types2.StateObject{ - BlockNumber: block2.Number(), - BlockHash: block2.Hash(), + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), Nodes: []types2.StateNode{ + { + Path: []byte{}, + NodeType: types2.Branch, + NodeValue: block5BranchRootNode, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x00'}, + NodeType: types2.Leaf, + LeafKey: test_helpers.BankLeafKey, + NodeValue: bankAccountAtBlock5LeafNode, + StorageNodes: emptyStorage, + }, { Path: []byte{'\x06'}, NodeType: types2.Leaf, LeafKey: contractLeafKey, - NodeValue: contractAccountAtBlock2LeafNode, + NodeValue: contractAccountAtBlock5LeafNode, StorageNodes: []types2.StorageNode{ { - Path: []byte{'\x0b'}, + Path: []byte{}, NodeType: types2.Leaf, - LeafKey: slot1StorageKey.Bytes(), - NodeValue: slot1StorageLeafNode, + NodeValue: slot0StorageLeafRootNode, + LeafKey: slot0StorageKey.Bytes(), + }, + { + Path: []byte{'\x02'}, + NodeType: types2.Removed, + LeafKey: slot0StorageKey.Bytes(), + NodeValue: []byte{}, + }, + { + Path: []byte{'\x04'}, + NodeType: types2.Removed, + LeafKey: slot2StorageKey.Bytes(), + NodeValue: []byte{}, }, }, }, @@ -1255,37 +1294,49 @@ func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { Path: []byte{'\x0e'}, NodeType: types2.Leaf, LeafKey: test_helpers.Account1LeafKey, - NodeValue: account1AtBlock2LeafNode, + NodeValue: account1AtBlock5LeafNode, StorageNodes: emptyStorage, }, }, - CodeAndCodeHashes: []types2.CodeAndCodeHash{ - { - Hash: test_helpers.CodeHash, - Code: test_helpers.ByteCodeAfterDeployment, - }, - }, }, }, { - "testBlock3", - //the contract's storage is changed - //and the block is mined by account 2 + "testBlock6", statediff.Args{ - OldStateRoot: block2.Root(), - NewStateRoot: block3.Root(), - BlockNumber: block3.Number(), - BlockHash: block3.Hash(), + OldStateRoot: block5.Root(), + NewStateRoot: block6.Root(), + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), }, &types2.StateObject{ - BlockNumber: block3.Number(), - BlockHash: block3.Hash(), + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), Nodes: []types2.StateNode{ + { + Path: []byte{}, + NodeType: types2.Branch, + NodeValue: block6BranchRootNode, + StorageNodes: emptyStorage, + }, { Path: []byte{'\x06'}, - NodeType: types2.Leaf, + NodeType: types2.Removed, LeafKey: contractLeafKey, - NodeValue: contractAccountAtBlock3LeafNode, + NodeValue: []byte{}, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x0c'}, + NodeType: types2.Leaf, + LeafKey: test_helpers.Account2LeafKey, + NodeValue: account2AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x0e'}, + NodeType: types2.Leaf, + LeafKey: test_helpers.Account1LeafKey, + NodeValue: account1AtBlock6LeafNode, StorageNodes: emptyStorage, }, }, @@ -1310,12 +1361,12 @@ func TestBuilderWithWatchedAddressAndStorageKeyList(t *testing.T) { sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { t.Logf("Test failed: %s", test.name) - t.Errorf("actual state diff: %+v\nexpected state diff: %+v", diff, test.expected) + t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected) } } } -func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { +func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing.T) { blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen) contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr) defer chain.Stop() @@ -1324,8 +1375,8 @@ func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { block5 = blocks[4] block6 = blocks[5] params := statediff.Params{ - IntermediateStateNodes: true, - IntermediateStorageNodes: true, + IntermediateStateNodes: false, + IntermediateStorageNodes: false, } builder = statediff.NewBuilder(chain.StateCache()) @@ -1347,12 +1398,6 @@ func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { BlockNumber: block4.Number(), BlockHash: block4.Hash(), Nodes: []types2.StateNode{ - { - Path: []byte{}, - NodeType: types2.Branch, - NodeValue: block4BranchRootNode, - StorageNodes: emptyStorage, - }, { Path: []byte{'\x00'}, NodeType: types2.Leaf, @@ -1366,11 +1411,6 @@ func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { LeafKey: contractLeafKey, NodeValue: contractAccountAtBlock4LeafNode, StorageNodes: []types2.StorageNode{ - { - Path: []byte{}, - NodeType: types2.Branch, - NodeValue: block4StorageBranchRootNode, - }, { Path: []byte{'\x04'}, NodeType: types2.Leaf, @@ -1413,12 +1453,6 @@ func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { BlockNumber: block5.Number(), BlockHash: block5.Hash(), Nodes: []types2.StateNode{ - { - Path: []byte{}, - NodeType: types2.Branch, - NodeValue: block5BranchRootNode, - StorageNodes: emptyStorage, - }, { Path: []byte{'\x00'}, NodeType: types2.Leaf, @@ -1435,8 +1469,8 @@ func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { { Path: []byte{}, NodeType: types2.Leaf, - NodeValue: slot0StorageLeafRootNode, LeafKey: slot0StorageKey.Bytes(), + NodeValue: slot0StorageLeafRootNode, }, { Path: []byte{'\x02'}, @@ -1475,17 +1509,10 @@ func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { BlockHash: block6.Hash(), Nodes: []types2.StateNode{ { - Path: []byte{}, - NodeType: types2.Branch, - NodeValue: block6BranchRootNode, - StorageNodes: emptyStorage, - }, - { - Path: []byte{'\x06'}, - NodeType: types2.Removed, - LeafKey: contractLeafKey, - NodeValue: []byte{}, - StorageNodes: emptyStorage, + Path: []byte{'\x06'}, + NodeType: types2.Removed, + LeafKey: contractLeafKey, + NodeValue: []byte{}, }, { Path: []byte{'\x0c'}, @@ -1515,10 +1542,12 @@ func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { if err != nil { t.Error(err) } + expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) if err != nil { t.Error(err) } + sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { @@ -1528,7 +1557,7 @@ func TestBuilderWithRemovedAccountAndStorage(t *testing.T) { } } -func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing.T) { +func TestBuilderWithRemovedNonWatchedAccount(t *testing.T) { blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen) contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr) defer chain.Stop() @@ -1537,9 +1566,9 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. block5 = blocks[4] block6 = blocks[5] params := statediff.Params{ - IntermediateStateNodes: false, - IntermediateStorageNodes: false, + WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.Account2Addr}, } + params.ComputeWatchedAddressesLeafKeys() builder = statediff.NewBuilder(chain.StateCache()) var tests = []struct { @@ -1547,7 +1576,6 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. startingArguments statediff.Args expected *types2.StateObject }{ - // blocks 0-3 are the same as in TestBuilderWithIntermediateNodes { "testBlock4", statediff.Args{ @@ -1561,12 +1589,123 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. BlockHash: block4.Hash(), Nodes: []types2.StateNode{ { - Path: []byte{'\x00'}, + Path: []byte{'\x0c'}, NodeType: types2.Leaf, - LeafKey: test_helpers.BankLeafKey, - NodeValue: bankAccountAtBlock4LeafNode, + LeafKey: test_helpers.Account2LeafKey, + NodeValue: account2AtBlock4LeafNode, StorageNodes: emptyStorage, }, + }, + }, + }, + { + "testBlock5", + statediff.Args{ + OldStateRoot: block4.Root(), + NewStateRoot: block5.Root(), + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + }, + &types2.StateObject{ + BlockNumber: block5.Number(), + BlockHash: block5.Hash(), + Nodes: []types2.StateNode{ + { + Path: []byte{'\x0e'}, + NodeType: types2.Leaf, + LeafKey: test_helpers.Account1LeafKey, + NodeValue: account1AtBlock5LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + { + "testBlock6", + statediff.Args{ + OldStateRoot: block5.Root(), + NewStateRoot: block6.Root(), + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + }, + &types2.StateObject{ + BlockNumber: block6.Number(), + BlockHash: block6.Hash(), + Nodes: []types2.StateNode{ + { + Path: []byte{'\x0c'}, + NodeType: types2.Leaf, + LeafKey: test_helpers.Account2LeafKey, + NodeValue: account2AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + { + Path: []byte{'\x0e'}, + NodeType: types2.Leaf, + LeafKey: test_helpers.Account1LeafKey, + NodeValue: account1AtBlock6LeafNode, + StorageNodes: emptyStorage, + }, + }, + }, + }, + } + + for _, test := range tests { + diff, err := builder.BuildStateDiffObject(test.startingArguments, params) + if err != nil { + t.Error(err) + } + receivedStateDiffRlp, err := rlp.EncodeToBytes(diff) + if err != nil { + t.Error(err) + } + + expectedStateDiffRlp, err := rlp.EncodeToBytes(test.expected) + if err != nil { + t.Error(err) + } + + sort.Slice(receivedStateDiffRlp, func(i, j int) bool { return receivedStateDiffRlp[i] < receivedStateDiffRlp[j] }) + sort.Slice(expectedStateDiffRlp, func(i, j int) bool { return expectedStateDiffRlp[i] < expectedStateDiffRlp[j] }) + if !bytes.Equal(receivedStateDiffRlp, expectedStateDiffRlp) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual state diff: %+v\r\n\r\n\r\nexpected state diff: %+v", diff, test.expected) + } + } +} + +func TestBuilderWithRemovedWatchedAccount(t *testing.T) { + blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen) + contractLeafKey = test_helpers.AddressToLeafKey(test_helpers.ContractAddr) + defer chain.Stop() + block3 = blocks[2] + block4 = blocks[3] + block5 = blocks[4] + block6 = blocks[5] + params := statediff.Params{ + WatchedAddresses: []common.Address{test_helpers.Account1Addr, test_helpers.ContractAddr}, + } + params.ComputeWatchedAddressesLeafKeys() + builder = statediff.NewBuilder(chain.StateCache()) + + var tests = []struct { + name string + startingArguments statediff.Args + expected *types2.StateObject + }{ + { + "testBlock4", + statediff.Args{ + OldStateRoot: block3.Root(), + NewStateRoot: block4.Root(), + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + }, + &types2.StateObject{ + BlockNumber: block4.Number(), + BlockHash: block4.Hash(), + Nodes: []types2.StateNode{ { Path: []byte{'\x06'}, NodeType: types2.Leaf, @@ -1593,13 +1732,6 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. }, }, }, - { - Path: []byte{'\x0c'}, - NodeType: types2.Leaf, - LeafKey: test_helpers.Account2LeafKey, - NodeValue: account2AtBlock4LeafNode, - StorageNodes: emptyStorage, - }, }, }, }, @@ -1615,13 +1747,6 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. BlockNumber: block5.Number(), BlockHash: block5.Hash(), Nodes: []types2.StateNode{ - { - Path: []byte{'\x00'}, - NodeType: types2.Leaf, - LeafKey: test_helpers.BankLeafKey, - NodeValue: bankAccountAtBlock5LeafNode, - StorageNodes: emptyStorage, - }, { Path: []byte{'\x06'}, NodeType: types2.Leaf, @@ -1676,13 +1801,6 @@ func TestBuilderWithRemovedAccountAndStorageWithoutIntermediateNodes(t *testing. LeafKey: contractLeafKey, NodeValue: []byte{}, }, - { - Path: []byte{'\x0c'}, - NodeType: types2.Leaf, - LeafKey: test_helpers.Account2LeafKey, - NodeValue: account2AtBlock6LeafNode, - StorageNodes: emptyStorage, - }, { Path: []byte{'\x0e'}, NodeType: types2.Leaf, diff --git a/statediff/config.go b/statediff/config.go index f20f3267ee9d..b4905ab5a75b 100644 --- a/statediff/config.go +++ b/statediff/config.go @@ -19,8 +19,10 @@ package statediff import ( "context" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/statediff/indexer/interfaces" ) @@ -53,7 +55,21 @@ type Params struct { IncludeTD bool IncludeCode bool WatchedAddresses []common.Address - WatchedStorageSlots []common.Hash + watchedAddressesLeafKeys map[common.Hash]struct{} +} + +// ComputeWatchedAddressesLeafKeys populates a map with keys (Keccak256Hash) of each of the WatchedAddresses +func (p *Params) ComputeWatchedAddressesLeafKeys() { + p.watchedAddressesLeafKeys = make(map[common.Hash]struct{}, len(p.WatchedAddresses)) + for _, address := range p.WatchedAddresses { + p.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + } +} + +// ParamsWithMutex allows to lock the parameters while they are being updated | read from +type ParamsWithMutex struct { + Params + sync.RWMutex } // Args bundles the arguments for the state diff builder diff --git a/statediff/indexer/database/dump/indexer.go b/statediff/indexer/database/dump/indexer.go index e450f941ac2f..fb9865f8d684 100644 --- a/statediff/indexer/database/dump/indexer.go +++ b/statediff/indexer/database/dump/indexer.go @@ -496,3 +496,28 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(batch interfaces.Batch, codeAnd func (sdi *StateDiffIndexer) Close() error { return sdi.dump.Close() } + +// LoadWatchedAddresses satisfies the interfaces.StateDiffIndexer interface +func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) { + return nil, nil +} + +// InsertWatchedAddresses satisfies the interfaces.StateDiffIndexer interface +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + return nil +} + +// RemoveWatchedAddresses satisfies the interfaces.StateDiffIndexer interface +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error { + return nil +} + +// SetWatchedAddresses satisfies the interfaces.StateDiffIndexer interface +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + return nil +} + +// ClearWatchedAddresses satisfies the interfaces.StateDiffIndexer interface +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + return nil +} diff --git a/statediff/indexer/database/file/config.go b/statediff/indexer/database/file/config.go index c2c6804c0f97..a075e896b3d7 100644 --- a/statediff/indexer/database/file/config.go +++ b/statediff/indexer/database/file/config.go @@ -23,8 +23,9 @@ import ( // Config holds params for writing sql statements out to a file type Config struct { - FilePath string - NodeInfo node.Info + FilePath string + WatchedAddressesFilePath string + NodeInfo node.Info } // Type satisfies interfaces.Config @@ -34,7 +35,8 @@ func (c Config) Type() shared.DBType { // TestConfig config for unit tests var TestConfig = Config{ - FilePath: "./statediffing_test_file.sql", + FilePath: "./statediffing_test_file.sql", + WatchedAddressesFilePath: "./statediffing_watched_addresses_test_file.sql", NodeInfo: node.Info{ GenesisBlock: "0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3", NetworkID: "1", diff --git a/statediff/indexer/database/file/indexer.go b/statediff/indexer/database/file/indexer.go index 870c1f259ce2..c842cef12add 100644 --- a/statediff/indexer/database/file/indexer.go +++ b/statediff/indexer/database/file/indexer.go @@ -17,6 +17,7 @@ package file import ( + "bufio" "context" "errors" "fmt" @@ -28,6 +29,8 @@ import ( "github.com/ipfs/go-cid" node "github.com/ipfs/go-ipld-format" "github.com/multiformats/go-multihash" + pg_query "github.com/pganalyze/pg_query_go/v2" + "github.com/thoas/go-funk" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -44,6 +47,9 @@ import ( ) const defaultFilePath = "./statediff.sql" +const defaultWatchedAddressesFilePath = "./statediff-watched-addresses.sql" + +const watchedAddressesInsert = "INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ('%s', '%d', '%d') ON CONFLICT (address) DO NOTHING;" var _ interfaces.StateDiffIndexer = &StateDiffIndexer{} @@ -57,6 +63,8 @@ type StateDiffIndexer struct { chainConfig *params.ChainConfig nodeID string wg *sync.WaitGroup + + watchedAddressesFilePath string } // NewStateDiffIndexer creates a void implementation of interfaces.StateDiffIndexer @@ -73,16 +81,24 @@ func NewStateDiffIndexer(ctx context.Context, chainConfig *params.ChainConfig, c return nil, fmt.Errorf("unable to create file (%s), err: %v", filePath, err) } log.Info("Writing statediff SQL statements to file", "file", filePath) + + watchedAddressesFilePath := config.WatchedAddressesFilePath + if watchedAddressesFilePath == "" { + watchedAddressesFilePath = defaultWatchedAddressesFilePath + } + log.Info("Writing watched addresses SQL statements to file", "file", watchedAddressesFilePath) + w := NewSQLWriter(file) wg := new(sync.WaitGroup) w.Loop() w.upsertNode(config.NodeInfo) w.upsertIPLDDirect(shared.RemovedNodeMhKey, []byte{}) return &StateDiffIndexer{ - fileWriter: w, - chainConfig: chainConfig, - nodeID: config.NodeInfo.ID, - wg: wg, + fileWriter: w, + chainConfig: chainConfig, + nodeID: config.NodeInfo.ID, + wg: wg, + watchedAddressesFilePath: watchedAddressesFilePath, }, nil } @@ -478,3 +494,165 @@ func (sdi *StateDiffIndexer) PushCodeAndCodeHash(batch interfaces.Batch, codeAnd func (sdi *StateDiffIndexer) Close() error { return sdi.fileWriter.Close() } + +// LoadWatchedAddresses loads watched addresses from a file +func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) { + // load sql statements from watched addresses file + stmts, err := loadWatchedAddressesStatements(sdi.watchedAddressesFilePath) + if err != nil { + return nil, err + } + + // extract addresses from the sql statements + watchedAddresses := []common.Address{} + for _, stmt := range stmts { + addressString, err := parseWatchedAddressStatement(stmt) + if err != nil { + return nil, err + } + watchedAddresses = append(watchedAddresses, common.HexToAddress(addressString)) + } + + return watchedAddresses, nil +} + +// InsertWatchedAddresses inserts the given addresses in a file +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + // load sql statements from watched addresses file + stmts, err := loadWatchedAddressesStatements(sdi.watchedAddressesFilePath) + if err != nil { + return err + } + + // get already watched addresses + var watchedAddresses []string + for _, stmt := range stmts { + addressString, err := parseWatchedAddressStatement(stmt) + if err != nil { + return err + } + + watchedAddresses = append(watchedAddresses, addressString) + } + + // append statements for new addresses to existing statements + for _, arg := range args { + // ignore if already watched + if funk.Contains(watchedAddresses, arg.Address) { + continue + } + + stmt := fmt.Sprintf(watchedAddressesInsert, arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) + stmts = append(stmts, stmt) + } + + return dumpWatchedAddressesStatements(sdi.watchedAddressesFilePath, stmts) +} + +// RemoveWatchedAddresses removes the given watched addresses from a file +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error { + // load sql statements from watched addresses file + stmts, err := loadWatchedAddressesStatements(sdi.watchedAddressesFilePath) + if err != nil { + return err + } + + // get rid of statements having addresses to be removed + var filteredStmts []string + for _, stmt := range stmts { + addressString, err := parseWatchedAddressStatement(stmt) + if err != nil { + return err + } + + toRemove := funk.Contains(args, func(arg sdtypes.WatchAddressArg) bool { + return arg.Address == addressString + }) + + if !toRemove { + filteredStmts = append(filteredStmts, stmt) + } + } + + return dumpWatchedAddressesStatements(sdi.watchedAddressesFilePath, filteredStmts) +} + +// SetWatchedAddresses clears and inserts the given addresses in a file +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + var stmts []string + for _, arg := range args { + stmt := fmt.Sprintf(watchedAddressesInsert, arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) + stmts = append(stmts, stmt) + } + + return dumpWatchedAddressesStatements(sdi.watchedAddressesFilePath, stmts) +} + +// ClearWatchedAddresses clears all the watched addresses from a file +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + return sdi.SetWatchedAddresses([]sdtypes.WatchAddressArg{}, big.NewInt(0)) +} + +// loadWatchedAddressesStatements loads sql statements from the given file in a string slice +func loadWatchedAddressesStatements(filePath string) ([]string, error) { + file, err := os.Open(filePath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return []string{}, nil + } + + return nil, fmt.Errorf("error opening watched addresses file: %v", err) + } + defer file.Close() + + stmts := []string{} + scanner := bufio.NewScanner(file) + for scanner.Scan() { + stmts = append(stmts, scanner.Text()) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error loading watched addresses: %v", err) + } + + return stmts, nil +} + +// dumpWatchedAddressesStatements dumps sql statements to the given file +func dumpWatchedAddressesStatements(filePath string, stmts []string) error { + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("error creating watched addresses file: %v", err) + } + defer file.Close() + + for _, stmt := range stmts { + _, err := file.Write([]byte(stmt + "\n")) + if err != nil { + return fmt.Errorf("error inserting watched_addresses entry: %v", err) + } + } + + return nil +} + +// parseWatchedAddressStatement parses given sql insert statement to extract the address argument +func parseWatchedAddressStatement(stmt string) (string, error) { + parseResult, err := pg_query.Parse(stmt) + if err != nil { + return "", fmt.Errorf("error parsing sql stmt: %v", err) + } + + // extract address argument from parse output for a SQL statement of form + // "INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) + // VALUES ('0xabc', '123', '130') ON CONFLICT (address) DO NOTHING;" + addressString := parseResult.Stmts[0].Stmt.GetInsertStmt(). + SelectStmt.GetSelectStmt(). + ValuesLists[0].GetList(). + Items[0].GetAConst(). + GetVal(). + GetString_(). + Str + + return addressString, nil +} diff --git a/statediff/indexer/database/file/indexer_legacy_test.go b/statediff/indexer/database/file/indexer_legacy_test.go index 56bca2683102..e89d168aaf98 100644 --- a/statediff/indexer/database/file/indexer_legacy_test.go +++ b/statediff/indexer/database/file/indexer_legacy_test.go @@ -81,7 +81,7 @@ func setupLegacy(t *testing.T) { } } -func dumpData(t *testing.T) { +func dumpFileData(t *testing.T) { sqlFileBytes, err := os.ReadFile(file.TestConfig.FilePath) require.NoError(t, err) @@ -89,10 +89,36 @@ func dumpData(t *testing.T) { require.NoError(t, err) } +func resetAndDumpWatchedAddressesFileData(t *testing.T) { + resetDB(t) + + sqlFileBytes, err := os.ReadFile(file.TestConfig.WatchedAddressesFilePath) + require.NoError(t, err) + + _, err = sqlxdb.Exec(string(sqlFileBytes)) + require.NoError(t, err) +} + +func resetDB(t *testing.T) { + file.TearDownDB(t, sqlxdb) + + connStr := postgres.DefaultConfig.DbConnectionString() + sqlxdb, err = sqlx.Connect("postgres", connStr) + if err != nil { + t.Fatalf("failed to connect to db with connection string: %s err: %v", connStr, err) + } +} + func tearDown(t *testing.T) { file.TearDownDB(t, sqlxdb) + err := os.Remove(file.TestConfig.FilePath) require.NoError(t, err) + + if err := os.Remove(file.TestConfig.WatchedAddressesFilePath); !errors.Is(err, os.ErrNotExist) { + require.NoError(t, err) + } + err = sqlxdb.Close() require.NoError(t, err) } @@ -106,7 +132,7 @@ func expectTrue(t *testing.T, value bool) { func TestFileIndexerLegacy(t *testing.T) { t.Run("Publish and index header IPLDs", func(t *testing.T) { setupLegacy(t) - dumpData(t) + dumpFileData(t) defer tearDown(t) pgStr := `SELECT cid, td, reward, block_hash, coinbase FROM eth.header_cids diff --git a/statediff/indexer/database/file/indexer_test.go b/statediff/indexer/database/file/indexer_test.go index ef849e8e86e7..fb5453fe661a 100644 --- a/statediff/indexer/database/file/indexer_test.go +++ b/statediff/indexer/database/file/indexer_test.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "math/big" "os" "testing" @@ -28,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/statediff/indexer/models" "github.com/ethereum/go-ethereum/statediff/indexer/shared" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" "github.com/ipfs/go-cid" blockstore "github.com/ipfs/go-ipfs-blockstore" @@ -51,12 +53,15 @@ var ( ind interfaces.StateDiffIndexer ipfsPgGet = `SELECT data FROM public.blocks WHERE key = $1` - tx1, tx2, tx3, tx4, tx5, rct1, rct2, rct3, rct4, rct5 []byte - mockBlock *types.Block - headerCID, trx1CID, trx2CID, trx3CID, trx4CID, trx5CID cid.Cid - rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid - rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte - state1CID, state2CID, storageCID cid.Cid + tx1, tx2, tx3, tx4, tx5, rct1, rct2, rct3, rct4, rct5 []byte + mockBlock *types.Block + headerCID, trx1CID, trx2CID, trx3CID, trx4CID, trx5CID cid.Cid + rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid + rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte + state1CID, state2CID, storageCID cid.Cid + contract1Address, contract2Address, contract3Address, contract4Address string + contract1CreatedAt, contract2CreatedAt, contract3CreatedAt, contract4CreatedAt uint64 + lastFilledAt, watchedAt1, watchedAt2, watchedAt3 uint64 ) func init() { @@ -161,15 +166,45 @@ func init() { rctLeaf3 = orderedRctLeafNodes[2] rctLeaf4 = orderedRctLeafNodes[3] rctLeaf5 = orderedRctLeafNodes[4] + + contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" + contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" + contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc" + contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc" + contract1CreatedAt = uint64(1) + contract2CreatedAt = uint64(2) + contract3CreatedAt = uint64(3) + contract4CreatedAt = uint64(4) + + lastFilledAt = uint64(0) + watchedAt1 = uint64(10) + watchedAt2 = uint64(15) + watchedAt3 = uint64(20) } -func setup(t *testing.T) { +func setupIndexer(t *testing.T) { if _, err := os.Stat(file.TestConfig.FilePath); !errors.Is(err, os.ErrNotExist) { err := os.Remove(file.TestConfig.FilePath) require.NoError(t, err) } + + if _, err := os.Stat(file.TestConfig.WatchedAddressesFilePath); !errors.Is(err, os.ErrNotExist) { + err := os.Remove(file.TestConfig.WatchedAddressesFilePath) + require.NoError(t, err) + } + ind, err = file.NewStateDiffIndexer(context.Background(), mocks.TestConfig, file.TestConfig) require.NoError(t, err) + + connStr := postgres.DefaultConfig.DbConnectionString() + sqlxdb, err = sqlx.Connect("postgres", connStr) + if err != nil { + t.Fatalf("failed to connect to db with connection string: %s err: %v", connStr, err) + } +} + +func setup(t *testing.T) { + setupIndexer(t) var tx interfaces.Batch tx, err = ind.PushBlock( mockBlock, @@ -192,19 +227,12 @@ func setup(t *testing.T) { } test_helpers.ExpectEqual(t, tx.(*file.BatchTx).BlockNumber, mocks.BlockNumber.Uint64()) - - connStr := postgres.DefaultConfig.DbConnectionString() - - sqlxdb, err = sqlx.Connect("postgres", connStr) - if err != nil { - t.Fatalf("failed to connect to db with connection string: %s err: %v", connStr, err) - } } func TestFileIndexer(t *testing.T) { t.Run("Publish and index header IPLDs in a single tx", func(t *testing.T) { setup(t) - dumpData(t) + dumpFileData(t) defer tearDown(t) pgStr := `SELECT cid, td, reward, block_hash, coinbase FROM eth.header_cids @@ -242,7 +270,7 @@ func TestFileIndexer(t *testing.T) { }) t.Run("Publish and index transaction IPLDs in a single tx", func(t *testing.T) { setup(t) - dumpData(t) + dumpFileData(t) defer tearDown(t) // check that txs were properly indexed and published @@ -370,7 +398,7 @@ func TestFileIndexer(t *testing.T) { t.Run("Publish and index log IPLDs for multiple receipt of a specific block", func(t *testing.T) { setup(t) - dumpData(t) + dumpFileData(t) defer tearDown(t) rcts := make([]string, 0) @@ -426,7 +454,7 @@ func TestFileIndexer(t *testing.T) { t.Run("Publish and index receipt IPLDs in a single tx", func(t *testing.T) { setup(t) - dumpData(t) + dumpFileData(t) defer tearDown(t) // check receipts were properly indexed and published @@ -527,7 +555,7 @@ func TestFileIndexer(t *testing.T) { t.Run("Publish and index state IPLDs in a single tx", func(t *testing.T) { setup(t) - dumpData(t) + dumpFileData(t) defer tearDown(t) // check that state nodes were properly indexed and published @@ -618,7 +646,7 @@ func TestFileIndexer(t *testing.T) { t.Run("Publish and index storage IPLDs in a single tx", func(t *testing.T) { setup(t) - dumpData(t) + dumpFileData(t) defer tearDown(t) // check that storage nodes were properly indexed @@ -688,3 +716,341 @@ func TestFileIndexer(t *testing.T) { test_helpers.ExpectEqual(t, data, []byte{}) }) } + +func TestFileWatchAddressMethods(t *testing.T) { + setupIndexer(t) + defer tearDown(t) + + type res struct { + Address string `db:"address"` + CreatedAt uint64 `db:"created_at"` + WatchedAt uint64 `db:"watched_at"` + LastFilledAt uint64 `db:"last_filled_at"` + } + pgStr := "SELECT * FROM eth_meta.watched_addresses" + + t.Run("Load watched addresses (empty table)", func(t *testing.T) { + expectedData := []common.Address{} + + rows, err := ind.LoadWatchedAddresses() + require.NoError(t, err) + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1))) + require.NoError(t, err) + resetAndDumpWatchedAddressesFileData(t) + + rows := []res{} + err = sqlxdb.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2))) + require.NoError(t, err) + resetAndDumpWatchedAddressesFileData(t) + + rows := []res{} + err = sqlxdb.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.RemoveWatchedAddresses(args) + require.NoError(t, err) + resetAndDumpWatchedAddressesFileData(t) + + rows := []res{} + err = sqlxdb.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses (some non-watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{} + + err = ind.RemoveWatchedAddresses(args) + require.NoError(t, err) + resetAndDumpWatchedAddressesFileData(t) + + rows := []res{} + err = sqlxdb.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2))) + require.NoError(t, err) + resetAndDumpWatchedAddressesFileData(t) + + rows := []res{} + err = sqlxdb.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3))) + require.NoError(t, err) + resetAndDumpWatchedAddressesFileData(t) + + rows := []res{} + err = sqlxdb.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Load watched addresses", func(t *testing.T) { + expectedData := []common.Address{ + common.HexToAddress(contract4Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + } + + rows, err := ind.LoadWatchedAddresses() + require.NoError(t, err) + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses", func(t *testing.T) { + expectedData := []res{} + + err = ind.ClearWatchedAddresses() + require.NoError(t, err) + resetAndDumpWatchedAddressesFileData(t) + + rows := []res{} + err = sqlxdb.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses (empty table)", func(t *testing.T) { + expectedData := []res{} + + err = ind.ClearWatchedAddresses() + require.NoError(t, err) + resetAndDumpWatchedAddressesFileData(t) + + rows := []res{} + err = sqlxdb.Select(&rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) +} diff --git a/statediff/indexer/database/file/test_helpers.go b/statediff/indexer/database/file/test_helpers.go index 27d204d55aea..27a1581a48da 100644 --- a/statediff/indexer/database/file/test_helpers.go +++ b/statediff/indexer/database/file/test_helpers.go @@ -57,6 +57,10 @@ func TearDownDB(t *testing.T, db *sqlx.DB) { if err != nil { t.Fatal(err) } + _, err = tx.Exec(`DELETE FROM eth_meta.watched_addresses`) + if err != nil { + t.Fatal(err) + } err = tx.Commit() if err != nil { t.Fatal(err) diff --git a/statediff/indexer/database/sql/indexer.go b/statediff/indexer/database/sql/indexer.go index 4f2b52434ba1..790766cfe557 100644 --- a/statediff/indexer/database/sql/indexer.go +++ b/statediff/indexer/database/sql/indexer.go @@ -555,3 +555,118 @@ func (sdi *StateDiffIndexer) Close() error { } // Update the known gaps table with the gap information. + +// LoadWatchedAddresses reads watched addresses from the database +func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) { + addressStrings := make([]string, 0) + pgStr := "SELECT address FROM eth_meta.watched_addresses" + err := sdi.dbWriter.db.Select(sdi.ctx, &addressStrings, pgStr) + if err != nil { + return nil, fmt.Errorf("error loading watched addresses: %v", err) + } + + watchedAddresses := []common.Address{} + for _, addressString := range addressStrings { + watchedAddresses = append(watchedAddresses, common.HexToAddress(addressString)) + } + + return watchedAddresses, nil +} + +// InsertWatchedAddresses inserts the given addresses in the database +func (sdi *StateDiffIndexer) InsertWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + tx, err := sdi.dbWriter.db.Begin(sdi.ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + rollback(sdi.ctx, tx) + panic(p) + } else if err != nil { + rollback(sdi.ctx, tx) + } else { + err = tx.Commit(sdi.ctx) + } + }() + + for _, arg := range args { + _, err = tx.Exec(sdi.ctx, `INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) + if err != nil { + return fmt.Errorf("error inserting watched_addresses entry: %v", err) + } + } + + return err +} + +// RemoveWatchedAddresses removes the given watched addresses from the database +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(args []sdtypes.WatchAddressArg) error { + tx, err := sdi.dbWriter.db.Begin(sdi.ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + rollback(sdi.ctx, tx) + panic(p) + } else if err != nil { + rollback(sdi.ctx, tx) + } else { + err = tx.Commit(sdi.ctx) + } + }() + + for _, arg := range args { + _, err = tx.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses WHERE address = $1`, arg.Address) + if err != nil { + return fmt.Errorf("error removing watched_addresses entry: %v", err) + } + } + + return err +} + +// SetWatchedAddresses clears and inserts the given addresses in the database +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + tx, err := sdi.dbWriter.db.Begin(sdi.ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + rollback(sdi.ctx, tx) + panic(p) + } else if err != nil { + rollback(sdi.ctx, tx) + } else { + err = tx.Commit(sdi.ctx) + } + }() + + _, err = tx.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses`) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + + for _, arg := range args { + _, err = tx.Exec(sdi.ctx, `INSERT INTO eth_meta.watched_addresses (address, created_at, watched_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO NOTHING`, + arg.Address, arg.CreatedAt, currentBlockNumber.Uint64()) + if err != nil { + return fmt.Errorf("error setting watched_addresses table: %v", err) + } + } + + return err +} + +// ClearWatchedAddresses clears all the watched addresses from the database +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + _, err := sdi.dbWriter.db.Exec(sdi.ctx, `DELETE FROM eth_meta.watched_addresses`) + if err != nil { + return fmt.Errorf("error clearing watched_addresses table: %v", err) + } + + return nil +} diff --git a/statediff/indexer/database/sql/indexer_shared_test.go b/statediff/indexer/database/sql/indexer_shared_test.go index 0351fb13481d..b621d2e1a2c3 100644 --- a/statediff/indexer/database/sql/indexer_shared_test.go +++ b/statediff/indexer/database/sql/indexer_shared_test.go @@ -24,12 +24,15 @@ var ( ind interfaces.StateDiffIndexer ipfsPgGet = `SELECT data FROM public.blocks WHERE key = $1` - tx1, tx2, tx3, tx4, tx5, rct1, rct2, rct3, rct4, rct5 []byte - mockBlock *types.Block - headerCID, trx1CID, trx2CID, trx3CID, trx4CID, trx5CID cid.Cid - rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid - rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte - state1CID, state2CID, storageCID cid.Cid + tx1, tx2, tx3, tx4, tx5, rct1, rct2, rct3, rct4, rct5 []byte + mockBlock *types.Block + headerCID, trx1CID, trx2CID, trx3CID, trx4CID, trx5CID cid.Cid + rct1CID, rct2CID, rct3CID, rct4CID, rct5CID cid.Cid + rctLeaf1, rctLeaf2, rctLeaf3, rctLeaf4, rctLeaf5 []byte + state1CID, state2CID, storageCID cid.Cid + contract1Address, contract2Address, contract3Address, contract4Address string + contract1CreatedAt, contract2CreatedAt, contract3CreatedAt, contract4CreatedAt uint64 + lastFilledAt, watchedAt1, watchedAt2, watchedAt3 uint64 ) func init() { @@ -134,6 +137,20 @@ func init() { rctLeaf3 = orderedRctLeafNodes[2] rctLeaf4 = orderedRctLeafNodes[3] rctLeaf5 = orderedRctLeafNodes[4] + + contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" + contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" + contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc" + contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc" + contract1CreatedAt = uint64(1) + contract2CreatedAt = uint64(2) + contract3CreatedAt = uint64(3) + contract4CreatedAt = uint64(4) + + lastFilledAt = uint64(0) + watchedAt1 = uint64(10) + watchedAt2 = uint64(15) + watchedAt3 = uint64(20) } func expectTrue(t *testing.T, value bool) { diff --git a/statediff/indexer/database/sql/pgx_indexer_test.go b/statediff/indexer/database/sql/pgx_indexer_test.go index ec5b94fd5eba..fef559486e8b 100644 --- a/statediff/indexer/database/sql/pgx_indexer_test.go +++ b/statediff/indexer/database/sql/pgx_indexer_test.go @@ -18,6 +18,7 @@ package sql_test import ( "context" + "math/big" "testing" "github.com/ipfs/go-cid" @@ -35,15 +36,20 @@ import ( "github.com/ethereum/go-ethereum/statediff/indexer/models" "github.com/ethereum/go-ethereum/statediff/indexer/shared" "github.com/ethereum/go-ethereum/statediff/indexer/test_helpers" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" ) -func setupPGX(t *testing.T) { +func setupPGXIndexer(t *testing.T) { db, err = postgres.SetupPGXDB() if err != nil { t.Fatal(err) } ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestConfig, db) require.NoError(t, err) +} + +func setupPGX(t *testing.T) { + setupPGXIndexer(t) var tx interfaces.Batch tx, err = ind.PushBlock( mockBlock, @@ -557,3 +563,334 @@ func TestPGXIndexer(t *testing.T) { test_helpers.ExpectEqual(t, data, []byte{}) }) } + +func TestPGXWatchAddressMethods(t *testing.T) { + setupPGXIndexer(t) + defer tearDown(t) + defer checkTxClosure(t, 1, 0, 1) + + type res struct { + Address string `db:"address"` + CreatedAt uint64 `db:"created_at"` + WatchedAt uint64 `db:"watched_at"` + LastFilledAt uint64 `db:"last_filled_at"` + } + pgStr := "SELECT * FROM eth_meta.watched_addresses" + + t.Run("Load watched addresses (empty table)", func(t *testing.T) { + expectedData := []common.Address{} + + rows, err := ind.LoadWatchedAddresses() + require.NoError(t, err) + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1))) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2))) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.RemoveWatchedAddresses(args) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses (some non-watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{} + + err = ind.RemoveWatchedAddresses(args) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2))) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3))) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Load watched addresses", func(t *testing.T) { + expectedData := []common.Address{ + common.HexToAddress(contract4Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + } + + rows, err := ind.LoadWatchedAddresses() + require.NoError(t, err) + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses", func(t *testing.T) { + expectedData := []res{} + + err = ind.ClearWatchedAddresses() + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses (empty table)", func(t *testing.T) { + expectedData := []res{} + + err = ind.ClearWatchedAddresses() + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) +} diff --git a/statediff/indexer/database/sql/sqlx_indexer_test.go b/statediff/indexer/database/sql/sqlx_indexer_test.go index 65c0a7615b45..73309ea7f723 100644 --- a/statediff/indexer/database/sql/sqlx_indexer_test.go +++ b/statediff/indexer/database/sql/sqlx_indexer_test.go @@ -18,6 +18,7 @@ package sql_test import ( "context" + "math/big" "testing" "github.com/ipfs/go-cid" @@ -36,15 +37,20 @@ import ( "github.com/ethereum/go-ethereum/statediff/indexer/models" "github.com/ethereum/go-ethereum/statediff/indexer/shared" "github.com/ethereum/go-ethereum/statediff/indexer/test_helpers" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" ) -func setupSQLX(t *testing.T) { +func setupSQLXIndexer(t *testing.T) { db, err = postgres.SetupSQLXDB() if err != nil { t.Fatal(err) } ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestConfig, db) require.NoError(t, err) +} + +func setupSQLX(t *testing.T) { + setupSQLXIndexer(t) var tx interfaces.Batch tx, err = ind.PushBlock( mockBlock, @@ -550,3 +556,334 @@ func TestSQLXIndexer(t *testing.T) { test_helpers.ExpectEqual(t, data, []byte{}) }) } + +func TestSQLXWatchAddressMethods(t *testing.T) { + setupSQLXIndexer(t) + defer tearDown(t) + defer checkTxClosure(t, 0, 0, 0) + + type res struct { + Address string `db:"address"` + CreatedAt uint64 `db:"created_at"` + WatchedAt uint64 `db:"watched_at"` + LastFilledAt uint64 `db:"last_filled_at"` + } + pgStr := "SELECT * FROM eth_meta.watched_addresses" + + t.Run("Load watched addresses (empty table)", func(t *testing.T) { + expectedData := []common.Address{} + + rows, err := ind.LoadWatchedAddresses() + require.NoError(t, err) + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt1))) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Insert watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.InsertWatchedAddresses(args, big.NewInt(int64(watchedAt2))) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt1, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.RemoveWatchedAddresses(args) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Remove watched addresses (some non-watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + expectedData := []res{} + + err = ind.RemoveWatchedAddresses(args) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt2, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt2))) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Set watched addresses (some already watched)", func(t *testing.T) { + args := []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + expectedData := []res{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + WatchedAt: watchedAt3, + LastFilledAt: lastFilledAt, + }, + } + + err = ind.SetWatchedAddresses(args, big.NewInt(int64(watchedAt3))) + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Load watched addresses", func(t *testing.T) { + expectedData := []common.Address{ + common.HexToAddress(contract4Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + } + + rows, err := ind.LoadWatchedAddresses() + require.NoError(t, err) + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses", func(t *testing.T) { + expectedData := []res{} + + err = ind.ClearWatchedAddresses() + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) + + t.Run("Clear watched addresses (empty table)", func(t *testing.T) { + expectedData := []res{} + + err = ind.ClearWatchedAddresses() + require.NoError(t, err) + + rows := []res{} + err = db.Select(context.Background(), &rows, pgStr) + if err != nil { + t.Fatal(err) + } + + expectTrue(t, len(rows) == len(expectedData)) + for idx, row := range rows { + test_helpers.ExpectEqual(t, row, expectedData[idx]) + } + }) +} diff --git a/statediff/indexer/database/sql/test_helpers.go b/statediff/indexer/database/sql/test_helpers.go index b1032f8ffc94..398258a0ee7c 100644 --- a/statediff/indexer/database/sql/test_helpers.go +++ b/statediff/indexer/database/sql/test_helpers.go @@ -73,6 +73,10 @@ func TearDownDB(t *testing.T, db Database) { if err != nil { t.Fatal(err) } + _, err = tx.Exec(ctx, `DELETE FROM eth_meta.watched_addresses`) + if err != nil { + t.Fatal(err) + } err = tx.Commit(ctx) if err != nil { t.Fatal(err) diff --git a/statediff/indexer/interfaces/interfaces.go b/statediff/indexer/interfaces/interfaces.go index 8f951230d7e7..6910e3f4962a 100644 --- a/statediff/indexer/interfaces/interfaces.go +++ b/statediff/indexer/interfaces/interfaces.go @@ -21,6 +21,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/statediff/indexer/shared" sdtypes "github.com/ethereum/go-ethereum/statediff/types" @@ -32,6 +33,14 @@ type StateDiffIndexer interface { PushStateNode(tx Batch, stateNode sdtypes.StateNode, headerID string) error PushCodeAndCodeHash(tx Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error ReportDBMetrics(delay time.Duration, quit <-chan bool) + + // Methods used by WatchAddress API/functionality + LoadWatchedAddresses() ([]common.Address, error) + InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error + RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error + SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error + ClearWatchedAddresses() error + io.Closer } diff --git a/statediff/service.go b/statediff/service.go index 960f776f8ec1..2a6c3d45c124 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -18,6 +18,7 @@ package statediff import ( "bytes" + "fmt" "math/big" "strconv" "strings" @@ -47,28 +48,34 @@ import ( "github.com/ethereum/go-ethereum/statediff/indexer/shared" types2 "github.com/ethereum/go-ethereum/statediff/types" "github.com/ethereum/go-ethereum/trie" + "github.com/thoas/go-funk" ) const ( - chainEventChanSize = 20000 - genesisBlockNumber = 0 - defaultRetryLimit = 3 // default retry limit once deadlock is detected. - deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html + chainEventChanSize = 20000 + genesisBlockNumber = 0 + defaultRetryLimit = 3 // default retry limit once deadlock is detected. + deadlockDetected = "deadlock detected" // 40P01 https://www.postgresql.org/docs/current/errcodes-appendix.html + typeAssertionFailed = "type assertion failed" + unexpectedOperation = "unexpected operation" ) -var writeLoopParams = Params{ - IntermediateStateNodes: true, - IntermediateStorageNodes: true, - IncludeBlock: true, - IncludeReceipts: true, - IncludeTD: true, - IncludeCode: true, +var writeLoopParams = ParamsWithMutex{ + Params: Params{ + IntermediateStateNodes: true, + IntermediateStorageNodes: true, + IncludeBlock: true, + IncludeReceipts: true, + IncludeTD: true, + IncludeCode: true, + }, } var statediffMetrics = RegisterStatediffMetrics(metrics.DefaultRegistry) type blockChain interface { SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription + CurrentBlock() *types.Block GetBlockByHash(hash common.Hash) *types.Block GetBlockByNumber(number uint64) *types.Block GetReceiptsByHash(hash common.Hash) types.Receipts @@ -103,6 +110,8 @@ type IService interface { WriteStateDiffFor(blockHash common.Hash, params Params) error // WriteLoop event loop for progressively processing and writing diffs directly to DB WriteLoop(chainEventCh chan core.ChainEvent) + // Method to change the addresses being watched in write loop params + WatchAddress(operation types2.OperationType, args []types2.WatchAddressArg) error } // Service is the underlying struct for the state diffing service @@ -159,6 +168,7 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params blockChain := ethServ.BlockChain() var indexer interfaces.StateDiffIndexer var db sql.Database + var err error quitCh := make(chan bool) if params.IndexerConfig != nil { info := nodeinfo.Info{ @@ -215,6 +225,12 @@ func New(stack *node.Node, ethServ *eth.Ethereum, cfg *ethconfig.Config, params } stack.RegisterLifecycle(sds) stack.RegisterAPIs(sds.APIs()) + + err = loadWatchedAddresses(indexer) + if err != nil { + return err + } + return nil } @@ -304,7 +320,9 @@ func (sds *Service) WriteLoop(chainEventCh chan core.ChainEvent) { func (sds *Service) writeGenesisStateDiff(currBlock *types.Block, workerId uint) { // For genesis block we need to return the entire state trie hence we diff it with an empty trie. log.Info("Writing state diff", "block height", genesisBlockNumber, "worker", workerId) - err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams) + writeLoopParams.RLock() + err := sds.writeStateDiffWithRetry(currBlock, common.Hash{}, writeLoopParams.Params) + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", genesisBlockNumber, "error", err.Error(), "worker", workerId) @@ -341,7 +359,9 @@ func (sds *Service) writeLoopWorker(params workerParams) { } log.Info("Writing state diff", "block height", currentBlock.Number().Uint64(), "worker", params.id) - err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams) + writeLoopParams.RLock() + err := sds.writeStateDiffWithRetry(currentBlock, parentBlock.Root(), writeLoopParams.Params) + writeLoopParams.RUnlock() if err != nil { log.Error("statediff.Service.WriteLoop: processing error", "block height", currentBlock.Number().Uint64(), "error", err.Error(), "worker", params.id) sds.KnownGaps.errorState = true @@ -456,6 +476,10 @@ func (sds *Service) streamStateDiff(currentBlock *types.Block, parentRoot common func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state diff", "block height", blockNumber) + + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + if blockNumber == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -468,6 +492,10 @@ func (sds *Service) StateDiffAt(blockNumber uint64, params Params) (*Payload, er func (sds *Service) StateDiffFor(blockHash common.Hash, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByHash(blockHash) log.Info("sending state diff", "block hash", blockHash) + + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + if currentBlock.NumberU64() == 0 { return sds.processStateDiff(currentBlock, common.Hash{}, params) } @@ -526,6 +554,10 @@ func (sds *Service) newPayload(stateObject []byte, block *types.Block, params Pa func (sds *Service) StateTrieAt(blockNumber uint64, params Params) (*Payload, error) { currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) log.Info("sending state trie", "block height", blockNumber) + + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + return sds.processStateTrie(currentBlock, params) } @@ -548,6 +580,10 @@ func (sds *Service) Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- boo if atomic.CompareAndSwapInt32(&sds.subscribers, 0, 1) { log.Info("State diffing subscription received; beginning statediff processing") } + + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + // Subscription type is defined as the hash of the rlp-serialized subscription params by, err := rlp.EncodeToBytes(params) if err != nil { @@ -644,7 +680,7 @@ func (sds *Service) Start() error { go sds.Loop(chainEventCh) if sds.enableWriteLoop { - log.Info("Starting statediff DB write loop", "params", writeLoopParams) + log.Info("Starting statediff DB write loop", "params", writeLoopParams.Params) chainEventCh := make(chan core.ChainEvent, chainEventChanSize) go sds.WriteLoop(chainEventCh) } @@ -741,6 +777,9 @@ func (sds *Service) StreamCodeAndCodeHash(blockNumber uint64, outChan chan<- typ // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) parentRoot := common.Hash{} if blockNumber != 0 { @@ -754,6 +793,9 @@ func (sds *Service) WriteStateDiffAt(blockNumber uint64, params Params) error { // This operation cannot be performed back past the point of db pruning; it requires an archival node // for historical data func (sds *Service) WriteStateDiffFor(blockHash common.Hash, params Params) error { + // compute leaf keys of watched addresses in the params + params.ComputeWatchedAddressesLeafKeys() + currentBlock := sds.BlockChain.GetBlockByHash(blockHash) parentRoot := common.Hash{} if currentBlock.NumberU64() != 0 { @@ -821,3 +863,130 @@ func (sds *Service) writeStateDiffWithRetry(block *types.Block, parentRoot commo } return err } + +// Performs one of following operations on the watched addresses in writeLoopParams and the db: +// add | remove | set | clear +func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.WatchAddressArg) error { + // lock writeLoopParams for a write + writeLoopParams.Lock() + defer writeLoopParams.Unlock() + + // get the current block number + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() + + switch operation { + case types2.Add: + // filter out args having an already watched address with a warning + filteredArgs, ok := funk.Filter(args, func(arg types2.WatchAddressArg) bool { + if funk.Contains(writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { + log.Warn("Address already being watched", "address", arg.Address) + return false + } + return true + }).([]types2.WatchAddressArg) + if !ok { + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) + } + + // get addresses from the filtered args + filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) + } + + // update the db + err = sds.indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = append(writeLoopParams.WatchedAddresses, filteredAddresses...) + funk.ForEach(filteredAddresses, func(address common.Address) { + writeLoopParams.watchedAddressesLeafKeys[crypto.Keccak256Hash(address.Bytes())] = struct{}{} + }) + case types2.Remove: + // get addresses from args + argAddresses, err := MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) + } + + // remove the provided addresses from currently watched addresses + addresses, ok := funk.Subtract(writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) + if !ok { + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err = sds.indexer.RemoveWatchedAddresses(args) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = addresses + funk.ForEach(argAddresses, func(address common.Address) { + delete(writeLoopParams.watchedAddressesLeafKeys, crypto.Keccak256Hash(address.Bytes())) + }) + case types2.Set: + // get addresses from args + argAddresses, err := MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) + } + + // update the db + err = sds.indexer.SetWatchedAddresses(args, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = argAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() + case types2.Clear: + // update the db + err := sds.indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + // update in-memory params + writeLoopParams.WatchedAddresses = []common.Address{} + writeLoopParams.ComputeWatchedAddressesLeafKeys() + + default: + return fmt.Errorf("%s %s", unexpectedOperation, operation) + } + + return nil +} + +// loadWatchedAddresses loads watched addresses to in-memory write loop params +func loadWatchedAddresses(indexer interfaces.StateDiffIndexer) error { + watchedAddresses, err := indexer.LoadWatchedAddresses() + if err != nil { + return err + } + + writeLoopParams.Lock() + defer writeLoopParams.Unlock() + + writeLoopParams.WatchedAddresses = watchedAddresses + writeLoopParams.ComputeWatchedAddressesLeafKeys() + + return nil +} + +// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address +func MapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) { + addresses, ok := funk.Map(args, func(arg types2.WatchAddressArg) common.Address { + return common.HexToAddress(arg.Address) + }).([]common.Address) + if !ok { + return nil, fmt.Errorf(typeAssertionFailed) + } + + return addresses, nil +} diff --git a/statediff/service_test.go b/statediff/service_test.go index 96be2da1bb5c..987e1b467f65 100644 --- a/statediff/service_test.go +++ b/statediff/service_test.go @@ -146,6 +146,7 @@ func testErrorInChainEventLoop(t *testing.T) { } } + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -197,6 +198,8 @@ func testErrorInBlockLoop(t *testing.T) { } }() service.Loop(eventsChannel) + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) @@ -270,6 +273,8 @@ func testErrorInStateDiffAt(t *testing.T) { if err != nil { t.Error(err) } + + defaultParams.ComputeWatchedAddressesLeafKeys() if !reflect.DeepEqual(builder.Params, defaultParams) { t.Error("Test failure:", t.Name()) t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams) diff --git a/statediff/test_helpers/mocks/blockchain.go b/statediff/test_helpers/mocks/blockchain.go index b4b1f36942a4..f2a77af64ba5 100644 --- a/statediff/test_helpers/mocks/blockchain.go +++ b/statediff/test_helpers/mocks/blockchain.go @@ -39,6 +39,7 @@ type BlockChain struct { Receipts map[common.Hash]types.Receipts TDByHash map[common.Hash]*big.Int TDByNum map[uint64]*big.Int + currentBlock *types.Block } // SetBlocksForHashes mock method @@ -128,6 +129,16 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int { return nil } +// SetCurrentBlock test method +func (bc *BlockChain) SetCurrentBlock(block *types.Block) { + bc.currentBlock = block +} + +// CurrentBlock mock method +func (bc *BlockChain) CurrentBlock() *types.Block { + return bc.currentBlock +} + func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) { if bc.TDByHash == nil { bc.TDByHash = make(map[common.Hash]*big.Int) diff --git a/statediff/test_helpers/mocks/indexer.go b/statediff/test_helpers/mocks/indexer.go new file mode 100644 index 000000000000..92005a8b4c4d --- /dev/null +++ b/statediff/test_helpers/mocks/indexer.go @@ -0,0 +1,70 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package mocks + +import ( + "math/big" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/statediff/indexer/interfaces" + sdtypes "github.com/ethereum/go-ethereum/statediff/types" +) + +var _ interfaces.StateDiffIndexer = &StateDiffIndexer{} + +// StateDiffIndexer is a mock state diff indexer +type StateDiffIndexer struct{} + +func (sdi *StateDiffIndexer) PushBlock(block *types.Block, receipts types.Receipts, totalDifficulty *big.Int) (interfaces.Batch, error) { + return nil, nil +} + +func (sdi *StateDiffIndexer) PushStateNode(tx interfaces.Batch, stateNode sdtypes.StateNode, headerID string) error { + return nil +} + +func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx interfaces.Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error { + return nil +} + +func (sdi *StateDiffIndexer) ReportDBMetrics(delay time.Duration, quit <-chan bool) {} + +func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) { + return nil, nil +} + +func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error { + return nil +} + +func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error { + return nil +} + +func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error { + return nil +} + +func (sdi *StateDiffIndexer) ClearWatchedAddresses() error { + return nil +} + +func (sdi *StateDiffIndexer) Close() error { + return nil +} diff --git a/statediff/test_helpers/mocks/service.go b/statediff/test_helpers/mocks/service.go index f10017df43f3..1ff6857ddebb 100644 --- a/statediff/test_helpers/mocks/service.go +++ b/statediff/test_helpers/mocks/service.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" + "github.com/thoas/go-funk" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" @@ -32,9 +33,15 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/statediff" + "github.com/ethereum/go-ethereum/statediff/indexer/interfaces" sdtypes "github.com/ethereum/go-ethereum/statediff/types" ) +var ( + typeAssertionFailed = "type assertion failed" + unexpectedOperation = "unexpected operation" +) + // MockStateDiffService is a mock state diff service type MockStateDiffService struct { sync.Mutex @@ -47,6 +54,8 @@ type MockStateDiffService struct { QuitChan chan bool Subscriptions map[common.Hash]map[rpc.ID]statediff.Subscription SubscriptionTypes map[common.Hash]statediff.Params + Indexer interfaces.StateDiffIndexer + writeLoopParams statediff.ParamsWithMutex } // Protocols mock method @@ -332,3 +341,98 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) { log.Info("unable to close subscription %s; channel has no receiver", id) } } + +// Performs one of following operations on the watched addresses in writeLoopParams and the db: +// add | remove | set | clear +func (sds *MockStateDiffService) WatchAddress(operation sdtypes.OperationType, args []sdtypes.WatchAddressArg) error { + // lock writeLoopParams for a write + sds.writeLoopParams.Lock() + defer sds.writeLoopParams.Unlock() + + // get the current block number + currentBlockNumber := sds.BlockChain.CurrentBlock().Number() + + switch operation { + case sdtypes.Add: + // filter out args having an already watched address with a warning + filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool { + if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) { + log.Warn("Address already being watched", "address", arg.Address) + return false + } + return true + }).([]sdtypes.WatchAddressArg) + if !ok { + return fmt.Errorf("add: filtered args %s", typeAssertionFailed) + } + + // get addresses from the filtered args + filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs) + if err != nil { + return fmt.Errorf("add: filtered addresses %s", err.Error()) + } + + // update the db + err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...) + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case sdtypes.Remove: + // get addresses from args + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("remove: mapped addresses %s", err.Error()) + } + + // remove the provided addresses from currently watched addresses + addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address) + if !ok { + return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed) + } + + // update the db + err = sds.Indexer.RemoveWatchedAddresses(args) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = addresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case sdtypes.Set: + // get addresses from args + argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args) + if err != nil { + return fmt.Errorf("set: mapped addresses %s", err.Error()) + } + + // update the db + err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber) + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = argAddresses + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + case sdtypes.Clear: + // update the db + err := sds.Indexer.ClearWatchedAddresses() + if err != nil { + return err + } + + // update in-memory params + sds.writeLoopParams.WatchedAddresses = []common.Address{} + sds.writeLoopParams.ComputeWatchedAddressesLeafKeys() + + default: + return fmt.Errorf("%s %s", unexpectedOperation, operation) + } + + return nil +} diff --git a/statediff/test_helpers/mocks/service_test.go b/statediff/test_helpers/mocks/service_test.go index c236e1fd1a2f..a638f14242fa 100644 --- a/statediff/test_helpers/mocks/service_test.go +++ b/statediff/test_helpers/mocks/service_test.go @@ -21,6 +21,7 @@ import ( "fmt" "math/big" "os" + "reflect" "sort" "sync" "testing" @@ -88,6 +89,7 @@ func init() { func TestAPI(t *testing.T) { testSubscriptionAPI(t) testHTTPAPI(t) + testWatchAddressAPI(t) } func testSubscriptionAPI(t *testing.T) { @@ -253,3 +255,286 @@ func testHTTPAPI(t *testing.T) { t.Errorf("paylaod does not have the expected total difficulty\r\nactual td: %d\r\nexpected td: %d", payload.TotalDifficulty.Int64(), mockTotalDifficulty.Int64()) } } + +func testWatchAddressAPI(t *testing.T) { + blocks, chain := test_helpers.MakeChain(6, test_helpers.Genesis, test_helpers.TestChainGen) + defer chain.Stop() + block6 := blocks[5] + + mockBlockChain := &BlockChain{} + mockBlockChain.SetCurrentBlock(block6) + mockIndexer := StateDiffIndexer{} + mockService := MockStateDiffService{ + BlockChain: mockBlockChain, + Indexer: &mockIndexer, + } + + // test data + var ( + contract1Address = "0x5d663F5269090bD2A7DC2390c911dF6083D7b28F" + contract2Address = "0x6Eb7e5C66DB8af2E96159AC440cbc8CDB7fbD26B" + contract3Address = "0xcfeB164C328CA13EFd3C77E1980d94975aDfedfc" + contract4Address = "0x0Edf0c4f393a628DE4828B228C48175b3EA297fc" + contract1CreatedAt = uint64(1) + contract2CreatedAt = uint64(2) + contract3CreatedAt = uint64(3) + contract4CreatedAt = uint64(4) + + args1 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams1 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + expectedParams1 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + }, + } + + args2 = []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams2 = expectedParams1 + expectedParams2 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args3 = []sdtypes.WatchAddressArg{ + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams3 = expectedParams2 + expectedParams3 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + }, + } + + args4 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + } + startingParams4 = expectedParams3 + expectedParams4 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args5 = []sdtypes.WatchAddressArg{ + { + Address: contract1Address, + CreatedAt: contract1CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + startingParams5 = expectedParams4 + expectedParams5 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract1Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args6 = []sdtypes.WatchAddressArg{ + { + Address: contract4Address, + CreatedAt: contract4CreatedAt, + }, + { + Address: contract2Address, + CreatedAt: contract2CreatedAt, + }, + { + Address: contract3Address, + CreatedAt: contract3CreatedAt, + }, + } + startingParams6 = expectedParams5 + expectedParams6 = statediff.Params{ + WatchedAddresses: []common.Address{ + common.HexToAddress(contract4Address), + common.HexToAddress(contract2Address), + common.HexToAddress(contract3Address), + }, + } + + args7 = []sdtypes.WatchAddressArg{} + startingParams7 = expectedParams6 + expectedParams7 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args8 = []sdtypes.WatchAddressArg{} + startingParams8 = expectedParams6 + expectedParams8 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + + args9 = []sdtypes.WatchAddressArg{} + startingParams9 = expectedParams8 + expectedParams9 = statediff.Params{ + WatchedAddresses: []common.Address{}, + } + ) + + tests := []struct { + name string + operation sdtypes.OperationType + args []sdtypes.WatchAddressArg + startingParams statediff.Params + expectedParams statediff.Params + expectedErr error + }{ + { + "testAddAddresses", + sdtypes.Add, + args1, + startingParams1, + expectedParams1, + nil, + }, + { + "testAddAddressesSomeWatched", + sdtypes.Add, + args2, + startingParams2, + expectedParams2, + nil, + }, + { + "testRemoveAddresses", + sdtypes.Remove, + args3, + startingParams3, + expectedParams3, + nil, + }, + { + "testRemoveAddressesSomeWatched", + sdtypes.Remove, + args4, + startingParams4, + expectedParams4, + nil, + }, + { + "testSetAddresses", + sdtypes.Set, + args5, + startingParams5, + expectedParams5, + nil, + }, + { + "testSetAddressesSomeWatched", + sdtypes.Set, + args6, + startingParams6, + expectedParams6, + nil, + }, + { + "testSetAddressesEmtpyArgs", + sdtypes.Set, + args7, + startingParams7, + expectedParams7, + nil, + }, + { + "testClearAddresses", + sdtypes.Clear, + args8, + startingParams8, + expectedParams8, + nil, + }, + { + "testClearAddressesEmpty", + sdtypes.Clear, + args9, + startingParams9, + expectedParams9, + nil, + }, + + // invalid args + { + "testInvalidOperation", + "WrongOp", + args9, + startingParams9, + statediff.Params{}, + fmt.Errorf("%s WrongOp", unexpectedOperation), + }, + } + + for _, test := range tests { + // set indexing params + mockService.writeLoopParams = statediff.ParamsWithMutex{ + Params: test.startingParams, + } + mockService.writeLoopParams.ComputeWatchedAddressesLeafKeys() + + // make the API call to change watched addresses + err := mockService.WatchAddress(test.operation, test.args) + if test.expectedErr != nil { + if err.Error() != test.expectedErr.Error() { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual err: %+v\nexpected err: %+v", err, test.expectedErr) + } + + continue + } + if err != nil { + t.Error(err) + } + + // check updated indexing params + test.expectedParams.ComputeWatchedAddressesLeafKeys() + updatedParams := mockService.writeLoopParams.Params + if !reflect.DeepEqual(updatedParams, test.expectedParams) { + t.Logf("Test failed: %s", test.name) + t.Errorf("actual params: %+v\nexpected params: %+v", updatedParams, test.expectedParams) + } + } +} diff --git a/statediff/types/types.go b/statediff/types/types.go index 36008a784063..0a29adaf892b 100644 --- a/statediff/types/types.go +++ b/statediff/types/types.go @@ -101,3 +101,20 @@ type CodeAndCodeHash struct { type StateNodeSink func(StateNode) error type StorageNodeSink func(StorageNode) error type CodeSink func(CodeAndCodeHash) error + +// OperationType for type of WatchAddress operation +type OperationType string + +const ( + Add OperationType = "add" + Remove OperationType = "remove" + Set OperationType = "set" + Clear OperationType = "clear" +) + +// WatchAddressArg is a arg type for WatchAddress API +type WatchAddressArg struct { + // Address represents common.Address + Address string + CreatedAt uint64 +}