From 56b7c8d232b91bc2e18d3d84e2a554862b7ba14d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 11 Feb 2022 16:46:23 +0100 Subject: [PATCH] fix: untyped null parameters would cause NPE Adding an untyped null value as a parameter to a statement was not possible, as: 1. The parameter collection would allow a null value to be added, but when the statement was built, it would throw a NullPointerException because it used an ImmutableMap internally, which does not support null values. 2. The translation from a hand-written Statement instance to a proto Statement instance would fail, as it did not take into account that the parameter could be null. Fixes #1679 --- .../cloud/spanner/AbstractReadContext.java | 16 ++++--- .../google/cloud/spanner/BatchClientImpl.java | 6 ++- .../spanner/PartitionedDmlTransaction.java | 6 ++- .../com/google/cloud/spanner/Statement.java | 19 +++++--- .../java/com/google/cloud/spanner/Value.java | 10 +++-- .../google/cloud/spanner/StatementTest.java | 5 +++ .../google/cloud/spanner/it/ITDMLTest.java | 43 +++++++++++++++++++ 7 files changed, 86 insertions(+), 19 deletions(-) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 8412fac67e3..ea658a36eab 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -583,8 +583,10 @@ ExecuteSqlRequest.Builder getExecuteSqlRequestBuilder( if (!stmtParameters.isEmpty()) { com.google.protobuf.Struct.Builder paramsBuilder = builder.getParamsBuilder(); for (Map.Entry param : stmtParameters.entrySet()) { - paramsBuilder.putFields(param.getKey(), param.getValue().toProto()); - builder.putParamTypes(param.getKey(), param.getValue().getType().toProto()); + paramsBuilder.putFields(param.getKey(), Value.toProto(param.getValue())); + if (param.getValue() != null) { + builder.putParamTypes(param.getKey(), param.getValue().getType().toProto()); + } } } if (withTransactionSelector) { @@ -612,10 +614,12 @@ ExecuteBatchDmlRequest.Builder getExecuteBatchDmlRequestBuilder( com.google.protobuf.Struct.Builder paramsBuilder = builder.getStatementsBuilder(idx).getParamsBuilder(); for (Map.Entry param : stmtParameters.entrySet()) { - paramsBuilder.putFields(param.getKey(), param.getValue().toProto()); - builder - .getStatementsBuilder(idx) - .putParamTypes(param.getKey(), param.getValue().getType().toProto()); + paramsBuilder.putFields(param.getKey(), Value.toProto(param.getValue())); + if (param.getValue() != null) { + builder + .getStatementsBuilder(idx) + .putParamTypes(param.getKey(), param.getValue().getType().toProto()); + } } } idx++; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index c84bef77cf8..a12aefe7861 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -169,8 +169,10 @@ public List partitionQuery( if (!stmtParameters.isEmpty()) { Struct.Builder paramsBuilder = builder.getParamsBuilder(); for (Map.Entry param : stmtParameters.entrySet()) { - paramsBuilder.putFields(param.getKey(), param.getValue().toProto()); - builder.putParamTypes(param.getKey(), param.getValue().getType().toProto()); + paramsBuilder.putFields(param.getKey(), Value.toProto(param.getValue())); + if (param.getValue() != null) { + builder.putParamTypes(param.getKey(), param.getValue().getType().toProto()); + } } } TransactionSelector selector = getTransactionSelector(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDmlTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDmlTransaction.java index 2aeceb276da..78c092f1792 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDmlTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDmlTransaction.java @@ -217,8 +217,10 @@ private void setParameters( if (!statementParameters.isEmpty()) { com.google.protobuf.Struct.Builder paramsBuilder = requestBuilder.getParamsBuilder(); for (Map.Entry param : statementParameters.entrySet()) { - paramsBuilder.putFields(param.getKey(), param.getValue().toProto()); - requestBuilder.putParamTypes(param.getKey(), param.getValue().getType().toProto()); + paramsBuilder.putFields(param.getKey(), Value.toProto(param.getValue())); + if (param.getValue() != null) { + requestBuilder.putParamTypes(param.getKey(), param.getValue().getType().toProto()); + } } } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Statement.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Statement.java index 401aef92998..3693e2cb587 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Statement.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Statement.java @@ -21,9 +21,9 @@ import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; import java.io.Serializable; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -56,11 +56,11 @@ public final class Statement implements Serializable { private static final long serialVersionUID = -1967958247625065259L; - private final ImmutableMap parameters; + private final Map parameters; private final String sql; private final QueryOptions queryOptions; - private Statement(String sql, ImmutableMap parameters, QueryOptions queryOptions) { + private Statement(String sql, Map parameters, QueryOptions queryOptions) { this.sql = sql; this.parameters = parameters; this.queryOptions = queryOptions; @@ -112,14 +112,17 @@ public ValueBinder bind(String parameter) { public Statement build() { checkState( currentBinding == null, "Binding for parameter '%s' is incomplete.", currentBinding); - return new Statement(sqlBuffer.toString(), ImmutableMap.copyOf(parameters), queryOptions); + return new Statement( + sqlBuffer.toString(), + Collections.unmodifiableMap(new HashMap<>(parameters)), + queryOptions); } private class Binder extends ValueBinder { @Override Builder handle(Value value) { Preconditions.checkArgument( - !value.isCommitTimestamp(), + value == null || !value.isCommitTimestamp(), "Mutation.COMMIT_TIMESTAMP cannot be bound as a query parameter"); checkState(currentBinding != null, "No binding in progress"); parameters.put(currentBinding, value); @@ -218,7 +221,11 @@ StringBuilder toString(StringBuilder b) { b.append(", "); } b.append(parameter.getKey()).append(": "); - parameter.getValue().toString(b); + if (parameter.getValue() == null) { + b.append("NULL"); + } else { + parameter.getValue().toString(b); + } } b.append("}"); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java index 38b4361d755..db9cee56e92 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java @@ -73,6 +73,9 @@ public abstract class Value implements Serializable { */ public static final Timestamp COMMIT_TIMESTAMP = Timestamp.ofTimeMicroseconds(0L); + static final com.google.protobuf.Value NULL_PROTO = + com.google.protobuf.Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(); + /** Constant to specify a PG Numeric NaN value. */ public static final String NAN = "NaN"; @@ -622,6 +625,10 @@ public String toString() { // END OF PUBLIC API. + static com.google.protobuf.Value toProto(Value value) { + return value == null ? NULL_PROTO : value.toProto(); + } + abstract void toString(StringBuilder b); abstract com.google.protobuf.Value toProto(); @@ -737,9 +744,6 @@ Value newValue(boolean isNull, BitSet nulls, boolean[] values) { /** Template class for {@code Value} implementations. */ private abstract static class AbstractValue extends Value { - static final com.google.protobuf.Value NULL_PROTO = - com.google.protobuf.Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(); - private final boolean isNull; private final Type type; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/StatementTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/StatementTest.java index b2c8244f64d..d5b5a3ec619 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/StatementTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/StatementTest.java @@ -61,6 +61,8 @@ public void serialization() { .append("bytes_field = @bytes_field ") .bind("bytes_field") .to(ByteArray.fromBase64("abcd")) + .bind("untyped_null_field") + .to((Value) null) .build(); reserializeAndAssert(stmt); } @@ -165,6 +167,9 @@ public void equalsAndHashCode() { tester.addEqualityGroup(Statement.newBuilder("SELECT @x, @y").bind("y").to(2).build()); tester.addEqualityGroup( Statement.newBuilder("SELECT @x, @y").bind("x").to(1).bind("y").to(2).build()); + tester.addEqualityGroup( + Statement.newBuilder("SELECT @x, @y").bind("x").to((Value) null).build(), + Statement.newBuilder("SELECT @x, @y").bind("x").to((Value) null).build()); tester.testEquals(); } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDMLTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDMLTest.java index 2c7a8da27c3..5ea30912103 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDMLTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDMLTest.java @@ -18,7 +18,12 @@ import static com.google.cloud.spanner.testing.EmulatorSpannerHelper.isUsingEmulator; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeFalse; import com.google.cloud.spanner.AbortedException; import com.google.cloud.spanner.Database; @@ -38,6 +43,7 @@ import com.google.cloud.spanner.TimestampBound; import com.google.cloud.spanner.TransactionRunner; import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.ConnectionOptions; import java.util.ArrayList; import java.util.Arrays; @@ -380,4 +386,41 @@ public void standardDMLWithExecuteSQL() { // checks for multi-stmts within a txn, therefore also verifying seqNo. executeQuery(DML_COUNT * 2, updateDml(), deleteDml()); } + + @Test + public void testUntypedNullValues() { + assumeFalse( + "Spanner PostgreSQL does not yet support untyped null values", + dialect.dialect == Dialect.POSTGRESQL); + + DatabaseClient client = getClient(dialect.dialect); + String sql; + if (dialect.dialect == Dialect.POSTGRESQL) { + sql = "INSERT INTO T (K, V) VALUES ($1, $2)"; + } else { + sql = "INSERT INTO T (K, V) VALUES (@p1, @p2)"; + } + Long updateCount = + client + .readWriteTransaction() + .run( + transaction -> + transaction.executeUpdate( + Statement.newBuilder(sql) + .bind("p1") + .to("k1") + .bind("p2") + .to((Value) null) + .build())); + + assertNotNull(updateCount); + assertEquals(1L, updateCount.longValue()); + + // Read the row back and verify that the value is null. + try (ResultSet resultSet = client.singleUse().executeQuery(Statement.of("SELECT V FROM T"))) { + assertTrue(resultSet.next()); + assertTrue(resultSet.isNull(0)); + assertFalse(resultSet.next()); + } + } }