Skip to content

Commit

Permalink
chore: refactor client side statements to accept the entire parsed st…
Browse files Browse the repository at this point in the history
…atement (#2556)

* chore: refactor client side statements to accept the entire parsed statement

Refactor the internal interface of client-side statements so these receive the
entire parsed statement, including any query parameters in the statement. This
allows us to create client-side statements that actually use the query parameters
that have been specified by the user.

* chore: simplify test
  • Loading branch information
olavloite authored Aug 2, 2023
1 parent 423e1a4 commit c34d51e
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.spanner.connection;

import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.StatementResult.ClientSideStatementType;
import java.util.List;

Expand Down Expand Up @@ -63,5 +64,5 @@ interface ClientSideStatement {
* needed for the execution of the {@link ClientSideStatement}.
* @return the result of the execution of the statement.
*/
StatementResult execute(ConnectionStatementExecutor executor, String statement);
StatementResult execute(ConnectionStatementExecutor executor, ParsedStatement statement);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.cloud.spanner.connection;

import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;

/**
* A {@link ClientSideStatementExecutor} is used to compile {@link ClientSideStatement}s from the
* json source file, and to execute these against a {@link Connection} (through a {@link
Expand All @@ -29,13 +31,13 @@ interface ClientSideStatementExecutor {
*
* @param connectionExecutor The {@link ConnectionStatementExecutor} to use to execute the
* statement on a {@link Connection}.
* @param sql The sql statement that is executed. This can be used to parse any additional
* @param statement The statement that is executed. This can be used to parse any additional
* arguments that might be needed for the execution of the {@link ClientSideStatementImpl}.
* @return the result of the execution.
* @throws Exception If an error occurs while executing the statement, for example if an invalid
* argument has been specified in the sql statement, or if the statement is invalid for the
* current state of the {@link Connection}.
*/
StatementResult execute(ConnectionStatementExecutor connectionExecutor, String sql)
StatementResult execute(ConnectionStatementExecutor connectionExecutor, ParsedStatement statement)
throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import com.google.cloud.spanner.connection.ClientSideStatementValueConverters.ExplainCommandConverter;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -47,9 +48,10 @@ class ClientSideStatementExplainExecutor implements ClientSideStatementExecutor
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String sql)
public StatementResult execute(ConnectionStatementExecutor connection, ParsedStatement statement)
throws Exception {
return (StatementResult) method.invoke(connection, getParameterValue(sql));
return (StatementResult)
method.invoke(connection, getParameterValue(statement.getSqlWithoutComments()));
}

String getParameterValue(String sql) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.spanner.connection;

import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.StatementResult.ClientSideStatementType;
import com.google.cloud.spanner.connection.StatementResult.ResultType;
import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -160,7 +161,8 @@ ClientSideStatementImpl compile() throws CompileException {
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String statement) {
public StatementResult execute(
ConnectionStatementExecutor connection, ParsedStatement statement) {
Preconditions.checkState(executor != null, "This statement has not been compiled");
try {
return executor.execute(connection, statement);
Expand All @@ -170,9 +172,9 @@ public StatementResult execute(ConnectionStatementExecutor connection, String st
if (e.getCause() instanceof SpannerException) {
throw (SpannerException) e.getCause();
}
throw new ExecuteException(e.getCause(), this, statement);
throw new ExecuteException(e.getCause(), this, statement.getStatement().getSql());
} catch (Exception e) {
throw new ExecuteException(e, this, statement);
throw new ExecuteException(e, this, statement.getStatement().getSql());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.cloud.spanner.connection;

import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import java.lang.reflect.Method;

Expand All @@ -42,7 +43,7 @@ class ClientSideStatementNoParamExecutor implements ClientSideStatementExecutor
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String statement)
public StatementResult execute(ConnectionStatementExecutor connection, ParsedStatement statement)
throws Exception {
return (StatementResult) method.invoke(connection);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import com.google.cloud.spanner.connection.ClientSideStatementValueConverters.PgTransactionModeConverter;
import java.lang.reflect.Method;
Expand All @@ -42,9 +43,10 @@ class ClientSideStatementPgBeginExecutor implements ClientSideStatementExecutor
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String sql)
public StatementResult execute(ConnectionStatementExecutor connection, ParsedStatement statement)
throws Exception {
return (StatementResult) method.invoke(connection, getParameterValue(sql));
return (StatementResult)
method.invoke(connection, getParameterValue(statement.getSqlWithoutComments()));
}

PgTransactionMode getParameterValue(String sql) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import com.google.common.base.Preconditions;
import java.lang.reflect.Constructor;
Expand Down Expand Up @@ -72,9 +73,10 @@ class ClientSideStatementSetExecutor<T> implements ClientSideStatementExecutor {
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String sql)
public StatementResult execute(ConnectionStatementExecutor connection, ParsedStatement statement)
throws Exception {
return (StatementResult) method.invoke(connection, getParameterValue(sql));
return (StatementResult)
method.invoke(connection, getParameterValue(statement.getSqlWithoutComments()));
}

T getParameterValue(String sql) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ public StatementResult execute(Statement statement) {
case CLIENT_SIDE:
return parsedStatement
.getClientSideStatement()
.execute(connectionStatementExecutor, parsedStatement.getSqlWithoutComments());
.execute(connectionStatementExecutor, parsedStatement);
case QUERY:
return StatementResultImpl.of(
internalExecuteQuery(CallType.SYNC, parsedStatement, AnalyzeMode.NONE));
Expand Down Expand Up @@ -957,7 +957,7 @@ public AsyncStatementResult executeAsync(Statement statement) {
return AsyncStatementResultImpl.of(
parsedStatement
.getClientSideStatement()
.execute(connectionStatementExecutor, parsedStatement.getSqlWithoutComments()),
.execute(connectionStatementExecutor, parsedStatement),
spanner.getAsyncExecutorProvider());
case QUERY:
return AsyncStatementResultImpl.of(
Expand Down Expand Up @@ -1010,7 +1010,7 @@ private ResultSet parseAndExecuteQuery(
case CLIENT_SIDE:
return parsedStatement
.getClientSideStatement()
.execute(connectionStatementExecutor, parsedStatement.getSqlWithoutComments())
.execute(connectionStatementExecutor, parsedStatement)
.getResultSet();
case QUERY:
return internalExecuteQuery(callType, parsedStatement, analyzeMode, options);
Expand Down Expand Up @@ -1050,7 +1050,7 @@ private AsyncResultSet parseAndExecuteQueryAsync(
return ResultSets.toAsyncResultSet(
parsedStatement
.getClientSideStatement()
.execute(connectionStatementExecutor, parsedStatement.getSqlWithoutComments())
.execute(connectionStatementExecutor, parsedStatement)
.getResultSet(),
spanner.getAsyncExecutorProvider(),
options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void testBeginWithNoOption() {
"start work isolation level serializable")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, never()).setTransactionMode(any());
Expand All @@ -89,7 +89,7 @@ public void testBeginReadOnly() {
"start work read only")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, times(index)).setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION);
Expand All @@ -114,7 +114,7 @@ public void testBeginReadWrite() {
"start work read write")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, times(index)).setTransactionMode(TransactionMode.READ_WRITE_TRANSACTION);
Expand All @@ -140,7 +140,7 @@ public void testBeginReadOnlyWithIsolationLevel() {
"begin read write , \nisolation level default\n\t,read only")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, times(index)).setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION);
Expand Down Expand Up @@ -173,7 +173,7 @@ public void testBeginWithNotDeferrable() {
"begin not deferrable read write , \nisolation level default\n\t,read only")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, times(index)).setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,18 @@ public void setup() {
parser = AbstractStatementParser.getInstance(dialect);
}

ParsedStatement parse(String sql) {
return parser.parse(Statement.of(sql));
}

@Test
public void testExecuteGetAutocommit() {
ParsedStatement statement = parser.parse(Statement.of("show variable autocommit"));
ConnectionImpl connection = mock(ConnectionImpl.class);
ConnectionStatementExecutorImpl executor = mock(ConnectionStatementExecutorImpl.class);
when(executor.getConnection()).thenReturn(connection);
when(executor.statementShowAutocommit()).thenCallRealMethod();
statement.getClientSideStatement().execute(executor, "show variable autocommit");
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).isAutocommit();
}

Expand All @@ -70,9 +74,7 @@ public void testExecuteGetReadOnly() {
ConnectionImpl connection = mock(ConnectionImpl.class);
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
statement
.getClientSideStatement()
.execute(executor, String.format("show variable %sreadonly", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).isReadOnly();
}

Expand All @@ -86,10 +88,7 @@ public void testExecuteGetAutocommitDmlMode() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getAutocommitDmlMode()).thenReturn(AutocommitDmlMode.TRANSACTIONAL);
statement
.getClientSideStatement()
.execute(
executor, String.format("show variable %sautocommit_dml_mode", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getAutocommitDmlMode();
}

Expand All @@ -102,7 +101,7 @@ public void testExecuteGetStatementTimeout() {
when(executor.statementShowStatementTimeout()).thenCallRealMethod();
when(connection.hasStatementTimeout()).thenReturn(true);
when(connection.getStatementTimeout(TimeUnit.NANOSECONDS)).thenReturn(1L);
statement.getClientSideStatement().execute(executor, "show variable statement_timeout");
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(2)).getStatementTimeout(TimeUnit.NANOSECONDS);
}

Expand All @@ -115,9 +114,7 @@ public void testExecuteGetReadTimestamp() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getReadTimestampOrNull()).thenReturn(Timestamp.now());
statement
.getClientSideStatement()
.execute(executor, String.format("show variable %sread_timestamp", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getReadTimestampOrNull();
}

Expand All @@ -130,10 +127,7 @@ public void testExecuteGetCommitTimestamp() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getCommitTimestampOrNull()).thenReturn(Timestamp.now());
statement
.getClientSideStatement()
.execute(
executor, String.format("show variable %scommit_timestamp", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getCommitTimestampOrNull();
}

Expand All @@ -147,10 +141,7 @@ public void testExecuteGetReadOnlyStaleness() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getReadOnlyStaleness()).thenReturn(TimestampBound.strong());
statement
.getClientSideStatement()
.execute(
executor, String.format("show variable %sread_only_staleness", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getReadOnlyStaleness();
}

Expand All @@ -164,10 +155,7 @@ public void testExecuteGetOptimizerVersion() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getOptimizerVersion()).thenReturn("1");
statement
.getClientSideStatement()
.execute(
executor, String.format("show variable %soptimizer_version", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getOptimizerVersion();
}

Expand All @@ -182,11 +170,7 @@ public void testExecuteGetOptimizerStatisticsPackage() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getOptimizerStatisticsPackage()).thenReturn("custom-package");
statement
.getClientSideStatement()
.execute(
executor,
String.format("show variable %soptimizer_statistics_package", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getOptimizerStatisticsPackage();
}

Expand All @@ -196,7 +180,7 @@ public void testExecuteBegin() {
for (String statement : subject.getClientSideStatement().getExampleStatements()) {
ConnectionImpl connection = mock(ConnectionImpl.class);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
subject.getClientSideStatement().execute(executor, statement);
subject.getClientSideStatement().execute(executor, parse(statement));
verify(connection, times(1)).beginTransaction();
}
}
Expand All @@ -209,7 +193,7 @@ public void testExecuteCommit() {
ConnectionStatementExecutorImpl executor = mock(ConnectionStatementExecutorImpl.class);
when(executor.getConnection()).thenReturn(connection);
when(executor.statementCommit()).thenCallRealMethod();
subject.getClientSideStatement().execute(executor, statement);
subject.getClientSideStatement().execute(executor, parse(statement));
verify(connection, times(1)).commit();
}
}
Expand All @@ -222,7 +206,7 @@ public void testExecuteRollback() {
ConnectionStatementExecutorImpl executor = mock(ConnectionStatementExecutorImpl.class);
when(executor.getConnection()).thenReturn(connection);
when(executor.statementRollback()).thenCallRealMethod();
subject.getClientSideStatement().execute(executor, statement);
subject.getClientSideStatement().execute(executor, parse(statement));
verify(connection, times(1)).rollback();
}
}
Expand Down
Loading

0 comments on commit c34d51e

Please sign in to comment.