Skip to content

Commit

Permalink
Add base VectorDatabaseTransform.
Browse files Browse the repository at this point in the history
  • Loading branch information
claudevdm committed Dec 18, 2024
1 parent d34409e commit 55a3efb
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 0 deletions.
20 changes: 20 additions & 0 deletions sdks/python/apache_beam/ml/rag/ingestion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Vector storage ingestion components for RAG pipelines.
This module provides components for storing vectors in RAG pipelines.
"""
66 changes: 66 additions & 0 deletions sdks/python/apache_beam/ml/rag/ingestion/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

import apache_beam as beam
from apache_beam.ml.rag.types import Chunk


class VectorDatabaseConfig(ABC):
"""Abstract base class for vector database configurations.
Implementations should provide database-specific configuration and
create appropriate write transforms.
"""
@abstractmethod
def create_write_transform(self) -> beam.PTransform:
"""Creates a PTransform that writes to the vector database.
Returns:
A PTransform that writes embeddings to the configured database.
"""
pass


class VectorDatabaseWriteTransform(beam.PTransform):
"""Generic transform for writing to vector databases.
Uses the provided database config to create an appropriate write transform.
"""
def __init__(self, database_config: VectorDatabaseConfig):
"""Initialize transform with database config.
Args:
database_config: Configuration for target vector database.
"""
if not isinstance(database_config, VectorDatabaseConfig):
raise TypeError(
f"database_config must be VectorDatabaseConfig, "
f"got {type(database_config)}")
self.database_config = database_config

def expand(self, pcoll: beam.PCollection[Chunk]):
"""Create and apply database-specific write transform.
Args:
pcoll: PCollection of Chunk's with embeddings to write.
Returns:
Result of writing to database.
"""
write_transform = self.database_config.create_write_transform()
return pcoll | write_transform
78 changes: 78 additions & 0 deletions sdks/python/apache_beam/ml/rag/ingestion/base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import apache_beam as beam
from apache_beam.ml.rag.types import Chunk, Embedding, Content
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, equal_to

from apache_beam.ml.rag.ingestion.base import (
VectorDatabaseConfig, VectorDatabaseWriteTransform)


class MockWriteTransform(beam.PTransform):
"""Mock transform that returns element."""
def expand(self, pcoll):
return pcoll | beam.Map(lambda x: x)


class MockDatabaseConfig(VectorDatabaseConfig):
"""Mock database config for testing."""
def __init__(self):
self.write_transform = MockWriteTransform()

def create_write_transform(self) -> beam.PTransform:
return self.write_transform


class VectorDatabaseBaseTest(unittest.TestCase):
def test_write_transform_creation(self):
"""Test that write transform is created correctly."""
config = MockDatabaseConfig()
transform = VectorDatabaseWriteTransform(config)
self.assertEqual(transform.database_config, config)

def test_pipeline_integration(self):
"""Test writing through pipeline."""
test_data = [
Chunk(
content=Content(text="foo"),
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2])),
Chunk(
content=Content(text="bar"),
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]))
]

with TestPipeline() as p:
result = (
p
| beam.Create(test_data)
| VectorDatabaseWriteTransform(MockDatabaseConfig()))

# Verify data was written
assert_that(result, equal_to(test_data))

def test_invalid_config(self):
"""Test error handling for invalid config."""
with self.assertRaises(TypeError):
VectorDatabaseWriteTransform(None)


if __name__ == '__main__':
unittest.main()

0 comments on commit 55a3efb

Please sign in to comment.