From ef110da071f79113f825e33059cfeb772e5f4303 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 4 Nov 2024 17:54:08 -0800 Subject: [PATCH] spanner/request-id: complete TODOs and tests --- spanner/batch.go | 8 +- spanner/client.go | 6 - spanner/errors.go | 4 + .../internal/testutil/inmem_spanner_server.go | 4 + spanner/pdml.go | 13 +- spanner/request_id_header_test.go | 324 +++++++++++++++++- spanner/transaction.go | 52 ++- 7 files changed, 387 insertions(+), 24 deletions(-) diff --git a/spanner/batch.go b/spanner/batch.go index 2c38d4a36d59..e09ae794dab0 100644 --- a/spanner/batch.go +++ b/spanner/batch.go @@ -209,9 +209,13 @@ func (t *BatchReadOnlyTransaction) partitionQuery(ctx context.Context, statement ParamTypes: paramTypes, } sh.updateLastUseTime() - // TODO: (@odeke-em) retrieve the requestID and increment the RPC number - // then send it along in every call per retry. + + // PartitionQuery does not retry automatically so we don't need to retrieve + // the injected requestID to increment the RPC number on retries. resp, err := client.PartitionQuery(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md))) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "partitionQuery"); err != nil { diff --git a/spanner/client.go b/spanner/client.go index dd50871b269c..346f52112edd 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -25,7 +25,6 @@ import ( "regexp" "strconv" "strings" - "sync/atomic" "time" "cloud.google.com/go/internal/trace" @@ -358,11 +357,6 @@ type ClientConfig struct { DisableNativeMetrics bool } -type requestIDConfig struct { - processID uint32 - dbClientHandleID *atomic.Uint64 -} - type openTelemetryConfig struct { meterProvider metric.MeterProvider attributeMap []attribute.KeyValue diff --git a/spanner/errors.go b/spanner/errors.go index edb52d26a47f..871e2a6b4941 100644 --- a/spanner/errors.go +++ b/spanner/errors.go @@ -123,6 +123,10 @@ func (e *Error) decorate(info string) { // APIError error having given error code as its status. func spannerErrorf(code codes.Code, format string, args ...interface{}) error { msg := fmt.Sprintf(format, args...) + return spannerError(code, msg) +} + +func spannerError(code codes.Code, msg string) error { wrapped, _ := apierror.FromError(status.Error(code, msg)) return &Error{ Code: code, diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index ae73b82230a1..08be3b21742c 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -90,6 +90,7 @@ const ( MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" MethodStreamingRead string = "EXECUTE_STREAMING_READ" MethodBatchWrite string = "BATCH_WRITE" + MethodPartitionQuery string = "PARTITION_QUERY" ) // StatementResult represents a mocked result on the test server. The result is @@ -1107,6 +1108,9 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba } func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { + if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil { + return nil, err + } s.mu.Lock() if s.stopped { s.mu.Unlock() diff --git a/spanner/pdml.go b/spanner/pdml.go index bbdd6f82e603..05bcb693d26f 100644 --- a/spanner/pdml.go +++ b/spanner/pdml.go @@ -110,13 +110,19 @@ func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlReq var md metadata.MD sh.updateLastUseTime() // Begin transaction. - res, err := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{ + client := sh.getClient() + res, err := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ Session: sh.getID(), Options: &sppb.TransactionOptions{ Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}}, ExcludeTxnFromChangeStreams: options.ExcludeTxnFromChangeStreams, }, }) + // This function is invoked afresh on every retry and it retrieves a fresh client + // each time hence does not need an extraction and increment of the injected spanner requestId. + if gcl, ok := client.(*grpcSpannerClient); ok { + defer gcl.setOrResetRPCID() + } if err != nil { return 0, ToSpannerError(err) } @@ -126,9 +132,8 @@ func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlReq } sh.updateLastUseTime() - // TODO: (@odeke-em) retrieve the requestID and increment the RPC number - // then send it along in every call per retry. - resultSet, err := sh.getClient().ExecuteSql(ctx, req, gax.WithGRPCOptions(grpc.Header(&md))) + + resultSet, err := client.ExecuteSql(ctx, req, gax.WithGRPCOptions(grpc.Header(&md))) if getGFELatencyMetricsFlag() && md != nil && sh.session.pool != nil { err := captureGFELatencyStats(tag.NewContext(ctx, sh.session.pool.tagMap), md, "executePdml_ExecuteSql") if err != nil { diff --git a/spanner/request_id_header_test.go b/spanner/request_id_header_test.go index 01d0d3698b60..b843951dbc9f 100644 --- a/spanner/request_id_header_test.go +++ b/spanner/request_id_header_test.go @@ -208,7 +208,9 @@ func ensureMonotonicityOfRequestIDs(requestIDs []*requestIDSegments) error { return fmt.Errorf("sameClientID but requestNo mismatch: #[%d].reqNo=%d < #[%d].reqNo=%d", i, rCurr.RequestNo, i-1, rPrev.RequestNo) } } else if rPrev.ClientID > rCurr.ClientID { - return fmt.Errorf("clientID inconsistency: previous request %d.ClientID=%d > %d.ClientID=%d", i-1, rPrev.ClientID, i, rCurr.ClientID) + // For requests that execute in parallel such as with PartitionQuery, + // we could have requests from previous clients executing slower than + // the newest client, hence this is not an error. } } @@ -306,10 +308,12 @@ func (it *interceptorTracker) unaryClientInterceptor(ctx context.Context, method it.unaryClientRequestIDSegments = append(it.unaryClientRequestIDSegments, reqID) it.mu.Unlock() + // fmt.Printf("unary.method=%q\n", method) // fmt.Printf("method=%q\nReq: %#v\nRes: %#v\n", method, req, reply) // Otherwise proceed with the call. return invoker(ctx, method, req, reply, cc, opts...) } + func (it *interceptorTracker) streamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { it.nStreamClientCalls.Add(1) reqID, err := checkForMissingSpannerRequestIDHeader(opts) @@ -321,6 +325,7 @@ func (it *interceptorTracker) streamClientInterceptor(ctx context.Context, desc it.streamClientRequestIDSegments = append(it.streamClientRequestIDSegments, reqID) it.mu.Unlock() + // fmt.Printf("stream.method=%q\n", method) // Otherwise proceed with the call. return streamer(ctx, desc, cc, method, opts...) } @@ -723,3 +728,320 @@ func TestRequestIDHeader_ClientBatchWriteWithError(t *testing.T) { t.Fatal(err) } } + +func TestRequestIDHeader_PartitionQueryWithoutError(t *testing.T) { + testRequestIDHeaderPartitionQuery(t, false) +} + +func TestRequestIDHeader_PartitionQueryWithError(t *testing.T) { + testRequestIDHeaderPartitionQuery(t, true) +} + +func testRequestIDHeaderPartitionQuery(t *testing.T, mustErrorOnPartitionQuery bool) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // The request will initially fail, and be retried. + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, + testutil.SimulatedExecutionTime{ + Errors: []error{newAbortedErrorWithMinimalRetryDelay()}, + }) + if mustErrorOnPartitionQuery { + server.TestSpanner.PutExecutionTime(testutil.MethodPartitionQuery, + testutil.SimulatedExecutionTime{ + Errors: []error{newAbortedErrorWithMinimalRetryDelay()}, + }) + } + + sqlFromSingers := "SELECT * FROM Singers" + resultSet := &sppb.ResultSet{ + Rows: []*structpb.ListValue{ + { + Values: []*structpb.Value{ + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 1}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Bruce"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "Wayne"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 2}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Robin"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "SideKick"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 3}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Gordon"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "Commissioner"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 4}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Joker"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "None"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 5}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Riddler"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "None"}}, + }, + }), + }}, + }, + Metadata: &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {Name: "SingerId", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}, + {Name: "FirstName", Type: &sppb.Type{Code: sppb.TypeCode_STRING}}, + {Name: "LastName", Type: &sppb.Type{Code: sppb.TypeCode_STRING}}, + }, + }, + }, + } + result := &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + } + server.TestSpanner.PutStatementResult(sqlFromSingers, result) + + ctx := context.Background() + txn, err := sc.BatchReadOnlyTransaction(ctx, StrongRead()) + + if err != nil { + t.Fatal(err) + } + defer txn.Close() + + // Singer represents the elements in a row from the Singers table. + type Singer struct { + SingerID int64 + FirstName string + LastName string + SingerInfo []byte + } + stmt := Statement{SQL: "SELECT * FROM Singers;"} + partitions, err := txn.PartitionQuery(ctx, stmt, PartitionOptions{}) + + if mustErrorOnPartitionQuery { + // The methods invoked should be: ['/BatchCreateSessions', '/CreateSession', '/BeginTransaction', '/PartitionQuery'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } + return + } + + if err != nil { + t.Fatal(err) + } + + wg := new(sync.WaitGroup) + for i, p := range partitions { + wg.Add(1) + go func(i int, p *Partition) { + defer wg.Done() + iter := txn.Execute(ctx, p) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + var s Singer + if err := row.ToStruct(&s); err != nil { + _ = err + } + _ = s + } + }(i, p) + } + wg.Wait() + + // The methods invoked should be: ['/BatchCreateSessions', '/CreateSession', '/BeginTransaction', '/PartitionQuery'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ReadWriteTransactionUpdate(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + ctx := context.Background() + updateSQL := testutil.UpdateBarSetFoo + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + if _, err = tx.Update(ctx, Statement{SQL: updateSQL}); err != nil { + return err + } + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateSQL}, {SQL: updateSQL}}); err != nil { + return err + } + if _, err = tx.Update(ctx, Statement{SQL: updateSQL}); err != nil { + return err + } + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateSQL}}) + return err + }) + if err != nil { + t.Fatal(err) + } + + gotReqs, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.CommitRequest{}, + }) + if err != nil { + t.Fatal(err) + } + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if got, want := gotReqs[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { + t.Errorf("got %d, want %d", got, want) + } + + // The methods invoked should be: ['/BatchCreateSessions', '/ExecuteSql', '/ExecuteBatchDml', '/ExecuteSql', '/ExecuteBatchDml', '/Commit'] + if g, w := interceptorTracker.unaryCallCount(), uint64(6); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ReadWriteTransactionBatchUpdateWithOptions(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + _, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + ctx := context.Background() + selectSQL := testutil.SelectSingerIDAlbumIDAlbumTitleFromAlbums + updateSQL := testutil.UpdateBarSetFoo + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + iter := tx.QueryWithOptions(ctx, NewStatement(selectSQL), QueryOptions{}) + iter.Next() + iter.Stop() + + qo := QueryOptions{} + iter = tx.ReadWithOptions(ctx, "FOO", AllKeys(), []string{"BAR"}, &ReadOptions{Priority: qo.Priority}) + iter.Next() + iter.Stop() + + tx.UpdateWithOptions(ctx, NewStatement(updateSQL), qo) + tx.BatchUpdateWithOptions(ctx, []Statement{ + NewStatement(updateSQL), + }, qo) + return nil + }) + if err != nil { + t.Fatal(err) + } + + // The methods invoked should be: ['/BatchCreateSessions', '/ExecuteSql', '/ExecuteBatchDml', '/Commit'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // The methods invoked should be: ['/ExecuteStreamingSql', '/StreamingRead'] + if g, w := interceptorTracker.streamCallCount(), uint64(2); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} diff --git a/spanner/transaction.go b/spanner/transaction.go index b4f29767a23c..f1fc2fe349f4 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -812,7 +812,9 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { }, }, }, gax.WithGRPCOptions(grpc.Header(&md))) - gcl.setOrResetRPCID() + if ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "begin_BeginTransaction"); err != nil { @@ -1213,10 +1215,14 @@ func (t *ReadWriteTransaction) update(ctx context.Context, stmt Statement, opts sh.updateLastUseTime() var md metadata.MD - // TODO: (@odeke-em) retrieve the requestID and increment the RPC number - // then send it along in every call per retry. - resultSet, err := sh.getClient().ExecuteSql(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md))) + client := sh.getClient() + // ReadWriteTransaction.update has to be updated manually and is not retried + // automatically, hence we do not need to extract and increment the request-id's RPC count. + resultSet, err := client.ExecuteSql(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md))) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "update"); err != nil { @@ -1319,16 +1325,20 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts sh.updateLastUseTime() var md metadata.MD - // TODO: (@odeke-em) retrieve the requestID and increment the RPC number - // then send it along in every call per retry. - resp, err := sh.getClient().ExecuteBatchDml(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.ExecuteBatchDmlRequest{ + // ReadWriteTransaction.batchUpdateWithOptions has to be updated manually and is not + // retried automatically, hence we do not need to extract and increment the request-id. + client := sh.getClient() + resp, err := client.ExecuteBatchDml(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.ExecuteBatchDmlRequest{ Session: sh.getID(), Transaction: ts, Statements: sppbStmts, Seqno: atomic.AddInt64(&t.sequenceNumber, 1), RequestOptions: createRequestOptions(opts.Priority, opts.RequestTag, t.txOpts.TransactionTag), }, gax.WithGRPCOptions(grpc.Header(&md))) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "batchUpdateWithOptions"); err != nil { @@ -1366,7 +1376,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts return counts, errInlineBeginTransactionFailed() } if resp.Status != nil && resp.Status.Code != 0 { - return counts, spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message) + return counts, spannerError(codes.Code(uint32(resp.Status.Code)), resp.Status.Message) } return counts, nil } @@ -1517,6 +1527,9 @@ func beginTransaction(ctx context.Context, sid string, client spannerClient, opt ExcludeTxnFromChangeStreams: opts.ExcludeTxnFromChangeStreams, }, }) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if err != nil { return nil, err } @@ -1665,6 +1678,8 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions if options.MaxCommitDelay != nil { maxCommitDelay = durationpb.New(*(options.MaxCommitDelay)) } + // .commit is not automatically retried hence no need to + // extract the spanner requestID to increment the retry count. res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{ Session: sid, Transaction: &sppb.CommitRequest_TransactionId{ @@ -1675,8 +1690,9 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions ReturnCommitStats: options.ReturnCommitStats, MaxCommitDelay: maxCommitDelay, }, gax.WithGRPCOptions(grpc.Header(&md))) - // TODO: (@odeke-em) retrieve the requestID and increment the RPC number - // then send it along in every call per retry. + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "commit"); err != nil { @@ -1723,6 +1739,9 @@ func (t *ReadWriteTransaction) rollback(ctx context.Context) { Session: sid, TransactionId: t.tx, }) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if isSessionNotFoundError(err) { t.sh.destroy() } @@ -1924,6 +1943,7 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta retryer := onCodes(DefaultRetryBackoff, codes.Aborted, codes.Internal) // Apply the mutation and retry if the commit is aborted. applyMutationWithRetry := func(ctx context.Context) error { + nRPCs := uint64(0) for { if sh == nil || sh.getID() == "" || sh.getClient() == nil { // No usable session for doing the commit, take one from pool. @@ -1934,8 +1954,15 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta return ToSpannerError(err) } } + nRPCs++ sh.updateLastUseTime() - res, err := sh.getClient().Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{ + client := sh.getClient() + // Firstly set the number of retries as the RPCID. + gcl, ok := client.(*grpcSpannerClient) + if ok { + gcl.setRPCID(nRPCs) + } + res, err := client.Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{ Session: sh.getID(), Transaction: &sppb.CommitRequest_SingleUseTransaction{ SingleUseTransaction: &sppb.TransactionOptions{ @@ -1949,6 +1976,9 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta RequestOptions: createRequestOptions(t.commitPriority, "", t.transactionTag), MaxCommitDelay: maxCommitDelay, }) + if ok { + gcl.setOrResetRPCID() + } if err != nil && !isAbortedErr(err) { // should not be the case with multiplexed sessions if isSessionNotFoundError(err) {