Skip to content

Commit

Permalink
Updated tests to use globally=False
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 2, 2024
1 parent b88ebed commit 70ae49a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
12 changes: 8 additions & 4 deletions tests/test_peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,31 +169,35 @@ def test_vector_avg(self):
Item.create(embedding=[1, 2, 3])
Item.create(embedding=[4, 5, 6])
avg = Item.select(fn.avg(Item.embedding)).scalar()
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
# does not type cast
assert avg == '[2.5,3.5,4.5]'

def test_vector_sum(self):
sum = Item.select(fn.sum(Item.embedding)).scalar()
assert sum is None
Item.create(embedding=[1, 2, 3])
Item.create(embedding=[4, 5, 6])
sum = Item.select(fn.sum(Item.embedding)).scalar()
assert np.array_equal(sum, np.array([5, 7, 9]))
# does not type cast
assert sum == '[5,7,9]'

def test_halfvec_avg(self):
avg = Item.select(fn.avg(Item.half_embedding)).scalar()
assert avg is None
Item.create(half_embedding=[1, 2, 3])
Item.create(half_embedding=[4, 5, 6])
avg = Item.select(fn.avg(Item.half_embedding)).scalar()
assert avg.to_list() == [2.5, 3.5, 4.5]
# does not type cast
assert avg == '[2.5,3.5,4.5]'

def test_halfvec_sum(self):
sum = Item.select(fn.sum(Item.half_embedding)).scalar()
assert sum is None
Item.create(half_embedding=[1, 2, 3])
Item.create(half_embedding=[4, 5, 6])
sum = Item.select(fn.sum(Item.half_embedding)).scalar()
assert sum.to_list() == [5, 7, 9]
# does not type cast
assert sum == '[5,7,9]'

def test_get_or_create(self):
Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]})
Expand Down
6 changes: 3 additions & 3 deletions tests/test_psycopg2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
cur.execute('DROP TABLE IF EXISTS psycopg2_items')
cur.execute('CREATE TABLE psycopg2_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))')

register_vector(cur)
register_vector(cur, globally=False)


class TestPsycopg2:
Expand Down Expand Up @@ -59,11 +59,11 @@ def test_cursor_factory(self):
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
conn = psycopg2.connect(dbname='pgvector_python_test')
cur = conn.cursor(cursor_factory=cursor_factory)
register_vector(cur)
register_vector(cur, globally=False)
conn.close()

def test_cursor_factory_connection(self):
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
conn = psycopg2.connect(dbname='pgvector_python_test', cursor_factory=cursor_factory)
register_vector(conn)
register_vector(conn, globally=False)
conn.close()
6 changes: 4 additions & 2 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ def test_avg(self):
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
avg = session.query(func.avg(Item.embedding)).first()[0]
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
# does not type cast
assert avg == '[2.5,3.5,4.5]'

def test_avg_orm(self):
with Session(engine) as session:
Expand All @@ -346,7 +347,8 @@ def test_avg_orm(self):
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
avg = session.scalars(select(func.avg(Item.embedding))).first()
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
# does not type cast
assert avg == '[2.5,3.5,4.5]'

def test_sum(self):
with Session(engine) as session:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_sqlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def test_vector_avg(self):
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
avg = session.exec(select(func.avg(Item.embedding))).first()
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
# does not type cast
assert avg == '[2.5,3.5,4.5]'

def test_vector_sum(self):
with Session(engine) as session:
Expand All @@ -221,7 +222,8 @@ def test_halfvec_avg(self):
session.add(Item(half_embedding=[1, 2, 3]))
session.add(Item(half_embedding=[4, 5, 6]))
avg = session.exec(select(func.avg(Item.half_embedding))).first()
assert avg.to_list() == [2.5, 3.5, 4.5]
# does not type cast
assert avg == '[2.5,3.5,4.5]'

def test_halfvec_sum(self):
with Session(engine) as session:
Expand Down

0 comments on commit 70ae49a

Please sign in to comment.