Skip to content

Commit

Permalink
feat(spanner): implement generation and propagation of "x-goog-spanne…
Browse files Browse the repository at this point in the history
…r-request-id" Header (#11048)

* spanner: implement generation and propagation of "x-goog-spanner-request-id" Header

In tandem with the specification:

    https://orijtech.notion.site/x-goog-spanner-request-id-always-on-gRPC-header-to-aid-in-quick-debugging-of-errors-14aba6bc91348091a58fca7a505c9827

this change adds sending over the "x-goog-spanner-request-id" header
for every unary and streaming call, in the form:

	<version>.<processId>.<clientId>.<channelId>.<requestCountForClient>.<rpcCountForRequest>

where:
* version is the version of the specification
* processId is a randomly generated uint64 singleton for the lifetime of a process
* clientId is the monotonically increasing id/number of gRPC Spanner clients created
* requestCountForClient is the monotonically increasing number of requests made by the client
* channelId currently at 1 is the Id of the client for Go
* rpcCountForRequest is the number of RPCs/retries within a specific request

This header is to be sent on both unary and streaming calls and it'll
help debug latencies for customers. On an error, customers can assert against
.Error and retrieve the associated .RequestID and log it, or even better
it'll be printed out whenever errors are logged.

Importantly making randIdForProcess to be a uint6 which is 64bits and not
a UUID4 which is 128bits which surely massively reduces the possibility of collisions
to ensure that high QPS applications can function and accept bursts of traffic
without worry, as the prior design used uint32 aka 32 bits for
which just 50,000 new processes being created could get the probability
of collisions to 25%, with this new change a company would have to
create 82 million QPS every second for 1,000 years for a 1% collision
with 2.6e18 for which the collision would be 1%.
Using 64-bits still provides really good protection whereby for a 1% chance of collision,
we would need 810 million objects, so we have good protection.
However, Google Cloud Spanner's backend has to store every one of the always on
headers for a desired retention period hence 64-bits is a great balance between collision
protection vs storage.

Fixes #11073

* Rebase with main; rename nthRPC to attempt

* Infer channelID from ConnPool directly

* Attach nthRequest to sessionClient instead of to grpcClient given channelID is derived from sessionClient.connPool

* Retain reference to grpc.Header(*metadata.MD)

We have to re-insert the request-id even after gax.Invoke->grpc
internals clear it. Added test to validate retries.

* Fix up Error.Error() to show RequestID for both cases

* spanner: bring in tests contributed by Knut

* spanner: allow errors with grpc.codes: Canceled and DeadlineExceeded to be wrapped with request-id

* spanner: correctly track and increment retry attempts for each ExecuteStreamingSql request

* spanner: propagate RequestID even for DeadlineExceeded

* spanner: assert .RequestID exists

* Address code reivew nits+feedback

* spanner: account for stream resets and retries

This change accounts for logic graciously raised by Knut
along with his test contribution.

* Address more updates
  • Loading branch information
odeke-em authored Dec 18, 2024
1 parent 300865f commit 10960c1
Show file tree
Hide file tree
Showing 13 changed files with 2,323 additions and 78 deletions.
12 changes: 6 additions & 6 deletions spanner/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
var (
sh *sessionHandle
err error
rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
rpc func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error)
)
if sh, _, err = t.acquire(ctx); err != nil {
return &RowIterator{err: err}
Expand All @@ -322,7 +322,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
sh.updateLastUseTime()
// Read or query partition.
if p.rreq != nil {
rpc = func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
rpc = func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) {
client, err := client.StreamingRead(ctx, &sppb.ReadRequest{
Session: p.rreq.Session,
Transaction: p.rreq.Transaction,
Expand All @@ -335,7 +335,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
ResumeToken: resumeToken,
DataBoostEnabled: p.rreq.DataBoostEnabled,
DirectedReadOptions: p.rreq.DirectedReadOptions,
})
}, opts...)
if err != nil {
return client, err
}
Expand All @@ -351,7 +351,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
return client, err
}
} else {
rpc = func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
rpc = func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) {
client, err := client.ExecuteStreamingSql(ctx, &sppb.ExecuteSqlRequest{
Session: p.qreq.Session,
Transaction: p.qreq.Transaction,
Expand All @@ -364,7 +364,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
ResumeToken: resumeToken,
DataBoostEnabled: p.qreq.DataBoostEnabled,
DirectedReadOptions: p.qreq.DirectedReadOptions,
})
}, opts...)
if err != nil {
return client, err
}
Expand All @@ -387,7 +387,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
t.sp.sc.metricsTracerFactory,
rpc,
t.setTimestamp,
t.release)
t.release, client.(*grpcSpannerClient))
}

