Skip to content

Commit

Permalink
feat(spanner): add ResetForRetry method for stmt-based transactions (#…
Browse files Browse the repository at this point in the history
…10956)

* feat(spanner): add ResetForRetry method for stmt-based transactions

Read/write transactions that are aborted should preferably be retried using the
same session as the original attempt. For this, statement-based transactions
should have a ResetForRetry function. This was missing in the Go client library.

This change adds this method, and re-uses the session when possible. If the
aborted error happens during the Commit RPC, the session handle was already
cleaned up by the original implementation. We will not change that now, as
that could lead to breakage in existing code that depends on this. When
the Go client is switched to multiplexed sessions for read/write transactions,
then this implementation should be re-visited, and it should be made sure that
ResetForRetry optimizes the retry attempt for an actual retry.

Updates googleapis/go-sql-spanner#300

* fix: only allow resetting if tx is really aborted

---------

Co-authored-by: Sri Harsha CH <[email protected]>
  • Loading branch information
olavloite and harshachinta authored Nov 12, 2024
1 parent 5b59819 commit 02c191c
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 14 deletions.
8 changes: 7 additions & 1 deletion spanner/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ func stream(
rpc,
nil,
nil,
func(err error) error {
return err
},
setTimestamp,
release,
)
Expand All @@ -79,6 +82,7 @@ func streamWithReplaceSessionFunc(
rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error),
replaceSession func(ctx context.Context) error,
setTransactionID func(transactionID),
updateTxState func(err error) error,
setTimestamp func(time.Time),
release func(error),
) *RowIterator {
Expand All @@ -89,6 +93,7 @@ func streamWithReplaceSessionFunc(
streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession),
rowd: &partialResultSetDecoder{},
setTransactionID: setTransactionID,
updateTxState: updateTxState,
setTimestamp: setTimestamp,
release: release,
cancel: cancel,
Expand Down Expand Up @@ -127,6 +132,7 @@ type RowIterator struct {
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTransactionID func(transactionID)
updateTxState func(err error) error
setTimestamp func(time.Time)
release func(error)
cancel func()
Expand Down Expand Up @@ -214,7 +220,7 @@ func (r *RowIterator) Next() (*Row, error) {
return row, nil
}
if err := r.streamd.lastErr(); err != nil {
r.err = ToSpannerError(err)
r.err = r.updateTxState(ToSpannerError(err))
} else if !r.rowd.done() {
r.err = errEarlyReadEnd()
} else {
Expand Down
82 changes: 71 additions & 11 deletions spanner/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package spanner

import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -63,6 +64,12 @@ type txReadOnly struct {
// operations.
txReadEnv

// updateTxStateFunc is a function that updates the state of the current
// transaction based on the given error. This function is by default a no-op,
// but is overridden for read/write transactions to set the state to txAborted
// if Spanner aborts the transaction.
updateTxStateFunc func(err error) error

// Atomic. Only needed for DML statements, but used forall.
sequenceNumber int64

Expand Down Expand Up @@ -98,6 +105,13 @@ type txReadOnly struct {
otConfig *openTelemetryConfig
}

func (t *txReadOnly) updateTxState(err error) error {
if t.updateTxStateFunc == nil {
return err
}
return t.updateTxStateFunc(err)
}

// TransactionOptions provides options for a transaction.
type TransactionOptions struct {
CommitOptions CommitOptions
Expand Down Expand Up @@ -323,7 +337,7 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key
t.setTransactionID(nil)
return client, errInlineBeginTransactionFailed()
}
return client, err
return client, t.updateTxState(err)
}
md, err := client.Header()
if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
Expand All @@ -338,6 +352,9 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key
},
t.replaceSessionFunc,
setTransactionID,
func(err error) error {
return t.updateTxState(err)
},
t.setTimestamp,
t.release,
)
Expand Down Expand Up @@ -607,7 +624,7 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que
t.setTransactionID(nil)
return client, errInlineBeginTransactionFailed()
}
return client, err
return client, t.updateTxState(err)
}
md, err := client.Header()
if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
Expand All @@ -622,6 +639,9 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que
},
t.replaceSessionFunc,
setTransactionID,
func(err error) error {
return t.updateTxState(err)
},
t.setTimestamp,
t.release)
}
Expand Down Expand Up @@ -673,6 +693,8 @@ const (
txActive
// transaction is closed, cannot be used anymore.
txClosed
// transaction was aborted by Spanner and should be retried.
txAborted
)

