Skip to content

Commit

Permalink
Merge pull request #1 from Polber/spanner-enrichment-use-case
Browse files Browse the repository at this point in the history
fix examples_test
  • Loading branch information
reeba212 authored Dec 7, 2024
2 parents fad1f7f + 0f32323 commit 5988a23
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 14 deletions.
166 changes: 159 additions & 7 deletions sdks/python/apache_beam/yaml/examples/testing/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import os
import random
import unittest
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from unittest import mock

Expand All @@ -34,11 +36,63 @@
from apache_beam.examples.snippets.util import assert_matches_stdout
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
from apache_beam.yaml.readme_test import TestEnvironment
from apache_beam.yaml.readme_test import replace_recursive


# Used to simulate Enrichment transform during tests
# The GitHub action that invokes these tests does not
# have gcp dependencies installed which is a prerequisite
# to apache_beam.transforms.enrichment.Enrichment as a top-level
# import.
@beam.ptransform.ptransform_fn
def test_enrichment(
pcoll,
enrichment_handler: str,
handler_config: Dict[str, Any],
timeout: Optional[float] = 30):
if enrichment_handler == 'BigTable':
row_key = handler_config['row_key']
bt_data = INPUT_TABLES[(
'BigTable', handler_config['instance_id'], handler_config['table_id'])]
products = {str(data[row_key]): data for data in bt_data}

def _fn(row):
left = row._asdict()
right = products[str(row[row_key])]
left['product'] = left.get('product', None) or right
return beam.Row(**row)
elif enrichment_handler == 'BigQuery':
row_key = handler_config['fields']
dataset, table = handler_config['table_name'].split('.')[-2:]
bq_data = INPUT_TABLES[('BigQuery', str(dataset), str(table))]
products = {
tuple(str(data[key]) for key in row_key): data
for data in bq_data
}

def _fn(row):
left = row._asdict()
right = products[tuple(str(left[k]) for k in row_key)]
row = {
key: left.get(key, None) or right[key]
for key in {*left.keys(), *right.keys()}
}
return beam.Row(**row)

else:
raise ValueError(f'{enrichment_handler} is not a valid enrichment_handler.')

return pcoll | beam.Map(_fn)


TEST_PROVIDERS = {
'TestEnrichment': test_enrichment,
}


def check_output(expected: List[str]):
def _check_inner(actual: List[PCollection[str]]):
formatted_actual = actual | beam.Flatten() | beam.Map(
Expand All @@ -59,7 +113,31 @@ def products_csv():
])


def spanner_data():
def spanner_orders_data():
return [{
'order_id': 1,
'customer_id': 1001,
'product_id': 2001,
'order_date': '24-03-24',
'order_amount': 150,
},
{
'order_id': 2,
'customer_id': 1002,
'product_id': 2002,
'order_date': '19-04-24',
'order_amount': 90,
},
{
'order_id': 3,
'customer_id': 1003,
'product_id': 2003,
'order_date': '7-05-24',
'order_amount': 110,
}]


def spanner_shipments_data():
return [{
'shipment_id': 'S1',
'customer_id': 'C1',
Expand Down Expand Up @@ -110,6 +188,44 @@ def spanner_data():
}]


def bigtable_data():
return [{
'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'
}, {
'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'
}, {
'product_id': '3', 'product_name': 'pixel 7', 'product_stock': '20'
}, {
'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'
}, {
'product_id': '5', 'product_name': 'pixel 11', 'product_stock': '3'
}, {
'product_id': '6', 'product_name': 'pixel 12', 'product_stock': '7'
}, {
'product_id': '7', 'product_name': 'pixel 13', 'product_stock': '8'
}, {
'product_id': '8', 'product_name': 'pixel 14', 'product_stock': '3'
}]


def bigquery_data():
return [{
'customer_id': 1001,
'customer_name': 'Alice',
'customer_email': '[email protected]'
},
{
'customer_id': 1002,
'customer_name': 'Bob',
'customer_email': '[email protected]'
},
{
'customer_id': 1003,
'customer_name': 'Claire',
'customer_email': '[email protected]'
}]


