Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Python PostCommit BigQuery JSON #30438

Merged
merged 3 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
70 changes: 54 additions & 16 deletions sdks/python/apache_beam/io/gcp/bigquery_json_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

# pylint: disable=wrong-import-order, wrong-import-position
try:
from google.api_core import exceptions as gexc
from google.cloud import bigquery
except ImportError:
gexc = None
bigquery = None
# pylint: enable=wrong-import-order, wrong-import-position

_LOGGER = logging.getLogger(__name__)

PROJECT = 'apache-beam-testing'
Expand All @@ -44,13 +53,27 @@

STREAMING_TEST_TABLE = "py_streaming_test" \
f"{time.time_ns() // 1000}_{randint(0,32)}"
FILE_LOAD_TABLE = "py_fileload_test" \
f"{time.time_ns() // 1000}_{randint(0,32)}"


class BigQueryJsonIT(unittest.TestCase):
created_tables = set()

@classmethod
def setUpClass(cls):
cls.test_pipeline = TestPipeline(is_integration_test=True)

@classmethod
def tearDownClass(cls):
if cls.created_tables:
client = bigquery.Client(project=PROJECT)
for ref in cls.created_tables:
try:
client.delete_table(ref[len(PROJECT) + 1:]) # need dataset:table
except gexc.NotFound:
pass # just skip

def run_test_write(self, options):
json_table_schema = self.generate_schema()
rows_to_write = []
Expand All @@ -71,8 +94,10 @@ def run_test_write(self, options):
parser = argparse.ArgumentParser()
parser.add_argument('--write_method')
parser.add_argument('--output')
parser.add_argument('--unescape', required=False)

known_args, pipeline_args = parser.parse_known_args(options)
self.created_tables.add(known_args.output)

with beam.Pipeline(argv=pipeline_args) as p:
_ = (
Expand All @@ -85,26 +110,46 @@ def run_test_write(self, options):
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
))

extra_opts = {'read_method': "EXPORT", 'input': known_args.output}
extra_opts = {
'read_method': "EXPORT",
'input': known_args.output,
'unescape': known_args.unescape
}
read_options = self.test_pipeline.get_full_options_as_args(**extra_opts)
self.read_and_validate_rows(read_options)

def read_and_validate_rows(self, options):
json_data = self.generate_data()

parser = argparse.ArgumentParser()
parser.add_argument('--read_method')
parser.add_argument('--query')
parser.add_argument('--input')
parser.add_argument('--unescape', required=False)

known_args, pipeline_args = parser.parse_known_args(options)

# TODO(yathu) remove this conversion when FILE_LOAD produces unescaped
# JSON string
def maybe_unescape(value):
if known_args.unescape:
value = bytes(value, "utf-8").decode("unicode_escape")[1:-1]
return json.loads(value)

class CompareJson(beam.DoFn, unittest.TestCase):
def process(self, row):
country_code = row["country_code"]
expected = json_data[country_code]

# Test country (JSON String)
country_actual = json.loads(row["country"])
country_actual = maybe_unescape(row["country"])
country_expected = json.loads(expected["country"])

self.assertTrue(country_expected == country_actual)

# Test stats (JSON String in BigQuery struct)
for stat, value in row["stats"].items():
stats_actual = json.loads(value)
stats_actual = maybe_unescape(value)
stats_expected = json.loads(expected["stats"][stat])
self.assertTrue(stats_expected == stats_actual)

Expand All @@ -113,25 +158,18 @@ def process(self, row):
city = city_row["city"]
city_name = city_row["city_name"]

city_actual = json.loads(city)
city_actual = maybe_unescape(city)
city_expected = json.loads(expected["cities"][city_name])
self.assertTrue(city_expected == city_actual)

# Test landmarks (JSON String in BigQuery array)
landmarks_actual = row["landmarks"]
landmarks_expected = expected["landmarks"]
for i in range(len(landmarks_actual)):
l_actual = json.loads(landmarks_actual[i])
l_actual = maybe_unescape(landmarks_actual[i])
l_expected = json.loads(landmarks_expected[i])
self.assertTrue(l_expected == l_actual)

parser = argparse.ArgumentParser()
parser.add_argument('--read_method')
parser.add_argument('--query')
parser.add_argument('--input')

known_args, pipeline_args = parser.parse_known_args(options)

method = ReadFromBigQuery.Method.DIRECT_READ if \
known_args.read_method == "DIRECT_READ" else \
ReadFromBigQuery.Method.EXPORT
Expand Down Expand Up @@ -197,12 +235,12 @@ def test_streaming_inserts(self):
@pytest.mark.it_postcommit
def test_file_loads_write(self):
extra_opts = {
'output': f"{PROJECT}:{DATASET_ID}.{STREAMING_TEST_TABLE}",
'write_method': "FILE_LOADS"
'output': f"{PROJECT}:{DATASET_ID}.{FILE_LOAD_TABLE}",
'write_method': "FILE_LOADS",
"unescape": "True"
}
options = self.test_pipeline.get_full_options_as_args(**extra_opts)
with self.assertRaises(ValueError):
self.run_test_write(options)
self.run_test_write(options)

# Schema for writing to BigQuery
def generate_schema(self):
Expand Down
Loading