diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index b26833333238..c537844dc84a 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 2 + "modification": 3 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index b26833333238..c537844dc84a 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 2 + "modification": 3 } diff --git a/sdks/python/apache_beam/ml/rag/ingestion/base.py b/sdks/python/apache_beam/ml/rag/ingestion/base.py index fb9b2ac475dd..3187cdd37bc1 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/base.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/base.py @@ -33,16 +33,14 @@ class VectorDatabaseWriteConfig(ABC): 3. Transform handles converting Chunks to database-specific format Example implementation: - ```python - class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig): - def __init__(self, table: str): - self.embedding_column = embedding_column - - def create_write_transform(self): - return beam.io.WriteToBigQuery( - table=self.table - ) - ``` + >>> class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig): + ... def __init__(self, table: str): + ... self.embedding_column = embedding_column + ... + ... def create_write_transform(self): + ... return beam.io.WriteToBigQuery( + ... table=self.table + ... ) """ @abstractmethod def create_write_transform(self) -> beam.PTransform: @@ -67,16 +65,14 @@ class VectorDatabaseWriteTransform(beam.PTransform): the database-specific write transform. Example usage: - ```python - config = BigQueryVectorConfig( - table='project.dataset.embeddings', - embedding_column='embedding' - ) - - with beam.Pipeline() as p: - chunks = p | beam.Create([...]) # PCollection[Chunk] - chunks | VectorDatabaseWriteTransform(config) - ``` + >>> config = BigQueryVectorConfig( + ... table='project.dataset.embeddings', + ... embedding_column='embedding' + ... ) + >>> + >>> with beam.Pipeline() as p: + ... chunks = p | beam.Create([...]) # PCollection[Chunk] + ... chunks | VectorDatabaseWriteTransform(config) Args: database_config: Configuration for the target vector database. diff --git a/sdks/python/apache_beam/ml/rag/ingestion/base_test.py b/sdks/python/apache_beam/ml/rag/ingestion/base_test.py index 764cc161aaa5..f21b3e644b82 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/base_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/base_test.py @@ -15,13 +15,13 @@ # limitations under the License. import unittest -import apache_beam as beam -from apache_beam.ml.rag.types import Chunk, Embedding, Content -from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.testing.util import assert_that, equal_to +import apache_beam as beam from apache_beam.ml.rag.ingestion.base import ( VectorDatabaseWriteConfig, VectorDatabaseWriteTransform) +from apache_beam.ml.rag.types import Chunk, Content, Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, equal_to class MockWriteTransform(beam.PTransform): diff --git a/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py b/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py index c2173995e2e5..9e7bb84f114f 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py @@ -14,16 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass - -from typing import Optional, List, Dict, Any from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Dict, List, Optional import apache_beam as beam +from apache_beam.io.gcp.bigquery_tools import ( + beam_row_from_dict, get_beam_typehints_from_tableschema) from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig from apache_beam.ml.rag.types import Chunk from apache_beam.typehints.row_type import RowTypeConstraint -from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict, get_beam_typehints_from_tableschema ChunkToDictFn = Callable[[Chunk], Dict[str, any]] @@ -39,23 +39,23 @@ class SchemaConfig: Attributes: schema: BigQuery TableSchema dict defining the table structure. Example: - { - 'fields': [ - {'name': 'id', 'type': 'STRING'}, - {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'}, - {'name': 'custom_field', 'type': 'STRING'} - ] - } + >>> { + ... 'fields': [ + ... {'name': 'id', 'type': 'STRING'}, + ... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'}, + ... {'name': 'custom_field', 'type': 'STRING'} + ... ] + ... } chunk_to_dict_fn: Function that converts a Chunk to a dict matching the schema. Takes a Chunk and returns Dict[str, Any] with keys matching schema fields. Example: - def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]: - return { - 'id': chunk.id, - 'embedding': chunk.embedding.dense_embedding, - 'custom_field': chunk.metadata.get('custom_field') - } + >>> def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]: + ... return { + ... 'id': chunk.id, + ... 'embedding': chunk.embedding.dense_embedding, + ... 'custom_field': chunk.metadata.get('custom_field') + ... } """ schema: Dict chunk_to_dict_fn: ChunkToDictFn @@ -66,7 +66,7 @@ def __init__( self, write_config: Dict[str, Any], *, # Force keyword arguments - schema_config: Optional[SchemaConfig] + schema_config: Optional[SchemaConfig] = None ): """Configuration for writing vectors to BigQuery using managed transforms. @@ -74,32 +74,28 @@ def __init__( custom schemas through SchemaConfig. Example with default schema: - ```python - config = BigQueryVectorWriterConfig( - write_config={'table': 'project.dataset.embeddings'}) - ``` + >>> config = BigQueryVectorWriterConfig( + ... write_config={'table': 'project.dataset.embeddings'}) Example with custom schema: - ```python - schema_config = SchemaConfig( - schema={ - 'fields': [ - {'name': 'id', 'type': 'STRING'}, - {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'}, - {'name': 'source_url', 'type': 'STRING'} - ] - }, - chunk_to_dict_fn=lambda chunk: { - 'id': chunk.id, - 'embedding': chunk.embedding.dense_embedding, - 'source_url': chunk.metadata.get('url') - } - ) - config = BigQueryVectorWriterConfig( - write_config={'table': 'project.dataset.embeddings'}, - schema_config=schema_config - ) - ``` + >>> schema_config = SchemaConfig( + ... schema={ + ... 'fields': [ + ... {'name': 'id', 'type': 'STRING'}, + ... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'}, + ... {'name': 'source_url', 'type': 'STRING'} + ... ] + ... }, + ... chunk_to_dict_fn=lambda chunk: { + ... 'id': chunk.id, + ... 'embedding': chunk.embedding.dense_embedding, + ... 'source_url': chunk.metadata.get('url') + ... } + ... ) + >>> config = BigQueryVectorWriterConfig( + ... write_config={'table': 'project.dataset.embeddings'}, + ... schema_config=schema_config + ... ) Args: write_config: BigQuery write configuration dict. Must include 'table'. diff --git a/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py new file mode 100644 index 000000000000..8d94162411db --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py @@ -0,0 +1,212 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +import secrets +import time +import unittest + +import hamcrest as hc +import pytest + +import apache_beam as beam +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher +from apache_beam.ml.rag.ingestion.bigquery import ( + BigQueryVectorWriterConfig, SchemaConfig) +from apache_beam.ml.rag.types import Chunk, Content, Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.transforms.periodicsequence import PeriodicImpulse + + +@pytest.mark.uses_gcp_java_expansion_service +@unittest.skipUnless( + os.environ.get('EXPANSION_JARS'), + "EXPANSION_JARS environment var is not provided, " + "indicating that jars have not been built") +class BigQueryVectorWriterConfigTest(unittest.TestCase): + BIG_QUERY_DATASET_ID = 'python_rag_bigquery_' + + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self._runner = type(self.test_pipeline.runner).__name__ + self.project = self.test_pipeline.get_option('project') + + self.bigquery_client = BigQueryWrapper() + self.dataset_id = '%s%d%s' % ( + self.BIG_QUERY_DATASET_ID, int(time.time()), secrets.token_hex(3)) + self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id) + _LOGGER = logging.getLogger(__name__) + _LOGGER.info( + "Created dataset %s in project %s", self.dataset_id, self.project) + + def tearDown(self): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=self.project, datasetId=self.dataset_id, deleteContents=True) + try: + _LOGGER = logging.getLogger(__name__) + _LOGGER.info( + "Deleting dataset %s in project %s", self.dataset_id, self.project) + self.bigquery_client.client.datasets.Delete(request) + # Failing to delete a dataset should not cause a test failure. + except Exception: + _LOGGER = logging.getLogger(__name__) + _LOGGER.debug( + 'Failed to clean up dataset %s in project %s', + self.dataset_id, + self.project) + + def test_default_schema(self): + table_name = 'python_default_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + config = BigQueryVectorWriterConfig(write_config={'table': table_id}) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo"), + metadata={"a": "b"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar"), + metadata={"c": "d"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, content, embedding, metadata FROM %s" % table_id, + data=[("1", "foo", [0.1, 0.2], [{ + "key": "a", "value": "b" + }]), ("2", "bar", [0.3, 0.4], [{ + "key": "c", "value": "d" + }])]) + ] + + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers)) + with beam.Pipeline(argv=args) as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + def test_custom_schema(self): + table_name = 'python_custom_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + schema_config = SchemaConfig( + schema={ + 'fields': [{ + 'name': 'id', 'type': 'STRING' + }, + { + 'name': 'embedding', + 'type': 'FLOAT64', + 'mode': 'REPEATED' + }, { + 'name': 'source', 'type': 'STRING' + }] + }, + chunk_to_dict_fn=lambda chunk: { + 'id': chunk.id, + 'embedding': chunk.embedding.dense_embedding, + 'source': chunk.metadata.get('source') + }) + config = BigQueryVectorWriterConfig( + write_config={'table': table_id}, schema_config=schema_config) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo content"), + metadata={"source": "foo"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar content"), + metadata={"source": "bar"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, embedding, source FROM %s" % table_id, + data=[("1", [0.1, 0.2], "foo"), ("2", [0.3, 0.4], "bar")]) + ] + + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers)) + + with beam.Pipeline(argv=args) as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + def test_streaming_default_schema(self): + self.skip_if_not_dataflow_runner() + + table_name = 'python_streaming_default_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + config = BigQueryVectorWriterConfig(write_config={'table': table_id}) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo"), + metadata={"a": "b"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar"), + metadata={"c": "d"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, content, embedding, metadata FROM %s" % table_id, + data=[("0", "foo", [0.1, 0.2], [{ + "key": "a", "value": "b" + }]), ("2", "bar", [0.3, 0.4], [{ + "key": "c", "value": "d" + }])]) + ] + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers), + streaming=True, + allow_unsafe_triggers=True) + + with beam.Pipeline(argv=args) as p: + _ = ( + p + | PeriodicImpulse(0, 4, 1) + | beam.Map(lambda t: chunks[t]) + | config.create_write_transform()) + + def skip_if_not_dataflow_runner(self): + # skip if dataflow runner is not specified + if not self._runner or "dataflowrunner" not in self._runner.lower(): + self.skipTest( + "Streaming with exactly-once route has the requirement " + "`beam:requirement:pardo:on_window_expiration:v1`, " + "which is currently only supported by the Dataflow runner") + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle index e290e8003b13..1dd15ecb09f9 100644 --- a/sdks/python/test-suites/direct/common.gradle +++ b/sdks/python/test-suites/direct/common.gradle @@ -447,6 +447,7 @@ project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata -> pythonPipelineOptions: [ "--runner=TestDirectRunner", "--project=${gcpProject}", + "--temp_location=gs://temp-storage-for-end-to-end-tests/temp-it", ], pytestOptions: [ "--capture=no", // print stdout instantly