From aac20bedf9ee7a6a2170f87fa88373b7d364ed9f Mon Sep 17 00:00:00 2001 From: Rajat Bhatta <93644539+rajatbhatta@users.noreply.github.com> Date: Wed, 9 Nov 2022 17:49:08 +0530 Subject: [PATCH] feat: support DML with Returning clause in Connection API (#1978) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: incorporate dml with returning clause * feat: changes * feat: change handling of AsyncResultSet. * fix: lint * doc: add comments * fix: lint * test: add tests for executeBatchUpdate * test: import fix * test: remove redundant import * test: add abort tests for dml returning * test: add pg dml returning tests * feat: change error statement * doc: add doc for dml with returning clause usage * Update google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java Co-authored-by: Knut Olav Løite * Update google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java Co-authored-by: Knut Olav Løite * fix: incorporate review comments * test: add more test cases * test: add todo * test: add separate abort tests for dml returning * fix: add try-with-resources block around ResultSet * feat: enhancement by adding a pre-check * feat: changes * test: delete unnecessary test * test: add few more tests to PG parser * feat: method doc update * test: nit fixes * feat: handle another corner case * test: add another test * clirr fixes * revert env for integration tests * remove comments * skip returning tests in emulator * fix: linting Co-authored-by: Knut Olav Løite --- .../clirr-ignored-differences.xml | 5 + .../connection/AbstractStatementParser.java | 51 ++- .../cloud/spanner/connection/Connection.java | 53 +-- .../spanner/connection/ConnectionImpl.java | 66 +++- .../connection/PostgreSQLStatementParser.java | 99 ++++- .../connection/ReadWriteTransaction.java | 6 +- .../connection/SingleUseTransaction.java | 36 +- .../connection/SpannerStatementParser.java | 91 +++++ .../cloud/spanner/MockSpannerServiceImpl.java | 25 +- .../cloud/spanner/connection/AbortedTest.java | 194 ++++++++++ .../connection/AbstractMockServerTest.java | 18 + .../connection/StatementParserTest.java | 209 +++++++++++ .../connection/it/ITDmlReturningTest.java | 344 ++++++++++++++++++ 13 files changed, 1150 insertions(+), 47 deletions(-) create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml index 039278a6436..adda4c3c854 100644 --- a/google-cloud-spanner/clirr-ignored-differences.xml +++ b/google-cloud-spanner/clirr-ignored-differences.xml @@ -202,4 +202,9 @@ com/google/cloud/spanner/DatabaseClient java.lang.String getDatabaseRole() + + 7013 + com/google/cloud/spanner/connection/AbstractStatementParser + boolean checkReturningClauseInternal(java.lang.String) + diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractStatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractStatementParser.java index fb272e913e5..b0ae8863e6f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractStatementParser.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractStatementParser.java @@ -140,6 +140,7 @@ public static class ParsedStatement { private final ClientSideStatementImpl clientSideStatement; private final Statement statement; private final String sqlWithoutComments; + private final boolean returningClause; private static ParsedStatement clientSideStatement( ClientSideStatementImpl clientSideStatement, @@ -155,11 +156,13 @@ private static ParsedStatement ddl(Statement statement, String sqlWithoutComment private static ParsedStatement query( Statement statement, String sqlWithoutComments, QueryOptions defaultQueryOptions) { return new ParsedStatement( - StatementType.QUERY, statement, sqlWithoutComments, defaultQueryOptions); + StatementType.QUERY, statement, sqlWithoutComments, defaultQueryOptions, false); } - private static ParsedStatement update(Statement statement, String sqlWithoutComments) { - return new ParsedStatement(StatementType.UPDATE, statement, sqlWithoutComments); + private static ParsedStatement update( + Statement statement, String sqlWithoutComments, boolean returningClause) { + return new ParsedStatement( + StatementType.UPDATE, statement, sqlWithoutComments, returningClause); } private static ParsedStatement unknown(Statement statement, String sqlWithoutComments) { @@ -176,23 +179,34 @@ private ParsedStatement( this.clientSideStatement = clientSideStatement; this.statement = statement; this.sqlWithoutComments = sqlWithoutComments; + this.returningClause = false; + } + + private ParsedStatement( + StatementType type, + Statement statement, + String sqlWithoutComments, + boolean returningClause) { + this(type, statement, sqlWithoutComments, null, returningClause); } private ParsedStatement(StatementType type, Statement statement, String sqlWithoutComments) { - this(type, statement, sqlWithoutComments, null); + this(type, statement, sqlWithoutComments, null, false); } private ParsedStatement( StatementType type, Statement statement, String sqlWithoutComments, - QueryOptions defaultQueryOptions) { + QueryOptions defaultQueryOptions, + boolean returningClause) { Preconditions.checkNotNull(type); Preconditions.checkNotNull(statement); this.type = type; this.clientSideStatement = null; this.statement = mergeQueryOptions(statement, defaultQueryOptions); this.sqlWithoutComments = sqlWithoutComments; + this.returningClause = returningClause; } @Override @@ -219,6 +233,12 @@ public StatementType getType() { return type; } + /** Returns whether the statement has a returning clause or not. * */ + @InternalApi + public boolean hasReturningClause() { + return this.returningClause; + } + /** * Returns true if the statement is a query that will return a {@link * com.google.cloud.spanner.ResultSet}. @@ -355,7 +375,7 @@ ParsedStatement parse(Statement statement, QueryOptions defaultQueryOptions) { } else if (isQuery(sql)) { return ParsedStatement.query(statement, sql, defaultQueryOptions); } else if (isUpdateStatement(sql)) { - return ParsedStatement.update(statement, sql); + return ParsedStatement.update(statement, sql, checkReturningClause(sql)); } else if (isDdlStatement(sql)) { return ParsedStatement.ddl(statement, sql); } @@ -460,6 +480,10 @@ private boolean statementStartsWith(String sql, Iterable checkStatements static final char SLASH = '/'; static final char ASTERISK = '*'; static final char DOLLAR = '$'; + static final char SPACE = ' '; + static final char CLOSE_PARENTHESIS = ')'; + static final char COMMA = ','; + static final char UNDERSCORE = '_'; /** * Removes comments from and trims the given sql statement using the dialect of this parser. @@ -522,4 +546,19 @@ static int countOccurrencesOf(char c, String string) { } return res; } + + /** + * Checks if the given SQL string contains a Returning clause. This method is used only in case of + * a DML statement. + * + * @param sql The sql string without comments that has to be evaluated. + * @return A boolean indicating whether the sql string has a Returning clause or not. + */ + @InternalApi + protected abstract boolean checkReturningClauseInternal(String sql); + + @InternalApi + public boolean checkReturningClause(String sql) { + return checkReturningClauseInternal(sql); + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java index 04ea893cb9c..afb3723c143 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java @@ -856,8 +856,8 @@ default RpcPriority getRPCPriority() { * state. The returned value depends on the type of statement: * *
    - *
  • Queries will return a {@link ResultSet} - *
  • DML statements will return an update count + *
  • Queries and DML statements with returning clause will return a {@link ResultSet}. + *
  • Simple DML statements will return an update count *
  • DDL statements will return a {@link ResultType#NO_RESULT} *
  • Connection and transaction statements (SET AUTOCOMMIT=TRUE|FALSE, SHOW AUTOCOMMIT, SET * TRANSACTION READ ONLY, etc) will return either a {@link ResultSet} or {@link @@ -874,9 +874,9 @@ default RpcPriority getRPCPriority() { * state asynchronously. The returned value depends on the type of statement: * *
      - *
    • Queries will return an {@link AsyncResultSet} - *
    • DML statements will return an {@link ApiFuture} with an update count that is done when - * the DML statement has been applied successfully, or that throws an {@link + *
    • Queries and DML statements with returning clause will return an {@link AsyncResultSet}. + *
    • Simple DML statements will return an {@link ApiFuture} with an update count that is done + * when the DML statement has been applied successfully, or that throws an {@link * ExecutionException} if the DML statement failed. *
    • DDL statements will return an {@link ApiFuture} containing a {@link Void} that is done * when the DDL statement has been applied successfully, or that throws an {@link @@ -894,31 +894,33 @@ default RpcPriority getRPCPriority() { AsyncStatementResult executeAsync(Statement statement); /** - * Executes the given statement as a query and returns the result as a {@link ResultSet}. This - * method blocks and waits for a response from Spanner. If the statement does not contain a valid - * query, the method will throw a {@link SpannerException}. + * Executes the given statement (a query or a DML statement with returning clause) and returns the + * result as a {@link ResultSet}. This method blocks and waits for a response from Spanner. If the + * statement does not contain a valid query or a DML statement with returning clause, the method + * will throw a {@link SpannerException}. * - * @param query The query statement to execute + * @param query The query statement or DML statement with returning clause to execute * @param options the options to configure the query - * @return a {@link ResultSet} with the results of the query + * @return a {@link ResultSet} with the results of the statement */ ResultSet executeQuery(Statement query, QueryOption... options); /** - * Executes the given statement asynchronously as a query and returns the result as an {@link - * AsyncResultSet}. This method is guaranteed to be non-blocking. If the statement does not - * contain a valid query, the method will throw a {@link SpannerException}. + * Executes the given statement (a query or a DML statement with returning clause) asynchronously + * and returns the result as an {@link AsyncResultSet}. This method is guaranteed to be + * non-blocking. If the statement does not contain a valid query or a DML statement with returning + * clause, the method will throw a {@link SpannerException}. * *

      See {@link AsyncResultSet#setCallback(java.util.concurrent.Executor, * com.google.cloud.spanner.AsyncResultSet.ReadyCallback)} for more information on how to consume - * the results of the query asynchronously. + * the results of the statement asynchronously. * *

      It is also possible to consume the returned {@link AsyncResultSet} in the same way as a * normal {@link ResultSet}, i.e. in a while-loop calling {@link AsyncResultSet#next()}. * - * @param query The query statement to execute + * @param query The query statement or DML statement with returning clause to execute * @param options the options to configure the query - * @return an {@link AsyncResultSet} with the results of the query + * @return an {@link AsyncResultSet} with the results of the statement */ AsyncResultSet executeQueryAsync(Statement query, QueryOption... options); @@ -951,8 +953,8 @@ default RpcPriority getRPCPriority() { ResultSet analyzeQuery(Statement query, QueryAnalyzeMode queryMode); /** - * Executes the given statement as a DML statement. If the statement does not contain a valid DML - * statement, the method will throw a {@link SpannerException}. + * Executes the given statement as a simple DML statement. If the statement does not contain a + * valid DML statement, the method will throw a {@link SpannerException}. * * @param update The update statement to execute * @return the number of records that were inserted/updated/deleted by this statement @@ -972,8 +974,9 @@ default ResultSetStats analyzeUpdate(Statement update, QueryAnalyzeMode analyzeM } /** - * Executes the given statement asynchronously as a DML statement. If the statement does not - * contain a valid DML statement, the method will throw a {@link SpannerException}. + * Executes the given statement asynchronously as a simple DML statement. If the statement does + * not contain a simple DML statement, the method will throw a {@link SpannerException}. A DML + * statement with returning clause will throw a {@link SpannerException}. * *

      This method is guaranteed to be non-blocking. * @@ -984,8 +987,9 @@ default ResultSetStats analyzeUpdate(Statement update, QueryAnalyzeMode analyzeM ApiFuture executeUpdateAsync(Statement update); /** - * Executes a list of DML statements in a single request. The statements will be executed in order - * and the semantics is the same as if each statement is executed by {@link + * Executes a list of DML statements (can be simple DML statements or DML statements with + * returning clause) in a single request. The statements will be executed in order and the + * semantics is the same as if each statement is executed by {@link * Connection#executeUpdate(Statement)} in a loop. This method returns an array of long integers, * each representing the number of rows modified by each statement. * @@ -1006,8 +1010,9 @@ default ResultSetStats analyzeUpdate(Statement update, QueryAnalyzeMode analyzeM long[] executeBatchUpdate(Iterable updates); /** - * Executes a list of DML statements in a single request. The statements will be executed in order - * and the semantics is the same as if each statement is executed by {@link + * Executes a list of DML statements (can be simple DML statements or DML statements with + * returning clause) in a single request. The statements will be executed in order and the + * semantics is the same as if each statement is executed by {@link * Connection#executeUpdate(Statement)} in a loop. This method returns an {@link ApiFuture} that * contains an array of long integers, each representing the number of rows modified by each * statement. diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java index f0c772acd9e..e935a607a83 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java @@ -853,6 +853,9 @@ public StatementResult execute(Statement statement) { case QUERY: return StatementResultImpl.of(internalExecuteQuery(parsedStatement, AnalyzeMode.NONE)); case UPDATE: + if (parsedStatement.hasReturningClause()) { + return StatementResultImpl.of(internalExecuteQuery(parsedStatement, AnalyzeMode.NONE)); + } return StatementResultImpl.of(get(internalExecuteUpdateAsync(parsedStatement))); case DDL: get(executeDdlAsync(parsedStatement)); @@ -881,6 +884,10 @@ public AsyncStatementResult executeAsync(Statement statement) { return AsyncStatementResultImpl.of( internalExecuteQueryAsync(parsedStatement, AnalyzeMode.NONE)); case UPDATE: + if (parsedStatement.hasReturningClause()) { + return AsyncStatementResultImpl.of( + internalExecuteQueryAsync(parsedStatement, AnalyzeMode.NONE)); + } return AsyncStatementResultImpl.of(internalExecuteUpdateAsync(parsedStatement)); case DDL: return AsyncStatementResultImpl.noResult(executeDdlAsync(parsedStatement)); @@ -918,7 +925,7 @@ private ResultSet parseAndExecuteQuery( Preconditions.checkNotNull(analyzeMode); ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); ParsedStatement parsedStatement = getStatementParser().parse(query, this.queryOptions); - if (parsedStatement.isQuery()) { + if (parsedStatement.isQuery() || parsedStatement.isUpdate()) { switch (parsedStatement.getType()) { case CLIENT_SIDE: return parsedStatement @@ -928,6 +935,19 @@ private ResultSet parseAndExecuteQuery( case QUERY: return internalExecuteQuery(parsedStatement, analyzeMode, options); case UPDATE: + if (parsedStatement.hasReturningClause()) { + // Cannot execute DML statement with returning clause in read-only mode or in + // READ_ONLY_TRANSACTION transaction mode. + if (this.isReadOnly() + || (this.isInTransaction() + && this.getTransactionMode() == TransactionMode.READ_ONLY_TRANSACTION)) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "DML statement with returning clause cannot be executed in read-only mode: " + + parsedStatement.getSqlWithoutComments()); + } + return internalExecuteQuery(parsedStatement, analyzeMode, options); + } case DDL: case UNKNOWN: default: @@ -935,7 +955,8 @@ private ResultSet parseAndExecuteQuery( } throw SpannerExceptionFactory.newSpannerException( ErrorCode.INVALID_ARGUMENT, - "Statement is not a query: " + parsedStatement.getSqlWithoutComments()); + "Statement is not a query or DML with returning clause: " + + parsedStatement.getSqlWithoutComments()); } private AsyncResultSet parseAndExecuteQueryAsync( @@ -943,7 +964,7 @@ private AsyncResultSet parseAndExecuteQueryAsync( Preconditions.checkNotNull(query); ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); ParsedStatement parsedStatement = getStatementParser().parse(query, this.queryOptions); - if (parsedStatement.isQuery()) { + if (parsedStatement.isQuery() || parsedStatement.isUpdate()) { switch (parsedStatement.getType()) { case CLIENT_SIDE: return ResultSets.toAsyncResultSet( @@ -956,6 +977,19 @@ private AsyncResultSet parseAndExecuteQueryAsync( case QUERY: return internalExecuteQueryAsync(parsedStatement, analyzeMode, options); case UPDATE: + if (parsedStatement.hasReturningClause()) { + // Cannot execute DML statement with returning clause in read-only mode or in + // READ_ONLY_TRANSACTION transaction mode. + if (this.isReadOnly() + || (this.isInTransaction() + && this.getTransactionMode() == TransactionMode.READ_ONLY_TRANSACTION)) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "DML statement with returning clause cannot be executed in read-only mode: " + + parsedStatement.getSqlWithoutComments()); + } + return internalExecuteQueryAsync(parsedStatement, analyzeMode, options); + } case DDL: case UNKNOWN: default: @@ -963,7 +997,8 @@ private AsyncResultSet parseAndExecuteQueryAsync( } throw SpannerExceptionFactory.newSpannerException( ErrorCode.INVALID_ARGUMENT, - "Statement is not a query: " + parsedStatement.getSqlWithoutComments()); + "Statement is not a query or DML with returning clause: " + + parsedStatement.getSqlWithoutComments()); } @Override @@ -974,6 +1009,13 @@ public long executeUpdate(Statement update) { if (parsedStatement.isUpdate()) { switch (parsedStatement.getType()) { case UPDATE: + if (parsedStatement.hasReturningClause()) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "DML statement with returning clause cannot be executed using executeUpdate: " + + parsedStatement.getSqlWithoutComments() + + ". Please use executeQuery instead."); + } return get(internalExecuteUpdateAsync(parsedStatement)); case CLIENT_SIDE: case QUERY: @@ -995,6 +1037,13 @@ public ApiFuture executeUpdateAsync(Statement update) { if (parsedStatement.isUpdate()) { switch (parsedStatement.getType()) { case UPDATE: + if (parsedStatement.hasReturningClause()) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "DML statement with returning clause cannot be executed using executeUpdateAsync: " + + parsedStatement.getSqlWithoutComments() + + ". Please use executeQueryAsync instead."); + } return internalExecuteUpdateAsync(parsedStatement); case CLIENT_SIDE: case QUERY: @@ -1141,8 +1190,9 @@ private ResultSet internalExecuteQuery( final QueryOption... options) { Preconditions.checkArgument( statement.getType() == StatementType.QUERY - || (statement.getType() == StatementType.UPDATE && analyzeMode != AnalyzeMode.NONE), - "Statement must either be a query or a DML mode with analyzeMode!=NONE"); + || (statement.getType() == StatementType.UPDATE + && (analyzeMode != AnalyzeMode.NONE || statement.hasReturningClause())), + "Statement must either be a query or a DML mode with analyzeMode!=NONE or returning clause"); UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork(); return get( transaction.executeQueryAsync( @@ -1154,7 +1204,9 @@ private AsyncResultSet internalExecuteQueryAsync( final AnalyzeMode analyzeMode, final QueryOption... options) { Preconditions.checkArgument( - statement.getType() == StatementType.QUERY, "Statement must be a query"); + (statement.getType() == StatementType.QUERY) + || (statement.getType() == StatementType.UPDATE && statement.hasReturningClause()), + "Statement must be a query or DML with returning clause."); UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork(); return ResultSets.toAsyncResultSet( transaction.executeQueryAsync( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java index 6eba0ce1b29..012bfbba875 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java @@ -25,10 +25,15 @@ import java.util.Collections; import java.util.HashSet; import java.util.Set; +import java.util.regex.Pattern; import javax.annotation.Nullable; @InternalApi public class PostgreSQLStatementParser extends AbstractStatementParser { + private static final Pattern RETURNING_PATTERN = Pattern.compile("returning[ '(\"*]"); + private static final Pattern AS_RETURNING_PATTERN = Pattern.compile("[ ')\"]as returning[ '(\"]"); + private static final String RETURNING_STRING = "returning"; + PostgreSQLStatementParser() throws CompileException { super( Collections.unmodifiableSet( @@ -65,6 +70,8 @@ String removeCommentsAndTrimInternal(String sql) { Preconditions.checkNotNull(sql); boolean isInSingleLineComment = false; int multiLineCommentLevel = 0; + boolean whitespaceBeforeOrAfterMultiLineComment = false; + int multiLineCommentStartIdx = -1; StringBuilder res = new StringBuilder(sql.length()); int index = 0; while (index < sql.length()) { @@ -78,6 +85,19 @@ String removeCommentsAndTrimInternal(String sql) { } else if (multiLineCommentLevel > 0) { if (sql.length() > index + 1 && c == ASTERISK && sql.charAt(index + 1) == SLASH) { multiLineCommentLevel--; + if (multiLineCommentLevel == 0) { + if (!whitespaceBeforeOrAfterMultiLineComment && (sql.length() > index + 2)) { + whitespaceBeforeOrAfterMultiLineComment = + Character.isWhitespace(sql.charAt(index + 2)); + } + // If the multiline comment does not have any whitespace before or after it, and it is + // neither at the start nor at the end of SQL string, append an extra space. + if (!whitespaceBeforeOrAfterMultiLineComment + && (multiLineCommentStartIdx != 0) + && (index != sql.length() - 2)) { + res.append(' '); + } + } index++; } else if (sql.length() > index + 1 && c == SLASH && sql.charAt(index + 1) == ASTERISK) { multiLineCommentLevel++; @@ -92,6 +112,10 @@ String removeCommentsAndTrimInternal(String sql) { continue; } else if (sql.length() > index + 1 && c == SLASH && sql.charAt(index + 1) == ASTERISK) { multiLineCommentLevel++; + if (index >= 1) { + whitespaceBeforeOrAfterMultiLineComment = Character.isWhitespace(sql.charAt(index - 1)); + } + multiLineCommentStartIdx = index; index += 2; continue; } else { @@ -120,7 +144,7 @@ String parseDollarQuotedString(String sql, int index) { if (c == DOLLAR) { return tag.toString(); } - if (!Character.isJavaIdentifierPart(c)) { + if (!isValidIdentifierChar(c)) { break; } tag.append(c); @@ -267,4 +291,77 @@ private void appendIfNotNull( result.append(prefix).append(tag).append(suffix); } } + + private boolean isValidIdentifierFirstChar(char c) { + return Character.isLetter(c) || c == UNDERSCORE; + } + + private boolean isValidIdentifierChar(char c) { + return isValidIdentifierFirstChar(c) || Character.isDigit(c) || c == DOLLAR; + } + + private boolean checkCharPrecedingReturning(char ch) { + return (ch == SPACE) + || (ch == SINGLE_QUOTE) + || (ch == CLOSE_PARENTHESIS) + || (ch == DOUBLE_QUOTE) + || (ch == DOLLAR); + } + + private boolean checkCharPrecedingSubstrWithReturning(char ch) { + return (ch == SPACE) + || (ch == SINGLE_QUOTE) + || (ch == CLOSE_PARENTHESIS) + || (ch == DOUBLE_QUOTE) + || (ch == COMMA); + } + + private boolean isReturning(String sql, int index) { + // RETURNING is a reserved keyword in PG, but requires a + // leading AS to be used as column label, to avoid ambiguity. + // We thus check for cases which do not have a leading AS. + // (https://www.postgresql.org/docs/current/sql-keywords-appendix.html) + if (index >= 1) { + if (((index + 10 <= sql.length()) + && RETURNING_PATTERN.matcher(sql.substring(index, index + 10)).matches() + && !((index >= 4) + && AS_RETURNING_PATTERN.matcher(sql.substring(index - 4, index + 10)).matches()))) { + if (checkCharPrecedingReturning(sql.charAt(index - 1))) { + return true; + } + // Check for cases where returning clause is part of a substring which starts with an + // invalid first character of an identifier. + // For example, + // insert into t select 2returning *; + int ind = index - 1; + while ((ind >= 0) && !checkCharPrecedingSubstrWithReturning(sql.charAt(ind))) { + ind--; + } + return !isValidIdentifierFirstChar(sql.charAt(ind + 1)); + } + } + return false; + } + + @InternalApi + @Override + protected boolean checkReturningClauseInternal(String rawSql) { + Preconditions.checkNotNull(rawSql); + String sql = rawSql.toLowerCase(); + // Do a pre-check to check if the SQL string definitely does not have a returning clause. + // If this check fails, do a more involved check to check for a returning clause. + if (!sql.contains(RETURNING_STRING)) { + return false; + } + sql = sql.replaceAll("\\s+", " "); + int index = 0; + while (index < sql.length()) { + if (isReturning(sql, index)) { + return true; + } else { + index = skip(sql, index, null); + } + } + return false; + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java index 968ef357c24..80dc97cde04 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java @@ -45,6 +45,7 @@ import com.google.cloud.spanner.TransactionContext; import com.google.cloud.spanner.TransactionManager; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; +import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; import com.google.cloud.spanner.connection.TransactionRetryListener.RetryResult; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -337,7 +338,10 @@ public ApiFuture executeQueryAsync( final ParsedStatement statement, final AnalyzeMode analyzeMode, final QueryOption... options) { - Preconditions.checkArgument(statement.isQuery(), "Statement is not a query"); + Preconditions.checkArgument( + (statement.getType() == StatementType.QUERY) + || (statement.getType() == StatementType.UPDATE && statement.hasReturningClause()), + "Statement must be a query or DML with returning clause"); checkValidTransaction(); ApiFuture res; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java index 9646c7c9c95..cbde67393aa 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java @@ -179,12 +179,18 @@ public ApiFuture executeQueryAsync( final QueryOption... options) { Preconditions.checkNotNull(statement); Preconditions.checkArgument( - statement.isQuery() || (statement.isUpdate() && analyzeMode != AnalyzeMode.NONE), + statement.isQuery() + || (statement.isUpdate() + && (analyzeMode != AnalyzeMode.NONE || statement.hasReturningClause())), "The statement must be a query, or the statement must be DML and AnalyzeMode must be PLAN or PROFILE"); checkAndMarkUsed(); - if (statement.isUpdate() && analyzeMode != AnalyzeMode.NONE) { - return analyzeTransactionalUpdateAsync(statement, analyzeMode); + if (statement.isUpdate()) { + if (analyzeMode != AnalyzeMode.NONE) { + return analyzeTransactionalUpdateAsync(statement, analyzeMode); + } + // DML with returning clause. + return executeDmlReturningAsync(statement, options); } final ReadOnlyTransaction currentTransaction = @@ -217,6 +223,30 @@ public ApiFuture executeQueryAsync( return executeStatementAsync(statement, callable, SpannerGrpc.getExecuteStreamingSqlMethod()); } + private ApiFuture executeDmlReturningAsync( + final ParsedStatement update, QueryOption... options) { + Callable callable = + () -> { + try { + writeTransaction = createWriteTransaction(); + ResultSet resultSet = + writeTransaction.run( + transaction -> + DirectExecuteResultSet.ofResultSet( + transaction.executeQuery(update.getStatement(), options))); + state = UnitOfWorkState.COMMITTED; + return resultSet; + } catch (Throwable t) { + state = UnitOfWorkState.COMMIT_FAILED; + throw t; + } + }; + return executeStatementAsync( + update, + callable, + ImmutableList.of(SpannerGrpc.getExecuteSqlMethod(), SpannerGrpc.getCommitMethod())); + } + @Override public Timestamp getReadTimestamp() { ConnectionPreconditions.checkState( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java index f223cbb935a..a9ca5fb9726 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java @@ -25,10 +25,16 @@ import com.google.common.collect.Sets; import java.util.Collections; import java.util.Set; +import java.util.regex.Pattern; @InternalApi public class SpannerStatementParser extends AbstractStatementParser { + private static final Pattern THEN_RETURN_PATTERN = + Pattern.compile("[ `')\"]then return[ *`'(\"]"); + private static final String THEN_STRING = "then"; + private static final String RETURN_STRING = "return"; + public SpannerStatementParser() throws CompileException { super( Collections.unmodifiableSet( @@ -69,6 +75,8 @@ String removeCommentsAndTrimInternal(String sql) { char startQuote = 0; boolean lastCharWasEscapeChar = false; boolean isTripleQuoted = false; + boolean whitespaceBeforeOrAfterMultiLineComment = false; + int multiLineCommentStartIdx = -1; StringBuilder res = new StringBuilder(sql.length()); int index = 0; while (index < sql.length()) { @@ -112,6 +120,17 @@ String removeCommentsAndTrimInternal(String sql) { } else if (isInMultiLineComment) { if (sql.length() > index + 1 && c == ASTERISK && sql.charAt(index + 1) == SLASH) { isInMultiLineComment = false; + if (!whitespaceBeforeOrAfterMultiLineComment && (sql.length() > index + 2)) { + whitespaceBeforeOrAfterMultiLineComment = + Character.isWhitespace(sql.charAt(index + 2)); + } + // If the multiline comment does not have any whitespace before or after it, and it is + // neither at the start nor at the end of SQL string, append an extra space. + if (!whitespaceBeforeOrAfterMultiLineComment + && (multiLineCommentStartIdx != 0) + && (index != sql.length() - 2)) { + res.append(' '); + } index++; } } else { @@ -121,6 +140,11 @@ String removeCommentsAndTrimInternal(String sql) { isInSingleLineComment = true; } else if (sql.length() > index + 1 && c == SLASH && sql.charAt(index + 1) == ASTERISK) { isInMultiLineComment = true; + if (index >= 1) { + whitespaceBeforeOrAfterMultiLineComment = + Character.isWhitespace(sql.charAt(index - 1)); + } + multiLineCommentStartIdx = index; index++; } else { if (c == SINGLE_QUOTE || c == DOUBLE_QUOTE || c == BACKTICK_QUOTE) { @@ -250,4 +274,71 @@ ParametersInfo convertPositionalParametersToNamedParametersInternal(char paramCh } return new ParametersInfo(paramIndex - 1, named.toString()); } + + private boolean isReturning(String sql, int index) { + return (index >= 1) + && (index + 12 <= sql.length()) + && THEN_RETURN_PATTERN.matcher(sql.substring(index - 1, index + 12)).matches(); + } + + @InternalApi + @Override + protected boolean checkReturningClauseInternal(String rawSql) { + Preconditions.checkNotNull(rawSql); + String sql = rawSql.toLowerCase(); + // Do a pre-check to check if the SQL string definitely does not have a returning clause. + // If this check fails, do a more involved check to check for a returning clause. + if (!(sql.contains(THEN_STRING) && sql.contains(RETURN_STRING))) { + return false; + } + sql = sql.replaceAll("\\s+", " "); + final char SINGLE_QUOTE = '\''; + final char DOUBLE_QUOTE = '"'; + final char BACKTICK_QUOTE = '`'; + boolean isInQuoted = false; + char startQuote = 0; + boolean lastCharWasEscapeChar = false; + boolean isTripleQuoted = false; + for (int index = 0; index < sql.length(); index++) { + char c = sql.charAt(index); + if (isInQuoted) { + if (c == startQuote) { + if (lastCharWasEscapeChar) { + lastCharWasEscapeChar = false; + } else if (isTripleQuoted) { + if (sql.length() > index + 2 + && sql.charAt(index + 1) == startQuote + && sql.charAt(index + 2) == startQuote) { + isInQuoted = false; + startQuote = 0; + isTripleQuoted = false; + } + } else { + isInQuoted = false; + startQuote = 0; + } + } else if (c == '\\') { + lastCharWasEscapeChar = true; + } else { + lastCharWasEscapeChar = false; + } + } else { + if (isReturning(sql, index)) { + return true; + } else { + if (c == SINGLE_QUOTE || c == DOUBLE_QUOTE || c == BACKTICK_QUOTE) { + isInQuoted = true; + startQuote = c; + // check whether it is a triple-quote + if (sql.length() > index + 2 + && sql.charAt(index + 1) == startQuote + && sql.charAt(index + 2) == startQuote) { + isTripleQuoted = true; + } + } + } + } + } + return false; + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 44942bb566c..340f999d8ee 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -270,6 +270,14 @@ public static StatementResult update(Statement statement, long updateCount) { return new StatementResult(statement, updateCount); } + /** + * Creates a {@link StatementResult} for a DML statement with returning clause that returns a + * ResultSet. + */ + public static StatementResult updateReturning(Statement statement, ResultSet resultSet) { + return new StatementResult(statement, resultSet); + } + /** Creates a {@link StatementResult} for statement that should return an error. */ public static StatementResult exception(Statement statement, StatusRuntimeException exception) { return new StatementResult(statement, exception); @@ -1093,9 +1101,6 @@ public void executeBatchDml( .build(); break resultLoop; case RESULT_SET: - throw Status.INVALID_ARGUMENT - .withDescription("Not a DML statement: " + statement.getSql()) - .asRuntimeException(); case UPDATE_COUNT: results.add(res); break; @@ -1120,10 +1125,20 @@ public void executeBatchDml( } ExecuteBatchDmlResponse.Builder builder = ExecuteBatchDmlResponse.newBuilder(); for (StatementResult res : results) { + Long updateCount; + switch (res.getType()) { + case UPDATE_COUNT: + updateCount = res.getUpdateCount(); + break; + case RESULT_SET: + updateCount = res.getResultSet().getStats().getRowCountExact(); + break; + default: + throw new IllegalStateException("Invalid result type: " + res.getType()); + } builder.addResultSets( ResultSet.newBuilder() - .setStats( - ResultSetStats.newBuilder().setRowCountExact(res.getUpdateCount()).build()) + .setStats(ResultSetStats.newBuilder().setRowCountExact(updateCount).build()) .setMetadata( ResultSetMetadata.newBuilder() .setTransaction( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbortedTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbortedTest.java index 96416c3e8d5..3a665c4fe7a 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbortedTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbortedTest.java @@ -89,6 +89,46 @@ public void testCommitAborted() { } } + @Test + public void testCommitAbortedDuringUpdateWithReturning() { + // Do two iterations to ensure that each iteration gets its own transaction, and that each + // transaction is the most recent transaction of that session. + for (int i = 0; i < 2; i++) { + mockSpanner.putStatementResult( + StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_BEFORE_INSERT)); + mockSpanner.putStatementResult( + StatementResult.updateReturning(INSERT_RETURNING_STATEMENT, UPDATE_RETURNING_RESULTSET)); + AbortInterceptor interceptor = new AbortInterceptor(0); + try (ITConnection connection = + createConnection(interceptor, new CountTransactionRetryListener())) { + // verify that the there is no test record + try (ResultSet rs = + connection.executeQuery(Statement.of("SELECT COUNT(*) AS C FROM TEST WHERE ID=1"))) { + assertThat(rs.next(), is(true)); + assertThat(rs.getLong("C"), is(equalTo(0L))); + assertThat(rs.next(), is(false)); + } + // do an insert with returning + connection.executeQuery( + Statement.of("INSERT INTO TEST (ID, NAME) VALUES (1, 'test aborted') THEN RETURN *")); + // indicate that the next statement should abort + interceptor.setProbability(1.0); + interceptor.setOnlyInjectOnce(true); + // do a commit that will first abort, and then on retry will succeed + connection.commit(); + mockSpanner.putStatementResult( + StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_AFTER_INSERT)); + // verify that the insert succeeded + try (ResultSet rs = + connection.executeQuery(Statement.of("SELECT COUNT(*) AS C FROM TEST WHERE ID=1"))) { + assertThat(rs.next(), is(true)); + assertThat(rs.getLong("C"), is(equalTo(1L))); + assertThat(rs.next(), is(false)); + } + } + } + } + @Test public void testAbortedDuringRetryOfFailedQuery() { final Statement invalidStatement = Statement.of("SELECT * FROM FOO"); @@ -144,6 +184,34 @@ public void testAbortedDuringRetryOfFailedUpdate() { assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(6); } + @Test + public void testAbortedDuringRetryOfFailedUpdateWithReturning() { + final Statement invalidStatement = Statement.of("INSERT INTO FOO THEN RETURN *"); + StatusRuntimeException notFound = + Status.NOT_FOUND.withDescription("Table not found").asRuntimeException(); + mockSpanner.putStatementResult(StatementResult.exception(invalidStatement, notFound)); + try (ITConnection connection = + createConnection(createAbortFirstRetryListener(invalidStatement, notFound))) { + connection.execute(INSERT_STATEMENT); + try { + connection.execute(invalidStatement); + fail("missing expected exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + } + // Force an abort and retry. + mockSpanner.abortNextStatement(); + connection.commit(); + } + assertThat(mockSpanner.countRequestsOfType(CommitRequest.class)).isEqualTo(2); + // The transaction will be executed 3 times, which means that there will be 6 + // ExecuteSqlRequests: + // 1. The initial attempt. + // 2. The first retry attempt. This will fail on the invalid statement as it is aborted. + // 3. the second retry attempt. This will succeed. + assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(6); + } + @Test public void testAbortedDuringRetryOfFailedBatchUpdate() { final Statement invalidStatement = Statement.of("INSERT INTO FOO"); @@ -167,6 +235,29 @@ public void testAbortedDuringRetryOfFailedBatchUpdate() { assertThat(mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)).isEqualTo(3); } + @Test + public void testAbortedDuringRetryOfFailedBatchUpdateWithReturning() { + final Statement invalidStatement = Statement.of("INSERT INTO FOO THEN RETURN *"); + StatusRuntimeException notFound = + Status.NOT_FOUND.withDescription("Table not found").asRuntimeException(); + mockSpanner.putStatementResult(StatementResult.exception(invalidStatement, notFound)); + try (ITConnection connection = + createConnection(createAbortFirstRetryListener(invalidStatement, notFound))) { + connection.execute(INSERT_STATEMENT); + try { + connection.executeBatchUpdate(Collections.singletonList(invalidStatement)); + fail("missing expected exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + } + // Force an abort and retry. + mockSpanner.abortNextStatement(); + connection.commit(); + } + assertThat(mockSpanner.countRequestsOfType(CommitRequest.class)).isEqualTo(2); + assertThat(mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)).isEqualTo(3); + } + @Test public void testAbortedDuringRetryOfFailedQueryAsFirstStatement() { final Statement invalidStatement = Statement.of("SELECT * FROM FOO"); @@ -225,6 +316,29 @@ public void testAbortedDuringRetryOfFailedUpdateAsFirstStatement() { assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(8); } + @Test + public void testAbortedDuringRetryOfFailedUpdateWithReturningAsFirstStatement() { + final Statement invalidStatement = Statement.of("INSERT INTO FOO THEN RETURN *"); + StatusRuntimeException notFound = + Status.NOT_FOUND.withDescription("Table not found").asRuntimeException(); + mockSpanner.putStatementResult(StatementResult.exception(invalidStatement, notFound)); + try (ITConnection connection = + createConnection(createAbortRetryListener(2, invalidStatement, notFound))) { + try { + connection.execute(invalidStatement); + fail("missing expected exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + } + connection.execute(INSERT_STATEMENT); + // Force an abort and retry. + mockSpanner.abortNextStatement(); + connection.commit(); + } + assertThat(mockSpanner.countRequestsOfType(CommitRequest.class)).isEqualTo(2); + assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(8); + } + @Test public void testAbortedDuringRetryOfFailedBatchUpdateAsFirstStatement() { final Statement invalidStatement = Statement.of("INSERT INTO FOO"); @@ -248,6 +362,29 @@ public void testAbortedDuringRetryOfFailedBatchUpdateAsFirstStatement() { assertThat(mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)).isEqualTo(6); } + @Test + public void testAbortedDuringRetryOfFailedBatchUpdateWithReturningAsFirstStatement() { + final Statement invalidStatement = Statement.of("INSERT INTO FOO THEN RETURN *"); + StatusRuntimeException notFound = + Status.NOT_FOUND.withDescription("Table not found").asRuntimeException(); + mockSpanner.putStatementResult(StatementResult.exception(invalidStatement, notFound)); + try (ITConnection connection = + createConnection(createAbortFirstRetryListener(invalidStatement, notFound))) { + try { + connection.executeBatchUpdate(Collections.singletonList(invalidStatement)); + fail("missing expected exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + } + connection.execute(INSERT_STATEMENT); + // Force an abort and retry. + mockSpanner.abortNextStatement(); + connection.commit(); + } + assertThat(mockSpanner.countRequestsOfType(CommitRequest.class)).isEqualTo(2); + assertThat(mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)).isEqualTo(6); + } + @Test public void testRetryUsesTags() { mockSpanner.putStatementResult( @@ -303,6 +440,63 @@ public void testRetryUsesTags() { } @Test + public void testRetryUsesTagsWithUpdateReturning() { + mockSpanner.putStatementResult( + StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_BEFORE_INSERT)); + mockSpanner.putStatementResult(StatementResult.update(INSERT_STATEMENT, UPDATE_COUNT)); + mockSpanner.putStatementResult( + StatementResult.updateReturning(INSERT_RETURNING_STATEMENT, UPDATE_RETURNING_RESULTSET)); + try (ITConnection connection = createConnection()) { + connection.setTransactionTag("transaction-tag"); + connection.setStatementTag("statement-tag"); + connection.executeUpdate(INSERT_STATEMENT); + connection.setStatementTag("statement-tag"); + connection.executeBatchUpdate(ImmutableList.of(INSERT_STATEMENT, INSERT_RETURNING_STATEMENT)); + connection.setStatementTag("statement-tag"); + connection.executeQuery(SELECT_COUNT_STATEMENT); + connection.setStatementTag("statement-tag"); + connection.executeQuery(INSERT_RETURNING_STATEMENT); + + mockSpanner.abortNextStatement(); + connection.commit(); + } + long executeSqlRequestCount = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() + .filter( + request -> + request.getRequestOptions().getRequestTag().equals("statement-tag") + && request + .getRequestOptions() + .getTransactionTag() + .equals("transaction-tag")) + .count(); + assertEquals(6L, executeSqlRequestCount); + + long executeBatchSqlRequestCount = + mockSpanner.getRequestsOfType(ExecuteBatchDmlRequest.class).stream() + .filter( + request -> + request.getRequestOptions().getRequestTag().equals("statement-tag") + && request + .getRequestOptions() + .getTransactionTag() + .equals("transaction-tag")) + .count(); + assertEquals(2L, executeBatchSqlRequestCount); + + long commitRequestCount = + mockSpanner.getRequestsOfType(CommitRequest.class).stream() + .filter( + request -> + request.getRequestOptions().getRequestTag().equals("") + && request + .getRequestOptions() + .getTransactionTag() + .equals("transaction-tag")) + .count(); + assertEquals(2L, commitRequestCount); + } + public void testRetryUsesAnalyzeModeForUpdate() { mockSpanner.putStatementResult( StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_BEFORE_INSERT)); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java index 7f9163e8e46..7b65b44f361 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java @@ -37,6 +37,7 @@ import com.google.protobuf.Value; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.StructType; import com.google.spanner.v1.StructType.Field; import com.google.spanner.v1.Type; @@ -96,8 +97,23 @@ public abstract class AbstractMockServerTest { .build()) .setMetadata(SELECT_COUNT_METADATA) .build(); + public static final com.google.spanner.v1.ResultSet UPDATE_RETURNING_RESULTSET = + com.google.spanner.v1.ResultSet.newBuilder() + .setStats(ResultSetStats.newBuilder().setRowCountExact(1)) + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("col") + .setType(Type.newBuilder().setCodeValue(TypeCode.INT64_VALUE)) + .build()))) + .build(); public static final Statement INSERT_STATEMENT = Statement.of("INSERT INTO TEST (ID, NAME) VALUES (1, 'test aborted')"); + public static final Statement INSERT_RETURNING_STATEMENT = + Statement.of("INSERT INTO TEST (ID, NAME) VALUES (1, 'test aborted') THEN RETURN *"); public static final long UPDATE_COUNT = 1L; public static final int RANDOM_RESULT_SET_ROW_COUNT = 100; @@ -149,6 +165,8 @@ public void getOperation( mockSpanner.putStatementResult( StatementResult.query(SELECT_COUNT_STATEMENT, SELECT_COUNT_RESULTSET_BEFORE_INSERT)); mockSpanner.putStatementResult(StatementResult.update(INSERT_STATEMENT, UPDATE_COUNT)); + mockSpanner.putStatementResult( + StatementResult.updateReturning(INSERT_RETURNING_STATEMENT, UPDATE_RETURNING_RESULTSET)); mockSpanner.putStatementResult( StatementResult.query(SELECT_RANDOM_STATEMENT, RANDOM_RESULT_SET)); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java index 2e39025bf33..56760d900af 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java @@ -1266,6 +1266,215 @@ public void testPostgreSQLGetQueryParameters() { parser.getQueryParameters("select '$2' from foo where bar=$1 and baz=$foo")); } + @Test + public void testGoogleSQLReturningClause() { + assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL); + + SpannerStatementParser parser = (SpannerStatementParser) this.parser; + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) then return *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) then\nreturn *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)\nthen\n\n\nreturn\n*")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)then return *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) then return(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)then return(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) then/*comment*/return *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) then return /*then return*/ *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)then/*comment*/return *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)then/*comment*/return(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)then /*comment*/return(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse( + Statement.of("insert into x (a,b) values (1,2)/*comment*/then/*comment*/return(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse( + Statement.of( + "insert into x (a,b) values (1,2)/*comment*/then/*comment*/return/*comment*/(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse( + Statement.of( + "insert into x (a,b) values (1,2)/*comment" + + "*/then" + + "/*comment" + + "*/return/*" + + "comment*/(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("delete from x where y=\"z\"then return *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x select 'then return' as returning then return *")) + .hasReturningClause()); + assertTrue( + parser.parse(Statement.of("delete from x where 10=`z`then return *")).hasReturningClause()); + assertFalse( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) returning (a)")) + .hasReturningClause()); + assertFalse( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) /*then return **/")) + .hasReturningClause()); + assertFalse( + parser.parse(Statement.of("insert into x (a,b) values (1,2)")).hasReturningClause()); + assertFalse( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)thenreturn*")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into t(a) select \"x\"then return*")) + .hasReturningClause()); + } + + @Test + public void testPostgreSQLReturningClause() { + assumeTrue(dialect == Dialect.POSTGRESQL); + + PostgreSQLStatementParser parser = (PostgreSQLStatementParser) this.parser; + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) returning *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)returning *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) returning(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)returning(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x (a,b) values (1,2)/*comment*/returning(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse( + Statement.of("insert into x (a,b) values (1,2)/*comment*/returning/*comment*/(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse( + Statement.of( + "insert into x (a,b) values (1,2)/*comment" + "*/returning/*" + "comment*/(a)")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x select 1 as returning returning *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x select 'returning' as returning returning *")) + .hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into x select 'returning'as returning returning *")) + .hasReturningClause()); + assertTrue( + parser.parse(Statement.of("delete from x where y=\"z\"returning *")).hasReturningClause()); + assertTrue( + parser.parse(Statement.of("delete from x where y='z'returning *")).hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into t1 select 1/*as /*returning*/ returning*/returning *")) + .hasReturningClause()); + assertFalse( + parser + .parse(Statement.of("insert into x (a,b) values (1,2) then return (a)")) + .hasReturningClause()); + assertFalse( + parser.parse(Statement.of("insert into x (a,b) values (1,2)")).hasReturningClause()); + assertFalse( + parser.parse(Statement.of("insert into t1 select 1 as returning")).hasReturningClause()); + assertFalse( + parser + .parse(Statement.of("insert into t1 select 1\nas\n\nreturning")) + .hasReturningClause()); + assertFalse( + parser.parse(Statement.of("insert into t1 select 1asreturning")).hasReturningClause()); + assertTrue( + parser + .parse(Statement.of("insert into t1 select 1 as/*eomment*/returning returning *")) + .hasReturningClause()); + assertFalse( + parser + .parse(Statement.of("UPDATE x SET y = $$ RETURNING a, b, c$$ WHERE z = 123")) + .hasReturningClause()); + assertFalse( + parser + .parse( + Statement.of("UPDATE x SET y = $foobar$ RETURNING a, b, c$foobar$ WHERE z = 123")) + .hasReturningClause()); + assertFalse( + parser + .parse(Statement.of("UPDATE x SET y = $returning$ returning $returning$ WHERE z = 123")) + .hasReturningClause()); + assertTrue( + parser + .parse( + Statement.of( + "UPDATE x SET y = $returning$returning$returning$ WHERE z = 123 ReTuRnInG *")) + .hasReturningClause()); + assertTrue( + parser.parse(Statement.of("insert into t1 select 1 returning*")).hasReturningClause()); + assertTrue( + parser.parse(Statement.of("insert into t1 select 2returning*")).hasReturningClause()); + assertTrue( + parser.parse(Statement.of("insert into t1 select 10e2returning*")).hasReturningClause()); + assertFalse( + parser + .parse(Statement.of("insert into t1 select 'test''returning *'")) + .hasReturningClause()); + assertTrue( + parser.parse(Statement.of("insert into t select 2,3returning*")).hasReturningClause()); + assertTrue( + parser.parse(Statement.of("insert into t1 select 10.returning*")).hasReturningClause()); + } + private void assertUnclosedLiteral(String sql) { try { parser.convertPositionalParametersToNamedParameters('?', sql); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java new file mode 100644 index 00000000000..5e5ca800402 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java @@ -0,0 +1,344 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection.it; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeFalse; + +import com.google.cloud.spanner.AsyncResultSet; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.ParallelIntegrationTest; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AsyncStatementResult; +import com.google.cloud.spanner.connection.Connection; +import com.google.cloud.spanner.connection.ITAbstractSpannerTest; +import com.google.cloud.spanner.connection.StatementResult; +import com.google.cloud.spanner.connection.StatementResult.ResultType; +import com.google.cloud.spanner.connection.TransactionMode; +import com.google.cloud.spanner.testing.EmulatorSpannerHelper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +/** Execute DML Returning statements using the generic connection API. */ +@Category(ParallelIntegrationTest.class) +@RunWith(Parameterized.class) +public class ITDmlReturningTest extends ITAbstractSpannerTest { + private final ImmutableMap UPDATE_RETURNING_MAP = + ImmutableMap.of( + Dialect.GOOGLE_STANDARD_SQL, + Statement.of("UPDATE Singers SET LastName = 'XYZ' WHERE FirstName = 'ABC' THEN RETURN *"), + Dialect.POSTGRESQL, + Statement.of("UPDATE Singers SET LastName = 'XYZ' WHERE FirstName = 'ABC' RETURNING *")); + private final ImmutableMap DDL_MAP = + ImmutableMap.of( + Dialect.GOOGLE_STANDARD_SQL, + "CREATE TABLE Singers (" + + " SingerId INT64," + + " FirstName STRING(1024)," + + " LastName STRING(1024)" + + ") PRIMARY KEY(SingerId)", + Dialect.POSTGRESQL, + "CREATE TABLE Singers (" + + " SingerId BIGINT PRIMARY KEY," + + " FirstName character varying(1024)," + + " LastName character varying(1024))"); + private final Map IS_INITIALIZED = new HashMap<>(); + + public ITDmlReturningTest() { + IS_INITIALIZED.put(Dialect.GOOGLE_STANDARD_SQL, false); + IS_INITIALIZED.put(Dialect.POSTGRESQL, false); + } + + @Parameter public Dialect dialect; + + @Parameters(name = "dialect = {0}") + public static Object[] data() { + return Dialect.values(); + } + + private boolean checkAndSetInitialized() { + if ((dialect == Dialect.GOOGLE_STANDARD_SQL) && !IS_INITIALIZED.get(dialect)) { + IS_INITIALIZED.put(dialect, true); + return true; + } + if ((dialect == Dialect.POSTGRESQL) && !IS_INITIALIZED.get(dialect)) { + IS_INITIALIZED.put(dialect, true); + return true; + } + return false; + } + + @Before + public void setupTable() { + assumeFalse( + "DML Returning is not supported in the emulator", EmulatorSpannerHelper.isUsingEmulator()); + if (checkAndSetInitialized()) { + database = + env.getTestHelper() + .createTestDatabase(dialect, Collections.singleton(DDL_MAP.get(dialect))); + List firstNames = Arrays.asList("ABC", "ABC", "DEF", "PQR", "ABC"); + List lastNames = Arrays.asList("XYZ", "DEF", "XYZ", "ABC", "GHI"); + List mutations = new ArrayList<>(); + for (int id = 1; id <= 5; id++) { + mutations.add( + Mutation.newInsertBuilder("SINGERS") + .set("SINGERID") + .to(id) + .set("FIRSTNAME") + .to(firstNames.get(id - 1)) + .set("LASTNAME") + .to(lastNames.get(id - 1)) + .build()); + } + env.getTestHelper().getDatabaseClient(database).write(mutations); + } + } + + @Test + public void testDmlReturningExecuteQuery() { + try (Connection connection = createConnection()) { + try (ResultSet rs = connection.executeQuery(UPDATE_RETURNING_MAP.get(dialect))) { + assertEquals(rs.getColumnCount(), 3); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "ABC"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "ABC"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "ABC"); + assertFalse(rs.next()); + assertNotNull(rs.getStats()); + assertEquals(rs.getStats().getRowCountExact(), 3); + } + } + } + + @Test + public void testDmlReturningExecuteQueryAsync() { + try (Connection connection = createConnection()) { + try (AsyncResultSet rs = connection.executeQueryAsync(UPDATE_RETURNING_MAP.get(dialect))) { + rs.setCallback( + Executors.newSingleThreadExecutor(), + resultSet -> { + try { + while (true) { + switch (resultSet.tryNext()) { + case OK: + assertEquals(resultSet.getColumnCount(), 3); + assertEquals(resultSet.getString(1), "ABC"); + break; + case DONE: + assertNotNull(resultSet.getStats()); + assertEquals(resultSet.getStats().getRowCountExact(), 3); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + default: + throw new IllegalStateException(); + } + } + } catch (SpannerException e) { + return CallbackResponse.DONE; + } + }); + } + } + } + + @Test + public void testDmlReturningExecuteUpdate() { + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + SpannerException e = + assertThrows( + SpannerException.class, + () -> connection.executeUpdate(UPDATE_RETURNING_MAP.get(dialect))); + assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION); + } + } + + @Test + public void testDmlReturningExecuteUpdateAsync() { + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + SpannerException e = + assertThrows( + SpannerException.class, + () -> connection.executeUpdateAsync(UPDATE_RETURNING_MAP.get(dialect))); + assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION); + } + } + + @Test + public void testDmlReturningExecuteBatchUpdate() { + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + final Statement UPDATE_STMT = UPDATE_RETURNING_MAP.get(dialect); + long[] counts = + connection.executeBatchUpdate(ImmutableList.of(UPDATE_STMT, UPDATE_STMT, UPDATE_STMT)); + assertArrayEquals(counts, new long[] {3, 3, 3}); + } + } + + @Test + public void testDmlReturningExecuteBatchUpdateAsync() { + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + final Statement UPDATE_STMT = UPDATE_RETURNING_MAP.get(dialect); + long[] counts = + connection + .executeBatchUpdateAsync(ImmutableList.of(UPDATE_STMT, UPDATE_STMT, UPDATE_STMT)) + .get(); + assertArrayEquals(counts, new long[] {3, 3, 3}); + } catch (ExecutionException | InterruptedException e) { + // ignore + } + } + + @Test + public void testDmlReturningExecute() { + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + StatementResult res = connection.execute(UPDATE_RETURNING_MAP.get(dialect)); + assertEquals(res.getResultType(), ResultType.RESULT_SET); + try (ResultSet rs = res.getResultSet()) { + assertEquals(rs.getColumnCount(), 3); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "ABC"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "ABC"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "ABC"); + assertFalse(rs.next()); + assertNotNull(rs.getStats()); + assertEquals(rs.getStats().getRowCountExact(), 3); + } + } + } + + @Test + public void testDmlReturningExecuteAsync() { + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + AsyncStatementResult res = connection.executeAsync(UPDATE_RETURNING_MAP.get(dialect)); + assertEquals(res.getResultType(), ResultType.RESULT_SET); + try (AsyncResultSet rs = res.getResultSetAsync()) { + rs.setCallback( + Executors.newSingleThreadExecutor(), + resultSet -> { + try { + while (true) { + switch (resultSet.tryNext()) { + case OK: + assertEquals(resultSet.getColumnCount(), 3); + assertEquals(resultSet.getString(1), "ABC"); + break; + case DONE: + assertNotNull(resultSet.getStats()); + assertEquals(resultSet.getStats().getRowCountExact(), 3); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + default: + throw new IllegalStateException(); + } + } + } catch (SpannerException e) { + System.out.printf("Error in callback: %s%n", e.getMessage()); + return CallbackResponse.DONE; + } + }); + } + } + } + + @Test + public void testDmlReturningExecuteQueryReadOnlyMode() { + try (Connection connection = createConnection()) { + connection.setReadOnly(true); + SpannerException e = + assertThrows( + SpannerException.class, + () -> connection.executeQuery(UPDATE_RETURNING_MAP.get(dialect))); + assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION); + } + } + + @Test + public void testDmlReturningExecuteQueryReadOnlyTransaction() { + try (Connection connection = createConnection()) { + connection.setReadOnly(false); + connection.setAutocommit(false); + connection.setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION); + SpannerException e = + assertThrows( + SpannerException.class, + () -> connection.executeQuery(UPDATE_RETURNING_MAP.get(dialect))); + assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION); + } + } + + @Test + public void testDmlReturningExecuteQueryAsyncReadOnlyMode() { + try (Connection connection = createConnection()) { + connection.setReadOnly(true); + SpannerException e = + assertThrows( + SpannerException.class, + () -> connection.executeQueryAsync(UPDATE_RETURNING_MAP.get(dialect))); + assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION); + } + } + + @Test + public void testDmlReturningExecuteQueryAsyncReadOnlyTransaction() { + try (Connection connection = createConnection()) { + connection.setReadOnly(false); + connection.setAutocommit(false); + connection.setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION); + SpannerException e = + assertThrows( + SpannerException.class, + () -> connection.executeQueryAsync(UPDATE_RETURNING_MAP.get(dialect))); + assertEquals(e.getErrorCode(), ErrorCode.FAILED_PRECONDITION); + } + } +}