// errRtsUnavailable returns error for read transaction's read timestamp being
Expand Down Expand Up @@ -1216,7 +1238,7 @@ func (t *ReadWriteTransaction) update(ctx context.Context, stmt Statement, opts
t.setTransactionID(nil)
return 0, errInlineBeginTransactionFailed()
}
return 0, ToSpannerError(err)
return 0, t.txReadOnly.updateTxState(ToSpannerError(err))
}
if hasInlineBeginTransaction {
if resultSet != nil && resultSet.GetMetadata() != nil && resultSet.GetMetadata().GetTransaction() != nil &&
Expand Down Expand Up @@ -1325,7 +1347,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts
t.setTransactionID(nil)
return nil, errInlineBeginTransactionFailed()
}
return nil, ToSpannerError(err)
return nil, t.txReadOnly.updateTxState(ToSpannerError(err))
}

haveTransactionID := false
Expand All @@ -1348,7 +1370,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts
return counts, errInlineBeginTransactionFailed()
}
if resp.Status != nil && resp.Status.Code != 0 {
return counts, spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message)
return counts, t.txReadOnly.updateTxState(spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message))
}
return counts, nil
}
Expand Down Expand Up @@ -1666,7 +1688,7 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr)
}
if e != nil {
return resp, toSpannerErrorWithCommitInfo(e, true)
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(e, true))
}
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
Expand Down Expand Up @@ -1758,6 +1780,7 @@ type ReadWriteStmtBasedTransaction struct {
// ReadWriteTransaction contains methods for performing transactional reads.
ReadWriteTransaction

client *Client
options TransactionOptions
}

Expand All @@ -1783,30 +1806,51 @@ func NewReadWriteStmtBasedTransaction(ctx context.Context, c *Client) (*ReadWrit
// used by the transaction will not be returned to the pool and cause a session
// leak.
//
// ResetForRetry resets the transaction before a retry attempt. This function
// returns a new transaction that should be used for the retry attempt. The
// transaction that is returned by this function is assigned a higher priority
// than the previous transaction, making it less probable to be aborted by
// Spanner again during the retry.
//
// NewReadWriteStmtBasedTransactionWithOptions is a configurable version of
// NewReadWriteStmtBasedTransaction.
func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client, options TransactionOptions) (*ReadWriteStmtBasedTransaction, error) {
return newReadWriteStmtBasedTransactionWithSessionHandle(ctx, c, options, nil)
}

func newReadWriteStmtBasedTransactionWithSessionHandle(ctx context.Context, c *Client, options TransactionOptions, sh *sessionHandle) (*ReadWriteStmtBasedTransaction, error) {
var (
sh *sessionHandle
err error
t *ReadWriteStmtBasedTransaction
)
sh, err = c.idleSessions.take(ctx)
if err != nil {
// If session retrieval fails, just fail the transaction.
return nil, err
if sh == nil {
sh, err = c.idleSessions.take(ctx)
if err != nil {
// If session retrieval fails, just fail the transaction.
return nil, err
}
}
t = &ReadWriteStmtBasedTransaction{
ReadWriteTransaction: ReadWriteTransaction{
txReadyOrClosed: make(chan struct{}),
},
client: c,
}
t.txReadOnly.sp = c.idleSessions
t.txReadOnly.sh = sh
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txReadOnly.ro = c.ro
t.txReadOnly.disableRouteToLeader = c.disableRouteToLeader
t.txReadOnly.updateTxStateFunc = func(err error) error {
if ErrCode(err) == codes.Aborted {
t.mu.Lock()
t.state = txAborted
t.mu.Unlock()
}
return err
}

t.txOpts = c.txo.merge(options)
t.ct = c.ct
t.otConfig = c.otConfig
Expand Down Expand Up @@ -1838,6 +1882,7 @@ func (t *ReadWriteStmtBasedTransaction) CommitWithReturnResp(ctx context.Context
}
if t.sh != nil {
t.sh.recycle()
t.sh = nil
}
return resp, err
}
Expand All @@ -1848,7 +1893,22 @@ func (t *ReadWriteStmtBasedTransaction) Rollback(ctx context.Context) {
t.rollback(ctx)
if t.sh != nil {
t.sh.recycle()
t.sh = nil
}
}

// ResetForRetry resets the transaction before a retry. This should be
// called if the transaction was aborted by Spanner and the application
// wants to retry the transaction.
// It is recommended to use this method above creating a new transaction,
// as this method will give the transaction a higher priority and thus a
// smaller probability of being aborted again by Spanner.
func (t *ReadWriteStmtBasedTransaction) ResetForRetry(ctx context.Context) (*ReadWriteStmtBasedTransaction, error) {
if t.state != txAborted {
return nil, fmt.Errorf("ResetForRetry should only be called on an active transaction that was aborted by Spanner")
}
// Create a new transaction that re-uses the current session if it is available.
return newReadWriteStmtBasedTransactionWithSessionHandle(ctx, t.client, t.options, t.sh)
}

// writeOnlyTransaction provides the most efficient way of doing write-only
Expand Down
104 changes: 102 additions & 2 deletions spanner/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,103 @@ func TestReadWriteStmtBasedTransaction_CommitAborted(t *testing.T) {
}
}

