Skip to content

Commit

Permalink
refactor: move Elasticsearch classes to dedicated package and use bui…
Browse files Browse the repository at this point in the history
…lder pattern
  • Loading branch information
sobychacko committed Dec 9, 2024
1 parent 778752a commit e4844e1
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions;
import org.springframework.ai.elasticsearch.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.elasticsearch.vectorstore.ElasticsearchVectorStoreOptions;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
Expand Down Expand Up @@ -73,9 +73,14 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti
elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity());
}

return new ElasticsearchVectorStore(elasticsearchVectorStoreOptions, restClient, embeddingModel,
properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
return ElasticsearchVectorStore.builder(restClient)
.options(elasticsearchVectorStoreOptions)
.embeddingModel(embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.batchingStrategy(batchingStrategy)
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.springframework.ai.autoconfigure.vectorstore.elasticsearch;

import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties;
import org.springframework.ai.vectorstore.SimilarityFunction;
import org.springframework.ai.elasticsearch.vectorstore.SimilarityFunction;
import org.springframework.boot.context.properties.ConfigurationProperties;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.document.Document;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.elasticsearch.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimilarityFunction;
import org.springframework.ai.elasticsearch.vectorstore.SimilarityFunction;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.elasticsearch.ElasticsearchRestClientAutoConfiguration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.text.ParseException;
import java.text.SimpleDateFormat;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -48,11 +48,12 @@
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -83,8 +84,6 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp
SimilarityFunction.cosine, VectorStoreSimilarityMetric.COSINE, SimilarityFunction.l2_norm,
VectorStoreSimilarityMetric.EUCLIDEAN, SimilarityFunction.dot_product, VectorStoreSimilarityMetric.DOT);

private final EmbeddingModel embeddingModel;

private final ElasticsearchClient elasticsearchClient;

private final ElasticsearchVectorStoreOptions options;
Expand All @@ -95,34 +94,43 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp

private final BatchingStrategy batchingStrategy;

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) {
this(new ElasticsearchVectorStoreOptions(), restClient, embeddingModel, initializeSchema);
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient,
EmbeddingModel embeddingModel, boolean initializeSchema) {
this(options, restClient, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null,
new TokenCountBatchingStrategy());
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient,
EmbeddingModel embeddingModel, boolean initializeSchema, ObservationRegistry observationRegistry,
VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) {

super(observationRegistry, customObservationConvention);
this(builder(restClient).options(options)
.embeddingModel(embeddingModel)
.initializeSchema(initializeSchema)
.observationRegistry(observationRegistry)
.customObservationConvention(customObservationConvention)
.batchingStrategy(batchingStrategy));
}

