Skip to content

Commit

Permalink
Add a working neural search sample (#614)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia authored Apr 19, 2024
1 parent 553f7d1 commit c448e8c
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 1 deletion.
266 changes: 266 additions & 0 deletions samples/Samples/NeuralSearch/NeuralSearchSample.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
/* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

using System.Diagnostics;
using OpenSearch.Client;
using OpenSearch.Net;

namespace Samples.NeuralSearch;

/// <summary>
/// Sample based off of the <a href="https://opensearch.org/docs/latest/search-plugins/neural-search-tutorial">Neural Search Tutorial</a>
/// </summary>
public class NeuralSearchSample : Sample
{
private const string SampleName = "neural-search";
private const string ResourceNamePrefix = "csharp-" + SampleName;
private const string MlModelGroupName = ResourceNamePrefix + "-model-group";
private const string IngestPipelineName = ResourceNamePrefix + "-ingest-pipeline";
private const string IndexName = ResourceNamePrefix + "-index";

private string? _modelGroupId;
private string? _modelRegistrationTaskId;
private string? _modelId;
private string? _modelDeployTaskId;
private bool _putIngestPipeline;
private bool _createdIndex;

public NeuralSearchSample() : base(SampleName, "A sample demonstrating how to perform a neural search query") { }

public class NeuralSearchDoc
{
[PropertyName("id")] public string? Id { get; set; }
[PropertyName("text")] public string? Text { get; set; }
[PropertyName("passage_embedding")] public float[]? PassageEmbedding { get; set; }
}

protected override async Task Run(IOpenSearchClient client)
{
// Temporarily configure the cluster to allow local running of the ML model
var putSettingsResp = await client.Cluster.PutSettingsAsync(s => s
.Transient(p => p
.Add("plugins.ml_commons.only_run_on_ml_node", false)
.Add("plugins.ml_commons.model_access_control_enabled", true)
.Add("plugins.ml_commons.native_memory_threshold", 99)));
Debug.Assert(putSettingsResp.IsValid, putSettingsResp.DebugInformation);
Console.WriteLine("Configured cluster to allow local execution of the ML model");

// Register an ML model group
var registerModelGroupResp = await client.Http.PostAsync<DynamicResponse>(
"/_plugins/_ml/model_groups/_register",
r => r.SerializableBody(new
{
name = MlModelGroupName,
description = $"A model group for the opensearch-net {SampleName} sample",
access_mode = "public"
}));
Debug.Assert(registerModelGroupResp.Success && (string) registerModelGroupResp.Body.status == "CREATED", registerModelGroupResp.DebugInformation);
Console.WriteLine($"Model group named {MlModelGroupName} {registerModelGroupResp.Body.status}: {registerModelGroupResp.Body.model_group_id}");
_modelGroupId = (string) registerModelGroupResp.Body.model_group_id;

// Register the ML model
var registerModelResp = await client.Http.PostAsync<DynamicResponse>(
"/_plugins/_ml/models/_register",
r => r.SerializableBody(new
{
name = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b",
version = "1.0.1",
model_group_id = _modelGroupId,
model_format = "TORCH_SCRIPT"
}));
Debug.Assert(registerModelResp.Success && (string) registerModelResp.Body.status == "CREATED", registerModelResp.DebugInformation);
Console.WriteLine($"Model registration task {registerModelResp.Body.status}: {registerModelResp.Body.task_id}");
_modelRegistrationTaskId = (string) registerModelResp.Body.task_id;

// Wait for ML model registration to complete
while (true)
{
var getTaskResp = await client.Http.GetAsync<DynamicResponse>($"/_plugins/_ml/tasks/{_modelRegistrationTaskId}");
Console.WriteLine($"Model registration: {getTaskResp.Body.state}");
Debug.Assert(getTaskResp.Success && (string) getTaskResp.Body.state != "FAILED", getTaskResp.DebugInformation);
if (((string)getTaskResp.Body.state).StartsWith("COMPLETED"))
{
_modelId = getTaskResp.Body.model_id;
break;
}
await Task.Delay(10000);
}
Console.WriteLine($"Model registered: {_modelId}");

// Deploy the ML model
var deployModelResp = await client.Http.PostAsync<DynamicResponse>($"/_plugins/_ml/models/{_modelId}/_deploy");
Debug.Assert(deployModelResp.Success && (string) deployModelResp.Body.status == "CREATED", deployModelResp.DebugInformation);
Console.WriteLine($"Model deployment task {deployModelResp.Body.status}: {deployModelResp.Body.task_id}");
_modelDeployTaskId = (string) deployModelResp.Body.task_id;

// Wait for ML model deployment to complete
while (true)
{
var getTaskResp = await client.Http.GetAsync<DynamicResponse>($"/_plugins/_ml/tasks/{_modelDeployTaskId}");
Console.WriteLine($"Model deployment: {getTaskResp.Body.state}");
Debug.Assert(getTaskResp.Success && (string) getTaskResp.Body.state != "FAILED", getTaskResp.DebugInformation);
if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) break;
await Task.Delay(10000);
}
Console.WriteLine($"Model deployed: {_modelId}");

// Create the text_embedding ingest pipeline
// TODO: Client does not yet contain typings for the text_embedding processor
var putIngestPipelineResp = await client.Http.PutAsync<PutPipelineResponse>(
$"/_ingest/pipeline/{IngestPipelineName}",
r => r.SerializableBody(new
{
description = $"A text_embedding ingest pipeline for the opensearch-net {SampleName} sample",
processors = new[]
{
new
{
text_embedding = new
{
model_id = _modelId,
field_map = new
{
text = "passage_embedding"
}
}
}
}
}));
Debug.Assert(putIngestPipelineResp.IsValid, putIngestPipelineResp.DebugInformation);
Console.WriteLine($"Put ingest pipeline {IngestPipelineName}: {putIngestPipelineResp.Acknowledged}");
_putIngestPipeline = true;

// Create the index
var createIndexResp = await client.Indices.CreateAsync(
IndexName,
i => i
.Settings(s => s
.Setting("index.knn", true)
.DefaultPipeline(IngestPipelineName))
.Map<NeuralSearchDoc>(m => m
.Properties(p => p
.Text(t => t.Name(d => d.Id))
.Text(t => t.Name(d => d.Text))
.KnnVector(k => k
.Name(d => d.PassageEmbedding)
.Dimension(768)
.Method(km => km
.Engine("lucene")
.SpaceType("l2")
.Name("hnsw"))))));
Debug.Assert(createIndexResp.IsValid, createIndexResp.DebugInformation);
Console.WriteLine($"Created index {IndexName}: {createIndexResp.Acknowledged}");
_createdIndex = true;

// Index some documents
var documents = new NeuralSearchDoc[]
{
new() { Id = "4319130149.jpg", Text = "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena ." },
new() { Id = "1775029934.jpg", Text = "A wild animal races across an uncut field with a minimal amount of trees ." },
new() { Id = "2664027527.jpg", Text = "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco ." },
new() { Id = "4427058951.jpg", Text = "A man who is riding a wild horse in the rodeo is very near to falling off ." },
new() { Id = "2691147709.jpg", Text = "A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse ." }
};
var bulkResp = await client.BulkAsync(b => b
.Index(IndexName)
.IndexMany(documents)
.Refresh(Refresh.WaitFor));
Debug.Assert(bulkResp.IsValid, bulkResp.DebugInformation);
Console.WriteLine($"Indexed {documents.Length} documents");

// Perform the neural search
// TODO: Client does not yet contain typings for neural query type
Console.WriteLine("Performing neural search for text 'wild west'");
var searchResp = await client.Http.PostAsync<SearchResponse<NeuralSearchDoc>>(
$"/{IndexName}/_search",
r => r.SerializableBody(new
{
_source = new { excludes = new[] { "passage_embedding" } },
query = new
{
neural = new
{
passage_embedding = new
{
query_text = "wild west",
model_id = _modelId,
k = 5
}
}
}
}));
Debug.Assert(searchResp.IsValid, searchResp.DebugInformation);
Console.WriteLine($"Found {searchResp.Hits.Count} documents");
foreach (var hit in searchResp.Hits) Console.WriteLine($"- Document id: {hit.Source.Id}, score: {hit.Score}, text: {hit.Source.Text}");
}

protected override async Task Cleanup(IOpenSearchClient client)
{
Console.WriteLine("\n\n-- CLEANING UP --");
if (_createdIndex)
{
// Cleanup the index
var deleteIndexResp = await client.Indices.DeleteAsync(IndexName);
Debug.Assert(deleteIndexResp.IsValid, deleteIndexResp.DebugInformation);
Console.WriteLine($"Deleted index: {deleteIndexResp.Acknowledged}");
}

if (_putIngestPipeline)
{
// Cleanup the ingest pipeline
var deleteIngestPipelineResp = await client.Ingest.DeletePipelineAsync(IngestPipelineName);
Debug.Assert(deleteIngestPipelineResp.IsValid, deleteIngestPipelineResp.DebugInformation);
Console.WriteLine($"Deleted ingest pipeline: {deleteIngestPipelineResp.Acknowledged}");
}

if (_modelDeployTaskId != null)
{
// Cleanup the model deployment task
var deleteModelDeployTaskResp = await client.Http.DeleteAsync<DynamicResponse>($"/_plugins/_ml/tasks/{_modelDeployTaskId}");
Debug.Assert(deleteModelDeployTaskResp.Success && (string) deleteModelDeployTaskResp.Body.result == "deleted", deleteModelDeployTaskResp.DebugInformation);
Console.WriteLine($"Deleted model deployment task: {deleteModelDeployTaskResp.Body.result}");
}

if (_modelId != null)
{
while (true)
{
// Try cleanup the ML model
var deleteModelResp = await client.Http.DeleteAsync<DynamicResponse>($"/_plugins/_ml/models/{_modelId}");
if (deleteModelResp.Success)
{
Console.WriteLine($"Deleted model: {deleteModelResp.Body.result}");
break;
}

Debug.Assert(((string?)deleteModelResp.Body.error?.reason)?.Contains("Try undeploy") ?? false, deleteModelResp.DebugInformation);

// Undeploy the ML model
var undeployModelResp = await client.Http.PostAsync<DynamicResponse>($"/_plugins/_ml/models/{_modelId}/_undeploy");
Debug.Assert(undeployModelResp.Success, undeployModelResp.DebugInformation);
Console.WriteLine("Undeployed model");
await Task.Delay(10000);
}
}

if (_modelRegistrationTaskId != null)
{
// Cleanup the model registration task
var deleteModelRegistrationTaskResp = await client.Http.DeleteAsync<DynamicResponse>($"/_plugins/_ml/tasks/{_modelRegistrationTaskId}");
Debug.Assert(deleteModelRegistrationTaskResp.Success && (string) deleteModelRegistrationTaskResp.Body.result == "deleted", deleteModelRegistrationTaskResp.DebugInformation);
Console.WriteLine($"Deleted model registration task: {deleteModelRegistrationTaskResp.Body.result}");
}

if (_modelGroupId != null)
{
// Cleanup the model group
var deleteModelGroupResp = await client.Http.DeleteAsync<DynamicResponse>($"/_plugins/_ml/model_groups/{_modelGroupId}");
Debug.Assert(deleteModelGroupResp.Success && (string) deleteModelGroupResp.Body.result == "deleted", deleteModelGroupResp.DebugInformation);
Console.WriteLine($"Deleted model group: {deleteModelGroupResp.Body.result}");
}
}
}
21 changes: 20 additions & 1 deletion samples/Samples/Sample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,29 @@ public Command AsCommand(IValueDescriptor<IOpenSearchClient> clientDescriptor)
{
var command = new Command(_name, _description);

command.SetHandler(Run, clientDescriptor);
command.SetHandler(async client =>
{
try
{
await Run(client);
}
finally
{
try
{
await Cleanup(client);
}
catch (Exception e)
{
await Console.Error.WriteLineAsync($"Cleanup Failed: {e}");
}
}
}, clientDescriptor);

return command;
}

protected abstract Task Run(IOpenSearchClient client);

protected virtual Task Cleanup(IOpenSearchClient client) => Task.CompletedTask;
}

0 comments on commit c448e8c

Please sign in to comment.