Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor client side statements to accept the entire parsed statement #2556

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.spanner.connection;

import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.StatementResult.ClientSideStatementType;
import java.util.List;

Expand Down Expand Up @@ -63,5 +64,5 @@ interface ClientSideStatement {
* needed for the execution of the {@link ClientSideStatement}.
* @return the result of the execution of the statement.
*/
StatementResult execute(ConnectionStatementExecutor executor, String statement);
StatementResult execute(ConnectionStatementExecutor executor, ParsedStatement statement);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.cloud.spanner.connection;

import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;

/**
* A {@link ClientSideStatementExecutor} is used to compile {@link ClientSideStatement}s from the
* json source file, and to execute these against a {@link Connection} (through a {@link
Expand All @@ -29,13 +31,13 @@ interface ClientSideStatementExecutor {
*
* @param connectionExecutor The {@link ConnectionStatementExecutor} to use to execute the
* statement on a {@link Connection}.
* @param sql The sql statement that is executed. This can be used to parse any additional
* @param statement The statement that is executed. This can be used to parse any additional
* arguments that might be needed for the execution of the {@link ClientSideStatementImpl}.
* @return the result of the execution.
* @throws Exception If an error occurs while executing the statement, for example if an invalid
* argument has been specified in the sql statement, or if the statement is invalid for the
* current state of the {@link Connection}.
*/
StatementResult execute(ConnectionStatementExecutor connectionExecutor, String sql)
StatementResult execute(ConnectionStatementExecutor connectionExecutor, ParsedStatement statement)
throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import com.google.cloud.spanner.connection.ClientSideStatementValueConverters.ExplainCommandConverter;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -47,9 +48,10 @@ class ClientSideStatementExplainExecutor implements ClientSideStatementExecutor
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String sql)
public StatementResult execute(ConnectionStatementExecutor connection, ParsedStatement statement)
throws Exception {
return (StatementResult) method.invoke(connection, getParameterValue(sql));
return (StatementResult)
method.invoke(connection, getParameterValue(statement.getSqlWithoutComments()));
}

String getParameterValue(String sql) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.spanner.connection;