func TestReadWriteStmtBasedTransaction_QueryAborted(t *testing.T) {
t.Parallel()
rowCount, attempts, err := testReadWriteStmtBasedTransaction(t, map[string]SimulatedExecutionTime{
MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
})
if err != nil {
t.Fatalf("transaction failed to commit: %v", err)
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
t.Fatalf("Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
if g, w := attempts, 2; g != w {
t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
}
}

func TestReadWriteStmtBasedTransaction_UpdateAborted(t *testing.T) {
t.Parallel()
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
SessionPoolConfig: SessionPoolConfig{
// Use a session pool with size 1 to ensure that there are no session leaks.
MinOpened: 1,
MaxOpened: 1,
},
})
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteSql,
SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}})

ctx := context.Background()
tx, err := NewReadWriteStmtBasedTransaction(ctx, client)
if err != nil {
t.Fatal(err)
}
_, err = tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
if g, w := ErrCode(err), codes.Aborted; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}
tx, err = tx.ResetForRetry(ctx)
if err != nil {
t.Fatal(err)
}
c, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
if err != nil {
t.Fatal(err)
}
if g, w := c, int64(UpdateBarSetFooRowCount); g != w {
t.Fatalf("update count mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestReadWriteStmtBasedTransaction_BatchUpdateAborted(t *testing.T) {
t.Parallel()
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
SessionPoolConfig: SessionPoolConfig{
// Use a session pool with size 1 to ensure that there are no session leaks.
MinOpened: 1,
MaxOpened: 1,
},
})
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteBatchDml,
SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}})

ctx := context.Background()
tx, err := NewReadWriteStmtBasedTransaction(ctx, client)
if err != nil {
t.Fatal(err)
}
_, err = tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}})
if g, w := ErrCode(err), codes.Aborted; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}
tx, err = tx.ResetForRetry(ctx)
if err != nil {
t.Fatal(err)
}
c, err := tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}})
if err != nil {
t.Fatal(err)
}
if g, w := c, []int64{UpdateBarSetFooRowCount}; !reflect.DeepEqual(g, w) {
t.Fatalf("update count mismatch\n Got: %v\nWant: %v", g, w)
}
}

func testReadWriteStmtBasedTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) (rowCount int64, attempts int, err error) {
server, client, teardown := setupMockedTestServer(t)
// server, client, teardown := setupMockedTestServer(t)
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
SessionPoolConfig: SessionPoolConfig{
// Use a session pool with size 1 to ensure that there are no session leaks.
MinOpened: 1,
MaxOpened: 1,
},
})
defer teardown()
for method, exec := range executionTimes {
server.TestSpanner.PutExecutionTime(method, exec)
Expand Down Expand Up @@ -500,9 +595,14 @@ func testReadWriteStmtBasedTransaction(t *testing.T, executionTimes map[string]S
return rowCount, nil
}

var tx *ReadWriteStmtBasedTransaction
for {
attempts++
tx, err := NewReadWriteStmtBasedTransaction(ctx, client)
if attempts > 1 {
tx, err = tx.ResetForRetry(ctx)
} else {
tx, err = NewReadWriteStmtBasedTransaction(ctx, client)
}
if err != nil {
return 0, attempts, fmt.Errorf("failed to begin a transaction: %v", err)
}
Expand Down

0 comments on commit 02c191c

Please sign in to comment.