Skip to content

Commit

Permalink
[Python] Add an option to retain timestamp returned from the BigTable…
Browse files Browse the repository at this point in the history
… row (#30088)

* retain timestamp returned from bigtable row

* make it optional to retain timestamp

* fix import order, remove trigger file

* ignore arg type in isinstance
  • Loading branch information
riteshghorse authored Jan 24, 2024
1 parent 2721414 commit 674fe77
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 17 deletions.
24 changes: 18 additions & 6 deletions sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from apache_beam.transforms.enrichment import EnrichmentSourceHandler

__all__ = [
'EnrichWithBigTable',
'BigTableEnrichmentHandler',
'ExceptionLevel',
]

Expand All @@ -52,8 +52,8 @@ class ExceptionLevel(Enum):
QUIET = 2


class EnrichWithBigTable(EnrichmentSourceHandler[beam.Row, beam.Row]):
"""EnrichWithBigTable is a handler for
class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
"""BigTableEnrichmentHandler is a handler for
:class:`apache_beam.transforms.enrichment.Enrichment` transform to interact
with GCP BigTable.
Expand All @@ -74,6 +74,10 @@ class EnrichWithBigTable(EnrichmentSourceHandler[beam.Row, beam.Row]):
``apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel``
to set the level when an empty row is returned from the BigTable query.
Defaults to ``ExceptionLevel.WARN``.
include_timestamp (bool): If enabled, the timestamp associated with the
value is returned as `(value, timestamp)` for each `row_key`.
Defaults to `False` - only the latest value without
the timestamp is returned.
"""
def __init__(
self,
Expand All @@ -82,9 +86,11 @@ def __init__(
table_id: str,
row_key: str,
row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1),
*,
app_profile_id: str = None, # type: ignore[assignment]
encoding: str = 'utf-8',
exception_level: ExceptionLevel = ExceptionLevel.WARN,
include_timestamp: bool = False,
):
self._project_id = project_id
self._instance_id = instance_id
Expand All @@ -94,6 +100,7 @@ def __init__(
self._app_profile_id = app_profile_id
self._encoding = encoding
self._exception_level = exception_level
self._include_timestamp = include_timestamp

def __enter__(self):
"""connect to the Google BigTable cluster."""
Expand Down Expand Up @@ -122,9 +129,14 @@ def __call__(self, request: beam.Row, *args, **kwargs):
if row:
for cf_id, cf_v in row.cells.items():
response_dict[cf_id] = {}
for k, v in cf_v.items():
response_dict[cf_id][k.decode(self._encoding)] = \
v[0].value.decode(self._encoding)
for col_id, col_v in cf_v.items():
if self._include_timestamp:
response_dict[cf_id][col_id.decode(self._encoding)] = [
(v.value.decode(self._encoding), v.timestamp) for v in col_v
]
else:
response_dict[cf_id][col_id.decode(
self._encoding)] = col_v[0].value.decode(self._encoding)
elif self._exception_level == ExceptionLevel.WARN:
_LOGGER.warning(
'no matching row found for row_key: %s '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Tuple

import pytest

Expand All @@ -33,7 +34,7 @@
from google.cloud.bigtable import Client
from google.cloud.bigtable.row_filters import ColumnRangeFilter
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.bigtable import EnrichWithBigTable
from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.bigtable import ExceptionLevel
except ImportError:
raise unittest.SkipTest('GCP BigTable dependencies are not installed.')
Expand All @@ -46,10 +47,13 @@ def __init__(
self,
n_fields: int,
fields: List[str],
enriched_fields: Dict[str, List[str]]):
enriched_fields: Dict[str, List[str]],
include_timestamp: bool = False,
):
self.n_fields = n_fields
self._fields = fields
self._enriched_fields = enriched_fields
self._include_timestamp = include_timestamp

def process(self, element: beam.Row, *args, **kwargs):
element_dict = element.as_dict()
Expand All @@ -62,12 +66,22 @@ def process(self, element: beam.Row, *args, **kwargs):
raise BeamAssertException(f"Expected a not None field: {field}")

for column_family, columns in self._enriched_fields.items():
if (len(element_dict[column_family]) != len(columns) or
not all(key in element_dict[column_family] for key in columns)):
if len(element_dict[column_family]) != len(columns):
raise BeamAssertException(
"Response from bigtable should contain a %s column_family with "
"%s keys." % (column_family, columns))

for key in columns:
if key not in element_dict[column_family]:
raise BeamAssertException(
"Response from bigtable should contain a %s column_family with "
"%s columns." % (column_family, columns))
if (self._include_timestamp and
not isinstance(element_dict[column_family][key][0], Tuple)): # type: ignore[arg-type]
raise BeamAssertException(
"Response from bigtable should contain timestamp associated with "
"its value.")


class _Currency(NamedTuple):
s_id: int
Expand Down Expand Up @@ -157,7 +171,7 @@ def test_enrichment_with_bigtable(self):
expected_enriched_fields = {
'product': ['product_id', 'product_name', 'product_stock'],
}
bigtable = EnrichWithBigTable(
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
Expand All @@ -182,7 +196,7 @@ def test_enrichment_with_bigtable_row_filter(self):
}
start_column = 'product_name'.encode()
column_filter = ColumnRangeFilter(self.column_family_id, start_column)
bigtable = EnrichWithBigTable(
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
Expand All @@ -204,7 +218,7 @@ def test_enrichment_with_bigtable_no_enrichment(self):
# won't be added. Hence, the response is same as the request.
expected_fields = ['sale_id', 'customer_id', 'product_id', 'quantity']
expected_enriched_fields = {}
bigtable = EnrichWithBigTable(
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
Expand All @@ -227,7 +241,7 @@ def test_enrichment_with_bigtable_bad_row_filter(self):
# column names then all columns in that column_family are returned.
start_column = 'car_name'.encode()
column_filter = ColumnRangeFilter('car_name', start_column)
bigtable = EnrichWithBigTable(
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
Expand All @@ -245,7 +259,7 @@ def test_enrichment_with_bigtable_bad_row_filter(self):
def test_enrichment_with_bigtable_raises_key_error(self):
"""raises a `KeyError` when the row_key doesn't exist in
the input PCollection."""
bigtable = EnrichWithBigTable(
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
Expand All @@ -262,7 +276,7 @@ def test_enrichment_with_bigtable_raises_key_error(self):
def test_enrichment_with_bigtable_raises_not_found(self):
"""raises a `NotFound` exception when the GCP BigTable Cluster
doesn't exist."""
bigtable = EnrichWithBigTable(
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id='invalid_table',
Expand All @@ -279,7 +293,7 @@ def test_enrichment_with_bigtable_raises_not_found(self):
def test_enrichment_with_bigtable_exception_level(self):
"""raises a `ValueError` exception when the GCP BigTable query returns
an empty row."""
bigtable = EnrichWithBigTable(
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
Expand All @@ -295,6 +309,33 @@ def test_enrichment_with_bigtable_exception_level(self):
res = test_pipeline.run()
res.wait_until_finish()

def test_enrichment_with_bigtable_with_timestamp(self):
"""test whether the `(value,timestamp)` is returned when the
`include_timestamp` is enabled."""
expected_fields = [
'sale_id', 'customer_id', 'product_id', 'quantity', 'product'
]
expected_enriched_fields = {
'product': ['product_id', 'product_name', 'product_stock'],
}
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
row_key=self.row_key,
include_timestamp=True)
with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| "Create" >> beam.Create(self.req)
| "Enrich W/ BigTable" >> Enrichment(bigtable)
| "Validate Response" >> beam.ParDo(
ValidateResponse(
len(expected_fields),
expected_fields,
expected_enriched_fields,
include_timestamp=True)))


if __name__ == '__main__':
unittest.main()

0 comments on commit 674fe77

Please sign in to comment.