Skip to content

Commit

Permalink
fix: es_vec test
Browse files Browse the repository at this point in the history
  • Loading branch information
aatmanvaidya committed Sep 12, 2024
1 parent a5d4df8 commit 08b3085
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions src/tests/core/store/test_es_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests
from requests.exceptions import ConnectTimeout
from core.store.es_vec import ES
from core.config import StoreConfig, StoreParameters
from core.config import StoreConfig, StoreEntity, StoreESParameters
from core.models.media import MediaType
import pprint
from datetime import datetime
Expand Down Expand Up @@ -33,15 +33,19 @@ def setUpClass(cls) -> None:
"audio_index_name": "test_audio",
}
cls.param = StoreConfig(
label="test",
type="es",
parameters=StoreParameters(
host_name=param_dict["host_name"],
image_index_name=param_dict["image_index_name"],
text_index_name=param_dict["text_index_name"],
video_index_name=param_dict["video_index_name"],
audio_index_name=param_dict["audio_index_name"],
),
entities=[
StoreEntity(
label="test",
type="es",
parameters=StoreESParameters(
host_name=param_dict["host_name"],
image_index_name=param_dict["image_index_name"],
text_index_name=param_dict["text_index_name"],
video_index_name=param_dict["video_index_name"],
audio_index_name=param_dict["audio_index_name"],
),
)
]
)
except ConnectTimeout:
print('Request has timed out')
Expand All @@ -56,8 +60,8 @@ def tearDownClass(cls) -> None:

# @skip
def test_create_indices(self):
print(self.param)
es = ES(self.param)
print(self.param.entities[0])
es = ES(self.param.entities[0])
es.connect()
es.optionally_create_index()
indices = es.get_indices()
Expand Down Expand Up @@ -96,7 +100,7 @@ def test_store_image(self):

# @skip
def test_store_and_search_vectors(self):
es = ES(self.param)
es = ES(self.param.entities[0])
es.connect()
es.optionally_create_index()
vec = np.random.randn(512).tolist()
Expand Down Expand Up @@ -133,7 +137,7 @@ def test_find_by_metadata_field(self):

def delete_indices(self):
print("DELETING INDICES")
es = ES(self.param)
es = ES(self.param.entities[0])
es.connect()
es.delete_indices()
print("INDICES DELETED")
Expand Down

0 comments on commit 08b3085

Please sign in to comment.