From da834fdee214f60b839d425b21cfb562bf1ad759 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Sat, 21 Dec 2024 10:02:08 -0500 Subject: [PATCH] Improve the behavior of an unseekable stream input. --- .../runners/dataflow/internal/apiclient.py | 17 +++-- .../dataflow/internal/apiclient_test.py | 71 +++++++++++++------ 2 files changed, 60 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index bc0b6fe14e0..f9b84e16906 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -42,6 +42,7 @@ import re import sys import time +import traceback import warnings from copy import copy from datetime import datetime @@ -707,11 +708,17 @@ def stage_file_with_retry( gcs_or_local_path, file_name, stream, mime_type, total_size) elif isinstance(stream_or_path, io.BufferedIOBase): stream = stream_or_path - if stream.tell() > 0: - assert stream.seekable(), "stream must be seekable" - stream.seek(0) - self.stage_file( - gcs_or_local_path, file_name, stream, mime_type, total_size) + try: + self.stage_file( + gcs_or_local_path, file_name, stream, mime_type, total_size) + except Exception as exn: + if stream.seekable(): + stream.seek(0) + raise exn + else: + raise retry.PermanentException( + "Failed to tell or seek in stream because we caught exception:", + ''.join(traceback.format_exception_only(exn.__class__, exn))) @retry.no_retries # Using no_retries marks this as an integration point. def create_job(self, job): diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 4b32dc567b0..9d3804b0ce3 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -44,6 +44,7 @@ from apache_beam.transforms import DoFn from apache_beam.transforms import ParDo from apache_beam.transforms.environments import DockerEnvironment +from apache_beam.utils import retry # Protect against environments where apitools library is not available. # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports @@ -1670,14 +1671,16 @@ def exists_return_value(*args): self.assertEqual(pipeline, pipeline_expected) def test_stage_file_with_retry(self): - count = 0 - def effect(self, *args, **kwargs): nonlocal count count += 1 - if count > 1: - return - raise Exception("This exception is raised for testing purpose.") + # Fail the first two calls and succeed afterward + if count <= 2: + raise Exception("This exception is raised for testing purpose.") + + class Unseekable(io.BufferedIOBase): + def seekable(self): + return False pipeline_options = PipelineOptions([ '--project', @@ -1690,28 +1693,50 @@ def effect(self, *args, **kwargs): pipeline_options.view_as(GoogleCloudOptions).no_auth = True client = apiclient.DataflowApplicationClient(pipeline_options) - with mock.patch.object(time, 'sleep'): - count = 0 - with mock.patch("builtins.open", - mock.mock_open(read_data="data")) as mock_file_open: - with mock.patch.object(client, 'stage_file') as mock_stage_file: - mock_stage_file.side_effect = effect - # call with a file name + with mock.patch.object(client, 'stage_file') as mock_stage_file: + mock_stage_file.side_effect = effect + + with mock.patch.object(time, 'sleep') as mock_sleep: + with mock.patch("builtins.open", + mock.mock_open(read_data="data")) as mock_file_open: + count = 0 + # calling with a file name client.stage_file_with_retry( - "/to", "new_name", "/from/old_name", total_size=1024) - self.assertEqual(mock_file_open.call_count, 2) - self.assertEqual(mock_stage_file.call_count, 2) - - count = 0 - with mock.patch("builtins.open", - mock.mock_open(read_data="data")) as mock_file_open: - with mock.patch.object(client, 'stage_file') as mock_stage_file: - mock_stage_file.side_effect = effect - # call with a seekable stream + "/to", "new_name", "/from/old_name", total_size=4) + self.assertEqual(mock_stage_file.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + self.assertEqual(mock_file_open.call_count, 3) + + count = 0 + mock_stage_file.reset_mock() + mock_sleep.reset_mock() + mock_file_open.reset_mock() + + # calling with a seekable stream client.stage_file_with_retry( "/to", "new_name", io.BytesIO(b'test'), total_size=4) + self.assertEqual(mock_stage_file.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + # no open() is called if a stream is provided + mock_file_open.assert_not_called() + + count = 0 + mock_sleep.reset_mock() + mock_file_open.reset_mock() + mock_stage_file.reset_mock() + + # calling with an unseekable stream + self.assertRaises( + retry.PermanentException, + client.stage_file_with_retry, + "/to", + "new_name", + Unseekable(), + total_size=4) + # Unseekable is only staged once. There won't be any retry if it fails + self.assertEqual(mock_stage_file.call_count, 1) + mock_sleep.assert_not_called() mock_file_open.assert_not_called() - self.assertEqual(mock_stage_file.call_count, 2) if __name__ == '__main__':