// MarshalBinary implements BinaryMarshaler.
Expand Down
8 changes: 8 additions & 0 deletions spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
} else {
// Create gtransport ConnPool as usual if MultiEndpoint is not used.
// gRPC options.

// Add a unaryClientInterceptor and streamClientInterceptor.
reqIDInjector := new(requestIDHeaderInjector)
opts = append(opts,
option.WithGRPCDialOption(grpc.WithChainStreamInterceptor(reqIDInjector.interceptStream)),
option.WithGRPCDialOption(grpc.WithChainUnaryInterceptor(reqIDInjector.interceptUnary)),
)

allOpts := allClientOpts(config.NumChannels, config.Compression, opts...)
pool, err = gtransport.DialPool(ctx, allOpts...)
if err != nil {
Expand Down
8 changes: 6 additions & 2 deletions spanner/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4187,13 +4187,17 @@ func TestReadWriteTransaction_ContextTimeoutDuringCommit(t *testing.T) {
if se.GRPCStatus().Code() != w.GRPCStatus().Code() {
t.Fatalf("Error status mismatch:\nGot: %v\nWant: %v", se.GRPCStatus(), w.GRPCStatus())
}
if se.Error() != w.Error() {
t.Fatalf("Error message mismatch:\nGot %s\nWant: %s", se.Error(), w.Error())
if !testEqual(se, w) {
t.Fatalf("Error message mismatch:\nGot: %s\nWant: %s", se.Error(), w.Error())
}
var outcome *TransactionOutcomeUnknownError
if !errors.As(err, &outcome) {
t.Fatalf("Missing wrapped TransactionOutcomeUnknownError error")
}

if w.RequestID != "" {
t.Fatal("Missing .RequestID")
}
}

func TestFailedCommit_NoRollback(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions spanner/cmp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ func testEqual(a, b interface{}) bool {
if strings.Contains(path.GoString(), "{*spanner.Error}.err") {
return true
}
if strings.Contains(path.GoString(), "{*spanner.Error}.RequestID") {
return true
}
return false
}, cmp.Ignore()))
}
25 changes: 20 additions & 5 deletions spanner/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ type Error struct {
// additionalInformation optionally contains any additional information
// about the error.
additionalInformation string

// RequestID is the associated ID that was sent to Google Cloud Spanner's
// backend, as the value in the "x-goog-spanner-request-id" gRPC header.
RequestID string
}

// TransactionOutcomeUnknownError is wrapped in a Spanner error when the error
Expand Down Expand Up @@ -85,10 +89,17 @@ func (e *Error) Error() string {
return "spanner: OK"
}
code := ErrCode(e)

var s string
if e.additionalInformation == "" {
return fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc)
s = fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc)
} else {
s = fmt.Sprintf("spanner: code = %q, desc = %q, additional information = %s", code, e.Desc, e.additionalInformation)
}
return fmt.Sprintf("spanner: code = %q, desc = %q, additional information = %s", code, e.Desc, e.additionalInformation)
if e.RequestID != "" {
s = fmt.Sprintf("%s, requestID = %q", s, e.RequestID)
}
return s
}

// Unwrap returns the wrapped error (if any).
Expand Down Expand Up @@ -123,6 +134,10 @@ func (e *Error) decorate(info string) {
// APIError error having given error code as its status.
func spannerErrorf(code codes.Code, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return spannerError(code, msg)
}

func spannerError(code codes.Code, msg string) error {
wrapped, _ := apierror.FromError(status.Error(code, msg))
return &Error{
Code: code,
Expand Down Expand Up @@ -172,9 +187,9 @@ func toSpannerErrorWithCommitInfo(err error, errorDuringCommit bool) error {
desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg)
wrapped = &TransactionOutcomeUnknownError{err: wrapped}
}
return &Error{status.FromContextError(err).Code(), toAPIError(wrapped), desc, ""}
return &Error{status.FromContextError(err).Code(), toAPIError(wrapped), desc, "", ""}
case status.Code(err) == codes.Unknown:
return &Error{codes.Unknown, toAPIError(err), err.Error(), ""}
return &Error{codes.Unknown, toAPIError(err), err.Error(), "", ""}
default:
statusErr := status.Convert(err)
code, desc := statusErr.Code(), statusErr.Message()
Expand All @@ -183,7 +198,7 @@ func toSpannerErrorWithCommitInfo(err error, errorDuringCommit bool) error {
desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg)
wrapped = &TransactionOutcomeUnknownError{err: wrapped}
}
return &Error{code, toAPIError(wrapped), desc, ""}
return &Error{code, toAPIError(wrapped), desc, "", ""}
}
}

