Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hugging face basic llm endpoint is implemented #598

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
void main() {
// TODO
// ignore_for_file: avoid_print, unused_element

import 'package:langchain_core/chat_models.dart';
import 'package:langchain_core/prompts.dart';

import 'package:langchain_huggingface/src/llm/huggingface_inference.dart';

void main() async {
// Uncomment the example you want to run:
await _example1();
await _example2();
}

/// The most basic building block of LangChain is calling an LLM on some input.
Future<void> _example1() async {
final huggingFace = HuggingfaceInference.call(
model: 'gpt2',
apiKey: '....API_KEY...',
);
final result = await huggingFace('Who are you?');
print(result);
}

Future<void> _example2() async {
final huggingFace = HuggingfaceInference.call(
model: 'gpt2',
apiKey: '....API_KEY...',
);

final str = huggingFace.stream(PromptValue.string('Who are you?'));

str.listen(print);
}
2 changes: 2 additions & 0 deletions packages/langchain_huggingface/lib/langchain_huggingface.dart
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
/// Hugging Face module for LangChain.dart.
library;

export 'src/llm/llm.dart';
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import 'package:huggingface_client/huggingface_client.dart';
import 'package:langchain_core/llms.dart';
import 'package:langchain_core/src/prompts/types.dart';
import 'package:meta/meta.dart';
import '../../langchain_huggingface.dart';
import 'mappers.dart';
import 'types.dart';

@immutable
class HuggingfaceInference extends BaseLLM<HuggingFaceOptions> {
const HuggingfaceInference._({
required this.model,
required this.apiKey,
required this.apiClient,
super.defaultOptions = const HuggingFaceOptions(),
});
final InferenceApi apiClient;
final String apiKey;
final String model;
factory HuggingfaceInference.call({
required String apiKey,
required String model,
}) {
final apiClient = InferenceApi(HuggingFaceClient.getInferenceClient(
apiKey, HuggingFaceClient.inferenceBasePath));
return HuggingfaceInference._(
model: model, apiKey: apiKey, apiClient: apiClient);
}
@override
Future<LLMResult> invoke(PromptValue input,
{HuggingFaceOptions? options}) async {
final parameters = ApiQueryNLPTextGeneration(
inputs: input.toString(),
temperature: options?.temperature ?? 1.0,
topK: options?.topK ?? 0,
topP: options?.topP ?? 0.0,
maxTime: options?.maxTime ?? -1.0,
returnFullText: options?.returnFullText ?? true,
repetitionPenalty: options?.repetitionPenalty ?? -1,
doSample: options?.doSample ?? true,
maxNewTokens: options?.maxNewTokens ?? -1,
options: InferenceOptions(
useCache: options?.useCache ?? true,
waitForModel: options?.waitForModel ?? false));
final result = await apiClient.queryNLPTextGeneration(
taskParameters: parameters, model: model);

return result![0]!.toLLMResult();
}

@override
Stream<LLMResult> stream(PromptValue input, {HuggingFaceOptions? options}) {
final query = ApiQueryNLPTextGeneration(
inputs: input.toString(),
temperature: options?.temperature ?? 1.0,
topK: options?.topK ?? 0,
topP: options?.topP ?? 0.0,
maxTime: options?.maxTime ?? -1.0,
returnFullText: options?.returnFullText ?? true,
repetitionPenalty: options?.repetitionPenalty ?? -1,
doSample: options?.doSample ?? true,
maxNewTokens: options?.maxNewTokens ?? -1,
options: InferenceOptions(
useCache: options?.useCache ?? true,
waitForModel: options?.waitForModel ?? false));
final stream = apiClient.textStreamGeneration(query: query, model: model);

return stream.map((response) => response.toLLMResult());
}

@override
String get modelType => 'llm';
@override
Future<List<int>> tokenize(PromptValue promptValue,
{HuggingFaceOptions? options}) async {
// TODO: implement tokenize
throw UnimplementedError();
}
}
2 changes: 2 additions & 0 deletions packages/langchain_huggingface/lib/src/llm/llm.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export 'huggingface_inference.dart';
export 'types.dart';
27 changes: 27 additions & 0 deletions packages/langchain_huggingface/lib/src/llm/mappers.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import 'package:huggingface_client/huggingface_client.dart';
import 'package:langchain_core/language_models.dart';
import 'package:langchain_core/llms.dart';

extension HuggingFaceResponseMapper on ApiResponseNLPTextGeneration {
//map to
LLMResult toLLMResult() {
return LLMResult(
id: 'id',
output: generatedText,
finishReason: FinishReason.unspecified,
metadata: {},
usage: const LanguageModelUsage());
}
}

extension HuggingFaceStreamResponseMapper on TextGenerationStreamResponse {
//map to
LLMResult toLLMResult() {
return LLMResult(
id: id.toString(),
output: text,
finishReason: FinishReason.unspecified,
metadata: {},
usage: const LanguageModelUsage());
}
}
95 changes: 95 additions & 0 deletions packages/langchain_huggingface/lib/src/llm/types.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import 'package:langchain_core/llms.dart';
import 'package:langchain_core/src/language_models/types.dart';
import 'package:meta/meta.dart';

@immutable
class HuggingFaceOptions extends LLMOptions {
const HuggingFaceOptions(
{this.topK,
this.topP,
super.model,
this.temperature,
this.repetitionPenalty,
this.maxNewTokens,
this.maxTime,
this.returnFullText,
this.numReturnSequences,
this.useCache,
this.waitForModel,
this.doSample});

/// (Default: true). Boolean. There is a cache layer on the inference API to speedup requests we have already seen.
/// Most models can use those results as is as models are deterministic (meaning the results will be the same anyway).
/// However if you use a non deterministic model, you can set this parameter to prevent the caching mechanism from being
/// used resulting in a real new query.
final bool? useCache;

/// (Default: false) Boolean. If the model is not ready, wait for it instead of receiving 503. It limits the number of requests
/// required to get your inference done. It is advised to only set this flag to true after receiving a 503
/// error as it will limit hanging in your application to known places.
final bool? waitForModel;

/// (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
final int? topK;

/// (Default: None). Float to define the tokens that are within the sample operation of text generation.
/// Add tokens in the sample for more probable to least probable until the sum of the probabilities
/// is greater than top_p.
final double? topP;

/// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
/// 0 means always take the highest score, 100.0 is getting closer to uniform probability.
final double? temperature;

/// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
/// to not be picked in successive generation passes.
final double? repetitionPenalty;

/// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input
/// length it is a estimate of the size of generated text you want. Each new tokens slows down the request,
/// so look for balance between response times and length of text generated.
final int? maxNewTokens;

/// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
/// Network can cause some overhead so it will be a soft limit. Use that in combination
/// with [maxNewTokens] for best results.
final double? maxTime;

/// (Default: True). Bool. If set to False, the return results will not contain the
/// original query making it easier for prompting.
final bool? returnFullText;

/// (Default: 1). Integer. The number of proposition you want to be returned.
final int? numReturnSequences;

/// (Optional: True). Bool. Whether or not to use sampling, use greedy
/// decoding otherwise
final bool? doSample;

@override
HuggingFaceOptions copyWith(
{final String? model,
final int? concurrencyLimit,
final int? topK,
final double? topP,
final double? temperature,
final double? repetitionPenalty,
final int? maxNewTokens,
final double? maxTime,
final bool? returnFullText,
final int? numReturnSequences,
final bool? doSample}) {
return HuggingFaceOptions(
model: model ?? this.model,
repetitionPenalty: repetitionPenalty ?? this.repetitionPenalty,
returnFullText: returnFullText ?? this.returnFullText,
numReturnSequences: numReturnSequences ?? this.numReturnSequences,
doSample: doSample ?? this.doSample,
topK: topK ?? this.topK,
temperature: temperature ?? this.temperature,
topP: topP ?? this.topP,
maxTime: maxTime ?? this.maxTime,
maxNewTokens: maxNewTokens ?? this.maxNewTokens,
);
}
}
3 changes: 3 additions & 0 deletions packages/langchain_huggingface/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ topics:

environment:
sdk: ">=3.4.0 <4.0.0"
dependencies:
huggingface_client: ^1.6.0
langchain_core: ^0.3.6
Loading