diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java index 19d8911ee8..835d411714 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java @@ -18,6 +18,7 @@ import com.google.api.core.ApiFuture; import com.google.api.core.SettableApiFuture; import com.google.api.gax.batching.FlowController; +import com.google.api.gax.rpc.FixedHeaderProvider; import com.google.auto.value.AutoValue; import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.ProtoData; import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializtionError; @@ -77,6 +78,11 @@ class ConnectionWorker implements AutoCloseable { */ private String streamName; + /* + * The location of this connection. + */ + private String location = null; + /* * The proto schema of rows to write. This schema can change during multiplexing. */ @@ -211,6 +217,7 @@ public static long getApiMaxRequestBytes() { public ConnectionWorker( String streamName, + String location, ProtoSchema writerSchema, long maxInflightRequests, long maxInflightBytes, @@ -223,6 +230,9 @@ public ConnectionWorker( this.hasMessageInWaitingQueue = lock.newCondition(); this.inflightReduced = lock.newCondition(); this.streamName = streamName; + if (location != null && !location.isEmpty()) { + this.location = location; + } this.maxRetryDuration = maxRetryDuration; if (writerSchema == null) { throw new StatusRuntimeException( @@ -236,6 +246,18 @@ public ConnectionWorker( this.waitingRequestQueue = new LinkedList(); this.inflightRequestQueue = new LinkedList(); // Always recreate a client for connection worker. + HashMap newHeaders = new HashMap<>(); + newHeaders.putAll(clientSettings.toBuilder().getHeaderProvider().getHeaders()); + if (this.location == null) { + newHeaders.put("x-goog-request-params", "write_stream=" + this.streamName); + } else { + newHeaders.put("x-goog-request-params", "write_location=" + this.location); + } + BigQueryWriteSettings stubSettings = + clientSettings + .toBuilder() + .setHeaderProvider(FixedHeaderProvider.create(newHeaders)) + .build(); this.client = BigQueryWriteClient.create(clientSettings); this.appendThread = @@ -297,6 +319,24 @@ public void run(Throwable finalStatus) { /** Schedules the writing of rows at given offset. */ ApiFuture append(StreamWriter streamWriter, ProtoRows rows, long offset) { + if (this.location != null && this.location != streamWriter.getLocation()) { + throw new StatusRuntimeException( + Status.fromCode(Code.INVALID_ARGUMENT) + .withDescription( + "StreamWriter with location " + + streamWriter.getLocation() + + " is scheduled to use a connection with location " + + this.location)); + } else if (this.location == null && streamWriter.getStreamName() != this.streamName) { + // Location is null implies this is non-multiplexed connection. + throw new StatusRuntimeException( + Status.fromCode(Code.INVALID_ARGUMENT) + .withDescription( + "StreamWriter with stream name " + + streamWriter.getStreamName() + + " is scheduled to use a connection with stream name " + + this.streamName)); + } Preconditions.checkNotNull(streamWriter); AppendRowsRequest.Builder requestBuilder = AppendRowsRequest.newBuilder(); requestBuilder.setProtoRows( @@ -322,6 +362,10 @@ Boolean isUserClosed() { } } + String getWriteLocation() { + return this.location; + } + private ApiFuture appendInternal( StreamWriter streamWriter, AppendRowsRequest message) { AppendRequestAndResponse requestWrapper = new AppendRequestAndResponse(message, streamWriter); diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java index 8fcb84165e..83be8ce52a 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java @@ -288,7 +288,8 @@ private ConnectionWorker createOrReuseConnectionWorker( String streamReference = streamWriter.getStreamName(); if (connectionWorkerPool.size() < currentMaxConnectionCount) { // Always create a new connection if we haven't reached current maximum. - return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema()); + return createConnectionWorker( + streamWriter.getStreamName(), streamWriter.getLocation(), streamWriter.getProtoSchema()); } else { ConnectionWorker existingBestConnection = pickBestLoadConnection( @@ -304,7 +305,10 @@ private ConnectionWorker createOrReuseConnectionWorker( if (currentMaxConnectionCount > settings.maxConnectionsPerRegion()) { currentMaxConnectionCount = settings.maxConnectionsPerRegion(); } - return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema()); + return createConnectionWorker( + streamWriter.getStreamName(), + streamWriter.getLocation(), + streamWriter.getProtoSchema()); } else { // Stick to the original connection if all the connections are overwhelmed. if (existingConnectionWorker != null) { @@ -359,8 +363,8 @@ static ConnectionWorker pickBestLoadConnection( * a single stream reference. This is because createConnectionWorker(...) is called via * computeIfAbsent(...) which is at most once per key. */ - private ConnectionWorker createConnectionWorker(String streamName, ProtoSchema writeSchema) - throws IOException { + private ConnectionWorker createConnectionWorker( + String streamName, String location, ProtoSchema writeSchema) throws IOException { if (enableTesting) { // Though atomic integer is super lightweight, add extra if check in case adding future logic. testValueCreateConnectionCount.getAndIncrement(); @@ -368,6 +372,7 @@ private ConnectionWorker createConnectionWorker(String streamName, ProtoSchema w ConnectionWorker connectionWorker = new ConnectionWorker( streamName, + location, writeSchema, maxInflightRequests, maxInflightBytes, diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java index ffc1290a78..b21a52a63d 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java @@ -208,6 +208,7 @@ private StreamWriter(Builder builder) throws IOException { SingleConnectionOrConnectionPool.ofSingleConnection( new ConnectionWorker( builder.streamName, + builder.location, builder.writerSchema, builder.maxInflightRequest, builder.maxInflightBytes, diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java index 980772b2ff..e558d567c8 100644 --- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java @@ -430,6 +430,7 @@ private StreamWriter getTestStreamWriter(String streamName) throws IOException { return StreamWriter.newBuilder(streamName) .setWriterSchema(createProtoSchema()) .setTraceId(TEST_TRACE_ID) + .setLocation("us") .setCredentialsProvider(NoCredentialsProvider.create()) .setChannelProvider(serviceHelper.createChannelProvider()) .build(); diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java index fbd0850ee0..13711bddd0 100644 --- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java @@ -39,6 +39,7 @@ import java.util.List; import java.util.UUID; import java.util.concurrent.ExecutionException; +import java.util.logging.Logger; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -46,6 +47,7 @@ @RunWith(JUnit4.class) public class ConnectionWorkerTest { + private static final Logger log = Logger.getLogger(StreamWriter.class.getName()); private static final String TEST_STREAM_1 = "projects/p1/datasets/d1/tables/t1/streams/s1"; private static final String TEST_STREAM_2 = "projects/p2/datasets/d2/tables/t2/streams/s2"; private static final String TEST_TRACE_ID = "DATAFLOW:job_id"; @@ -84,10 +86,12 @@ public void testMultiplexedAppendSuccess() throws Exception { StreamWriter sw1 = StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema("foo")) + .setLocation("us") .build(); StreamWriter sw2 = StreamWriter.newBuilder(TEST_STREAM_2, client) .setWriterSchema(createProtoSchema("complicate")) + .setLocation("us") .build(); // We do a pattern of: // send to stream1, string1 @@ -205,11 +209,20 @@ public void testAppendInSameStream_switchSchema() throws Exception { // send to stream1, schema1 // ... StreamWriter sw1 = - StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build(); + StreamWriter.newBuilder(TEST_STREAM_1, client) + .setLocation("us") + .setWriterSchema(schema1) + .build(); StreamWriter sw2 = - StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema2).build(); + StreamWriter.newBuilder(TEST_STREAM_1, client) + .setLocation("us") + .setWriterSchema(schema2) + .build(); StreamWriter sw3 = - StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema3).build(); + StreamWriter.newBuilder(TEST_STREAM_1, client) + .setLocation("us") + .setWriterSchema(schema3) + .build(); for (long i = 0; i < appendCount; i++) { switch ((int) i % 4) { case 0: @@ -305,10 +318,14 @@ public void testAppendInSameStream_switchSchema() throws Exception { public void testAppendButInflightQueueFull() throws Exception { ProtoSchema schema1 = createProtoSchema("foo"); StreamWriter sw1 = - StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build(); + StreamWriter.newBuilder(TEST_STREAM_1, client) + .setLocation("us") + .setWriterSchema(schema1) + .build(); ConnectionWorker connectionWorker = new ConnectionWorker( TEST_STREAM_1, + "us", createProtoSchema("foo"), 6, 100000, @@ -356,10 +373,14 @@ public void testAppendButInflightQueueFull() throws Exception { public void testThrowExceptionWhileWithinAppendLoop() throws Exception { ProtoSchema schema1 = createProtoSchema("foo"); StreamWriter sw1 = - StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build(); + StreamWriter.newBuilder(TEST_STREAM_1, client) + .setLocation("us") + .setWriterSchema(schema1) + .build(); ConnectionWorker connectionWorker = new ConnectionWorker( TEST_STREAM_1, + "us", createProtoSchema("foo"), 100000, 100000, @@ -411,6 +432,69 @@ public void testThrowExceptionWhileWithinAppendLoop() throws Exception { assertThat(ex.getCause()).hasMessageThat().contains("Any exception can happen."); } + @Test + public void testLocationMismatch() throws Exception { + ProtoSchema schema1 = createProtoSchema("foo"); + StreamWriter sw1 = + StreamWriter.newBuilder(TEST_STREAM_1, client) + .setWriterSchema(schema1) + .setLocation("eu") + .build(); + ConnectionWorker connectionWorker = + new ConnectionWorker( + TEST_STREAM_1, + "us", + createProtoSchema("foo"), + 100000, + 100000, + Duration.ofSeconds(100), + FlowController.LimitExceededBehavior.Block, + TEST_TRACE_ID, + client.getSettings()); + StatusRuntimeException ex = + assertThrows( + StatusRuntimeException.class, + () -> + sendTestMessage( + connectionWorker, + sw1, + createFooProtoRows(new String[] {String.valueOf(0)}), + 0)); + assertEquals( + "INVALID_ARGUMENT: StreamWriter with location eu is scheduled to use a connection with location us", + ex.getMessage()); + } + + @Test + public void testStreamNameMismatch() throws Exception { + ProtoSchema schema1 = createProtoSchema("foo"); + StreamWriter sw1 = + StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build(); + ConnectionWorker connectionWorker = + new ConnectionWorker( + TEST_STREAM_2, + null, + createProtoSchema("foo"), + 100000, + 100000, + Duration.ofSeconds(100), + FlowController.LimitExceededBehavior.Block, + TEST_TRACE_ID, + client.getSettings()); + StatusRuntimeException ex = + assertThrows( + StatusRuntimeException.class, + () -> + sendTestMessage( + connectionWorker, + sw1, + createFooProtoRows(new String[] {String.valueOf(0)}), + 0)); + assertEquals( + "INVALID_ARGUMENT: StreamWriter with stream name projects/p1/datasets/d1/tables/t1/streams/s1 is scheduled to use a connection with stream name projects/p2/datasets/d2/tables/t2/streams/s2", + ex.getMessage()); + } + @Test public void testExponentialBackoff() throws Exception { assertThat(ConnectionWorker.calculateSleepTimeMilli(0)).isEqualTo(1); @@ -440,6 +524,7 @@ private ConnectionWorker createConnectionWorker( throws IOException { return new ConnectionWorker( streamName, + "us", createProtoSchema("foo"), maxRequests, maxBytes,