From b2d1c60dc7710213f21fd4bf3fda0c5d375fb759 Mon Sep 17 00:00:00 2001 From: Kerry Donny-Clark Date: Wed, 23 Aug 2023 16:46:22 -0400 Subject: [PATCH] is_service_runner now returns false with dataflow_endpoint=localhost (#28128) --- .../options/pipeline_options_validator.py | 10 ++++++++-- .../options/pipeline_options_validator_test.py | 12 +++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/options/pipeline_options_validator.py b/sdks/python/apache_beam/options/pipeline_options_validator.py index 640569b6293e..7c07e5a1e6c7 100644 --- a/sdks/python/apache_beam/options/pipeline_options_validator.py +++ b/sdks/python/apache_beam/options/pipeline_options_validator.py @@ -157,8 +157,13 @@ def is_service_runner(self): dataflow_endpoint = ( self.options.view_as(GoogleCloudOptions).dataflow_endpoint) - is_service_endpoint = (dataflow_endpoint is not None) - return is_service_runner and is_service_endpoint + if dataflow_endpoint is None: + return False + else: + endpoint_parts = urlparse(dataflow_endpoint, allow_fragments=False) + if endpoint_parts.netloc.startswith("localhost"): + return False + return is_service_runner def is_full_string_match(self, pattern, string): """Returns True if the pattern matches the whole string.""" @@ -404,6 +409,7 @@ def validate_repeatable_argument_passed_as_list(self, view, arg_name): # Minimally validates the endpoint url. This is not a strict application # of http://www.faqs.org/rfcs/rfc1738.html. + # If the url matches localhost, set def validate_endpoint_url(self, endpoint_url): url_parts = urlparse(endpoint_url, allow_fragments=False) if not url_parts.scheme or not url_parts.netloc: diff --git a/sdks/python/apache_beam/options/pipeline_options_validator_test.py b/sdks/python/apache_beam/options/pipeline_options_validator_test.py index 653ea112c8c3..1f4c8be226dc 100644 --- a/sdks/python/apache_beam/options/pipeline_options_validator_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_validator_test.py @@ -311,6 +311,11 @@ def test_is_service_runner(self): 'options': ['--dataflow_endpoint=https://dataflow.googleapis.com/'], 'expected': False, }, + { + 'runner': MockRunners.DataflowRunner(), + 'options': [], + 'expected': True, + }, { 'runner': MockRunners.DataflowRunner(), 'options': ['--dataflow_endpoint=https://another.service.com'], @@ -321,6 +326,11 @@ def test_is_service_runner(self): 'options': ['--dataflow_endpoint=https://dataflow.googleapis.com'], 'expected': True, }, + { + 'runner': MockRunners.DataflowRunner(), + 'options': ['--dataflow_endpoint=http://localhost:1000'], + 'expected': False, + }, { 'runner': MockRunners.DataflowRunner(), 'options': ['--dataflow_endpoint=foo: //dataflow. googleapis. com'], @@ -336,7 +346,7 @@ def test_is_service_runner(self): for case in test_cases: validator = PipelineOptionsValidator( PipelineOptions(case['options']), case['runner']) - self.assertEqual(validator.is_service_runner(), case['expected']) + self.assertEqual(validator.is_service_runner(), case['expected'], case) def test_dataflow_job_file_and_template_location_mutually_exclusive(self): runner = MockRunners.OtherRunner()