Expand Down
47 changes: 32 additions & 15 deletions spanner/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package spanner
import (
"context"
"strings"
"sync/atomic"

vkit "cloud.google.com/go/spanner/apiv1"
"cloud.google.com/go/spanner/apiv1/spannerpb"
Expand Down Expand Up @@ -67,6 +68,15 @@ type spannerClient interface {
type grpcSpannerClient struct {
raw *vkit.Client
metricsTracerFactory *builtinMetricsTracerFactory

// These fields are used to uniquely track x-goog-spanner-request-id where:
// raw(*vkit.Client) is the channel, and channelID is derived from the ordinal
// count of unique *vkit.Client as retrieved from the session pool.
channelID uint64
// id is derived from the SpannerClient.
id int
// nthRequest is incremented for each new request (but not for retries of requests).
nthRequest *atomic.Uint32
}

var (
Expand All @@ -76,13 +86,16 @@ var (

// newGRPCSpannerClient initializes a new spannerClient that uses the gRPC
// Spanner API.
func newGRPCSpannerClient(ctx context.Context, sc *sessionClient, opts ...option.ClientOption) (spannerClient, error) {
func newGRPCSpannerClient(ctx context.Context, sc *sessionClient, channelID uint64, opts ...option.ClientOption) (spannerClient, error) {
raw, err := vkit.NewClient(ctx, opts...)
if err != nil {
return nil, err
}

g := &grpcSpannerClient{raw: raw, metricsTracerFactory: sc.metricsTracerFactory}
clientID := sc.nthClient
g.prepareRequestIDTrackers(clientID, channelID, sc.nthRequest)

clientInfo := []string{"gccl", internal.Version}
if sc.userAgent != "" {
agentWithVersion := strings.SplitN(sc.userAgent, "/", 2)
Expand Down Expand Up @@ -118,7 +131,7 @@ func (g *grpcSpannerClient) CreateSession(ctx context.Context, req *spannerpb.Cr
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.CreateSession(ctx, req, opts...)
resp, err := g.raw.CreateSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -128,7 +141,7 @@ func (g *grpcSpannerClient) BatchCreateSessions(ctx context.Context, req *spanne
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.BatchCreateSessions(ctx, req, opts...)
resp, err := g.raw.BatchCreateSessions(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -138,21 +151,21 @@ func (g *grpcSpannerClient) GetSession(ctx context.Context, req *spannerpb.GetSe
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.GetSession(ctx, req, opts...)
resp, err := g.raw.GetSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest, opts ...gax.CallOption) *vkit.SessionIterator {
return g.raw.ListSessions(ctx, req, opts...)
return g.raw.ListSessions(ctx, req, g.optsWithNextRequestID(opts)...)
}

func (g *grpcSpannerClient) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest, opts ...gax.CallOption) error {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
err := g.raw.DeleteSession(ctx, req, opts...)
err := g.raw.DeleteSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return err
Expand All @@ -162,21 +175,23 @@ func (g *grpcSpannerClient) ExecuteSql(ctx context.Context, req *spannerpb.Execu
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.ExecuteSql(ctx, req, opts...)
resp, err := g.raw.ExecuteSql(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) ExecuteStreamingSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (spannerpb.Spanner_ExecuteStreamingSqlClient, error) {
// Note: This method does not add g.optsWithNextRequestID to inject x-goog-spanner-request-id
// as it is already manually added when creating Stream iterators for ExecuteStreamingSql.
return g.raw.ExecuteStreamingSql(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
}

func (g *grpcSpannerClient) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest, opts ...gax.CallOption) (*spannerpb.ExecuteBatchDmlResponse, error) {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.ExecuteBatchDml(ctx, req, opts...)
resp, err := g.raw.ExecuteBatchDml(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -186,21 +201,23 @@ func (g *grpcSpannerClient) Read(ctx context.Context, req *spannerpb.ReadRequest
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.Read(ctx, req, opts...)
resp, err := g.raw.Read(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) StreamingRead(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (spannerpb.Spanner_StreamingReadClient, error) {
// Note: This method does not add g.optsWithNextRequestID, as it is already
// manually added when creating Stream iterators for StreamingRead.
return g.raw.StreamingRead(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
}

func (g *grpcSpannerClient) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest, opts ...gax.CallOption) (*spannerpb.Transaction, error) {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.BeginTransaction(ctx, req, opts...)
resp, err := g.raw.BeginTransaction(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -210,7 +227,7 @@ func (g *grpcSpannerClient) Commit(ctx context.Context, req *spannerpb.CommitReq
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.Commit(ctx, req, opts...)
resp, err := g.raw.Commit(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -220,7 +237,7 @@ func (g *grpcSpannerClient) Rollback(ctx context.Context, req *spannerpb.Rollbac
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
err := g.raw.Rollback(ctx, req, opts...)
err := g.raw.Rollback(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return err
Expand All @@ -230,7 +247,7 @@ func (g *grpcSpannerClient) PartitionQuery(ctx context.Context, req *spannerpb.P
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.PartitionQuery(ctx, req, opts...)
resp, err := g.raw.PartitionQuery(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -240,12 +257,12 @@ func (g *grpcSpannerClient) PartitionRead(ctx context.Context, req *spannerpb.Pa
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.PartitionRead(ctx, req, opts...)
resp, err := g.raw.PartitionRead(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) BatchWrite(ctx context.Context, req *spannerpb.BatchWriteRequest, opts ...gax.CallOption) (spannerpb.Spanner_BatchWriteClient, error) {
return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...)
}
4 changes: 4 additions & 0 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ const (
MethodExecuteBatchDml string = "EXECUTE_BATCH_DML"
MethodStreamingRead string = "EXECUTE_STREAMING_READ"
MethodBatchWrite string = "BATCH_WRITE"
MethodPartitionQuery string = "PARTITION_QUERY"
)

// StatementResult represents a mocked result on the test server. The result is
Expand Down Expand Up @@ -1107,6 +1108,9 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba
}

func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
return nil, err
}
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
Expand Down
Loading

0 comments on commit 10960c1

Please sign in to comment.