Skip to content

Commit

Permalink
fix: allow getMetadata() calls before calling next() (#3111)
Browse files Browse the repository at this point in the history
* fix: allow getMetadata() calls before calling next()

Calling getMetadata() on a MergedResultSet should be allowed without
first calling next() in order to be consistent with other ResultSets
that are returned by the Connection API.

* fix: remove unnecessary partitions
  • Loading branch information
olavloite authored May 27, 2024
1 parent 7e7c814 commit 39902c3
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static com.google.common.base.Preconditions.checkState;

import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.ForwardingStructReader;
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.SpannerException;
Expand All @@ -30,6 +31,7 @@
import com.google.spanner.v1.ResultSetStats;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
Expand All @@ -47,15 +49,18 @@ static class PartitionExecutor implements Runnable {
private final Connection connection;
private final String partitionId;
private final LinkedBlockingDeque<PartitionExecutorResult> queue;
private final CountDownLatch metadataAvailableLatch;
private final AtomicBoolean shouldStop = new AtomicBoolean();

PartitionExecutor(
Connection connection,
String partitionId,
LinkedBlockingDeque<PartitionExecutorResult> queue) {
LinkedBlockingDeque<PartitionExecutorResult> queue,
CountDownLatch metadataAvailableLatch) {
this.connection = Preconditions.checkNotNull(connection);
this.partitionId = Preconditions.checkNotNull(partitionId);
this.queue = queue;
this.metadataAvailableLatch = Preconditions.checkNotNull(metadataAvailableLatch);
}

@Override
Expand All @@ -68,6 +73,7 @@ public void run() {
queue.put(
PartitionExecutorResult.dataAndMetadata(
row, resultSet.getType(), resultSet.getMetadata()));
metadataAvailableLatch.countDown();
first = false;
} else {
queue.put(PartitionExecutorResult.data(row));
Expand All @@ -82,9 +88,11 @@ public void run() {
queue.put(
PartitionExecutorResult.typeAndMetadata(
resultSet.getType(), resultSet.getMetadata()));
metadataAvailableLatch.countDown();
}
} catch (Throwable exception) {
putWithoutInterruptPropagation(PartitionExecutorResult.exception(exception));
metadataAvailableLatch.countDown();
} finally {
// Emit a special 'finished' result to ensure that the row producer is not blocked on a
// queue that never receives any more results. This ensures that we can safely block on
Expand Down Expand Up @@ -215,6 +223,7 @@ private static class RowProducerImpl implements RowProducer {
private final AtomicInteger finishedCounter;
private final LinkedBlockingDeque<PartitionExecutorResult> queue;
private ResultSetMetadata metadata;
private final CountDownLatch metadataAvailableLatch = new CountDownLatch(1);
private Type type;
private Struct currentRow;
private Throwable exception;
Expand Down Expand Up @@ -243,7 +252,7 @@ private static class RowProducerImpl implements RowProducer {
this.finishedCounter = new AtomicInteger(partitions.size());
for (String partition : partitions) {
PartitionExecutor partitionExecutor =
new PartitionExecutor(connection, partition, this.queue);
new PartitionExecutor(connection, partition, this.queue, this.metadataAvailableLatch);
this.partitionExecutors.add(partitionExecutor);
this.executor.submit(partitionExecutor);
}
Expand Down Expand Up @@ -310,8 +319,27 @@ public Struct get() {
return currentRow;
}

private PartitionExecutorResult getFirstResult() {
try {
metadataAvailableLatch.await();
} catch (InterruptedException interruptedException) {
throw SpannerExceptionFactory.propagateInterrupt(interruptedException);
}
PartitionExecutorResult result = queue.peek();
if (result == null) {
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.FAILED_PRECONDITION, "Thread-unsafe access to ResultSet");
}
if (result.exception != null) {
throw SpannerExceptionFactory.asSpannerException(result.exception);
}
return result;
}

public ResultSetMetadata getMetadata() {
checkState(metadata != null, "next() call required");
if (metadata == null) {
return getFirstResult().metadata;
}
return metadata;
}

Expand All @@ -326,7 +354,9 @@ public int getParallelism() {
}

public Type getType() {
checkState(type != null, "next() call required");
if (type == null) {
return getFirstResult().type;
}
return type;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,94 @@ public void testRunEmptyPartitionedQuery() {
assertEquals(1, mockSpanner.countRequestsOfType(PartitionQueryRequest.class));
}

@Test
public void testGetMetadataWithoutNextCall() {
int generatedRowCount = 1;
RandomResultSetGenerator generator = new RandomResultSetGenerator(generatedRowCount);
Statement statement =
Statement.newBuilder("select * from random_table where active=@active")
.bind("active")
.to(true)
.build();
mockSpanner.putStatementResult(StatementResult.query(statement, generator.generate()));

int maxPartitions = 1;
try (Connection connection = createConnection()) {
connection.setAutocommit(true);
try (PartitionedQueryResultSet resultSet =
connection.runPartitionedQuery(
statement, PartitionOptions.newBuilder().setMaxPartitions(maxPartitions).build())) {
assertNotNull(resultSet.getMetadata());
assertEquals(24, resultSet.getMetadata().getRowType().getFieldsCount());
assertNotNull(resultSet.getType());
assertEquals(24, resultSet.getType().getStructFields().size());

assertTrue(resultSet.next());
assertNotNull(resultSet.getMetadata());
assertEquals(24, resultSet.getMetadata().getRowType().getFieldsCount());
assertNotNull(resultSet.getType());
assertEquals(24, resultSet.getType().getStructFields().size());

assertFalse(resultSet.next());
}
}
assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class));
assertEquals(1, mockSpanner.countRequestsOfType(PartitionQueryRequest.class));
}

@Test
public void testGetMetadataWithoutNextCallOnEmptyResultSet() {
int generatedRowCount = 0;
RandomResultSetGenerator generator = new RandomResultSetGenerator(generatedRowCount);
Statement statement =
Statement.newBuilder("select * from random_table where active=@active")
.bind("active")
.to(true)
.build();
mockSpanner.putStatementResult(StatementResult.query(statement, generator.generate()));

int maxPartitions = 1;
try (Connection connection = createConnection()) {
connection.setAutocommit(true);
try (PartitionedQueryResultSet resultSet =
connection.runPartitionedQuery(
statement, PartitionOptions.newBuilder().setMaxPartitions(maxPartitions).build())) {
assertNotNull(resultSet.getMetadata());
assertEquals(24, resultSet.getMetadata().getRowType().getFieldsCount());
assertNotNull(resultSet.getType());
assertEquals(24, resultSet.getType().getStructFields().size());

assertFalse(resultSet.next());
}
}
assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class));
assertEquals(1, mockSpanner.countRequestsOfType(PartitionQueryRequest.class));
}

@Test
public void testGetMetadataWithoutNextCallOnResultSetWithError() {
Statement statement =
Statement.newBuilder("select * from random_table where active=@active")
.bind("active")
.to(true)
.build();
mockSpanner.putStatementResult(
StatementResult.exception(statement, Status.NOT_FOUND.asRuntimeException()));

int maxPartitions = 1;
try (Connection connection = createConnection()) {
connection.setAutocommit(true);
try (PartitionedQueryResultSet resultSet =
connection.runPartitionedQuery(
statement, PartitionOptions.newBuilder().setMaxPartitions(maxPartitions).build())) {
assertThrows(SpannerException.class, resultSet::getMetadata);
assertThrows(SpannerException.class, resultSet::getType);
}
}
assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class));
assertEquals(1, mockSpanner.countRequestsOfType(PartitionQueryRequest.class));
}

@Test
public void testRunPartitionedQueryUsingSql() {
int generatedRowCount = 20;
Expand Down

0 comments on commit 39902c3

Please sign in to comment.