diff --git a/.github/workflows/vet.sh b/.github/workflows/vet.sh index 362ec936642b..17dd311b60e7 100755 --- a/.github/workflows/vet.sh +++ b/.github/workflows/vet.sh @@ -61,6 +61,7 @@ golint ./... 2>&1 | ( grep -v "firestore.arrayUnion" | grep -v "firestore.arrayRemove" | grep -v "maxAttempts" | + grep -v "firestore.commitResponse" | grep -v "UptimeCheckIpIterator" | grep -vE "apiv[0-9]+" | grep -v "ALL_CAPS" | diff --git a/firestore/examples_test.go b/firestore/examples_test.go index 70bec366016d..155362076f9a 100644 --- a/firestore/examples_test.go +++ b/firestore/examples_test.go @@ -564,6 +564,9 @@ func ExampleClient_RunTransaction() { } defer client.Close() + // write the CommitResponse here, via firestore.WithCommitResponse (below) + var cr firestore.CommitResponse + nm := client.Doc("States/NewMexico") err = client.RunTransaction(ctx, func(ctx context.Context, tx *firestore.Transaction) error { doc, err := tx.Get(nm) // tx.Get, NOT nm.Get! @@ -575,10 +578,11 @@ func ExampleClient_RunTransaction() { return err } return tx.Update(nm, []firestore.Update{{Path: "pop", Value: pop.(float64) + 0.2}}) - }) + }, firestore.WithCommitResponseTo(&cr)) if err != nil { // TODO: Handle error. } + // CommitResponse can be accessed here } func ExampleArrayUnion_create() { diff --git a/firestore/transaction.go b/firestore/transaction.go index 5cf07eab5564..6fd2ecce3160 100644 --- a/firestore/transaction.go +++ b/firestore/transaction.go @@ -17,6 +17,7 @@ package firestore import ( "context" "errors" + "time" pb "cloud.google.com/go/firestore/apiv1/firestorepb" "cloud.google.com/go/internal/trace" @@ -41,6 +42,7 @@ type Transaction struct { // A TransactionOption is an option passed to Client.Transaction. type TransactionOption interface { config(t *Transaction) + handleCommitResponse(r *pb.CommitResponse) } // MaxAttempts is a TransactionOption that configures the maximum number of times to @@ -49,7 +51,8 @@ func MaxAttempts(n int) maxAttempts { return maxAttempts(n) } type maxAttempts int -func (m maxAttempts) config(t *Transaction) { t.maxAttempts = int(m) } +func (m maxAttempts) config(t *Transaction) { t.maxAttempts = int(m) } +func (m maxAttempts) handleCommitResponse(r *pb.CommitResponse) {} // DefaultTransactionMaxAttempts is the default number of times to attempt a transaction. const DefaultTransactionMaxAttempts = 5 @@ -60,7 +63,35 @@ var ReadOnly = ro{} type ro struct{} -func (ro) config(t *Transaction) { t.readOnly = true } +func (ro) config(t *Transaction) { t.readOnly = true } +func (ro) handleCommitResponse(r *pb.CommitResponse) {} + +// CommitResponse exposes information about a committed transaction. +type CommitResponse struct { + response *pb.CommitResponse +} + +// CommitTime returns the commit time from the commit response. +func (r *CommitResponse) CommitTime() time.Time { + return r.response.CommitTime.AsTime() +} + +// commitResponse is the TransactionOption to record a commit response. +type commitResponse struct { + responseTo *CommitResponse +} + +func (c commitResponse) config(t *Transaction) {} +func (c commitResponse) handleCommitResponse(r *pb.CommitResponse) { + c.responseTo.response = r +} + +// WithCommitResponseTo returns a TransactionOption that specifies where the +// CommitResponse should be written on successful commit. Nothing is written +// on a failed commit. +func WithCommitResponseTo(r *CommitResponse) commitResponse { + return commitResponse{responseTo: r} +} var ( // Defined here for testing. @@ -115,6 +146,7 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr } } var backoff gax.Backoff + var commitResponse *pb.CommitResponse // TODO(jba): use other than the standard backoff parameters? // TODO(jba): get backoff time from gRPC trailer metadata? See // extractRetryDelay in https://code.googlesource.com/gocloud/+/master/spanner/retry.go. @@ -142,13 +174,20 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr return err } t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/firestore.Client.Commit") - _, err = t.c.c.Commit(t.ctx, &pb.CommitRequest{ + commitResponse, err = t.c.c.Commit(t.ctx, &pb.CommitRequest{ Database: t.c.path(), Writes: t.writes, Transaction: t.id, }) trace.EndSpan(t.ctx, err) + // on success, handle the commit response + if err == nil { + for _, opt := range opts { + opt.handleCommitResponse(commitResponse) + } + } + // If a read-write transaction returns Aborted, retry. // On success or other failures, return here. if t.readOnly || status.Code(err) != codes.Aborted { diff --git a/firestore/transaction_test.go b/firestore/transaction_test.go index ecb0d0373122..e6888a83d015 100644 --- a/firestore/transaction_test.go +++ b/firestore/transaction_test.go @@ -84,6 +84,7 @@ func TestRunTransaction(t *testing.T) { }, &pb.CommitResponse{CommitTime: aTimestamp3}, ) + var commitResponse CommitResponse err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { docref := c.Collection("C").Doc("a") doc, err := tx.Get(docref) @@ -95,11 +96,17 @@ func TestRunTransaction(t *testing.T) { return err } return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}}) - }) + }, WithCommitResponseTo(&commitResponse)) if err != nil { t.Fatal(err) } + // validate commit time + commitTime := commitResponse.CommitTime() + if commitTime != aTimestamp3.AsTime() { + t.Fatalf("commit time %v should equal %v", commitTime, aTimestamp3) + } + // Query srv.reset() srv.addRPC(beginReq, beginRes)