Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(BQ Python) Fix streaming with large loads by performing job waits in finish_bundle #23012

Merged
merged 12 commits into from
Sep 14, 2022
158 changes: 75 additions & 83 deletions sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@
# triggering file write to avoid generating too many small files.
_FILE_TRIGGERING_BATCHING_DURATION_SECS = 1

# How many seconds we wait before polling a pending job
_SLEEP_DURATION_BETWEEN_POLLS = 10


def _generate_job_name(job_name, job_type, step_name):
return bigquery_tools.generate_bq_job_name(
Expand Down Expand Up @@ -355,9 +358,10 @@ def __init__(
self._step_name = step_name
self._load_job_project_id = load_job_project_id

def setup(self):
self._bq_wrapper = bigquery_tools.BigQueryWrapper(client=self._test_client)
def start_bundle(self):
self.bq_wrapper = bigquery_tools.BigQueryWrapper(client=self._test_client)
ahmedabu98 marked this conversation as resolved.
Show resolved Hide resolved
self._bq_io_metadata = create_bigquery_io_metadata(self._step_name)
self.pending_jobs = []

def display_data(self):
return {
Expand Down Expand Up @@ -392,7 +396,7 @@ def process(self, element, schema_mod_job_name_prefix):

try:
# Check if destination table exists
destination_table = self._bq_wrapper.get_table(
destination_table = self.bq_wrapper.get_table(
project_id=table_reference.projectId,
dataset_id=table_reference.datasetId,
table_id=table_reference.tableId)
Expand All @@ -404,7 +408,7 @@ def process(self, element, schema_mod_job_name_prefix):
else:
raise

temp_table_load_job = self._bq_wrapper.get_job(
temp_table_load_job = self.bq_wrapper.get_job(
project=temp_table_load_job_reference.projectId,
job_id=temp_table_load_job_reference.jobId,
location=temp_table_load_job_reference.location)
Expand Down Expand Up @@ -432,20 +436,36 @@ def process(self, element, schema_mod_job_name_prefix):
table_reference)
# Trigger potential schema modification by loading zero rows into the
# destination table with the temporary table schema.
schema_update_job_reference = self._bq_wrapper.perform_load_job(
destination=table_reference,
source_stream=io.BytesIO(), # file with zero rows
job_id=job_name,
schema=temp_table_schema,
write_disposition='WRITE_APPEND',
create_disposition='CREATE_NEVER',
additional_load_parameters=additional_parameters,
job_labels=self._bq_io_metadata.add_additional_bq_job_labels(),
# JSON format is hardcoded because zero rows load(unlike AVRO) and
# a nested schema(unlike CSV, which a default one) is permitted.
source_format="NEWLINE_DELIMITED_JSON",
load_job_project_id=self._load_job_project_id)
yield (destination, schema_update_job_reference)
schema_update_job_reference = self.bq_wrapper.perform_load_job(
destination=table_reference,
source_stream=io.BytesIO(), # file with zero rows
job_id=job_name,
schema=temp_table_schema,
write_disposition='WRITE_APPEND',
create_disposition='CREATE_NEVER',
additional_load_parameters=additional_parameters,
job_labels=self._bq_io_metadata.add_additional_bq_job_labels(),
# JSON format is hardcoded because zero rows load(unlike AVRO) and
# a nested schema(unlike CSV, which a default one) is permitted.
source_format="NEWLINE_DELIMITED_JSON",
load_job_project_id=self._load_job_project_id)
self.pending_jobs.append(
GlobalWindows.windowed_value(
(destination, schema_update_job_reference)))

def finish_bundle(self):
# Unlike the other steps, schema update is not always necessary.
# In that case, return a None value to avoid blocking in streaming context.
# Otherwise, the streaming pipeline would get stuck waiting for the
# TriggerCopyJobs side-input.
if not self.pending_jobs:
return [GlobalWindows.windowed_value(None)]

for windowed_value in self.pending_jobs:
job_ref = windowed_value.value[1]
ahmedabu98 marked this conversation as resolved.
Show resolved Hide resolved
self.bq_wrapper.wait_for_bq_job(
job_ref, sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS)
return self.pending_jobs


class TriggerCopyJobs(beam.DoFn):
Expand All @@ -462,6 +482,9 @@ class TriggerCopyJobs(beam.DoFn):
copying from temp_tables to destination_table is not atomic.
See: https://issues.apache.org/jira/browse/BEAM-7822
"""

TRIGGER_DELETE_TEMP_TABLES = 'TriggerDeleteTempTables'

def __init__(
self,
project=None,
Expand Down Expand Up @@ -490,6 +513,7 @@ def start_bundle(self):
self.bq_wrapper = bigquery_tools.BigQueryWrapper(client=self.test_client)
if not self.bq_io_metadata:
self.bq_io_metadata = create_bigquery_io_metadata(self._step_name)
self.pending_jobs = []

def process(self, element, job_name_prefix=None, unused_schema_mod_jobs=None):
destination = element[0]
Expand Down Expand Up @@ -551,8 +575,19 @@ def process(self, element, job_name_prefix=None, unused_schema_mod_jobs=None):

if wait_for_job:
self.bq_wrapper.wait_for_bq_job(job_reference, sleep_duration_sec=10)
self.pending_jobs.append(
GlobalWindows.windowed_value((destination, job_reference)))

def finish_bundle(self):
for windowed_value in self.pending_jobs:
job_ref = windowed_value.value[1]
self.bq_wrapper.wait_for_bq_job(
job_ref, sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS)
yield windowed_value

yield (destination, job_reference)
yield pvalue.TaggedOutput(
TriggerCopyJobs.TRIGGER_DELETE_TEMP_TABLES,
GlobalWindows.windowed_value(None))


class TriggerLoadJobs(beam.DoFn):
Expand Down Expand Up @@ -609,6 +644,7 @@ def start_bundle(self):
self.bq_wrapper = bigquery_tools.BigQueryWrapper(client=self.test_client)
if not self.bq_io_metadata:
self.bq_io_metadata = create_bigquery_io_metadata(self._step_name)
self.pending_jobs = []

def process(self, element, load_job_name_prefix, *schema_side_inputs):
# Each load job is assumed to have files respecting these constraints:
Expand Down Expand Up @@ -682,7 +718,15 @@ def process(self, element, load_job_name_prefix, *schema_side_inputs):
source_format=self.source_format,
job_labels=self.bq_io_metadata.add_additional_bq_job_labels(),
load_job_project_id=self.load_job_project_id)
yield (destination, job_reference)
self.pending_jobs.append(
GlobalWindows.windowed_value((destination, job_reference)))

def finish_bundle(self):
for windowed_value in self.pending_jobs:
job_ref = windowed_value.value[1]
self.bq_wrapper.wait_for_bq_job(
job_ref, sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS)
return self.pending_jobs


class PartitionFiles(beam.DoFn):
Expand Down Expand Up @@ -739,29 +783,6 @@ def process(self, element):
yield pvalue.TaggedOutput(output_tag, (destination, partition))


class WaitForBQJobs(beam.DoFn):
"""Takes in a series of BQ job names as side input, and waits for all of them.

If any job fails, it will fail. If all jobs succeed, it will succeed.

Experimental; no backwards compatibility guarantees.
"""
def __init__(self, test_client=None):
self.test_client = test_client

def start_bundle(self):
self.bq_wrapper = bigquery_tools.BigQueryWrapper(client=self.test_client)

def process(self, element, dest_ids_list):
job_references = [elm[1] for elm in dest_ids_list]
for ref in job_references:
# We must poll repeatedly until the job finishes or fails, thus setting
# max_retries to 0.
self.bq_wrapper.wait_for_bq_job(ref, sleep_duration_sec=10, max_retries=0)

return dest_ids_list # Pass the list of destination-jobs downstream


class DeleteTablesFn(beam.DoFn):
def __init__(self, test_client=None):
self.test_client = test_client
Expand Down Expand Up @@ -1038,15 +1059,8 @@ def _load_data(
temp_tables_load_job_ids_pc = trigger_loads_outputs['main']
temp_tables_pc = trigger_loads_outputs[TriggerLoadJobs.TEMP_TABLES]

finished_temp_tables_load_jobs_pc = (
p
| "ImpulseMonitorLoadJobs" >> beam.Create([None])
| "WaitForTempTableLoadJobs" >> beam.ParDo(
WaitForBQJobs(self.test_client),
pvalue.AsList(temp_tables_load_job_ids_pc)))

schema_mod_job_ids_pc = (
finished_temp_tables_load_jobs_pc
temp_tables_load_job_ids_pc
| beam.ParDo(
UpdateDestinationSchema(
project=self.project,
Expand All @@ -1057,15 +1071,8 @@ def _load_data(
load_job_project_id=self.load_job_project_id),
schema_mod_job_name_pcv))

finished_schema_mod_jobs_pc = (
p
| "ImpulseMonitorSchemaModJobs" >> beam.Create([None])
| "WaitForSchemaModJobs" >> beam.ParDo(
WaitForBQJobs(self.test_client),
pvalue.AsList(schema_mod_job_ids_pc)))

destination_copy_job_ids_pc = (
finished_temp_tables_load_jobs_pc
copy_job_outputs = (
temp_tables_load_job_ids_pc
| beam.ParDo(
TriggerCopyJobs(
project=self.project,
Expand All @@ -1075,25 +1082,17 @@ def _load_data(
step_name=step_name,
load_job_project_id=self.load_job_project_id),
copy_job_name_pcv,
pvalue.AsIter(finished_schema_mod_jobs_pc)))
pvalue.AsIter(schema_mod_job_ids_pc)).with_outputs(
TriggerCopyJobs.TRIGGER_DELETE_TEMP_TABLES, main='main'))

finished_copy_jobs_pc = (
p
| "ImpulseMonitorCopyJobs" >> beam.Create([None])
| "WaitForCopyJobs" >> beam.ParDo(
WaitForBQJobs(self.test_client),
pvalue.AsList(destination_copy_job_ids_pc)))
destination_copy_job_ids_pc = copy_job_outputs['main']
trigger_delete = copy_job_outputs[
TriggerCopyJobs.TRIGGER_DELETE_TEMP_TABLES]

_ = (
p
| "RemoveTempTables/Impulse" >> beam.Create([None])
| "RemoveTempTables/PassTables" >> beam.FlatMap(
lambda _,
unused_copy_jobs,
deleting_tables: deleting_tables,
pvalue.AsIter(finished_copy_jobs_pc),
pvalue.AsIter(temp_tables_pc))
| "RemoveTempTables/AddUselessValue" >> beam.Map(lambda x: (x, None))
temp_tables_pc
| "RemoveTempTables/AddUselessValue" >> beam.Map(
lambda x, unused_trigger: (x, None), pvalue.AsList(trigger_delete))
| "RemoveTempTables/DeduplicateTables" >> beam.GroupByKey()
| "RemoveTempTables/GetTableNames" >> beam.Keys()
| "RemoveTempTables/Delete" >> beam.ParDo(
Expand All @@ -1116,13 +1115,6 @@ def _load_data(
load_job_name_pcv,
*self.schema_side_inputs))

_ = (
p
| "ImpulseMonitorDestinationLoadJobs" >> beam.Create([None])
| "WaitForDestinationLoadJobs" >> beam.ParDo(
WaitForBQJobs(self.test_client),
pvalue.AsList(destination_load_job_ids_pc)))

destination_load_job_ids_pc = (
(temp_tables_load_job_ids_pc, destination_load_job_ids_pc)
| beam.Flatten())
Expand Down
Loading