Skip to content

Commit

Permalink
Close cursor if it's opened by different tx (#12546)
Browse files Browse the repository at this point in the history
So we keep cursor interface opened inside `DomainRoTx`. If we then
create new tx and try to read with it from `DomainRoTx`, value will be
fetched via previously opened cursor which is not expected behaviour.

I set stupid fix for that but i assume there could be better solutions.

---------

Co-authored-by: alex.sharov <[email protected]>
  • Loading branch information
awskii and AskAlexSharov authored Dec 17, 2024
1 parent 7c5dead commit 8c12f51
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 20 deletions.
6 changes: 3 additions & 3 deletions core/rawdb/rawtemporaldb/accessors_receipt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ func TestAppendReceipt(t *testing.T) {
require.NoError(err)
defer tx.Rollback()

doms, err := state.NewSharedDomains(tx, log.New())
ttx := tx.(kv.TemporalTx)
doms, err := state.NewSharedDomains(ttx, log.New())
require.NoError(err)
defer doms.Close()
doms.SetTx(tx)
doms.SetTx(ttx)

doms.SetTxNum(0) // block1
err = AppendReceipt(doms, &types.Receipt{CumulativeGasUsed: 10, FirstLogIndexWithinBlock: 0}, 0)
Expand All @@ -48,7 +49,6 @@ func TestAppendReceipt(t *testing.T) {
err = doms.Flush(context.Background(), tx)
require.NoError(err)

ttx := tx.(kv.TemporalTx)
v, ok, err := ttx.HistorySeek(kv.ReceiptDomain, FirstLogIndexKey, 0)
require.NoError(err)
require.True(ok)
Expand Down
11 changes: 9 additions & 2 deletions core/vm/gas_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ package vm_test
import (
"context"
"errors"
"fmt"
"math"
"strconv"
"testing"
"unsafe"

"github.com/holiman/uint256"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -195,12 +197,16 @@ func TestCreateGas(t *testing.T) {
var txc wrap.TxContainer
txc.Tx = tx

domains, err := state3.NewSharedDomains(tx, log.New())
eface := *(*[2]uintptr)(unsafe.Pointer(&tx))
fmt.Printf("init tx %x\n", eface[1])

domains, err := state3.NewSharedDomains(txc.Tx, log.New())
require.NoError(t, err)
defer domains.Close()
txc.Doms = domains

stateReader = rpchelper.NewLatestStateReader(tx)
//stateReader = rpchelper.NewLatestStateReader(domains)
stateReader = rpchelper.NewLatestDomainStateReader(domains)
stateWriter = rpchelper.NewLatestStateWriter(txc, nil, 0)

s := state.New(stateReader)
Expand Down Expand Up @@ -230,5 +236,6 @@ func TestCreateGas(t *testing.T) {
t.Errorf("test %d: gas used mismatch: have %v, want %v", i, gasUsed, tt.gasUsed)
}
tx.Rollback()
domains.Close()
}
}
4 changes: 3 additions & 1 deletion erigon-lib/kv/mdbx/kv_mdbx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ func (tx *MdbxTx) stdCursor(bucket string) (kv.RwCursor, error) {
if tx.toCloseMap == nil {
tx.toCloseMap = make(map[uint64]kv.Closer)
}
tx.toCloseMap[c.id] = c.c
tx.toCloseMap[c.id] = c
return c, nil
}

Expand Down Expand Up @@ -1268,6 +1268,8 @@ func (c *MdbxCursor) Close() {
}
}

func (c *MdbxCursor) IsClosed() bool { return c.c == nil }

type MdbxDupSortCursor struct {
*MdbxCursor
}
Expand Down
207 changes: 207 additions & 0 deletions erigon-lib/state/aggregator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,213 @@ import (
"github.com/stretchr/testify/require"
)

func TestAggregatorV3_Merge(t *testing.T) {
t.Parallel()
db, agg := testDbAndAggregatorv3(t, 10)
rwTx, err := db.BeginRwNosync(context.Background())
require.NoError(t, err)
defer func() {
if rwTx != nil {
rwTx.Rollback()
}
}()

ac := agg.BeginFilesRo()
defer ac.Close()
domains, err := NewSharedDomains(WrapTxWithCtx(rwTx, ac), log.New())
require.NoError(t, err)
defer domains.Close()

txs := uint64(1000)
rnd := rand.New(rand.NewSource(time.Now().UnixNano()))

var (
commKey1 = []byte("someCommKey")
commKey2 = []byte("otherCommKey")
)

// keys are encodings of numbers 1..31
// each key changes value on every txNum which is multiple of the key
var maxWrite, otherMaxWrite uint64
for txNum := uint64(1); txNum <= txs; txNum++ {
domains.SetTxNum(txNum)

addr, loc := make([]byte, length.Addr), make([]byte, length.Hash)

n, err := rnd.Read(addr)
require.NoError(t, err)
require.EqualValues(t, length.Addr, n)

n, err = rnd.Read(loc)
require.NoError(t, err)
require.EqualValues(t, length.Hash, n)

buf := types.EncodeAccountBytesV3(1, uint256.NewInt(0), nil, 0)
err = domains.DomainPut(kv.AccountsDomain, addr, nil, buf, nil, 0)
require.NoError(t, err)

err = domains.DomainPut(kv.StorageDomain, addr, loc, []byte{addr[0], loc[0]}, nil, 0)
require.NoError(t, err)

var v [8]byte
binary.BigEndian.PutUint64(v[:], txNum)
if txNum%135 == 0 {
pv, step, err := domains.GetLatest(kv.CommitmentDomain, commKey2)
require.NoError(t, err)

err = domains.DomainPut(kv.CommitmentDomain, commKey2, nil, v[:], pv, step)
require.NoError(t, err)
otherMaxWrite = txNum
} else {
pv, step, err := domains.GetLatest(kv.CommitmentDomain, commKey1)
require.NoError(t, err)

err = domains.DomainPut(kv.CommitmentDomain, commKey1, nil, v[:], pv, step)
require.NoError(t, err)
maxWrite = txNum
}
require.NoError(t, err)

}

err = domains.Flush(context.Background(), rwTx)
require.NoError(t, err)

require.NoError(t, err)
err = rwTx.Commit()
require.NoError(t, err)
rwTx = nil

err = agg.BuildFiles(txs)
require.NoError(t, err)

rwTx, err = db.BeginRw(context.Background())
require.NoError(t, err)
defer rwTx.Rollback()

logEvery := time.NewTicker(30 * time.Second)
defer logEvery.Stop()
stat, err := ac.Prune(context.Background(), rwTx, 0, logEvery)
require.NoError(t, err)
t.Logf("Prune: %s", stat)

err = rwTx.Commit()
require.NoError(t, err)

err = agg.MergeLoop(context.Background())
require.NoError(t, err)

// Check the history
roTx, err := db.BeginRo(context.Background())
require.NoError(t, err)
defer roTx.Rollback()

dc := agg.BeginFilesRo()

v, _, ex, err := dc.GetLatest(kv.CommitmentDomain, commKey1, roTx)
require.NoError(t, err)
require.Truef(t, ex, "key %x not found", commKey1)

require.EqualValues(t, maxWrite, binary.BigEndian.Uint64(v[:]))

v, _, ex, err = dc.GetLatest(kv.CommitmentDomain, commKey2, roTx)
require.NoError(t, err)
require.Truef(t, ex, "key %x not found", commKey2)
dc.Close()

require.EqualValues(t, otherMaxWrite, binary.BigEndian.Uint64(v[:]))
}

func TestAggregatorV3_MergeValTransform(t *testing.T) {
t.Parallel()
db, agg := testDbAndAggregatorv3(t, 10)
rwTx, err := db.BeginRwNosync(context.Background())
require.NoError(t, err)
defer func() {
if rwTx != nil {
rwTx.Rollback()
}
}()
ac := agg.BeginFilesRo()
defer ac.Close()
domains, err := NewSharedDomains(WrapTxWithCtx(rwTx, ac), log.New())
require.NoError(t, err)
defer domains.Close()

txs := uint64(1000)
rnd := rand.New(rand.NewSource(time.Now().UnixNano()))

agg.commitmentValuesTransform = true

state := make(map[string][]byte)

// keys are encodings of numbers 1..31
// each key changes value on every txNum which is multiple of the key
//var maxWrite, otherMaxWrite uint64
for txNum := uint64(1); txNum <= txs; txNum++ {
domains.SetTxNum(txNum)

addr, loc := make([]byte, length.Addr), make([]byte, length.Hash)

n, err := rnd.Read(addr)
require.NoError(t, err)
require.EqualValues(t, length.Addr, n)

n, err = rnd.Read(loc)
require.NoError(t, err)
require.EqualValues(t, length.Hash, n)

buf := types.EncodeAccountBytesV3(1, uint256.NewInt(txNum*1e6), nil, 0)
err = domains.DomainPut(kv.AccountsDomain, addr, nil, buf, nil, 0)
require.NoError(t, err)

err = domains.DomainPut(kv.StorageDomain, addr, loc, []byte{addr[0], loc[0]}, nil, 0)
require.NoError(t, err)

if (txNum+1)%agg.StepSize() == 0 {
_, err := domains.ComputeCommitment(context.Background(), true, txNum/10, "")
require.NoError(t, err)
}

state[string(addr)] = buf
state[string(addr)+string(loc)] = []byte{addr[0], loc[0]}
}

err = domains.Flush(context.Background(), rwTx)
require.NoError(t, err)

err = rwTx.Commit()
require.NoError(t, err)
rwTx = nil

err = agg.BuildFiles(txs)
require.NoError(t, err)

ac.Close()
ac = agg.BeginFilesRo()
defer ac.Close()

rwTx, err = db.BeginRwNosync(context.Background())
require.NoError(t, err)
defer func() {
if rwTx != nil {
rwTx.Rollback()
}
}()

logEvery := time.NewTicker(30 * time.Second)
defer logEvery.Stop()
stat, err := ac.Prune(context.Background(), rwTx, 0, logEvery)
require.NoError(t, err)
t.Logf("Prune: %s", stat)

err = rwTx.Commit()
require.NoError(t, err)

err = agg.MergeLoop(context.Background())
require.NoError(t, err)
}

func TestAggregatorV3_RestartOnDatadir(t *testing.T) {
t.Parallel()
//t.Skip()
Expand Down
40 changes: 38 additions & 2 deletions erigon-lib/state/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,8 @@ type DomainRoTx struct {
keyBuf [60]byte // 52b key and 8b for inverted step
comBuf []byte

valsC kv.Cursor
valsC kv.Cursor
valCViewID uint64 // to make sure that valsC reading from the same view with given kv.Tx

getFromFileCache *DomainGetFromFileCache
}
Expand Down Expand Up @@ -1638,6 +1639,7 @@ func (dt *DomainRoTx) Close() {
if dt.files == nil { // invariant: it's safe to call Close multiple times
return
}
dt.closeValsCursor()
files := dt.files
dt.files = nil
for i := range files {
Expand Down Expand Up @@ -1705,12 +1707,46 @@ func (dt *DomainRoTx) statelessBtree(i int) *BtIndex {
return r
}

func (dt *DomainRoTx) valsCursor(tx kv.Tx) (c kv.Cursor, err error) {
var sdTxImmutabilityInvariant = errors.New("tx passed into ShredDomains is immutable")

func (dt *DomainRoTx) closeValsCursor() {
if dt.valsC != nil {
dt.valsC.Close()
dt.valCViewID = 0
dt.valsC = nil
// dt.vcParentPtr.Store(0)
}
}

type canCheckClosed interface {
IsClosed() bool
}

func (dt *DomainRoTx) valsCursor(tx kv.Tx) (c kv.Cursor, err error) {
if dt.valsC != nil { // run in assert mode only
if asserts {
if tx.ViewID() != dt.valCViewID {
panic(fmt.Errorf("%w: DomainRoTx=%s cursor ViewID=%d; given tx.ViewID=%d", sdTxImmutabilityInvariant, dt.d.filenameBase, dt.valCViewID, tx.ViewID())) // cursor opened by different tx, invariant broken
}
if mc, ok := dt.valsC.(canCheckClosed); !ok && mc.IsClosed() {
panic(fmt.Sprintf("domainRoTx=%s cursor lives longer than Cursor (=> than tx opened that cursor)", dt.d.filenameBase))
}
// if dt.d.largeValues {
// if mc, ok := dt.valsC.(*mdbx.MdbxCursor); ok && mc.IsClosed() {
// panic(fmt.Sprintf("domainRoTx=%s cursor lives longer than Cursor (=> than tx opened that cursor)", dt.d.filenameBase))
// }
// } else {
// if mc, ok := dt.valsC.(*mdbx.MdbxDupSortCursor); ok && mc.IsClosed() {
// panic(fmt.Sprintf("domainRoTx=%s cursor lives longer than DupCursor (=> than tx opened that cursor)", dt.d.filenameBase))
// }
// }
}
return dt.valsC, nil
}

if asserts {
dt.valCViewID = tx.ViewID()
}
if dt.d.largeValues {
dt.valsC, err = tx.Cursor(dt.d.valuesTable)
return dt.valsC, err
Expand Down
3 changes: 2 additions & 1 deletion erigon-lib/state/domain_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -915,13 +915,14 @@ func (sd *SharedDomains) Flush(ctx context.Context, tx kv.RwTx) error {
_, f, l, _ := runtime.Caller(1)
fmt.Printf("[SD aggTx=%d] FLUSHING at tx %d [%x], caller %s:%d\n", sd.aggTx.id, sd.TxNum(), fh, filepath.Base(f), l)
}
for _, w := range sd.domainWriters {
for di, w := range sd.domainWriters {
if w == nil {
continue
}
if err := w.Flush(ctx, tx); err != nil {
return err
}
sd.aggTx.d[di].closeValsCursor()
}
for _, w := range sd.iiWriters {
if w == nil {
Expand Down
Loading

0 comments on commit 8c12f51

Please sign in to comment.