Skip to content

Commit

Permalink
polygon/heimdall: use generics to eliminate casts (#10371)
Browse files Browse the repository at this point in the history
  • Loading branch information
battlmonstr authored May 16, 2024
1 parent 5559f23 commit 3e93af6
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 158 deletions.
35 changes: 18 additions & 17 deletions polygon/heimdall/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"github.com/ledgerwatch/erigon-lib/kv"
)

var _ Waypoint = Checkpoint{}

type CheckpointId uint64

// Checkpoint defines a response object type of bor checkpoint
Expand All @@ -20,50 +18,53 @@ type Checkpoint struct {
Fields WaypointFields
}

func (c Checkpoint) RawId() uint64 {
var _ Entity = &Checkpoint{}
var _ Waypoint = &Checkpoint{}

func (c *Checkpoint) RawId() uint64 {
return uint64(c.Id)
}

func (c Checkpoint) StartBlock() *big.Int {
func (c *Checkpoint) StartBlock() *big.Int {
return c.Fields.StartBlock
}

func (c Checkpoint) EndBlock() *big.Int {
func (c *Checkpoint) EndBlock() *big.Int {
return c.Fields.EndBlock
}

func (c Checkpoint) BlockNumRange() ClosedRange {
func (c *Checkpoint) BlockNumRange() ClosedRange {
return ClosedRange{
Start: c.StartBlock().Uint64(),
End: c.EndBlock().Uint64(),
}
}

func (c Checkpoint) RootHash() libcommon.Hash {
func (c *Checkpoint) RootHash() libcommon.Hash {
return c.Fields.RootHash
}

func (c Checkpoint) Timestamp() uint64 {
func (c *Checkpoint) Timestamp() uint64 {
return c.Fields.Timestamp
}

func (c Checkpoint) Length() uint64 {
func (c *Checkpoint) Length() uint64 {
return c.Fields.Length()
}

func (c Checkpoint) CmpRange(n uint64) int {
func (c *Checkpoint) CmpRange(n uint64) int {
return c.Fields.CmpRange(n)
}

func (m Checkpoint) String() string {
func (c *Checkpoint) String() string {
return fmt.Sprintf(
"Checkpoint {%v (%d:%d) %v %v %v}",
m.Fields.Proposer.String(),
m.Fields.StartBlock,
m.Fields.EndBlock,
m.Fields.RootHash.Hex(),
m.Fields.ChainID,
m.Fields.Timestamp,
c.Fields.Proposer.String(),
c.Fields.StartBlock,
c.Fields.EndBlock,
c.Fields.RootHash.Hex(),
c.Fields.ChainID,
c.Fields.Timestamp,
)
}

Expand Down
34 changes: 17 additions & 17 deletions polygon/heimdall/entity_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,29 @@ import (
"github.com/ledgerwatch/log/v3"
)

type entityFetcher interface {
type entityFetcher[TEntity Entity] interface {
FetchLastEntityId(ctx context.Context) (uint64, error)
FetchEntitiesRange(ctx context.Context, idRange ClosedRange) ([]Entity, error)
FetchEntitiesRange(ctx context.Context, idRange ClosedRange) ([]TEntity, error)
}

type entityFetcherImpl struct {
type entityFetcherImpl[TEntity Entity] struct {
name string

fetchLastEntityId func(ctx context.Context) (int64, error)
fetchEntity func(ctx context.Context, id int64) (Entity, error)
fetchEntitiesPage func(ctx context.Context, page uint64, limit uint64) ([]Entity, error)
fetchEntity func(ctx context.Context, id int64) (TEntity, error)
fetchEntitiesPage func(ctx context.Context, page uint64, limit uint64) ([]TEntity, error)

logger log.Logger
}

func newEntityFetcher(
func newEntityFetcher[TEntity Entity](
name string,
fetchLastEntityId func(ctx context.Context) (int64, error),
fetchEntity func(ctx context.Context, id int64) (Entity, error),
fetchEntitiesPage func(ctx context.Context, page uint64, limit uint64) ([]Entity, error),
fetchEntity func(ctx context.Context, id int64) (TEntity, error),
fetchEntitiesPage func(ctx context.Context, page uint64, limit uint64) ([]TEntity, error),
logger log.Logger,
) entityFetcher {
return &entityFetcherImpl{
) entityFetcher[TEntity] {
return &entityFetcherImpl[TEntity]{
name: name,
fetchLastEntityId: fetchLastEntityId,
fetchEntity: fetchEntity,
Expand All @@ -41,12 +41,12 @@ func newEntityFetcher(
}
}

func (f *entityFetcherImpl) FetchLastEntityId(ctx context.Context) (uint64, error) {
func (f *entityFetcherImpl[TEntity]) FetchLastEntityId(ctx context.Context) (uint64, error) {
id, err := f.fetchLastEntityId(ctx)
return uint64(id), err
}

func (f *entityFetcherImpl) FetchEntitiesRange(ctx context.Context, idRange ClosedRange) ([]Entity, error) {
func (f *entityFetcherImpl[TEntity]) FetchEntitiesRange(ctx context.Context, idRange ClosedRange) ([]TEntity, error) {
count := idRange.Len()

const batchFetchThreshold = 100
Expand All @@ -62,20 +62,20 @@ func (f *entityFetcherImpl) FetchEntitiesRange(ctx context.Context, idRange Clos
return f.FetchEntitiesRangeSequentially(ctx, idRange)
}

func (f *entityFetcherImpl) FetchEntitiesRangeSequentially(ctx context.Context, idRange ClosedRange) ([]Entity, error) {
return ClosedRangeMap(idRange, func(id uint64) (Entity, error) {
func (f *entityFetcherImpl[TEntity]) FetchEntitiesRangeSequentially(ctx context.Context, idRange ClosedRange) ([]TEntity, error) {
return ClosedRangeMap(idRange, func(id uint64) (TEntity, error) {
return f.fetchEntity(ctx, int64(id))
})
}

func (f *entityFetcherImpl) FetchAllEntities(ctx context.Context) ([]Entity, error) {
func (f *entityFetcherImpl[TEntity]) FetchAllEntities(ctx context.Context) ([]TEntity, error) {
// TODO: once heimdall API is fixed to return sorted items in pages we can only fetch
//
// the new pages after lastStoredCheckpointId using the checkpoints/list paging API
// (for now we have to fetch all of them)
// and also remove sorting we do after fetching

var entities []Entity
var entities []TEntity

fetchStartTime := time.Now()
progressLogTicker := time.NewTicker(30 * time.Second)
Expand Down Expand Up @@ -106,7 +106,7 @@ func (f *entityFetcherImpl) FetchAllEntities(ctx context.Context) ([]Entity, err
}
}

slices.SortFunc(entities, func(e1, e2 Entity) int {
slices.SortFunc(entities, func(e1, e2 TEntity) int {
n1 := e1.BlockNumRange().Start
n2 := e2.BlockNumRange().Start
return cmp.Compare(n1, n2)
Expand Down
74 changes: 41 additions & 33 deletions polygon/heimdall/entity_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ import (
"github.com/ledgerwatch/erigon-lib/kv/iter"
)

type entityStore interface {
type entityStore[TEntity Entity] interface {
Prepare(ctx context.Context) error
Close()
GetLastEntityId(ctx context.Context) (uint64, bool, error)
GetLastEntity(ctx context.Context) (Entity, error)
GetEntity(ctx context.Context, id uint64) (Entity, error)
PutEntity(ctx context.Context, id uint64, entity Entity) error
FindByBlockNum(ctx context.Context, blockNum uint64) (Entity, error)
RangeFromId(ctx context.Context, startId uint64) ([]Entity, error)
RangeFromBlockNum(ctx context.Context, startBlockNum uint64) ([]Entity, error)
GetLastEntity(ctx context.Context) (TEntity, error)
GetEntity(ctx context.Context, id uint64) (TEntity, error)
PutEntity(ctx context.Context, id uint64, entity TEntity) error
FindByBlockNum(ctx context.Context, blockNum uint64) (TEntity, error)
RangeFromId(ctx context.Context, startId uint64) ([]TEntity, error)
RangeFromBlockNum(ctx context.Context, startBlockNum uint64) ([]TEntity, error)
}

type entityStoreImpl struct {
type entityStoreImpl[TEntity Entity] struct {
tx kv.RwTx
table string

makeEntity func() Entity
makeEntity func() TEntity
getLastEntityId func(ctx context.Context, tx kv.Tx) (uint64, bool, error)
loadEntityBytes func(ctx context.Context, tx kv.Getter, id uint64) ([]byte, error)

Expand All @@ -35,15 +35,15 @@ type entityStoreImpl struct {
prepareOnce sync.Once
}

func newEntityStore(
func newEntityStore[TEntity Entity](
tx kv.RwTx,
table string,
makeEntity func() Entity,
makeEntity func() TEntity,
getLastEntityId func(ctx context.Context, tx kv.Tx) (uint64, bool, error),
loadEntityBytes func(ctx context.Context, tx kv.Getter, id uint64) ([]byte, error),
blockNumToIdIndexFactory func(ctx context.Context) (*RangeIndex, error),
) entityStore {
return &entityStoreImpl{
) entityStore[TEntity] {
return &entityStoreImpl[TEntity]{
tx: tx,
table: table,

Expand All @@ -55,7 +55,7 @@ func newEntityStore(
}
}

func (s *entityStoreImpl) Prepare(ctx context.Context) error {
func (s *entityStoreImpl[TEntity]) Prepare(ctx context.Context) error {
var err error
s.prepareOnce.Do(func() {
s.blockNumToIdIndex, err = s.blockNumToIdIndexFactory(ctx)
Expand All @@ -68,22 +68,30 @@ func (s *entityStoreImpl) Prepare(ctx context.Context) error {
return err
}

func (s *entityStoreImpl) Close() {
func (s *entityStoreImpl[TEntity]) Close() {
s.blockNumToIdIndex.Close()
}

func (s *entityStoreImpl) GetLastEntityId(ctx context.Context) (uint64, bool, error) {
func (s *entityStoreImpl[TEntity]) GetLastEntityId(ctx context.Context) (uint64, bool, error) {
return s.getLastEntityId(ctx, s.tx)
}

func (s *entityStoreImpl) GetLastEntity(ctx context.Context) (Entity, error) {
// Zero value of any type T
// https://stackoverflow.com/questions/70585852/return-default-value-for-generic-type)
// https://go.dev/ref/spec#The_zero_value
func Zero[T any]() T {
var value T
return value
}

func (s *entityStoreImpl[TEntity]) GetLastEntity(ctx context.Context) (TEntity, error) {
id, ok, err := s.GetLastEntityId(ctx)
if err != nil {
return nil, err
return Zero[TEntity](), err
}
// not found
if !ok {
return nil, nil
return Zero[TEntity](), nil
}
return s.GetEntity(ctx, id)
}
Expand All @@ -94,28 +102,28 @@ func entityStoreKey(id uint64) [8]byte {
return key
}

func (s *entityStoreImpl) entityUnmarshalJSON(jsonBytes []byte) (Entity, error) {
func (s *entityStoreImpl[TEntity]) entityUnmarshalJSON(jsonBytes []byte) (TEntity, error) {
entity := s.makeEntity()
if err := json.Unmarshal(jsonBytes, entity); err != nil {
return nil, err
return Zero[TEntity](), err
}
return entity, nil
}

func (s *entityStoreImpl) GetEntity(ctx context.Context, id uint64) (Entity, error) {
func (s *entityStoreImpl[TEntity]) GetEntity(ctx context.Context, id uint64) (TEntity, error) {
jsonBytes, err := s.loadEntityBytes(ctx, s.tx, id)
if err != nil {
return nil, err
return Zero[TEntity](), err
}
// not found
if jsonBytes == nil {
return nil, nil
return Zero[TEntity](), nil
}

return s.entityUnmarshalJSON(jsonBytes)
}

func (s *entityStoreImpl) PutEntity(ctx context.Context, id uint64, entity Entity) error {
func (s *entityStoreImpl[TEntity]) PutEntity(ctx context.Context, id uint64, entity TEntity) error {
jsonBytes, err := json.Marshal(entity)
if err != nil {
return err
Expand All @@ -131,27 +139,27 @@ func (s *entityStoreImpl) PutEntity(ctx context.Context, id uint64, entity Entit
return s.blockNumToIdIndex.Put(ctx, entity.BlockNumRange(), id)
}

func (s *entityStoreImpl) FindByBlockNum(ctx context.Context, blockNum uint64) (Entity, error) {
func (s *entityStoreImpl[TEntity]) FindByBlockNum(ctx context.Context, blockNum uint64) (TEntity, error) {
id, err := s.blockNumToIdIndex.Lookup(ctx, blockNum)
if err != nil {
return nil, err
return Zero[TEntity](), err
}
// not found
if id == 0 {
return nil, nil
return Zero[TEntity](), nil
}

return s.GetEntity(ctx, id)
}

func (s *entityStoreImpl) RangeFromId(_ context.Context, startId uint64) ([]Entity, error) {
func (s *entityStoreImpl[TEntity]) RangeFromId(_ context.Context, startId uint64) ([]TEntity, error) {
startKey := entityStoreKey(startId)
it, err := s.tx.Range(s.table, startKey[:], nil)
if err != nil {
return nil, err
}

var entities []Entity
var entities []TEntity
for it.HasNext() {
_, jsonBytes, err := it.Next()
if err != nil {
Expand All @@ -167,7 +175,7 @@ func (s *entityStoreImpl) RangeFromId(_ context.Context, startId uint64) ([]Enti
return entities, nil
}

func (s *entityStoreImpl) RangeFromBlockNum(ctx context.Context, startBlockNum uint64) ([]Entity, error) {
func (s *entityStoreImpl[TEntity]) RangeFromBlockNum(ctx context.Context, startBlockNum uint64) ([]TEntity, error) {
id, err := s.blockNumToIdIndex.Lookup(ctx, startBlockNum)
if err != nil {
return nil, err
Expand All @@ -180,11 +188,11 @@ func (s *entityStoreImpl) RangeFromBlockNum(ctx context.Context, startBlockNum u
return s.RangeFromId(ctx, id)
}

func buildBlockNumToIdIndex(
func buildBlockNumToIdIndex[TEntity Entity](
ctx context.Context,
index *RangeIndex,
iteratorFactory func() (iter.KV, error),
entityUnmarshalJSON func([]byte) (Entity, error),
entityUnmarshalJSON func([]byte) (TEntity, error),
) error {
it, err := iteratorFactory()
if err != nil {
Expand Down
Loading

0 comments on commit 3e93af6

Please sign in to comment.