def create_test_method(
pipeline_spec_file: str,
custom_preprocessors: List[Callable[..., Union[Dict, List]]]):
Expand All @@ -135,7 +251,11 @@ def test_yaml_example(self):
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {})))) as p:
actual = [yaml_transform.expand_pipeline(p, pipeline_spec)]
actual = [
yaml_transform.expand_pipeline(
p,
pipeline_spec, [yaml_provider.InlineProvider(TEST_PROVIDERS)])
]
if not actual[0]:
actual = list(p.transforms_stack[0].parts[-1].outputs.values())
for transform in p.transforms_stack[0].parts[:-1]:
Expand Down Expand Up @@ -214,7 +334,6 @@ def _wordcount_test_preprocessor(
'test_simple_filter_and_combine_yaml',
'test_spanner_read_yaml',
'test_spanner_write_yaml',
'test_bigtable_enrichment_yaml',
'test_enrich_spanner_with_bigquery_yaml'
])
def _io_write_test_preprocessor(
Expand Down Expand Up @@ -251,7 +370,8 @@ def _file_io_read_test_preprocessor(
return test_spec


@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml'])
@YamlExamplesTestSuite.register_test_preprocessor(
['test_spanner_read_yaml', 'test_enrich_spanner_with_bigquery_yaml'])
def _spanner_io_read_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):

Expand All @@ -267,14 +387,42 @@ def _spanner_io_read_test_preprocessor(
k: v
for k, v in config.items() if k.startswith('__')
}
transform['config']['elements'] = INPUT_TABLES[(
str(instance), str(database), str(table))]
elements = INPUT_TABLES[(str(instance), str(database), str(table))]
if config.get('query', None):
config['query'].replace('select ',
'SELECT ').replace(' from ', ' FROM ')
columns = set(
''.join(config['query'].split('SELECT ')[1:]).split(
' FROM', maxsplit=1)[0])
if columns != {'*'}:
elements = [{
column: element[column]
for column in element if column in columns
} for element in elements]
transform['config']['elements'] = elements

return test_spec


@YamlExamplesTestSuite.register_test_preprocessor(
['test_bigtable_enrichment_yaml', 'test_enrich_spanner_with_bigquery_yaml'])
def _enrichment_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):
if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '').startswith('Enrichment'):
transform['type'] = 'TestEnrichment'

return test_spec


INPUT_FILES = {'products.csv': products_csv()}
INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()}
INPUT_TABLES = {
('shipment-test', 'shipment', 'shipments'): spanner_shipments_data(),
('orders-test', 'order-database', 'orders'): spanner_orders_data(),
('BigTable', 'beam-test', 'bigtable-enrichment-test'): bigtable_data(),
('BigQuery', 'ALL_TEST', 'customers'): bigquery_data()
}

YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__))
ExamplesTest = YamlExamplesTestSuite(
Expand All @@ -292,6 +440,10 @@ def _spanner_io_read_test_preprocessor(
'IOExamplesTest', os.path.join(YAML_DOCS_DIR,
'../transforms/io/*.yaml')).run()

MLTest = YamlExamplesTestSuite(
'MLExamplesTest', os.path.join(YAML_DOCS_DIR,
'../transforms/ml/*.yaml')).run()

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,25 @@ pipeline:
callable: 'lambda x: x.order_amount'
output_type: integer

# Step 4: Filter orders with amount greater than 100
# Step 4: Filter orders with amount greater than 110
- type: Filter
name: FilterHighValueOrders
input: MapEnrichedValues
config:
keep: "order_amount > 100"
keep: "order_amount > 110"
language: "python"


# Step 6: Write processed order to another spanner table
# Note: Make sure to replace $VARS with your values.
- type: WriteToSpanner
name: WriteProcessedOrders
input: FilterHighValueOrders
config:
project_id: 'apache-beam-testing'
instance_id: 'orders-test'
database_id: 'order-database'
table_id: 'orders_with_customers'
project_id: '$PROJECT'
instance_id: '$INSTANCE'
database_id: '$DATABASE'
table_id: '$TABLE'
error_handling:
output: my_error_output

Expand All @@ -94,6 +95,8 @@ pipeline:
config:
path: 'errors.json'

options:
yaml_experimental_features: Enrichment

# Expected:
# Row(customer_id=1001, customer_name='Alice', customer_email='[email protected]', product_id=2001, order_date='24-03-24', order_amount=150)
# Row(customer_id=1003, customer_name='Claire', customer_email='[email protected]', product_id=2003, order_date='7-05-24', order_amount=110)

0 comments on commit 5988a23

Please sign in to comment.