private ElasticsearchVectorStore(ElasticsearchBuilder builder) {
super(builder);
this.initializeSchema = builder.initializeSchema;
this.options = builder.options;
this.filterExpressionConverter = builder.filterExpressionConverter;
this.batchingStrategy = builder.batchingStrategy;

this.initializeSchema = initializeSchema;
Objects.requireNonNull(embeddingModel, "RestClient must not be null");
Objects.requireNonNull(embeddingModel, "EmbeddingModel must not be null");
String version = Version.VERSION == null ? "Unknown" : Version.VERSION.toString();
this.elasticsearchClient = new ElasticsearchClient(new RestClientTransport(restClient,
this.elasticsearchClient = new ElasticsearchClient(new RestClientTransport(builder.restClient,
new JacksonJsonpMapper(
new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false))))
.withTransportOptions(t -> t.addHeader("user-agent", "spring-ai elastic-java/" + version));
this.embeddingModel = embeddingModel;
this.options = options;
this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();
this.batchingStrategy = batchingStrategy;
}

@Override
Expand Down Expand Up @@ -297,4 +305,94 @@ private String getSimilarityMetric() {
public record ElasticSearchDocument(String id, String content, Map<String, Object> metadata, float[] embedding) {
}

/**
* Creates a new builder instance for ElasticsearchVectorStore.
* @param restClient the Elasticsearch REST client
* @return a new ElasticsearchBuilder instance
*/
public static ElasticsearchBuilder builder(RestClient restClient) {
return new ElasticsearchBuilder(restClient);
}

public static class ElasticsearchBuilder extends AbstractVectorStoreBuilder<ElasticsearchBuilder> {

private final RestClient restClient;

private ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();

private boolean initializeSchema = false;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private FilterExpressionConverter filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();

/**
* Creates a new builder instance with the specified REST client.
* @param restClient the Elasticsearch REST client
* @throws IllegalArgumentException if restClient is null
*/
public ElasticsearchBuilder(RestClient restClient) {
Assert.notNull(restClient, "RestClient must not be null");
this.restClient = restClient;
}

/**
* Sets the Elasticsearch vector store options.
* @param options the vector store options to use
* @return the builder instance
* @throws IllegalArgumentException if options is null
*/
public ElasticsearchBuilder options(ElasticsearchVectorStoreOptions options) {
Assert.notNull(options, "options must not be null");
this.options = options;
return this;
}

/**
* Sets whether to initialize the schema.
* @param initializeSchema true to initialize schema, false otherwise
* @return the builder instance
*/
public ElasticsearchBuilder initializeSchema(boolean initializeSchema) {
this.initializeSchema = initializeSchema;
return this;
}

/**
* Sets the batching strategy for vector operations.
* @param batchingStrategy the batching strategy to use
* @return the builder instance
* @throws IllegalArgumentException if batchingStrategy is null
*/
public ElasticsearchBuilder batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "batchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return this;
}

/**
* Sets the filter expression converter.
* @param converter the filter expression converter to use
* @return the builder instance
* @throws IllegalArgumentException if converter is null
*/
public ElasticsearchBuilder filterExpressionConverter(FilterExpressionConverter converter) {
Assert.notNull(converter, "filterExpressionConverter must not be null");
this.filterExpressionConverter = converter;
return this;
}

/**
* Builds the ElasticsearchVectorStore instance.
* @return a new ElasticsearchVectorStore instance
* @throws IllegalStateException if the builder is in an invalid state
*/
@Override
public ElasticsearchVectorStore build() {
validate();
return new ElasticsearchVectorStore(this);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

/**
* Provided Elasticsearch vector option configuration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

/**
* https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.util.Date;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import org.testcontainers.utility.DockerImageName;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -51,6 +51,7 @@
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
Expand Down Expand Up @@ -376,23 +377,34 @@ public static class TestApplication {

@Bean("vectorStore_cosine")
public ElasticsearchVectorStore vectorStoreDefault(EmbeddingModel embeddingModel, RestClient restClient) {
return new ElasticsearchVectorStore(restClient, embeddingModel, true);
return ElasticsearchVectorStore.builder(restClient)
.embeddingModel(embeddingModel)
.initializeSchema(true)
.build();
}

@Bean("vectorStore_l2_norm")
public ElasticsearchVectorStore vectorStoreL2(EmbeddingModel embeddingModel, RestClient restClient) {
ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
options.setIndexName("index_l2");
options.setSimilarity(SimilarityFunction.l2_norm);
return new ElasticsearchVectorStore(options, restClient, embeddingModel, true);
return ElasticsearchVectorStore.builder(restClient)
.embeddingModel(embeddingModel)
.initializeSchema(true)
.options(options)
.build();
}

@Bean("vectorStore_dot_product")
public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingModel, RestClient restClient) {
ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
options.setIndexName("index_dot_product");
options.setSimilarity(SimilarityFunction.dot_product);
return new ElasticsearchVectorStore(options, restClient, embeddingModel, true);
return ElasticsearchVectorStore.builder(restClient)
.embeddingModel(embeddingModel)
.initializeSchema(true)
.options(options)
.build();
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -51,6 +51,8 @@
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames;
Expand All @@ -67,6 +69,7 @@
/**
* @author Christian Tzolov
* @author Thomas Vitale
* @author Soby Chacko
*/
@Testcontainers
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
Expand Down Expand Up @@ -205,8 +208,14 @@ public TestObservationRegistry observationRegistry() {
@Bean
public ElasticsearchVectorStore vectorStoreDefault(EmbeddingModel embeddingModel, RestClient restClient,
ObservationRegistry observationRegistry) {
return new ElasticsearchVectorStore(new ElasticsearchVectorStoreOptions(), restClient, embeddingModel, true,
observationRegistry, null, new TokenCountBatchingStrategy());
return ElasticsearchVectorStore.builder(restClient)
.embeddingModel(embeddingModel)
.initializeSchema(true)
.options(new ElasticsearchVectorStoreOptions())
.observationRegistry(observationRegistry)
.customObservationConvention(null)
.batchingStrategy(new TokenCountBatchingStrategy())
.build();
}

@Bean
Expand Down

0 comments on commit e4844e1

Please sign in to comment.