diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java index f6acf04b6ca..06237131458 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java @@ -16,10 +16,12 @@ package com.google.cloud.spanner; +import com.google.api.gax.rpc.ServerStream; import com.google.cloud.Timestamp; import com.google.cloud.spanner.Options.RpcPriority; import com.google.cloud.spanner.Options.TransactionOption; import com.google.cloud.spanner.Options.UpdateOption; +import com.google.spanner.v1.BatchWriteResponse; /** * Interface for all the APIs that are used to read/write data into a Cloud Spanner database. An @@ -191,6 +193,56 @@ CommitResponse writeWithOptions(Iterable mutations, TransactionOption. CommitResponse writeAtLeastOnceWithOptions( Iterable mutations, TransactionOption... options) throws SpannerException; + /** + * Applies batch of mutation groups in a collection of efficient transactions. The mutation groups + * are applied non-atomically in an unspecified order and thus, they must be independent of each + * other. Partial failure is possible, i.e., some mutation groups may have been applied + * successfully, while some may have failed. The results of individual batches are streamed into + * the response as and when the batches are applied. + * + *

One BatchWriteResponse can contain the results for multiple MutationGroups. Inspect the + * indexes field to determine the MutationGroups that the BatchWriteResponse is for. + * + *

The mutation groups may be applied more than once. This can lead to failures if the mutation + * groups are non-idempotent. For example, an insert that is replayed can return an {@link + * ErrorCode#ALREADY_EXISTS} error. For this reason, users of the library may prefer to use {@link + * #write(Iterable)} instead. However, {@code batchWriteAtLeastOnce()} method may be appropriate + * for non-atomically committing multiple mutation groups in a single RPC with low latency. + * + *

Example of BatchWriteAtLeastOnce + * + *

{@code
+   * Iterable mutationGroups =
+   *     ImmutableList.of(
+   *         MutationGroup.of(
+   *             Mutation.newInsertBuilder("FOO1").set("ID").to(1L).set("NAME").to("Bar1").build(),
+   *             Mutation.newInsertBuilder("FOO2").set("ID").to(2L).set("NAME").to("Bar2").build()),
+   *         MutationGroup.of(
+   *             Mutation.newInsertBuilder("FOO3").set("ID").to(3L).set("NAME").to("Bar3").build(),
+   *             Mutation.newInsertBuilder("FOO4").set("ID").to(4L).set("NAME").to("Bar4").build()),
+   *         MutationGroup.of(
+   *             Mutation.newInsertBuilder("FOO4").set("ID").to(4L).set("NAME").to("Bar4").build(),
+   *             Mutation.newInsertBuilder("FOO5").set("ID").to(5L).set("NAME").to("Bar5").build()),
+   *         MutationGroup.of(
+   *             Mutation.newInsertBuilder("FOO6").set("ID").to(6L).set("NAME").to("Bar6").build()));
+   * ServerStream responses =
+   *     dbClient.batchWriteAtLeastOnce(mutationGroups, Options.tag("batch-write-tag"));
+   * for (BatchWriteResponse response : responses) {
+   *   // Do something when a response is received.
+   * }
+   * }
+ * + * Options for a transaction can include: + * + * + */ + ServerStream batchWriteAtLeastOnce( + Iterable mutationGroups, TransactionOption... options) throws SpannerException; + /** * Returns a context in which a single read can be performed using {@link TimestampBound#strong()} * concurrency. This method will return a {@link ReadContext} that will not return the read diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java index a2a46b5a198..3835cb1f338 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import com.google.api.gax.rpc.ServerStream; import com.google.cloud.Timestamp; import com.google.cloud.spanner.Options.TransactionOption; import com.google.cloud.spanner.Options.UpdateOption; @@ -24,6 +25,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; import com.google.common.util.concurrent.ListenableFuture; +import com.google.spanner.v1.BatchWriteResponse; import io.opencensus.common.Scope; import io.opencensus.trace.Span; import io.opencensus.trace.Tracer; @@ -106,6 +108,21 @@ public CommitResponse writeAtLeastOnceWithOptions( } } + @Override + public ServerStream batchWriteAtLeastOnce( + final Iterable mutationGroups, final TransactionOption... options) + throws SpannerException { + Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); + try (Scope s = tracer.withSpan(span)) { + return runWithSessionRetry(session -> session.batchWriteAtLeastOnce(mutationGroups, options)); + } catch (RuntimeException e) { + TraceUtil.setWithFailure(span, e); + throw e; + } finally { + span.end(TraceUtil.END_SPAN_OPTIONS); + } + } + @Override public ReadContext singleUse() { Span span = tracer.spanBuilder(READ_ONLY_TRANSACTION).startSpan(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/MutationGroup.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/MutationGroup.java new file mode 100644 index 00000000000..101ffe48349 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/MutationGroup.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.spanner.v1.BatchWriteRequest; +import java.util.ArrayList; +import java.util.List; + +/** Represents a group of Cloud Spanner mutations to be committed together. */ +public class MutationGroup { + private final ImmutableList mutations; + + private MutationGroup(ImmutableList mutations) { + this.mutations = mutations; + } + + /** Creates a {@code MutationGroup} given a vararg of mutations. */ + public static MutationGroup of(Mutation... mutations) { + Preconditions.checkArgument(mutations.length > 0, "Should pass in at least one mutation."); + return new MutationGroup(ImmutableList.copyOf(mutations)); + } + + /** Creates a {@code MutationGroup} given an iterable of mutations. */ + public static MutationGroup of(Iterable mutations) { + return new MutationGroup(ImmutableList.copyOf(mutations)); + } + + /** Returns corresponding mutations for this MutationGroup. */ + public ImmutableList getMutations() { + return mutations; + } + + static BatchWriteRequest.MutationGroup toProto(final MutationGroup mutationGroup) { + List mutationsProto = new ArrayList<>(); + Mutation.toProto(mutationGroup.getMutations(), mutationsProto); + return BatchWriteRequest.MutationGroup.newBuilder().addAllMutations(mutationsProto).build(); + } + + static List toListProto( + final Iterable mutationGroups) { + List mutationGroupsProto = new ArrayList<>(); + for (MutationGroup group : mutationGroups) { + mutationGroupsProto.add(toProto(group)); + } + return mutationGroupsProto; + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 1f37a849b3b..002a00134f6 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -21,6 +21,7 @@ import com.google.api.core.ApiFuture; import com.google.api.core.SettableApiFuture; +import com.google.api.gax.rpc.ServerStream; import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbstractReadContext.MultiUseReadOnlyTransaction; import com.google.cloud.spanner.AbstractReadContext.SingleReadContext; @@ -35,6 +36,8 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; import com.google.protobuf.Empty; +import com.google.spanner.v1.BatchWriteRequest; +import com.google.spanner.v1.BatchWriteResponse; import com.google.spanner.v1.BeginTransactionRequest; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.RequestOptions; @@ -160,7 +163,6 @@ public CommitResponse writeAtLeastOnceWithOptions( Iterable mutations, TransactionOption... transactionOptions) throws SpannerException { setActive(null); - Options commitRequestOptions = Options.fromTransactionOptions(transactionOptions); List mutationsProto = new ArrayList<>(); Mutation.toProto(mutations, mutationsProto); final CommitRequest.Builder requestBuilder = @@ -172,15 +174,9 @@ public CommitResponse writeAtLeastOnceWithOptions( .setSingleUseTransaction( TransactionOptions.newBuilder() .setReadWrite(TransactionOptions.ReadWrite.getDefaultInstance())); - if (commitRequestOptions.hasPriority() || commitRequestOptions.hasTag()) { - RequestOptions.Builder requestOptionsBuilder = RequestOptions.newBuilder(); - if (commitRequestOptions.hasPriority()) { - requestOptionsBuilder.setPriority(commitRequestOptions.priority()); - } - if (commitRequestOptions.hasTag()) { - requestOptionsBuilder.setTransactionTag(commitRequestOptions.tag()); - } - requestBuilder.setRequestOptions(requestOptionsBuilder.build()); + RequestOptions commitRequestOptions = getRequestOptions(transactionOptions); + if (commitRequestOptions != null) { + requestBuilder.setRequestOptions(commitRequestOptions); } CommitRequest request = requestBuilder.build(); Span span = tracer.spanBuilder(SpannerImpl.COMMIT).startSpan(); @@ -195,6 +191,45 @@ public CommitResponse writeAtLeastOnceWithOptions( } } + private RequestOptions getRequestOptions(TransactionOption... transactionOptions) { + Options requestOptions = Options.fromTransactionOptions(transactionOptions); + if (requestOptions.hasPriority() || requestOptions.hasTag()) { + RequestOptions.Builder requestOptionsBuilder = RequestOptions.newBuilder(); + if (requestOptions.hasPriority()) { + requestOptionsBuilder.setPriority(requestOptions.priority()); + } + if (requestOptions.hasTag()) { + requestOptionsBuilder.setTransactionTag(requestOptions.tag()); + } + return requestOptionsBuilder.build(); + } + return null; + } + + @Override + public ServerStream batchWriteAtLeastOnce( + Iterable mutationGroups, TransactionOption... transactionOptions) + throws SpannerException { + setActive(null); + List mutationGroupsProto = + MutationGroup.toListProto(mutationGroups); + final BatchWriteRequest.Builder requestBuilder = + BatchWriteRequest.newBuilder().setSession(name).addAllMutationGroups(mutationGroupsProto); + RequestOptions batchWriteRequestOptions = getRequestOptions(transactionOptions); + if (batchWriteRequestOptions != null) { + requestBuilder.setRequestOptions(batchWriteRequestOptions); + } + Span span = tracer.spanBuilder(SpannerImpl.BATCH_WRITE).startSpan(); + try (Scope s = tracer.withSpan(span)) { + return spanner.getRpc().batchWriteAtLeastOnce(requestBuilder.build(), this.options); + } catch (Throwable e) { + TraceUtil.setWithFailure(span, e); + throw SpannerExceptionFactory.newSpannerException(e); + } finally { + span.end(TraceUtil.END_SPAN_OPTIONS); + } + } + @Override public ReadContext singleUse() { return singleUse(TimestampBound.strong()); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java index c97fbf9b643..eb9b365f345 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java @@ -43,6 +43,7 @@ import com.google.api.core.ApiFutures; import com.google.api.core.SettableApiFuture; import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.rpc.ServerStream; import com.google.cloud.Timestamp; import com.google.cloud.grpc.GrpcTransportOptions; import com.google.cloud.grpc.GrpcTransportOptions.ExecutorFactory; @@ -69,6 +70,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.Empty; +import com.google.spanner.v1.BatchWriteResponse; import com.google.spanner.v1.ResultSetStats; import io.opencensus.common.Scope; import io.opencensus.metrics.DerivedLongCumulative; @@ -1172,6 +1174,17 @@ public CommitResponse writeAtLeastOnceWithOptions( } } + @Override + public ServerStream batchWriteAtLeastOnce( + Iterable mutationGroups, TransactionOption... options) + throws SpannerException { + try { + return get().batchWriteAtLeastOnce(mutationGroups, options); + } finally { + close(); + } + } + @Override public ReadContext singleUse() { try { @@ -1465,6 +1478,18 @@ public CommitResponse writeAtLeastOnceWithOptions( } } + @Override + public ServerStream batchWriteAtLeastOnce( + Iterable mutationGroups, TransactionOption... options) + throws SpannerException { + try { + markUsed(); + return delegate.batchWriteAtLeastOnce(mutationGroups, options); + } catch (SpannerException e) { + throw lastException = e; + } + } + @Override public long executePartitionedUpdate(Statement stmt, UpdateOption... options) throws SpannerException { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 721be9cd762..5ff916bbe93 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -66,6 +66,7 @@ class SpannerImpl extends BaseService implements Spanner { static final String COMMIT = "CloudSpannerOperation.Commit"; static final String QUERY = "CloudSpannerOperation.ExecuteStreamingQuery"; static final String READ = "CloudSpannerOperation.ExecuteStreamingRead"; + static final String BATCH_WRITE = "CloudSpannerOperation.BatchWrite"; private static final Object CLIENT_ID_LOCK = new Object(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 3e6fc5fbcb8..bdf038f0be7 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -157,6 +157,8 @@ import com.google.spanner.admin.instance.v1.UpdateInstanceMetadata; import com.google.spanner.admin.instance.v1.UpdateInstanceRequest; import com.google.spanner.v1.BatchCreateSessionsRequest; +import com.google.spanner.v1.BatchWriteRequest; +import com.google.spanner.v1.BatchWriteResponse; import com.google.spanner.v1.BeginTransactionRequest; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.CommitResponse; @@ -1684,6 +1686,14 @@ public ServerStream executeStreamingPartitionedDml( return partitionedDmlStub.executeStreamingSqlCallable().call(request, context); } + @Override + public ServerStream batchWriteAtLeastOnce( + BatchWriteRequest request, @Nullable Map options) { + GrpcCallContext context = + newCallContext(options, request.getSession(), request, SpannerGrpc.getBatchWriteMethod()); + return spannerStub.batchWriteCallable().call(request, context); + } + @Override public StreamingCall executeQuery( ExecuteSqlRequest request, diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java index 53b4b97764f..39f798e0113 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java @@ -403,6 +403,9 @@ ApiFuture executeQueryAsync( ServerStream executeStreamingPartitionedDml( ExecuteSqlRequest request, @Nullable Map options, Duration timeout); + ServerStream batchWriteAtLeastOnce( + BatchWriteRequest request, @Nullable Map options); + /** * Executes a query with streaming result. * diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java index feb0f4b23c6..686efd3aff4 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java @@ -40,6 +40,7 @@ import com.google.api.gax.grpc.testing.LocalChannelProvider; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.ApiCallContext; +import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.StatusCode; import com.google.cloud.ByteArray; import com.google.cloud.NoCredentials; @@ -71,6 +72,8 @@ import com.google.protobuf.ListValue; import com.google.protobuf.NullValue; import com.google.rpc.RetryInfo; +import com.google.spanner.v1.BatchWriteRequest; +import com.google.spanner.v1.BatchWriteResponse; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.DeleteSessionRequest; import com.google.spanner.v1.ExecuteBatchDmlRequest; @@ -142,6 +145,31 @@ public class DatabaseClientImplTest { private static final Statement INVALID_UPDATE_STATEMENT = Statement.of("UPDATE NON_EXISTENT_TABLE SET BAR=1 WHERE BAZ=2"); private static final long UPDATE_COUNT = 1L; + private static final com.google.rpc.Status STATUS_OK = + com.google.rpc.Status.newBuilder().setCode(com.google.rpc.Code.OK_VALUE).build(); + private static final Iterable MUTATION_GROUPS = + ImmutableList.of( + MutationGroup.of( + Mutation.newInsertBuilder("FOO1").set("ID").to(1L).set("NAME").to("Bar1").build(), + Mutation.newInsertBuilder("FOO2").set("ID").to(2L).set("NAME").to("Bar2").build()), + MutationGroup.of( + Mutation.newInsertBuilder("FOO3").set("ID").to(3L).set("NAME").to("Bar3").build(), + Mutation.newInsertBuilder("FOO4").set("ID").to(4L).set("NAME").to("Bar4").build()), + MutationGroup.of( + Mutation.newInsertBuilder("FOO4").set("ID").to(4L).set("NAME").to("Bar4").build(), + Mutation.newInsertBuilder("FOO5").set("ID").to(5L).set("NAME").to("Bar5").build()), + MutationGroup.of( + Mutation.newInsertBuilder("FOO6").set("ID").to(6L).set("NAME").to("Bar6").build())); + private static final Iterable BATCH_WRITE_RESPONSES = + ImmutableList.of( + BatchWriteResponse.newBuilder() + .setStatus(STATUS_OK) + .addAllIndexes(ImmutableList.of(0, 1)) + .build(), + BatchWriteResponse.newBuilder() + .setStatus(STATUS_OK) + .addAllIndexes(ImmutableList.of(2, 3)) + .build()); private Spanner spanner; private Spanner spannerWithEmptySessionPool; private static ExecutorService executor; @@ -159,6 +187,7 @@ public static void startStaticServer() throws IOException { StatementResult.exception( INVALID_UPDATE_STATEMENT, Status.INVALID_ARGUMENT.withDescription("invalid statement").asRuntimeException())); + mockSpanner.setBatchWriteResult(BATCH_WRITE_RESPONSES); executor = Executors.newSingleThreadExecutor(); String uniqueName = InProcessServerBuilder.generateName(); @@ -645,7 +674,7 @@ public void testWriteAtLeastOnceWithOptions() { } @Test - public void writeAtLeastOnceWithTagOptions() { + public void testWriteAtLeastOnceWithTagOptions() { DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); client.writeAtLeastOnceWithOptions( @@ -663,6 +692,62 @@ public void writeAtLeastOnceWithTagOptions() { assertThat(commit.getRequestOptions().getRequestTag()).isEmpty(); } + @Test + public void testBatchWriteAtLeastOnceWithoutOptions() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + + ServerStream responseStream = client.batchWriteAtLeastOnce(MUTATION_GROUPS); + int idx = 0; + for (BatchWriteResponse response : responseStream) { + assertEquals( + response.getStatus(), + com.google.rpc.Status.newBuilder().setCode(com.google.rpc.Code.OK_VALUE).build()); + assertEquals(response.getIndexesList(), ImmutableList.of(idx, idx + 1)); + idx += 2; + } + + assertNotNull(responseStream); + List requests = mockSpanner.getRequestsOfType(BatchWriteRequest.class); + assertEquals(requests.size(), 1); + BatchWriteRequest request = requests.get(0); + assertEquals(request.getMutationGroupsCount(), 4); + assertEquals(request.getRequestOptions().getPriority(), Priority.PRIORITY_UNSPECIFIED); + } + + @Test + public void testBatchWriteAtLeastOnceWithOptions() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + ServerStream responseStream = + client.batchWriteAtLeastOnce(MUTATION_GROUPS, Options.priority(RpcPriority.LOW)); + for (BatchWriteResponse response : responseStream) {} + + assertNotNull(responseStream); + List requests = mockSpanner.getRequestsOfType(BatchWriteRequest.class); + assertEquals(requests.size(), 1); + BatchWriteRequest request = requests.get(0); + assertEquals(request.getMutationGroupsCount(), 4); + assertEquals(request.getRequestOptions().getPriority(), Priority.PRIORITY_LOW); + } + + @Test + public void testBatchWriteAtLeastOnceWithTagOptions() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + ServerStream responseStream = + client.batchWriteAtLeastOnce(MUTATION_GROUPS, Options.tag("app=spanner,env=test")); + for (BatchWriteResponse response : responseStream) {} + + assertNotNull(responseStream); + List requests = mockSpanner.getRequestsOfType(BatchWriteRequest.class); + assertEquals(requests.size(), 1); + BatchWriteRequest request = requests.get(0); + assertEquals(request.getMutationGroupsCount(), 4); + assertEquals(request.getRequestOptions().getTransactionTag(), "app=spanner,env=test"); + assertThat(request.getRequestOptions().getRequestTag()).isEmpty(); + } + @Test public void testExecuteQueryWithTag() { DatabaseClient client = diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index c742b008360..7bf9f51a4ea 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -42,6 +42,8 @@ import com.google.rpc.RetryInfo; import com.google.spanner.v1.BatchCreateSessionsRequest; import com.google.spanner.v1.BatchCreateSessionsResponse; +import com.google.spanner.v1.BatchWriteRequest; +import com.google.spanner.v1.BatchWriteResponse; import com.google.spanner.v1.BeginTransactionRequest; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.CommitResponse; @@ -598,6 +600,7 @@ private static void checkStreamException( private ConcurrentMap transactionLastUsed = new ConcurrentHashMap<>(); private int maxNumSessionsInOneBatch = 100; private int maxTotalSessions = Integer.MAX_VALUE; + private Iterable batchWriteResult = new ArrayList<>(); private AtomicInteger numSessionsCreated = new AtomicInteger(); private SimulatedExecutionTime beginTransactionExecutionTime = NO_EXECUTION_TIME; private SimulatedExecutionTime commitExecutionTime = NO_EXECUTION_TIME; @@ -678,6 +681,12 @@ public void putPartialStatementResult(StatementResult result) { } } + public void setBatchWriteResult(final Iterable responses) { + synchronized (lock) { + this.batchWriteResult = responses; + } + } + private StatementResult getResult(Statement statement) { StatementResult res; synchronized (lock) { @@ -1962,6 +1971,29 @@ public void commit(CommitRequest request, StreamObserver respons } } + @Override + public void batchWrite( + BatchWriteRequest request, StreamObserver responseObserver) { + requests.add(request); + Preconditions.checkNotNull(request.getSession()); + Session session = sessions.get(request.getSession()); + if (session == null) { + setSessionNotFound(request.getSession(), responseObserver); + return; + } + sessionLastUsed.put(session.getName(), Instant.now()); + try { + for (BatchWriteResponse response : batchWriteResult) { + responseObserver.onNext(response); + } + responseObserver.onCompleted(); + } catch (StatusRuntimeException t) { + responseObserver.onError(t); + } catch (Throwable t) { + responseObserver.onError(Status.INTERNAL.asRuntimeException()); + } + } + private void commitTransaction(ByteString transactionId) { transactions.remove(transactionId); isPartitionedDmlTransaction.remove(transactionId); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MutationGroupTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MutationGroupTest.java new file mode 100644 index 00000000000..99fcb8b6943 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MutationGroupTest.java @@ -0,0 +1,114 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableList; +import com.google.spanner.v1.BatchWriteRequest; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link MutationGroup}. */ +@RunWith(JUnit4.class) +public class MutationGroupTest { + private final Random random = new Random(); + + private Mutation getRandomMutation() { + return Mutation.newInsertBuilder(String.valueOf(random.nextInt())) + .set("ID") + .to(random.nextInt()) + .set("NAME") + .to(String.valueOf(random.nextInt())) + .build(); + } + + private BatchWriteRequest.MutationGroup getMutationGroupProto(ImmutableList mutations) { + List mutationsProto = new ArrayList<>(); + Mutation.toProto(mutations, mutationsProto); + return BatchWriteRequest.MutationGroup.newBuilder().addAllMutations(mutationsProto).build(); + } + + @Test + public void ofVarargTest() { + Mutation[] mutations = + new Mutation[] { + getRandomMutation(), getRandomMutation(), getRandomMutation(), getRandomMutation() + }; + MutationGroup mutationGroup = MutationGroup.of(mutations); + assertArrayEquals(mutations, mutationGroup.getMutations().toArray()); + assertEquals( + MutationGroup.toProto(mutationGroup), + getMutationGroupProto(ImmutableList.copyOf(mutations))); + } + + @Test + public void ofIterableTest() { + ImmutableList mutations = + ImmutableList.of( + getRandomMutation(), getRandomMutation(), getRandomMutation(), getRandomMutation()); + MutationGroup mutationGroup = MutationGroup.of(mutations); + assertEquals(mutations, mutationGroup.getMutations()); + assertEquals(MutationGroup.toProto(mutationGroup), getMutationGroupProto(mutations)); + } + + @Test + public void toProtoTest() { + Mutation[] mutations = + new Mutation[] { + getRandomMutation(), getRandomMutation(), getRandomMutation(), getRandomMutation() + }; + MutationGroup mutationGroup = MutationGroup.of(mutations); + assertEquals( + MutationGroup.toProto(mutationGroup), + getMutationGroupProto(ImmutableList.copyOf(mutations))); + } + + @Test + public void toListProtoTest() { + Mutation[] mutations1 = + new Mutation[] { + getRandomMutation(), getRandomMutation(), getRandomMutation(), getRandomMutation() + }; + Mutation[] mutations2 = + new Mutation[] { + getRandomMutation(), getRandomMutation(), getRandomMutation(), getRandomMutation() + }; + Mutation[] mutations3 = + new Mutation[] { + getRandomMutation(), getRandomMutation(), getRandomMutation(), getRandomMutation() + }; + List mutationGroups = + ImmutableList.of( + MutationGroup.of(mutations1), + MutationGroup.of(mutations2), + MutationGroup.of(mutations3)); + List mutationGroupsProto = + MutationGroup.toListProto(mutationGroups); + assertEquals( + mutationGroupsProto.get(0), getMutationGroupProto(ImmutableList.copyOf(mutations1))); + assertEquals( + mutationGroupsProto.get(1), getMutationGroupProto(ImmutableList.copyOf(mutations2))); + assertEquals( + mutationGroupsProto.get(2), getMutationGroupProto(ImmutableList.copyOf(mutations3))); + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITWriteTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITWriteTest.java index 020beb68ef9..1888bf6b9e5 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITWriteTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITWriteTest.java @@ -25,10 +25,12 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.Assume.assumeFalse; import static org.junit.Assume.assumeTrue; +import com.google.api.gax.rpc.ServerStream; import com.google.cloud.ByteArray; import com.google.cloud.Date; import com.google.cloud.Timestamp; @@ -41,6 +43,7 @@ import com.google.cloud.spanner.Key; import com.google.cloud.spanner.KeySet; import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.MutationGroup; import com.google.cloud.spanner.Options; import com.google.cloud.spanner.ParallelIntegrationTest; import com.google.cloud.spanner.ResultSet; @@ -53,6 +56,9 @@ import com.google.cloud.spanner.testing.EmulatorSpannerHelper; import com.google.common.collect.ImmutableList; import com.google.protobuf.NullValue; +import com.google.rpc.Code; +import com.google.rpc.Status; +import com.google.spanner.v1.BatchWriteResponse; import io.grpc.Context; import java.math.BigDecimal; import java.util.ArrayList; @@ -60,9 +66,11 @@ import java.util.Base64; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -100,51 +108,65 @@ public static List data() { private static DatabaseClient googleStandardSQLClient; private static DatabaseClient postgreSQLClient; - private static final String GOOGLE_STANDARD_SQL_SCHEMA = - "CREATE TABLE T (" - + " K STRING(MAX) NOT NULL," - + " BoolValue BOOL," - + " Int64Value INT64," - + " Float64Value FLOAT64," - + " StringValue STRING(MAX)," - + " JsonValue JSON," - + " BytesValue BYTES(MAX)," - + " TimestampValue TIMESTAMP OPTIONS (allow_commit_timestamp = true)," - + " DateValue DATE," - + " NumericValue NUMERIC," - + " BoolArrayValue ARRAY," - + " Int64ArrayValue ARRAY," - + " Float64ArrayValue ARRAY," - + " StringArrayValue ARRAY," - + " JsonArrayValue ARRAY," - + " BytesArrayValue ARRAY," - + " TimestampArrayValue ARRAY," - + " DateArrayValue ARRAY," - + " NumericArrayValue ARRAY," - + ") PRIMARY KEY (K)"; - - private static final String POSTGRESQL_SCHEMA = - "CREATE TABLE T (" - + " K VARCHAR PRIMARY KEY," - + " BoolValue BOOL," - + " Int64Value BIGINT," - + " Float64Value DOUBLE PRECISION," - + " StringValue VARCHAR," - + " JsonValue JSONB," - + " BytesValue BYTEA," - + " TimestampValue TIMESTAMPTZ," - + " DateValue DATE," - + " NumericValue NUMERIC," - + " BoolArrayValue BOOL[]," - + " Int64ArrayValue BIGINT[]," - + " Float64ArrayValue DOUBLE PRECISION[]," - + " StringArrayValue VARCHAR[]," - + " JsonArrayValue JSONB[]," - + " BytesArrayValue BYTEA[]," - + " TimestampArrayValue TIMESTAMPTZ[]," - + " DateArrayValue DATE[]," - + " NumericArrayValue NUMERIC[]" - + ")"; + private static final String[] GOOGLE_STANDARD_SQL_SCHEMA = + new String[] { + "CREATE TABLE T (" + + " K STRING(MAX) NOT NULL," + + " BoolValue BOOL," + + " Int64Value INT64," + + " Float64Value FLOAT64," + + " StringValue STRING(MAX)," + + " JsonValue JSON," + + " BytesValue BYTES(MAX)," + + " TimestampValue TIMESTAMP OPTIONS (allow_commit_timestamp = true)," + + " DateValue DATE," + + " NumericValue NUMERIC," + + " BoolArrayValue ARRAY," + + " Int64ArrayValue ARRAY," + + " Float64ArrayValue ARRAY," + + " StringArrayValue ARRAY," + + " JsonArrayValue ARRAY," + + " BytesArrayValue ARRAY," + + " TimestampArrayValue ARRAY," + + " DateArrayValue ARRAY," + + " NumericArrayValue ARRAY," + + ") PRIMARY KEY (K)", + "CREATE TABLE T1 (" + + " K1 STRING(MAX) NOT NULL," + + " K STRING(MAX) NOT NULL," + + " CONSTRAINT FK FOREIGN KEY (K) REFERENCES T(K)" + + ") PRIMARY KEY (K1)" + }; + + private static final String[] POSTGRESQL_SCHEMA = + new String[] { + "CREATE TABLE T (" + + " K VARCHAR PRIMARY KEY," + + " BoolValue BOOL," + + " Int64Value BIGINT," + + " Float64Value DOUBLE PRECISION," + + " StringValue VARCHAR," + + " JsonValue JSONB," + + " BytesValue BYTEA," + + " TimestampValue TIMESTAMPTZ," + + " DateValue DATE," + + " NumericValue NUMERIC," + + " BoolArrayValue BOOL[]," + + " Int64ArrayValue BIGINT[]," + + " Float64ArrayValue DOUBLE PRECISION[]," + + " StringArrayValue VARCHAR[]," + + " JsonArrayValue JSONB[]," + + " BytesArrayValue BYTEA[]," + + " TimestampArrayValue TIMESTAMPTZ[]," + + " DateArrayValue DATE[]," + + " NumericArrayValue NUMERIC[]" + + ")", + "CREATE TABLE T1 (" + + " K1 VARCHAR PRIMARY KEY," + + " K VARCHAR," + + " CONSTRAINT FK FOREIGN KEY (K) REFERENCES T(K)" + + ")" + }; /** Sequence used to generate unique keys. */ private static int seq; @@ -161,7 +183,7 @@ public static void setUpDatabase() if (!EmulatorSpannerHelper.isUsingEmulator()) { Database postgreSQLDatabase = env.getTestHelper() - .createTestDatabase(Dialect.POSTGRESQL, Collections.singletonList(POSTGRESQL_SCHEMA)); + .createTestDatabase(Dialect.POSTGRESQL, Arrays.asList(POSTGRESQL_SCHEMA)); postgreSQLClient = env.getTestHelper().getDatabaseClient(postgreSQLDatabase); } } @@ -192,9 +214,13 @@ private Mutation.WriteBuilder baseInsert() { } private Struct readLastRow(String... columns) { + return readRow("T", lastKey, columns); + } + + private Struct readRow(String table, String key, String... columns) { return client .singleUse(TimestampBound.strong()) - .readRow("T", Key.of(lastKey), Arrays.asList(columns)); + .readRow(table, Key.of(key), Arrays.asList(columns)); } @Test @@ -212,6 +238,71 @@ public void writeAtLeastOnce() { assertThat(row.getString(0)).isEqualTo("v1"); } + @Test + public void batchWriteAtLeastOnce() { + assumeFalse("Emulator does not support BatchWriteAtLeastOnce", isUsingEmulator()); + final String k1 = uniqueString(), k2 = uniqueString(), k3 = uniqueString(), k4 = uniqueString(); + lastKey = k3; + final List mutationGroups = + ImmutableList.of( + MutationGroup.of( + Mutation.newInsertOrUpdateBuilder("T") + .set("K") + .to(k1) + .set("StringValue") + .to("v1") + .set("BoolValue") + .to(true) + .build(), + Mutation.newInsertOrUpdateBuilder("T") + .set("K") + .to(k2) + .set("StringValue") + .to("v2") + .build()), + MutationGroup.of( + Mutation.newInsertOrUpdateBuilder("T") + .set("K") + .to(k3) + .set("StringValue") + .to("v1") + .set("BoolValue") + .to(false) + .build(), + Mutation.newInsertOrUpdateBuilder("T1").set("K1").to(k4).set("K").to(k3).build())); + ServerStream responses = client.batchWriteAtLeastOnce(mutationGroups); + Set responseIndexes = new HashSet<>(); + Set appliedMutationIndexes = new HashSet<>(); + for (BatchWriteResponse response : responses) { + responseIndexes.addAll(response.getIndexesList()); + if (response.getStatus().equals(Status.newBuilder().setCode(Code.OK_VALUE).build())) { + appliedMutationIndexes.addAll(response.getIndexesList()); + assertNotNull(response.getCommitTimestamp()); + } + } + assertEquals(responseIndexes, new HashSet<>(Arrays.asList(0, 1))); + + Struct row; + // assert row with key k1 + if (appliedMutationIndexes.contains(0)) { + row = readRow("T", k1, "StringValue", "BoolValue"); + assertEquals(row.getString(0), "v1"); + assertTrue(row.getBoolean(1)); + row = readRow("T", k2, "StringValue", "BoolValue"); + assertEquals(row.getString(0), "v2"); + assertTrue(row.isNull(1)); + } + + // assert row with key k4, and corresponding referencing table. + if (appliedMutationIndexes.contains(1)) { + row = readRow("T", k3, "StringValue", "BoolValue"); + assertEquals(row.getString(0), "v1"); + assertFalse(row.getBoolean(1)); + row = readRow("T1", k4, "K"); + assertEquals(row.getString(0), k3); + } + } + @Test public void testWriteReturnsCommitStats() { assumeFalse("Emulator does not return commit statistics", isUsingEmulator());