Skip to content

Commit

Permalink
refactor(chromadb): Update generated Chroma API client (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
walsha2 authored Aug 30, 2023
1 parent 618be2f commit 4f0e737
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 37 deletions.
96 changes: 60 additions & 36 deletions packages/chromadb/lib/src/generated/client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import 'dart:io' as io;
import 'dart:convert';
import 'package:http/http.dart' as http;
import 'dart:typed_data';
import 'package:http/retry.dart';
import 'schema/schema.dart';

Expand Down Expand Up @@ -62,25 +63,27 @@ class ChromaApiClientException implements io.HttpException {
/// `client`: Override HTTP client to use for requests
class ChromaApiClient {
ChromaApiClient({
this.host,
String? host,
http.Client? client,
}) {
// Create a retry client
_client = RetryClient(client ?? http.Client());
this.client = RetryClient(client ?? http.Client());
// Ensure trailing slash is removed from host
this.host = host?.replaceAll(RegExp(r'/$'), '');
}

/// User provided override for host URL
late final String? host;

/// Internal HTTP client
late final http.Client _client;
/// HTTP client for requests
late final http.Client client;

// ------------------------------------------
// METHOD: endSession
// ------------------------------------------

/// Close the HTTP client and end session
void endSession() => _client.close();
void endSession() => client.close();

// ------------------------------------------
// METHOD: onRequest
Expand Down Expand Up @@ -130,21 +133,42 @@ class ChromaApiClient {
}

// Ensure query parameters are properly encoded
queryParams = queryParams.map(
(key, value) => MapEntry(key, Uri.encodeComponent(value.toString())));

// Determine the connection type
final hostUri = Uri.parse(host);
secure ??= hostUri.scheme == 'https';
queryParams = queryParams.map((key, value) {
if (value is List) {
return MapEntry(
key,
value.map((v) => Uri.encodeComponent(v.toString())).toList(),
);
} else {
return MapEntry(
key,
Uri.encodeComponent(value.toString()),
);
}
});

// Build the request URI
final uri = Uri(
scheme: secure ? 'https' : 'http',
host: hostUri.host,
port: hostUri.port,
path: path,
queryParameters: queryParams.isEmpty ? null : queryParams,
);
secure ??= Uri.parse(host).scheme == 'https';
Uri uri;
String authority;
if (host.contains('http')) {
authority = Uri.parse(host).authority;
} else {
authority = Uri.parse(Uri.https(host).toString()).authority;
}
if (secure) {
uri = Uri.https(
authority,
path,
queryParams.isEmpty ? null : queryParams,
);
} else {
uri = Uri.http(
authority,
path,
queryParams.isEmpty ? null : queryParams,
);
}

// Build the headers
Map<String, String> headers = {}..addAll(headerParams);
Expand Down Expand Up @@ -198,7 +222,7 @@ class ChromaApiClient {
request = await onRequest(request);

// Submit request
response = await http.Response.fromStream(await _client.send(request));
response = await http.Response.fromStream(await client.send(request));

// Handle user response middleware
response = await onResponse(response);
Expand Down Expand Up @@ -236,7 +260,7 @@ class ChromaApiClient {
/// `GET` `http://localhost:8000/api/v1`
Future<Map<String, int>> root() async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1',
secure: false,
method: HttpMethod.get,
Expand All @@ -256,7 +280,7 @@ class ChromaApiClient {
/// `POST` `http://localhost:8000/api/v1/reset`
Future<bool> reset() async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/reset',
secure: false,
method: HttpMethod.post,
Expand All @@ -276,15 +300,15 @@ class ChromaApiClient {
/// `GET` `http://localhost:8000/api/v1/version`
Future<String> version() async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/version',
secure: false,
method: HttpMethod.get,
isMultipart: false,
requestType: '',
responseType: 'application/json',
);
return json.decode(r.body);
return r.body;
}

// ------------------------------------------
Expand All @@ -296,7 +320,7 @@ class ChromaApiClient {
/// `GET` `http://localhost:8000/api/v1/heartbeat`
Future<Map<String, int>> heartbeat() async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/heartbeat',
secure: false,
method: HttpMethod.get,
Expand All @@ -316,7 +340,7 @@ class ChromaApiClient {
/// `GET` `http://localhost:8000/api/v1/collections`
Future<List<CollectionType>> listCollections() async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections',
secure: false,
method: HttpMethod.get,
Expand All @@ -341,7 +365,7 @@ class ChromaApiClient {
required CreateCollection request,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections',
secure: false,
method: HttpMethod.post,
Expand Down Expand Up @@ -369,7 +393,7 @@ class ChromaApiClient {
required AddEmbedding request,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionId/add',
secure: false,
method: HttpMethod.post,
Expand Down Expand Up @@ -397,7 +421,7 @@ class ChromaApiClient {
required UpdateEmbedding request,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionId/update',
secure: false,
method: HttpMethod.post,
Expand Down Expand Up @@ -425,7 +449,7 @@ class ChromaApiClient {
required AddEmbedding request,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionId/upsert',
secure: false,
method: HttpMethod.post,
Expand Down Expand Up @@ -453,7 +477,7 @@ class ChromaApiClient {
required GetEmbedding request,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionId/get',
secure: false,
method: HttpMethod.post,
Expand Down Expand Up @@ -481,7 +505,7 @@ class ChromaApiClient {
required DeleteEmbedding request,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionId/delete',
secure: false,
method: HttpMethod.post,
Expand All @@ -506,7 +530,7 @@ class ChromaApiClient {
required String collectionId,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionId/count',
secure: false,
method: HttpMethod.get,
Expand All @@ -533,7 +557,7 @@ class ChromaApiClient {
required QueryEmbedding request,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionId/query',
secure: false,
method: HttpMethod.post,
Expand All @@ -558,7 +582,7 @@ class ChromaApiClient {
required String collectionName,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionName',
secure: false,
method: HttpMethod.get,
Expand All @@ -582,7 +606,7 @@ class ChromaApiClient {
required String collectionName,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionName',
secure: false,
method: HttpMethod.delete,
Expand All @@ -609,7 +633,7 @@ class ChromaApiClient {
required UpdateCollection request,
}) async {
final r = await _request(
host: 'http://localhost:8000',
host: 'localhost:8000',
path: '/api/v1/collections/$collectionId',
secure: false,
method: HttpMethod.put,
Expand Down
2 changes: 1 addition & 1 deletion packages/chromadb/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ dev_dependencies:
build_runner: ^2.4.6
freezed: ^2.4.2
json_serializable: ^6.7.1
openapi_spec: ^0.4.10
openapi_spec: ^0.5.0
test: ^1.24.3

0 comments on commit 4f0e737

Please sign in to comment.