Skip to content

Commit

Permalink
Improve the behavior of an unseekable stream input.
Browse files Browse the repository at this point in the history
  • Loading branch information
shunping committed Dec 21, 2024
1 parent e891af8 commit da834fd
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 28 deletions.
17 changes: 12 additions & 5 deletions sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import re
import sys
import time
import traceback
import warnings
from copy import copy
from datetime import datetime
Expand Down Expand Up @@ -707,11 +708,17 @@ def stage_file_with_retry(
gcs_or_local_path, file_name, stream, mime_type, total_size)
elif isinstance(stream_or_path, io.BufferedIOBase):
stream = stream_or_path
if stream.tell() > 0:
assert stream.seekable(), "stream must be seekable"
stream.seek(0)
self.stage_file(
gcs_or_local_path, file_name, stream, mime_type, total_size)
try:
self.stage_file(
gcs_or_local_path, file_name, stream, mime_type, total_size)
except Exception as exn:
if stream.seekable():
stream.seek(0)
raise exn
else:
raise retry.PermanentException(
"Failed to tell or seek in stream because we caught exception:",
''.join(traceback.format_exception_only(exn.__class__, exn)))

@retry.no_retries # Using no_retries marks this as an integration point.
def create_job(self, job):
Expand Down
71 changes: 48 additions & 23 deletions sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from apache_beam.transforms import DoFn
from apache_beam.transforms import ParDo
from apache_beam.transforms.environments import DockerEnvironment
from apache_beam.utils import retry

# Protect against environments where apitools library is not available.
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
Expand Down Expand Up @@ -1670,14 +1671,16 @@ def exists_return_value(*args):
self.assertEqual(pipeline, pipeline_expected)

def test_stage_file_with_retry(self):
count = 0

def effect(self, *args, **kwargs):
nonlocal count
count += 1
if count > 1:
return
raise Exception("This exception is raised for testing purpose.")
# Fail the first two calls and succeed afterward
if count <= 2:
raise Exception("This exception is raised for testing purpose.")

class Unseekable(io.BufferedIOBase):
def seekable(self):
return False

pipeline_options = PipelineOptions([
'--project',
Expand All @@ -1690,28 +1693,50 @@ def effect(self, *args, **kwargs):
pipeline_options.view_as(GoogleCloudOptions).no_auth = True
client = apiclient.DataflowApplicationClient(pipeline_options)

with mock.patch.object(time, 'sleep'):
count = 0
with mock.patch("builtins.open",
mock.mock_open(read_data="data")) as mock_file_open:
with mock.patch.object(client, 'stage_file') as mock_stage_file:
mock_stage_file.side_effect = effect
# call with a file name
with mock.patch.object(client, 'stage_file') as mock_stage_file:
mock_stage_file.side_effect = effect

with mock.patch.object(time, 'sleep') as mock_sleep:
with mock.patch("builtins.open",
mock.mock_open(read_data="data")) as mock_file_open:
count = 0
# calling with a file name
client.stage_file_with_retry(
"/to", "new_name", "/from/old_name", total_size=1024)
self.assertEqual(mock_file_open.call_count, 2)
self.assertEqual(mock_stage_file.call_count, 2)

count = 0
with mock.patch("builtins.open",
mock.mock_open(read_data="data")) as mock_file_open:
with mock.patch.object(client, 'stage_file') as mock_stage_file:
mock_stage_file.side_effect = effect
# call with a seekable stream
"/to", "new_name", "/from/old_name", total_size=4)
self.assertEqual(mock_stage_file.call_count, 3)
self.assertEqual(mock_sleep.call_count, 2)
self.assertEqual(mock_file_open.call_count, 3)

count = 0
mock_stage_file.reset_mock()
mock_sleep.reset_mock()
mock_file_open.reset_mock()

# calling with a seekable stream
client.stage_file_with_retry(
"/to", "new_name", io.BytesIO(b'test'), total_size=4)
self.assertEqual(mock_stage_file.call_count, 3)
self.assertEqual(mock_sleep.call_count, 2)
# no open() is called if a stream is provided
mock_file_open.assert_not_called()

count = 0
mock_sleep.reset_mock()
mock_file_open.reset_mock()
mock_stage_file.reset_mock()

# calling with an unseekable stream
self.assertRaises(
retry.PermanentException,
client.stage_file_with_retry,
"/to",
"new_name",
Unseekable(),
total_size=4)
# Unseekable is only staged once. There won't be any retry if it fails
self.assertEqual(mock_stage_file.call_count, 1)
mock_sleep.assert_not_called()
mock_file_open.assert_not_called()
self.assertEqual(mock_stage_file.call_count, 2)


if __name__ == '__main__':
Expand Down

0 comments on commit da834fd

Please sign in to comment.