Skip to content

Commit

Permalink
Enable ssl use for cluster in CassandraIO (#29302)
Browse files Browse the repository at this point in the history
* Add properties / methods to facilitate optional ssl use

Add properties / methods to facilitate optional ssl use by surfacing datastax driver functionality to programmatically configure ssl options when building Cluster.

* Add comments with link to driver documentation for programmatic config

Add comments with link to driver documentation for programmatic config of ssl

* Apply suggested change to address spotless check failure.
  • Loading branch information
niv-lac authored Nov 4, 2023
1 parent 35f19d5 commit 87ca614
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.datastax.driver.core.ConsistencyLevel;
import com.datastax.driver.core.PlainTextAuthProvider;
import com.datastax.driver.core.QueryOptions;
import com.datastax.driver.core.SSLOptions;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.SocketOptions;
import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy;
Expand Down Expand Up @@ -192,6 +193,9 @@ public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>>
@Nullable
abstract ValueProvider<Set<RingRange>> ringRanges();

@Nullable
abstract ValueProvider<SSLOptions> sslOptions();

abstract Builder<T> builder();

/** Specify the hosts of the Apache Cassandra instances. */
Expand Down Expand Up @@ -385,6 +389,22 @@ public Read<T> withRingRanges(ValueProvider<Set<RingRange>> ringRange) {
return builder().setRingRanges(ringRange).build();
}

/**
* Optionally, specify {@link SSLOptions} configuration to utilize SSL. See
* https://docs.datastax.com/en/developer/java-driver/3.11/manual/ssl/#jsse-programmatic
*/
public Read<T> withSsl(SSLOptions sslOptions) {
return withSsl(ValueProvider.StaticValueProvider.of(sslOptions));
}

/**
* Optionally, specify {@link SSLOptions} configuration to utilize SSL. See
* https://docs.datastax.com/en/developer/java-driver/3.11/manual/ssl/#jsse-programmatic
*/
public Read<T> withSsl(ValueProvider<SSLOptions> sslOptions) {
return builder().setSslOptions(sslOptions).build();
}

@Override
public PCollection<T> expand(PBegin input) {
checkArgument((hosts() != null && port() != null), "WithHosts() and withPort() are required");
Expand Down Expand Up @@ -422,7 +442,8 @@ private static <T> Set<RingRange> getRingRanges(Read<T> read) {
read.localDc(),
read.consistencyLevel(),
read.connectTimeout(),
read.readTimeout())) {
read.readTimeout(),
read.sslOptions())) {
if (isMurmur3Partitioner(cluster)) {
LOG.info("Murmur3Partitioner detected, splitting");
Integer splitCount;
Expand Down Expand Up @@ -495,6 +516,8 @@ abstract static class Builder<T> {

abstract Builder<T> setRingRanges(ValueProvider<Set<RingRange>> ringRange);

abstract Builder<T> setSslOptions(ValueProvider<SSLOptions> sslOptions);

abstract Read<T> autoBuild();

public Read<T> build() {
Expand Down Expand Up @@ -543,6 +566,8 @@ public abstract static class Write<T> extends PTransform<PCollection<T>, PDone>

abstract @Nullable ValueProvider<Integer> readTimeout();

abstract @Nullable ValueProvider<SSLOptions> sslOptions();

abstract @Nullable SerializableFunction<Session, Mapper> mapperFactoryFn();

abstract Builder<T> builder();
Expand Down Expand Up @@ -725,6 +750,22 @@ public Write<T> withMapperFactoryFn(SerializableFunction<Session, Mapper> mapper
return builder().setMapperFactoryFn(mapperFactoryFn).build();
}

/**
* Optionally, specify {@link SSLOptions} configuration to utilize SSL. See
* https://docs.datastax.com/en/developer/java-driver/3.11/manual/ssl/#jsse-programmatic
*/
public Write<T> withSsl(SSLOptions sslOptions) {
return withSsl(ValueProvider.StaticValueProvider.of(sslOptions));
}

/**
* Optionally, specify {@link SSLOptions} configuration to utilize SSL. See
* https://docs.datastax.com/en/developer/java-driver/3.11/manual/ssl/#jsse-programmatic
*/
public Write<T> withSsl(ValueProvider<SSLOptions> sslOptions) {
return builder().setSslOptions(sslOptions).build();
}

@Override
public void validate(PipelineOptions pipelineOptions) {
checkState(
Expand Down Expand Up @@ -799,6 +840,8 @@ abstract static class Builder<T> {

abstract Optional<SerializableFunction<Session, Mapper>> mapperFactoryFn();

abstract Builder<T> setSslOptions(ValueProvider<SSLOptions> sslOptions);

abstract Write<T> autoBuild(); // not public

public Write<T> build() {
Expand Down Expand Up @@ -880,7 +923,8 @@ static Cluster getCluster(
ValueProvider<String> localDc,
ValueProvider<String> consistencyLevel,
ValueProvider<Integer> connectTimeout,
ValueProvider<Integer> readTimeout) {
ValueProvider<Integer> readTimeout,
ValueProvider<SSLOptions> sslOptions) {

Cluster.Builder builder =
Cluster.builder().addContactPoints(hosts.get().toArray(new String[0])).withPort(port.get());
Expand Down Expand Up @@ -913,6 +957,10 @@ static Cluster getCluster(
socketOptions.setReadTimeoutMillis(readTimeout.get());
}

if (sslOptions != null) {
builder.withSSL(sslOptions.get());
}

return builder.build();
}

Expand Down Expand Up @@ -941,7 +989,8 @@ private static class Mutator<T> {
spec.localDc(),
spec.consistencyLevel(),
spec.connectTimeout(),
spec.readTimeout());
spec.readTimeout(),
spec.sslOptions());
this.session = cluster.connect(spec.keyspace().get());
this.mapperFactoryFn = spec.mapperFactoryFn();
this.mutateFutures = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ static Session getSession(Read<?> read) {
read.localDc(),
read.consistencyLevel(),
read.connectTimeout(),
read.readTimeout()));
read.readTimeout(),
read.sslOptions()));
return sessionMap.computeIfAbsent(
readToSessionHash(read),
k -> cluster.connect(Objects.requireNonNull(read.keyspace()).get()));
Expand Down

0 comments on commit 87ca614

Please sign in to comment.