From ef37e9bce9d96ec32a7d620ce6a539cf0b23ff66 Mon Sep 17 00:00:00 2001 From: Andrei Gurau Date: Wed, 16 Nov 2022 13:52:52 -0500 Subject: [PATCH] add rootCaCertificate option to SplunkIO fix test Add error reporting for BatchConverter match failure (#24022) * add error reporting for BatchConverters * Test pytorch * Finish up torch tests * yapf * yapf * Remove else Update automation to use Go 1.19 (#24175) Co-authored-by: lostluck <13907733+lostluck@users.noreply.github.com> Fix broken json for notebook (#24183) Using Teardown context instead of deprecated finalize (#24180) * Using Teardown context instead of deprecated finalize * making function public Co-authored-by: Scott Strong [Python]Support pipe operator as Union (PEP -604) (#24106) Fixes https://github.com/apache/beam/issues/21972 Add custom inference function support to the PyTorch model handler (#24062) * Initial type def and function signature * [Draft] Add custom inference fn support to Pytorch Model Handler * Formatting * Split out default * Remove Keyed version for testing * Move device optimization * Make default available for import, add to test classes * Remove incorrect default from keyed test * Keyed impl * Fix device arg * custom inference test * formatting * Add helpers to define custom inference functions using model methods * Trailing whitespace * Unit tests * Fix incorrect getattr syntax * Type typo * Fix docstring * Fix keyed helper, add basic generate route * Modify generate() to be different than forward() * formatting * Remove extra generate() def Strip FGAC database role from changestreams metadata requests (#24177) Co-authored-by: Doug Judd Updated README of Interactive Beam Removed deprecated cache_dir runner param in favor of the cache_root global option. Minor update Fix arguments to checkState in BatchViewOverrides Re-use serializable pipeline options when already available (#24192) Fix Python PostCommit Example CustomPTransformIT on portable (#24159) * Fix Python PostCommit Examples on portable * Fix custom_ptransform pipeline options gets modified * Specify flinkConfDir revert upgrade to go 1.19 for action unit tests (#24189) Use only ValueProviders in SpannerConfig (#24156) [Tour of Beam] [Frontend] Content tree URLs (#23776) * Content tree navigation (#23593) Unit content navigation (#23593) Update URL on node click (#23593) Active unit color (#23593) removeListener in unit (#23593) First unit is opened on group title click (#23593) WIP by Alexey Inkin (#23593) selectedUnitColor (#23593) Unit borderRadius (#23593) RegExp todo (#23593) added referenced collection package to remove warning (#23593) small refinement (#23593) expand on group tap, padding, openNode (#23593) group expansion bug fix (#23593) selected & unselected progress indicators (#23593) * AnimatedBuilders instead of StatefulWidgets in unit & group (#23593) * fixed _getNodeAncestors (#23593) * get sdkId (#23593) * addressing comments (#23593) * sdkId getter & StatelessExpansionTile (#23593) * expand & collapse group (#23593) * StatelessExpansionTile (#23593) * license (#23593) * ValueChanged and ValueKey in StatelessExpansionTile (#23593) Co-authored-by: darkhan.nausharipov Co-authored-by: Alexey Inkin refs: issue-24196, fix broken hyperlink Add a reference to Java RunInference example Python TextIO Performance Test (#23951) * Python TextIO Performance Test * Add filebasedio_perf_test module for unified test framework for Python file-based IOs * Fix MetricsReader publishes metrics duplicately if more than one load test declared. This is because MetricsReader.publishers was static class variable * Fix pylint * Distribute Python performance tests random time at a day instead of all at 3PM * Add information about length conversion Fix PythonLint (#24219) Bump loader-utils from 1.4.1 to 1.4.2 in /sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel (#24191) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Temporary update Python RC validation job updates updates Uses _all to follow alias/datastreams when estimating index size Fixes #24117 Adds test for following aliases when estimating index size Bump github.com/aws/aws-sdk-go-v2/config from 1.18.0 to 1.18.1 in /sdks (#24222) Bumps [github.com/aws/aws-sdk-go-v2/config](https://github.com/aws/aws-sdk-go-v2) from 1.18.0 to 1.18.1. - [Release notes](https://github.com/aws/aws-sdk-go-v2/releases) - [Changelog](https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md) - [Commits](https://github.com/aws/aws-sdk-go-v2/compare/config/v1.18.0...config/v1.18.1) --- updated-dependencies: - dependency-name: github.com/aws/aws-sdk-go-v2/config dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Add enableGzipHttpCompression option to SplunkIO (#24197) add enableBatchLogs as SplunkIO option spotless fix issue with setEnableBatchLogs --- .../test-properties.json | 2 +- .../workflows/build_playground_backend.yml | 2 +- .github/workflows/go_tests.yml | 2 +- .github/workflows/local_env_tests.yml | 4 +- ..._PerformanceTests_BigQueryIO_Python.groovy | 4 +- ...PerformanceTests_FileBasedIO_Python.groovy | 81 ++++++++ ...ob_PerformanceTests_PubsubIO_Python.groovy | 2 +- ...b_PerformanceTests_SpannerIO_Python.groovy | 4 +- ...ob_PostCommit_Python_Examples_Flink.groovy | 1 + .../Python_IO_IT_Tests_Dataflow.json | 122 ++++++++++++ .../beam/gradle/BeamModulePlugin.groovy | 2 +- .../beam-ml/dataframe_api_preprocessing.ipynb | 4 +- .../frontend/lib/models/content_tree.dart | 20 +- .../frontend/lib/models/group.dart | 29 ++- .../frontend/lib/models/module.dart | 29 ++- .../frontend/lib/models/node.dart | 20 +- .../frontend/lib/models/parent_node.dart | 18 ++ .../frontend/lib/models/unit.dart | 15 +- .../pages/tour/controllers/content_tree.dart | 71 ++++++- .../frontend/lib/pages/tour/path.dart | 9 +- .../frontend/lib/pages/tour/state.dart | 2 + .../lib/pages/tour/widgets/group.dart | 40 ++-- .../lib/pages/tour/widgets/group_title.dart | 5 +- .../lib/pages/tour/widgets/module.dart | 2 +- .../widgets/stateless_expansion_tile.dart} | 32 ++- .../tour/widgets/tour_progress_indicator.dart | 16 +- .../frontend/lib/pages/tour/widgets/unit.dart | 38 +++- .../frontend/lib/pages/welcome/screen.dart | 3 - learning/tour-of-beam/frontend/pubspec.lock | 4 +- learning/tour-of-beam/frontend/pubspec.yaml | 1 + .../lib/src/constants/colors.dart | 46 +++-- .../lib/src/theme/theme.dart | 26 +++ release/go-licenses/Dockerfile | 2 +- .../python_release_automation_utils.sh | 8 +- .../types/CoderTypeInformation.java | 6 +- .../streaming/KvToByteBufferKeySelector.java | 2 +- .../streaming/SdfByteBufferKeySelector.java | 2 +- .../streaming/WorkItemKeySelector.java | 2 +- .../streaming/io/UnboundedSourceWrapper.java | 3 +- .../state/FlinkBroadcastStateInternals.java | 17 +- .../runners/dataflow/BatchViewOverrides.java | 1 + sdks/go.mod | 14 +- sdks/go.sum | 13 +- sdks/go/run_with_go_version.sh | 2 +- .../io/elasticsearch/ElasticsearchIOTest.java | 8 + .../io/elasticsearch/ElasticsearchIOTest.java | 8 + .../io/elasticsearch/ElasticsearchIOTest.java | 8 + .../io/elasticsearch/ElasticsearchIOTest.java | 8 + .../ElasticsearchIOTestUtils.java | 35 +++- .../sdk/io/elasticsearch/ElasticsearchIO.java | 17 +- .../sdk/io/gcp/spanner/SpannerAccessor.java | 6 +- .../sdk/io/gcp/spanner/SpannerConfig.java | 6 +- .../spanner/changestreams/dao/DaoFactory.java | 2 +- .../io/gcp/spanner/SpannerAccessorTest.java | 4 +- .../org/apache/beam/sdk/io/jdbc/JdbcIO.java | 4 +- .../sdk/io/splunk/CustomX509TrustManager.java | 84 ++++++++ .../sdk/io/splunk/HttpEventPublisher.java | 53 ++++- .../beam/sdk/io/splunk/SplunkEventWriter.java | 66 +++++- .../apache/beam/sdk/io/splunk/SplunkIO.java | 42 +++- .../io/splunk/CustomX509TrustManagerTest.java | 80 ++++++++ .../sdk/io/splunk/HttpEventPublisherTest.java | 111 ++++++++++- .../resources/SplunkTestCerts/PrivateKey.pem | 28 +++ .../SplunkTestCerts/RecognizedCertificate.crt | 26 +++ .../test/resources/SplunkTestCerts/RootCA.crt | 31 +++ .../resources/SplunkTestCerts/RootCA_2.crt | 13 ++ .../SplunkTestCerts/RootCA_PrivateKey.pem | 52 +++++ .../UnrecognizedCertificate.crt | 21 ++ .../examples/cookbook/custom_ptransform.py | 9 +- .../apache_beam/io/filebasedio_perf_test.py | 188 ++++++++++++++++++ .../ml/inference/pytorch_inference.py | 156 ++++++++++++--- .../ml/inference/pytorch_inference_test.py | 131 +++++++++++- .../apache_beam/runners/interactive/README.md | 21 +- .../yarn.lock | 6 +- .../testing/load_tests/load_test.py | 4 +- .../load_tests/load_test_metrics_utils.py | 16 +- .../apache_beam/testing/synthetic_pipeline.py | 2 +- .../typehints/arrow_type_compatibility.py | 33 ++- .../arrow_type_compatibility_test.py | 24 +++ sdks/python/apache_beam/typehints/batch.py | 83 ++++---- .../apache_beam/typehints/batch_test.py | 32 +++ .../typehints/native_type_compatibility.py | 9 + .../typehints/pandas_type_compatibility.py | 15 +- .../pandas_type_compatibility_test.py | 26 ++- .../typehints/pytorch_type_compatibility.py | 30 +-- .../pytorch_type_compatibility_test.py | 28 +++ .../python/apache_beam/typehints/typehints.py | 5 + .../apache_beam/typehints/typehints_test.py | 10 + .../python/test-suites/portable/common.gradle | 8 + .../content/en/blog/splitAtFraction-method.md | 2 +- .../site/content/en/blog/splittable-do-fn.md | 4 +- .../en/documentation/runners/dataflow.md | 2 +- .../sdks/python-machine-learning.md | 2 +- 92 files changed, 1975 insertions(+), 312 deletions(-) create mode 100644 .test-infra/jenkins/job_PerformanceTests_FileBasedIO_Python.groovy rename learning/tour-of-beam/frontend/lib/{components/filler_text.dart => pages/tour/widgets/stateless_expansion_tile.dart} (52%) create mode 100644 sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/CustomX509TrustManager.java create mode 100644 sdks/java/io/splunk/src/test/java/org/apache/beam/sdk/io/splunk/CustomX509TrustManagerTest.java create mode 100644 sdks/java/io/splunk/src/test/resources/SplunkTestCerts/PrivateKey.pem create mode 100644 sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RecognizedCertificate.crt create mode 100644 sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA.crt create mode 100644 sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA_2.crt create mode 100644 sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA_PrivateKey.pem create mode 100644 sdks/java/io/splunk/src/test/resources/SplunkTestCerts/UnrecognizedCertificate.crt create mode 100644 sdks/python/apache_beam/io/filebasedio_perf_test.py diff --git a/.github/actions/setup-default-test-properties/test-properties.json b/.github/actions/setup-default-test-properties/test-properties.json index 1df5f8dc48a8..f2e1dce9922d 100644 --- a/.github/actions/setup-default-test-properties/test-properties.json +++ b/.github/actions/setup-default-test-properties/test-properties.json @@ -18,6 +18,6 @@ "SPARK_VERSIONS": ["2", "3"] }, "GoTestProperties": { - "SUPPORTED_VERSIONS": ["1.18"] + "SUPPORTED_VERSIONS": ["1.19"] } } diff --git a/.github/workflows/build_playground_backend.yml b/.github/workflows/build_playground_backend.yml index fad16b3cd5fa..f9c161c934df 100644 --- a/.github/workflows/build_playground_backend.yml +++ b/.github/workflows/build_playground_backend.yml @@ -34,7 +34,7 @@ jobs: name: Build Playground Backend App runs-on: ubuntu-latest env: - GO_VERSION: 1.18.0 + GO_VERSION: 1.19.3 BEAM_VERSION: 2.40.0 TERRAFORM_VERSION: 1.0.9 STAND_SUFFIX: '' diff --git a/.github/workflows/go_tests.yml b/.github/workflows/go_tests.yml index 8ee8110ad417..3233be40e401 100644 --- a/.github/workflows/go_tests.yml +++ b/.github/workflows/go_tests.yml @@ -74,4 +74,4 @@ jobs: echo -e "Please address Staticcheck warnings before checking in changes\n" echo -e "Staticcheck Warnings:\n" echo -e "$RESULTS" && exit 1 - fi \ No newline at end of file + fi diff --git a/.github/workflows/local_env_tests.yml b/.github/workflows/local_env_tests.yml index 1c120f809df0..9c0be264819c 100644 --- a/.github/workflows/local_env_tests.yml +++ b/.github/workflows/local_env_tests.yml @@ -42,7 +42,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: - go-version: '1.18' + go-version: '1.19' - name: "Installing local env dependencies" run: "sudo ./local-env-setup.sh" id: local_env_install_ubuntu @@ -57,7 +57,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: - go-version: '1.18' + go-version: '1.19' - name: "Installing local env dependencies" run: "./local-env-setup.sh" id: local_env_install_mac diff --git a/.test-infra/jenkins/job_PerformanceTests_BigQueryIO_Python.groovy b/.test-infra/jenkins/job_PerformanceTests_BigQueryIO_Python.groovy index 1ccb8238ba87..853347f9ebfb 100644 --- a/.test-infra/jenkins/job_PerformanceTests_BigQueryIO_Python.groovy +++ b/.test-infra/jenkins/job_PerformanceTests_BigQueryIO_Python.groovy @@ -90,7 +90,7 @@ PhraseTriggeringPostCommitBuilder.postCommitJob( executeJob(delegate, bqio_read_test) } -CronJobBuilder.cronJob('beam_PerformanceTests_BiqQueryIO_Read_Python', 'H 15 * * *', this) { +CronJobBuilder.cronJob('beam_PerformanceTests_BiqQueryIO_Read_Python', 'H H * * *', this) { executeJob(delegate, bqio_read_test) } @@ -103,6 +103,6 @@ PhraseTriggeringPostCommitBuilder.postCommitJob( executeJob(delegate, bqio_write_test) } -CronJobBuilder.cronJob('beam_PerformanceTests_BiqQueryIO_Write_Python_Batch', 'H 15 * * *', this) { +CronJobBuilder.cronJob('beam_PerformanceTests_BiqQueryIO_Write_Python_Batch', 'H H * * *', this) { executeJob(delegate, bqio_write_test) } diff --git a/.test-infra/jenkins/job_PerformanceTests_FileBasedIO_Python.groovy b/.test-infra/jenkins/job_PerformanceTests_FileBasedIO_Python.groovy new file mode 100644 index 000000000000..e45beadf321a --- /dev/null +++ b/.test-infra/jenkins/job_PerformanceTests_FileBasedIO_Python.groovy @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import CommonJobProperties as common +import LoadTestsBuilder as loadTestsBuilder +import InfluxDBCredentialsHelper + +def now = new Date().format("MMddHHmmss", TimeZone.getTimeZone('UTC')) + +def jobs = [ + [ + name : 'beam_PerformanceTests_TextIOIT_Python', + description : 'Runs performance tests for Python TextIOIT', + test : 'apache_beam.io.filebasedio_perf_test', + githubTitle : 'Python TextIO Performance Test', + githubTriggerPhrase: 'Run Python TextIO Performance Test', + pipelineOptions : [ + publish_to_big_query : true, + metrics_dataset : 'beam_performance', + metrics_table : 'python_textio_1GB_results', + influx_measurement : 'python_textio_1GB_results', + test_class : 'TextIOPerfTest', + input_options : '\'{' + + '"num_records": 25000000,' + + '"key_size": 9,' + + '"value_size": 21}\'', + dataset_size : '1050000000', + num_workers : '5', + autoscaling_algorithm: 'NONE' + ] + ] +] + +jobs.findAll { + it.name in [ + 'beam_PerformanceTests_TextIOIT_Python', + ] +}.forEach { testJob -> createGCSFileBasedIOITTestJob(testJob) } + +private void createGCSFileBasedIOITTestJob(testJob) { + job(testJob.name) { + description(testJob.description) + common.setTopLevelMainJobProperties(delegate) + common.enablePhraseTriggeringFromPullRequest(delegate, testJob.githubTitle, testJob.githubTriggerPhrase) + common.setAutoJob(delegate, 'H H * * *') + InfluxDBCredentialsHelper.useCredentials(delegate) + additionalPipelineArgs = [ + influxDatabase: InfluxDBCredentialsHelper.InfluxDBDatabaseName, + influxHost: InfluxDBCredentialsHelper.InfluxDBHostUrl, + ] + testJob.pipelineOptions.putAll(additionalPipelineArgs) + + def dataflowSpecificOptions = [ + runner : 'DataflowRunner', + project : 'apache-beam-testing', + region : 'us-central1', + temp_location : 'gs://temp-storage-for-perf-tests/', + filename_prefix : "gs://temp-storage-for-perf-tests/${testJob.name}/\${BUILD_ID}/", + ] + + Map allPipelineOptions = dataflowSpecificOptions << testJob.pipelineOptions + + loadTestsBuilder.loadTest( + delegate, testJob.name, CommonTestProperties.Runner.DATAFLOW, CommonTestProperties.SDK.PYTHON, allPipelineOptions, testJob.test) + } +} diff --git a/.test-infra/jenkins/job_PerformanceTests_PubsubIO_Python.groovy b/.test-infra/jenkins/job_PerformanceTests_PubsubIO_Python.groovy index 327e93f392ff..262eda3fd909 100644 --- a/.test-infra/jenkins/job_PerformanceTests_PubsubIO_Python.groovy +++ b/.test-infra/jenkins/job_PerformanceTests_PubsubIO_Python.groovy @@ -70,6 +70,6 @@ PhraseTriggeringPostCommitBuilder.postCommitJob( executeJob(delegate, psio_test) } -CronJobBuilder.cronJob('beam_PerformanceTests_PubsubIOIT_Python_Streaming', 'H 15 * * *', this) { +CronJobBuilder.cronJob('beam_PerformanceTests_PubsubIOIT_Python_Streaming', 'H H * * *', this) { executeJob(delegate, psio_test) } diff --git a/.test-infra/jenkins/job_PerformanceTests_SpannerIO_Python.groovy b/.test-infra/jenkins/job_PerformanceTests_SpannerIO_Python.groovy index 489c72ebaa25..416186567075 100644 --- a/.test-infra/jenkins/job_PerformanceTests_SpannerIO_Python.groovy +++ b/.test-infra/jenkins/job_PerformanceTests_SpannerIO_Python.groovy @@ -92,7 +92,7 @@ PhraseTriggeringPostCommitBuilder.postCommitJob( executeJob(delegate, spannerio_read_test_2gb) } -CronJobBuilder.cronJob('beam_PerformanceTests_SpannerIO_Read_2GB_Python', 'H 15 * * *', this) { +CronJobBuilder.cronJob('beam_PerformanceTests_SpannerIO_Read_2GB_Python', 'H H * * *', this) { executeJob(delegate, spannerio_read_test_2gb) } @@ -105,6 +105,6 @@ PhraseTriggeringPostCommitBuilder.postCommitJob( executeJob(delegate, spannerio_write_test_2gb) } -CronJobBuilder.cronJob('beam_PerformanceTests_SpannerIO_Write_2GB_Python_Batch', 'H 15 * * *', this) { +CronJobBuilder.cronJob('beam_PerformanceTests_SpannerIO_Write_2GB_Python_Batch', 'H H * * *', this) { executeJob(delegate, spannerio_write_test_2gb) } diff --git a/.test-infra/jenkins/job_PostCommit_Python_Examples_Flink.groovy b/.test-infra/jenkins/job_PostCommit_Python_Examples_Flink.groovy index 779395bf7093..c1a44b8e9d43 100644 --- a/.test-infra/jenkins/job_PostCommit_Python_Examples_Flink.groovy +++ b/.test-infra/jenkins/job_PostCommit_Python_Examples_Flink.groovy @@ -37,6 +37,7 @@ PostcommitJobBuilder.postCommitJob('beam_PostCommit_Python_Examples_Flink', gradle { rootBuildScriptDir(commonJobProperties.checkoutDir) tasks(":sdks:python:test-suites:portable:flinkExamplesPostCommit") + switches("-PflinkConfDir=$WORKSPACE/src/runners/flink/src/test/resources") commonJobProperties.setGradleSwitches(delegate) } } diff --git a/.test-infra/metrics/grafana/dashboards/perftests_metrics/Python_IO_IT_Tests_Dataflow.json b/.test-infra/metrics/grafana/dashboards/perftests_metrics/Python_IO_IT_Tests_Dataflow.json index 5b1ff2b8103b..6db7a46edb5a 100644 --- a/.test-infra/metrics/grafana/dashboards/perftests_metrics/Python_IO_IT_Tests_Dataflow.json +++ b/.test-infra/metrics/grafana/dashboards/perftests_metrics/Python_IO_IT_Tests_Dataflow.json @@ -482,6 +482,128 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "cacheTimeout": null, + "dashLength": 10, + "dashes": false, + "datasource": "BeamInfluxDB", + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 9, + "w": 12, + "x": 12, + "y": 9 + }, + "hiddenSeries": false, + "id": 6, + "interval": "24h", + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": false, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "connected", + "options": { + "dataLinks": [] + }, + "percentage": false, + "pluginVersion": "6.7.2", + "pointradius": 2, + "points": true, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "alias": "$tag_metric", + "groupBy": [ + { + "params": [ + "$__interval" + ], + "type": "time" + } + ], + "measurement": "", + "orderByTime": "ASC", + "policy": "default", + "query": "SELECT mean(\"value\") FROM \"python_textio_1GB_results\" WHERE \"metric\" = 'read_runtime' OR \"metric\" = 'write_runtime' AND $timeFilter GROUP BY time($__interval), \"metric\"", + "rawQuery": true, + "refId": "A", + "resultFormat": "time_series", + "select": [ + [ + { + "params": [ + "value" + ], + "type": "field" + }, + { + "params": [], + "type": "mean" + } + ] + ], + "tags": [] + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "TextIO | GCS | 1 GB", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transparent": true, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:403", + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "$$hashKey": "object:404", + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "schemaVersion": 22, diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 6aa2e4859c59..4c510b4dda2a 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -1972,7 +1972,7 @@ class BeamModulePlugin implements Plugin { def goRootDir = "${project.rootDir}/sdks/go" // This sets the whole project Go version. - project.ext.goVersion = "go1.18.1" + project.ext.goVersion = "go1.19.3" // Minor TODO: Figure out if we can pull out the GOCMD env variable after goPrepare script // completion, and avoid this GOBIN substitution. diff --git a/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb b/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb index b77a63569820..645d62d32be3 100644 --- a/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb +++ b/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb @@ -1535,7 +1535,7 @@ "numerical_cols = beam_df.select_dtypes(include=np.number).columns.tolist()\n", "categorical_cols = list(set(beam_df.columns) - set(numerical_cols))" ] - } + }, { "cell_type": "code", "execution_count": null, @@ -3492,4 +3492,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/learning/tour-of-beam/frontend/lib/models/content_tree.dart b/learning/tour-of-beam/frontend/lib/models/content_tree.dart index 4c3ba29378ab..471bb6734fe6 100644 --- a/learning/tour-of-beam/frontend/lib/models/content_tree.dart +++ b/learning/tour-of-beam/frontend/lib/models/content_tree.dart @@ -18,19 +18,29 @@ import '../repositories/models/get_content_tree_response.dart'; import 'module.dart'; +import 'node.dart'; +import 'parent_node.dart'; -class ContentTreeModel { - final String sdkId; +class ContentTreeModel extends ParentNodeModel { final List modules; + String get sdkId => id; + + @override + List get nodes => modules; + const ContentTreeModel({ - required this.sdkId, + required super.id, required this.modules, - }); + }) : super( + parent: null, + title: '', + nodes: modules, + ); ContentTreeModel.fromResponse(GetContentTreeResponse response) : this( - sdkId: response.sdkId, + id: response.sdkId, modules: response.modules .map(ModuleModel.fromResponse) .toList(growable: false), diff --git a/learning/tour-of-beam/frontend/lib/models/group.dart b/learning/tour-of-beam/frontend/lib/models/group.dart index 22086e6303e8..ba1d4047a574 100644 --- a/learning/tour-of-beam/frontend/lib/models/group.dart +++ b/learning/tour-of-beam/frontend/lib/models/group.dart @@ -23,15 +23,28 @@ import 'parent_node.dart'; class GroupModel extends ParentNodeModel { const GroupModel({ required super.id, - required super.title, required super.nodes, + required super.parent, + required super.title, }); - GroupModel.fromResponse(GroupResponseModel group) - : super( - id: group.id, - title: group.title, - nodes: - group.nodes.map(NodeModel.fromResponse).toList(growable: false), - ); + factory GroupModel.fromResponse( + GroupResponseModel groupResponse, + ParentNodeModel parent, + ) { + final group = GroupModel( + id: groupResponse.id, + nodes: [], + parent: parent, + title: groupResponse.title, + ); + + group.nodes.addAll( + groupResponse.nodes.map( + (node) => NodeModel.fromResponse(node, group), + ), + ); + + return group; + } } diff --git a/learning/tour-of-beam/frontend/lib/models/module.dart b/learning/tour-of-beam/frontend/lib/models/module.dart index 81f8c1b6d613..eb1f7e50633c 100644 --- a/learning/tour-of-beam/frontend/lib/models/module.dart +++ b/learning/tour-of-beam/frontend/lib/models/module.dart @@ -27,18 +27,27 @@ class ModuleModel extends ParentNodeModel { const ModuleModel({ required super.id, - required super.title, required super.nodes, + required super.parent, + required super.title, required this.complexity, }); - ModuleModel.fromResponse(ModuleResponseModel module) - : complexity = module.complexity, - super( - id: module.id, - title: module.title, - nodes: module.nodes - .map(NodeModel.fromResponse) - .toList(growable: false), - ); + factory ModuleModel.fromResponse(ModuleResponseModel moduleResponse) { + final module = ModuleModel( + complexity: moduleResponse.complexity, + nodes: [], + id: moduleResponse.id, + parent: null, + title: moduleResponse.title, + ); + + module.nodes.addAll( + moduleResponse.nodes.map( + (node) => NodeModel.fromResponse(node, module), + ), + ); + + return module; + } } diff --git a/learning/tour-of-beam/frontend/lib/models/node.dart b/learning/tour-of-beam/frontend/lib/models/node.dart index d13ceea2d282..7a653de1e3ad 100644 --- a/learning/tour-of-beam/frontend/lib/models/node.dart +++ b/learning/tour-of-beam/frontend/lib/models/node.dart @@ -19,15 +19,18 @@ import '../repositories/models/node.dart'; import '../repositories/models/node_type_enum.dart'; import 'group.dart'; +import 'parent_node.dart'; import 'unit.dart'; abstract class NodeModel { final String id; final String title; + final NodeModel? parent; const NodeModel({ required this.id, required this.title, + required this.parent, }); /// Constructs nodes from the response data. @@ -36,20 +39,27 @@ abstract class NodeModel { /// because they come from a golang backend which does not /// support inheritance, and so they use an extra layer of composition /// which is inconvenient in Flutter. - static List fromMaps(List json) { + static List fromMaps(List json, ParentNodeModel parent) { return json .cast>() .map(NodeResponseModel.fromJson) - .map(fromResponse) + .map((nodeResponse) => fromResponse(nodeResponse, parent)) .toList(); } - static NodeModel fromResponse(NodeResponseModel node) { + static NodeModel fromResponse( + NodeResponseModel node, + ParentNodeModel parent, + ) { switch (node.type) { case NodeType.group: - return GroupModel.fromResponse(node.group!); + return GroupModel.fromResponse(node.group!, parent); case NodeType.unit: - return UnitModel.fromResponse(node.unit!); + return UnitModel.fromResponse(node.unit!, parent); } } + + NodeModel getFirstUnit(); + + NodeModel? getNodeByTreeIds(List treeIds); } diff --git a/learning/tour-of-beam/frontend/lib/models/parent_node.dart b/learning/tour-of-beam/frontend/lib/models/parent_node.dart index 0271cfb9508f..53f3c7a17667 100644 --- a/learning/tour-of-beam/frontend/lib/models/parent_node.dart +++ b/learning/tour-of-beam/frontend/lib/models/parent_node.dart @@ -16,6 +16,8 @@ * limitations under the License. */ +import 'package:collection/collection.dart'; + import 'node.dart'; abstract class ParentNodeModel extends NodeModel { @@ -23,7 +25,23 @@ abstract class ParentNodeModel extends NodeModel { const ParentNodeModel({ required super.id, + required super.parent, required super.title, required this.nodes, }); + + @override + NodeModel getFirstUnit() => nodes[0].getFirstUnit(); + + @override + NodeModel? getNodeByTreeIds(List treeIds) { + final firstId = treeIds.firstOrNull; + final child = nodes.firstWhereOrNull((node) => node.id == firstId); + + if (child == null) { + return null; + } + + return child.getNodeByTreeIds(treeIds.sublist(1)); + } } diff --git a/learning/tour-of-beam/frontend/lib/models/unit.dart b/learning/tour-of-beam/frontend/lib/models/unit.dart index 48b55af33d15..eb2e158ddf62 100644 --- a/learning/tour-of-beam/frontend/lib/models/unit.dart +++ b/learning/tour-of-beam/frontend/lib/models/unit.dart @@ -18,8 +18,19 @@ import '../repositories/models/unit.dart'; import 'node.dart'; +import 'parent_node.dart'; class UnitModel extends NodeModel { - UnitModel.fromResponse(UnitResponseModel unit) - : super(id: unit.id, title: unit.title); + UnitModel.fromResponse(UnitResponseModel unit, ParentNodeModel parent) + : super( + id: unit.id, + parent: parent, + title: unit.title, + ); + + @override + NodeModel getFirstUnit() => this; + + @override + NodeModel? getNodeByTreeIds(List treeIds) => this; } diff --git a/learning/tour-of-beam/frontend/lib/pages/tour/controllers/content_tree.dart b/learning/tour-of-beam/frontend/lib/pages/tour/controllers/content_tree.dart index dc5fc5a15ceb..bfa63c94df4f 100644 --- a/learning/tour-of-beam/frontend/lib/pages/tour/controllers/content_tree.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/controllers/content_tree.dart @@ -17,33 +17,96 @@ */ import 'package:flutter/widgets.dart'; +import 'package:get_it/get_it.dart'; import 'package:playground_components/playground_components.dart'; +import '../../../cache/content_tree.dart'; +import '../../../models/group.dart'; import '../../../models/node.dart'; +import '../../../models/unit.dart'; class ContentTreeController extends ChangeNotifier { String _sdkId; List _treeIds; NodeModel? _currentNode; + final _contentTreeCache = GetIt.instance.get(); + final _expandedIds = {}; + + Set get expandedIds => _expandedIds; ContentTreeController({ required String initialSdkId, List initialTreeIds = const [], }) : _sdkId = initialSdkId, - _treeIds = initialTreeIds; + _treeIds = initialTreeIds { + _expandedIds.addAll(initialTreeIds); + + _contentTreeCache.addListener(_onContentTreeCacheChange); + _onContentTreeCacheChange(); + } Sdk get sdk => Sdk.parseOrCreate(_sdkId); String get sdkId => _sdkId; List get treeIds => _treeIds; NodeModel? get currentNode => _currentNode; - void onNodeTap(NodeModel node) { + void openNode(NodeModel node) { + if (!_expandedIds.contains(node.id)) { + _expandedIds.add(node.id); + } + if (node == _currentNode) { return; } - _currentNode = node; - // TODO(alexeyinkin): Set _treeIds from node. + if (node is GroupModel) { + openNode(node.nodes.first); + } else if (node is UnitModel) { + _currentNode = node; + } + + if (_currentNode != null) { + _treeIds = _getNodeAncestors(_currentNode!, [_currentNode!.id]); + } + notifyListeners(); + } + + void expandGroup(GroupModel group) { + _expandedIds.add(group.id); + notifyListeners(); + } + + void collapseGroup(GroupModel group) { + _expandedIds.remove(group.id); + notifyListeners(); + } + + List _getNodeAncestors(NodeModel node, List ancestorIds) { + if (node.parent != null) { + return _getNodeAncestors( + node.parent!, + [...ancestorIds, node.parent!.id], + ); + } + return ancestorIds.reversed.toList(); + } + + void _onContentTreeCacheChange() { + final contentTree = _contentTreeCache.getContentTree(_sdkId); + if (contentTree == null) { + return; + } + + openNode( + contentTree.getNodeByTreeIds(_treeIds) ?? contentTree.getFirstUnit(), + ); + notifyListeners(); } + + @override + void dispose() { + _contentTreeCache.removeListener(_onContentTreeCacheChange); + super.dispose(); + } } diff --git a/learning/tour-of-beam/frontend/lib/pages/tour/path.dart b/learning/tour-of-beam/frontend/lib/pages/tour/path.dart index 5f8971852f9f..07dd386bdfcb 100644 --- a/learning/tour-of-beam/frontend/lib/pages/tour/path.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/path.dart @@ -26,7 +26,7 @@ class TourPath extends PagePath { final String sdkId; final List treeIds; - static final _regExp = RegExp(r'^/tour/([a-z]+)(/[/-a-zA-Z0-9]+)?$'); + static final _regExp = RegExp(r'^/tour/([a-z]+)((/[-a-zA-Z0-9]+)*)$'); TourPath({ required this.sdkId, @@ -47,7 +47,12 @@ class TourPath extends PagePath { if (matches == null) return null; final sdkId = matches[1] ?? (throw Error()); - final treeIds = matches[2]?.split('/') ?? const []; + final treeIdsString = matches[2]; + + final treeIds = (treeIdsString == null) + ? const [] + // TODO(nausharipov): use RegExp to remove the slash + : treeIdsString.substring(1).split('/'); return TourPath( sdkId: sdkId, diff --git a/learning/tour-of-beam/frontend/lib/pages/tour/state.dart b/learning/tour-of-beam/frontend/lib/pages/tour/state.dart index ae8fc0e1e706..e709839e915e 100644 --- a/learning/tour-of-beam/frontend/lib/pages/tour/state.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/state.dart @@ -44,6 +44,7 @@ class TourNotifier extends ChangeNotifier with PageStateMixin { playgroundController = _createPlaygroundController(initialSdkId) { contentTreeController.addListener(_onChanged); _unitContentCache.addListener(_onChanged); + _onChanged(); } @override @@ -53,6 +54,7 @@ class TourNotifier extends ChangeNotifier with PageStateMixin { ); void _onChanged() { + emitPathChanged(); final currentNode = contentTreeController.currentNode; if (currentNode is UnitModel) { final content = _unitContentCache.getUnitContent( diff --git a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/group.dart b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/group.dart index bdebcfc507be..fad732b105bb 100644 --- a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/group.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/group.dart @@ -17,13 +17,12 @@ */ import 'package:flutter/material.dart'; -import 'package:playground_components/playground_components.dart'; -import '../../../components/expansion_tile_wrapper.dart'; import '../../../models/group.dart'; import '../controllers/content_tree.dart'; import 'group_nodes.dart'; import 'group_title.dart'; +import 'stateless_expansion_tile.dart'; class GroupWidget extends StatelessWidget { final GroupModel group; @@ -36,23 +35,32 @@ class GroupWidget extends StatelessWidget { @override Widget build(BuildContext context) { - return ExpansionTileWrapper( - ExpansionTile( - tilePadding: EdgeInsets.zero, - title: GroupTitleWidget( - group: group, - onTap: () => contentTreeController.onNodeTap(group), - ), - childrenPadding: const EdgeInsets.only( - left: BeamSizes.size24, - ), - children: [ - GroupNodesWidget( + return AnimatedBuilder( + animation: contentTreeController, + builder: (context, child) { + final isExpanded = contentTreeController.expandedIds.contains(group.id); + + return StatelessExpansionTile( + isExpanded: isExpanded, + onExpansionChanged: (isExpanding) { + if (isExpanding) { + contentTreeController.expandGroup(group); + } else { + contentTreeController.collapseGroup(group); + } + }, + title: GroupTitleWidget( + group: group, + onTap: () { + contentTreeController.openNode(group); + }, + ), + child: GroupNodesWidget( nodes: group.nodes, contentTreeController: contentTreeController, ), - ], - ), + ); + }, ); } } diff --git a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/group_title.dart b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/group_title.dart index a25c5498bd92..974199946cba 100644 --- a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/group_title.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/group_title.dart @@ -38,7 +38,10 @@ class GroupTitleWidget extends StatelessWidget { onTap: onTap, child: Row( children: [ - TourProgressIndicator(assetPath: Assets.svg.unitProgress0), + TourProgressIndicator( + assetPath: Assets.svg.unitProgress0, + isSelected: false, + ), Text( group.title, style: Theme.of(context).textTheme.headlineMedium, diff --git a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/module.dart b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/module.dart index 886e9f98d863..b01987bf0a7c 100644 --- a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/module.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/module.dart @@ -39,7 +39,7 @@ class ModuleWidget extends StatelessWidget { children: [ ModuleTitleWidget( module: module, - onTap: () => contentTreeController.onNodeTap(module), + onTap: () => contentTreeController.openNode(module), ), ...module.nodes .map( diff --git a/learning/tour-of-beam/frontend/lib/components/filler_text.dart b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/stateless_expansion_tile.dart similarity index 52% rename from learning/tour-of-beam/frontend/lib/components/filler_text.dart rename to learning/tour-of-beam/frontend/lib/pages/tour/widgets/stateless_expansion_tile.dart index ca6099e6d9de..149bd04a586e 100644 --- a/learning/tour-of-beam/frontend/lib/components/filler_text.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/stateless_expansion_tile.dart @@ -17,13 +17,37 @@ */ import 'package:flutter/material.dart'; +import 'package:playground_components/playground_components.dart'; -class FillerText extends StatelessWidget { - final int width; - const FillerText({required this.width}); +import '../../../components/expansion_tile_wrapper.dart'; + +class StatelessExpansionTile extends StatelessWidget { + final bool isExpanded; + final ValueChanged? onExpansionChanged; + final Widget title; + final Widget child; + + const StatelessExpansionTile({ + required this.isExpanded, + required this.onExpansionChanged, + required this.title, + required this.child, + }); @override Widget build(BuildContext context) { - return Text(''.padRight(width, 'Just a filler text. ')); + return ExpansionTileWrapper( + ExpansionTile( + key: ValueKey(isExpanded), + initiallyExpanded: isExpanded, + tilePadding: EdgeInsets.zero, + onExpansionChanged: onExpansionChanged, + title: title, + childrenPadding: const EdgeInsets.only( + left: BeamSizes.size24, + ), + children: [child], + ), + ); } } diff --git a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/tour_progress_indicator.dart b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/tour_progress_indicator.dart index 6184a22a9d4f..6f3d6ba56087 100644 --- a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/tour_progress_indicator.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/tour_progress_indicator.dart @@ -21,18 +21,30 @@ import 'package:flutter_svg/svg.dart'; import 'package:playground_components/playground_components.dart'; class TourProgressIndicator extends StatelessWidget { + // TODO(nausharipov): replace assetPath with progress enum final String assetPath; + final bool isSelected; - const TourProgressIndicator({required this.assetPath}); + const TourProgressIndicator({ + required this.assetPath, + required this.isSelected, + }); @override Widget build(BuildContext context) { + final ext = Theme.of(context).extension()!; + return Padding( padding: const EdgeInsets.only( left: BeamSizes.size4, right: BeamSizes.size8, ), - child: SvgPicture.asset(assetPath), + child: SvgPicture.asset( + assetPath, + color: isSelected + ? ext.selectedProgressColor + : ext.unselectedProgressColor, + ), ); } } diff --git a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/unit.dart b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/unit.dart index 914361a347a4..cfc0e32235a9 100644 --- a/learning/tour-of-beam/frontend/lib/pages/tour/widgets/unit.dart +++ b/learning/tour-of-beam/frontend/lib/pages/tour/widgets/unit.dart @@ -35,17 +35,33 @@ class UnitWidget extends StatelessWidget { @override Widget build(BuildContext context) { - return ClickableWidget( - onTap: () => contentTreeController.onNodeTap(unit), - child: Padding( - padding: const EdgeInsets.symmetric(vertical: BeamSizes.size10), - child: Row( - children: [ - TourProgressIndicator(assetPath: Assets.svg.unitProgress0), - Expanded(child: Text(unit.title)), - ], - ), - ), + return AnimatedBuilder( + animation: contentTreeController, + builder: (context, child) { + final isSelected = contentTreeController.currentNode?.id == unit.id; + + return ClickableWidget( + onTap: () => contentTreeController.openNode(unit), + child: Container( + decoration: BoxDecoration( + color: isSelected ? Theme.of(context).selectedRowColor : null, + borderRadius: BorderRadius.circular(BeamSizes.size3), + ), + padding: const EdgeInsets.symmetric(vertical: BeamSizes.size10), + child: Row( + children: [ + TourProgressIndicator( + assetPath: Assets.svg.unitProgress0, + isSelected: isSelected, + ), + Expanded( + child: Text(unit.title), + ), + ], + ), + ), + ); + }, ); } } diff --git a/learning/tour-of-beam/frontend/lib/pages/welcome/screen.dart b/learning/tour-of-beam/frontend/lib/pages/welcome/screen.dart index 421593562672..af6e91969bd2 100644 --- a/learning/tour-of-beam/frontend/lib/pages/welcome/screen.dart +++ b/learning/tour-of-beam/frontend/lib/pages/welcome/screen.dart @@ -24,7 +24,6 @@ import 'package:playground_components/playground_components.dart'; import '../../components/builders/content_tree.dart'; import '../../components/builders/sdks.dart'; -import '../../components/filler_text.dart'; import '../../components/scaffold.dart'; import '../../constants/sizes.dart'; import '../../generated/assets.gen.dart'; @@ -397,7 +396,6 @@ class _ModuleBody extends StatelessWidget { padding: _modulePadding, child: Column( children: [ - const FillerText(width: 20), const SizedBox(height: BeamSizes.size16), Divider( color: themeData.dividerColor, @@ -416,7 +414,6 @@ class _LastModuleBody extends StatelessWidget { return Container( margin: _moduleLeftMargin, padding: _modulePadding, - child: const FillerText(width: 20), ); } } diff --git a/learning/tour-of-beam/frontend/pubspec.lock b/learning/tour-of-beam/frontend/pubspec.lock index e1ed198ef56a..51fb2a0fd730 100644 --- a/learning/tour-of-beam/frontend/pubspec.lock +++ b/learning/tour-of-beam/frontend/pubspec.lock @@ -156,7 +156,7 @@ packages: source: hosted version: "4.2.0" collection: - dependency: transitive + dependency: "direct main" description: name: collection url: "https://pub.dartlang.org" @@ -278,7 +278,7 @@ packages: name: flutter_code_editor url: "https://pub.dartlang.org" source: hosted - version: "0.1.4" + version: "0.1.8" flutter_driver: dependency: transitive description: flutter diff --git a/learning/tour-of-beam/frontend/pubspec.yaml b/learning/tour-of-beam/frontend/pubspec.yaml index a6e829542e0c..a8f4fd9a4ce7 100644 --- a/learning/tour-of-beam/frontend/pubspec.yaml +++ b/learning/tour-of-beam/frontend/pubspec.yaml @@ -28,6 +28,7 @@ environment: dependencies: app_state: ^0.8.1 + collection: ^1.16.0 easy_localization: ^3.0.1 easy_localization_ext: ^0.1.0 easy_localization_loader: ^1.0.0 diff --git a/playground/frontend/playground_components/lib/src/constants/colors.dart b/playground/frontend/playground_components/lib/src/constants/colors.dart index 447d564056ed..6db92295c16a 100644 --- a/playground/frontend/playground_components/lib/src/constants/colors.dart +++ b/playground/frontend/playground_components/lib/src/constants/colors.dart @@ -36,45 +36,51 @@ class BeamColors { class BeamGraphColors { static const node = BeamColors.grey3; - static const border = Color(0xFF45454E); + static const border = Color(0xff45454E); static const edge = BeamLightThemeColors.primary; } class BeamNotificationColors { - static const error = Color(0xFFE54545); - static const info = Color(0xFF3E67F6); - static const success = Color(0xFF37AC66); - static const warning = Color(0xFFEEAB00); + static const error = Color(0xffE54545); + static const info = Color(0xff3E67F6); + static const success = Color(0xff37AC66); + static const warning = Color(0xffEEAB00); } class BeamLightThemeColors { - static const border = Color(0xFFE5E5E5); + static const border = Color(0xffE5E5E5); static const primaryBackground = BeamColors.white; static const secondaryBackground = Color(0xffFCFCFC); + static const selectedUnitColor = Color(0xffE6E7E9); + static const selectedProgressColor = BeamColors.grey3; + static const unselectedProgressColor = selectedUnitColor; static const grey = Color(0xffE5E5E5); - static const listBackground = Color(0xFFA0A4AB); + static const listBackground = BeamColors.grey3; static const text = BeamColors.darkBlue; static const primary = Color(0xffE74D1A); - static const icon = Color(0xFFA0A4AB); + static const icon = BeamColors.grey3; - static const code1 = Color(0xFFDA2833); - static const code2 = Color(0xFF5929B4); - static const codeComment = Color(0xFF4C6B60); - static const codeBackground = Color(0xFFFEF6F3); + static const code1 = Color(0xffDA2833); + static const code2 = Color(0xff5929B4); + static const codeComment = Color(0xff4C6B60); + static const codeBackground = Color(0xffFEF6F3); } class BeamDarkThemeColors { - static const border = Color(0xFFA0A4AB); + static const border = BeamColors.grey3; static const primaryBackground = Color(0xff18181B); static const secondaryBackground = BeamColors.darkGrey; + static const selectedUnitColor = Color(0xff626267); + static const selectedProgressColor = BeamColors.grey1; + static const unselectedProgressColor = selectedUnitColor; static const grey = Color(0xff3F3F46); - static const listBackground = Color(0xFF606772); - static const text = Color(0xffFFFFFF); + static const listBackground = Color(0xff606772); + static const text = Color(0xffffffff); static const primary = Color(0xffF26628); - static const icon = Color(0xFF606772); + static const icon = Color(0xff606772); - static const code1 = Color(0xFFDA2833); - static const code2 = Color(0xFF5929B4); - static const codeComment = Color(0xFF4C6B60); - static const codeBackground = Color(0xFF231B1B); + static const code1 = Color(0xffDA2833); + static const code2 = Color(0xff5929B4); + static const codeComment = Color(0xff4C6B60); + static const codeBackground = Color(0xff231B1B); } diff --git a/playground/frontend/playground_components/lib/src/theme/theme.dart b/playground/frontend/playground_components/lib/src/theme/theme.dart index 14c811abe931..287eef0a14f3 100644 --- a/playground/frontend/playground_components/lib/src/theme/theme.dart +++ b/playground/frontend/playground_components/lib/src/theme/theme.dart @@ -32,6 +32,9 @@ class BeamThemeExtension extends ThemeExtension { final Color primaryBackgroundTextColor; final Color lightGreyBackgroundTextColor; final Color secondaryBackgroundColor; + // TODO(nausharipov): simplify new color addition + final Color selectedProgressColor; + final Color unselectedProgressColor; final Color codeBackgroundColor; final TextStyle codeRootStyle; @@ -50,6 +53,8 @@ class BeamThemeExtension extends ThemeExtension { required this.codeBackgroundColor, required this.codeRootStyle, required this.codeTheme, + required this.selectedProgressColor, + required this.unselectedProgressColor, }); @override @@ -64,6 +69,8 @@ class BeamThemeExtension extends ThemeExtension { Color? codeBackgroundColor, TextStyle? codeRootStyle, CodeThemeData? codeTheme, + Color? selectedProgressColor, + Color? unselectedProgressColor, }) { return BeamThemeExtension( borderColor: borderColor ?? this.borderColor, @@ -79,6 +86,10 @@ class BeamThemeExtension extends ThemeExtension { codeBackgroundColor: codeBackgroundColor ?? this.codeBackgroundColor, codeRootStyle: codeRootStyle ?? this.codeRootStyle, codeTheme: codeTheme ?? this.codeTheme, + selectedProgressColor: + selectedProgressColor ?? this.selectedProgressColor, + unselectedProgressColor: + unselectedProgressColor ?? this.unselectedProgressColor, ); } @@ -104,6 +115,13 @@ class BeamThemeExtension extends ThemeExtension { Color.lerp(codeBackgroundColor, other?.codeBackgroundColor, t)!, codeRootStyle: TextStyle.lerp(codeRootStyle, other?.codeRootStyle, t)!, codeTheme: t == 0.0 ? codeTheme : other?.codeTheme ?? codeTheme, + selectedProgressColor: + Color.lerp(selectedProgressColor, other?.selectedProgressColor, t)!, + unselectedProgressColor: Color.lerp( + unselectedProgressColor, + other?.unselectedProgressColor, + t, + )!, ); } } @@ -121,6 +139,7 @@ final kLightTheme = ThemeData( ), primaryColor: BeamLightThemeColors.primary, scaffoldBackgroundColor: BeamLightThemeColors.secondaryBackground, + selectedRowColor: BeamLightThemeColors.selectedUnitColor, tabBarTheme: _getTabBarTheme( textColor: BeamLightThemeColors.text, indicatorColor: BeamLightThemeColors.primary, @@ -136,6 +155,8 @@ final kLightTheme = ThemeData( lightGreyBackgroundTextColor: BeamColors.black, markdownStyle: _getMarkdownStyle(Brightness.light), secondaryBackgroundColor: BeamLightThemeColors.secondaryBackground, + selectedProgressColor: BeamLightThemeColors.selectedProgressColor, + unselectedProgressColor: BeamLightThemeColors.unselectedProgressColor, codeBackgroundColor: BeamLightThemeColors.codeBackground, codeRootStyle: GoogleFonts.sourceCodePro( color: BeamLightThemeColors.text, @@ -194,6 +215,7 @@ final kDarkTheme = ThemeData( ), primaryColor: BeamDarkThemeColors.primary, scaffoldBackgroundColor: BeamDarkThemeColors.secondaryBackground, + selectedRowColor: BeamDarkThemeColors.selectedUnitColor, tabBarTheme: _getTabBarTheme( textColor: BeamDarkThemeColors.text, indicatorColor: BeamDarkThemeColors.primary, @@ -209,6 +231,8 @@ final kDarkTheme = ThemeData( lightGreyBackgroundTextColor: BeamColors.black, markdownStyle: _getMarkdownStyle(Brightness.dark), secondaryBackgroundColor: BeamDarkThemeColors.secondaryBackground, + selectedProgressColor: BeamDarkThemeColors.selectedProgressColor, + unselectedProgressColor: BeamDarkThemeColors.unselectedProgressColor, codeBackgroundColor: BeamDarkThemeColors.codeBackground, codeRootStyle: GoogleFonts.sourceCodePro( color: BeamDarkThemeColors.text, @@ -396,8 +420,10 @@ MarkdownStyleSheet _getMarkdownStyle(Brightness brightness) { return MarkdownStyleSheet( p: textTheme.bodyMedium, + pPadding: EdgeInsets.only(top: BeamSizes.size2), h1: textTheme.headlineLarge, h3: textTheme.headlineMedium, + h3Padding: EdgeInsets.only(top: BeamSizes.size4), code: GoogleFonts.sourceCodePro( color: textColor, backgroundColor: BeamColors.transparent, diff --git a/release/go-licenses/Dockerfile b/release/go-licenses/Dockerfile index 035055a9224e..4773643ca035 100644 --- a/release/go-licenses/Dockerfile +++ b/release/go-licenses/Dockerfile @@ -16,7 +16,7 @@ # limitations under the License. ############################################################################### -FROM golang:1.18-bullseye +FROM golang:1.19-bullseye WORKDIR /usr/src/app COPY go.mod ./ diff --git a/release/src/main/python-release/python_release_automation_utils.sh b/release/src/main/python-release/python_release_automation_utils.sh index b8fccae0eced..2f5a9ac0a5db 100644 --- a/release/src/main/python-release/python_release_automation_utils.sh +++ b/release/src/main/python-release/python_release_automation_utils.sh @@ -83,13 +83,13 @@ function get_version() { function download_files() { if [[ $1 = *"wheel"* ]]; then if [[ $2 == "python3.7" ]]; then - BEAM_PYTHON_SDK_WHL="apache_beam-$VERSION*-cp37-cp37m-manylinux1_x86_64.whl" + BEAM_PYTHON_SDK_WHL="apache_beam-$VERSION*-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" elif [[ $2 == "python3.8" ]]; then - BEAM_PYTHON_SDK_WHL="apache_beam-$VERSION*-cp38-cp38-manylinux1_x86_64.whl" + BEAM_PYTHON_SDK_WHL="apache_beam-$VERSION*-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" elif [[ $2 == "python3.9" ]]; then - BEAM_PYTHON_SDK_WHL="apache_beam-$VERSION*-cp39-cp39-manylinux1_x86_64.whl" + BEAM_PYTHON_SDK_WHL="apache_beam-$VERSION*-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" elif [[ $2 == "python3.10" ]]; then - BEAM_PYTHON_SDK_WHL="apache_beam-$VERSION*-cp310-cp310-manylinux1_x86_64.whl" + BEAM_PYTHON_SDK_WHL="apache_beam-$VERSION*-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" else echo "Unable to determine a Beam wheel for interpreter version $2." exit 1 diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java index e9d22dbf8e62..b64d8bde095c 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java @@ -42,10 +42,14 @@ public class CoderTypeInformation extends TypeInformation implements Atomi private final SerializablePipelineOptions pipelineOptions; public CoderTypeInformation(Coder coder, PipelineOptions pipelineOptions) { + this(coder, new SerializablePipelineOptions(pipelineOptions)); + } + + public CoderTypeInformation(Coder coder, SerializablePipelineOptions pipelineOptions) { checkNotNull(coder); checkNotNull(pipelineOptions); this.coder = coder; - this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions); + this.pipelineOptions = pipelineOptions; } public Coder getCoder() { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToByteBufferKeySelector.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToByteBufferKeySelector.java index 68c891d4f94d..204247b1d836 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToByteBufferKeySelector.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToByteBufferKeySelector.java @@ -51,6 +51,6 @@ public ByteBuffer getKey(WindowedValue> value) { @Override public TypeInformation getProducedType() { - return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions.get()); + return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java index 29af81de42f1..8c6f10abf448 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java @@ -56,6 +56,6 @@ public ByteBuffer getKey(WindowedValue, Double>> value) { @Override public TypeInformation getProducedType() { - return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions.get()); + return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WorkItemKeySelector.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WorkItemKeySelector.java index 3cdb0aece9a7..64ea6ca26d4d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WorkItemKeySelector.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WorkItemKeySelector.java @@ -52,6 +52,6 @@ public ByteBuffer getKey(WindowedValue> value) throws Except @Override public TypeInformation getProducedType() { - return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions.get()); + return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java index 5f35fd96ce02..92d0652e11f8 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java @@ -434,8 +434,7 @@ public void initializeState(FunctionInitializationContext context) throws Except @SuppressWarnings("unchecked") CoderTypeInformation, CheckpointMarkT>> typeInformation = - (CoderTypeInformation) - new CoderTypeInformation<>(checkpointCoder, serializedOptions.get()); + (CoderTypeInformation) new CoderTypeInformation<>(checkpointCoder, serializedOptions); stateForCheckpoint = stateStore.getListState( new ListStateDescriptor<>( diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java index 9830133166ff..10d3ea1f7a5a 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java @@ -33,7 +33,6 @@ import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; @@ -102,14 +101,14 @@ public T state( public ValueState bindValue(StateTag> address, Coder coder) { return new FlinkBroadcastValueState<>( - stateBackend, address, namespace, coder, pipelineOptions.get()); + stateBackend, address, namespace, coder, pipelineOptions); } @Override public BagState bindBag(StateTag> address, Coder elemCoder) { return new FlinkBroadcastBagState<>( - stateBackend, address, namespace, elemCoder, pipelineOptions.get()); + stateBackend, address, namespace, elemCoder, pipelineOptions); } @Override @@ -142,7 +141,7 @@ CombiningState bindCombiningValue( Combine.CombineFn combineFn) { return new FlinkCombiningState<>( - stateBackend, address, combineFn, namespace, accumCoder, pipelineOptions.get()); + stateBackend, address, combineFn, namespace, accumCoder, pipelineOptions); } @Override @@ -187,7 +186,7 @@ private abstract class AbstractBroadcastState { String name, StateNamespace namespace, Coder coder, - PipelineOptions pipelineOptions) { + SerializablePipelineOptions pipelineOptions) { this.name = name; this.namespace = namespace; @@ -303,7 +302,7 @@ private class FlinkBroadcastValueState extends AbstractBroadcastState StateTag> address, StateNamespace namespace, Coder coder, - PipelineOptions pipelineOptions) { + SerializablePipelineOptions pipelineOptions) { super(flinkStateBackend, address.getId(), namespace, coder, pipelineOptions); this.namespace = namespace; @@ -363,7 +362,7 @@ private class FlinkBroadcastBagState extends AbstractBroadcastState> StateTag> address, StateNamespace namespace, Coder coder, - PipelineOptions pipelineOptions) { + SerializablePipelineOptions pipelineOptions) { super(flinkStateBackend, address.getId(), namespace, ListCoder.of(coder), pipelineOptions); this.namespace = namespace; @@ -451,7 +450,7 @@ private class FlinkCombiningState extends AbstractBroad Combine.CombineFn combineFn, StateNamespace namespace, Coder accumCoder, - PipelineOptions pipelineOptions) { + SerializablePipelineOptions pipelineOptions) { super(flinkStateBackend, address.getId(), namespace, accumCoder, pipelineOptions); this.namespace = namespace; @@ -568,7 +567,7 @@ private class FlinkCombiningStateWithContext StateNamespace namespace, Coder accumCoder, CombineWithContext.Context context) { - super(flinkStateBackend, address.getId(), namespace, accumCoder, pipelineOptions.get()); + super(flinkStateBackend, address.getId(), namespace, accumCoder, pipelineOptions); this.namespace = namespace; this.address = address; diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java index b31306f6fc11..9013ca2b3499 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java @@ -180,6 +180,7 @@ public void processElement(ProcessContext c) throws Exception { "Multiple values [%s, %s] found for single key [%s] within window [%s].", map.get(kv.getValue().getValue().getKey()), kv.getValue().getValue().getValue(), + kv.getValue().getValue().getKey(), kv.getKey()); map.put( kv.getValue().getValue().getKey(), diff --git a/sdks/go.mod b/sdks/go.mod index 2ce62f733963..d11ad08058b9 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -23,14 +23,14 @@ module github.com/apache/beam/sdks/v2 go 1.18 require ( - cloud.google.com/go/bigquery v1.43.0 + cloud.google.com/go/bigquery v1.42.0 cloud.google.com/go/datastore v1.9.0 cloud.google.com/go/profiler v0.3.0 cloud.google.com/go/pubsub v1.26.0 cloud.google.com/go/storage v1.28.0 github.com/aws/aws-sdk-go-v2 v1.17.1 - github.com/aws/aws-sdk-go-v2/config v1.18.0 - github.com/aws/aws-sdk-go-v2/credentials v1.13.0 + github.com/aws/aws-sdk-go-v2/config v1.18.1 + github.com/aws/aws-sdk-go-v2/credentials v1.13.1 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.39 github.com/aws/aws-sdk-go-v2/service/s3 v1.29.2 github.com/aws/smithy-go v1.13.4 @@ -61,7 +61,10 @@ require ( gopkg.in/yaml.v2 v2.4.0 ) -require cloud.google.com/go/bigtable v1.18.0 +require ( + cloud.google.com/go/bigtable v1.18.0 + github.com/tetratelabs/wazero v1.0.0-pre.3 +) require ( cloud.google.com/go v0.105.0 // indirect @@ -87,7 +90,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.19 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.11.25 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.8 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.17.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.17.3 // indirect github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/census-instrumentation/opencensus-proto v0.2.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect @@ -123,7 +126,6 @@ require ( github.com/shabbyrobe/gocovmerge v0.0.0-20180507124511-f6ea450bfb63 // indirect github.com/sirupsen/logrus v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/tetratelabs/wazero v1.0.0-pre.3 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/tools v0.1.12 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index dbf81346f372..63edd0f3386f 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -36,8 +36,8 @@ cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvf cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/bigquery v1.43.0 h1:u0fvz5ysJBe1jwUPI4LuPwAX+o+6fCUwf3ECeg6eDUQ= -cloud.google.com/go/bigquery v1.43.0/go.mod h1:ZMQcXHsl+xmU1z36G2jNGZmKp9zNY5BUua5wDgmNCfw= +cloud.google.com/go/bigquery v1.42.0 h1:JuTk8po4bCKRwObdT0zLb1K0BGkGHJdtgs2GK3j2Gws= +cloud.google.com/go/bigquery v1.42.0/go.mod h1:8dRTJxhtG+vwBKzE5OseQn/hiydoQN3EedCaOdYmxRA= cloud.google.com/go/bigtable v1.18.0 h1:OzxQqEBRNcUt0u3V9HobUS95hr1GVVPNHtPGrCeXBfU= cloud.google.com/go/bigtable v1.18.0/go.mod h1:TwTdxeNeIwj2lOmtvqISXlRWuIovWkjSZsd03sCLz2U= cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow= @@ -149,11 +149,13 @@ github.com/aws/aws-sdk-go-v2 v1.17.1/go.mod h1:JLnGeGONAyi2lWXI1p0PCIOIy333JMVK1 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.9 h1:RKci2D7tMwpvGpDNZnGQw9wk6v7o/xSwFcUAuNPoB8k= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.9/go.mod h1:vCmV1q1VK8eoQJ5+aYE7PkK1K6v41qJ5pJdK3ggCDvg= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= -github.com/aws/aws-sdk-go-v2/config v1.18.0 h1:ULASZmfhKR/QE9UeZ7mzYjUzsnIydy/K1YMT6uH1KC0= github.com/aws/aws-sdk-go-v2/config v1.18.0/go.mod h1:H13DRX9Nv5tAcQvPABrE3dm5XnLp1RC7fVSM3OWiLvA= +github.com/aws/aws-sdk-go-v2/config v1.18.1 h1:wMzU9tBq/tEdTUcmB9WsYe5stdP0/EAf84vfeqS5S6A= +github.com/aws/aws-sdk-go-v2/config v1.18.1/go.mod h1:jQIgBmQJa5oPzTUtWMjFryPDCBlVqIgoFmdfFKLx4WE= github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc= -github.com/aws/aws-sdk-go-v2/credentials v1.13.0 h1:W5f73j1qurASap+jdScUo4aGzSXxaC7wq1i7CiwhvU8= github.com/aws/aws-sdk-go-v2/credentials v1.13.0/go.mod h1:prZpUfBu1KZLBLVX482Sq4DpDXGugAre08TPEc21GUg= +github.com/aws/aws-sdk-go-v2/credentials v1.13.1 h1:HusGjp9C8zwu1SSEh3s501Llqr2xhn+FYKV5XMnOt6M= +github.com/aws/aws-sdk-go-v2/credentials v1.13.1/go.mod h1:C8xoJdzfQq/kl6gGIuJeHpcAaZnraJfTV9FoBgW1QYg= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.19 h1:E3PXZSI3F2bzyj6XxUXdTIfvp425HHhwKsFvmzBwHgs= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.19/go.mod h1:VihW95zQpeKQWVPGkwT+2+WJNQV8UXFfMTWdU6VErL8= @@ -189,8 +191,9 @@ github.com/aws/aws-sdk-go-v2/service/sso v1.11.25/go.mod h1:IARHuzTXmj1C0KS35vbo github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.8 h1:jcw6kKZrtNfBPJkaHrscDOZoe5gvi9wjudnxvozYFJo= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.8/go.mod h1:er2JHN+kBY6FcMfcBBKNGCT3CarImmdFzishsqBmSRI= github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg= -github.com/aws/aws-sdk-go-v2/service/sts v1.17.2 h1:tpwEMRdMf2UsplengAOnmSIRdvAxf75oUFR+blBr92I= github.com/aws/aws-sdk-go-v2/service/sts v1.17.2/go.mod h1:bXcN3koeVYiJcdDU89n3kCYILob7Y34AeLopUbZgLT4= +github.com/aws/aws-sdk-go-v2/service/sts v1.17.3 h1:WMAsVk4yQTHOZ2m7dFnF5Azr/aDecBbpWRwc+M6iFIM= +github.com/aws/aws-sdk-go-v2/service/sts v1.17.3/go.mod h1:bXcN3koeVYiJcdDU89n3kCYILob7Y34AeLopUbZgLT4= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.13.4 h1:/RN2z1txIJWeXeOkzX+Hk/4Uuvv7dWtCjbmVJcrskyk= github.com/aws/smithy-go v1.13.4/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= diff --git a/sdks/go/run_with_go_version.sh b/sdks/go/run_with_go_version.sh index 1d8589a82cd1..5331a1782093 100755 --- a/sdks/go/run_with_go_version.sh +++ b/sdks/go/run_with_go_version.sh @@ -37,7 +37,7 @@ set -e # # This variable is also used as the execution command downscript. # The list of downloadable versions are at https://go.dev/dl/ -GOVERS=go1.18.1 +GOVERS=go1.19.3 if ! command -v go &> /dev/null then diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java index b8b9b29d10d2..a68d08c38044 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java @@ -87,6 +87,14 @@ public void testSizes() throws Exception { elasticsearchIOTestCommon.testSizes(); } + @Test + public void testSizesWithAlias() throws Exception { + // need to create the index using the helper method (not create it at first insertion) + // for the indexSettings() to be run + createIndex(elasticsearchIOTestCommon.restClient, getEsIndex(), true); + elasticsearchIOTestCommon.testSizes(); + } + @Test public void testRead() throws Exception { // need to create the index using the helper method (not create it at first insertion) diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java index 72efa08eb8b8..be98bfe16e81 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java @@ -86,6 +86,14 @@ public void testSizes() throws Exception { elasticsearchIOTestCommon.testSizes(); } + @Test + public void testSizesWithAlias() throws Exception { + // need to create the index using the helper method (not create it at first insertion) + // for the indexSettings() to be run + createIndex(elasticsearchIOTestCommon.restClient, getEsIndex(), true); + elasticsearchIOTestCommon.testSizes(); + } + @Test public void testRead() throws Exception { // need to create the index using the helper method (not create it at first insertion) diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-7/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-7/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java index 307ff3b6c43f..aab09f1b962a 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-7/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-7/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java @@ -87,6 +87,14 @@ public void testSizes() throws Exception { elasticsearchIOTestCommon.testSizes(); } + @Test + public void testSizesWithAlias() throws Exception { + // need to create the index using the helper method (not create it at first insertion) + // for the indexSettings() to be run + createIndex(elasticsearchIOTestCommon.restClient, getEsIndex(), true); + elasticsearchIOTestCommon.testSizes(); + } + @Test public void testRead() throws Exception { // need to create the index using the helper method (not create it at first insertion) diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-8/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-8/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java index 04cbf26b675d..6bf96360d533 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-8/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-8/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java @@ -87,6 +87,14 @@ public void testSizes() throws Exception { elasticsearchIOTestCommon.testSizes(); } + @Test + public void testSizesWithAlias() throws Exception { + // need to create the index using the helper method (not create it at first insertion) + // for the indexSettings() to be run + createIndex(elasticsearchIOTestCommon.restClient, getEsIndex(), true); + elasticsearchIOTestCommon.testSizes(); + } + @Test public void testRead() throws Exception { // need to create the index using the helper method (not create it at first insertion) diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTestUtils.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTestUtils.java index 74477716ca8e..4416f6c7ec3e 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTestUtils.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTestUtils.java @@ -56,6 +56,7 @@ class ElasticsearchIOTestUtils { static final String ELASTICSEARCH_PASSWORD = "superSecure"; static final String ELASTIC_UNAME = "elastic"; static final Set INVALID_DOCS_IDS = new HashSet<>(Arrays.asList(6, 7)); + static final String ALIAS_SUFFIX = "-aliased"; static final String[] FAMOUS_SCIENTISTS = { "einstein", @@ -87,9 +88,9 @@ static void deleteIndex(ConnectionConfiguration connectionConfiguration, RestCli deleteIndex(restClient, connectionConfiguration.getIndex()); } - private static void closeIndex(RestClient restClient, String index) throws IOException { + private static Response closeIndex(RestClient restClient, String index) throws IOException { Request request = new Request("POST", String.format("/%s/_close", index)); - restClient.performRequest(request); + return restClient.performRequest(request); } static void deleteIndex(RestClient restClient, String index) throws IOException { @@ -98,6 +99,10 @@ static void deleteIndex(RestClient restClient, String index) throws IOException Request request = new Request("DELETE", String.format("/%s", index)); restClient.performRequest(request); } catch (IOException e) { + if (e.getMessage().contains("matches an alias")) { + deleteIndex(restClient, index + ALIAS_SUFFIX); + return; + } // it is fine to ignore this expression as deleteIndex occurs in @before, // so when the first tests is run, the index does not exist yet if (!e.getMessage().contains("index_not_found_exception")) { @@ -106,9 +111,31 @@ static void deleteIndex(RestClient restClient, String index) throws IOException } } - public static void createIndex(RestClient restClient, String indexName) throws IOException { + public static void createIndex(RestClient restClient, String indexName, boolean createAsAlias) + throws IOException { deleteIndex(restClient, indexName); - Request request = new Request("PUT", String.format("/%s", indexName)); + + if (createAsAlias) { + // The intent is that an alias by the name of @param indexName points to a newly created + // index. This way, tests can validate that if the index targeted for reads/writes is + // actually an alias, everything continues to work. + String newIndexName = indexName + ALIAS_SUFFIX; + Request request = new Request("PUT", String.format("/%s", newIndexName)); + restClient.performRequest(request); + createIndexAlias(restClient, newIndexName, indexName); + } else { + Request request = new Request("PUT", String.format("/%s", indexName)); + restClient.performRequest(request); + } + } + + public static void createIndex(RestClient restClient, String indexName) throws IOException { + createIndex(restClient, indexName, false); + } + + public static void createIndexAlias(RestClient restClient, String indexName, String aliasName) + throws IOException { + Request request = new Request("PUT", String.format("/%s/_alias/%s", indexName, aliasName)); restClient.performRequest(request); } diff --git a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java index 7ac12bc70718..67863218783f 100644 --- a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java +++ b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java @@ -873,9 +873,8 @@ public long getEstimatedSizeBytes(PipelineOptions options) throws IOException { return estimatedByteSize; } final ConnectionConfiguration connectionConfiguration = spec.getConnectionConfiguration(); - JsonNode statsJson = getStats(connectionConfiguration, false); - JsonNode indexStats = - statsJson.path("indices").path(connectionConfiguration.getIndex()).path("primaries"); + JsonNode statsJson = getStats(connectionConfiguration); + JsonNode indexStats = statsJson.path("_all").path("primaries"); long indexSize = indexStats.path("store").path("size_in_bytes").asLong(); LOG.debug("estimate source byte size: total index size {}", indexSize); @@ -927,9 +926,8 @@ static long estimateIndexSize(ConnectionConfiguration connectionConfiguration) // NB: Elasticsearch 5.x+ now provides the slice API. // (https://www.elastic.co/guide/en/elasticsearch/reference/5.0/search-request-scroll.html // #sliced-scroll) - JsonNode statsJson = getStats(connectionConfiguration, false); - JsonNode indexStats = - statsJson.path("indices").path(connectionConfiguration.getIndex()).path("primaries"); + JsonNode statsJson = getStats(connectionConfiguration); + JsonNode indexStats = statsJson.path("_all").path("primaries"); JsonNode store = indexStats.path("store"); return store.path("size_in_bytes").asLong(); } @@ -956,12 +954,9 @@ public Coder getOutputCoder() { return StringUtf8Coder.of(); } - private static JsonNode getStats( - ConnectionConfiguration connectionConfiguration, boolean shardLevel) throws IOException { + private static JsonNode getStats(ConnectionConfiguration connectionConfiguration) + throws IOException { HashMap params = new HashMap<>(); - if (shardLevel) { - params.put("level", "shards"); - } String endpoint = String.format("/%s/_stats", connectionConfiguration.getIndex()); try (RestClient restClient = connectionConfiguration.createClient()) { Request request = new Request("GET", endpoint); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java index 2a95791f6c80..2a277722cc1f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java @@ -185,9 +185,9 @@ private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) { } String userAgentString = USER_AGENT_PREFIX + "/" + ReleaseInfo.getReleaseInfo().getVersion(); builder.setHeaderProvider(FixedHeaderProvider.create("user-agent", userAgentString)); - String databaseRole = spannerConfig.getDatabaseRole(); - if (databaseRole != null && !databaseRole.isEmpty()) { - builder.setDatabaseRole(databaseRole); + ValueProvider databaseRole = spannerConfig.getDatabaseRole(); + if (databaseRole != null && databaseRole.get() != null && !databaseRole.get().isEmpty()) { + builder.setDatabaseRole(databaseRole.get()); } SpannerOptions options = builder.build(); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java index c10af8429ece..8bf6cbb61435 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java @@ -73,7 +73,7 @@ public abstract class SpannerConfig implements Serializable { public abstract @Nullable ValueProvider getRpcPriority(); - public abstract @Nullable String getDatabaseRole(); + public abstract @Nullable ValueProvider getDatabaseRole(); @VisibleForTesting abstract @Nullable ServiceFactory getServiceFactory(); @@ -147,7 +147,7 @@ abstract Builder setExecuteStreamingSqlRetrySettings( abstract Builder setRpcPriority(ValueProvider rpcPriority); - abstract Builder setDatabaseRole(String databaseRole); + abstract Builder setDatabaseRole(ValueProvider databaseRole); public abstract SpannerConfig build(); } @@ -262,7 +262,7 @@ public SpannerConfig withRpcPriority(ValueProvider rpcPriority) { } /** Specifies the Cloud Spanner database role. */ - public SpannerConfig withDatabaseRole(String databaseRole) { + public SpannerConfig withDatabaseRole(ValueProvider databaseRole) { return toBuilder().setDatabaseRole(databaseRole).build(); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java index 7c94720c7875..43b581480dc4 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java @@ -72,7 +72,7 @@ public DaoFactory( } this.changeStreamSpannerConfig = changeStreamSpannerConfig; this.changeStreamName = changeStreamName; - this.metadataSpannerConfig = metadataSpannerConfig; + this.metadataSpannerConfig = metadataSpannerConfig.withDatabaseRole(null); this.partitionMetadataTableName = partitionMetadataTableName; this.rpcPriority = rpcPriority; this.jobName = jobName; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java index ef9f59a2d6e2..df38d22f5c13 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java @@ -109,7 +109,7 @@ public void testCreateWithValidDatabaseRole() { .setProjectId(StaticValueProvider.of("project")) .setInstanceId(StaticValueProvider.of("test1")) .setDatabaseId(StaticValueProvider.of("test1")) - .setDatabaseRole("test-role") + .setDatabaseRole(StaticValueProvider.of("test-role")) .build(); SpannerAccessor acc1 = SpannerAccessor.getOrCreate(config1); @@ -130,7 +130,7 @@ public void testCreateWithEmptyDatabaseRole() { .setProjectId(StaticValueProvider.of("project")) .setInstanceId(StaticValueProvider.of("test1")) .setDatabaseId(StaticValueProvider.of("test1")) - .setDatabaseRole("") + .setDatabaseRole(StaticValueProvider.of("")) .build(); SpannerAccessor acc1 = SpannerAccessor.getOrCreate(config1); diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index efb742fb2cba..97fd2a41ed17 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -2369,8 +2369,8 @@ public void finishBundle() throws Exception { cleanUpStatementAndConnection(); } - @Override - protected void finalize() throws Throwable { + @Teardown + public void tearDown() throws Exception { cleanUpStatementAndConnection(); } diff --git a/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/CustomX509TrustManager.java b/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/CustomX509TrustManager.java new file mode 100644 index 000000000000..384b20158b71 --- /dev/null +++ b/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/CustomX509TrustManager.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.splunk; + +import java.io.IOException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +/** A Custom X509TrustManager that trusts a user provided CA and default CA's. */ +public class CustomX509TrustManager implements X509TrustManager { + + private final X509TrustManager defaultTrustManager; + private final X509TrustManager userTrustManager; + + public CustomX509TrustManager(X509Certificate userCertificate) + throws CertificateException, KeyStoreException, NoSuchAlgorithmException, IOException { + // Get Default Trust Manager + TrustManagerFactory trustMgrFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustMgrFactory.init((KeyStore) null); + defaultTrustManager = getX509TrustManager(trustMgrFactory.getTrustManagers()); + + // Create Trust Manager with user provided certificate + KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + trustStore.load(null, null); + trustStore.setCertificateEntry("User Provided Root CA", userCertificate); + trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustMgrFactory.init(trustStore); + userTrustManager = getX509TrustManager(trustMgrFactory.getTrustManagers()); + } + + private X509TrustManager getX509TrustManager(TrustManager[] trustManagers) { + for (TrustManager tm : trustManagers) { + if (tm instanceof X509TrustManager) { + return (X509TrustManager) tm; + } + } + return null; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + defaultTrustManager.checkClientTrusted(chain, authType); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + try { + defaultTrustManager.checkServerTrusted(chain, authType); + } catch (CertificateException ce) { + // If the certificate chain couldn't be verified using the default trust manager, + // try verifying the same with the user-provided root CA + userTrustManager.checkServerTrusted(chain, authType); + } + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return defaultTrustManager.getAcceptedIssuers(); + } +} diff --git a/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/HttpEventPublisher.java b/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/HttpEventPublisher.java index 8c4613a6d21b..3f3ebb89cf28 100644 --- a/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/HttpEventPublisher.java +++ b/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/HttpEventPublisher.java @@ -34,13 +34,20 @@ import com.google.auto.value.AutoValue; import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.io.UnsupportedEncodingException; import java.security.KeyManagementException; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; import java.util.List; import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; @@ -103,6 +110,9 @@ static Builder newBuilder() { abstract @Nullable Integer maxElapsedMillis(); + @SuppressWarnings("mutable") + abstract byte @Nullable [] rootCaCertificate(); + abstract Boolean disableCertificateValidation(); abstract Boolean enableGzipHttpCompression(); @@ -215,6 +225,10 @@ abstract static class Builder { abstract Boolean disableCertificateValidation(); + abstract Builder setRootCaCertificate(byte[] certificate); + + abstract byte[] rootCaCertificate(); + abstract Builder setEnableGzipHttpCompression(Boolean enableGzipHttpCompression); abstract Builder setMaxElapsedMillis(Integer maxElapsedMillis); @@ -259,6 +273,17 @@ Builder withDisableCertificateValidation(Boolean disableCertificateValidation) { return setDisableCertificateValidation(disableCertificateValidation); } + /** + * Method to set the root CA certificate. + * + * @param certificate User provided root CA certificate + * @return {@link Builder} + */ + public Builder withRootCaCertificate(byte[] certificate) { + checkNotNull(certificate, "withRootCaCertificate(certificate) called with null input."); + return setRootCaCertificate(certificate); + } + /** * Method to specify if HTTP requests sent to Splunk HEC should be GZIP encoded. * @@ -291,7 +316,8 @@ Builder withMaxElapsedMillis(Integer maxElapsedMillis) { * @return {@link HttpEventPublisher} */ HttpEventPublisher build() - throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException { + throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException, + CertificateException { checkNotNull(token(), "Authentication token needs to be specified via withToken(token)."); checkNotNull(genericUrl(), "URL needs to be specified via withUrl(url)."); @@ -309,7 +335,8 @@ HttpEventPublisher build() } CloseableHttpClient httpClient = - getHttpClient(DEFAULT_MAX_CONNECTIONS, disableCertificateValidation()); + getHttpClient( + DEFAULT_MAX_CONNECTIONS, disableCertificateValidation(), rootCaCertificate()); setTransport(new ApacheHttpTransport(httpClient)); setRequestFactory(transport().createRequestFactory()); @@ -334,10 +361,12 @@ GenericUrl getGenericUrl(String baseUrl) { * * @param maxConnections max number of parallel connections * @param disableCertificateValidation should disable certificate validation + * @param rootCaCertificate root CA certificate */ private CloseableHttpClient getHttpClient( - int maxConnections, boolean disableCertificateValidation) - throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException { + int maxConnections, boolean disableCertificateValidation, byte[] rootCaCertificate) + throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException, + CertificateException { HttpClientBuilder builder = ApacheHttpTransport.newDefaultHttpClientBuilder(); @@ -349,14 +378,24 @@ private CloseableHttpClient getHttpClient( ? NoopHostnameVerifier.INSTANCE : new DefaultHostnameVerifier(); - SSLContextBuilder sslContextBuilder = SSLContextBuilder.create(); + SSLContext sslContext = SSLContextBuilder.create().build(); if (disableCertificateValidation) { LOG.info("Certificate validation is disabled"); - sslContextBuilder.loadTrustMaterial((TrustStrategy) (chain, authType) -> true); + sslContext = + SSLContextBuilder.create() + .loadTrustMaterial((TrustStrategy) (chain, authType) -> true) + .build(); + } else if (rootCaCertificate != null) { + LOG.info("Self-Signed Certificate provided"); + InputStream inStream = new ByteArrayInputStream(rootCaCertificate); + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate cert = (X509Certificate) cf.generateCertificate(inStream); + CustomX509TrustManager customTrustManager = new CustomX509TrustManager(cert); + sslContext.init(null, new TrustManager[] {customTrustManager}, null); } SSLConnectionSocketFactory connectionSocketFactory = - new SSLConnectionSocketFactory(sslContextBuilder.build(), hostnameVerifier); + new SSLConnectionSocketFactory(sslContext, hostnameVerifier); builder.setSSLSocketFactory(connectionSocketFactory); } diff --git a/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/SplunkEventWriter.java b/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/SplunkEventWriter.java index 00b0a0ea8f55..638b944aad73 100644 --- a/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/SplunkEventWriter.java +++ b/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/SplunkEventWriter.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.splunk; +import static java.util.stream.Collectors.toList; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; @@ -26,12 +27,19 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; import java.io.IOException; -import java.io.UnsupportedEncodingException; +import java.io.InputStream; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; import java.security.KeyManagementException; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; import java.time.Instant; import java.util.List; +import org.apache.beam.repackaged.core.org.apache.commons.compress.utils.IOUtils; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Distribution; import org.apache.beam.sdk.metrics.Metrics; @@ -123,6 +131,8 @@ static Builder newBuilder() { abstract @Nullable ValueProvider inputBatchCount(); + abstract @Nullable ValueProvider rootCaCertificatePath(); + abstract @Nullable ValueProvider enableBatchLogs(); abstract @Nullable ValueProvider enableGzipHttpCompression(); @@ -187,13 +197,18 @@ public void setup() { .withDisableCertificateValidation(disableValidation) .withEnableGzipHttpCompression(enableGzipHttpCompression); + if (rootCaCertificatePath() != null && rootCaCertificatePath().get() != null) { + builder.withRootCaCertificate(getCertFromGcsAsBytes(rootCaCertificatePath().get())); + } + publisher = builder.build(); LOG.info("Successfully created HttpEventPublisher"); } catch (NoSuchAlgorithmException | KeyStoreException | KeyManagementException - | UnsupportedEncodingException e) { + | IOException + | CertificateException e) { LOG.error("Error creating HttpEventPublisher: {}", e.getMessage()); throw new RuntimeException(e); } @@ -396,6 +411,41 @@ private static void flushWriteFailures( } } + /** + * Reads a root CA certificate from GCS and returns it as raw bytes. + * + * @param filePath path to root CA cert in GCS + * @return raw contents of cert + * @throws RuntimeException thrown if not able to read or parse cert + */ + public static byte[] getCertFromGcsAsBytes(String filePath) throws IOException { + ReadableByteChannel channel = getGcsFileByteChannel(filePath); + try (InputStream inputStream = Channels.newInputStream(channel)) { + return IOUtils.toByteArray(inputStream); + } catch (IOException e) { + throw new RuntimeException("Error when reading: " + filePath, e); + } + } + + /** Handles getting the {@link ReadableByteChannel} for {@code filePath}. */ + private static ReadableByteChannel getGcsFileByteChannel(String filePath) throws IOException { + try { + MatchResult result = FileSystems.match(filePath); + checkArgument( + result.status() == MatchResult.Status.OK && !result.metadata().isEmpty(), + "Failed to match any files with the pattern: " + filePath); + + List rId = + result.metadata().stream().map(MatchResult.Metadata::resourceId).collect(toList()); + + checkArgument(rId.size() == 1, "Expected exactly 1 file, but got " + rId.size() + " files."); + + return FileSystems.open(rId.get(0)); + } catch (IOException e) { + throw new RuntimeException("Error when finding: " + filePath, e); + } + } + @AutoValue.Builder abstract static class Builder { @@ -410,6 +460,8 @@ abstract static class Builder { abstract Builder setDisableCertificateValidation( ValueProvider disableCertificateValidation); + abstract Builder setRootCaCertificatePath(ValueProvider rootCaCertificatePath); + abstract Builder setEnableBatchLogs(ValueProvider enableBatchLogs); abstract Builder setEnableGzipHttpCompression(ValueProvider enableGzipHttpCompression); @@ -482,6 +534,16 @@ Builder withDisableCertificateValidation(ValueProvider disableCertifica return setDisableCertificateValidation(disableCertificateValidation); } + /** + * Method to set the self signed certificate path. + * + * @param rootCaCertificatePath Path to self-signed certificate + * @return {@link Builder} + */ + public Builder withRootCaCertificatePath(ValueProvider rootCaCertificatePath) { + return setRootCaCertificatePath(rootCaCertificatePath); + } + /** * Method to enable batch logs. * diff --git a/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/SplunkIO.java b/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/SplunkIO.java index bc8e822ce671..f3521583845b 100644 --- a/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/SplunkIO.java +++ b/sdks/java/io/splunk/src/main/java/org/apache/beam/sdk/io/splunk/SplunkIO.java @@ -142,6 +142,8 @@ public abstract static class Write abstract @Nullable ValueProvider disableCertificateValidation(); + abstract @Nullable ValueProvider rootCaCertificatePath(); + abstract @Nullable ValueProvider enableBatchLogs(); abstract @Nullable ValueProvider enableGzipHttpCompression(); @@ -158,9 +160,10 @@ public PCollection expand(PCollection input) { .withInputBatchCount(batchCount()) .withDisableCertificateValidation(disableCertificateValidation()) .withToken(token()) + .withRootCaCertificatePath(rootCaCertificatePath()) .withEnableBatchLogs(enableBatchLogs()) .withEnableGzipHttpCompression(enableGzipHttpCompression()); - ; + SplunkEventWriter writer = builder.build(); LOG.info("SplunkEventWriter configured"); @@ -186,6 +189,8 @@ abstract static class Builder { abstract Builder setDisableCertificateValidation( ValueProvider disableCertificateValidation); + abstract Builder setRootCaCertificatePath(ValueProvider rootCaCertificatePath); + abstract Builder setEnableBatchLogs(ValueProvider enableBatchLogs); abstract Builder setEnableGzipHttpCompression( @@ -264,6 +269,35 @@ public Write withDisableCertificateValidation(Boolean disableCertificateValidati .build(); } + /** + * Same as {@link Builder#withRootCaCertificatePath(ValueProvider)} but without a {@link + * ValueProvider}. + * + * @param rootCaCertificatePath Path to root CA certificate + * @return {@link Builder} + */ + public Write withRootCaCertificatePath(ValueProvider rootCaCertificatePath) { + checkArgument( + rootCaCertificatePath != null, + "withRootCaCertificatePath(rootCaCertificatePath) called with null input."); + return toBuilder().setRootCaCertificatePath(rootCaCertificatePath).build(); + } + + /** + * Method to set the root CA certificate. + * + * @param rootCaCertificatePath Path to root CA certificate + * @return {@link Builder} + */ + public Write withRootCaCertificatePath(String rootCaCertificatePath) { + checkArgument( + rootCaCertificatePath != null, + "withRootCaCertificatePath(rootCaCertificatePath) called with null input."); + return toBuilder() + .setRootCaCertificatePath(StaticValueProvider.of(rootCaCertificatePath)) + .build(); + } + /** * Same as {@link Builder#withEnableBatchLogs(ValueProvider)} but without a {@link * ValueProvider}. @@ -274,7 +308,7 @@ public Write withDisableCertificateValidation(Boolean disableCertificateValidati public Write withEnableBatchLogs(ValueProvider enableBatchLogs) { checkArgument( enableBatchLogs != null, "withEnableBatchLogs(enableBatchLogs) called with null input."); - return toBuilder().setEnableGzipHttpCompression(enableBatchLogs).build(); + return toBuilder().setEnableBatchLogs(enableBatchLogs).build(); } /** @@ -286,9 +320,7 @@ public Write withEnableBatchLogs(ValueProvider enableBatchLogs) { public Write withEnableBatchLogs(Boolean enableBatchLogs) { checkArgument( enableBatchLogs != null, "withEnableBatchLogs(enableBatchLogs) called with null input."); - return toBuilder() - .setEnableGzipHttpCompression(StaticValueProvider.of(enableBatchLogs)) - .build(); + return toBuilder().setEnableBatchLogs(StaticValueProvider.of(enableBatchLogs)).build(); } /** diff --git a/sdks/java/io/splunk/src/test/java/org/apache/beam/sdk/io/splunk/CustomX509TrustManagerTest.java b/sdks/java/io/splunk/src/test/java/org/apache/beam/sdk/io/splunk/CustomX509TrustManagerTest.java new file mode 100644 index 000000000000..e45dd1e98c1d --- /dev/null +++ b/sdks/java/io/splunk/src/test/java/org/apache/beam/sdk/io/splunk/CustomX509TrustManagerTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.splunk; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link CustomX509TrustManager}. */ +@RunWith(JUnit4.class) +public final class CustomX509TrustManagerTest { + + private CustomX509TrustManager customTrustManager; + private X509Certificate rootCa; + private X509Certificate recognizedSelfSignedCertificate; + private X509Certificate unrecognizedSelfSignedCertificate; + + @Before + public void setUp() + throws NoSuchAlgorithmException, CertificateException, FileNotFoundException, + KeyStoreException, IOException { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + ClassLoader classLoader = this.getClass().getClassLoader(); + FileInputStream rootCaInputStream = + new FileInputStream(classLoader.getResource("SplunkTestCerts/RootCA.crt").getFile()); + FileInputStream recognizedInputStream = + new FileInputStream( + classLoader.getResource("SplunkTestCerts/RecognizedCertificate.crt").getFile()); + FileInputStream unrecognizedInputStream = + new FileInputStream( + classLoader.getResource("SplunkTestCerts/UnrecognizedCertificate.crt").getFile()); + rootCa = (X509Certificate) cf.generateCertificate(rootCaInputStream); + recognizedSelfSignedCertificate = + (X509Certificate) cf.generateCertificate(recognizedInputStream); + unrecognizedSelfSignedCertificate = + (X509Certificate) cf.generateCertificate(unrecognizedInputStream); + + customTrustManager = new CustomX509TrustManager(rootCa); + } + + /** + * Tests whether a recognized (user provided) self-signed certificate is accepted by TrustManager. + */ + @Test + public void testCustomX509TrustManagerWithRecognizedCertificate() throws CertificateException { + customTrustManager.checkServerTrusted( + new X509Certificate[] {recognizedSelfSignedCertificate}, "RSA"); + } + + /** Tests whether a unrecognized self-signed certificate is rejected by TrustManager. */ + @Test(expected = Exception.class) + public void testCustomX509TrustManagerWithUnrecognizedCertificate() throws CertificateException { + customTrustManager.checkServerTrusted( + new X509Certificate[] {unrecognizedSelfSignedCertificate}, "RSA"); + } +} diff --git a/sdks/java/io/splunk/src/test/java/org/apache/beam/sdk/io/splunk/HttpEventPublisherTest.java b/sdks/java/io/splunk/src/test/java/org/apache/beam/sdk/io/splunk/HttpEventPublisherTest.java index aa02cc6b08bb..510f12a1e02a 100644 --- a/sdks/java/io/splunk/src/test/java/org/apache/beam/sdk/io/splunk/HttpEventPublisherTest.java +++ b/sdks/java/io/splunk/src/test/java/org/apache/beam/sdk/io/splunk/HttpEventPublisherTest.java @@ -18,20 +18,30 @@ package org.apache.beam.sdk.io.splunk; import static org.junit.Assert.assertEquals; +import static org.mockserver.integration.ClientAndServer.startClientAndServer; import com.google.api.client.http.GenericUrl; import com.google.api.client.http.HttpContent; import com.google.api.client.util.ExponentialBackOff; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.UnsupportedEncodingException; +import java.net.ServerSocket; import java.nio.charset.StandardCharsets; import java.security.KeyManagementException; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import javax.net.ssl.SSLHandshakeException; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.junit.Before; import org.junit.Test; +import org.mockserver.configuration.ConfigurationProperties; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.HttpRequest; +import org.mockserver.model.HttpResponse; +import org.mockserver.verify.VerificationTimes; /** Unit tests for {@link HttpEventPublisher}. */ public class HttpEventPublisherTest { @@ -58,11 +68,33 @@ public class HttpEventPublisherTest { private static final ImmutableList SPLUNK_EVENTS = ImmutableList.of(SPLUNK_TEST_EVENT_1, SPLUNK_TEST_EVENT_2); + private static final String ROOT_CA_PATH = "SplunkTestCerts/RootCA.crt"; + private static final String ROOT_CA_KEY_PATH = + Resources.getResource("SplunkTestCerts/RootCA_PrivateKey.pem").getPath(); + private static final String INCORRECT_ROOT_CA_PATH = "SplunkTestCerts/RootCA_2.crt"; + private static final String CERTIFICATE_PATH = "SplunkTestCerts/RecognizedCertificate.crt"; + private static final String KEY_PATH = + Resources.getResource("SplunkTestCerts/PrivateKey.pem").getPath(); + private static final String EXPECTED_PATH = "/" + HttpEventPublisher.HEC_URL_PATH; + private ClientAndServer mockServer; + + @Before + public void setUp() throws IOException { + ConfigurationProperties.disableSystemOut(true); + ConfigurationProperties.privateKeyPath(KEY_PATH); + ConfigurationProperties.x509CertificatePath(CERTIFICATE_PATH); + ConfigurationProperties.certificateAuthorityCertificate(ROOT_CA_PATH); + ConfigurationProperties.certificateAuthorityPrivateKey(ROOT_CA_KEY_PATH); + ServerSocket socket = new ServerSocket(0); + int port = socket.getLocalPort(); + socket.close(); + mockServer = startClientAndServer("localhost", port); + } @Test public void stringPayloadTest() - throws UnsupportedEncodingException, NoSuchAlgorithmException, KeyStoreException, - KeyManagementException { + throws IOException, NoSuchAlgorithmException, KeyStoreException, KeyManagementException, + CertificateException { HttpEventPublisher publisher = HttpEventPublisher.newBuilder() @@ -86,7 +118,8 @@ public void stringPayloadTest() @Test public void contentTest() - throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException { + throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException, + CertificateException { HttpEventPublisher publisher = HttpEventPublisher.newBuilder() @@ -130,7 +163,8 @@ public void genericURLTest() @Test public void configureBackOffDefaultTest() - throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException { + throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException, + CertificateException { HttpEventPublisher publisherDefaultBackOff = HttpEventPublisher.newBuilder() @@ -147,7 +181,8 @@ public void configureBackOffDefaultTest() @Test public void configureBackOffCustomTest() - throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException { + throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException, + CertificateException { int timeoutInMillis = 600000; // 10 minutes HttpEventPublisher publisherWithBackOff = @@ -162,4 +197,68 @@ public void configureBackOffCustomTest() assertEquals( timeoutInMillis, publisherWithBackOff.getConfiguredBackOff().getMaxElapsedTimeMillis()); } + + @Test(expected = CertificateException.class) + public void invalidRootCaTest() throws Exception { + HttpEventPublisher publisherWithInvalidCert = + HttpEventPublisher.newBuilder() + .withUrl("https://example.com") + .withToken("test-token") + .withDisableCertificateValidation(false) + .withRootCaCertificate("invalid_ca".getBytes(StandardCharsets.UTF_8)) + .withEnableGzipHttpCompression(true) + .build(); + System.out.println(publisherWithInvalidCert); + } + + /** Tests if a self-signed certificate is trusted if its root CA is passed. */ + @Test + public void recognizedSelfSignedCertificateTest() throws Exception { + mockServerListening(200); + byte[] rootCa = + Resources.toString(Resources.getResource(ROOT_CA_PATH), StandardCharsets.UTF_8) + .getBytes(StandardCharsets.UTF_8); + HttpEventPublisher publisher = + HttpEventPublisher.newBuilder() + .withUrl("https://localhost:" + String.valueOf(mockServer.getLocalPort())) + .withToken("test-token") + .withDisableCertificateValidation(false) + .withRootCaCertificate(rootCa) + .withEnableGzipHttpCompression(true) + .build(); + publisher.execute(SPLUNK_EVENTS); + + // Server received exactly one POST request. + mockServer.verify(HttpRequest.request(EXPECTED_PATH), VerificationTimes.once()); + } + + /** + * Tests if a self-signed certificate is not trusted if it is not derived by the root CA which is + * passed. + */ + @Test(expected = SSLHandshakeException.class) + public void unrecognizedSelfSignedCertificateTest() throws Exception { + mockServerListening(200); + byte[] rootCa = + Resources.toString(Resources.getResource(INCORRECT_ROOT_CA_PATH), StandardCharsets.UTF_8) + .getBytes(StandardCharsets.UTF_8); + + int timeoutInMillis = 5000; // 5 seconds + HttpEventPublisher publisher = + HttpEventPublisher.newBuilder() + .withUrl("https://localhost:" + String.valueOf(mockServer.getLocalPort())) + .withToken("test-token") + .withDisableCertificateValidation(false) + .withRootCaCertificate(rootCa) + .withMaxElapsedMillis(timeoutInMillis) + .withEnableGzipHttpCompression(true) + .build(); + publisher.execute(SPLUNK_EVENTS); + } + + private void mockServerListening(int statusCode) { + mockServer + .when(HttpRequest.request(EXPECTED_PATH)) + .respond(HttpResponse.response().withStatusCode(statusCode)); + } } diff --git a/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/PrivateKey.pem b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/PrivateKey.pem new file mode 100644 index 000000000000..abc5cea317fb --- /dev/null +++ b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/PrivateKey.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDKdvgJWBpStv+7 +cXEQLbZKM9JootILsGmFQ7v3mQUDO/r3P5QB461SFoKOvGhmPJwkqSjQK45QDP+j +4Z/Cd/YdWvC8g6AzKzpgPV0IeyKiFvyBChIFknM1pNoitwpBY4STG3aoiFt/wbEJ +JLdCfaBCVgil+DaQaoBdF2mC2ugMNOwth0gEiEPw14PS0U/tLgGzXk5CJWrmcYl1 +qQ4kNtkap2eJ9HV5m/QmB5+sB5iLEbVv5zgj6a575AKugzQD+yqqzf5/vGUlCr5n +pQ7aBSnDDqT/4/o5lQCXPn325eh8oo1M0mz75qRlupDYrfT0V7lwQ/C6+CkuidOu +aN+xevrjAgMBAAECggEBALFJabps7mfdnKNbG6EKFiR1qlo7sOfRayTpge+2i2Ag +porYnlblMgC+e0ZXjqdvjV7AzV8ztKM+LqAnUoisGNPtrP212JLV2IErWoqxoEsF +C8hGtC8y5TVlDCn308AcT5utIcND27NMPSR/hQVxEeLkiSAj8EuXJp3dgWO3IhiE +sZblM4e5/wS105jMrClx9YWTnF6WsrLSAje5Cb51pJhNc/xxqkfg4HEpfCHmwE/o +2v2YX3if4y292QnQsO/Kb3JYSMpezzpBm4KQ8GqEt9RiL5AVpvLF0NuACaWSPDUZ +FELjofl1ZhbkjNu1Z8wGZ/a+Mk8wrgJU0MhS0/JeqZkCgYEA/+OhJ/2lSCL034HT +Ju/U+bzANrDlN4jajPoyjzsuBDBpycDZsf2/MVU1+xj1+pD3blXNP0dfH3DKQaOt +4bzgNzzqWAV/y5M7v6hb2SfJB95b+4eVD7fDPNP2Rad0cn0YiKSaqVEg5NA6qevG +gKAL62+0DunhPO6oOtxRvvQQmJ0CgYEAyo1qi9Wz0tzv7fxrzSEiBf09z9Ndt6gs +bpuTd6wkFDGVlJHZMqucjsYxevvBFTFRxbS2KajkEF9E/rUYX/mjrp/rfxowoy6+ +xY6jLZayaIFQAi74LWfrk81FR5Bj/EQ9pTFv1H0nONFSi1h5XFBqCDZcOoeeLqDX +L0LEOqFcyX8CgYArigCuvEK2LnR53y4dTutu/sW5yImH1HpTSHL32frvbYlicbTY +yzMP4s7Hhay80JO5K1I79RnjUJ6pYn9AjJGd9HhvN6hR7CBbcPsHzPQwqY3/E0ey +n/LRU8NwgJiYrl5RSaijLJGrPR7uMJba5eCBU8VQUE0pv/XR5hDmq8JzJQKBgEZj +xK4ZsudJflvPB8t+gytfqTZq9ruXRvGdQS8qdFNMM/YwhTF1r+9x8soRaTUrMYaq +WU+68J677OczGehogbhyvM0r1dEvsn5HJm/2WcO1hI9tsTNeVODFShknlYeaU23v +8zP91j6Jh80DDxHEpER8V6rDbHY50O4MntLdNriTAoGBAMjvuzeaV2mE77YlKmW3 +RoOW8FyFp3RBRsGAhSF9kDCMWbdOA+2zc/dbOjuzqdmb8181kFgzTvsM5NvsTWxR +mijWOpIojSsbxd+mE49bJf1wm6VY8gkmy3Yq0H19yslSp5Fm5itYs7Xi9+lVO8g2 +9cWIwlfwlDt4OVGOnzqFpbSj +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RecognizedCertificate.crt b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RecognizedCertificate.crt new file mode 100644 index 000000000000..6ddc14f29d20 --- /dev/null +++ b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RecognizedCertificate.crt @@ -0,0 +1,26 @@ +-----BEGIN CERTIFICATE----- +MIIEVDCCAjwCCQCPd0wQvWpFsTANBgkqhkiG9w0BAQsFADBsMQswCQYDVQQGEwJV +UzEXMBUGA1UECAwOTm9ydGggQ2Fyb2xpbmExDzANBgNVBAcMBkR1cmhhbTEPMA0G +A1UECgwGR29vZ2xlMQ4wDAYDVQQLDAVDbG91ZDESMBAGA1UEAwwJbG9jYWxob3N0 +MB4XDTIyMDcyNjE5MjYxMVoXDTM2MDQwMzE5MjYxMVowbDELMAkGA1UEBhMCVVMx +FzAVBgNVBAgMDk5vcnRoIENhcm9saW5hMQ8wDQYDVQQHDAZEdXJoYW0xDzANBgNV +BAoMBkdvb2dsZTEOMAwGA1UECwwFQ2xvdWQxEjAQBgNVBAMMCWxvY2FsaG9zdDCC +ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMp2+AlYGlK2/7txcRAttkoz +0mii0guwaYVDu/eZBQM7+vc/lAHjrVIWgo68aGY8nCSpKNArjlAM/6Phn8J39h1a +8LyDoDMrOmA9XQh7IqIW/IEKEgWSczWk2iK3CkFjhJMbdqiIW3/BsQkkt0J9oEJW +CKX4NpBqgF0XaYLa6Aw07C2HSASIQ/DXg9LRT+0uAbNeTkIlauZxiXWpDiQ22Rqn +Z4n0dXmb9CYHn6wHmIsRtW/nOCPprnvkAq6DNAP7KqrN/n+8ZSUKvmelDtoFKcMO +pP/j+jmVAJc+ffbl6HyijUzSbPvmpGW6kNit9PRXuXBD8Lr4KS6J065o37F6+uMC +AwEAATANBgkqhkiG9w0BAQsFAAOCAgEAC4LZWaxUDn0ON+pxOc7NGF/sfc3ZCMFc +o/MSuT0ZiAw+gkhvj0M3wererhOiX9iexoXlNf4RBAmjcobdlOTn7jFO0MOKkCft +xZrYzelltvOuzaYa+iECQKniqNBkfKqH6hYyiV5ASYNWAndiR4YQ9F/5acrFHIId +JHQh8tN6I6BoMdbYUYnoMSUjsuPBF1pYSyt+T5bkpOatXrUFqFL6R1N3S6Rl6Ter +f2kFg27hek6UWwlIqQi16LSDkbDLSIGCojFoABof/rpBwCPIaG1kXHcKU6akfqTZ +2ypQgzwtQx3ehcxyBY1nBH5AzGlb4gIt9fN+Mb8ht/CD6FgMFJznPzylMZ7U+MHP +fVusK8C3d8YonvZBJpxy9sMbgzgjSWhCNp3BZXgeNhHlFJPRB08bCJ245uJhwpzr +1pulxv/Ou9ZlpslaHoZ8MvFNBA8r2owzq0uHpuOzlAe6v+btZt95dstc9wIDpW9G +k5LP9KaAtAQznBjJ/KI47Wj1TmATEMKbD8ACPlclJMMtHWuoyZVQ0hd6DD8zH2+T +DBCwLRSTDTUdX+RsoUJsUNn2tlFJG0f393TmlFLTl5+iZ2o6x0xQsPhdbrhGV6xQ +molurUhHAAbRuV5Xdjsfe6bDBFd9877wxy2X8ISRWWuHj5JroHhMHtr/jM1yQsIz +fV7s3q6Ub0Q= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA.crt b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA.crt new file mode 100644 index 000000000000..c19922361a17 --- /dev/null +++ b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA.crt @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFVDCCAzwCCQDyGYGOTz5rCzANBgkqhkiG9w0BAQsFADBsMQswCQYDVQQGEwJV +UzEXMBUGA1UECAwOTm9ydGggQ2Fyb2xpbmExDzANBgNVBAcMBkR1cmhhbTEPMA0G +A1UECgwGR29vZ2xlMQ4wDAYDVQQLDAVDbG91ZDESMBAGA1UEAwwJbG9jYWxob3N0 +MB4XDTIyMDcyNjE5MjQ0OVoXDTI1MDUxNTE5MjQ0OVowbDELMAkGA1UEBhMCVVMx +FzAVBgNVBAgMDk5vcnRoIENhcm9saW5hMQ8wDQYDVQQHDAZEdXJoYW0xDzANBgNV +BAoMBkdvb2dsZTEOMAwGA1UECwwFQ2xvdWQxEjAQBgNVBAMMCWxvY2FsaG9zdDCC +AiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAL/Zdqv3XIskLZgLg2q6QK8D +Aenh9bQH0gRG8VYSAnPRW53epntr86udULvb5Y220V6dQdlCZSWv5eCWT8yuO0dr +qi5d8JT+NjjjZ258n6D4UtvW0UbubY5CHAR3Kf+YmWi/hdoL2K/Yi+rgC29+eBIx +VHUCqHSo6eI868zg4yUH5bFBNPxWcSQt4flVvgz4HKuraEKHopfhSyiw8CSf72Ps +zKceJF/OeZbo2lKieICEB3cbXQ7vl6ifd7VP79qfDcNxkzxF/rvzHV9UMX6ba5jU +SjUBWOP1b+nDgDDrEAF87Dn1qrsYlyifhocdyS9kIN2+TWH0dJdoFpcB1CryaFFw +iX5MlQhSQRmVxWRFYQpTlMxaES32F/pkYQeo0DsgL8K/OP0Zs5Aljen9oZ+qkzXb +aDo1f0FV+NoH1KNJbg7PxqHUFWx82G4XrvS2nVvdEwXO1I2ndMWfOpvtsHGduBMF +/aDVHIbAG4HRF/WYz+GBmjq5JP7xR5swZKV1dBxAgbkZEQ1Ah1pujDYvfVMlspJd +66APdiDaWQjKC344b7kNa5mQhbFMn3sZ2RWnibSQJmLmrH/eeMQj7Y6FEgjTW8Jp +qHLsQ1etsEu0IGIZA0pLG8uUEyKN6iSWZEXVorGOASvUtnPzXAjVd5COIWgrAn8G +wWf31h2hJyMa4BLi0xwpAgMBAAEwDQYJKoZIhvcNAQELBQADggIBADtl0gf96fOx +4fekHKWnBDZG5RXs6QgYpvx0oxQUv6eugNsgzVZsTfFYbVRptZvt27GzgrhWGVrA +mYf8xaPmc9Pkd9HW9y+PPpqJbIhky0+NdOEZ4yfGSzz/wg+mJX9hNTbIum6oYkfO +X2CnDTm/DtIwXPrQ8ReZr+HnoulYHEhMqT5l2x2xy8pArGhR6HfRRzT/6ROISJeS +K5b50zHY9EdMvRBztfupuVh+jYWZldhnom3xoZ+WLIcJ/YB2/Hq73xD1k2p59qNM +2z8dv5eYefDgljhHwy7QngUW5kWXoJKi6hf5fNkBO7uwM3SCgL4eU3W7t1F9E3OM +a7N5qYxzAbUo0jsN5NwIl0f6DfRRMexOvogsQUud9KMl/v4n5JOan0GASJUHLuTH +jj/gpGAsmvYYT+WKyNsPgBhWlrsxqzsBhFNeRig0tSAgewAo8mqQ3E8NLxfQUIOV +KmiQiM1fD1biu26tFuBO9z5yzh1bovY/yQFMwrYBVUQFPg3Eyc8uCqRYimOtYpaf +Jw9/G/FC9ulaU1KfMcTl0W8e4cvWXMTIRMXYSPCCRwAMftcPfayjY3RLW2A2hukf +vVP57lmWE9KzjH2ogN6F/Lo5nqvC9Eo92/sWYoPuHMT+u6TB7/bciwJ/qekOE1p4 +2gqt2zYedeYzkuYud3eIhU+TDYtuEX5b +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA_2.crt b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA_2.crt new file mode 100644 index 000000000000..66f482117ede --- /dev/null +++ b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA_2.crt @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB5jCCAW2gAwIBAgITd/Dv+jhoas7ztyVnkyujXhT9UDAKBggqhkjOPQQDAzAj +MREwDwYDVQQKEwhUZXN0IExMQzEOMAwGA1UEAxMFbXktY2EwHhcNMjExMTAzMTgy +NTIwWhcNMzExMTA0MDQzMzAwWjAjMREwDwYDVQQKEwhUZXN0IExMQzEOMAwGA1UE +AxMFbXktY2EwdjAQBgcqhkjOPQIBBgUrgQQAIgNiAATKhcCiEv0KAF0NJdwMobH6 +yrc91Vmu6zCLjpBSnuzVqSQ7scwhsHQLvc9JtPNaFU6cFS5537x4nlEkBmFceUkZ +WWuZ8n/M40H6tRSJSLuYLLrs56ncnaeYoAgkcN93b4ajYzBhMA4GA1UdDwEB/wQE +AwIBBjAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBS5oWwcIzvRinKf/kTeFIXX +mJsWITAfBgNVHSMEGDAWgBS5oWwcIzvRinKf/kTeFIXXmJsWITAKBggqhkjOPQQD +AwNnADBkAjAiMPONNFoFy0kIrSk0AwNmjQEg6l4Q3zbcCXczjH/EoLLFyWLNnp25 +g0LTaaGgk1gCMF+HX0fTPVfFTvvHbdvzaTvWwB0OpgCdq0ugrps/KIebaq2kk/WN +xkLaTxN3//1/RQ== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA_PrivateKey.pem b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA_PrivateKey.pem new file mode 100644 index 000000000000..a2b5847ddcde --- /dev/null +++ b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/RootCA_PrivateKey.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJRAIBADANBgkqhkiG9w0BAQEFAASCCS4wggkqAgEAAoICAQC/2Xar91yLJC2Y +C4NqukCvAwHp4fW0B9IERvFWEgJz0Vud3qZ7a/OrnVC72+WNttFenUHZQmUlr+Xg +lk/MrjtHa6ouXfCU/jY442dufJ+g+FLb1tFG7m2OQhwEdyn/mJlov4XaC9iv2Ivq +4AtvfngSMVR1Aqh0qOniPOvM4OMlB+WxQTT8VnEkLeH5Vb4M+Byrq2hCh6KX4Uso +sPAkn+9j7MynHiRfznmW6NpSoniAhAd3G10O75eon3e1T+/anw3DcZM8Rf678x1f +VDF+m2uY1Eo1AVjj9W/pw4Aw6xABfOw59aq7GJcon4aHHckvZCDdvk1h9HSXaBaX +AdQq8mhRcIl+TJUIUkEZlcVkRWEKU5TMWhEt9hf6ZGEHqNA7IC/Cvzj9GbOQJY3p +/aGfqpM122g6NX9BVfjaB9SjSW4Oz8ah1BVsfNhuF670tp1b3RMFztSNp3TFnzqb +7bBxnbgTBf2g1RyGwBuB0Rf1mM/hgZo6uST+8UebMGSldXQcQIG5GRENQIdabow2 +L31TJbKSXeugD3Yg2lkIygt+OG+5DWuZkIWxTJ97GdkVp4m0kCZi5qx/3njEI+2O +hRII01vCaahy7ENXrbBLtCBiGQNKSxvLlBMijeoklmRF1aKxjgEr1LZz81wI1XeQ +jiFoKwJ/BsFn99YdoScjGuAS4tMcKQIDAQABAoICAB1LQFKbz7azTH716xgl3nCa +vfUPeqwFsazThFBHKbazlhCyCau43RksSUKWHiQYcTnIO2DIQZeSl0BG02KGjCio +qPCxiWXGt1LSbl2xi9JReJ123Le++l2JfKu14mTT0UDsVazouCqJnzu7ACQDJKRq +geHoCP7fN+9CrCK5iBWEci8xrLyHGnmSw/mFfSKP1BjmcGIQQeR9EzPgaJq/DRet +9cXi1V0Hsws2/Pc3Nb0x683lEL2SGg82Yln+Hbq9JKXeNsQyT+Y1BhwjR/d0Febd +K4OSdBdCx9bi8jUF/4iqoYtsFqjA5Xvfd/QzuR+SY25Ye4pkgFUKIMDoF5SDNShO +Mtk9zQzcfy0lI+S8QkCHNHRltqXet5GbyPRBe5NOIn9homMjWbLSZ1szJSrxS30L +MBNBYpAvyS465Xk4ZH7+ycJPHbhlTH8kY6pmNT1Z2z6/TWX3gnKs5sX83lRro0qi ++No4fJF1H22ydauR+Olpwu25EN2qaxD/UuBsEbJWDRNrlWB+j/z8kQY6JHsmT9nG +hLEARTEoJwi8OPp405mUVDJYGxu0c6hNIRTzYl1/RiSncZ/wnzN6b7QGu7gSjoMs +xHa4Fy3+oXRtnFViCBG4aAGNHHuvoJSqudyoxGH843xDkHAaFDQ3vr12obxiq1/y +y4QzfwkLtiObmVMPQfsBAoIBAQD36oxGuknSvym0SL8kZrY2V780GHrI91WMtS+4 +iT5fqmM66EstZfVSs/xcIqF/mrrLSXomfypA6n7sZYn5m8tzRpmVkxoYz8sxb5eS +3OOxgsmwUPn9l6wA7cEQTt4jiXjzllaZCsvDTwfnMnV0m4ac3zAEZGIr6q644ZwY +cItApa7AvzHPNAtzoDpIpRVBLE85IbSB+RYFLsnoJWC3bgPs/osnB+vUGcf99gFN +x84/fpsGBrE8d0yp5svNFrhCyBwdS7K+6QITv71H+MHDez4sAsTR/DQZWzhvx7mu +KipxQBCLJscqs4LpddtJu29EpNo98AToqcz4/MTmszFUWZmhAoIBAQDGGuerzEyE +D2pIdC44J9Go0JftR8jpZoae94o1gbquT8Py4w5VEOThlP6hTr4mTby3A7/xyk1O +ruITiWTvlrlDDV7rZGCwzu/7Lbe0GGF59RBiqTOAt5EARKVbL45CZRWthVGEUYiZ +QwCu1/eEQMCkTmHdSPI7w4XY2t3MZTPwMVg5wRyaELrYfO/56SttfbCmtVAxkSoc +WufYd2ddSs+OBoGMoSf/qcbWk0lydi4BXDAB2Jp28oy+ogg2zHQKQXyeLt6bDS02 +STJgXhc69zU4QIAq8x8pKTnjD67AK1PJps3mdr2uLEn1cpIotqHVaTdxNoNX8rdF +7O/tFZwwR8WJAoIBAQCyRiifR6CEetCrgtBoha+rvkeRV7UbxLfxGe15/r8qneUD +XD2LJNFXqnPjcUe+8e429txuyG0DB11D8vRX2Q9hEriolYJjqzELmJpfkpXtdQZB +0skh8apPdKiraHtXBKlESKx3GwtRpAgj2eYadyhCsD/gOKtbt0PzUNElxfBtCXdz +xUk1HdDKUcL3sDZrikhh5fneqNaL3Yx1ckNtRCBwkM5Rv+F2wR9OYVOosfB/OjSS +Dcfvmj43wu50yUyxQSLuchvUKsxClVOwaJI1Vu9rSIZuFbUFMtKPlwjP1CR3EcGT +vsvjyfhul0CccbtemkkR1wJAqLHrriCNYPgtFs1hAoIBAQCH5iETS50Z4vFILtse +DsLXCfGPBvWel4S2PJ4FQq9rsLB9SKGmXWaGEY3z3m63HBDfg6UDG4KY/YN9X31s +lnsUsnFxDXT/FZavOpeQ5kDIRwMsi2IXZNYF1xGQUjlG9s0+MfzPxpbsfHhVeTYE +9d6xWEvuX4I82U5SiyIoeyx8E32wcPdMAToMFPkS+Y+fFuA+HJecyTaYKQxvBMpV +x2JGzdPFQzCLRE5xGK0D5mp86F7OhWbBPnaqt8Dmxq678lyorwJqX0mqud/jF+jb +vIY3xpel3w57UBqz1yhMD+z0oocRGFfayesag7QcVd0C56Du+zRy+sAbKgUVIQP6 +YuHxAoIBAQDODDe8lu9FY+oiFmbZkoRjSfa/FdgXez/nrBCTSonefvy5GcyAmwIE +8+3gCP0oZ5c9na5GlOFE2srbEjuu9zGJVUp445Qln+sEdPvSAvZLTZv7WrvCFATs +Oatf1KmuMtqpbn+7pSr4wtor9IlmQBho8KNcxgORF6kkLSR9zfEVlMBIjl85QoRf +kJSnWyyz0xek45MxdwLu7mmaRnNBplqh52rMRicVxZGxV+IRwpPZ7WO5CfcHSu1C +r0V3H6WrCcFtWmJ8g2EnJQU/KZPnYSMvM2g8gvEiphJcXvnPUjDkW8xdI4JF9sSf +IZLcLY3f6dPCU8P/SQQec+Q+XmZTJS6M +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/UnrecognizedCertificate.crt b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/UnrecognizedCertificate.crt new file mode 100644 index 000000000000..784c5c73111b --- /dev/null +++ b/sdks/java/io/splunk/src/test/resources/SplunkTestCerts/UnrecognizedCertificate.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDeTCCAmGgAwIBAgIEUylz4zANBgkqhkiG9w0BAQsFADBtMQswCQYDVQQGEwJV +UzEXMBUGA1UECBMOTm9ydGggQ2Fyb2xpbmExDzANBgNVBAcTBkR1cmhhbTEPMA0G +A1UEChMGR29vZ2xlMQ8wDQYDVQQLEwZHb29nbGUxEjAQBgNVBAMTCWxvY2FsaG9z +dDAeFw0yMTA5MDgyMzI5MTFaFw0yMjA5MDMyMzI5MTFaMG0xCzAJBgNVBAYTAlVT +MRcwFQYDVQQIEw5Ob3J0aCBDYXJvbGluYTEPMA0GA1UEBxMGRHVyaGFtMQ8wDQYD +VQQKEwZHb29nbGUxDzANBgNVBAsTBkdvb2dsZTESMBAGA1UEAxMJbG9jYWxob3N0 +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAqyrHH61bmZzHZii7kRZD +azZyPPgD+3W6+Nz3rb+TLw3epDj/qgGSKNXbemAOM30mpTxotECkIVKtnclDZxuH +eGqc08wYYoOVaj3LBSekA5AryP49d3eVrdYwUaaefSCZFS1JdW0JsjcR1HrMN5Ka +sxFfMfgASKNe6MZO2irNU/sofERVaXzngSCMlDuXpOJEw8Nrmm6LbcJxgweh55aq +t9RMnrSVKH5SUrUiARdZuV543imsagLDUj/6cBFj91hT8/GFxjWtw3BijubNkVFk +9XBo7F6RfqyCh0CMiOrTznWMFgrEwYLqeZRYlvYNa6HGbO++HM3NuwHDpDX+ICJl +3wIDAQABoyEwHzAdBgNVHQ4EFgQUSFJ/wKq2CdpWxyzy7s+/vhLqtdMwDQYJKoZI +hvcNAQELBQADggEBAFBS7bj4eTRqWQoom+z3pKKJ6/1FZ1zfx5XIUmu5uG+aKfVu +ZbsJ9Yv7Q6Bvczb50KJLrAitk5bYitobcFof9fMUh7s3e+jUM18vjNDGW2Ckso+i +7s0G6f/FVloOi8rHmtqplwjFEAuNsRJhm2AHywiLAWP5Ww8tRHujZD50tXnaFdYw +LTR7v2A8TARp66W5I3s+IQJw1Y1WYUpnmG63LQcO+783ahRbzOD0sTi1YZ1tBgB0 +M/lePRiW+Efqc9zZr64U8YaYVezR504tQGRPQzb/rT8lQ865NpLtZNujX6Fc19dQ +WNL3V51FtAY+rwiideyKq4hP6AEHdgKjvyp5mt4= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sdks/python/apache_beam/examples/cookbook/custom_ptransform.py b/sdks/python/apache_beam/examples/cookbook/custom_ptransform.py index 4beab18aab1e..a922216a5220 100644 --- a/sdks/python/apache_beam/examples/cookbook/custom_ptransform.py +++ b/sdks/python/apache_beam/examples/cookbook/custom_ptransform.py @@ -118,11 +118,12 @@ def get_args(argv): def run(argv=None): known_args, pipeline_args = get_args(argv) - options = PipelineOptions(pipeline_args) - run_count1(known_args, options) - run_count2(known_args, options) - run_count3(known_args, options) + # pipeline initialization may modify PipelineOptions object. + # Create instances for each. + run_count1(known_args, PipelineOptions(pipeline_args)) + run_count2(known_args, PipelineOptions(pipeline_args)) + run_count3(known_args, PipelineOptions(pipeline_args)) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/io/filebasedio_perf_test.py b/sdks/python/apache_beam/io/filebasedio_perf_test.py new file mode 100644 index 000000000000..7d5b673098d5 --- /dev/null +++ b/sdks/python/apache_beam/io/filebasedio_perf_test.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Performance tests for file based io connectors.""" + +import logging +import sys +import uuid +from typing import Tuple + +import apache_beam as beam +from apache_beam import typehints +from apache_beam.io.filesystems import FileSystems +from apache_beam.io.iobase import Read +from apache_beam.io.textio import ReadFromText +from apache_beam.io.textio import WriteToText +from apache_beam.testing.load_tests.load_test import LoadTest +from apache_beam.testing.load_tests.load_test import LoadTestOptions +from apache_beam.testing.load_tests.load_test_metrics_utils import CountMessages +from apache_beam.testing.load_tests.load_test_metrics_utils import MeasureTime +from apache_beam.testing.synthetic_pipeline import SyntheticSource +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +WRITE_NAMESPACE = 'write' +READ_NAMESPACE = 'read' + +_LOGGER = logging.getLogger(__name__) + + +class FileBasedIOTestOptions(LoadTestOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument( + '--test_class', required=True, help='Test class to run.') + parser.add_argument( + '--filename_prefix', + required=True, + help='Destination prefix for files generated by the test.') + parser.add_argument( + '--compression_type', + default='auto', + help='File compression type for writing and reading test files.') + parser.add_argument( + '--number_of_shards', + type=int, + default=0, + help='Number of files this test will create during the write phase.') + parser.add_argument( + '--dataset_size', + type=int, + help='Size of data saved on the target filesystem (bytes).') + + +@typehints.with_output_types(bytes) +@typehints.with_input_types(Tuple[bytes, bytes]) +class SyntheticRecordToStrFn(beam.DoFn): + """ + A DoFn that convert key-value bytes from synthetic source to string record. + + It uses base64 to convert random bytes emitted from the synthetic source. + Therefore, every 3 bytes give 4 bytes long ascii characters. + + Output length = 4(ceil[len(key)/3] + ceil[len(value)/3]) + 1 + """ + def process(self, element): + import base64 + yield base64.b64encode(element[0]) + b' ' + base64.b64encode(element[1]) + + +class CreateFolderFn(beam.DoFn): + """Create folder at pipeline runtime.""" + def __init__(self, folder): + self.folder = folder + + def process(self, element): + from apache_beam.io.filesystems import FileSystems # pylint: disable=reimported + filesystem = FileSystems.get_filesystem(self.folder) + if filesystem.has_dirs() and not filesystem.exists(self.folder): + filesystem.mkdirs(self.folder) + + +class TextIOPerfTest: + def run(self): + write_test = _TextIOWritePerfTest(need_cleanup=False) + read_test = _TextIOReadPerfTest(input_folder=write_test.output_folder) + write_test.run() + read_test.run() + + +class _TextIOWritePerfTest(LoadTest): + def __init__(self, need_cleanup=True): + super().__init__(WRITE_NAMESPACE) + self.need_cleanup = need_cleanup + self.test_options = self.pipeline.get_pipeline_options().view_as( + FileBasedIOTestOptions) + self.output_folder = FileSystems.join( + self.test_options.filename_prefix, str(uuid.uuid4())) + + def test(self): + # first makedir if needed + _ = ( + self.pipeline + | beam.Impulse() + | beam.ParDo(CreateFolderFn(self.output_folder))) + + # write to text + _ = ( + self.pipeline + | 'Produce rows' >> Read( + SyntheticSource(self.parse_synthetic_source_options())) + | 'Count records' >> beam.ParDo(CountMessages(self.metrics_namespace)) + | 'Format' >> beam.ParDo(SyntheticRecordToStrFn()) + | 'Measure time' >> beam.ParDo(MeasureTime(self.metrics_namespace)) + | 'Write Text' >> WriteToText( + file_path_prefix=FileSystems.join(self.output_folder, 'test'), + compression_type=self.test_options.compression_type, + num_shards=self.test_options.number_of_shards)) + + def cleanup(self): + if not self.need_cleanup: + return + try: + FileSystems.delete([self.output_folder]) + except IOError: + # may not have delete permission, just raise a warning + _LOGGER.warning( + 'Unable to delete file %s during cleanup.', self.output_folder) + + +class _TextIOReadPerfTest(LoadTest): + def __init__(self, input_folder): + super().__init__(READ_NAMESPACE) + self.test_options = self.pipeline.get_pipeline_options().view_as( + FileBasedIOTestOptions) + self.input_folder = input_folder + + def test(self): + output = ( + self.pipeline + | 'Read from text' >> + ReadFromText(file_pattern=FileSystems.join(self.input_folder, '*')) + | 'Count records' >> beam.ParDo(CountMessages(self.metrics_namespace)) + | 'Measure time' >> beam.ParDo(MeasureTime(self.metrics_namespace)) + | 'Count' >> beam.combiners.Count.Globally()) + assert_that(output, equal_to([self.input_options['num_records']])) + + def cleanup(self): + try: + #FileSystems.delete([self.input_folder]) + pass + except IOError: + # may not have delete permission, just raise a warning + _LOGGER.warning( + 'Unable to delete file %s during cleanup.', self.input_folder) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + + test_options = TestPipeline().get_pipeline_options().view_as( + FileBasedIOTestOptions) + supported_test_classes = list( + filter( + lambda s: s.endswith('PerfTest') and not s.startswith('_'), + dir(sys.modules[__name__]))) + + if test_options.test_class not in supported_test_classes: + raise RuntimeError( + f'Test {test_options.test_class} not found. ' + 'Supported tests are {supported_test_classes}') + + getattr(sys.modules[__name__], test_options.test_class)().run() diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 46938ad619d8..5428e8bf4cac 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -38,6 +38,18 @@ 'PytorchModelHandlerKeyedTensor', ] +TensorInferenceFn = Callable[ + [Sequence[torch.Tensor], torch.nn.Module, str, Optional[Dict[str, Any]]], + Iterable[PredictionResult]] + +KeyedTensorInferenceFn = Callable[[ + Sequence[Dict[str, torch.Tensor]], + torch.nn.Module, + str, + Optional[Dict[str, Any]] +], + Iterable[PredictionResult]] + def _load_model( model_class: torch.nn.Module, state_dict_path, device, **model_params): @@ -100,6 +112,46 @@ def _convert_to_result( return [PredictionResult(x, y) for x, y in zip(batch, predictions)] +def default_tensor_inference_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + batched_tensors = torch.stack(batch) + batched_tensors = _convert_to_device(batched_tensors, device) + predictions = model(batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) + + +def make_tensor_model_fn(model_fn: str) -> TensorInferenceFn: + """ + Produces a TensorInferenceFn that uses a method of the model other that + the forward() method. + + Args: + model_fn: A string name of the method to be used. This is accessed through + getattr(model, model_fn) + """ + def attr_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + with torch.no_grad(): + batched_tensors = torch.stack(batch) + batched_tensors = _convert_to_device(batched_tensors, device) + pred_fn = getattr(model, model_fn) + predictions = pred_fn(batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) + + return attr_fn + + class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, PredictionResult, torch.nn.Module]): @@ -108,7 +160,9 @@ def __init__( state_dict_path: str, model_class: Callable[..., torch.nn.Module], model_params: Dict[str, Any], - device: str = 'CPU'): + device: str = 'CPU', + *, + inference_fn: TensorInferenceFn = default_tensor_inference_fn): """Implementation of the ModelHandler interface for PyTorch. Example Usage:: @@ -127,6 +181,8 @@ def __init__( device: the device on which you wish to run the model. If ``device = GPU`` then a GPU device will be used if it is available. Otherwise, it will be CPU. + inference_fn: the inference function to use during RunInference. + default=_default_tensor_inference_fn **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. @@ -140,6 +196,7 @@ def __init__( self._device = torch.device('cpu') self._model_class = model_class self._model_params = model_params + self._inference_fn = inference_fn def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -179,13 +236,7 @@ def run_inference( """ inference_args = {} if not inference_args else inference_args - # torch.no_grad() mitigates GPU memory issues - # https://github.com/apache/beam/issues/22811 - with torch.no_grad(): - batched_tensors = torch.stack(batch) - batched_tensors = _convert_to_device(batched_tensors, self._device) - predictions = model(batched_tensors, **inference_args) - return _convert_to_result(batch, predictions) + return self._inference_fn(batch, model, self._device, inference_args) def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """ @@ -205,6 +256,69 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): pass +def default_keyed_tensor_inference_fn( + batch: Sequence[Dict[str, torch.Tensor]], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: + # If elements in `batch` are provided as a dictionaries from key to Tensors, + # then iterate through the batch list, and group Tensors to the same key + key_to_tensor_list = defaultdict(list) + + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + for example in batch: + for key, tensor in example.items(): + key_to_tensor_list[key].append(tensor) + key_to_batched_tensors = {} + for key in key_to_tensor_list: + batched_tensors = torch.stack(key_to_tensor_list[key]) + batched_tensors = _convert_to_device(batched_tensors, device) + key_to_batched_tensors[key] = batched_tensors + predictions = model(**key_to_batched_tensors, **inference_args) + + return _convert_to_result(batch, predictions) + + +def make_keyed_tensor_model_fn(model_fn: str) -> KeyedTensorInferenceFn: + """ + Produces a KeyedTensorInferenceFn that uses a method of the model other that + the forward() method. + + Args: + model_fn: A string name of the method to be used. This is accessed through + getattr(model, model_fn) + """ + def attr_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + # If elements in `batch` are provided as a dictionaries from key to Tensors, + # then iterate through the batch list, and group Tensors to the same key + key_to_tensor_list = defaultdict(list) + + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + for example in batch: + for key, tensor in example.items(): + key_to_tensor_list[key].append(tensor) + key_to_batched_tensors = {} + for key in key_to_tensor_list: + batched_tensors = torch.stack(key_to_tensor_list[key]) + batched_tensors = _convert_to_device(batched_tensors, device) + key_to_batched_tensors[key] = batched_tensors + pred_fn = getattr(model, model_fn) + predictions = pred_fn(**key_to_batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) + + return attr_fn + + @experimental(extra_message="No backwards-compatibility guarantees.") class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], PredictionResult, @@ -214,7 +328,9 @@ def __init__( state_dict_path: str, model_class: Callable[..., torch.nn.Module], model_params: Dict[str, Any], - device: str = 'CPU'): + device: str = 'CPU', + *, + inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn): """Implementation of the ModelHandler interface for PyTorch. Example Usage:: @@ -237,6 +353,8 @@ def __init__( device: the device on which you wish to run the model. If ``device = GPU`` then a GPU device will be used if it is available. Otherwise, it will be CPU. + inference_fn: the function to invoke on run_inference. + default = default_keyed_tensor_inference_fn **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. @@ -250,6 +368,7 @@ def __init__( self._device = torch.device('cpu') self._model_class = model_class self._model_params = model_params + self._inference_fn = inference_fn def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -289,24 +408,7 @@ def run_inference( """ inference_args = {} if not inference_args else inference_args - # If elements in `batch` are provided as a dictionaries from key to Tensors, - # then iterate through the batch list, and group Tensors to the same key - key_to_tensor_list = defaultdict(list) - - # torch.no_grad() mitigates GPU memory issues - # https://github.com/apache/beam/issues/22811 - with torch.no_grad(): - for example in batch: - for key, tensor in example.items(): - key_to_tensor_list[key].append(tensor) - key_to_batched_tensors = {} - for key in key_to_tensor_list: - batched_tensors = torch.stack(key_to_tensor_list[key]) - batched_tensors = _convert_to_device(batched_tensors, self._device) - key_to_batched_tensors[key] = batched_tensors - predictions = model(**key_to_batched_tensors, **inference_args) - - return _convert_to_result(batch, predictions) + return self._inference_fn(batch, model, self._device, inference_args) def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """ diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 32036f43de86..d6d3a2934555 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -37,6 +37,10 @@ import torch from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference + from apache_beam.ml.inference.pytorch_inference import default_keyed_tensor_inference_fn + from apache_beam.ml.inference.pytorch_inference import default_tensor_inference_fn + from apache_beam.ml.inference.pytorch_inference import make_keyed_tensor_model_fn + from apache_beam.ml.inference.pytorch_inference import make_tensor_model_fn from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor except ImportError: @@ -97,6 +101,15 @@ for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) ] +KEYED_TORCH_HELPER_PREDICTIONS = [ + PredictionResult(ex, pred) for ex, + pred in zip( + KEYED_TORCH_EXAMPLES, + torch.Tensor([(example['k1'] * 2.0 + 0.5) + + (example['k2'] * 2.0 + 0.5) + 0.5 + for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) +] + KEYED_TORCH_DICT_OUT_PREDICTIONS = [ PredictionResult( p.example, { @@ -106,14 +119,16 @@ class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor): - def __init__(self, device): + def __init__(self, device, *, inference_fn=default_tensor_inference_fn): self._device = device + self._inference_fn = inference_fn class TestPytorchModelHandlerKeyedTensorForInferenceOnly( PytorchModelHandlerKeyedTensor): - def __init__(self, device): + def __init__(self, device, *, inference_fn=default_keyed_tensor_inference_fn): self._device = device + self._inference_fn = inference_fn def _compare_prediction_result(x, y): @@ -134,6 +149,16 @@ def _compare_prediction_result(x, y): return torch.equal(x.inference, y.inference) +def custom_tensor_inference_fn(batch, model, device, inference_args): + predictions = [ + PredictionResult(ex, pred) for ex, + pred in zip( + batch, + torch.Tensor([item * 2.0 + 1.5 for item in batch]).reshape(-1, 1)) + ] + return predictions + + class PytorchLinearRegression(torch.nn.Module): def __init__(self, input_dim, output_dim): super().__init__() @@ -143,6 +168,10 @@ def forward(self, x): out = self.linear(x) return out + def generate(self, x): + out = self.linear(x) + 0.5 + return out + class PytorchLinearRegressionDict(torch.nn.Module): def __init__(self, input_dim, output_dim): @@ -231,6 +260,33 @@ def test_run_inference_multiple_tensor_features_dict_output(self): for actual, expected in zip(predictions, TWO_FEATURES_DICT_OUT_PREDICTIONS): self.assertEqual(actual, expected) + def test_run_inference_custom(self): + examples = [ + torch.from_numpy(np.array([1], dtype="float32")), + torch.from_numpy(np.array([5], dtype="float32")), + torch.from_numpy(np.array([-3], dtype="float32")), + torch.from_numpy(np.array([10.0], dtype="float32")), + ] + expected_predictions = [ + PredictionResult(ex, pred) for ex, + pred in zip( + examples, + torch.Tensor([example * 2.0 + 1.5 + for example in examples]).reshape(-1, 1)) + ] + + model = PytorchLinearRegression(input_dim=1, output_dim=1) + model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0]])), + ('linear.bias', torch.Tensor([0.5]))])) + model.eval() + + inference_runner = TestPytorchModelHandlerForInferenceOnly( + torch.device('cpu'), inference_fn=custom_tensor_inference_fn) + predictions = inference_runner.run_inference(examples, model) + for actual, expected in zip(predictions, expected_predictions): + self.assertEqual(actual, expected) + def test_run_inference_keyed(self): """ This tests for inputs that are passed as a dictionary from key to tensor @@ -315,6 +371,77 @@ def test_inference_runner_inference_args(self): for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertEqual(actual, expected) + def test_run_inference_helper(self): + examples = [ + torch.from_numpy(np.array([1], dtype="float32")), + torch.from_numpy(np.array([5], dtype="float32")), + torch.from_numpy(np.array([-3], dtype="float32")), + torch.from_numpy(np.array([10.0], dtype="float32")), + ] + expected_predictions = [ + PredictionResult(ex, pred) for ex, + pred in zip( + examples, + torch.Tensor([example * 2.0 + 1.0 + for example in examples]).reshape(-1, 1)) + ] + + gen_fn = make_tensor_model_fn('generate') + + model = PytorchLinearRegression(input_dim=1, output_dim=1) + model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0]])), + ('linear.bias', torch.Tensor([0.5]))])) + model.eval() + + inference_runner = TestPytorchModelHandlerForInferenceOnly( + torch.device('cpu'), inference_fn=gen_fn) + predictions = inference_runner.run_inference(examples, model) + for actual, expected in zip(predictions, expected_predictions): + self.assertEqual(actual, expected) + + def test_run_inference_keyed_helper(self): + """ + This tests for inputs that are passed as a dictionary from key to tensor + instead of a standard non-keyed tensor example. + + Example: + Typical input format is + input = torch.tensor([1, 2, 3]) + + But Pytorch syntax allows inputs to have the form + input = { + 'k1' : torch.tensor([1, 2, 3]), + 'k2' : torch.tensor([4, 5, 6]) + } + """ + class PytorchLinearRegressionMultipleArgs(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.linear = torch.nn.Linear(input_dim, output_dim) + + def forward(self, k1, k2): + out = self.linear(k1) + self.linear(k2) + return out + + def generate(self, k1, k2): + out = self.linear(k1) + self.linear(k2) + 0.5 + return out + + model = PytorchLinearRegressionMultipleArgs(input_dim=1, output_dim=1) + model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0]])), + ('linear.bias', torch.Tensor([0.5]))])) + model.eval() + + gen_fn = make_keyed_tensor_model_fn('generate') + + inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( + torch.device('cpu'), inference_fn=gen_fn) + predictions = inference_runner.run_inference(KEYED_TORCH_EXAMPLES, model) + for actual, expected in zip(predictions, KEYED_TORCH_HELPER_PREDICTIONS): + self.assertTrue(_compare_prediction_result(actual, expected)) + def test_num_bytes(self): inference_runner = TestPytorchModelHandlerForInferenceOnly( torch.device('cpu')) diff --git a/sdks/python/apache_beam/runners/interactive/README.md b/sdks/python/apache_beam/runners/interactive/README.md index 15c1d6b3e95c..ff6c57a94e61 100644 --- a/sdks/python/apache_beam/runners/interactive/README.md +++ b/sdks/python/apache_beam/runners/interactive/README.md @@ -312,11 +312,13 @@ By default, the caches are kept on the local file system of the machine in You can specify the caching directory as follows ```python -cache_dir = 'some/path/to/dir' -runner = interactive_runner.InteractiveRunner(cache_dir=cache_dir) -p = beam.Pipeline(runner=runner) +ib.options.cache_root = 'some/path/to/dir' ``` +When using an `InteractiveRunner(underlying_runner=...)` that is running remotely +and distributed, a distributed file system such as Cloud Storage +(`ib.options.cache_root = gs://bucket/obj`) is necessary. + #### Caching PCollection on Google Cloud Storage You can choose to cache PCollections on Google Cloud Storage with a few @@ -342,16 +344,11 @@ credential settings. * Make sure you have **read and write access to that bucket** when you specify to use that directory as caching directory. -* ```python - cache_dir = 'gs://bucket-name/dir' - runner = interactive_runner.InteractiveRunner(cache_dir=cache_dir) - p = beam.Pipeline(runner=runner) - ``` - -* Alternatively, you may configure a cache directory to be used by all interactive pipelines through using the `cache_root` option under interactive_beam. If the cache directory is specified this way, no additional parameters are required to be passed in during pipeline instantiation. +* You may configure a cache directory to be used by all pipelines created afterward with + an `InteractiveRunner`. * ```python - ib.options.cache_root = 'gs://bucket-name/dir' + ib.options.cache_root = 'gs://bucket-name/obj' ``` ### Portability across Execution Platforms @@ -376,7 +373,7 @@ You can choose to run Interactive Beam on Flink with the following settings. * Alternatively, if the runtime environment is configured with a Google Cloud project, you can run Interactive Beam with Flink on Cloud Dataproc. To do so, configure the pipeline with a Google Cloud project. If using dev versioned Beam built from source code, it is necessary to specify an `environment_config` option to configure a containerized Beam SDK (you can choose a released container or build one yourself). * ```python - ib.options.cache_root = 'gs://bucket-name/dir' + ib.options.cache_root = 'gs://bucket-name/obj' options = PipelineOptions([ # The project can be attained simply from running the following commands: # import google.auth diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/yarn.lock b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/yarn.lock index 12383a9f397b..2e7083232ff4 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/yarn.lock +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/yarn.lock @@ -6302,9 +6302,9 @@ loader-runner@^4.2.0: integrity sha512-92+huvxMvYlMzMt0iIOukcwYBFpkYJdpl2xsZ7LrlayO7E8SOv+JJUEK17B/dJIHAOLMfh2dZZ/Y18WgmGtYNw== loader-utils@^1.0.0: - version "1.4.1" - resolved "https://registry.yarnpkg.com/loader-utils/-/loader-utils-1.4.1.tgz#278ad7006660bccc4d2c0c1578e17c5c78d5c0e0" - integrity sha512-1Qo97Y2oKaU+Ro2xnDMR26g1BwMT29jNbem1EvcujW2jqt+j5COXyscjM7bLQkM9HaxI7pkWeW7gnI072yMI9Q== + version "1.4.2" + resolved "https://registry.yarnpkg.com/loader-utils/-/loader-utils-1.4.2.tgz#29a957f3a63973883eb684f10ffd3d151fec01a3" + integrity sha512-I5d00Pd/jwMD2QCduo657+YM/6L3KZu++pmX9VFncxaxvHcru9jx1lBaFft+r4Mt2jK0Yhp41XlRAihzPxHNCg== dependencies: big.js "^5.2.2" emojis-list "^3.0.0" diff --git a/sdks/python/apache_beam/testing/load_tests/load_test.py b/sdks/python/apache_beam/testing/load_tests/load_test.py index f5917fbfba27..3112c12ab86c 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test.py @@ -25,6 +25,7 @@ from apache_beam.metrics import MetricsFilter from apache_beam.options.pipeline_options import GoogleCloudOptions from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.runners.runner import PipelineState from apache_beam.testing.load_tests.load_test_metrics_utils import InfluxDBMetricsPublisherOptions from apache_beam.testing.load_tests.load_test_metrics_utils import MetricsReader from apache_beam.testing.test_pipeline import TestPipeline @@ -148,7 +149,8 @@ def run(self): if not hasattr(self, 'result'): self.result = self.pipeline.run() # Defaults to waiting forever, unless timeout_ms has been set - self.result.wait_until_finish(duration=self.timeout_ms) + state = self.result.wait_until_finish(duration=self.timeout_ms) + assert state != PipelineState.FAILED self._metrics_monitor.publish_metrics(self.result, self.extra_metrics) finally: self.cleanup() diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py index 9c6ef2a935ec..fbca1cb96e9d 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py @@ -33,7 +33,6 @@ import logging import time import uuid -from typing import Any from typing import List from typing import Mapping from typing import Optional @@ -185,8 +184,6 @@ class MetricsReader(object): A :class:`MetricsReader` retrieves metrics from pipeline result, prepares it for publishers and setup publishers. """ - publishers = [] # type: List[Any] - def __init__( self, project_name=None, @@ -206,6 +203,7 @@ def __init__( filters: MetricFilter to query only filtered metrics """ self._namespace = namespace + self.publishers: List[MetricsPublisher] = [] self.publishers.append(ConsoleMetricsPublisher()) check = project_name and bq_table and bq_dataset and publish_to_bq @@ -385,7 +383,13 @@ def _prepare_runtime_metrics(self, distributions): return runtime_in_s -class ConsoleMetricsPublisher(object): +class MetricsPublisher: + """Base class for metrics publishers.""" + def publish(self, results): + raise NotImplementedError + + +class ConsoleMetricsPublisher(MetricsPublisher): """A :class:`ConsoleMetricsPublisher` publishes collected metrics to console output.""" def publish(self, results): @@ -401,7 +405,7 @@ def publish(self, results): _LOGGER.info("No test results were collected.") -class BigQueryMetricsPublisher(object): +class BigQueryMetricsPublisher(MetricsPublisher): """A :class:`BigQueryMetricsPublisher` publishes collected metrics to BigQuery output.""" def __init__(self, project_name, table, dataset): @@ -484,7 +488,7 @@ def http_auth_enabled(self): return self.user is not None and self.password is not None -class InfluxDBMetricsPublisher(object): +class InfluxDBMetricsPublisher(MetricsPublisher): """Publishes collected metrics to InfluxDB database.""" def __init__( self, diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py index 305e42294867..a520b31cb9fb 100644 --- a/sdks/python/apache_beam/testing/synthetic_pipeline.py +++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py @@ -22,7 +22,7 @@ controlled through arguments. Please see function 'parse_args()' for more details about the arguments. -Shape of the pipeline is primariy controlled through two arguments. Argument +Shape of the pipeline is primarily controlled through two arguments. Argument 'steps' can be used to define a list of steps as a JSON string. Argument 'barrier' describes how these steps are separated from each other. Argument 'barrier' can be use to build a pipeline as a series of steps or a tree of diff --git a/sdks/python/apache_beam/typehints/arrow_type_compatibility.py b/sdks/python/apache_beam/typehints/arrow_type_compatibility.py index cad6ac8751ca..c8e425f0e96a 100644 --- a/sdks/python/apache_beam/typehints/arrow_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/arrow_type_compatibility.py @@ -303,13 +303,19 @@ def __init__(self, element_type: RowTypeConstraint): self._arrow_schema = arrow_schema @staticmethod - @BatchConverter.register def from_typehints(element_type, batch_type) -> Optional['PyarrowBatchConverter']: - if isinstance(element_type, RowTypeConstraint) and batch_type == pa.Table: - return PyarrowBatchConverter(element_type) + assert batch_type == pa.Table - return None + if not isinstance(element_type, RowTypeConstraint): + element_type = RowTypeConstraint.from_user_type(element_type) + if element_type is None: + raise TypeError( + "Element type must be compatible with Beam Schemas (" + "https://beam.apache.org/documentation/programming-guide/#schemas) " + "for batch type pa.Table.") + + return PyarrowBatchConverter(element_type) def produce_batch(self, elements): arrays = [ @@ -358,13 +364,11 @@ def __init__(self, element_type: type): self._arrow_type = _arrow_type_from_beam_fieldtype(beam_fieldtype) @staticmethod - @BatchConverter.register def from_typehints(element_type, batch_type) -> Optional['PyarrowArrayBatchConverter']: - if batch_type == pa.Array: - return PyarrowArrayBatchConverter(element_type) + assert batch_type == pa.Array - return None + return PyarrowArrayBatchConverter(element_type) def produce_batch(self, elements): return pa.array(list(elements), type=self._arrow_type) @@ -382,3 +386,16 @@ def get_length(self, batch: pa.Array): def estimate_byte_size(self, batch: pa.Array): return batch.nbytes + + +@BatchConverter.register(name="pyarrow") +def create_pyarrow_batch_converter( + element_type: type, batch_type: type) -> BatchConverter: + if batch_type == pa.Table: + return PyarrowBatchConverter.from_typehints( + element_type=element_type, batch_type=batch_type) + elif batch_type == pa.Array: + return PyarrowArrayBatchConverter.from_typehints( + element_type=element_type, batch_type=batch_type) + + raise TypeError("batch type must be pa.Table or pa.Array") diff --git a/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py b/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py index 6a8649cff1ea..e708b151d905 100644 --- a/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py @@ -19,6 +19,7 @@ import logging import unittest +from typing import Any from typing import Optional import pyarrow as pa @@ -192,6 +193,29 @@ def test_hash(self): self.assertEqual(hash(self.create_batch_converter()), hash(self.converter)) +class ArrowBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ + ( + pa.RecordBatch, + row_type.RowTypeConstraint.from_fields([ + ("bar", Optional[float]), # noqa: F821 + ("baz", Optional[str]), # noqa: F821 + ]), + r'batch type must be pa\.Table or pa\.Array', + ), + ( + pa.Table, + Any, + r'Element type must be compatible with Beam Schemas', + ), + ]) + def test_construction_errors( + self, batch_typehint, element_typehint, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + BatchConverter.from_typehints( + element_type=element_typehint, batch_type=batch_typehint) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/typehints/batch.py b/sdks/python/apache_beam/typehints/batch.py index de6c7fb71572..35351b147d48 100644 --- a/sdks/python/apache_beam/typehints/batch.py +++ b/sdks/python/apache_beam/typehints/batch.py @@ -29,7 +29,7 @@ from typing import Callable from typing import Generic from typing import Iterator -from typing import List +from typing import Mapping from typing import Optional from typing import Sequence from typing import TypeVar @@ -44,7 +44,8 @@ B = TypeVar('B') E = TypeVar('E') -BATCH_CONVERTER_REGISTRY: List[Callable[[type, type], 'BatchConverter']] = [] +BatchConverterConstructor = Callable[[type, type], 'BatchConverter'] +BATCH_CONVERTER_REGISTRY: Mapping[str, BatchConverterConstructor] = {} __all__ = ['BatchConverter'] @@ -72,26 +73,34 @@ def estimate_byte_size(self, batch): raise NotImplementedError @staticmethod - def register( - batch_converter_constructor: Callable[[type, type], 'BatchConverter']): - BATCH_CONVERTER_REGISTRY.append(batch_converter_constructor) - return batch_converter_constructor + def register(*, name: str): + def do_registration( + batch_converter_constructor: Callable[[type, type], 'BatchConverter']): + if name in BATCH_CONVERTER_REGISTRY: + raise AssertionError( + f"Attempted to register two batch converters with name {name}") + + BATCH_CONVERTER_REGISTRY[name] = batch_converter_constructor + return batch_converter_constructor + + return do_registration @staticmethod def from_typehints(*, element_type, batch_type) -> 'BatchConverter': element_type = typehints.normalize(element_type) batch_type = typehints.normalize(batch_type) - for constructor in BATCH_CONVERTER_REGISTRY: - result = constructor(element_type, batch_type) - if result is not None: - return result - - # TODO(https://github.com/apache/beam/issues/21654): Aggregate error - # information from the failed BatchConverter matches instead of this - # generic error. + errors = {} + for name, constructor in BATCH_CONVERTER_REGISTRY.items(): + try: + return constructor(element_type, batch_type) + except TypeError as e: + errors[name] = e.args[0] + + error_summaries = '\n\n'.join( + f"{name}:\n\t{msg}" for name, msg in errors.items()) raise TypeError( - f"Unable to find BatchConverter for element_type {element_type!r} and " - f"batch_type {batch_type!r}") + f"Unable to find BatchConverter for element_type={element_type!r} and " + f"batch_type={batch_type!r}. Error summaries:\n\n{error_summaries}") @property def batch_type(self): @@ -124,13 +133,13 @@ def __init__(self, batch_type, element_type): self.element_coder = coders.registry.get_coder(element_type) @staticmethod - @BatchConverter.register + @BatchConverter.register(name="list") def from_typehints(element_type, batch_type): - if (isinstance(batch_type, typehints.ListConstraint) and - batch_type.inner_type == element_type): - return ListBatchConverter(batch_type, element_type) - else: - return None + if (not isinstance(batch_type, typehints.ListConstraint) or + batch_type.inner_type != element_type): + raise TypeError("batch type must be List[T] for element type T") + + return ListBatchConverter(batch_type, element_type) def produce_batch(self, elements): return list(elements) @@ -173,29 +182,35 @@ def __init__( self.partition_dimension = partition_dimension @staticmethod - @BatchConverter.register + @BatchConverter.register(name="numpy") def from_typehints(element_type, batch_type) -> Optional['NumpyBatchConverter']: if not isinstance(element_type, NumpyTypeHint.NumpyTypeConstraint): try: element_type = NumpyArray[element_type, ()] - except TypeError: - # TODO: Is there a better way to detect if element_type is a dtype? - return None + except TypeError as e: + raise TypeError("Element type is not a dtype") from e if not isinstance(batch_type, NumpyTypeHint.NumpyTypeConstraint): if not batch_type == np.ndarray: - # TODO: Include explanation for mismatch? - return None + raise TypeError( + "batch type must be np.ndarray or " + "beam.typehints.batch.NumpyArray[..]") batch_type = NumpyArray[element_type.dtype, (N, )] if not batch_type.dtype == element_type.dtype: - return None - batch_shape = list(batch_type.shape) - partition_dimension = batch_shape.index(N) - batch_shape.pop(partition_dimension) - if not tuple(batch_shape) == element_type.shape: - return None + raise TypeError( + "batch type and element type must have equivalent dtypes " + f"(batch={batch_type.dtype}, element={element_type.dtype})") + + computed_element_shape = list(batch_type.shape) + partition_dimension = computed_element_shape.index(N) + computed_element_shape.pop(partition_dimension) + if not tuple(computed_element_shape) == element_type.shape: + raise TypeError( + "Failed to align batch type's batch dimension with element type. " + f"(batch type dimensions: {batch_type.shape}, element type " + f"dimenstions: {element_type.shape}") return NumpyBatchConverter( batch_type, diff --git a/sdks/python/apache_beam/typehints/batch_test.py b/sdks/python/apache_beam/typehints/batch_test.py index a6ea003dd496..3fbad76fce06 100644 --- a/sdks/python/apache_beam/typehints/batch_test.py +++ b/sdks/python/apache_beam/typehints/batch_test.py @@ -149,6 +149,38 @@ def test_hash(self): self.assertEqual(hash(self.create_batch_converter()), hash(self.converter)) +class BatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ + ( + typing.List[int], + str, + r'batch type must be List\[T\] for element type T', + ), + ( + np.ndarray, + typing.Any, + r'Element type is not a dtype', + ), + ( + np.array, + np.int64, + ( + r'batch type must be np\.ndarray or ' + r'beam\.typehints\.batch\.NumpyArray\[\.\.\]'), + ), + ( + NumpyArray[np.int64, (3, N, 2)], + NumpyArray[np.int64, (3, 7)], + r'Failed to align batch type\'s batch dimension', + ), + ]) + def test_construction_errors( + self, batch_typehint, element_typehint, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + BatchConverter.from_typehints( + element_type=element_typehint, batch_type=batch_typehint) + + @contextlib.contextmanager def temp_seed(seed): state = random.getstate() diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 2a54450e212a..153b9d4b4588 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -22,6 +22,7 @@ import collections import logging import sys +import types import typing from apache_beam.typehints import typehints @@ -176,6 +177,14 @@ def convert_to_beam_type(typ): Raises: ValueError: The type was malformed. """ + # Convert `int | float` to typing.Union[int, float] + # pipe operator as Union and types.UnionType are introduced + # in Python 3.10. + # GH issue: https://github.com/apache/beam/issues/21972 + if (sys.version_info.major == 3 and + sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)): + typ = typing.Union[typ] + if isinstance(typ, typing.TypeVar): # This is a special case, as it's not parameterized by types. # Also, identity must be preserved through conversion (i.e. the same diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py index a143f9c4ef37..ca9523f28349 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py @@ -136,7 +136,7 @@ def dtype_to_fieldtype(dtype): return Any -@BatchConverter.register +@BatchConverter.register(name="pandas") def create_pandas_batch_converter( element_type: type, batch_type: type) -> BatchConverter: if batch_type == pd.DataFrame: @@ -146,7 +146,7 @@ def create_pandas_batch_converter( return SeriesBatchConverter.from_typehints( element_type=element_type, batch_type=batch_type) - return None + raise TypeError("batch type must be pd.Series or pd.DataFrame") class DataFrameBatchConverter(BatchConverter): @@ -160,13 +160,15 @@ def __init__( @staticmethod def from_typehints(element_type, batch_type) -> Optional['DataFrameBatchConverter']: - if not batch_type == pd.DataFrame: - return None + assert batch_type == pd.DataFrame if not isinstance(element_type, RowTypeConstraint): element_type = RowTypeConstraint.from_user_type(element_type) if element_type is None: - return None + raise TypeError( + "Element type must be compatible with Beam Schemas (" + "https://beam.apache.org/documentation/programming-guide/#schemas) " + "for batch type pd.DataFrame") index_columns = [ field_name @@ -275,8 +277,7 @@ def unbatch(series): @staticmethod def from_typehints(element_type, batch_type) -> Optional['SeriesBatchConverter']: - if not batch_type == pd.Series: - return None + assert batch_type == pd.Series dtype = dtype_from_typehint(element_type) diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py index 0ee9b1178a9b..5a8dc72dd4b9 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py @@ -18,6 +18,7 @@ """Unit tests for pandas batched type converters.""" import unittest +from typing import Any from typing import Optional import numpy as np @@ -115,7 +116,7 @@ dtype=pd.StringDtype()), }, ]) -class DataFrameBatchConverterTest(unittest.TestCase): +class PandasBatchConverterTest(unittest.TestCase): def create_batch_converter(self): return BatchConverter.from_typehints( element_type=self.element_typehint, batch_type=self.batch_typehint) @@ -208,5 +209,28 @@ def test_hash(self): self.assertEqual(hash(self.create_batch_converter()), hash(self.converter)) +class PandasBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ + ( + Any, + row_type.RowTypeConstraint.from_fields([ + ("bar", Optional[float]), # noqa: F821 + ("baz", Optional[str]), # noqa: F821 + ]), + r'batch type must be pd\.Series or pd\.DataFrame', + ), + ( + pd.DataFrame, + Any, + r'Element type must be compatible with Beam Schemas', + ), + ]) + def test_construction_errors( + self, batch_typehint, element_typehint, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + BatchConverter.from_typehints( + element_type=element_typehint, batch_type=batch_typehint) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py index fbecb6d5105b..f008174bcc03 100644 --- a/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py @@ -37,29 +37,31 @@ def __init__( self.partition_dimension = partition_dimension @staticmethod - @BatchConverter.register + @BatchConverter.register(name="pytorch") def from_typehints(element_type, batch_type) -> Optional['PytorchBatchConverter']: if not isinstance(element_type, PytorchTypeHint.PytorchTypeConstraint): - try: - element_type = PytorchTensor[element_type, ()] - except TypeError: - # TODO: Is there a better way to detect if element_type is a dtype? - return None + element_type = PytorchTensor[element_type, ()] if not isinstance(batch_type, PytorchTypeHint.PytorchTypeConstraint): if not batch_type == torch.Tensor: - # TODO: Include explanation for mismatch? - return None + raise TypeError( + "batch type must be torch.Tensor or " + "beam.typehints.pytorch_type_compatibility.PytorchTensor[..]") batch_type = PytorchTensor[element_type.dtype, (N, )] if not batch_type.dtype == element_type.dtype: - return None - batch_shape = list(batch_type.shape) - partition_dimension = batch_shape.index(N) - batch_shape.pop(partition_dimension) - if not tuple(batch_shape) == element_type.shape: - return None + raise TypeError( + "batch type and element type must have equivalent dtypes " + f"(batch={batch_type.dtype}, element={element_type.dtype})") + computed_element_shape = list(batch_type.shape) + partition_dimension = computed_element_shape.index(N) + computed_element_shape.pop(partition_dimension) + if not tuple(computed_element_shape) == element_type.shape: + raise TypeError( + "Could not align batch type's batch dimension with element type. " + f"(batch type dimensions: {batch_type.shape}, element type " + f"dimenstions: {element_type.shape}") return PytorchBatchConverter( batch_type, diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py index e851d4679ccb..d1f5c0d271ee 100644 --- a/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py @@ -18,6 +18,7 @@ """Unit tests for pytorch_type_compabitility.""" import unittest +from typing import Any import pytest from parameterized import parameterized @@ -134,5 +135,32 @@ def test_hash(self): self.assertEqual(hash(self.create_batch_converter()), hash(self.converter)) +class PytorchBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ + ( + Any, + PytorchTensor[torch.int64, ()], + ( + r'batch type must be torch\.Tensor or ' + r'beam\.typehints\.pytorch_type_compatibility.PytorchTensor'), + ), + ( + PytorchTensor[torch.int64, (3, N, 2)], + PytorchTensor[torch.int64, (3, 7)], + r'Could not align batch type\'s batch dimension', + ), + ( + PytorchTensor[torch.int64, (N, 10)], + PytorchTensor[torch.float32, (10, )], + r'batch type and element type must have equivalent dtypes', + ), + ]) + def test_construction_errors( + self, batch_typehint, element_typehint, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + BatchConverter.from_typehints( + element_type=element_typehint, batch_type=batch_typehint) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 1695ac2bf7b2..487fb78d26bb 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -67,6 +67,8 @@ import copy import logging +import sys +import types import typing from collections import abc @@ -384,6 +386,9 @@ def validate_composite_type_param(type_param, error_msg_prefix): not isinstance(type_param, tuple(possible_classes)) and type_param is not None and getattr(type_param, '__module__', None) != 'typing') + if sys.version_info.major == 3 and sys.version_info.minor >= 10: + if isinstance(type_param, types.UnionType): + is_not_type_constraint = False is_forbidden_type = ( isinstance(type_param, type) and type_param in DISALLOWED_PRIMITIVE_TYPES) diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 16130d643a4a..532da8917f8c 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -1599,6 +1599,16 @@ def expand(self, pcoll: typing.Any) -> typehints.Any: self.assertEqual(th.input_types, ((typehints.Any, ), {})) self.assertEqual(th.input_types, ((typehints.Any, ), {})) + def test_pipe_operator_as_union(self): + # union types can be written using pipe operator from Python 3.10. + # https://peps.python.org/pep-0604/ + if sys.version_info.major == 3 and sys.version_info.minor >= 10: + type_a = int | float # pylint: disable=unsupported-binary-operation + type_b = typing.Union[int, float] + self.assertEqual( + native_type_compatibility.convert_to_beam_type(type_a), + native_type_compatibility.convert_to_beam_type(type_b)) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle index 0eae96c8bec9..9585d0922046 100644 --- a/sdks/python/test-suites/portable/common.gradle +++ b/sdks/python/test-suites/portable/common.gradle @@ -230,6 +230,11 @@ project.tasks.register("flinkExamples") { "--environment_type=LOOPBACK", "--temp_location=gs://temp-storage-for-end-to-end-tests/temp-it", "--flink_job_server_jar=${project(":runners:flink:${latestFlinkVersion}:job-server").shadowJar.archivePath}", + '--sdk_harness_log_level_overrides=' + + // suppress info level flink.runtime log flood + '{\\"org.apache.flink.runtime\\":\\"WARN\\",' + + // suppress full __metricscontainers log printed in FlinkPipelineRunner.createPortablePipelineResult + '\\"org.apache.beam.runners.flink.FlinkPipelineRunner\\":\\"WARN\\"}' ] def cmdArgs = mapToArgString([ "test_opts": testOpts, @@ -312,6 +317,9 @@ project.tasks.register("postCommitPy${pythonVersionSuffix}IT") { // suppress metric name collision warning logs '\\"org.apache.flink.runtime.metrics.groups\\":\\"ERROR\\"}' ] + if (project.hasProperty('flinkConfDir')) { + pipelineOpts += ["--flink-conf-dir=${project.property('flinkConfDir')}"] + } def cmdArgs = mapToArgString([ "test_opts": testOpts, "suite": "postCommitIT-flink-py${pythonVersionSuffix}", diff --git a/website/www/site/content/en/blog/splitAtFraction-method.md b/website/www/site/content/en/blog/splitAtFraction-method.md index 0ae5b7346933..f2e280aaabd8 100644 --- a/website/www/site/content/en/blog/splitAtFraction-method.md +++ b/website/www/site/content/en/blog/splitAtFraction-method.md @@ -22,7 +22,7 @@ See the License for the specific language governing permissions and limitations under the License. --> -This morning, Eugene and Malo from the Google Cloud Dataflow team posted [*No shard left behind: dynamic work rebalancing in Google Cloud Dataflow*](https://cloud.google.com/blog/big-data/2016/05/no-shard-left-behind-dynamic-work-rebalancing-in-google-cloud-dataflow). This article discusses Cloud Dataflow’s solution to the well-known straggler problem. +This morning, Eugene and Malo from the Google Cloud Dataflow team posted [*No shard left behind: dynamic work rebalancing in Google Cloud Dataflow*](https://cloud.google.com/blog/products/gcp/no-shard-left-behind-dynamic-work-rebalancing-in-google-cloud-dataflow). This article discusses Cloud Dataflow’s solution to the well-known straggler problem. diff --git a/website/www/site/content/en/blog/splittable-do-fn.md b/website/www/site/content/en/blog/splittable-do-fn.md index d7c1abfafe85..f38a5d6d4886 100644 --- a/website/www/site/content/en/blog/splittable-do-fn.md +++ b/website/www/site/content/en/blog/splittable-do-fn.md @@ -187,7 +187,7 @@ runner with information such as its estimated size (or its generalization, uses this information to tune the execution and control the breakdown of the `Source` into bundles. For example, a slowly progressing large bundle of a file may be [dynamically -split](https://cloud.google.com/blog/big-data/2016/05/no-shard-left-behind-dynamic-work-rebalancing-in-google-cloud-dataflow) +split](https://cloud.google.com/blog/products/gcp/no-shard-left-behind-dynamic-work-rebalancing-in-google-cloud-dataflow) by a batch-focused runner before it becomes a straggler, and a latency-focused streaming runner may control how many elements it reads from a source in each bundle to optimize for latency vs. per-bundle overhead. @@ -251,7 +251,7 @@ a `@ProcessElement` call is going to take too long and become a straggler, it can split the restriction in some proportion so that the primary is short enough to not be a straggler, and can schedule the residual in parallel on another worker. For details, see [No Shard Left -Behind](https://cloud.google.com/blog/big-data/2016/05/no-shard-left-behind-dynamic-work-rebalancing-in-google-cloud-dataflow). +Behind](https://cloud.google.com/blog/products/gcp/no-shard-left-behind-dynamic-work-rebalancing-in-google-cloud-dataflow). Logically, the execution of an SDF on an element works according to the following diagram, where "magic" stands for the runner-specific ability to split diff --git a/website/www/site/content/en/documentation/runners/dataflow.md b/website/www/site/content/en/documentation/runners/dataflow.md index eb5398d3c258..7b5d3e60f567 100644 --- a/website/www/site/content/en/documentation/runners/dataflow.md +++ b/website/www/site/content/en/documentation/runners/dataflow.md @@ -26,7 +26,7 @@ The Cloud Dataflow Runner and service are suitable for large scale, continuous j * a fully managed service * [autoscaling](https://cloud.google.com/dataflow/service/dataflow-service-desc#autoscaling) of the number of workers throughout the lifetime of the job -* [dynamic work rebalancing](https://cloud.google.com/blog/big-data/2016/05/no-shard-left-behind-dynamic-work-rebalancing-in-google-cloud-dataflow) +* [dynamic work rebalancing](https://cloud.google.com/blog/products/gcp/no-shard-left-behind-dynamic-work-rebalancing-in-google-cloud-dataflow) The [Beam Capability Matrix](/documentation/runners/capability-matrix/) documents the supported capabilities of the Cloud Dataflow Runner. diff --git a/website/www/site/content/en/documentation/sdks/python-machine-learning.md b/website/www/site/content/en/documentation/sdks/python-machine-learning.md index b35ab347a8b9..c235c2d4b059 100644 --- a/website/www/site/content/en/documentation/sdks/python-machine-learning.md +++ b/website/www/site/content/en/documentation/sdks/python-machine-learning.md @@ -219,7 +219,7 @@ For detailed instructions explaining how to build and run a pipeline that uses M ## Beam Java SDK support -The RunInference API is available with the Beam Java SDK versions 2.41.0 and later through Apache Beam's [Multi-language Pipelines framework](https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines). For information about the Java wrapper transform, see [RunInference.java](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java). For example pipelines, see [RunInferenceTransformTest.java](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java). +The RunInference API is available with the Beam Java SDK versions 2.41.0 and later through Apache Beam's [Multi-language Pipelines framework](https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines). For information about the Java wrapper transform, see [RunInference.java](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java). To try it out, see the [Java Sklearn Mnist Classification example](https://github.com/apache/beam/tree/master/examples/multi-language). ## Troubleshooting