diff --git a/docstore/awsdynamodb/dynamo.go b/docstore/awsdynamodb/dynamo.go index 2df0aef081..a378dc40bd 100644 --- a/docstore/awsdynamodb/dynamo.go +++ b/docstore/awsdynamodb/dynamo.go @@ -162,10 +162,12 @@ func (c *collection) RevisionField() string { return c.opts.RevisionField } func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError { errs := make([]error, len(actions)) - beforeGets, gets, writes, afterGets := driver.GroupActions(actions) + beforeGets, gets, writes, writesTx, afterGets := driver.GroupActions(actions) c.runGets(ctx, beforeGets, errs, opts) ch := make(chan struct{}) + ch2 := make(chan struct{}) go func() { defer close(ch); c.runWrites(ctx, writes, errs, opts) }() + go func() { defer close(ch2); c.transactWrite(ctx, writesTx, errs, opts) }() c.runGets(ctx, gets, errs, opts) <-ch c.runGets(ctx, afterGets, errs, opts) @@ -613,25 +615,26 @@ func revisionPrecondition(doc driver.Document, revField string) (*expression.Con return &cb, nil } -// TODO(jba): use this if/when we support atomic writes. -func (c *collection) transactWrite(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions, start, end int) { +func (c *collection) transactWrite(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions) { + if len(actions) == 0 { + return + } setErr := func(err error) { - for i := start; i <= end; i++ { - errs[actions[i].Index] = err + for _, a := range actions { + errs[a.Index] = err } } + tws := make([]*dyn.TransactWriteItem, 0, len(actions)) var ops []*writeOp - tws := make([]*dyn.TransactWriteItem, 0, end-start+1) - for i := start; i <= end; i++ { - a := actions[i] - op, err := c.newWriteOp(a, opts) + for _, w := range actions { + op, err := c.newWriteOp(w, opts) if err != nil { - setErr(err) - return + errs[w.Index] = err + } else { + ops = append(ops, op) + tws = append(tws, op.writeItem) } - ops = append(ops, op) - tws = append(tws, op.writeItem) } in := &dyn.TransactWriteItemsInput{ diff --git a/docstore/docstore.go b/docstore/docstore.go index 388d46ea8a..a4bb0b6508 100644 --- a/docstore/docstore.go +++ b/docstore/docstore.go @@ -135,18 +135,20 @@ func (c *Collection) Actions() *ActionList { // document; a Get after the write will see the new value if the service is strongly // consistent, but may see the old value if the service is eventually consistent. type ActionList struct { - coll *Collection - actions []*Action - beforeDo func(asFunc func(interface{}) bool) error + coll *Collection + actions []*Action + enableAtomicWrites bool + beforeDo func(asFunc func(interface{}) bool) error } // An Action is a read or write on a single document. // Use the methods of ActionList to create and execute Actions. type Action struct { - kind driver.ActionKind - doc Document - fieldpaths []FieldPath // paths to retrieve, for Get - mods Mods // modifications to make, for Update + kind driver.ActionKind + doc Document + fieldpaths []FieldPath // paths to retrieve, for Get + mods Mods // modifications to make, for Update + inAtomicWrite bool // if this action is a part of atomic writes } func (l *ActionList) add(a *Action) *ActionList { @@ -170,7 +172,7 @@ func (l *ActionList) add(a *Action) *ActionList { // Except for setting the revision field and possibly setting the key fields, the doc // argument is not modified. func (l *ActionList) Create(doc Document) *ActionList { - return l.add(&Action{kind: driver.Create, doc: doc}) + return l.add(&Action{kind: driver.Create, doc: doc, inAtomicWrite: l.enableAtomicWrites}) } // Replace adds an action that replaces a document to the given ActionList, and @@ -182,7 +184,7 @@ func (l *ActionList) Create(doc Document) *ActionList { // See the Revisions section of the package documentation for how revisions are // handled. func (l *ActionList) Replace(doc Document) *ActionList { - return l.add(&Action{kind: driver.Replace, doc: doc}) + return l.add(&Action{kind: driver.Replace, doc: doc, inAtomicWrite: l.enableAtomicWrites}) } // Put adds an action that adds or replaces a document to the given ActionList, and returns the ActionList. @@ -195,7 +197,7 @@ func (l *ActionList) Replace(doc Document) *ActionList { // See the Revisions section of the package documentation for how revisions are // handled. func (l *ActionList) Put(doc Document) *ActionList { - return l.add(&Action{kind: driver.Put, doc: doc}) + return l.add(&Action{kind: driver.Put, doc: doc, inAtomicWrite: l.enableAtomicWrites}) } // Delete adds an action that deletes a document to the given ActionList, and returns @@ -210,7 +212,7 @@ func (l *ActionList) Delete(doc Document) *ActionList { // semantics of an action list are to stop at first error, then we might abort a // list of Deletes just because one of the docs was not present, and that seems // wrong, or at least something you'd want to turn off. - return l.add(&Action{kind: driver.Delete, doc: doc}) + return l.add(&Action{kind: driver.Delete, doc: doc, inAtomicWrite: l.enableAtomicWrites}) } // Get adds an action that retrieves a document to the given ActionList, and @@ -252,9 +254,10 @@ func (l *ActionList) Get(doc Document, fps ...FieldPath) *ActionList { // the updated document, call Get after calling Update. func (l *ActionList) Update(doc Document, mods Mods) *ActionList { return l.add(&Action{ - kind: driver.Update, - doc: doc, - mods: mods, + kind: driver.Update, + doc: doc, + mods: mods, + inAtomicWrite: l.enableAtomicWrites, }) } @@ -430,7 +433,7 @@ func (c *Collection) toDriverAction(a *Action) (*driver.Action, error) { // A Put with a revision field is equivalent to a Replace. kind = driver.Replace } - d := &driver.Action{Kind: kind, Doc: ddoc, Key: key} + d := &driver.Action{Kind: kind, Doc: ddoc, Key: key, InAtomicWrite: a.inAtomicWrite} if a.fieldpaths != nil { d.FieldPaths, err = parseFieldPaths(a.fieldpaths) if err != nil { @@ -534,6 +537,12 @@ func (l *ActionList) String() string { return "[" + strings.Join(as, ", ") + "]" } +// AtomicWrites causes all following writes in the list to execute atomically. +func (l *ActionList) AtomicWrites() *ActionList { + l.enableAtomicWrites = true + return l +} + func (a *Action) String() string { buf := &strings.Builder{} fmt.Fprintf(buf, "%s(%v", a.kind, a.doc) diff --git a/docstore/driver/driver.go b/docstore/driver/driver.go index bd2bbccb0e..d786afdb9b 100644 --- a/docstore/driver/driver.go +++ b/docstore/driver/driver.go @@ -125,14 +125,14 @@ const ( //go:generate stringer -type=ActionKind -// An Action describes a single operation on a single document. type Action struct { - Kind ActionKind // the kind of action - Doc Document // the document on which to perform the action - Key interface{} // the document key returned by Collection.Key, to avoid recomputing it - FieldPaths [][]string // field paths to retrieve, for Get only - Mods []Mod // modifications to make, for Update only - Index int // the index of the action in the original action list + Kind ActionKind // the kind of action + Doc Document // the document on which to perform the action + Key interface{} // the document key returned by Collection.Key, to avoid recomputing it + FieldPaths [][]string // field paths to retrieve, for Get only + Mods []Mod // modifications to make, for Update only + Index int // the index of the action in the original action list + InAtomicWrite bool // if this action is a part of transaction } // A Mod is a modification to a field path in a document. diff --git a/docstore/driver/util.go b/docstore/driver/util.go index 5ead3f3fa4..53bb26f89a 100644 --- a/docstore/driver/util.go +++ b/docstore/driver/util.go @@ -55,12 +55,13 @@ func SplitActions(actions []*Action, split func(a, b *Action) bool) [][]*Action // GroupActions separates actions into four sets: writes, gets that must happen before the writes, // gets that must happen after the writes, and gets that can happen concurrently with the writes. -func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets []*Action) { +func GroupActions(actions []*Action) (beforeGets, getList, writeList, writesTxList, afterGets []*Action) { // maps from key to action bgets := map[interface{}]*Action{} agets := map[interface{}]*Action{} cgets := map[interface{}]*Action{} writes := map[interface{}]*Action{} + writesTx := map[interface{}]*Action{} var nilkeys []*Action for _, a := range actions { if a.Key == nil { @@ -69,7 +70,7 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets } else if a.Kind == Get { // If there was a prior write with this key, make sure this get // happens after the writes. - if _, ok := writes[a.Key]; ok { + if valueExistsInMaps(a.Key, writes, writesTx) { agets[a.Key] = a } else { cgets[a.Key] = a @@ -81,7 +82,11 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets delete(cgets, a.Key) bgets[a.Key] = g } - writes[a.Key] = a + if a.InAtomicWrite { + writesTx[a.Key] = a + } else { + writes[a.Key] = a + } } } @@ -95,7 +100,16 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets return as } - return vals(bgets), vals(cgets), append(vals(writes), nilkeys...), vals(agets) + return vals(bgets), vals(cgets), append(vals(writes), nilkeys...), vals(writesTx), vals(agets) +} + +func valueExistsInMaps(key interface{}, maps ...map[interface{}]*Action) bool { + for _, m := range maps { + if _, ok := m[key]; ok { + return true + } + } + return false } // AsFunc creates and returns an "as function" that behaves as follows: diff --git a/docstore/driver/util_test.go b/docstore/driver/util_test.go index 0df100c911..0f7493669e 100644 --- a/docstore/driver/util_test.go +++ b/docstore/driver/util_test.go @@ -79,7 +79,7 @@ func TestGroupActions(t *testing.T) { }{ { in: []*Action{{Kind: Get, Key: 1}}, - want: [][]int{nil, {0}, nil, nil}, + want: [][]int{nil, {0}, nil, nil, nil}, }, { in: []*Action{ @@ -89,16 +89,16 @@ func TestGroupActions(t *testing.T) { {Kind: Replace, Key: 2}, {Kind: Get, Key: 2}, }, - want: [][]int{{0}, {1}, {2, 3}, {4}}, + want: [][]int{{0}, {1}, {2, 3}, nil, {4}}, }, { in: []*Action{{Kind: Create}, {Kind: Create}, {Kind: Create}}, - want: [][]int{nil, nil, {0, 1, 2}, nil}, + want: [][]int{nil, nil, {0, 1, 2}, nil, nil}, }, } { - got := make([][]*Action, 4) - got[0], got[1], got[2], got[3] = GroupActions(test.in) - want := make([][]*Action, 4) + got := make([][]*Action, 5) + got[0], got[1], got[2], got[3], got[4] = GroupActions(test.in) + want := make([][]*Action, 5) for i, s := range test.want { for _, x := range s { want[i] = append(want[i], test.in[x]) diff --git a/docstore/drivertest/drivertest.go b/docstore/drivertest/drivertest.go index 940a57ac30..35cda22415 100644 --- a/docstore/drivertest/drivertest.go +++ b/docstore/drivertest/drivertest.go @@ -29,6 +29,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "gocloud.dev/docstore" "gocloud.dev/docstore/driver" "gocloud.dev/gcerrors" @@ -1900,6 +1901,81 @@ func testMultipleActions(t *testing.T, coll *docstore.Collection, revField strin } } +func testAtomicWrites(t *testing.T, coll *docstore.Collection, revField string) { + t.Helper() + + ctx := context.Background() + + must := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + var docs []docmap + for i := 0; i < 9; i++ { + docs = append(docs, docmap{ + KeyField: fmt.Sprintf("testAtomicWrites%d", i), + "s": fmt.Sprint(i), + revField: nil, + }) + } + + compare := func(gots, wants []docmap) { + t.Helper() + for i := 0; i < len(gots); i++ { + got := gots[i] + want := clone(wants[i]) + want[revField] = got[revField] + if !cmp.Equal(got, want, cmpopts.IgnoreUnexported(tspb.Timestamp{})) { + t.Errorf("index #%d:\ngot %v\nwant %v", i, got, want) + } + } + } + + // Put the first six docs. + actions := coll.Actions() + for i := 0; i < 6; i++ { + actions.Create(docs[i]) + } + must(actions.Do(ctx)) + + // Delete the first three, get the second three, and update last three in transaction. + gdocs := []docmap{ + {KeyField: docs[3][KeyField]}, + {KeyField: docs[4][KeyField]}, + {KeyField: docs[5][KeyField]}, + } + actions = coll.Actions() + actions.Get(gdocs[0]) + actions.Delete(docs[0]) + actions.Delete(docs[1]) + actions.Get(gdocs[1]) + actions.Delete(docs[2]) + actions.Get(gdocs[2]) + actions.AtomicWrites() + actions.Update(docs[6], docstore.Mods{"s": "66'"}) + actions.Update(docs[7], docstore.Mods{"s": "77'"}) + actions.Update(docs[8], docstore.Mods{"s": "88"}) + + must(actions.Do(ctx)) + compare(gdocs, docs[3:6]) + + // Get the docs updated as part of atomic writes and verify that got written. + actions = coll.Actions() + + doc := docmap{KeyField: docs[6][KeyField]} + _ = coll.Get(ctx, doc) + assert.Equal(t, "66", doc["s"]) + doc = docmap{KeyField: docs[7][KeyField]} + _ = coll.Get(ctx, doc) + assert.Equal(t, "77", doc["s"]) + doc = docmap{KeyField: docs[8][KeyField]} + _ = coll.Get(ctx, doc) + assert.Equal(t, "88", doc["s"]) +} + func testActionsOnStructNoRev(t *testing.T, _ Harness, coll *docstore.Collection) { t.Helper() diff --git a/docstore/gcpfirestore/fs.go b/docstore/gcpfirestore/fs.go index 076b50f48d..dc9961db33 100644 --- a/docstore/gcpfirestore/fs.go +++ b/docstore/gcpfirestore/fs.go @@ -265,7 +265,7 @@ func (c *collection) RevisionField() string { // RunActions implements driver.RunActions. func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError { errs := make([]error, len(actions)) - beforeGets, gets, writes, afterGets := driver.GroupActions(actions) + beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions) calls := c.buildCommitCalls(writes, errs) // runGets does not issue concurrent RPCs, so it doesn't need a throttle. c.runGets(ctx, beforeGets, errs, opts) diff --git a/docstore/memdocstore/mem.go b/docstore/memdocstore/mem.go index a5470261b6..c4ecf83178 100644 --- a/docstore/memdocstore/mem.go +++ b/docstore/memdocstore/mem.go @@ -191,7 +191,7 @@ func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, o } } - beforeGets, gets, writes, afterGets := driver.GroupActions(actions) + beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions) run(beforeGets) run(gets) run(writes) diff --git a/docstore/mongodocstore/mongo.go b/docstore/mongodocstore/mongo.go index 3e75cf9c66..24aa0c72f1 100644 --- a/docstore/mongodocstore/mongo.go +++ b/docstore/mongodocstore/mongo.go @@ -204,7 +204,7 @@ const mongoIDField = "_id" func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError { errs := make([]error, len(actions)) - beforeGets, gets, writes, afterGets := driver.GroupActions(actions) + beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions) c.runGets(ctx, beforeGets, errs, opts) ch := make(chan []error) go func() { ch <- c.bulkWrite(ctx, writes, errs, opts) }() diff --git a/go.mod b/go.mod index c143e3a866..448f5dd6f4 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/google/wire v0.6.0 github.com/googleapis/gax-go/v2 v2.13.0 github.com/lib/pq v1.10.9 + github.com/stretchr/testify v1.9.0 go.opencensus.io v0.24.0 golang.org/x/crypto v0.26.0 golang.org/x/net v0.28.0 @@ -95,6 +96,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -107,6 +109,7 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/prometheus v0.54.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 // indirect @@ -120,4 +123,5 @@ require ( golang.org/x/time v0.6.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240812133136-8ffd90a71988 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240812133136-8ffd90a71988 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 30afdc2858..4086e50e00 100644 --- a/go.sum +++ b/go.sum @@ -307,8 +307,12 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1 github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= @@ -322,6 +326,8 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1: github.com/prometheus/prometheus v0.54.0 h1:6+VmEkohHcofl3W5LyRlhw1Lfm575w/aX6ZFyVAmzM0= github.com/prometheus/prometheus v0.54.0/go.mod h1:xlLByHhk2g3ycakQGrMaU8K7OySZx98BzeCR99991NY= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -745,6 +751,8 @@ google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6h google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=