diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java index eb1860902c6f..351425db67a3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java @@ -247,7 +247,16 @@ static SpannerOptions buildSpannerOptions(SpannerConfig spannerConfig) { } if (plainText != null && Boolean.TRUE.equals(plainText.get())) { builder.setChannelConfigurator(b -> b.usePlaintext()); - builder.setCredentials(NoCredentials.getInstance()); + } + ValueProvider clientCert = spannerConfig.getClientCertPath(); + ValueProvider clientKey = spannerConfig.getClientCertKeyPath(); + if (clientCert != null + && clientKey != null + && clientCert.isAccessible() + && clientKey.isAccessible() + && !Strings.isNullOrEmpty(clientCert.get()) + && !Strings.isNullOrEmpty(clientKey.get())) { + builder.useClientCert(clientCert.get(), clientKey.get()); } } @@ -273,6 +282,8 @@ static SpannerOptions buildSpannerOptions(SpannerConfig spannerConfig) { ValueProvider credentials = spannerConfig.getCredentials(); if (credentials != null && credentials.get() != null) { builder.setCredentials(credentials.get()); + } else if (experimentalHost != null && !Strings.isNullOrEmpty(experimentalHost.get())) { + builder.setCredentials(NoCredentials.getInstance()); } ValueProvider waitForSessionCreationDuration = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java index a17c851f38a0..92eac9108283 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java @@ -97,6 +97,10 @@ public String getHostValue() { public abstract @Nullable ValueProvider getPlainText(); + public abstract @Nullable ValueProvider getClientCertPath(); + + public abstract @Nullable ValueProvider getClientCertKeyPath(); + @VisibleForTesting abstract @Nullable ServiceFactory getServiceFactory(); @@ -194,6 +198,10 @@ abstract Builder setExecuteStreamingSqlRetrySettings( abstract Builder setWaitForSessionCreationDuration( ValueProvider waitForSessionCreationDuration); + abstract Builder setClientCertPath(ValueProvider clientCertPath); + + abstract Builder setClientCertKeyPath(ValueProvider clientCertKeyPath); + public abstract SpannerConfig build(); } @@ -414,4 +422,33 @@ public SpannerConfig withWaitForSessionCreationDuration( return withWaitForSessionCreationDuration( ValueProvider.StaticValueProvider.of(waitForSessionCreationDuration)); } + + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using a Spanner Omni instance (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public SpannerConfig withClientCert( + ValueProvider certPath, ValueProvider keyPath) { + return toBuilder().setClientCertPath(certPath).setClientCertKeyPath(keyPath).build(); + } + + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using a Spanner Omni instance (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public SpannerConfig withClientCert(String certPath, String keyPath) { + return withClientCert( + ValueProvider.StaticValueProvider.of(certPath), + ValueProvider.StaticValueProvider.of(keyPath)); + } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index d271e763aac5..c326541818b3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -37,6 +37,7 @@ import com.google.api.gax.rpc.StatusCode.Code; import com.google.auth.Credentials; import com.google.auto.value.AutoValue; +import com.google.cloud.NoCredentials; import com.google.cloud.ServiceFactory; import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbortedException; @@ -146,6 +147,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; @@ -657,6 +659,35 @@ public ReadAll withUsingPlainTextChannel(boolean plainText) { return withUsingPlainTextChannel(ValueProvider.StaticValueProvider.of(plainText)); } + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public ReadAll withClientCert(ValueProvider certPath, ValueProvider keyPath) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withClientCert(certPath, keyPath)); + } + + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public ReadAll withClientCert(String certPath, String keyPath) { + return withClientCert( + ValueProvider.StaticValueProvider.of(certPath), + ValueProvider.StaticValueProvider.of(keyPath)); + } + /** Specifies the Cloud Spanner database. */ public ReadAll withDatabaseId(ValueProvider databaseId) { SpannerConfig config = getSpannerConfig(); @@ -917,6 +948,35 @@ public Read withUsingPlainTextChannel(boolean plainText) { return withUsingPlainTextChannel(ValueProvider.StaticValueProvider.of(plainText)); } + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public Read withClientCert(ValueProvider certPath, ValueProvider keyPath) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withClientCert(certPath, keyPath)); + } + + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public Read withClientCert(String certPath, String keyPath) { + return withClientCert( + ValueProvider.StaticValueProvider.of(certPath), + ValueProvider.StaticValueProvider.of(keyPath)); + } + /** If true the uses Cloud Spanner batch API. */ public Read withBatching(boolean batching) { return toBuilder().setBatching(batching).build(); @@ -1244,6 +1304,36 @@ public CreateTransaction withUsingPlainTextChannel(boolean plainText) { return withUsingPlainTextChannel(ValueProvider.StaticValueProvider.of(plainText)); } + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public CreateTransaction withClientCert( + ValueProvider certPath, ValueProvider keyPath) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withClientCert(certPath, keyPath)); + } + + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public CreateTransaction withClientCert(String certPath, String keyPath) { + return withClientCert( + ValueProvider.StaticValueProvider.of(certPath), + ValueProvider.StaticValueProvider.of(keyPath)); + } + @VisibleForTesting CreateTransaction withServiceFactory(ServiceFactory serviceFactory) { SpannerConfig config = getSpannerConfig(); @@ -1412,6 +1502,35 @@ public Write withUsingPlainTextChannel(boolean plainText) { return withUsingPlainTextChannel(ValueProvider.StaticValueProvider.of(plainText)); } + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public Write withClientCert(ValueProvider certPath, ValueProvider keyPath) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withClientCert(certPath, keyPath)); + } + + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public Write withClientCert(String certPath, String keyPath) { + return withClientCert( + ValueProvider.StaticValueProvider.of(certPath), + ValueProvider.StaticValueProvider.of(keyPath)); + } + public Write withDialectView(PCollectionView dialect) { return toBuilder().setDialectView(dialect).build(); } @@ -1770,6 +1889,10 @@ public abstract static class ReadChangeStream abstract @Nullable ValueProvider getPlainText(); + abstract @Nullable ValueProvider getClientCertPath(); + + abstract @Nullable ValueProvider getClientCertKeyPath(); + abstract Duration getRealTimeCheckpointInterval(); abstract int getHeartbeatMillis(); @@ -1807,6 +1930,10 @@ abstract static class Builder { abstract Builder setPlainText(ValueProvider plainText); + abstract Builder setClientCertPath(ValueProvider clientCertPath); + + abstract Builder setClientCertKeyPath(ValueProvider clientCertKeyPath); + /** * When caught up to real-time, checkpoint processing of change stream this often. This sets a * bound on latency of processing if a steady trickle of elements prevents the heartbeat @@ -1946,6 +2073,36 @@ public ReadChangeStream withUsingPlainTextChannel(boolean plainText) { return withUsingPlainTextChannel(ValueProvider.StaticValueProvider.of(plainText)); } + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public ReadChangeStream withClientCert( + ValueProvider certPath, ValueProvider keyPath) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withClientCert(certPath, keyPath)); + } + + /** + * Specifies certificate paths to use for mTLS channel. + * + *

Note: These parameters are only valid when using Spanner Omni (set via {@code + * withExperimentalHost}). + * + * @param certPath Path to the client certificate file. + * @param keyPath Path to the client certificate key file. + */ + public ReadChangeStream withClientCert(String certPath, String keyPath) { + return withClientCert( + ValueProvider.StaticValueProvider.of(certPath), + ValueProvider.StaticValueProvider.of(keyPath)); + } + /** * Configures low latency experiment for readChangeStream transform. Example usage: * @@ -2177,9 +2334,17 @@ SpannerConfig buildChangeStreamSpannerConfig() { static SpannerConfig buildSpannerConfigWithCredential( SpannerConfig spannerConfig, PipelineOptions pipelineOptions) { if (spannerConfig.getCredentials() == null && pipelineOptions != null) { - final Credentials credentials = pipelineOptions.as(GcpOptions.class).getGcpCredential(); - if (credentials != null) { - spannerConfig = spannerConfig.withCredentials(credentials); + boolean isExperimentalHostEmpty = + spannerConfig.getExperimentalHost() == null + || (spannerConfig.getExperimentalHost().isAccessible() + && Strings.isNullOrEmpty(spannerConfig.getExperimentalHost().get())); + if (isExperimentalHostEmpty) { + final Credentials credentials = pipelineOptions.as(GcpOptions.class).getGcpCredential(); + if (credentials != null) { + spannerConfig = spannerConfig.withCredentials(credentials); + } + } else { + spannerConfig = spannerConfig.withCredentials(NoCredentials.getInstance()); } } return spannerConfig; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTransformRegistrar.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTransformRegistrar.java index 70908f982721..544aa2938d51 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTransformRegistrar.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTransformRegistrar.java @@ -81,6 +81,8 @@ public abstract static class CrossLanguageConfiguration { @Nullable String emulatorHost; @Nullable String experimentalHost; @Nullable Boolean plainText; + @Nullable String clientCertPath; + @Nullable String clientCertKeyPath; public void setInstanceId(String instanceId) { this.instanceId = instanceId; @@ -110,6 +112,14 @@ public void setPlainText(@Nullable Boolean plainText) { this.plainText = plainText; } + public void setClientCertPath(@Nullable String clientCertPath) { + this.clientCertPath = clientCertPath; + } + + public void setClientCertKeyPath(@Nullable String clientCertKeyPath) { + this.clientCertKeyPath = clientCertKeyPath; + } + void checkMandatoryFields() { if (projectId.isEmpty()) { throw new IllegalArgumentException("projectId can't be empty"); @@ -120,6 +130,10 @@ void checkMandatoryFields() { if (instanceId.isEmpty()) { throw new IllegalArgumentException("instanceId can't be empty"); } + if ((clientCertPath != null) != (clientCertKeyPath != null)) { + throw new IllegalArgumentException( + "Both clientCertPath and clientCertKeyPath must be specified together."); + } } } @@ -249,6 +263,11 @@ public PTransform> buildExternal( if (configuration.plainText != null) { readTransform = readTransform.withUsingPlainTextChannel(configuration.plainText); } + if (configuration.clientCertPath != null && configuration.clientCertKeyPath != null) { + readTransform = + readTransform.withClientCert( + configuration.clientCertPath, configuration.clientCertKeyPath); + } @Nullable TimestampBound timestampBound = configuration.getTimestampBound(); if (timestampBound != null) { readTransform = readTransform.withTimestampBound(timestampBound); @@ -393,6 +412,11 @@ public PTransform, PDone> buildExternal( if (configuration.plainText != null) { writeTransform = writeTransform.withUsingPlainTextChannel(configuration.plainText); } + if (configuration.clientCertPath != null && configuration.clientCertKeyPath != null) { + writeTransform = + writeTransform.withClientCert( + configuration.clientCertPath, configuration.clientCertKeyPath); + } if (configuration.commitDeadline != null) { writeTransform = writeTransform.withCommitDeadline(configuration.commitDeadline); } @@ -504,6 +528,12 @@ public PTransform> buildExternal( readChangeStream = readChangeStream.withMetadataTable(configuration.metadataTable); } + if (configuration.clientCertPath != null && configuration.clientCertKeyPath != null) { + readChangeStream = + readChangeStream.withClientCert( + configuration.clientCertPath, configuration.clientCertKeyPath); + } + if (configuration.rpcPriority != null) { readChangeStream = readChangeStream.withRpcPriority(configuration.rpcPriority); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/MetadataSpannerConfigFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/MetadataSpannerConfigFactory.java index 959582e9c35f..de81814c2110 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/MetadataSpannerConfigFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/MetadataSpannerConfigFactory.java @@ -87,6 +87,12 @@ public static SpannerConfig create( config = config.withUsingPlainTextChannel(plainText.get()); } + ValueProvider clientCert = primaryConfig.getClientCertPath(); + ValueProvider clientKey = primaryConfig.getClientCertKeyPath(); + if (clientCert != null && clientKey != null) { + config = config.withClientCert(clientCert, clientKey); + } + ValueProvider isLocalChannelProvider = primaryConfig.getIsLocalChannelProvider(); if (isLocalChannelProvider != null) { config = diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java index 34c839d3e1e6..405d1e9b1e15 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java @@ -35,7 +35,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -110,21 +109,19 @@ public void setUp() throws Exception { PipelineOptionsFactory.register(SpannerTestPipelineOptions.class); options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class); - project = options.getInstanceProjectId(); - if (project == null) { - project = options.as(GcpOptions.class).getProject(); - } + project = SpannerTestHelper.getProject(options, options.getInstanceProjectId()); + options.setInstanceId(SpannerTestHelper.getInstanceId(options.getInstanceId())); - spanner = + SpannerOptions.Builder spannerBuilder = SpannerOptions.newBuilder() .setProjectId(project) .disableGrpcGcpExtension() .setSessionPoolOption( SessionPoolOptions.newBuilder() .setWaitForMinSessionsDuration(java.time.Duration.ofMinutes(5)) - .build()) - .build() - .getService(); + .build()); + spannerBuilder = SpannerTestHelper.setUpSpannerOptions(spannerBuilder); + spanner = spannerBuilder.build().getService(); databaseName = generateDatabaseName(); pgDatabaseName = "pg-" + databaseName; @@ -485,17 +482,19 @@ private void makeTestData() { } private SpannerConfig createSpannerConfig() { - return SpannerConfig.create() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(databaseName); + return SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(databaseName)); } private SpannerConfig createPgSpannerConfig() { - return SpannerConfig.create() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(pgDatabaseName); + return SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(pgDatabaseName)); } private DatabaseClient getDatabaseClient() { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTestHelper.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTestHelper.java new file mode 100644 index 000000000000..55ec74f19eb6 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTestHelper.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.cloud.spanner.SpannerOptions; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; + +/** Helper methods for Spanner IT tests to support Spanner Omni natively. */ +public class SpannerTestHelper { + + public static String getOmniEndpoint() { + return System.getenv("SPANNER_OMNI_ENDPOINT"); + } + + public static boolean isOmni() { + return !Strings.isNullOrEmpty(getOmniEndpoint()); + } + + public static String getOmniClientCert() { + return System.getenv("SPANNER_OMNI_CLIENT_CERT"); + } + + public static String getOmniClientKey() { + return System.getenv("SPANNER_OMNI_CLIENT_KEY"); + } + + public static boolean isOmniUsePlainText() { + return Boolean.parseBoolean(System.getenv("SPANNER_OMNI_USE_PLAIN_TEXT")); + } + + public static String getProject(PipelineOptions options, String instanceProjectId) { + if (isOmni()) { + return "default"; + } + String project = instanceProjectId; + if (project == null) { + project = options.as(GcpOptions.class).getProject(); + } + return project; + } + + public static String getInstanceId(String instanceId) { + if (isOmni()) { + return "default"; + } + return instanceId; + } + + public static SpannerOptions.Builder setUpSpannerOptions(SpannerOptions.Builder builder) { + if (isOmni()) { + builder.setExperimentalHost(getOmniEndpoint()); + if (isOmniUsePlainText()) { + builder.usePlainText(); + } + String cert = getOmniClientCert(); + String key = getOmniClientKey(); + if (!Strings.isNullOrEmpty(cert) && !Strings.isNullOrEmpty(key)) { + builder.useClientCert(cert, key); + } + } + return builder; + } + + public static SpannerConfig setUpSpannerConfig(SpannerConfig config) { + if (isOmni()) { + config = config.withExperimentalHost(getOmniEndpoint()); + if (isOmniUsePlainText()) { + config = config.withUsingPlainTextChannel(true); + } + String cert = getOmniClientCert(); + String key = getOmniClientKey(); + if (!Strings.isNullOrEmpty(cert) && !Strings.isNullOrEmpty(key)) { + config = config.withClientCert(cert, key); + } + } + return config; + } + + public static SpannerIO.Write setUpSpannerIO(SpannerIO.Write write) { + if (isOmni()) { + write = write.withExperimentalHost(getOmniEndpoint()); + if (isOmniUsePlainText()) { + write = write.withUsingPlainTextChannel(true); + } + String cert = getOmniClientCert(); + String key = getOmniClientKey(); + if (!Strings.isNullOrEmpty(cert) && !Strings.isNullOrEmpty(key)) { + write = write.withClientCert(cert, key); + } + } + return write; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java index df23435d82ab..e4ad156ef885 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java @@ -17,9 +17,11 @@ */ package org.apache.beam.sdk.io.gcp.spanner; +import static org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper.isOmni; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.junit.Assume.assumeFalse; import com.google.api.gax.longrunning.OperationFuture; import com.google.cloud.spanner.Database; @@ -37,7 +39,6 @@ import java.util.Collections; import java.util.Objects; import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; @@ -117,21 +118,20 @@ public void setUp() throws Exception { PipelineOptionsFactory.register(SpannerTestPipelineOptions.class); options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class); - project = options.getInstanceProjectId(); - if (project == null) { - project = options.as(GcpOptions.class).getProject(); - } + project = SpannerTestHelper.getProject(options, options.getInstanceProjectId()); + options.setInstanceId(SpannerTestHelper.getInstanceId(options.getInstanceId())); - spanner = + SpannerOptions.Builder spannerBuilder = SpannerOptions.newBuilder() .setProjectId(project) .disableGrpcGcpExtension() .setSessionPoolOption( SessionPoolOptions.newBuilder() .setWaitForMinSessionsDuration(java.time.Duration.ofMinutes(5)) - .build()) - .build() - .getService(); + .build()); + + spannerBuilder = SpannerTestHelper.setUpSpannerOptions(spannerBuilder); + spanner = spannerBuilder.build().getService(); databaseName = generateDatabaseName(); pgDatabaseName = "pg-" + databaseName; @@ -191,10 +191,11 @@ public void testWrite() throws Exception { .apply("Generate mu", ParDo.of(new GenerateMutations(options.getTable()))) .apply( "Write db", - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(databaseName)); + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(databaseName))); PCollectionView dialectView = p.apply("PG Dialect", Create.of(Dialect.POSTGRESQL)) @@ -203,10 +204,11 @@ public void testWrite() throws Exception { .apply("Generate PG mu", ParDo.of(new GenerateMutations(options.getTable()))) .apply( "Write PG db", - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(pgDatabaseName) + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(pgDatabaseName)) .withDialectView(dialectView)); PipelineResult result = p.run(); @@ -218,6 +220,8 @@ public void testWrite() throws Exception { @Test public void testWriteViaSchemaTransform() throws Exception { + assumeFalse( + "SchemaTransform tests do not support dynamic SpannerConfig overrides for Omni", isOmni()); int numRecords = 100; final Schema tableSchema = Schema.builder().addInt64Field("Key").addStringField("Value").build(); @@ -258,20 +262,22 @@ public void testSequentialWrite() throws Exception { .apply("Gen mutations1", ParDo.of(new GenerateMutations(options.getTable()))) .apply( "write to table1", - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(databaseName)); + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(databaseName))); p.apply("second step", GenerateSequence.from(numRecords).to(2 * numRecords)) .apply("Gen mutations2", ParDo.of(new GenerateMutations(options.getTable()))) .apply("wait", Wait.on(stepOne.getOutput())) .apply( "write to table2", - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(databaseName)); + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(databaseName))); PCollectionView dialectView = p.apply("PG Dialect", Create.of(Dialect.POSTGRESQL)) @@ -282,10 +288,11 @@ public void testSequentialWrite() throws Exception { .apply("Gen pg mutations1", ParDo.of(new GenerateMutations(options.getTable()))) .apply( "write to pg table1", - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(pgDatabaseName) + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(pgDatabaseName)) .withDialectView(dialectView)); p.apply("pg second step", GenerateSequence.from(numRecords).to(2 * numRecords)) @@ -293,10 +300,11 @@ public void testSequentialWrite() throws Exception { .apply("pg wait", Wait.on(pgStepOne.getOutput())) .apply( "write to pg table2", - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(pgDatabaseName) + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(pgDatabaseName)) .withDialectView(dialectView)); PipelineResult result = p.run(); @@ -313,10 +321,11 @@ public void testReportFailures() throws Exception { .apply("Generate mu", ParDo.of(new GenerateMutations(options.getTable(), new DivBy2()))) .apply( "Write db", - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(databaseName) + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(databaseName)) .withFailureMode(SpannerIO.FailureMode.REPORT_FAILURES)); PCollectionView dialectView = @@ -326,10 +335,11 @@ public void testReportFailures() throws Exception { .apply("Generate pg mu", ParDo.of(new GenerateMutations(options.getTable(), new DivBy2()))) .apply( "Write pg db", - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(pgDatabaseName) + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(pgDatabaseName)) .withFailureMode(SpannerIO.FailureMode.REPORT_FAILURES) .withDialectView(dialectView)); @@ -348,10 +358,11 @@ public void testFailFast() throws Exception { p.apply(GenerateSequence.from(0).to(2 * numRecords)) .apply(ParDo.of(new GenerateMutations(options.getTable(), new DivBy2()))) .apply( - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(databaseName)); + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(databaseName))); PipelineResult result = p.run(); result.waitUntilFinish(); @@ -369,10 +380,11 @@ public void testPgFailFast() throws Exception { p.apply(GenerateSequence.from(0).to(2 * numRecords)) .apply(ParDo.of(new GenerateMutations(options.getTable(), new DivBy2()))) .apply( - SpannerIO.write() - .withProjectId(project) - .withInstanceId(options.getInstanceId()) - .withDatabaseId(pgDatabaseName) + SpannerTestHelper.setUpSpannerIO( + SpannerIO.write() + .withProjectId(project) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(pgDatabaseName)) .withDialectView(dialectView)); PipelineResult result = p.run(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/IntegrationTestEnv.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/IntegrationTestEnv.java index b36004a5cd15..b00568f8d40c 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/IntegrationTestEnv.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/IntegrationTestEnv.java @@ -33,8 +33,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.common.IOITHelper; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.commons.lang3.RandomStringUtils; import org.junit.rules.ExternalResource; import org.slf4j.Logger; @@ -73,12 +73,11 @@ protected void before() throws Throwable { final ChangeStreamTestPipelineOptions options = IOITHelper.readIOTestPipelineOptions(ChangeStreamTestPipelineOptions.class); - projectId = - Optional.ofNullable(options.getProjectId()) - .orElseGet(() -> options.as(GcpOptions.class).getProject()); - instanceId = options.getInstanceId(); + projectId = SpannerTestHelper.getProject(options, options.getProjectId()); + instanceId = SpannerTestHelper.getInstanceId(options.getInstanceId()); generateDatabaseIds(options); - spanner = + + SpannerOptions.Builder spannerBuilder = SpannerOptions.newBuilder() .setProjectId(projectId) .setHost(host) @@ -86,9 +85,9 @@ protected void before() throws Throwable { .setSessionPoolOption( SessionPoolOptions.newBuilder() .setWaitForMinSessionsDuration(java.time.Duration.ofMinutes(5)) - .build()) - .build() - .getService(); + .build()); + spannerBuilder = SpannerTestHelper.setUpSpannerOptions(spannerBuilder); + spanner = spannerBuilder.build().getService(); databaseAdminClient = spanner.getDatabaseAdminClient(); metadataTableName = generateTableName(METADATA_TABLE_NAME_PREFIX); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamIT.java index e6178cbf5402..f66e88ddbeb0 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamIT.java @@ -42,6 +42,7 @@ import org.apache.beam.runners.direct.DirectRunner; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -137,10 +138,11 @@ public void testReadSpannerChangeStreamImpl(TestPipeline testPipeline, String ro final Timestamp endAt = deleteTimestamps.getRight(); SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId); + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)); if (role != null) { spannerConfig = spannerConfig.withDatabaseRole(StaticValueProvider.of(role)); } @@ -197,10 +199,11 @@ public void testReadSpannerChangeStreamFilteredByTransactionTag() { final Timestamp endAt = deleteTimestamps.getRight(); final SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId); + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)); // Filter records to only those from transactions with tag "app=beam;action=update" final PCollection tokens = diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedByTimestampAndTransactionIdIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedByTimestampAndTransactionIdIT.java index 04c09a2e12ce..c4708e11a618 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedByTimestampAndTransactionIdIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedByTimestampAndTransactionIdIT.java @@ -37,6 +37,7 @@ import org.apache.beam.sdk.coders.BooleanCoder; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.state.BagState; @@ -102,10 +103,11 @@ public static void setup() throws InterruptedException, ExecutionException, Time @Test public void testTransactionBoundaries() { final SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId); + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)); // Commit a initial transaction to get the timestamp to start reading from. List mutations = new ArrayList<>(); mutations.add(insertRecordMutation(0, "FirstName0", "LastName0")); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedWithinKeyGloballyIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedWithinKeyGloballyIT.java index 513e5aeb2d76..4ecb721aa74a 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedWithinKeyGloballyIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedWithinKeyGloballyIT.java @@ -34,6 +34,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.StateSpec; @@ -93,10 +94,11 @@ public static void setup() throws InterruptedException, ExecutionException, Time @Test public void testOrderedWithinKey() { final SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId); + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)); // Get the time increment interval at which to flush data changes ordered by key. final long timeIncrementInSeconds = 10; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedWithinKeyIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedWithinKeyIT.java index e1731099204d..30d748ab1c79 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedWithinKeyIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamOrderedWithinKeyIT.java @@ -33,6 +33,7 @@ import java.util.stream.StreamSupport; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ChangeStreamRecordMetadata; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.testing.PAssert; @@ -88,10 +89,11 @@ public static void setup() throws InterruptedException, ExecutionException, Time public void testOrderedWithinKey() { LOG.info("Test pipeline: {}", pipeline); final SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId); + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)); // Commit a initial transaction to get the timestamp to start reading from. List mutations = new ArrayList<>(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPlacementTableIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPlacementTableIT.java index 9318dad7ec6d..0b5831f05e1d 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPlacementTableIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPlacementTableIT.java @@ -43,6 +43,7 @@ import org.apache.beam.runners.direct.DirectRunner; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -149,10 +150,11 @@ public void testReadSpannerChangeStreamImpl( final Timestamp endAt = deleteTimestamps.getRight(); SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId); + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)); if (role != null) { spannerConfig = spannerConfig.withDatabaseRole(StaticValueProvider.of(role)); } @@ -213,10 +215,11 @@ public void testReadSpannerChangeStreamFilteredByTransactionTag() { final Timestamp endAt = deleteTimestamps.getRight(); final SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId); + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)); // Filter records to only those from transactions with tag "app=beam;action=update" final PCollection tokens = diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPlacementTablePostgresIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPlacementTablePostgresIT.java index 129a4334d1bb..f017e72eb0ae 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPlacementTablePostgresIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPlacementTablePostgresIT.java @@ -37,6 +37,7 @@ import java.util.Optional; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.options.ValueProvider; @@ -125,10 +126,11 @@ private void testReadSpannerChangeStreamImpl(List tvfNameList) { final Timestamp endAt = deleteTimestamps.getRight(); final SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId) + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)) .withHost(ValueProvider.StaticValueProvider.of(host)); SpannerIO.ReadChangeStream readChangeStream = diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPostgresIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPostgresIT.java index 5f5f55e46964..99f5269e14fd 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPostgresIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamPostgresIT.java @@ -36,6 +36,7 @@ import java.util.Optional; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.options.ValueProvider; @@ -114,10 +115,11 @@ public void testReadSpannerChangeStream() { final Timestamp endAt = deleteTimestamps.getRight(); final SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId) + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)) .withHost(ValueProvider.StaticValueProvider.of(host)); final PCollection tokens = diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamTransactionBoundariesIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamTransactionBoundariesIT.java index 12e1caa76428..fe0228193815 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamTransactionBoundariesIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamTransactionBoundariesIT.java @@ -34,6 +34,7 @@ import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; +import org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.state.BagState; @@ -90,10 +91,11 @@ public static void setup() throws InterruptedException, ExecutionException, Time public void testTransactionBoundaries() { LOG.info("Test pipeline: {}", pipeline); final SpannerConfig spannerConfig = - SpannerConfig.create() - .withProjectId(projectId) - .withInstanceId(instanceId) - .withDatabaseId(databaseId); + SpannerTestHelper.setUpSpannerConfig( + SpannerConfig.create() + .withProjectId(projectId) + .withInstanceId(instanceId) + .withDatabaseId(databaseId)); // Commit a initial transaction to get the timestamp to start reading from. List mutations = new ArrayList<>(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamsSchemaTransformIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamsSchemaTransformIT.java index 56d964087128..606af9cfc996 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamsSchemaTransformIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/it/SpannerChangeStreamsSchemaTransformIT.java @@ -17,9 +17,11 @@ */ package org.apache.beam.sdk.io.gcp.spanner.changestreams.it; +import static org.apache.beam.sdk.io.gcp.spanner.SpannerTestHelper.isOmni; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeFalse; import com.google.cloud.Timestamp; import com.google.cloud.spanner.DatabaseClient; @@ -89,6 +91,8 @@ public void before() { @Test public void testReadSpannerChangeStream() { + assumeFalse( + "SchemaTransform tests do not support dynamic SpannerConfig overrides for Omni", isOmni()); // Defines how many rows are going to be inserted / updated / deleted in the test final int numRows = 5; // Inserts numRows rows and uses the first commit timestamp as the startAt for reading the