forked from apache/beam
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from Polber/spanner-enrichment-use-case
fix examples_test
- Loading branch information
Showing
3 changed files
with
169 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,9 +21,11 @@ | |
import os | ||
import random | ||
import unittest | ||
from typing import Any | ||
from typing import Callable | ||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
from typing import Union | ||
from unittest import mock | ||
|
||
|
@@ -34,11 +36,63 @@ | |
from apache_beam.examples.snippets.util import assert_matches_stdout | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.testing.test_pipeline import TestPipeline | ||
from apache_beam.yaml import yaml_provider | ||
from apache_beam.yaml import yaml_transform | ||
from apache_beam.yaml.readme_test import TestEnvironment | ||
from apache_beam.yaml.readme_test import replace_recursive | ||
|
||
|
||
# Used to simulate Enrichment transform during tests | ||
# The GitHub action that invokes these tests does not | ||
# have gcp dependencies installed which is a prerequisite | ||
# to apache_beam.transforms.enrichment.Enrichment as a top-level | ||
# import. | ||
@beam.ptransform.ptransform_fn | ||
def test_enrichment( | ||
pcoll, | ||
enrichment_handler: str, | ||
handler_config: Dict[str, Any], | ||
timeout: Optional[float] = 30): | ||
if enrichment_handler == 'BigTable': | ||
row_key = handler_config['row_key'] | ||
bt_data = INPUT_TABLES[( | ||
'BigTable', handler_config['instance_id'], handler_config['table_id'])] | ||
products = {str(data[row_key]): data for data in bt_data} | ||
|
||
def _fn(row): | ||
left = row._asdict() | ||
right = products[str(row[row_key])] | ||
left['product'] = left.get('product', None) or right | ||
return beam.Row(**row) | ||
elif enrichment_handler == 'BigQuery': | ||
row_key = handler_config['fields'] | ||
dataset, table = handler_config['table_name'].split('.')[-2:] | ||
bq_data = INPUT_TABLES[('BigQuery', str(dataset), str(table))] | ||
products = { | ||
tuple(str(data[key]) for key in row_key): data | ||
for data in bq_data | ||
} | ||
|
||
def _fn(row): | ||
left = row._asdict() | ||
right = products[tuple(str(left[k]) for k in row_key)] | ||
row = { | ||
key: left.get(key, None) or right[key] | ||
for key in {*left.keys(), *right.keys()} | ||
} | ||
return beam.Row(**row) | ||
|
||
else: | ||
raise ValueError(f'{enrichment_handler} is not a valid enrichment_handler.') | ||
|
||
return pcoll | beam.Map(_fn) | ||
|
||
|
||
TEST_PROVIDERS = { | ||
'TestEnrichment': test_enrichment, | ||
} | ||
|
||
|
||
def check_output(expected: List[str]): | ||
def _check_inner(actual: List[PCollection[str]]): | ||
formatted_actual = actual | beam.Flatten() | beam.Map( | ||
|
@@ -59,7 +113,31 @@ def products_csv(): | |
]) | ||
|
||
|
||
def spanner_data(): | ||
def spanner_orders_data(): | ||
return [{ | ||
'order_id': 1, | ||
'customer_id': 1001, | ||
'product_id': 2001, | ||
'order_date': '24-03-24', | ||
'order_amount': 150, | ||
}, | ||
{ | ||
'order_id': 2, | ||
'customer_id': 1002, | ||
'product_id': 2002, | ||
'order_date': '19-04-24', | ||
'order_amount': 90, | ||
}, | ||
{ | ||
'order_id': 3, | ||
'customer_id': 1003, | ||
'product_id': 2003, | ||
'order_date': '7-05-24', | ||
'order_amount': 110, | ||
}] | ||
|
||
|
||
def spanner_shipments_data(): | ||
return [{ | ||
'shipment_id': 'S1', | ||
'customer_id': 'C1', | ||
|
@@ -110,6 +188,44 @@ def spanner_data(): | |
}] | ||
|
||
|
||
def bigtable_data(): | ||
return [{ | ||
'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2' | ||
}, { | ||
'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4' | ||
}, { | ||
'product_id': '3', 'product_name': 'pixel 7', 'product_stock': '20' | ||
}, { | ||
'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10' | ||
}, { | ||
'product_id': '5', 'product_name': 'pixel 11', 'product_stock': '3' | ||
}, { | ||
'product_id': '6', 'product_name': 'pixel 12', 'product_stock': '7' | ||
}, { | ||
'product_id': '7', 'product_name': 'pixel 13', 'product_stock': '8' | ||
}, { | ||
'product_id': '8', 'product_name': 'pixel 14', 'product_stock': '3' | ||
}] | ||
|
||
|
||
def bigquery_data(): | ||
return [{ | ||
'customer_id': 1001, | ||
'customer_name': 'Alice', | ||
'customer_email': '[email protected]' | ||
}, | ||
{ | ||
'customer_id': 1002, | ||
'customer_name': 'Bob', | ||
'customer_email': '[email protected]' | ||
}, | ||
{ | ||
'customer_id': 1003, | ||
'customer_name': 'Claire', | ||
'customer_email': '[email protected]' | ||
}] | ||
|
||
|
||
def create_test_method( | ||
pipeline_spec_file: str, | ||
custom_preprocessors: List[Callable[..., Union[Dict, List]]]): | ||
|
@@ -135,7 +251,11 @@ def test_yaml_example(self): | |
pickle_library='cloudpickle', | ||
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( | ||
'options', {})))) as p: | ||
actual = [yaml_transform.expand_pipeline(p, pipeline_spec)] | ||
actual = [ | ||
yaml_transform.expand_pipeline( | ||
p, | ||
pipeline_spec, [yaml_provider.InlineProvider(TEST_PROVIDERS)]) | ||
] | ||
if not actual[0]: | ||
actual = list(p.transforms_stack[0].parts[-1].outputs.values()) | ||
for transform in p.transforms_stack[0].parts[:-1]: | ||
|
@@ -214,7 +334,6 @@ def _wordcount_test_preprocessor( | |
'test_simple_filter_and_combine_yaml', | ||
'test_spanner_read_yaml', | ||
'test_spanner_write_yaml', | ||
'test_bigtable_enrichment_yaml', | ||
'test_enrich_spanner_with_bigquery_yaml' | ||
]) | ||
def _io_write_test_preprocessor( | ||
|
@@ -251,7 +370,8 @@ def _file_io_read_test_preprocessor( | |
return test_spec | ||
|
||
|
||
@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml']) | ||
@YamlExamplesTestSuite.register_test_preprocessor( | ||
['test_spanner_read_yaml', 'test_enrich_spanner_with_bigquery_yaml']) | ||
def _spanner_io_read_test_preprocessor( | ||
test_spec: dict, expected: List[str], env: TestEnvironment): | ||
|
||
|
@@ -267,14 +387,42 @@ def _spanner_io_read_test_preprocessor( | |
k: v | ||
for k, v in config.items() if k.startswith('__') | ||
} | ||
transform['config']['elements'] = INPUT_TABLES[( | ||
str(instance), str(database), str(table))] | ||
elements = INPUT_TABLES[(str(instance), str(database), str(table))] | ||
if config.get('query', None): | ||
config['query'].replace('select ', | ||
'SELECT ').replace(' from ', ' FROM ') | ||
columns = set( | ||
''.join(config['query'].split('SELECT ')[1:]).split( | ||
' FROM', maxsplit=1)[0]) | ||
if columns != {'*'}: | ||
elements = [{ | ||
column: element[column] | ||
for column in element if column in columns | ||
} for element in elements] | ||
transform['config']['elements'] = elements | ||
|
||
return test_spec | ||
|
||
|
||
@YamlExamplesTestSuite.register_test_preprocessor( | ||
['test_bigtable_enrichment_yaml', 'test_enrich_spanner_with_bigquery_yaml']) | ||
def _enrichment_test_preprocessor( | ||
test_spec: dict, expected: List[str], env: TestEnvironment): | ||
if pipeline := test_spec.get('pipeline', None): | ||
for transform in pipeline.get('transforms', []): | ||
if transform.get('type', '').startswith('Enrichment'): | ||
transform['type'] = 'TestEnrichment' | ||
|
||
return test_spec | ||
|
||
|
||
INPUT_FILES = {'products.csv': products_csv()} | ||
INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()} | ||
INPUT_TABLES = { | ||
('shipment-test', 'shipment', 'shipments'): spanner_shipments_data(), | ||
('orders-test', 'order-database', 'orders'): spanner_orders_data(), | ||
('BigTable', 'beam-test', 'bigtable-enrichment-test'): bigtable_data(), | ||
('BigQuery', 'ALL_TEST', 'customers'): bigquery_data() | ||
} | ||
|
||
YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__)) | ||
ExamplesTest = YamlExamplesTestSuite( | ||
|
@@ -292,6 +440,10 @@ def _spanner_io_read_test_preprocessor( | |
'IOExamplesTest', os.path.join(YAML_DOCS_DIR, | ||
'../transforms/io/*.yaml')).run() | ||
|
||
MLTest = YamlExamplesTestSuite( | ||
'MLExamplesTest', os.path.join(YAML_DOCS_DIR, | ||
'../transforms/ml/*.yaml')).run() | ||
|
||
if __name__ == '__main__': | ||
logging.getLogger().setLevel(logging.INFO) | ||
unittest.main() |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,24 +66,25 @@ pipeline: | |
callable: 'lambda x: x.order_amount' | ||
output_type: integer | ||
|
||
# Step 4: Filter orders with amount greater than 100 | ||
# Step 4: Filter orders with amount greater than 110 | ||
- type: Filter | ||
name: FilterHighValueOrders | ||
input: MapEnrichedValues | ||
config: | ||
keep: "order_amount > 100" | ||
keep: "order_amount > 110" | ||
language: "python" | ||
|
||
|
||
# Step 6: Write processed order to another spanner table | ||
# Note: Make sure to replace $VARS with your values. | ||
- type: WriteToSpanner | ||
name: WriteProcessedOrders | ||
input: FilterHighValueOrders | ||
config: | ||
project_id: 'apache-beam-testing' | ||
instance_id: 'orders-test' | ||
database_id: 'order-database' | ||
table_id: 'orders_with_customers' | ||
project_id: '$PROJECT' | ||
instance_id: '$INSTANCE' | ||
database_id: '$DATABASE' | ||
table_id: '$TABLE' | ||
error_handling: | ||
output: my_error_output | ||
|
||
|
@@ -94,6 +95,8 @@ pipeline: | |
config: | ||
path: 'errors.json' | ||
|
||
options: | ||
yaml_experimental_features: Enrichment | ||
|
||
# Expected: | ||
# Row(customer_id=1001, customer_name='Alice', customer_email='[email protected]', product_id=2001, order_date='24-03-24', order_amount=150) | ||
# Row(customer_id=1003, customer_name='Claire', customer_email='[email protected]', product_id=2003, order_date='7-05-24', order_amount=110) |