import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.StatementResult.ClientSideStatementType;
import com.google.cloud.spanner.connection.StatementResult.ResultType;
import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -160,7 +161,8 @@ ClientSideStatementImpl compile() throws CompileException {
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String statement) {
public StatementResult execute(
ConnectionStatementExecutor connection, ParsedStatement statement) {
Preconditions.checkState(executor != null, "This statement has not been compiled");
try {
return executor.execute(connection, statement);
Expand All @@ -170,9 +172,9 @@ public StatementResult execute(ConnectionStatementExecutor connection, String st
if (e.getCause() instanceof SpannerException) {
throw (SpannerException) e.getCause();
}
throw new ExecuteException(e.getCause(), this, statement);
throw new ExecuteException(e.getCause(), this, statement.getStatement().getSql());
} catch (Exception e) {
throw new ExecuteException(e, this, statement);
throw new ExecuteException(e, this, statement.getStatement().getSql());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.cloud.spanner.connection;

import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import java.lang.reflect.Method;

Expand All @@ -42,7 +43,7 @@ class ClientSideStatementNoParamExecutor implements ClientSideStatementExecutor
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String statement)
public StatementResult execute(ConnectionStatementExecutor connection, ParsedStatement statement)
throws Exception {
return (StatementResult) method.invoke(connection);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import com.google.cloud.spanner.connection.ClientSideStatementValueConverters.PgTransactionModeConverter;
import java.lang.reflect.Method;
Expand All @@ -42,9 +43,10 @@ class ClientSideStatementPgBeginExecutor implements ClientSideStatementExecutor
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String sql)
public StatementResult execute(ConnectionStatementExecutor connection, ParsedStatement statement)
throws Exception {
return (StatementResult) method.invoke(connection, getParameterValue(sql));
return (StatementResult)
method.invoke(connection, getParameterValue(statement.getSqlWithoutComments()));
}

PgTransactionMode getParameterValue(String sql) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
import com.google.common.base.Preconditions;
import java.lang.reflect.Constructor;
Expand Down Expand Up @@ -72,9 +73,10 @@ class ClientSideStatementSetExecutor<T> implements ClientSideStatementExecutor {
}

@Override
public StatementResult execute(ConnectionStatementExecutor connection, String sql)
public StatementResult execute(ConnectionStatementExecutor connection, ParsedStatement statement)
throws Exception {
return (StatementResult) method.invoke(connection, getParameterValue(sql));
return (StatementResult)
method.invoke(connection, getParameterValue(statement.getSqlWithoutComments()));
}

T getParameterValue(String sql) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ public StatementResult execute(Statement statement) {
case CLIENT_SIDE:
return parsedStatement
.getClientSideStatement()
.execute(connectionStatementExecutor, parsedStatement.getSqlWithoutComments());
.execute(connectionStatementExecutor, parsedStatement);
case QUERY:
return StatementResultImpl.of(
internalExecuteQuery(CallType.SYNC, parsedStatement, AnalyzeMode.NONE));
Expand Down Expand Up @@ -957,7 +957,7 @@ public AsyncStatementResult executeAsync(Statement statement) {
return AsyncStatementResultImpl.of(
parsedStatement
.getClientSideStatement()
.execute(connectionStatementExecutor, parsedStatement.getSqlWithoutComments()),
.execute(connectionStatementExecutor, parsedStatement),
spanner.getAsyncExecutorProvider());
case QUERY:
return AsyncStatementResultImpl.of(
Expand Down Expand Up @@ -1010,7 +1010,7 @@ private ResultSet parseAndExecuteQuery(
case CLIENT_SIDE:
return parsedStatement
.getClientSideStatement()
.execute(connectionStatementExecutor, parsedStatement.getSqlWithoutComments())
.execute(connectionStatementExecutor, parsedStatement)
.getResultSet();
case QUERY:
return internalExecuteQuery(callType, parsedStatement, analyzeMode, options);
Expand Down Expand Up @@ -1050,7 +1050,7 @@ private AsyncResultSet parseAndExecuteQueryAsync(
return ResultSets.toAsyncResultSet(
parsedStatement
.getClientSideStatement()
.execute(connectionStatementExecutor, parsedStatement.getSqlWithoutComments())
.execute(connectionStatementExecutor, parsedStatement)
.getResultSet(),
spanner.getAsyncExecutorProvider(),
options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void testBeginWithNoOption() {
"start work isolation level serializable")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, never()).setTransactionMode(any());
Expand All @@ -89,7 +89,7 @@ public void testBeginReadOnly() {
"start work read only")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, times(index)).setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION);
Expand All @@ -114,7 +114,7 @@ public void testBeginReadWrite() {
"start work read write")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, times(index)).setTransactionMode(TransactionMode.READ_WRITE_TRANSACTION);
Expand All @@ -140,7 +140,7 @@ public void testBeginReadOnlyWithIsolationLevel() {
"begin read write , \nisolation level default\n\t,read only")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, times(index)).setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION);
Expand Down Expand Up @@ -173,7 +173,7 @@ public void testBeginWithNotDeferrable() {
"begin not deferrable read write , \nisolation level default\n\t,read only")) {
ParsedStatement statement = parser.parse(Statement.of(sql));
assertEquals(sql, StatementType.CLIENT_SIDE, statement.getType());
statement.getClientSideStatement().execute(executor, sql);
statement.getClientSideStatement().execute(executor, statement);

verify(connection, times(index)).beginTransaction();
verify(connection, times(index)).setTransactionMode(TransactionMode.READ_ONLY_TRANSACTION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,18 @@ public void setup() {
parser = AbstractStatementParser.getInstance(dialect);
}

ParsedStatement parse(String sql) {
return parser.parse(Statement.of(sql));
}

@Test
public void testExecuteGetAutocommit() {
ParsedStatement statement = parser.parse(Statement.of("show variable autocommit"));
ConnectionImpl connection = mock(ConnectionImpl.class);
ConnectionStatementExecutorImpl executor = mock(ConnectionStatementExecutorImpl.class);
when(executor.getConnection()).thenReturn(connection);
when(executor.statementShowAutocommit()).thenCallRealMethod();
statement.getClientSideStatement().execute(executor, "show variable autocommit");
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).isAutocommit();
}

Expand All @@ -70,9 +74,7 @@ public void testExecuteGetReadOnly() {
ConnectionImpl connection = mock(ConnectionImpl.class);
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
statement
.getClientSideStatement()
.execute(executor, String.format("show variable %sreadonly", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).isReadOnly();
}

Expand All @@ -86,10 +88,7 @@ public void testExecuteGetAutocommitDmlMode() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getAutocommitDmlMode()).thenReturn(AutocommitDmlMode.TRANSACTIONAL);
statement
.getClientSideStatement()
.execute(
executor, String.format("show variable %sautocommit_dml_mode", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getAutocommitDmlMode();
}

Expand All @@ -102,7 +101,7 @@ public void testExecuteGetStatementTimeout() {
when(executor.statementShowStatementTimeout()).thenCallRealMethod();
when(connection.hasStatementTimeout()).thenReturn(true);
when(connection.getStatementTimeout(TimeUnit.NANOSECONDS)).thenReturn(1L);
statement.getClientSideStatement().execute(executor, "show variable statement_timeout");
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(2)).getStatementTimeout(TimeUnit.NANOSECONDS);
}

Expand All @@ -115,9 +114,7 @@ public void testExecuteGetReadTimestamp() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getReadTimestampOrNull()).thenReturn(Timestamp.now());
statement
.getClientSideStatement()
.execute(executor, String.format("show variable %sread_timestamp", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getReadTimestampOrNull();
}

Expand All @@ -130,10 +127,7 @@ public void testExecuteGetCommitTimestamp() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getCommitTimestampOrNull()).thenReturn(Timestamp.now());
statement
.getClientSideStatement()
.execute(
executor, String.format("show variable %scommit_timestamp", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getCommitTimestampOrNull();
}

Expand All @@ -147,10 +141,7 @@ public void testExecuteGetReadOnlyStaleness() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getReadOnlyStaleness()).thenReturn(TimestampBound.strong());
statement
.getClientSideStatement()
.execute(
executor, String.format("show variable %sread_only_staleness", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getReadOnlyStaleness();
}

Expand All @@ -164,10 +155,7 @@ public void testExecuteGetOptimizerVersion() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getOptimizerVersion()).thenReturn("1");
statement
.getClientSideStatement()
.execute(
executor, String.format("show variable %soptimizer_version", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getOptimizerVersion();
}

Expand All @@ -182,11 +170,7 @@ public void testExecuteGetOptimizerStatisticsPackage() {
when(connection.getDialect()).thenReturn(dialect);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
when(connection.getOptimizerStatisticsPackage()).thenReturn("custom-package");
statement
.getClientSideStatement()
.execute(
executor,
String.format("show variable %soptimizer_statistics_package", getNamespace(dialect)));
statement.getClientSideStatement().execute(executor, statement);
verify(connection, times(1)).getOptimizerStatisticsPackage();
}

Expand All @@ -196,7 +180,7 @@ public void testExecuteBegin() {
for (String statement : subject.getClientSideStatement().getExampleStatements()) {
ConnectionImpl connection = mock(ConnectionImpl.class);
ConnectionStatementExecutorImpl executor = new ConnectionStatementExecutorImpl(connection);
subject.getClientSideStatement().execute(executor, statement);
subject.getClientSideStatement().execute(executor, parse(statement));
verify(connection, times(1)).beginTransaction();
}
}
Expand All @@ -209,7 +193,7 @@ public void testExecuteCommit() {
ConnectionStatementExecutorImpl executor = mock(ConnectionStatementExecutorImpl.class);
when(executor.getConnection()).thenReturn(connection);
when(executor.statementCommit()).thenCallRealMethod();
subject.getClientSideStatement().execute(executor, statement);
subject.getClientSideStatement().execute(executor, parse(statement));
verify(connection, times(1)).commit();
}
}
Expand All @@ -222,7 +206,7 @@ public void testExecuteRollback() {
ConnectionStatementExecutorImpl executor = mock(ConnectionStatementExecutorImpl.class);
when(executor.getConnection()).thenReturn(connection);
when(executor.statementRollback()).thenCallRealMethod();
subject.getClientSideStatement().execute(executor, statement);
subject.getClientSideStatement().execute(executor, parse(statement));
verify(connection, times(1)).rollback();
}
}
Expand Down
Loading