DDL statements will return an {@link ApiFuture} containing a {@link Void} that is done
* when the DDL statement has been applied successfully, or that throws an {@link
@@ -894,31 +894,33 @@ default RpcPriority getRPCPriority() {
AsyncStatementResult executeAsync(Statement statement);
/**
- * Executes the given statement as a query and returns the result as a {@link ResultSet}. This
- * method blocks and waits for a response from Spanner. If the statement does not contain a valid
- * query, the method will throw a {@link SpannerException}.
+ * Executes the given statement (a query or a DML statement with returning clause) and returns the
+ * result as a {@link ResultSet}. This method blocks and waits for a response from Spanner. If the
+ * statement does not contain a valid query or a DML statement with returning clause, the method
+ * will throw a {@link SpannerException}.
*
- * @param query The query statement to execute
+ * @param query The query statement or DML statement with returning clause to execute
* @param options the options to configure the query
- * @return a {@link ResultSet} with the results of the query
+ * @return a {@link ResultSet} with the results of the statement
*/
ResultSet executeQuery(Statement query, QueryOption... options);
/**
- * Executes the given statement asynchronously as a query and returns the result as an {@link
- * AsyncResultSet}. This method is guaranteed to be non-blocking. If the statement does not
- * contain a valid query, the method will throw a {@link SpannerException}.
+ * Executes the given statement (a query or a DML statement with returning clause) asynchronously
+ * and returns the result as an {@link AsyncResultSet}. This method is guaranteed to be
+ * non-blocking. If the statement does not contain a valid query or a DML statement with returning
+ * clause, the method will throw a {@link SpannerException}.
*
* See {@link AsyncResultSet#setCallback(java.util.concurrent.Executor,
* com.google.cloud.spanner.AsyncResultSet.ReadyCallback)} for more information on how to consume
- * the results of the query asynchronously.
+ * the results of the statement asynchronously.
*
*
It is also possible to consume the returned {@link AsyncResultSet} in the same way as a
* normal {@link ResultSet}, i.e. in a while-loop calling {@link AsyncResultSet#next()}.
*
- * @param query The query statement to execute
+ * @param query The query statement or DML statement with returning clause to execute
* @param options the options to configure the query
- * @return an {@link AsyncResultSet} with the results of the query
+ * @return an {@link AsyncResultSet} with the results of the statement
*/
AsyncResultSet executeQueryAsync(Statement query, QueryOption... options);
@@ -951,8 +953,8 @@ default RpcPriority getRPCPriority() {
ResultSet analyzeQuery(Statement query, QueryAnalyzeMode queryMode);
/**
- * Executes the given statement as a DML statement. If the statement does not contain a valid DML
- * statement, the method will throw a {@link SpannerException}.
+ * Executes the given statement as a simple DML statement. If the statement does not contain a
+ * valid DML statement, the method will throw a {@link SpannerException}.
*
* @param update The update statement to execute
* @return the number of records that were inserted/updated/deleted by this statement
@@ -972,8 +974,9 @@ default ResultSetStats analyzeUpdate(Statement update, QueryAnalyzeMode analyzeM
}
/**
- * Executes the given statement asynchronously as a DML statement. If the statement does not
- * contain a valid DML statement, the method will throw a {@link SpannerException}.
+ * Executes the given statement asynchronously as a simple DML statement. If the statement does
+ * not contain a simple DML statement, the method will throw a {@link SpannerException}. A DML
+ * statement with returning clause will throw a {@link SpannerException}.
*
*
This method is guaranteed to be non-blocking.
*
@@ -984,8 +987,9 @@ default ResultSetStats analyzeUpdate(Statement update, QueryAnalyzeMode analyzeM
ApiFuture executeUpdateAsync(Statement update);
/**
- * Executes a list of DML statements in a single request. The statements will be executed in order
- * and the semantics is the same as if each statement is executed by {@link
+ * Executes a list of DML statements (can be simple DML statements or DML statements with
+ * returning clause) in a single request. The statements will be executed in order and the
+ * semantics is the same as if each statement is executed by {@link
* Connection#executeUpdate(Statement)} in a loop. This method returns an array of long integers,
* each representing the number of rows modified by each statement.
*
@@ -1006,8 +1010,9 @@ default ResultSetStats analyzeUpdate(Statement update, QueryAnalyzeMode analyzeM
long[] executeBatchUpdate(Iterable updates);
/**
- * Executes a list of DML statements in a single request. The statements will be executed in order
- * and the semantics is the same as if each statement is executed by {@link
+ * Executes a list of DML statements (can be simple DML statements or DML statements with
+ * returning clause) in a single request. The statements will be executed in order and the
+ * semantics is the same as if each statement is executed by {@link
* Connection#executeUpdate(Statement)} in a loop. This method returns an {@link ApiFuture} that
* contains an array of long integers, each representing the number of rows modified by each
* statement.
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
index f0c772acd9e..e935a607a83 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
@@ -853,6 +853,9 @@ public StatementResult execute(Statement statement) {
case QUERY:
return StatementResultImpl.of(internalExecuteQuery(parsedStatement, AnalyzeMode.NONE));
case UPDATE:
+ if (parsedStatement.hasReturningClause()) {
+ return StatementResultImpl.of(internalExecuteQuery(parsedStatement, AnalyzeMode.NONE));
+ }
return StatementResultImpl.of(get(internalExecuteUpdateAsync(parsedStatement)));
case DDL:
get(executeDdlAsync(parsedStatement));
@@ -881,6 +884,10 @@ public AsyncStatementResult executeAsync(Statement statement) {
return AsyncStatementResultImpl.of(
internalExecuteQueryAsync(parsedStatement, AnalyzeMode.NONE));
case UPDATE:
+ if (parsedStatement.hasReturningClause()) {
+ return AsyncStatementResultImpl.of(
+ internalExecuteQueryAsync(parsedStatement, AnalyzeMode.NONE));
+ }
return AsyncStatementResultImpl.of(internalExecuteUpdateAsync(parsedStatement));
case DDL:
return AsyncStatementResultImpl.noResult(executeDdlAsync(parsedStatement));
@@ -918,7 +925,7 @@ private ResultSet parseAndExecuteQuery(
Preconditions.checkNotNull(analyzeMode);
ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG);
ParsedStatement parsedStatement = getStatementParser().parse(query, this.queryOptions);
- if (parsedStatement.isQuery()) {
+ if (parsedStatement.isQuery() || parsedStatement.isUpdate()) {
switch (parsedStatement.getType()) {
case CLIENT_SIDE:
return parsedStatement
@@ -928,6 +935,19 @@ private ResultSet parseAndExecuteQuery(
case QUERY:
return internalExecuteQuery(parsedStatement, analyzeMode, options);
case UPDATE:
+ if (parsedStatement.hasReturningClause()) {
+ // Cannot execute DML statement with returning clause in read-only mode or in
+ // READ_ONLY_TRANSACTION transaction mode.
+ if (this.isReadOnly()
+ || (this.isInTransaction()
+ && this.getTransactionMode() == TransactionMode.READ_ONLY_TRANSACTION)) {
+ throw SpannerExceptionFactory.newSpannerException(
+ ErrorCode.FAILED_PRECONDITION,
+ "DML statement with returning clause cannot be executed in read-only mode: "
+ + parsedStatement.getSqlWithoutComments());
+ }
+ return internalExecuteQuery(parsedStatement, analyzeMode, options);
+ }
case DDL:
case UNKNOWN:
default:
@@ -935,7 +955,8 @@ private ResultSet parseAndExecuteQuery(
}
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT,
- "Statement is not a query: " + parsedStatement.getSqlWithoutComments());
+ "Statement is not a query or DML with returning clause: "
+ + parsedStatement.getSqlWithoutComments());
}
private AsyncResultSet parseAndExecuteQueryAsync(
@@ -943,7 +964,7 @@ private AsyncResultSet parseAndExecuteQueryAsync(
Preconditions.checkNotNull(query);
ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG);
ParsedStatement parsedStatement = getStatementParser().parse(query, this.queryOptions);
- if (parsedStatement.isQuery()) {
+ if (parsedStatement.isQuery() || parsedStatement.isUpdate()) {
switch (parsedStatement.getType()) {
case CLIENT_SIDE:
return ResultSets.toAsyncResultSet(
@@ -956,6 +977,19 @@ private AsyncResultSet parseAndExecuteQueryAsync(
case QUERY:
return internalExecuteQueryAsync(parsedStatement, analyzeMode, options);
case UPDATE:
+ if (parsedStatement.hasReturningClause()) {
+ // Cannot execute DML statement with returning clause in read-only mode or in
+ // READ_ONLY_TRANSACTION transaction mode.
+ if (this.isReadOnly()
+ || (this.isInTransaction()
+ && this.getTransactionMode() == TransactionMode.READ_ONLY_TRANSACTION)) {
+ throw SpannerExceptionFactory.newSpannerException(
+ ErrorCode.FAILED_PRECONDITION,
+ "DML statement with returning clause cannot be executed in read-only mode: "
+ + parsedStatement.getSqlWithoutComments());
+ }
+ return internalExecuteQueryAsync(parsedStatement, analyzeMode, options);
+ }
case DDL:
case UNKNOWN:
default:
@@ -963,7 +997,8 @@ private AsyncResultSet parseAndExecuteQueryAsync(
}
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT,
- "Statement is not a query: " + parsedStatement.getSqlWithoutComments());
+ "Statement is not a query or DML with returning clause: "
+ + parsedStatement.getSqlWithoutComments());
}
@Override
@@ -974,6 +1009,13 @@ public long executeUpdate(Statement update) {
if (parsedStatement.isUpdate()) {
switch (parsedStatement.getType()) {
case UPDATE:
+ if (parsedStatement.hasReturningClause()) {
+ throw SpannerExceptionFactory.newSpannerException(
+ ErrorCode.FAILED_PRECONDITION,
+ "DML statement with returning clause cannot be executed using executeUpdate: "
+ + parsedStatement.getSqlWithoutComments()
+ + ". Please use executeQuery instead.");
+ }
return get(internalExecuteUpdateAsync(parsedStatement));
case CLIENT_SIDE:
case QUERY:
@@ -995,6 +1037,13 @@ public ApiFuture executeUpdateAsync(Statement update) {
if (parsedStatement.isUpdate()) {
switch (parsedStatement.getType()) {
case UPDATE:
+ if (parsedStatement.hasReturningClause()) {
+ throw SpannerExceptionFactory.newSpannerException(
+ ErrorCode.FAILED_PRECONDITION,
+ "DML statement with returning clause cannot be executed using executeUpdateAsync: "
+ + parsedStatement.getSqlWithoutComments()
+ + ". Please use executeQueryAsync instead.");
+ }
return internalExecuteUpdateAsync(parsedStatement);
case CLIENT_SIDE:
case QUERY:
@@ -1141,8 +1190,9 @@ private ResultSet internalExecuteQuery(
final QueryOption... options) {
Preconditions.checkArgument(
statement.getType() == StatementType.QUERY
- || (statement.getType() == StatementType.UPDATE && analyzeMode != AnalyzeMode.NONE),
- "Statement must either be a query or a DML mode with analyzeMode!=NONE");
+ || (statement.getType() == StatementType.UPDATE
+ && (analyzeMode != AnalyzeMode.NONE || statement.hasReturningClause())),
+ "Statement must either be a query or a DML mode with analyzeMode!=NONE or returning clause");
UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork();
return get(
transaction.executeQueryAsync(
@@ -1154,7 +1204,9 @@ private AsyncResultSet internalExecuteQueryAsync(
final AnalyzeMode analyzeMode,
final QueryOption... options) {
Preconditions.checkArgument(
- statement.getType() == StatementType.QUERY, "Statement must be a query");
+ (statement.getType() == StatementType.QUERY)
+ || (statement.getType() == StatementType.UPDATE && statement.hasReturningClause()),
+ "Statement must be a query or DML with returning clause.");
UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork();
return ResultSets.toAsyncResultSet(
transaction.executeQueryAsync(
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java
index 6eba0ce1b29..012bfbba875 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java
@@ -25,10 +25,15 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
+import java.util.regex.Pattern;
import javax.annotation.Nullable;
@InternalApi
public class PostgreSQLStatementParser extends AbstractStatementParser {
+ private static final Pattern RETURNING_PATTERN = Pattern.compile("returning[ '(\"*]");
+ private static final Pattern AS_RETURNING_PATTERN = Pattern.compile("[ ')\"]as returning[ '(\"]");
+ private static final String RETURNING_STRING = "returning";
+
PostgreSQLStatementParser() throws CompileException {
super(
Collections.unmodifiableSet(
@@ -65,6 +70,8 @@ String removeCommentsAndTrimInternal(String sql) {
Preconditions.checkNotNull(sql);
boolean isInSingleLineComment = false;
int multiLineCommentLevel = 0;
+ boolean whitespaceBeforeOrAfterMultiLineComment = false;
+ int multiLineCommentStartIdx = -1;
StringBuilder res = new StringBuilder(sql.length());
int index = 0;
while (index < sql.length()) {
@@ -78,6 +85,19 @@ String removeCommentsAndTrimInternal(String sql) {
} else if (multiLineCommentLevel > 0) {
if (sql.length() > index + 1 && c == ASTERISK && sql.charAt(index + 1) == SLASH) {
multiLineCommentLevel--;
+ if (multiLineCommentLevel == 0) {
+ if (!whitespaceBeforeOrAfterMultiLineComment && (sql.length() > index + 2)) {
+ whitespaceBeforeOrAfterMultiLineComment =
+ Character.isWhitespace(sql.charAt(index + 2));
+ }
+ // If the multiline comment does not have any whitespace before or after it, and it is
+ // neither at the start nor at the end of SQL string, append an extra space.
+ if (!whitespaceBeforeOrAfterMultiLineComment
+ && (multiLineCommentStartIdx != 0)
+ && (index != sql.length() - 2)) {
+ res.append(' ');
+ }
+ }
index++;
} else if (sql.length() > index + 1 && c == SLASH && sql.charAt(index + 1) == ASTERISK) {
multiLineCommentLevel++;
@@ -92,6 +112,10 @@ String removeCommentsAndTrimInternal(String sql) {
continue;
} else if (sql.length() > index + 1 && c == SLASH && sql.charAt(index + 1) == ASTERISK) {
multiLineCommentLevel++;
+ if (index >= 1) {
+ whitespaceBeforeOrAfterMultiLineComment = Character.isWhitespace(sql.charAt(index - 1));
+ }
+ multiLineCommentStartIdx = index;
index += 2;
continue;
} else {
@@ -120,7 +144,7 @@ String parseDollarQuotedString(String sql, int index) {
if (c == DOLLAR) {
return tag.toString();
}
- if (!Character.isJavaIdentifierPart(c)) {
+ if (!isValidIdentifierChar(c)) {
break;
}
tag.append(c);
@@ -267,4 +291,77 @@ private void appendIfNotNull(
result.append(prefix).append(tag).append(suffix);
}
}
+
+ private boolean isValidIdentifierFirstChar(char c) {
+ return Character.isLetter(c) || c == UNDERSCORE;
+ }
+
+ private boolean isValidIdentifierChar(char c) {
+ return isValidIdentifierFirstChar(c) || Character.isDigit(c) || c == DOLLAR;
+ }
+
+ private boolean checkCharPrecedingReturning(char ch) {
+ return (ch == SPACE)
+ || (ch == SINGLE_QUOTE)
+ || (ch == CLOSE_PARENTHESIS)
+ || (ch == DOUBLE_QUOTE)
+ || (ch == DOLLAR);
+ }
+
+ private boolean checkCharPrecedingSubstrWithReturning(char ch) {
+ return (ch == SPACE)
+ || (ch == SINGLE_QUOTE)
+ || (ch == CLOSE_PARENTHESIS)
+ || (ch == DOUBLE_QUOTE)
+ || (ch == COMMA);
+ }
+
+ private boolean isReturning(String sql, int index) {
+ // RETURNING is a reserved keyword in PG, but requires a
+ // leading AS to be used as column label, to avoid ambiguity.
+ // We thus check for cases which do not have a leading AS.
+ // (https://www.postgresql.org/docs/current/sql-keywords-appendix.html)
+ if (index >= 1) {
+ if (((index + 10 <= sql.length())
+ && RETURNING_PATTERN.matcher(sql.substring(index, index + 10)).matches()
+ && !((index >= 4)
+ && AS_RETURNING_PATTERN.matcher(sql.substring(index - 4, index + 10)).matches()))) {
+ if (checkCharPrecedingReturning(sql.charAt(index - 1))) {
+ return true;
+ }
+ // Check for cases where returning clause is part of a substring which starts with an
+ // invalid first character of an identifier.
+ // For example,
+ // insert into t select 2returning *;
+ int ind = index - 1;
+ while ((ind >= 0) && !checkCharPrecedingSubstrWithReturning(sql.charAt(ind))) {
+ ind--;
+ }
+ return !isValidIdentifierFirstChar(sql.charAt(ind + 1));
+ }
+ }
+ return false;
+ }
+
+ @InternalApi
+ @Override
+ protected boolean checkReturningClauseInternal(String rawSql) {
+ Preconditions.checkNotNull(rawSql);
+ String sql = rawSql.toLowerCase();
+ // Do a pre-check to check if the SQL string definitely does not have a returning clause.
+ // If this check fails, do a more involved check to check for a returning clause.
+ if (!sql.contains(RETURNING_STRING)) {
+ return false;
+ }
+ sql = sql.replaceAll("\\s+", " ");
+ int index = 0;
+ while (index < sql.length()) {
+ if (isReturning(sql, index)) {
+ return true;
+ } else {
+ index = skip(sql, index, null);
+ }
+ }
+ return false;
+ }
}
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java
index 968ef357c24..80dc97cde04 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java
@@ -45,6 +45,7 @@
import com.google.cloud.spanner.TransactionContext;
import com.google.cloud.spanner.TransactionManager;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
+import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType;
import com.google.cloud.spanner.connection.TransactionRetryListener.RetryResult;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
@@ -337,7 +338,10 @@ public ApiFuture executeQueryAsync(
final ParsedStatement statement,
final AnalyzeMode analyzeMode,
final QueryOption... options) {
- Preconditions.checkArgument(statement.isQuery(), "Statement is not a query");
+ Preconditions.checkArgument(
+ (statement.getType() == StatementType.QUERY)
+ || (statement.getType() == StatementType.UPDATE && statement.hasReturningClause()),
+ "Statement must be a query or DML with returning clause");
checkValidTransaction();
ApiFuture res;
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java
index 9646c7c9c95..cbde67393aa 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java
@@ -179,12 +179,18 @@ public ApiFuture executeQueryAsync(
final QueryOption... options) {
Preconditions.checkNotNull(statement);
Preconditions.checkArgument(
- statement.isQuery() || (statement.isUpdate() && analyzeMode != AnalyzeMode.NONE),
+ statement.isQuery()
+ || (statement.isUpdate()
+ && (analyzeMode != AnalyzeMode.NONE || statement.hasReturningClause())),
"The statement must be a query, or the statement must be DML and AnalyzeMode must be PLAN or PROFILE");
checkAndMarkUsed();
- if (statement.isUpdate() && analyzeMode != AnalyzeMode.NONE) {
- return analyzeTransactionalUpdateAsync(statement, analyzeMode);
+ if (statement.isUpdate()) {
+ if (analyzeMode != AnalyzeMode.NONE) {
+ return analyzeTransactionalUpdateAsync(statement, analyzeMode);
+ }
+ // DML with returning clause.
+ return executeDmlReturningAsync(statement, options);
}
final ReadOnlyTransaction currentTransaction =
@@ -217,6 +223,30 @@ public ApiFuture executeQueryAsync(
return executeStatementAsync(statement, callable, SpannerGrpc.getExecuteStreamingSqlMethod());
}
+ private ApiFuture executeDmlReturningAsync(
+ final ParsedStatement update, QueryOption... options) {
+ Callable callable =
+ () -> {
+ try {
+ writeTransaction = createWriteTransaction();
+ ResultSet resultSet =
+ writeTransaction.run(
+ transaction ->
+ DirectExecuteResultSet.ofResultSet(
+ transaction.executeQuery(update.getStatement(), options)));
+ state = UnitOfWorkState.COMMITTED;
+ return resultSet;
+ } catch (Throwable t) {
+ state = UnitOfWorkState.COMMIT_FAILED;
+ throw t;
+ }
+ };
+ return executeStatementAsync(
+ update,
+ callable,
+ ImmutableList.of(SpannerGrpc.getExecuteSqlMethod(), SpannerGrpc.getCommitMethod()));
+ }
+
@Override
public Timestamp getReadTimestamp() {
ConnectionPreconditions.checkState(
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java
index f223cbb935a..a9ca5fb9726 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java
@@ -25,10 +25,16 @@
import com.google.common.collect.Sets;
import java.util.Collections;
import java.util.Set;
+import java.util.regex.Pattern;
@InternalApi
public class SpannerStatementParser extends AbstractStatementParser {
+ private static final Pattern THEN_RETURN_PATTERN =
+ Pattern.compile("[ `')\"]then return[ *`'(\"]");
+ private static final String THEN_STRING = "then";
+ private static final String RETURN_STRING = "return";
+
public SpannerStatementParser() throws CompileException {
super(
Collections.unmodifiableSet(
@@ -69,6 +75,8 @@ String removeCommentsAndTrimInternal(String sql) {
char startQuote = 0;
boolean lastCharWasEscapeChar = false;
boolean isTripleQuoted = false;
+ boolean whitespaceBeforeOrAfterMultiLineComment = false;
+ int multiLineCommentStartIdx = -1;
StringBuilder res = new StringBuilder(sql.length());
int index = 0;
while (index < sql.length()) {
@@ -112,6 +120,17 @@ String removeCommentsAndTrimInternal(String sql) {
} else if (isInMultiLineComment) {
if (sql.length() > index + 1 && c == ASTERISK && sql.charAt(index + 1) == SLASH) {
isInMultiLineComment = false;
+ if (!whitespaceBeforeOrAfterMultiLineComment && (sql.length() > index + 2)) {
+ whitespaceBeforeOrAfterMultiLineComment =
+ Character.isWhitespace(sql.charAt(index + 2));
+ }
+ // If the multiline comment does not have any whitespace before or after it, and it is
+ // neither at the start nor at the end of SQL string, append an extra space.
+ if (!whitespaceBeforeOrAfterMultiLineComment
+ && (multiLineCommentStartIdx != 0)
+ && (index != sql.length() - 2)) {
+ res.append(' ');
+ }
index++;
}
} else {
@@ -121,6 +140,11 @@ String removeCommentsAndTrimInternal(String sql) {
isInSingleLineComment = true;
} else if (sql.length() > index + 1 && c == SLASH && sql.charAt(index + 1) == ASTERISK) {
isInMultiLineComment = true;
+ if (index >= 1) {
+ whitespaceBeforeOrAfterMultiLineComment =
+ Character.isWhitespace(sql.charAt(index - 1));
+ }
+ multiLineCommentStartIdx = index;
index++;
} else {
if (c == SINGLE_QUOTE || c == DOUBLE_QUOTE || c == BACKTICK_QUOTE) {
@@ -250,4 +274,71 @@ ParametersInfo convertPositionalParametersToNamedParametersInternal(char paramCh
}
return new ParametersInfo(paramIndex - 1, named.toString());
}
+
+ private boolean isReturning(String sql, int index) {
+ return (index >= 1)
+ && (index + 12 <= sql.length())
+ && THEN_RETURN_PATTERN.matcher(sql.substring(index - 1, index + 12)).matches();
+ }
+
+ @InternalApi
+ @Override
+ protected boolean checkReturningClauseInternal(String rawSql) {
+ Preconditions.checkNotNull(rawSql);
+ String sql = rawSql.toLowerCase();
+ // Do a pre-check to check if the SQL string definitely does not have a returning clause.
+ // If this check fails, do a more involved check to check for a returning clause.
+ if (!(sql.contains(THEN_STRING) && sql.contains(RETURN_STRING))) {
+ return false;
+ }
+ sql = sql.replaceAll("\\s+", " ");
+ final char SINGLE_QUOTE = '\'';
+ final char DOUBLE_QUOTE = '"';
+ final char BACKTICK_QUOTE = '`';
+ boolean isInQuoted = false;
+ char startQuote = 0;
+ boolean lastCharWasEscapeChar = false;
+ boolean isTripleQuoted = false;
+ for (int index = 0; index < sql.length(); index++) {
+ char c = sql.charAt(index);
+ if (isInQuoted) {
+ if (c == startQuote) {
+ if (lastCharWasEscapeChar) {
+ lastCharWasEscapeChar = false;
+ } else if (isTripleQuoted) {
+ if (sql.length() > index + 2
+ && sql.charAt(index + 1) == startQuote
+ && sql.charAt(index + 2) == startQuote) {
+ isInQuoted = false;
+ startQuote = 0;
+ isTripleQuoted = false;
+ }
+ } else {
+ isInQuoted = false;
+ startQuote = 0;
+ }
+ } else if (c == '\\') {
+ lastCharWasEscapeChar = true;
+ } else {
+ lastCharWasEscapeChar = false;
+ }
+ } else {
+ if (isReturning(sql, index)) {
+ return true;
+ } else {
+ if (c == SINGLE_QUOTE || c == DOUBLE_QUOTE || c == BACKTICK_QUOTE) {
+ isInQuoted = true;
+ startQuote = c;
+ // check whether it is a triple-quote
+ if (sql.length() > index + 2
+ && sql.charAt(index + 1) == startQuote
+ && sql.charAt(index + 2) == startQuote) {
+ isTripleQuoted = true;
+ }
+ }
+ }
+ }
+ }
+ return false;
+ }
}
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
index 44942bb566c..340f999d8ee 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
@@ -270,6 +270,14 @@ public static StatementResult update(Statement statement, long updateCount) {
return new StatementResult(statement, updateCount);
}
+ /**
+ * Creates a {@link StatementResult} for a DML statement with returning clause that returns a
+ * ResultSet.
+ */
+ public static StatementResult updateReturning(Statement statement, ResultSet resultSet) {
+ return new StatementResult(statement, resultSet);
+ }
+
/** Creates a {@link StatementResult} for statement that should return an error. */
public static StatementResult exception(Statement statement, StatusRuntimeException exception) {
return new StatementResult(statement, exception);
@@ -1093,9 +1101,6 @@ public void executeBatchDml(
.build();
break resultLoop;
case RESULT_SET:
- throw Status.INVALID_ARGUMENT
- .withDescription("Not a DML statement: " + statement.getSql())
- .asRuntimeException();
case UPDATE_COUNT:
results.add(res);
break;
@@ -1120,10 +1125,20 @@ public void executeBatchDml(
}
ExecuteBatchDmlResponse.Builder builder = ExecuteBatchDmlResponse.newBuilder();
for (StatementResult res : results) {
+ Long updateCount;
+ switch (res.getType()) {
+ case UPDATE_COUNT:
+ updateCount = res.getUpdateCount();
+ break;
+ case RESULT_SET:
+ updateCount = res.getResultSet().getStats().getRowCountExact();
+ break;
+ default:
+ throw new IllegalStateException("Invalid result type: " + res.getType());
+ }
builder.addResultSets(
ResultSet.newBuilder()
- .setStats(
- ResultSetStats.newBuilder().setRowCountExact(res.getUpdateCount()).build())
+ .setStats(ResultSetStats.newBuilder().setRowCountExact(updateCount).build())
.setMetadata(
ResultSetMetadata.newBuilder()
.setTransaction(
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbortedTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbortedTest.java
index 96416c3e8d5..3a665c4fe7a 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbortedTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbortedTest.java
@@ -89,6 +89,46 @@ public void testCommitAborted() {
}
}
+ @Test
+ public void testCommitAbortedDuringUpdateWithReturning() {
+ // Do two iterations to ensure that each iteration gets its own transaction, and that each
+ // transaction is the most recent transaction of that session.
+ for (int i = 0; i < 2; i++) {
+ mockSpanner.putStatementResult(
+ StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_BEFORE_INSERT));
+ mockSpanner.putStatementResult(
+ StatementResult.updateReturning(INSERT_RETURNING_STATEMENT, UPDATE_RETURNING_RESULTSET));
+ AbortInterceptor interceptor = new AbortInterceptor(0);
+ try (ITConnection connection =
+ createConnection(interceptor, new CountTransactionRetryListener())) {
+ // verify that the there is no test record
+ try (ResultSet rs =
+ connection.executeQuery(Statement.of("SELECT COUNT(*) AS C FROM TEST WHERE ID=1"))) {
+ assertThat(rs.next(), is(true));
+ assertThat(rs.getLong("C"), is(equalTo(0L)));
+ assertThat(rs.next(), is(false));
+ }
+ // do an insert with returning
+ connection.executeQuery(
+ Statement.of("INSERT INTO TEST (ID, NAME) VALUES (1, 'test aborted') THEN RETURN *"));
+ // indicate that the next statement should abort
+ interceptor.setProbability(1.0);
+ interceptor.setOnlyInjectOnce(true);
+ // do a commit that will first abort, and then on retry will succeed
+ connection.commit();
+ mockSpanner.putStatementResult(
+ StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_AFTER_INSERT));
+ // verify that the insert succeeded
+ try (ResultSet rs =
+ connection.executeQuery(Statement.of("SELECT COUNT(*) AS C FROM TEST WHERE ID=1"))) {
+ assertThat(rs.next(), is(true));
+ assertThat(rs.getLong("C"), is(equalTo(1L)));
+ assertThat(rs.next(), is(false));
+ }
+ }
+ }
+ }
+
@Test
public void testAbortedDuringRetryOfFailedQuery() {
final Statement invalidStatement = Statement.of("SELECT * FROM FOO");
@@ -144,6 +184,34 @@ public void testAbortedDuringRetryOfFailedUpdate() {
assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(6);
}
+ @Test
+ public void testAbortedDuringRetryOfFailedUpdateWithReturning() {
+ final Statement invalidStatement = Statement.of("INSERT INTO FOO THEN RETURN *");
+ StatusRuntimeException notFound =
+ Status.NOT_FOUND.withDescription("Table not found").asRuntimeException();
+ mockSpanner.putStatementResult(StatementResult.exception(invalidStatement, notFound));
+ try (ITConnection connection =
+ createConnection(createAbortFirstRetryListener(invalidStatement, notFound))) {
+ connection.execute(INSERT_STATEMENT);
+ try {
+ connection.execute(invalidStatement);
+ fail("missing expected exception");
+ } catch (SpannerException e) {
+ assertThat(e.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND);
+ }
+ // Force an abort and retry.
+ mockSpanner.abortNextStatement();
+ connection.commit();
+ }
+ assertThat(mockSpanner.countRequestsOfType(CommitRequest.class)).isEqualTo(2);
+ // The transaction will be executed 3 times, which means that there will be 6
+ // ExecuteSqlRequests:
+ // 1. The initial attempt.
+ // 2. The first retry attempt. This will fail on the invalid statement as it is aborted.
+ // 3. the second retry attempt. This will succeed.
+ assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(6);
+ }
+
@Test
public void testAbortedDuringRetryOfFailedBatchUpdate() {
final Statement invalidStatement = Statement.of("INSERT INTO FOO");
@@ -167,6 +235,29 @@ public void testAbortedDuringRetryOfFailedBatchUpdate() {
assertThat(mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)).isEqualTo(3);
}
+ @Test
+ public void testAbortedDuringRetryOfFailedBatchUpdateWithReturning() {
+ final Statement invalidStatement = Statement.of("INSERT INTO FOO THEN RETURN *");
+ StatusRuntimeException notFound =
+ Status.NOT_FOUND.withDescription("Table not found").asRuntimeException();
+ mockSpanner.putStatementResult(StatementResult.exception(invalidStatement, notFound));
+ try (ITConnection connection =
+ createConnection(createAbortFirstRetryListener(invalidStatement, notFound))) {
+ connection.execute(INSERT_STATEMENT);
+ try {
+ connection.executeBatchUpdate(Collections.singletonList(invalidStatement));
+ fail("missing expected exception");
+ } catch (SpannerException e) {
+ assertThat(e.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND);
+ }
+ // Force an abort and retry.
+ mockSpanner.abortNextStatement();
+ connection.commit();
+ }
+ assertThat(mockSpanner.countRequestsOfType(CommitRequest.class)).isEqualTo(2);
+ assertThat(mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)).isEqualTo(3);
+ }
+
@Test
public void testAbortedDuringRetryOfFailedQueryAsFirstStatement() {
final Statement invalidStatement = Statement.of("SELECT * FROM FOO");
@@ -225,6 +316,29 @@ public void testAbortedDuringRetryOfFailedUpdateAsFirstStatement() {
assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(8);
}
+ @Test
+ public void testAbortedDuringRetryOfFailedUpdateWithReturningAsFirstStatement() {
+ final Statement invalidStatement = Statement.of("INSERT INTO FOO THEN RETURN *");
+ StatusRuntimeException notFound =
+ Status.NOT_FOUND.withDescription("Table not found").asRuntimeException();
+ mockSpanner.putStatementResult(StatementResult.exception(invalidStatement, notFound));
+ try (ITConnection connection =
+ createConnection(createAbortRetryListener(2, invalidStatement, notFound))) {
+ try {
+ connection.execute(invalidStatement);
+ fail("missing expected exception");
+ } catch (SpannerException e) {
+ assertThat(e.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND);
+ }
+ connection.execute(INSERT_STATEMENT);
+ // Force an abort and retry.
+ mockSpanner.abortNextStatement();
+ connection.commit();
+ }
+ assertThat(mockSpanner.countRequestsOfType(CommitRequest.class)).isEqualTo(2);
+ assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(8);
+ }
+
@Test
public void testAbortedDuringRetryOfFailedBatchUpdateAsFirstStatement() {
final Statement invalidStatement = Statement.of("INSERT INTO FOO");
@@ -248,6 +362,29 @@ public void testAbortedDuringRetryOfFailedBatchUpdateAsFirstStatement() {
assertThat(mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)).isEqualTo(6);
}
+ @Test
+ public void testAbortedDuringRetryOfFailedBatchUpdateWithReturningAsFirstStatement() {
+ final Statement invalidStatement = Statement.of("INSERT INTO FOO THEN RETURN *");
+ StatusRuntimeException notFound =
+ Status.NOT_FOUND.withDescription("Table not found").asRuntimeException();
+ mockSpanner.putStatementResult(StatementResult.exception(invalidStatement, notFound));
+ try (ITConnection connection =
+ createConnection(createAbortFirstRetryListener(invalidStatement, notFound))) {
+ try {
+ connection.executeBatchUpdate(Collections.singletonList(invalidStatement));
+ fail("missing expected exception");
+ } catch (SpannerException e) {
+ assertThat(e.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND);
+ }
+ connection.execute(INSERT_STATEMENT);
+ // Force an abort and retry.
+ mockSpanner.abortNextStatement();
+ connection.commit();
+ }
+ assertThat(mockSpanner.countRequestsOfType(CommitRequest.class)).isEqualTo(2);
+ assertThat(mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)).isEqualTo(6);
+ }
+
@Test
public void testRetryUsesTags() {
mockSpanner.putStatementResult(
@@ -303,6 +440,63 @@ public void testRetryUsesTags() {
}
@Test
+ public void testRetryUsesTagsWithUpdateReturning() {
+ mockSpanner.putStatementResult(
+ StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_BEFORE_INSERT));
+ mockSpanner.putStatementResult(StatementResult.update(INSERT_STATEMENT, UPDATE_COUNT));
+ mockSpanner.putStatementResult(
+ StatementResult.updateReturning(INSERT_RETURNING_STATEMENT, UPDATE_RETURNING_RESULTSET));
+ try (ITConnection connection = createConnection()) {
+ connection.setTransactionTag("transaction-tag");
+ connection.setStatementTag("statement-tag");
+ connection.executeUpdate(INSERT_STATEMENT);
+ connection.setStatementTag("statement-tag");
+ connection.executeBatchUpdate(ImmutableList.of(INSERT_STATEMENT, INSERT_RETURNING_STATEMENT));
+ connection.setStatementTag("statement-tag");
+ connection.executeQuery(SELECT_COUNT_STATEMENT);
+ connection.setStatementTag("statement-tag");
+ connection.executeQuery(INSERT_RETURNING_STATEMENT);
+
+ mockSpanner.abortNextStatement();
+ connection.commit();
+ }
+ long executeSqlRequestCount =
+ mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream()
+ .filter(
+ request ->
+ request.getRequestOptions().getRequestTag().equals("statement-tag")
+ && request
+ .getRequestOptions()
+ .getTransactionTag()
+ .equals("transaction-tag"))
+ .count();
+ assertEquals(6L, executeSqlRequestCount);
+
+ long executeBatchSqlRequestCount =
+ mockSpanner.getRequestsOfType(ExecuteBatchDmlRequest.class).stream()
+ .filter(
+ request ->
+ request.getRequestOptions().getRequestTag().equals("statement-tag")
+ && request
+ .getRequestOptions()
+ .getTransactionTag()
+ .equals("transaction-tag"))
+ .count();
+ assertEquals(2L, executeBatchSqlRequestCount);
+
+ long commitRequestCount =
+ mockSpanner.getRequestsOfType(CommitRequest.class).stream()
+ .filter(
+ request ->
+ request.getRequestOptions().getRequestTag().equals("")
+ && request
+ .getRequestOptions()
+ .getTransactionTag()
+ .equals("transaction-tag"))
+ .count();
+ assertEquals(2L, commitRequestCount);
+ }
+
public void testRetryUsesAnalyzeModeForUpdate() {
mockSpanner.putStatementResult(
StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_BEFORE_INSERT));
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java
index 7f9163e8e46..7b65b44f361 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java
@@ -37,6 +37,7 @@
import com.google.protobuf.Value;
import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.ResultSetMetadata;
+import com.google.spanner.v1.ResultSetStats;
import com.google.spanner.v1.StructType;
import com.google.spanner.v1.StructType.Field;
import com.google.spanner.v1.Type;
@@ -96,8 +97,23 @@ public abstract class AbstractMockServerTest {
.build())
.setMetadata(SELECT_COUNT_METADATA)
.build();
+ public static final com.google.spanner.v1.ResultSet UPDATE_RETURNING_RESULTSET =
+ com.google.spanner.v1.ResultSet.newBuilder()
+ .setStats(ResultSetStats.newBuilder().setRowCountExact(1))
+ .setMetadata(
+ ResultSetMetadata.newBuilder()
+ .setRowType(
+ StructType.newBuilder()
+ .addFields(
+ Field.newBuilder()
+ .setName("col")
+ .setType(Type.newBuilder().setCodeValue(TypeCode.INT64_VALUE))
+ .build())))
+ .build();
public static final Statement INSERT_STATEMENT =
Statement.of("INSERT INTO TEST (ID, NAME) VALUES (1, 'test aborted')");
+ public static final Statement INSERT_RETURNING_STATEMENT =
+ Statement.of("INSERT INTO TEST (ID, NAME) VALUES (1, 'test aborted') THEN RETURN *");
public static final long UPDATE_COUNT = 1L;
public static final int RANDOM_RESULT_SET_ROW_COUNT = 100;
@@ -149,6 +165,8 @@ public void getOperation(
mockSpanner.putStatementResult(
StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_BEFORE_INSERT));
mockSpanner.putStatementResult(StatementResult.update(INSERT_STATEMENT, UPDATE_COUNT));
+ mockSpanner.putStatementResult(
+ StatementResult.updateReturning(INSERT_RETURNING_STATEMENT, UPDATE_RETURNING_RESULTSET));
mockSpanner.putStatementResult(
StatementResult.query(SELECT_RANDOM_STATEMENT, RANDOM_RESULT_SET));
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java
index 2e39025bf33..56760d900af 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java
@@ -1266,6 +1266,215 @@ public void testPostgreSQLGetQueryParameters() {
parser.getQueryParameters("select '$2' from foo where bar=$1 and baz=$foo"));
}
+ @Test
+ public void testGoogleSQLReturningClause() {
+ assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
+
+ SpannerStatementParser parser = (SpannerStatementParser) this.parser;
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) then return *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) then\nreturn *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)\nthen\n\n\nreturn\n*"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)then return *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) then return(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)then return(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) then/*comment*/return *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) then return /*then return*/ *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)then/*comment*/return *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)then/*comment*/return(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)then /*comment*/return(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(
+ Statement.of("insert into x (a,b) values (1,2)/*comment*/then/*comment*/return(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(
+ Statement.of(
+ "insert into x (a,b) values (1,2)/*comment*/then/*comment*/return/*comment*/(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(
+ Statement.of(
+ "insert into x (a,b) values (1,2)/*comment"
+ + "*/then"
+ + "/*comment"
+ + "*/return/*"
+ + "comment*/(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("delete from x where y=\"z\"then return *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x select 'then return' as returning then return *"))
+ .hasReturningClause());
+ assertTrue(
+ parser.parse(Statement.of("delete from x where 10=`z`then return *")).hasReturningClause());
+ assertFalse(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) returning (a)"))
+ .hasReturningClause());
+ assertFalse(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) /*then return **/"))
+ .hasReturningClause());
+ assertFalse(
+ parser.parse(Statement.of("insert into x (a,b) values (1,2)")).hasReturningClause());
+ assertFalse(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)thenreturn*"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into t(a) select \"x\"then return*"))
+ .hasReturningClause());
+ }
+
+ @Test
+ public void testPostgreSQLReturningClause() {
+ assumeTrue(dialect == Dialect.POSTGRESQL);
+
+ PostgreSQLStatementParser parser = (PostgreSQLStatementParser) this.parser;
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) returning *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)returning *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) returning(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)returning(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2)/*comment*/returning(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(
+ Statement.of("insert into x (a,b) values (1,2)/*comment*/returning/*comment*/(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(
+ Statement.of(
+ "insert into x (a,b) values (1,2)/*comment" + "*/returning/*" + "comment*/(a)"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x select 1 as returning returning *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x select 'returning' as returning returning *"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into x select 'returning'as returning returning *"))
+ .hasReturningClause());
+ assertTrue(
+ parser.parse(Statement.of("delete from x where y=\"z\"returning *")).hasReturningClause());
+ assertTrue(
+ parser.parse(Statement.of("delete from x where y='z'returning *")).hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into t1 select 1/*as /*returning*/ returning*/returning *"))
+ .hasReturningClause());
+ assertFalse(
+ parser
+ .parse(Statement.of("insert into x (a,b) values (1,2) then return (a)"))
+ .hasReturningClause());
+ assertFalse(
+ parser.parse(Statement.of("insert into x (a,b) values (1,2)")).hasReturningClause());
+ assertFalse(
+ parser.parse(Statement.of("insert into t1 select 1 as returning")).hasReturningClause());
+ assertFalse(
+ parser
+ .parse(Statement.of("insert into t1 select 1\nas\n\nreturning"))
+ .hasReturningClause());
+ assertFalse(
+ parser.parse(Statement.of("insert into t1 select 1asreturning")).hasReturningClause());
+ assertTrue(
+ parser
+ .parse(Statement.of("insert into t1 select 1 as/*eomment*/returning returning *"))
+ .hasReturningClause());
+ assertFalse(
+ parser
+ .parse(Statement.of("UPDATE x SET y = $$ RETURNING a, b, c$$ WHERE z = 123"))
+ .hasReturningClause());
+ assertFalse(
+ parser
+ .parse(
+ Statement.of("UPDATE x SET y = $foobar$ RETURNING a, b, c$foobar$ WHERE z = 123"))
+ .hasReturningClause());
+ assertFalse(
+ parser
+ .parse(Statement.of("UPDATE x SET y = $returning$ returning $returning$ WHERE z = 123"))
+ .hasReturningClause());
+ assertTrue(
+ parser
+ .parse(
+ Statement.of(
+ "UPDATE x SET y = $returning$returning$returning$ WHERE z = 123 ReTuRnInG *"))
+ .hasReturningClause());
+ assertTrue(
+ parser.parse(Statement.of("insert into t1 select 1 returning*")).hasReturningClause());
+ assertTrue(
+ parser.parse(Statement.of("insert into t1 select 2returning*")).hasReturningClause());
+ assertTrue(
+ parser.parse(Statement.of("insert into t1 select 10e2returning*")).hasReturningClause());
+ assertFalse(
+ parser
+ .parse(Statement.of("insert into t1 select 'test''returning *'"))
+ .hasReturningClause());
+ assertTrue(
+ parser.parse(Statement.of("insert into t select 2,3returning*")).hasReturningClause());
+ assertTrue(
+ parser.parse(Statement.of("insert into t1 select 10.returning*")).hasReturningClause());
+ }
+
private void assertUnclosedLiteral(String sql) {
try {
parser.convertPositionalParametersToNamedParameters('?', sql);
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java
new file mode 100644
index 00000000000..5e5ca800402
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java
@@ -0,0 +1,344 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.connection.it;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assume.assumeFalse;
+
+import com.google.cloud.spanner.AsyncResultSet;
+import com.google.cloud.spanner.AsyncResultSet.CallbackResponse;
+import com.google.cloud.spanner.Dialect;
+import com.google.cloud.spanner.ErrorCode;
+import com.google.cloud.spanner.Mutation;
+import com.google.cloud.spanner.ParallelIntegrationTest;
+import com.google.cloud.spanner.ResultSet;
+import com.google.cloud.spanner.SpannerException;
+import com.google.cloud.spanner.Statement;
+import com.google.cloud.spanner.connection.AsyncStatementResult;
+import com.google.cloud.spanner.connection.Connection;
+import com.google.cloud.spanner.connection.ITAbstractSpannerTest;
+import com.google.cloud.spanner.connection.StatementResult;
+import com.google.cloud.spanner.connection.StatementResult.ResultType;
+import com.google.cloud.spanner.connection.TransactionMode;
+import com.google.cloud.spanner.testing.EmulatorSpannerHelper;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+/** Execute DML Returning statements using the generic connection API. */
+@Category(ParallelIntegrationTest.class)
+@RunWith(Parameterized.class)
+public class ITDmlReturningTest extends ITAbstractSpannerTest {
+ private final ImmutableMap UPDATE_RETURNING_MAP =
+ ImmutableMap.of(
+ Dialect.GOOGLE_STANDARD_SQL,
+ Statement.of("UPDATE Singers SET LastName = 'XYZ' WHERE FirstName = 'ABC' THEN RETURN *"),
+ Dialect.POSTGRESQL,
+ Statement.of("UPDATE Singers SET LastName = 'XYZ' WHERE FirstName = 'ABC' RETURNING *"));
+ private final ImmutableMap DDL_MAP =
+ ImmutableMap.of(
+ Dialect.GOOGLE_STANDARD_SQL,
+ "CREATE TABLE Singers ("
+ + " SingerId INT64,"
+ + " FirstName STRING(1024),"
+ + " LastName STRING(1024)"
+ + ") PRIMARY KEY(SingerId)",
+ Dialect.POSTGRESQL,
+ "CREATE TABLE Singers ("
+ + " SingerId BIGINT PRIMARY KEY,"
+ + " FirstName character varying(1024),"
+ + " LastName character varying(1024))");
+ private final Map IS_INITIALIZED = new HashMap<>();
+
+ public ITDmlReturningTest() {
+ IS_INITIALIZED.put(Dialect.GOOGLE_STANDARD_SQL, false);
+ IS_INITIALIZED.put(Dialect.POSTGRESQL, false);
+ }
+
+ @Parameter public Dialect dialect;
+
+ @Parameters(name = "dialect = {0}")
+ public static Object[] data() {
+ return Dialect.values();
+ }
+
+ private boolean checkAndSetInitialized() {
+ if ((dialect == Dialect.GOOGLE_STANDARD_SQL) && !IS_INITIALIZED.get(dialect)) {
+ IS_INITIALIZED.put(dialect, true);
+ return true;
+ }
+ if ((dialect == Dialect.POSTGRESQL) && !IS_INITIALIZED.get(dialect)) {
+ IS_INITIALIZED.put(dialect, true);
+ return true;
+ }
+ return false;
+ }
+
+ @Before
+ public void setupTable() {
+ assumeFalse(
+ "DML Returning is not supported in the emulator", EmulatorSpannerHelper.isUsingEmulator());
+ if (checkAndSetInitialized()) {
+ database =
+ env.getTestHelper()
+ .createTestDatabase(dialect, Collections.singleton(DDL_MAP.get(dialect)));
+ List firstNames = Arrays.asList("ABC", "ABC", "DEF", "PQR", "ABC");
+ List lastNames = Arrays.asList("XYZ", "DEF", "XYZ", "ABC", "GHI");
+ List mutations = new ArrayList<>();
+ for (int id = 1; id <= 5; id++) {
+ mutations.add(
+ Mutation.newInsertBuilder("SINGERS")
+ .set("SINGERID")
+ .to(id)
+ .set("FIRSTNAME")
+ .to(firstNames.get(id - 1))
+ .set("LASTNAME")
+ .to(lastNames.get(id - 1))
+ .build());
+ }
+ env.getTestHelper().getDatabaseClient(database).write(mutations);
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteQuery() {
+ try (Connection connection = createConnection()) {
+ try (ResultSet rs = connection.executeQuery(UPDATE_RETURNING_MAP.get(dialect))) {
+ assertEquals(rs.getColumnCount(), 3);
+ assertTrue(rs.next());
+ assertEquals(rs.getString(1), "ABC");
+ assertTrue(rs.next());
+ assertEquals(rs.getString(1), "ABC");
+ assertTrue(rs.next());
+ assertEquals(rs.getString(1), "ABC");
+ assertFalse(rs.next());
+ assertNotNull(rs.getStats());
+ assertEquals(rs.getStats().getRowCountExact(), 3);
+ }
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteQueryAsync() {
+ try (Connection connection = createConnection()) {
+ try (AsyncResultSet rs = connection.executeQueryAsync(UPDATE_RETURNING_MAP.get(dialect))) {
+ rs.setCallback(
+ Executors.newSingleThreadExecutor(),
+ resultSet -> {
+ try {
+ while (true) {
+ switch (resultSet.tryNext()) {
+ case OK:
+ assertEquals(resultSet.getColumnCount(), 3);
+ assertEquals(resultSet.getString(1), "ABC");
+ break;
+ case DONE:
+ assertNotNull(resultSet.getStats());
+ assertEquals(resultSet.getStats().getRowCountExact(), 3);
+ return CallbackResponse.DONE;
+ case NOT_READY:
+ return CallbackResponse.CONTINUE;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+ } catch (SpannerException e) {
+ return CallbackResponse.DONE;
+ }
+ });
+ }
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteUpdate() {
+ try (Connection connection = createConnection()) {
+ connection.setAutocommit(false);
+ SpannerException e =
+ assertThrows(
+ SpannerException.class,
+ () -> connection.executeUpdate(UPDATE_RETURNING_MAP.get(dialect)));
+ assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION);
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteUpdateAsync() {
+ try (Connection connection = createConnection()) {
+ connection.setAutocommit(false);
+ SpannerException e =
+ assertThrows(
+ SpannerException.class,
+ () -> connection.executeUpdateAsync(UPDATE_RETURNING_MAP.get(dialect)));
+ assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION);
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteBatchUpdate() {
+ try (Connection connection = createConnection()) {
+ connection.setAutocommit(false);
+ final Statement UPDATE_STMT = UPDATE_RETURNING_MAP.get(dialect);
+ long[] counts =
+ connection.executeBatchUpdate(ImmutableList.of(UPDATE_STMT, UPDATE_STMT, UPDATE_STMT));
+ assertArrayEquals(counts, new long[] {3, 3, 3});
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteBatchUpdateAsync() {
+ try (Connection connection = createConnection()) {
+ connection.setAutocommit(false);
+ final Statement UPDATE_STMT = UPDATE_RETURNING_MAP.get(dialect);
+ long[] counts =
+ connection
+ .executeBatchUpdateAsync(ImmutableList.of(UPDATE_STMT, UPDATE_STMT, UPDATE_STMT))
+ .get();
+ assertArrayEquals(counts, new long[] {3, 3, 3});
+ } catch (ExecutionException | InterruptedException e) {
+ // ignore
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecute() {
+ try (Connection connection = createConnection()) {
+ connection.setAutocommit(false);
+ StatementResult res = connection.execute(UPDATE_RETURNING_MAP.get(dialect));
+ assertEquals(res.getResultType(), ResultType.RESULT_SET);
+ try (ResultSet rs = res.getResultSet()) {
+ assertEquals(rs.getColumnCount(), 3);
+ assertTrue(rs.next());
+ assertEquals(rs.getString(1), "ABC");
+ assertTrue(rs.next());
+ assertEquals(rs.getString(1), "ABC");
+ assertTrue(rs.next());
+ assertEquals(rs.getString(1), "ABC");
+ assertFalse(rs.next());
+ assertNotNull(rs.getStats());
+ assertEquals(rs.getStats().getRowCountExact(), 3);
+ }
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteAsync() {
+ try (Connection connection = createConnection()) {
+ connection.setAutocommit(false);
+ AsyncStatementResult res = connection.executeAsync(UPDATE_RETURNING_MAP.get(dialect));
+ assertEquals(res.getResultType(), ResultType.RESULT_SET);
+ try (AsyncResultSet rs = res.getResultSetAsync()) {
+ rs.setCallback(
+ Executors.newSingleThreadExecutor(),
+ resultSet -> {
+ try {
+ while (true) {
+ switch (resultSet.tryNext()) {
+ case OK:
+ assertEquals(resultSet.getColumnCount(), 3);
+ assertEquals(resultSet.getString(1), "ABC");
+ break;
+ case DONE:
+ assertNotNull(resultSet.getStats());
+ assertEquals(resultSet.getStats().getRowCountExact(), 3);
+ return CallbackResponse.DONE;
+ case NOT_READY:
+ return CallbackResponse.CONTINUE;
+ default:
+ throw new IllegalStateException();
+ }
+ }
+ } catch (SpannerException e) {
+ System.out.printf("Error in callback: %s%n", e.getMessage());
+ return CallbackResponse.DONE;
+ }
+ });
+ }
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteQueryReadOnlyMode() {
+ try (Connection connection = createConnection()) {
+ connection.setReadOnly(true);
+ SpannerException e =
+ assertThrows(
+ SpannerException.class,
+ () -> connection.executeQuery(UPDATE_RETURNING_MAP.get(dialect)));
+ assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION);
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteQueryReadOnlyTransaction() {
+ try (Connection connection = createConnection()) {
+ connection.setReadOnly(false);
+ connection.setAutocommit(false);
+ connection.setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION);
+ SpannerException e =
+ assertThrows(
+ SpannerException.class,
+ () -> connection.executeQuery(UPDATE_RETURNING_MAP.get(dialect)));
+ assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION);
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteQueryAsyncReadOnlyMode() {
+ try (Connection connection = createConnection()) {
+ connection.setReadOnly(true);
+ SpannerException e =
+ assertThrows(
+ SpannerException.class,
+ () -> connection.executeQueryAsync(UPDATE_RETURNING_MAP.get(dialect)));
+ assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION);
+ }
+ }
+
+ @Test
+ public void testDmlReturningExecuteQueryAsyncReadOnlyTransaction() {
+ try (Connection connection = createConnection()) {
+ connection.setReadOnly(false);
+ connection.setAutocommit(false);
+ connection.setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION);
+ SpannerException e =
+ assertThrows(
+ SpannerException.class,
+ () -> connection.executeQueryAsync(UPDATE_RETURNING_MAP.get(dialect)));
+ assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION);
+ }
+ }
+}