From 18cd8d4293742eb5e79d71f710e3c9ddcfc30340 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 29 Oct 2024 02:32:49 -0700 Subject: [PATCH] spanner: implement generation and propagation of "x-goog-spanner-request-id" Header This change adds sending over the "x-goog-spanner-request-id" header for every unary and streaming call, in the form: .... where: * processId is a randomly generated uint32 singleton for the lifetime of a process * clientId is the monotonically increasing id/number of gRPC Spanner clients created * requestCountForClient is the monotonically increasing number of requests made by the client * channelId currently at 1 is the Id of the client for Go * rpcCountForRequest is the number of RPCs/retries within a specific request This header is to be sent on both unary and streaming calls and it'll help debug latencies for customers. After this change, the next phase shall be providing a mechanism for customers to consume the requestID and log it along with the documentation for how to accomplish that. Updates #11073 --- spanner/batch.go | 6 + spanner/client.go | 13 +- spanner/errors.go | 4 + spanner/grpc_client.go | 94 +- .../internal/testutil/inmem_spanner_server.go | 4 + spanner/pdml.go | 11 +- spanner/request_id_header_test.go | 1047 +++++++++++++++++ spanner/sessionclient.go | 4 + spanner/transaction.go | 61 +- 9 files changed, 1220 insertions(+), 24 deletions(-) create mode 100644 spanner/request_id_header_test.go diff --git a/spanner/batch.go b/spanner/batch.go index 69399d0fecce..e09ae794dab0 100644 --- a/spanner/batch.go +++ b/spanner/batch.go @@ -209,7 +209,13 @@ func (t *BatchReadOnlyTransaction) partitionQuery(ctx context.Context, statement ParamTypes: paramTypes, } sh.updateLastUseTime() + + // PartitionQuery does not retry automatically so we don't need to retrieve + // the injected requestID to increment the RPC number on retries. resp, err := client.PartitionQuery(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md))) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "partitionQuery"); err != nil { diff --git a/spanner/client.go b/spanner/client.go index a795e0f7a32c..346f52112edd 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -1318,10 +1318,21 @@ func (c *Client) BatchWriteWithOptions(ctx context.Context, mgs []*MutationGroup return &BatchWriteResponseIterator{meterTracerFactory: c.metricsTracerFactory, err: err} } + nRPCs := uint64(0) rpc := func(ct context.Context) (sppb.Spanner_BatchWriteClient, error) { var md metadata.MD sh.updateLastUseTime() - stream, rpcErr := sh.getClient().BatchWrite(contextWithOutgoingMetadata(ct, sh.getMetadata(), c.disableRouteToLeader), &sppb.BatchWriteRequest{ + nRPCs++ + + // Firstly set the number of retries as the RPCID. + client := sh.getClient() + gcl, ok := client.(*grpcSpannerClient) + if ok { + gcl.setRPCID(nRPCs) + defer gcl.setOrResetRPCID() + } + + stream, rpcErr := client.BatchWrite(contextWithOutgoingMetadata(ct, sh.getMetadata(), c.disableRouteToLeader), &sppb.BatchWriteRequest{ Session: sh.getID(), MutationGroups: mgsPb, RequestOptions: createRequestOptions(opts.Priority, "", opts.TransactionTag), diff --git a/spanner/errors.go b/spanner/errors.go index edb52d26a47f..871e2a6b4941 100644 --- a/spanner/errors.go +++ b/spanner/errors.go @@ -123,6 +123,10 @@ func (e *Error) decorate(info string) { // APIError error having given error code as its status. func spannerErrorf(code codes.Code, format string, args ...interface{}) error { msg := fmt.Sprintf(format, args...) + return spannerError(code, msg) +} + +func spannerError(code codes.Code, msg string) error { wrapped, _ := apierror.FromError(status.Error(code, msg)) return &Error{ Code: code, diff --git a/spanner/grpc_client.go b/spanner/grpc_client.go index 9b7f1bca4ca6..517d3c30f19f 100644 --- a/spanner/grpc_client.go +++ b/spanner/grpc_client.go @@ -18,7 +18,11 @@ package spanner import ( "context" + "fmt" + "math/rand" "strings" + "sync/atomic" + "time" vkit "cloud.google.com/go/spanner/apiv1" "cloud.google.com/go/spanner/apiv1/spannerpb" @@ -26,6 +30,7 @@ import ( "github.com/googleapis/gax-go/v2" "google.golang.org/api/option" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) @@ -65,10 +70,43 @@ type spannerClient interface { // grpcSpannerClient is the gRPC API implementation of the transport-agnostic // spannerClient interface. type grpcSpannerClient struct { + id uint64 raw *vkit.Client metricsTracerFactory *builtinMetricsTracerFactory + + // These fields are used to uniquely track x-goog-spanner-request-id + // grpc.ClientConn is presumed to be the channel, hence channelID + // is redundant. However, is it correct to presume that raw.Connection() + // will always be the same throughout the lifetime of a grcpSpannerClient? + channelID uint64 + // nthRequest shall always be incremented on every fresh request. + nthRequest *atomic.Uint32 + // This id uniquely defines the RPC being issued and in + // the case of retries it should be incremented. + rpcID *atomic.Uint64 } +func (g *grpcSpannerClient) setOrResetRPCID() { + if g.rpcID == nil { + g.rpcID = new(atomic.Uint64) + } + g.rpcID.Store(1) +} + +func (g *grpcSpannerClient) setRPCID(rpcID uint64) { + g.rpcID.Store(rpcID) +} + +func (g *grpcSpannerClient) prepareRequestIDTrackers() { + g.id = nGRPCClient.Add(1) + g.nthRequest = new(atomic.Uint32) + g.channelID = 1 // Assuming that .raw.Connection() never changes. + g.nthRequest = new(atomic.Uint32) + g.setOrResetRPCID() +} + +var nGRPCClient = new(atomic.Uint64) + var ( // Ensure that grpcSpannerClient implements spannerClient. _ spannerClient = (*grpcSpannerClient)(nil) @@ -83,6 +121,8 @@ func newGRPCSpannerClient(ctx context.Context, sc *sessionClient, opts ...option } g := &grpcSpannerClient{raw: raw, metricsTracerFactory: sc.metricsTracerFactory} + g.prepareRequestIDTrackers() + clientInfo := []string{"gccl", internal.Version} if sc.userAgent != "" { agentWithVersion := strings.SplitN(sc.userAgent, "/", 2) @@ -118,7 +158,7 @@ func (g *grpcSpannerClient) CreateSession(ctx context.Context, req *spannerpb.Cr mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.CreateSession(ctx, req, opts...) + resp, err := g.raw.CreateSession(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -128,7 +168,7 @@ func (g *grpcSpannerClient) BatchCreateSessions(ctx context.Context, req *spanne mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.BatchCreateSessions(ctx, req, opts...) + resp, err := g.raw.BatchCreateSessions(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -138,45 +178,67 @@ func (g *grpcSpannerClient) GetSession(ctx context.Context, req *spannerpb.GetSe mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.GetSession(ctx, req, opts...) + resp, err := g.raw.GetSession(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err } func (g *grpcSpannerClient) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest, opts ...gax.CallOption) *vkit.SessionIterator { - return g.raw.ListSessions(ctx, req, opts...) + return g.raw.ListSessions(ctx, req, g.optsWithNextRequestID(opts)...) } func (g *grpcSpannerClient) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest, opts ...gax.CallOption) error { mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - err := g.raw.DeleteSession(ctx, req, opts...) + err := g.raw.DeleteSession(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return err } +var randIdForProcess uint32 + +func init() { + randIdForProcess = rand.New(rand.NewSource(time.Now().UnixNano())).Uint32() +} + +const xSpannerRequestIDHeader = "x-goog-spanner-request-id" + +// optsWithNextRequestID bundles priors with a new header "x-goog-spanner-request-id" +func (g *grpcSpannerClient) optsWithNextRequestID(priors []gax.CallOption) []gax.CallOption { + // TODO: Decide if each field should be padded and to what width or + // should we just let fields fill up so as to reduce bandwidth? + // Go creates grpc.ClientConn which is presumed to be a channel, so channelID is going to be redundant. + requestID := fmt.Sprintf("%d.%d.%d.%d.%d", randIdForProcess, g.id, g.nextNthRequest(), g.channelID, g.rpcID.Load()) + md := metadata.MD{xSpannerRequestIDHeader: []string{requestID}} + return append(priors, gax.WithGRPCOptions(grpc.Header(&md))) +} + +func (g *grpcSpannerClient) nextNthRequest() uint32 { + return g.nthRequest.Add(1) +} + func (g *grpcSpannerClient) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (*spannerpb.ResultSet, error) { mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.ExecuteSql(ctx, req, opts...) + resp, err := g.raw.ExecuteSql(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err } func (g *grpcSpannerClient) ExecuteStreamingSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (spannerpb.Spanner_ExecuteStreamingSqlClient, error) { - return g.raw.ExecuteStreamingSql(peer.NewContext(ctx, &peer.Peer{}), req, opts...) + return g.raw.ExecuteStreamingSql(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...) } func (g *grpcSpannerClient) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest, opts ...gax.CallOption) (*spannerpb.ExecuteBatchDmlResponse, error) { mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.ExecuteBatchDml(ctx, req, opts...) + resp, err := g.raw.ExecuteBatchDml(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -186,21 +248,21 @@ func (g *grpcSpannerClient) Read(ctx context.Context, req *spannerpb.ReadRequest mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.Read(ctx, req, opts...) + resp, err := g.raw.Read(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err } func (g *grpcSpannerClient) StreamingRead(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (spannerpb.Spanner_StreamingReadClient, error) { - return g.raw.StreamingRead(peer.NewContext(ctx, &peer.Peer{}), req, opts...) + return g.raw.StreamingRead(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...) } func (g *grpcSpannerClient) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest, opts ...gax.CallOption) (*spannerpb.Transaction, error) { mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.BeginTransaction(ctx, req, opts...) + resp, err := g.raw.BeginTransaction(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -210,7 +272,7 @@ func (g *grpcSpannerClient) Commit(ctx context.Context, req *spannerpb.CommitReq mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.Commit(ctx, req, opts...) + resp, err := g.raw.Commit(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -220,7 +282,7 @@ func (g *grpcSpannerClient) Rollback(ctx context.Context, req *spannerpb.Rollbac mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - err := g.raw.Rollback(ctx, req, opts...) + err := g.raw.Rollback(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return err @@ -230,7 +292,7 @@ func (g *grpcSpannerClient) PartitionQuery(ctx context.Context, req *spannerpb.P mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.PartitionQuery(ctx, req, opts...) + resp, err := g.raw.PartitionQuery(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -240,12 +302,12 @@ func (g *grpcSpannerClient) PartitionRead(ctx context.Context, req *spannerpb.Pa mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.PartitionRead(ctx, req, opts...) + resp, err := g.raw.PartitionRead(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err } func (g *grpcSpannerClient) BatchWrite(ctx context.Context, req *spannerpb.BatchWriteRequest, opts ...gax.CallOption) (spannerpb.Spanner_BatchWriteClient, error) { - return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, opts...) + return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...) } diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index ae73b82230a1..08be3b21742c 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -90,6 +90,7 @@ const ( MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" MethodStreamingRead string = "EXECUTE_STREAMING_READ" MethodBatchWrite string = "BATCH_WRITE" + MethodPartitionQuery string = "PARTITION_QUERY" ) // StatementResult represents a mocked result on the test server. The result is @@ -1107,6 +1108,9 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba } func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { + if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil { + return nil, err + } s.mu.Lock() if s.stopped { s.mu.Unlock() diff --git a/spanner/pdml.go b/spanner/pdml.go index bb33ef291c64..05bcb693d26f 100644 --- a/spanner/pdml.go +++ b/spanner/pdml.go @@ -110,13 +110,19 @@ func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlReq var md metadata.MD sh.updateLastUseTime() // Begin transaction. - res, err := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{ + client := sh.getClient() + res, err := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ Session: sh.getID(), Options: &sppb.TransactionOptions{ Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}}, ExcludeTxnFromChangeStreams: options.ExcludeTxnFromChangeStreams, }, }) + // This function is invoked afresh on every retry and it retrieves a fresh client + // each time hence does not need an extraction and increment of the injected spanner requestId. + if gcl, ok := client.(*grpcSpannerClient); ok { + defer gcl.setOrResetRPCID() + } if err != nil { return 0, ToSpannerError(err) } @@ -126,7 +132,8 @@ func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlReq } sh.updateLastUseTime() - resultSet, err := sh.getClient().ExecuteSql(ctx, req, gax.WithGRPCOptions(grpc.Header(&md))) + + resultSet, err := client.ExecuteSql(ctx, req, gax.WithGRPCOptions(grpc.Header(&md))) if getGFELatencyMetricsFlag() && md != nil && sh.session.pool != nil { err := captureGFELatencyStats(tag.NewContext(ctx, sh.session.pool.tagMap), md, "executePdml_ExecuteSql") if err != nil { diff --git a/spanner/request_id_header_test.go b/spanner/request_id_header_test.go new file mode 100644 index 000000000000..b843951dbc9f --- /dev/null +++ b/spanner/request_id_header_test.go @@ -0,0 +1,1047 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "encoding/json" + "fmt" + "regexp" + "strings" + "sync" + "sync/atomic" + "testing" + + sppb "cloud.google.com/go/spanner/apiv1/spannerpb" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" + + "cloud.google.com/go/spanner/internal/testutil" +) + +var regRequestID = regexp.MustCompile(`^(?P\d+)\.(?P\d+)\.(?P\d+)\.(?P\d+)\.(?P\d+)$`) + +type requestIDSegments struct { + ProcessID uint32 `json:"proc_id"` + ClientID uint32 `json:"c_id"` + RequestNo uint32 `json:"req_id"` + ChannelID uint32 `json:"ch_id"` + RPCNo uint32 `json:"rpc_id"` +} + +func checkForMissingSpannerRequestIDHeader(opts []grpc.CallOption) (*requestIDSegments, error) { + requestID := "" + for _, opt := range opts { + if hdrOpt, ok := opt.(grpc.HeaderCallOption); ok { + hdrs := hdrOpt.HeaderAddr.Get(xSpannerRequestIDHeader) + gotRequestID := len(hdrs) != 0 && len(hdrs[0]) != 0 + if gotRequestID { + requestID = hdrs[0] + break + } + } + } + + if requestID == "" { + return nil, status.Errorf(codes.InvalidArgument, "missing %q header", xSpannerRequestIDHeader) + } + if !regRequestID.MatchString(requestID) { + return nil, status.Errorf(codes.InvalidArgument, "requestID does not conform to pattern=%q", regRequestID.String()) + } + + // Now extract the respective fields and validate that they match our rubric. + template := `{"proc_id":$randProcessId,"c_id":$clientId,"req_id":$reqId,"ch_id":$channelId,"rpc_id":$rpcId}` + asJSONBytes := []byte{} + for _, submatch := range regRequestID.FindAllStringSubmatchIndex(requestID, -1) { + asJSONBytes = regRequestID.ExpandString(asJSONBytes, template, requestID, submatch) + } + recv := new(requestIDSegments) + if err := json.Unmarshal(asJSONBytes, recv); err != nil { + return nil, status.Error(codes.InvalidArgument, "could not correctly parse requestID segements") + } + if g, w := recv.ProcessID, randIdForProcess; g != w { + return nil, status.Errorf(codes.InvalidArgument, "invalid processId, got=%d want=%d", g, w) + } + if g := recv.ClientID; g < 1 { + return nil, status.Errorf(codes.InvalidArgument, "clientID must be >= 1, got=%d", g) + } + if g := recv.RequestNo; g < 1 { + return nil, status.Errorf(codes.InvalidArgument, "requestNumber must be >= 1, got=%d", g) + } + if g := recv.ChannelID; g < 1 { + return nil, status.Errorf(codes.InvalidArgument, "channelID must be >= 1, got=%d", g) + } + if g := recv.RPCNo; g < 1 { + return nil, status.Errorf(codes.InvalidArgument, "rpcID must be >= 1, got=%d", g) + } + return recv, nil +} + +func TestRequestIDHeader_sentOnEveryClientCall(t *testing.T) { + interceptorTracker := newInterceptorTracker() + + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + sqlSELECT1 := "SELECT 1" + resultSet := &sppb.ResultSet{ + Rows: []*structpb.ListValue{ + {Values: []*structpb.Value{ + {Kind: &structpb.Value_NumberValue{NumberValue: 1}}, + }}, + }, + Metadata: &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {Name: "Int", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}, + }, + }, + }, + } + result := &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + } + server.TestSpanner.PutStatementResult(sqlSELECT1, result) + + txn := sc.ReadOnlyTransaction() + defer txn.Close() + + ctx := context.Background() + stmt := NewStatement(sqlSELECT1) + rowIter := txn.Query(ctx, stmt) + defer rowIter.Stop() + for { + rows, err := rowIter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + _ = rows + } + + if interceptorTracker.unaryCallCount() < 1 { + t.Error("unaryClientCall was not invoked") + } + if interceptorTracker.streamCallCount() < 1 { + t.Error("streamClientCall was not invoked") + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +type interceptorTracker struct { + nUnaryClientCalls *atomic.Uint64 + nStreamClientCalls *atomic.Uint64 + + mu sync.Mutex // mu protects the fields down below. + unaryClientRequestIDSegments []*requestIDSegments + streamClientRequestIDSegments []*requestIDSegments +} + +func (it *interceptorTracker) unaryCallCount() uint64 { + return it.nUnaryClientCalls.Load() +} + +func (it *interceptorTracker) streamCallCount() uint64 { + return it.nStreamClientCalls.Load() +} + +func (it *interceptorTracker) validateRequestIDsMonotonicity() error { + if err := ensureMonotonicityOfRequestIDs(it.unaryClientRequestIDSegments); err != nil { + return fmt.Errorf("unaryClientRequestIDs: %w", err) + } + if err := ensureMonotonicityOfRequestIDs(it.streamClientRequestIDSegments); err != nil { + return fmt.Errorf("streamClientRequestIDs: %w", err) + } + return nil +} + +func ensureMonotonicityOfRequestIDs(requestIDs []*requestIDSegments) error { + // Compare the current against previous requestID which requires at least 2 elements. + for i := 1; i < len(requestIDs); i++ { + rCurr, rPrev := requestIDs[i], requestIDs[i-1] + if rPrev.ProcessID != rCurr.ProcessID { + return fmt.Errorf("processID mismatch: #[%d].ProcessID=%d, #[%d].ProcessID=%d", i, rCurr.ProcessID, i-1, rPrev.ProcessID) + } + if rPrev.ClientID == rCurr.ClientID { + // In the case of retries, we shall might have the same request + // number, but rpc id must be monotonically increasing. + if rPrev.RequestNo == rCurr.RequestNo { + if rPrev.RPCNo >= rCurr.RPCNo { + return fmt.Errorf("sameClientID but rpcNo mismatch: #[%d].RPCNo=%d >= #[%d].RPCNo=%d", i-1, rPrev.RPCNo, i, rCurr.RPCNo) + } + } else if rPrev.RequestNo > rCurr.RequestNo { + return fmt.Errorf("sameClientID but requestNo mismatch: #[%d].reqNo=%d < #[%d].reqNo=%d", i, rCurr.RequestNo, i-1, rPrev.RequestNo) + } + } else if rPrev.ClientID > rCurr.ClientID { + // For requests that execute in parallel such as with PartitionQuery, + // we could have requests from previous clients executing slower than + // the newest client, hence this is not an error. + } + } + + // All checks passed so good to go. + return nil +} + +func TestRequestIDHeader_ensureMonotonicityOfRequestIDs(t *testing.T) { + tests := []struct { + name string + in []*requestIDSegments + wantErr string + }{ + {name: "no values", wantErr: ""}, + {name: "1 value", in: []*requestIDSegments{{ProcessID: 123}}, wantErr: ""}, + {name: "Different processIDs", in: []*requestIDSegments{{ProcessID: 1}, {ProcessID: 2}}, wantErr: "processID mismatch"}, + { + name: "Different clientID, prev has higher value", + in: []*requestIDSegments{ + {ProcessID: 1, ClientID: 2}, + {ProcessID: 1, ClientID: 1}, + }, + wantErr: "clientID inconsistency: previous request", + }, + { + name: "Different clientID, prev has lower value", + in: []*requestIDSegments{ + {ProcessID: 1, ClientID: 1}, + {ProcessID: 1, ClientID: 2}, + }, + wantErr: "", + }, + { + name: "Same clientID, prev has same RPCNo", + in: []*requestIDSegments{ + {ProcessID: 1, ClientID: 1, RPCNo: 1}, + {ProcessID: 1, ClientID: 1, RPCNo: 1}, + }, + wantErr: "sameClientID but rpcNo mismatch", + }, + { + name: "Same clientID, prev has higher RPCNo", + in: []*requestIDSegments{ + {ProcessID: 1, ClientID: 1, RPCNo: 2}, + {ProcessID: 1, ClientID: 1, RPCNo: 1}, + }, + wantErr: "sameClientID but rpcNo mismatch", + }, + { + name: "Same clientID, prev has lower RPCNo", + in: []*requestIDSegments{ + {ProcessID: 1, ClientID: 1, RPCNo: 1}, + {ProcessID: 1, ClientID: 1, RPCNo: 2}, + }, + wantErr: "", + }, + { + name: "Same clientID, prev has higher clientID", + in: []*requestIDSegments{ + {ProcessID: 1, ClientID: 2, RPCNo: 1}, + {ProcessID: 1, ClientID: 1, RPCNo: 1}, + }, + wantErr: "clientID inconsistency: previous request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ensureMonotonicityOfRequestIDs(tt.in) + if tt.wantErr != "" { + if err == nil { + t.Fatal("Expected a non-nil error") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Error mismatch\n\t%q\ncould not be found in\n\t%q", tt.wantErr, err) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + }) + } +} + +func (it *interceptorTracker) unaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + it.nUnaryClientCalls.Add(1) + reqID, err := checkForMissingSpannerRequestIDHeader(opts) + if err != nil { + return err + } + + it.mu.Lock() + it.unaryClientRequestIDSegments = append(it.unaryClientRequestIDSegments, reqID) + it.mu.Unlock() + + // fmt.Printf("unary.method=%q\n", method) + // fmt.Printf("method=%q\nReq: %#v\nRes: %#v\n", method, req, reply) + // Otherwise proceed with the call. + return invoker(ctx, method, req, reply, cc, opts...) +} + +func (it *interceptorTracker) streamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + it.nStreamClientCalls.Add(1) + reqID, err := checkForMissingSpannerRequestIDHeader(opts) + if err != nil { + return nil, err + } + + it.mu.Lock() + it.streamClientRequestIDSegments = append(it.streamClientRequestIDSegments, reqID) + it.mu.Unlock() + + // fmt.Printf("stream.method=%q\n", method) + // Otherwise proceed with the call. + return streamer(ctx, desc, cc, method, opts...) +} + +func newInterceptorTracker() *interceptorTracker { + return &interceptorTracker{ + nUnaryClientCalls: new(atomic.Uint64), + nStreamClientCalls: new(atomic.Uint64), + } +} + +func TestRequestIDHeader_onRetriesWithFailedTransactionCommit(t *testing.T) { + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // First commit will fail, and the retry will begin a new transaction. + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{newAbortedErrorWithMinimalRetryDelay()}, + }) + + ctx := context.Background() + ms := []*Mutation{ + Insert("Accounts", []string{"AccountId"}, []any{int64(1)}), + } + + if _, err := sc.Apply(ctx, ms); err != nil { + t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", err) + } + + if _, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.CommitRequest{}, // First commit fails. + &sppb.BeginTransactionRequest{}, + &sppb.CommitRequest{}, // Second commit succeeds. + }); err != nil { + t.Fatal(err) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(5); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g := interceptorTracker.streamCallCount(); g > 0 { + t.Errorf("streamClientCall was unexpectedly invoked %d times", g) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +// Tests that SessionNotFound errors are retried. +func TestRequestIDHeader_retriesOnSessionNotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + serverErr := newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s") + server.TestSpanner.PutExecutionTime(testutil.MethodBeginTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{serverErr, serverErr, serverErr}, + }) + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{serverErr}, + }) + + txn := sc.ReadOnlyTransaction() + defer txn.Close() + + var wantErr error + if _, _, got := txn.acquire(ctx); !testEqual(wantErr, got) { + t.Fatalf("Expect acquire to succeed, got %v, want %v.", got, wantErr) + } + + // The server error should lead to a retry of the BeginTransaction call and + // a valid session handle to be returned that will be used by the following + // requests. Note that calling txn.Query(...) does not actually send the + // query to the (mock) server. That is done at the first call to + // RowIterator.Next. The following statement only verifies that the + // transaction is in a valid state and received a valid session handle. + if got := txn.Query(ctx, NewStatement("SELECT 1")); !testEqual(wantErr, got.err) { + t.Fatalf("Expect Query to succeed, got %v, want %v.", got.err, wantErr) + } + + if got := txn.Read(ctx, "Users", KeySets(Key{"alice"}, Key{"bob"}), []string{"name", "email"}); !testEqual(wantErr, got.err) { + t.Fatalf("Expect Read to succeed, got %v, want %v.", got.err, wantErr) + } + + wantErr = ToSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")) + ms := []*Mutation{ + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []any{int64(1), "Foo", int64(50)}), + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []any{int64(2), "Bar", int64(1)}), + } + _, got := sc.Apply(ctx, ms, ApplyAtLeastOnce()) + if !testEqual(wantErr, got) { + t.Fatalf("Expect Apply to fail\nGot: %v\nWant: %v\n", got, wantErr) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(8); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g := interceptorTracker.streamCallCount(); g > 0 { + t.Errorf("streamClientCall was unexpectedly invoked %d times", g) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_BatchDMLWithMultipleDML(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + ctx := context.Background() + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + updateBarSetFoo := testutil.UpdateBarSetFoo + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { + return err + } + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}, {SQL: updateBarSetFoo}}); err != nil { + return err + } + if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { + return err + } + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}}) + return err + }) + if err != nil { + t.Fatal(err) + } + + gotReqs, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.CommitRequest{}, + }) + if err != nil { + t.Fatal(err) + } + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if got, want := gotReqs[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { + t.Errorf("got %d, want %d", got, want) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(6); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g := interceptorTracker.streamCallCount(); g > 0 { + t.Errorf("streamClientCall was unexpectedly invoked %d times", g) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_clientBatchWrite(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []any{"foo1", 1}}, + }}, + } + iter := sc.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != len(mutationGroups) { + t.Fatalf("Response count mismatch.\nGot: %v\nWant:%v", responseCount, len(mutationGroups)) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchWriteRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(1); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g, w := interceptorTracker.streamCallCount(), uint64(1); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ClientBatchWriteWithSessionNotFound(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + server.TestSpanner.PutExecutionTime( + testutil.MethodBatchWrite, + testutil.SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []any{"foo1", 1}}, + }}, + } + iter := sc.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != len(mutationGroups) { + t.Fatalf("Response count mismatch.\nGot: %v\nWant:%v", responseCount, len(mutationGroups)) + } + + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchWriteRequest{}, + &sppb.BatchWriteRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(1); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a retry for BatchWrite after the first SessionNotFound error, hence expecting 2 calls. + if g, w := interceptorTracker.streamCallCount(), uint64(2); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ClientBatchWriteWithError(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + injectedErr := status.Error(codes.InvalidArgument, "Invalid argument") + server.TestSpanner.PutExecutionTime( + testutil.MethodBatchWrite, + testutil.SimulatedExecutionTime{Errors: []error{injectedErr}}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []any{"foo1", 1}}, + }}, + } + iter := sc.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != 0 { + t.Fatalf("Do unexpectedly called %d times", responseCount) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(1); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(1); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_PartitionQueryWithoutError(t *testing.T) { + testRequestIDHeaderPartitionQuery(t, false) +} + +func TestRequestIDHeader_PartitionQueryWithError(t *testing.T) { + testRequestIDHeaderPartitionQuery(t, true) +} + +func testRequestIDHeaderPartitionQuery(t *testing.T, mustErrorOnPartitionQuery bool) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // The request will initially fail, and be retried. + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, + testutil.SimulatedExecutionTime{ + Errors: []error{newAbortedErrorWithMinimalRetryDelay()}, + }) + if mustErrorOnPartitionQuery { + server.TestSpanner.PutExecutionTime(testutil.MethodPartitionQuery, + testutil.SimulatedExecutionTime{ + Errors: []error{newAbortedErrorWithMinimalRetryDelay()}, + }) + } + + sqlFromSingers := "SELECT * FROM Singers" + resultSet := &sppb.ResultSet{ + Rows: []*structpb.ListValue{ + { + Values: []*structpb.Value{ + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 1}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Bruce"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "Wayne"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 2}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Robin"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "SideKick"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 3}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Gordon"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "Commissioner"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 4}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Joker"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "None"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 5}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Riddler"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "None"}}, + }, + }), + }}, + }, + Metadata: &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {Name: "SingerId", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}, + {Name: "FirstName", Type: &sppb.Type{Code: sppb.TypeCode_STRING}}, + {Name: "LastName", Type: &sppb.Type{Code: sppb.TypeCode_STRING}}, + }, + }, + }, + } + result := &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + } + server.TestSpanner.PutStatementResult(sqlFromSingers, result) + + ctx := context.Background() + txn, err := sc.BatchReadOnlyTransaction(ctx, StrongRead()) + + if err != nil { + t.Fatal(err) + } + defer txn.Close() + + // Singer represents the elements in a row from the Singers table. + type Singer struct { + SingerID int64 + FirstName string + LastName string + SingerInfo []byte + } + stmt := Statement{SQL: "SELECT * FROM Singers;"} + partitions, err := txn.PartitionQuery(ctx, stmt, PartitionOptions{}) + + if mustErrorOnPartitionQuery { + // The methods invoked should be: ['/BatchCreateSessions', '/CreateSession', '/BeginTransaction', '/PartitionQuery'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } + return + } + + if err != nil { + t.Fatal(err) + } + + wg := new(sync.WaitGroup) + for i, p := range partitions { + wg.Add(1) + go func(i int, p *Partition) { + defer wg.Done() + iter := txn.Execute(ctx, p) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + var s Singer + if err := row.ToStruct(&s); err != nil { + _ = err + } + _ = s + } + }(i, p) + } + wg.Wait() + + // The methods invoked should be: ['/BatchCreateSessions', '/CreateSession', '/BeginTransaction', '/PartitionQuery'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ReadWriteTransactionUpdate(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + ctx := context.Background() + updateSQL := testutil.UpdateBarSetFoo + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + if _, err = tx.Update(ctx, Statement{SQL: updateSQL}); err != nil { + return err + } + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateSQL}, {SQL: updateSQL}}); err != nil { + return err + } + if _, err = tx.Update(ctx, Statement{SQL: updateSQL}); err != nil { + return err + } + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateSQL}}) + return err + }) + if err != nil { + t.Fatal(err) + } + + gotReqs, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.CommitRequest{}, + }) + if err != nil { + t.Fatal(err) + } + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if got, want := gotReqs[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { + t.Errorf("got %d, want %d", got, want) + } + + // The methods invoked should be: ['/BatchCreateSessions', '/ExecuteSql', '/ExecuteBatchDml', '/ExecuteSql', '/ExecuteBatchDml', '/Commit'] + if g, w := interceptorTracker.unaryCallCount(), uint64(6); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ReadWriteTransactionBatchUpdateWithOptions(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + _, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + ctx := context.Background() + selectSQL := testutil.SelectSingerIDAlbumIDAlbumTitleFromAlbums + updateSQL := testutil.UpdateBarSetFoo + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + iter := tx.QueryWithOptions(ctx, NewStatement(selectSQL), QueryOptions{}) + iter.Next() + iter.Stop() + + qo := QueryOptions{} + iter = tx.ReadWithOptions(ctx, "FOO", AllKeys(), []string{"BAR"}, &ReadOptions{Priority: qo.Priority}) + iter.Next() + iter.Stop() + + tx.UpdateWithOptions(ctx, NewStatement(updateSQL), qo) + tx.BatchUpdateWithOptions(ctx, []Statement{ + NewStatement(updateSQL), + }, qo) + return nil + }) + if err != nil { + t.Fatal(err) + } + + // The methods invoked should be: ['/BatchCreateSessions', '/ExecuteSql', '/ExecuteBatchDml', '/Commit'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // The methods invoked should be: ['/ExecuteStreamingSql', '/StreamingRead'] + if g, w := interceptorTracker.streamCallCount(), uint64(2); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index 7468f21bc722..221d207c8965 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -274,6 +274,8 @@ func (sc *sessionClient) executeBatchCreateSessions(client spannerClient, create break } var mdForGFELatency metadata.MD + // Each invocation of client.BatchCreateSessions is not automatically retried + // hence we don't need to pull out the spanner request-id and increment the RPC number. response, err := client.BatchCreateSessions(contextWithOutgoingMetadata(ctx, sc.md, sc.disableRouteToLeader), &sppb.BatchCreateSessionsRequest{ SessionCount: remainingCreateCount, Database: sc.database, @@ -342,6 +344,8 @@ func (sc *sessionClient) executeCreateMultiplexedSession(ctx context.Context, cl return } var mdForGFELatency metadata.MD + // Each invocation of executeCreateMultiplexedSession is not automatically retried + // hence we don't need to pull out the spanner request-id and increment the RPC number. response, err := client.CreateSession(contextWithOutgoingMetadata(ctx, sc.md, sc.disableRouteToLeader), &sppb.CreateSessionRequest{ Database: sc.database, // Multiplexed sessions do not support labels. diff --git a/spanner/transaction.go b/spanner/transaction.go index 9e3e107d1065..f1fc2fe349f4 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -785,15 +785,26 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { } }() // Retry the BeginTransaction call if a 'Session not found' is returned. + nRPCs := uint64(0) for { + nRPCs++ sh, err = t.sp.takeMultiplexed(ctx) if err != nil { return err } + t.setSessionEligibilityForLongRunning(sh) sh.updateLastUseTime() var md metadata.MD - res, err = sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.BeginTransactionRequest{ + + client := sh.getClient() + // Firstly set the number of retries as the RPCID. + gcl, ok := client.(*grpcSpannerClient) + if ok { + gcl.setRPCID(nRPCs) + } + + res, err = client.BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.BeginTransactionRequest{ Session: sh.getID(), Options: &sppb.TransactionOptions{ Mode: &sppb.TransactionOptions_ReadOnly_{ @@ -801,6 +812,9 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { }, }, }, gax.WithGRPCOptions(grpc.Header(&md))) + if ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "begin_BeginTransaction"); err != nil { @@ -1201,7 +1215,14 @@ func (t *ReadWriteTransaction) update(ctx context.Context, stmt Statement, opts sh.updateLastUseTime() var md metadata.MD - resultSet, err := sh.getClient().ExecuteSql(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md))) + + client := sh.getClient() + // ReadWriteTransaction.update has to be updated manually and is not retried + // automatically, hence we do not need to extract and increment the request-id's RPC count. + resultSet, err := client.ExecuteSql(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md))) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "update"); err != nil { @@ -1304,13 +1325,20 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts sh.updateLastUseTime() var md metadata.MD - resp, err := sh.getClient().ExecuteBatchDml(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.ExecuteBatchDmlRequest{ + + // ReadWriteTransaction.batchUpdateWithOptions has to be updated manually and is not + // retried automatically, hence we do not need to extract and increment the request-id. + client := sh.getClient() + resp, err := client.ExecuteBatchDml(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.ExecuteBatchDmlRequest{ Session: sh.getID(), Transaction: ts, Statements: sppbStmts, Seqno: atomic.AddInt64(&t.sequenceNumber, 1), RequestOptions: createRequestOptions(opts.Priority, opts.RequestTag, t.txOpts.TransactionTag), }, gax.WithGRPCOptions(grpc.Header(&md))) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "batchUpdateWithOptions"); err != nil { @@ -1348,7 +1376,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts return counts, errInlineBeginTransactionFailed() } if resp.Status != nil && resp.Status.Code != 0 { - return counts, spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message) + return counts, spannerError(codes.Code(uint32(resp.Status.Code)), resp.Status.Message) } return counts, nil } @@ -1499,6 +1527,9 @@ func beginTransaction(ctx context.Context, sid string, client spannerClient, opt ExcludeTxnFromChangeStreams: opts.ExcludeTxnFromChangeStreams, }, }) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if err != nil { return nil, err } @@ -1647,6 +1678,8 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions if options.MaxCommitDelay != nil { maxCommitDelay = durationpb.New(*(options.MaxCommitDelay)) } + // .commit is not automatically retried hence no need to + // extract the spanner requestID to increment the retry count. res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{ Session: sid, Transaction: &sppb.CommitRequest_TransactionId{ @@ -1657,6 +1690,10 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions ReturnCommitStats: options.ReturnCommitStats, MaxCommitDelay: maxCommitDelay, }, gax.WithGRPCOptions(grpc.Header(&md))) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } + if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "commit"); err != nil { trace.TracePrintf(ctx, nil, "Error in recording GFE Latency. Try disabling and rerunning. Error: %v", err) @@ -1702,6 +1739,9 @@ func (t *ReadWriteTransaction) rollback(ctx context.Context) { Session: sid, TransactionId: t.tx, }) + if gcl, ok := client.(*grpcSpannerClient); ok { + gcl.setOrResetRPCID() + } if isSessionNotFoundError(err) { t.sh.destroy() } @@ -1903,6 +1943,7 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta retryer := onCodes(DefaultRetryBackoff, codes.Aborted, codes.Internal) // Apply the mutation and retry if the commit is aborted. applyMutationWithRetry := func(ctx context.Context) error { + nRPCs := uint64(0) for { if sh == nil || sh.getID() == "" || sh.getClient() == nil { // No usable session for doing the commit, take one from pool. @@ -1913,8 +1954,15 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta return ToSpannerError(err) } } + nRPCs++ sh.updateLastUseTime() - res, err := sh.getClient().Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{ + client := sh.getClient() + // Firstly set the number of retries as the RPCID. + gcl, ok := client.(*grpcSpannerClient) + if ok { + gcl.setRPCID(nRPCs) + } + res, err := client.Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{ Session: sh.getID(), Transaction: &sppb.CommitRequest_SingleUseTransaction{ SingleUseTransaction: &sppb.TransactionOptions{ @@ -1928,6 +1976,9 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta RequestOptions: createRequestOptions(t.commitPriority, "", t.transactionTag), MaxCommitDelay: maxCommitDelay, }) + if ok { + gcl.setOrResetRPCID() + } if err != nil && !isAbortedErr(err) { // should not be the case with multiplexed sessions if isSessionNotFoundError(err) {