Skip to content

Commit

Permalink
Add tests for the API to change addresses being watched
Browse files Browse the repository at this point in the history
  • Loading branch information
prathamesh0 committed Apr 1, 2022
1 parent af4dbed commit 7f38afe
Show file tree
Hide file tree
Showing 6 changed files with 480 additions and 5 deletions.
10 changes: 5 additions & 5 deletions statediff/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W
}

// get addresses from the filtered args
filteredAddresses, err := mapWatchAddressArgsToAddresses(filteredArgs)
filteredAddresses, err := MapWatchAddressArgsToAddresses(filteredArgs)
if err != nil {
return fmt.Errorf("add: filtered addresses %s", err.Error())
}
Expand All @@ -907,7 +907,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W
})
case types2.Remove:
// get addresses from args
argAddresses, err := mapWatchAddressArgsToAddresses(args)
argAddresses, err := MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("remove: mapped addresses %s", err.Error())
}
Expand All @@ -931,7 +931,7 @@ func (sds *Service) WatchAddress(operation types2.OperationType, args []types2.W
})
case types2.Set:
// get addresses from args
argAddresses, err := mapWatchAddressArgsToAddresses(args)
argAddresses, err := MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("set: mapped addresses %s", err.Error())
}
Expand Down Expand Up @@ -979,8 +979,8 @@ func loadWatchedAddresses(indexer interfaces.StateDiffIndexer) error {
return nil
}

// mapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address
func mapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) {
// MapWatchAddressArgsToAddresses maps []WatchAddressArg to corresponding []common.Address
func MapWatchAddressArgsToAddresses(args []types2.WatchAddressArg) ([]common.Address, error) {
addresses, ok := funk.Map(args, func(arg types2.WatchAddressArg) common.Address {
return common.HexToAddress(arg.Address)
}).([]common.Address)
Expand Down
5 changes: 5 additions & 0 deletions statediff/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ func testErrorInChainEventLoop(t *testing.T) {
}
}

defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
Expand Down Expand Up @@ -197,6 +198,8 @@ func testErrorInBlockLoop(t *testing.T) {
}
}()
service.Loop(eventsChannel)

defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
Expand Down Expand Up @@ -270,6 +273,8 @@ func testErrorInStateDiffAt(t *testing.T) {
if err != nil {
t.Error(err)
}

defaultParams.ComputeWatchedAddressesLeafKeys()
if !reflect.DeepEqual(builder.Params, defaultParams) {
t.Error("Test failure:", t.Name())
t.Logf("Actual params does not equal expected.\nactual:%+v\nexpected: %+v", builder.Params, defaultParams)
Expand Down
11 changes: 11 additions & 0 deletions statediff/test_helpers/mocks/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type BlockChain struct {
Receipts map[common.Hash]types.Receipts
TDByHash map[common.Hash]*big.Int
TDByNum map[uint64]*big.Int
currentBlock *types.Block
}

// SetBlocksForHashes mock method
Expand Down Expand Up @@ -128,6 +129,16 @@ func (bc *BlockChain) GetTd(hash common.Hash, blockNum uint64) *big.Int {
return nil
}

// SetCurrentBlock test method
func (bc *BlockChain) SetCurrentBlock(block *types.Block) {
bc.currentBlock = block
}

// CurrentBlock mock method
func (bc *BlockChain) CurrentBlock() *types.Block {
return bc.currentBlock
}

func (bc *BlockChain) SetTd(hash common.Hash, blockNum uint64, td *big.Int) {
if bc.TDByHash == nil {
bc.TDByHash = make(map[common.Hash]*big.Int)
Expand Down
70 changes: 70 additions & 0 deletions statediff/test_helpers/mocks/indexer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package mocks

import (
"math/big"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/statediff/indexer/interfaces"
sdtypes "github.com/ethereum/go-ethereum/statediff/types"
)

var _ interfaces.StateDiffIndexer = &StateDiffIndexer{}

// StateDiffIndexer is a mock state diff indexer
type StateDiffIndexer struct{}

func (sdi *StateDiffIndexer) PushBlock(block *types.Block, receipts types.Receipts, totalDifficulty *big.Int) (interfaces.Batch, error) {
return nil, nil
}

func (sdi *StateDiffIndexer) PushStateNode(tx interfaces.Batch, stateNode sdtypes.StateNode, headerID string) error {
return nil
}

func (sdi *StateDiffIndexer) PushCodeAndCodeHash(tx interfaces.Batch, codeAndCodeHash sdtypes.CodeAndCodeHash) error {
return nil
}

func (sdi *StateDiffIndexer) ReportDBMetrics(delay time.Duration, quit <-chan bool) {}

func (sdi *StateDiffIndexer) LoadWatchedAddresses() ([]common.Address, error) {
return nil, nil
}

func (sdi *StateDiffIndexer) InsertWatchedAddresses(addresses []sdtypes.WatchAddressArg, currentBlock *big.Int) error {
return nil
}

func (sdi *StateDiffIndexer) RemoveWatchedAddresses(addresses []sdtypes.WatchAddressArg) error {
return nil
}

func (sdi *StateDiffIndexer) SetWatchedAddresses(args []sdtypes.WatchAddressArg, currentBlockNumber *big.Int) error {
return nil
}

func (sdi *StateDiffIndexer) ClearWatchedAddresses() error {
return nil
}

func (sdi *StateDiffIndexer) Close() error {
return nil
}
104 changes: 104 additions & 0 deletions statediff/test_helpers/mocks/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,23 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
"github.com/thoas/go-funk"

"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc"
"github.com/ethereum/go-ethereum/statediff"
"github.com/ethereum/go-ethereum/statediff/indexer/interfaces"
sdtypes "github.com/ethereum/go-ethereum/statediff/types"
)

var (
typeAssertionFailed = "type assertion failed"
unexpectedOperation = "unexpected operation"
)

// MockStateDiffService is a mock state diff service
type MockStateDiffService struct {
sync.Mutex
Expand All @@ -47,6 +54,8 @@ type MockStateDiffService struct {
QuitChan chan bool
Subscriptions map[common.Hash]map[rpc.ID]statediff.Subscription
SubscriptionTypes map[common.Hash]statediff.Params
Indexer interfaces.StateDiffIndexer
writeLoopParams statediff.ParamsWithMutex
}

// Protocols mock method
Expand Down Expand Up @@ -332,3 +341,98 @@ func sendNonBlockingQuit(id rpc.ID, sub statediff.Subscription) {
log.Info("unable to close subscription %s; channel has no receiver", id)
}
}

// Performs one of following operations on the watched addresses in writeLoopParams and the db:
// add | remove | set | clear
func (sds *MockStateDiffService) WatchAddress(operation sdtypes.OperationType, args []sdtypes.WatchAddressArg) error {
// lock writeLoopParams for a write
sds.writeLoopParams.Lock()
defer sds.writeLoopParams.Unlock()

// get the current block number
currentBlockNumber := sds.BlockChain.CurrentBlock().Number()

switch operation {
case sdtypes.Add:
// filter out args having an already watched address with a warning
filteredArgs, ok := funk.Filter(args, func(arg sdtypes.WatchAddressArg) bool {
if funk.Contains(sds.writeLoopParams.WatchedAddresses, common.HexToAddress(arg.Address)) {
log.Warn("Address already being watched", "address", arg.Address)
return false
}
return true
}).([]sdtypes.WatchAddressArg)
if !ok {
return fmt.Errorf("add: filtered args %s", typeAssertionFailed)
}

// get addresses from the filtered args
filteredAddresses, err := statediff.MapWatchAddressArgsToAddresses(filteredArgs)
if err != nil {
return fmt.Errorf("add: filtered addresses %s", err.Error())
}

// update the db
err = sds.Indexer.InsertWatchedAddresses(filteredArgs, currentBlockNumber)
if err != nil {
return err
}

// update in-memory params
sds.writeLoopParams.WatchedAddresses = append(sds.writeLoopParams.WatchedAddresses, filteredAddresses...)
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Remove:
// get addresses from args
argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("remove: mapped addresses %s", err.Error())
}

// remove the provided addresses from currently watched addresses
addresses, ok := funk.Subtract(sds.writeLoopParams.WatchedAddresses, argAddresses).([]common.Address)
if !ok {
return fmt.Errorf("remove: filtered addresses %s", typeAssertionFailed)
}

// update the db
err = sds.Indexer.RemoveWatchedAddresses(args)
if err != nil {
return err
}

// update in-memory params
sds.writeLoopParams.WatchedAddresses = addresses
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Set:
// get addresses from args
argAddresses, err := statediff.MapWatchAddressArgsToAddresses(args)
if err != nil {
return fmt.Errorf("set: mapped addresses %s", err.Error())
}

// update the db
err = sds.Indexer.SetWatchedAddresses(args, currentBlockNumber)
if err != nil {
return err
}

// update in-memory params
sds.writeLoopParams.WatchedAddresses = argAddresses
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()
case sdtypes.Clear:
// update the db
err := sds.Indexer.ClearWatchedAddresses()
if err != nil {
return err
}

// update in-memory params
sds.writeLoopParams.WatchedAddresses = []common.Address{}
sds.writeLoopParams.ComputeWatchedAddressesLeafKeys()

default:
return fmt.Errorf("%s %s", unexpectedOperation, operation)
}

return nil
}
Loading

0 comments on commit 7f38afe

Please sign in to comment.